Skip to content

Commit

Permalink
Feat/multi dict prover (#553)
Browse files Browse the repository at this point in the history
* refactor: disallow explicit dictionaries for anything other than v0 decompression setup

* refactor: pass around dictionary stores instead of individual dictionaries

* feat: add new dictionary

---------

Co-authored-by: Arya Tabaie <15056835+Tabaie@users.noreply.github.com>
  • Loading branch information
Tabaie and Tabaie authored Jan 21, 2025
1 parent 39679c7 commit 543aedd
Show file tree
Hide file tree
Showing 24 changed files with 111 additions and 189 deletions.
2 changes: 1 addition & 1 deletion docker/config/prover/v3/prover-config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ requests_root_dir = "/data/prover/v3/execution"
[blob_decompression]
prover_mode = "dev"
requests_root_dir = "/data/prover/v3/compression"
dict_path = "/opt/linea/prover/lib/compressor/compressor_dict.bin"
dict_paths = ["/opt/linea/prover/lib/compressor/compressor_dict.bin"]

[aggregation]
prover_mode = "dev"
Expand Down
3 changes: 1 addition & 2 deletions prover/backend/aggregation/prove.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,10 @@ func makePiProof(cfg *config.Config, cf *CollectedFields) (plonk.Proof, witness.
}

assignment, err := c.Assign(pi_interconnection.Request{
DictPath: cfg.BlobDecompression.DictPath,
Decompressions: cf.DecompressionPI,
Executions: cf.ExecutionPI,
Aggregation: cf.AggregationPublicInput(cfg),
})
}, cfg.BlobDecompressionDictStore(string(circuits.BlobDecompressionV1CircuitID))) // TODO @Tabaie: when there is a version 2, input the compressor version to use here
if err != nil {
return nil, nil, fmt.Errorf("could not assign the public input circuit: %w", err)
}
Expand Down
13 changes: 3 additions & 10 deletions prover/backend/blobdecompression/prove.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"bytes"
"encoding/base64"
"fmt"
"os"

blob_v0 "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v0"
blob_v1 "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v1"

Expand Down Expand Up @@ -68,14 +66,9 @@ func Prove(cfg *config.Config, req *Request) (*Response, error) {
return nil, fmt.Errorf("unsupported blob version: %v", version)
}

dictPath := cfg.BlobDecompressionDictPath(string(circuitID))

logrus.Infof("reading the dictionary at %v", dictPath)
logrus.Info("reading dictionaries")

dict, err := os.ReadFile(dictPath)
if err != nil {
return nil, fmt.Errorf("error reading the dictionary: %w", err)
}
dictStore := cfg.BlobDecompressionDictStore(string(circuitID))

// This computes the assignment

Expand All @@ -88,7 +81,7 @@ func Prove(cfg *config.Config, req *Request) (*Response, error) {

assignment, pubInput, _snarkHash, err := blobdecompression.Assign(
utils.RightPad(blobBytes, expectedMaxUsableBytes),
dict,
dictStore,
req.Eip4844Enabled,
xBytes,
y,
Expand Down
7 changes: 4 additions & 3 deletions prover/circuits/blobdecompression/circuit.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package blobdecompression

import (
"errors"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/dictionary"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
fr381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
Expand All @@ -19,12 +20,12 @@ func Compile(dictionaryNbBytes int) constraint.ConstraintSystem {

// Assign the circuit with concrete data. Returns the assigned circuit and the
// public input computed during the assignment.
func Assign(blobData []byte, dict []byte, eip4844Enabled bool, x [32]byte, y fr381.Element) (circuit frontend.Circuit, publicInput fr.Element, snarkHash []byte, err error) {
func Assign(blobData []byte, dictStore dictionary.Store, eip4844Enabled bool, x [32]byte, y fr381.Element) (circuit frontend.Circuit, publicInput fr.Element, snarkHash []byte, err error) {
switch blob.GetVersion(blobData) {
case 1:
return v1.Assign(blobData, dict, eip4844Enabled, x, y)
return v1.Assign(blobData, dictStore, eip4844Enabled, x, y)
case 0:
return v0.Assign(blobData, dict, eip4844Enabled, x, y)
return v0.Assign(blobData, dictStore, eip4844Enabled, x, y)
}
err = errors.New("decompression circuit assignment : unsupported blob version")
return
Expand Down
14 changes: 7 additions & 7 deletions prover/circuits/blobdecompression/v0/assign_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ import (
)

func TestBlobV0(t *testing.T) {
resp, blobBytes, dict := mustGetTestCompressedData(t)
dict := lzss.AugmentDict(test_utils.GetDict(t))
dictStore, err := dictionary.SingletonStore(dict, 0)
assert.NoError(t, err)

resp, blobBytes := mustGetTestCompressedData(t, dictStore)
circ := v0.Allocate(dict)

logrus.Infof("Building the constraint system")
Expand All @@ -46,7 +50,7 @@ func TestBlobV0(t *testing.T) {
givenSnarkHash, err := utils.HexDecodeString(resp.SnarkHash)
assert.NoError(t, err)

a, _, snarkHash, err := blobdecompression.Assign(blobBytes, dict, true, x, y)
a, _, snarkHash, err := blobdecompression.Assign(blobBytes, dictStore, true, x, y)
assert.NoError(t, err)
_, ok := a.(*v0.Circuit)
assert.True(t, ok)
Expand All @@ -64,9 +68,7 @@ func TestBlobV0(t *testing.T) {

// mustGetTestCompressedData is a test utility function that we use to get
// actual compressed data from the
func mustGetTestCompressedData(t *testing.T) (resp blobsubmission.Response, blobBytes []byte, dict []byte) {
dict = lzss.AugmentDict(test_utils.GetDict(t))

func mustGetTestCompressedData(t *testing.T, dictStore dictionary.Store) (resp blobsubmission.Response, blobBytes []byte) {
respJson, err := os.ReadFile("sample-blob.json")
assert.NoError(t, err)

Expand All @@ -75,8 +77,6 @@ func mustGetTestCompressedData(t *testing.T) (resp blobsubmission.Response, blob
blobBytes, err = base64.StdEncoding.DecodeString(resp.CompressedData)
assert.NoError(t, err)

dictStore, err := dictionary.SingletonStore(dict, 0)
assert.NoError(t, err)
_, _, _, err = blob.DecompressBlob(blobBytes, dictStore)
assert.NoError(t, err)

Expand Down
7 changes: 1 addition & 6 deletions prover/circuits/blobdecompression/v0/prelude.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func MakeCS(dict []byte) constraint.ConstraintSystem {
// Assign the circuit with concrete data. Returns the assigned circuit and the
// public input computed during the assignment.
// @alexandre.belling should we instead compute snarkHash independently here? Seems like it doesn't need to be included in the req received by Prove
func Assign(blobData, dict []byte, eip4844Enabled bool, x [32]byte, y fr381.Element) (assignment frontend.Circuit, publicInput fr.Element, snarkHash []byte, err error) {
func Assign(blobData []byte, dictStore dictionary.Store, eip4844Enabled bool, x [32]byte, y fr381.Element) (assignment frontend.Circuit, publicInput fr.Element, snarkHash []byte, err error) {
const maxCLen = blob.MaxUsableBytes
const maxDLen = blob.MaxUncompressedBytes

Expand All @@ -56,11 +56,6 @@ func Assign(blobData, dict []byte, eip4844Enabled bool, x [32]byte, y fr381.Elem
return
}

dictStore, err := dictionary.SingletonStore(dict, 0)
if err != nil {
err = fmt.Errorf("failed to create dictionary store %w", err)
return
}
header, uncompressedData, _, err := blob.DecompressBlob(blobData, dictStore)
if err != nil {
err = fmt.Errorf("decompression circuit assignment : could not decompress the data : %w", err)
Expand Down
5 changes: 2 additions & 3 deletions prover/circuits/blobdecompression/v1/assign_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func prepare(t require.TestingT, blobBytes []byte) (c *v1.Circuit, a frontend.Ci

dictStore, err := dictionary.SingletonStore(blobtestutils.GetDict(t), 1)
assert.NoError(t, err)
_, payload, _, err := blobcompressorv1.DecompressBlob(blobBytes, dictStore)
_, payload, _, dict, err := blobcompressorv1.DecompressBlob(blobBytes, dictStore)
assert.NoError(t, err)

resp, err := blobsubmission.CraftResponse(&blobsubmission.Request{
Expand All @@ -50,8 +50,7 @@ func prepare(t require.TestingT, blobBytes []byte) (c *v1.Circuit, a frontend.Ci
y.SetBytes(b)

blobBytes = append(blobBytes, make([]byte, blobcompressorv1.MaxUsableBytes-len(blobBytes))...)
dict := blobtestutils.GetDict(t)
a, _, snarkHash, err := blobdecompression.Assign(blobBytes, dict, true, x, y)
a, _, snarkHash, err := blobdecompression.Assign(blobBytes, dictStore, true, x, y)
assert.NoError(t, err)

_, ok := a.(*v1.Circuit)
Expand Down
13 changes: 4 additions & 9 deletions prover/circuits/blobdecompression/v1/circuit.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,18 +247,13 @@ func Compile(dictionaryLength int) constraint.ConstraintSystem {
}
}

func AssignFPI(blobBytes, dict []byte, eip4844Enabled bool, x [32]byte, y fr381.Element) (fpi FunctionalPublicInput, err error) {
func AssignFPI(blobBytes []byte, dictStore dictionary.Store, eip4844Enabled bool, x [32]byte, y fr381.Element) (fpi FunctionalPublicInput, dict []byte, err error) {
if len(blobBytes) != blob.MaxUsableBytes {
err = fmt.Errorf("decompression circuit assignment : invalid blob length : %d. expected %d", len(blobBytes), blob.MaxUsableBytes)
return
}

dictStore, err := dictionary.SingletonStore(dict, 1)
if err != nil {
err = fmt.Errorf("failed to create dictionary store %w", err)
return
}
header, payload, _, err := blob.DecompressBlob(blobBytes, dictStore)
header, payload, _, dict, err := blob.DecompressBlob(blobBytes, dictStore)
if err != nil {
return
}
Expand Down Expand Up @@ -294,9 +289,9 @@ func AssignFPI(blobBytes, dict []byte, eip4844Enabled bool, x [32]byte, y fr381.
return
}

func Assign(blobBytes, dict []byte, eip4844Enabled bool, x [32]byte, y fr381.Element) (assignment frontend.Circuit, publicInput fr377.Element, snarkHash []byte, err error) {
func Assign(blobBytes []byte, dictStore dictionary.Store, eip4844Enabled bool, x [32]byte, y fr381.Element) (assignment frontend.Circuit, publicInput fr377.Element, snarkHash []byte, err error) {

fpi, err := AssignFPI(blobBytes, dict, eip4844Enabled, x, y)
fpi, dict, err := AssignFPI(blobBytes, dictStore, eip4844Enabled, x, y)
if err != nil {
return
}
Expand Down
4 changes: 2 additions & 2 deletions prover/circuits/blobdecompression/v1/snark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func TestParseHeader(t *testing.T) {

for _, blobData := range blobs {

header, _, blocks, err := blob.DecompressBlob(blobData, dictStore)
header, _, blocks, _, err := blob.DecompressBlob(blobData, dictStore)
assert.NoError(t, err)

assert.LessOrEqual(t, len(blocks), MaxNbBatches, "too many batches")
Expand Down Expand Up @@ -347,7 +347,7 @@ func TestDictHash(t *testing.T) {
dict := blobtestutils.GetDict(t)
dictStore, err := dictionary.SingletonStore(blobtestutils.GetDict(t), 1)
assert.NoError(t, err)
header, _, _, err := blob.DecompressBlob(blobBytes, dictStore) // a bit roundabout, but the header field is not public
header, _, _, _, err := blob.DecompressBlob(blobBytes, dictStore) // a bit roundabout, but the header field is not public
assert.NoError(t, err)

circuit := testDataDictHashCircuit{
Expand Down
16 changes: 3 additions & 13 deletions prover/circuits/pi-interconnection/assign.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"encoding/base64"
"fmt"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/dictionary"
"hash"

"github.com/consensys/linea-monorepo/prover/crypto/mimc"
Expand All @@ -13,7 +14,6 @@ import (
decompression "github.com/consensys/linea-monorepo/prover/circuits/blobdecompression/v1"
"github.com/consensys/linea-monorepo/prover/circuits/internal"
"github.com/consensys/linea-monorepo/prover/circuits/pi-interconnection/keccak"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob"
public_input "github.com/consensys/linea-monorepo/prover/public-input"
"github.com/consensys/linea-monorepo/prover/utils"
"github.com/sirupsen/logrus"
Expand All @@ -24,12 +24,9 @@ type Request struct {
Decompressions []blobsubmission.Response
Executions []public_input.Execution
Aggregation public_input.Aggregation
// Path to the compression dictionary. Used to extract the execution data
// for each execution.
DictPath string
}

func (c *Compiled) Assign(r Request) (a Circuit, err error) {
func (c *Compiled) Assign(r Request, dictStore dictionary.Store) (a Circuit, err error) {
internal.RegisterHints()
keccak.RegisterHints()
utils.RegisterHints()
Expand All @@ -56,13 +53,6 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) {
return
}

// @alex: We should pass that as a parameter. And also (@arya) pass a list
// of dictionnary because this function.
dict, err := blob.GetDict(r.DictPath)
if err != nil {
return Circuit{}, fmt.Errorf("could not find the dictionnary: path=%v err=%v", r.DictPath, err)
}

// For Shnarfs and Merkle Roots
hshK := c.Keccak.GetHasher()

Expand Down Expand Up @@ -111,7 +101,7 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) {
fpi decompression.FunctionalPublicInput
sfpi decompression.FunctionalPublicInputSnark
)
if fpi, err = decompression.AssignFPI(blobData[:], dict, p.Eip4844Enabled, x, y); err != nil {
if fpi, _, err = decompression.AssignFPI(blobData[:], dictStore, p.Eip4844Enabled, x, y); err != nil {
return
}
execDataChecksums = append(execDataChecksums, fpi.BatchSums...) // len(execDataChecksums) = index of the first execution associated with the next blob
Expand Down
57 changes: 0 additions & 57 deletions prover/circuits/pi-interconnection/bench/main.go

This file was deleted.

13 changes: 9 additions & 4 deletions prover/circuits/pi-interconnection/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package pi_interconnection_test
import (
"encoding/base64"
"fmt"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/dictionary"
"slices"
"testing"

Expand Down Expand Up @@ -47,7 +48,10 @@ func TestSingleBlockBlobE2E(t *testing.T) {
compiled, err := pi_interconnection.Compile(cfg, dummy.Compile)
assert.NoError(t, err)

a, err := compiled.Assign(req)
dictStore, err := dictionary.SingletonStore(blobtesting.GetDict(t), 1)
assert.NoError(t, err)

a, err := compiled.Assign(req, dictStore)
assert.NoError(t, err)

cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, compiled.Circuit, frontend.WithCapacity(3_000_000))
Expand Down Expand Up @@ -112,7 +116,6 @@ func TestTinyTwoBatchBlob(t *testing.T) {
req := pi_interconnection.Request{
Decompressions: []blobsubmission.Response{*blobResp},
Executions: execReq,
DictPath: "../../lib/compressor/compressor_dict.bin",
Aggregation: public_input.Aggregation{
FinalShnarf: blobResp.ExpectedShnarf,
ParentAggregationFinalShnarf: blobReq.PrevShnarf,
Expand Down Expand Up @@ -209,7 +212,6 @@ func TestTwoTwoBatchBlobs(t *testing.T) {
req := pi_interconnection.Request{
Decompressions: []blobsubmission.Response{*blobResp0, *blobResp1},
Executions: execReq,
DictPath: "../../lib/compressor/compressor_dict.bin",
Aggregation: public_input.Aggregation{
FinalShnarf: blobResp1.ExpectedShnarf,
ParentAggregationFinalShnarf: blobReq0.PrevShnarf,
Expand Down Expand Up @@ -255,6 +257,9 @@ func testPI(t *testing.T, req pi_interconnection.Request, options ...testPIOptio
slackIterationNum := len(cfg.slack) * len(cfg.slack)
slackIterationNum *= slackIterationNum

dictStore, err := dictionary.SingletonStore(blobtesting.GetDict(t), 1)
assert.NoError(t, err)

var slack [4]int

for i := 0; i < slackIterationNum; i++ {
Expand All @@ -277,7 +282,7 @@ func testPI(t *testing.T, req pi_interconnection.Request, options ...testPIOptio
compiled, err := pi_interconnection.Compile(cfg, dummy.Compile)
assert.NoError(t, err)

a, err := compiled.Assign(req)
a, err := compiled.Assign(req, dictStore)
assert.NoError(t, err)

assert.NoError(t, test.IsSolved(compiled.Circuit, &a, ecc.BLS12_377.ScalarField()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ func AssignSingleBlockBlob(t require.TestingT) pi_interconnection.Request {
merkleRoots := aggregation.PackInMiniTrees(test_utils.BlocksToHex(execReq.L2MessageHashes))

return pi_interconnection.Request{
DictPath: "../../lib/compressor/compressor_dict.bin",
Decompressions: []blobsubmission.Response{*blobResp},
Executions: []public_input.Execution{execReq},
Aggregation: public_input.Aggregation{
Expand Down
Loading

0 comments on commit 543aedd

Please sign in to comment.