Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Add new domain abstraction #86

Merged
merged 4 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package goethkzg
import (
"encoding/json"

"github.com/crate-crypto/go-eth-kzg/internal/domain"
"github.com/crate-crypto/go-eth-kzg/internal/kzg"
kzgmulti "github.com/crate-crypto/go-eth-kzg/internal/kzg_multi"
"github.com/crate-crypto/go-eth-kzg/internal/kzg_multi/fk20"
Expand All @@ -13,8 +14,8 @@ import (
// Note: We could marshall this object so that clients won't need to process the SRS each time. The time to process is
// about 2-5 seconds.
type Context struct {
domain *kzg.Domain
domainExtended *kzg.Domain
domain *domain.Domain
domainExtended *domain.Domain
commitKeyLagrange *kzg.CommitKey
commitKeyMonomial *kzg.CommitKey
openKey *kzg.OpeningKey
Expand Down Expand Up @@ -121,20 +122,20 @@ func NewContext4096(trustedSetup *JSONTrustedSetup) (*Context, error) {
G2: setupG2Points,
}

domain := kzg.NewDomain(ScalarsPerBlob)
domainBlobLen := domain.NewDomain(ScalarsPerBlob)
// Bit-Reverse the roots and the trusted setup according to the specs
// The bit reversal is not needed for simple KZG however it was
// implemented to make the step for full dank-sharding easier.
commitKeyLagrange.ReversePoints()
domain.ReverseRoots()
domainBlobLen.ReverseRoots()

domainExtended := kzg.NewDomain(scalarsPerExtBlob)
domainExtended := domain.NewDomain(scalarsPerExtBlob)
domainExtended.ReverseRoots()

fk20 := fk20.NewFK20(commitKeyMonomial.G1, scalarsPerExtBlob, scalarsPerCell)

return &Context{
domain: domain,
domain: domainBlobLen,
domainExtended: domainExtended,
commitKeyLagrange: &commitKeyLagrange,
commitKeyMonomial: &commitKeyMonomial,
Expand Down
8 changes: 4 additions & 4 deletions api_eip7594.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"slices"

"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
"github.com/crate-crypto/go-eth-kzg/internal/kzg"
"github.com/crate-crypto/go-eth-kzg/internal/domain"
kzgmulti "github.com/crate-crypto/go-eth-kzg/internal/kzg_multi"
)

Expand All @@ -16,7 +16,7 @@ func (ctx *Context) ComputeCellsAndKZGProofs(blob *Blob, numGoRoutines int) ([Ce
}

// Bit reverse the polynomial representing the Blob so that it is in normal order
kzg.BitReverse(polynomial)
domain.BitReverse(polynomial)

// Convert the polynomial in lagrange form to a polynomial in monomial form
polyCoeff := ctx.domain.IfftFr(polynomial)
Expand Down Expand Up @@ -100,7 +100,7 @@ func (ctx *Context) RecoverCellsAndComputeKZGProofs(cellIDs []uint64, cells []*C
missingCellIds := make([]uint64, 0, CellsPerExtBlob)
for cellID := uint64(0); cellID < CellsPerExtBlob; cellID++ {
if !slices.Contains(cellIDs, cellID) {
missingCellIds = append(missingCellIds, (kzg.BitReverseInt(cellID, CellsPerExtBlob)))
missingCellIds = append(missingCellIds, (domain.BitReverseInt(cellID, CellsPerExtBlob)))
}
}

Expand All @@ -119,7 +119,7 @@ func (ctx *Context) RecoverCellsAndComputeKZGProofs(cellIDs []uint64, cells []*C
copy(extendedBlob[cellID*scalarsPerCell:], cellEvals)
}
// Bit reverse the extendedBlob so that it is in normal order
kzg.BitReverse(extendedBlob)
domain.BitReverse(extendedBlob)

polyCoeff, err := ctx.dataRecovery.RecoverPolynomialCoefficients(extendedBlob, missingCellIds)
if err != nil {
Expand Down
67 changes: 67 additions & 0 deletions internal/domain/coset_fft.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package domain

import (
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
)

// FFTCoset represents a coset for Fast Fourier Transform operations.
// It contains the generator of the coset and its inverse.
type FFTCoset struct {
// CosetGen is the generator element of the coset.
// It's used to shift the domain for coset FFT operations.
CosetGen fr.Element

// InvCosetGen is the inverse of the coset generator.
// It's used in inverse coset FFT operations to shift back to the original domain.
InvCosetGen fr.Element
}

// CosetDomain represents a domain for performing FFT operations over a coset.
// It combines a standard FFT domain with coset information for efficient coset FFT computations.
type CosetDomain struct {
// domain is the underlying FFT domain.
domain *Domain

// coset contains the coset generator and its inverse for this domain.
coset FFTCoset
}

// NewCosetDomain creates a new CosetDomain with the given Domain and FFTCoset.
func NewCosetDomain(domain *Domain, fft_coset FFTCoset) *CosetDomain {
return &CosetDomain{
domain: domain,
coset: fft_coset,
}
}

// CosetFFtFr performs a forward coset FFT on the input values.
//
// It first scales the input values by powers of the coset generator,
// then performs a standard FFT on the scaled values.
func (d *CosetDomain) CosetFFtFr(values []fr.Element) []fr.Element {
result := make([]fr.Element, len(values))

cosetScale := fr.One()
for i := 0; i < len(values); i++ {
result[i].Mul(&values[i], &cosetScale)
cosetScale.Mul(&cosetScale, &d.coset.CosetGen)
}

return d.domain.FftFr(result)
}

// CosetIFFtFr performs an inverse coset FFT on the input values.
//
// It first performs a standard inverse FFT, then scales the results
// by powers of the inverse coset generator to shift back to the original domain.
func (d *CosetDomain) CosetIFFtFr(values []fr.Element) []fr.Element {
result := d.domain.IfftFr(values)

cosetScale := fr.One()
for i := 0; i < len(result); i++ {
result[i].Mul(&result[i], &cosetScale)
cosetScale.Mul(&cosetScale, &d.coset.InvCosetGen)
}

return result
}
25 changes: 8 additions & 17 deletions internal/kzg/domain.go → internal/domain/domain.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package kzg
package domain

import (
"fmt"
Expand Down Expand Up @@ -39,12 +39,6 @@ type Domain struct {
// f(x)/g(x) where g(x) is a linear polynomial
// which vanishes on a point on the domain
PreComputedInverses []fr.Element

// CosetGenerator is the generator for the coset domain.
CosetGenerator fr.Element

// CosetGeneratorInv is the inverse of the generator for the coset domain.
CosetGeneratorInv fr.Element
}

// NewDomain returns a new domain with the desired number of points x.
Expand Down Expand Up @@ -100,9 +94,6 @@ func NewDomain(x uint64) *Domain {
// We use BatchInvert instead of the above for clarity.
domain.PreComputedInverses = fr.BatchInvert(domain.Roots)

domain.CosetGenerator = fr.NewElement(7)
domain.CosetGeneratorInv.Inverse(&domain.CosetGenerator)

return domain
}

Expand Down Expand Up @@ -165,11 +156,11 @@ func (domain *Domain) ReverseRoots() {
BitReverse(domain.PreComputedInverses)
}

// findRootIndex returns the index of the element in the domain or -1 if not found.
// FindRootIndex returns the index of the element in the domain or -1 if not found.
//
// - If point is in the domain (meaning that point is a domain.Cardinality'th root of unity), returns the index of the point in the domain.
// - If point is not in the domain, returns -1.
func (domain *Domain) findRootIndex(point fr.Element) int64 {
func (domain *Domain) FindRootIndex(point fr.Element) int64 {
for i := int64(0); i < int64(domain.Cardinality); i++ {
if point.Equal(&domain.Roots[i]) {
return i
Expand All @@ -185,21 +176,21 @@ func (domain *Domain) findRootIndex(point fr.Element) int64 {
// If len(poly) != domain.Cardinality, returns an error.
//
// [evaluate_polynomial_in_evaluation_form]: https://github.com/ethereum/consensus-specs/blob/017a8495f7671f5fff2075a9bfc9238c1a0982f8/specs/deneb/polynomial-commitments.md#evaluate_polynomial_in_evaluation_form
func (domain *Domain) EvaluateLagrangePolynomial(poly Polynomial, evalPoint fr.Element) (*fr.Element, error) {
outputPoint, _, err := domain.evaluateLagrangePolynomial(poly, evalPoint)
func (domain *Domain) EvaluateLagrangePolynomial(poly []fr.Element, evalPoint fr.Element) (*fr.Element, error) {
outputPoint, _, err := domain.EvaluateLagrangePolynomialWithIndex(poly, evalPoint)

return outputPoint, err
}

// evaluateLagrangePolynomial is the implementation for [EvaluateLagrangePolynomial].
// EvaluateLagrangePolynomialWithIndex is the implementation for [EvaluateLagrangePolynomial].
//
// It evaluates a Lagrange polynomial at the given point of evaluation and reports whether the given point was among the points of the domain:
// - The input polynomial is given in evaluation form, that is, a list of evaluations at the points in the domain.
// - The evaluationResult is the result of evaluation at evalPoint.
// - indexInDomain is the index inside domain.Roots, if evalPoint is among them, -1 otherwise
//
// This semantics was copied from the go library, see: https://cs.opensource.google/go/x/exp/+/522b1b58:slices/slices.go;l=117
func (domain *Domain) evaluateLagrangePolynomial(poly Polynomial, evalPoint fr.Element) (*fr.Element, int64, error) {
func (domain *Domain) EvaluateLagrangePolynomialWithIndex(poly []fr.Element, evalPoint fr.Element) (*fr.Element, int64, error) {
var indexInDomain int64 = -1

if domain.Cardinality != uint64(len(poly)) {
Expand All @@ -210,7 +201,7 @@ func (domain *Domain) evaluateLagrangePolynomial(poly Polynomial, evalPoint fr.E
// then evaluation of the polynomial in lagrange form
// is the same as indexing it with the position
// that the evaluation point is in, in the domain
indexInDomain = domain.findRootIndex(evalPoint)
indexInDomain = domain.FindRootIndex(evalPoint)
if indexInDomain != -1 {
return &poly[indexInDomain], indexInDomain, nil
}
Expand Down
10 changes: 5 additions & 5 deletions internal/kzg/domain_test.go → internal/domain/domain_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package kzg
package domain

import (
"crypto/rand"
Expand Down Expand Up @@ -102,7 +102,7 @@ func TestEvalPolynomialSmoke(t *testing.T) {

// lagrangePoly are the evaluations of the coefficient polynomial over
// `domain`
lagrangePoly := make(Polynomial, domain.Cardinality)
lagrangePoly := make([]fr.Element, domain.Cardinality)
for i := 0; i < int(domain.Cardinality); i++ {
x := domain.Roots[i]
lagrangePoly[i] = f(x)
Expand All @@ -113,7 +113,7 @@ func TestEvalPolynomialSmoke(t *testing.T) {
for i := int64(0); i < int64(domain.Cardinality); i++ {
inputPoint := domain.Roots[i]

gotOutputPoint, indexInDomain, err := domain.evaluateLagrangePolynomial(lagrangePoly, inputPoint)
gotOutputPoint, indexInDomain, err := domain.EvaluateLagrangePolynomialWithIndex(lagrangePoly, inputPoint)
if err != nil {
t.Error(err)
}
Expand All @@ -137,7 +137,7 @@ func TestEvalPolynomialSmoke(t *testing.T) {
// Sample some random point
inputPoint := samplePointOutsideDomain(*domain)

gotOutputPoint, indexInDomain, err := domain.evaluateLagrangePolynomial(lagrangePoly, *inputPoint)
gotOutputPoint, indexInDomain, err := domain.EvaluateLagrangePolynomialWithIndex(lagrangePoly, *inputPoint)
if err != nil {
t.Errorf(err.Error(), inputPoint.Bytes())
}
Expand All @@ -161,7 +161,7 @@ func samplePointOutsideDomain(domain Domain) *fr.Element {

for {
randElement.SetUint64(randUint64())
if domain.findRootIndex(randElement) == -1 {
if domain.FindRootIndex(randElement) == -1 {
break
}
}
Expand Down
5 changes: 5 additions & 0 deletions internal/domain/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package domain

import "errors"

var ErrPolynomialMismatchedSizeDomain = errors.New("domain size does not equal the number of evaluations in the polynomial")
26 changes: 1 addition & 25 deletions internal/kzg/fft.go → internal/domain/fft.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package kzg
package domain

import (
"math/big"
Expand Down Expand Up @@ -92,30 +92,6 @@ func fftG1(values []bls12381.G1Affine, nthRootOfUnity fr.Element) []bls12381.G1A
return evaluations
}

func (d *Domain) CosetFFtFr(values []fr.Element) []fr.Element {
result := make([]fr.Element, len(values))

cosetScale := fr.One()
for i := 0; i < len(values); i++ {
result[i].Mul(&values[i], &cosetScale)
cosetScale.Mul(&cosetScale, &d.CosetGenerator)
}

return d.FftFr(result)
}

func (d *Domain) CosetIFFtFr(values []fr.Element) []fr.Element {
result := d.IfftFr(values)

cosetScale := fr.One()
for i := 0; i < len(result); i++ {
result[i].Mul(&result[i], &cosetScale)
cosetScale.Mul(&cosetScale, &d.CosetGeneratorInv)
}

return result
}

func (d *Domain) FftFr(values []fr.Element) []fr.Element {
return fftFr(values, d.Generator)
}
Expand Down
34 changes: 8 additions & 26 deletions internal/kzg/fft_test.go → internal/domain/fft_test.go
Original file line number Diff line number Diff line change
@@ -1,34 +1,11 @@
package kzg
package domain

import (
"math/big"
"testing"

"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
)

func TestSRSConversion(t *testing.T) {
n := uint64(4096)
domain := NewDomain(n)
secret := big.NewInt(100)
srsMonomial, err := newMonomialSRSInsecureUint64(n, secret)
if err != nil {
t.Error(err)
}
srsLagrange, err := newLagrangeSRSInsecure(*domain, secret)
if err != nil {
t.Error(err)
}

lagrangeSRS := domain.IfftG1(srsMonomial.CommitKey.G1)

for i := uint64(0); i < n; i++ {
if !lagrangeSRS[i].Equal(&srsLagrange.CommitKey.G1[i]) {
t.Fatalf("conversion incorrect")
}
}
}

func TestFFt(t *testing.T) {
n := uint64(8)
polyMonomial := []fr.Element{
Expand All @@ -53,8 +30,13 @@ func TestFFt(t *testing.T) {
}
}

polyLagrangeCoset := d.CosetFFtFr(polyMonomial)
gotPolyMonomial = d.CosetIFFtFr(polyLagrangeCoset)
fftCoset := FFTCoset{}
fftCoset.CosetGen = fr.NewElement(7)
fftCoset.InvCosetGen.Inverse(&fftCoset.CosetGen)
cosetDomain := NewCosetDomain(d, fftCoset)

polyLagrangeCoset := cosetDomain.CosetFFtFr(polyMonomial)
gotPolyMonomial = cosetDomain.CosetIFFtFr(polyLagrangeCoset)

for i := uint64(0); i < n; i++ {
if !polyMonomial[i].Equal(&gotPolyMonomial[i]) {
Expand Down
Loading
Loading