Skip to content

Commit

Permalink
chore: Add initial RecoverCellsAndComputeKZGProofs method (#79)
Browse files Browse the repository at this point in the history
* chore: add bitReverseInt method

* chore: add rough recovery code

* chore: add impl for RecoverCellsAndComputeKZGProofs

* chore: add consensus-spec tests

* chore: linter

* chore: remove constants from recovery.go and use a structure

* chore: remove mention of the word `Cell`

* chore: update API

* chore: use new method to determine where we cn do reconstruction

* chore: use BlockErasureIndex for missingIndice and misc cleanup

* chore: typo fix
  • Loading branch information
kevaundray authored Jun 21, 2024
1 parent e381910 commit ddc17c3
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 10 deletions.
8 changes: 8 additions & 0 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"

"github.com/crate-crypto/go-eth-kzg/internal/kzg"
kzgmulti "github.com/crate-crypto/go-eth-kzg/internal/kzg_multi"
)

// Context holds the necessary configuration needed to create and verify proofs.
Expand All @@ -16,6 +17,8 @@ type Context struct {
commitKeyLagrange *kzg.CommitKey
commitKeyMonomial *kzg.CommitKey
openKey *kzg.OpeningKey

dataRecovery *kzgmulti.DataRecovery
}

// BlsModulus is the bytes representation of the bls12-381 scalar field modulus.
Expand Down Expand Up @@ -131,5 +134,10 @@ func NewContext4096(trustedSetup *JSONTrustedSetup) (*Context, error) {
commitKeyLagrange: &commitKeyLagrange,
commitKeyMonomial: &commitKeyMonomial,
openKey: &openingKey,
// TODO: We compute the extendedDomain again in here.
// TODO: We could pass it in, but it breaks the API.
// TODO: And although its not an issue now because fft uses just the primitiveGenerator, the extended domain
// TODO: that recovery takes is not bit reversed.
dataRecovery: kzgmulti.NewDataRecovery(scalarsPerCell, ScalarsPerBlob, expansionFactor),
}, nil
}
96 changes: 95 additions & 1 deletion api_eip7594.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package goethkzg

import (
"errors"

"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
"github.com/crate-crypto/go-eth-kzg/internal/kzg"
kzgmulti "github.com/crate-crypto/go-eth-kzg/internal/kzg_multi"
Expand Down Expand Up @@ -64,7 +66,72 @@ func (ctx *Context) computeCellsAndKZGProofsFromPolyCoeff(polyCoeff []fr.Element

//lint:ignore U1000 still fleshing out the API
func (ctx *Context) RecoverCellsAndComputeKZGProofs(cellIDs []uint64, cells []*Cell, _proofs []KZGProof, numGoRoutines int) ([CellsPerExtBlob]*Cell, [CellsPerExtBlob]KZGProof, error) {
return [CellsPerExtBlob]*Cell{}, [CellsPerExtBlob]KZGProof{}, nil
// Check each proof can be deserialized
// TODO: This seems somewhat useless as we should not be calling this method with proofs
// TODO: that are not valid.
for _, proof := range _proofs {
_, err := DeserializeKZGProof(proof)
if err != nil {
return [CellsPerExtBlob]*Cell{}, [CellsPerExtBlob]KZGProof{}, err
}
}

if len(cellIDs) != len(cells) {
return [CellsPerExtBlob]*Cell{}, [CellsPerExtBlob]KZGProof{}, errors.New("number of cell IDs should be equal to the number of cells")
}
if len(cellIDs) != len(_proofs) {
return [CellsPerExtBlob]*Cell{}, [CellsPerExtBlob]KZGProof{}, errors.New("number of cell IDs should be equal to the number of proofs")
}

// Check that the cell Ids are unique
if !isUniqueUint64(cellIDs) {
return [CellsPerExtBlob]*Cell{}, [CellsPerExtBlob]KZGProof{}, errors.New("cell IDs should be unique")
}

// Check that each CellId is less than CellsPerExtBlob
for _, cellID := range cellIDs {
if cellID >= CellsPerExtBlob {
return [CellsPerExtBlob]*Cell{}, [CellsPerExtBlob]KZGProof{}, errors.New("cell ID should be less than CellsPerExtBlob")
}
}

// Check that we have enough cells to perform reconstruction
if len(cellIDs) < ctx.dataRecovery.NumBlocksNeededToReconstruct() {
return [CellsPerExtBlob]*Cell{}, [CellsPerExtBlob]KZGProof{}, errors.New("not enough cells to perform reconstruction")
}

// Find the missing cell IDs and bit reverse them
// So that they are in normal order
missingCellIds := make([]uint64, 0, CellsPerExtBlob)
for cellID := uint64(0); cellID < CellsPerExtBlob; cellID++ {
if !containsUint64(cellIDs, cellID) {
missingCellIds = append(missingCellIds, (kzg.BitReverseInt(cellID, CellsPerExtBlob)))
}
}

// Convert Cells to field elements
extendedBlob := make([]fr.Element, scalarsPerExtBlob)
// for each cellId, we get the corresponding cell in cells
// then use the cellId to place the cell in the correct position in the data(extendedBlob) array
for i, cellID := range cellIDs {
cell := cells[i]
// Deserialize the cell
cellEvals, err := deserializeCell(cell)
if err != nil {
return [CellsPerExtBlob]*Cell{}, [CellsPerExtBlob]KZGProof{}, err
}
// Place the cell in the correct position in the data array
copy(extendedBlob[cellID*scalarsPerCell:], cellEvals)
}
// Bit reverse the extendedBlob so that it is in normal order
kzg.BitReverse(extendedBlob)

polyCoeff, err := ctx.dataRecovery.RecoverPolynomialCoefficients(extendedBlob, missingCellIds)
if err != nil {
return [CellsPerExtBlob]*Cell{}, [CellsPerExtBlob]KZGProof{}, err
}

return ctx.computeCellsAndKZGProofsFromPolyCoeff(polyCoeff, numGoRoutines)
}

//lint:ignore U1000 still fleshing out the API
Expand Down Expand Up @@ -147,3 +214,30 @@ func partition(slice []fr.Element, k int) [][]fr.Element {

return result
}

// TODO: in go 1.21, we can use slice.Contains and remove this method
func containsUint64(u64Slice []uint64, element uint64) bool {
for _, v := range u64Slice {
if v == element {
return true
}
}

return false
}

func isUniqueUint64(slice []uint64) bool {
elementMap := make(map[uint64]bool)

for _, element := range slice {
if elementMap[element] {
// Element already seen
return false
}
// Mark the element as seen
elementMap[element] = true
}

// All elements are unique
return true
}
60 changes: 60 additions & 0 deletions consensus_specs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ var (
computeCellsAndKZGProofsTests = filepath.Join(testDir, "compute_cells_and_kzg_proofs/*/*/*")
verifyCellKZGProofTests = filepath.Join(testDir, "verify_cell_kzg_proof/*/*/*")
verifyCellKZGProofBatchTests = filepath.Join(testDir, "verify_cell_kzg_proof_batch/*/*/*")
recoverCellsAndKZGProofsTests = filepath.Join(testDir, "recover_cells_and_kzg_proofs/*/*/*")
)

func TestBlobToKZGCommitment(t *testing.T) {
Expand Down Expand Up @@ -526,6 +527,65 @@ func TestVerifyCellKZGProofBatch(t *testing.T) {
}
}

func TestRecoverCellsAndKZGProofs(t *testing.T) {
type Test struct {
Input struct {
CellIds []uint64 `yaml:"cell_ids"`
Cells []string `yaml:"cells"`
Proofs []string `yaml:"proofs"`
}
Output *[][]string `yaml:"output"`
}

tests, err := filepath.Glob(recoverCellsAndKZGProofsTests)
require.NoError(t, err)
require.True(t, len(tests) > 0)

for _, testPath := range tests {
t.Run(testPath, func(t *testing.T) {
testFile, err := os.Open(testPath)
require.NoError(t, err)
test := Test{}
err = yaml.NewDecoder(testFile).Decode(&test)
require.NoError(t, testFile.Close())
require.NoError(t, err)
testCaseValid := test.Output != nil

cellIds := test.Input.CellIds

cells, err := hexStrArrToCells(test.Input.Cells)
if err != nil {
require.False(t, testCaseValid)
return
}

proofs, err := HexStrArrToProofs(test.Input.Proofs)
if err != nil {
require.False(t, testCaseValid)
return
}

recoveredCells, recoveredProofs, err := ctx.RecoverCellsAndComputeKZGProofs(cellIds, cells, proofs, 0)

if err == nil {
require.NotNil(t, test.Output)
expectedCellStrs := (*test.Output)[0]
expectedCells, err := hexStrArrToCells(expectedCellStrs)
require.NoError(t, err)

require.Equal(t, expectedCells, recoveredCells[:])

expectedProofStrs := (*test.Output)[1]
expectedProofs, err := HexStrArrToProofs(expectedProofStrs)
require.NoError(t, err)
require.Equal(t, expectedProofs, recoveredProofs[:])
} else {
require.Nil(t, test.Output)
}
})
}
}

func hexStrToCell(hexStr string) (*goethkzg.Cell, error) {
var cell goethkzg.Cell
byts, err := hexStrToBytes(hexStr)
Expand Down
22 changes: 13 additions & 9 deletions internal/kzg/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,24 +135,28 @@ to think about all these when you add DAS.
// [reverse_bits]: https://github.com/ethereum/consensus-specs/blob/017a8495f7671f5fff2075a9bfc9238c1a0982f8/specs/deneb/polynomial-commitments.md#reverse_bits
func BitReverse[K interface{}](list []K) {
n := uint64(len(list))
if !utils.IsPowerOfTwo(n) {
panic("size of list given to bitReverse must be a power of two")
}

// The standard library's bits.Reverse64 inverts its input as a 64-bit unsigned integer.
// However, we need to invert it as a log2(len(list))-bit integer, so we need to correct this by
// shifting appropriately.
shiftCorrection := uint64(64 - bits.TrailingZeros64(n))

for i := uint64(0); i < n; i++ {
// Find index irev, such that i and irev get swapped
irev := bits.Reverse64(i) >> shiftCorrection
irev := BitReverseInt(i, n)
if irev > i {
list[i], list[irev] = list[irev], list[i]
}
}
}

func BitReverseInt(k, bitsize uint64) uint64 {
if !utils.IsPowerOfTwo(bitsize) {
panic("bitsize given to bitReverse must be a power of two")
}

// The standard library's bits.Reverse64 inverts its input as a 64-bit unsigned integer.
// However, we need to invert it as a log2(len(list))-bit integer, so we need to correct this by
// shifting appropriately.
shiftCorrection := uint64(64 - bits.TrailingZeros64(bitsize))
return bits.Reverse64(k) >> shiftCorrection
}

// ReverseRoots applies the bit-reversal permutation to the list of precomputed roots of unity and their inverses in the domain.
//
// [bit_reversal_permutation]: https://github.com/ethereum/consensus-specs/blob/017a8495f7671f5fff2075a9bfc9238c1a0982f8/specs/deneb/polynomial-commitments.md#bit_reversal_permutation
Expand Down
121 changes: 121 additions & 0 deletions internal/kzg_multi/recovery.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package kzgmulti

import (
"errors"

"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
"github.com/crate-crypto/go-eth-kzg/internal/kzg"
)

// BlockErasureIndex is used to indicate the index of the block erasure that is missing
// from the codeword.
type BlockErasureIndex = uint64

// DataRecovery implements a unique decoding algorithm.
//
// The algorithm is not generic and is specific to the use-case where:
// - We have block erasures. ie we do not lose data in random locations, but in contiguous chunks.
// The chunks themselves are predetermined.
type DataRecovery struct {
// rootsOfUnityBlockErasureIndex is a domain that corresponds to the number of blocks
// that we can have in the codeword.
rootsOfUnityBlockErasureIndex *kzg.Domain
domainExtended *kzg.Domain
// blockErasureSize indicates the size of `blocks of evaluations` that
// can be missing. For example, if blockErasureSize is 4, then 4 evaluations
// can be missing, or 8 or 16.
//
// This is contrary to a general unique decoding algorithm where the number of
// missing elements can be any number and the evaluations do not need to
// be in blocks.
blockErasureSize int
// numScalarsInCodeword is the number of scalars in the codeword
// ie the number of scalars we get when we get encode the data.
numScalarsInCodeword int
// numScalarsInDataWord is the number of scalars in the message
// that we encode.
numScalarsInDataWord int
// expansionFactor is the factor by which the data word or message is expanded
expansionFactor int
// totalNumBlocks is the total number of blocks(groups of evaluations) in the codeword
totalNumBlocks int
}

func NewDataRecovery(blockErasureSize, numScalarsInDataWord, expansionFactor int) *DataRecovery {
// Compute the number of scalars that will be in the codeword
numScalarsInCodeword := numScalarsInDataWord * expansionFactor

// Compute the total number of blocks that we will need to
// represent the codeword
totalNumBlocks := numScalarsInCodeword / blockErasureSize

rootsOfUnityBlockErasureIndex := kzg.NewDomain(uint64(totalNumBlocks))
domainExtended := kzg.NewDomain(uint64(numScalarsInCodeword))

return &DataRecovery{
rootsOfUnityBlockErasureIndex: rootsOfUnityBlockErasureIndex,
domainExtended: domainExtended,
blockErasureSize: blockErasureSize,
numScalarsInCodeword: numScalarsInCodeword,
totalNumBlocks: totalNumBlocks,
numScalarsInDataWord: numScalarsInDataWord,
expansionFactor: expansionFactor,
}
}

// Note: These blockErasure indices should not be in bit reversed order
func (dr *DataRecovery) constructVanishingPolyOnIndices(missingBlockErasureIndices []BlockErasureIndex) []fr.Element {
// Collect all of the roots that are associated with the missing block erasure indices
missingBlockErasureIndexRoots := make([]fr.Element, len(missingBlockErasureIndices))
for i, index := range missingBlockErasureIndices {
missingBlockErasureIndexRoots[i] = dr.rootsOfUnityBlockErasureIndex.Roots[index]
}

shortZeroPoly := vanishingPolyCoeff(missingBlockErasureIndexRoots)

zeroPolyCoeff := make([]fr.Element, dr.numScalarsInCodeword)
for i, coeff := range shortZeroPoly {
zeroPolyCoeff[i*dr.blockErasureSize] = coeff
}

return zeroPolyCoeff
}

// NumBlocksNeededToReconstruct returns the number of blocks that are needed to reconstruct
// the original data word.
func (dr *DataRecovery) NumBlocksNeededToReconstruct() int {
return dr.numScalarsInDataWord / dr.blockErasureSize
}

func (dr *DataRecovery) RecoverPolynomialCoefficients(data []fr.Element, missingIndices []BlockErasureIndex) ([]fr.Element, error) {
zX := dr.constructVanishingPolyOnIndices(missingIndices)

zXEval := dr.domainExtended.FftFr(zX)

if len(zXEval) != len(data) {
return nil, errors.New("length of data and zXEval should be equal")
}

eZEval := make([]fr.Element, len(data))
for i := 0; i < len(data); i++ {
eZEval[i].Mul(&data[i], &zXEval[i])
}

dzPoly := dr.domainExtended.IfftFr(eZEval)

cosetZxEval := dr.domainExtended.CosetFFtFr(zX)
cosetDzEVal := dr.domainExtended.CosetFFtFr(dzPoly)

cosetQuotientEval := make([]fr.Element, len(cosetZxEval))
cosetZxEval = fr.BatchInvert(cosetZxEval)

for i := 0; i < len(cosetZxEval); i++ {
cosetQuotientEval[i].Mul(&cosetDzEVal[i], &cosetZxEval[i])
}

polyCoeff := dr.domainExtended.CosetIFFtFr(cosetQuotientEval)

// Truncate the polynomial coefficients to the number of scalars in the data word
polyCoeff = polyCoeff[:dr.numScalarsInDataWord]
return polyCoeff, nil
}

0 comments on commit ddc17c3

Please sign in to comment.