From 3159e37c6bac345d8da80af5b5cb0071f3a03eb4 Mon Sep 17 00:00:00 2001 From: Azam Soleimanian <49027816+Soleimani193@users.noreply.github.com> Date: Wed, 12 Feb 2025 15:16:35 +0100 Subject: [PATCH 1/4] added more functionalities to the discoverer (#681) --- .../compiler/permutation/permutation_test.go | 4 +- .../compiler/permutation/settings.go | 4 +- prover/protocol/distributed/distributed.go | 2 +- prover/protocol/distributed/lpp/lpp.go | 6 +- .../period_separating_module_discoverer.go | 9 +- .../query_based_module_discoverer.go | 136 ++++++++++++++++++ prover/utils/collection/mapping.go | 24 ++++ 7 files changed, 170 insertions(+), 15 deletions(-) create mode 100644 prover/protocol/distributed/namebaseddiscoverer/query_based_module_discoverer.go diff --git a/prover/protocol/distributed/compiler/permutation/permutation_test.go b/prover/protocol/distributed/compiler/permutation/permutation_test.go index bcac9a88d..bd4735750 100644 --- a/prover/protocol/distributed/compiler/permutation/permutation_test.go +++ b/prover/protocol/distributed/compiler/permutation/permutation_test.go @@ -8,7 +8,7 @@ import ( "github.com/consensys/linea-monorepo/prover/protocol/compiler/grandproduct" "github.com/consensys/linea-monorepo/prover/protocol/distributed" dist_permutation "github.com/consensys/linea-monorepo/prover/protocol/distributed/compiler/permutation" - "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" + discoverer "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/wizard" "github.com/stretchr/testify/require" @@ -125,7 +125,7 @@ func TestPermutation(t *testing.T) { // test-case. initialComp := wizard.Compile(tc.DefineFunc) - disc := namebaseddiscoverer.PeriodSeperatingModuleDiscoverer{} + disc := discoverer.PeriodSeperatingModuleDiscoverer{} disc.Analyze(initialComp) // This declares a compiled IOP with only the columns of the module A diff --git a/prover/protocol/distributed/compiler/permutation/settings.go b/prover/protocol/distributed/compiler/permutation/settings.go index 59de8d1f3..c1b153cee 100644 --- a/prover/protocol/distributed/compiler/permutation/settings.go +++ b/prover/protocol/distributed/compiler/permutation/settings.go @@ -1,8 +1,8 @@ package dist_permutation -import "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" +import discoverer "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" type Settings struct { // Name of the target module - TargetModuleName namebaseddiscoverer.ModuleName + TargetModuleName discoverer.ModuleName } diff --git a/prover/protocol/distributed/distributed.go b/prover/protocol/distributed/distributed.go index 65798fea9..56cfde798 100644 --- a/prover/protocol/distributed/distributed.go +++ b/prover/protocol/distributed/distributed.go @@ -30,12 +30,12 @@ type ModuleDiscoverer interface { // Analyze is responsible for letting the module discoverer compute how to // group best the columns into modules. Analyze(comp *wizard.CompiledIOP) - NbModules() int ModuleList() []ModuleName FindModule(col ifaces.Column) ModuleName // given a query and a module name it checks if the query is inside the module ExpressionIsInModule(*symbolic.Expression, ModuleName) bool QueryIsInModule(ifaces.Query, ModuleName) bool + // it checks if the given column is in the given module ColumnIsInModule(col ifaces.Column, name ModuleName) bool } diff --git a/prover/protocol/distributed/lpp/lpp.go b/prover/protocol/distributed/lpp/lpp.go index 55ebf510c..c4feba827 100644 --- a/prover/protocol/distributed/lpp/lpp.go +++ b/prover/protocol/distributed/lpp/lpp.go @@ -18,7 +18,7 @@ func CompileLPPAndGetSeed(comp *wizard.CompiledIOP, lppCompilers ...func(*wizard ) // get the LPP columns from comp. - lppCols = append(lppCols, getLPPColumns(comp)...) + lppCols = append(lppCols, GetLPPColumns(comp)...) for _, col := range comp.Columns.AllHandlesAtRound(0) { oldColumns = append(oldColumns, col) @@ -107,7 +107,7 @@ func GetLPPComp(oldComp *wizard.CompiledIOP, newLPPCols []ifaces.Column) *wizard ) // get the LPP columns - lppCols = append(lppCols, getLPPColumns(oldComp)...) + lppCols = append(lppCols, GetLPPColumns(oldComp)...) lppCols = append(lppCols, newLPPCols...) for _, col := range lppCols { @@ -117,7 +117,7 @@ func GetLPPComp(oldComp *wizard.CompiledIOP, newLPPCols []ifaces.Column) *wizard } // it extract LPP columns from the context of each LPP query. -func getLPPColumns(c *wizard.CompiledIOP) []ifaces.Column { +func GetLPPColumns(c *wizard.CompiledIOP) []ifaces.Column { var ( lppColumns = []ifaces.Column{} diff --git a/prover/protocol/distributed/namebaseddiscoverer/period_separating_module_discoverer.go b/prover/protocol/distributed/namebaseddiscoverer/period_separating_module_discoverer.go index c2f04fe3f..54dfc1c7d 100644 --- a/prover/protocol/distributed/namebaseddiscoverer/period_separating_module_discoverer.go +++ b/prover/protocol/distributed/namebaseddiscoverer/period_separating_module_discoverer.go @@ -1,4 +1,4 @@ -package namebaseddiscoverer +package discoverer import ( "strings" @@ -55,11 +55,6 @@ func periodSeparator(name string) string { return name[:index] } -// NbModules returns the number of modules -func (p *PeriodSeperatingModuleDiscoverer) NbModules() int { - return len(p.modules) -} - // ModuleList returns the list of module names func (p *PeriodSeperatingModuleDiscoverer) ModuleList() []ModuleName { moduleNames := make([]ModuleName, 0, len(p.modules)) @@ -131,7 +126,7 @@ func (p *PeriodSeperatingModuleDiscoverer) ExpressionIsInModule(expr *symbolic.E } if nCols == 0 { - panic("could not find any column in the expression") + panic("unsupported, could not find any column in the expression") } else { return b } diff --git a/prover/protocol/distributed/namebaseddiscoverer/query_based_module_discoverer.go b/prover/protocol/distributed/namebaseddiscoverer/query_based_module_discoverer.go new file mode 100644 index 000000000..2b6aa445d --- /dev/null +++ b/prover/protocol/distributed/namebaseddiscoverer/query_based_module_discoverer.go @@ -0,0 +1,136 @@ +package discoverer + +import ( + "github.com/consensys/linea-monorepo/prover/protocol/column" + "github.com/consensys/linea-monorepo/prover/protocol/distributed" + "github.com/consensys/linea-monorepo/prover/protocol/distributed/lpp" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/variables" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/symbolic" + "github.com/consensys/linea-monorepo/prover/utils/collection" +) + +// struct implementing more complex analysis for [distributed.ModuleDiscoverer] +type QueryBasedDiscoverer struct { + // simple module discoverer; it does simple analysis + // it can neither capture the specific columns (like verifier columns or periodicSampling variables), + // nor categorizes the columns GL and LPP. + SimpleDiscoverer distributed.ModuleDiscoverer + // all the columns involved in the LPP queries, including the verifier columns + LPPColumns collection.Mapping[ModuleName, []ifaces.Column] + // all the columns involved in the GL queries, including the verifier columns + GLColumns collection.Mapping[ModuleName, []ifaces.Column] + // all the periodicSamples involved in the GL queries + PeriodicSamplingGL collection.Mapping[ModuleName, []variables.PeriodicSample] +} + +// Analyze first analyzes the simpleDiscoverer and then adds extra analyses based on the queries +// , to get the the verifier, GL , LPP columns and also PeriodicSamplings +func (d QueryBasedDiscoverer) Analyze(comp *wizard.CompiledIOP) { + + d.SimpleDiscoverer.Analyze(comp) + + // sanity check + if len(comp.QueriesParams.AllKeysAt(0)) != 0 { + panic("At this step, we do not expect a query with parameters") + } + + // capture and analysis GL Queries + for _, moduleName := range d.SimpleDiscoverer.ModuleList() { + for _, q := range comp.QueriesNoParams.AllKeys() { + + if global, ok := comp.QueriesNoParams.Data(q).(query.GlobalConstraint); ok { + + if d.SimpleDiscoverer.ExpressionIsInModule(global.Expression, moduleName) { + + // it analyzes the expression and update d.GLColumns and d.PeriodicSamplingGL + d.analyzeExprGL(global.Expression, moduleName) + + } + } + + // the same for local queries. + if local, ok := comp.QueriesNoParams.Data(q).(query.LocalConstraint); ok { + + if d.SimpleDiscoverer.ExpressionIsInModule(local.Expression, moduleName) { + + // it analyzes de expression and update d.GLColumns and d.PeriodicSamplingGL + d.analyzeExprGL(local.Expression, moduleName) + + } + } + + // get the LPP columns from all the LPP queries, and check if they are in the module + // this does not contain the new columns from preparation phase like the multiplicity columns. + lppCols := lpp.GetLPPColumns(comp) + for _, col := range lppCols { + if d.SimpleDiscoverer.ColumnIsInModule(col, moduleName) { + // update d content + d.LPPColumns.AppendNew(moduleName, []ifaces.Column{col}) + } + } + + } + } + +} + +// ModuleList returns the list of module names +func (d *QueryBasedDiscoverer) ModuleList() []ModuleName { + + return d.SimpleDiscoverer.ModuleList() +} + +// FindModule finds the module name for a given column +func (d *QueryBasedDiscoverer) FindModule(col ifaces.Column) ModuleName { + return d.SimpleDiscoverer.FindModule(col) +} + +// QueryIsInModule checks if the given query is inside the given module +func (d *QueryBasedDiscoverer) QueryIsInModule(q ifaces.Query, moduleName ModuleName) bool { + return d.SimpleDiscoverer.QueryIsInModule(q, moduleName) +} + +// ColumnIsInModule checks that the given column is inside the given module. +func (d *QueryBasedDiscoverer) ColumnIsInModule(col ifaces.Column, name ModuleName) bool { + return d.SimpleDiscoverer.ColumnIsInModule(col, name) +} + +// ExpressionIsInModule checks that all the columns (except verifiercol) in the expression are from the given module. +// +// It does not check the presence of the coins and other metadata in the module. +// the restriction over verifier column comes from the fact that the discoverer Analyses compiledIOP and the verifier columns are not accessible there. +func (p *QueryBasedDiscoverer) ExpressionIsInModule(expr *symbolic.Expression, name ModuleName) bool { + return p.SimpleDiscoverer.ExpressionIsInModule(expr, name) +} + +// analyzeExpr analyzes the expression and update the content of d. +func (d *QueryBasedDiscoverer) analyzeExprGL(expr *symbolic.Expression, moduleName ModuleName) { + + var ( + board = expr.Board() + metadata = board.ListVariableMetadata() + ) + + for _, m := range metadata { + + switch t := m.(type) { + case ifaces.Column: + + if shifted, ok := t.(column.Shifted); ok { + + d.GLColumns.AppendNew(moduleName, []ifaces.Column{shifted.Parent}) + + } else { + d.GLColumns.AppendNew(moduleName, []ifaces.Column{t}) + } + + case variables.PeriodicSample: + + d.PeriodicSamplingGL.AppendNew(moduleName, []variables.PeriodicSample{t}) + + } + } +} diff --git a/prover/utils/collection/mapping.go b/prover/utils/collection/mapping.go index 54d89ae63..462c36c84 100644 --- a/prover/utils/collection/mapping.go +++ b/prover/utils/collection/mapping.go @@ -2,6 +2,7 @@ package collection import ( "fmt" + "reflect" "github.com/consensys/linea-monorepo/prover/utils" ) @@ -131,3 +132,26 @@ func (kv *Mapping[K, V]) TryDel(k K) bool { } return found } + +// AppendNew appends a new value v to V if V is a slice +func (kv *Mapping[K, V]) AppendNew(k K, v V) { + // Get the current value from the map + currentValue, exists := kv.innerMap[k] + if !exists { + // If the key does not exist, initialize it with an empty slice of type V + kv.innerMap[k] = v + return + } + + // Use reflection to check if the current value is a slice + currentValueValue := reflect.ValueOf(currentValue) + if currentValueValue.Kind() == reflect.Slice { + // Append the new value to the slice + newSlice := reflect.Append(currentValueValue, reflect.ValueOf(v)) + // Set the updated slice back to the map + kv.innerMap[k] = newSlice.Interface().(V) + } else { + // If V is not a slice, handle the error + fmt.Println("Error: V is not a slice") + } +} From a95353377a69e55cd4da5eb5a4807253799d14ef Mon Sep 17 00:00:00 2001 From: Arijit Dutta <37040536+arijitdutta67@users.noreply.github.com> Date: Mon, 17 Feb 2025 15:13:44 +0530 Subject: [PATCH 2/4] Prover/distribute projection query (#608) * added projection query * compiler for projection added * compiler added to arcane * duplicate name fix * removing cptHolter to math poly * shifted the no lint command * added bin file in gitignore * removed gitignore change * remove bin file * code simplification as Alex suggested * fix in lpp * distributed projection query added * more test cases added * adding slice to accomodate multiple query per module * param changed to support additive structure * added horizontal split code * test added * check test added * incorporate Alex suggestion on PR 585 * Added documentation --------- Signed-off-by: Arijit Dutta <37040536+arijitdutta67@users.noreply.github.com> --- prover/maths/common/poly/poly.go | 5 +- .../distributedprojection/compiler.go | 57 ++++ .../compiler/distributedprojection/prover.go | 289 ++++++++++++++++++ .../distributedprojection/verifier.go | 91 ++++++ prover/protocol/compiler/projection/prover.go | 4 +- .../compiler/permutation/permutation.go | 2 +- .../compiler/projection/projection.go | 287 ++++++++++++++++- .../compiler/projection/projection_test.go | 209 +++++++++++++ .../distributed/conglomeration/translator.go | 10 + prover/protocol/distributed/lpp/lpp_test.go | 2 +- .../protocol/query/distributed_projection.go | 123 ++++++++ .../query/distributed_projection_test.go | 95 ++++++ prover/protocol/query/gnark_params.go | 5 + prover/protocol/query/projection.go | 4 +- prover/protocol/query/projection_test.go | 11 +- prover/protocol/wizard/compiled.go | 8 + prover/protocol/wizard/gnark_verifier.go | 14 + prover/protocol/wizard/prover.go | 21 ++ prover/protocol/wizard/verifier.go | 6 + 19 files changed, 1221 insertions(+), 22 deletions(-) create mode 100644 prover/protocol/compiler/distributedprojection/compiler.go create mode 100644 prover/protocol/compiler/distributedprojection/prover.go create mode 100644 prover/protocol/compiler/distributedprojection/verifier.go create mode 100644 prover/protocol/distributed/compiler/projection/projection_test.go create mode 100644 prover/protocol/query/distributed_projection.go create mode 100644 prover/protocol/query/distributed_projection_test.go diff --git a/prover/maths/common/poly/poly.go b/prover/maths/common/poly/poly.go index dfa685144..ca376b636 100644 --- a/prover/maths/common/poly/poly.go +++ b/prover/maths/common/poly/poly.go @@ -123,11 +123,10 @@ func EvaluateLagrangesAnyDomain(domain []field.Element, x field.Element) []field return lagrange } -// CmptHorner computes a random Horner accumulation of the filtered elements +// GetHornerTrace computes a random Horner accumulation of the filtered elements // starting from the last entry down to the first entry. The final value is // stored in the last entry of the returned slice. -// Todo: send it to a common utility package -func CmptHorner(c, fC []field.Element, x field.Element) []field.Element { +func GetHornerTrace(c, fC []field.Element, x field.Element) []field.Element { var ( horner = make([]field.Element, len(c)) diff --git a/prover/protocol/compiler/distributedprojection/compiler.go b/prover/protocol/compiler/distributedprojection/compiler.go new file mode 100644 index 000000000..ab83b7e21 --- /dev/null +++ b/prover/protocol/compiler/distributedprojection/compiler.go @@ -0,0 +1,57 @@ +package distributedprojection + +import ( + "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/symbolic" +) + +func CompileDistributedProjection(comp *wizard.CompiledIOP) { + + for _, qName := range comp.QueriesParams.AllUnignoredKeys() { + // Filter out non distributed projection queries + distributedprojection, ok := comp.QueriesParams.Data(qName).(query.DistributedProjection) + if !ok { + continue + } + + // This ensures that the distributed projection query is not used again in the + // compilation process. We know that the query was not already ignored at the beginning + // because we are iterating over the unignored keys. + comp.QueriesParams.MarkAsIgnored(qName) + round := comp.QueriesParams.Round(qName) + compile(comp, round, distributedprojection) + } +} + +func compile(comp *wizard.CompiledIOP, round int, distributedprojection query.DistributedProjection) { + var ( + pa = &distribuedProjectionProverAction{ + Name: distributedprojection.ID, + FilterA: make([]*symbolic.Expression, len(distributedprojection.Inp)), + FilterB: make([]*symbolic.Expression, len(distributedprojection.Inp)), + ColumnA: make([]*symbolic.Expression, len(distributedprojection.Inp)), + ColumnB: make([]*symbolic.Expression, len(distributedprojection.Inp)), + HornerA: make([]ifaces.Column, len(distributedprojection.Inp)), + HornerB: make([]ifaces.Column, len(distributedprojection.Inp)), + HornerA0: make([]query.LocalOpening, len(distributedprojection.Inp)), + HornerB0: make([]query.LocalOpening, len(distributedprojection.Inp)), + EvalCoin: make([]coin.Info, len(distributedprojection.Inp)), + IsA: make([]bool, len(distributedprojection.Inp)), + IsB: make([]bool, len(distributedprojection.Inp)), + } + ) + pa.Push(comp, distributedprojection) + pa.RegisterQueries(comp, round, distributedprojection) + comp.RegisterProverAction(round, pa) + comp.RegisterVerifierAction(round, &distributedProjectionVerifierAction{ + Name: pa.Name, + HornerA0: pa.HornerA0, + HornerB0: pa.HornerB0, + isA: pa.IsA, + isB: pa.IsB, + }) + +} diff --git a/prover/protocol/compiler/distributedprojection/prover.go b/prover/protocol/compiler/distributedprojection/prover.go new file mode 100644 index 000000000..ee7e49138 --- /dev/null +++ b/prover/protocol/compiler/distributedprojection/prover.go @@ -0,0 +1,289 @@ +package distributedprojection + +import ( + "github.com/consensys/linea-monorepo/prover/maths/common/poly" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + sym "github.com/consensys/linea-monorepo/prover/symbolic" + "github.com/consensys/linea-monorepo/prover/utils" + "github.com/consensys/linea-monorepo/prover/utils/collection" + "github.com/sirupsen/logrus" +) + +type distribuedProjectionProverAction struct { + Name ifaces.QueryID + FilterA, FilterB []*sym.Expression + ColumnA, ColumnB []*sym.Expression + HornerA, HornerB []ifaces.Column + HornerA0, HornerB0 []query.LocalOpening + EvalCoin []coin.Info + IsA, IsB []bool +} + +// Run executes the distributed projection prover action. +// It iterates over the input filters, columns, and evaluation coins. +// Depending on the values of IsA and IsB, it computes the Horner traces for columns A and B, +// and assigns them to the corresponding columns and local points in the prover runtime. +// If both IsA and IsB are true, it computes the Horner traces for both columns A and B. +// If IsA is true and IsB is false, it computes the Horner trace for column A only. +// If IsA is false and IsB is true, it computes the Horner trace for column B only. +// If neither IsA nor IsB is true, it panics with an error message indicating an invalid prover assignment. +func (pa *distribuedProjectionProverAction) Run(run *wizard.ProverRuntime) { + for index := range pa.FilterA { + if pa.IsA[index] && pa.IsB[index] { + var ( + colA = column.EvalExprColumn(run, pa.ColumnA[index].Board()).IntoRegVecSaveAlloc() + fA = column.EvalExprColumn(run, pa.FilterA[index].Board()).IntoRegVecSaveAlloc() + colB = column.EvalExprColumn(run, pa.ColumnB[index].Board()).IntoRegVecSaveAlloc() + fB = column.EvalExprColumn(run, pa.FilterB[index].Board()).IntoRegVecSaveAlloc() + x = run.GetRandomCoinField(pa.EvalCoin[index].Name) + hornerA = poly.GetHornerTrace(colA, fA, x) + hornerB = poly.GetHornerTrace(colB, fB, x) + ) + run.AssignColumn(pa.HornerA[index].GetColID(), smartvectors.NewRegular(hornerA)) + run.AssignLocalPoint(pa.HornerA0[index].ID, hornerA[0]) + run.AssignColumn(pa.HornerB[index].GetColID(), smartvectors.NewRegular(hornerB)) + run.AssignLocalPoint(pa.HornerB0[index].ID, hornerB[0]) + } else if pa.IsA[index] && !pa.IsB[index] { + var ( + colA = column.EvalExprColumn(run, pa.ColumnA[index].Board()).IntoRegVecSaveAlloc() + fA = column.EvalExprColumn(run, pa.FilterA[index].Board()).IntoRegVecSaveAlloc() + x = run.GetRandomCoinField(pa.EvalCoin[index].Name) + hornerA = poly.GetHornerTrace(colA, fA, x) + ) + run.AssignColumn(pa.HornerA[index].GetColID(), smartvectors.NewRegular(hornerA)) + run.AssignLocalPoint(pa.HornerA0[index].ID, hornerA[0]) + } else if !pa.IsA[index] && pa.IsB[index] { + var ( + colB = column.EvalExprColumn(run, pa.ColumnB[index].Board()).IntoRegVecSaveAlloc() + fB = column.EvalExprColumn(run, pa.FilterB[index].Board()).IntoRegVecSaveAlloc() + x = run.GetRandomCoinField(pa.EvalCoin[index].Name) + hornerB = poly.GetHornerTrace(colB, fB, x) + ) + run.AssignColumn(pa.HornerB[index].GetColID(), smartvectors.NewRegular(hornerB)) + run.AssignLocalPoint(pa.HornerB0[index].ID, hornerB[0]) + } else { + utils.Panic("Invalid prover assignment in distributed projection id: %v", pa.Name) + } + } +} + +// Push populates the distribuedProjectionProverAction with data from the provided DistributedProjection query. +// It processes each input in the query and assigns the corresponding values to the prover action's fields +// based on whether the input is in module A, module B, or both. +// +// Parameters: +// - comp: A pointer to the CompiledIOP, used to access coin data. +// - distributedprojection: The DistributedProjection query containing the inputs to be processed. +// +// The function does not return any value, but updates the fields of the distribuedProjectionProverAction in-place. +func (pa *distribuedProjectionProverAction) Push(comp *wizard.CompiledIOP, distributedprojection query.DistributedProjection) { + for index, input := range distributedprojection.Inp { + if input.IsAInModule && input.IsBInModule { + pa.FilterA[index] = input.FilterA + pa.FilterB[index] = input.FilterB + pa.ColumnA[index] = input.ColumnA + pa.ColumnB[index] = input.ColumnB + pa.EvalCoin[index] = comp.Coins.Data(input.EvalCoin) + pa.IsA[index] = true + pa.IsB[index] = true + + } else if input.IsAInModule && !input.IsBInModule { + pa.FilterA[index] = input.FilterA + pa.ColumnA[index] = input.ColumnA + pa.EvalCoin[index] = comp.Coins.Data(input.EvalCoin) + pa.IsA[index] = true + pa.IsB[index] = false + + } else if !input.IsAInModule && input.IsBInModule { + pa.FilterB[index] = input.FilterB + pa.ColumnB[index] = input.ColumnB + pa.EvalCoin[index] = comp.Coins.Data(input.EvalCoin) + pa.IsA[index] = false + pa.IsB[index] = true + + } else { + logrus.Errorf("Invalid distributed projection query while pushing prover action entries: %v", distributedprojection.ID) + } + } +} + +// RegisterQueries registers the necessary queries for the distributed projection prover action. +// It processes each input in the distributed projection, shifting expressions and registering +// queries for columns A and B based on their presence in the respective modules. +// +// Parameters: +// - comp: A pointer to the CompiledIOP, used for inserting commits and queries. +// - round: An integer representing the current round of the protocol. +// - distributedprojection: A DistributedProjection query containing the inputs to be processed. +// +// The function does not return any value, but updates the internal state of the prover action +// by registering the required queries for each input. +func (pa *distribuedProjectionProverAction) RegisterQueries(comp *wizard.CompiledIOP, round int, distributedprojection query.DistributedProjection) { + for index, input := range distributedprojection.Inp { + if input.IsAInModule && input.IsBInModule { + var ( + fA = pa.FilterA[index] + fAShifted = shiftExpression(comp, fA, -1) + colA = pa.ColumnA[index] + colAShifted = shiftExpression(comp, colA, -1) + fB = pa.FilterB[index] + fBShifted = shiftExpression(comp, fB, -1) + colB = pa.ColumnB[index] + colBShifted = shiftExpression(comp, colB, -1) + ) + pa.registerForCol(comp, fAShifted, colAShifted, input, "A", round, index) + pa.registerForCol(comp, fBShifted, colBShifted, input, "B", round, index) + } else if input.IsAInModule && !input.IsBInModule { + var ( + fA = pa.FilterA[index] + fAShifted = shiftExpression(comp, fA, -1) + colA = pa.ColumnA[index] + colAShifted = shiftExpression(comp, colA, -1) + ) + pa.registerForCol(comp, fAShifted, colAShifted, input, "A", round, index) + } else if !input.IsAInModule && input.IsBInModule { + var ( + fB = pa.FilterB[index] + fBShifted = shiftExpression(comp, fB, -1) + colB = pa.ColumnB[index] + colBShifted = shiftExpression(comp, colB, -1) + ) + pa.registerForCol(comp, fBShifted, colBShifted, input, "B", round, index) + } else { + utils.Panic("Invalid prover action case for the distributed projection query %v", pa.Name) + } + } +} + +// registerForCol registers queries for a specific column (A or B) in the distributed projection prover action. +// It inserts commits, global queries, local queries, and local openings for the Horner polynomial evaluation. +// +// Parameters: +// - comp: A pointer to the CompiledIOP, used for inserting commits and queries. +// - fShifted: A shifted filter expression. +// - colShifted: A shifted column expression. +// - input: A pointer to the DistributedProjectionInput containing size information. +// - colName: A string indicating which column to register ("A" or "B"). +// - round: An integer representing the current round of the protocol. +// - index: An integer used to uniquely identify the registered queries. +// +// The function doesn't return any value but updates the internal state of the prover action +// by registering the required queries and commits for the specified column. +func (pa *distribuedProjectionProverAction) registerForCol( + comp *wizard.CompiledIOP, + fShifted, colShifted *sym.Expression, + input *query.DistributedProjectionInput, + colName string, + round int, + index int, +) { + switch colName { + case "A": + { + pa.HornerA[index] = comp.InsertCommit(round, ifaces.ColIDf("%v_HORNER_A_%v", pa.Name, index), input.SizeA) + comp.InsertGlobal( + round, + ifaces.QueryIDf("%v_HORNER_A_%v_GLOBAL", pa.Name, index), + sym.Sub( + pa.HornerA[index], + sym.Mul( + sym.Sub(1, pa.FilterA[index]), + column.Shift(pa.HornerA[index], 1), + ), + sym.Mul( + pa.FilterA[index], + sym.Add( + pa.ColumnA[index], + sym.Mul( + pa.EvalCoin[index], + column.Shift(pa.HornerA[index], 1), + ), + ), + ), + ), + ) + comp.InsertLocal( + round, + ifaces.QueryIDf("%v_HORNER_A_LOCAL_END_%v", pa.Name, index), + sym.Sub( + column.Shift(pa.HornerA[index], -1), + sym.Mul(fShifted, colShifted), + ), + ) + pa.HornerA0[index] = comp.InsertLocalOpening(round, ifaces.QueryIDf("%v_HORNER_A0_%v", pa.Name, index), pa.HornerA[index]) + } + case "B": + { + pa.HornerB[index] = comp.InsertCommit(round, ifaces.ColIDf("%v_HORNER_B_%v", pa.Name, index), input.SizeB) + + comp.InsertGlobal( + round, + ifaces.QueryIDf("%v_HORNER_B_%v_GLOBAL", pa.Name, index), + sym.Sub( + pa.HornerB[index], + sym.Mul( + sym.Sub(1, pa.FilterB[index]), + column.Shift(pa.HornerB[index], 1), + ), + sym.Mul( + pa.FilterB[index], + sym.Add(pa.ColumnB[index], sym.Mul(pa.EvalCoin[index], column.Shift(pa.HornerB[index], 1))), + ), + ), + ) + + comp.InsertLocal( + round, + ifaces.QueryIDf("%v_HORNER_B_LOCAL_END_%v", pa.Name, index), + sym.Sub( + column.Shift(pa.HornerB[index], -1), + sym.Mul(fShifted, colShifted), + ), + ) + + pa.HornerB0[index] = comp.InsertLocalOpening(round, ifaces.QueryIDf("%v_HORNER_B0_%v", pa.Name, index), pa.HornerB[index]) + } + default: + utils.Panic("Invalid column name %v, should be either A or B", colName) + } + +} + +// shiftExpression shifts a column which is a symbolic expression by a specified number of positions. +// It creates a new expression with the shifted column while maintaining the structure of the original expression. +// +// Parameters: +// - comp: A pointer to the CompiledIOP, used to check for the existence of coins. +// - expr: The original symbolic expression to be shifted. +// - nbShift: The number of positions to shift the column. Positive values shift forward, negative values shift backward. +// +// Returns: +// +// A new *sym.Expression with the column shifted according to the specified nbShift. +func shiftExpression(comp *wizard.CompiledIOP, expr *sym.Expression, nbShift int) *sym.Expression { + var ( + board = expr.Board() + metadata = board.ListVariableMetadata() + translationMap = collection.NewMapping[string, *sym.Expression]() + ) + + for _, m := range metadata { + switch t := m.(type) { + case ifaces.Column: + translationMap.InsertNew(string(t.GetColID()), ifaces.ColumnAsVariable(column.Shift(t, nbShift))) + case coin.Info: + if !comp.Coins.Exists(t.Name) { + utils.Panic("Coin %v does not exist in the InitialComp", t.Name) + } + translationMap.InsertNew(t.String(), sym.NewVariable(t)) + default: + utils.Panic("Unsupported type for shift expression operation") + } + } + return expr.Replay(translationMap) +} diff --git a/prover/protocol/compiler/distributedprojection/verifier.go b/prover/protocol/compiler/distributedprojection/verifier.go new file mode 100644 index 000000000..608339449 --- /dev/null +++ b/prover/protocol/compiler/distributedprojection/verifier.go @@ -0,0 +1,91 @@ +package distributedprojection + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/utils" +) + +type distributedProjectionVerifierAction struct { + Name ifaces.QueryID + HornerA0, HornerB0 []query.LocalOpening + isA, isB []bool + skipped bool +} + +// Run implements the [wizard.VerifierAction] +func (va *distributedProjectionVerifierAction) Run(run wizard.Runtime) error { + var ( + actualParam = field.Zero() + ) + for index := range va.HornerA0 { + var ( + elemParam = field.Zero() + ) + if va.isA[index] && va.isB[index] { + elemParam = run.GetLocalPointEvalParams(va.HornerB0[index].ID).Y + elemParam.Neg(&elemParam) + temp := run.GetLocalPointEvalParams(va.HornerA0[index].ID).Y + elemParam.Add(&elemParam, &temp) + } else if va.isA[index] && !va.isB[index] { + elemParam = run.GetLocalPointEvalParams(va.HornerA0[index].ID).Y + } else if !va.isA[index] && va.isB[index] { + elemParam = run.GetLocalPointEvalParams(va.HornerB0[index].ID).Y + elemParam.Neg(&elemParam) + } else { + utils.Panic("Unsupported verifier action registered for %v", va.Name) + } + actualParam.Add(&actualParam, &elemParam) + } + queryParam := run.GetDistributedProjectionParams(va.Name).HornerVal + if actualParam != queryParam { + utils.Panic("The distributed projection query %v did not pass, query param %v and actual param %v", va.Name, queryParam, actualParam) + } + return nil +} + +// RunGnark implements the [wizard.VerifierAction] interface. +func (va *distributedProjectionVerifierAction) RunGnark(api frontend.API, run wizard.GnarkRuntime) { + + var ( + actualParam = frontend.Variable(0) + ) + for index := range va.HornerA0 { + var ( + elemParam = frontend.Variable(0) + ) + if va.isA[index] && va.isB[index] { + var ( + a, b frontend.Variable + ) + a = run.GetLocalPointEvalParams(va.HornerA0[index].ID).Y + b = run.GetLocalPointEvalParams(va.HornerB0[index].ID).Y + elemParam = api.Sub(a, b) + } else if va.isA[index] && !va.isB[index] { + a := run.GetLocalPointEvalParams(va.HornerA0[index].ID).Y + elemParam = api.Add(elemParam, a) + } else if !va.isA[index] && va.isB[index] { + b := run.GetLocalPointEvalParams(va.HornerB0[index].ID).Y + elemParam = api.Sub(elemParam, b) + } else { + utils.Panic("Unsupported verifier action registered for %v", va.Name) + } + actualParam = api.Add(actualParam, elemParam) + } + queryParam := run.GetDistributedProjectionParams(va.Name).Sum + + api.AssertIsEqual(actualParam, queryParam) +} + +// Skip implements the [wizard.VerifierAction] +func (va *distributedProjectionVerifierAction) Skip() { + va.skipped = true +} + +// IsSkipped implements the [wizard.VerifierAction] +func (va *distributedProjectionVerifierAction) IsSkipped() bool { + return va.skipped +} diff --git a/prover/protocol/compiler/projection/prover.go b/prover/protocol/compiler/projection/prover.go index 7726a0d3e..fb96d5c77 100644 --- a/prover/protocol/compiler/projection/prover.go +++ b/prover/protocol/compiler/projection/prover.go @@ -39,8 +39,8 @@ func (pa projectionProverAction) Run(run *wizard.ProverRuntime) { fA = pa.FilterA.GetColAssignment(run).IntoRegVecSaveAlloc() fB = pa.FilterB.GetColAssignment(run).IntoRegVecSaveAlloc() x = run.GetRandomCoinField(pa.EvalCoin.Name) - hornerA = poly.CmptHorner(a, fA, x) - hornerB = poly.CmptHorner(b, fB, x) + hornerA = poly.GetHornerTrace(a, fA, x) + hornerB = poly.GetHornerTrace(b, fB, x) ) run.AssignColumn(pa.HornerA.GetColID(), smartvectors.NewRegular(hornerA)) diff --git a/prover/protocol/distributed/compiler/permutation/permutation.go b/prover/protocol/distributed/compiler/permutation/permutation.go index 794aed774..d9806640c 100644 --- a/prover/protocol/distributed/compiler/permutation/permutation.go +++ b/prover/protocol/distributed/compiler/permutation/permutation.go @@ -71,7 +71,7 @@ func NewPermutationIntoGrandProductCtx( } /* - Handles the lookups and permutations checks + Handles the permutations checks */ for round := 0; round < numRounds; round++ { queries := initialComp.QueriesNoParams.AllKeysAt(round) diff --git a/prover/protocol/distributed/compiler/projection/projection.go b/prover/protocol/distributed/compiler/projection/projection.go index 20f705d66..ddaf3d4b7 100644 --- a/prover/protocol/distributed/compiler/projection/projection.go +++ b/prover/protocol/distributed/compiler/projection/projection.go @@ -1,17 +1,286 @@ -package projection +package dist_projection import ( + "github.com/consensys/linea-monorepo/prover/maths/common/poly" + "github.com/consensys/linea-monorepo/prover/maths/field" "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" + "github.com/consensys/linea-monorepo/prover/protocol/distributed" + "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" + "github.com/consensys/linea-monorepo/prover/symbolic" + "github.com/consensys/linea-monorepo/prover/utils" ) -// CompileDist compiles the projection queries distributedly. -// It receives a compiledIOP object relevant to a segment. -// The seed is a random coin from randomness beacon (FS of all LPP commitments). -// All the compilation steps are similar to the permutation compilation apart from: -// - random coins \alpha and \gamma are generated from the seed (and the tableName). -// - no verifierAction is needed over the ZOpening. -// - ZOpenings are declared as public input. -func CompileDist(comp *wizard.CompiledIOP, seed coin.Info) { +// Used for deriving names of queries and coins +const ( + distProjectionStr = "DISTRIBUTED_PROJECTION" + MaxNumOfQueriesPerModule = 10 +) + +type DistributeProjectionCtx struct { + DistProjectionInput []*query.DistributedProjectionInput + EvalCoins []coin.Info + // The module name for which we are processing the distributed projection query + TargetModuleName string + // Query stores the [query.DistributedProjection] generated by the compilation + Query query.DistributedProjection + // LastRoundPerm indicates the highest round at which a compiled projection + // occurs. + LastRoundProjection int +} + +// NewDistributeProjectionCtx processes all the projection queries from the initialComp +// and registers DistributedProjection queries to the target module using the module +// discoverer +func NewDistributeProjectionCtx( + targetModuleName namebaseddiscoverer.ModuleName, + initialComp, moduleComp *wizard.CompiledIOP, + disc distributed.ModuleDiscoverer, +) *DistributeProjectionCtx { + var ( + p = &DistributeProjectionCtx{ + DistProjectionInput: make([]*query.DistributedProjectionInput, 0, MaxNumOfQueriesPerModule), + EvalCoins: make([]coin.Info, 0, MaxNumOfQueriesPerModule), + TargetModuleName: targetModuleName, + LastRoundProjection: getLastRoundPerm(initialComp), + } + numRounds = initialComp.NumRounds() + qId = p.QueryID() + ) + if p.LastRoundProjection < 0 { + return p + } + + /* + Handles the projection checks + */ + for round := 0; round < numRounds; round++ { + queries := initialComp.QueriesNoParams.AllKeysAt(round) + for queryInRound, qName := range queries { + + // Skip if it was already compiled + if initialComp.QueriesNoParams.IsIgnored(qName) { + continue + } + + q_, ok := initialComp.QueriesNoParams.Data(qName).(query.Projection) + if !ok { + continue + } + var ( + onlyA = (disc.FindModule(q_.Inp.ColumnA[0]) == targetModuleName) && (disc.FindModule(q_.Inp.ColumnB[0]) != targetModuleName) + onlyB = (disc.FindModule(q_.Inp.ColumnA[0]) != targetModuleName) && (disc.FindModule(q_.Inp.ColumnB[0]) == targetModuleName) + bothAAndB = (disc.FindModule(q_.Inp.ColumnA[0]) == targetModuleName) && (disc.FindModule(q_.Inp.ColumnB[0]) == targetModuleName) + ) + if bothAAndB { + check(q_.Inp.ColumnA, disc, targetModuleName) + check(q_.Inp.ColumnB, disc, targetModuleName) + p.push(moduleComp, q_, round, queryInRound, true, true) + initialComp.QueriesNoParams.MarkAsIgnored(qName) + // Todo: Add panic if other cols are from other modules + } else if onlyA { + check(q_.Inp.ColumnA, disc, targetModuleName) + p.push(moduleComp, q_, round, queryInRound, true, false) + initialComp.QueriesNoParams.MarkAsIgnored(qName) + } else if onlyB { + check(q_.Inp.ColumnB, disc, targetModuleName) + p.push(moduleComp, q_, round, queryInRound, false, true) + initialComp.QueriesNoParams.MarkAsIgnored(qName) + } else { + continue + } + } + } + // We register the grand product query in round one because + // alphas, betas, and the query param are assigned in round one + p.Query = moduleComp.InsertDistributedProjection(p.LastRoundProjection+1, qId, p.DistProjectionInput) + + moduleComp.RegisterProverAction(p.LastRoundProjection+1, p) + return p + +} + +// Check verifies if all columns of the projection query belongs to the same module or not +func check(cols []ifaces.Column, + disc distributed.ModuleDiscoverer, + targetModuleName namebaseddiscoverer.ModuleName, +) error { + for _, col := range cols { + if disc.FindModule(col) != targetModuleName { + utils.Panic("unsupported projection query, colName: %v, target: %v", col.GetColID(), targetModuleName) + } + } + return nil +} + +// push appends a new DistributedProjectionInput to the DistProjectionInput slice +func (p *DistributeProjectionCtx) push(comp *wizard.CompiledIOP, q query.Projection, round, queryInRound int, isA, isB bool) { + var ( + isMultiColumn = len(q.Inp.ColumnA) > 1 + alphaName = p.getCoinName("MERGING_COIN", round, queryInRound) + betaName = p.getCoinName("EVAL_COIN", round, queryInRound) + alpha coin.Info + beta coin.Info + ) + // Register alpha and beta + if isMultiColumn { + if comp.Coins.Exists(alphaName) { + alpha = comp.Coins.Data(alphaName) + } else { + alpha = comp.InsertCoin(p.LastRoundProjection+1, alphaName, coin.Field) + } + } + + if comp.Coins.Exists(betaName) { + beta = comp.Coins.Data(betaName) + } else { + beta = comp.InsertCoin(p.LastRoundProjection+1, betaName, coin.Field) + } + p.EvalCoins = append(p.EvalCoins, beta) + if isA && isB { + fA, _, _ := wizardutils.AsExpr(q.Inp.FilterA) + fB, _, _ := wizardutils.AsExpr(q.Inp.FilterB) + p.DistProjectionInput = append(p.DistProjectionInput, &query.DistributedProjectionInput{ + ColumnA: wizardutils.RandLinCombColSymbolic(alpha, q.Inp.ColumnA), + ColumnB: wizardutils.RandLinCombColSymbolic(alpha, q.Inp.ColumnB), + FilterA: fA, + FilterB: fB, + SizeA: q.Inp.FilterA.Size(), + SizeB: q.Inp.FilterB.Size(), + EvalCoin: beta.Name, + IsAInModule: true, + IsBInModule: true, + }) + } else if isA { + fA, _, _ := wizardutils.AsExpr(q.Inp.FilterA) + p.DistProjectionInput = append(p.DistProjectionInput, &query.DistributedProjectionInput{ + ColumnA: wizardutils.RandLinCombColSymbolic(alpha, q.Inp.ColumnA), + ColumnB: symbolic.NewConstant(1), + FilterA: fA, + FilterB: symbolic.NewConstant(1), + SizeA: q.Inp.FilterA.Size(), + EvalCoin: beta.Name, + IsAInModule: true, + IsBInModule: false, + }) + } else if isB { + fB, _, _ := wizardutils.AsExpr(q.Inp.FilterB) + p.DistProjectionInput = append(p.DistProjectionInput, &query.DistributedProjectionInput{ + ColumnA: symbolic.NewConstant(1), + ColumnB: wizardutils.RandLinCombColSymbolic(alpha, q.Inp.ColumnB), + FilterA: symbolic.NewConstant(1), + FilterB: fB, + SizeB: q.Inp.FilterB.Size(), + EvalCoin: beta.Name, + IsAInModule: false, + IsBInModule: true, + }) + } else { + panic("Invalid distributed projection query while initial pushing") + } +} + +// computeQueryParam computes the parameter of the DistributedProjection query +func (p *DistributeProjectionCtx) computeQueryParam(run *wizard.ProverRuntime) field.Element { + var ( + queryParam = field.Zero() + elemParam = field.Zero() + ) + for elemIndex, inp := range p.DistProjectionInput { + if inp.IsAInModule && inp.IsBInModule { + var ( + colABoard = inp.ColumnA.Board() + colBBoard = inp.ColumnB.Board() + filterABorad = inp.FilterA.Board() + filterBBoard = inp.FilterB.Board() + colA = column.EvalExprColumn(run, colABoard).IntoRegVecSaveAlloc() + colB = column.EvalExprColumn(run, colBBoard).IntoRegVecSaveAlloc() + filterA = column.EvalExprColumn(run, filterABorad).IntoRegVecSaveAlloc() + filterB = column.EvalExprColumn(run, filterBBoard).IntoRegVecSaveAlloc() + ) + hornerA := poly.GetHornerTrace(colA, filterA, run.GetRandomCoinField(p.EvalCoins[elemIndex].Name)) + hornerB := poly.GetHornerTrace(colB, filterB, run.GetRandomCoinField(p.EvalCoins[elemIndex].Name)) + elemParam = hornerB[0] + elemParam.Neg(&elemParam) + elemParam.Add(&elemParam, &hornerA[0]) + } else if inp.IsAInModule && !inp.IsBInModule { + var ( + colABoard = inp.ColumnA.Board() + filterABorad = inp.FilterA.Board() + colA = column.EvalExprColumn(run, colABoard).IntoRegVecSaveAlloc() + filterA = column.EvalExprColumn(run, filterABorad).IntoRegVecSaveAlloc() + ) + hornerA := poly.GetHornerTrace(colA, filterA, run.GetRandomCoinField(p.EvalCoins[elemIndex].Name)) + elemParam = hornerA[0] + } else if !inp.IsAInModule && inp.IsBInModule { + var ( + colBBoard = inp.ColumnB.Board() + filterBBorad = inp.FilterB.Board() + colB = column.EvalExprColumn(run, colBBoard).IntoRegVecSaveAlloc() + filterB = column.EvalExprColumn(run, filterBBorad).IntoRegVecSaveAlloc() + ) + hornerB := poly.GetHornerTrace(colB, filterB, run.GetRandomCoinField(p.EvalCoins[elemIndex].Name)) + elemParam = hornerB[0] + elemParam.Neg(&elemParam) + } else { + panic("Invalid distributed projection query encountered during param evaluation") + } + queryParam.Add(&queryParam, &elemParam) + } + return queryParam +} + +// Run implements [wizard.ProverAction] interface +func (p *DistributeProjectionCtx) Run(run *wizard.ProverRuntime) { + run.AssignDistributedProjection(p.Query.ID, query.DistributedProjectionParams{ + HornerVal: p.computeQueryParam(run), + }) +} + +// deriveName constructs a name for the DistributeProjectionCtx context +func deriveName[R ~string](q ifaces.QueryID, ss ...any) R { + ss = append([]any{distProjectionStr, q}, ss...) + return wizardutils.DeriveName[R](ss...) +} + +// QueryID formats and returns a name of the [query.DistributedProjection] generated by the current context +func (p *DistributeProjectionCtx) QueryID() ifaces.QueryID { + return deriveName[ifaces.QueryID](ifaces.QueryID(p.TargetModuleName)) +} + +func (p *DistributeProjectionCtx) getCoinName(name string, round, queryInRound int) coin.Name { + return deriveName[coin.Name](p.QueryID(), name, round, queryInRound) +} + +// getLastRoundPerm scans the initialComp and looks for uncompiled projection queries. It returns +// the highest round found for a matched projection query. It returns -1 if no queries are found. +func getLastRoundPerm(initialComp *wizard.CompiledIOP) int { + + var ( + lastRound = -1 + numRounds = initialComp.NumRounds() + ) + + for round := 0; round < numRounds; round++ { + queries := initialComp.QueriesNoParams.AllKeysAt(round) + for _, qName := range queries { + + if initialComp.QueriesNoParams.IsIgnored(qName) { + continue + } + + _, ok := initialComp.QueriesNoParams.Data(qName).(query.Projection) + if !ok { + continue + } + + lastRound = max(lastRound, round) + } + } + return lastRound } diff --git a/prover/protocol/distributed/compiler/projection/projection_test.go b/prover/protocol/distributed/compiler/projection/projection_test.go new file mode 100644 index 000000000..ae6de7bcd --- /dev/null +++ b/prover/protocol/distributed/compiler/projection/projection_test.go @@ -0,0 +1,209 @@ +package dist_projection_test + +import ( + "testing" + + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/distributedprojection" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy" + "github.com/consensys/linea-monorepo/prover/protocol/distributed" + dist_projection "github.com/consensys/linea-monorepo/prover/protocol/distributed/compiler/projection" + "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/stretchr/testify/require" +) + +func TestDistributeProjection(t *testing.T) { + var ( + moduleAName = "moduleA" + flagSizeA = 512 + flagSizeB = 256 + flagA, flagB, columnA, columnB, flagC, columnC ifaces.Column + colA0, colA1, colA2, colB0, colB1, colB2, colC0, colC1, colC2 ifaces.Column + ) + testcases := []struct { + Name string + DefineFunc func(builder *wizard.Builder) + InitialProverFunc func(run *wizard.ProverRuntime) + }{ + { + Name: "distribute-projection-both-A-and-B", + DefineFunc: func(builder *wizard.Builder) { + flagA = builder.RegisterCommit(ifaces.ColID("moduleA.FilterA"), flagSizeA) + flagB = builder.RegisterCommit(ifaces.ColID("moduleA.FliterB"), flagSizeB) + columnA = builder.RegisterCommit(ifaces.ColID("moduleA.ColumnA"), flagSizeA) + columnB = builder.RegisterCommit(ifaces.ColID("moduleA.ColumnB"), flagSizeB) + _ = builder.InsertProjection("ProjectionTest-both-A-and-B", + query.ProjectionInput{ColumnA: []ifaces.Column{columnA}, ColumnB: []ifaces.Column{columnB}, FilterA: flagA, FilterB: flagB}) + + }, + InitialProverFunc: func(run *wizard.ProverRuntime) { + // assign filters and columns + var ( + flagAWit = make([]field.Element, flagSizeA) + columnAWit = make([]field.Element, flagSizeA) + flagBWit = make([]field.Element, flagSizeB) + columnBWit = make([]field.Element, flagSizeB) + ) + for i := 0; i < 10; i++ { + flagAWit[i] = field.One() + columnAWit[i] = field.NewElement(uint64(i)) + } + for i := flagSizeB - 10; i < flagSizeB; i++ { + flagBWit[i] = field.One() + columnBWit[i] = field.NewElement(uint64(i - (flagSizeB - 10))) + } + run.AssignColumn(flagA.GetColID(), smartvectors.RightZeroPadded(flagAWit, flagSizeA)) + run.AssignColumn(flagB.GetColID(), smartvectors.RightZeroPadded(flagBWit, flagSizeB)) + run.AssignColumn(columnB.GetColID(), smartvectors.RightZeroPadded(columnBWit, flagSizeB)) + run.AssignColumn(columnA.GetColID(), smartvectors.RightZeroPadded(columnAWit, flagSizeA)) + }, + }, + { + Name: "distribute-projection-multiple_projections", + DefineFunc: func(builder *wizard.Builder) { + flagA = builder.RegisterCommit(ifaces.ColID("moduleA.FilterA"), flagSizeA) + flagB = builder.RegisterCommit(ifaces.ColID("moduleB.FliterB"), flagSizeB) + flagC = builder.RegisterCommit(ifaces.ColID("moduleC.FliterB"), flagSizeB) + columnA = builder.RegisterCommit(ifaces.ColID("moduleA.ColumnA"), flagSizeA) + columnB = builder.RegisterCommit(ifaces.ColID("moduleB.ColumnB"), flagSizeB) + columnC = builder.RegisterCommit(ifaces.ColID("moduleC.ColumnC"), flagSizeB) + _ = builder.InsertProjection("ProjectionTest-A-B", + query.ProjectionInput{ColumnA: []ifaces.Column{columnA}, ColumnB: []ifaces.Column{columnB}, FilterA: flagA, FilterB: flagB}) + _ = builder.InsertProjection("ProjectionTest-C-A", + query.ProjectionInput{ColumnA: []ifaces.Column{columnC}, ColumnB: []ifaces.Column{columnA}, FilterA: flagC, FilterB: flagA}) + + }, + InitialProverFunc: func(run *wizard.ProverRuntime) { + // assign filters and columns + var ( + flagAWit = make([]field.Element, flagSizeA) + columnAWit = make([]field.Element, flagSizeA) + flagBWit = make([]field.Element, flagSizeB) + columnBWit = make([]field.Element, flagSizeB) + flagCWit = make([]field.Element, flagSizeB) + columnCWit = make([]field.Element, flagSizeB) + ) + for i := 0; i < 10; i++ { + flagAWit[i] = field.One() + columnAWit[i] = field.NewElement(uint64(i)) + } + for i := flagSizeB - 10; i < flagSizeB; i++ { + flagBWit[i] = field.One() + flagCWit[i] = field.One() + columnBWit[i] = field.NewElement(uint64(i - (flagSizeB - 10))) + columnCWit[i] = field.NewElement(uint64(i - (flagSizeB - 10))) + } + run.AssignColumn(flagA.GetColID(), smartvectors.RightZeroPadded(flagAWit, flagSizeA)) + run.AssignColumn(flagB.GetColID(), smartvectors.RightZeroPadded(flagBWit, flagSizeB)) + run.AssignColumn(flagC.GetColID(), smartvectors.RightZeroPadded(flagCWit, flagSizeB)) + run.AssignColumn(columnB.GetColID(), smartvectors.RightZeroPadded(columnBWit, flagSizeB)) + run.AssignColumn(columnA.GetColID(), smartvectors.RightZeroPadded(columnAWit, flagSizeA)) + run.AssignColumn(columnC.GetColID(), smartvectors.RightZeroPadded(columnCWit, flagSizeB)) + }, + }, + { + Name: "distribute-projection-multiple_projections-multi-columns", + DefineFunc: func(builder *wizard.Builder) { + flagA = builder.RegisterCommit(ifaces.ColID("moduleA.FilterA"), flagSizeA) + flagB = builder.RegisterCommit(ifaces.ColID("moduleB.FliterB"), flagSizeB) + flagC = builder.RegisterCommit(ifaces.ColID("moduleC.FliterB"), flagSizeB) + colA0 = builder.RegisterCommit(ifaces.ColID("moduleA.ColumnA0"), flagSizeA) + colA1 = builder.RegisterCommit(ifaces.ColID("moduleA.ColumnA1"), flagSizeA) + colA2 = builder.RegisterCommit(ifaces.ColID("moduleA.ColumnA2"), flagSizeA) + colB0 = builder.RegisterCommit(ifaces.ColID("moduleB.ColumnB0"), flagSizeB) + colB1 = builder.RegisterCommit(ifaces.ColID("moduleB.ColumnB1"), flagSizeB) + colB2 = builder.RegisterCommit(ifaces.ColID("moduleB.ColumnB2"), flagSizeB) + colC0 = builder.RegisterCommit(ifaces.ColID("moduleC.ColumnC0"), flagSizeB) + colC1 = builder.RegisterCommit(ifaces.ColID("moduleC.ColumnC1"), flagSizeB) + colC2 = builder.RegisterCommit(ifaces.ColID("moduleC.ColumnC2"), flagSizeB) + _ = builder.InsertProjection("ProjectionTest-A-B-Multicolum", + query.ProjectionInput{ColumnA: []ifaces.Column{colA0, colA1, colA2}, ColumnB: []ifaces.Column{colB0, colB1, colB2}, FilterA: flagA, FilterB: flagB}) + _ = builder.InsertProjection("ProjectionTest-C-A-Multicolumn", + query.ProjectionInput{ColumnA: []ifaces.Column{colC0, colC1, colC2}, ColumnB: []ifaces.Column{colA0, colA1, colA2}, FilterA: flagC, FilterB: flagA}) + + }, + InitialProverFunc: func(run *wizard.ProverRuntime) { + // assign filters and columns + var ( + flagAWit = make([]field.Element, flagSizeA) + flagBWit = make([]field.Element, flagSizeB) + flagCWit = make([]field.Element, flagSizeB) + colA0Wit = make([]field.Element, flagSizeA) + colA1Wit = make([]field.Element, flagSizeA) + colA2Wit = make([]field.Element, flagSizeA) + colB0Wit = make([]field.Element, flagSizeB) + colB1Wit = make([]field.Element, flagSizeB) + colB2Wit = make([]field.Element, flagSizeB) + colC0Wit = make([]field.Element, flagSizeB) + colC1Wit = make([]field.Element, flagSizeB) + colC2Wit = make([]field.Element, flagSizeB) + ) + for i := 0; i < 10; i++ { + flagAWit[i] = field.One() + colA0Wit[i] = field.NewElement(uint64(i)) + colA1Wit[i] = field.NewElement(uint64(i + 1)) + colA2Wit[i] = field.NewElement(uint64(i + 2)) + } + for i := flagSizeB - 10; i < flagSizeB; i++ { + flagBWit[i] = field.One() + flagCWit[i] = field.One() + colB0Wit[i] = field.NewElement(uint64(i - (flagSizeB - 10))) + colC0Wit[i] = field.NewElement(uint64(i - (flagSizeB - 10))) + colB1Wit[i] = field.NewElement(uint64(i + 1 - (flagSizeB - 10))) + colC1Wit[i] = field.NewElement(uint64(i + 1 - (flagSizeB - 10))) + colB2Wit[i] = field.NewElement(uint64(i + 2 - (flagSizeB - 10))) + colC2Wit[i] = field.NewElement(uint64(i + 2 - (flagSizeB - 10))) + } + run.AssignColumn(flagA.GetColID(), smartvectors.RightZeroPadded(flagAWit, flagSizeA)) + run.AssignColumn(flagB.GetColID(), smartvectors.RightZeroPadded(flagBWit, flagSizeB)) + run.AssignColumn(flagC.GetColID(), smartvectors.RightZeroPadded(flagCWit, flagSizeB)) + run.AssignColumn(colA0.GetColID(), smartvectors.RightZeroPadded(colA0Wit, flagSizeA)) + run.AssignColumn(colA1.GetColID(), smartvectors.RightZeroPadded(colA1Wit, flagSizeA)) + run.AssignColumn(colA2.GetColID(), smartvectors.RightZeroPadded(colA2Wit, flagSizeA)) + run.AssignColumn(colB0.GetColID(), smartvectors.RightZeroPadded(colB0Wit, flagSizeB)) + run.AssignColumn(colB1.GetColID(), smartvectors.RightZeroPadded(colB1Wit, flagSizeB)) + run.AssignColumn(colB2.GetColID(), smartvectors.RightZeroPadded(colB2Wit, flagSizeB)) + run.AssignColumn(colC0.GetColID(), smartvectors.RightZeroPadded(colC0Wit, flagSizeB)) + run.AssignColumn(colC1.GetColID(), smartvectors.RightZeroPadded(colC1Wit, flagSizeB)) + run.AssignColumn(colC2.GetColID(), smartvectors.RightZeroPadded(colC2Wit, flagSizeB)) + }, + }, + } + for _, tc := range testcases { + + t.Run(tc.Name, func(t *testing.T) { + // This function assigns the initial module and is aimed at working + // for all test-case. + initialProve := tc.InitialProverFunc + + // initialComp is defined according to the define function provided by the + // test-case. + initialComp := wizard.Compile(tc.DefineFunc) + + disc := namebaseddiscoverer.PeriodSeperatingModuleDiscoverer{} + disc.Analyze(initialComp) + + // This declares a compiled IOP with only the columns of the module A + moduleAComp := distributed.GetFreshModuleComp(initialComp, &disc, moduleAName) + dist_projection.NewDistributeProjectionCtx(moduleAName, initialComp, moduleAComp, &disc) + + wizard.ContinueCompilation(moduleAComp, distributedprojection.CompileDistributedProjection, dummy.CompileAtProverLvl) + + // This runs the initial prover + initialRuntime := wizard.RunProver(initialComp, initialProve) + + proof := wizard.Prove(moduleAComp, func(run *wizard.ProverRuntime) { + run.ParentRuntime = initialRuntime + }) + valid := wizard.Verify(moduleAComp, proof) + require.NoError(t, valid) + + }) + + } + +} diff --git a/prover/protocol/distributed/conglomeration/translator.go b/prover/protocol/distributed/conglomeration/translator.go index 4123cf764..e7e3bb960 100644 --- a/prover/protocol/distributed/conglomeration/translator.go +++ b/prover/protocol/distributed/conglomeration/translator.go @@ -264,6 +264,11 @@ func (run *runtimeTranslator) GetGrandProductParams(name ifaces.QueryID) query.G return run.Rt.GetGrandProductParams(name) } +func (run *runtimeTranslator) GetDistributedProjectionParams(name ifaces.QueryID) query.DistributedProjectionParams { + name = ifaces.QueryID(run.Prefix) + "." + name + return run.Rt.GetDistributedProjectionParams(name) +} + func (run *runtimeTranslator) GetLogDerivSumParams(name ifaces.QueryID) query.LogDerivSumParams { name = ifaces.QueryID(run.Prefix) + "." + name return run.Rt.GetLogDerivSumParams(name) @@ -361,6 +366,11 @@ func (run *gnarkRuntimeTranslator) GetLogDerivSumParams(name ifaces.QueryID) que return run.Rt.GetLogDerivSumParams(name) } +func (run *gnarkRuntimeTranslator) GetDistributedProjectionParams(name ifaces.QueryID) query.GnarkDistributedProjectionParams { + name = ifaces.QueryID(run.Prefix) + "." + name + return run.Rt.GetDistributedProjectionParams(name) +} + func (run *gnarkRuntimeTranslator) GetLocalPointEvalParams(name ifaces.QueryID) query.GnarkLocalOpeningParams { name = ifaces.QueryID(run.Prefix) + "." + name return run.Rt.GetLocalPointEvalParams(name) diff --git a/prover/protocol/distributed/lpp/lpp_test.go b/prover/protocol/distributed/lpp/lpp_test.go index 49edf4fc5..5946c0b2d 100644 --- a/prover/protocol/distributed/lpp/lpp_test.go +++ b/prover/protocol/distributed/lpp/lpp_test.go @@ -186,7 +186,7 @@ func TestSeedGeneration(t *testing.T) { run.ProverID = proverID }) - // get and compar the coins with the other segments/modules + // get and compare the coins with the other segments/modules coin2Gamma := runtime2.Coins.MustGet("TABLE_module1.col5,module1.col1,module1.col4_LOGDERIVATIVE_GAMMA_FieldFromSeed").(field.Element) coin2Alpha := runtime2.Coins.MustGet("TABLE_module1.col5,module1.col1,module1.col4_LOGDERIVATIVE_ALPHA_FieldFromSeed").(field.Element) diff --git a/prover/protocol/query/distributed_projection.go b/prover/protocol/query/distributed_projection.go new file mode 100644 index 000000000..334ce572e --- /dev/null +++ b/prover/protocol/query/distributed_projection.go @@ -0,0 +1,123 @@ +package query + +import ( + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/linea-monorepo/prover/crypto/fiatshamir" + "github.com/consensys/linea-monorepo/prover/maths/common/poly" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/symbolic" + "github.com/consensys/linea-monorepo/prover/utils" +) + +type DistributedProjectionInput struct { + ColumnA, ColumnB *symbolic.Expression + FilterA, FilterB *symbolic.Expression + SizeA, SizeB int + EvalCoin coin.Name + IsAInModule, IsBInModule bool +} + +type DistributedProjection struct { + Round int + ID ifaces.QueryID + Inp []*DistributedProjectionInput +} + +type DistributedProjectionParams struct { + HornerVal field.Element +} + +func NewDistributedProjection(round int, id ifaces.QueryID, inp []*DistributedProjectionInput) DistributedProjection { + for _, in := range inp { + if err := in.ColumnA.Validate(); err != nil { + utils.Panic("ColumnA for the distributed projection query %v is not a valid expression", id) + } + if err := in.ColumnB.Validate(); err != nil { + utils.Panic("ColumnB for the distributed projection query %v is not a valid expression", id) + } + if err := in.FilterA.Validate(); err != nil { + utils.Panic("FilterA for the distributed projection query %v is not a valid expression", id) + } + if err := in.FilterB.Validate(); err != nil { + utils.Panic("FilterB for the distributed projection query %v is not a valid expression", id) + } + if !in.IsAInModule && !in.IsBInModule { + utils.Panic("Invalid distributed projection query %v, both A and B are not in the module", id) + } + } + return DistributedProjection{Round: round, ID: id, Inp: inp} +} + +// Constructor for distributed projection query parameters +func NewDistributedProjectionParams(hornerVal field.Element) DistributedProjectionParams { + return DistributedProjectionParams{HornerVal: hornerVal} +} + +// Name returns the unique identifier of the GrandProduct query. +func (dp DistributedProjection) Name() ifaces.QueryID { + return dp.ID +} + +// Updates a Fiat-Shamir state +func (dpp DistributedProjectionParams) UpdateFS(fs *fiatshamir.State) { + fs.Update(dpp.HornerVal) +} + +func (dp DistributedProjection) Check(run ifaces.Runtime) error { + var ( + actualParam = field.Zero() + params = run.GetParams(dp.ID).(DistributedProjectionParams) + evalRand field.Element + ) + _, errBeta := evalRand.SetRandom() + if errBeta != nil { + // Cannot happen unless the entropy was exhausted + panic(errBeta) + } + for _, inp := range dp.Inp { + var ( + colABoard = inp.ColumnA.Board() + colBBoard = inp.ColumnB.Board() + filterABorad = inp.FilterA.Board() + filterBBoard = inp.FilterB.Board() + colA = column.EvalExprColumn(run, colABoard).IntoRegVecSaveAlloc() + colB = column.EvalExprColumn(run, colBBoard).IntoRegVecSaveAlloc() + filterA = column.EvalExprColumn(run, filterABorad).IntoRegVecSaveAlloc() + filterB = column.EvalExprColumn(run, filterBBoard).IntoRegVecSaveAlloc() + elemParam = field.One() + ) + if inp.IsAInModule && !inp.IsBInModule { + hornerA := poly.GetHornerTrace(colA, filterA, evalRand) + elemParam = hornerA[0] + } else if !inp.IsAInModule && inp.IsBInModule { + hornerB := poly.GetHornerTrace(colB, filterB, evalRand) + elemParam = hornerB[0] + elemParam.Neg(&elemParam) + } else if inp.IsAInModule && inp.IsBInModule { + hornerA := poly.GetHornerTrace(colA, filterA, evalRand) + hornerB := poly.GetHornerTrace(colB, filterB, evalRand) + elemParam = hornerB[0] + elemParam.Neg(&elemParam) + elemParam.Add(&elemParam, &hornerA[0]) + } else { + utils.Panic("Invalid distributed projection query %v", dp.ID) + } + actualParam.Add(&actualParam, &elemParam) + + } + + if actualParam != params.HornerVal { + return fmt.Errorf("the distributed projection query %v is not satisfied, actualParam = %v, param.HornerVal = %v", dp.ID, actualParam, params.HornerVal) + } + + return nil +} + +func (dp DistributedProjection) CheckGnark(api frontend.API, run ifaces.GnarkRuntime) { + panic("UNSUPPORTED : can't check a Projection query directly into the circuit") +} diff --git a/prover/protocol/query/distributed_projection_test.go b/prover/protocol/query/distributed_projection_test.go new file mode 100644 index 000000000..eb4b7c67a --- /dev/null +++ b/prover/protocol/query/distributed_projection_test.go @@ -0,0 +1,95 @@ +package query_test + +import ( + "testing" + + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" +) + +func TestDistributedProjectionQuery(t *testing.T) { + var ( + runS *wizard.ProverRuntime + DP ifaces.Query + round = 0 + flagSizeA = 512 + flagSizeB = 256 + flagA, flagB, columnA, columnB ifaces.Column + flagAWit = make([]field.Element, flagSizeA) + columnAWit = make([]field.Element, flagSizeA) + flagBWit = make([]field.Element, flagSizeB) + columnBWit = make([]field.Element, flagSizeB) + queryNameBothAAndB = ifaces.QueryID("DistributedProjectionTestBothAAndB") + ) + // Computing common test data + + // assign filters and columns + for i := 0; i < 10; i++ { + flagAWit[i] = field.One() + columnAWit[i] = field.NewElement(uint64(i)) + } + for i := flagSizeB - 10; i < flagSizeB; i++ { + flagBWit[i] = field.One() + columnBWit[i] = field.NewElement(uint64(i - (flagSizeB - 10))) + } + + testcases := []struct { + Name string + HornerParam field.Element + QueryName ifaces.QueryID + DefineFunc func(builder *wizard.Builder) + ProverFunc func(run *wizard.ProverRuntime) + }{ + { + Name: "distributed-projection-both-A-and-B", + QueryName: queryNameBothAAndB, + DefineFunc: func(builder *wizard.Builder) { + flagA = builder.RegisterCommit(ifaces.ColID("FilterA"), flagSizeA) + flagB = builder.RegisterCommit(ifaces.ColID("FliterB"), flagSizeB) + columnA = builder.RegisterCommit(ifaces.ColID("ColumnA"), flagSizeA) + columnB = builder.RegisterCommit(ifaces.ColID("ColumnB"), flagSizeB) + var ( + colA, _, _ = wizardutils.AsExpr(columnA) + colB, _, _ = wizardutils.AsExpr(columnB) + fA, _, _ = wizardutils.AsExpr(flagA) + fB, _, _ = wizardutils.AsExpr(flagB) + ) + DP = builder.CompiledIOP.InsertDistributedProjection(round, queryNameBothAAndB, + []*query.DistributedProjectionInput{ + {ColumnA: colA, ColumnB: colB, FilterA: fA, FilterB: fB, IsAInModule: true, IsBInModule: true}, + }) + }, + ProverFunc: func(run *wizard.ProverRuntime) { + runS = run + run.AssignColumn(flagA.GetColID(), smartvectors.RightZeroPadded(flagAWit, flagSizeA)) + run.AssignColumn(flagB.GetColID(), smartvectors.RightZeroPadded(flagBWit, flagSizeB)) + run.AssignColumn(columnB.GetColID(), smartvectors.RightZeroPadded(columnBWit, flagSizeB)) + run.AssignColumn(columnA.GetColID(), smartvectors.RightZeroPadded(columnAWit, flagSizeA)) + + runS.AssignDistributedProjection(queryNameBothAAndB, query.DistributedProjectionParams{HornerVal: field.Zero()}) + }, + }, + } + + for _, tc := range testcases { + + t.Run(tc.Name, func(t *testing.T) { + prover := tc.ProverFunc + var ( + comp = wizard.Compile(tc.DefineFunc) + _ = wizard.Prove(comp, prover) + errDP = DP.Check(runS) + ) + + if errDP != nil { + t.Fatalf("error verifying the distributed projection query: %v", errDP.Error()) + } + + }) + } + +} diff --git a/prover/protocol/query/gnark_params.go b/prover/protocol/query/gnark_params.go index 448d011a2..c0da87154 100644 --- a/prover/protocol/query/gnark_params.go +++ b/prover/protocol/query/gnark_params.go @@ -25,6 +25,11 @@ type GnarkGrandProductParams struct { Prod frontend.Variable } +// A gnark circuit version of DistributedProjectionParams +type GnarkDistributedProjectionParams struct { + Sum frontend.Variable +} + func (p LogDerivSumParams) GnarkAssign() GnarkLogDerivSumParams { return GnarkLogDerivSumParams{Sum: p.Sum} } diff --git a/prover/protocol/query/projection.go b/prover/protocol/query/projection.go index 45f88eb0f..2b1df8068 100644 --- a/prover/protocol/query/projection.go +++ b/prover/protocol/query/projection.go @@ -90,8 +90,8 @@ func (p Projection) Check(run ifaces.Runtime) error { bLinComb[row] = rowLinComb(linCombRand, row, b) } var ( - hornerA = poly.CmptHorner(aLinComb, fA, evalRand) - hornerB = poly.CmptHorner(bLinComb, fB, evalRand) + hornerA = poly.GetHornerTrace(aLinComb, fA, evalRand) + hornerB = poly.GetHornerTrace(bLinComb, fB, evalRand) ) if hornerA[0] != hornerB[0] { return fmt.Errorf("the projection query %v check is not satisfied", p.ID) diff --git a/prover/protocol/query/projection_test.go b/prover/protocol/query/projection_test.go index bd1a307fa..2fcc52935 100644 --- a/prover/protocol/query/projection_test.go +++ b/prover/protocol/query/projection_test.go @@ -33,10 +33,13 @@ func TestProjection(t *testing.T) { prover := func(run *wizard.ProverRuntime) { runS = run // assign filters and columns - flagAWit := make([]field.Element, flagSizeA) - columnAWit := make([]field.Element, flagSizeA) - flagBWit := make([]field.Element, flagSizeB) - columnBWit := make([]field.Element, flagSizeB) + var ( + flagAWit = make([]field.Element, flagSizeA) + columnAWit = make([]field.Element, flagSizeA) + flagBWit = make([]field.Element, flagSizeB) + columnBWit = make([]field.Element, flagSizeB) + ) + for i := 0; i < 10; i++ { flagAWit[i] = field.One() columnAWit[i] = field.NewElement(uint64(i)) diff --git a/prover/protocol/wizard/compiled.go b/prover/protocol/wizard/compiled.go index 65f11771a..bc98746e2 100644 --- a/prover/protocol/wizard/compiled.go +++ b/prover/protocol/wizard/compiled.go @@ -646,6 +646,14 @@ func (c *CompiledIOP) InsertProjection(id ifaces.QueryID, in query.ProjectionInp return q } +// Register a distributed projection query +func (c *CompiledIOP) InsertDistributedProjection(round int, id ifaces.QueryID, in []*query.DistributedProjectionInput) query.DistributedProjection { + q := query.NewDistributedProjection(round, id, in) + // Finally registers the query + c.QueriesParams.AddToRound(round, q.Name(), q) + return q +} + // AddPublicInput inserts a public-input in the compiled-IOP func (c *CompiledIOP) InsertPublicInput(name string, acc ifaces.Accessor) PublicInput { diff --git a/prover/protocol/wizard/gnark_verifier.go b/prover/protocol/wizard/gnark_verifier.go index fc2c1292f..778de2360 100644 --- a/prover/protocol/wizard/gnark_verifier.go +++ b/prover/protocol/wizard/gnark_verifier.go @@ -23,6 +23,7 @@ type GnarkRuntime interface { GetSpec() *CompiledIOP GetPublicInput(api frontend.API, name string) frontend.Variable GetGrandProductParams(name ifaces.QueryID) query.GnarkGrandProductParams + GetDistributedProjectionParams(name ifaces.QueryID) query.GnarkDistributedProjectionParams GetLogDerivSumParams(name ifaces.QueryID) query.GnarkLogDerivSumParams GetLocalPointEvalParams(name ifaces.QueryID) query.GnarkLocalOpeningParams GetInnerProductParams(name ifaces.QueryID) query.GnarkInnerProductParams @@ -73,6 +74,8 @@ type WizardVerifierCircuit struct { logDerivSumIDs collection.Mapping[ifaces.QueryID, int] `gnark:"-"` // Same for grand-product query grandProductIDs collection.Mapping[ifaces.QueryID, int] `gnark:"-"` + // Same for distributed projection query + distributedProjectionIDs collection.Mapping[ifaces.QueryID, int] `gnark:"-"` // Columns stores the gnark witness part corresponding to the columns // provided in the proof and in the VerifyingKey. @@ -95,6 +98,9 @@ type WizardVerifierCircuit struct { // GrandProductParams stores an assignment for each [query.GrandProductParams] // from the proof. It is part of the witness of the gnark circuit. GrandProductParams []query.GnarkGrandProductParams `gnark:",secret"` + // DistributedProjectionParams stores an assignment for each [query.DistributedProjectionParams] + // from the proof. It is part of the witness of the gnark circuit. + DistributedProjectionParams []query.GnarkDistributedProjectionParams `gnark:",secret"` // FS is the Fiat-Shamir state, mirroring [VerifierRuntime.FS]. The same // cautionnary rules apply to it; e.g. don't use it externally when @@ -369,6 +375,14 @@ func (c *WizardVerifierCircuit) GetGrandProductParams(name ifaces.QueryID) query return c.GrandProductParams[qID] } +// GetDistributedProjectionParams returns the parameters for the requested +// [query.DistributedProjection] query. Its work mirrors the function +// [VerifierRuntime.GetDistributedProjectionParams] +func (c *WizardVerifierCircuit) GetDistributedProjectionParams(name ifaces.QueryID) query.GnarkDistributedProjectionParams { + qID := c.distributedProjectionIDs.MustGet(name) + return c.DistributedProjectionParams[qID] +} + // GetColumns returns the gnark assignment of a column in a gnark circuit. It // mirrors the function [VerifierRuntime.GetColumn] func (c *WizardVerifierCircuit) GetColumn(name ifaces.ColID) []frontend.Variable { diff --git a/prover/protocol/wizard/prover.go b/prover/protocol/wizard/prover.go index 5b1aa151e..9084b2e5b 100644 --- a/prover/protocol/wizard/prover.go +++ b/prover/protocol/wizard/prover.go @@ -831,3 +831,24 @@ func (run *ProverRuntime) AssignGrandProduct(name ifaces.QueryID, y field.Elemen params := query.NewGrandProductParams(y) run.QueriesParams.InsertNew(name, params) } + +// AssignDistributedProjection assigns the value horner(A, fA) if A +// is in the module, horner(B, fB)^{-1} if B is in the module, and +// horner(A, fA) * horner(B, fB)^{-1} if both are in the module. +// The function will panic if: +// - the parameters were already assigned +// - the specified query is not registered +// - the assignment round is incorrect +func (run *ProverRuntime) AssignDistributedProjection(name ifaces.QueryID, distributedProjectionParam query.DistributedProjectionParams) { + + // Global prover locks for accessing the maps + run.lock.Lock() + defer run.lock.Unlock() + + // Make sure, it is done at the right round + run.Spec.QueriesParams.MustBeInRound(run.currRound, name) + + // Adds it to the assignments + params := query.NewDistributedProjectionParams(distributedProjectionParam.HornerVal) + run.QueriesParams.InsertNew(name, params) +} diff --git a/prover/protocol/wizard/verifier.go b/prover/protocol/wizard/verifier.go index aaacfe165..e598b9f4f 100644 --- a/prover/protocol/wizard/verifier.go +++ b/prover/protocol/wizard/verifier.go @@ -41,6 +41,7 @@ type Runtime interface { GetSpec() *CompiledIOP GetPublicInput(name string) field.Element GetGrandProductParams(name ifaces.QueryID) query.GrandProductParams + GetDistributedProjectionParams(name ifaces.QueryID) query.DistributedProjectionParams GetLogDerivSumParams(name ifaces.QueryID) query.LogDerivSumParams GetLocalPointEvalParams(name ifaces.QueryID) query.LocalOpeningParams GetInnerProductParams(name ifaces.QueryID) query.InnerProductParams @@ -433,6 +434,11 @@ func (run *VerifierRuntime) GetGrandProductParams(name ifaces.QueryID) query.Gra return run.QueriesParams.MustGet(name).(query.GrandProductParams) } +// GetGrandProductParams returns the parameters of a [query.DistributedProjection] +func (run *VerifierRuntime) GetDistributedProjectionParams(name ifaces.QueryID) query.DistributedProjectionParams { + return run.QueriesParams.MustGet(name).(query.DistributedProjectionParams) +} + /* CopyColumnInto implements `column.GetWitness` Copies the witness into a slice From b4093d479a89d314efd843f76a0825403dbda500 Mon Sep 17 00:00:00 2001 From: Azam Soleimanian <49027816+Soleimani193@users.noreply.github.com> Date: Mon, 17 Feb 2025 11:08:36 +0100 Subject: [PATCH 3/4] Prover/distribution of global queries (#647) * distribution of global queries * wip: added the boundary checks * testing boundaries for simple cases * test for complex queries * fixed provider assignment --- .../stitch_split/splitter/constraints.go | 2 +- prover/protocol/distributed/comp_splitting.go | 225 +++++++++++++ .../distributed/compiler/global/global.go | 312 +++++++++++++++++- .../compiler/global/global_test.go | 95 ++++++ .../distributed/compiler/global/prover.go | 28 ++ .../cross_segment_consistency.go | 2 +- .../period_separating_module_discoverer.go | 8 +- prover/protocol/wizard/verifier.go | 2 +- 8 files changed, 666 insertions(+), 8 deletions(-) create mode 100644 prover/protocol/distributed/compiler/global/global_test.go create mode 100644 prover/protocol/distributed/compiler/global/prover.go diff --git a/prover/protocol/compiler/stitch_split/splitter/constraints.go b/prover/protocol/compiler/stitch_split/splitter/constraints.go index 7eb3787eb..63c8a700b 100644 --- a/prover/protocol/compiler/stitch_split/splitter/constraints.go +++ b/prover/protocol/compiler/stitch_split/splitter/constraints.go @@ -81,7 +81,7 @@ func (ctx splitterContext) LocalGlobalConstraints() { continue } - // if the associated expression is eligible to the stitching, mark the query, over the sub columns, as ignored. + // if the associated expression is eligible to the splitting, mark the query as ignored. ctx.comp.QueriesNoParams.MarkAsIgnored(qName) // adjust the query over the sub columns diff --git a/prover/protocol/distributed/comp_splitting.go b/prover/protocol/distributed/comp_splitting.go index f42b9e4cd..9100dda9d 100644 --- a/prover/protocol/distributed/comp_splitting.go +++ b/prover/protocol/distributed/comp_splitting.go @@ -1,8 +1,13 @@ package distributed import ( + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/symbolic" "github.com/consensys/linea-monorepo/prover/utils" ) @@ -90,3 +95,223 @@ func getSegmentFromWitness(wit ifaces.ColAssignment, numSegs, segID int) ifaces. segSize := wit.Len() / numSegs return wit.SubVector(segSize*segID, segSize*segID+segSize) } + +func GetFreshCompGL(in SegmentModuleInputs) *wizard.CompiledIOP { + + var ( + // initialize the segment CompiledIOP + segComp = wizard.NewCompiledIOP() + initialComp = in.InitialComp + glColumns = extractGLColumns(initialComp) + glColumnsInModule = []ifaces.Column{} + ) + + // get the GL columns + for _, colName := range initialComp.Columns.AllKeysAt(0) { + + col := initialComp.Columns.GetHandle(colName) + if !in.Disc.ColumnIsInModule(col, in.ModuleName) { + continue + } + + if isGLColumn(col, glColumns) { + // TBD: register at round 1, to create a separate commitment over GL. + segComp.InsertCommit(0, col.GetColID(), col.Size()/in.NumSegmentsInModule) + glColumnsInModule = append(glColumnsInModule, col) + } + + } + + // register provider and receiver + provider := segComp.InsertCommit(0, "PROVIDER", getSizeForProviderReceiver(initialComp)) + receiver := segComp.InsertCommit(0, "RECEIVER", getSizeForProviderReceiver(initialComp)) + + // create a new moduleProver + glProver := glProver{ + glCols: glColumnsInModule, + numSegments: in.NumSegmentsInModule, + provider: provider, + receiver: receiver, + } + + // register Prover action for the segment-module to assign columns per round + segComp.RegisterProverAction(0, glProver) + return segComp + +} + +type glProver struct { + glCols []ifaces.Column + numSegments int + provider ifaces.Column + receiver ifaces.Column +} + +func (p glProver) Run(run *wizard.ProverRuntime) { + if run.ParentRuntime == nil { + utils.Panic("invalid call: the runtime does not have a [ParentRuntime]") + } + if run.ProverID > p.numSegments { + panic("proverID can not be larger than number of segments") + } + + for _, col := range p.glCols { + // get the witness from the initialProver + colWitness := run.ParentRuntime.GetColumn(col.GetColID()) + colSegWitness := getSegmentFromWitness(colWitness, p.numSegments, run.ProverID) + // assign it in the module in the round col was declared + run.AssignColumn(col.GetColID(), colSegWitness, col.Round()) + } + // assign Provider and Receiver + assignProvider(run, run.ProverID, p.numSegments, p.provider) + // for the current segment, the receiver is the provider of the previous segment. + assignProvider(run, utils.PositiveMod(run.ProverID-1, p.numSegments), p.numSegments, p.receiver) + +} + +func extractGLColumns(comp *wizard.CompiledIOP) []ifaces.Column { + + glColumns := []ifaces.Column{} + // extract global queries + for _, queryID := range comp.QueriesNoParams.AllKeysAt(0) { + + if glob, ok := comp.QueriesNoParams.Data(queryID).(query.GlobalConstraint); ok { + glColumns = append(glColumns, ListColumnsFromExpr(glob.Expression, true)...) + } + + if local, ok := comp.QueriesNoParams.Data(queryID).(query.LocalConstraint); ok { + glColumns = append(glColumns, ListColumnsFromExpr(local.Expression, true)...) + } + } + + // extract localOpenings + return glColumns +} + +// ListColumnsFromExpr returns the natural version of all the columns in the expression. +// if natural is true, it return the natural version of the columns, +// otherwise it return the original columns. +func ListColumnsFromExpr(expr *symbolic.Expression, natural bool) []ifaces.Column { + + var ( + board = expr.Board() + metadata = board.ListVariableMetadata() + colList = []ifaces.Column{} + ) + + for _, m := range metadata { + switch t := m.(type) { + case ifaces.Column: + + if shifted, ok := t.(column.Shifted); ok && natural { + colList = append(colList, shifted.Parent) + } else { + colList = append(colList, t) + } + + } + } + return colList + +} + +func isGLColumn(col ifaces.Column, glColumns []ifaces.Column) bool { + + for _, glCol := range glColumns { + if col.GetColID() == glCol.GetColID() { + return true + } + } + return false + +} + +func getSizeForProviderReceiver(comp *wizard.CompiledIOP) int { + + numBoundaries := 0 + + for _, queryID := range comp.QueriesNoParams.AllKeysAt(0) { + + if global, ok := comp.QueriesNoParams.Data(queryID).(query.GlobalConstraint); ok { + + var ( + board = global.Board() + metadata = board.ListVariableMetadata() + maxShift = global.MinMaxOffset().Max + ) + + for _, m := range metadata { + switch t := m.(type) { + case ifaces.Column: + + if shifted, ok := t.(column.Shifted); ok { + // number of boundaries from the current column + numBoundaries += maxShift - column.StackOffsets(shifted) + + } else { + numBoundaries += maxShift + } + + } + } + } + } + return utils.NextPowerOfTwo(numBoundaries) +} + +// assignProvider mainly assigns the provider +// it also can be used for the receiver assignment, +// since the receiver of segment i equal to the provider of segment (i-1). +func assignProvider(run *wizard.ProverRuntime, segID, numSegments int, col ifaces.Column) { + + var ( + parentRuntime = run.ParentRuntime + initialComp = parentRuntime.Spec + allBoundaries = []field.Element{} + ) + + for _, q := range initialComp.QueriesNoParams.AllKeysAt(0) { + if global, ok := initialComp.QueriesNoParams.Data(q).(query.GlobalConstraint); ok { + + var ( + board = global.Board() + metadata = board.ListVariableMetadata() + maxShift = global.MinMaxOffset().Max + ) + + for _, m := range metadata { + switch t := m.(type) { + case ifaces.Column: + + var ( + segmentSize = t.Size() / numSegments + lastRow = (segID+1)*segmentSize - 1 + colWit []field.Element + // number of boundaries from the current column + numBoundaries = 0 + ) + + if shifted, ok := t.(column.Shifted); ok { + numBoundaries = maxShift - column.StackOffsets(shifted) + colWit = shifted.Parent.GetColAssignment(parentRuntime).IntoRegVecSaveAlloc() + + } else { + numBoundaries = maxShift + colWit = t.GetColAssignment(parentRuntime).IntoRegVecSaveAlloc() + } + + for i := lastRow - numBoundaries + 1; i <= lastRow; i++ { + + allBoundaries = append(allBoundaries, + colWit[i]) + + } + + } + + } + + } + } + run.AssignColumn(col.GetColID(), smartvectors.RightZeroPadded(allBoundaries, col.Size())) +} diff --git a/prover/protocol/distributed/compiler/global/global.go b/prover/protocol/distributed/compiler/global/global.go index 37f1d6070..35f7177a9 100644 --- a/prover/protocol/distributed/compiler/global/global.go +++ b/prover/protocol/distributed/compiler/global/global.go @@ -1,8 +1,312 @@ package global -import "github.com/consensys/linea-monorepo/prover/protocol/wizard" +import ( + "github.com/consensys/linea-monorepo/prover/protocol/accessors" + "github.com/consensys/linea-monorepo/prover/protocol/column" + "github.com/consensys/linea-monorepo/prover/protocol/column/verifiercol" + "github.com/consensys/linea-monorepo/prover/protocol/distributed" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/variables" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/symbolic" + "github.com/consensys/linea-monorepo/prover/utils" + "github.com/consensys/linea-monorepo/prover/utils/collection" +) -// CompileDist compiles the global queries distributedly. -func CompileDist(comp *wizard.CompiledIOP) { - panic("unimplemented") +type segmentID int + +type DistributionInputs struct { + ModuleComp *wizard.CompiledIOP + InitialComp *wizard.CompiledIOP + // module Discoverer used to detect the relevant part of the query to the module + Disc distributed.ModuleDiscoverer + // Name of the module + ModuleName distributed.ModuleName + // number of segments for the module + NumSegments int +} + +func DistributeGlobal(in DistributionInputs) { + + var ( + bInputs = boundaryInputs{ + moduleComp: in.ModuleComp, + numSegments: in.NumSegments, + provider: in.ModuleComp.Columns.GetHandle("PROVIDER"), + receiver: in.ModuleComp.Columns.GetHandle("RECEIVER"), + providerOpenings: []query.LocalOpening{}, + receiverOpenings: []query.LocalOpening{}, + } + ) + + for _, qName := range in.InitialComp.QueriesNoParams.AllUnignoredKeys() { + + q, ok := in.InitialComp.QueriesNoParams.Data(qName).(query.GlobalConstraint) + if !ok { + continue + } + + if in.Disc.ExpressionIsInModule(q.Expression, in.ModuleName) { + + // apply global constraint over the segment. + in.ModuleComp.InsertGlobal(0, + q.ID, + AdjustExpressionForGlobal(in.ModuleComp, q.Expression, in.NumSegments), + ) + + // collect the boundaries for provider and receiver + + BoundariesForProvider(&bInputs, q) + BoundariesForReceiver(&bInputs, q) + + } + + // @Azam the risk is that some global constraints may be skipped here. + // we can prevent this by tagging the query as ignored from the initialComp, + // and at the end make sure that no query has remained in initial CompiledIOP. + } + + in.ModuleComp.RegisterProverAction(0, ¶mAssignments{ + provider: bInputs.provider, + receiver: bInputs.receiver, + providerOpenings: bInputs.providerOpenings, + receiverOpenings: bInputs.receiverOpenings, + }) + +} + +type boundaryInputs struct { + moduleComp *wizard.CompiledIOP + numSegments int + provider ifaces.Column + receiver ifaces.Column + providerOpenings, receiverOpenings []query.LocalOpening +} + +func AdjustExpressionForGlobal( + comp *wizard.CompiledIOP, + expr *symbolic.Expression, + numSegments int, +) *symbolic.Expression { + + var ( + board = expr.Board() + metadatas = board.ListVariableMetadata() + translationMap = collection.NewMapping[string, *symbolic.Expression]() + colTranslation ifaces.Column + size = column.ExprIsOnSameLengthHandles(&board) + ) + + for _, metadata := range metadatas { + + // For each slot, get the expression obtained by replacing the commitment + // by the appropriated column. + + switch m := metadata.(type) { + case ifaces.Column: + + switch col := m.(type) { + case column.Natural: + colTranslation = comp.Columns.GetHandle(m.GetColID()) + + case verifiercol.VerifierCol: + // panic happens specially for the case of FromAccessors + panic("unsupported for now, unless module discoverer can capture such columns") + + case column.Shifted: + colTranslation = column.Shift(comp.Columns.GetHandle(col.Parent.GetColID()), col.Offset) + + } + + translationMap.InsertNew(m.String(), ifaces.ColumnAsVariable(colTranslation)) + case variables.X: + utils.Panic("unsupported, the value of `x` in the unsplit query and the split would be different") + case variables.PeriodicSample: + // Check that the period is not larger than the domain size. If + // the period is smaller this is a no-op because the period does + // not change. + segSize := size / numSegments + + if m.T > segSize { + + panic("unsupported, since this depends on the segment ID, unless the module discoverer can detect such cases") + } + translationMap.InsertNew(m.String(), symbolic.NewVariable(metadata)) + default: + // Repass the same variable (for coins or other types of single-valued variable) + translationMap.InsertNew(m.String(), symbolic.NewVariable(metadata)) + } + + } + return expr.Replay(translationMap) +} + +func BoundariesForProvider(in *boundaryInputs, q query.GlobalConstraint) { + + var ( + board = q.Board() + offsetRange = q.MinMaxOffset() + provider = in.provider + maxShift = offsetRange.Max + colsInExpr = distributed.ListColumnsFromExpr(q.Expression, false) + colsOnProvider = onBoundaries(colsInExpr, maxShift) + numBoundaries = offsetRange.Max - offsetRange.Min + size = column.ExprIsOnSameLengthHandles(&board) + segSize = size / in.numSegments + ) + for _, col := range colsInExpr { + for i := 0; i < numBoundaries; i++ { + if colsOnProvider.Exists(col.GetColID()) { + + pos := colsOnProvider.MustGet(col.GetColID()) + + if i < maxShift-column.StackOffsets(col) { + // take from provider, since the size of the provider is different from size of the expression + // take it via accessor. + var ( + index = pos[0] + i + name = ifaces.QueryIDf("%v_%v", "FROM_PROVIDER_AT", index) + loProvider = in.moduleComp.InsertLocalOpening(0, name, column.Shift(provider, index)) + accessorProvider = accessors.NewLocalOpeningAccessor(loProvider, 0) + indexOnCol = segSize - (maxShift - column.StackOffsets(col) - i) + nameExpr = ifaces.QueryIDf("%v_%v_%v", "CONSISTENCY_AGAINST_PROVIDER", col.GetColID(), i) + colInModule ifaces.Column + ) + + // replace col with its replacement in the module. + if shifted, ok := col.(column.Shifted); ok { + colInModule = in.moduleComp.Columns.GetHandle(shifted.Parent.GetColID()) + } else { + colInModule = in.moduleComp.Columns.GetHandle(col.GetColID()) + } + + // add the localOpening to the list + in.providerOpenings = append(in.providerOpenings, loProvider) + // impose that loProvider = loCol + in.moduleComp.InsertLocal(0, nameExpr, + symbolic.Sub(accessorProvider, column.Shift(colInModule, indexOnCol)), + ) + + } + } + } + } + +} + +func BoundariesForReceiver(in *boundaryInputs, q query.GlobalConstraint) { + + var ( + offsetRange = q.MinMaxOffset() + receiver = in.receiver + maxShift = offsetRange.Max + colsInExpr = distributed.ListColumnsFromExpr(q.Expression, false) + colsOnReceiver = onBoundaries(colsInExpr, maxShift) + numBoundaries = offsetRange.Max - offsetRange.Min + comp = in.moduleComp + colInModule ifaces.Column + // list of local openings by the boundary index + allLists = make([][]query.LocalOpening, numBoundaries) + ) + + for i := 0; i < numBoundaries; i++ { + + translationMap := collection.NewMapping[string, *symbolic.Expression]() + + for _, col := range colsInExpr { + + // replace col with its replacement in the module. + if shifted, ok := col.(column.Shifted); ok { + colInModule = in.moduleComp.Columns.GetHandle(shifted.Parent.GetColID()) + } else { + colInModule = in.moduleComp.Columns.GetHandle(col.GetColID()) + } + + if colsOnReceiver.Exists(col.GetColID()) { + pos := colsOnReceiver.MustGet(col.GetColID()) + + if i < maxShift-column.StackOffsets(col) { + // take from receiver, since the size of the receiver is different from size of the expression + // take it via accessor. + var ( + index = pos[0] + i + name = ifaces.QueryIDf("%v_%v", "FROM_RECEIVER_AT", index) + lo = comp.InsertLocalOpening(0, name, column.Shift(receiver, index)) + accessor = accessors.NewLocalOpeningAccessor(lo, 0) + ) + // add the localOpening to the list + allLists[i] = append(allLists[i], lo) + // in.receiverOpenings = append(in.receiverOpenings, lo) + // translate the column + translationMap.InsertNew(string(col.GetColID()), accessor.AsVariable()) + } else { + // take the rest from the column + tookFromReceiver := maxShift - column.StackOffsets(col) + translationMap.InsertNew(string(col.GetColID()), ifaces.ColumnAsVariable(column.Shift(colInModule, i-tookFromReceiver))) + } + + } else { + translationMap.InsertNew(string(col.GetColID()), ifaces.ColumnAsVariable((column.Shift(colInModule, i)))) + } + + } + + expr := q.Expression.Replay(translationMap) + name := ifaces.QueryIDf("%v_%v_%v", "CONSISTENCY_AGAINST_RECEIVER", q.ID, i) + comp.InsertLocal(0, name, expr) + + } + + // order receiverOpenings column by column + for i := 0; i < numBoundaries; i++ { + for _, list := range allLists { + if len(list) > i { + in.receiverOpenings = append(in.receiverOpenings, list[i]) + } + } + } + +} + +// it indicates the column list having the provider cells (i.e., +// some cells of the columns are needed to be provided to the next segment) +func onBoundaries(colsInExpr []ifaces.Column, maxShift int) collection.Mapping[ifaces.ColID, [2]int] { + + var ( + ctr = 0 + colsOnReceiver = collection.NewMapping[ifaces.ColID, [2]int]() + ) + for _, col := range colsInExpr { + // number of boundaries from the column (that falls on the receiver) is + // maxShift - column.StackOffsets(col) + newCtr := ctr + maxShift - column.StackOffsets(col) + + // it does not have any cell on the receiver. + if newCtr == ctr { + continue + } + + colsOnReceiver.InsertNew(col.GetColID(), [2]int{ctr, newCtr}) + ctr = newCtr + + } + + return colsOnReceiver + +} + +// it generates natural verifier columns, from a given verifier column +func createVerifierColForModule(col ifaces.Column, numSegments int) ifaces.Column { + + if vcol, ok := col.(verifiercol.VerifierCol); ok { + + switch v := vcol.(type) { + case verifiercol.ConstCol: + return verifiercol.NewConstantCol(v.F, v.Size()/numSegments) + default: + panic("unsupported") + } + } + return nil } diff --git a/prover/protocol/distributed/compiler/global/global_test.go b/prover/protocol/distributed/compiler/global/global_test.go new file mode 100644 index 000000000..20b4e44b6 --- /dev/null +++ b/prover/protocol/distributed/compiler/global/global_test.go @@ -0,0 +1,95 @@ +package global_test + +import ( + "testing" + + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/protocol/column" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy" + "github.com/consensys/linea-monorepo/prover/protocol/distributed" + "github.com/consensys/linea-monorepo/prover/protocol/distributed/compiler/global" + md "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/symbolic" + "github.com/stretchr/testify/require" +) + +// It tests DistributedLogDerivSum. +func TestDistributedGlobal(t *testing.T) { + const ( + numSegModule = 2 + ) + + //initialComp + define := func(b *wizard.Builder) { + + var ( + col0 = b.CompiledIOP.InsertCommit(0, "module.col0", 8) + col1 = b.CompiledIOP.InsertCommit(0, "module.col1", 8) + col2 = b.CompiledIOP.InsertCommit(0, "module.col2", 8) + col3 = b.CompiledIOP.InsertCommit(0, "module.col3", 8) + ) + + b.CompiledIOP.InsertGlobal(0, "global0", + symbolic.Sub( + col1, column.Shift(col0, 3), + symbolic.Mul(2, column.Shift(col2, 2)), + symbolic.Neg(column.Shift(col3, 3)), + ), + ) + + } + + // initialProver + prover := func(run *wizard.ProverRuntime) { + run.AssignColumn("module.col0", smartvectors.ForTest(3, 0, 2, 1, 4, 1, 13, 0)) + run.AssignColumn("module.col1", smartvectors.ForTest(1, 7, 1, 11, 2, 1, 0, 2)) + run.AssignColumn("module.col2", smartvectors.ForTest(7, 0, 1, 3, 0, 4, 1, 0)) + run.AssignColumn("module.col3", smartvectors.ForTest(2, 14, 0, 2, 3, 0, 10, 0)) + + } + + // initial compiledIOP is the parent to all the SegmentModuleComp objects. + initialComp := wizard.Compile(define) + + // Initialize the period separating module discoverer + disc := &md.PeriodSeperatingModuleDiscoverer{} + disc.Analyze(initialComp) + + // distribute the columns among modules and segments. + moduleComp := distributed.GetFreshCompGL( + distributed.SegmentModuleInputs{ + InitialComp: initialComp, + Disc: disc, + ModuleName: "module", + NumSegmentsInModule: numSegModule, + }, + ) + + // distribute the query among segments. + global.DistributeGlobal(global.DistributionInputs{ + ModuleComp: moduleComp, + InitialComp: initialComp, + Disc: disc, + ModuleName: "module", + NumSegments: numSegModule, + }) + + // This dummy compiles the global/local queries of the segment. + wizard.ContinueCompilation(moduleComp, dummy.Compile) + + // run the initial runtime + initialRuntime := wizard.ProverOnlyFirstRound(initialComp, prover) + + // Compile and prove for module + for proverID := 0; proverID < numSegModule; proverID++ { + proof := wizard.Prove(moduleComp, func(run *wizard.ProverRuntime) { + run.ParentRuntime = initialRuntime + // inputs for vertical splitting of the witness + run.ProverID = proverID + }) + valid := wizard.Verify(moduleComp, proof) + require.NoError(t, valid) + } + +} diff --git a/prover/protocol/distributed/compiler/global/prover.go b/prover/protocol/distributed/compiler/global/prover.go new file mode 100644 index 000000000..4cc1ce0dc --- /dev/null +++ b/prover/protocol/distributed/compiler/global/prover.go @@ -0,0 +1,28 @@ +package global + +import ( + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" +) + +type paramAssignments struct { + provider ifaces.Column + receiver ifaces.Column + providerOpenings []query.LocalOpening + receiverOpenings []query.LocalOpening +} + +// it assigns the LocalOpening for the segment. +func (pa paramAssignments) Run(run *wizard.ProverRuntime) { + var ( + providerWit = run.GetColumn(pa.provider.GetColID()).IntoRegVecSaveAlloc() + receiverWit = run.GetColumn(pa.receiver.GetColID()).IntoRegVecSaveAlloc() + ) + + for i := range pa.providerOpenings { + + run.AssignLocalPoint(pa.providerOpenings[i].ID, providerWit[i]) + run.AssignLocalPoint(pa.receiverOpenings[i].ID, receiverWit[i]) + } +} diff --git a/prover/protocol/distributed/conglomeration/cross_segment_consistency.go b/prover/protocol/distributed/conglomeration/cross_segment_consistency.go index 431aa9787..911c2ef06 100644 --- a/prover/protocol/distributed/conglomeration/cross_segment_consistency.go +++ b/prover/protocol/distributed/conglomeration/cross_segment_consistency.go @@ -10,7 +10,7 @@ import ( "github.com/consensys/linea-monorepo/prover/protocol/wizard" ) -// crossSegmentCheclk is a verifier action that performs cross-segment checks: +// crossSegmentCheck is a verifier action that performs cross-segment checks: // for instance, it checks that the log-derivative sums all sums to 0 and that // the grand product is 1. The goal is to ensure that the lookups, permutations // in the original protocol are satisfied. diff --git a/prover/protocol/distributed/namebaseddiscoverer/period_separating_module_discoverer.go b/prover/protocol/distributed/namebaseddiscoverer/period_separating_module_discoverer.go index 54dfc1c7d..2008db369 100644 --- a/prover/protocol/distributed/namebaseddiscoverer/period_separating_module_discoverer.go +++ b/prover/protocol/distributed/namebaseddiscoverer/period_separating_module_discoverer.go @@ -4,6 +4,7 @@ import ( "strings" "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/column/verifiercol" "github.com/consensys/linea-monorepo/prover/protocol/distributed" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" @@ -84,8 +85,12 @@ func (p *PeriodSeperatingModuleDiscoverer) QueryIsInModule(ifaces.Query, ModuleN // ColumnIsInModule checks that the given column is inside the given module. func (p *PeriodSeperatingModuleDiscoverer) ColumnIsInModule(col ifaces.Column, name ModuleName) bool { + colID := col.GetColID() + if shifted, ok := col.(column.Shifted); ok { + colID = shifted.Parent.GetColID() + } for _, c := range p.modules[name] { - if c.GetColID() == col.GetColID() { + if c.GetColID() == colID { return true } } @@ -95,6 +100,7 @@ func (p *PeriodSeperatingModuleDiscoverer) ColumnIsInModule(col ifaces.Column, n // ExpressionIsInModule checks that all the columns (except verifiercol) in the expression are from the given module. // // It does not check the presence of the coins and other metadata in the module. +// the restriction over verifier column comes from the fact that the discoverer Analyses compiledIOP and the verifier columns are not accessible there. func (p *PeriodSeperatingModuleDiscoverer) ExpressionIsInModule(expr *symbolic.Expression, name ModuleName) bool { var ( board = expr.Board() diff --git a/prover/protocol/wizard/verifier.go b/prover/protocol/wizard/verifier.go index e598b9f4f..4d91e1369 100644 --- a/prover/protocol/wizard/verifier.go +++ b/prover/protocol/wizard/verifier.go @@ -475,7 +475,7 @@ func (run VerifierRuntime) GetColumnAt(name ifaces.ColID, pos int) field.Element wit := run.Columns.MustGet(name) if pos >= wit.Len() || pos < 0 { - utils.Panic("asked pos %v for vector of size %v", pos, wit) + utils.Panic("asked pos %v for vector of size %v", pos, wit.Len()) } return wit.Get(pos) From f97cd001962e476f87d60dc731e9b3236e343e4f Mon Sep 17 00:00:00 2001 From: Arijit Dutta <37040536+arijitdutta67@users.noreply.github.com> Date: Fri, 21 Feb 2025 19:04:10 +0530 Subject: [PATCH 4/4] Prover/Complete Projection Query (#667) * added projection query * compiler for projection added * compiler added to arcane * minor fix added * fix lint error * duplicate name fix * removing cptHolter to math poly * shifted the no lint command * added bin file in gitignore * removed gitignore change * remove bin file * code simplification as Alex suggested * fix in lpp * distributed projection query added * more test cases added * adding slice to accomodate multiple query per module * param changed to support additive structure * compiler code added (wip) * check test added * incorporate Alex suggestion on PR 585 * Added documentation * vertical splitting wip * feat: cumNumOnes added * feat: vertical split verifier action added * feat : added filter counting and hash computing functionality * feat: vertical consistency check and cleanup --------- Signed-off-by: Arijit Dutta <37040536+arijitdutta67@users.noreply.github.com> --- .../accessors/from_distributedprojection.go | 71 +++ .../distributedprojection/compiler.go | 47 +- .../compiler/distributedprojection/prover.go | 24 +- .../distributedprojection/verifier.go | 212 +++++++-- prover/protocol/distributed/comp_splitting.go | 10 +- .../compiler/inclusion/inclusion.go | 4 +- .../compiler/inclusion/inclusion_test.go | 10 +- .../compiler/projection/projection.go | 413 ++++++++++++++---- .../compiler/projection/projection_test.go | 221 ++++++++-- .../distributed/conglomeration/translator.go | 2 +- .../distributed/constants/constant.go | 7 +- prover/protocol/distributed/lpp/lpp.go | 71 ++- prover/protocol/distributed/lpp/lpp_test.go | 2 +- .../protocol/query/distributed_projection.go | 79 +--- .../query/distributed_projection_test.go | 95 ---- prover/protocol/query/gnark_params.go | 5 + prover/protocol/wizard/prover.go | 6 +- prover/protocol/wizard/verifier.go | 23 +- 18 files changed, 921 insertions(+), 381 deletions(-) create mode 100644 prover/protocol/accessors/from_distributedprojection.go delete mode 100644 prover/protocol/query/distributed_projection_test.go diff --git a/prover/protocol/accessors/from_distributedprojection.go b/prover/protocol/accessors/from_distributedprojection.go new file mode 100644 index 000000000..0677cdd84 --- /dev/null +++ b/prover/protocol/accessors/from_distributedprojection.go @@ -0,0 +1,71 @@ +package accessors + +import ( + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/symbolic" +) + +const ( + DISTRIBUTED_PROJECTION_ACCESSOR = "DISTRIBUTED_PROJECTION_ACCESSOR" +) + +// FromDistributedProjectionAccessor implements [ifaces.Accessor] and accesses the result of +// a [query.DISTRIBUTED_PROJECTION]. +type FromDistributedProjectionAccessor struct { + // Q is the underlying query whose parameters are accessed by the current + // [ifaces.Accessor]. + Q query.DistributedProjection +} + +// NewDistributedProjectionAccessor creates an [ifaces.Accessor] returning the opening +// point of a [query.DISTRIBUTED_PROJECTION]. +func NewDistributedProjectionAccessor(q query.DistributedProjection) ifaces.Accessor { + return &FromDistributedProjectionAccessor{Q: q} +} + +// Name implements [ifaces.Accessor] +func (l *FromDistributedProjectionAccessor) Name() string { + return fmt.Sprintf("%v_%v", DISTRIBUTED_PROJECTION_ACCESSOR, l.Q.ID) +} + +// String implements [github.com/consensys/linea-monorepo/prover/symbolic.Metadata] +func (l *FromDistributedProjectionAccessor) String() string { + return l.Name() +} + +// GetVal implements [ifaces.Accessor] +func (l *FromDistributedProjectionAccessor) GetVal(run ifaces.Runtime) field.Element { + params := run.GetParams(l.Q.ID).(query.DistributedProjectionParams) + return params.ScaledHorner +} + +func (l *FromDistributedProjectionAccessor) GetValCumSumCurr(run ifaces.Runtime) field.Element { + params := run.GetParams(l.Q.ID).(query.DistributedProjectionParams) + return params.HashCumSumOneCurr +} + +func (l *FromDistributedProjectionAccessor) GetValCumSumPrev(run ifaces.Runtime) field.Element { + params := run.GetParams(l.Q.ID).(query.DistributedProjectionParams) + return params.HashCumSumOnePrev +} + +// GetFrontendVariable implements [ifaces.Accessor] +func (l *FromDistributedProjectionAccessor) GetFrontendVariable(_ frontend.API, circ ifaces.GnarkRuntime) frontend.Variable { + params := circ.GetParams(l.Q.ID).(query.GnarkDistributedProjectionParams) + return params.Sum +} + +// AsVariable implements the [ifaces.Accessor] interface +func (l *FromDistributedProjectionAccessor) AsVariable() *symbolic.Expression { + return symbolic.NewVariable(l) +} + +// Round implements the [ifaces.Accessor] interface +func (l *FromDistributedProjectionAccessor) Round() int { + return l.Q.Round +} diff --git a/prover/protocol/compiler/distributedprojection/compiler.go b/prover/protocol/compiler/distributedprojection/compiler.go index ab83b7e21..8c1c1a283 100644 --- a/prover/protocol/compiler/distributedprojection/compiler.go +++ b/prover/protocol/compiler/distributedprojection/compiler.go @@ -1,6 +1,9 @@ package distributedprojection import ( + "math/big" + + "github.com/consensys/linea-monorepo/prover/maths/field" "github.com/consensys/linea-monorepo/prover/protocol/coin" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/query" @@ -20,7 +23,6 @@ func CompileDistributedProjection(comp *wizard.CompiledIOP) { // This ensures that the distributed projection query is not used again in the // compilation process. We know that the query was not already ignored at the beginning // because we are iterating over the unignored keys. - comp.QueriesParams.MarkAsIgnored(qName) round := comp.QueriesParams.Round(qName) compile(comp, round, distributedprojection) } @@ -29,29 +31,38 @@ func CompileDistributedProjection(comp *wizard.CompiledIOP) { func compile(comp *wizard.CompiledIOP, round int, distributedprojection query.DistributedProjection) { var ( pa = &distribuedProjectionProverAction{ - Name: distributedprojection.ID, - FilterA: make([]*symbolic.Expression, len(distributedprojection.Inp)), - FilterB: make([]*symbolic.Expression, len(distributedprojection.Inp)), - ColumnA: make([]*symbolic.Expression, len(distributedprojection.Inp)), - ColumnB: make([]*symbolic.Expression, len(distributedprojection.Inp)), - HornerA: make([]ifaces.Column, len(distributedprojection.Inp)), - HornerB: make([]ifaces.Column, len(distributedprojection.Inp)), - HornerA0: make([]query.LocalOpening, len(distributedprojection.Inp)), - HornerB0: make([]query.LocalOpening, len(distributedprojection.Inp)), - EvalCoin: make([]coin.Info, len(distributedprojection.Inp)), - IsA: make([]bool, len(distributedprojection.Inp)), - IsB: make([]bool, len(distributedprojection.Inp)), + Name: distributedprojection.ID, + Query: distributedprojection, + FilterA: make([]*symbolic.Expression, len(distributedprojection.Inp)), + FilterB: make([]*symbolic.Expression, len(distributedprojection.Inp)), + ColumnA: make([]*symbolic.Expression, len(distributedprojection.Inp)), + ColumnB: make([]*symbolic.Expression, len(distributedprojection.Inp)), + HornerA: make([]ifaces.Column, len(distributedprojection.Inp)), + HornerB: make([]ifaces.Column, len(distributedprojection.Inp)), + HornerA0: make([]query.LocalOpening, len(distributedprojection.Inp)), + HornerB0: make([]query.LocalOpening, len(distributedprojection.Inp)), + EvalCoins: make([]coin.Info, len(distributedprojection.Inp)), + IsA: make([]bool, len(distributedprojection.Inp)), + IsB: make([]bool, len(distributedprojection.Inp)), } ) pa.Push(comp, distributedprojection) pa.RegisterQueries(comp, round, distributedprojection) comp.RegisterProverAction(round, pa) comp.RegisterVerifierAction(round, &distributedProjectionVerifierAction{ - Name: pa.Name, - HornerA0: pa.HornerA0, - HornerB0: pa.HornerB0, - isA: pa.IsA, - isB: pa.IsB, + Name: pa.Name, + Query: pa.Query, + HornerA0: pa.HornerA0, + HornerB0: pa.HornerB0, + IsA: pa.IsA, + IsB: pa.IsB, + EvalCoins: pa.EvalCoins, + FilterA: pa.FilterA, + FilterB: pa.FilterB, + CumNumOnesPrevSegmentsA: make([]big.Int, len(distributedprojection.Inp)), + CumNumOnesPrevSegmentsB: make([]big.Int, len(distributedprojection.Inp)), + NumOnesCurrSegmentA: make([]field.Element, len(distributedprojection.Inp)), + NumOnesCurrSegmentB: make([]field.Element, len(distributedprojection.Inp)), }) } diff --git a/prover/protocol/compiler/distributedprojection/prover.go b/prover/protocol/compiler/distributedprojection/prover.go index ee7e49138..7406318b4 100644 --- a/prover/protocol/compiler/distributedprojection/prover.go +++ b/prover/protocol/compiler/distributedprojection/prover.go @@ -16,11 +16,12 @@ import ( type distribuedProjectionProverAction struct { Name ifaces.QueryID + Query query.DistributedProjection FilterA, FilterB []*sym.Expression ColumnA, ColumnB []*sym.Expression HornerA, HornerB []ifaces.Column HornerA0, HornerB0 []query.LocalOpening - EvalCoin []coin.Info + EvalCoins []coin.Info IsA, IsB []bool } @@ -33,14 +34,14 @@ type distribuedProjectionProverAction struct { // If IsA is false and IsB is true, it computes the Horner trace for column B only. // If neither IsA nor IsB is true, it panics with an error message indicating an invalid prover assignment. func (pa *distribuedProjectionProverAction) Run(run *wizard.ProverRuntime) { - for index := range pa.FilterA { + for index := range pa.Query.Inp { if pa.IsA[index] && pa.IsB[index] { var ( colA = column.EvalExprColumn(run, pa.ColumnA[index].Board()).IntoRegVecSaveAlloc() fA = column.EvalExprColumn(run, pa.FilterA[index].Board()).IntoRegVecSaveAlloc() colB = column.EvalExprColumn(run, pa.ColumnB[index].Board()).IntoRegVecSaveAlloc() fB = column.EvalExprColumn(run, pa.FilterB[index].Board()).IntoRegVecSaveAlloc() - x = run.GetRandomCoinField(pa.EvalCoin[index].Name) + x = run.GetRandomCoinField(pa.EvalCoins[index].Name) hornerA = poly.GetHornerTrace(colA, fA, x) hornerB = poly.GetHornerTrace(colB, fB, x) ) @@ -52,7 +53,7 @@ func (pa *distribuedProjectionProverAction) Run(run *wizard.ProverRuntime) { var ( colA = column.EvalExprColumn(run, pa.ColumnA[index].Board()).IntoRegVecSaveAlloc() fA = column.EvalExprColumn(run, pa.FilterA[index].Board()).IntoRegVecSaveAlloc() - x = run.GetRandomCoinField(pa.EvalCoin[index].Name) + x = run.GetRandomCoinField(pa.EvalCoins[index].Name) hornerA = poly.GetHornerTrace(colA, fA, x) ) run.AssignColumn(pa.HornerA[index].GetColID(), smartvectors.NewRegular(hornerA)) @@ -61,7 +62,7 @@ func (pa *distribuedProjectionProverAction) Run(run *wizard.ProverRuntime) { var ( colB = column.EvalExprColumn(run, pa.ColumnB[index].Board()).IntoRegVecSaveAlloc() fB = column.EvalExprColumn(run, pa.FilterB[index].Board()).IntoRegVecSaveAlloc() - x = run.GetRandomCoinField(pa.EvalCoin[index].Name) + x = run.GetRandomCoinField(pa.EvalCoins[index].Name) hornerB = poly.GetHornerTrace(colB, fB, x) ) run.AssignColumn(pa.HornerB[index].GetColID(), smartvectors.NewRegular(hornerB)) @@ -88,24 +89,21 @@ func (pa *distribuedProjectionProverAction) Push(comp *wizard.CompiledIOP, distr pa.FilterB[index] = input.FilterB pa.ColumnA[index] = input.ColumnA pa.ColumnB[index] = input.ColumnB - pa.EvalCoin[index] = comp.Coins.Data(input.EvalCoin) + pa.EvalCoins[index] = comp.Coins.Data(input.EvalCoin) pa.IsA[index] = true pa.IsB[index] = true - } else if input.IsAInModule && !input.IsBInModule { pa.FilterA[index] = input.FilterA pa.ColumnA[index] = input.ColumnA - pa.EvalCoin[index] = comp.Coins.Data(input.EvalCoin) + pa.EvalCoins[index] = comp.Coins.Data(input.EvalCoin) pa.IsA[index] = true pa.IsB[index] = false - } else if !input.IsAInModule && input.IsBInModule { pa.FilterB[index] = input.FilterB pa.ColumnB[index] = input.ColumnB - pa.EvalCoin[index] = comp.Coins.Data(input.EvalCoin) + pa.EvalCoins[index] = comp.Coins.Data(input.EvalCoin) pa.IsA[index] = false pa.IsB[index] = true - } else { logrus.Errorf("Invalid distributed projection query while pushing prover action entries: %v", distributedprojection.ID) } @@ -200,7 +198,7 @@ func (pa *distribuedProjectionProverAction) registerForCol( sym.Add( pa.ColumnA[index], sym.Mul( - pa.EvalCoin[index], + pa.EvalCoins[index], column.Shift(pa.HornerA[index], 1), ), ), @@ -232,7 +230,7 @@ func (pa *distribuedProjectionProverAction) registerForCol( ), sym.Mul( pa.FilterB[index], - sym.Add(pa.ColumnB[index], sym.Mul(pa.EvalCoin[index], column.Shift(pa.HornerB[index], 1))), + sym.Add(pa.ColumnB[index], sym.Mul(pa.EvalCoins[index], column.Shift(pa.HornerB[index], 1))), ), ), ) diff --git a/prover/protocol/compiler/distributedprojection/verifier.go b/prover/protocol/compiler/distributedprojection/verifier.go index 608339449..1a06e4172 100644 --- a/prover/protocol/compiler/distributedprojection/verifier.go +++ b/prover/protocol/compiler/distributedprojection/verifier.go @@ -1,23 +1,70 @@ package distributedprojection import ( + "fmt" + "math/big" + "github.com/consensys/gnark/frontend" + "github.com/consensys/linea-monorepo/prover/crypto/mimc" "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/wizard" + sym "github.com/consensys/linea-monorepo/prover/symbolic" "github.com/consensys/linea-monorepo/prover/utils" ) type distributedProjectionVerifierAction struct { - Name ifaces.QueryID - HornerA0, HornerB0 []query.LocalOpening - isA, isB []bool - skipped bool + Name ifaces.QueryID + Query query.DistributedProjection + HornerA0, HornerB0 []query.LocalOpening + IsA, IsB []bool + skipped bool + EvalCoins []coin.Info + CumNumOnesPrevSegmentsA []big.Int + CumNumOnesPrevSegmentsB []big.Int + NumOnesCurrSegmentA []field.Element + NumOnesCurrSegmentB []field.Element + FilterA, FilterB []*sym.Expression } // Run implements the [wizard.VerifierAction] func (va *distributedProjectionVerifierAction) Run(run wizard.Runtime) error { + for index, inp := range va.Query.Inp { + if va.IsA[index] && va.IsB[index] { + va.CumNumOnesPrevSegmentsA[index] = inp.CumulativeNumOnesPrevSegmentsA + va.NumOnesCurrSegmentA[index] = inp.CurrNumOnesA + va.CumNumOnesPrevSegmentsB[index] = inp.CumulativeNumOnesPrevSegmentsB + va.NumOnesCurrSegmentB[index] = inp.CurrNumOnesB + } else if va.IsA[index] && !va.IsB[index] { + va.CumNumOnesPrevSegmentsA[index] = inp.CumulativeNumOnesPrevSegmentsA + va.NumOnesCurrSegmentA[index] = inp.CurrNumOnesA + } else if !va.IsA[index] && va.IsB[index] { + va.CumNumOnesPrevSegmentsB[index] = inp.CumulativeNumOnesPrevSegmentsB + va.NumOnesCurrSegmentB[index] = inp.CurrNumOnesB + } + } + errorCheckHorner := va.scaledHornerCheck(run) + if errorCheckHorner != nil { + return errorCheckHorner + } + errorCheckCurrNumOnes := va.currSumOneCheck(run) + if errorCheckCurrNumOnes != nil { + return errorCheckCurrNumOnes + } + + errorCheckHash := va.hashCheck(run) + if errorCheckHash != nil { + return errorCheckHash + } + + return nil +} + +// method to check consistancy of the scaled horner +func (va *distributedProjectionVerifierAction) scaledHornerCheck(run wizard.Runtime) error { var ( actualParam = field.Zero() ) @@ -25,59 +72,152 @@ func (va *distributedProjectionVerifierAction) Run(run wizard.Runtime) error { var ( elemParam = field.Zero() ) - if va.isA[index] && va.isB[index] { - elemParam = run.GetLocalPointEvalParams(va.HornerB0[index].ID).Y - elemParam.Neg(&elemParam) - temp := run.GetLocalPointEvalParams(va.HornerA0[index].ID).Y - elemParam.Add(&elemParam, &temp) - } else if va.isA[index] && !va.isB[index] { - elemParam = run.GetLocalPointEvalParams(va.HornerA0[index].ID).Y - } else if !va.isA[index] && va.isB[index] { - elemParam = run.GetLocalPointEvalParams(va.HornerB0[index].ID).Y - elemParam.Neg(&elemParam) + if va.IsA[index] && va.IsB[index] { + var ( + multA, multB field.Element + hornerA, hornerB field.Element + ) + hornerA = run.GetLocalPointEvalParams(va.HornerA0[index].ID).Y + multA = run.GetRandomCoinField(va.EvalCoins[index].Name) + multA.Exp(multA, &va.CumNumOnesPrevSegmentsA[index]) + hornerA.Mul(&hornerA, &multA) + elemParam.Add(&elemParam, &hornerA) + + hornerB = run.GetLocalPointEvalParams(va.HornerB0[index].ID).Y + multB = run.GetRandomCoinField(va.EvalCoins[index].Name) + multB.Exp(multB, &va.CumNumOnesPrevSegmentsB[index]) + hornerB.Mul(&elemParam, &multB) + elemParam.Sub(&elemParam, &hornerB) + } else if va.IsA[index] && !va.IsB[index] { + var ( + multA field.Element + hornerA field.Element + ) + hornerA = run.GetLocalPointEvalParams(va.HornerA0[index].ID).Y + multA = run.GetRandomCoinField(va.EvalCoins[index].Name) + multA.Exp(multA, &va.CumNumOnesPrevSegmentsA[index]) + hornerA.Mul(&hornerA, &multA) + elemParam.Add(&elemParam, &hornerA) + } else if !va.IsA[index] && va.IsB[index] { + var ( + multB field.Element + hornerB field.Element + ) + hornerB = run.GetLocalPointEvalParams(va.HornerB0[index].ID).Y + multB = run.GetRandomCoinField(va.EvalCoins[index].Name) + multB.Exp(multB, &va.CumNumOnesPrevSegmentsB[index]) + hornerB.Mul(&hornerB, &multB) + elemParam.Sub(&elemParam, &hornerB) } else { utils.Panic("Unsupported verifier action registered for %v", va.Name) } actualParam.Add(&actualParam, &elemParam) } - queryParam := run.GetDistributedProjectionParams(va.Name).HornerVal + queryParam := run.GetDistributedProjectionParams(va.Name).ScaledHorner if actualParam != queryParam { utils.Panic("The distributed projection query %v did not pass, query param %v and actual param %v", va.Name, queryParam, actualParam) } return nil } -// RunGnark implements the [wizard.VerifierAction] interface. -func (va *distributedProjectionVerifierAction) RunGnark(api frontend.API, run wizard.GnarkRuntime) { +func (va *distributedProjectionVerifierAction) currSumOneCheck(run wizard.Runtime) error { + for index := range va.HornerA0 { + if va.IsA[index] && va.IsB[index] { + var ( + numOnesA = field.Zero() + numOnesB = field.Zero() + one = field.One() + ) + fA := column.EvalExprColumn(run, va.FilterA[index].Board()).IntoRegVecSaveAlloc() + for i := 0; i < len(fA); i++ { + if fA[i] == field.One() { + numOnesA.Add(&numOnesA, &one) + } + } + if numOnesA != va.NumOnesCurrSegmentA[index] { + return fmt.Errorf("number of one for filterA does not match, actual = %v, assigned = %v", numOnesA, va.NumOnesCurrSegmentA[index]) + } + fB := column.EvalExprColumn(run, va.FilterB[index].Board()).IntoRegVecSaveAlloc() + for i := 0; i < len(fB); i++ { + if fB[i] == field.One() { + numOnesB.Add(&numOnesB, &one) + } + } + if numOnesB != va.NumOnesCurrSegmentB[index] { + return fmt.Errorf("number of one for filterB does not match, actual = %v, assigned = %v", numOnesB, va.NumOnesCurrSegmentB[index]) + } + } + if va.IsA[index] && !va.IsB[index] { + var ( + numOnesA = field.Zero() + one = field.One() + ) + fA := column.EvalExprColumn(run, va.FilterA[index].Board()).IntoRegVecSaveAlloc() + for i := 0; i < len(fA); i++ { + if fA[i] == field.One() { + numOnesA.Add(&numOnesA, &one) + } + } + if numOnesA != va.NumOnesCurrSegmentA[index] { + return fmt.Errorf("number of one for filterA does not match, actual = %v, assigned = %v", numOnesA, va.NumOnesCurrSegmentA[index]) + } + } + if !va.IsA[index] && va.IsB[index] { + var ( + numOnesB = field.Zero() + one = field.One() + ) + fB := column.EvalExprColumn(run, va.FilterB[index].Board()).IntoRegVecSaveAlloc() + for i := 0; i < len(fB); i++ { + if fB[i] == field.One() { + numOnesB.Add(&numOnesB, &one) + } + } + if numOnesB != va.NumOnesCurrSegmentB[index] { + return fmt.Errorf("number of one for filterB does not match, actual = %v, assigned = %v", numOnesB, va.NumOnesCurrSegmentB[index]) + } + } + } + return nil +} +func (va *distributedProjectionVerifierAction) hashCheck(run wizard.Runtime) error { var ( - actualParam = frontend.Variable(0) + oldState = field.Zero() ) for index := range va.HornerA0 { - var ( - elemParam = frontend.Variable(0) - ) - if va.isA[index] && va.isB[index] { + if va.IsA[index] && va.IsB[index] { var ( - a, b frontend.Variable + sumA, sumB field.Element ) - a = run.GetLocalPointEvalParams(va.HornerA0[index].ID).Y - b = run.GetLocalPointEvalParams(va.HornerB0[index].ID).Y - elemParam = api.Sub(a, b) - } else if va.isA[index] && !va.isB[index] { - a := run.GetLocalPointEvalParams(va.HornerA0[index].ID).Y - elemParam = api.Add(elemParam, a) - } else if !va.isA[index] && va.isB[index] { - b := run.GetLocalPointEvalParams(va.HornerB0[index].ID).Y - elemParam = api.Sub(elemParam, b) + sumA = field.NewElement(va.CumNumOnesPrevSegmentsA[index].Uint64()) + sumA.Add(&sumA, &va.NumOnesCurrSegmentA[index]) + sumB = field.NewElement(va.CumNumOnesPrevSegmentsB[index].Uint64()) + sumB.Add(&sumB, &va.NumOnesCurrSegmentB[index]) + oldState = mimc.BlockCompression(oldState, sumA) + oldState = mimc.BlockCompression(oldState, sumB) + } else if va.IsA[index] && !va.IsB[index] { + sumA := field.NewElement(va.CumNumOnesPrevSegmentsA[index].Uint64()) + sumA.Add(&sumA, &va.NumOnesCurrSegmentA[index]) + oldState = mimc.BlockCompression(oldState, sumA) + } else if !va.IsA[index] && va.IsB[index] { + sumB := field.NewElement(va.CumNumOnesPrevSegmentsB[index].Uint64()) + sumB.Add(&sumB, &va.NumOnesCurrSegmentB[index]) + oldState = mimc.BlockCompression(oldState, sumB) } else { - utils.Panic("Unsupported verifier action registered for %v", va.Name) + panic("Invalid distributed projection query encountered during current hash verification") } - actualParam = api.Add(actualParam, elemParam) } - queryParam := run.GetDistributedProjectionParams(va.Name).Sum + if oldState != run.GetDistributedProjectionParams(va.Name).HashCumSumOneCurr { + return fmt.Errorf("HashCumSumOneCurr does not match, actual = %v, assigned = %v", oldState, run.GetDistributedProjectionParams(va.Name).HashCumSumOneCurr) + } + return nil +} + +// RunGnark implements the [wizard.VerifierAction] interface. +func (va *distributedProjectionVerifierAction) RunGnark(api frontend.API, run wizard.GnarkRuntime) { - api.AssertIsEqual(actualParam, queryParam) + panic("unimplemented") } // Skip implements the [wizard.VerifierAction] diff --git a/prover/protocol/distributed/comp_splitting.go b/prover/protocol/distributed/comp_splitting.go index 9100dda9d..0ca1402c3 100644 --- a/prover/protocol/distributed/comp_splitting.go +++ b/prover/protocol/distributed/comp_splitting.go @@ -45,8 +45,8 @@ func GetFreshSegmentModuleComp(in SegmentModuleInputs) *wizard.CompiledIOP { if !in.Disc.ColumnIsInModule(col, in.ModuleName) { continue } - - segModComp.InsertCommit(col.Round(), col.GetColID(), col.Size()/in.NumSegmentsInModule) + // Make colSize a power of two + segModComp.InsertCommit(col.Round(), col.GetColID(), utils.NextPowerOfTwo(col.Size()/in.NumSegmentsInModule)) columnsInRound = append(columnsInRound, col) } @@ -79,7 +79,7 @@ func (p segmentModuleProver) Run(run *wizard.ProverRuntime) { utils.Panic("invalid call: the runtime does not have a [ParentRuntime]") } if run.ProverID > p.numSegments { - panic("proverID can not be larger than number of segments") + panic("proverID cannot be larger than number of segments") } for _, col := range p.cols { @@ -92,7 +92,7 @@ func (p segmentModuleProver) Run(run *wizard.ProverRuntime) { } func getSegmentFromWitness(wit ifaces.ColAssignment, numSegs, segID int) ifaces.ColAssignment { - segSize := wit.Len() / numSegs + segSize := utils.NextPowerOfTwo(wit.Len() / numSegs) return wit.SubVector(segSize*segID, segSize*segID+segSize) } @@ -288,7 +288,7 @@ func assignProvider(run *wizard.ProverRuntime, segID, numSegments int, col iface lastRow = (segID+1)*segmentSize - 1 colWit []field.Element // number of boundaries from the current column - numBoundaries = 0 + numBoundaries int ) if shifted, ok := t.(column.Shifted); ok { diff --git a/prover/protocol/distributed/compiler/inclusion/inclusion.go b/prover/protocol/distributed/compiler/inclusion/inclusion.go index 855e25b27..3de8d9a2e 100644 --- a/prover/protocol/distributed/compiler/inclusion/inclusion.go +++ b/prover/protocol/distributed/compiler/inclusion/inclusion.go @@ -52,9 +52,9 @@ func DistributeLogDerivativeSum( continue } - // panic if there is more than a LogDerivativeSum query in the initialComp. + // panic if there is more than one LogDerivativeSum queries in the initialComp. if string(queryID) != "" { - utils.Panic("found more than a LogDerivativeSum query in the initialComp") + utils.Panic("found more than one LogDerivativeSum queries in the initialComp") } queryID = qName diff --git a/prover/protocol/distributed/compiler/inclusion/inclusion_test.go b/prover/protocol/distributed/compiler/inclusion/inclusion_test.go index e62939f58..c5c0ba8b7 100644 --- a/prover/protocol/distributed/compiler/inclusion/inclusion_test.go +++ b/prover/protocol/distributed/compiler/inclusion/inclusion_test.go @@ -183,9 +183,15 @@ func TestSeedGeneration(t *testing.T) { func checkConsistency(runs []wizard.Runtime) error { - var res field.Element + var ( + res field.Element + ) for _, run := range runs { - logderiv := run.GetPublicInput(constants.LogDerivativeSumPublicInput) + logderiv_ := run.GetPublicInput(constants.LogDerivativeSumPublicInput) + logderiv, ok := logderiv_.(field.Element) + if !ok { + return errors.New("the logderiv is not a field element") + } res.Add(&res, &logderiv) } diff --git a/prover/protocol/distributed/compiler/projection/projection.go b/prover/protocol/distributed/compiler/projection/projection.go index ddaf3d4b7..ffac65bfb 100644 --- a/prover/protocol/distributed/compiler/projection/projection.go +++ b/prover/protocol/distributed/compiler/projection/projection.go @@ -1,12 +1,17 @@ package dist_projection import ( + "math/big" + + "github.com/consensys/linea-monorepo/prover/crypto/mimc" "github.com/consensys/linea-monorepo/prover/maths/common/poly" "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/accessors" "github.com/consensys/linea-monorepo/prover/protocol/coin" "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/distributed" - "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" + "github.com/consensys/linea-monorepo/prover/protocol/distributed/constants" + discoverer "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/wizard" @@ -22,8 +27,10 @@ const ( ) type DistributeProjectionCtx struct { + // List of all projection queries alloted to the segments DistProjectionInput []*query.DistributedProjectionInput - EvalCoins []coin.Info + // List of Evaluation Coins per query + EvalCoins []coin.Info // The module name for which we are processing the distributed projection query TargetModuleName string // Query stores the [query.DistributedProjection] generated by the compilation @@ -31,22 +38,29 @@ type DistributeProjectionCtx struct { // LastRoundPerm indicates the highest round at which a compiled projection // occurs. LastRoundProjection int + // number of segments per module + NumSegmentsPerModule int + // original projection queries to extract filters + QIds []query.Projection } // NewDistributeProjectionCtx processes all the projection queries from the initialComp // and registers DistributedProjection queries to the target module using the module // discoverer func NewDistributeProjectionCtx( - targetModuleName namebaseddiscoverer.ModuleName, + targetModuleName discoverer.ModuleName, initialComp, moduleComp *wizard.CompiledIOP, disc distributed.ModuleDiscoverer, + numSegmentPerModule int, ) *DistributeProjectionCtx { var ( p = &DistributeProjectionCtx{ - DistProjectionInput: make([]*query.DistributedProjectionInput, 0, MaxNumOfQueriesPerModule), - EvalCoins: make([]coin.Info, 0, MaxNumOfQueriesPerModule), - TargetModuleName: targetModuleName, - LastRoundProjection: getLastRoundPerm(initialComp), + DistProjectionInput: make([]*query.DistributedProjectionInput, 0, MaxNumOfQueriesPerModule), + EvalCoins: make([]coin.Info, 0, MaxNumOfQueriesPerModule), + TargetModuleName: targetModuleName, + LastRoundProjection: getLastRoundProjection(initialComp), + NumSegmentsPerModule: numSegmentPerModule, + QIds: make([]query.Projection, 0, MaxNumOfQueriesPerModule), } numRounds = initialComp.NumRounds() qId = p.QueryID() @@ -60,7 +74,7 @@ func NewDistributeProjectionCtx( */ for round := 0; round < numRounds; round++ { queries := initialComp.QueriesNoParams.AllKeysAt(round) - for queryInRound, qName := range queries { + for _, qName := range queries { // Skip if it was already compiled if initialComp.QueriesNoParams.IsIgnored(qName) { @@ -79,17 +93,13 @@ func NewDistributeProjectionCtx( if bothAAndB { check(q_.Inp.ColumnA, disc, targetModuleName) check(q_.Inp.ColumnB, disc, targetModuleName) - p.push(moduleComp, q_, round, queryInRound, true, true) - initialComp.QueriesNoParams.MarkAsIgnored(qName) - // Todo: Add panic if other cols are from other modules + p.push(moduleComp, q_, true, true) } else if onlyA { check(q_.Inp.ColumnA, disc, targetModuleName) - p.push(moduleComp, q_, round, queryInRound, true, false) - initialComp.QueriesNoParams.MarkAsIgnored(qName) + p.push(moduleComp, q_, true, false) } else if onlyB { check(q_.Inp.ColumnB, disc, targetModuleName) - p.push(moduleComp, q_, round, queryInRound, false, true) - initialComp.QueriesNoParams.MarkAsIgnored(qName) + p.push(moduleComp, q_, false, true) } else { continue } @@ -100,6 +110,13 @@ func NewDistributeProjectionCtx( p.Query = moduleComp.InsertDistributedProjection(p.LastRoundProjection+1, qId, p.DistProjectionInput) moduleComp.RegisterProverAction(p.LastRoundProjection+1, p) + + // declare [query.LogDerivSumParams] as [wizard.PublicInput] + moduleComp.PublicInputs = append(moduleComp.PublicInputs, + wizard.PublicInput{ + Name: constants.DistributedProjectionPublicInput, + Acc: accessors.NewDistributedProjectionAccessor(p.Query), + }) return p } @@ -107,7 +124,7 @@ func NewDistributeProjectionCtx( // Check verifies if all columns of the projection query belongs to the same module or not func check(cols []ifaces.Column, disc distributed.ModuleDiscoverer, - targetModuleName namebaseddiscoverer.ModuleName, + targetModuleName discoverer.ModuleName, ) error { for _, col := range cols { if disc.FindModule(col) != targetModuleName { @@ -118,11 +135,11 @@ func check(cols []ifaces.Column, } // push appends a new DistributedProjectionInput to the DistProjectionInput slice -func (p *DistributeProjectionCtx) push(comp *wizard.CompiledIOP, q query.Projection, round, queryInRound int, isA, isB bool) { +func (p *DistributeProjectionCtx) push(comp *wizard.CompiledIOP, q query.Projection, isA, isB bool) { var ( isMultiColumn = len(q.Inp.ColumnA) > 1 - alphaName = p.getCoinName("MERGING_COIN", round, queryInRound) - betaName = p.getCoinName("EVAL_COIN", round, queryInRound) + alphaName = coin.Namef("%v_%v", "MERGING_COIN"+string(q.ID), "FieldFromSeed") + betaName = coin.Namef("%v_%v", "EVAL_COIN"+string(q.ID), "FieldFromSeed") alpha coin.Info beta coin.Info ) @@ -131,50 +148,73 @@ func (p *DistributeProjectionCtx) push(comp *wizard.CompiledIOP, q query.Project if comp.Coins.Exists(alphaName) { alpha = comp.Coins.Data(alphaName) } else { - alpha = comp.InsertCoin(p.LastRoundProjection+1, alphaName, coin.Field) + alpha = comp.InsertCoin(p.LastRoundProjection+1, alphaName, coin.FieldFromSeed) } } - if comp.Coins.Exists(betaName) { beta = comp.Coins.Data(betaName) } else { - beta = comp.InsertCoin(p.LastRoundProjection+1, betaName, coin.Field) + beta = comp.InsertCoin(p.LastRoundProjection+1, betaName, coin.FieldFromSeed) } p.EvalCoins = append(p.EvalCoins, beta) + p.QIds = append(p.QIds, q) + // Push the DistributedProjectionInput if isA && isB { - fA, _, _ := wizardutils.AsExpr(q.Inp.FilterA) - fB, _, _ := wizardutils.AsExpr(q.Inp.FilterB) + fA, _, _ := wizardutils.AsExpr(comp.Columns.GetHandle(q.Inp.FilterA.GetColID())) + fB, _, _ := wizardutils.AsExpr(comp.Columns.GetHandle(q.Inp.FilterB.GetColID())) + var ( + colA = make([]ifaces.Column, 0, len(q.Inp.ColumnA)) + colB = make([]ifaces.Column, 0, len(q.Inp.ColumnB)) + ) + for i := 0; i < len(q.Inp.ColumnA); i++ { + colA = append(colA, comp.Columns.GetHandle(q.Inp.ColumnA[i].GetColID())) + } + for i := 0; i < len(q.Inp.ColumnB); i++ { + colB = append(colB, comp.Columns.GetHandle(q.Inp.ColumnB[i].GetColID())) + } p.DistProjectionInput = append(p.DistProjectionInput, &query.DistributedProjectionInput{ - ColumnA: wizardutils.RandLinCombColSymbolic(alpha, q.Inp.ColumnA), - ColumnB: wizardutils.RandLinCombColSymbolic(alpha, q.Inp.ColumnB), + ColumnA: wizardutils.RandLinCombColSymbolic(alpha, colA), + ColumnB: wizardutils.RandLinCombColSymbolic(alpha, colB), FilterA: fA, FilterB: fB, - SizeA: q.Inp.FilterA.Size(), - SizeB: q.Inp.FilterB.Size(), + SizeA: comp.Columns.GetSize(q.Inp.FilterA.GetColID()), + SizeB: comp.Columns.GetSize(q.Inp.FilterB.GetColID()), EvalCoin: beta.Name, IsAInModule: true, IsBInModule: true, }) } else if isA { - fA, _, _ := wizardutils.AsExpr(q.Inp.FilterA) + fA, _, _ := wizardutils.AsExpr(comp.Columns.GetHandle(q.Inp.FilterA.GetColID())) + var ( + colA = make([]ifaces.Column, 0, len(q.Inp.ColumnA)) + ) + for i := 0; i < len(q.Inp.ColumnA); i++ { + colA = append(colA, comp.Columns.GetHandle(q.Inp.ColumnA[i].GetColID())) + } p.DistProjectionInput = append(p.DistProjectionInput, &query.DistributedProjectionInput{ - ColumnA: wizardutils.RandLinCombColSymbolic(alpha, q.Inp.ColumnA), + ColumnA: wizardutils.RandLinCombColSymbolic(alpha, colA), ColumnB: symbolic.NewConstant(1), FilterA: fA, FilterB: symbolic.NewConstant(1), - SizeA: q.Inp.FilterA.Size(), + SizeA: comp.Columns.GetSize(q.Inp.FilterA.GetColID()), EvalCoin: beta.Name, IsAInModule: true, IsBInModule: false, }) } else if isB { - fB, _, _ := wizardutils.AsExpr(q.Inp.FilterB) + fB, _, _ := wizardutils.AsExpr(comp.Columns.GetHandle(q.Inp.FilterB.GetColID())) + var ( + colB = make([]ifaces.Column, 0, len(q.Inp.ColumnB)) + ) + for i := 0; i < len(q.Inp.ColumnB); i++ { + colB = append(colB, comp.Columns.GetHandle(q.Inp.ColumnB[i].GetColID())) + } p.DistProjectionInput = append(p.DistProjectionInput, &query.DistributedProjectionInput{ ColumnA: symbolic.NewConstant(1), - ColumnB: wizardutils.RandLinCombColSymbolic(alpha, q.Inp.ColumnB), + ColumnB: wizardutils.RandLinCombColSymbolic(alpha, colB), FilterA: symbolic.NewConstant(1), FilterB: fB, - SizeB: q.Inp.FilterB.Size(), + SizeB: comp.Columns.GetSize(q.Inp.FilterB.GetColID()), EvalCoin: beta.Name, IsAInModule: false, IsBInModule: true, @@ -184,48 +224,232 @@ func (p *DistributeProjectionCtx) push(comp *wizard.CompiledIOP, q query.Project } } -// computeQueryParam computes the parameter of the DistributedProjection query -func (p *DistributeProjectionCtx) computeQueryParam(run *wizard.ProverRuntime) field.Element { +// Run implements [wizard.ProverAction] interface +func (p *DistributeProjectionCtx) Run(run *wizard.ProverRuntime) { + p.assignSumNumOnes(run) + run.AssignDistributedProjection(p.Query.ID, query.DistributedProjectionParams{ + ScaledHorner: p.computeScaledHorner(run), + HashCumSumOnePrev: p.computeHashPrev(), + HashCumSumOneCurr: p.computeHashCurr(), + }) +} + +// assignSumNumOnes assigns the cumulative number of ones in the previous segments as well as the current segment +func (p *DistributeProjectionCtx) assignSumNumOnes(run *wizard.ProverRuntime) { + var ( + initialRuntime = run.ParentRuntime + segId = run.ProverID + one = field.One() + bigOne = big.NewInt(1) + ) + for elemIndex, inp := range p.DistProjectionInput { + if inp.IsAInModule && inp.IsBInModule { + var ( + fA = initialRuntime.GetColumn(p.QIds[elemIndex].Inp.FilterA.GetColID()) + fB = initialRuntime.GetColumn(p.QIds[elemIndex].Inp.FilterB.GetColID()) + segSizeA = utils.NextPowerOfTwo(fA.Len() / p.NumSegmentsPerModule) + segSizeB = utils.NextPowerOfTwo(fB.Len() / p.NumSegmentsPerModule) + numOnesCurrA = field.Zero() + numOnesCurrB = field.Zero() + cumSumOnesPrevA = *big.NewInt(0) + cumSumOnesPrevB = *big.NewInt(0) + ) + if segId == 0 { + var ( + fACurr = fA.SubVector(segId*segSizeA, (segId+1)*segSizeA).IntoRegVecSaveAlloc() + fBCurr = fB.SubVector(segId*segSizeB, (segId+1)*segSizeB).IntoRegVecSaveAlloc() + ) + for i := 0; i < len(fACurr); i++ { + if fACurr[i] == one { + numOnesCurrA.Add(&numOnesCurrA, &one) + } + } + for i := 0; i < len(fBCurr); i++ { + if fBCurr[i] == one { + numOnesCurrB.Add(&numOnesCurrB, &one) + } + } + p.DistProjectionInput[elemIndex].CumulativeNumOnesPrevSegmentsA = cumSumOnesPrevA + p.DistProjectionInput[elemIndex].CumulativeNumOnesPrevSegmentsB = cumSumOnesPrevB + p.DistProjectionInput[elemIndex].CurrNumOnesA = numOnesCurrA + p.DistProjectionInput[elemIndex].CurrNumOnesB = numOnesCurrB + } else { + var ( + fAPrev = fA.SubVector(0, segId*segSizeA).IntoRegVecSaveAlloc() + fBPrev = fB.SubVector(0, segId*segSizeB).IntoRegVecSaveAlloc() + fACurr = fA.SubVector(segId*segSizeA, (segId+1)*segSizeA).IntoRegVecSaveAlloc() + fBCurr = fB.SubVector(segId*segSizeB, (segId+1)*segSizeB).IntoRegVecSaveAlloc() + ) + for i := 0; i < len(fAPrev); i++ { + if fAPrev[i] == one { + cumSumOnesPrevA.Add(&cumSumOnesPrevA, bigOne) + } + } + for i := 0; i < len(fBPrev); i++ { + if fBPrev[i] == one { + cumSumOnesPrevB.Add(&cumSumOnesPrevB, bigOne) + } + } + for i := 0; i < len(fACurr); i++ { + if fACurr[i] == one { + numOnesCurrA.Add(&numOnesCurrA, &one) + } + } + for i := 0; i < len(fBCurr); i++ { + if fBCurr[i] == one { + numOnesCurrB.Add(&numOnesCurrB, &one) + } + } + p.DistProjectionInput[elemIndex].CumulativeNumOnesPrevSegmentsA = cumSumOnesPrevA + p.DistProjectionInput[elemIndex].CumulativeNumOnesPrevSegmentsB = cumSumOnesPrevB + p.DistProjectionInput[elemIndex].CurrNumOnesA = numOnesCurrA + p.DistProjectionInput[elemIndex].CurrNumOnesB = numOnesCurrB + } + } + if inp.IsAInModule && !inp.IsBInModule { + var ( + fA = initialRuntime.GetColumn(p.QIds[elemIndex].Inp.FilterA.GetColID()) + segSizeA = utils.NextPowerOfTwo(fA.Len() / p.NumSegmentsPerModule) + numOnesCurrA = field.Zero() + cumSumOnesPrevA = *big.NewInt(0) + ) + if segId == 0 { + var ( + fACurr = fA.SubVector(segId*segSizeA, (segId+1)*segSizeA).IntoRegVecSaveAlloc() + ) + + for i := 0; i < len(fACurr); i++ { + if fACurr[i] == one { + numOnesCurrA.Add(&numOnesCurrA, &one) + } + } + p.DistProjectionInput[elemIndex].CumulativeNumOnesPrevSegmentsA = cumSumOnesPrevA + p.DistProjectionInput[elemIndex].CurrNumOnesA = numOnesCurrA + } else { + var ( + fAPrev = fA.SubVector(0, segId*segSizeA).IntoRegVecSaveAlloc() + fACurr = fA.SubVector(segId*segSizeA, (segId+1)*segSizeA).IntoRegVecSaveAlloc() + ) + for i := 0; i < len(fAPrev); i++ { + if fAPrev[i] == one { + cumSumOnesPrevA.Add(&cumSumOnesPrevA, bigOne) + } + } + for i := 0; i < len(fACurr); i++ { + if fACurr[i] == one { + numOnesCurrA.Add(&numOnesCurrA, &one) + } + } + p.DistProjectionInput[elemIndex].CumulativeNumOnesPrevSegmentsA = cumSumOnesPrevA + p.DistProjectionInput[elemIndex].CurrNumOnesA = numOnesCurrA + } + } + if !inp.IsAInModule && inp.IsBInModule { + var ( + fB = initialRuntime.GetColumn(p.QIds[elemIndex].Inp.FilterB.GetColID()) + segSizeB = utils.NextPowerOfTwo(fB.Len() / p.NumSegmentsPerModule) + numOnesCurrB field.Element + cumSumOnesPrevB = *big.NewInt(0) + ) + if segId == 0 { + var ( + fBCurr = fB.SubVector(segId*segSizeB, (segId+1)*segSizeB).IntoRegVecSaveAlloc() + ) + + for i := 0; i < len(fBCurr); i++ { + if fBCurr[i] == one { + numOnesCurrB.Add(&numOnesCurrB, &one) + } + } + p.DistProjectionInput[elemIndex].CumulativeNumOnesPrevSegmentsB = cumSumOnesPrevB + p.DistProjectionInput[elemIndex].CurrNumOnesB = numOnesCurrB + } else { + var ( + fBPrev = fB.SubVector(0, segId*segSizeB).IntoRegVecSaveAlloc() + fBCurr = fB.SubVector(segId*segSizeB, (segId+1)*segSizeB).IntoRegVecSaveAlloc() + ) + for i := 0; i < len(fBPrev); i++ { + if fBPrev[i] == one { + cumSumOnesPrevB.Add(&cumSumOnesPrevB, bigOne) + } + } + for i := 0; i < len(fBCurr); i++ { + if fBCurr[i] == one { + numOnesCurrB.Add(&numOnesCurrB, &one) + } + } + p.DistProjectionInput[elemIndex].CumulativeNumOnesPrevSegmentsB = cumSumOnesPrevB + p.DistProjectionInput[elemIndex].CurrNumOnesB = numOnesCurrB + } + } + } +} + +// computeScaledHorner computes the parameter of the DistributedProjection query +func (p *DistributeProjectionCtx) computeScaledHorner(run *wizard.ProverRuntime) field.Element { var ( queryParam = field.Zero() - elemParam = field.Zero() ) for elemIndex, inp := range p.DistProjectionInput { + var ( + elemParam = field.Zero() + ) if inp.IsAInModule && inp.IsBInModule { var ( - colABoard = inp.ColumnA.Board() - colBBoard = inp.ColumnB.Board() - filterABorad = inp.FilterA.Board() - filterBBoard = inp.FilterB.Board() - colA = column.EvalExprColumn(run, colABoard).IntoRegVecSaveAlloc() - colB = column.EvalExprColumn(run, colBBoard).IntoRegVecSaveAlloc() - filterA = column.EvalExprColumn(run, filterABorad).IntoRegVecSaveAlloc() - filterB = column.EvalExprColumn(run, filterBBoard).IntoRegVecSaveAlloc() + colABoard = inp.ColumnA.Board() + colBBoard = inp.ColumnB.Board() + filterABorad = inp.FilterA.Board() + filterBBoard = inp.FilterB.Board() + colA = column.EvalExprColumn(run, colABoard).IntoRegVecSaveAlloc() + colB = column.EvalExprColumn(run, colBBoard).IntoRegVecSaveAlloc() + filterA = column.EvalExprColumn(run, filterABorad).IntoRegVecSaveAlloc() + filterB = column.EvalExprColumn(run, filterBBoard).IntoRegVecSaveAlloc() + multA, multB, hornerA, hornerB field.Element ) - hornerA := poly.GetHornerTrace(colA, filterA, run.GetRandomCoinField(p.EvalCoins[elemIndex].Name)) - hornerB := poly.GetHornerTrace(colB, filterB, run.GetRandomCoinField(p.EvalCoins[elemIndex].Name)) - elemParam = hornerB[0] - elemParam.Neg(&elemParam) - elemParam.Add(&elemParam, &hornerA[0]) + // Add hornerA after scaling + hornerATrace := poly.GetHornerTrace(colA, filterA, run.GetRandomCoinField(p.EvalCoins[elemIndex].Name)) + multA = run.GetRandomCoinField(p.EvalCoins[elemIndex].Name) + multA.Exp(multA, &inp.CumulativeNumOnesPrevSegmentsA) + hornerA = hornerATrace[0] + hornerA.Mul(&hornerA, &multA) + elemParam.Add(&elemParam, &hornerA) + // Subtract hornerB after scaling + hornerBTrace := poly.GetHornerTrace(colB, filterB, run.GetRandomCoinField(p.EvalCoins[elemIndex].Name)) + hornerB = hornerBTrace[0] + multB = run.GetRandomCoinField(p.EvalCoins[elemIndex].Name) + multB.Exp(multB, &inp.CumulativeNumOnesPrevSegmentsB) + hornerB.Mul(&hornerB, &multB) + elemParam.Sub(&elemParam, &hornerB) } else if inp.IsAInModule && !inp.IsBInModule { var ( - colABoard = inp.ColumnA.Board() - filterABorad = inp.FilterA.Board() - colA = column.EvalExprColumn(run, colABoard).IntoRegVecSaveAlloc() - filterA = column.EvalExprColumn(run, filterABorad).IntoRegVecSaveAlloc() + colABoard = inp.ColumnA.Board() + filterABorad = inp.FilterA.Board() + colA = column.EvalExprColumn(run, colABoard).IntoRegVecSaveAlloc() + filterA = column.EvalExprColumn(run, filterABorad).IntoRegVecSaveAlloc() + multA, hornerA field.Element ) - hornerA := poly.GetHornerTrace(colA, filterA, run.GetRandomCoinField(p.EvalCoins[elemIndex].Name)) - elemParam = hornerA[0] + // Add hornerA after scaling + hornerATrace := poly.GetHornerTrace(colA, filterA, run.GetRandomCoinField(p.EvalCoins[elemIndex].Name)) + multA = run.GetRandomCoinField(p.EvalCoins[elemIndex].Name) + multA.Exp(multA, &inp.CumulativeNumOnesPrevSegmentsA) + hornerA = hornerATrace[0] + hornerA.Mul(&hornerA, &multA) + elemParam.Add(&elemParam, &hornerA) } else if !inp.IsAInModule && inp.IsBInModule { var ( - colBBoard = inp.ColumnB.Board() - filterBBorad = inp.FilterB.Board() - colB = column.EvalExprColumn(run, colBBoard).IntoRegVecSaveAlloc() - filterB = column.EvalExprColumn(run, filterBBorad).IntoRegVecSaveAlloc() + colBBoard = inp.ColumnB.Board() + filterBBorad = inp.FilterB.Board() + colB = column.EvalExprColumn(run, colBBoard).IntoRegVecSaveAlloc() + filterB = column.EvalExprColumn(run, filterBBorad).IntoRegVecSaveAlloc() + multB, hornerB field.Element ) - hornerB := poly.GetHornerTrace(colB, filterB, run.GetRandomCoinField(p.EvalCoins[elemIndex].Name)) - elemParam = hornerB[0] - elemParam.Neg(&elemParam) + // Subtract hornerB after scaling + hornerBTrace := poly.GetHornerTrace(colB, filterB, run.GetRandomCoinField(p.EvalCoins[elemIndex].Name)) + hornerB = hornerBTrace[0] + multB = run.GetRandomCoinField(p.EvalCoins[elemIndex].Name) + multB.Exp(multB, &inp.CumulativeNumOnesPrevSegmentsB) + hornerB.Mul(&hornerB, &multB) + elemParam.Sub(&elemParam, &hornerB) } else { panic("Invalid distributed projection query encountered during param evaluation") } @@ -234,11 +458,56 @@ func (p *DistributeProjectionCtx) computeQueryParam(run *wizard.ProverRuntime) f return queryParam } -// Run implements [wizard.ProverAction] interface -func (p *DistributeProjectionCtx) Run(run *wizard.ProverRuntime) { - run.AssignDistributedProjection(p.Query.ID, query.DistributedProjectionParams{ - HornerVal: p.computeQueryParam(run), - }) +// computeHashPrev computes the hash of the cumulative number of ones in the previous segments +func (p *DistributeProjectionCtx) computeHashPrev() field.Element { + var ( + oldState = field.Zero() + ) + for _, inp := range p.DistProjectionInput { + if inp.IsAInModule && inp.IsBInModule { + sumA := field.NewElement(inp.CumulativeNumOnesPrevSegmentsA.Uint64()) + oldState = mimc.BlockCompression(oldState, sumA) + sumB := field.NewElement(inp.CumulativeNumOnesPrevSegmentsB.Uint64()) + oldState = mimc.BlockCompression(oldState, sumB) + } else if inp.IsAInModule && !inp.IsBInModule { + sumA := field.NewElement(inp.CumulativeNumOnesPrevSegmentsA.Uint64()) + oldState = mimc.BlockCompression(oldState, sumA) + } else if !inp.IsAInModule && inp.IsBInModule { + sumB := field.NewElement(inp.CumulativeNumOnesPrevSegmentsB.Uint64()) + oldState = mimc.BlockCompression(oldState, sumB) + } else { + panic("Invalid distributed projection query encountered during previous hash computation") + } + } + return oldState +} + +// computeHashCurr computes the hash of the cumulative number of ones in the current segment +func (p *DistributeProjectionCtx) computeHashCurr() field.Element { + var ( + oldState = field.Zero() + ) + for _, inp := range p.DistProjectionInput { + if inp.IsAInModule && inp.IsBInModule { + sumA := field.NewElement(inp.CumulativeNumOnesPrevSegmentsA.Uint64()) + sumA.Add(&sumA, &inp.CurrNumOnesA) + oldState = mimc.BlockCompression(oldState, sumA) + sumB := field.NewElement(inp.CumulativeNumOnesPrevSegmentsB.Uint64()) + sumB.Add(&sumB, &inp.CurrNumOnesB) + oldState = mimc.BlockCompression(oldState, sumB) + } else if inp.IsAInModule && !inp.IsBInModule { + sumA := field.NewElement(inp.CumulativeNumOnesPrevSegmentsA.Uint64()) + sumA.Add(&sumA, &inp.CurrNumOnesA) + oldState = mimc.BlockCompression(oldState, sumA) + } else if !inp.IsAInModule && inp.IsBInModule { + sumB := field.NewElement(inp.CumulativeNumOnesPrevSegmentsB.Uint64()) + sumB.Add(&sumB, &inp.CurrNumOnesB) + oldState = mimc.BlockCompression(oldState, sumB) + } else { + panic("Invalid distributed projection query encountered during current hash computation") + } + } + return oldState } // deriveName constructs a name for the DistributeProjectionCtx context @@ -252,13 +521,9 @@ func (p *DistributeProjectionCtx) QueryID() ifaces.QueryID { return deriveName[ifaces.QueryID](ifaces.QueryID(p.TargetModuleName)) } -func (p *DistributeProjectionCtx) getCoinName(name string, round, queryInRound int) coin.Name { - return deriveName[coin.Name](p.QueryID(), name, round, queryInRound) -} - -// getLastRoundPerm scans the initialComp and looks for uncompiled projection queries. It returns +// getLastRoundProjection scans the initialComp and looks for uncompiled projection queries. It returns // the highest round found for a matched projection query. It returns -1 if no queries are found. -func getLastRoundPerm(initialComp *wizard.CompiledIOP) int { +func getLastRoundProjection(initialComp *wizard.CompiledIOP) int { var ( lastRound = -1 diff --git a/prover/protocol/distributed/compiler/projection/projection_test.go b/prover/protocol/distributed/compiler/projection/projection_test.go index ae6de7bcd..c9da25f4e 100644 --- a/prover/protocol/distributed/compiler/projection/projection_test.go +++ b/prover/protocol/distributed/compiler/projection/projection_test.go @@ -1,6 +1,7 @@ package dist_projection_test import ( + "errors" "testing" "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" @@ -9,16 +10,32 @@ import ( "github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy" "github.com/consensys/linea-monorepo/prover/protocol/distributed" dist_projection "github.com/consensys/linea-monorepo/prover/protocol/distributed/compiler/projection" - "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" + "github.com/consensys/linea-monorepo/prover/protocol/distributed/constants" + "github.com/consensys/linea-monorepo/prover/protocol/distributed/lpp" + md "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/wizard" "github.com/stretchr/testify/require" ) +type AllVerifierRuntimes struct { + RuntimesA []wizard.Runtime + RuntimesB []wizard.Runtime + RuntimesC []wizard.Runtime +} + func TestDistributeProjection(t *testing.T) { + const ( + numSegModuleA = 4 + numSegModuleB = 4 + numSegModuleC = 4 + ) var ( + allVerfiers = AllVerifierRuntimes{} moduleAName = "moduleA" + moduleBName = "moduleB" + moduleCName = "moduleC" flagSizeA = 512 flagSizeB = 256 flagA, flagB, columnA, columnB, flagC, columnC ifaces.Column @@ -30,44 +47,10 @@ func TestDistributeProjection(t *testing.T) { InitialProverFunc func(run *wizard.ProverRuntime) }{ { - Name: "distribute-projection-both-A-and-B", - DefineFunc: func(builder *wizard.Builder) { - flagA = builder.RegisterCommit(ifaces.ColID("moduleA.FilterA"), flagSizeA) - flagB = builder.RegisterCommit(ifaces.ColID("moduleA.FliterB"), flagSizeB) - columnA = builder.RegisterCommit(ifaces.ColID("moduleA.ColumnA"), flagSizeA) - columnB = builder.RegisterCommit(ifaces.ColID("moduleA.ColumnB"), flagSizeB) - _ = builder.InsertProjection("ProjectionTest-both-A-and-B", - query.ProjectionInput{ColumnA: []ifaces.Column{columnA}, ColumnB: []ifaces.Column{columnB}, FilterA: flagA, FilterB: flagB}) - - }, - InitialProverFunc: func(run *wizard.ProverRuntime) { - // assign filters and columns - var ( - flagAWit = make([]field.Element, flagSizeA) - columnAWit = make([]field.Element, flagSizeA) - flagBWit = make([]field.Element, flagSizeB) - columnBWit = make([]field.Element, flagSizeB) - ) - for i := 0; i < 10; i++ { - flagAWit[i] = field.One() - columnAWit[i] = field.NewElement(uint64(i)) - } - for i := flagSizeB - 10; i < flagSizeB; i++ { - flagBWit[i] = field.One() - columnBWit[i] = field.NewElement(uint64(i - (flagSizeB - 10))) - } - run.AssignColumn(flagA.GetColID(), smartvectors.RightZeroPadded(flagAWit, flagSizeA)) - run.AssignColumn(flagB.GetColID(), smartvectors.RightZeroPadded(flagBWit, flagSizeB)) - run.AssignColumn(columnB.GetColID(), smartvectors.RightZeroPadded(columnBWit, flagSizeB)) - run.AssignColumn(columnA.GetColID(), smartvectors.RightZeroPadded(columnAWit, flagSizeA)) - }, - }, - { - Name: "distribute-projection-multiple_projections", DefineFunc: func(builder *wizard.Builder) { flagA = builder.RegisterCommit(ifaces.ColID("moduleA.FilterA"), flagSizeA) flagB = builder.RegisterCommit(ifaces.ColID("moduleB.FliterB"), flagSizeB) - flagC = builder.RegisterCommit(ifaces.ColID("moduleC.FliterB"), flagSizeB) + flagC = builder.RegisterCommit(ifaces.ColID("moduleC.FliterC"), flagSizeB) columnA = builder.RegisterCommit(ifaces.ColID("moduleA.ColumnA"), flagSizeA) columnB = builder.RegisterCommit(ifaces.ColID("moduleB.ColumnB"), flagSizeB) columnC = builder.RegisterCommit(ifaces.ColID("moduleC.ColumnC"), flagSizeB) @@ -176,34 +159,176 @@ func TestDistributeProjection(t *testing.T) { for _, tc := range testcases { t.Run(tc.Name, func(t *testing.T) { - // This function assigns the initial module and is aimed at working - // for all test-case. - initialProve := tc.InitialProverFunc // initialComp is defined according to the define function provided by the // test-case. initialComp := wizard.Compile(tc.DefineFunc) - disc := namebaseddiscoverer.PeriodSeperatingModuleDiscoverer{} + // apply the LPP relevant compilers and generate the seed for initialComp + lppComp := lpp.CompileLPPAndGetSeed(initialComp) + + // Initialize the period separating module discoverer + disc := &md.PeriodSeperatingModuleDiscoverer{} disc.Analyze(initialComp) - // This declares a compiled IOP with only the columns of the module A - moduleAComp := distributed.GetFreshModuleComp(initialComp, &disc, moduleAName) - dist_projection.NewDistributeProjectionCtx(moduleAName, initialComp, moduleAComp, &disc) + // distribute the columns among modules and segments; this includes also multiplicity columns + // for all the segments from the same module, compiledIOP object is the same. + moduleCompA := distributed.GetFreshSegmentModuleComp( + distributed.SegmentModuleInputs{ + InitialComp: initialComp, + Disc: disc, + ModuleName: moduleAName, + NumSegmentsInModule: numSegModuleA, + }, + ) + + moduleCompB := distributed.GetFreshSegmentModuleComp(distributed.SegmentModuleInputs{ + InitialComp: initialComp, + Disc: disc, + ModuleName: moduleBName, + NumSegmentsInModule: numSegModuleB, + }) + + moduleCompC := distributed.GetFreshSegmentModuleComp(distributed.SegmentModuleInputs{ + InitialComp: initialComp, + Disc: disc, + ModuleName: moduleCName, + NumSegmentsInModule: numSegModuleC, + }) + + // distribute the query LogDerivativeSum among modules. + // The seed is used to generate randomness for each moduleComp. + dist_projection.NewDistributeProjectionCtx(moduleAName, initialComp, moduleCompA, disc, numSegModuleA) + dist_projection.NewDistributeProjectionCtx(moduleBName, initialComp, moduleCompB, disc, numSegModuleB) + dist_projection.NewDistributeProjectionCtx(moduleCName, initialComp, moduleCompC, disc, numSegModuleC) - wizard.ContinueCompilation(moduleAComp, distributedprojection.CompileDistributedProjection, dummy.CompileAtProverLvl) + // This compiles the log-derivative queries into global/local queries. + wizard.ContinueCompilation(moduleCompA, distributedprojection.CompileDistributedProjection, dummy.Compile) + wizard.ContinueCompilation(moduleCompB, distributedprojection.CompileDistributedProjection, dummy.Compile) + wizard.ContinueCompilation(moduleCompC, distributedprojection.CompileDistributedProjection, dummy.Compile) - // This runs the initial prover - initialRuntime := wizard.RunProver(initialComp, initialProve) + // run the initial runtime + initialRuntime := wizard.ProverOnlyFirstRound(initialComp, tc.InitialProverFunc) - proof := wizard.Prove(moduleAComp, func(run *wizard.ProverRuntime) { + // compile and verify for lpp-Prover + lppProof := wizard.Prove(lppComp, func(run *wizard.ProverRuntime) { run.ParentRuntime = initialRuntime }) - valid := wizard.Verify(moduleAComp, proof) + lppVerifierRuntime, valid := wizard.VerifyWithRuntime(lppComp, lppProof) require.NoError(t, valid) + // Compile and prove for moduleA + allVerfiers.RuntimesA = make([]wizard.Runtime, 0, numSegModuleA) + for proverID := 0; proverID < numSegModuleA; proverID++ { + proofA := wizard.Prove(moduleCompA, func(run *wizard.ProverRuntime) { + run.ParentRuntime = initialRuntime + // inputs for vertical splitting of the witness + run.ProverID = proverID + }) + runtimeA, validA := wizard.VerifyWithRuntime(moduleCompA, proofA, lppVerifierRuntime) + require.NoError(t, validA) + + allVerfiers.RuntimesA = append(allVerfiers.RuntimesA, runtimeA) + } + + // Compile and prove for moduleB + allVerfiers.RuntimesB = make([]wizard.Runtime, 0, numSegModuleB) + for proverID := 0; proverID < numSegModuleB; proverID++ { + proofB := wizard.Prove(moduleCompB, func(run *wizard.ProverRuntime) { + run.ParentRuntime = initialRuntime + // inputs for vertical splitting of the witness + run.ProverID = proverID + }) + runtimeB, validB := wizard.VerifyWithRuntime(moduleCompB, proofB, lppVerifierRuntime) + require.NoError(t, validB) + + allVerfiers.RuntimesB = append(allVerfiers.RuntimesB, runtimeB) + + } + + // Compile and prove for moduleC + allVerfiers.RuntimesC = make([]wizard.Runtime, 0, numSegModuleC) + for proverID := 0; proverID < numSegModuleC; proverID++ { + proofC := wizard.Prove(moduleCompC, func(run *wizard.ProverRuntime) { + run.ParentRuntime = initialRuntime + // inputs for vertical splitting of the witness + run.ProverID = proverID + }) + runtimeC, validC := wizard.VerifyWithRuntime(moduleCompC, proofC, lppVerifierRuntime) + require.NoError(t, validC) + + allVerfiers.RuntimesC = append(allVerfiers.RuntimesC, runtimeC) + } + + // apply the crosse checks over the public inputs. + require.NoError(t, checkConsistency(allVerfiers)) + }) } } + +func checkConsistency(allVerRuns AllVerifierRuntimes) error { + + var ( + res = field.Zero() + currCumSumArrayA = make([]field.Element, 0, len(allVerRuns.RuntimesA)) + prevCumSumArrayA = make([]field.Element, 0, len(allVerRuns.RuntimesA)) + currCumSumArrayB = make([]field.Element, 0, len(allVerRuns.RuntimesB)) + prevCumSumArrayB = make([]field.Element, 0, len(allVerRuns.RuntimesB)) + currCumSumArrayC = make([]field.Element, 0, len(allVerRuns.RuntimesC)) + prevCumSumArrayC = make([]field.Element, 0, len(allVerRuns.RuntimesC)) + ) + for _, run := range allVerRuns.RuntimesA { + distributedPubInputs_ := run.GetPublicInput(constants.DistributedProjectionPublicInput) + distributedPubInputs, ok := distributedPubInputs_.(wizard.DistributedProjectionPublicInput) + if !ok { + return errors.New("the distributed projection public input is not valid for module A") + } + res.Add(&res, &distributedPubInputs.ScaledHorner) + currCumSumArrayA = append(currCumSumArrayA, distributedPubInputs.CumSumCurr) + prevCumSumArrayA = append(prevCumSumArrayA, distributedPubInputs.CumSumPrev) + } + for _, run := range allVerRuns.RuntimesB { + distributedPubInputs_ := run.GetPublicInput(constants.DistributedProjectionPublicInput) + distributedPubInputs, ok := distributedPubInputs_.(wizard.DistributedProjectionPublicInput) + if !ok { + return errors.New("the distributed projection public input is not valid for module B") + } + res.Add(&res, &distributedPubInputs.ScaledHorner) + currCumSumArrayB = append(currCumSumArrayB, distributedPubInputs.CumSumCurr) + prevCumSumArrayB = append(prevCumSumArrayB, distributedPubInputs.CumSumPrev) + } + for _, run := range allVerRuns.RuntimesC { + distributedPubInputs_ := run.GetPublicInput(constants.DistributedProjectionPublicInput) + distributedPubInputs, ok := distributedPubInputs_.(wizard.DistributedProjectionPublicInput) + if !ok { + return errors.New("the distributed projection public input is not valid for module C") + } + res.Add(&res, &distributedPubInputs.ScaledHorner) + currCumSumArrayC = append(currCumSumArrayC, distributedPubInputs.CumSumCurr) + prevCumSumArrayC = append(prevCumSumArrayC, distributedPubInputs.CumSumPrev) + } + + if !res.IsZero() { + return errors.New("the distributed projection sums do not cancel each others") + } + for i := 1; i < len(currCumSumArrayA); i++ { + if currCumSumArrayA[i-1] != prevCumSumArrayA[i] { + return errors.New("the vertical splitting for the distributed projection is not consistent for module A") + } + } + for i := 1; i < len(currCumSumArrayB); i++ { + if currCumSumArrayB[i-1] != prevCumSumArrayB[i] { + return errors.New("the vertical splitting for the distributed projection is not consistent for module B") + } + } + for i := 1; i < len(currCumSumArrayC); i++ { + if currCumSumArrayC[i-1] != prevCumSumArrayC[i] { + return errors.New("the vertical splitting for the distributed projection is not consistent for module C") + } + } + + return nil +} diff --git a/prover/protocol/distributed/conglomeration/translator.go b/prover/protocol/distributed/conglomeration/translator.go index e7e3bb960..11b22ad8e 100644 --- a/prover/protocol/distributed/conglomeration/translator.go +++ b/prover/protocol/distributed/conglomeration/translator.go @@ -254,7 +254,7 @@ func (run *runtimeTranslator) GetSpec() *wizard.CompiledIOP { return run.Rt.GetSpec() } -func (run *runtimeTranslator) GetPublicInput(name string) field.Element { +func (run *runtimeTranslator) GetPublicInput(name string) any { name = run.Prefix + "." + name return run.Rt.GetPublicInput(name) } diff --git a/prover/protocol/distributed/constants/constant.go b/prover/protocol/distributed/constants/constant.go index 0a35a13d5..f66ec3b65 100644 --- a/prover/protocol/distributed/constants/constant.go +++ b/prover/protocol/distributed/constants/constant.go @@ -1,7 +1,8 @@ package constants const ( - LogDerivativeSumPublicInput = "LOG_DERIVATE_SUM_PUBLIC_INPUT" - GrandProductPublicInput = "GRAND_PRODUCT_PUBLIC_INPUT" - GrandSumPublicInput = "GRAND_SUM_PUBLIC_INPUT" + LogDerivativeSumPublicInput = "LOG_DERIVATE_SUM_PUBLIC_INPUT" + GrandProductPublicInput = "GRAND_PRODUCT_PUBLIC_INPUT" + GrandSumPublicInput = "GRAND_SUM_PUBLIC_INPUT" + DistributedProjectionPublicInput = "DISTRIBUTED_PROJECTION_PUBLIC_INPUT" ) diff --git a/prover/protocol/distributed/lpp/lpp.go b/prover/protocol/distributed/lpp/lpp.go index c4feba827..32b70704c 100644 --- a/prover/protocol/distributed/lpp/lpp.go +++ b/prover/protocol/distributed/lpp/lpp.go @@ -26,15 +26,17 @@ func CompileLPPAndGetSeed(comp *wizard.CompiledIOP, lppCompilers ...func(*wizard // applies lppCompiler; this would add a new round and probably new columns to the current round // but no new column to the new round. - for _, lppCompiler := range lppCompilers { - lppCompiler(comp) + if len(lppCompilers) > 0 { + for _, lppCompiler := range lppCompilers { + lppCompiler(comp) - if comp.NumRounds() != 2 || comp.Columns.NumRounds() != 1 { - panic("we expect to have new round while no column is yet registered for the new round") - } + if comp.NumRounds() != 2 || comp.Columns.NumRounds() != 1 { + panic("we expect to have new round while no column is yet registered for the new round") + } - numRounds := comp.NumRounds() - comp.EqualizeRounds(numRounds) + numRounds := comp.NumRounds() + comp.EqualizeRounds(numRounds) + } } // filter the new lpp columns. @@ -60,7 +62,7 @@ func CompileLPPAndGetSeed(comp *wizard.CompiledIOP, lppCompilers ...func(*wizard // register the seed, generated from LPP, in comp // for the sake of the assignment it also should be registered in lppComp lppComp.InsertCoin(1, "SEED", coin.Field) - comp.InsertCoin(1, "SEED", coin.Field) + comp.InsertCoin(0, "SEED", coin.Field) // prepare and register prover actions. lppProver := &lppProver{ @@ -118,7 +120,8 @@ func GetLPPComp(oldComp *wizard.CompiledIOP, newLPPCols []ifaces.Column) *wizard // it extract LPP columns from the context of each LPP query. func GetLPPColumns(c *wizard.CompiledIOP) []ifaces.Column { - + // Todo(arijit): should prevent inserting duplicate columns + // Use checkAlreadyExists() var ( lppColumns = []ifaces.Column{} ) @@ -127,30 +130,46 @@ func GetLPPColumns(c *wizard.CompiledIOP) []ifaces.Column { q := c.QueriesNoParams.Data(qName) switch v := q.(type) { case query.Inclusion: - - for i := range v.Including { - lppColumns = append(lppColumns, v.Including[i]...) + if !checkAlreadyExists(lppColumns, v.Including[0][0]) { + for i := range v.Including { + lppColumns = append(lppColumns, v.Including[i]...) + } + } + if !checkAlreadyExists(lppColumns, v.Included[0]) { + lppColumns = append(lppColumns, v.Included...) } - - lppColumns = append(lppColumns, v.Included...) if v.IncludingFilter != nil { - lppColumns = append(lppColumns, v.IncludingFilter...) + if !checkAlreadyExists(lppColumns, v.IncludingFilter[0]) { + lppColumns = append(lppColumns, v.IncludingFilter...) + } } if v.IncludedFilter != nil { - lppColumns = append(lppColumns, v.IncludedFilter) + if !checkAlreadyExists(lppColumns, v.IncludedFilter) { + lppColumns = append(lppColumns, v.IncludedFilter) + } } case query.Permutation: for i := range v.A { - lppColumns = append(lppColumns, v.A[i]...) - lppColumns = append(lppColumns, v.B[i]...) + if !checkAlreadyExists(lppColumns, v.A[0][0]) { + lppColumns = append(lppColumns, v.A[i]...) + } + if !checkAlreadyExists(lppColumns, v.B[0][0]) { + lppColumns = append(lppColumns, v.B[i]...) + } } case query.Projection: - lppColumns = append(lppColumns, v.Inp.ColumnA...) - lppColumns = append(lppColumns, v.Inp.ColumnB...) - lppColumns = append(lppColumns, v.Inp.FilterA, v.Inp.FilterB) + // If ColumnA exists then FilterA also exists + if !checkAlreadyExists(lppColumns, v.Inp.ColumnA[0]) { + lppColumns = append(lppColumns, v.Inp.ColumnA...) + lppColumns = append(lppColumns, v.Inp.FilterA) + } + if !checkAlreadyExists(lppColumns, v.Inp.ColumnB[0]) { + lppColumns = append(lppColumns, v.Inp.ColumnB...) + lppColumns = append(lppColumns, v.Inp.FilterB) + } default: //do noting @@ -160,3 +179,13 @@ func GetLPPColumns(c *wizard.CompiledIOP) []ifaces.Column { return lppColumns } + +// it checks if the column is already inserted in the list +func checkAlreadyExists(lppColumns []ifaces.Column, sampleCol ifaces.Column) bool { + for _, col := range lppColumns { + if col.GetColID() == sampleCol.GetColID() { + return true + } + } + return false +} diff --git a/prover/protocol/distributed/lpp/lpp_test.go b/prover/protocol/distributed/lpp/lpp_test.go index 5946c0b2d..88f609cc2 100644 --- a/prover/protocol/distributed/lpp/lpp_test.go +++ b/prover/protocol/distributed/lpp/lpp_test.go @@ -140,7 +140,7 @@ func TestSeedGeneration(t *testing.T) { run.ProverID = proverID }) - // get and compar the coins with the other segments/modules + // get and compare the coins with the other segments/modules coin1 := runtime0.Coins.MustGet("TABLE_module0.col2_LOGDERIVATIVE_GAMMA_FieldFromSeed").(field.Element) coin0 := runtime0.Coins.MustGet("TABLE_module1.col0_LOGDERIVATIVE_GAMMA_FieldFromSeed").(field.Element) if coinLookup1Gamma.IsZero() { diff --git a/prover/protocol/query/distributed_projection.go b/prover/protocol/query/distributed_projection.go index 334ce572e..3ad815a5d 100644 --- a/prover/protocol/query/distributed_projection.go +++ b/prover/protocol/query/distributed_projection.go @@ -1,27 +1,31 @@ package query import ( - "fmt" + "math/big" "github.com/consensys/gnark/frontend" "github.com/consensys/linea-monorepo/prover/crypto/fiatshamir" - "github.com/consensys/linea-monorepo/prover/maths/common/poly" "github.com/consensys/linea-monorepo/prover/maths/field" "github.com/consensys/linea-monorepo/prover/protocol/coin" - "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/symbolic" "github.com/consensys/linea-monorepo/prover/utils" ) type DistributedProjectionInput struct { - ColumnA, ColumnB *symbolic.Expression - FilterA, FilterB *symbolic.Expression - SizeA, SizeB int - EvalCoin coin.Name - IsAInModule, IsBInModule bool + ColumnA, ColumnB *symbolic.Expression + FilterA, FilterB *symbolic.Expression + SizeA, SizeB int + EvalCoin coin.Name + IsAInModule, IsBInModule bool + CumulativeNumOnesPrevSegmentsA, CumulativeNumOnesPrevSegmentsB big.Int + CurrNumOnesA, CurrNumOnesB field.Element } +// func (dpInp *DistributedProjectionInput) completeAssign(run *ifaces.Runtime) { +// dpInp.CumulativeNumOnesPrevSegments.Run(run) +// } + type DistributedProjection struct { Round int ID ifaces.QueryID @@ -29,7 +33,8 @@ type DistributedProjection struct { } type DistributedProjectionParams struct { - HornerVal field.Element + ScaledHorner field.Element + HashCumSumOnePrev, HashCumSumOneCurr field.Element } func NewDistributedProjection(round int, id ifaces.QueryID, inp []*DistributedProjectionInput) DistributedProjection { @@ -54,8 +59,11 @@ func NewDistributedProjection(round int, id ifaces.QueryID, inp []*DistributedPr } // Constructor for distributed projection query parameters -func NewDistributedProjectionParams(hornerVal field.Element) DistributedProjectionParams { - return DistributedProjectionParams{HornerVal: hornerVal} +func NewDistributedProjectionParams(scaledHorner, hashCumSumOnePrev, hashCumSumOneCurr field.Element) DistributedProjectionParams { + return DistributedProjectionParams{ + ScaledHorner: scaledHorner, + HashCumSumOnePrev: hashCumSumOnePrev, + HashCumSumOneCurr: hashCumSumOneCurr} } // Name returns the unique identifier of the GrandProduct query. @@ -65,56 +73,11 @@ func (dp DistributedProjection) Name() ifaces.QueryID { // Updates a Fiat-Shamir state func (dpp DistributedProjectionParams) UpdateFS(fs *fiatshamir.State) { - fs.Update(dpp.HornerVal) + fs.Update(dpp.ScaledHorner) } +// Unimplemented func (dp DistributedProjection) Check(run ifaces.Runtime) error { - var ( - actualParam = field.Zero() - params = run.GetParams(dp.ID).(DistributedProjectionParams) - evalRand field.Element - ) - _, errBeta := evalRand.SetRandom() - if errBeta != nil { - // Cannot happen unless the entropy was exhausted - panic(errBeta) - } - for _, inp := range dp.Inp { - var ( - colABoard = inp.ColumnA.Board() - colBBoard = inp.ColumnB.Board() - filterABorad = inp.FilterA.Board() - filterBBoard = inp.FilterB.Board() - colA = column.EvalExprColumn(run, colABoard).IntoRegVecSaveAlloc() - colB = column.EvalExprColumn(run, colBBoard).IntoRegVecSaveAlloc() - filterA = column.EvalExprColumn(run, filterABorad).IntoRegVecSaveAlloc() - filterB = column.EvalExprColumn(run, filterBBoard).IntoRegVecSaveAlloc() - elemParam = field.One() - ) - if inp.IsAInModule && !inp.IsBInModule { - hornerA := poly.GetHornerTrace(colA, filterA, evalRand) - elemParam = hornerA[0] - } else if !inp.IsAInModule && inp.IsBInModule { - hornerB := poly.GetHornerTrace(colB, filterB, evalRand) - elemParam = hornerB[0] - elemParam.Neg(&elemParam) - } else if inp.IsAInModule && inp.IsBInModule { - hornerA := poly.GetHornerTrace(colA, filterA, evalRand) - hornerB := poly.GetHornerTrace(colB, filterB, evalRand) - elemParam = hornerB[0] - elemParam.Neg(&elemParam) - elemParam.Add(&elemParam, &hornerA[0]) - } else { - utils.Panic("Invalid distributed projection query %v", dp.ID) - } - actualParam.Add(&actualParam, &elemParam) - - } - - if actualParam != params.HornerVal { - return fmt.Errorf("the distributed projection query %v is not satisfied, actualParam = %v, param.HornerVal = %v", dp.ID, actualParam, params.HornerVal) - } - return nil } diff --git a/prover/protocol/query/distributed_projection_test.go b/prover/protocol/query/distributed_projection_test.go deleted file mode 100644 index eb4b7c67a..000000000 --- a/prover/protocol/query/distributed_projection_test.go +++ /dev/null @@ -1,95 +0,0 @@ -package query_test - -import ( - "testing" - - "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" - "github.com/consensys/linea-monorepo/prover/maths/field" - "github.com/consensys/linea-monorepo/prover/protocol/ifaces" - "github.com/consensys/linea-monorepo/prover/protocol/query" - "github.com/consensys/linea-monorepo/prover/protocol/wizard" - "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" -) - -func TestDistributedProjectionQuery(t *testing.T) { - var ( - runS *wizard.ProverRuntime - DP ifaces.Query - round = 0 - flagSizeA = 512 - flagSizeB = 256 - flagA, flagB, columnA, columnB ifaces.Column - flagAWit = make([]field.Element, flagSizeA) - columnAWit = make([]field.Element, flagSizeA) - flagBWit = make([]field.Element, flagSizeB) - columnBWit = make([]field.Element, flagSizeB) - queryNameBothAAndB = ifaces.QueryID("DistributedProjectionTestBothAAndB") - ) - // Computing common test data - - // assign filters and columns - for i := 0; i < 10; i++ { - flagAWit[i] = field.One() - columnAWit[i] = field.NewElement(uint64(i)) - } - for i := flagSizeB - 10; i < flagSizeB; i++ { - flagBWit[i] = field.One() - columnBWit[i] = field.NewElement(uint64(i - (flagSizeB - 10))) - } - - testcases := []struct { - Name string - HornerParam field.Element - QueryName ifaces.QueryID - DefineFunc func(builder *wizard.Builder) - ProverFunc func(run *wizard.ProverRuntime) - }{ - { - Name: "distributed-projection-both-A-and-B", - QueryName: queryNameBothAAndB, - DefineFunc: func(builder *wizard.Builder) { - flagA = builder.RegisterCommit(ifaces.ColID("FilterA"), flagSizeA) - flagB = builder.RegisterCommit(ifaces.ColID("FliterB"), flagSizeB) - columnA = builder.RegisterCommit(ifaces.ColID("ColumnA"), flagSizeA) - columnB = builder.RegisterCommit(ifaces.ColID("ColumnB"), flagSizeB) - var ( - colA, _, _ = wizardutils.AsExpr(columnA) - colB, _, _ = wizardutils.AsExpr(columnB) - fA, _, _ = wizardutils.AsExpr(flagA) - fB, _, _ = wizardutils.AsExpr(flagB) - ) - DP = builder.CompiledIOP.InsertDistributedProjection(round, queryNameBothAAndB, - []*query.DistributedProjectionInput{ - {ColumnA: colA, ColumnB: colB, FilterA: fA, FilterB: fB, IsAInModule: true, IsBInModule: true}, - }) - }, - ProverFunc: func(run *wizard.ProverRuntime) { - runS = run - run.AssignColumn(flagA.GetColID(), smartvectors.RightZeroPadded(flagAWit, flagSizeA)) - run.AssignColumn(flagB.GetColID(), smartvectors.RightZeroPadded(flagBWit, flagSizeB)) - run.AssignColumn(columnB.GetColID(), smartvectors.RightZeroPadded(columnBWit, flagSizeB)) - run.AssignColumn(columnA.GetColID(), smartvectors.RightZeroPadded(columnAWit, flagSizeA)) - - runS.AssignDistributedProjection(queryNameBothAAndB, query.DistributedProjectionParams{HornerVal: field.Zero()}) - }, - }, - } - - for _, tc := range testcases { - - t.Run(tc.Name, func(t *testing.T) { - prover := tc.ProverFunc - var ( - comp = wizard.Compile(tc.DefineFunc) - _ = wizard.Prove(comp, prover) - errDP = DP.Check(runS) - ) - - if errDP != nil { - t.Fatalf("error verifying the distributed projection query: %v", errDP.Error()) - } - - }) - } - -} diff --git a/prover/protocol/query/gnark_params.go b/prover/protocol/query/gnark_params.go index c0da87154..5b7f98c75 100644 --- a/prover/protocol/query/gnark_params.go +++ b/prover/protocol/query/gnark_params.go @@ -83,6 +83,11 @@ func (p GnarkGrandProductParams) UpdateFS(fs *fiatshamir.GnarkFiatShamir) { fs.Update(p.Prod) } +// Update the fiat-shamir state with the the present parameters +func (p GnarkDistributedProjectionParams) UpdateFS(fs *fiatshamir.GnarkFiatShamir) { + fs.Update(p.Sum) +} + // Update the fiat-shamir state with the the present parameters func (p GnarkUnivariateEvalParams) UpdateFS(fs *fiatshamir.GnarkFiatShamir) { fs.Update(p.Ys...) diff --git a/prover/protocol/wizard/prover.go b/prover/protocol/wizard/prover.go index 9084b2e5b..d555801f3 100644 --- a/prover/protocol/wizard/prover.go +++ b/prover/protocol/wizard/prover.go @@ -212,7 +212,7 @@ func ProverOnlyFirstRound(c *CompiledIOP, highLevelprover ProverStep) *ProverRun // highLevelprover(&runtime) - // Then, run the compiled prover steps. This will only run thoses of the + // Then, run the compiled prover steps. This will only run those of the // first round. // runtime.runProverSteps() @@ -849,6 +849,8 @@ func (run *ProverRuntime) AssignDistributedProjection(name ifaces.QueryID, distr run.Spec.QueriesParams.MustBeInRound(run.currRound, name) // Adds it to the assignments - params := query.NewDistributedProjectionParams(distributedProjectionParam.HornerVal) + params := query.NewDistributedProjectionParams(distributedProjectionParam.ScaledHorner, + distributedProjectionParam.HashCumSumOnePrev, + distributedProjectionParam.HashCumSumOneCurr) run.QueriesParams.InsertNew(name, params) } diff --git a/prover/protocol/wizard/verifier.go b/prover/protocol/wizard/verifier.go index 4d91e1369..9767609ae 100644 --- a/prover/protocol/wizard/verifier.go +++ b/prover/protocol/wizard/verifier.go @@ -4,7 +4,9 @@ import ( "github.com/consensys/linea-monorepo/prover/crypto/fiatshamir" "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/accessors" "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/distributed/constants" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/utils" @@ -33,13 +35,21 @@ type Proof struct { QueriesParams collection.Mapping[ifaces.QueryID, ifaces.QueryParams] } +// DistributedProjectionPublicInput is a struct that holds the public inputs +// for the distributed projection protocol. +type DistributedProjectionPublicInput struct { + ScaledHorner field.Element + CumSumPrev field.Element + CumSumCurr field.Element +} + // Runtime is a generic interface extending the [ifaces.Runtime] interface // with all methods of [wizard.VerifierRuntime]. This is used to allow the // writing of adapters for the verifier runtime. type Runtime interface { ifaces.Runtime GetSpec() *CompiledIOP - GetPublicInput(name string) field.Element + GetPublicInput(name string) any GetGrandProductParams(name ifaces.QueryID) query.GrandProductParams GetDistributedProjectionParams(name ifaces.QueryID) query.DistributedProjectionParams GetLogDerivSumParams(name ifaces.QueryID) query.LogDerivSumParams @@ -491,9 +501,18 @@ func (run *VerifierRuntime) GetParams(name ifaces.QueryID) ifaces.QueryParams { } // GetPublicInput returns a public input from its name -func (run *VerifierRuntime) GetPublicInput(name string) field.Element { +func (run *VerifierRuntime) GetPublicInput(name string) any { allPubs := run.Spec.PublicInputs for i := range allPubs { + if allPubs[i].Name == name && name == constants.DistributedProjectionPublicInput { + if s, ok := allPubs[i].Acc.(*accessors.FromDistributedProjectionAccessor); ok { + return DistributedProjectionPublicInput{ + ScaledHorner: s.GetVal(run), + CumSumCurr: s.GetValCumSumCurr(run), + CumSumPrev: s.GetValCumSumPrev(run), + } + } + } if allPubs[i].Name == name { return allPubs[i].Acc.GetVal(run) }