Skip to content

Commit

Permalink
Avoid passing large arrays on heap, also use less memory (#63)
Browse files Browse the repository at this point in the history
* all: reduce memory consumption and large arrays on heap

* reuse and recycle

* Update internal/kzg/kzg_prove.go

my bad

* Update internal/kzg/kzg_prove.go

* benchmark for blob deserialization

* benchmark for compute challenge

* undo certain api changes

* remove unused hashToBLSField

* lint: remove new line

---------

Co-authored-by: kevaundray <kevtheappdev@gmail.com>
  • Loading branch information
holiman and kevaundray authored Mar 11, 2024
1 parent 755f7aa commit 51e065e
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 57 deletions.
4 changes: 2 additions & 2 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestNonCanonicalSmoke(t *testing.T) {
blobGood := GetRandBlob(123456789)
blobBad := GetRandBlob(123456789)
unreducedScalar := nonCanonicalScalar(123445)
modifyBlob(&blobBad, unreducedScalar, 0)
modifyBlob(blobBad, unreducedScalar, 0)

commitment, err := ctx.BlobToKZGCommitment(blobGood, NumGoRoutines)
require.NoError(t, err)
Expand Down Expand Up @@ -74,7 +74,7 @@ func TestNonCanonicalSmoke(t *testing.T) {
err = ctx.VerifyBlobKZGProof(blobBad, commitment, blobProof)
require.Error(t, err, "expected an error since blob was not canonical")

err = ctx.VerifyBlobKZGProofBatch([]gokzg4844.Blob{blobBad}, []gokzg4844.KZGCommitment{commitment}, []gokzg4844.KZGProof{blobProof})
err = ctx.VerifyBlobKZGProofBatch([]gokzg4844.Blob{*blobBad}, []gokzg4844.KZGCommitment{commitment}, []gokzg4844.KZGProof{blobProof})
require.Error(t, err, "expected an error since blob was not canonical")
}

Expand Down
44 changes: 37 additions & 7 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
gokzg4844 "github.com/crate-crypto/go-kzg-4844"
"github.com/crate-crypto/go-kzg-4844/internal/kzg"
"github.com/stretchr/testify/require"
)

Expand All @@ -34,14 +35,14 @@ func GetRandFieldElement(seed int64) [32]byte {
return gokzg4844.SerializeScalar(r)
}

func GetRandBlob(seed int64) gokzg4844.Blob {
func GetRandBlob(seed int64) *gokzg4844.Blob {
var blob gokzg4844.Blob
bytesPerBlob := gokzg4844.ScalarsPerBlob * gokzg4844.SerializedScalarSize
for i := 0; i < bytesPerBlob; i += gokzg4844.SerializedScalarSize {
fieldElementBytes := GetRandFieldElement(seed + int64(i))
copy(blob[i:i+gokzg4844.SerializedScalarSize], fieldElementBytes[:])
}
return blob
return &blob
}

func Benchmark(b *testing.B) {
Expand All @@ -58,7 +59,7 @@ func Benchmark(b *testing.B) {
proof, err := ctx.ComputeBlobKZGProof(blob, commitment, NumGoRoutines)
require.NoError(b, err)

blobs[i] = blob
blobs[i] = *blob
commitments[i] = commitment
proofs[i] = proof
fields[i] = GetRandFieldElement(int64(i))
Expand All @@ -69,37 +70,43 @@ func Benchmark(b *testing.B) {
///////////////////////////////////////////////////////////////////////////

b.Run("BlobToKZGCommitment", func(b *testing.B) {
b.ReportAllocs()
for n := 0; n < b.N; n++ {
_, _ = ctx.BlobToKZGCommitment(blobs[0], NumGoRoutines)
_, _ = ctx.BlobToKZGCommitment(&blobs[0], NumGoRoutines)
}
})

b.Run("ComputeKZGProof", func(b *testing.B) {
b.ReportAllocs()
for n := 0; n < b.N; n++ {
_, _, _ = ctx.ComputeKZGProof(blobs[0], fields[0], NumGoRoutines)
_, _, _ = ctx.ComputeKZGProof(&blobs[0], fields[0], NumGoRoutines)
}
})

b.Run("ComputeBlobKZGProof", func(b *testing.B) {
b.ReportAllocs()
for n := 0; n < b.N; n++ {
_, _ = ctx.ComputeBlobKZGProof(blobs[0], commitments[0], NumGoRoutines)
_, _ = ctx.ComputeBlobKZGProof(&blobs[0], commitments[0], NumGoRoutines)
}
})

b.Run("VerifyKZGProof", func(b *testing.B) {
b.ReportAllocs()
for n := 0; n < b.N; n++ {
_ = ctx.VerifyKZGProof(commitments[0], fields[0], fields[1], proofs[0])
}
})

b.Run("VerifyBlobKZGProof", func(b *testing.B) {
b.ReportAllocs()
for n := 0; n < b.N; n++ {
_ = ctx.VerifyBlobKZGProof(blobs[0], commitments[0], proofs[0])
_ = ctx.VerifyBlobKZGProof(&blobs[0], commitments[0], proofs[0])
}
})

for i := 1; i <= len(blobs); i *= 2 {
b.Run(fmt.Sprintf("VerifyBlobKZGProofBatch(count=%v)", i), func(b *testing.B) {
b.ReportAllocs()
for n := 0; n < b.N; n++ {
_ = ctx.VerifyBlobKZGProofBatch(blobs[:i], commitments[:i], proofs[:i])
}
Expand All @@ -108,9 +115,32 @@ func Benchmark(b *testing.B) {

for i := 1; i <= len(blobs); i *= 2 {
b.Run(fmt.Sprintf("VerifyBlobKZGProofBatchPar(count=%v)", i), func(b *testing.B) {
b.ReportAllocs()
for n := 0; n < b.N; n++ {
_ = ctx.VerifyBlobKZGProofBatchPar(blobs[:i], commitments[:i], proofs[:i])
}
})
}
}

func BenchmarkDeserializeBlob(b *testing.B) {
var (
blob = GetRandBlob(int64(13))
first, err = gokzg4844.DeserializeBlob(blob)
second kzg.Polynomial
)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.ReportAllocs()
for n := 0; n < b.N; n++ {
second, err = gokzg4844.DeserializeBlob(blob)
if err != nil {
b.Fatal(err)
}
}
if have, want := fmt.Sprintf("%x", second), fmt.Sprintf("%x", first); have != want {
b.Fatalf("have %s want %s", have, want)
}
}
10 changes: 5 additions & 5 deletions consensus_specs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ func TestVerifyBlobKZGProofBatch(t *testing.T) {
require.False(t, testCaseValid)
return
}
blobs = append(blobs, blob)
blobs = append(blobs, *blob)
}

var commitments []gokzg4844.KZGCommitment
Expand Down Expand Up @@ -355,18 +355,18 @@ func TestVerifyBlobKZGProofBatch(t *testing.T) {
}
}

func hexStrToBlob(hexStr string) (gokzg4844.Blob, error) {
func hexStrToBlob(hexStr string) (*gokzg4844.Blob, error) {
var blob gokzg4844.Blob
byts, err := hexStrToBytes(hexStr)
if err != nil {
return blob, err
return nil, err
}

if len(blob) != len(byts) {
return blob, fmt.Errorf("blob does not have the correct length, %d ", len(byts))
return nil, fmt.Errorf("blob does not have the correct length, %d ", len(byts))
}
copy(blob[:], byts)
return blob, nil
return &blob, nil
}

func hexStrToScalar(hexStr string) (gokzg4844.Scalar, error) {
Expand Down
2 changes: 1 addition & 1 deletion examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestBlobProveVerifyBatchIntegration(t *testing.T) {
proof, err := ctx.ComputeBlobKZGProof(blob, commitment, NumGoRoutines)
require.NoError(t, err)

blobs[i] = blob
blobs[i] = *blob
commitments[i] = commitment
proofs[i] = proof
}
Expand Down
23 changes: 8 additions & 15 deletions fiatshamir.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,18 @@ const DomSepProtocol = "FSBLOBVERIFY_V1_"
// computeChallenge is provided to match the spec at [compute_challenge].
//
// [compute_challenge]: https://github.com/ethereum/consensus-specs/blob/017a8495f7671f5fff2075a9bfc9238c1a0982f8/specs/deneb/polynomial-commitments.md#compute_challenge
func computeChallenge(blob Blob, commitment KZGCommitment) fr.Element {
polyDegreeBytes := u64ToByteArray16(ScalarsPerBlob)
data := append([]byte(DomSepProtocol), polyDegreeBytes...)
data = append(data, blob[:]...)
data = append(data, commitment[:]...)

return hashToBLSField(data)
}

// hashToBLSField hashed the given binary data to a field element according to [hash_to_bls_field].
//
// [hash_to_bls_field]: https://github.com/ethereum/consensus-specs/blob/017a8495f7671f5fff2075a9bfc9238c1a0982f8/specs/deneb/polynomial-commitments.md#hash_to_bls_field
func hashToBLSField(data []byte) fr.Element {
digest := sha256.Sum256(data)

// Now interpret those bytes as a field element
func computeChallenge(blob *Blob, commitment KZGCommitment) fr.Element {
h := sha256.New()
h.Write([]byte(DomSepProtocol))
h.Write(u64ToByteArray16(ScalarsPerBlob))
h.Write(blob[:])
h.Write(commitment[:])

digest := h.Sum(nil)
var challenge fr.Element
challenge.SetBytes(digest[:])

return challenge
}

Expand Down
24 changes: 23 additions & 1 deletion fiatshamir_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ import (
"testing"

bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
"github.com/stretchr/testify/require"
)

// This is both an interop test and a regression check
// If the way computeChallenge is computed is updated
// then this test will fail
func TestComputeChallengeInterop(t *testing.T) {
blob := Blob{}
blob := &Blob{}
commitment := SerializeG1Point(bls12381.G1Affine{})
challenge := computeChallenge(blob, KZGCommitment(commitment))
expected := []byte{
Expand All @@ -34,3 +35,24 @@ func TestTo16Bytes(t *testing.T) {
got := u64ToByteArray16(number)
require.Equal(t, expected, got)
}

func BenchmarkComputeChallenge(b *testing.B) {
var (
blob = &Blob{}
commitment = SerializeG1Point(bls12381.G1Affine{})
challenge fr.Element
want = []byte{
0x04, 0xb7, 0xb2, 0x2a, 0xf6, 0x3d, 0x2b, 0x2f,
0x1c, 0xed, 0x8d, 0x55, 0x05, 0x60, 0xe5, 0xd1,
0xe4, 0xb0, 0x1e, 0x35, 0x59, 0x03, 0xde, 0xe2,
0x27, 0x81, 0xe8, 0x78, 0x26, 0x85, 0x60, 0x96,
}
)
b.ResetTimer()
b.ReportAllocs()
for n := 0; n < b.N; n++ {
challenge = computeChallenge(blob, KZGCommitment(commitment))
}
have := SerializeScalar(challenge)
require.Equal(b, want, have[:])
}
29 changes: 17 additions & 12 deletions internal/kzg/kzg_prove.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,27 +78,28 @@ func (domain *Domain) computeQuotientPoly(f Polynomial, indexInDomain int64, fz,
// This is the implementation of computeQuotientPoly for the case where z is not in the domain.
// Since both input and output polynomials are given in evaluation form, this method just performs the desired operation pointwise.
func (domain *Domain) computeQuotientPolyOutsideDomain(f Polynomial, fz, z fr.Element) (Polynomial, error) {
// Compute the lagrange form the of the numerator f(X) - f(z)
// Since f(X) is already in lagrange form, we can compute f(X) - f(z)
// by shifting all elements in f(X) by f(z)
numerator := make(Polynomial, len(f))
for i := 0; i < len(f); i++ {
numerator[i].Sub(&f[i], &fz)
}

// Compute the lagrange form of the denominator X - z.
// This means that we need to compute w - z for all points w in the domain.
denominator := make(Polynomial, len(f))
tmpDenom := make(Polynomial, len(f))
for i := 0; i < len(f); i++ {
denominator[i].Sub(&domain.Roots[i], &z)
tmpDenom[i].Sub(&domain.Roots[i], &z)
}

// To invert the denominator polynomial at each point of the domain, we perform a batch-inversion.
// Since `z` is not in the domain, we are sure that there are no zeroes in this inversion.
//
// Note: if there was a zero, the gnark-crypto library would skip
// it and not panic.
denominator = fr.BatchInvert(denominator)
// Note: the returned slice is a new slice, thus we are free to use tmpDenom.
denominator := fr.BatchInvert(tmpDenom)

// Compute the lagrange form of the numerator f(X) - f(z)
// Since f(X) is already in lagrange form, we can compute f(X) - f(z)
// by shifting all elements in f(X) by f(z)
numerator := tmpDenom
for i := 0; i < len(f); i++ {
numerator[i].Sub(&f[i], &fz)
}

// Compute the quotient q(X)
for i := 0; i < len(f); i++ {
Expand Down Expand Up @@ -134,7 +135,11 @@ func (domain *Domain) computeQuotientPolyOnDomain(f Polynomial, index uint64) (P
// Evaluation of 1/(X-z) at every point of the domain, except for index.
invRootsMinusZ := fr.BatchInvert(rootsMinusZ)

quotientPoly := make(Polynomial, domain.Cardinality)
// The rootsMinusZ is now free to reuse, since BatchInvert returned
// a fresh slice. But we need to ensure to set the value for 'index' to zero
quotientPoly := rootsMinusZ
quotientPoly[index] = fr.Element{}

for j := 0; j < int(domain.Cardinality); j++ {
// Check if we are on the current root of unity
// Note: For notations below, we use `m` to denote `index`
Expand Down
6 changes: 3 additions & 3 deletions prove.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
// value to a negative number or 0 will make it default to the number of CPUs.
//
// [blob_to_kzg_commitment]: https://github.com/ethereum/consensus-specs/blob/017a8495f7671f5fff2075a9bfc9238c1a0982f8/specs/deneb/polynomial-commitments.md#blob_to_kzg_commitment
func (c *Context) BlobToKZGCommitment(blob Blob, numGoRoutines int) (KZGCommitment, error) {
func (c *Context) BlobToKZGCommitment(blob *Blob, numGoRoutines int) (KZGCommitment, error) {
// 1. Deserialization
//
// Deserialize blob into polynomial
Expand Down Expand Up @@ -43,7 +43,7 @@ func (c *Context) BlobToKZGCommitment(blob Blob, numGoRoutines int) (KZGCommitme
// value to a negative number or 0 will make it default to the number of CPUs.
//
// [compute_blob_kzg_proof]: https://github.com/ethereum/consensus-specs/blob/017a8495f7671f5fff2075a9bfc9238c1a0982f8/specs/deneb/polynomial-commitments.md#compute_blob_kzg_proof
func (c *Context) ComputeBlobKZGProof(blob Blob, blobCommitment KZGCommitment, numGoRoutines int) (KZGProof, error) {
func (c *Context) ComputeBlobKZGProof(blob *Blob, blobCommitment KZGCommitment, numGoRoutines int) (KZGProof, error) {
// 1. Deserialization
//
polynomial, err := DeserializeBlob(blob)
Expand Down Expand Up @@ -82,7 +82,7 @@ func (c *Context) ComputeBlobKZGProof(blob Blob, blobCommitment KZGCommitment, n
// value to a negative number or 0 will make it default to the number of CPUs.
//
// [compute_kzg_proof]: https://github.com/ethereum/consensus-specs/blob/017a8495f7671f5fff2075a9bfc9238c1a0982f8/specs/deneb/polynomial-commitments.md#compute_kzg_proof
func (c *Context) ComputeKZGProof(blob Blob, inputPointBytes Scalar, numGoRoutines int) (KZGProof, Scalar, error) {
func (c *Context) ComputeKZGProof(blob *Blob, inputPointBytes Scalar, numGoRoutines int) (KZGProof, Scalar, error) {
// 1. Deserialization
//
polynomial, err := DeserializeBlob(blob)
Expand Down
13 changes: 5 additions & 8 deletions serialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,13 @@ func DeserializeKZGProof(proof KZGProof) (bls12381.G1Affine, error) {
// DeserializeBlob implements [blob_to_polynomial].
//
// [blob_to_polynomial]: https://github.com/ethereum/consensus-specs/blob/017a8495f7671f5fff2075a9bfc9238c1a0982f8/specs/deneb/polynomial-commitments.md#blob_to_polynomial
func DeserializeBlob(blob Blob) (kzg.Polynomial, error) {
func DeserializeBlob(blob *Blob) (kzg.Polynomial, error) {
poly := make(kzg.Polynomial, ScalarsPerBlob)
for i := 0; i < ScalarsPerBlob; i++ {
chunk := blob[i*SerializedScalarSize : (i+1)*SerializedScalarSize]
serScalar := (*Scalar)(chunk)
scalar, err := DeserializeScalar(*serScalar)
if err != nil {
return nil, err
if err := poly[i].SetBytesCanonical(chunk); err != nil {
return nil, ErrNonCanonicalScalar
}
poly[i] = scalar
}
return poly, nil
}
Expand All @@ -143,12 +140,12 @@ func SerializeScalar(element fr.Element) Scalar {
//
// Note: This method is never used in the API because we always expect a byte array and will never receive deserialized
// field elements. We include it so that upstream fuzzers do not need to reimplement it.
func SerializePoly(poly kzg.Polynomial) Blob {
func SerializePoly(poly kzg.Polynomial) *Blob {
var blob Blob
for i := 0; i < ScalarsPerBlob; i++ {
chunk := blob[i*SerializedScalarSize : (i+1)*SerializedScalarSize]
serScalar := SerializeScalar(poly[i])
copy(chunk, serScalar[:])
}
return blob
return &blob
}
6 changes: 3 additions & 3 deletions verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (c *Context) VerifyKZGProof(blobCommitment KZGCommitment, inputPointBytes,
// VerifyBlobKZGProof implements [verify_blob_kzg_proof].
//
// [verify_blob_kzg_proof]: https://github.com/ethereum/consensus-specs/blob/017a8495f7671f5fff2075a9bfc9238c1a0982f8/specs/deneb/polynomial-commitments.md#verify_blob_kzg_proof
func (c *Context) VerifyBlobKZGProof(blob Blob, blobCommitment KZGCommitment, kzgProof KZGProof) error {
func (c *Context) VerifyBlobKZGProof(blob *Blob, blobCommitment KZGCommitment, kzgProof KZGProof) error {
// 1. Deserialize
//
polynomial, err := DeserializeBlob(blob)
Expand Down Expand Up @@ -114,7 +114,7 @@ func (c *Context) VerifyBlobKZGProofBatch(blobs []Blob, polynomialCommitments []
return err
}

blob := blobs[i]
blob := &blobs[i]
polynomial, err := DeserializeBlob(blob)
if err != nil {
return err
Expand Down Expand Up @@ -160,7 +160,7 @@ func (c *Context) VerifyBlobKZGProofBatchPar(blobs []Blob, commitments []KZGComm
for i := range blobs {
j := i // Capture the value of the loop variable
errG.Go(func() error {
return c.VerifyBlobKZGProof(blobs[j], commitments[j], proofs[j])
return c.VerifyBlobKZGProof(&blobs[j], commitments[j], proofs[j])
})
}

Expand Down

0 comments on commit 51e065e

Please sign in to comment.