diff --git a/prover/maths/common/poly/poly.go b/prover/maths/common/poly/poly.go index dfa685144..ca376b636 100644 --- a/prover/maths/common/poly/poly.go +++ b/prover/maths/common/poly/poly.go @@ -123,11 +123,10 @@ func EvaluateLagrangesAnyDomain(domain []field.Element, x field.Element) []field return lagrange } -// CmptHorner computes a random Horner accumulation of the filtered elements +// GetHornerTrace computes a random Horner accumulation of the filtered elements // starting from the last entry down to the first entry. The final value is // stored in the last entry of the returned slice. -// Todo: send it to a common utility package -func CmptHorner(c, fC []field.Element, x field.Element) []field.Element { +func GetHornerTrace(c, fC []field.Element, x field.Element) []field.Element { var ( horner = make([]field.Element, len(c)) diff --git a/prover/protocol/accessors/from_distributedprojection.go b/prover/protocol/accessors/from_distributedprojection.go new file mode 100644 index 000000000..0677cdd84 --- /dev/null +++ b/prover/protocol/accessors/from_distributedprojection.go @@ -0,0 +1,71 @@ +package accessors + +import ( + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/symbolic" +) + +const ( + DISTRIBUTED_PROJECTION_ACCESSOR = "DISTRIBUTED_PROJECTION_ACCESSOR" +) + +// FromDistributedProjectionAccessor implements [ifaces.Accessor] and accesses the result of +// a [query.DISTRIBUTED_PROJECTION]. +type FromDistributedProjectionAccessor struct { + // Q is the underlying query whose parameters are accessed by the current + // [ifaces.Accessor]. + Q query.DistributedProjection +} + +// NewDistributedProjectionAccessor creates an [ifaces.Accessor] returning the opening +// point of a [query.DISTRIBUTED_PROJECTION]. +func NewDistributedProjectionAccessor(q query.DistributedProjection) ifaces.Accessor { + return &FromDistributedProjectionAccessor{Q: q} +} + +// Name implements [ifaces.Accessor] +func (l *FromDistributedProjectionAccessor) Name() string { + return fmt.Sprintf("%v_%v", DISTRIBUTED_PROJECTION_ACCESSOR, l.Q.ID) +} + +// String implements [github.com/consensys/linea-monorepo/prover/symbolic.Metadata] +func (l *FromDistributedProjectionAccessor) String() string { + return l.Name() +} + +// GetVal implements [ifaces.Accessor] +func (l *FromDistributedProjectionAccessor) GetVal(run ifaces.Runtime) field.Element { + params := run.GetParams(l.Q.ID).(query.DistributedProjectionParams) + return params.ScaledHorner +} + +func (l *FromDistributedProjectionAccessor) GetValCumSumCurr(run ifaces.Runtime) field.Element { + params := run.GetParams(l.Q.ID).(query.DistributedProjectionParams) + return params.HashCumSumOneCurr +} + +func (l *FromDistributedProjectionAccessor) GetValCumSumPrev(run ifaces.Runtime) field.Element { + params := run.GetParams(l.Q.ID).(query.DistributedProjectionParams) + return params.HashCumSumOnePrev +} + +// GetFrontendVariable implements [ifaces.Accessor] +func (l *FromDistributedProjectionAccessor) GetFrontendVariable(_ frontend.API, circ ifaces.GnarkRuntime) frontend.Variable { + params := circ.GetParams(l.Q.ID).(query.GnarkDistributedProjectionParams) + return params.Sum +} + +// AsVariable implements the [ifaces.Accessor] interface +func (l *FromDistributedProjectionAccessor) AsVariable() *symbolic.Expression { + return symbolic.NewVariable(l) +} + +// Round implements the [ifaces.Accessor] interface +func (l *FromDistributedProjectionAccessor) Round() int { + return l.Q.Round +} diff --git a/prover/protocol/compiler/distributedprojection/compiler.go b/prover/protocol/compiler/distributedprojection/compiler.go new file mode 100644 index 000000000..8c1c1a283 --- /dev/null +++ b/prover/protocol/compiler/distributedprojection/compiler.go @@ -0,0 +1,68 @@ +package distributedprojection + +import ( + "math/big" + + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/symbolic" +) + +func CompileDistributedProjection(comp *wizard.CompiledIOP) { + + for _, qName := range comp.QueriesParams.AllUnignoredKeys() { + // Filter out non distributed projection queries + distributedprojection, ok := comp.QueriesParams.Data(qName).(query.DistributedProjection) + if !ok { + continue + } + + // This ensures that the distributed projection query is not used again in the + // compilation process. We know that the query was not already ignored at the beginning + // because we are iterating over the unignored keys. + round := comp.QueriesParams.Round(qName) + compile(comp, round, distributedprojection) + } +} + +func compile(comp *wizard.CompiledIOP, round int, distributedprojection query.DistributedProjection) { + var ( + pa = &distribuedProjectionProverAction{ + Name: distributedprojection.ID, + Query: distributedprojection, + FilterA: make([]*symbolic.Expression, len(distributedprojection.Inp)), + FilterB: make([]*symbolic.Expression, len(distributedprojection.Inp)), + ColumnA: make([]*symbolic.Expression, len(distributedprojection.Inp)), + ColumnB: make([]*symbolic.Expression, len(distributedprojection.Inp)), + HornerA: make([]ifaces.Column, len(distributedprojection.Inp)), + HornerB: make([]ifaces.Column, len(distributedprojection.Inp)), + HornerA0: make([]query.LocalOpening, len(distributedprojection.Inp)), + HornerB0: make([]query.LocalOpening, len(distributedprojection.Inp)), + EvalCoins: make([]coin.Info, len(distributedprojection.Inp)), + IsA: make([]bool, len(distributedprojection.Inp)), + IsB: make([]bool, len(distributedprojection.Inp)), + } + ) + pa.Push(comp, distributedprojection) + pa.RegisterQueries(comp, round, distributedprojection) + comp.RegisterProverAction(round, pa) + comp.RegisterVerifierAction(round, &distributedProjectionVerifierAction{ + Name: pa.Name, + Query: pa.Query, + HornerA0: pa.HornerA0, + HornerB0: pa.HornerB0, + IsA: pa.IsA, + IsB: pa.IsB, + EvalCoins: pa.EvalCoins, + FilterA: pa.FilterA, + FilterB: pa.FilterB, + CumNumOnesPrevSegmentsA: make([]big.Int, len(distributedprojection.Inp)), + CumNumOnesPrevSegmentsB: make([]big.Int, len(distributedprojection.Inp)), + NumOnesCurrSegmentA: make([]field.Element, len(distributedprojection.Inp)), + NumOnesCurrSegmentB: make([]field.Element, len(distributedprojection.Inp)), + }) + +} diff --git a/prover/protocol/compiler/distributedprojection/prover.go b/prover/protocol/compiler/distributedprojection/prover.go new file mode 100644 index 000000000..7406318b4 --- /dev/null +++ b/prover/protocol/compiler/distributedprojection/prover.go @@ -0,0 +1,287 @@ +package distributedprojection + +import ( + "github.com/consensys/linea-monorepo/prover/maths/common/poly" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + sym "github.com/consensys/linea-monorepo/prover/symbolic" + "github.com/consensys/linea-monorepo/prover/utils" + "github.com/consensys/linea-monorepo/prover/utils/collection" + "github.com/sirupsen/logrus" +) + +type distribuedProjectionProverAction struct { + Name ifaces.QueryID + Query query.DistributedProjection + FilterA, FilterB []*sym.Expression + ColumnA, ColumnB []*sym.Expression + HornerA, HornerB []ifaces.Column + HornerA0, HornerB0 []query.LocalOpening + EvalCoins []coin.Info + IsA, IsB []bool +} + +// Run executes the distributed projection prover action. +// It iterates over the input filters, columns, and evaluation coins. +// Depending on the values of IsA and IsB, it computes the Horner traces for columns A and B, +// and assigns them to the corresponding columns and local points in the prover runtime. +// If both IsA and IsB are true, it computes the Horner traces for both columns A and B. +// If IsA is true and IsB is false, it computes the Horner trace for column A only. +// If IsA is false and IsB is true, it computes the Horner trace for column B only. +// If neither IsA nor IsB is true, it panics with an error message indicating an invalid prover assignment. +func (pa *distribuedProjectionProverAction) Run(run *wizard.ProverRuntime) { + for index := range pa.Query.Inp { + if pa.IsA[index] && pa.IsB[index] { + var ( + colA = column.EvalExprColumn(run, pa.ColumnA[index].Board()).IntoRegVecSaveAlloc() + fA = column.EvalExprColumn(run, pa.FilterA[index].Board()).IntoRegVecSaveAlloc() + colB = column.EvalExprColumn(run, pa.ColumnB[index].Board()).IntoRegVecSaveAlloc() + fB = column.EvalExprColumn(run, pa.FilterB[index].Board()).IntoRegVecSaveAlloc() + x = run.GetRandomCoinField(pa.EvalCoins[index].Name) + hornerA = poly.GetHornerTrace(colA, fA, x) + hornerB = poly.GetHornerTrace(colB, fB, x) + ) + run.AssignColumn(pa.HornerA[index].GetColID(), smartvectors.NewRegular(hornerA)) + run.AssignLocalPoint(pa.HornerA0[index].ID, hornerA[0]) + run.AssignColumn(pa.HornerB[index].GetColID(), smartvectors.NewRegular(hornerB)) + run.AssignLocalPoint(pa.HornerB0[index].ID, hornerB[0]) + } else if pa.IsA[index] && !pa.IsB[index] { + var ( + colA = column.EvalExprColumn(run, pa.ColumnA[index].Board()).IntoRegVecSaveAlloc() + fA = column.EvalExprColumn(run, pa.FilterA[index].Board()).IntoRegVecSaveAlloc() + x = run.GetRandomCoinField(pa.EvalCoins[index].Name) + hornerA = poly.GetHornerTrace(colA, fA, x) + ) + run.AssignColumn(pa.HornerA[index].GetColID(), smartvectors.NewRegular(hornerA)) + run.AssignLocalPoint(pa.HornerA0[index].ID, hornerA[0]) + } else if !pa.IsA[index] && pa.IsB[index] { + var ( + colB = column.EvalExprColumn(run, pa.ColumnB[index].Board()).IntoRegVecSaveAlloc() + fB = column.EvalExprColumn(run, pa.FilterB[index].Board()).IntoRegVecSaveAlloc() + x = run.GetRandomCoinField(pa.EvalCoins[index].Name) + hornerB = poly.GetHornerTrace(colB, fB, x) + ) + run.AssignColumn(pa.HornerB[index].GetColID(), smartvectors.NewRegular(hornerB)) + run.AssignLocalPoint(pa.HornerB0[index].ID, hornerB[0]) + } else { + utils.Panic("Invalid prover assignment in distributed projection id: %v", pa.Name) + } + } +} + +// Push populates the distribuedProjectionProverAction with data from the provided DistributedProjection query. +// It processes each input in the query and assigns the corresponding values to the prover action's fields +// based on whether the input is in module A, module B, or both. +// +// Parameters: +// - comp: A pointer to the CompiledIOP, used to access coin data. +// - distributedprojection: The DistributedProjection query containing the inputs to be processed. +// +// The function does not return any value, but updates the fields of the distribuedProjectionProverAction in-place. +func (pa *distribuedProjectionProverAction) Push(comp *wizard.CompiledIOP, distributedprojection query.DistributedProjection) { + for index, input := range distributedprojection.Inp { + if input.IsAInModule && input.IsBInModule { + pa.FilterA[index] = input.FilterA + pa.FilterB[index] = input.FilterB + pa.ColumnA[index] = input.ColumnA + pa.ColumnB[index] = input.ColumnB + pa.EvalCoins[index] = comp.Coins.Data(input.EvalCoin) + pa.IsA[index] = true + pa.IsB[index] = true + } else if input.IsAInModule && !input.IsBInModule { + pa.FilterA[index] = input.FilterA + pa.ColumnA[index] = input.ColumnA + pa.EvalCoins[index] = comp.Coins.Data(input.EvalCoin) + pa.IsA[index] = true + pa.IsB[index] = false + } else if !input.IsAInModule && input.IsBInModule { + pa.FilterB[index] = input.FilterB + pa.ColumnB[index] = input.ColumnB + pa.EvalCoins[index] = comp.Coins.Data(input.EvalCoin) + pa.IsA[index] = false + pa.IsB[index] = true + } else { + logrus.Errorf("Invalid distributed projection query while pushing prover action entries: %v", distributedprojection.ID) + } + } +} + +// RegisterQueries registers the necessary queries for the distributed projection prover action. +// It processes each input in the distributed projection, shifting expressions and registering +// queries for columns A and B based on their presence in the respective modules. +// +// Parameters: +// - comp: A pointer to the CompiledIOP, used for inserting commits and queries. +// - round: An integer representing the current round of the protocol. +// - distributedprojection: A DistributedProjection query containing the inputs to be processed. +// +// The function does not return any value, but updates the internal state of the prover action +// by registering the required queries for each input. +func (pa *distribuedProjectionProverAction) RegisterQueries(comp *wizard.CompiledIOP, round int, distributedprojection query.DistributedProjection) { + for index, input := range distributedprojection.Inp { + if input.IsAInModule && input.IsBInModule { + var ( + fA = pa.FilterA[index] + fAShifted = shiftExpression(comp, fA, -1) + colA = pa.ColumnA[index] + colAShifted = shiftExpression(comp, colA, -1) + fB = pa.FilterB[index] + fBShifted = shiftExpression(comp, fB, -1) + colB = pa.ColumnB[index] + colBShifted = shiftExpression(comp, colB, -1) + ) + pa.registerForCol(comp, fAShifted, colAShifted, input, "A", round, index) + pa.registerForCol(comp, fBShifted, colBShifted, input, "B", round, index) + } else if input.IsAInModule && !input.IsBInModule { + var ( + fA = pa.FilterA[index] + fAShifted = shiftExpression(comp, fA, -1) + colA = pa.ColumnA[index] + colAShifted = shiftExpression(comp, colA, -1) + ) + pa.registerForCol(comp, fAShifted, colAShifted, input, "A", round, index) + } else if !input.IsAInModule && input.IsBInModule { + var ( + fB = pa.FilterB[index] + fBShifted = shiftExpression(comp, fB, -1) + colB = pa.ColumnB[index] + colBShifted = shiftExpression(comp, colB, -1) + ) + pa.registerForCol(comp, fBShifted, colBShifted, input, "B", round, index) + } else { + utils.Panic("Invalid prover action case for the distributed projection query %v", pa.Name) + } + } +} + +// registerForCol registers queries for a specific column (A or B) in the distributed projection prover action. +// It inserts commits, global queries, local queries, and local openings for the Horner polynomial evaluation. +// +// Parameters: +// - comp: A pointer to the CompiledIOP, used for inserting commits and queries. +// - fShifted: A shifted filter expression. +// - colShifted: A shifted column expression. +// - input: A pointer to the DistributedProjectionInput containing size information. +// - colName: A string indicating which column to register ("A" or "B"). +// - round: An integer representing the current round of the protocol. +// - index: An integer used to uniquely identify the registered queries. +// +// The function doesn't return any value but updates the internal state of the prover action +// by registering the required queries and commits for the specified column. +func (pa *distribuedProjectionProverAction) registerForCol( + comp *wizard.CompiledIOP, + fShifted, colShifted *sym.Expression, + input *query.DistributedProjectionInput, + colName string, + round int, + index int, +) { + switch colName { + case "A": + { + pa.HornerA[index] = comp.InsertCommit(round, ifaces.ColIDf("%v_HORNER_A_%v", pa.Name, index), input.SizeA) + comp.InsertGlobal( + round, + ifaces.QueryIDf("%v_HORNER_A_%v_GLOBAL", pa.Name, index), + sym.Sub( + pa.HornerA[index], + sym.Mul( + sym.Sub(1, pa.FilterA[index]), + column.Shift(pa.HornerA[index], 1), + ), + sym.Mul( + pa.FilterA[index], + sym.Add( + pa.ColumnA[index], + sym.Mul( + pa.EvalCoins[index], + column.Shift(pa.HornerA[index], 1), + ), + ), + ), + ), + ) + comp.InsertLocal( + round, + ifaces.QueryIDf("%v_HORNER_A_LOCAL_END_%v", pa.Name, index), + sym.Sub( + column.Shift(pa.HornerA[index], -1), + sym.Mul(fShifted, colShifted), + ), + ) + pa.HornerA0[index] = comp.InsertLocalOpening(round, ifaces.QueryIDf("%v_HORNER_A0_%v", pa.Name, index), pa.HornerA[index]) + } + case "B": + { + pa.HornerB[index] = comp.InsertCommit(round, ifaces.ColIDf("%v_HORNER_B_%v", pa.Name, index), input.SizeB) + + comp.InsertGlobal( + round, + ifaces.QueryIDf("%v_HORNER_B_%v_GLOBAL", pa.Name, index), + sym.Sub( + pa.HornerB[index], + sym.Mul( + sym.Sub(1, pa.FilterB[index]), + column.Shift(pa.HornerB[index], 1), + ), + sym.Mul( + pa.FilterB[index], + sym.Add(pa.ColumnB[index], sym.Mul(pa.EvalCoins[index], column.Shift(pa.HornerB[index], 1))), + ), + ), + ) + + comp.InsertLocal( + round, + ifaces.QueryIDf("%v_HORNER_B_LOCAL_END_%v", pa.Name, index), + sym.Sub( + column.Shift(pa.HornerB[index], -1), + sym.Mul(fShifted, colShifted), + ), + ) + + pa.HornerB0[index] = comp.InsertLocalOpening(round, ifaces.QueryIDf("%v_HORNER_B0_%v", pa.Name, index), pa.HornerB[index]) + } + default: + utils.Panic("Invalid column name %v, should be either A or B", colName) + } + +} + +// shiftExpression shifts a column which is a symbolic expression by a specified number of positions. +// It creates a new expression with the shifted column while maintaining the structure of the original expression. +// +// Parameters: +// - comp: A pointer to the CompiledIOP, used to check for the existence of coins. +// - expr: The original symbolic expression to be shifted. +// - nbShift: The number of positions to shift the column. Positive values shift forward, negative values shift backward. +// +// Returns: +// +// A new *sym.Expression with the column shifted according to the specified nbShift. +func shiftExpression(comp *wizard.CompiledIOP, expr *sym.Expression, nbShift int) *sym.Expression { + var ( + board = expr.Board() + metadata = board.ListVariableMetadata() + translationMap = collection.NewMapping[string, *sym.Expression]() + ) + + for _, m := range metadata { + switch t := m.(type) { + case ifaces.Column: + translationMap.InsertNew(string(t.GetColID()), ifaces.ColumnAsVariable(column.Shift(t, nbShift))) + case coin.Info: + if !comp.Coins.Exists(t.Name) { + utils.Panic("Coin %v does not exist in the InitialComp", t.Name) + } + translationMap.InsertNew(t.String(), sym.NewVariable(t)) + default: + utils.Panic("Unsupported type for shift expression operation") + } + } + return expr.Replay(translationMap) +} diff --git a/prover/protocol/compiler/distributedprojection/verifier.go b/prover/protocol/compiler/distributedprojection/verifier.go new file mode 100644 index 000000000..1a06e4172 --- /dev/null +++ b/prover/protocol/compiler/distributedprojection/verifier.go @@ -0,0 +1,231 @@ +package distributedprojection + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/linea-monorepo/prover/crypto/mimc" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + sym "github.com/consensys/linea-monorepo/prover/symbolic" + "github.com/consensys/linea-monorepo/prover/utils" +) + +type distributedProjectionVerifierAction struct { + Name ifaces.QueryID + Query query.DistributedProjection + HornerA0, HornerB0 []query.LocalOpening + IsA, IsB []bool + skipped bool + EvalCoins []coin.Info + CumNumOnesPrevSegmentsA []big.Int + CumNumOnesPrevSegmentsB []big.Int + NumOnesCurrSegmentA []field.Element + NumOnesCurrSegmentB []field.Element + FilterA, FilterB []*sym.Expression +} + +// Run implements the [wizard.VerifierAction] +func (va *distributedProjectionVerifierAction) Run(run wizard.Runtime) error { + for index, inp := range va.Query.Inp { + if va.IsA[index] && va.IsB[index] { + va.CumNumOnesPrevSegmentsA[index] = inp.CumulativeNumOnesPrevSegmentsA + va.NumOnesCurrSegmentA[index] = inp.CurrNumOnesA + va.CumNumOnesPrevSegmentsB[index] = inp.CumulativeNumOnesPrevSegmentsB + va.NumOnesCurrSegmentB[index] = inp.CurrNumOnesB + } else if va.IsA[index] && !va.IsB[index] { + va.CumNumOnesPrevSegmentsA[index] = inp.CumulativeNumOnesPrevSegmentsA + va.NumOnesCurrSegmentA[index] = inp.CurrNumOnesA + } else if !va.IsA[index] && va.IsB[index] { + va.CumNumOnesPrevSegmentsB[index] = inp.CumulativeNumOnesPrevSegmentsB + va.NumOnesCurrSegmentB[index] = inp.CurrNumOnesB + } + } + errorCheckHorner := va.scaledHornerCheck(run) + if errorCheckHorner != nil { + return errorCheckHorner + } + errorCheckCurrNumOnes := va.currSumOneCheck(run) + if errorCheckCurrNumOnes != nil { + return errorCheckCurrNumOnes + } + + errorCheckHash := va.hashCheck(run) + if errorCheckHash != nil { + return errorCheckHash + } + + return nil +} + +// method to check consistancy of the scaled horner +func (va *distributedProjectionVerifierAction) scaledHornerCheck(run wizard.Runtime) error { + var ( + actualParam = field.Zero() + ) + for index := range va.HornerA0 { + var ( + elemParam = field.Zero() + ) + if va.IsA[index] && va.IsB[index] { + var ( + multA, multB field.Element + hornerA, hornerB field.Element + ) + hornerA = run.GetLocalPointEvalParams(va.HornerA0[index].ID).Y + multA = run.GetRandomCoinField(va.EvalCoins[index].Name) + multA.Exp(multA, &va.CumNumOnesPrevSegmentsA[index]) + hornerA.Mul(&hornerA, &multA) + elemParam.Add(&elemParam, &hornerA) + + hornerB = run.GetLocalPointEvalParams(va.HornerB0[index].ID).Y + multB = run.GetRandomCoinField(va.EvalCoins[index].Name) + multB.Exp(multB, &va.CumNumOnesPrevSegmentsB[index]) + hornerB.Mul(&elemParam, &multB) + elemParam.Sub(&elemParam, &hornerB) + } else if va.IsA[index] && !va.IsB[index] { + var ( + multA field.Element + hornerA field.Element + ) + hornerA = run.GetLocalPointEvalParams(va.HornerA0[index].ID).Y + multA = run.GetRandomCoinField(va.EvalCoins[index].Name) + multA.Exp(multA, &va.CumNumOnesPrevSegmentsA[index]) + hornerA.Mul(&hornerA, &multA) + elemParam.Add(&elemParam, &hornerA) + } else if !va.IsA[index] && va.IsB[index] { + var ( + multB field.Element + hornerB field.Element + ) + hornerB = run.GetLocalPointEvalParams(va.HornerB0[index].ID).Y + multB = run.GetRandomCoinField(va.EvalCoins[index].Name) + multB.Exp(multB, &va.CumNumOnesPrevSegmentsB[index]) + hornerB.Mul(&hornerB, &multB) + elemParam.Sub(&elemParam, &hornerB) + } else { + utils.Panic("Unsupported verifier action registered for %v", va.Name) + } + actualParam.Add(&actualParam, &elemParam) + } + queryParam := run.GetDistributedProjectionParams(va.Name).ScaledHorner + if actualParam != queryParam { + utils.Panic("The distributed projection query %v did not pass, query param %v and actual param %v", va.Name, queryParam, actualParam) + } + return nil +} + +func (va *distributedProjectionVerifierAction) currSumOneCheck(run wizard.Runtime) error { + for index := range va.HornerA0 { + if va.IsA[index] && va.IsB[index] { + var ( + numOnesA = field.Zero() + numOnesB = field.Zero() + one = field.One() + ) + fA := column.EvalExprColumn(run, va.FilterA[index].Board()).IntoRegVecSaveAlloc() + for i := 0; i < len(fA); i++ { + if fA[i] == field.One() { + numOnesA.Add(&numOnesA, &one) + } + } + if numOnesA != va.NumOnesCurrSegmentA[index] { + return fmt.Errorf("number of one for filterA does not match, actual = %v, assigned = %v", numOnesA, va.NumOnesCurrSegmentA[index]) + } + fB := column.EvalExprColumn(run, va.FilterB[index].Board()).IntoRegVecSaveAlloc() + for i := 0; i < len(fB); i++ { + if fB[i] == field.One() { + numOnesB.Add(&numOnesB, &one) + } + } + if numOnesB != va.NumOnesCurrSegmentB[index] { + return fmt.Errorf("number of one for filterB does not match, actual = %v, assigned = %v", numOnesB, va.NumOnesCurrSegmentB[index]) + } + } + if va.IsA[index] && !va.IsB[index] { + var ( + numOnesA = field.Zero() + one = field.One() + ) + fA := column.EvalExprColumn(run, va.FilterA[index].Board()).IntoRegVecSaveAlloc() + for i := 0; i < len(fA); i++ { + if fA[i] == field.One() { + numOnesA.Add(&numOnesA, &one) + } + } + if numOnesA != va.NumOnesCurrSegmentA[index] { + return fmt.Errorf("number of one for filterA does not match, actual = %v, assigned = %v", numOnesA, va.NumOnesCurrSegmentA[index]) + } + } + if !va.IsA[index] && va.IsB[index] { + var ( + numOnesB = field.Zero() + one = field.One() + ) + fB := column.EvalExprColumn(run, va.FilterB[index].Board()).IntoRegVecSaveAlloc() + for i := 0; i < len(fB); i++ { + if fB[i] == field.One() { + numOnesB.Add(&numOnesB, &one) + } + } + if numOnesB != va.NumOnesCurrSegmentB[index] { + return fmt.Errorf("number of one for filterB does not match, actual = %v, assigned = %v", numOnesB, va.NumOnesCurrSegmentB[index]) + } + } + } + return nil +} + +func (va *distributedProjectionVerifierAction) hashCheck(run wizard.Runtime) error { + var ( + oldState = field.Zero() + ) + for index := range va.HornerA0 { + if va.IsA[index] && va.IsB[index] { + var ( + sumA, sumB field.Element + ) + sumA = field.NewElement(va.CumNumOnesPrevSegmentsA[index].Uint64()) + sumA.Add(&sumA, &va.NumOnesCurrSegmentA[index]) + sumB = field.NewElement(va.CumNumOnesPrevSegmentsB[index].Uint64()) + sumB.Add(&sumB, &va.NumOnesCurrSegmentB[index]) + oldState = mimc.BlockCompression(oldState, sumA) + oldState = mimc.BlockCompression(oldState, sumB) + } else if va.IsA[index] && !va.IsB[index] { + sumA := field.NewElement(va.CumNumOnesPrevSegmentsA[index].Uint64()) + sumA.Add(&sumA, &va.NumOnesCurrSegmentA[index]) + oldState = mimc.BlockCompression(oldState, sumA) + } else if !va.IsA[index] && va.IsB[index] { + sumB := field.NewElement(va.CumNumOnesPrevSegmentsB[index].Uint64()) + sumB.Add(&sumB, &va.NumOnesCurrSegmentB[index]) + oldState = mimc.BlockCompression(oldState, sumB) + } else { + panic("Invalid distributed projection query encountered during current hash verification") + } + } + if oldState != run.GetDistributedProjectionParams(va.Name).HashCumSumOneCurr { + return fmt.Errorf("HashCumSumOneCurr does not match, actual = %v, assigned = %v", oldState, run.GetDistributedProjectionParams(va.Name).HashCumSumOneCurr) + } + return nil +} + +// RunGnark implements the [wizard.VerifierAction] interface. +func (va *distributedProjectionVerifierAction) RunGnark(api frontend.API, run wizard.GnarkRuntime) { + + panic("unimplemented") +} + +// Skip implements the [wizard.VerifierAction] +func (va *distributedProjectionVerifierAction) Skip() { + va.skipped = true +} + +// IsSkipped implements the [wizard.VerifierAction] +func (va *distributedProjectionVerifierAction) IsSkipped() bool { + return va.skipped +} diff --git a/prover/protocol/compiler/projection/prover.go b/prover/protocol/compiler/projection/prover.go index 7726a0d3e..fb96d5c77 100644 --- a/prover/protocol/compiler/projection/prover.go +++ b/prover/protocol/compiler/projection/prover.go @@ -39,8 +39,8 @@ func (pa projectionProverAction) Run(run *wizard.ProverRuntime) { fA = pa.FilterA.GetColAssignment(run).IntoRegVecSaveAlloc() fB = pa.FilterB.GetColAssignment(run).IntoRegVecSaveAlloc() x = run.GetRandomCoinField(pa.EvalCoin.Name) - hornerA = poly.CmptHorner(a, fA, x) - hornerB = poly.CmptHorner(b, fB, x) + hornerA = poly.GetHornerTrace(a, fA, x) + hornerB = poly.GetHornerTrace(b, fB, x) ) run.AssignColumn(pa.HornerA.GetColID(), smartvectors.NewRegular(hornerA)) diff --git a/prover/protocol/compiler/stitch_split/splitter/constraints.go b/prover/protocol/compiler/stitch_split/splitter/constraints.go index 7eb3787eb..63c8a700b 100644 --- a/prover/protocol/compiler/stitch_split/splitter/constraints.go +++ b/prover/protocol/compiler/stitch_split/splitter/constraints.go @@ -81,7 +81,7 @@ func (ctx splitterContext) LocalGlobalConstraints() { continue } - // if the associated expression is eligible to the stitching, mark the query, over the sub columns, as ignored. + // if the associated expression is eligible to the splitting, mark the query as ignored. ctx.comp.QueriesNoParams.MarkAsIgnored(qName) // adjust the query over the sub columns diff --git a/prover/protocol/distributed/comp_splitting.go b/prover/protocol/distributed/comp_splitting.go index f42b9e4cd..0ca1402c3 100644 --- a/prover/protocol/distributed/comp_splitting.go +++ b/prover/protocol/distributed/comp_splitting.go @@ -1,8 +1,13 @@ package distributed import ( + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/symbolic" "github.com/consensys/linea-monorepo/prover/utils" ) @@ -40,8 +45,8 @@ func GetFreshSegmentModuleComp(in SegmentModuleInputs) *wizard.CompiledIOP { if !in.Disc.ColumnIsInModule(col, in.ModuleName) { continue } - - segModComp.InsertCommit(col.Round(), col.GetColID(), col.Size()/in.NumSegmentsInModule) + // Make colSize a power of two + segModComp.InsertCommit(col.Round(), col.GetColID(), utils.NextPowerOfTwo(col.Size()/in.NumSegmentsInModule)) columnsInRound = append(columnsInRound, col) } @@ -74,7 +79,7 @@ func (p segmentModuleProver) Run(run *wizard.ProverRuntime) { utils.Panic("invalid call: the runtime does not have a [ParentRuntime]") } if run.ProverID > p.numSegments { - panic("proverID can not be larger than number of segments") + panic("proverID cannot be larger than number of segments") } for _, col := range p.cols { @@ -87,6 +92,226 @@ func (p segmentModuleProver) Run(run *wizard.ProverRuntime) { } func getSegmentFromWitness(wit ifaces.ColAssignment, numSegs, segID int) ifaces.ColAssignment { - segSize := wit.Len() / numSegs + segSize := utils.NextPowerOfTwo(wit.Len() / numSegs) return wit.SubVector(segSize*segID, segSize*segID+segSize) } + +func GetFreshCompGL(in SegmentModuleInputs) *wizard.CompiledIOP { + + var ( + // initialize the segment CompiledIOP + segComp = wizard.NewCompiledIOP() + initialComp = in.InitialComp + glColumns = extractGLColumns(initialComp) + glColumnsInModule = []ifaces.Column{} + ) + + // get the GL columns + for _, colName := range initialComp.Columns.AllKeysAt(0) { + + col := initialComp.Columns.GetHandle(colName) + if !in.Disc.ColumnIsInModule(col, in.ModuleName) { + continue + } + + if isGLColumn(col, glColumns) { + // TBD: register at round 1, to create a separate commitment over GL. + segComp.InsertCommit(0, col.GetColID(), col.Size()/in.NumSegmentsInModule) + glColumnsInModule = append(glColumnsInModule, col) + } + + } + + // register provider and receiver + provider := segComp.InsertCommit(0, "PROVIDER", getSizeForProviderReceiver(initialComp)) + receiver := segComp.InsertCommit(0, "RECEIVER", getSizeForProviderReceiver(initialComp)) + + // create a new moduleProver + glProver := glProver{ + glCols: glColumnsInModule, + numSegments: in.NumSegmentsInModule, + provider: provider, + receiver: receiver, + } + + // register Prover action for the segment-module to assign columns per round + segComp.RegisterProverAction(0, glProver) + return segComp + +} + +type glProver struct { + glCols []ifaces.Column + numSegments int + provider ifaces.Column + receiver ifaces.Column +} + +func (p glProver) Run(run *wizard.ProverRuntime) { + if run.ParentRuntime == nil { + utils.Panic("invalid call: the runtime does not have a [ParentRuntime]") + } + if run.ProverID > p.numSegments { + panic("proverID can not be larger than number of segments") + } + + for _, col := range p.glCols { + // get the witness from the initialProver + colWitness := run.ParentRuntime.GetColumn(col.GetColID()) + colSegWitness := getSegmentFromWitness(colWitness, p.numSegments, run.ProverID) + // assign it in the module in the round col was declared + run.AssignColumn(col.GetColID(), colSegWitness, col.Round()) + } + // assign Provider and Receiver + assignProvider(run, run.ProverID, p.numSegments, p.provider) + // for the current segment, the receiver is the provider of the previous segment. + assignProvider(run, utils.PositiveMod(run.ProverID-1, p.numSegments), p.numSegments, p.receiver) + +} + +func extractGLColumns(comp *wizard.CompiledIOP) []ifaces.Column { + + glColumns := []ifaces.Column{} + // extract global queries + for _, queryID := range comp.QueriesNoParams.AllKeysAt(0) { + + if glob, ok := comp.QueriesNoParams.Data(queryID).(query.GlobalConstraint); ok { + glColumns = append(glColumns, ListColumnsFromExpr(glob.Expression, true)...) + } + + if local, ok := comp.QueriesNoParams.Data(queryID).(query.LocalConstraint); ok { + glColumns = append(glColumns, ListColumnsFromExpr(local.Expression, true)...) + } + } + + // extract localOpenings + return glColumns +} + +// ListColumnsFromExpr returns the natural version of all the columns in the expression. +// if natural is true, it return the natural version of the columns, +// otherwise it return the original columns. +func ListColumnsFromExpr(expr *symbolic.Expression, natural bool) []ifaces.Column { + + var ( + board = expr.Board() + metadata = board.ListVariableMetadata() + colList = []ifaces.Column{} + ) + + for _, m := range metadata { + switch t := m.(type) { + case ifaces.Column: + + if shifted, ok := t.(column.Shifted); ok && natural { + colList = append(colList, shifted.Parent) + } else { + colList = append(colList, t) + } + + } + } + return colList + +} + +func isGLColumn(col ifaces.Column, glColumns []ifaces.Column) bool { + + for _, glCol := range glColumns { + if col.GetColID() == glCol.GetColID() { + return true + } + } + return false + +} + +func getSizeForProviderReceiver(comp *wizard.CompiledIOP) int { + + numBoundaries := 0 + + for _, queryID := range comp.QueriesNoParams.AllKeysAt(0) { + + if global, ok := comp.QueriesNoParams.Data(queryID).(query.GlobalConstraint); ok { + + var ( + board = global.Board() + metadata = board.ListVariableMetadata() + maxShift = global.MinMaxOffset().Max + ) + + for _, m := range metadata { + switch t := m.(type) { + case ifaces.Column: + + if shifted, ok := t.(column.Shifted); ok { + // number of boundaries from the current column + numBoundaries += maxShift - column.StackOffsets(shifted) + + } else { + numBoundaries += maxShift + } + + } + } + } + } + return utils.NextPowerOfTwo(numBoundaries) +} + +// assignProvider mainly assigns the provider +// it also can be used for the receiver assignment, +// since the receiver of segment i equal to the provider of segment (i-1). +func assignProvider(run *wizard.ProverRuntime, segID, numSegments int, col ifaces.Column) { + + var ( + parentRuntime = run.ParentRuntime + initialComp = parentRuntime.Spec + allBoundaries = []field.Element{} + ) + + for _, q := range initialComp.QueriesNoParams.AllKeysAt(0) { + if global, ok := initialComp.QueriesNoParams.Data(q).(query.GlobalConstraint); ok { + + var ( + board = global.Board() + metadata = board.ListVariableMetadata() + maxShift = global.MinMaxOffset().Max + ) + + for _, m := range metadata { + switch t := m.(type) { + case ifaces.Column: + + var ( + segmentSize = t.Size() / numSegments + lastRow = (segID+1)*segmentSize - 1 + colWit []field.Element + // number of boundaries from the current column + numBoundaries int + ) + + if shifted, ok := t.(column.Shifted); ok { + numBoundaries = maxShift - column.StackOffsets(shifted) + colWit = shifted.Parent.GetColAssignment(parentRuntime).IntoRegVecSaveAlloc() + + } else { + numBoundaries = maxShift + colWit = t.GetColAssignment(parentRuntime).IntoRegVecSaveAlloc() + } + + for i := lastRow - numBoundaries + 1; i <= lastRow; i++ { + + allBoundaries = append(allBoundaries, + colWit[i]) + + } + + } + + } + + } + } + run.AssignColumn(col.GetColID(), smartvectors.RightZeroPadded(allBoundaries, col.Size())) +} diff --git a/prover/protocol/distributed/compiler/global/global.go b/prover/protocol/distributed/compiler/global/global.go index 37f1d6070..35f7177a9 100644 --- a/prover/protocol/distributed/compiler/global/global.go +++ b/prover/protocol/distributed/compiler/global/global.go @@ -1,8 +1,312 @@ package global -import "github.com/consensys/linea-monorepo/prover/protocol/wizard" +import ( + "github.com/consensys/linea-monorepo/prover/protocol/accessors" + "github.com/consensys/linea-monorepo/prover/protocol/column" + "github.com/consensys/linea-monorepo/prover/protocol/column/verifiercol" + "github.com/consensys/linea-monorepo/prover/protocol/distributed" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/variables" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/symbolic" + "github.com/consensys/linea-monorepo/prover/utils" + "github.com/consensys/linea-monorepo/prover/utils/collection" +) -// CompileDist compiles the global queries distributedly. -func CompileDist(comp *wizard.CompiledIOP) { - panic("unimplemented") +type segmentID int + +type DistributionInputs struct { + ModuleComp *wizard.CompiledIOP + InitialComp *wizard.CompiledIOP + // module Discoverer used to detect the relevant part of the query to the module + Disc distributed.ModuleDiscoverer + // Name of the module + ModuleName distributed.ModuleName + // number of segments for the module + NumSegments int +} + +func DistributeGlobal(in DistributionInputs) { + + var ( + bInputs = boundaryInputs{ + moduleComp: in.ModuleComp, + numSegments: in.NumSegments, + provider: in.ModuleComp.Columns.GetHandle("PROVIDER"), + receiver: in.ModuleComp.Columns.GetHandle("RECEIVER"), + providerOpenings: []query.LocalOpening{}, + receiverOpenings: []query.LocalOpening{}, + } + ) + + for _, qName := range in.InitialComp.QueriesNoParams.AllUnignoredKeys() { + + q, ok := in.InitialComp.QueriesNoParams.Data(qName).(query.GlobalConstraint) + if !ok { + continue + } + + if in.Disc.ExpressionIsInModule(q.Expression, in.ModuleName) { + + // apply global constraint over the segment. + in.ModuleComp.InsertGlobal(0, + q.ID, + AdjustExpressionForGlobal(in.ModuleComp, q.Expression, in.NumSegments), + ) + + // collect the boundaries for provider and receiver + + BoundariesForProvider(&bInputs, q) + BoundariesForReceiver(&bInputs, q) + + } + + // @Azam the risk is that some global constraints may be skipped here. + // we can prevent this by tagging the query as ignored from the initialComp, + // and at the end make sure that no query has remained in initial CompiledIOP. + } + + in.ModuleComp.RegisterProverAction(0, ¶mAssignments{ + provider: bInputs.provider, + receiver: bInputs.receiver, + providerOpenings: bInputs.providerOpenings, + receiverOpenings: bInputs.receiverOpenings, + }) + +} + +type boundaryInputs struct { + moduleComp *wizard.CompiledIOP + numSegments int + provider ifaces.Column + receiver ifaces.Column + providerOpenings, receiverOpenings []query.LocalOpening +} + +func AdjustExpressionForGlobal( + comp *wizard.CompiledIOP, + expr *symbolic.Expression, + numSegments int, +) *symbolic.Expression { + + var ( + board = expr.Board() + metadatas = board.ListVariableMetadata() + translationMap = collection.NewMapping[string, *symbolic.Expression]() + colTranslation ifaces.Column + size = column.ExprIsOnSameLengthHandles(&board) + ) + + for _, metadata := range metadatas { + + // For each slot, get the expression obtained by replacing the commitment + // by the appropriated column. + + switch m := metadata.(type) { + case ifaces.Column: + + switch col := m.(type) { + case column.Natural: + colTranslation = comp.Columns.GetHandle(m.GetColID()) + + case verifiercol.VerifierCol: + // panic happens specially for the case of FromAccessors + panic("unsupported for now, unless module discoverer can capture such columns") + + case column.Shifted: + colTranslation = column.Shift(comp.Columns.GetHandle(col.Parent.GetColID()), col.Offset) + + } + + translationMap.InsertNew(m.String(), ifaces.ColumnAsVariable(colTranslation)) + case variables.X: + utils.Panic("unsupported, the value of `x` in the unsplit query and the split would be different") + case variables.PeriodicSample: + // Check that the period is not larger than the domain size. If + // the period is smaller this is a no-op because the period does + // not change. + segSize := size / numSegments + + if m.T > segSize { + + panic("unsupported, since this depends on the segment ID, unless the module discoverer can detect such cases") + } + translationMap.InsertNew(m.String(), symbolic.NewVariable(metadata)) + default: + // Repass the same variable (for coins or other types of single-valued variable) + translationMap.InsertNew(m.String(), symbolic.NewVariable(metadata)) + } + + } + return expr.Replay(translationMap) +} + +func BoundariesForProvider(in *boundaryInputs, q query.GlobalConstraint) { + + var ( + board = q.Board() + offsetRange = q.MinMaxOffset() + provider = in.provider + maxShift = offsetRange.Max + colsInExpr = distributed.ListColumnsFromExpr(q.Expression, false) + colsOnProvider = onBoundaries(colsInExpr, maxShift) + numBoundaries = offsetRange.Max - offsetRange.Min + size = column.ExprIsOnSameLengthHandles(&board) + segSize = size / in.numSegments + ) + for _, col := range colsInExpr { + for i := 0; i < numBoundaries; i++ { + if colsOnProvider.Exists(col.GetColID()) { + + pos := colsOnProvider.MustGet(col.GetColID()) + + if i < maxShift-column.StackOffsets(col) { + // take from provider, since the size of the provider is different from size of the expression + // take it via accessor. + var ( + index = pos[0] + i + name = ifaces.QueryIDf("%v_%v", "FROM_PROVIDER_AT", index) + loProvider = in.moduleComp.InsertLocalOpening(0, name, column.Shift(provider, index)) + accessorProvider = accessors.NewLocalOpeningAccessor(loProvider, 0) + indexOnCol = segSize - (maxShift - column.StackOffsets(col) - i) + nameExpr = ifaces.QueryIDf("%v_%v_%v", "CONSISTENCY_AGAINST_PROVIDER", col.GetColID(), i) + colInModule ifaces.Column + ) + + // replace col with its replacement in the module. + if shifted, ok := col.(column.Shifted); ok { + colInModule = in.moduleComp.Columns.GetHandle(shifted.Parent.GetColID()) + } else { + colInModule = in.moduleComp.Columns.GetHandle(col.GetColID()) + } + + // add the localOpening to the list + in.providerOpenings = append(in.providerOpenings, loProvider) + // impose that loProvider = loCol + in.moduleComp.InsertLocal(0, nameExpr, + symbolic.Sub(accessorProvider, column.Shift(colInModule, indexOnCol)), + ) + + } + } + } + } + +} + +func BoundariesForReceiver(in *boundaryInputs, q query.GlobalConstraint) { + + var ( + offsetRange = q.MinMaxOffset() + receiver = in.receiver + maxShift = offsetRange.Max + colsInExpr = distributed.ListColumnsFromExpr(q.Expression, false) + colsOnReceiver = onBoundaries(colsInExpr, maxShift) + numBoundaries = offsetRange.Max - offsetRange.Min + comp = in.moduleComp + colInModule ifaces.Column + // list of local openings by the boundary index + allLists = make([][]query.LocalOpening, numBoundaries) + ) + + for i := 0; i < numBoundaries; i++ { + + translationMap := collection.NewMapping[string, *symbolic.Expression]() + + for _, col := range colsInExpr { + + // replace col with its replacement in the module. + if shifted, ok := col.(column.Shifted); ok { + colInModule = in.moduleComp.Columns.GetHandle(shifted.Parent.GetColID()) + } else { + colInModule = in.moduleComp.Columns.GetHandle(col.GetColID()) + } + + if colsOnReceiver.Exists(col.GetColID()) { + pos := colsOnReceiver.MustGet(col.GetColID()) + + if i < maxShift-column.StackOffsets(col) { + // take from receiver, since the size of the receiver is different from size of the expression + // take it via accessor. + var ( + index = pos[0] + i + name = ifaces.QueryIDf("%v_%v", "FROM_RECEIVER_AT", index) + lo = comp.InsertLocalOpening(0, name, column.Shift(receiver, index)) + accessor = accessors.NewLocalOpeningAccessor(lo, 0) + ) + // add the localOpening to the list + allLists[i] = append(allLists[i], lo) + // in.receiverOpenings = append(in.receiverOpenings, lo) + // translate the column + translationMap.InsertNew(string(col.GetColID()), accessor.AsVariable()) + } else { + // take the rest from the column + tookFromReceiver := maxShift - column.StackOffsets(col) + translationMap.InsertNew(string(col.GetColID()), ifaces.ColumnAsVariable(column.Shift(colInModule, i-tookFromReceiver))) + } + + } else { + translationMap.InsertNew(string(col.GetColID()), ifaces.ColumnAsVariable((column.Shift(colInModule, i)))) + } + + } + + expr := q.Expression.Replay(translationMap) + name := ifaces.QueryIDf("%v_%v_%v", "CONSISTENCY_AGAINST_RECEIVER", q.ID, i) + comp.InsertLocal(0, name, expr) + + } + + // order receiverOpenings column by column + for i := 0; i < numBoundaries; i++ { + for _, list := range allLists { + if len(list) > i { + in.receiverOpenings = append(in.receiverOpenings, list[i]) + } + } + } + +} + +// it indicates the column list having the provider cells (i.e., +// some cells of the columns are needed to be provided to the next segment) +func onBoundaries(colsInExpr []ifaces.Column, maxShift int) collection.Mapping[ifaces.ColID, [2]int] { + + var ( + ctr = 0 + colsOnReceiver = collection.NewMapping[ifaces.ColID, [2]int]() + ) + for _, col := range colsInExpr { + // number of boundaries from the column (that falls on the receiver) is + // maxShift - column.StackOffsets(col) + newCtr := ctr + maxShift - column.StackOffsets(col) + + // it does not have any cell on the receiver. + if newCtr == ctr { + continue + } + + colsOnReceiver.InsertNew(col.GetColID(), [2]int{ctr, newCtr}) + ctr = newCtr + + } + + return colsOnReceiver + +} + +// it generates natural verifier columns, from a given verifier column +func createVerifierColForModule(col ifaces.Column, numSegments int) ifaces.Column { + + if vcol, ok := col.(verifiercol.VerifierCol); ok { + + switch v := vcol.(type) { + case verifiercol.ConstCol: + return verifiercol.NewConstantCol(v.F, v.Size()/numSegments) + default: + panic("unsupported") + } + } + return nil } diff --git a/prover/protocol/distributed/compiler/global/global_test.go b/prover/protocol/distributed/compiler/global/global_test.go new file mode 100644 index 000000000..20b4e44b6 --- /dev/null +++ b/prover/protocol/distributed/compiler/global/global_test.go @@ -0,0 +1,95 @@ +package global_test + +import ( + "testing" + + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/protocol/column" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy" + "github.com/consensys/linea-monorepo/prover/protocol/distributed" + "github.com/consensys/linea-monorepo/prover/protocol/distributed/compiler/global" + md "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/symbolic" + "github.com/stretchr/testify/require" +) + +// It tests DistributedLogDerivSum. +func TestDistributedGlobal(t *testing.T) { + const ( + numSegModule = 2 + ) + + //initialComp + define := func(b *wizard.Builder) { + + var ( + col0 = b.CompiledIOP.InsertCommit(0, "module.col0", 8) + col1 = b.CompiledIOP.InsertCommit(0, "module.col1", 8) + col2 = b.CompiledIOP.InsertCommit(0, "module.col2", 8) + col3 = b.CompiledIOP.InsertCommit(0, "module.col3", 8) + ) + + b.CompiledIOP.InsertGlobal(0, "global0", + symbolic.Sub( + col1, column.Shift(col0, 3), + symbolic.Mul(2, column.Shift(col2, 2)), + symbolic.Neg(column.Shift(col3, 3)), + ), + ) + + } + + // initialProver + prover := func(run *wizard.ProverRuntime) { + run.AssignColumn("module.col0", smartvectors.ForTest(3, 0, 2, 1, 4, 1, 13, 0)) + run.AssignColumn("module.col1", smartvectors.ForTest(1, 7, 1, 11, 2, 1, 0, 2)) + run.AssignColumn("module.col2", smartvectors.ForTest(7, 0, 1, 3, 0, 4, 1, 0)) + run.AssignColumn("module.col3", smartvectors.ForTest(2, 14, 0, 2, 3, 0, 10, 0)) + + } + + // initial compiledIOP is the parent to all the SegmentModuleComp objects. + initialComp := wizard.Compile(define) + + // Initialize the period separating module discoverer + disc := &md.PeriodSeperatingModuleDiscoverer{} + disc.Analyze(initialComp) + + // distribute the columns among modules and segments. + moduleComp := distributed.GetFreshCompGL( + distributed.SegmentModuleInputs{ + InitialComp: initialComp, + Disc: disc, + ModuleName: "module", + NumSegmentsInModule: numSegModule, + }, + ) + + // distribute the query among segments. + global.DistributeGlobal(global.DistributionInputs{ + ModuleComp: moduleComp, + InitialComp: initialComp, + Disc: disc, + ModuleName: "module", + NumSegments: numSegModule, + }) + + // This dummy compiles the global/local queries of the segment. + wizard.ContinueCompilation(moduleComp, dummy.Compile) + + // run the initial runtime + initialRuntime := wizard.ProverOnlyFirstRound(initialComp, prover) + + // Compile and prove for module + for proverID := 0; proverID < numSegModule; proverID++ { + proof := wizard.Prove(moduleComp, func(run *wizard.ProverRuntime) { + run.ParentRuntime = initialRuntime + // inputs for vertical splitting of the witness + run.ProverID = proverID + }) + valid := wizard.Verify(moduleComp, proof) + require.NoError(t, valid) + } + +} diff --git a/prover/protocol/distributed/compiler/global/prover.go b/prover/protocol/distributed/compiler/global/prover.go new file mode 100644 index 000000000..4cc1ce0dc --- /dev/null +++ b/prover/protocol/distributed/compiler/global/prover.go @@ -0,0 +1,28 @@ +package global + +import ( + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" +) + +type paramAssignments struct { + provider ifaces.Column + receiver ifaces.Column + providerOpenings []query.LocalOpening + receiverOpenings []query.LocalOpening +} + +// it assigns the LocalOpening for the segment. +func (pa paramAssignments) Run(run *wizard.ProverRuntime) { + var ( + providerWit = run.GetColumn(pa.provider.GetColID()).IntoRegVecSaveAlloc() + receiverWit = run.GetColumn(pa.receiver.GetColID()).IntoRegVecSaveAlloc() + ) + + for i := range pa.providerOpenings { + + run.AssignLocalPoint(pa.providerOpenings[i].ID, providerWit[i]) + run.AssignLocalPoint(pa.receiverOpenings[i].ID, receiverWit[i]) + } +} diff --git a/prover/protocol/distributed/compiler/inclusion/inclusion.go b/prover/protocol/distributed/compiler/inclusion/inclusion.go index 855e25b27..3de8d9a2e 100644 --- a/prover/protocol/distributed/compiler/inclusion/inclusion.go +++ b/prover/protocol/distributed/compiler/inclusion/inclusion.go @@ -52,9 +52,9 @@ func DistributeLogDerivativeSum( continue } - // panic if there is more than a LogDerivativeSum query in the initialComp. + // panic if there is more than one LogDerivativeSum queries in the initialComp. if string(queryID) != "" { - utils.Panic("found more than a LogDerivativeSum query in the initialComp") + utils.Panic("found more than one LogDerivativeSum queries in the initialComp") } queryID = qName diff --git a/prover/protocol/distributed/compiler/inclusion/inclusion_test.go b/prover/protocol/distributed/compiler/inclusion/inclusion_test.go index e62939f58..c5c0ba8b7 100644 --- a/prover/protocol/distributed/compiler/inclusion/inclusion_test.go +++ b/prover/protocol/distributed/compiler/inclusion/inclusion_test.go @@ -183,9 +183,15 @@ func TestSeedGeneration(t *testing.T) { func checkConsistency(runs []wizard.Runtime) error { - var res field.Element + var ( + res field.Element + ) for _, run := range runs { - logderiv := run.GetPublicInput(constants.LogDerivativeSumPublicInput) + logderiv_ := run.GetPublicInput(constants.LogDerivativeSumPublicInput) + logderiv, ok := logderiv_.(field.Element) + if !ok { + return errors.New("the logderiv is not a field element") + } res.Add(&res, &logderiv) } diff --git a/prover/protocol/distributed/compiler/permutation/permutation.go b/prover/protocol/distributed/compiler/permutation/permutation.go index 794aed774..d9806640c 100644 --- a/prover/protocol/distributed/compiler/permutation/permutation.go +++ b/prover/protocol/distributed/compiler/permutation/permutation.go @@ -71,7 +71,7 @@ func NewPermutationIntoGrandProductCtx( } /* - Handles the lookups and permutations checks + Handles the permutations checks */ for round := 0; round < numRounds; round++ { queries := initialComp.QueriesNoParams.AllKeysAt(round) diff --git a/prover/protocol/distributed/compiler/permutation/permutation_test.go b/prover/protocol/distributed/compiler/permutation/permutation_test.go index bcac9a88d..bd4735750 100644 --- a/prover/protocol/distributed/compiler/permutation/permutation_test.go +++ b/prover/protocol/distributed/compiler/permutation/permutation_test.go @@ -8,7 +8,7 @@ import ( "github.com/consensys/linea-monorepo/prover/protocol/compiler/grandproduct" "github.com/consensys/linea-monorepo/prover/protocol/distributed" dist_permutation "github.com/consensys/linea-monorepo/prover/protocol/distributed/compiler/permutation" - "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" + discoverer "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/wizard" "github.com/stretchr/testify/require" @@ -125,7 +125,7 @@ func TestPermutation(t *testing.T) { // test-case. initialComp := wizard.Compile(tc.DefineFunc) - disc := namebaseddiscoverer.PeriodSeperatingModuleDiscoverer{} + disc := discoverer.PeriodSeperatingModuleDiscoverer{} disc.Analyze(initialComp) // This declares a compiled IOP with only the columns of the module A diff --git a/prover/protocol/distributed/compiler/permutation/settings.go b/prover/protocol/distributed/compiler/permutation/settings.go index 59de8d1f3..c1b153cee 100644 --- a/prover/protocol/distributed/compiler/permutation/settings.go +++ b/prover/protocol/distributed/compiler/permutation/settings.go @@ -1,8 +1,8 @@ package dist_permutation -import "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" +import discoverer "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" type Settings struct { // Name of the target module - TargetModuleName namebaseddiscoverer.ModuleName + TargetModuleName discoverer.ModuleName } diff --git a/prover/protocol/distributed/compiler/projection/projection.go b/prover/protocol/distributed/compiler/projection/projection.go index 20f705d66..ffac65bfb 100644 --- a/prover/protocol/distributed/compiler/projection/projection.go +++ b/prover/protocol/distributed/compiler/projection/projection.go @@ -1,17 +1,551 @@ -package projection +package dist_projection import ( + "math/big" + + "github.com/consensys/linea-monorepo/prover/crypto/mimc" + "github.com/consensys/linea-monorepo/prover/maths/common/poly" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/accessors" "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" + "github.com/consensys/linea-monorepo/prover/protocol/distributed" + "github.com/consensys/linea-monorepo/prover/protocol/distributed/constants" + discoverer "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/protocol/wizardutils" + "github.com/consensys/linea-monorepo/prover/symbolic" + "github.com/consensys/linea-monorepo/prover/utils" +) + +// Used for deriving names of queries and coins +const ( + distProjectionStr = "DISTRIBUTED_PROJECTION" + MaxNumOfQueriesPerModule = 10 ) -// CompileDist compiles the projection queries distributedly. -// It receives a compiledIOP object relevant to a segment. -// The seed is a random coin from randomness beacon (FS of all LPP commitments). -// All the compilation steps are similar to the permutation compilation apart from: -// - random coins \alpha and \gamma are generated from the seed (and the tableName). -// - no verifierAction is needed over the ZOpening. -// - ZOpenings are declared as public input. -func CompileDist(comp *wizard.CompiledIOP, seed coin.Info) { +type DistributeProjectionCtx struct { + // List of all projection queries alloted to the segments + DistProjectionInput []*query.DistributedProjectionInput + // List of Evaluation Coins per query + EvalCoins []coin.Info + // The module name for which we are processing the distributed projection query + TargetModuleName string + // Query stores the [query.DistributedProjection] generated by the compilation + Query query.DistributedProjection + // LastRoundPerm indicates the highest round at which a compiled projection + // occurs. + LastRoundProjection int + // number of segments per module + NumSegmentsPerModule int + // original projection queries to extract filters + QIds []query.Projection +} + +// NewDistributeProjectionCtx processes all the projection queries from the initialComp +// and registers DistributedProjection queries to the target module using the module +// discoverer +func NewDistributeProjectionCtx( + targetModuleName discoverer.ModuleName, + initialComp, moduleComp *wizard.CompiledIOP, + disc distributed.ModuleDiscoverer, + numSegmentPerModule int, +) *DistributeProjectionCtx { + var ( + p = &DistributeProjectionCtx{ + DistProjectionInput: make([]*query.DistributedProjectionInput, 0, MaxNumOfQueriesPerModule), + EvalCoins: make([]coin.Info, 0, MaxNumOfQueriesPerModule), + TargetModuleName: targetModuleName, + LastRoundProjection: getLastRoundProjection(initialComp), + NumSegmentsPerModule: numSegmentPerModule, + QIds: make([]query.Projection, 0, MaxNumOfQueriesPerModule), + } + numRounds = initialComp.NumRounds() + qId = p.QueryID() + ) + if p.LastRoundProjection < 0 { + return p + } + + /* + Handles the projection checks + */ + for round := 0; round < numRounds; round++ { + queries := initialComp.QueriesNoParams.AllKeysAt(round) + for _, qName := range queries { + + // Skip if it was already compiled + if initialComp.QueriesNoParams.IsIgnored(qName) { + continue + } + + q_, ok := initialComp.QueriesNoParams.Data(qName).(query.Projection) + if !ok { + continue + } + var ( + onlyA = (disc.FindModule(q_.Inp.ColumnA[0]) == targetModuleName) && (disc.FindModule(q_.Inp.ColumnB[0]) != targetModuleName) + onlyB = (disc.FindModule(q_.Inp.ColumnA[0]) != targetModuleName) && (disc.FindModule(q_.Inp.ColumnB[0]) == targetModuleName) + bothAAndB = (disc.FindModule(q_.Inp.ColumnA[0]) == targetModuleName) && (disc.FindModule(q_.Inp.ColumnB[0]) == targetModuleName) + ) + if bothAAndB { + check(q_.Inp.ColumnA, disc, targetModuleName) + check(q_.Inp.ColumnB, disc, targetModuleName) + p.push(moduleComp, q_, true, true) + } else if onlyA { + check(q_.Inp.ColumnA, disc, targetModuleName) + p.push(moduleComp, q_, true, false) + } else if onlyB { + check(q_.Inp.ColumnB, disc, targetModuleName) + p.push(moduleComp, q_, false, true) + } else { + continue + } + } + } + // We register the grand product query in round one because + // alphas, betas, and the query param are assigned in round one + p.Query = moduleComp.InsertDistributedProjection(p.LastRoundProjection+1, qId, p.DistProjectionInput) + + moduleComp.RegisterProverAction(p.LastRoundProjection+1, p) + + // declare [query.LogDerivSumParams] as [wizard.PublicInput] + moduleComp.PublicInputs = append(moduleComp.PublicInputs, + wizard.PublicInput{ + Name: constants.DistributedProjectionPublicInput, + Acc: accessors.NewDistributedProjectionAccessor(p.Query), + }) + return p + +} + +// Check verifies if all columns of the projection query belongs to the same module or not +func check(cols []ifaces.Column, + disc distributed.ModuleDiscoverer, + targetModuleName discoverer.ModuleName, +) error { + for _, col := range cols { + if disc.FindModule(col) != targetModuleName { + utils.Panic("unsupported projection query, colName: %v, target: %v", col.GetColID(), targetModuleName) + } + } + return nil +} + +// push appends a new DistributedProjectionInput to the DistProjectionInput slice +func (p *DistributeProjectionCtx) push(comp *wizard.CompiledIOP, q query.Projection, isA, isB bool) { + var ( + isMultiColumn = len(q.Inp.ColumnA) > 1 + alphaName = coin.Namef("%v_%v", "MERGING_COIN"+string(q.ID), "FieldFromSeed") + betaName = coin.Namef("%v_%v", "EVAL_COIN"+string(q.ID), "FieldFromSeed") + alpha coin.Info + beta coin.Info + ) + // Register alpha and beta + if isMultiColumn { + if comp.Coins.Exists(alphaName) { + alpha = comp.Coins.Data(alphaName) + } else { + alpha = comp.InsertCoin(p.LastRoundProjection+1, alphaName, coin.FieldFromSeed) + } + } + if comp.Coins.Exists(betaName) { + beta = comp.Coins.Data(betaName) + } else { + beta = comp.InsertCoin(p.LastRoundProjection+1, betaName, coin.FieldFromSeed) + } + p.EvalCoins = append(p.EvalCoins, beta) + p.QIds = append(p.QIds, q) + // Push the DistributedProjectionInput + if isA && isB { + fA, _, _ := wizardutils.AsExpr(comp.Columns.GetHandle(q.Inp.FilterA.GetColID())) + fB, _, _ := wizardutils.AsExpr(comp.Columns.GetHandle(q.Inp.FilterB.GetColID())) + var ( + colA = make([]ifaces.Column, 0, len(q.Inp.ColumnA)) + colB = make([]ifaces.Column, 0, len(q.Inp.ColumnB)) + ) + for i := 0; i < len(q.Inp.ColumnA); i++ { + colA = append(colA, comp.Columns.GetHandle(q.Inp.ColumnA[i].GetColID())) + } + for i := 0; i < len(q.Inp.ColumnB); i++ { + colB = append(colB, comp.Columns.GetHandle(q.Inp.ColumnB[i].GetColID())) + } + p.DistProjectionInput = append(p.DistProjectionInput, &query.DistributedProjectionInput{ + ColumnA: wizardutils.RandLinCombColSymbolic(alpha, colA), + ColumnB: wizardutils.RandLinCombColSymbolic(alpha, colB), + FilterA: fA, + FilterB: fB, + SizeA: comp.Columns.GetSize(q.Inp.FilterA.GetColID()), + SizeB: comp.Columns.GetSize(q.Inp.FilterB.GetColID()), + EvalCoin: beta.Name, + IsAInModule: true, + IsBInModule: true, + }) + } else if isA { + fA, _, _ := wizardutils.AsExpr(comp.Columns.GetHandle(q.Inp.FilterA.GetColID())) + var ( + colA = make([]ifaces.Column, 0, len(q.Inp.ColumnA)) + ) + for i := 0; i < len(q.Inp.ColumnA); i++ { + colA = append(colA, comp.Columns.GetHandle(q.Inp.ColumnA[i].GetColID())) + } + p.DistProjectionInput = append(p.DistProjectionInput, &query.DistributedProjectionInput{ + ColumnA: wizardutils.RandLinCombColSymbolic(alpha, colA), + ColumnB: symbolic.NewConstant(1), + FilterA: fA, + FilterB: symbolic.NewConstant(1), + SizeA: comp.Columns.GetSize(q.Inp.FilterA.GetColID()), + EvalCoin: beta.Name, + IsAInModule: true, + IsBInModule: false, + }) + } else if isB { + fB, _, _ := wizardutils.AsExpr(comp.Columns.GetHandle(q.Inp.FilterB.GetColID())) + var ( + colB = make([]ifaces.Column, 0, len(q.Inp.ColumnB)) + ) + for i := 0; i < len(q.Inp.ColumnB); i++ { + colB = append(colB, comp.Columns.GetHandle(q.Inp.ColumnB[i].GetColID())) + } + p.DistProjectionInput = append(p.DistProjectionInput, &query.DistributedProjectionInput{ + ColumnA: symbolic.NewConstant(1), + ColumnB: wizardutils.RandLinCombColSymbolic(alpha, colB), + FilterA: symbolic.NewConstant(1), + FilterB: fB, + SizeB: comp.Columns.GetSize(q.Inp.FilterB.GetColID()), + EvalCoin: beta.Name, + IsAInModule: false, + IsBInModule: true, + }) + } else { + panic("Invalid distributed projection query while initial pushing") + } +} + +// Run implements [wizard.ProverAction] interface +func (p *DistributeProjectionCtx) Run(run *wizard.ProverRuntime) { + p.assignSumNumOnes(run) + run.AssignDistributedProjection(p.Query.ID, query.DistributedProjectionParams{ + ScaledHorner: p.computeScaledHorner(run), + HashCumSumOnePrev: p.computeHashPrev(), + HashCumSumOneCurr: p.computeHashCurr(), + }) +} + +// assignSumNumOnes assigns the cumulative number of ones in the previous segments as well as the current segment +func (p *DistributeProjectionCtx) assignSumNumOnes(run *wizard.ProverRuntime) { + var ( + initialRuntime = run.ParentRuntime + segId = run.ProverID + one = field.One() + bigOne = big.NewInt(1) + ) + for elemIndex, inp := range p.DistProjectionInput { + if inp.IsAInModule && inp.IsBInModule { + var ( + fA = initialRuntime.GetColumn(p.QIds[elemIndex].Inp.FilterA.GetColID()) + fB = initialRuntime.GetColumn(p.QIds[elemIndex].Inp.FilterB.GetColID()) + segSizeA = utils.NextPowerOfTwo(fA.Len() / p.NumSegmentsPerModule) + segSizeB = utils.NextPowerOfTwo(fB.Len() / p.NumSegmentsPerModule) + numOnesCurrA = field.Zero() + numOnesCurrB = field.Zero() + cumSumOnesPrevA = *big.NewInt(0) + cumSumOnesPrevB = *big.NewInt(0) + ) + if segId == 0 { + var ( + fACurr = fA.SubVector(segId*segSizeA, (segId+1)*segSizeA).IntoRegVecSaveAlloc() + fBCurr = fB.SubVector(segId*segSizeB, (segId+1)*segSizeB).IntoRegVecSaveAlloc() + ) + for i := 0; i < len(fACurr); i++ { + if fACurr[i] == one { + numOnesCurrA.Add(&numOnesCurrA, &one) + } + } + for i := 0; i < len(fBCurr); i++ { + if fBCurr[i] == one { + numOnesCurrB.Add(&numOnesCurrB, &one) + } + } + p.DistProjectionInput[elemIndex].CumulativeNumOnesPrevSegmentsA = cumSumOnesPrevA + p.DistProjectionInput[elemIndex].CumulativeNumOnesPrevSegmentsB = cumSumOnesPrevB + p.DistProjectionInput[elemIndex].CurrNumOnesA = numOnesCurrA + p.DistProjectionInput[elemIndex].CurrNumOnesB = numOnesCurrB + } else { + var ( + fAPrev = fA.SubVector(0, segId*segSizeA).IntoRegVecSaveAlloc() + fBPrev = fB.SubVector(0, segId*segSizeB).IntoRegVecSaveAlloc() + fACurr = fA.SubVector(segId*segSizeA, (segId+1)*segSizeA).IntoRegVecSaveAlloc() + fBCurr = fB.SubVector(segId*segSizeB, (segId+1)*segSizeB).IntoRegVecSaveAlloc() + ) + for i := 0; i < len(fAPrev); i++ { + if fAPrev[i] == one { + cumSumOnesPrevA.Add(&cumSumOnesPrevA, bigOne) + } + } + for i := 0; i < len(fBPrev); i++ { + if fBPrev[i] == one { + cumSumOnesPrevB.Add(&cumSumOnesPrevB, bigOne) + } + } + for i := 0; i < len(fACurr); i++ { + if fACurr[i] == one { + numOnesCurrA.Add(&numOnesCurrA, &one) + } + } + for i := 0; i < len(fBCurr); i++ { + if fBCurr[i] == one { + numOnesCurrB.Add(&numOnesCurrB, &one) + } + } + p.DistProjectionInput[elemIndex].CumulativeNumOnesPrevSegmentsA = cumSumOnesPrevA + p.DistProjectionInput[elemIndex].CumulativeNumOnesPrevSegmentsB = cumSumOnesPrevB + p.DistProjectionInput[elemIndex].CurrNumOnesA = numOnesCurrA + p.DistProjectionInput[elemIndex].CurrNumOnesB = numOnesCurrB + } + } + if inp.IsAInModule && !inp.IsBInModule { + var ( + fA = initialRuntime.GetColumn(p.QIds[elemIndex].Inp.FilterA.GetColID()) + segSizeA = utils.NextPowerOfTwo(fA.Len() / p.NumSegmentsPerModule) + numOnesCurrA = field.Zero() + cumSumOnesPrevA = *big.NewInt(0) + ) + if segId == 0 { + var ( + fACurr = fA.SubVector(segId*segSizeA, (segId+1)*segSizeA).IntoRegVecSaveAlloc() + ) + + for i := 0; i < len(fACurr); i++ { + if fACurr[i] == one { + numOnesCurrA.Add(&numOnesCurrA, &one) + } + } + p.DistProjectionInput[elemIndex].CumulativeNumOnesPrevSegmentsA = cumSumOnesPrevA + p.DistProjectionInput[elemIndex].CurrNumOnesA = numOnesCurrA + } else { + var ( + fAPrev = fA.SubVector(0, segId*segSizeA).IntoRegVecSaveAlloc() + fACurr = fA.SubVector(segId*segSizeA, (segId+1)*segSizeA).IntoRegVecSaveAlloc() + ) + for i := 0; i < len(fAPrev); i++ { + if fAPrev[i] == one { + cumSumOnesPrevA.Add(&cumSumOnesPrevA, bigOne) + } + } + for i := 0; i < len(fACurr); i++ { + if fACurr[i] == one { + numOnesCurrA.Add(&numOnesCurrA, &one) + } + } + p.DistProjectionInput[elemIndex].CumulativeNumOnesPrevSegmentsA = cumSumOnesPrevA + p.DistProjectionInput[elemIndex].CurrNumOnesA = numOnesCurrA + } + } + if !inp.IsAInModule && inp.IsBInModule { + var ( + fB = initialRuntime.GetColumn(p.QIds[elemIndex].Inp.FilterB.GetColID()) + segSizeB = utils.NextPowerOfTwo(fB.Len() / p.NumSegmentsPerModule) + numOnesCurrB field.Element + cumSumOnesPrevB = *big.NewInt(0) + ) + if segId == 0 { + var ( + fBCurr = fB.SubVector(segId*segSizeB, (segId+1)*segSizeB).IntoRegVecSaveAlloc() + ) + + for i := 0; i < len(fBCurr); i++ { + if fBCurr[i] == one { + numOnesCurrB.Add(&numOnesCurrB, &one) + } + } + p.DistProjectionInput[elemIndex].CumulativeNumOnesPrevSegmentsB = cumSumOnesPrevB + p.DistProjectionInput[elemIndex].CurrNumOnesB = numOnesCurrB + } else { + var ( + fBPrev = fB.SubVector(0, segId*segSizeB).IntoRegVecSaveAlloc() + fBCurr = fB.SubVector(segId*segSizeB, (segId+1)*segSizeB).IntoRegVecSaveAlloc() + ) + for i := 0; i < len(fBPrev); i++ { + if fBPrev[i] == one { + cumSumOnesPrevB.Add(&cumSumOnesPrevB, bigOne) + } + } + for i := 0; i < len(fBCurr); i++ { + if fBCurr[i] == one { + numOnesCurrB.Add(&numOnesCurrB, &one) + } + } + p.DistProjectionInput[elemIndex].CumulativeNumOnesPrevSegmentsB = cumSumOnesPrevB + p.DistProjectionInput[elemIndex].CurrNumOnesB = numOnesCurrB + } + } + } +} + +// computeScaledHorner computes the parameter of the DistributedProjection query +func (p *DistributeProjectionCtx) computeScaledHorner(run *wizard.ProverRuntime) field.Element { + var ( + queryParam = field.Zero() + ) + for elemIndex, inp := range p.DistProjectionInput { + var ( + elemParam = field.Zero() + ) + if inp.IsAInModule && inp.IsBInModule { + var ( + colABoard = inp.ColumnA.Board() + colBBoard = inp.ColumnB.Board() + filterABorad = inp.FilterA.Board() + filterBBoard = inp.FilterB.Board() + colA = column.EvalExprColumn(run, colABoard).IntoRegVecSaveAlloc() + colB = column.EvalExprColumn(run, colBBoard).IntoRegVecSaveAlloc() + filterA = column.EvalExprColumn(run, filterABorad).IntoRegVecSaveAlloc() + filterB = column.EvalExprColumn(run, filterBBoard).IntoRegVecSaveAlloc() + multA, multB, hornerA, hornerB field.Element + ) + // Add hornerA after scaling + hornerATrace := poly.GetHornerTrace(colA, filterA, run.GetRandomCoinField(p.EvalCoins[elemIndex].Name)) + multA = run.GetRandomCoinField(p.EvalCoins[elemIndex].Name) + multA.Exp(multA, &inp.CumulativeNumOnesPrevSegmentsA) + hornerA = hornerATrace[0] + hornerA.Mul(&hornerA, &multA) + elemParam.Add(&elemParam, &hornerA) + // Subtract hornerB after scaling + hornerBTrace := poly.GetHornerTrace(colB, filterB, run.GetRandomCoinField(p.EvalCoins[elemIndex].Name)) + hornerB = hornerBTrace[0] + multB = run.GetRandomCoinField(p.EvalCoins[elemIndex].Name) + multB.Exp(multB, &inp.CumulativeNumOnesPrevSegmentsB) + hornerB.Mul(&hornerB, &multB) + elemParam.Sub(&elemParam, &hornerB) + } else if inp.IsAInModule && !inp.IsBInModule { + var ( + colABoard = inp.ColumnA.Board() + filterABorad = inp.FilterA.Board() + colA = column.EvalExprColumn(run, colABoard).IntoRegVecSaveAlloc() + filterA = column.EvalExprColumn(run, filterABorad).IntoRegVecSaveAlloc() + multA, hornerA field.Element + ) + // Add hornerA after scaling + hornerATrace := poly.GetHornerTrace(colA, filterA, run.GetRandomCoinField(p.EvalCoins[elemIndex].Name)) + multA = run.GetRandomCoinField(p.EvalCoins[elemIndex].Name) + multA.Exp(multA, &inp.CumulativeNumOnesPrevSegmentsA) + hornerA = hornerATrace[0] + hornerA.Mul(&hornerA, &multA) + elemParam.Add(&elemParam, &hornerA) + } else if !inp.IsAInModule && inp.IsBInModule { + var ( + colBBoard = inp.ColumnB.Board() + filterBBorad = inp.FilterB.Board() + colB = column.EvalExprColumn(run, colBBoard).IntoRegVecSaveAlloc() + filterB = column.EvalExprColumn(run, filterBBorad).IntoRegVecSaveAlloc() + multB, hornerB field.Element + ) + // Subtract hornerB after scaling + hornerBTrace := poly.GetHornerTrace(colB, filterB, run.GetRandomCoinField(p.EvalCoins[elemIndex].Name)) + hornerB = hornerBTrace[0] + multB = run.GetRandomCoinField(p.EvalCoins[elemIndex].Name) + multB.Exp(multB, &inp.CumulativeNumOnesPrevSegmentsB) + hornerB.Mul(&hornerB, &multB) + elemParam.Sub(&elemParam, &hornerB) + } else { + panic("Invalid distributed projection query encountered during param evaluation") + } + queryParam.Add(&queryParam, &elemParam) + } + return queryParam +} + +// computeHashPrev computes the hash of the cumulative number of ones in the previous segments +func (p *DistributeProjectionCtx) computeHashPrev() field.Element { + var ( + oldState = field.Zero() + ) + for _, inp := range p.DistProjectionInput { + if inp.IsAInModule && inp.IsBInModule { + sumA := field.NewElement(inp.CumulativeNumOnesPrevSegmentsA.Uint64()) + oldState = mimc.BlockCompression(oldState, sumA) + sumB := field.NewElement(inp.CumulativeNumOnesPrevSegmentsB.Uint64()) + oldState = mimc.BlockCompression(oldState, sumB) + } else if inp.IsAInModule && !inp.IsBInModule { + sumA := field.NewElement(inp.CumulativeNumOnesPrevSegmentsA.Uint64()) + oldState = mimc.BlockCompression(oldState, sumA) + } else if !inp.IsAInModule && inp.IsBInModule { + sumB := field.NewElement(inp.CumulativeNumOnesPrevSegmentsB.Uint64()) + oldState = mimc.BlockCompression(oldState, sumB) + } else { + panic("Invalid distributed projection query encountered during previous hash computation") + } + } + return oldState +} + +// computeHashCurr computes the hash of the cumulative number of ones in the current segment +func (p *DistributeProjectionCtx) computeHashCurr() field.Element { + var ( + oldState = field.Zero() + ) + for _, inp := range p.DistProjectionInput { + if inp.IsAInModule && inp.IsBInModule { + sumA := field.NewElement(inp.CumulativeNumOnesPrevSegmentsA.Uint64()) + sumA.Add(&sumA, &inp.CurrNumOnesA) + oldState = mimc.BlockCompression(oldState, sumA) + sumB := field.NewElement(inp.CumulativeNumOnesPrevSegmentsB.Uint64()) + sumB.Add(&sumB, &inp.CurrNumOnesB) + oldState = mimc.BlockCompression(oldState, sumB) + } else if inp.IsAInModule && !inp.IsBInModule { + sumA := field.NewElement(inp.CumulativeNumOnesPrevSegmentsA.Uint64()) + sumA.Add(&sumA, &inp.CurrNumOnesA) + oldState = mimc.BlockCompression(oldState, sumA) + } else if !inp.IsAInModule && inp.IsBInModule { + sumB := field.NewElement(inp.CumulativeNumOnesPrevSegmentsB.Uint64()) + sumB.Add(&sumB, &inp.CurrNumOnesB) + oldState = mimc.BlockCompression(oldState, sumB) + } else { + panic("Invalid distributed projection query encountered during current hash computation") + } + } + return oldState +} + +// deriveName constructs a name for the DistributeProjectionCtx context +func deriveName[R ~string](q ifaces.QueryID, ss ...any) R { + ss = append([]any{distProjectionStr, q}, ss...) + return wizardutils.DeriveName[R](ss...) +} + +// QueryID formats and returns a name of the [query.DistributedProjection] generated by the current context +func (p *DistributeProjectionCtx) QueryID() ifaces.QueryID { + return deriveName[ifaces.QueryID](ifaces.QueryID(p.TargetModuleName)) +} + +// getLastRoundProjection scans the initialComp and looks for uncompiled projection queries. It returns +// the highest round found for a matched projection query. It returns -1 if no queries are found. +func getLastRoundProjection(initialComp *wizard.CompiledIOP) int { + + var ( + lastRound = -1 + numRounds = initialComp.NumRounds() + ) + + for round := 0; round < numRounds; round++ { + queries := initialComp.QueriesNoParams.AllKeysAt(round) + for _, qName := range queries { + + if initialComp.QueriesNoParams.IsIgnored(qName) { + continue + } + + _, ok := initialComp.QueriesNoParams.Data(qName).(query.Projection) + if !ok { + continue + } + + lastRound = max(lastRound, round) + } + } + return lastRound } diff --git a/prover/protocol/distributed/compiler/projection/projection_test.go b/prover/protocol/distributed/compiler/projection/projection_test.go new file mode 100644 index 000000000..c9da25f4e --- /dev/null +++ b/prover/protocol/distributed/compiler/projection/projection_test.go @@ -0,0 +1,334 @@ +package dist_projection_test + +import ( + "errors" + "testing" + + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/distributedprojection" + "github.com/consensys/linea-monorepo/prover/protocol/compiler/dummy" + "github.com/consensys/linea-monorepo/prover/protocol/distributed" + dist_projection "github.com/consensys/linea-monorepo/prover/protocol/distributed/compiler/projection" + "github.com/consensys/linea-monorepo/prover/protocol/distributed/constants" + "github.com/consensys/linea-monorepo/prover/protocol/distributed/lpp" + md "github.com/consensys/linea-monorepo/prover/protocol/distributed/namebaseddiscoverer" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/stretchr/testify/require" +) + +type AllVerifierRuntimes struct { + RuntimesA []wizard.Runtime + RuntimesB []wizard.Runtime + RuntimesC []wizard.Runtime +} + +func TestDistributeProjection(t *testing.T) { + const ( + numSegModuleA = 4 + numSegModuleB = 4 + numSegModuleC = 4 + ) + var ( + allVerfiers = AllVerifierRuntimes{} + moduleAName = "moduleA" + moduleBName = "moduleB" + moduleCName = "moduleC" + flagSizeA = 512 + flagSizeB = 256 + flagA, flagB, columnA, columnB, flagC, columnC ifaces.Column + colA0, colA1, colA2, colB0, colB1, colB2, colC0, colC1, colC2 ifaces.Column + ) + testcases := []struct { + Name string + DefineFunc func(builder *wizard.Builder) + InitialProverFunc func(run *wizard.ProverRuntime) + }{ + { + DefineFunc: func(builder *wizard.Builder) { + flagA = builder.RegisterCommit(ifaces.ColID("moduleA.FilterA"), flagSizeA) + flagB = builder.RegisterCommit(ifaces.ColID("moduleB.FliterB"), flagSizeB) + flagC = builder.RegisterCommit(ifaces.ColID("moduleC.FliterC"), flagSizeB) + columnA = builder.RegisterCommit(ifaces.ColID("moduleA.ColumnA"), flagSizeA) + columnB = builder.RegisterCommit(ifaces.ColID("moduleB.ColumnB"), flagSizeB) + columnC = builder.RegisterCommit(ifaces.ColID("moduleC.ColumnC"), flagSizeB) + _ = builder.InsertProjection("ProjectionTest-A-B", + query.ProjectionInput{ColumnA: []ifaces.Column{columnA}, ColumnB: []ifaces.Column{columnB}, FilterA: flagA, FilterB: flagB}) + _ = builder.InsertProjection("ProjectionTest-C-A", + query.ProjectionInput{ColumnA: []ifaces.Column{columnC}, ColumnB: []ifaces.Column{columnA}, FilterA: flagC, FilterB: flagA}) + + }, + InitialProverFunc: func(run *wizard.ProverRuntime) { + // assign filters and columns + var ( + flagAWit = make([]field.Element, flagSizeA) + columnAWit = make([]field.Element, flagSizeA) + flagBWit = make([]field.Element, flagSizeB) + columnBWit = make([]field.Element, flagSizeB) + flagCWit = make([]field.Element, flagSizeB) + columnCWit = make([]field.Element, flagSizeB) + ) + for i := 0; i < 10; i++ { + flagAWit[i] = field.One() + columnAWit[i] = field.NewElement(uint64(i)) + } + for i := flagSizeB - 10; i < flagSizeB; i++ { + flagBWit[i] = field.One() + flagCWit[i] = field.One() + columnBWit[i] = field.NewElement(uint64(i - (flagSizeB - 10))) + columnCWit[i] = field.NewElement(uint64(i - (flagSizeB - 10))) + } + run.AssignColumn(flagA.GetColID(), smartvectors.RightZeroPadded(flagAWit, flagSizeA)) + run.AssignColumn(flagB.GetColID(), smartvectors.RightZeroPadded(flagBWit, flagSizeB)) + run.AssignColumn(flagC.GetColID(), smartvectors.RightZeroPadded(flagCWit, flagSizeB)) + run.AssignColumn(columnB.GetColID(), smartvectors.RightZeroPadded(columnBWit, flagSizeB)) + run.AssignColumn(columnA.GetColID(), smartvectors.RightZeroPadded(columnAWit, flagSizeA)) + run.AssignColumn(columnC.GetColID(), smartvectors.RightZeroPadded(columnCWit, flagSizeB)) + }, + }, + { + Name: "distribute-projection-multiple_projections-multi-columns", + DefineFunc: func(builder *wizard.Builder) { + flagA = builder.RegisterCommit(ifaces.ColID("moduleA.FilterA"), flagSizeA) + flagB = builder.RegisterCommit(ifaces.ColID("moduleB.FliterB"), flagSizeB) + flagC = builder.RegisterCommit(ifaces.ColID("moduleC.FliterB"), flagSizeB) + colA0 = builder.RegisterCommit(ifaces.ColID("moduleA.ColumnA0"), flagSizeA) + colA1 = builder.RegisterCommit(ifaces.ColID("moduleA.ColumnA1"), flagSizeA) + colA2 = builder.RegisterCommit(ifaces.ColID("moduleA.ColumnA2"), flagSizeA) + colB0 = builder.RegisterCommit(ifaces.ColID("moduleB.ColumnB0"), flagSizeB) + colB1 = builder.RegisterCommit(ifaces.ColID("moduleB.ColumnB1"), flagSizeB) + colB2 = builder.RegisterCommit(ifaces.ColID("moduleB.ColumnB2"), flagSizeB) + colC0 = builder.RegisterCommit(ifaces.ColID("moduleC.ColumnC0"), flagSizeB) + colC1 = builder.RegisterCommit(ifaces.ColID("moduleC.ColumnC1"), flagSizeB) + colC2 = builder.RegisterCommit(ifaces.ColID("moduleC.ColumnC2"), flagSizeB) + _ = builder.InsertProjection("ProjectionTest-A-B-Multicolum", + query.ProjectionInput{ColumnA: []ifaces.Column{colA0, colA1, colA2}, ColumnB: []ifaces.Column{colB0, colB1, colB2}, FilterA: flagA, FilterB: flagB}) + _ = builder.InsertProjection("ProjectionTest-C-A-Multicolumn", + query.ProjectionInput{ColumnA: []ifaces.Column{colC0, colC1, colC2}, ColumnB: []ifaces.Column{colA0, colA1, colA2}, FilterA: flagC, FilterB: flagA}) + + }, + InitialProverFunc: func(run *wizard.ProverRuntime) { + // assign filters and columns + var ( + flagAWit = make([]field.Element, flagSizeA) + flagBWit = make([]field.Element, flagSizeB) + flagCWit = make([]field.Element, flagSizeB) + colA0Wit = make([]field.Element, flagSizeA) + colA1Wit = make([]field.Element, flagSizeA) + colA2Wit = make([]field.Element, flagSizeA) + colB0Wit = make([]field.Element, flagSizeB) + colB1Wit = make([]field.Element, flagSizeB) + colB2Wit = make([]field.Element, flagSizeB) + colC0Wit = make([]field.Element, flagSizeB) + colC1Wit = make([]field.Element, flagSizeB) + colC2Wit = make([]field.Element, flagSizeB) + ) + for i := 0; i < 10; i++ { + flagAWit[i] = field.One() + colA0Wit[i] = field.NewElement(uint64(i)) + colA1Wit[i] = field.NewElement(uint64(i + 1)) + colA2Wit[i] = field.NewElement(uint64(i + 2)) + } + for i := flagSizeB - 10; i < flagSizeB; i++ { + flagBWit[i] = field.One() + flagCWit[i] = field.One() + colB0Wit[i] = field.NewElement(uint64(i - (flagSizeB - 10))) + colC0Wit[i] = field.NewElement(uint64(i - (flagSizeB - 10))) + colB1Wit[i] = field.NewElement(uint64(i + 1 - (flagSizeB - 10))) + colC1Wit[i] = field.NewElement(uint64(i + 1 - (flagSizeB - 10))) + colB2Wit[i] = field.NewElement(uint64(i + 2 - (flagSizeB - 10))) + colC2Wit[i] = field.NewElement(uint64(i + 2 - (flagSizeB - 10))) + } + run.AssignColumn(flagA.GetColID(), smartvectors.RightZeroPadded(flagAWit, flagSizeA)) + run.AssignColumn(flagB.GetColID(), smartvectors.RightZeroPadded(flagBWit, flagSizeB)) + run.AssignColumn(flagC.GetColID(), smartvectors.RightZeroPadded(flagCWit, flagSizeB)) + run.AssignColumn(colA0.GetColID(), smartvectors.RightZeroPadded(colA0Wit, flagSizeA)) + run.AssignColumn(colA1.GetColID(), smartvectors.RightZeroPadded(colA1Wit, flagSizeA)) + run.AssignColumn(colA2.GetColID(), smartvectors.RightZeroPadded(colA2Wit, flagSizeA)) + run.AssignColumn(colB0.GetColID(), smartvectors.RightZeroPadded(colB0Wit, flagSizeB)) + run.AssignColumn(colB1.GetColID(), smartvectors.RightZeroPadded(colB1Wit, flagSizeB)) + run.AssignColumn(colB2.GetColID(), smartvectors.RightZeroPadded(colB2Wit, flagSizeB)) + run.AssignColumn(colC0.GetColID(), smartvectors.RightZeroPadded(colC0Wit, flagSizeB)) + run.AssignColumn(colC1.GetColID(), smartvectors.RightZeroPadded(colC1Wit, flagSizeB)) + run.AssignColumn(colC2.GetColID(), smartvectors.RightZeroPadded(colC2Wit, flagSizeB)) + }, + }, + } + for _, tc := range testcases { + + t.Run(tc.Name, func(t *testing.T) { + + // initialComp is defined according to the define function provided by the + // test-case. + initialComp := wizard.Compile(tc.DefineFunc) + + // apply the LPP relevant compilers and generate the seed for initialComp + lppComp := lpp.CompileLPPAndGetSeed(initialComp) + + // Initialize the period separating module discoverer + disc := &md.PeriodSeperatingModuleDiscoverer{} + disc.Analyze(initialComp) + + // distribute the columns among modules and segments; this includes also multiplicity columns + // for all the segments from the same module, compiledIOP object is the same. + moduleCompA := distributed.GetFreshSegmentModuleComp( + distributed.SegmentModuleInputs{ + InitialComp: initialComp, + Disc: disc, + ModuleName: moduleAName, + NumSegmentsInModule: numSegModuleA, + }, + ) + + moduleCompB := distributed.GetFreshSegmentModuleComp(distributed.SegmentModuleInputs{ + InitialComp: initialComp, + Disc: disc, + ModuleName: moduleBName, + NumSegmentsInModule: numSegModuleB, + }) + + moduleCompC := distributed.GetFreshSegmentModuleComp(distributed.SegmentModuleInputs{ + InitialComp: initialComp, + Disc: disc, + ModuleName: moduleCName, + NumSegmentsInModule: numSegModuleC, + }) + + // distribute the query LogDerivativeSum among modules. + // The seed is used to generate randomness for each moduleComp. + dist_projection.NewDistributeProjectionCtx(moduleAName, initialComp, moduleCompA, disc, numSegModuleA) + dist_projection.NewDistributeProjectionCtx(moduleBName, initialComp, moduleCompB, disc, numSegModuleB) + dist_projection.NewDistributeProjectionCtx(moduleCName, initialComp, moduleCompC, disc, numSegModuleC) + + // This compiles the log-derivative queries into global/local queries. + wizard.ContinueCompilation(moduleCompA, distributedprojection.CompileDistributedProjection, dummy.Compile) + wizard.ContinueCompilation(moduleCompB, distributedprojection.CompileDistributedProjection, dummy.Compile) + wizard.ContinueCompilation(moduleCompC, distributedprojection.CompileDistributedProjection, dummy.Compile) + + // run the initial runtime + initialRuntime := wizard.ProverOnlyFirstRound(initialComp, tc.InitialProverFunc) + + // compile and verify for lpp-Prover + lppProof := wizard.Prove(lppComp, func(run *wizard.ProverRuntime) { + run.ParentRuntime = initialRuntime + }) + lppVerifierRuntime, valid := wizard.VerifyWithRuntime(lppComp, lppProof) + require.NoError(t, valid) + + // Compile and prove for moduleA + allVerfiers.RuntimesA = make([]wizard.Runtime, 0, numSegModuleA) + for proverID := 0; proverID < numSegModuleA; proverID++ { + proofA := wizard.Prove(moduleCompA, func(run *wizard.ProverRuntime) { + run.ParentRuntime = initialRuntime + // inputs for vertical splitting of the witness + run.ProverID = proverID + }) + runtimeA, validA := wizard.VerifyWithRuntime(moduleCompA, proofA, lppVerifierRuntime) + require.NoError(t, validA) + + allVerfiers.RuntimesA = append(allVerfiers.RuntimesA, runtimeA) + } + + // Compile and prove for moduleB + allVerfiers.RuntimesB = make([]wizard.Runtime, 0, numSegModuleB) + for proverID := 0; proverID < numSegModuleB; proverID++ { + proofB := wizard.Prove(moduleCompB, func(run *wizard.ProverRuntime) { + run.ParentRuntime = initialRuntime + // inputs for vertical splitting of the witness + run.ProverID = proverID + }) + runtimeB, validB := wizard.VerifyWithRuntime(moduleCompB, proofB, lppVerifierRuntime) + require.NoError(t, validB) + + allVerfiers.RuntimesB = append(allVerfiers.RuntimesB, runtimeB) + + } + + // Compile and prove for moduleC + allVerfiers.RuntimesC = make([]wizard.Runtime, 0, numSegModuleC) + for proverID := 0; proverID < numSegModuleC; proverID++ { + proofC := wizard.Prove(moduleCompC, func(run *wizard.ProverRuntime) { + run.ParentRuntime = initialRuntime + // inputs for vertical splitting of the witness + run.ProverID = proverID + }) + runtimeC, validC := wizard.VerifyWithRuntime(moduleCompC, proofC, lppVerifierRuntime) + require.NoError(t, validC) + + allVerfiers.RuntimesC = append(allVerfiers.RuntimesC, runtimeC) + } + + // apply the crosse checks over the public inputs. + require.NoError(t, checkConsistency(allVerfiers)) + + }) + + } + +} + +func checkConsistency(allVerRuns AllVerifierRuntimes) error { + + var ( + res = field.Zero() + currCumSumArrayA = make([]field.Element, 0, len(allVerRuns.RuntimesA)) + prevCumSumArrayA = make([]field.Element, 0, len(allVerRuns.RuntimesA)) + currCumSumArrayB = make([]field.Element, 0, len(allVerRuns.RuntimesB)) + prevCumSumArrayB = make([]field.Element, 0, len(allVerRuns.RuntimesB)) + currCumSumArrayC = make([]field.Element, 0, len(allVerRuns.RuntimesC)) + prevCumSumArrayC = make([]field.Element, 0, len(allVerRuns.RuntimesC)) + ) + for _, run := range allVerRuns.RuntimesA { + distributedPubInputs_ := run.GetPublicInput(constants.DistributedProjectionPublicInput) + distributedPubInputs, ok := distributedPubInputs_.(wizard.DistributedProjectionPublicInput) + if !ok { + return errors.New("the distributed projection public input is not valid for module A") + } + res.Add(&res, &distributedPubInputs.ScaledHorner) + currCumSumArrayA = append(currCumSumArrayA, distributedPubInputs.CumSumCurr) + prevCumSumArrayA = append(prevCumSumArrayA, distributedPubInputs.CumSumPrev) + } + for _, run := range allVerRuns.RuntimesB { + distributedPubInputs_ := run.GetPublicInput(constants.DistributedProjectionPublicInput) + distributedPubInputs, ok := distributedPubInputs_.(wizard.DistributedProjectionPublicInput) + if !ok { + return errors.New("the distributed projection public input is not valid for module B") + } + res.Add(&res, &distributedPubInputs.ScaledHorner) + currCumSumArrayB = append(currCumSumArrayB, distributedPubInputs.CumSumCurr) + prevCumSumArrayB = append(prevCumSumArrayB, distributedPubInputs.CumSumPrev) + } + for _, run := range allVerRuns.RuntimesC { + distributedPubInputs_ := run.GetPublicInput(constants.DistributedProjectionPublicInput) + distributedPubInputs, ok := distributedPubInputs_.(wizard.DistributedProjectionPublicInput) + if !ok { + return errors.New("the distributed projection public input is not valid for module C") + } + res.Add(&res, &distributedPubInputs.ScaledHorner) + currCumSumArrayC = append(currCumSumArrayC, distributedPubInputs.CumSumCurr) + prevCumSumArrayC = append(prevCumSumArrayC, distributedPubInputs.CumSumPrev) + } + + if !res.IsZero() { + return errors.New("the distributed projection sums do not cancel each others") + } + for i := 1; i < len(currCumSumArrayA); i++ { + if currCumSumArrayA[i-1] != prevCumSumArrayA[i] { + return errors.New("the vertical splitting for the distributed projection is not consistent for module A") + } + } + for i := 1; i < len(currCumSumArrayB); i++ { + if currCumSumArrayB[i-1] != prevCumSumArrayB[i] { + return errors.New("the vertical splitting for the distributed projection is not consistent for module B") + } + } + for i := 1; i < len(currCumSumArrayC); i++ { + if currCumSumArrayC[i-1] != prevCumSumArrayC[i] { + return errors.New("the vertical splitting for the distributed projection is not consistent for module C") + } + } + + return nil +} diff --git a/prover/protocol/distributed/conglomeration/cross_segment_consistency.go b/prover/protocol/distributed/conglomeration/cross_segment_consistency.go index 431aa9787..911c2ef06 100644 --- a/prover/protocol/distributed/conglomeration/cross_segment_consistency.go +++ b/prover/protocol/distributed/conglomeration/cross_segment_consistency.go @@ -10,7 +10,7 @@ import ( "github.com/consensys/linea-monorepo/prover/protocol/wizard" ) -// crossSegmentCheclk is a verifier action that performs cross-segment checks: +// crossSegmentCheck is a verifier action that performs cross-segment checks: // for instance, it checks that the log-derivative sums all sums to 0 and that // the grand product is 1. The goal is to ensure that the lookups, permutations // in the original protocol are satisfied. diff --git a/prover/protocol/distributed/conglomeration/translator.go b/prover/protocol/distributed/conglomeration/translator.go index 4123cf764..11b22ad8e 100644 --- a/prover/protocol/distributed/conglomeration/translator.go +++ b/prover/protocol/distributed/conglomeration/translator.go @@ -254,7 +254,7 @@ func (run *runtimeTranslator) GetSpec() *wizard.CompiledIOP { return run.Rt.GetSpec() } -func (run *runtimeTranslator) GetPublicInput(name string) field.Element { +func (run *runtimeTranslator) GetPublicInput(name string) any { name = run.Prefix + "." + name return run.Rt.GetPublicInput(name) } @@ -264,6 +264,11 @@ func (run *runtimeTranslator) GetGrandProductParams(name ifaces.QueryID) query.G return run.Rt.GetGrandProductParams(name) } +func (run *runtimeTranslator) GetDistributedProjectionParams(name ifaces.QueryID) query.DistributedProjectionParams { + name = ifaces.QueryID(run.Prefix) + "." + name + return run.Rt.GetDistributedProjectionParams(name) +} + func (run *runtimeTranslator) GetLogDerivSumParams(name ifaces.QueryID) query.LogDerivSumParams { name = ifaces.QueryID(run.Prefix) + "." + name return run.Rt.GetLogDerivSumParams(name) @@ -361,6 +366,11 @@ func (run *gnarkRuntimeTranslator) GetLogDerivSumParams(name ifaces.QueryID) que return run.Rt.GetLogDerivSumParams(name) } +func (run *gnarkRuntimeTranslator) GetDistributedProjectionParams(name ifaces.QueryID) query.GnarkDistributedProjectionParams { + name = ifaces.QueryID(run.Prefix) + "." + name + return run.Rt.GetDistributedProjectionParams(name) +} + func (run *gnarkRuntimeTranslator) GetLocalPointEvalParams(name ifaces.QueryID) query.GnarkLocalOpeningParams { name = ifaces.QueryID(run.Prefix) + "." + name return run.Rt.GetLocalPointEvalParams(name) diff --git a/prover/protocol/distributed/constants/constant.go b/prover/protocol/distributed/constants/constant.go index 0a35a13d5..f66ec3b65 100644 --- a/prover/protocol/distributed/constants/constant.go +++ b/prover/protocol/distributed/constants/constant.go @@ -1,7 +1,8 @@ package constants const ( - LogDerivativeSumPublicInput = "LOG_DERIVATE_SUM_PUBLIC_INPUT" - GrandProductPublicInput = "GRAND_PRODUCT_PUBLIC_INPUT" - GrandSumPublicInput = "GRAND_SUM_PUBLIC_INPUT" + LogDerivativeSumPublicInput = "LOG_DERIVATE_SUM_PUBLIC_INPUT" + GrandProductPublicInput = "GRAND_PRODUCT_PUBLIC_INPUT" + GrandSumPublicInput = "GRAND_SUM_PUBLIC_INPUT" + DistributedProjectionPublicInput = "DISTRIBUTED_PROJECTION_PUBLIC_INPUT" ) diff --git a/prover/protocol/distributed/distributed.go b/prover/protocol/distributed/distributed.go index 65798fea9..56cfde798 100644 --- a/prover/protocol/distributed/distributed.go +++ b/prover/protocol/distributed/distributed.go @@ -30,12 +30,12 @@ type ModuleDiscoverer interface { // Analyze is responsible for letting the module discoverer compute how to // group best the columns into modules. Analyze(comp *wizard.CompiledIOP) - NbModules() int ModuleList() []ModuleName FindModule(col ifaces.Column) ModuleName // given a query and a module name it checks if the query is inside the module ExpressionIsInModule(*symbolic.Expression, ModuleName) bool QueryIsInModule(ifaces.Query, ModuleName) bool + // it checks if the given column is in the given module ColumnIsInModule(col ifaces.Column, name ModuleName) bool } diff --git a/prover/protocol/distributed/lpp/lpp.go b/prover/protocol/distributed/lpp/lpp.go index 55ebf510c..32b70704c 100644 --- a/prover/protocol/distributed/lpp/lpp.go +++ b/prover/protocol/distributed/lpp/lpp.go @@ -18,7 +18,7 @@ func CompileLPPAndGetSeed(comp *wizard.CompiledIOP, lppCompilers ...func(*wizard ) // get the LPP columns from comp. - lppCols = append(lppCols, getLPPColumns(comp)...) + lppCols = append(lppCols, GetLPPColumns(comp)...) for _, col := range comp.Columns.AllHandlesAtRound(0) { oldColumns = append(oldColumns, col) @@ -26,15 +26,17 @@ func CompileLPPAndGetSeed(comp *wizard.CompiledIOP, lppCompilers ...func(*wizard // applies lppCompiler; this would add a new round and probably new columns to the current round // but no new column to the new round. - for _, lppCompiler := range lppCompilers { - lppCompiler(comp) + if len(lppCompilers) > 0 { + for _, lppCompiler := range lppCompilers { + lppCompiler(comp) - if comp.NumRounds() != 2 || comp.Columns.NumRounds() != 1 { - panic("we expect to have new round while no column is yet registered for the new round") - } + if comp.NumRounds() != 2 || comp.Columns.NumRounds() != 1 { + panic("we expect to have new round while no column is yet registered for the new round") + } - numRounds := comp.NumRounds() - comp.EqualizeRounds(numRounds) + numRounds := comp.NumRounds() + comp.EqualizeRounds(numRounds) + } } // filter the new lpp columns. @@ -60,7 +62,7 @@ func CompileLPPAndGetSeed(comp *wizard.CompiledIOP, lppCompilers ...func(*wizard // register the seed, generated from LPP, in comp // for the sake of the assignment it also should be registered in lppComp lppComp.InsertCoin(1, "SEED", coin.Field) - comp.InsertCoin(1, "SEED", coin.Field) + comp.InsertCoin(0, "SEED", coin.Field) // prepare and register prover actions. lppProver := &lppProver{ @@ -107,7 +109,7 @@ func GetLPPComp(oldComp *wizard.CompiledIOP, newLPPCols []ifaces.Column) *wizard ) // get the LPP columns - lppCols = append(lppCols, getLPPColumns(oldComp)...) + lppCols = append(lppCols, GetLPPColumns(oldComp)...) lppCols = append(lppCols, newLPPCols...) for _, col := range lppCols { @@ -117,8 +119,9 @@ func GetLPPComp(oldComp *wizard.CompiledIOP, newLPPCols []ifaces.Column) *wizard } // it extract LPP columns from the context of each LPP query. -func getLPPColumns(c *wizard.CompiledIOP) []ifaces.Column { - +func GetLPPColumns(c *wizard.CompiledIOP) []ifaces.Column { + // Todo(arijit): should prevent inserting duplicate columns + // Use checkAlreadyExists() var ( lppColumns = []ifaces.Column{} ) @@ -127,30 +130,46 @@ func getLPPColumns(c *wizard.CompiledIOP) []ifaces.Column { q := c.QueriesNoParams.Data(qName) switch v := q.(type) { case query.Inclusion: - - for i := range v.Including { - lppColumns = append(lppColumns, v.Including[i]...) + if !checkAlreadyExists(lppColumns, v.Including[0][0]) { + for i := range v.Including { + lppColumns = append(lppColumns, v.Including[i]...) + } + } + if !checkAlreadyExists(lppColumns, v.Included[0]) { + lppColumns = append(lppColumns, v.Included...) } - - lppColumns = append(lppColumns, v.Included...) if v.IncludingFilter != nil { - lppColumns = append(lppColumns, v.IncludingFilter...) + if !checkAlreadyExists(lppColumns, v.IncludingFilter[0]) { + lppColumns = append(lppColumns, v.IncludingFilter...) + } } if v.IncludedFilter != nil { - lppColumns = append(lppColumns, v.IncludedFilter) + if !checkAlreadyExists(lppColumns, v.IncludedFilter) { + lppColumns = append(lppColumns, v.IncludedFilter) + } } case query.Permutation: for i := range v.A { - lppColumns = append(lppColumns, v.A[i]...) - lppColumns = append(lppColumns, v.B[i]...) + if !checkAlreadyExists(lppColumns, v.A[0][0]) { + lppColumns = append(lppColumns, v.A[i]...) + } + if !checkAlreadyExists(lppColumns, v.B[0][0]) { + lppColumns = append(lppColumns, v.B[i]...) + } } case query.Projection: - lppColumns = append(lppColumns, v.Inp.ColumnA...) - lppColumns = append(lppColumns, v.Inp.ColumnB...) - lppColumns = append(lppColumns, v.Inp.FilterA, v.Inp.FilterB) + // If ColumnA exists then FilterA also exists + if !checkAlreadyExists(lppColumns, v.Inp.ColumnA[0]) { + lppColumns = append(lppColumns, v.Inp.ColumnA...) + lppColumns = append(lppColumns, v.Inp.FilterA) + } + if !checkAlreadyExists(lppColumns, v.Inp.ColumnB[0]) { + lppColumns = append(lppColumns, v.Inp.ColumnB...) + lppColumns = append(lppColumns, v.Inp.FilterB) + } default: //do noting @@ -160,3 +179,13 @@ func getLPPColumns(c *wizard.CompiledIOP) []ifaces.Column { return lppColumns } + +// it checks if the column is already inserted in the list +func checkAlreadyExists(lppColumns []ifaces.Column, sampleCol ifaces.Column) bool { + for _, col := range lppColumns { + if col.GetColID() == sampleCol.GetColID() { + return true + } + } + return false +} diff --git a/prover/protocol/distributed/lpp/lpp_test.go b/prover/protocol/distributed/lpp/lpp_test.go index 49edf4fc5..88f609cc2 100644 --- a/prover/protocol/distributed/lpp/lpp_test.go +++ b/prover/protocol/distributed/lpp/lpp_test.go @@ -140,7 +140,7 @@ func TestSeedGeneration(t *testing.T) { run.ProverID = proverID }) - // get and compar the coins with the other segments/modules + // get and compare the coins with the other segments/modules coin1 := runtime0.Coins.MustGet("TABLE_module0.col2_LOGDERIVATIVE_GAMMA_FieldFromSeed").(field.Element) coin0 := runtime0.Coins.MustGet("TABLE_module1.col0_LOGDERIVATIVE_GAMMA_FieldFromSeed").(field.Element) if coinLookup1Gamma.IsZero() { @@ -186,7 +186,7 @@ func TestSeedGeneration(t *testing.T) { run.ProverID = proverID }) - // get and compar the coins with the other segments/modules + // get and compare the coins with the other segments/modules coin2Gamma := runtime2.Coins.MustGet("TABLE_module1.col5,module1.col1,module1.col4_LOGDERIVATIVE_GAMMA_FieldFromSeed").(field.Element) coin2Alpha := runtime2.Coins.MustGet("TABLE_module1.col5,module1.col1,module1.col4_LOGDERIVATIVE_ALPHA_FieldFromSeed").(field.Element) diff --git a/prover/protocol/distributed/namebaseddiscoverer/period_separating_module_discoverer.go b/prover/protocol/distributed/namebaseddiscoverer/period_separating_module_discoverer.go index c2f04fe3f..2008db369 100644 --- a/prover/protocol/distributed/namebaseddiscoverer/period_separating_module_discoverer.go +++ b/prover/protocol/distributed/namebaseddiscoverer/period_separating_module_discoverer.go @@ -1,9 +1,10 @@ -package namebaseddiscoverer +package discoverer import ( "strings" "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/column/verifiercol" "github.com/consensys/linea-monorepo/prover/protocol/distributed" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" @@ -55,11 +56,6 @@ func periodSeparator(name string) string { return name[:index] } -// NbModules returns the number of modules -func (p *PeriodSeperatingModuleDiscoverer) NbModules() int { - return len(p.modules) -} - // ModuleList returns the list of module names func (p *PeriodSeperatingModuleDiscoverer) ModuleList() []ModuleName { moduleNames := make([]ModuleName, 0, len(p.modules)) @@ -89,8 +85,12 @@ func (p *PeriodSeperatingModuleDiscoverer) QueryIsInModule(ifaces.Query, ModuleN // ColumnIsInModule checks that the given column is inside the given module. func (p *PeriodSeperatingModuleDiscoverer) ColumnIsInModule(col ifaces.Column, name ModuleName) bool { + colID := col.GetColID() + if shifted, ok := col.(column.Shifted); ok { + colID = shifted.Parent.GetColID() + } for _, c := range p.modules[name] { - if c.GetColID() == col.GetColID() { + if c.GetColID() == colID { return true } } @@ -100,6 +100,7 @@ func (p *PeriodSeperatingModuleDiscoverer) ColumnIsInModule(col ifaces.Column, n // ExpressionIsInModule checks that all the columns (except verifiercol) in the expression are from the given module. // // It does not check the presence of the coins and other metadata in the module. +// the restriction over verifier column comes from the fact that the discoverer Analyses compiledIOP and the verifier columns are not accessible there. func (p *PeriodSeperatingModuleDiscoverer) ExpressionIsInModule(expr *symbolic.Expression, name ModuleName) bool { var ( board = expr.Board() @@ -131,7 +132,7 @@ func (p *PeriodSeperatingModuleDiscoverer) ExpressionIsInModule(expr *symbolic.E } if nCols == 0 { - panic("could not find any column in the expression") + panic("unsupported, could not find any column in the expression") } else { return b } diff --git a/prover/protocol/distributed/namebaseddiscoverer/query_based_module_discoverer.go b/prover/protocol/distributed/namebaseddiscoverer/query_based_module_discoverer.go new file mode 100644 index 000000000..2b6aa445d --- /dev/null +++ b/prover/protocol/distributed/namebaseddiscoverer/query_based_module_discoverer.go @@ -0,0 +1,136 @@ +package discoverer + +import ( + "github.com/consensys/linea-monorepo/prover/protocol/column" + "github.com/consensys/linea-monorepo/prover/protocol/distributed" + "github.com/consensys/linea-monorepo/prover/protocol/distributed/lpp" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/query" + "github.com/consensys/linea-monorepo/prover/protocol/variables" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/symbolic" + "github.com/consensys/linea-monorepo/prover/utils/collection" +) + +// struct implementing more complex analysis for [distributed.ModuleDiscoverer] +type QueryBasedDiscoverer struct { + // simple module discoverer; it does simple analysis + // it can neither capture the specific columns (like verifier columns or periodicSampling variables), + // nor categorizes the columns GL and LPP. + SimpleDiscoverer distributed.ModuleDiscoverer + // all the columns involved in the LPP queries, including the verifier columns + LPPColumns collection.Mapping[ModuleName, []ifaces.Column] + // all the columns involved in the GL queries, including the verifier columns + GLColumns collection.Mapping[ModuleName, []ifaces.Column] + // all the periodicSamples involved in the GL queries + PeriodicSamplingGL collection.Mapping[ModuleName, []variables.PeriodicSample] +} + +// Analyze first analyzes the simpleDiscoverer and then adds extra analyses based on the queries +// , to get the the verifier, GL , LPP columns and also PeriodicSamplings +func (d QueryBasedDiscoverer) Analyze(comp *wizard.CompiledIOP) { + + d.SimpleDiscoverer.Analyze(comp) + + // sanity check + if len(comp.QueriesParams.AllKeysAt(0)) != 0 { + panic("At this step, we do not expect a query with parameters") + } + + // capture and analysis GL Queries + for _, moduleName := range d.SimpleDiscoverer.ModuleList() { + for _, q := range comp.QueriesNoParams.AllKeys() { + + if global, ok := comp.QueriesNoParams.Data(q).(query.GlobalConstraint); ok { + + if d.SimpleDiscoverer.ExpressionIsInModule(global.Expression, moduleName) { + + // it analyzes the expression and update d.GLColumns and d.PeriodicSamplingGL + d.analyzeExprGL(global.Expression, moduleName) + + } + } + + // the same for local queries. + if local, ok := comp.QueriesNoParams.Data(q).(query.LocalConstraint); ok { + + if d.SimpleDiscoverer.ExpressionIsInModule(local.Expression, moduleName) { + + // it analyzes de expression and update d.GLColumns and d.PeriodicSamplingGL + d.analyzeExprGL(local.Expression, moduleName) + + } + } + + // get the LPP columns from all the LPP queries, and check if they are in the module + // this does not contain the new columns from preparation phase like the multiplicity columns. + lppCols := lpp.GetLPPColumns(comp) + for _, col := range lppCols { + if d.SimpleDiscoverer.ColumnIsInModule(col, moduleName) { + // update d content + d.LPPColumns.AppendNew(moduleName, []ifaces.Column{col}) + } + } + + } + } + +} + +// ModuleList returns the list of module names +func (d *QueryBasedDiscoverer) ModuleList() []ModuleName { + + return d.SimpleDiscoverer.ModuleList() +} + +// FindModule finds the module name for a given column +func (d *QueryBasedDiscoverer) FindModule(col ifaces.Column) ModuleName { + return d.SimpleDiscoverer.FindModule(col) +} + +// QueryIsInModule checks if the given query is inside the given module +func (d *QueryBasedDiscoverer) QueryIsInModule(q ifaces.Query, moduleName ModuleName) bool { + return d.SimpleDiscoverer.QueryIsInModule(q, moduleName) +} + +// ColumnIsInModule checks that the given column is inside the given module. +func (d *QueryBasedDiscoverer) ColumnIsInModule(col ifaces.Column, name ModuleName) bool { + return d.SimpleDiscoverer.ColumnIsInModule(col, name) +} + +// ExpressionIsInModule checks that all the columns (except verifiercol) in the expression are from the given module. +// +// It does not check the presence of the coins and other metadata in the module. +// the restriction over verifier column comes from the fact that the discoverer Analyses compiledIOP and the verifier columns are not accessible there. +func (p *QueryBasedDiscoverer) ExpressionIsInModule(expr *symbolic.Expression, name ModuleName) bool { + return p.SimpleDiscoverer.ExpressionIsInModule(expr, name) +} + +// analyzeExpr analyzes the expression and update the content of d. +func (d *QueryBasedDiscoverer) analyzeExprGL(expr *symbolic.Expression, moduleName ModuleName) { + + var ( + board = expr.Board() + metadata = board.ListVariableMetadata() + ) + + for _, m := range metadata { + + switch t := m.(type) { + case ifaces.Column: + + if shifted, ok := t.(column.Shifted); ok { + + d.GLColumns.AppendNew(moduleName, []ifaces.Column{shifted.Parent}) + + } else { + d.GLColumns.AppendNew(moduleName, []ifaces.Column{t}) + } + + case variables.PeriodicSample: + + d.PeriodicSamplingGL.AppendNew(moduleName, []variables.PeriodicSample{t}) + + } + } +} diff --git a/prover/protocol/query/distributed_projection.go b/prover/protocol/query/distributed_projection.go new file mode 100644 index 000000000..3ad815a5d --- /dev/null +++ b/prover/protocol/query/distributed_projection.go @@ -0,0 +1,86 @@ +package query + +import ( + "math/big" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/linea-monorepo/prover/crypto/fiatshamir" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/symbolic" + "github.com/consensys/linea-monorepo/prover/utils" +) + +type DistributedProjectionInput struct { + ColumnA, ColumnB *symbolic.Expression + FilterA, FilterB *symbolic.Expression + SizeA, SizeB int + EvalCoin coin.Name + IsAInModule, IsBInModule bool + CumulativeNumOnesPrevSegmentsA, CumulativeNumOnesPrevSegmentsB big.Int + CurrNumOnesA, CurrNumOnesB field.Element +} + +// func (dpInp *DistributedProjectionInput) completeAssign(run *ifaces.Runtime) { +// dpInp.CumulativeNumOnesPrevSegments.Run(run) +// } + +type DistributedProjection struct { + Round int + ID ifaces.QueryID + Inp []*DistributedProjectionInput +} + +type DistributedProjectionParams struct { + ScaledHorner field.Element + HashCumSumOnePrev, HashCumSumOneCurr field.Element +} + +func NewDistributedProjection(round int, id ifaces.QueryID, inp []*DistributedProjectionInput) DistributedProjection { + for _, in := range inp { + if err := in.ColumnA.Validate(); err != nil { + utils.Panic("ColumnA for the distributed projection query %v is not a valid expression", id) + } + if err := in.ColumnB.Validate(); err != nil { + utils.Panic("ColumnB for the distributed projection query %v is not a valid expression", id) + } + if err := in.FilterA.Validate(); err != nil { + utils.Panic("FilterA for the distributed projection query %v is not a valid expression", id) + } + if err := in.FilterB.Validate(); err != nil { + utils.Panic("FilterB for the distributed projection query %v is not a valid expression", id) + } + if !in.IsAInModule && !in.IsBInModule { + utils.Panic("Invalid distributed projection query %v, both A and B are not in the module", id) + } + } + return DistributedProjection{Round: round, ID: id, Inp: inp} +} + +// Constructor for distributed projection query parameters +func NewDistributedProjectionParams(scaledHorner, hashCumSumOnePrev, hashCumSumOneCurr field.Element) DistributedProjectionParams { + return DistributedProjectionParams{ + ScaledHorner: scaledHorner, + HashCumSumOnePrev: hashCumSumOnePrev, + HashCumSumOneCurr: hashCumSumOneCurr} +} + +// Name returns the unique identifier of the GrandProduct query. +func (dp DistributedProjection) Name() ifaces.QueryID { + return dp.ID +} + +// Updates a Fiat-Shamir state +func (dpp DistributedProjectionParams) UpdateFS(fs *fiatshamir.State) { + fs.Update(dpp.ScaledHorner) +} + +// Unimplemented +func (dp DistributedProjection) Check(run ifaces.Runtime) error { + return nil +} + +func (dp DistributedProjection) CheckGnark(api frontend.API, run ifaces.GnarkRuntime) { + panic("UNSUPPORTED : can't check a Projection query directly into the circuit") +} diff --git a/prover/protocol/query/gnark_params.go b/prover/protocol/query/gnark_params.go index 448d011a2..5b7f98c75 100644 --- a/prover/protocol/query/gnark_params.go +++ b/prover/protocol/query/gnark_params.go @@ -25,6 +25,11 @@ type GnarkGrandProductParams struct { Prod frontend.Variable } +// A gnark circuit version of DistributedProjectionParams +type GnarkDistributedProjectionParams struct { + Sum frontend.Variable +} + func (p LogDerivSumParams) GnarkAssign() GnarkLogDerivSumParams { return GnarkLogDerivSumParams{Sum: p.Sum} } @@ -78,6 +83,11 @@ func (p GnarkGrandProductParams) UpdateFS(fs *fiatshamir.GnarkFiatShamir) { fs.Update(p.Prod) } +// Update the fiat-shamir state with the the present parameters +func (p GnarkDistributedProjectionParams) UpdateFS(fs *fiatshamir.GnarkFiatShamir) { + fs.Update(p.Sum) +} + // Update the fiat-shamir state with the the present parameters func (p GnarkUnivariateEvalParams) UpdateFS(fs *fiatshamir.GnarkFiatShamir) { fs.Update(p.Ys...) diff --git a/prover/protocol/query/projection.go b/prover/protocol/query/projection.go index 45f88eb0f..2b1df8068 100644 --- a/prover/protocol/query/projection.go +++ b/prover/protocol/query/projection.go @@ -90,8 +90,8 @@ func (p Projection) Check(run ifaces.Runtime) error { bLinComb[row] = rowLinComb(linCombRand, row, b) } var ( - hornerA = poly.CmptHorner(aLinComb, fA, evalRand) - hornerB = poly.CmptHorner(bLinComb, fB, evalRand) + hornerA = poly.GetHornerTrace(aLinComb, fA, evalRand) + hornerB = poly.GetHornerTrace(bLinComb, fB, evalRand) ) if hornerA[0] != hornerB[0] { return fmt.Errorf("the projection query %v check is not satisfied", p.ID) diff --git a/prover/protocol/query/projection_test.go b/prover/protocol/query/projection_test.go index bd1a307fa..2fcc52935 100644 --- a/prover/protocol/query/projection_test.go +++ b/prover/protocol/query/projection_test.go @@ -33,10 +33,13 @@ func TestProjection(t *testing.T) { prover := func(run *wizard.ProverRuntime) { runS = run // assign filters and columns - flagAWit := make([]field.Element, flagSizeA) - columnAWit := make([]field.Element, flagSizeA) - flagBWit := make([]field.Element, flagSizeB) - columnBWit := make([]field.Element, flagSizeB) + var ( + flagAWit = make([]field.Element, flagSizeA) + columnAWit = make([]field.Element, flagSizeA) + flagBWit = make([]field.Element, flagSizeB) + columnBWit = make([]field.Element, flagSizeB) + ) + for i := 0; i < 10; i++ { flagAWit[i] = field.One() columnAWit[i] = field.NewElement(uint64(i)) diff --git a/prover/protocol/wizard/compiled.go b/prover/protocol/wizard/compiled.go index cfc0a6452..1065d2167 100644 --- a/prover/protocol/wizard/compiled.go +++ b/prover/protocol/wizard/compiled.go @@ -647,6 +647,14 @@ func (c *CompiledIOP) InsertProjection(id ifaces.QueryID, in query.ProjectionInp return q } +// Register a distributed projection query +func (c *CompiledIOP) InsertDistributedProjection(round int, id ifaces.QueryID, in []*query.DistributedProjectionInput) query.DistributedProjection { + q := query.NewDistributedProjection(round, id, in) + // Finally registers the query + c.QueriesParams.AddToRound(round, q.Name(), q) + return q +} + // AddPublicInput inserts a public-input in the compiled-IOP func (c *CompiledIOP) InsertPublicInput(name string, acc ifaces.Accessor) PublicInput { diff --git a/prover/protocol/wizard/gnark_verifier.go b/prover/protocol/wizard/gnark_verifier.go index fc2c1292f..778de2360 100644 --- a/prover/protocol/wizard/gnark_verifier.go +++ b/prover/protocol/wizard/gnark_verifier.go @@ -23,6 +23,7 @@ type GnarkRuntime interface { GetSpec() *CompiledIOP GetPublicInput(api frontend.API, name string) frontend.Variable GetGrandProductParams(name ifaces.QueryID) query.GnarkGrandProductParams + GetDistributedProjectionParams(name ifaces.QueryID) query.GnarkDistributedProjectionParams GetLogDerivSumParams(name ifaces.QueryID) query.GnarkLogDerivSumParams GetLocalPointEvalParams(name ifaces.QueryID) query.GnarkLocalOpeningParams GetInnerProductParams(name ifaces.QueryID) query.GnarkInnerProductParams @@ -73,6 +74,8 @@ type WizardVerifierCircuit struct { logDerivSumIDs collection.Mapping[ifaces.QueryID, int] `gnark:"-"` // Same for grand-product query grandProductIDs collection.Mapping[ifaces.QueryID, int] `gnark:"-"` + // Same for distributed projection query + distributedProjectionIDs collection.Mapping[ifaces.QueryID, int] `gnark:"-"` // Columns stores the gnark witness part corresponding to the columns // provided in the proof and in the VerifyingKey. @@ -95,6 +98,9 @@ type WizardVerifierCircuit struct { // GrandProductParams stores an assignment for each [query.GrandProductParams] // from the proof. It is part of the witness of the gnark circuit. GrandProductParams []query.GnarkGrandProductParams `gnark:",secret"` + // DistributedProjectionParams stores an assignment for each [query.DistributedProjectionParams] + // from the proof. It is part of the witness of the gnark circuit. + DistributedProjectionParams []query.GnarkDistributedProjectionParams `gnark:",secret"` // FS is the Fiat-Shamir state, mirroring [VerifierRuntime.FS]. The same // cautionnary rules apply to it; e.g. don't use it externally when @@ -369,6 +375,14 @@ func (c *WizardVerifierCircuit) GetGrandProductParams(name ifaces.QueryID) query return c.GrandProductParams[qID] } +// GetDistributedProjectionParams returns the parameters for the requested +// [query.DistributedProjection] query. Its work mirrors the function +// [VerifierRuntime.GetDistributedProjectionParams] +func (c *WizardVerifierCircuit) GetDistributedProjectionParams(name ifaces.QueryID) query.GnarkDistributedProjectionParams { + qID := c.distributedProjectionIDs.MustGet(name) + return c.DistributedProjectionParams[qID] +} + // GetColumns returns the gnark assignment of a column in a gnark circuit. It // mirrors the function [VerifierRuntime.GetColumn] func (c *WizardVerifierCircuit) GetColumn(name ifaces.ColID) []frontend.Variable { diff --git a/prover/protocol/wizard/prover.go b/prover/protocol/wizard/prover.go index 5b1aa151e..d555801f3 100644 --- a/prover/protocol/wizard/prover.go +++ b/prover/protocol/wizard/prover.go @@ -212,7 +212,7 @@ func ProverOnlyFirstRound(c *CompiledIOP, highLevelprover ProverStep) *ProverRun // highLevelprover(&runtime) - // Then, run the compiled prover steps. This will only run thoses of the + // Then, run the compiled prover steps. This will only run those of the // first round. // runtime.runProverSteps() @@ -831,3 +831,26 @@ func (run *ProverRuntime) AssignGrandProduct(name ifaces.QueryID, y field.Elemen params := query.NewGrandProductParams(y) run.QueriesParams.InsertNew(name, params) } + +// AssignDistributedProjection assigns the value horner(A, fA) if A +// is in the module, horner(B, fB)^{-1} if B is in the module, and +// horner(A, fA) * horner(B, fB)^{-1} if both are in the module. +// The function will panic if: +// - the parameters were already assigned +// - the specified query is not registered +// - the assignment round is incorrect +func (run *ProverRuntime) AssignDistributedProjection(name ifaces.QueryID, distributedProjectionParam query.DistributedProjectionParams) { + + // Global prover locks for accessing the maps + run.lock.Lock() + defer run.lock.Unlock() + + // Make sure, it is done at the right round + run.Spec.QueriesParams.MustBeInRound(run.currRound, name) + + // Adds it to the assignments + params := query.NewDistributedProjectionParams(distributedProjectionParam.ScaledHorner, + distributedProjectionParam.HashCumSumOnePrev, + distributedProjectionParam.HashCumSumOneCurr) + run.QueriesParams.InsertNew(name, params) +} diff --git a/prover/protocol/wizard/verifier.go b/prover/protocol/wizard/verifier.go index aaacfe165..9767609ae 100644 --- a/prover/protocol/wizard/verifier.go +++ b/prover/protocol/wizard/verifier.go @@ -4,7 +4,9 @@ import ( "github.com/consensys/linea-monorepo/prover/crypto/fiatshamir" "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/accessors" "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/distributed/constants" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/utils" @@ -33,14 +35,23 @@ type Proof struct { QueriesParams collection.Mapping[ifaces.QueryID, ifaces.QueryParams] } +// DistributedProjectionPublicInput is a struct that holds the public inputs +// for the distributed projection protocol. +type DistributedProjectionPublicInput struct { + ScaledHorner field.Element + CumSumPrev field.Element + CumSumCurr field.Element +} + // Runtime is a generic interface extending the [ifaces.Runtime] interface // with all methods of [wizard.VerifierRuntime]. This is used to allow the // writing of adapters for the verifier runtime. type Runtime interface { ifaces.Runtime GetSpec() *CompiledIOP - GetPublicInput(name string) field.Element + GetPublicInput(name string) any GetGrandProductParams(name ifaces.QueryID) query.GrandProductParams + GetDistributedProjectionParams(name ifaces.QueryID) query.DistributedProjectionParams GetLogDerivSumParams(name ifaces.QueryID) query.LogDerivSumParams GetLocalPointEvalParams(name ifaces.QueryID) query.LocalOpeningParams GetInnerProductParams(name ifaces.QueryID) query.InnerProductParams @@ -433,6 +444,11 @@ func (run *VerifierRuntime) GetGrandProductParams(name ifaces.QueryID) query.Gra return run.QueriesParams.MustGet(name).(query.GrandProductParams) } +// GetGrandProductParams returns the parameters of a [query.DistributedProjection] +func (run *VerifierRuntime) GetDistributedProjectionParams(name ifaces.QueryID) query.DistributedProjectionParams { + return run.QueriesParams.MustGet(name).(query.DistributedProjectionParams) +} + /* CopyColumnInto implements `column.GetWitness` Copies the witness into a slice @@ -469,7 +485,7 @@ func (run VerifierRuntime) GetColumnAt(name ifaces.ColID, pos int) field.Element wit := run.Columns.MustGet(name) if pos >= wit.Len() || pos < 0 { - utils.Panic("asked pos %v for vector of size %v", pos, wit) + utils.Panic("asked pos %v for vector of size %v", pos, wit.Len()) } return wit.Get(pos) @@ -485,9 +501,18 @@ func (run *VerifierRuntime) GetParams(name ifaces.QueryID) ifaces.QueryParams { } // GetPublicInput returns a public input from its name -func (run *VerifierRuntime) GetPublicInput(name string) field.Element { +func (run *VerifierRuntime) GetPublicInput(name string) any { allPubs := run.Spec.PublicInputs for i := range allPubs { + if allPubs[i].Name == name && name == constants.DistributedProjectionPublicInput { + if s, ok := allPubs[i].Acc.(*accessors.FromDistributedProjectionAccessor); ok { + return DistributedProjectionPublicInput{ + ScaledHorner: s.GetVal(run), + CumSumCurr: s.GetValCumSumCurr(run), + CumSumPrev: s.GetValCumSumPrev(run), + } + } + } if allPubs[i].Name == name { return allPubs[i].Acc.GetVal(run) } diff --git a/prover/utils/collection/mapping.go b/prover/utils/collection/mapping.go index 54d89ae63..462c36c84 100644 --- a/prover/utils/collection/mapping.go +++ b/prover/utils/collection/mapping.go @@ -2,6 +2,7 @@ package collection import ( "fmt" + "reflect" "github.com/consensys/linea-monorepo/prover/utils" ) @@ -131,3 +132,26 @@ func (kv *Mapping[K, V]) TryDel(k K) bool { } return found } + +// AppendNew appends a new value v to V if V is a slice +func (kv *Mapping[K, V]) AppendNew(k K, v V) { + // Get the current value from the map + currentValue, exists := kv.innerMap[k] + if !exists { + // If the key does not exist, initialize it with an empty slice of type V + kv.innerMap[k] = v + return + } + + // Use reflection to check if the current value is a slice + currentValueValue := reflect.ValueOf(currentValue) + if currentValueValue.Kind() == reflect.Slice { + // Append the new value to the slice + newSlice := reflect.Append(currentValueValue, reflect.ValueOf(v)) + // Set the updated slice back to the map + kv.innerMap[k] = newSlice.Interface().(V) + } else { + // If V is not a slice, handle the error + fmt.Println("Error: V is not a slice") + } +}