Skip to content

Commit

Permalink
test for multiple global
Browse files Browse the repository at this point in the history
  • Loading branch information
Soleimani193 committed Feb 20, 2025
1 parent fc90254 commit 096c905
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 83 deletions.
125 changes: 65 additions & 60 deletions prover/protocol/distributed/compiler/global/global.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,23 @@ func DistributeGlobal(in DistributionInputs) {

var (
bInputs = boundaryInputs{
moduleComp: in.ModuleComp,
numSegments: in.NumSegments,
provider: in.ModuleComp.Columns.GetHandle("PROVIDER"),
receiver: in.ModuleComp.Columns.GetHandle("RECEIVER"),
providerOpenings: []query.LocalOpening{},
receiverOpenings: []query.LocalOpening{},
moduleComp: in.ModuleComp,
numSegments: in.NumSegments,

provider: boundaries{
boundaryCol: in.ModuleComp.Columns.GetHandle("PROVIDER"),
lastPosOnBoundaryCol: 0,
boundaryOpenings: collection.NewMapping[query.LocalOpening, int](),
},
receiver: boundaries{
boundaryCol: in.ModuleComp.Columns.GetHandle("RECEIVER"),
boundaryOpenings: collection.NewMapping[query.LocalOpening, int](),
lastPosOnBoundaryCol: 0,
},
}

provider = bInputs.provider.boundaryCol
receiver = bInputs.receiver.boundaryCol
)

for _, qName := range in.InitialComp.QueriesNoParams.AllUnignoredKeys() {
Expand All @@ -55,7 +65,7 @@ func DistributeGlobal(in DistributionInputs) {
if in.Disc.ExpressionIsInModule(q.Expression, in.ModuleName) {

// apply global constraint over the segment.
in.ModuleComp.InsertGlobal(0,
in.ModuleComp.InsertGlobal(constants.RoundGL,
q.ID,
AdjustExpressionForGlobal(in.ModuleComp, q.Expression, in.NumSegments),
)
Expand All @@ -73,17 +83,17 @@ func DistributeGlobal(in DistributionInputs) {

// get the hash of the provider and the receiver
var (
colOnes = verifiercol.NewConstantCol(field.One(), bInputs.provider.Size())
mimcHasherProvider = edc.NewMIMCHasher(in.ModuleComp, bInputs.provider, colOnes, "MIMC_HASHER_PROVIDER")
mimicHasherReceiver = edc.NewMIMCHasher(in.ModuleComp, bInputs.receiver, colOnes, "MIMC_HASHER_RECEIVER")
colOnes = verifiercol.NewConstantCol(field.One(), provider.Size())
mimcHasherProvider = edc.NewMIMCHasher(in.ModuleComp, provider, colOnes, "MIMC_HASHER_PROVIDER")
mimcHasherReceiver = edc.NewMIMCHasher(in.ModuleComp, receiver, colOnes, "MIMC_HASHER_RECEIVER")
)

mimcHasherProvider.DefineHasher(in.ModuleComp, "DISTRIBUTED_GLOBAL_QUERY_MIMC_HASHER_PROVIDER")
mimcHasherProvider.DefineHasher(in.ModuleComp, "DISTRIBUTED_GLOBAL_QUERY_MIMC_HASHER_RECEIVER")
mimcHasherReceiver.DefineHasher(in.ModuleComp, "DISTRIBUTED_GLOBAL_QUERY_MIMC_HASHER_RECEIVER")

var (
openingHashProvider = in.ModuleComp.InsertLocalOpening(0, "ACCESSOR_FROM_HASH_PROVIDER", mimcHasherProvider.HashFinal)
openingHashReceiver = in.ModuleComp.InsertLocalOpening(0, "ACCESSOR_FROM_HASH_RECEIVER", mimicHasherReceiver.HashFinal)
openingHashProvider = in.ModuleComp.InsertLocalOpening(constants.RoundGL, "ACCESSOR_FROM_HASH_PROVIDER", mimcHasherProvider.HashFinal)
openingHashReceiver = in.ModuleComp.InsertLocalOpening(constants.RoundGL, "ACCESSOR_FROM_HASH_RECEIVER", mimcHasherReceiver.HashFinal)
)

// declare the hash of the provider/receiver as the public inputs.
Expand All @@ -96,30 +106,37 @@ func DistributeGlobal(in DistributionInputs) {
in.ModuleComp.PublicInputs = append(in.ModuleComp.PublicInputs,
wizard.PublicInput{
Name: constants.GlobalReceiverPublicInput,
Acc: accessors.NewLocalOpeningAccessor(openingHashReceiver, 0),
Acc: accessors.NewLocalOpeningAccessor(openingHashReceiver, constants.RoundGL),
})

in.ModuleComp.RegisterProverAction(0, &proverActionForBoundaries{
provider: bInputs.provider,
receiver: bInputs.receiver,
providerOpenings: bInputs.providerOpenings,
receiverOpenings: bInputs.receiverOpenings,

mimicHasherProvider: *mimcHasherProvider,
mimicHasherReceiver: *mimicHasherReceiver,
hashOpeningProvider: openingHashProvider,
hashOpeningReceiver: openingHashReceiver,
in.ModuleComp.RegisterProverAction(constants.RoundGL, &proverActionForBoundaries{
provider: boundaryAssignments{
boundaries: bInputs.provider,
hashOpening: openingHashProvider,
mimcHash: *mimcHasherProvider,
},

receiver: boundaryAssignments{
boundaries: bInputs.receiver,
hashOpening: openingHashReceiver,
mimcHash: *mimcHasherReceiver,
},
})

}

type boundaryInputs struct {
moduleComp *wizard.CompiledIOP
numSegments int
provider ifaces.Column
receiver ifaces.Column
providerOpenings, receiverOpenings []query.LocalOpening
segID int
moduleComp *wizard.CompiledIOP
provider boundaries
receiver boundaries
numSegments int
segID int
}

type boundaries struct {
boundaryCol ifaces.Column
boundaryOpenings collection.Mapping[query.LocalOpening, int]
lastPosOnBoundaryCol int
}

func AdjustExpressionForGlobal(
Expand Down Expand Up @@ -168,7 +185,7 @@ func AdjustExpressionForGlobal(

if m.T > segSize {

panic("unsupported, since this depends on the segment ID, unless the module discoverer can detect such cases")
panic("unsupported")
}
translationMap.InsertNew(m.String(), symbolic.NewVariable(metadata))
default:
Expand All @@ -185,10 +202,9 @@ func BoundariesForProvider(in *boundaryInputs, q query.GlobalConstraint) {
var (
board = q.Board()
offsetRange = q.MinMaxOffset()
provider = in.provider
maxShift = offsetRange.Max
colsInExpr = distributed.ListColumnsFromExpr(q.Expression, false)
colsOnProvider = onBoundaries(colsInExpr, maxShift)
colsOnProvider = onBoundaries(colsInExpr, maxShift, &in.provider)
numBoundaries = offsetRange.Max - offsetRange.Min
size = column.ExprIsOnSameLengthHandles(&board)
segSize = size / in.numSegments
Expand All @@ -204,9 +220,9 @@ func BoundariesForProvider(in *boundaryInputs, q query.GlobalConstraint) {
// take it via accessor.
var (
index = pos[0] + i
name = ifaces.QueryIDf("%v_%v", "FROM_PROVIDER_AT", index)
loProvider = in.moduleComp.InsertLocalOpening(0, name, column.Shift(provider, index))
accessorProvider = accessors.NewLocalOpeningAccessor(loProvider, 0)
name = ifaces.QueryIDf("%v_%v_%v", q.ID, "FROM_PROVIDER_AT", index)
loProvider = in.moduleComp.InsertLocalOpening(constants.RoundGL, name, column.Shift(in.provider.boundaryCol, index))
accessorProvider = accessors.NewLocalOpeningAccessor(loProvider, constants.RoundGL)
indexOnCol = segSize - (maxShift - column.StackOffsets(col) - i)
nameExpr = ifaces.QueryIDf("%v_%v_%v", "CONSISTENCY_AGAINST_PROVIDER", col.GetColID(), i)
colInModule ifaces.Column
Expand All @@ -219,10 +235,10 @@ func BoundariesForProvider(in *boundaryInputs, q query.GlobalConstraint) {
colInModule = in.moduleComp.Columns.GetHandle(col.GetColID())
}

// add the localOpening to the list
in.providerOpenings = append(in.providerOpenings, loProvider)
// add the localOpening to the map
in.provider.boundaryOpenings.InsertNew(loProvider, index)
// impose that loProvider = loCol
in.moduleComp.InsertLocal(0, nameExpr,
in.moduleComp.InsertLocal(constants.RoundGL, nameExpr,
symbolic.Sub(accessorProvider, column.Shift(colInModule, indexOnCol)),
)

Expand All @@ -237,15 +253,12 @@ func BoundariesForReceiver(in *boundaryInputs, q query.GlobalConstraint) {

var (
offsetRange = q.MinMaxOffset()
receiver = in.receiver
maxShift = offsetRange.Max
colsInExpr = distributed.ListColumnsFromExpr(q.Expression, false)
colsOnReceiver = onBoundaries(colsInExpr, maxShift)
colsOnReceiver = onBoundaries(colsInExpr, maxShift, &in.receiver)
numBoundaries = offsetRange.Max - offsetRange.Min
comp = in.moduleComp
colInModule ifaces.Column
// list of local openings by the boundary index
allLists = make([][]query.LocalOpening, numBoundaries)
)

for i := 0; i < numBoundaries; i++ {
Expand All @@ -269,12 +282,12 @@ func BoundariesForReceiver(in *boundaryInputs, q query.GlobalConstraint) {
// take it via accessor.
var (
index = pos[0] + i
name = ifaces.QueryIDf("%v_%v", "FROM_RECEIVER_AT", index)
lo = comp.InsertLocalOpening(0, name, column.Shift(receiver, index))
accessor = accessors.NewLocalOpeningAccessor(lo, 0)
name = ifaces.QueryIDf("%v_%v_%v", q.ID, "FROM_RECEIVER_AT", index)
lo = comp.InsertLocalOpening(constants.RoundGL, name, column.Shift(in.receiver.boundaryCol, index))
accessor = accessors.NewLocalOpeningAccessor(lo, constants.RoundGL)
)
// add the localOpening to the list
allLists[i] = append(allLists[i], lo)
// add the localOpening to the map
in.receiver.boundaryOpenings.InsertNew(lo, index)
// in.receiverOpenings = append(in.receiverOpenings, lo)
// translate the column
translationMap.InsertNew(string(col.GetColID()), accessor.AsVariable())
Expand All @@ -295,28 +308,19 @@ func BoundariesForReceiver(in *boundaryInputs, q query.GlobalConstraint) {
if in.segID != 0 || q.NoBoundCancel {
expr := q.Expression.Replay(translationMap)
name := ifaces.QueryIDf("%v_%v_%v", "CONSISTENCY_AGAINST_RECEIVER", q.ID, i)
comp.InsertLocal(0, name, expr)
comp.InsertLocal(constants.RoundGL, name, expr)
}

}

// order receiverOpenings column by column
for i := 0; i < numBoundaries; i++ {
for _, list := range allLists {
if len(list) > i {
in.receiverOpenings = append(in.receiverOpenings, list[i])
}
}
}

}

// it indicates the column list having the provider cells (i.e.,
// some cells of the columns are needed to be provided to the next segment)
func onBoundaries(colsInExpr []ifaces.Column, maxShift int) collection.Mapping[ifaces.ColID, [2]int] {
func onBoundaries(colsInExpr []ifaces.Column, maxShift int, b *boundaries) collection.Mapping[ifaces.ColID, [2]int] {

var (
ctr = 0
ctr = b.lastPosOnBoundaryCol
colsOnReceiver = collection.NewMapping[ifaces.ColID, [2]int]()
)
for _, col := range colsInExpr {
Expand All @@ -334,6 +338,7 @@ func onBoundaries(colsInExpr []ifaces.Column, maxShift int) collection.Mapping[i

}

b.lastPosOnBoundaryCol = ctr
return colsOnReceiver

}
Expand Down
11 changes: 11 additions & 0 deletions prover/protocol/distributed/compiler/global/global_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ func TestDistributedGlobal(t *testing.T) {
col1 = b.CompiledIOP.InsertCommit(0, "module.col1", 8)
col2 = b.CompiledIOP.InsertCommit(0, "module.col2", 8)
col3 = b.CompiledIOP.InsertCommit(0, "module.col3", 8)

fibonacci = b.CompiledIOP.InsertCommit(0, "module.fibo", 16)
)

b.CompiledIOP.InsertGlobal(0, "global0",
Expand All @@ -44,6 +46,13 @@ func TestDistributedGlobal(t *testing.T) {
),
)

b.CompiledIOP.InsertGlobal(0, "fibonacci",
symbolic.Sub(
fibonacci,
column.Shift(fibonacci, -1),
column.Shift(fibonacci, -2)),
)

}

// initialProver
Expand All @@ -53,6 +62,8 @@ func TestDistributedGlobal(t *testing.T) {
run.AssignColumn("module.col2", smartvectors.ForTest(7, 0, 1, 3, 0, 4, 1, 0))
run.AssignColumn("module.col3", smartvectors.ForTest(2, 14, 0, 2, 3, 0, 10, 0))

run.AssignColumn("module.fibo", smartvectors.ForTest(1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987))

}

// initial compiledIOP is the parent to all the SegmentModuleComp objects.
Expand Down
54 changes: 33 additions & 21 deletions prover/protocol/distributed/compiler/global/prover.go
Original file line number Diff line number Diff line change
@@ -1,45 +1,57 @@
package global

import (
"github.com/consensys/linea-monorepo/prover/protocol/ifaces"
"fmt"

"github.com/consensys/linea-monorepo/prover/maths/common/vector"
"github.com/consensys/linea-monorepo/prover/protocol/query"
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
edc "github.com/consensys/linea-monorepo/prover/zkevm/prover/publicInput/execution_data_collector"
)

type boundaryAssignments struct {
boundaries boundaries
hashOpening query.LocalOpening
mimcHash edc.MIMCHasher
}
type proverActionForBoundaries struct {
provider ifaces.Column
receiver ifaces.Column
providerOpenings []query.LocalOpening
receiverOpenings []query.LocalOpening

hashOpeningProvider query.LocalOpening
hashOpeningReceiver query.LocalOpening
mimicHasherProvider edc.MIMCHasher
mimicHasherReceiver edc.MIMCHasher
provider boundaryAssignments
receiver boundaryAssignments
}

// it assigns all the LocalOpening covering the boundaries
func (pa proverActionForBoundaries) Run(run *wizard.ProverRuntime) {
var (
providerWit = run.GetColumn(pa.provider.GetColID()).IntoRegVecSaveAlloc()
receiverWit = run.GetColumn(pa.receiver.GetColID()).IntoRegVecSaveAlloc()
provider = pa.provider.boundaries.boundaryCol
receiver = pa.receiver.boundaries.boundaryCol
providerOpenings = pa.provider.boundaries.boundaryOpenings
receiverOpenings = pa.receiver.boundaries.boundaryOpenings

providerWit = run.GetColumn(provider.GetColID()).IntoRegVecSaveAlloc()
receiverWit = run.GetColumn(receiver.GetColID()).IntoRegVecSaveAlloc()
)

for i := range pa.providerOpenings {
fmt.Printf("provider %v\n", vector.Prettify(providerWit))
fmt.Printf("receiver %v\n", vector.Prettify(receiverWit))

for _, loProvider := range providerOpenings.ListAllKeys() {
index := providerOpenings.MustGet(loProvider)
run.AssignLocalPoint(loProvider.ID, providerWit[index])
}

run.AssignLocalPoint(pa.providerOpenings[i].ID, providerWit[i])
run.AssignLocalPoint(pa.receiverOpenings[i].ID, receiverWit[i])
for _, loReceiver := range receiverOpenings.ListAllKeys() {
index := receiverOpenings.MustGet(loReceiver)
run.AssignLocalPoint(loReceiver.ID, receiverWit[index])
}

pa.mimicHasherProvider.AssignHasher(run)
pa.mimicHasherReceiver.AssignHasher(run)
pa.provider.mimcHash.AssignHasher(run)
pa.receiver.mimcHash.AssignHasher(run)

var (
hashProvider = run.GetColumnAt(pa.mimicHasherProvider.HashFinal.GetColID(), 0)
hashReceiver = run.GetColumnAt(pa.mimicHasherReceiver.HashFinal.GetColID(), 0)
hashProvider = run.GetColumnAt(pa.provider.mimcHash.HashFinal.GetColID(), 0)
hashReceiver = run.GetColumnAt(pa.receiver.mimcHash.HashFinal.GetColID(), 0)
)

run.AssignLocalPoint(pa.hashOpeningProvider.ID, hashProvider)
run.AssignLocalPoint(pa.hashOpeningReceiver.ID, hashReceiver)
run.AssignLocalPoint(pa.provider.hashOpening.ID, hashProvider)
run.AssignLocalPoint(pa.receiver.hashOpening.ID, hashReceiver)
}
3 changes: 3 additions & 0 deletions prover/protocol/distributed/conglomeration/conglomeration.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ func (ctx *recursionCtx) captureCompPreVortex(tmpl *wizard.CompiledIOP) {
_ = tmpl.GetPublicInputAccessor(constants.GrandProductPublicInput)
_ = tmpl.GetPublicInputAccessor(constants.GrandSumPublicInput)
_ = tmpl.GetPublicInputAccessor(constants.LogDerivativeSumPublicInput)

_ = tmpl.GetPublicInputAccessor(constants.GlobalProviderPublicInput)
_ = tmpl.GetPublicInputAccessor(constants.GlobalReceiverPublicInput)
)

ctx.LastRound = lastRound
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ func TestConglomerationPureVortexSingleRound(t *testing.T) {
builder.InsertPublicInput(constants.GrandProductPublicInput, accessors.NewConstant(field.NewElement(1)))
builder.InsertPublicInput(constants.GrandSumPublicInput, accessors.NewConstant(field.NewElement(0)))
builder.InsertPublicInput(constants.LogDerivativeSumPublicInput, accessors.NewConstant(field.NewElement(0)))

builder.InsertPublicInput(constants.GlobalProviderPublicInput, accessors.NewConstant(field.NewElement(0)))
builder.InsertPublicInput(constants.GlobalReceiverPublicInput, accessors.NewConstant(field.NewElement(0)))
}

prover := func(k int) func(run *wizard.ProverRuntime) {
Expand Down Expand Up @@ -172,6 +175,9 @@ func TestConglomerationPureVortexMultiRound(t *testing.T) {
builder.InsertPublicInput(constants.GrandProductPublicInput, accessors.NewConstant(field.NewElement(1)))
builder.InsertPublicInput(constants.GrandSumPublicInput, accessors.NewConstant(field.NewElement(0)))
builder.InsertPublicInput(constants.LogDerivativeSumPublicInput, accessors.NewConstant(field.NewElement(0)))

builder.InsertPublicInput(constants.GlobalProviderPublicInput, accessors.NewConstant(field.NewElement(0)))
builder.InsertPublicInput(constants.GlobalReceiverPublicInput, accessors.NewConstant(field.NewElement(0)))
}

prover := func(k int) func(run *wizard.ProverRuntime) {
Expand Down Expand Up @@ -232,6 +238,9 @@ func TestConglomerationLookup(t *testing.T) {
builder.InsertPublicInput(constants.GrandProductPublicInput, accessors.NewConstant(field.NewElement(1)))
builder.InsertPublicInput(constants.GrandSumPublicInput, accessors.NewConstant(field.NewElement(0)))
builder.InsertPublicInput(constants.LogDerivativeSumPublicInput, accessors.NewConstant(field.NewElement(0)))

builder.InsertPublicInput(constants.GlobalProviderPublicInput, accessors.NewConstant(field.NewElement(0)))
builder.InsertPublicInput(constants.GlobalReceiverPublicInput, accessors.NewConstant(field.NewElement(0)))
}

prover := func(k int) func(run *wizard.ProverRuntime) {
Expand Down
Loading

0 comments on commit 096c905

Please sign in to comment.