diff --git a/Cargo.lock b/Cargo.lock index 048ab85..fb0e039 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1098,6 +1098,7 @@ dependencies = [ "clap", "derivative", "goldilocks", + "itertools 0.13.0", "lazy_static", "nimue", "nimue-pow", diff --git a/Cargo.toml b/Cargo.toml index ab8c041..39ca403 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ rayon = { version = "1.10.0", optional = true } goldilocks = { git = "https://github.com/scroll-tech/ceno-Goldilocks" } thiserror = "1" +itertools = { version = "0.13", default-features = false } [profile.release] debug = true diff --git a/src/bin/benchmark.rs b/src/bin/benchmark.rs index 9c6da26..f88b444 100644 --- a/src/bin/benchmark.rs +++ b/src/bin/benchmark.rs @@ -230,7 +230,7 @@ fn run_whir( let num_coeffs = 1 << num_variables; - let mv_params = MultivariateParameters::::new(num_variables); + let mv_params = MultivariateParameters::::new(num_variables, 1); let whir_params = WhirParameters:: { initial_statement: true, diff --git a/src/bin/main.rs b/src/bin/main.rs index cdba3d1..4b87f6a 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -229,7 +229,7 @@ fn run_whir_as_ldt( let num_coeffs = 1 << num_variables; - let mv_params = MultivariateParameters::::new(num_variables); + let mv_params = MultivariateParameters::::new(num_variables, 1); let whir_params = WhirParameters:: { initial_statement: false, @@ -337,7 +337,7 @@ fn run_whir_pcs( let num_coeffs = 1 << num_variables; - let mv_params = MultivariateParameters::::new(num_variables); + let mv_params = MultivariateParameters::::new(num_variables, 1); let whir_params = WhirParameters:: { initial_statement: true, diff --git a/src/ceno_binding/mod.rs b/src/ceno_binding/mod.rs index 5d7e5f4..53d35f3 100644 --- a/src/ceno_binding/mod.rs +++ b/src/ceno_binding/mod.rs @@ -8,31 +8,34 @@ use std::fmt::Debug; pub enum Error { #[error(transparent)] ProofError(#[from] nimue::ProofError), + #[error("InvalidPcsParams")] + InvalidPcsParam, } pub trait PolynomialCommitmentScheme: Clone { type Param: Clone; - type CommitmentWithData; + type CommitmentWithWitness; type Proof: Clone + CanonicalSerialize + CanonicalDeserialize; type Poly: Clone; type Transcript; - fn setup(poly_size: usize) -> Self::Param; + fn setup(poly_size: usize, num_polys: usize) -> Self::Param; fn commit_and_write( pp: &Self::Param, poly: &Self::Poly, transcript: &mut Self::Transcript, - ) -> Result; + ) -> Result; - fn batch_commit( + fn batch_commit_and_write( pp: &Self::Param, polys: &[Self::Poly], - ) -> Result; + transcript: &mut Self::Transcript, + ) -> Result; fn open( pp: &Self::Param, - comm: Self::CommitmentWithData, + comm: Self::CommitmentWithWitness, point: &[E], eval: &E, transcript: &mut Self::Transcript, @@ -42,10 +45,9 @@ pub trait PolynomialCommitmentScheme: Clone { /// 1. Open at one point /// 2. All the polynomials share the same commitment. /// 3. The point is already a random point generated by a sum-check. - fn batch_open( + fn simple_batch_open( pp: &Self::Param, - polys: &[Self::Poly], - comm: Self::CommitmentWithData, + comm: Self::CommitmentWithWitness, point: &[E], evals: &[E], transcript: &mut Self::Transcript, @@ -59,11 +61,11 @@ pub trait PolynomialCommitmentScheme: Clone { transcript: &Self::Transcript, ) -> Result<(), Error>; - fn batch_verify( + fn simple_batch_verify( vp: &Self::Param, point: &[E], evals: &[E], proof: &Self::Proof, - transcript: &mut Self::Transcript, + transcript: &Self::Transcript, ) -> Result<(), Error>; } diff --git a/src/ceno_binding/pcs.rs b/src/ceno_binding/pcs.rs index 12d1496..cf4e7fa 100644 --- a/src/ceno_binding/pcs.rs +++ b/src/ceno_binding/pcs.rs @@ -5,7 +5,7 @@ use crate::parameters::{ }; use crate::poly_utils::{coeffs::CoefficientList, MultilinearPoint}; use crate::whir::{ - committer::{Committer, Witness}, + committer::{Committer, Witnesses}, iopattern::WhirIOPattern, parameters::WhirConfig, prover::Prover, @@ -34,14 +34,14 @@ where E: FftField + CanonicalSerialize + CanonicalDeserialize, { type Param = WhirPCSConfig; - type CommitmentWithData = Witness>; + type CommitmentWithWitness = Witnesses>; type Proof = WhirProof, E>; // TODO: support both base and extension fields type Poly = CoefficientList; type Transcript = Merlin; - fn setup(poly_size: usize) -> Self::Param { - let mv_params = MultivariateParameters::::new(poly_size); + fn setup(poly_size: usize, num_polys: usize) -> Self::Param { + let mv_params = MultivariateParameters::::new(poly_size, num_polys); let starting_rate = 1; let pow_bits = default_max_pow(poly_size, starting_rate); let mut rng = ChaCha8Rng::from_seed([0u8; 32]); @@ -67,22 +67,41 @@ where pp: &Self::Param, poly: &Self::Poly, transcript: &mut Self::Transcript, - ) -> Result { + ) -> Result { let committer = Committer::new(pp.clone()); let witness = committer.commit(transcript, poly.clone())?; - Ok(witness) + Ok(witness.into()) } - fn batch_commit( - _pp: &Self::Param, - _polys: &[Self::Poly], - ) -> Result { - todo!() + // Assumption: + // 1. there must be at least one polynomial + // 2. all polynomials are in base field + // (TODO: this assumption is from the whir implementation, + // if we are going to support extension field, need modify whir's implementation) + // 3. all polynomials must have the same number of variables + fn batch_commit_and_write( + pp: &Self::Param, + polys: &[Self::Poly], + transcript: &mut Self::Transcript, + ) -> Result { + if polys.is_empty() { + return Err(Error::InvalidPcsParam); + } + + for i in 1..polys.len() { + if polys[i].num_variables() != polys[0].num_variables() { + return Err(Error::InvalidPcsParam); + } + } + + let committer = Committer::new(pp.clone()); + let witness = committer.batch_commit(transcript, polys)?; + Ok(witness) } fn open( pp: &Self::Param, - witness: Self::CommitmentWithData, + witness: Self::CommitmentWithWitness, point: &[E], eval: &E, transcript: &mut Self::Transcript, @@ -93,19 +112,21 @@ where evaluations: vec![eval.clone()], }; - let proof = prover.prove(transcript, statement, witness)?; + let proof = prover.prove(transcript, statement, witness.into())?; Ok(proof) } - fn batch_open( - _pp: &Self::Param, - _polys: &[Self::Poly], - _comm: Self::CommitmentWithData, - _point: &[E], - _evals: &[E], - _transcript: &mut Self::Transcript, + fn simple_batch_open( + pp: &Self::Param, + witnesses: Self::CommitmentWithWitness, + point: &[E], + evals: &[E], + transcript: &mut Self::Transcript, ) -> Result { - todo!() + assert_eq!(witnesses.polys.len(), evals.len()); + let prover = Prover(pp.clone()); + let proof = prover.simple_batch_prove(transcript, point, evals, witnesses)?; + Ok(proof) } fn verify( @@ -134,14 +155,24 @@ where Ok(()) } - fn batch_verify( - _vp: &Self::Param, - _point: &[E], - _evals: &[E], - _proof: &Self::Proof, - _transcript: &mut Self::Transcript, + fn simple_batch_verify( + vp: &Self::Param, + point: &[E], + evals: &[E], + proof: &Self::Proof, + transcript: &Self::Transcript, ) -> Result<(), Error> { - todo!() + let reps = 1000; + let verifier = Verifier::new(vp.clone()); + let io = IOPattern::::new("🌪️") + .commit_statement(&vp) + .add_whir_proof(&vp); + + for _ in 0..reps { + let mut arthur = io.to_arthur(transcript.transcript()); + verifier.simple_batch_verify(&mut arthur, point, evals, proof)?; + } + Ok(()) } } @@ -154,10 +185,10 @@ mod tests { use crate::crypto::fields::Field64_2 as F; #[test] - fn single_point_verify() { + fn single_poly_verify() { let poly_size = 10; let num_coeffs = 1 << poly_size; - let pp = Whir::::setup(poly_size); + let pp = Whir::::setup(poly_size, 1); let poly = CoefficientList::new( (0..num_coeffs) @@ -179,4 +210,40 @@ mod tests { let proof = Whir::::open(&pp, witness, &point, &eval, &mut merlin).unwrap(); Whir::::verify(&pp, &point, &eval, &proof, &merlin).unwrap(); } + + #[test] + fn simple_batch_polys_verify() { + let poly_size = 10; + let num_coeffs = 1 << poly_size; + let num_polys = 1 << 3; + let pp = Whir::::setup(poly_size, num_polys); + + let mut polys = Vec::new(); + for _ in 0..num_polys { + let poly = CoefficientList::new( + (0..num_coeffs) + .map(::BasePrimeField::from) + .collect(), + ); + polys.push(poly); + } + + let io = IOPattern::::new("🌪️") + .commit_statement(&pp) + .add_whir_proof(&pp); + let mut merlin = io.to_merlin(); + + let witness = Whir::::batch_commit_and_write(&pp, &polys, &mut merlin).unwrap(); + + let mut rng = rand::thread_rng(); + let point: Vec = (0..poly_size).map(|_| F::from(rng.gen::())).collect(); + let evals = polys + .iter() + .map(|poly| poly.evaluate_at_extension(&MultilinearPoint(point.clone()))) + .collect::>(); + + let proof = + Whir::::simple_batch_open(&pp, witness, &point, &evals, &mut merlin).unwrap(); + Whir::::simple_batch_verify(&pp, &point, &evals, &proof, &merlin).unwrap(); + } } diff --git a/src/fs_utils.rs b/src/fs_utils.rs index 4a665dc..57bb3cc 100644 --- a/src/fs_utils.rs +++ b/src/fs_utils.rs @@ -2,7 +2,7 @@ use ark_ff::Field; use nimue::plugins::ark::FieldIOPattern; use nimue_pow::PoWIOPattern; pub trait OODIOPattern { - fn add_ood(self, num_samples: usize) -> Self; + fn add_ood(self, num_samples: usize, num_answers: usize) -> Self; } impl OODIOPattern for IOPattern @@ -10,10 +10,10 @@ where F: Field, IOPattern: FieldIOPattern, { - fn add_ood(self, num_samples: usize) -> Self { + fn add_ood(self, num_samples: usize, num_answers: usize) -> Self { if num_samples > 0 { self.challenge_scalars(num_samples, "ood_query") - .add_scalars(num_samples, "ood_ans") + .add_scalars(num_answers, "ood_ans") } else { self } @@ -24,7 +24,7 @@ pub trait WhirPoWIOPattern { fn pow(self, bits: f64) -> Self; } -impl WhirPoWIOPattern for IOPattern +impl WhirPoWIOPattern for IOPattern where IOPattern: PoWIOPattern, { diff --git a/src/parameters.rs b/src/parameters.rs index 633670d..3877863 100644 --- a/src/parameters.rs +++ b/src/parameters.rs @@ -46,13 +46,15 @@ impl FromStr for SoundnessType { #[derive(Debug, Clone, Copy)] pub struct MultivariateParameters { pub(crate) num_variables: usize, + pub(crate) num_polys: usize, _field: PhantomData, } impl MultivariateParameters { - pub fn new(num_variables: usize) -> Self { + pub fn new(num_variables: usize, num_polys: usize) -> Self { Self { num_variables, + num_polys, _field: PhantomData, } } @@ -60,7 +62,11 @@ impl MultivariateParameters { impl Display for MultivariateParameters { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Number of variables: {}", self.num_variables) + write!( + f, + "Number of polynomials: {}, Number of variables: {}", + self.num_polys, self.num_variables + ) } } diff --git a/src/poly_utils/coeffs.rs b/src/poly_utils/coeffs.rs index 5fbbea5..369a272 100644 --- a/src/poly_utils/coeffs.rs +++ b/src/poly_utils/coeffs.rs @@ -29,6 +29,21 @@ impl CoefficientList where F: Field, { + fn coeff_at(&self, index: usize) -> F { + self.coeffs[index] + } + + pub fn combine(polys: Vec, coeffs: Vec) -> Self { + let mut combined_coeffs = vec![F::ZERO; polys[0].coeffs.len()]; + polys.iter().enumerate().for_each(|(poly_index, poly)| { + for i in 0..combined_coeffs.len() { + combined_coeffs[i] += poly.coeff_at(i) * coeffs[poly_index]; + } + }); + + Self::new(combined_coeffs) + } + /// Evaluate the given polynomial at `point` from {0,1}^n pub fn evaluate_hypercube(&self, point: BinaryHypercubePoint) -> F { assert_eq!(self.coeffs.len(), 1 << self.num_variables); diff --git a/src/utils.rs b/src/utils.rs index fc0104a..bf31ae0 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,5 +1,16 @@ use crate::ntt::transpose; +use crate::whir::fs_utils::{DigestReader, DigestWriter}; + +use ark_crypto_primitives::merkle_tree::Config; use ark_ff::Field; +use nimue::{ + plugins::ark::{FieldChallenges, FieldReader, FieldWriter}, + ByteChallenges, ByteReader, ByteWriter, ProofResult, +}; +use nimue_pow::PoWChallenge; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; use std::collections::BTreeSet; // checks whether the given number n is a power of two. @@ -83,11 +94,99 @@ pub fn stack_evaluations(mut evals: Vec, folding_factor: usize) -> evals } +/// Takes the vector of evaluations (assume that evals[i] = f(omega^i)) +/// and folds them into a vector of such that folded_evals[i] = [f(omega^(i + k * j)) for j in 0..folding_factor] +/// This function will mutate the function without return +pub fn stack_evaluations_mut(evals: &mut [F], folding_factor: usize) { + let folding_factor_exp = 1 << folding_factor; + assert!(evals.len() % folding_factor_exp == 0); + let size_of_new_domain = evals.len() / folding_factor_exp; + + // interpret evals as (folding_factor_exp x size_of_new_domain)-matrix and transpose in-place + transpose(evals, folding_factor_exp, size_of_new_domain); +} + +/// Takes a vector of matrix and stacking them horizontally +/// Use in-place matrix transposes to avoid data copy +/// each matrix has domain_size elements +/// each matrix has shape (*, 1<( + evals: Vec, + domain_size: usize, + folding_factor: usize, +) -> Vec { + let fold_size = 1 << folding_factor; + let num_polys: usize = evals.len() / domain_size; + let num_polys_log2: usize = num_polys.ilog2() as usize; + + let mut evals = stack_evaluations(evals, num_polys_log2); + #[cfg(not(feature = "parallel"))] + let stacked_evals = evals.chunks_exact_mut(fold_size * num_polys); + #[cfg(feature = "parallel")] + let stacked_evals = evals.par_chunks_exact_mut(fold_size * num_polys); + stacked_evals.for_each(|eval| stack_evaluations_mut(eval, folding_factor)); + evals +} + +// generate a random vector for batching open +pub fn generate_random_vector_batch_open( + merlin: &mut Merlin, + size: usize, +) -> ProofResult> +where + F: Field, + MerkleConfig: Config, + Merlin: FieldChallenges + FieldWriter + ByteWriter + DigestWriter, +{ + let [gamma] = merlin.challenge_scalars()?; + let res = expand_randomness(gamma, size); + Ok(res) +} + +// generate a random vector for batching verify +pub fn generate_random_vector_batch_verify( + arthur: &mut Arthur, + size: usize, +) -> ProofResult> +where + F: Field, + MerkleConfig: Config, + Arthur: FieldChallenges + FieldReader + ByteReader + DigestReader, +{ + let [gamma] = arthur.challenge_scalars()?; + let res = expand_randomness(gamma, size); + Ok(res) +} + #[cfg(test)] mod tests { use crate::utils::base_decomposition; - use super::{is_power_of_two, stack_evaluations, to_binary}; + use super::{horizontal_stacking, is_power_of_two, stack_evaluations, to_binary}; + + #[test] + fn test_horizontal_stacking() { + use crate::crypto::fields::Field64 as F; + + let num = 256; + let domain_size = 128; + let folding_factor = 2; + let fold_size = 1 << folding_factor; + assert_eq!(domain_size % fold_size, 0); + let evals: Vec<_> = (0..num as u64).map(F::from).collect(); + + let stacked = horizontal_stacking(evals, domain_size, folding_factor); + assert_eq!(stacked.len(), num); + + for (i, fold) in stacked.chunks_exact(fold_size).enumerate() { + assert_eq!(fold.len(), fold_size); + let offset = if i % 2 == 0 { 0 } else { domain_size }; + let row_id = i / 2; + for j in 0..fold_size { + assert_eq!(fold[j], F::from((offset + row_id * fold_size + j) as u64)); + } + } + } #[test] fn test_evaluations_stack() { diff --git a/src/whir/committer.rs b/src/whir/committer.rs index 30d3484..73b64c6 100644 --- a/src/whir/committer.rs +++ b/src/whir/committer.rs @@ -27,6 +27,41 @@ where pub(crate) ood_answers: Vec, } +pub struct Witnesses +where + MerkleConfig: Config, +{ + pub(crate) polys: Vec>, + pub(crate) merkle_tree: MerkleTree, + pub(crate) merkle_leaves: Vec, + pub(crate) ood_points: Vec, + pub(crate) ood_answers: Vec, +} + +impl From> for Witnesses { + fn from(witness: Witness) -> Self { + Self { + polys: vec![witness.polynomial], + merkle_tree: witness.merkle_tree, + merkle_leaves: witness.merkle_leaves, + ood_points: witness.ood_points, + ood_answers: witness.ood_answers, + } + } +} + +impl From> for Witness { + fn from(witness: Witnesses) -> Self { + Self { + polynomial: witness.polys[0].clone(), + merkle_tree: witness.merkle_tree, + merkle_leaves: witness.merkle_leaves, + ood_points: witness.ood_points, + ood_answers: witness.ood_answers, + } + } +} + pub struct Committer(WhirConfig) where F: FftField, @@ -35,7 +70,7 @@ where impl Committer where F: FftField, - MerkleConfig: Config + MerkleConfig: Config, { pub fn new(config: WhirConfig) -> Self { Self(config) @@ -111,4 +146,96 @@ where ood_answers, }) } + + pub fn batch_commit( + &self, + merlin: &mut Merlin, + polys: &[CoefficientList], + ) -> ProofResult> + where + Merlin: FieldWriter + FieldChallenges + ByteWriter + DigestWriter, + { + let base_domain = self.0.starting_domain.base_domain.unwrap(); + let expansion = base_domain.size() / polys[0].num_coeffs(); + let evals = polys + .iter() + .map(|poly| expand_from_coeff(poly.coeffs(), expansion)) + .collect::>>(); + + assert_eq!(base_domain.size(), evals[0].len()); + + let folded_evals = evals + .into_iter() + .map(|evals| utils::stack_evaluations(evals, self.0.folding_factor)) + .map(|evals| { + restructure_evaluations( + evals, + self.0.fold_optimisation, + base_domain.group_gen(), + base_domain.group_gen_inv(), + self.0.folding_factor, + ) + }) + .flat_map(|evals| { + evals + .into_iter() + .map(F::from_base_prime_field) + .collect::>() + }) + .collect::>(); + let folded_evals = + utils::horizontal_stacking(folded_evals, base_domain.size(), self.0.folding_factor); + + // Group folds together as a leaf. + let fold_size = 1 << self.0.folding_factor; + #[cfg(not(feature = "parallel"))] + let leafs_iter = folded_evals.chunks_exact(fold_size * polys.len()); + #[cfg(feature = "parallel")] + let leafs_iter = folded_evals.par_chunks_exact(fold_size * polys.len()); + + let merkle_tree = MerkleTree::::new( + &self.0.leaf_hash_params, + &self.0.two_to_one_params, + leafs_iter, + ) + .unwrap(); + + let root = merkle_tree.root(); + + merlin.add_digest(root)?; + + let mut ood_points = vec![F::ZERO; self.0.committment_ood_samples]; + let mut ood_answers = vec![F::ZERO; polys.len() * self.0.committment_ood_samples]; + if self.0.committment_ood_samples > 0 { + merlin.fill_challenge_scalars(&mut ood_points)?; + ood_points + .iter() + .enumerate() + .for_each(|(point_index, ood_point)| { + for j in 0..polys.len() { + let eval = polys[j].evaluate_at_extension( + &MultilinearPoint::expand_from_univariate( + *ood_point, + self.0.mv_parameters.num_variables, + ), + ); + ood_answers[point_index * polys.len() + j] = eval; + } + }); + merlin.add_scalars(&ood_answers)?; + } + + let polys = polys + .into_iter() + .map(|poly| poly.clone().to_extension()) + .collect::>(); + + Ok(Witnesses { + polys, + merkle_tree, + merkle_leaves: folded_evals, + ood_points, + ood_answers, + }) + } } diff --git a/src/whir/iopattern.rs b/src/whir/iopattern.rs index 649dda8..70fd665 100644 --- a/src/whir/iopattern.rs +++ b/src/whir/iopattern.rs @@ -41,7 +41,8 @@ where let mut this = self.add_digest("merkle_digest"); if params.committment_ood_samples > 0 { assert!(params.initial_statement); - this = this.add_ood(params.committment_ood_samples); + let num_answers = params.committment_ood_samples * params.mv_parameters.num_polys; + this = this.add_ood(params.committment_ood_samples, num_answers); } this } @@ -50,6 +51,9 @@ where mut self, params: &WhirConfig, ) -> Self { + if params.mv_parameters.num_polys > 1 { + self = self.challenge_scalars(1, "batch_poly_combination_randomness"); + } // TODO: Add statement if params.initial_statement { self = self @@ -67,7 +71,7 @@ where let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8; self = self .add_digest("merkle_digest") - .add_ood(r.ood_samples) + .add_ood(r.ood_samples, r.ood_samples) .challenge_bytes(r.num_queries * domain_size_bytes, "stir_queries") .pow(r.pow_bits) .challenge_scalars(1, "combination_randomness") diff --git a/src/whir/mod.rs b/src/whir/mod.rs index 7db951b..26aac36 100644 --- a/src/whir/mod.rs +++ b/src/whir/mod.rs @@ -4,11 +4,11 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use crate::poly_utils::MultilinearPoint; pub mod committer; +pub mod fs_utils; pub mod iopattern; pub mod parameters; pub mod prover; pub mod verifier; -pub mod fs_utils; #[derive(Debug, Clone, Default)] pub struct Statement { @@ -67,7 +67,7 @@ mod tests { let mut rng = ark_std::test_rng(); let (leaf_hash_params, two_to_one_params) = merkle_tree::default_config::(&mut rng); - let mv_params = MultivariateParameters::::new(num_variables); + let mv_params = MultivariateParameters::::new(num_variables, 1); let whir_params = WhirParameters:: { initial_statement: true, diff --git a/src/whir/prover.rs b/src/whir/prover.rs index 0661b31..c3b7e81 100644 --- a/src/whir/prover.rs +++ b/src/whir/prover.rs @@ -1,4 +1,8 @@ -use super::{committer::Witness, parameters::WhirConfig, Statement, WhirProof}; +use super::{ + committer::{Witness, Witnesses}, + parameters::WhirConfig, + Statement, WhirProof, +}; use crate::{ domain::Domain, ntt::expand_from_coeff, @@ -11,9 +15,11 @@ use crate::{ sumcheck::prover_not_skipping::SumcheckProverNotSkipping, utils::{self, expand_randomness}, }; + use ark_crypto_primitives::merkle_tree::{Config, MerkleTree, MultiPath}; use ark_ff::FftField; use ark_poly::EvaluationDomain; +use itertools::zip_eq; use nimue::{ plugins::ark::{FieldChallenges, FieldWriter}, ByteChallenges, ByteWriter, ProofResult, @@ -65,6 +71,126 @@ where witness.polynomial.num_variables() == self.0.mv_parameters.num_variables } + fn validate_witnesses(&self, witness: &Witnesses) -> bool { + assert_eq!( + witness.polys.len(), + self.0.mv_parameters.num_polys, + "number of polynomials not match" + ); + assert_eq!( + witness.ood_points.len() * witness.polys.len(), + witness.ood_answers.len() + ); + if !self.0.initial_statement { + assert!(witness.ood_points.is_empty()); + } + assert!(!witness.polys.is_empty(), "Input polys cannot be empty"); + witness.polys.iter().skip(1).for_each(|poly| { + assert_eq!( + poly.num_variables(), + witness.polys[0].num_variables(), + "All polys must have the same number of variables" + ); + }); + witness.polys[0].num_variables() == self.0.mv_parameters.num_variables + } + + /// batch open a single point for multiple polys + pub fn simple_batch_prove( + &self, + merlin: &mut Merlin, + point: &[F], + evals: &[F], + witness: Witnesses, + ) -> ProofResult> + where + Merlin: FieldChallenges + + FieldWriter + + ByteChallenges + + ByteWriter + + PoWChallenge + + DigestWriter, + { + assert!(self.0.initial_statement, "must be true for pcs"); + assert!(self.validate_parameters()); + assert!(self.validate_witnesses(&witness)); + assert_eq!( + point.len(), + self.0.mv_parameters.num_variables, + "number of variables mismatch" + ); + assert_eq!( + evals.len(), + self.0.mv_parameters.num_polys, + "number of polynomials not equal number of evaluations" + ); + + let compute_dot_product = + |evals: &[F], coeff: &[F]| -> F { zip_eq(evals, coeff).map(|(a, b)| *a * *b).sum() }; + + let random_coeff = utils::generate_random_vector_batch_open(merlin, witness.polys.len())?; + + let initial_claims: Vec<_> = witness + .ood_points + .into_iter() + .map(|ood_point| { + MultilinearPoint::expand_from_univariate( + ood_point, + self.0.mv_parameters.num_variables, + ) + }) + .chain(std::iter::once(MultilinearPoint(point.to_vec()))) + .collect(); + + let ood_answers = witness + .ood_answers + .chunks_exact(witness.polys.len()) + .map(|answer| compute_dot_product(answer, &random_coeff)) + .collect::>(); + let eval = compute_dot_product(evals, &random_coeff); + + let initial_answers: Vec<_> = ood_answers + .into_iter() + .chain(std::iter::once(eval)) + .collect(); + + let polynomial = CoefficientList::combine(witness.polys, random_coeff.clone()); + + let [combination_randomness_gen] = merlin.challenge_scalars()?; + let combination_randomness = + expand_randomness(combination_randomness_gen, initial_claims.len()); + + let mut sumcheck_prover = Some(SumcheckProverNotSkipping::new( + polynomial.clone(), + &initial_claims, + &combination_randomness, + &initial_answers, + )); + + let folding_randomness = sumcheck_prover + .as_mut() + .unwrap() + .compute_sumcheck_polynomials::( + merlin, + self.0.folding_factor, + self.0.starting_folding_pow_bits, + )?; + + let round_state = RoundState { + domain: self.0.starting_domain.clone(), + round: 0, + sumcheck_prover, + folding_randomness, + coefficients: polynomial, + prev_merkle: witness.merkle_tree, + prev_merkle_answers: witness.merkle_leaves, + merkle_proofs: vec![], + batching_randomness: Some(random_coeff), + }; + + self.round(merlin, round_state) + } + pub fn prove( &self, merlin: &mut Merlin, @@ -72,7 +198,12 @@ where witness: Witness, ) -> ProofResult> where - Merlin: FieldChallenges + FieldWriter + ByteChallenges + ByteWriter + PoWChallenge + DigestWriter, + Merlin: FieldChallenges + + FieldWriter + + ByteChallenges + + ByteWriter + + PoWChallenge + + DigestWriter, { assert!(self.validate_parameters()); assert!(self.validate_statement(&statement)); @@ -142,6 +273,7 @@ where prev_merkle: witness.merkle_tree, prev_merkle_answers: witness.merkle_leaves, merkle_proofs: vec![], + batching_randomness: None, }; self.round(merlin, round_state) @@ -153,7 +285,12 @@ where mut round_state: RoundState, ) -> ProofResult> where - Merlin: FieldChallenges + ByteChallenges + FieldWriter + ByteWriter + PoWChallenge + DigestWriter, + Merlin: FieldChallenges + + ByteChallenges + + FieldWriter + + ByteWriter + + PoWChallenge + + DigestWriter, { // Fold the coefficients let folded_coefficients = round_state @@ -175,7 +312,7 @@ where self.0.final_queries, merlin, )?; - + let merkle_proof = round_state .prev_merkle .generate_multi_proof(final_challenge_indexes.clone()) @@ -281,11 +418,34 @@ where .prev_merkle .generate_multi_proof(stir_challenges_indexes.clone()) .unwrap(); - let fold_size = 1 << self.0.folding_factor; - let answers: Vec<_> = stir_challenges_indexes + // leaves of first round is not combined yet + let fold_size = if round_state.batching_randomness.is_some() { + (1 << self.0.folding_factor) * self.0.mv_parameters.num_polys + } else { + 1 << self.0.folding_factor + }; + + let raw_answers: Vec<_> = stir_challenges_indexes .iter() .map(|i| round_state.prev_merkle_answers[i * fold_size..(i + 1) * fold_size].to_vec()) .collect(); + let answers: Vec<_> = raw_answers + .iter() + .map(|raw_answer| match &round_state.batching_randomness { + Some(random_coeff) => { + let chunk_size = 1 << self.0.folding_factor; + let num_polys = self.0.mv_parameters.num_polys; + let mut res = vec![F::ZERO; chunk_size]; + for i in 0..chunk_size { + for j in 0..num_polys { + res[i] += raw_answer[i + j * chunk_size] * random_coeff[j]; + } + } + res + } + _ => raw_answer.clone(), + }) + .collect(); // Evaluate answers in the folding randomness. let mut stir_evaluations = ood_answers.clone(); match self.0.fold_optimisation { @@ -318,7 +478,7 @@ where CoefficientList::new(answers.to_vec()).evaluate(&round_state.folding_randomness) })), } - round_state.merkle_proofs.push((merkle_proof, answers)); + round_state.merkle_proofs.push((merkle_proof, raw_answers)); // PoW if round_params.pow_bits > 0. { @@ -350,11 +510,12 @@ where ) }); - let folding_randomness = sumcheck_prover.compute_sumcheck_polynomials::( - merlin, - self.0.folding_factor, - round_params.folding_pow_bits, - )?; + let folding_randomness = sumcheck_prover + .compute_sumcheck_polynomials::( + merlin, + self.0.folding_factor, + round_params.folding_pow_bits, + )?; let round_state = RoundState { round: round_state.round + 1, @@ -365,6 +526,7 @@ where prev_merkle: merkle_tree, prev_merkle_answers: folded_evals, merkle_proofs: round_state.merkle_proofs, + batching_randomness: None, }; self.round(merlin, round_state) @@ -384,4 +546,5 @@ where prev_merkle: MerkleTree, prev_merkle_answers: Vec, merkle_proofs: Vec<(MultiPath, Vec>)>, + batching_randomness: Option>, } diff --git a/src/whir/verifier.rs b/src/whir/verifier.rs index c92e1fe..84d3ecc 100644 --- a/src/whir/verifier.rs +++ b/src/whir/verifier.rs @@ -1,22 +1,26 @@ +use itertools::zip_eq; use std::iter; use ark_crypto_primitives::merkle_tree::Config; use ark_ff::FftField; use ark_poly::EvaluationDomain; use nimue::{ - plugins::ark::{FieldChallenges, FieldReader} - , ByteChallenges, ByteReader, ProofError, ProofResult, + plugins::ark::{FieldChallenges, FieldReader}, + ByteChallenges, ByteReader, ProofError, ProofResult, }; use nimue_pow::{self, PoWChallenge}; use super::{parameters::WhirConfig, Statement, WhirProof}; -use crate::whir::fs_utils::{get_challenge_stir_queries, DigestReader}; use crate::{ parameters::FoldType, poly_utils::{coeffs::CoefficientList, eq_poly_outside, fold::compute_fold, MultilinearPoint}, sumcheck::proof::SumcheckPolynomial, utils::expand_randomness, }; +use crate::{ + utils, + whir::fs_utils::{get_challenge_stir_queries, DigestReader}, +}; pub struct Verifier where @@ -84,8 +88,11 @@ where { let root = arthur.read_digest()?; + let num_polys = self.params.mv_parameters.num_polys; + let size = self.params.committment_ood_samples * num_polys; + let mut ood_points = vec![F::ZERO; self.params.committment_ood_samples]; - let mut ood_answers = vec![F::ZERO; self.params.committment_ood_samples]; + let mut ood_answers = vec![F::ZERO; size]; if self.params.committment_ood_samples > 0 { arthur.fill_challenge_scalars(&mut ood_points)?; arthur.fill_next_scalars(&mut ood_answers)?; @@ -104,9 +111,15 @@ where parsed_commitment: &ParsedCommitment, statement: &Statement, // Will be needed later whir_proof: &WhirProof, + batched_randomness: Vec, // used in first round ) -> ProofResult> where - Arthur: FieldReader + FieldChallenges + PoWChallenge + ByteReader + ByteChallenges + DigestReader, + Arthur: FieldReader + + FieldChallenges + + PoWChallenge + + ByteReader + + ByteChallenges + + DigestReader, { let mut sumcheck_rounds = Vec::new(); let mut folding_randomness: MultilinearPoint; @@ -195,6 +208,25 @@ where return Err(ProofError::InvalidProof); } + let combined_answers: Vec<_> = answers + .into_iter() + .map(|raw_answer| { + if batched_randomness.len() > 0 { + let chunk_size = 1 << self.params.folding_factor; + let num_polys = self.params.mv_parameters.num_polys; + let mut res = vec![F::ZERO; chunk_size]; + for i in 0..chunk_size { + for j in 0..num_polys { + res[i] += raw_answer[i + j * chunk_size] * batched_randomness[j]; + } + } + res + } else { + raw_answer.clone() + } + }) + .collect(); + if round_params.pow_bits > 0. { arthur.challenge_pow::(round_params.pow_bits)?; } @@ -227,7 +259,7 @@ where ood_answers, stir_challenges_indexes, stir_challenges_points, - stir_challenges_answers: answers.to_vec(), + stir_challenges_answers: combined_answers, combination_randomness, sumcheck_rounds, domain_gen_inv, @@ -309,6 +341,57 @@ where }) } + /// this is copied and modified from `fn compute_v_poly` + /// to avoid modify the original function for compatibility + fn compute_v_poly_for_batched(&self, statement: &Statement, proof: &ParsedProof) -> F { + let mut num_variables = self.params.mv_parameters.num_variables; + + let mut folding_randomness = MultilinearPoint( + iter::once(&proof.final_sumcheck_randomness.0) + .chain(iter::once(&proof.final_folding_randomness.0)) + .chain(proof.rounds.iter().rev().map(|r| &r.folding_randomness.0)) + .flatten() + .copied() + .collect(), + ); + + let mut value = statement + .points + .iter() + .zip(&proof.initial_combination_randomness) + .map(|(point, randomness)| *randomness * eq_poly_outside(&point, &folding_randomness)) + .sum(); + + for round_proof in &proof.rounds { + num_variables -= self.params.folding_factor; + folding_randomness = MultilinearPoint(folding_randomness.0[..num_variables].to_vec()); + + let ood_points = &round_proof.ood_points; + let stir_challenges_points = &round_proof.stir_challenges_points; + let stir_challenges: Vec<_> = ood_points + .iter() + .chain(stir_challenges_points) + .cloned() + .map(|univariate| { + MultilinearPoint::expand_from_univariate(univariate, num_variables) + // TODO: + // Maybe refactor outside + }) + .collect(); + + let sum_of_claims: F = stir_challenges + .into_iter() + .map(|point| eq_poly_outside(&point, &folding_randomness)) + .zip(&round_proof.combination_randomness) + .map(|(point, rand)| point * rand) + .sum(); + + value += sum_of_claims; + } + + value + } + fn compute_v_poly( &self, parsed_commitment: &ParsedCommitment, @@ -463,6 +546,198 @@ where result } + pub fn simple_batch_verify( + &self, + arthur: &mut Arthur, + point: &[F], + evals: &[F], + whir_proof: &WhirProof, + ) -> ProofResult<()> + where + Arthur: FieldChallenges + + FieldReader + + ByteChallenges + + ByteReader + + PoWChallenge + + DigestReader, + { + // We first do a pass in which we rederive all the FS challenges + // Then we will check the algebraic part (so to optimise inversions) + let parsed_commitment = self.parse_commitment(arthur)?; + + // parse proof + let num_polys = self.params.mv_parameters.num_polys; + let compute_dot_product = + |evals: &[F], coeff: &[F]| -> F { zip_eq(evals, coeff).map(|(a, b)| *a * *b).sum() }; + + let random_coeff = utils::generate_random_vector_batch_verify(arthur, num_polys)?; + + let initial_claims: Vec<_> = parsed_commitment + .ood_points + .clone() + .into_iter() + .map(|ood_point| { + MultilinearPoint::expand_from_univariate( + ood_point, + self.params.mv_parameters.num_variables, + ) + }) + .chain(std::iter::once(MultilinearPoint(point.to_vec()))) + .collect(); + + let ood_answers = parsed_commitment + .ood_answers + .clone() + .chunks_exact(num_polys) + .map(|answer| compute_dot_product(answer, &random_coeff)) + .collect::>(); + let eval = compute_dot_product(evals, &random_coeff); + + let initial_answers: Vec<_> = ood_answers + .into_iter() + .chain(std::iter::once(eval)) + .collect(); + + let statement = Statement { + points: initial_claims, + evaluations: initial_answers, + }; + let parsed = self.parse_proof( + arthur, + &parsed_commitment, + &statement, + whir_proof, + random_coeff, + )?; + + let computed_folds = self.compute_folds(&parsed); + + let mut prev: Option<(SumcheckPolynomial, F)> = None; + if let Some(round) = parsed.initial_sumcheck_rounds.first() { + // Check the first polynomial + let (mut prev_poly, mut randomness) = round.clone(); + if prev_poly.sum_over_hypercube() + != statement + .evaluations + .clone() + .into_iter() + .zip(&parsed.initial_combination_randomness) + .map(|(ans, rand)| ans * rand) + .sum() + { + return Err(ProofError::InvalidProof); + } + + // Check the rest of the rounds + for (sumcheck_poly, new_randomness) in &parsed.initial_sumcheck_rounds[1..] { + if sumcheck_poly.sum_over_hypercube() + != prev_poly.evaluate_at_point(&randomness.into()) + { + return Err(ProofError::InvalidProof); + } + prev_poly = sumcheck_poly.clone(); + randomness = *new_randomness; + } + + prev = Some((prev_poly, randomness)); + } + + for (round, folds) in parsed.rounds.iter().zip(&computed_folds) { + let (sumcheck_poly, new_randomness) = &round.sumcheck_rounds[0].clone(); + + let values = round.ood_answers.iter().copied().chain(folds.clone()); + + let prev_eval = if let Some((prev_poly, randomness)) = prev { + prev_poly.evaluate_at_point(&randomness.into()) + } else { + F::ZERO + }; + let claimed_sum = prev_eval + + values + .zip(&round.combination_randomness) + .map(|(val, rand)| val * rand) + .sum::(); + + if sumcheck_poly.sum_over_hypercube() != claimed_sum { + return Err(ProofError::InvalidProof); + } + + prev = Some((sumcheck_poly.clone(), *new_randomness)); + + // Check the rest of the round + for (sumcheck_poly, new_randomness) in &round.sumcheck_rounds[1..] { + let (prev_poly, randomness) = prev.unwrap(); + if sumcheck_poly.sum_over_hypercube() + != prev_poly.evaluate_at_point(&randomness.into()) + { + return Err(ProofError::InvalidProof); + } + prev = Some((sumcheck_poly.clone(), *new_randomness)); + } + } + + // Check the foldings computed from the proof match the evaluations of the polynomial + let final_folds = &computed_folds[computed_folds.len() - 1]; + let final_evaluations = parsed + .final_coefficients + .evaluate_at_univariate(&parsed.final_randomness_points); + if !final_folds + .iter() + .zip(final_evaluations) + .all(|(&fold, eval)| fold == eval) + { + return Err(ProofError::InvalidProof); + } + + // Check the final sumchecks + if self.params.final_sumcheck_rounds > 0 { + let prev_sumcheck_poly_eval = if let Some((prev_poly, randomness)) = prev { + prev_poly.evaluate_at_point(&randomness.into()) + } else { + F::ZERO + }; + let (sumcheck_poly, new_randomness) = &parsed.final_sumcheck_rounds[0].clone(); + let claimed_sum = prev_sumcheck_poly_eval; + + if sumcheck_poly.sum_over_hypercube() != claimed_sum { + return Err(ProofError::InvalidProof); + } + + prev = Some((sumcheck_poly.clone(), *new_randomness)); + + // Check the rest of the round + for (sumcheck_poly, new_randomness) in &parsed.final_sumcheck_rounds[1..] { + let (prev_poly, randomness) = prev.unwrap(); + if sumcheck_poly.sum_over_hypercube() + != prev_poly.evaluate_at_point(&randomness.into()) + { + return Err(ProofError::InvalidProof); + } + prev = Some((sumcheck_poly.clone(), *new_randomness)); + } + } + + let prev_sumcheck_poly_eval = if let Some((prev_poly, randomness)) = prev { + prev_poly.evaluate_at_point(&randomness.into()) + } else { + F::ZERO + }; + + // Check the final sumcheck evaluation + let evaluation_of_v_poly = self.compute_v_poly_for_batched(&statement, &parsed); + + if prev_sumcheck_poly_eval + != evaluation_of_v_poly + * parsed + .final_coefficients + .evaluate(&parsed.final_sumcheck_randomness) + { + return Err(ProofError::InvalidProof); + } + + Ok(()) + } + pub fn verify( &self, arthur: &mut Arthur, @@ -470,12 +745,17 @@ where whir_proof: &WhirProof, ) -> ProofResult<()> where - Arthur: FieldChallenges + FieldReader + ByteChallenges + ByteReader + PoWChallenge + DigestReader, + Arthur: FieldChallenges + + FieldReader + + ByteChallenges + + ByteReader + + PoWChallenge + + DigestReader, { // We first do a pass in which we rederive all the FS challenges // Then we will check the algebraic part (so to optimise inversions) let parsed_commitment = self.parse_commitment(arthur)?; - let parsed = self.parse_proof(arthur, &parsed_commitment, statement, whir_proof)?; + let parsed = self.parse_proof(arthur, &parsed_commitment, statement, whir_proof, vec![])?; let computed_folds = self.compute_folds(&parsed);