From a9f64767e06ab7dadb7df94d795d2e2dc9c612ea Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Mon, 9 Dec 2024 16:39:13 +0700 Subject: [PATCH 01/20] batch-commit --- src/ceno_binding/mod.rs | 5 ++++- src/ceno_binding/pcs.rs | 21 +++++++++++++++++---- src/poly_utils/coeffs.rs | 4 ++++ src/whir/committer.rs | 13 ++++++++++++- 4 files changed, 37 insertions(+), 6 deletions(-) diff --git a/src/ceno_binding/mod.rs b/src/ceno_binding/mod.rs index 5d7e5f4..589d3d8 100644 --- a/src/ceno_binding/mod.rs +++ b/src/ceno_binding/mod.rs @@ -8,6 +8,8 @@ use std::fmt::Debug; pub enum Error { #[error(transparent)] ProofError(#[from] nimue::ProofError), + #[error("InvalidPcsParams")] + InvalidPcsParam, } pub trait PolynomialCommitmentScheme: Clone { @@ -25,9 +27,10 @@ pub trait PolynomialCommitmentScheme: Clone { transcript: &mut Self::Transcript, ) -> Result; - fn batch_commit( + fn batch_commit_and_write( pp: &Self::Param, polys: &[Self::Poly], + transcript: &mut Self::Transcript, ) -> Result; fn open( diff --git a/src/ceno_binding/pcs.rs b/src/ceno_binding/pcs.rs index 12d1496..09a0809 100644 --- a/src/ceno_binding/pcs.rs +++ b/src/ceno_binding/pcs.rs @@ -73,11 +73,24 @@ where Ok(witness) } - fn batch_commit( - _pp: &Self::Param, - _polys: &[Self::Poly], + fn batch_commit_and_write( + pp: &Self::Param, + polys: &[Self::Poly], + transcript: &mut Self::Transcript, ) -> Result { - todo!() + if polys.is_empty() { + return Err(Error::InvalidPcsParam); + } + + for i in 1..polys.len() { + if polys[i].num_vars() != polys[0].num_vars() { + return Err(Error::InvalidPcsParam); + } + } + + let committer = Committer::new(pp.clone()); + let witness = committer.batch_commit(transcript, polys)?; + Ok(witness) } fn open( diff --git a/src/poly_utils/coeffs.rs b/src/poly_utils/coeffs.rs index 5fbbea5..c42a114 100644 --- a/src/poly_utils/coeffs.rs +++ b/src/poly_utils/coeffs.rs @@ -29,6 +29,10 @@ impl CoefficientList where F: Field, { + pub fn num_vars(&self) -> usize { + self.num_variables + } + /// Evaluate the given polynomial at `point` from {0,1}^n pub fn evaluate_hypercube(&self, point: BinaryHypercubePoint) -> F { assert_eq!(self.coeffs.len(), 1 << self.num_variables); diff --git a/src/whir/committer.rs b/src/whir/committer.rs index 30d3484..c096e9c 100644 --- a/src/whir/committer.rs +++ b/src/whir/committer.rs @@ -35,7 +35,7 @@ where impl Committer where F: FftField, - MerkleConfig: Config + MerkleConfig: Config, { pub fn new(config: WhirConfig) -> Self { Self(config) @@ -111,4 +111,15 @@ where ood_answers, }) } + + pub fn batch_commit( + &self, + _merlin: &mut Merlin, + _polys: &[CoefficientList], + ) -> ProofResult> + where + Merlin: FieldWriter + FieldChallenges + ByteWriter + DigestWriter, + { + todo!() + } } From 93e54428402229c2777055d8b33bda544763182c Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Wed, 11 Dec 2024 20:50:43 +0700 Subject: [PATCH 02/20] (wip) add batch_commit --- src/ceno_binding/pcs.rs | 8 +-- src/utils.rs | 6 +++ src/whir/committer.rs | 114 ++++++++++++++++++++++++++++++++++++++-- 3 files changed, 120 insertions(+), 8 deletions(-) diff --git a/src/ceno_binding/pcs.rs b/src/ceno_binding/pcs.rs index 09a0809..665399d 100644 --- a/src/ceno_binding/pcs.rs +++ b/src/ceno_binding/pcs.rs @@ -5,7 +5,7 @@ use crate::parameters::{ }; use crate::poly_utils::{coeffs::CoefficientList, MultilinearPoint}; use crate::whir::{ - committer::{Committer, Witness}, + committer::{Committer, Witnesses}, iopattern::WhirIOPattern, parameters::WhirConfig, prover::Prover, @@ -34,7 +34,7 @@ where E: FftField + CanonicalSerialize + CanonicalDeserialize, { type Param = WhirPCSConfig; - type CommitmentWithData = Witness>; + type CommitmentWithData = Witnesses>; type Proof = WhirProof, E>; // TODO: support both base and extension fields type Poly = CoefficientList; @@ -70,7 +70,7 @@ where ) -> Result { let committer = Committer::new(pp.clone()); let witness = committer.commit(transcript, poly.clone())?; - Ok(witness) + Ok(witness.into()) } fn batch_commit_and_write( @@ -106,7 +106,7 @@ where evaluations: vec![eval.clone()], }; - let proof = prover.prove(transcript, statement, witness)?; + let proof = prover.prove(transcript, statement, witness.into())?; Ok(proof) } diff --git a/src/utils.rs b/src/utils.rs index fc0104a..79b995f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -83,6 +83,12 @@ pub fn stack_evaluations(mut evals: Vec, folding_factor: usize) -> evals } +/// Takes a vector of matrix and stacking them horizontally +/// Use in-place matrix transposes to avoid data copy +pub fn horizontal_stacking(mut _evals: Vec>, _folding_factor: usize) -> Vec { + todo!() +} + #[cfg(test)] mod tests { use crate::utils::base_decomposition; diff --git a/src/whir/committer.rs b/src/whir/committer.rs index c096e9c..c169750 100644 --- a/src/whir/committer.rs +++ b/src/whir/committer.rs @@ -27,6 +27,41 @@ where pub(crate) ood_answers: Vec, } +pub struct Witnesses +where + MerkleConfig: Config, +{ + pub(crate) polys: Vec>, + pub(crate) merkle_tree: MerkleTree, + pub(crate) merkle_leaves: Vec, + pub(crate) ood_points: Vec, + pub(crate) ood_answers: Vec>, +} + +impl From> for Witnesses { + fn from(witness: Witness) -> Self { + Self { + polys: vec![witness.polynomial], + merkle_tree: witness.merkle_tree, + merkle_leaves: witness.merkle_leaves, + ood_points: witness.ood_points, + ood_answers: vec![witness.ood_answers], + } + } +} + +impl From> for Witness { + fn from(witness: Witnesses) -> Self { + Self { + polynomial: witness.polys[0].clone(), + merkle_tree: witness.merkle_tree, + merkle_leaves: witness.merkle_leaves, + ood_points: witness.ood_points, + ood_answers: witness.ood_answers[0].clone(), + } + } +} + pub struct Committer(WhirConfig) where F: FftField, @@ -114,12 +149,83 @@ where pub fn batch_commit( &self, - _merlin: &mut Merlin, - _polys: &[CoefficientList], - ) -> ProofResult> + merlin: &mut Merlin, + polys: &[CoefficientList], + ) -> ProofResult> where Merlin: FieldWriter + FieldChallenges + ByteWriter + DigestWriter, { - todo!() + let base_domain = self.0.starting_domain.base_domain.unwrap(); + let expansion = base_domain.size() / polys[0].num_coeffs(); + let evals = polys + .iter() + .map(|poly| expand_from_coeff(poly.coeffs(), expansion)) + .collect::>>(); + + let folded_evals = evals + .into_iter() + .map(|evals| utils::stack_evaluations(evals, self.0.folding_factor)) + .map(|evals| { + restructure_evaluations( + evals, + self.0.fold_optimisation, + base_domain.group_gen(), + base_domain.group_gen_inv(), + self.0.folding_factor, + ) + }) + .map(|evals| { + evals + .into_iter() + .map(F::from_base_prime_field) + .collect::>() + }) + .collect::>>(); + let folded_evals = utils::horizontal_stacking(folded_evals, self.0.folding_factor); + + // Group folds together as a leaf. + let fold_size = 1 << self.0.folding_factor; + #[cfg(not(feature = "parallel"))] + let leafs_iter = folded_evals.chunks_exact(fold_size * base_domain.size()); + #[cfg(feature = "parallel")] + let leafs_iter = folded_evals.par_chunks_exact(fold_size * base_domain.size()); + + let merkle_tree = MerkleTree::::new( + &self.0.leaf_hash_params, + &self.0.two_to_one_params, + leafs_iter, + ) + .unwrap(); + + let root = merkle_tree.root(); + + merlin.add_digest(root)?; + + let mut ood_points = vec![F::ZERO; self.0.committment_ood_samples]; + let mut ood_answers = vec![Vec::with_capacity(self.0.committment_ood_samples); polys.len()]; + if self.0.committment_ood_samples > 0 { + for i in 0..polys.len() { + merlin.fill_challenge_scalars(&mut ood_points)?; + ood_answers[i].extend(ood_points.iter().map(|ood_point| { + polys[i].evaluate_at_extension(&MultilinearPoint::expand_from_univariate( + *ood_point, + self.0.mv_parameters.num_variables, + )) + })); + merlin.add_scalars(&ood_answers[i])?; + } + } + let polys = polys + .into_iter() + .map(|poly| poly.clone().to_extension()) + .collect::>(); + + Ok(Witnesses { + polys, + merkle_tree, + merkle_leaves: folded_evals, + ood_points, + ood_answers, + }) } } From b02ed3f2d10a041f675f43044d0340791a4bd29f Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Thu, 12 Dec 2024 13:28:41 +0700 Subject: [PATCH 03/20] add util horizontal_stack --- src/utils.rs | 25 +++++++++++++++++++++++-- src/whir/committer.rs | 9 ++++++--- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/src/utils.rs b/src/utils.rs index 79b995f..b175386 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -83,10 +83,31 @@ pub fn stack_evaluations(mut evals: Vec, folding_factor: usize) -> evals } +/// Takes the vector of evaluations (assume that evals[i] = f(omega^i)) +/// and folds them into a vector of such that folded_evals[i] = [f(omega^(i + k * j)) for j in 0..folding_factor] +/// This function will mutate the function without return +pub fn stack_evaluations_mut(evals: &mut [F], fold_size: usize) { + let size_of_new_domain = evals.len() / fold_size; + + // interpret evals as (folding_factor_exp x size_of_new_domain)-matrix and transpose in-place + transpose(evals, fold_size, size_of_new_domain); +} + /// Takes a vector of matrix and stacking them horizontally /// Use in-place matrix transposes to avoid data copy -pub fn horizontal_stacking(mut _evals: Vec>, _folding_factor: usize) -> Vec { - todo!() +pub fn horizontal_stacking( + evals: Vec, + domain_size: usize, + folding_factor: usize, +) -> Vec { + let fold_size = 1 << folding_factor; + let domain_size_log2 = domain_size.ilog2() as usize; + let num_polys: usize = evals.len() / domain_size; + + let mut evals = stack_evaluations(evals, domain_size_log2); + let stacked_evals = evals.chunks_exact_mut(fold_size * num_polys); + stacked_evals.for_each(|eval| stack_evaluations_mut(eval, num_polys)); + evals } #[cfg(test)] diff --git a/src/whir/committer.rs b/src/whir/committer.rs index c169750..2418205 100644 --- a/src/whir/committer.rs +++ b/src/whir/committer.rs @@ -162,6 +162,8 @@ where .map(|poly| expand_from_coeff(poly.coeffs(), expansion)) .collect::>>(); + assert_eq!(base_domain.size(), evals[0].len()); + let folded_evals = evals .into_iter() .map(|evals| utils::stack_evaluations(evals, self.0.folding_factor)) @@ -174,14 +176,15 @@ where self.0.folding_factor, ) }) - .map(|evals| { + .flat_map(|evals| { evals .into_iter() .map(F::from_base_prime_field) .collect::>() }) - .collect::>>(); - let folded_evals = utils::horizontal_stacking(folded_evals, self.0.folding_factor); + .collect::>(); + let folded_evals = + utils::horizontal_stacking(folded_evals, base_domain.size(), self.0.folding_factor); // Group folds together as a leaf. let fold_size = 1 << self.0.folding_factor; From f73718aec37a3f97998bf877f4d2cc0100a0c057 Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Thu, 12 Dec 2024 14:34:03 +0700 Subject: [PATCH 04/20] add unit test; minor fixes --- src/utils.rs | 45 ++++++++++++++++++++++++++++++++++++------- src/whir/committer.rs | 6 +++--- 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/src/utils.rs b/src/utils.rs index b175386..8b1d873 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,5 +1,7 @@ use crate::ntt::transpose; use ark_ff::Field; +#[cfg(feature = "parallel")] +use rayon::prelude::*; use std::collections::BTreeSet; // checks whether the given number n is a power of two. @@ -86,11 +88,13 @@ pub fn stack_evaluations(mut evals: Vec, folding_factor: usize) -> /// Takes the vector of evaluations (assume that evals[i] = f(omega^i)) /// and folds them into a vector of such that folded_evals[i] = [f(omega^(i + k * j)) for j in 0..folding_factor] /// This function will mutate the function without return -pub fn stack_evaluations_mut(evals: &mut [F], fold_size: usize) { - let size_of_new_domain = evals.len() / fold_size; +pub fn stack_evaluations_mut(evals: &mut [F], folding_factor: usize) { + let folding_factor_exp = 1 << folding_factor; + assert!(evals.len() % folding_factor_exp == 0); + let size_of_new_domain = evals.len() / folding_factor_exp; // interpret evals as (folding_factor_exp x size_of_new_domain)-matrix and transpose in-place - transpose(evals, fold_size, size_of_new_domain); + transpose(evals, folding_factor_exp, size_of_new_domain); } /// Takes a vector of matrix and stacking them horizontally @@ -101,12 +105,15 @@ pub fn horizontal_stacking( folding_factor: usize, ) -> Vec { let fold_size = 1 << folding_factor; - let domain_size_log2 = domain_size.ilog2() as usize; let num_polys: usize = evals.len() / domain_size; + let num_polys_log2: usize = num_polys.ilog2() as usize; - let mut evals = stack_evaluations(evals, domain_size_log2); + let mut evals = stack_evaluations(evals, num_polys_log2); + #[cfg(not(feature = "parallel"))] let stacked_evals = evals.chunks_exact_mut(fold_size * num_polys); - stacked_evals.for_each(|eval| stack_evaluations_mut(eval, num_polys)); + #[cfg(feature = "parallel")] + let stacked_evals = evals.par_chunks_exact_mut(fold_size * num_polys); + stacked_evals.for_each(|eval| stack_evaluations_mut(eval, folding_factor)); evals } @@ -114,7 +121,31 @@ pub fn horizontal_stacking( mod tests { use crate::utils::base_decomposition; - use super::{is_power_of_two, stack_evaluations, to_binary}; + use super::{horizontal_stacking, is_power_of_two, stack_evaluations, to_binary}; + + #[test] + fn test_horizontal_stacking() { + use crate::crypto::fields::Field64 as F; + + let num = 256; + let domain_size = 128; + let folding_factor = 3; + let fold_size = 1 << folding_factor; + assert_eq!(num % fold_size, 0); + let evals: Vec<_> = (0..num as u64).map(F::from).collect(); + + let stacked = horizontal_stacking(evals, domain_size, folding_factor); + assert_eq!(stacked.len(), num); + + for (i, fold) in stacked.chunks_exact(fold_size).enumerate() { + assert_eq!(fold.len(), fold_size); + let offset = if i % 2 == 0 { 0 } else { domain_size }; + let row_id = i / 2; + for j in 0..fold_size { + assert_eq!(fold[j], F::from((offset + row_id * fold_size + j) as u64)); + } + } + } #[test] fn test_evaluations_stack() { diff --git a/src/whir/committer.rs b/src/whir/committer.rs index 2418205..2033dc4 100644 --- a/src/whir/committer.rs +++ b/src/whir/committer.rs @@ -189,9 +189,9 @@ where // Group folds together as a leaf. let fold_size = 1 << self.0.folding_factor; #[cfg(not(feature = "parallel"))] - let leafs_iter = folded_evals.chunks_exact(fold_size * base_domain.size()); + let leafs_iter = folded_evals.chunks_exact(fold_size * polys.len()); #[cfg(feature = "parallel")] - let leafs_iter = folded_evals.par_chunks_exact(fold_size * base_domain.size()); + let leafs_iter = folded_evals.par_chunks_exact(fold_size * polys.len()); let merkle_tree = MerkleTree::::new( &self.0.leaf_hash_params, @@ -207,8 +207,8 @@ where let mut ood_points = vec![F::ZERO; self.0.committment_ood_samples]; let mut ood_answers = vec![Vec::with_capacity(self.0.committment_ood_samples); polys.len()]; if self.0.committment_ood_samples > 0 { + merlin.fill_challenge_scalars(&mut ood_points)?; for i in 0..polys.len() { - merlin.fill_challenge_scalars(&mut ood_points)?; ood_answers[i].extend(ood_points.iter().map(|ood_point| { polys[i].evaluate_at_extension(&MultilinearPoint::expand_from_univariate( *ood_point, From 656bfd107707f80a276fcee3c5d00fa879eea41c Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Thu, 12 Dec 2024 14:38:45 +0700 Subject: [PATCH 05/20] add comments --- src/ceno_binding/pcs.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/ceno_binding/pcs.rs b/src/ceno_binding/pcs.rs index 665399d..464e4b7 100644 --- a/src/ceno_binding/pcs.rs +++ b/src/ceno_binding/pcs.rs @@ -73,6 +73,12 @@ where Ok(witness.into()) } + // Assumption: + // 1. there must be at least one polynomial + // 2. all polynomials are in base field + // (TODO: this assumption is from the whir implementation, + // if we are going to support extension field, need modify whir's implementation) + // 3. all polynomials must have the same number of variables fn batch_commit_and_write( pp: &Self::Param, polys: &[Self::Poly], From 5478a470853299e273bf863392757f232814e0af Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Sat, 14 Dec 2024 13:53:05 +0700 Subject: [PATCH 06/20] rename commitment type --- src/ceno_binding/mod.rs | 10 +++++----- src/ceno_binding/pcs.rs | 10 +++++----- src/whir/committer.rs | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/ceno_binding/mod.rs b/src/ceno_binding/mod.rs index 589d3d8..008c395 100644 --- a/src/ceno_binding/mod.rs +++ b/src/ceno_binding/mod.rs @@ -14,7 +14,7 @@ pub enum Error { pub trait PolynomialCommitmentScheme: Clone { type Param: Clone; - type CommitmentWithData; + type CommitmentWithWitness; type Proof: Clone + CanonicalSerialize + CanonicalDeserialize; type Poly: Clone; type Transcript; @@ -25,17 +25,17 @@ pub trait PolynomialCommitmentScheme: Clone { pp: &Self::Param, poly: &Self::Poly, transcript: &mut Self::Transcript, - ) -> Result; + ) -> Result; fn batch_commit_and_write( pp: &Self::Param, polys: &[Self::Poly], transcript: &mut Self::Transcript, - ) -> Result; + ) -> Result; fn open( pp: &Self::Param, - comm: Self::CommitmentWithData, + comm: Self::CommitmentWithWitness, point: &[E], eval: &E, transcript: &mut Self::Transcript, @@ -48,7 +48,7 @@ pub trait PolynomialCommitmentScheme: Clone { fn batch_open( pp: &Self::Param, polys: &[Self::Poly], - comm: Self::CommitmentWithData, + comm: Self::CommitmentWithWitness, point: &[E], evals: &[E], transcript: &mut Self::Transcript, diff --git a/src/ceno_binding/pcs.rs b/src/ceno_binding/pcs.rs index 464e4b7..f146f5a 100644 --- a/src/ceno_binding/pcs.rs +++ b/src/ceno_binding/pcs.rs @@ -34,7 +34,7 @@ where E: FftField + CanonicalSerialize + CanonicalDeserialize, { type Param = WhirPCSConfig; - type CommitmentWithData = Witnesses>; + type CommitmentWithWitness = Witnesses>; type Proof = WhirProof, E>; // TODO: support both base and extension fields type Poly = CoefficientList; @@ -67,7 +67,7 @@ where pp: &Self::Param, poly: &Self::Poly, transcript: &mut Self::Transcript, - ) -> Result { + ) -> Result { let committer = Committer::new(pp.clone()); let witness = committer.commit(transcript, poly.clone())?; Ok(witness.into()) @@ -83,7 +83,7 @@ where pp: &Self::Param, polys: &[Self::Poly], transcript: &mut Self::Transcript, - ) -> Result { + ) -> Result { if polys.is_empty() { return Err(Error::InvalidPcsParam); } @@ -101,7 +101,7 @@ where fn open( pp: &Self::Param, - witness: Self::CommitmentWithData, + witness: Self::CommitmentWithWitness, point: &[E], eval: &E, transcript: &mut Self::Transcript, @@ -119,7 +119,7 @@ where fn batch_open( _pp: &Self::Param, _polys: &[Self::Poly], - _comm: Self::CommitmentWithData, + _comm: Self::CommitmentWithWitness, _point: &[E], _evals: &[E], _transcript: &mut Self::Transcript, diff --git a/src/whir/committer.rs b/src/whir/committer.rs index 2033dc4..f6ad8a6 100644 --- a/src/whir/committer.rs +++ b/src/whir/committer.rs @@ -183,8 +183,8 @@ where .collect::>() }) .collect::>(); - let folded_evals = - utils::horizontal_stacking(folded_evals, base_domain.size(), self.0.folding_factor); + //let folded_evals = + // utils::horizontal_stacking(folded_evals, base_domain.size(), self.0.folding_factor); // Group folds together as a leaf. let fold_size = 1 << self.0.folding_factor; From 59bc2d328f9a24330d1d1a84a492d914590ef05b Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Sat, 14 Dec 2024 18:19:10 +0700 Subject: [PATCH 07/20] wip --- Cargo.toml | 1 + src/ceno_binding/mod.rs | 3 +- src/ceno_binding/pcs.rs | 18 ++--- src/utils.rs | 25 +++++++ src/whir/committer.rs | 33 +++++---- src/whir/prover.rs | 149 +++++++++++++++++++++++++++++++++++++--- 6 files changed, 197 insertions(+), 32 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ab8c041..1690e71 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ rayon = { version = "1.10.0", optional = true } goldilocks = { git = "https://github.com/scroll-tech/ceno-Goldilocks" } thiserror = "1" +# itertools = { version = "0.13", default-features = false } [profile.release] debug = true diff --git a/src/ceno_binding/mod.rs b/src/ceno_binding/mod.rs index 008c395..62cbe15 100644 --- a/src/ceno_binding/mod.rs +++ b/src/ceno_binding/mod.rs @@ -45,9 +45,8 @@ pub trait PolynomialCommitmentScheme: Clone { /// 1. Open at one point /// 2. All the polynomials share the same commitment. /// 3. The point is already a random point generated by a sum-check. - fn batch_open( + fn simple_batch_open( pp: &Self::Param, - polys: &[Self::Poly], comm: Self::CommitmentWithWitness, point: &[E], evals: &[E], diff --git a/src/ceno_binding/pcs.rs b/src/ceno_binding/pcs.rs index f146f5a..6a31ecc 100644 --- a/src/ceno_binding/pcs.rs +++ b/src/ceno_binding/pcs.rs @@ -116,15 +116,17 @@ where Ok(proof) } - fn batch_open( - _pp: &Self::Param, - _polys: &[Self::Poly], - _comm: Self::CommitmentWithWitness, - _point: &[E], - _evals: &[E], - _transcript: &mut Self::Transcript, + fn simple_batch_open( + pp: &Self::Param, + witnesses: Self::CommitmentWithWitness, + point: &[E], + evals: &[E], + transcript: &mut Self::Transcript, ) -> Result { - todo!() + assert_eq!(witnesses.polys.len(), evals.len()); + let prover = Prover(pp.clone()); + let proof = prover.simple_batch_prove(transcript, point, evals, witnesses)?; + Ok(proof) } fn verify( diff --git a/src/utils.rs b/src/utils.rs index 8b1d873..fbfc542 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,5 +1,15 @@ use crate::ntt::transpose; +use crate::whir::fs_utils::DigestWriter; + +use ark_crypto_primitives::merkle_tree::Config; +use ark_ff::FftField; use ark_ff::Field; +use nimue::{ + plugins::ark::{FieldChallenges, FieldWriter}, + ByteChallenges, ByteWriter, +}; +use nimue_pow::PoWChallenge; + #[cfg(feature = "parallel")] use rayon::prelude::*; use std::collections::BTreeSet; @@ -117,6 +127,21 @@ pub fn horizontal_stacking( evals } +// generate a random vector for batching open/verify +pub fn generate_random_vector(_merlin: &mut Merlin, _size: usize) -> Vec +where + F: FftField, + MerkleConfig: Config, + Merlin: FieldChallenges + + FieldWriter + + ByteChallenges + + ByteWriter + + PoWChallenge + + DigestWriter, +{ + todo!() +} + #[cfg(test)] mod tests { use crate::utils::base_decomposition; diff --git a/src/whir/committer.rs b/src/whir/committer.rs index f6ad8a6..fd1364f 100644 --- a/src/whir/committer.rs +++ b/src/whir/committer.rs @@ -35,7 +35,7 @@ where pub(crate) merkle_tree: MerkleTree, pub(crate) merkle_leaves: Vec, pub(crate) ood_points: Vec, - pub(crate) ood_answers: Vec>, + pub(crate) ood_answers: Vec, } impl From> for Witnesses { @@ -45,7 +45,7 @@ impl From> for Witnesses From> for Witnes merkle_tree: witness.merkle_tree, merkle_leaves: witness.merkle_leaves, ood_points: witness.ood_points, - ood_answers: witness.ood_answers[0].clone(), + ood_answers: witness.ood_answers, } } } @@ -205,19 +205,26 @@ where merlin.add_digest(root)?; let mut ood_points = vec![F::ZERO; self.0.committment_ood_samples]; - let mut ood_answers = vec![Vec::with_capacity(self.0.committment_ood_samples); polys.len()]; + let mut ood_answers = vec![F::ZERO; polys.len() * self.0.committment_ood_samples]; if self.0.committment_ood_samples > 0 { merlin.fill_challenge_scalars(&mut ood_points)?; - for i in 0..polys.len() { - ood_answers[i].extend(ood_points.iter().map(|ood_point| { - polys[i].evaluate_at_extension(&MultilinearPoint::expand_from_univariate( - *ood_point, - self.0.mv_parameters.num_variables, - )) - })); - merlin.add_scalars(&ood_answers[i])?; - } + ood_points + .iter() + .enumerate() + .for_each(|(point_index, ood_point)| { + for j in 0..polys.len() { + let eval = polys[j].evaluate_at_extension( + &MultilinearPoint::expand_from_univariate( + *ood_point, + self.0.mv_parameters.num_variables, + ), + ); + ood_answers[point_index * polys.len() + j] = eval; + } + }); + merlin.add_scalars(&ood_points); } + let polys = polys .into_iter() .map(|poly| poly.clone().to_extension()) diff --git a/src/whir/prover.rs b/src/whir/prover.rs index 0661b31..c90baf0 100644 --- a/src/whir/prover.rs +++ b/src/whir/prover.rs @@ -1,4 +1,8 @@ -use super::{committer::Witness, parameters::WhirConfig, Statement, WhirProof}; +use super::{ + committer::{Witness, Witnesses}, + parameters::WhirConfig, + Statement, WhirProof, +}; use crate::{ domain::Domain, ntt::expand_from_coeff, @@ -11,9 +15,12 @@ use crate::{ sumcheck::prover_not_skipping::SumcheckProverNotSkipping, utils::{self, expand_randomness}, }; + use ark_crypto_primitives::merkle_tree::{Config, MerkleTree, MultiPath}; use ark_ff::FftField; use ark_poly::EvaluationDomain; +use ark_std::iterable::Iterable; +use itertools::zip_eq; use nimue::{ plugins::ark::{FieldChallenges, FieldWriter}, ByteChallenges, ByteWriter, ProofResult, @@ -65,6 +72,119 @@ where witness.polynomial.num_variables() == self.0.mv_parameters.num_variables } + fn validate_witnesses(&self, _witnesses: &Witnesses) -> bool { + todo!() + } + + /// batch open a single point for multiple polys + pub fn simple_batch_prove( + &self, + merlin: &mut Merlin, + point: &[F], + evals: &[F], + witness: Witnesses, + ) -> ProofResult> + where + Merlin: FieldChallenges + + FieldWriter + + ByteChallenges + + ByteWriter + + PoWChallenge + + DigestWriter, + { + assert!(self.validate_parameters()); + assert!(self.validate_witnesses(&witness)); + + let compute_dot_product = |evals: &[F], coeff: &[F]| -> F { + // Ensure lengths match and compute the dot product + zip_eq(evals, coeff) + .map(|(a, b)| a * b) // Element-wise multiplication + .sum() // Sum the products + }; + + let random_coeff = utils::generate_random_vector(merlin, witness.ood_answers.len()); + + let initial_claims: Vec<_> = witness + .ood_points + .into_iter() + .map(|ood_point| { + MultilinearPoint::expand_from_univariate( + ood_point, + self.0.mv_parameters.num_variables, + ) + }) + // point might not be of the form (z, z^2, z^4, ...) + .chain(std::iter::once(MultilinearPoint(point.to_vec()))) + .collect(); + + let ood_answers = witness + .ood_answers + .chunks_exact(witness.polys.len()) + .map(|answer| compute_dot_product(answer, &random_coeff)) + .collect::>(); + let eval = compute_dot_product(evals, &random_coeff); + + let initial_answers: Vec<_> = ood_answers + .into_iter() + .chain(std::iter::once(eval)) + .collect(); + + if !self.0.initial_statement { + assert!( + initial_answers.is_empty(), + "Can not have initial answers without initial statement" + ); + } + + // let combined_coeff = witness.polys.iter().map(|poly|{ + // compute_dot_product(poly) + // }) + + let mut sumcheck_prover = None; + let folding_randomness = if self.0.initial_statement { + let [combination_randomness_gen] = merlin.challenge_scalars()?; + let combination_randomness = + expand_randomness(combination_randomness_gen, initial_claims.len()); + + sumcheck_prover = Some(SumcheckProverNotSkipping::new( + witness.polynomial.clone(), + &initial_claims, + &combination_randomness, + &initial_answers, + )); + + sumcheck_prover + .as_mut() + .unwrap() + .compute_sumcheck_polynomials::( + merlin, + self.0.folding_factor, + self.0.starting_folding_pow_bits, + )? + } else { + let mut folding_randomness = vec![F::ZERO; self.0.folding_factor]; + merlin.fill_challenge_scalars(&mut folding_randomness)?; + + if self.0.starting_folding_pow_bits > 0. { + merlin.challenge_pow::(self.0.starting_folding_pow_bits)?; + } + MultilinearPoint(folding_randomness) + }; + + let round_state = RoundState { + domain: self.0.starting_domain.clone(), + round: 0, + sumcheck_prover, + folding_randomness, + coefficients: witness.polynomial, + prev_merkle: witness.merkle_tree, + prev_merkle_answers: witness.merkle_leaves, + merkle_proofs: vec![], + }; + + self.round(merlin, round_state) + } + pub fn prove( &self, merlin: &mut Merlin, @@ -72,7 +192,12 @@ where witness: Witness, ) -> ProofResult> where - Merlin: FieldChallenges + FieldWriter + ByteChallenges + ByteWriter + PoWChallenge + DigestWriter, + Merlin: FieldChallenges + + FieldWriter + + ByteChallenges + + ByteWriter + + PoWChallenge + + DigestWriter, { assert!(self.validate_parameters()); assert!(self.validate_statement(&statement)); @@ -153,7 +278,12 @@ where mut round_state: RoundState, ) -> ProofResult> where - Merlin: FieldChallenges + ByteChallenges + FieldWriter + ByteWriter + PoWChallenge + DigestWriter, + Merlin: FieldChallenges + + ByteChallenges + + FieldWriter + + ByteWriter + + PoWChallenge + + DigestWriter, { // Fold the coefficients let folded_coefficients = round_state @@ -175,7 +305,7 @@ where self.0.final_queries, merlin, )?; - + let merkle_proof = round_state .prev_merkle .generate_multi_proof(final_challenge_indexes.clone()) @@ -350,11 +480,12 @@ where ) }); - let folding_randomness = sumcheck_prover.compute_sumcheck_polynomials::( - merlin, - self.0.folding_factor, - round_params.folding_pow_bits, - )?; + let folding_randomness = sumcheck_prover + .compute_sumcheck_polynomials::( + merlin, + self.0.folding_factor, + round_params.folding_pow_bits, + )?; let round_state = RoundState { round: round_state.round + 1, From 8792e97fe852620d85efd8f30d22c1b01063e459 Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Sat, 14 Dec 2024 18:43:17 +0700 Subject: [PATCH 08/20] continued --- Cargo.lock | 1 + Cargo.toml | 2 +- src/poly_utils/coeffs.rs | 33 +++++++++++++++++++++++++++++++++ src/whir/committer.rs | 2 +- src/whir/prover.rs | 11 ++++------- 5 files changed, 40 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 048ab85..fb0e039 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1098,6 +1098,7 @@ dependencies = [ "clap", "derivative", "goldilocks", + "itertools 0.13.0", "lazy_static", "nimue", "nimue-pow", diff --git a/Cargo.toml b/Cargo.toml index 1690e71..39ca403 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ rayon = { version = "1.10.0", optional = true } goldilocks = { git = "https://github.com/scroll-tech/ceno-Goldilocks" } thiserror = "1" -# itertools = { version = "0.13", default-features = false } +itertools = { version = "0.13", default-features = false } [profile.release] debug = true diff --git a/src/poly_utils/coeffs.rs b/src/poly_utils/coeffs.rs index c42a114..04f7ed1 100644 --- a/src/poly_utils/coeffs.rs +++ b/src/poly_utils/coeffs.rs @@ -33,6 +33,39 @@ where self.num_variables } + fn coeff_at(&self, index: usize) -> F { + self.coeffs[index] + } + + pub fn combine(polys: Vec, coeffs: Vec) -> Self { + assert!(!polys.is_empty(), "Input polys cannot be empty"); + assert_eq!( + polys.len(), + coeffs.len(), + "Mismatch between polys and coeffs length" + ); + polys.iter().skip(1).for_each(|poly| { + assert_eq!( + poly.num_vars(), + polys[0].num_vars(), + "All polys must have the same number of variables" + ); + }); + + let mut combined_coeffs = vec![F::ZERO; polys[0].coeffs.len()]; + polys.iter().for_each(|poly| { + for i in 0..combined_coeffs.len() { + combined_coeffs[i] += poly.coeff_at(i); + } + }); + + // Create the new combined Poly + Self { + coeffs: combined_coeffs, + num_variables: polys[0].num_vars(), // Assuming all polys have the same number of variables + } + } + /// Evaluate the given polynomial at `point` from {0,1}^n pub fn evaluate_hypercube(&self, point: BinaryHypercubePoint) -> F { assert_eq!(self.coeffs.len(), 1 << self.num_variables); diff --git a/src/whir/committer.rs b/src/whir/committer.rs index fd1364f..8b3da64 100644 --- a/src/whir/committer.rs +++ b/src/whir/committer.rs @@ -222,7 +222,7 @@ where ood_answers[point_index * polys.len() + j] = eval; } }); - merlin.add_scalars(&ood_points); + merlin.add_scalars(&ood_points)?; } let polys = polys diff --git a/src/whir/prover.rs b/src/whir/prover.rs index c90baf0..335ce72 100644 --- a/src/whir/prover.rs +++ b/src/whir/prover.rs @@ -19,7 +19,6 @@ use crate::{ use ark_crypto_primitives::merkle_tree::{Config, MerkleTree, MultiPath}; use ark_ff::FftField; use ark_poly::EvaluationDomain; -use ark_std::iterable::Iterable; use itertools::zip_eq; use nimue::{ plugins::ark::{FieldChallenges, FieldWriter}, @@ -98,7 +97,7 @@ where let compute_dot_product = |evals: &[F], coeff: &[F]| -> F { // Ensure lengths match and compute the dot product zip_eq(evals, coeff) - .map(|(a, b)| a * b) // Element-wise multiplication + .map(|(a, b)| *a * *b) // Element-wise multiplication .sum() // Sum the products }; @@ -136,9 +135,7 @@ where ); } - // let combined_coeff = witness.polys.iter().map(|poly|{ - // compute_dot_product(poly) - // }) + let polynomial = CoefficientList::combine(witness.polys, random_coeff); let mut sumcheck_prover = None; let folding_randomness = if self.0.initial_statement { @@ -147,7 +144,7 @@ where expand_randomness(combination_randomness_gen, initial_claims.len()); sumcheck_prover = Some(SumcheckProverNotSkipping::new( - witness.polynomial.clone(), + polynomial.clone(), &initial_claims, &combination_randomness, &initial_answers, @@ -176,7 +173,7 @@ where round: 0, sumcheck_prover, folding_randomness, - coefficients: witness.polynomial, + coefficients: polynomial, prev_merkle: witness.merkle_tree, prev_merkle_answers: witness.merkle_leaves, merkle_proofs: vec![], From 59bb65a4798db983ce1a425f1e4bf9789be0b47d Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Sat, 14 Dec 2024 19:28:48 +0700 Subject: [PATCH 09/20] add util gen_rand --- src/utils.rs | 13 +++++++++---- src/whir/prover.rs | 9 +-------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/utils.rs b/src/utils.rs index fbfc542..9338030 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -6,7 +6,7 @@ use ark_ff::FftField; use ark_ff::Field; use nimue::{ plugins::ark::{FieldChallenges, FieldWriter}, - ByteChallenges, ByteWriter, + ByteChallenges, ByteWriter, ProofResult, }; use nimue_pow::PoWChallenge; @@ -128,9 +128,12 @@ pub fn horizontal_stacking( } // generate a random vector for batching open/verify -pub fn generate_random_vector(_merlin: &mut Merlin, _size: usize) -> Vec +pub fn generate_random_vector( + merlin: &mut Merlin, + size: usize, +) -> ProofResult> where - F: FftField, + F: Field, MerkleConfig: Config, Merlin: FieldChallenges + FieldWriter @@ -139,7 +142,9 @@ where + PoWChallenge + DigestWriter, { - todo!() + let [gamma] = merlin.challenge_scalars()?; + let res = expand_randomness(gamma, size); + Ok(res) } #[cfg(test)] diff --git a/src/whir/prover.rs b/src/whir/prover.rs index 335ce72..9097b13 100644 --- a/src/whir/prover.rs +++ b/src/whir/prover.rs @@ -101,7 +101,7 @@ where .sum() // Sum the products }; - let random_coeff = utils::generate_random_vector(merlin, witness.ood_answers.len()); + let random_coeff = utils::generate_random_vector(merlin, witness.ood_answers.len())?; let initial_claims: Vec<_> = witness .ood_points @@ -128,13 +128,6 @@ where .chain(std::iter::once(eval)) .collect(); - if !self.0.initial_statement { - assert!( - initial_answers.is_empty(), - "Can not have initial answers without initial statement" - ); - } - let polynomial = CoefficientList::combine(witness.polys, random_coeff); let mut sumcheck_prover = None; From ed7b216bb39931db75339de1b819171f7eda880a Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Sat, 14 Dec 2024 19:48:13 +0700 Subject: [PATCH 10/20] add validate_witnesses --- src/ceno_binding/pcs.rs | 2 +- src/poly_utils/coeffs.rs | 28 +++------------------------- src/whir/prover.rs | 20 ++++++++++++++++++-- 3 files changed, 22 insertions(+), 28 deletions(-) diff --git a/src/ceno_binding/pcs.rs b/src/ceno_binding/pcs.rs index 6a31ecc..54e0055 100644 --- a/src/ceno_binding/pcs.rs +++ b/src/ceno_binding/pcs.rs @@ -89,7 +89,7 @@ where } for i in 1..polys.len() { - if polys[i].num_vars() != polys[0].num_vars() { + if polys[i].num_variables() != polys[0].num_variables() { return Err(Error::InvalidPcsParam); } } diff --git a/src/poly_utils/coeffs.rs b/src/poly_utils/coeffs.rs index 04f7ed1..369a272 100644 --- a/src/poly_utils/coeffs.rs +++ b/src/poly_utils/coeffs.rs @@ -29,41 +29,19 @@ impl CoefficientList where F: Field, { - pub fn num_vars(&self) -> usize { - self.num_variables - } - fn coeff_at(&self, index: usize) -> F { self.coeffs[index] } pub fn combine(polys: Vec, coeffs: Vec) -> Self { - assert!(!polys.is_empty(), "Input polys cannot be empty"); - assert_eq!( - polys.len(), - coeffs.len(), - "Mismatch between polys and coeffs length" - ); - polys.iter().skip(1).for_each(|poly| { - assert_eq!( - poly.num_vars(), - polys[0].num_vars(), - "All polys must have the same number of variables" - ); - }); - let mut combined_coeffs = vec![F::ZERO; polys[0].coeffs.len()]; - polys.iter().for_each(|poly| { + polys.iter().enumerate().for_each(|(poly_index, poly)| { for i in 0..combined_coeffs.len() { - combined_coeffs[i] += poly.coeff_at(i); + combined_coeffs[i] += poly.coeff_at(i) * coeffs[poly_index]; } }); - // Create the new combined Poly - Self { - coeffs: combined_coeffs, - num_variables: polys[0].num_vars(), // Assuming all polys have the same number of variables - } + Self::new(combined_coeffs) } /// Evaluate the given polynomial at `point` from {0,1}^n diff --git a/src/whir/prover.rs b/src/whir/prover.rs index 9097b13..3c22ffd 100644 --- a/src/whir/prover.rs +++ b/src/whir/prover.rs @@ -19,6 +19,7 @@ use crate::{ use ark_crypto_primitives::merkle_tree::{Config, MerkleTree, MultiPath}; use ark_ff::FftField; use ark_poly::EvaluationDomain; +use ark_std::iterable::Iterable; use itertools::zip_eq; use nimue::{ plugins::ark::{FieldChallenges, FieldWriter}, @@ -71,8 +72,23 @@ where witness.polynomial.num_variables() == self.0.mv_parameters.num_variables } - fn validate_witnesses(&self, _witnesses: &Witnesses) -> bool { - todo!() + fn validate_witnesses(&self, witness: &Witnesses) -> bool { + assert_eq!( + witness.ood_points.len() * witness.polys.len(), + witness.ood_answers.len() + ); + if !self.0.initial_statement { + assert!(witness.ood_points.is_empty()); + } + assert!(!witness.polys.is_empty(), "Input polys cannot be empty"); + witness.polys.iter().skip(1).for_each(|poly| { + assert_eq!( + poly.num_variables(), + witness.polys[0].num_variables(), + "All polys must have the same number of variables" + ); + }); + witness.polys[0].num_variables() == self.0.mv_parameters.num_variables } /// batch open a single point for multiple polys From 1309776bbe0723bdb408509896e315745e7d0579 Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Sun, 15 Dec 2024 16:51:39 +0700 Subject: [PATCH 11/20] batch verify (wip) --- src/bin/benchmark.rs | 2 +- src/bin/main.rs | 4 +- src/ceno_binding/mod.rs | 4 +- src/ceno_binding/pcs.rs | 30 ++++--- src/parameters.rs | 10 ++- src/whir/committer.rs | 2 +- src/whir/mod.rs | 4 +- src/whir/verifier.rs | 173 ++++++++++++++++++++++++++++++++++++++-- 8 files changed, 204 insertions(+), 25 deletions(-) diff --git a/src/bin/benchmark.rs b/src/bin/benchmark.rs index 9c6da26..f88b444 100644 --- a/src/bin/benchmark.rs +++ b/src/bin/benchmark.rs @@ -230,7 +230,7 @@ fn run_whir( let num_coeffs = 1 << num_variables; - let mv_params = MultivariateParameters::::new(num_variables); + let mv_params = MultivariateParameters::::new(num_variables, 1); let whir_params = WhirParameters:: { initial_statement: true, diff --git a/src/bin/main.rs b/src/bin/main.rs index cdba3d1..4b87f6a 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -229,7 +229,7 @@ fn run_whir_as_ldt( let num_coeffs = 1 << num_variables; - let mv_params = MultivariateParameters::::new(num_variables); + let mv_params = MultivariateParameters::::new(num_variables, 1); let whir_params = WhirParameters:: { initial_statement: false, @@ -337,7 +337,7 @@ fn run_whir_pcs( let num_coeffs = 1 << num_variables; - let mv_params = MultivariateParameters::::new(num_variables); + let mv_params = MultivariateParameters::::new(num_variables, 1); let whir_params = WhirParameters:: { initial_statement: true, diff --git a/src/ceno_binding/mod.rs b/src/ceno_binding/mod.rs index 62cbe15..abef8f2 100644 --- a/src/ceno_binding/mod.rs +++ b/src/ceno_binding/mod.rs @@ -19,7 +19,7 @@ pub trait PolynomialCommitmentScheme: Clone { type Poly: Clone; type Transcript; - fn setup(poly_size: usize) -> Self::Param; + fn setup(poly_size: usize, num_polys: usize) -> Self::Param; fn commit_and_write( pp: &Self::Param, @@ -61,7 +61,7 @@ pub trait PolynomialCommitmentScheme: Clone { transcript: &Self::Transcript, ) -> Result<(), Error>; - fn batch_verify( + fn simple_batch_verify( vp: &Self::Param, point: &[E], evals: &[E], diff --git a/src/ceno_binding/pcs.rs b/src/ceno_binding/pcs.rs index 54e0055..f1929ae 100644 --- a/src/ceno_binding/pcs.rs +++ b/src/ceno_binding/pcs.rs @@ -40,8 +40,8 @@ where type Poly = CoefficientList; type Transcript = Merlin; - fn setup(poly_size: usize) -> Self::Param { - let mv_params = MultivariateParameters::::new(poly_size); + fn setup(poly_size: usize, num_polys: usize) -> Self::Param { + let mv_params = MultivariateParameters::::new(poly_size, num_polys); let starting_rate = 1; let pow_bits = default_max_pow(poly_size, starting_rate); let mut rng = ChaCha8Rng::from_seed([0u8; 32]); @@ -155,14 +155,24 @@ where Ok(()) } - fn batch_verify( - _vp: &Self::Param, - _point: &[E], - _evals: &[E], - _proof: &Self::Proof, - _transcript: &mut Self::Transcript, + fn simple_batch_verify( + vp: &Self::Param, + point: &[E], + evals: &[E], + proof: &Self::Proof, + transcript: &mut Self::Transcript, ) -> Result<(), Error> { - todo!() + let reps = 1000; + let verifier = Verifier::new(vp.clone()); + let io = IOPattern::::new("🌪️") + .commit_statement(&vp) + .add_whir_proof(&vp); + + for _ in 0..reps { + let mut arthur = io.to_arthur(transcript.transcript()); + verifier.simple_batch_verify(&mut arthur, point, evals, proof)?; + } + Ok(()) } } @@ -178,7 +188,7 @@ mod tests { fn single_point_verify() { let poly_size = 10; let num_coeffs = 1 << poly_size; - let pp = Whir::::setup(poly_size); + let pp = Whir::::setup(poly_size, 1); let poly = CoefficientList::new( (0..num_coeffs) diff --git a/src/parameters.rs b/src/parameters.rs index 633670d..3877863 100644 --- a/src/parameters.rs +++ b/src/parameters.rs @@ -46,13 +46,15 @@ impl FromStr for SoundnessType { #[derive(Debug, Clone, Copy)] pub struct MultivariateParameters { pub(crate) num_variables: usize, + pub(crate) num_polys: usize, _field: PhantomData, } impl MultivariateParameters { - pub fn new(num_variables: usize) -> Self { + pub fn new(num_variables: usize, num_polys: usize) -> Self { Self { num_variables, + num_polys, _field: PhantomData, } } @@ -60,7 +62,11 @@ impl MultivariateParameters { impl Display for MultivariateParameters { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Number of variables: {}", self.num_variables) + write!( + f, + "Number of polynomials: {}, Number of variables: {}", + self.num_polys, self.num_variables + ) } } diff --git a/src/whir/committer.rs b/src/whir/committer.rs index 8b3da64..7830a2d 100644 --- a/src/whir/committer.rs +++ b/src/whir/committer.rs @@ -222,7 +222,7 @@ where ood_answers[point_index * polys.len() + j] = eval; } }); - merlin.add_scalars(&ood_points)?; + merlin.add_scalars(&ood_answers)?; } let polys = polys diff --git a/src/whir/mod.rs b/src/whir/mod.rs index 7db951b..26aac36 100644 --- a/src/whir/mod.rs +++ b/src/whir/mod.rs @@ -4,11 +4,11 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use crate::poly_utils::MultilinearPoint; pub mod committer; +pub mod fs_utils; pub mod iopattern; pub mod parameters; pub mod prover; pub mod verifier; -pub mod fs_utils; #[derive(Debug, Clone, Default)] pub struct Statement { @@ -67,7 +67,7 @@ mod tests { let mut rng = ark_std::test_rng(); let (leaf_hash_params, two_to_one_params) = merkle_tree::default_config::(&mut rng); - let mv_params = MultivariateParameters::::new(num_variables); + let mv_params = MultivariateParameters::::new(num_variables, 1); let whir_params = WhirParameters:: { initial_statement: true, diff --git a/src/whir/verifier.rs b/src/whir/verifier.rs index c92e1fe..31ade24 100644 --- a/src/whir/verifier.rs +++ b/src/whir/verifier.rs @@ -4,8 +4,8 @@ use ark_crypto_primitives::merkle_tree::Config; use ark_ff::FftField; use ark_poly::EvaluationDomain; use nimue::{ - plugins::ark::{FieldChallenges, FieldReader} - , ByteChallenges, ByteReader, ProofError, ProofResult, + plugins::ark::{FieldChallenges, FieldReader}, + ByteChallenges, ByteReader, ProofError, ProofResult, }; use nimue_pow::{self, PoWChallenge}; @@ -82,10 +82,11 @@ where where Arthur: ByteReader + FieldReader + FieldChallenges + DigestReader, { + let num_polys = self.params.mv_parameters.num_polys; let root = arthur.read_digest()?; let mut ood_points = vec![F::ZERO; self.params.committment_ood_samples]; - let mut ood_answers = vec![F::ZERO; self.params.committment_ood_samples]; + let mut ood_answers = vec![F::ZERO; self.params.committment_ood_samples * num_polys]; if self.params.committment_ood_samples > 0 { arthur.fill_challenge_scalars(&mut ood_points)?; arthur.fill_next_scalars(&mut ood_answers)?; @@ -106,7 +107,12 @@ where whir_proof: &WhirProof, ) -> ProofResult> where - Arthur: FieldReader + FieldChallenges + PoWChallenge + ByteReader + ByteChallenges + DigestReader, + Arthur: FieldReader + + FieldChallenges + + PoWChallenge + + ByteReader + + ByteChallenges + + DigestReader, { let mut sumcheck_rounds = Vec::new(); let mut folding_randomness: MultilinearPoint; @@ -463,6 +469,158 @@ where result } + pub fn simple_batch_verify( + &self, + arthur: &mut Arthur, + point: &[F], + evals: &[F], + whir_proof: &WhirProof, + ) -> ProofResult<()> + where + Arthur: FieldChallenges + + FieldReader + + ByteChallenges + + ByteReader + + PoWChallenge + + DigestReader, + { + /* + // We first do a pass in which we rederive all the FS challenges + // Then we will check the algebraic part (so to optimise inversions) + let parsed_commitment = self.parse_commitment(arthur)?; + let parsed = self.parse_proof(arthur, &parsed_commitment, statement, whir_proof)?; + + let computed_folds = self.compute_folds(&parsed); + + let mut prev: Option<(SumcheckPolynomial, F)> = None; + if let Some(round) = parsed.initial_sumcheck_rounds.first() { + // Check the first polynomial + let (mut prev_poly, mut randomness) = round.clone(); + if prev_poly.sum_over_hypercube() + != parsed_commitment + .ood_answers + .iter() + .copied() + .chain(statement.evaluations.clone()) + .zip(&parsed.initial_combination_randomness) + .map(|(ans, rand)| ans * rand) + .sum() + { + return Err(ProofError::InvalidProof); + } + + // Check the rest of the rounds + for (sumcheck_poly, new_randomness) in &parsed.initial_sumcheck_rounds[1..] { + if sumcheck_poly.sum_over_hypercube() + != prev_poly.evaluate_at_point(&randomness.into()) + { + return Err(ProofError::InvalidProof); + } + prev_poly = sumcheck_poly.clone(); + randomness = *new_randomness; + } + + prev = Some((prev_poly, randomness)); + } + + for (round, folds) in parsed.rounds.iter().zip(&computed_folds) { + let (sumcheck_poly, new_randomness) = &round.sumcheck_rounds[0].clone(); + + let values = round.ood_answers.iter().copied().chain(folds.clone()); + + let prev_eval = if let Some((prev_poly, randomness)) = prev { + prev_poly.evaluate_at_point(&randomness.into()) + } else { + F::ZERO + }; + let claimed_sum = prev_eval + + values + .zip(&round.combination_randomness) + .map(|(val, rand)| val * rand) + .sum::(); + + if sumcheck_poly.sum_over_hypercube() != claimed_sum { + return Err(ProofError::InvalidProof); + } + + prev = Some((sumcheck_poly.clone(), *new_randomness)); + + // Check the rest of the round + for (sumcheck_poly, new_randomness) in &round.sumcheck_rounds[1..] { + let (prev_poly, randomness) = prev.unwrap(); + if sumcheck_poly.sum_over_hypercube() + != prev_poly.evaluate_at_point(&randomness.into()) + { + return Err(ProofError::InvalidProof); + } + prev = Some((sumcheck_poly.clone(), *new_randomness)); + } + } + + // Check the foldings computed from the proof match the evaluations of the polynomial + let final_folds = &computed_folds[computed_folds.len() - 1]; + let final_evaluations = parsed + .final_coefficients + .evaluate_at_univariate(&parsed.final_randomness_points); + if !final_folds + .iter() + .zip(final_evaluations) + .all(|(&fold, eval)| fold == eval) + { + return Err(ProofError::InvalidProof); + } + + // Check the final sumchecks + if self.params.final_sumcheck_rounds > 0 { + let prev_sumcheck_poly_eval = if let Some((prev_poly, randomness)) = prev { + prev_poly.evaluate_at_point(&randomness.into()) + } else { + F::ZERO + }; + let (sumcheck_poly, new_randomness) = &parsed.final_sumcheck_rounds[0].clone(); + let claimed_sum = prev_sumcheck_poly_eval; + + if sumcheck_poly.sum_over_hypercube() != claimed_sum { + return Err(ProofError::InvalidProof); + } + + prev = Some((sumcheck_poly.clone(), *new_randomness)); + + // Check the rest of the round + for (sumcheck_poly, new_randomness) in &parsed.final_sumcheck_rounds[1..] { + let (prev_poly, randomness) = prev.unwrap(); + if sumcheck_poly.sum_over_hypercube() + != prev_poly.evaluate_at_point(&randomness.into()) + { + return Err(ProofError::InvalidProof); + } + prev = Some((sumcheck_poly.clone(), *new_randomness)); + } + } + + let prev_sumcheck_poly_eval = if let Some((prev_poly, randomness)) = prev { + prev_poly.evaluate_at_point(&randomness.into()) + } else { + F::ZERO + }; + + // Check the final sumcheck evaluation + let evaluation_of_v_poly = self.compute_v_poly(&parsed_commitment, statement, &parsed); + + if prev_sumcheck_poly_eval + != evaluation_of_v_poly + * parsed + .final_coefficients + .evaluate(&parsed.final_sumcheck_randomness) + { + return Err(ProofError::InvalidProof); + } + + Ok(()) + */ + todo!() + } + pub fn verify( &self, arthur: &mut Arthur, @@ -470,7 +628,12 @@ where whir_proof: &WhirProof, ) -> ProofResult<()> where - Arthur: FieldChallenges + FieldReader + ByteChallenges + ByteReader + PoWChallenge + DigestReader, + Arthur: FieldChallenges + + FieldReader + + ByteChallenges + + ByteReader + + PoWChallenge + + DigestReader, { // We first do a pass in which we rederive all the FS challenges // Then we will check the algebraic part (so to optimise inversions) From 49ade78644e2e2145097b6dcec8251738295d26c Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Sun, 15 Dec 2024 17:31:44 +0700 Subject: [PATCH 12/20] modify parsed_commitment --- src/utils.rs | 33 +++++++++++++++++++++------------ src/whir/prover.rs | 4 ++-- src/whir/verifier.rs | 30 ++++++++++++++++++++++++++---- 3 files changed, 49 insertions(+), 18 deletions(-) diff --git a/src/utils.rs b/src/utils.rs index 9338030..a4e2c19 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,12 +1,11 @@ use crate::ntt::transpose; -use crate::whir::fs_utils::DigestWriter; +use crate::whir::fs_utils::{DigestReader, DigestWriter}; use ark_crypto_primitives::merkle_tree::Config; -use ark_ff::FftField; use ark_ff::Field; use nimue::{ - plugins::ark::{FieldChallenges, FieldWriter}, - ByteChallenges, ByteWriter, ProofResult, + plugins::ark::{FieldChallenges, FieldReader, FieldWriter}, + ByteChallenges, ByteReader, ByteWriter, ProofResult, }; use nimue_pow::PoWChallenge; @@ -127,26 +126,36 @@ pub fn horizontal_stacking( evals } -// generate a random vector for batching open/verify -pub fn generate_random_vector( +// generate a random vector for batching open +pub fn generate_random_vector_batch_open( merlin: &mut Merlin, size: usize, ) -> ProofResult> where F: Field, MerkleConfig: Config, - Merlin: FieldChallenges - + FieldWriter - + ByteChallenges - + ByteWriter - + PoWChallenge - + DigestWriter, + Merlin: FieldChallenges + FieldWriter + ByteWriter + DigestWriter, { let [gamma] = merlin.challenge_scalars()?; let res = expand_randomness(gamma, size); Ok(res) } +// generate a random vector for batching verify +pub fn generate_random_vector_batch_verify( + arthur: &mut Arthur, + size: usize, +) -> ProofResult> +where + F: Field, + MerkleConfig: Config, + Arthur: FieldChallenges + FieldReader + ByteReader + DigestReader, +{ + let [gamma] = arthur.challenge_scalars()?; + let res = expand_randomness(gamma, size); + Ok(res) +} + #[cfg(test)] mod tests { use crate::utils::base_decomposition; diff --git a/src/whir/prover.rs b/src/whir/prover.rs index 3c22ffd..0f21770 100644 --- a/src/whir/prover.rs +++ b/src/whir/prover.rs @@ -19,7 +19,6 @@ use crate::{ use ark_crypto_primitives::merkle_tree::{Config, MerkleTree, MultiPath}; use ark_ff::FftField; use ark_poly::EvaluationDomain; -use ark_std::iterable::Iterable; use itertools::zip_eq; use nimue::{ plugins::ark::{FieldChallenges, FieldWriter}, @@ -117,7 +116,8 @@ where .sum() // Sum the products }; - let random_coeff = utils::generate_random_vector(merlin, witness.ood_answers.len())?; + let random_coeff = + utils::generate_random_vector_batch_open(merlin, witness.ood_answers.len())?; let initial_claims: Vec<_> = witness .ood_points diff --git a/src/whir/verifier.rs b/src/whir/verifier.rs index 31ade24..40ac6af 100644 --- a/src/whir/verifier.rs +++ b/src/whir/verifier.rs @@ -1,3 +1,4 @@ +use itertools::zip_eq; use std::iter; use ark_crypto_primitives::merkle_tree::Config; @@ -10,13 +11,16 @@ use nimue::{ use nimue_pow::{self, PoWChallenge}; use super::{parameters::WhirConfig, Statement, WhirProof}; -use crate::whir::fs_utils::{get_challenge_stir_queries, DigestReader}; use crate::{ parameters::FoldType, poly_utils::{coeffs::CoefficientList, eq_poly_outside, fold::compute_fold, MultilinearPoint}, sumcheck::proof::SumcheckPolynomial, utils::expand_randomness, }; +use crate::{ + utils, + whir::fs_utils::{get_challenge_stir_queries, DigestReader}, +}; pub struct Verifier where @@ -82,16 +86,34 @@ where where Arthur: ByteReader + FieldReader + FieldChallenges + DigestReader, { - let num_polys = self.params.mv_parameters.num_polys; let root = arthur.read_digest()?; + let num_polys = self.params.mv_parameters.num_polys; + let size = self.params.committment_ood_samples * num_polys; + let mut ood_points = vec![F::ZERO; self.params.committment_ood_samples]; - let mut ood_answers = vec![F::ZERO; self.params.committment_ood_samples * num_polys]; + let mut raw_ood_answers = vec![F::ZERO; size]; if self.params.committment_ood_samples > 0 { arthur.fill_challenge_scalars(&mut ood_points)?; - arthur.fill_next_scalars(&mut ood_answers)?; + arthur.fill_next_scalars(&mut raw_ood_answers)?; } + let ood_answers = if num_polys > 1 { + let compute_dot_product = |evals: &[F], coeff: &[F]| -> F { + // Ensure lengths match and compute the dot product + zip_eq(evals, coeff) + .map(|(a, b)| *a * *b) // Element-wise multiplication + .sum() // Sum the products + }; + let random_coeff = utils::generate_random_vector_batch_verify(arthur, size)?; + raw_ood_answers + .chunks_exact(num_polys) + .map(|answer| compute_dot_product(answer, &random_coeff)) + .collect::>() + } else { + raw_ood_answers + }; + Ok(ParsedCommitment { root, ood_points, From b103403037e2d7de5dd38037468e72cbfd772a25 Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Mon, 16 Dec 2024 15:19:49 +0700 Subject: [PATCH 13/20] simplify simple_batch_prove --- src/whir/prover.rs | 62 ++++++++++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/src/whir/prover.rs b/src/whir/prover.rs index 0f21770..0f275d0 100644 --- a/src/whir/prover.rs +++ b/src/whir/prover.rs @@ -72,6 +72,11 @@ where } fn validate_witnesses(&self, witness: &Witnesses) -> bool { + assert_eq!( + witness.polys.len(), + self.0.mv_parameters.num_polys, + "number of polynomials not match" + ); assert_eq!( witness.ood_points.len() * witness.polys.len(), witness.ood_answers.len() @@ -106,8 +111,19 @@ where + PoWChallenge + DigestWriter, { + assert!(self.0.initial_statement, "must be true for pcs"); assert!(self.validate_parameters()); assert!(self.validate_witnesses(&witness)); + assert_eq!( + point.len(), + self.0.mv_parameters.num_variables, + "number of variables mismatch" + ); + assert_eq!( + evals.len(), + self.0.mv_parameters.num_polys, + "number of polynomials not equal number of evaluations" + ); let compute_dot_product = |evals: &[F], coeff: &[F]| -> F { // Ensure lengths match and compute the dot product @@ -128,7 +144,6 @@ where self.0.mv_parameters.num_variables, ) }) - // point might not be of the form (z, z^2, z^4, ...) .chain(std::iter::once(MultilinearPoint(point.to_vec()))) .collect(); @@ -146,36 +161,25 @@ where let polynomial = CoefficientList::combine(witness.polys, random_coeff); - let mut sumcheck_prover = None; - let folding_randomness = if self.0.initial_statement { - let [combination_randomness_gen] = merlin.challenge_scalars()?; - let combination_randomness = - expand_randomness(combination_randomness_gen, initial_claims.len()); - - sumcheck_prover = Some(SumcheckProverNotSkipping::new( - polynomial.clone(), - &initial_claims, - &combination_randomness, - &initial_answers, - )); + let [combination_randomness_gen] = merlin.challenge_scalars()?; + let combination_randomness = + expand_randomness(combination_randomness_gen, initial_claims.len()); - sumcheck_prover - .as_mut() - .unwrap() - .compute_sumcheck_polynomials::( - merlin, - self.0.folding_factor, - self.0.starting_folding_pow_bits, - )? - } else { - let mut folding_randomness = vec![F::ZERO; self.0.folding_factor]; - merlin.fill_challenge_scalars(&mut folding_randomness)?; + let mut sumcheck_prover = Some(SumcheckProverNotSkipping::new( + polynomial.clone(), + &initial_claims, + &combination_randomness, + &initial_answers, + )); - if self.0.starting_folding_pow_bits > 0. { - merlin.challenge_pow::(self.0.starting_folding_pow_bits)?; - } - MultilinearPoint(folding_randomness) - }; + let folding_randomness = sumcheck_prover + .as_mut() + .unwrap() + .compute_sumcheck_polynomials::( + merlin, + self.0.folding_factor, + self.0.starting_folding_pow_bits, + )?; let round_state = RoundState { domain: self.0.starting_domain.clone(), From 7d8342112f12b7b7c7095b0a0f519c9be52b74ad Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Mon, 16 Dec 2024 16:02:42 +0700 Subject: [PATCH 14/20] add batch-verify --- src/whir/prover.rs | 11 ++------ src/whir/verifier.rs | 65 ++++++++++++++++++++++++++++---------------- 2 files changed, 45 insertions(+), 31 deletions(-) diff --git a/src/whir/prover.rs b/src/whir/prover.rs index 0f275d0..d09a726 100644 --- a/src/whir/prover.rs +++ b/src/whir/prover.rs @@ -125,15 +125,10 @@ where "number of polynomials not equal number of evaluations" ); - let compute_dot_product = |evals: &[F], coeff: &[F]| -> F { - // Ensure lengths match and compute the dot product - zip_eq(evals, coeff) - .map(|(a, b)| *a * *b) // Element-wise multiplication - .sum() // Sum the products - }; + let compute_dot_product = + |evals: &[F], coeff: &[F]| -> F { zip_eq(evals, coeff).map(|(a, b)| *a * *b).sum() }; - let random_coeff = - utils::generate_random_vector_batch_open(merlin, witness.ood_answers.len())?; + let random_coeff = utils::generate_random_vector_batch_open(merlin, witness.polys.len())?; let initial_claims: Vec<_> = witness .ood_points diff --git a/src/whir/verifier.rs b/src/whir/verifier.rs index 40ac6af..640ef68 100644 --- a/src/whir/verifier.rs +++ b/src/whir/verifier.rs @@ -92,28 +92,12 @@ where let size = self.params.committment_ood_samples * num_polys; let mut ood_points = vec![F::ZERO; self.params.committment_ood_samples]; - let mut raw_ood_answers = vec![F::ZERO; size]; + let mut ood_answers = vec![F::ZERO; size]; if self.params.committment_ood_samples > 0 { arthur.fill_challenge_scalars(&mut ood_points)?; - arthur.fill_next_scalars(&mut raw_ood_answers)?; + arthur.fill_next_scalars(&mut ood_answers)?; } - let ood_answers = if num_polys > 1 { - let compute_dot_product = |evals: &[F], coeff: &[F]| -> F { - // Ensure lengths match and compute the dot product - zip_eq(evals, coeff) - .map(|(a, b)| *a * *b) // Element-wise multiplication - .sum() // Sum the products - }; - let random_coeff = utils::generate_random_vector_batch_verify(arthur, size)?; - raw_ood_answers - .chunks_exact(num_polys) - .map(|answer| compute_dot_product(answer, &random_coeff)) - .collect::>() - } else { - raw_ood_answers - }; - Ok(ParsedCommitment { root, ood_points, @@ -506,11 +490,48 @@ where + PoWChallenge + DigestReader, { - /* // We first do a pass in which we rederive all the FS challenges // Then we will check the algebraic part (so to optimise inversions) let parsed_commitment = self.parse_commitment(arthur)?; - let parsed = self.parse_proof(arthur, &parsed_commitment, statement, whir_proof)?; + + // parse proof + let num_polys = self.params.mv_parameters.num_polys; + let compute_dot_product = + |evals: &[F], coeff: &[F]| -> F { zip_eq(evals, coeff).map(|(a, b)| *a * *b).sum() }; + + let random_coeff = utils::generate_random_vector_batch_verify(arthur, num_polys)?; + + let initial_claims: Vec<_> = parsed_commitment + .ood_points + .clone() + .into_iter() + .map(|ood_point| { + MultilinearPoint::expand_from_univariate( + ood_point, + self.params.mv_parameters.num_variables, + ) + }) + .chain(std::iter::once(MultilinearPoint(point.to_vec()))) + .collect(); + + let ood_answers = parsed_commitment + .ood_answers + .clone() + .chunks_exact(num_polys) + .map(|answer| compute_dot_product(answer, &random_coeff)) + .collect::>(); + let eval = compute_dot_product(evals, &random_coeff); + + let initial_answers: Vec<_> = ood_answers + .into_iter() + .chain(std::iter::once(eval)) + .collect(); + + let statement = Statement { + points: initial_claims, + evaluations: initial_answers, + }; + let parsed = self.parse_proof(arthur, &parsed_commitment, &statement, whir_proof)?; let computed_folds = self.compute_folds(&parsed); @@ -627,7 +648,7 @@ where }; // Check the final sumcheck evaluation - let evaluation_of_v_poly = self.compute_v_poly(&parsed_commitment, statement, &parsed); + let evaluation_of_v_poly = self.compute_v_poly(&parsed_commitment, &statement, &parsed); if prev_sumcheck_poly_eval != evaluation_of_v_poly @@ -639,8 +660,6 @@ where } Ok(()) - */ - todo!() } pub fn verify( From 6157e84b7262eab17eb477bd23cf318733a2b170 Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Tue, 17 Dec 2024 09:28:34 +0700 Subject: [PATCH 15/20] add unit test; debug wip --- src/ceno_binding/mod.rs | 2 +- src/ceno_binding/pcs.rs | 40 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/src/ceno_binding/mod.rs b/src/ceno_binding/mod.rs index abef8f2..53d35f3 100644 --- a/src/ceno_binding/mod.rs +++ b/src/ceno_binding/mod.rs @@ -66,6 +66,6 @@ pub trait PolynomialCommitmentScheme: Clone { point: &[E], evals: &[E], proof: &Self::Proof, - transcript: &mut Self::Transcript, + transcript: &Self::Transcript, ) -> Result<(), Error>; } diff --git a/src/ceno_binding/pcs.rs b/src/ceno_binding/pcs.rs index f1929ae..cf4e7fa 100644 --- a/src/ceno_binding/pcs.rs +++ b/src/ceno_binding/pcs.rs @@ -160,7 +160,7 @@ where point: &[E], evals: &[E], proof: &Self::Proof, - transcript: &mut Self::Transcript, + transcript: &Self::Transcript, ) -> Result<(), Error> { let reps = 1000; let verifier = Verifier::new(vp.clone()); @@ -185,7 +185,7 @@ mod tests { use crate::crypto::fields::Field64_2 as F; #[test] - fn single_point_verify() { + fn single_poly_verify() { let poly_size = 10; let num_coeffs = 1 << poly_size; let pp = Whir::::setup(poly_size, 1); @@ -210,4 +210,40 @@ mod tests { let proof = Whir::::open(&pp, witness, &point, &eval, &mut merlin).unwrap(); Whir::::verify(&pp, &point, &eval, &proof, &merlin).unwrap(); } + + #[test] + fn simple_batch_polys_verify() { + let poly_size = 10; + let num_coeffs = 1 << poly_size; + let num_polys = 1 << 3; + let pp = Whir::::setup(poly_size, num_polys); + + let mut polys = Vec::new(); + for _ in 0..num_polys { + let poly = CoefficientList::new( + (0..num_coeffs) + .map(::BasePrimeField::from) + .collect(), + ); + polys.push(poly); + } + + let io = IOPattern::::new("🌪️") + .commit_statement(&pp) + .add_whir_proof(&pp); + let mut merlin = io.to_merlin(); + + let witness = Whir::::batch_commit_and_write(&pp, &polys, &mut merlin).unwrap(); + + let mut rng = rand::thread_rng(); + let point: Vec = (0..poly_size).map(|_| F::from(rng.gen::())).collect(); + let evals = polys + .iter() + .map(|poly| poly.evaluate_at_extension(&MultilinearPoint(point.clone()))) + .collect::>(); + + let proof = + Whir::::simple_batch_open(&pp, witness, &point, &evals, &mut merlin).unwrap(); + Whir::::simple_batch_verify(&pp, &point, &evals, &proof, &merlin).unwrap(); + } } From 302122934e50493b3d1585f8e164671115021149 Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Wed, 18 Dec 2024 10:23:14 +0700 Subject: [PATCH 16/20] fix transcript setup --- src/fs_utils.rs | 8 ++++---- src/whir/iopattern.rs | 8 ++++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/fs_utils.rs b/src/fs_utils.rs index 4a665dc..57bb3cc 100644 --- a/src/fs_utils.rs +++ b/src/fs_utils.rs @@ -2,7 +2,7 @@ use ark_ff::Field; use nimue::plugins::ark::FieldIOPattern; use nimue_pow::PoWIOPattern; pub trait OODIOPattern { - fn add_ood(self, num_samples: usize) -> Self; + fn add_ood(self, num_samples: usize, num_answers: usize) -> Self; } impl OODIOPattern for IOPattern @@ -10,10 +10,10 @@ where F: Field, IOPattern: FieldIOPattern, { - fn add_ood(self, num_samples: usize) -> Self { + fn add_ood(self, num_samples: usize, num_answers: usize) -> Self { if num_samples > 0 { self.challenge_scalars(num_samples, "ood_query") - .add_scalars(num_samples, "ood_ans") + .add_scalars(num_answers, "ood_ans") } else { self } @@ -24,7 +24,7 @@ pub trait WhirPoWIOPattern { fn pow(self, bits: f64) -> Self; } -impl WhirPoWIOPattern for IOPattern +impl WhirPoWIOPattern for IOPattern where IOPattern: PoWIOPattern, { diff --git a/src/whir/iopattern.rs b/src/whir/iopattern.rs index 649dda8..70fd665 100644 --- a/src/whir/iopattern.rs +++ b/src/whir/iopattern.rs @@ -41,7 +41,8 @@ where let mut this = self.add_digest("merkle_digest"); if params.committment_ood_samples > 0 { assert!(params.initial_statement); - this = this.add_ood(params.committment_ood_samples); + let num_answers = params.committment_ood_samples * params.mv_parameters.num_polys; + this = this.add_ood(params.committment_ood_samples, num_answers); } this } @@ -50,6 +51,9 @@ where mut self, params: &WhirConfig, ) -> Self { + if params.mv_parameters.num_polys > 1 { + self = self.challenge_scalars(1, "batch_poly_combination_randomness"); + } // TODO: Add statement if params.initial_statement { self = self @@ -67,7 +71,7 @@ where let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8; self = self .add_digest("merkle_digest") - .add_ood(r.ood_samples) + .add_ood(r.ood_samples, r.ood_samples) .challenge_bytes(r.num_queries * domain_size_bytes, "stir_queries") .pow(r.pow_bits) .challenge_scalars(1, "combination_randomness") From 4fb4b2bd13c117587585dfcf867efe3f881e6897 Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Wed, 18 Dec 2024 13:44:23 +0700 Subject: [PATCH 17/20] debugging --- src/utils.rs | 6 ++++-- src/whir/committer.rs | 4 ++-- src/whir/verifier.rs | 15 ++++++++++++--- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/utils.rs b/src/utils.rs index a4e2c19..bf31ae0 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -108,6 +108,8 @@ pub fn stack_evaluations_mut(evals: &mut [F], folding_factor: usize) { /// Takes a vector of matrix and stacking them horizontally /// Use in-place matrix transposes to avoid data copy +/// each matrix has domain_size elements +/// each matrix has shape (*, 1<( evals: Vec, domain_size: usize, @@ -168,9 +170,9 @@ mod tests { let num = 256; let domain_size = 128; - let folding_factor = 3; + let folding_factor = 2; let fold_size = 1 << folding_factor; - assert_eq!(num % fold_size, 0); + assert_eq!(domain_size % fold_size, 0); let evals: Vec<_> = (0..num as u64).map(F::from).collect(); let stacked = horizontal_stacking(evals, domain_size, folding_factor); diff --git a/src/whir/committer.rs b/src/whir/committer.rs index 7830a2d..73b64c6 100644 --- a/src/whir/committer.rs +++ b/src/whir/committer.rs @@ -183,8 +183,8 @@ where .collect::>() }) .collect::>(); - //let folded_evals = - // utils::horizontal_stacking(folded_evals, base_domain.size(), self.0.folding_factor); + let folded_evals = + utils::horizontal_stacking(folded_evals, base_domain.size(), self.0.folding_factor); // Group folds together as a leaf. let fold_size = 1 << self.0.folding_factor; diff --git a/src/whir/verifier.rs b/src/whir/verifier.rs index 640ef68..bb0bb66 100644 --- a/src/whir/verifier.rs +++ b/src/whir/verifier.rs @@ -204,7 +204,10 @@ where .unwrap() || merkle_proof.leaf_indexes != stir_challenges_indexes { - return Err(ProofError::InvalidProof); + if r != 0 { + println!("hehe0, leafs={:?}", &answers.len(),); + return Err(ProofError::InvalidProof); + } } if round_params.pow_bits > 0. { @@ -549,7 +552,10 @@ where .map(|(ans, rand)| ans * rand) .sum() { - return Err(ProofError::InvalidProof); + if false { + println!("hehe1"); + return Err(ProofError::InvalidProof); + } } // Check the rest of the rounds @@ -656,7 +662,10 @@ where .final_coefficients .evaluate(&parsed.final_sumcheck_randomness) { - return Err(ProofError::InvalidProof); + if false { + println!("hehe2"); + return Err(ProofError::InvalidProof); + } } Ok(()) From 6bd396fa76b72e67219f57b8cf292bbe18da7802 Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Wed, 18 Dec 2024 14:56:54 +0700 Subject: [PATCH 18/20] correct initial stir evaluation in prover --- src/whir/prover.rs | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/src/whir/prover.rs b/src/whir/prover.rs index d09a726..6f5c2d8 100644 --- a/src/whir/prover.rs +++ b/src/whir/prover.rs @@ -154,7 +154,7 @@ where .chain(std::iter::once(eval)) .collect(); - let polynomial = CoefficientList::combine(witness.polys, random_coeff); + let polynomial = CoefficientList::combine(witness.polys, random_coeff.clone()); let [combination_randomness_gen] = merlin.challenge_scalars()?; let combination_randomness = @@ -185,6 +185,7 @@ where prev_merkle: witness.merkle_tree, prev_merkle_answers: witness.merkle_leaves, merkle_proofs: vec![], + batching_randomness: Some(random_coeff), }; self.round(merlin, round_state) @@ -272,6 +273,7 @@ where prev_merkle: witness.merkle_tree, prev_merkle_answers: witness.merkle_leaves, merkle_proofs: vec![], + batching_randomness: None, }; self.round(merlin, round_state) @@ -416,10 +418,29 @@ where .prev_merkle .generate_multi_proof(stir_challenges_indexes.clone()) .unwrap(); - let fold_size = 1 << self.0.folding_factor; + // leaves of first round is not combined yet + let fold_size = if round_state.batching_randomness.is_some() { + (1 << self.0.folding_factor) * self.0.mv_parameters.num_polys + } else { + 1 << self.0.folding_factor + }; let answers: Vec<_> = stir_challenges_indexes .iter() .map(|i| round_state.prev_merkle_answers[i * fold_size..(i + 1) * fold_size].to_vec()) + .map(|raw_answer| match &round_state.batching_randomness { + Some(random_coeff) => { + let chunk_size = 1 << self.0.folding_factor; + let num_polys = self.0.mv_parameters.num_polys; + let mut res = vec![F::ZERO; chunk_size]; + for i in 0..chunk_size { + for j in 0..num_polys { + res[i] += raw_answer[i + j * chunk_size] * random_coeff[j]; + } + } + res + } + _ => raw_answer, + }) .collect(); // Evaluate answers in the folding randomness. let mut stir_evaluations = ood_answers.clone(); @@ -501,6 +522,7 @@ where prev_merkle: merkle_tree, prev_merkle_answers: folded_evals, merkle_proofs: round_state.merkle_proofs, + batching_randomness: None, }; self.round(merlin, round_state) @@ -520,4 +542,5 @@ where prev_merkle: MerkleTree, prev_merkle_answers: Vec, merkle_proofs: Vec<(MultiPath, Vec>)>, + batching_randomness: Option>, } From 7f6ccb9e51f9c955d6e6c3bfb2773f07d7715ea2 Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Wed, 18 Dec 2024 17:09:42 +0700 Subject: [PATCH 19/20] bug fixes; last remaining --- src/whir/prover.rs | 10 ++++++--- src/whir/verifier.rs | 53 ++++++++++++++++++++++++++++++-------------- 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/src/whir/prover.rs b/src/whir/prover.rs index 6f5c2d8..c3b7e81 100644 --- a/src/whir/prover.rs +++ b/src/whir/prover.rs @@ -424,9 +424,13 @@ where } else { 1 << self.0.folding_factor }; - let answers: Vec<_> = stir_challenges_indexes + + let raw_answers: Vec<_> = stir_challenges_indexes .iter() .map(|i| round_state.prev_merkle_answers[i * fold_size..(i + 1) * fold_size].to_vec()) + .collect(); + let answers: Vec<_> = raw_answers + .iter() .map(|raw_answer| match &round_state.batching_randomness { Some(random_coeff) => { let chunk_size = 1 << self.0.folding_factor; @@ -439,7 +443,7 @@ where } res } - _ => raw_answer, + _ => raw_answer.clone(), }) .collect(); // Evaluate answers in the folding randomness. @@ -474,7 +478,7 @@ where CoefficientList::new(answers.to_vec()).evaluate(&round_state.folding_randomness) })), } - round_state.merkle_proofs.push((merkle_proof, answers)); + round_state.merkle_proofs.push((merkle_proof, raw_answers)); // PoW if round_params.pow_bits > 0. { diff --git a/src/whir/verifier.rs b/src/whir/verifier.rs index bb0bb66..892b3f7 100644 --- a/src/whir/verifier.rs +++ b/src/whir/verifier.rs @@ -111,6 +111,7 @@ where parsed_commitment: &ParsedCommitment, statement: &Statement, // Will be needed later whir_proof: &WhirProof, + batched_randomness: Vec, // used in first round ) -> ProofResult> where Arthur: FieldReader @@ -204,12 +205,28 @@ where .unwrap() || merkle_proof.leaf_indexes != stir_challenges_indexes { - if r != 0 { - println!("hehe0, leafs={:?}", &answers.len(),); - return Err(ProofError::InvalidProof); - } + return Err(ProofError::InvalidProof); } + let combined_answers: Vec<_> = answers + .into_iter() + .map(|raw_answer| { + if batched_randomness.len() > 0 { + let chunk_size = 1 << self.params.folding_factor; + let num_polys = self.params.mv_parameters.num_polys; + let mut res = vec![F::ZERO; chunk_size]; + for i in 0..chunk_size { + for j in 0..num_polys { + res[i] += raw_answer[i + j * chunk_size] * batched_randomness[j]; + } + } + res + } else { + raw_answer.clone() + } + }) + .collect(); + if round_params.pow_bits > 0. { arthur.challenge_pow::(round_params.pow_bits)?; } @@ -242,7 +259,7 @@ where ood_answers, stir_challenges_indexes, stir_challenges_points, - stir_challenges_answers: answers.to_vec(), + stir_challenges_answers: combined_answers, combination_randomness, sumcheck_rounds, domain_gen_inv, @@ -534,7 +551,13 @@ where points: initial_claims, evaluations: initial_answers, }; - let parsed = self.parse_proof(arthur, &parsed_commitment, &statement, whir_proof)?; + let parsed = self.parse_proof( + arthur, + &parsed_commitment, + &statement, + whir_proof, + random_coeff, + )?; let computed_folds = self.compute_folds(&parsed); @@ -543,19 +566,15 @@ where // Check the first polynomial let (mut prev_poly, mut randomness) = round.clone(); if prev_poly.sum_over_hypercube() - != parsed_commitment - .ood_answers - .iter() - .copied() - .chain(statement.evaluations.clone()) + != statement + .evaluations + .clone() + .into_iter() .zip(&parsed.initial_combination_randomness) .map(|(ans, rand)| ans * rand) .sum() { - if false { - println!("hehe1"); - return Err(ProofError::InvalidProof); - } + return Err(ProofError::InvalidProof); } // Check the rest of the rounds @@ -663,7 +682,7 @@ where .evaluate(&parsed.final_sumcheck_randomness) { if false { - println!("hehe2"); + println!("hehe"); return Err(ProofError::InvalidProof); } } @@ -688,7 +707,7 @@ where // We first do a pass in which we rederive all the FS challenges // Then we will check the algebraic part (so to optimise inversions) let parsed_commitment = self.parse_commitment(arthur)?; - let parsed = self.parse_proof(arthur, &parsed_commitment, statement, whir_proof)?; + let parsed = self.parse_proof(arthur, &parsed_commitment, statement, whir_proof, vec![])?; let computed_folds = self.compute_folds(&parsed); From 84b4b5f0d6fcd46955e791a5c6fdc9859d1d590e Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Wed, 18 Dec 2024 17:34:58 +0700 Subject: [PATCH 20/20] bug fixed --- src/whir/verifier.rs | 58 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 53 insertions(+), 5 deletions(-) diff --git a/src/whir/verifier.rs b/src/whir/verifier.rs index 892b3f7..84d3ecc 100644 --- a/src/whir/verifier.rs +++ b/src/whir/verifier.rs @@ -341,6 +341,57 @@ where }) } + /// this is copied and modified from `fn compute_v_poly` + /// to avoid modify the original function for compatibility + fn compute_v_poly_for_batched(&self, statement: &Statement, proof: &ParsedProof) -> F { + let mut num_variables = self.params.mv_parameters.num_variables; + + let mut folding_randomness = MultilinearPoint( + iter::once(&proof.final_sumcheck_randomness.0) + .chain(iter::once(&proof.final_folding_randomness.0)) + .chain(proof.rounds.iter().rev().map(|r| &r.folding_randomness.0)) + .flatten() + .copied() + .collect(), + ); + + let mut value = statement + .points + .iter() + .zip(&proof.initial_combination_randomness) + .map(|(point, randomness)| *randomness * eq_poly_outside(&point, &folding_randomness)) + .sum(); + + for round_proof in &proof.rounds { + num_variables -= self.params.folding_factor; + folding_randomness = MultilinearPoint(folding_randomness.0[..num_variables].to_vec()); + + let ood_points = &round_proof.ood_points; + let stir_challenges_points = &round_proof.stir_challenges_points; + let stir_challenges: Vec<_> = ood_points + .iter() + .chain(stir_challenges_points) + .cloned() + .map(|univariate| { + MultilinearPoint::expand_from_univariate(univariate, num_variables) + // TODO: + // Maybe refactor outside + }) + .collect(); + + let sum_of_claims: F = stir_challenges + .into_iter() + .map(|point| eq_poly_outside(&point, &folding_randomness)) + .zip(&round_proof.combination_randomness) + .map(|(point, rand)| point * rand) + .sum(); + + value += sum_of_claims; + } + + value + } + fn compute_v_poly( &self, parsed_commitment: &ParsedCommitment, @@ -673,7 +724,7 @@ where }; // Check the final sumcheck evaluation - let evaluation_of_v_poly = self.compute_v_poly(&parsed_commitment, &statement, &parsed); + let evaluation_of_v_poly = self.compute_v_poly_for_batched(&statement, &parsed); if prev_sumcheck_poly_eval != evaluation_of_v_poly @@ -681,10 +732,7 @@ where .final_coefficients .evaluate(&parsed.final_sumcheck_randomness) { - if false { - println!("hehe"); - return Err(ProofError::InvalidProof); - } + return Err(ProofError::InvalidProof); } Ok(())