diff --git a/Cargo.toml b/Cargo.toml index c2925184..86b97ba7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,7 +76,6 @@ tracing-subscriber = { version = "0.3.17", default-features = false, features = "alloc", "env-filter", ], optional = true } -bincode = { version = "1.3", optional = true } [dev-dependencies] criterion = "0.8" @@ -86,7 +85,7 @@ proptest = { version = "1.0", default-features = true } default = ["parallel"] parallel = ["dep:rayon", "p3-maybe-rayon/parallel", "p3-util/parallel"] rayon = ["dep:rayon"] -cli = ["dep:clap", "dep:tracing-subscriber", "dep:tracing-forest", "dep:bincode", "rand/default"] +cli = ["dep:clap", "dep:tracing-subscriber", "dep:tracing-forest", "rand/default"] [[bin]] name = "main" diff --git a/benches/stir_queries.rs b/benches/stir_queries.rs index 4dfcaacd..f16a9c63 100644 --- a/benches/stir_queries.rs +++ b/benches/stir_queries.rs @@ -6,7 +6,8 @@ use p3_challenger::DuplexChallenger; use p3_field::extension::BinomialExtensionField; use rand::{SeedableRng, rngs::SmallRng}; use whir_p3::{ - fiat_shamir::domain_separator::DomainSeparator, whir::utils::get_challenge_stir_queries, + fiat_shamir::{domain_separator::DomainSeparator, transcript::FiatShamirWriter}, + whir::utils::get_challenge_stir_queries, }; type F = BabyBear; @@ -29,12 +30,13 @@ fn bench_stir_queries(c: &mut Criterion) { // Benchmarks from the main file use case group.bench_function("benchmark main round 1", |b| { b.iter(|| { - let mut challenger = setup_challenger(); - get_challenge_stir_queries::<_, F, EF>( + let challenger = setup_challenger(); + let mut transcript = FiatShamirWriter::::init(challenger); + get_challenge_stir_queries::( + &mut transcript, black_box(67_108_864), black_box(5), black_box(80), - black_box(&mut challenger), ) .unwrap() }); @@ -42,12 +44,13 @@ fn bench_stir_queries(c: &mut Criterion) { group.bench_function("benchmark main round 2", |b| { b.iter(|| { - let mut challenger = setup_challenger(); - get_challenge_stir_queries::<_, F, EF>( + let challenger = setup_challenger(); + let mut transcript = FiatShamirWriter::::init(challenger); + get_challenge_stir_queries::( + &mut transcript, black_box(8_388_608), black_box(5), black_box(26), - black_box(&mut challenger), ) .unwrap() }); @@ -55,12 +58,13 @@ fn bench_stir_queries(c: &mut Criterion) { group.bench_function("benchmark main round 3", |b| { b.iter(|| { - let mut challenger = setup_challenger(); - get_challenge_stir_queries::<_, F, EF>( + let challenger = setup_challenger(); + let mut transcript = FiatShamirWriter::::init(challenger); + get_challenge_stir_queries::( + &mut transcript, black_box(4_194_304), black_box(5), black_box(11), - black_box(&mut challenger), ) .unwrap() }); @@ -68,12 +72,13 @@ fn bench_stir_queries(c: &mut Criterion) { group.bench_function("benchmark main round 4", |b| { b.iter(|| { - let mut challenger = setup_challenger(); - get_challenge_stir_queries::<_, F, EF>( + let challenger = setup_challenger(); + let mut transcript = FiatShamirWriter::::init(challenger); + get_challenge_stir_queries::( + &mut transcript, black_box(2_097_152), black_box(5), black_box(7), - black_box(&mut challenger), ) .unwrap() }); @@ -82,12 +87,13 @@ fn bench_stir_queries(c: &mut Criterion) { // Large case: Many queries, large domain group.bench_function("large_64_queries_64k_domain", |b| { b.iter(|| { - let mut challenger = setup_challenger(); - get_challenge_stir_queries::<_, F, EF>( + let challenger = setup_challenger(); + let mut transcript = FiatShamirWriter::::init(challenger); + get_challenge_stir_queries::( + &mut transcript, black_box(65536), black_box(6), black_box(64), - black_box(&mut challenger), ) .unwrap() }); @@ -96,12 +102,13 @@ fn bench_stir_queries(c: &mut Criterion) { // Very large case: Extreme scenario group.bench_function("very_large_256_queries_1m_domain", |b| { b.iter(|| { - let mut challenger = setup_challenger(); - get_challenge_stir_queries::<_, F, EF>( + let challenger = setup_challenger(); + let mut transcript = FiatShamirWriter::::init(challenger); + get_challenge_stir_queries::( + &mut transcript, black_box(1_048_576), black_box(10), black_box(256), - black_box(&mut challenger), ) .unwrap() }); @@ -110,12 +117,13 @@ fn bench_stir_queries(c: &mut Criterion) { // Edge case: Many queries, tiny bits per query group.bench_function("edge_100_queries_tiny_bits", |b| { b.iter(|| { - let mut challenger = setup_challenger(); - get_challenge_stir_queries::<_, F, EF>( + let challenger = setup_challenger(); + let mut transcript = FiatShamirWriter::::init(challenger); + get_challenge_stir_queries::( + &mut transcript, black_box(64), // domain_size black_box(2), // folding_factor (domain becomes 16, needs 4 bits) black_box(100), // num_queries (lots of queries, few bits each) - black_box(&mut challenger), ) .unwrap() }); diff --git a/benches/sumcheck.rs b/benches/sumcheck.rs index 1dabfb90..6e5441be 100644 --- a/benches/sumcheck.rs +++ b/benches/sumcheck.rs @@ -5,14 +5,13 @@ use p3_field::extension::BinomialExtensionField; use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; use rand::{Rng, SeedableRng, rngs::SmallRng}; use whir_p3::{ - fiat_shamir::domain_separator::DomainSeparator, + fiat_shamir::{domain_separator::DomainSeparator, transcript::FiatShamirWriter}, parameters::{FoldingFactor, ProtocolParameters, errors::SecurityAssumption}, poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, sumcheck::sumcheck_single::SumcheckSingle, whir::{ constraints::{Constraint, statement::EqStatement}, - parameters::InitialPhaseConfig, - proof::{InitialPhase, SumcheckData, WhirProof}, + parameters::InitialPhase, }, }; @@ -31,7 +30,7 @@ fn create_test_protocol_params( let perm = Perm::new_from_rng_128(&mut rng); ProtocolParameters { - initial_phase_config: InitialPhaseConfig::WithStatementClassic, + initial_phase: InitialPhase::WithStatementClassic, security_level: 32, pow_bits: 0, rs_domain_initial_reduction_factor: 1, @@ -90,9 +89,6 @@ fn bench_sumcheck_prover(c: &mut Criterion) { // Benchmark for the classic, round-by-round sumcheck let classic_folding_schedule = [*num_vars / 2, num_vars - (*num_vars / 2)]; - // Create parameters with a dummy folding factor (we'll use manual schedule) - let params = create_test_protocol_params(FoldingFactor::Constant(2)); - // Setup domain separator let domsep: DomainSeparator = DomainSeparator::new(vec![]); @@ -101,41 +97,33 @@ fn bench_sumcheck_prover(c: &mut Criterion) { // Setup fresh challenger for each iteration let mut challenger = setup_challenger(); domsep.observe_domain_separator(&mut challenger); - - // Initialize proof - let mut proof = WhirProof::::from_protocol_parameters(¶ms, *num_vars); + let mut transcript = FiatShamirWriter::init(challenger.clone()); // Create constraint using challenger directly let statement = generate_statement(&mut challenger, *num_vars, poly, 3); let alpha: EF = challenger.sample_algebra_element(); let constraint = Constraint::new_eq_only(alpha, statement); - // Extract sumcheck data from the initial phase - let InitialPhase::WithStatement { ref mut sumcheck } = proof.initial_phase else { - panic!("Expected WithStatement variant"); - }; - // First round - fold first half of variables let (mut sumcheck_prover, _) = SumcheckSingle::from_base_evals( + &mut transcript, poly, - sumcheck, - &mut challenger, classic_folding_schedule[0], 0, &constraint, - ); + ) + .unwrap(); // Second round - fold remaining variables if classic_folding_schedule.len() > 1 && classic_folding_schedule[1] > 0 { - let mut sumcheck_data: SumcheckData = SumcheckData::default(); - sumcheck_prover.compute_sumcheck_polynomials( - &mut sumcheck_data, - &mut challenger, - classic_folding_schedule[1], - 0, - None, - ); - proof.set_final_sumcheck_data(sumcheck_data); + sumcheck_prover + .compute_sumcheck_polynomials( + &mut transcript, + classic_folding_schedule[1], + 0, + None, + ) + .unwrap(); } }); }); diff --git a/benches/sumcheck_svo.rs b/benches/sumcheck_svo.rs index cd61b8e9..106d4e64 100644 --- a/benches/sumcheck_svo.rs +++ b/benches/sumcheck_svo.rs @@ -11,13 +11,9 @@ use whir::{ }; use whir_p3::{ self as whir, - fiat_shamir::domain_separator::DomainSeparator, + fiat_shamir::{domain_separator::DomainSeparator, transcript::FiatShamirWriter}, parameters::{FoldingFactor, ProtocolParameters, errors::SecurityAssumption}, - whir::{ - constraints::Constraint, - parameters::InitialPhaseConfig, - proof::{InitialPhase, SumcheckData, WhirProof}, - }, + whir::{constraints::Constraint, parameters::InitialPhase}, }; type F = KoalaBear; @@ -39,7 +35,7 @@ fn create_test_protocol_params_classic( let perm = Poseidon16::new_from_rng_128(&mut rng); ProtocolParameters { - initial_phase_config: InitialPhaseConfig::WithStatementClassic, + initial_phase: InitialPhase::WithStatementClassic, security_level: 32, pow_bits: 0, rs_domain_initial_reduction_factor: 1, @@ -95,37 +91,21 @@ fn bench_sumcheck_prover_svo(c: &mut Criterion) { let poly = generate_poly(*num_vars); // Classic benchmark - folding all variables in one round - let params_classic = - create_test_protocol_params_classic(FoldingFactor::Constant(*num_vars)); group.bench_with_input(BenchmarkId::new("Classic", *num_vars), &poly, |b, poly| { b.iter(|| { // Setup fresh challenger for each iteration let mut challenger = setup_challenger(); domsep.observe_domain_separator(&mut challenger); - - // Initialize proof - let mut proof = - WhirProof::::from_protocol_parameters(¶ms_classic, *num_vars); + let mut transcript = FiatShamirWriter::init(challenger.clone()); // Create constraint using challenger directly let statement = generate_statement(&mut challenger, *num_vars, poly, 3); let alpha: EF = challenger.sample_algebra_element(); let constraint = Constraint::new_eq_only(alpha, statement); - // Extract sumcheck data from the initial phase - let InitialPhase::WithStatement { ref mut sumcheck } = proof.initial_phase else { - panic!("Expected WithStatement variant"); - }; - // Fold all variables in one round - SumcheckSingle::from_base_evals( - poly, - sumcheck, - &mut challenger, - *num_vars, - 0, - &constraint, - ); + SumcheckSingle::from_base_evals(&mut transcript, poly, *num_vars, 0, &constraint) + .unwrap(); }); }); @@ -135,24 +115,22 @@ fn bench_sumcheck_prover_svo(c: &mut Criterion) { // Setup fresh challenger for each iteration let mut challenger = setup_challenger(); domsep.observe_domain_separator(&mut challenger); + let mut transcript = FiatShamirWriter::init(challenger.clone()); // Create constraint using challenger directly let statement = generate_statement(&mut challenger, *num_vars, poly, 1); let alpha: EF = challenger.sample_algebra_element(); let constraint = Constraint::new_eq_only(alpha, statement); - // Create sumcheck data - let mut sumcheck_data: SumcheckData = SumcheckData::default(); - // Fold all variables using SVO optimization SumcheckSingle::from_base_evals_svo( + &mut transcript, poly, - &mut sumcheck_data, - &mut challenger, *num_vars, 0, &constraint, - ); + ) + .unwrap(); }); }); } diff --git a/benches/whir.rs b/benches/whir.rs index f69d58e7..ce989c86 100644 --- a/benches/whir.rs +++ b/benches/whir.rs @@ -6,14 +6,13 @@ use p3_koala_bear::{KoalaBear, Poseidon2KoalaBear}; use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; use rand::{Rng, SeedableRng, rngs::SmallRng}; use whir_p3::{ - fiat_shamir::domain_separator::DomainSeparator, + fiat_shamir::{domain_separator::DomainSeparator, transcript::FiatShamirWriter}, parameters::{DEFAULT_MAX_POW, FoldingFactor, ProtocolParameters, errors::SecurityAssumption}, poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, whir::{ committer::writer::CommitmentWriter, constraints::statement::EqStatement, - parameters::{InitialPhaseConfig, WhirConfig}, - proof::WhirProof, + parameters::{InitialPhase, WhirConfig}, prover::Prover, }, }; @@ -30,9 +29,7 @@ type MyChallenger = DuplexChallenger; #[allow(clippy::type_complexity)] fn prepare_inputs() -> ( - WhirConfig, - ProtocolParameters, - usize, + WhirConfig, Radix2DFTSmallBatch, EvaluationsList, EqStatement, @@ -74,7 +71,7 @@ fn prepare_inputs() -> ( // Assemble the protocol-level parameters. let whir_params = ProtocolParameters { - initial_phase_config: InitialPhaseConfig::WithStatementClassic, + initial_phase: InitialPhase::WithStatementClassic, security_level, pow_bits, folding_factor, @@ -86,7 +83,7 @@ fn prepare_inputs() -> ( }; // Combine multivariate and protocol parameters into a unified WHIR config. - let params = WhirConfig::new(num_variables, whir_params.clone()); + let params = WhirConfig::new(num_variables, whir_params); // Sample random multilinear polynomial @@ -114,8 +111,8 @@ fn prepare_inputs() -> ( let mut domainsep = DomainSeparator::new(vec![]); // Commit protocol parameters and proof type to the domain separator. - domainsep.commit_statement::<_, _, _, 32>(¶ms); - domainsep.add_whir_proof::<_, _, _, 32>(¶ms); + domainsep.commit_statement::<_, _, 32>(¶ms); + domainsep.add_whir_proof::<_, _, 32>(¶ms); // Instantiate the Fiat-Shamir challenger from an empty seed and Keccak. let challenger = MyChallenger::new(poseidon16); @@ -126,31 +123,20 @@ fn prepare_inputs() -> ( let dft = Radix2DFTSmallBatch::::new(1 << params.max_fft_size()); // Return all preprocessed components needed to run commit/prove/verify benchmarks. - ( - params, - whir_params, - num_variables, - dft, - polynomial, - statement, - challenger, - domainsep, - ) + (params, dft, polynomial, statement, challenger, domainsep) } fn benchmark_commit_and_prove(c: &mut Criterion) { - let (params, whir_params, num_variables, dft, polynomial, statement, challenger, domainsep) = - prepare_inputs(); + let (params, dft, polynomial, statement, challenger, domainsep) = prepare_inputs(); c.bench_function("commit", |b| { b.iter(|| { let mut challenger_clone = challenger.clone(); domainsep.observe_domain_separator(&mut challenger_clone); - let mut proof = - WhirProof::::from_protocol_parameters(&whir_params, num_variables); + let mut transcript = FiatShamirWriter::init(challenger_clone); let committer = CommitmentWriter::new(¶ms); let _witness = committer - .commit(&dft, &mut proof, &mut challenger_clone, polynomial.clone()) + .commit(&dft, &mut transcript, polynomial.clone()) .unwrap(); }); }); @@ -159,22 +145,15 @@ fn benchmark_commit_and_prove(c: &mut Criterion) { b.iter(|| { let mut challenger_clone = challenger.clone(); domainsep.observe_domain_separator(&mut challenger_clone); - let mut proof = - WhirProof::::from_protocol_parameters(&whir_params, num_variables); + let mut transcript = FiatShamirWriter::init(challenger_clone); let committer = CommitmentWriter::new(¶ms); let witness = committer - .commit(&dft, &mut proof, &mut challenger_clone, polynomial.clone()) + .commit(&dft, &mut transcript, polynomial.clone()) .unwrap(); let prover = Prover(¶ms); prover - .prove( - &dft, - &mut proof, - &mut challenger_clone, - statement.clone(), - witness, - ) + .prove(&dft, &mut transcript, statement.clone(), witness) .unwrap(); }); }); diff --git a/src/bin/main.rs b/src/bin/main.rs index 81c2d590..98803e56 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -1,11 +1,9 @@ use std::time::Instant; use clap::Parser; -use p3_baby_bear::BabyBear; use p3_challenger::DuplexChallenger; use p3_dft::Radix2DFTSmallBatch; use p3_field::extension::BinomialExtensionField; -use p3_goldilocks::Goldilocks; use p3_koala_bear::{KoalaBear, Poseidon2KoalaBear}; use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; use rand::{ @@ -15,14 +13,16 @@ use rand::{ use tracing_forest::{ForestLayer, util::LevelFilter}; use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt}; use whir_p3::{ - fiat_shamir::domain_separator::DomainSeparator, + fiat_shamir::{ + domain_separator::DomainSeparator, + transcript::{FiatShamirReader, FiatShamirWriter}, + }, parameters::{DEFAULT_MAX_POW, FoldingFactor, ProtocolParameters, errors::SecurityAssumption}, poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, whir::{ committer::{reader::CommitmentReader, writer::CommitmentWriter}, constraints::statement::EqStatement, - parameters::{InitialPhaseConfig, WhirConfig}, - proof::WhirProof, + parameters::{InitialPhase, WhirConfig}, prover::Prover, verifier::Verifier, }, @@ -30,11 +30,6 @@ use whir_p3::{ type F = KoalaBear; type EF = BinomialExtensionField; -type _F = BabyBear; -type _EF = BinomialExtensionField<_F, 5>; -type __F = Goldilocks; -type __EF = BinomialExtensionField<__F, 2>; - type Poseidon16 = Poseidon2KoalaBear<16>; type Poseidon24 = Poseidon2KoalaBear<24>; @@ -114,7 +109,7 @@ fn main() { // Construct WHIR protocol parameters let whir_params = ProtocolParameters { - initial_phase_config: InitialPhaseConfig::WithStatementClassic, + initial_phase: InitialPhase::WithStatementClassic, security_level, pow_bits, folding_factor, @@ -125,10 +120,7 @@ fn main() { rs_domain_initial_reduction_factor, }; - let params = WhirConfig::::new( - num_variables, - whir_params.clone(), - ); + let params = WhirConfig::::new(num_variables, whir_params); let mut rng = StdRng::seed_from_u64(0); let polynomial = EvaluationsList::::new((0..num_coeffs).map(|_| rng.random()).collect()); @@ -148,8 +140,12 @@ fn main() { // Define the Fiat-Shamir domain separator pattern for committing and proving let mut domainsep = DomainSeparator::new(vec![]); - domainsep.commit_statement::<_, _, _, 32>(¶ms); - domainsep.add_whir_proof::<_, _, _, 32>(¶ms); + domainsep.commit_statement::<_, _, 32>(¶ms); + domainsep.add_whir_proof::<_, _, 32>(¶ms); + + let mut challenger = MyChallenger::new(poseidon16); + // Initialize the challenger with domain separator + domainsep.observe_domain_separator(&mut challenger); println!("========================================="); println!("Whir (PCS) 🌪️"); @@ -157,23 +153,13 @@ fn main() { println!("WARN: more PoW bits required than what specified."); } - let challenger = MyChallenger::new(poseidon16); - - // Initialize the prover's challenger with domain separator - let mut prover_challenger = challenger.clone(); - domainsep.observe_domain_separator(&mut prover_challenger); - + let mut transcript: FiatShamirWriter = FiatShamirWriter::init(challenger.clone()); // Commit to the polynomial and produce a witness let committer = CommitmentWriter::new(¶ms); let dft = Radix2DFTSmallBatch::::new(1 << params.max_fft_size()); - - let mut proof = WhirProof::::from_protocol_parameters(&whir_params, num_variables); - let time = Instant::now(); - let witness = committer - .commit(&dft, &mut proof, &mut prover_challenger, polynomial) - .unwrap(); + let witness = committer.commit(&dft, &mut transcript, polynomial).unwrap(); let commit_time = time.elapsed(); // Generate a proof using the prover @@ -182,39 +168,32 @@ fn main() { // Generate a proof for the given statement and witness let time = Instant::now(); prover - .prove( - &dft, - &mut proof, - &mut prover_challenger, - statement.clone(), - witness, - ) + .prove(&dft, &mut transcript, statement.clone(), witness) .unwrap(); - let opening_time = time.elapsed(); + let proof = transcript.finalize(); + + println!( + "Proof size: {} bytes ({:.2} KB)", + proof.len(), + proof.len() as f64 / 1024.0 + ); // Create a commitment reader let commitment_reader = CommitmentReader::new(¶ms); // Create a verifier with matching parameters let verifier = Verifier::new(¶ms); - - // Initialize the verifier's challenger with domain separator - let mut verifier_challenger = challenger; - domainsep.observe_domain_separator(&mut verifier_challenger); + let mut transcript = FiatShamirReader::::init(proof, challenger); // Parse the commitment - let parsed_commitment = - commitment_reader.parse_commitment::<8>(&proof, &mut verifier_challenger); + let parsed_commitment = commitment_reader + .parse_commitment::<_, 8>(&mut transcript) + .unwrap(); let verif_time = Instant::now(); verifier - .verify( - &proof, - &mut verifier_challenger, - &parsed_commitment, - statement, - ) + .verify(&mut transcript, &parsed_commitment, statement) .unwrap(); let verify_time = verif_time.elapsed(); @@ -224,11 +203,6 @@ fn main() { commit_time.as_millis(), opening_time.as_millis() ); - let proof_bytes = bincode::serialize(&proof).expect("Failed to serialize proof"); - println!( - "Proof size: {} bytes ({:.2} KB)", - proof_bytes.len(), - proof_bytes.len() as f64 / 1024.0 - ); + println!("Verification time: {} μs", verify_time.as_micros()); } diff --git a/src/fiat_shamir/domain_separator.rs b/src/fiat_shamir/domain_separator.rs index dd716902..217622a9 100644 --- a/src/fiat_shamir/domain_separator.rs +++ b/src/fiat_shamir/domain_separator.rs @@ -2,12 +2,12 @@ use alloc::vec::Vec; use core::marker::PhantomData; use p3_challenger::{FieldChallenger, GrindingChallenger}; -use p3_field::{ExtensionField, Field, TwoAdicField}; +use p3_field::{ExtensionField, TwoAdicField}; use crate::{ constant::K_SKIP_SUMCHECK, fiat_shamir::pattern::{Hint, Observe, Pattern, Sample}, - whir::parameters::{InitialPhaseConfig, WhirConfig}, + whir::parameters::{InitialPhase, WhirConfig}, }; /// Configuration parameters for a sumcheck phase in the protocol. @@ -49,7 +49,7 @@ pub struct DomainSeparator { impl DomainSeparator where EF: ExtensionField, - F: Field, + F: TwoAdicField, { /// Create a new DomainSeparator with the domain separator. #[must_use] @@ -99,36 +99,30 @@ where } } - pub fn commit_statement( + pub fn commit_statement( &mut self, - params: &WhirConfig, - ) where - Challenger: FieldChallenger + GrindingChallenger, - { + params: &WhirConfig, + ) { // TODO: Add params self.observe(DIGEST_ELEMS, Observe::MerkleDigest); if params.commitment_ood_samples > 0 { - assert!(params.initial_phase_config.has_initial_statement()); + assert!(params.initial_phase.has_initial_statement()); self.add_ood(params.commitment_ood_samples); } } - pub fn add_whir_proof( + pub fn add_whir_proof( &mut self, - params: &WhirConfig, - ) where - Challenger: FieldChallenger + GrindingChallenger, - EF: TwoAdicField, - F: TwoAdicField, - { + params: &WhirConfig, + ) { // TODO: Add statement - if params.initial_phase_config.has_initial_statement() { + if params.initial_phase.has_initial_statement() { self.sample(1, Sample::InitialCombinationRandomness); self.add_sumcheck(&SumcheckParams { rounds: params.folding_factor.at_round(0), pow_bits: params.starting_folding_pow_bits, - univariate_skip: match params.initial_phase_config { - InitialPhaseConfig::WithStatementUnivariateSkip => Some(K_SKIP_SUMCHECK), + univariate_skip: match params.initial_phase { + InitialPhase::WithStatementSkip => Some(K_SKIP_SUMCHECK), _ => None, }, }); diff --git a/src/fiat_shamir/errors.rs b/src/fiat_shamir/errors.rs index 6c286f02..b9e11195 100644 --- a/src/fiat_shamir/errors.rs +++ b/src/fiat_shamir/errors.rs @@ -10,4 +10,6 @@ pub enum FiatShamirError { /// Proof-of-work witness fails difficulty requirement. #[error("Invalid grinding witness: proof-of-work verification failed")] InvalidGrindingWitness, + #[error("Invalid element")] + ElementIO, } diff --git a/src/fiat_shamir/mod.rs b/src/fiat_shamir/mod.rs index 18d0f2d0..02313cfc 100644 --- a/src/fiat_shamir/mod.rs +++ b/src/fiat_shamir/mod.rs @@ -3,3 +3,4 @@ pub mod errors; pub mod pattern; #[cfg(test)] mod tests; +pub mod transcript; diff --git a/src/fiat_shamir/transcript.rs b/src/fiat_shamir/transcript.rs new file mode 100644 index 00000000..5e1b8d91 --- /dev/null +++ b/src/fiat_shamir/transcript.rs @@ -0,0 +1,703 @@ +use alloc::{vec, vec::Vec}; +use core::marker::PhantomData; + +use p3_baby_bear::BabyBear; +use p3_challenger::{FieldChallenger, GrindingChallenger}; +use p3_field::{ + BasedVectorSpace, ExtensionField, Field, PrimeField32, PrimeField64, RawDataSerializable, + extension::BinomialExtensionField, integers::QuotientMap, +}; +use p3_goldilocks::Goldilocks; +use p3_koala_bear::KoalaBear; +use p3_symmetric::Hash; + +use crate::fiat_shamir::errors::FiatShamirError as Error; + +pub trait SerializedField: Field { + fn from_bytes(bytes: &[u8]) -> Result; + fn to_bytes(&self) -> Vec; +} + +macro_rules! impl_field_ser { + (32, $field:ty) => { + impl SerializedField for $field { + fn from_bytes(bytes: &[u8]) -> Result { + let bytes = bytes.try_into().map_err(|_| Error::ElementIO)?; + let inner = u32::from_le_bytes(bytes); + <$field>::from_canonical_checked(inner).ok_or(Error::ElementIO) + } + + fn to_bytes(&self) -> Vec { + <$field>::as_canonical_u32(self).to_le_bytes().to_vec() + } + } + }; + (64, $field:ty) => { + impl SerializedField for $field { + fn from_bytes(bytes: &[u8]) -> Result { + let bytes = bytes.try_into().map_err(|_| Error::ElementIO)?; + let inner = u64::from_le_bytes(bytes); + <$field>::from_canonical_checked(inner).ok_or(Error::ElementIO) + } + + fn to_bytes(&self) -> Vec { + <$field>::as_canonical_u64(self).to_le_bytes().to_vec() + } + } + }; +} + +macro_rules! impl_ext_field_ser { + ($base:ty, $d:expr) => { + impl SerializedField for BinomialExtensionField<$base, $d> { + fn from_bytes(bytes: &[u8]) -> Result { + (bytes.len() == Self::NUM_BYTES) + .then_some(()) + .ok_or(Error::ElementIO)?; + bytes + .chunks_exact(<$base>::NUM_BYTES) + .map(<$base>::from_bytes) + .collect::, Error>>() + .map(|coeffs| Self::from_basis_coefficients_slice(&coeffs).unwrap()) + } + + fn to_bytes(&self) -> Vec { + self.as_basis_coefficients_slice() + .iter() + .flat_map(|e: &$base| e.to_bytes()) + .collect() + } + } + }; +} + +impl_field_ser!(64, Goldilocks); +impl_field_ser!(32, BabyBear); +impl_field_ser!(32, KoalaBear); + +impl_ext_field_ser!(Goldilocks, 2); +impl_ext_field_ser!(BabyBear, 4); +impl_ext_field_ser!(KoalaBear, 4); +impl_ext_field_ser!(KoalaBear, 8); + +pub trait ChallengeBits { + fn sample(&mut self, bits: usize) -> usize; + fn sample_many(&mut self, n: usize, bits: usize) -> Vec { + (0..n).map(|_| self.sample(bits)).collect() + } +} + +pub trait Challenge { + fn sample(&mut self) -> F; + fn sample_many(&mut self, n: usize) -> Vec { + (0..n).map(|_| self.sample()).collect() + } +} + +pub trait Pow { + fn pow(&mut self, bits: usize) -> Result<(), Error>; +} + +pub trait Writer { + // TOOD: rename as send? + fn write(&mut self, el: T) -> Result<(), Error>; + fn write_hint(&mut self, el: T) -> Result<(), Error>; + fn write_many(&mut self, el: &[T]) -> Result<(), Error> + where + T: Copy, + { + el.iter().try_for_each(|&e| self.write(e)) + } + fn write_hint_many(&mut self, el: &[T]) -> Result<(), Error> + where + T: Copy, + { + el.iter().try_for_each(|&e| self.write_hint(e)) + } +} + +pub trait ProverTranscript: + Writer + Writer + Writer<[F; D]> + Challenge + Pow + ChallengeBits +{ +} + +pub trait Reader { + fn read(&mut self) -> Result; + fn read_hint(&mut self) -> Result; + fn read_many(&mut self, n: usize) -> Result, Error> { + (0..n).map(|_| self.read()).collect::, _>>() + } + fn read_hint_many(&mut self, n: usize) -> Result, Error> { + (0..n) + .map(|_| self.read_hint()) + .collect::, _>>() + } +} + +pub trait VerifierTranscript: + Reader + Reader + Reader<[F; D]> + Challenge + ChallengeBits + Pow +{ +} + +#[derive(Debug, Clone)] +pub struct FiatShamirWriter { + challenger: Challenger, + data: Vec, + _marker: PhantomData, +} + +impl FiatShamirWriter { + pub const fn init(challenger: Challenger) -> Self { + Self { + data: vec![], + challenger, + _marker: PhantomData, + } + } + + pub fn finalize(self) -> Vec { + self.data + } +} + +impl ProverTranscript + for FiatShamirWriter +where + F: SerializedField, + EF: ExtensionField + SerializedField, + Challenger: FieldChallenger + GrindingChallenger, +{ +} + +impl + SerializedField, Challenger> Writer + for FiatShamirWriter +where + Challenger: FieldChallenger, +{ + fn write(&mut self, e: EF) -> Result<(), Error> { + self.write_hint(e)?; + self.challenger.observe_algebra_element(e); + Ok(()) + } + + fn write_hint(&mut self, e: EF) -> Result<(), Error> { + let bytes = e.to_bytes(); + self.data.extend_from_slice(&bytes); + Ok(()) + } +} + +impl, Challenger> Challenge + for FiatShamirWriter +where + Challenger: FieldChallenger, +{ + fn sample(&mut self) -> EF { + self.challenger.sample_algebra_element::() + } +} + +impl ChallengeBits for FiatShamirWriter +where + Challenger: GrindingChallenger, +{ + fn sample(&mut self, bits: usize) -> usize { + self.challenger.sample_bits(bits) + } +} + +impl Writer<[F; D]> + for FiatShamirWriter +where + Challenger: FieldChallenger, +{ + fn write(&mut self, e: [F; D]) -> Result<(), Error> { + self.write_hint(e)?; + self.challenger.observe_slice(&e); + Ok(()) + } + + fn write_hint(&mut self, e: [F; D]) -> Result<(), Error> { + e.into_iter() + .try_for_each(|e| >::write_hint(self, e)) + } +} + +impl Writer> + for FiatShamirWriter +where + Challenger: FieldChallenger, +{ + fn write(&mut self, e: Hash) -> Result<(), Error> { + self.write_hint(e)?; + self.challenger.observe_slice(e.as_ref()); + Ok(()) + } + + fn write_hint(&mut self, e: Hash) -> Result<(), Error> { + e.into_iter() + .try_for_each(|e| >::write_hint(self, e)) + } +} + +impl Pow for FiatShamirWriter +where + Challenger: GrindingChallenger + FieldChallenger, +{ + fn pow(&mut self, bits: usize) -> Result<(), Error> { + if bits > 0 { + let witness = self.challenger.grind(bits); + self.write_hint(witness)?; + } + Ok(()) + } +} + +#[derive(Debug, Clone)] +struct Buffer { + pos: usize, + data: Vec, +} + +impl Buffer { + const fn new(data: Vec) -> Self { + Self { pos: 0, data } + } + + fn read_exact(&mut self, dst: &mut [u8]) -> Result<(), Error> { + (self.pos + dst.len() <= self.data.len()) + .then_some(()) + .ok_or(Error::ElementIO)?; + dst.copy_from_slice(&self.data[self.pos..self.pos + dst.len()]); + self.pos += dst.len(); + Ok(()) + } +} + +#[derive(Debug, Clone)] +pub struct FiatShamirReader { + challenger: Challenger, + buffer: Buffer, + _marker: PhantomData, +} + +impl FiatShamirReader { + pub const fn init(proof: Vec, challenger: Challenger) -> Self { + let buffer = Buffer::new(proof); + Self { + buffer, + challenger, + _marker: PhantomData, + } + } +} + +impl VerifierTranscript + for FiatShamirReader +where + F: SerializedField, + EF: ExtensionField + SerializedField, + Challenger: FieldChallenger + GrindingChallenger, +{ +} + +impl + SerializedField, Challenger> Reader + for FiatShamirReader +where + Challenger: FieldChallenger + Clone, +{ + fn read(&mut self) -> Result { + let e: EF = self.read_hint()?; + self.challenger.observe_algebra_element(e); + Ok(e) + } + + fn read_hint(&mut self) -> Result { + let mut bytes = vec![0u8; EF::NUM_BYTES]; + self.buffer.read_exact(bytes.as_mut())?; + EF::from_bytes(&bytes) + } +} + +impl Reader<[F; D]> + for FiatShamirReader +where + Challenger: FieldChallenger + Clone, +{ + fn read(&mut self) -> Result<[F; D], Error> { + let result: [F; D] = self.read_hint()?; + self.challenger.observe_slice(&result); + Ok(result) + } + + fn read_hint(&mut self) -> Result<[F; D], Error> { + Ok((0..D) + .map(|_| self.read_hint()) + .collect::, _>>()? + .try_into() + .unwrap()) + } +} + +impl, Challenger> Challenge + for FiatShamirReader +where + Challenger: FieldChallenger, +{ + fn sample(&mut self) -> EF { + self.challenger.sample_algebra_element::() + } +} + +impl ChallengeBits for FiatShamirReader +where + Challenger: GrindingChallenger, +{ + fn sample(&mut self, bits: usize) -> usize { + self.challenger.sample_bits(bits) + } +} + +impl Pow for FiatShamirReader +where + Challenger: GrindingChallenger + FieldChallenger, +{ + fn pow(&mut self, bits: usize) -> Result<(), Error> { + if bits > 0 { + let witness: F = self.read_hint()?; + self.challenger + .check_witness(bits, witness) + .then_some(()) + .ok_or(Error::InvalidGrindingWitness)?; + } + Ok(()) + } +} + +#[cfg(test)] +mod test { + + use p3_baby_bear::{BabyBear, Poseidon2BabyBear}; + use p3_challenger::{DuplexChallenger, FieldChallenger, GrindingChallenger}; + use p3_field::{ExtensionField, Field, extension::BinomialExtensionField}; + use p3_goldilocks::{Goldilocks, Poseidon2Goldilocks}; + use p3_koala_bear::{KoalaBear, Poseidon2KoalaBear}; + use rand::{ + Rng, SeedableRng, + distr::{Distribution, StandardUniform}, + rngs::SmallRng, + }; + + use super::*; + + #[test] + fn test_serialization() { + fn run_test(rng: &mut impl Rng, n: usize) + where + StandardUniform: Distribution, + { + let mut bytes = vec![0u8; F::NUM_BYTES]; + let zero = F::from_bytes(&bytes).unwrap(); + assert_eq!(zero, F::ZERO); + bytes[0] = 1; + let one = F::from_bytes(&bytes).unwrap(); + assert_eq!(one, F::ONE); + for _ in 0..n { + let a0: F = rng.random(); + let bytes = a0.to_bytes(); + let a1 = F::from_bytes(&bytes).unwrap(); + assert_eq!(a0, a1); + } + } + + let n = 1000; + let mut rng = SmallRng::seed_from_u64(0); + run_test::(&mut rng, n); + run_test::>(&mut rng, n); + run_test::(&mut rng, n); + run_test::>(&mut rng, n); + run_test::(&mut rng, n); + run_test::>(&mut rng, n); + } + + enum TestFs { + Writer(FiatShamirWriter), + Reader(FiatShamirReader), + } + + impl, Challenger: FieldChallenger> Challenge + for TestFs + { + fn sample(&mut self) -> EF { + match self { + Self::Writer(w) => w.sample(), + Self::Reader(r) => r.sample(), + } + } + } + + impl ChallengeBits for TestFs { + fn sample(&mut self, bits: usize) -> usize { + match self { + Self::Writer(w) => w.sample(bits), + Self::Reader(r) => r.sample(bits), + } + } + } + + impl + FieldChallenger> + Pow for TestFs + { + fn pow(&mut self, bits: usize) -> Result<(), Error> { + match self { + Self::Writer(w) => w.pow(bits), + Self::Reader(r) => r.pow(bits), + } + } + } + + impl + FieldChallenger> + TestFs + { + fn finalize(self) -> Vec { + match self { + Self::Writer(w) => w.finalize(), + Self::Reader(_) => unreachable!(), + } + } + } + + fn test_fs< + F: Field + SerializedField, + Ext: ExtensionField + SerializedField, + Challenger: FieldChallenger + GrindingChallenger, + >( + seed: u64, + fs: &mut TestFs, + ) -> (Vec, Vec, Vec, Vec) + where + StandardUniform: Distribution + Distribution, + { + let mut rng_range = SmallRng::seed_from_u64(seed); + let mut rng_number = SmallRng::seed_from_u64(seed + 1); + + let mut els_w = vec![]; + let mut els_ext_w = vec![]; + let mut chs_w = vec![]; + let mut idx_w = vec![]; + + (0..100).for_each(|_| { + let n_el_draw = rng_range.random_range(0..10); + let chs: Vec = + as Challenge>::sample_many(fs, n_el_draw); + let n_idx_draw = rng_range.random_range(0..10); + let idx_bits = rng_range.random_range(0..4); + let idx = + as ChallengeBits>::sample_many(fs, n_idx_draw, idx_bits); + chs_w.extend(chs); + idx_w.extend(idx); + + let n_base = rng_range.random_range(0..10); + let n_ext = rng_range.random_range(0..10); + match fs { + TestFs::Writer(w) => { + let els: Vec = (0..n_base).map(|_| rng_number.random()).collect(); + let els_ext: Vec = (0..n_ext).map(|_| rng_number.random()).collect(); + w.write_many(&els).unwrap(); + w.write_many(&els_ext).unwrap(); + els_w.extend(els); + els_ext_w.extend(els_ext); + } + TestFs::Reader(r) => { + let els: Vec = r.read_many(n_base).unwrap(); + let els_ext: Vec = r.read_many(n_ext).unwrap(); + els_w.extend(els); + els_ext_w.extend(els_ext); + } + } + let pow_bits = rng_range.random_range(0..4); + fs.pow(pow_bits).unwrap(); + + let n_base = rng_range.random_range(0..10); + let n_ext = rng_range.random_range(0..10); + match fs { + TestFs::Writer(w) => { + let els: Vec = (0..n_base).map(|_| rng_number.random()).collect(); + let els_ext: Vec = (0..n_ext).map(|_| rng_number.random()).collect(); + w.write_hint_many(&els).unwrap(); + w.write_hint_many(&els_ext).unwrap(); + els_w.extend(els); + els_ext_w.extend(els_ext); + } + TestFs::Reader(r) => { + let els: Vec = r.read_hint_many(n_base).unwrap(); + let els_ext: Vec = r.read_hint_many(n_ext).unwrap(); + els_w.extend(els); + els_ext_w.extend(els_ext); + } + } + }); + (els_w, els_ext_w, chs_w, idx_w) + } + + #[test] + fn test_transcript() { + fn run_test< + F: SerializedField, + EF: ExtensionField + SerializedField, + Challenger: FieldChallenger + GrindingChallenger, + >( + challenger: &Challenger, + ) where + StandardUniform: Distribution + Distribution, + { + for seed in 0..10 { + let w = FiatShamirWriter::::init(challenger.clone()); + let mut w = TestFs::Writer(w); + let (els0, els_ext0, chs0, idx0) = test_fs::(seed, &mut w); + let checkpoint_writer: EF = + as Challenge>::sample(&mut w); + let proof = w.finalize(); + + let r = FiatShamirReader::::init(proof, challenger.clone()); + let mut r = TestFs::Reader(r); + let (els1, els_ext1, chs1, idx1) = test_fs::(seed, &mut r); + let checkpoint_reader: EF = + as Challenge>::sample(&mut r); + assert_eq!(checkpoint_writer, checkpoint_reader); + + match &mut r { + TestFs::Reader(r) => { + assert_eq!(r.buffer.pos, r.buffer.data.len()); + assert!( as Reader>::read(r).is_err()); + assert!( as Reader>::read_many(r, 2).is_err()); + } + TestFs::Writer(_) => unreachable!(), + } + + assert_eq!(els0, els1); + assert_eq!(els_ext0, els_ext1); + assert_eq!(chs0, chs1); + assert_eq!(idx0, idx1); + } + } + + { + type F = Goldilocks; + type Ext = BinomialExtensionField; + type Perm = Poseidon2Goldilocks<16>; + type Challenger = DuplexChallenger; + let challenger = + Challenger::new(Perm::new_from_rng_128(&mut SmallRng::seed_from_u64(1000))); + run_test::(&challenger); + } + + { + type F = BabyBear; + type Ext = BinomialExtensionField; + type Perm = Poseidon2BabyBear<16>; + type Challenger = DuplexChallenger; + let challenger = + Challenger::new(Perm::new_from_rng_128(&mut SmallRng::seed_from_u64(1000))); + run_test::(&challenger); + } + + { + type F = KoalaBear; + type Ext = BinomialExtensionField; + type Perm = Poseidon2KoalaBear<16>; + type Challenger = DuplexChallenger; + let challenger = + Challenger::new(Perm::new_from_rng_128(&mut SmallRng::seed_from_u64(1000))); + run_test::(&challenger); + } + } + + #[test] + fn test_algebra_element() { + fn run_test< + F: SerializedField, + EF: ExtensionField + SerializedField, + Challenger: FieldChallenger + GrindingChallenger, + >( + challenger: &Challenger, + ) where + StandardUniform: Distribution + Distribution, + { + let mut rng = SmallRng::seed_from_u64(1); + + let proof = { + let mut w1 = FiatShamirWriter::::init(challenger.clone()); + let mut w0 = FiatShamirWriter::::init(challenger.clone()); + + let e0 = as Challenge>::sample(&mut w0); + let e1 = (0..EF::DIMENSION) + .map(|_| as Challenge>::sample(&mut w1)); + let e1 = EF::from_basis_coefficients_iter(e1).unwrap(); + assert_eq!(e0, e1); + + let e0: EF = rng.random(); + w0.write(e0).unwrap(); + let e0 = e0.as_basis_coefficients_slice(); + w1.write_many(e0).unwrap(); + let checkpoint0 = as Challenge>::sample(&mut w0); + let checkpoint1 = as Challenge>::sample(&mut w1); + assert_eq!(checkpoint0, checkpoint1); + let proof0 = w0.finalize(); + let proof1 = w1.finalize(); + assert_eq!(proof0, proof1); + proof0 + }; + + { + let mut r1 = FiatShamirReader::::init(proof.clone(), challenger.clone()); + let mut r0 = FiatShamirReader::::init(proof, challenger.clone()); + + let e0 = as Challenge>::sample(&mut r0); + let e1 = (0..EF::DIMENSION) + .map(|_| as Challenge>::sample(&mut r1)); + let e1 = EF::from_basis_coefficients_iter(e1).unwrap(); + assert_eq!(e0, e1); + + let e0: EF = r0.read().unwrap(); + let e1 = r1.read_many(EF::DIMENSION).unwrap(); + let e1 = EF::from_basis_coefficients_slice(&e1).unwrap(); + assert_eq!(e0, e1); + + let checkpoint0 = as Challenge>::sample(&mut r0); + let checkpoint1 = as Challenge>::sample(&mut r1); + assert_eq!(checkpoint0, checkpoint1); + } + } + + { + type F = Goldilocks; + type Ext = BinomialExtensionField; + type Perm = Poseidon2Goldilocks<16>; + type Challenger = DuplexChallenger; + let challenger = + Challenger::new(Perm::new_from_rng_128(&mut SmallRng::seed_from_u64(1000))); + run_test::(&challenger); + } + + { + type F = BabyBear; + type Ext = BinomialExtensionField; + type Perm = Poseidon2BabyBear<16>; + type Challenger = DuplexChallenger; + let challenger = + Challenger::new(Perm::new_from_rng_128(&mut SmallRng::seed_from_u64(1000))); + run_test::(&challenger); + } + + { + type F = KoalaBear; + type Ext = BinomialExtensionField; + type Perm = Poseidon2KoalaBear<16>; + type Challenger = DuplexChallenger; + let challenger = + Challenger::new(Perm::new_from_rng_128(&mut SmallRng::seed_from_u64(1000))); + run_test::(&challenger); + } + } +} diff --git a/src/parameters/mod.rs b/src/parameters/mod.rs index ba377725..5b5e5ab2 100644 --- a/src/parameters/mod.rs +++ b/src/parameters/mod.rs @@ -3,7 +3,7 @@ use core::fmt::Display; use errors::SecurityAssumption; use thiserror::Error; -use crate::whir::parameters::InitialPhaseConfig; +use crate::whir::parameters::InitialPhase; pub mod errors; @@ -151,7 +151,7 @@ pub struct ProtocolParameters { /// /// This determines whether an initial statement is included and which optimization /// strategy to use for the sumcheck protocol. - pub initial_phase_config: InitialPhaseConfig, + pub initial_phase: InitialPhase, /// The logarithmic inverse rate for sampling. pub starting_log_inv_rate: usize, /// The value v such that that the size of the Reed Solomon domain on which diff --git a/src/sumcheck/product_polynomial.rs b/src/sumcheck/product_polynomial.rs index 516758c0..8554a525 100644 --- a/src/sumcheck/product_polynomial.rs +++ b/src/sumcheck/product_polynomial.rs @@ -19,14 +19,18 @@ //! At each round, we compute a univariate polynomial `h(X)` that represents the partial sum //! over remaining variables. For quadratic sumcheck, `h(X)` is degree-2. -use p3_challenger::{FieldChallenger, GrindingChallenger}; use p3_field::{ExtensionField, Field, PackedFieldExtension, PackedValue, dot_product}; use p3_util::log2_strict_usize; use tracing::instrument; use crate::{ + fiat_shamir::{ + errors::FiatShamirError, + transcript::{Challenge, Pow, Writer}, + }, poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, - whir::{constraints::Constraint, proof::SumcheckData}, + sumcheck::sumcheck_single::observe_and_sample, + whir::constraints::Constraint, }; /// A paired representation of evaluation and weight polynomials for quadratic sumcheck. @@ -371,15 +375,14 @@ impl> ProductPolynomial { /// /// The verifier's challenge `r \in EF` for this round. #[instrument(skip_all)] - pub(crate) fn round( + pub(crate) fn round( &mut self, - sumcheck_data: &mut SumcheckData, - challenger: &mut Challenger, + transcript: &mut Transcript, sum: &mut EF, pow_bits: usize, - ) -> EF + ) -> Result where - Challenger: FieldChallenger + GrindingChallenger, + Transcript: Writer + Pow + Challenge, { // Step 1: Compute sumcheck polynomial coefficients. // @@ -406,7 +409,7 @@ impl> ProductPolynomial { }; // Step 2-4: Commit to transcript, do PoW, and receive challenge. - let r = sumcheck_data.observe_and_sample(challenger, c0, c2, pow_bits); + let r = observe_and_sample(transcript, c0, c2, pow_bits)?; // Step 5: Fold both polynomials using the challenge. self.compress(r); @@ -436,7 +439,7 @@ impl> ProductPolynomial { // are more efficient than packed operations. self.transition(); - r + Ok(r) } /// Extracts the evaluation polynomial as a scalar [`EvaluationsList`]. diff --git a/src/sumcheck/sumcheck_single.rs b/src/sumcheck/sumcheck_single.rs index 91f4abd0..7bca6c18 100644 --- a/src/sumcheck/sumcheck_single.rs +++ b/src/sumcheck/sumcheck_single.rs @@ -1,23 +1,62 @@ //! Sumcheck protocol implementation. -use p3_challenger::{FieldChallenger, GrindingChallenger}; +use alloc::vec::Vec; + use p3_field::{ExtensionField, Field, PackedFieldExtension, PackedValue, TwoAdicField}; use p3_interpolation::interpolate_subgroup; use p3_util::log2_strict_usize; use tracing::instrument; use crate::{ + fiat_shamir::{ + errors::FiatShamirError, + transcript::{Challenge, Pow, Writer}, + }, poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, sumcheck::{ product_polynomial::ProductPolynomial, sumcheck_single_skip::compute_skipping_sumcheck_polynomial, }, - whir::{ - constraints::{Constraint, statement::EqStatement}, - proof::{SumcheckData, SumcheckSkipData}, - }, + whir::constraints::{Constraint, statement::EqStatement}, }; +/// Commits polynomial coefficients to the transcript and returns a challenge. +/// +/// This helper function handles the Fiat-Shamir interaction for a sumcheck round. +/// +/// # Arguments +/// +/// * `transcript` - Fiat-Shamir transcript. +/// * `c0` - Constant coefficient `h(0)`. +/// * `c2` - Quadratic coefficient. +/// * `pow_bits` - PoW difficulty (0 to skip grinding). +/// +/// # Returns +/// +/// The sampled challenge `r \in EF`. +pub(crate) fn observe_and_sample, Transcript>( + transcript: &mut Transcript, + c0: EF, + c2: EF, + pow_bits: usize, +) -> Result +where + Transcript: Challenge + Pow + Writer, +{ + // Absorb coefficients into the transcript. + // + // Note: We only send (c_0, c_2). The verifier derives c_1 from the sum constraint. + transcript.write_many(&[c0, c2])?; + + // Optional proof-of-work to increase prover cost. + // + // This makes it expensive for a malicious prover to "mine" favorable challenges. + transcript.pow(pow_bits)?; + + // Sample the verifier's challenge for this round. + Ok(transcript.sample()) +} + /// Implements the single-round sumcheck protocol for verifying a multilinear polynomial evaluation. /// /// This struct is responsible for: @@ -104,18 +143,17 @@ where /// - Initializes internal sumcheck state with weights and expected sum. /// - Applies first set of sumcheck rounds #[instrument(skip_all)] - pub fn from_base_evals( + pub fn from_base_evals( + transcript: &mut Transcript, evals: &EvaluationsList, - sumcheck: &mut SumcheckData, - challenger: &mut Challenger, folding_factor: usize, pow_bits: usize, constraint: &Constraint, - ) -> (Self, MultilinearPoint) + ) -> Result<(Self, MultilinearPoint), FiatShamirError> where F: TwoAdicField, EF: TwoAdicField, - Challenger: FieldChallenger + GrindingChallenger, + Transcript: Challenge + Pow + Writer, { assert_ne!(folding_factor, 0); @@ -129,16 +167,13 @@ where "number of rounds must be less than or equal to instance size" ); - let (mut poly, r, mut sum) = - initial_round(evals, sumcheck, challenger, constraint, pow_bits); + let (mut poly, r, mut sum) = initial_round(evals, transcript, constraint, pow_bits)?; - let rs = core::iter::once(r) - .chain( - (1..folding_factor).map(|_| poly.round(sumcheck, challenger, &mut sum, pow_bits)), - ) - .collect(); + let rs = core::iter::once(Ok(r)) + .chain((1..folding_factor).map(|_| poly.round(transcript, &mut sum, pow_bits))) + .collect::, _>>()?; - (Self { poly, sum }, MultilinearPoint::new(rs)) + Ok((Self { poly, sum }, MultilinearPoint::new(rs))) } /// Constructs a new `SumcheckSingle` instance from evaluations in the base field. @@ -150,19 +185,18 @@ where /// - Applies first set of sumcheck rounds with univariate skip optimization. #[instrument(skip_all)] #[allow(clippy::too_many_arguments)] - pub fn with_skip( + pub fn with_skip( evals: &EvaluationsList, - skip_data: &mut SumcheckSkipData, - challenger: &mut Challenger, + transcript: &mut Transcript, folding_factor: usize, pow_bits: usize, k_skip: usize, constraint: &Constraint, - ) -> (Self, MultilinearPoint) + ) -> Result<(Self, MultilinearPoint), FiatShamirError> where F: TwoAdicField, EF: TwoAdicField, - Challenger: FieldChallenger + GrindingChallenger, + Transcript: Challenge + Pow + Writer, { assert_ne!(folding_factor, 0); assert!(k_skip > 1); @@ -188,18 +222,13 @@ where debug_assert_eq!(sumcheck_poly.iter().step_by(2).copied().sum::(), sum); // Fiat–Shamir: commit to h by absorbing its M evaluations into the transcript. - challenger.observe_algebra_slice(&sumcheck_poly); - - // Store skip evaluations - skip_data.evaluations.extend_from_slice(&sumcheck_poly); + transcript.write_many(&sumcheck_poly)?; // Proof-of-work challenge to delay prover (only if pow_bits > 0). - if pow_bits > 0 { - skip_data.pow = challenger.grind(pow_bits); - } + transcript.pow(pow_bits)?; // Receive the verifier challenge for this entire collapsed round. - let r: EF = challenger.sample_algebra_element(); + let r: EF = transcript.sample(); // Interpolate the LDE matrices at the folding randomness to get the new "folded" polynomial state. let new_p = EvaluationsList::new(interpolate_subgroup(&f_mat, r)); @@ -211,14 +240,11 @@ where let mut sum = poly.dot_product(); // Apply rest of sumcheck rounds - let rs = core::iter::once(r) - .chain( - (k_skip..folding_factor) - .map(|_| poly.round(&mut skip_data.sumcheck, challenger, &mut sum, pow_bits)), - ) - .collect(); - - (Self { poly, sum }, MultilinearPoint::new(rs)) + let rs = core::iter::once(Ok(r)) + .chain((k_skip..folding_factor).map(|_| poly.round(transcript, &mut sum, pow_bits))) + .collect::, _>>()?; + + Ok((Self { poly, sum }, MultilinearPoint::new(rs))) } /// Returns the number of variables in the polynomial. @@ -274,18 +300,17 @@ where /// - If `folding_factor > num_variables()` /// - If univariate skip is attempted with evaluations in the extension field. #[instrument(skip_all)] - pub fn compute_sumcheck_polynomials( + pub fn compute_sumcheck_polynomials( &mut self, - sumcheck_data: &mut SumcheckData, - challenger: &mut Challenger, + transcript: &mut Transcript, folding_factor: usize, pow_bits: usize, constraint: Option>, - ) -> MultilinearPoint + ) -> Result, FiatShamirError> where F: TwoAdicField, EF: TwoAdicField, - Challenger: FieldChallenger + GrindingChallenger, + Transcript: Challenge + Pow + Writer, { if let Some(constraint) = constraint { self.poly.combine(&mut self.sum, &constraint); @@ -294,14 +319,11 @@ where // Standard round-by-round folding // Proceed with one-variable-per-round folding for remaining variables. let res = (0..folding_factor) - .map(|_| { - self.poly - .round(sumcheck_data, challenger, &mut self.sum, pow_bits) - }) - .collect(); + .map(|_| self.poly.round(transcript, &mut self.sum, pow_bits)) + .collect::, _>>()?; // Return the full vector of verifier challenges as a multilinear point. - MultilinearPoint::new(res) + Ok(MultilinearPoint::new(res)) } } @@ -326,15 +348,14 @@ where /// * The verifier's challenge `r` as an `EF` element. /// * [`ProductPolynomial`] with new compressed polynomial evaluations and weights in the extension field. /// * Updated sum. -fn initial_round, Challenger>( +fn initial_round, Transcript>( evals: &EvaluationsList, - sumcheck_data: &mut SumcheckData, - challenger: &mut Challenger, + transcript: &mut Transcript, constraint: &Constraint, pow_bits: usize, -) -> (ProductPolynomial, EF, EF) +) -> Result<(ProductPolynomial, EF, EF), FiatShamirError> where - Challenger: FieldChallenger + GrindingChallenger, + Transcript: Challenge + Pow + Writer, { let num_vars = evals.num_variables(); @@ -363,7 +384,7 @@ where let c2 = EF::ExtensionPacking::to_ext_iter([c2]).sum(); // Commit to transcript, perform PoW, and receive challenge. - let r = sumcheck_data.observe_and_sample(challenger, c0, c2, pow_bits); + let r = observe_and_sample(transcript, c0, c2, pow_bits)?; // Fold both polynomials and update the sum. weights.compress(r); @@ -372,7 +393,7 @@ where let poly = ProductPolynomial::::new_packed(evals, weights); debug_assert_eq!(poly.dot_product(), sum); - (poly, r, sum) + Ok((poly, r, sum)) } else { // Scalar path: Direct computation for small polynomials. let (mut weights, mut sum) = constraint.combine_new(); @@ -381,7 +402,7 @@ where let (c0, c2) = evals.sumcheck_coefficients(&weights); // Commit to transcript, perform PoW, and receive challenge. - let r = sumcheck_data.observe_and_sample(challenger, c0, c2, pow_bits); + let r = observe_and_sample(transcript, c0, c2, pow_bits)?; // Fold both polynomials and update the sum. weights.compress(r); @@ -390,6 +411,6 @@ where let poly = ProductPolynomial::::new_small(evals, weights); debug_assert_eq!(poly.dot_product(), sum); - (poly, r, sum) + Ok((poly, r, sum)) } } diff --git a/src/sumcheck/sumcheck_single_svo.rs b/src/sumcheck/sumcheck_single_svo.rs index de01e9f8..e3d613aa 100644 --- a/src/sumcheck/sumcheck_single_svo.rs +++ b/src/sumcheck/sumcheck_single_svo.rs @@ -1,9 +1,12 @@ use alloc::{vec, vec::Vec}; -use p3_challenger::{FieldChallenger, GrindingChallenger}; use p3_field::{ExtensionField, Field, TwoAdicField}; use crate::{ + fiat_shamir::{ + errors::FiatShamirError, + transcript::{Challenge, Pow, Writer}, + }, poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, sumcheck::{ eq_state::SumcheckEqState, @@ -11,7 +14,7 @@ use crate::{ sumcheck_single::SumcheckSingle, sumcheck_small_value::{algorithm_5, svo_first_rounds}, }, - whir::{constraints::Constraint, proof::SumcheckData}, + whir::constraints::Constraint, }; /// Number of SVO rounds (first 3 rounds use special optimized algorithm). @@ -27,18 +30,17 @@ where /// Compute a Sumcheck using the Small Value Optimization (SVO) for the first three rounds and /// Algorithm 5 (page 18) for the remaining rounds. /// See Algorithm 6 (page 19) in . - pub fn from_base_evals_svo( + pub fn from_base_evals_svo( + transcript: &mut Transcript, evals: &EvaluationsList, - sumcheck_data: &mut SumcheckData, - challenger: &mut Challenger, folding_factor: usize, pow_bits: usize, constraint: &Constraint, - ) -> (Self, MultilinearPoint) + ) -> Result<(Self, MultilinearPoint), FiatShamirError> where F: TwoAdicField, EF: TwoAdicField, - Challenger: FieldChallenger + GrindingChallenger, + Transcript: Writer + Challenge + Pow, { assert_ne!(folding_factor, 0); let mut challenges = Vec::with_capacity(folding_factor); @@ -52,40 +54,38 @@ where let mut eq_poly = SumcheckEqState::<_, NUM_SVO_ROUNDS>::new(w); svo_first_rounds( - sumcheck_data, - challenger, + transcript, evals, w, &mut eq_poly, &mut challenges, &mut sum, pow_bits, - ); + )?; // We fold to obtain p(r1, r2, r3, x). let mut folded_evals = evals.fold_batch(&challenges); algorithm_5( - sumcheck_data, - challenger, + transcript, &mut folded_evals, &mut eq_poly, &mut challenges, &mut sum, pow_bits, - ); + )?; let challenge_point = MultilinearPoint::new(challenges); // Final weight: eq(w, r) let weights = EvaluationsList::new(vec![w.eq_poly(&challenge_point)]); - ( + Ok(( Self { poly: ProductPolynomial::new(folded_evals, weights), sum, }, challenge_point, - ) + )) } } diff --git a/src/sumcheck/sumcheck_small_value.rs b/src/sumcheck/sumcheck_small_value.rs index fc2a4162..bc53a01b 100644 --- a/src/sumcheck/sumcheck_small_value.rs +++ b/src/sumcheck/sumcheck_small_value.rs @@ -1,14 +1,16 @@ use alloc::vec::Vec; use core::ops::{Add, AddAssign}; -use p3_challenger::{FieldChallenger, GrindingChallenger}; use p3_field::{ExtensionField, Field}; use p3_maybe_rayon::prelude::*; use crate::{ + fiat_shamir::{ + errors::FiatShamirError, + transcript::{Challenge, Pow, Writer}, + }, poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, sumcheck::{eq_state::SumcheckEqState, sumcheck_single_svo::NUM_SVO_ROUNDS}, - whir::proof::SumcheckData, }; /// A container for the SVO accumulators for a specific number of rounds `N`. @@ -94,17 +96,17 @@ impl AddAssign for SvoAccumulators { /// Algorithm 6. Page 19. /// Compute three sumcheck rounds using the small value optimization and split-eq accumulators. #[allow(clippy::too_many_arguments)] -pub fn svo_first_rounds>( - sumcheck_data: &mut SumcheckData, - challenger: &mut Challenger, +pub fn svo_first_rounds>( + transcript: &mut Transcript, poly: &EvaluationsList, w: &MultilinearPoint, eq_poly: &mut SumcheckEqState<'_, EF, NUM_SVO_ROUNDS>, challenges: &mut Vec, sum: &mut EF, pow_bits: usize, -) where - Challenger: FieldChallenger + GrindingChallenger, +) -> Result<(), FiatShamirError> +where + Transcript: Writer + Challenge + Pow, { let (e_in, e_out) = join( || w.svo_e_in_table::(), @@ -133,15 +135,11 @@ pub fn svo_first_rounds>( let s_inf = (t_1_evals[1] - t_1_evals[0]) * (linear_1_evals[1] - linear_1_evals[0]); // 3. Send S_1(u) to the verifier. - sumcheck_data.polynomial_evaluations.push([s_0, s_inf]); - challenger.observe_algebra_slice(&[s_0, s_inf]); - - if pow_bits > 0 { - sumcheck_data.push_pow_witness(challenger.grind(pow_bits)); - } + transcript.write_many(&[s_0, s_inf])?; + transcript.pow(pow_bits)?; // 4. Receive the challenge r_1 from the verifier. - let r_1: EF = challenger.sample_algebra_element(); + let r_1: EF = transcript.sample(); challenges.push(r_1); eq_poly.bind(r_1); @@ -179,15 +177,11 @@ pub fn svo_first_rounds>( let s_inf = (t_2_evals[1] - t_2_evals[0]) * (linear_2_evals[1] - linear_2_evals[0]); // 3. Send S_2(u) to the verifier. - sumcheck_data.polynomial_evaluations.push([s_0, s_inf]); - challenger.observe_algebra_slice(&[s_0, s_inf]); - - if pow_bits > 0 { - sumcheck_data.push_pow_witness(challenger.grind(pow_bits)); - } + transcript.write_many(&[s_0, s_inf])?; + transcript.pow(pow_bits)?; // 4. Receive the challenge r_2 from the verifier. - let r_2: EF = challenger.sample_algebra_element(); + let r_2: EF = transcript.sample(); challenges.push(r_2); eq_poly.bind(r_2); @@ -246,15 +240,11 @@ pub fn svo_first_rounds>( ]; // 3. Send S_3(u) to the verifier. - sumcheck_data.polynomial_evaluations.push(round_poly_evals); - challenger.observe_algebra_slice(&round_poly_evals); - - if pow_bits > 0 { - sumcheck_data.push_pow_witness(challenger.grind(pow_bits)); - } + transcript.write_many(&round_poly_evals)?; + transcript.pow(pow_bits)?; // 4. Receive the challenge r_3 from the verifier. - let r_3: EF = challenger.sample_algebra_element(); + let r_3: EF = transcript.sample(); challenges.push(r_3); eq_poly.bind(r_3); @@ -262,6 +252,8 @@ pub fn svo_first_rounds>( *sum = round_poly_evals[1] * r_3.square() + (eval_1 - round_poly_evals[0] - round_poly_evals[1]) * r_3 + round_poly_evals[0]; + + Ok(()) } /// Computes the round polynomial evaluations `t_i(u)` for a single standard sumcheck round. @@ -371,18 +363,18 @@ where /// Algorithm 5. Page 18. /// Compute the remaining sumcheck rounds, from round l0 + 1 to round l. -pub fn algorithm_5( - sumcheck_data: &mut SumcheckData, - challenger: &mut Challenger, +pub fn algorithm_5( + transcript: &mut Transcript, poly: &mut EvaluationsList, eq_poly: &mut SumcheckEqState<'_, EF, NUM_SVO_ROUNDS>, challenges: &mut Vec, sum: &mut EF, pow_bits: usize, -) where +) -> Result<(), FiatShamirError> +where F: Field, EF: ExtensionField + Send + Sync, - Challenger: FieldChallenger + GrindingChallenger, + Transcript: Writer + Challenge + Pow, { let num_vars = eq_poly.num_variables(); // Current position in the sumcheck @@ -410,14 +402,11 @@ pub fn algorithm_5( let s_inf = (t_evals[1] - t_evals[0]) * (linear_evals[1] - linear_evals[0]); // Send S_i(u) to the verifier - sumcheck_data.polynomial_evaluations.push([s_0, s_inf]); - challenger.observe_algebra_slice(&[s_0, s_inf]); - if pow_bits > 0 { - sumcheck_data.push_pow_witness(challenger.grind(pow_bits)); - } + transcript.write_many(&[s_0, s_inf])?; + transcript.pow(pow_bits)?; // Receive the challenge r_i from the verifier - let r_i: EF = challenger.sample_algebra_element(); + let r_i: EF = transcript.sample(); challenges.push(r_i); // Update state for next round: binding updates scalar AND pops used table @@ -428,6 +417,8 @@ pub fn algorithm_5( let eval_1 = *sum - s_0; *sum = s_inf * r_i.square() + (eval_1 - s_0 - s_inf) * r_i + s_0; } + + Ok(()) } #[cfg(test)] diff --git a/src/sumcheck/tests.rs b/src/sumcheck/tests.rs index 36753a8f..389e20ab 100644 --- a/src/sumcheck/tests.rs +++ b/src/sumcheck/tests.rs @@ -1,7 +1,7 @@ use alloc::{vec, vec::Vec}; use p3_baby_bear::{BabyBear, Poseidon2BabyBear}; -use p3_challenger::{DuplexChallenger, FieldChallenger, GrindingChallenger}; +use p3_challenger::DuplexChallenger; use p3_field::{PrimeCharacteristicRing, TwoAdicField, extension::BinomialExtensionField}; use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; use rand::{Rng, SeedableRng, rngs::SmallRng}; @@ -9,7 +9,12 @@ use rand::{Rng, SeedableRng, rngs::SmallRng}; use super::sumcheck_single::SumcheckSingle; use crate::{ constant::K_SKIP_SUMCHECK, - fiat_shamir::domain_separator::DomainSeparator, + fiat_shamir::{ + domain_separator::DomainSeparator, + transcript::{ + Challenge, ChallengeBits, FiatShamirReader, FiatShamirWriter, Reader, Writer, + }, + }, parameters::{FoldingFactor, ProtocolParameters, errors::SecurityAssumption}, poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, whir::{ @@ -18,11 +23,8 @@ use crate::{ evaluator::ConstraintPolyEvaluator, statement::{EqStatement, SelectStatement}, }, - parameters::InitialPhaseConfig, - proof::{InitialPhase, SumcheckData, WhirProof}, - verifier::sumcheck::{ - verify_final_sumcheck_rounds, verify_initial_sumcheck_rounds, verify_sumcheck_rounds, - }, + parameters::InitialPhase, + verifier::sumcheck::{verify_initial_sumcheck_rounds, verify_standard_sumcheck_rounds}, }, }; @@ -51,13 +53,13 @@ fn domainsep_and_challenger() -> (DomainSeparator, MyChallenger) { fn create_test_protocol_params( folding_factor: FoldingFactor, - initial_phase_config: InitialPhaseConfig, + initial_phase_config: InitialPhase, ) -> ProtocolParameters { let mut rng = SmallRng::seed_from_u64(1); let perm = Perm::new_from_rng_128(&mut rng); ProtocolParameters { - initial_phase_config, + initial_phase: initial_phase_config, security_level: 32, pow_bits: 0, rs_domain_initial_reduction_factor: 1, @@ -69,16 +71,15 @@ fn create_test_protocol_params( } } -fn make_constraint( - challenger: &mut Challenger, - constraint_evals: &mut Vec, +fn make_constraint( + transcript: &mut Transcript, num_vars: usize, num_eqs: usize, num_sels: usize, poly: &EvaluationsList, ) -> Constraint where - Challenger: FieldChallenger + GrindingChallenger, + Transcript: Writer + Challenge + ChallengeBits, { // To simulate stir point derivation derive domain generator let omega = F::two_adic_generator(num_vars); @@ -92,7 +93,7 @@ where // - Collect (point, eval) pairs for use in the statement and constraint aggregation. (0..num_eqs).for_each(|_| { // Sample a univariate field element from the prover's challenger. - let point: EF = challenger.sample_algebra_element(); + let point = >::sample(transcript); // Expand it into a `num_vars`-dimensional multilinear point. let point = MultilinearPoint::expand_from_univariate(point, num_vars); @@ -100,11 +101,8 @@ where // Evaluate the current sumcheck polynomial at the sampled point. let eval = poly.evaluate_hypercube_base(&point); - // Store evaluation for verifier to read later. - constraint_evals.push(eval); - // Add the evaluation result to the transcript for Fiat-Shamir soundness. - challenger.observe_algebra_element(eval); + transcript.write(eval).unwrap(); // Add the evaluation constraint: poly(point) == eval. eq_statement.add_evaluated_constraint(point, eval); @@ -115,7 +113,7 @@ where // - Collect (var, eval) pairs for use in the statement and constraint aggregation. (0..num_sels).for_each(|_| { // Simulate stir point derivation - let index: usize = challenger.sample_bits(num_vars); + let index: usize = ::sample(transcript, num_vars); let var = omega.exp_u64(index as u64); // Evaluate the current sumcheck polynomial as univariate at the sampled variable. @@ -123,32 +121,28 @@ where .iter() .rfold(EF::ZERO, |result, &coeff| result * var + coeff); - // Store evaluation for verifier to read later. - constraint_evals.push(eval); - // Add the evaluation result to the transcript for Fiat-Shamir soundness. - challenger.observe_algebra_element(eval); + transcript.write(eval).unwrap(); // Add the evaluation constraint: poly(point) == eval. sel_statement.add_constraint(var, eval); }); // Return the constructed constraint with the alpha used for linear combination. - let alpha: EF = challenger.sample_algebra_element(); + let alpha: EF = >::sample(transcript); Constraint::new(alpha, eq_statement, sel_statement) } -fn make_constraint_ext( - challenger: &mut Challenger, - constraint_evals: &mut Vec, +fn make_constraint_ext( + transcript: &mut Transcript, num_vars: usize, num_eqs: usize, num_sels: usize, poly: &EvaluationsList, ) -> Constraint where - Challenger: FieldChallenger + GrindingChallenger, + Transcript: Writer + Challenge + ChallengeBits, { // To simulate stir point derivation derive domain generator let omega = F::two_adic_generator(num_vars); @@ -162,7 +156,7 @@ where // - Collect (point, eval) pairs for use in the statement and constraint aggregation. (0..num_eqs).for_each(|_| { // Sample a univariate field element from the prover's challenger. - let point: EF = challenger.sample_algebra_element(); + let point = >::sample(transcript); // Expand it into a `num_vars`-dimensional multilinear point. let point = MultilinearPoint::expand_from_univariate(point, num_vars); @@ -171,10 +165,10 @@ where let eval = poly.evaluate_hypercube_ext::(&point); // Store evaluation for verifier to read later. - constraint_evals.push(eval); + // constraint_evals.push(eval); // Add the evaluation result to the transcript for Fiat-Shamir soundness. - challenger.observe_algebra_element(eval); + transcript.write(eval).unwrap(); // Add the evaluation constraint: poly(point) == eval. eq_statement.add_evaluated_constraint(point, eval); @@ -185,7 +179,7 @@ where // - Collect (var, eval) pairs for use in the statement and constraint aggregation. (0..num_sels).for_each(|_| { // Simulate stir point derivation - let index: usize = challenger.sample_bits(num_vars); + let index = ::sample(transcript, num_vars); let var = omega.exp_u64(index as u64); @@ -194,43 +188,40 @@ where .iter() .rfold(EF::ZERO, |result, &coeff| result * var + coeff); - // Store evaluation for verifier to read later. - constraint_evals.push(eval); - - // Add the evaluation result to the transcript for Fiat-Shamir soundness. - challenger.observe_algebra_element(eval); + // Add the evaluation result to the transcript for Fiat-Shamir soundness.e + transcript.write(eval).unwrap(); // Add the evaluation constraint: poly(point) == eval. sel_statement.add_constraint(var, eval); }); - // Return the constructed constraint with the alpha used for linear combination. - let alpha: EF = challenger.sample_algebra_element(); - - Constraint::new(alpha, eq_statement, sel_statement) + Constraint::new( + >::sample(transcript), + eq_statement, + sel_statement, + ) } -fn read_constraint( - challenger: &mut Challenger, - constraint_evals: &[EF], +fn read_constraint( + transcript: &mut Transcript, num_vars: usize, num_eqs: usize, num_sels: usize, ) -> Constraint where - Challenger: FieldChallenger + GrindingChallenger, + Transcript: Reader + Challenge + ChallengeBits, { // Create a new statement that will hold all reconstructed constraints. let mut eq_statement = EqStatement::initialize(num_vars); // For each point, sample a challenge and read its corresponding evaluation from the proof. - for &eval in constraint_evals.iter().take(num_eqs) { + for _ in 0..num_eqs { + let point = >::sample(transcript); // Sample a univariate challenge and expand to a multilinear point. - let point = - MultilinearPoint::expand_from_univariate(challenger.sample_algebra_element(), num_vars); + let point = MultilinearPoint::expand_from_univariate(point, num_vars); // Observe the evaluation to keep the challenger synchronized (must match prover) - challenger.observe_algebra_element(eval); + let eval = transcript.read().unwrap(); // Add the constraint: poly(point) == eval. eq_statement.add_evaluated_constraint(point, eval); @@ -243,24 +234,21 @@ where let omega = F::two_adic_generator(num_vars); // For each point, sample a challenge and read its corresponding evaluation from the proof. - for i in 0..num_sels { + for _ in 0..num_sels { // Simulate stir point derivation - let index: usize = challenger.sample_bits(num_vars); + let index = ::sample(transcript, num_vars); let var = omega.exp_u64(index as u64); // Read the committed evaluation corresponding to this point from constraint_evals. // Sel evaluations are stored after eq evaluations. - let eval = constraint_evals[num_eqs + i]; - - // Observe the evaluation to keep the challenger synchronized (must match prover) - challenger.observe_algebra_element(eval); + let eval = transcript.read().unwrap(); // Add the constraint: poly(point) == eval. sel_statement.add_constraint(var, eval); } Constraint::new( - challenger.sample_algebra_element(), + >::sample(transcript), eq_statement, sel_statement, ) @@ -304,45 +292,17 @@ fn run_sumcheck_test( let mut rng = SmallRng::seed_from_u64(1); let poly = EvaluationsList::new((0..1 << num_vars).map(|_| rng.random()).collect()); - // PROVER - let (domsep, challenger) = domainsep_and_challenger(); - let mut prover_challenger = challenger.clone(); - - // Initialize proof and challenger - let params = - create_test_protocol_params(folding_factor, InitialPhaseConfig::WithStatementClassic); - let mut proof = WhirProof::::from_protocol_parameters(¶ms, num_vars); - domsep.observe_domain_separator(&mut prover_challenger); - - // Store constraint evaluations for each round (prover writes, verifier reads) - let mut all_constraint_evals: Vec> = Vec::new(); + let (domsep, mut challenger) = domainsep_and_challenger(); + domsep.observe_domain_separator(&mut challenger); - // Create the initial constraint statement - let mut constraint_evals: Vec = Vec::new(); - let constraint = make_constraint( - &mut prover_challenger, - &mut constraint_evals, - num_vars, - num_eqs[0], - num_sels[0], - &poly, - ); - all_constraint_evals.push(constraint_evals); + // PROVER + let mut transcript = FiatShamirWriter::init(challenger.clone()); + let constraint = make_constraint(&mut transcript, num_vars, num_eqs[0], num_sels[0], &poly); // ROUND 0 let folding0 = folding_factor.at_round(0); - // Extract sumcheck data from the initial phase - let InitialPhase::WithStatement { ref mut sumcheck } = proof.initial_phase else { - panic!("Expected WithStatement variant"); - }; - let (mut sumcheck, mut prover_randomness) = SumcheckSingle::from_base_evals( - &poly, - sumcheck, - &mut prover_challenger, - folding0, - 0, - &constraint, - ); + let (mut sumcheck, mut prover_randomness) = + SumcheckSingle::from_base_evals(&mut transcript, &poly, folding0, 0, &constraint).unwrap(); // Track how many variables remain to fold let mut num_vars_inter = num_vars - folding0; @@ -353,27 +313,20 @@ fn run_sumcheck_test( { let folding = folding_factor.at_round(round); // Sample new evaluation constraints and combine them into the sumcheck state - let mut constraint_evals: Vec = Vec::new(); let constraint = make_constraint_ext( - &mut prover_challenger, - &mut constraint_evals, + &mut transcript, num_vars_inter, num_eq_points, num_sel_points, &sumcheck.evals(), ); - all_constraint_evals.push(constraint_evals); // Compute and apply the next folding round - let mut sumcheck_data: SumcheckData = SumcheckData::default(); - prover_randomness.extend(&sumcheck.compute_sumcheck_polynomials( - &mut sumcheck_data, - &mut prover_challenger, - folding, - 0, - Some(constraint), - )); - proof.rounds[round - 1].sumcheck = sumcheck_data; + prover_randomness.extend( + &sumcheck + .compute_sumcheck_polynomials(&mut transcript, folding, 0, Some(constraint)) + .unwrap(), + ); num_vars_inter -= folding; @@ -386,15 +339,11 @@ fn run_sumcheck_test( assert_eq!(num_vars_inter, final_rounds); // FINAL ROUND - let mut sumcheck_data: SumcheckData = SumcheckData::default(); - prover_randomness.extend(&sumcheck.compute_sumcheck_polynomials( - &mut sumcheck_data, - &mut prover_challenger, - final_rounds, - 0, - None, - )); - proof.set_final_sumcheck_data(sumcheck_data); + prover_randomness.extend( + &sumcheck + .compute_sumcheck_polynomials(&mut transcript, final_rounds, 0, None) + .unwrap(), + ); let final_folded_value = sumcheck.evals().as_constant().unwrap(); assert_eq!(sumcheck.num_variables(), 0); @@ -405,11 +354,11 @@ fn run_sumcheck_test( poly.evaluate_hypercube_base(&prover_randomness), final_folded_value ); - // Commit final result to Fiat-Shamir transcript - prover_challenger.observe_algebra_element(final_folded_value); + // Commit final result to Fiat-Shamir transcripte + transcript.write(final_folded_value).unwrap(); + let proof = transcript.finalize(); // VERIFIER - let mut verifer_challenger = challenger; // Running total for the verifier’s sum of constraint combinations let mut sum = EF::ZERO; @@ -423,19 +372,12 @@ fn run_sumcheck_test( // Recompute the same variable count as prover had let mut num_vars_inter = num_vars; - // Apply domain separator to verifier challenger - domsep.observe_domain_separator(&mut verifer_challenger); + let mut transcript = FiatShamirReader::init(proof, challenger); // VERIFY INITIAL ROUND (round 0) { // Reconstruct round constraint from transcript - let constraint = read_constraint( - &mut verifer_challenger, - &all_constraint_evals[0], - num_vars_inter, - num_eqs[0], - num_sels[0], - ); + let constraint = read_constraint(&mut transcript, num_vars_inter, num_eqs[0], num_sels[0]); // Accumulate the weighted sum of constraint values constraint.combine_evals(&mut sum); // Save constraints for later equality check @@ -445,8 +387,8 @@ fn run_sumcheck_test( let folding = folding_factor.at_round(0); verifier_randomness.extend( &verify_initial_sumcheck_rounds( - &proof.initial_phase, - &mut verifer_challenger, + &mut transcript, + InitialPhase::WithStatementClassic, &mut sum, folding, 0, @@ -463,8 +405,7 @@ fn run_sumcheck_test( { // Reconstruct round constraint from transcript let constraint = read_constraint( - &mut verifer_challenger, - &all_constraint_evals[round_idx], + &mut transcript, num_vars_inter, num_eq_points, num_sel_points, @@ -478,14 +419,7 @@ fn run_sumcheck_test( // Note: proof.rounds[round_idx - 1] because rounds are 0-indexed but we start at round 1 let folding = folding_factor.at_round(round_idx); verifier_randomness.extend( - &verify_sumcheck_rounds( - &proof.rounds[round_idx - 1], - &mut verifer_challenger, - &mut sum, - folding, - 0, - ) - .unwrap(), + &verify_standard_sumcheck_rounds(&mut transcript, folding, &mut sum, 0).unwrap(), ); num_vars_inter -= folding; @@ -493,14 +427,7 @@ fn run_sumcheck_test( // Final round check verifier_randomness.extend( - &verify_final_sumcheck_rounds( - proof.final_sumcheck.as_ref(), - &mut verifer_challenger, - &mut sum, - final_rounds, - 0, - ) - .unwrap(), + &verify_standard_sumcheck_rounds(&mut transcript, final_rounds, &mut sum, 0).unwrap(), ); // Check that the randomness vectors are the same @@ -552,50 +479,34 @@ fn run_sumcheck_test_skips( let mut rng = SmallRng::seed_from_u64(1); let poly = EvaluationsList::new((0..1 << num_vars).map(|_| rng.random()).collect()); + let (domsep, mut challenger) = domainsep_and_challenger(); + domsep.observe_domain_separator(&mut challenger); // PROVER SIDE - let (domsep, challenger) = domainsep_and_challenger(); - let mut prover_challenger = challenger.clone(); - - // Initialize proof and challenger - let params = create_test_protocol_params( - folding_factor, - InitialPhaseConfig::WithStatementUnivariateSkip, - ); - let mut proof = WhirProof::::from_protocol_parameters(¶ms, num_vars); - domsep.observe_domain_separator(&mut prover_challenger); - - // Store constraint evaluations for each round (prover writes, verifier reads) - let mut all_constraint_evals: Vec> = Vec::new(); + let mut transcript = FiatShamirWriter::init(challenger.clone()); // Sample and commit initial evaluation constraints - let mut constraint_evals: Vec = Vec::new(); let constraint = make_constraint( - &mut prover_challenger, - &mut constraint_evals, + &mut transcript, num_vars, num_eq_points[0], num_sel_points[0], &poly, ); - all_constraint_evals.push(constraint_evals); constraint.validate_for_skip_case(); // ROUND 0 // Initialize sumcheck with univariate skip (skips K_SKIP_SUMCHECK) let folding0 = folding_factor.at_round(0); - // Extract skip data from the initial phase - let InitialPhase::WithStatementSkip(ref mut skip_data) = proof.initial_phase else { - panic!("Expected WithStatementSkip variant"); - }; + let (mut sumcheck, mut prover_randomness) = SumcheckSingle::with_skip( &poly, - skip_data, - &mut prover_challenger, + &mut transcript, folding0, 0, K_SKIP_SUMCHECK, &constraint, - ); + ) + .unwrap(); // Track how many variables remain after folding let mut num_vars_inter = num_vars - folding0; @@ -609,27 +520,20 @@ fn run_sumcheck_test_skips( { let folding = folding_factor.at_round(round); // Sample new evaluation constraints and combine them into the sumcheck state - let mut constraint_evals: Vec = Vec::new(); let constraint = make_constraint_ext( - &mut prover_challenger, - &mut constraint_evals, + &mut transcript, num_vars_inter, num_eq_points, num_sel_points, &sumcheck.evals(), ); - all_constraint_evals.push(constraint_evals); // Fold the sumcheck polynomial again and extend randomness vector - let mut sumcheck_data: SumcheckData = SumcheckData::default(); - prover_randomness.extend(&sumcheck.compute_sumcheck_polynomials( - &mut sumcheck_data, - &mut prover_challenger, - folding, - 0, - Some(constraint), - )); - proof.rounds[round - 1].sumcheck = sumcheck_data; + prover_randomness.extend( + &sumcheck + .compute_sumcheck_polynomials(&mut transcript, folding, 0, Some(constraint)) + .unwrap(), + ); num_vars_inter -= folding; @@ -642,15 +546,11 @@ fn run_sumcheck_test_skips( assert_eq!(num_vars_inter, final_rounds); // FINAL ROUND - let mut sumcheck_data: SumcheckData = SumcheckData::default(); - prover_randomness.extend(&sumcheck.compute_sumcheck_polynomials( - &mut sumcheck_data, - &mut prover_challenger, - final_rounds, - 0, - None, - )); - proof.set_final_sumcheck_data(sumcheck_data); + prover_randomness.extend( + &sumcheck + .compute_sumcheck_polynomials(&mut transcript, final_rounds, 0, None) + .unwrap(), + ); // After final round, polynomial must collapse to a constant assert_eq!(sumcheck.num_variables(), 0); @@ -671,10 +571,11 @@ fn run_sumcheck_test_skips( final_folded_value ); // Commit final result to Fiat-Shamir transcript - prover_challenger.observe_algebra_element(final_folded_value); + transcript.write(final_folded_value).unwrap(); + let proof = transcript.finalize(); // VERIFIER SIDE - let mut verifier_challenger = challenger; + let mut transcript = FiatShamirReader::init(proof, challenger); // Running total for the verifier's sum of constraint combinations let mut sum = EF::ZERO; @@ -688,15 +589,11 @@ fn run_sumcheck_test_skips( // Recompute the same variable count as prover had let mut num_vars_inter = num_vars; - // Apply domain separator to verifier challenger - domsep.observe_domain_separator(&mut verifier_challenger); - // VERIFY INITIAL ROUND (round 0) { // Reconstruct round constraint from transcript let constraint = read_constraint( - &mut verifier_challenger, - &all_constraint_evals[0], + &mut transcript, num_vars_inter, num_eq_points[0], num_sel_points[0], @@ -710,8 +607,8 @@ fn run_sumcheck_test_skips( let folding = folding_factor.at_round(0); verifier_randomness.extend( &verify_initial_sumcheck_rounds( - &proof.initial_phase, - &mut verifier_challenger, + &mut transcript, + InitialPhase::WithStatementSkip, &mut sum, folding, 0, @@ -730,13 +627,7 @@ fn run_sumcheck_test_skips( .skip(1) { // Reconstruct round constraint from transcript - let constraint = read_constraint( - &mut verifier_challenger, - &all_constraint_evals[round_idx], - num_vars_inter, - num_eq_pts, - num_sel_pts, - ); + let constraint = read_constraint(&mut transcript, num_vars_inter, num_eq_pts, num_sel_pts); // Accumulate the weighted sum of constraint values constraint.combine_evals(&mut sum); // Save constraints for later equality check @@ -745,14 +636,7 @@ fn run_sumcheck_test_skips( // Extend r with verifier's folding challenges let folding = folding_factor.at_round(round_idx); verifier_randomness.extend( - &verify_sumcheck_rounds( - &proof.rounds[round_idx - 1], - &mut verifier_challenger, - &mut sum, - folding, - 0, - ) - .unwrap(), + &verify_standard_sumcheck_rounds(&mut transcript, folding, &mut sum, 0).unwrap(), ); num_vars_inter -= folding; @@ -760,14 +644,7 @@ fn run_sumcheck_test_skips( // FINAL FOLDING verifier_randomness.extend( - &verify_final_sumcheck_rounds( - proof.final_sumcheck.as_ref(), - &mut verifier_challenger, - &mut sum, - final_rounds, - 0, - ) - .unwrap(), + &verify_standard_sumcheck_rounds(&mut transcript, final_rounds, &mut sum, 0).unwrap(), ); // Check that the randomness vectors are the same @@ -803,43 +680,19 @@ fn run_sumcheck_test_svo( let poly = EvaluationsList::new((0..1 << num_vars).map(|_| rng.random()).collect()); // PROVER - let (domsep, challenger) = domainsep_and_challenger(); - let mut prover_challenger = challenger.clone(); + let (domsep, mut challenger) = domainsep_and_challenger(); + domsep.observe_domain_separator(&mut challenger); + let mut transcript = FiatShamirWriter::init(challenger.clone()); // Initialize proof and challenger - let params = create_test_protocol_params(folding_factor, InitialPhaseConfig::WithStatementSvo); - let mut proof = WhirProof::::from_protocol_parameters(¶ms, num_vars); - domsep.observe_domain_separator(&mut prover_challenger); - - // Store constraint evaluations for each round (prover writes, verifier reads) - let mut all_constraint_evals: Vec> = Vec::new(); - - // Create the initial constraint statement - let mut constraint_evals: Vec = Vec::new(); - let constraint = make_constraint( - &mut prover_challenger, - &mut constraint_evals, - num_vars, - num_eqs[0], - num_sels[0], - &poly, - ); - all_constraint_evals.push(constraint_evals); + let constraint = make_constraint(&mut transcript, num_vars, num_eqs[0], num_sels[0], &poly); // ROUND 0 let folding0 = folding_factor.at_round(0); // Extract sumcheck data from the initial phase - let InitialPhase::WithStatementSvo { ref mut sumcheck } = proof.initial_phase else { - panic!("Expected WithStatementSvo variant"); - }; - let (mut sumcheck, mut prover_randomness) = SumcheckSingle::from_base_evals_svo( - &poly, - sumcheck, - &mut prover_challenger, - folding0, - 0, - &constraint, - ); + let (mut sumcheck, mut prover_randomness) = + SumcheckSingle::from_base_evals_svo(&mut transcript, &poly, folding0, 0, &constraint) + .unwrap(); // Track how many variables remain to fold let mut num_vars_inter = num_vars - folding0; @@ -850,27 +703,20 @@ fn run_sumcheck_test_svo( { let folding = folding_factor.at_round(round); // Sample new evaluation constraints and combine them into the sumcheck state - let mut constraint_evals: Vec = Vec::new(); let constraint = make_constraint_ext( - &mut prover_challenger, - &mut constraint_evals, + &mut transcript, num_vars_inter, num_eq_points, num_sel_points, &sumcheck.evals(), ); - all_constraint_evals.push(constraint_evals); // Compute and apply the next folding round - let mut sumcheck_data: SumcheckData = SumcheckData::default(); - prover_randomness.extend(&sumcheck.compute_sumcheck_polynomials( - &mut sumcheck_data, - &mut prover_challenger, - folding, - 0, - Some(constraint), - )); - proof.rounds[round - 1].sumcheck = sumcheck_data; + prover_randomness.extend( + &sumcheck + .compute_sumcheck_polynomials(&mut transcript, folding, 0, Some(constraint)) + .unwrap(), + ); num_vars_inter -= folding; @@ -883,16 +729,11 @@ fn run_sumcheck_test_svo( assert_eq!(num_vars_inter, final_rounds); // FINAL ROUND - let mut sumcheck_data: SumcheckData = SumcheckData::default(); - prover_randomness.extend(&sumcheck.compute_sumcheck_polynomials( - &mut sumcheck_data, - &mut prover_challenger, - final_rounds, - 0, - None, - )); - proof.set_final_sumcheck_data(sumcheck_data); - + prover_randomness.extend( + &sumcheck + .compute_sumcheck_polynomials(&mut transcript, final_rounds, 0, None) + .unwrap(), + ); assert_eq!(sumcheck.num_variables(), 0); assert_eq!(sumcheck.num_evals(), 1); @@ -903,10 +744,11 @@ fn run_sumcheck_test_svo( final_folded_value ); // Commit final result to Fiat-Shamir transcript - prover_challenger.observe_algebra_element(final_folded_value); + transcript.write(final_folded_value).unwrap(); + let proof = transcript.finalize(); // VERIFIER - let mut verifier_challenger = challenger; + let mut transcript = FiatShamirReader::init(proof, challenger); // Running total for the verifier's sum of constraint combinations let mut sum = EF::ZERO; @@ -920,19 +762,10 @@ fn run_sumcheck_test_svo( // Recompute the same variable count as prover had let mut num_vars_inter = num_vars; - // Apply domain separator to verifier challenger - domsep.observe_domain_separator(&mut verifier_challenger); - // VERIFY INITIAL ROUND (round 0) { // Reconstruct round constraint from transcript - let constraint = read_constraint( - &mut verifier_challenger, - &all_constraint_evals[0], - num_vars_inter, - num_eqs[0], - num_sels[0], - ); + let constraint = read_constraint(&mut transcript, num_vars_inter, num_eqs[0], num_sels[0]); // Accumulate the weighted sum of constraint values constraint.combine_evals(&mut sum); // Save constraints for later equality check @@ -942,8 +775,8 @@ fn run_sumcheck_test_svo( let folding = folding_factor.at_round(0); verifier_randomness.extend( &verify_initial_sumcheck_rounds( - &proof.initial_phase, - &mut verifier_challenger, + &mut transcript, + InitialPhase::WithStatementSvo, &mut sum, folding, 0, @@ -959,13 +792,7 @@ fn run_sumcheck_test_svo( num_eqs.iter().zip(num_sels.iter()).enumerate().skip(1) { // Reconstruct round constraint from transcript - let constraint = read_constraint( - &mut verifier_challenger, - &all_constraint_evals[round_idx], - num_vars_inter, - num_eq_pts, - num_sel_pts, - ); + let constraint = read_constraint(&mut transcript, num_vars_inter, num_eq_pts, num_sel_pts); // Accumulate the weighted sum of constraint values constraint.combine_evals(&mut sum); // Save constraints for later equality check @@ -974,14 +801,7 @@ fn run_sumcheck_test_svo( // Extend r with verifier's folding challenges let folding = folding_factor.at_round(round_idx); verifier_randomness.extend( - &verify_sumcheck_rounds( - &proof.rounds[round_idx - 1], - &mut verifier_challenger, - &mut sum, - folding, - 0, - ) - .unwrap(), + &verify_standard_sumcheck_rounds(&mut transcript, folding, &mut sum, 0).unwrap(), ); num_vars_inter -= folding; @@ -989,14 +809,7 @@ fn run_sumcheck_test_svo( // FINAL FOLDING verifier_randomness.extend( - &verify_final_sumcheck_rounds( - proof.final_sumcheck.as_ref(), - &mut verifier_challenger, - &mut sum, - final_rounds, - 0, - ) - .unwrap(), + &verify_standard_sumcheck_rounds(&mut transcript, final_rounds, &mut sum, 0).unwrap(), ); // Check that the randomness vectors are the same diff --git a/src/whir/committer/reader.rs b/src/whir/committer/reader.rs index 393b0c8a..aae40a9e 100644 --- a/src/whir/committer/reader.rs +++ b/src/whir/committer/reader.rs @@ -1,12 +1,15 @@ use core::{fmt::Debug, ops::Deref}; -use p3_challenger::{FieldChallenger, GrindingChallenger}; -use p3_field::{ExtensionField, Field, TwoAdicField}; +use p3_field::{ExtensionField, Field}; use p3_symmetric::Hash; use crate::{ + fiat_shamir::{ + errors::FiatShamirError, + transcript::{Challenge, Reader}, + }, poly::multilinear::MultilinearPoint, - whir::{constraints::statement::EqStatement, parameters::WhirConfig, proof::WhirProof}, + whir::{constraints::statement::EqStatement, parameters::WhirConfig}, }; /// Represents a parsed commitment from the prover in the WHIR protocol. @@ -15,9 +18,6 @@ use crate::{ /// query points and their corresponding answers, which are required for verifier checks. #[derive(Debug, Clone)] pub struct ParsedCommitment { - /// Number of variables in the committed polynomial. - pub num_variables: usize, - /// Merkle root of the committed evaluation table. /// /// This hash is used by the verifier to check Merkle proofs of queried evaluations. @@ -41,9 +41,7 @@ where /// /// # Arguments /// - /// - `verifier_state`: The verifier's Fiat-Shamir state from which data is read. - /// - `proof`: The proof data the verifier reads (currently unused, reserved for RF flow). - /// - `challenger`: The verifier's challenger (currently unused, reserved for RF flow). + /// - `transcript`: The verifier's Fiat-Shamir state from which data is read. /// - `num_variables`: Number of variables in the committed multilinear polynomial. /// - `ood_samples`: Number of out-of-domain points the verifier expects to query. /// @@ -56,65 +54,35 @@ where /// - The prover's claimed answers at those points. /// /// This is used to verify consistency of polynomial commitments in WHIR. - pub fn parse( - proof: &WhirProof, - challenger: &mut Challenger, + pub fn parse( + transcript: &mut Transcript, num_variables: usize, ood_samples: usize, - ) -> ParsedCommitment> + ) -> Result>, FiatShamirError> where - F: TwoAdicField, - EF: ExtensionField + TwoAdicField, - Challenger: FieldChallenger + GrindingChallenger, + EF: ExtensionField, + Transcript: Reader<[F; DIGEST_ELEMS]> + Reader + Challenge, { - Self::parse_with_round(proof, challenger, num_variables, ood_samples, None) - } - - pub fn parse_with_round( - proof: &WhirProof, - challenger: &mut Challenger, - num_variables: usize, - ood_samples: usize, - round_index: Option, - ) -> ParsedCommitment> - where - F: TwoAdicField, - EF: ExtensionField + TwoAdicField, - Challenger: FieldChallenger + GrindingChallenger, - { - let (root_array, ood_answers) = round_index.map_or_else( - || (proof.initial_commitment, proof.initial_ood_answers.clone()), - |idx| { - let round_proof = &proof.rounds[idx]; - (round_proof.commitment, round_proof.ood_answers.clone()) - }, - ); - - // Convert to Hash type - let root: Hash = root_array.into(); - - // Observe the root in the challenger to match prover's transcript - challenger.observe_slice(&root_array); - + // Read the Merkle root + let root: [F; DIGEST_ELEMS] = transcript.read()?; // Construct equality constraints for all out-of-domain (OOD) samples. // Each constraint enforces that the committed polynomial evaluates to the // claimed `ood_answer` at the corresponding `ood_point`, using a univariate // equality weight over `num_variables` inputs. let mut ood_statement = EqStatement::initialize(num_variables); - (0..ood_samples).for_each(|i| { - let point = challenger.sample_algebra_element(); + (0..ood_samples).try_for_each(|_| { + let point = transcript.sample(); let point = MultilinearPoint::expand_from_univariate(point, num_variables); - let eval = ood_answers[i]; - challenger.observe_algebra_element(eval); + let eval = transcript.read()?; ood_statement.add_evaluated_constraint(point, eval); - }); + Ok(()) + })?; // Return a structured representation of the commitment. - ParsedCommitment { - num_variables, - root, + Ok(ParsedCommitment { + root: root.into(), ood_statement, - } + }) } } @@ -123,27 +91,21 @@ where /// The `CommitmentReader` wraps the WHIR configuration and provides a convenient /// method to extract a `ParsedCommitment` by reading values from the Fiat-Shamir transcript. #[derive(Debug)] -pub struct CommitmentReader<'a, EF, F, H, C, Challenger>( +pub struct CommitmentReader<'a, F, EF, Hasher, Compress>( /// Reference to the verifier’s configuration object. /// /// This contains all parameters needed to parse the commitment, /// including how many out-of-domain samples are expected. - &'a WhirConfig, -) -where - F: Field, - EF: ExtensionField; + &'a WhirConfig, +); -impl<'a, EF, F, H, C, Challenger> CommitmentReader<'a, EF, F, H, C, Challenger> -where - F: TwoAdicField, - EF: ExtensionField + TwoAdicField, - Challenger: FieldChallenger + GrindingChallenger, +impl<'a, F: Field, EF: ExtensionField, Hasher, Compress> + CommitmentReader<'a, F, EF, Hasher, Compress> { /// Create a new commitment reader from a WHIR configuration. /// /// This allows the verifier to parse a commitment from the Fiat-Shamir transcript. - pub const fn new(params: &'a WhirConfig) -> Self { + pub const fn new(params: &'a WhirConfig) -> Self { Self(params) } @@ -151,26 +113,27 @@ where /// /// Reads the Merkle root and out-of-domain (OOD) challenge points and answers /// expected for verifying the committed polynomial. - pub fn parse_commitment( + pub fn parse_commitment( &self, - proof: &WhirProof, - challenger: &mut Challenger, - ) -> ParsedCommitment> { - ParsedCommitment::<_, Hash>::parse( - proof, - challenger, + transcript: &mut Transcript, + ) -> Result>, FiatShamirError> + where + Transcript: Reader<[F; DIGEST_ELEMS]> + Reader + Challenge, + { + ParsedCommitment::<_, [F; DIGEST_ELEMS]>::parse( + transcript, self.num_variables, self.commitment_ood_samples, ) } } -impl Deref for CommitmentReader<'_, EF, F, H, C, Challenger> +impl Deref for CommitmentReader<'_, F, EF, Hasher, Compress> where F: Field, EF: ExtensionField, { - type Target = WhirConfig; + type Target = WhirConfig; fn deref(&self) -> &Self::Target { self.0 @@ -190,12 +153,10 @@ mod tests { use super::*; use crate::{ + fiat_shamir::transcript::{FiatShamirReader, FiatShamirWriter}, parameters::{FoldingFactor, ProtocolParameters, errors::SecurityAssumption}, poly::evals::EvaluationsList, - whir::{ - DomainSeparator, committer::writer::CommitmentWriter, parameters::InitialPhaseConfig, - proof::WhirProof, - }, + whir::{DomainSeparator, committer::writer::CommitmentWriter, parameters::InitialPhase}, }; type F = BabyBear; @@ -210,14 +171,10 @@ mod tests { /// This sets up the protocol parameters and multivariate polynomial settings, /// with control over number of variables and OOD samples. #[allow(clippy::type_complexity)] - fn make_test_params( + fn make_test_params( num_variables: usize, ood_samples: usize, - ) -> ( - WhirConfig, - SmallRng, - WhirProof, - ) { + ) -> (WhirConfig, SmallRng) { let mut rng = SmallRng::seed_from_u64(1); let perm = Perm::new_from_rng_128(&mut rng); @@ -229,7 +186,7 @@ mod tests { // Define core protocol parameters for WHIR. let whir_params = ProtocolParameters { - initial_phase_config: InitialPhaseConfig::WithStatementClassic, + initial_phase: InitialPhase::WithStatementClassic, security_level: 100, pow_bits: 10, rs_domain_initial_reduction_factor: 1, @@ -241,23 +198,19 @@ mod tests { }; // Construct full WHIR configuration with MV polynomial shape and protocol rules. - let mut config = WhirConfig::new(num_variables, whir_params.clone()); + let mut config = WhirConfig::new(num_variables, whir_params); // Set the number of OOD samples for commitment testing. config.commitment_ood_samples = ood_samples; // Return the config and a thread-local random number generator. - ( - config, - SmallRng::seed_from_u64(1), - WhirProof::from_protocol_parameters(&whir_params, num_variables), - ) + (config, SmallRng::seed_from_u64(1)) } #[test] fn test_commitment_roundtrip_with_ood() { // Create WHIR config with 5 variables and 3 OOD samples, plus a random number generator. - let (params, mut rng, mut proof) = make_test_params(5, 3); + let (params, mut rng) = make_test_params(5, 3); // Create a random degree-5 multilinear polynomial (32 coefficients). let polynomial = EvaluationsList::new((0..32).map(|_| rng.random()).collect()); @@ -270,26 +223,24 @@ mod tests { // Set up Fiat-Shamir transcript and commit the protocol parameters. let mut ds = DomainSeparator::new(vec![]); - ds.commit_statement::<_, _, _, 8>(¶ms); + ds.commit_statement::<_, _, 8>(¶ms); // Create the prover state from the transcript. let mut rng = SmallRng::seed_from_u64(1); - let challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); - let mut prover_challenger = challenger.clone(); - ds.observe_domain_separator(&mut prover_challenger); + let mut challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); + ds.observe_domain_separator(&mut challenger); + let mut trancript = FiatShamirWriter::init(challenger.clone()); // Commit the polynomial and obtain a witness (root, Merkle proof, OOD evaluations). - let witness = committer - .commit(&dft, &mut proof, &mut prover_challenger, polynomial) - .unwrap(); + let witness = committer.commit(&dft, &mut trancript, polynomial).unwrap(); // Simulate verifier state using transcript view of prover's nonce string. - let mut verifier_challenger = challenger; - ds.observe_domain_separator(&mut verifier_challenger); + let proof = trancript.finalize(); + let mut trancript = FiatShamirReader::init(proof, challenger); // Create a commitment reader and parse the commitment from verifier state. let reader = CommitmentReader::new(¶ms); - let parsed = reader.parse_commitment::<8>(&proof, &mut verifier_challenger); + let parsed = reader.parse_commitment::<_, 8>(&mut trancript).unwrap(); // Ensure the Merkle root matches between prover and parsed result. assert_eq!(parsed.root, witness.prover_data.root()); @@ -301,7 +252,7 @@ mod tests { #[test] fn test_commitment_roundtrip_no_ood() { // Create WHIR config with 4 variables and *no* OOD samples. - let (params, mut rng, mut proof) = make_test_params(4, 0); + let (params, mut rng) = make_test_params(4, 0); // Generate a polynomial with 16 random coefficients. let polynomial = EvaluationsList::new((0..16).map(|_| rng.random()).collect()); @@ -312,26 +263,24 @@ mod tests { // Begin the transcript and commit to the statement parameters. let mut ds = DomainSeparator::new(vec![]); - ds.commit_statement::<_, _, _, 8>(¶ms); + ds.commit_statement::<_, _, 8>(¶ms); // Create the prover state from the transcript. let mut rng = SmallRng::seed_from_u64(1); - let challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); - let mut prover_challenger = challenger.clone(); - ds.observe_domain_separator(&mut prover_challenger); + let mut challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); + ds.observe_domain_separator(&mut challenger); + let mut transcript = FiatShamirWriter::init(challenger.clone()); // Commit the polynomial to obtain the witness. - let witness = committer - .commit(&dft, &mut proof, &mut prover_challenger, polynomial) - .unwrap(); + let witness = committer.commit(&dft, &mut transcript, polynomial).unwrap(); + let proof = transcript.finalize(); // Initialize the verifier view of the transcript. - let mut verifier_challenger = challenger; - ds.observe_domain_separator(&mut verifier_challenger); + let mut transcript = FiatShamirReader::init(proof, challenger); // Parse the commitment from verifier transcript. let reader = CommitmentReader::new(¶ms); - let parsed = reader.parse_commitment::<8>(&proof, &mut verifier_challenger); + let parsed = reader.parse_commitment::<_, 8>(&mut transcript).unwrap(); // Validate the Merkle root matches. assert_eq!(parsed.root, witness.prover_data.root()); @@ -343,7 +292,7 @@ mod tests { #[test] fn test_commitment_roundtrip_large_polynomial() { // Create config with 10 variables and 5 OOD samples. - let (params, mut rng, mut proof) = make_test_params(10, 5); + let (params, mut rng) = make_test_params(10, 5); // Generate a large polynomial with 1024 random coefficients. let polynomial = EvaluationsList::new((0..1024).map(|_| rng.random()).collect()); @@ -354,26 +303,25 @@ mod tests { // Start a new transcript and commit to the public parameters. let mut ds = DomainSeparator::new(vec![]); - ds.commit_statement::<_, _, _, 8>(¶ms); + ds.commit_statement::<_, _, 8>(¶ms); // Create prover state from the transcript. let mut rng = SmallRng::seed_from_u64(1); - let challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); - let mut prover_challenger = challenger.clone(); - ds.observe_domain_separator(&mut prover_challenger); + let mut challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); + ds.observe_domain_separator(&mut challenger); + + let mut transcript = FiatShamirWriter::init(challenger.clone()); // Commit the polynomial and obtain the witness. - let witness = committer - .commit(&dft, &mut proof, &mut prover_challenger, polynomial) - .unwrap(); + let witness = committer.commit(&dft, &mut transcript, polynomial).unwrap(); + let proof = transcript.finalize(); // Initialize verifier view from prover's transcript string. - let mut verifier_challenger = challenger; - ds.observe_domain_separator(&mut verifier_challenger); + let mut transcript = FiatShamirReader::init(proof, challenger); // Parse the commitment from verifier's transcript. let reader = CommitmentReader::new(¶ms); - let parsed = reader.parse_commitment::<8>(&proof, &mut verifier_challenger); + let parsed = reader.parse_commitment::<_, 8>(&mut transcript).unwrap(); // Check Merkle root and OOD answers match. assert_eq!(parsed.root, witness.prover_data.root()); @@ -383,7 +331,7 @@ mod tests { #[test] fn test_oods_constraints_correctness() { // Create WHIR config with 4 variables and 2 OOD samples. - let (params, mut rng, mut proof) = make_test_params(4, 2); + let (params, mut rng) = make_test_params(4, 2); // Generate a multilinear polynomial with 16 coefficients. let polynomial = EvaluationsList::new((0..16).map(|_| rng.random()).collect()); @@ -394,25 +342,23 @@ mod tests { // Set up Fiat-Shamir transcript and commit to the public parameters. let mut ds = DomainSeparator::new(vec![]); - ds.commit_statement::<_, _, _, 8>(¶ms); + ds.commit_statement::<_, _, 8>(¶ms); // Create the prover state from the transcript. let mut rng = SmallRng::seed_from_u64(1); - let challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); - let mut prover_challenger = challenger.clone(); - ds.observe_domain_separator(&mut prover_challenger); + let mut challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); + ds.observe_domain_separator(&mut challenger); - let witness = committer - .commit(&dft, &mut proof, &mut prover_challenger, polynomial) - .unwrap(); + let mut transcript = FiatShamirWriter::init(challenger.clone()); + let witness = committer.commit(&dft, &mut transcript, polynomial).unwrap(); + let proof = transcript.finalize(); // Initialize the verifier view of the transcript. - let mut verifier_challenger = challenger; - ds.observe_domain_separator(&mut verifier_challenger); + let mut transcript = FiatShamirReader::init(proof, challenger); // Parse the commitment from the verifier's state. let reader = CommitmentReader::new(¶ms); - let parsed = reader.parse_commitment::<8>(&proof, &mut verifier_challenger); + let parsed = reader.parse_commitment::<_, 8>(&mut transcript).unwrap(); // Each constraint should have correct univariate weight, sum, and flag. for (i, (point, &eval)) in parsed.ood_statement.iter().enumerate() { diff --git a/src/whir/committer/writer.rs b/src/whir/committer/writer.rs index ff1f12a2..25911944 100644 --- a/src/whir/committer/writer.rs +++ b/src/whir/committer/writer.rs @@ -1,7 +1,6 @@ use alloc::sync::Arc; use core::ops::Deref; -use p3_challenger::{FieldChallenger, GrindingChallenger}; use p3_commit::Mmcs; use p3_dft::TwoAdicSubgroupDft; use p3_field::{ExtensionField, Field, TwoAdicField}; @@ -13,12 +12,12 @@ use tracing::{info_span, instrument}; use super::Witness; use crate::{ - fiat_shamir::errors::FiatShamirError, - poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, - whir::{ - committer::DenseMatrix, constraints::statement::EqStatement, parameters::WhirConfig, - proof::WhirProof, + fiat_shamir::{ + errors::FiatShamirError, + transcript::{Challenge, Writer}, }, + poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, + whir::{committer::DenseMatrix, constraints::statement::EqStatement, parameters::WhirConfig}, }; /// Responsible for committing polynomials using a Merkle-based scheme. @@ -28,22 +27,16 @@ use crate::{ /// /// It provides a commitment that can be used for proof generation and verification. #[derive(Debug)] -pub struct CommitmentWriter<'a, EF, F, H, C, Challenger>( +pub struct CommitmentWriter<'a, F, EF, Hash, Compress>( /// Reference to the WHIR protocol configuration. - &'a WhirConfig, -) -where - F: Field, - EF: ExtensionField; + &'a WhirConfig, +); -impl<'a, EF, F, H, C, Challenger> CommitmentWriter<'a, EF, F, H, C, Challenger> -where - F: TwoAdicField, - EF: ExtensionField + TwoAdicField, - Challenger: FieldChallenger + GrindingChallenger, +impl<'a, F: TwoAdicField, EF: ExtensionField, Hash, Compress> + CommitmentWriter<'a, F, EF, Hash, Compress> { /// Create a new writer that borrows the WHIR protocol configuration. - pub const fn new(params: &'a WhirConfig) -> Self { + pub const fn new(params: &'a WhirConfig) -> Self { Self(params) } @@ -57,21 +50,21 @@ where /// - Computes out-of-domain (OOD) challenge points and their evaluations. /// - Returns a `Witness` containing the commitment data. #[instrument(skip_all)] - pub fn commit, const DIGEST_ELEMS: usize>( + pub fn commit, Transcript, const DIGEST_ELEMS: usize>( &self, dft: &Dft, - proof: &mut WhirProof, - challenger: &mut Challenger, + transcript: &mut Transcript, polynomial: EvaluationsList, ) -> Result, DIGEST_ELEMS>, FiatShamirError> where - H: CryptographicHasher + Hash: CryptographicHasher + CryptographicHasher + Sync, - C: PseudoCompressionFunction<[F; DIGEST_ELEMS], 2> + Compress: PseudoCompressionFunction<[F; DIGEST_ELEMS], 2> + PseudoCompressionFunction<[F::Packing; DIGEST_ELEMS], 2> + Sync, [F; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>, + Transcript: Writer<[F; DIGEST_ELEMS]> + Writer + Challenge, { // Transpose for reverse variable order // And then pad with zeros @@ -95,29 +88,26 @@ where .in_scope(|| dft.dft_batch(padded).to_row_major_matrix()); // Commit to the Merkle tree - let merkle_tree = MerkleTreeMmcs::::new( - self.merkle_hash.clone(), - self.merkle_compress.clone(), - ); + let merkle_tree = + MerkleTreeMmcs::::new( + self.merkle_hash.clone(), + self.merkle_compress.clone(), + ); let (root, prover_data) = info_span!("commit_matrix").in_scope(|| merkle_tree.commit_matrix(folded_matrix)); - - proof.initial_commitment = *root.as_ref(); - challenger.observe_slice(root.as_ref()); + Writer::<[F; DIGEST_ELEMS]>::write(transcript, *root.as_ref())?; let mut ood_statement = EqStatement::initialize(self.num_variables); - (0..self.0.commitment_ood_samples).for_each(|_| { + (0..self.0.commitment_ood_samples).try_for_each(|_| { // Generate OOD points from ProverState randomness - let point = MultilinearPoint::expand_from_univariate( - challenger.sample_algebra_element(), - self.num_variables, - ); + let var: EF = transcript.sample(); + let point = MultilinearPoint::expand_from_univariate(var, self.num_variables); let eval = info_span!("ood evaluation") .in_scope(|| polynomial.evaluate_hypercube_base(&point)); - proof.initial_ood_answers.push(eval); - challenger.observe_algebra_element(eval); + transcript.write(eval)?; ood_statement.add_evaluated_constraint(point, eval); - }); + Ok(()) + })?; // Return the witness containing the polynomial, Merkle tree, and OOD results. Ok(Witness { @@ -128,13 +118,11 @@ where } } -impl Deref for CommitmentWriter<'_, EF, F, H, C, Challenger> +impl Deref for CommitmentWriter<'_, F, EF, Hash, Compress> where F: Field, - EF: ExtensionField, { - type Target = WhirConfig; - + type Target = WhirConfig; fn deref(&self) -> &Self::Target { self.0 } @@ -152,9 +140,9 @@ mod tests { use super::*; use crate::{ - fiat_shamir::domain_separator::DomainSeparator, + fiat_shamir::{domain_separator::DomainSeparator, transcript::FiatShamirWriter}, parameters::{FoldingFactor, ProtocolParameters, errors::SecurityAssumption}, - whir::parameters::InitialPhaseConfig, + whir::parameters::InitialPhase, }; type F = BabyBear; @@ -180,7 +168,7 @@ mod tests { let merkle_compress = MyCompress::new(perm); let whir_params = ProtocolParameters { - initial_phase_config: InitialPhaseConfig::WithStatementClassic, + initial_phase: InitialPhase::WithStatementClassic, security_level, pow_bits, rs_domain_initial_reduction_factor: 1, @@ -195,31 +183,27 @@ mod tests { }; // Define multivariate parameters for the polynomial. - let params = WhirConfig::::new( - num_variables, - whir_params.clone(), - ); + let params = WhirConfig::::new(num_variables, whir_params); // Generate a random polynomial with 32 coefficients. let mut rng = SmallRng::seed_from_u64(1); let polynomial = EvaluationsList::::new(vec![rng.random(); 32]); - let mut proof = WhirProof::::from_protocol_parameters(&whir_params, num_variables); - // Set up the DomainSeparator and initialize a ProverState narg_string. - let mut domainsep: DomainSeparator = DomainSeparator::new(vec![]); - domainsep.commit_statement::<_, _, _, 8>(¶ms); - domainsep.add_whir_proof::<_, _, _, 8>(¶ms); + let mut domainsep = DomainSeparator::new(vec![]); + domainsep.commit_statement::<_, _, 8>(¶ms); + domainsep.add_whir_proof::<_, _, 8>(¶ms); let mut rng = SmallRng::seed_from_u64(1); let mut challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); domainsep.observe_domain_separator(&mut challenger); + let mut transcript: FiatShamirWriter = FiatShamirWriter::init(challenger); // Run the Commitment Phase let committer = CommitmentWriter::new(¶ms); let dft = Radix2DFTSmallBatch::::default(); let witness = committer - .commit(&dft, &mut proof, &mut challenger, polynomial.clone()) + .commit(&dft, &mut transcript, polynomial.clone()) .unwrap(); // Ensure OOD (out-of-domain) points are generated. @@ -237,7 +221,7 @@ mod tests { // Check that OOD answers match expected evaluations for (i, (ood_point, ood_eval)) in witness.ood_statement.iter().enumerate() { - let expected_eval = polynomial.evaluate_hypercube_base(ood_point); + let expected_eval = polynomial.evaluate_hypercube_base::(ood_point); assert_eq!( *ood_eval, expected_eval, "OOD answer at index {i} should match expected evaluation" @@ -261,7 +245,7 @@ mod tests { let merkle_compress = MyCompress::new(perm); let whir_params = ProtocolParameters { - initial_phase_config: InitialPhaseConfig::WithStatementClassic, + initial_phase: InitialPhase::WithStatementClassic, security_level, pow_bits, rs_domain_initial_reduction_factor: 1, @@ -275,28 +259,22 @@ mod tests { starting_log_inv_rate: starting_rate, }; - let params = WhirConfig::::new( - num_variables, - whir_params.clone(), - ); + let params = WhirConfig::::new(num_variables, whir_params); let mut rng = SmallRng::seed_from_u64(1); let polynomial = EvaluationsList::::new(vec![rng.random(); 1024]); - let mut proof = WhirProof::::from_protocol_parameters(&whir_params, num_variables); - let mut domainsep = DomainSeparator::new(vec![]); - domainsep.commit_statement::<_, _, _, 8>(¶ms); + domainsep.commit_statement::<_, _, 8>(¶ms); let mut rng = SmallRng::seed_from_u64(1); let mut challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); domainsep.observe_domain_separator(&mut challenger); + let mut transcript = FiatShamirWriter::init(challenger); let dft = Radix2DFTSmallBatch::::default(); let committer = CommitmentWriter::new(¶ms); - let _ = committer - .commit(&dft, &mut proof, &mut challenger, polynomial) - .unwrap(); + let _ = committer.commit(&dft, &mut transcript, polynomial).unwrap(); } #[test] @@ -315,7 +293,7 @@ mod tests { let merkle_compress = MyCompress::new(perm); let whir_params = ProtocolParameters { - initial_phase_config: InitialPhaseConfig::WithStatementClassic, + initial_phase: InitialPhase::WithStatementClassic, security_level, pow_bits, rs_domain_initial_reduction_factor: 1, @@ -329,10 +307,7 @@ mod tests { starting_log_inv_rate: starting_rate, }; - let mut params = WhirConfig::::new( - num_variables, - whir_params.clone(), - ); + let mut params = WhirConfig::::new(num_variables, whir_params); // Explicitly set OOD samples to 0 params.commitment_ood_samples = 0; @@ -340,21 +315,17 @@ mod tests { let mut rng = SmallRng::seed_from_u64(1); let polynomial = EvaluationsList::::new(vec![rng.random(); 32]); - let mut proof = WhirProof::::from_protocol_parameters(&whir_params, num_variables); - let mut domainsep = DomainSeparator::new(vec![]); - domainsep.commit_statement::<_, _, _, 8>(¶ms); + domainsep.commit_statement::<_, _, 8>(¶ms); let mut rng = SmallRng::seed_from_u64(1); let mut challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); - domainsep.observe_domain_separator(&mut challenger); + let mut transcript = FiatShamirWriter::init(challenger); let dft = Radix2DFTSmallBatch::::default(); let committer = CommitmentWriter::new(¶ms); - let witness = committer - .commit(&dft, &mut proof, &mut challenger, polynomial) - .unwrap(); + let witness = committer.commit(&dft, &mut transcript, polynomial).unwrap(); assert!( witness.ood_statement.is_empty(), diff --git a/src/whir/mod.rs b/src/whir/mod.rs index 98754b82..5b62e2a0 100644 --- a/src/whir/mod.rs +++ b/src/whir/mod.rs @@ -3,26 +3,27 @@ use alloc::{vec, vec::Vec}; use committer::{reader::CommitmentReader, writer::CommitmentWriter}; use constraints::statement::EqStatement; use p3_baby_bear::{BabyBear, Poseidon2BabyBear}; -use p3_challenger::{DuplexChallenger, FieldChallenger}; +use p3_challenger::DuplexChallenger; use p3_dft::Radix2DFTSmallBatch; use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; -use parameters::{InitialPhaseConfig, WhirConfig}; +use parameters::{InitialPhase, WhirConfig}; use prover::Prover; use rand::{SeedableRng, rngs::SmallRng}; use verifier::Verifier; use crate::{ - fiat_shamir::domain_separator::DomainSeparator, + fiat_shamir::{ + domain_separator::DomainSeparator, + transcript::{Challenge, FiatShamirReader, FiatShamirWriter}, + }, parameters::{FoldingFactor, ProtocolParameters, errors::SecurityAssumption}, poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, - whir::proof::WhirProof, }; pub mod committer; pub mod constraints; pub mod parameters; -pub mod proof; pub mod prover; pub mod utils; pub mod verifier; @@ -44,7 +45,7 @@ pub fn make_whir_things( soundness_type: SecurityAssumption, pow_bits: usize, rs_domain_initial_reduction_factor: usize, - initial_phase_config: InitialPhaseConfig, + initial_phase: InitialPhase, ) { // Calculate polynomial size: 2^num_variables coefficients for multilinear polynomial let num_coeffs = 1 << num_variables; @@ -62,7 +63,7 @@ pub fn make_whir_things( // Configure WHIR protocol with all security and performance parameters let whir_params = ProtocolParameters { - initial_phase_config, + initial_phase, security_level: 32, pow_bits, rs_domain_initial_reduction_factor, @@ -74,10 +75,7 @@ pub fn make_whir_things( }; // Create unified configuration combining protocol and polynomial parameters - let params = WhirConfig::::new( - num_variables, - whir_params.clone(), - ); + let params = WhirConfig::::new(num_variables, whir_params); // Define test polynomial: all coefficients = 1 for simple verification // @@ -100,46 +98,38 @@ pub fn make_whir_things( } // Setup Fiat-Shamir transcript structure for non-interactive proof generation - let mut domainsep = DomainSeparator::new(vec![]); + let mut domainsep: DomainSeparator = DomainSeparator::new(vec![]); // Add statement commitment to transcript - domainsep.commit_statement::<_, _, _, 32>(¶ms); + domainsep.commit_statement::<_, _, 32>(¶ms); // Add proof structure to transcript - domainsep.add_whir_proof::<_, _, _, 32>(¶ms); + domainsep.add_whir_proof::<_, _, 32>(¶ms); // Create fresh RNG and challenger for transcript randomness // Initialize prover's view of the Fiat-Shamir transcript let mut rng = SmallRng::seed_from_u64(1); - let mut prover_challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); - domainsep.observe_domain_separator(&mut prover_challenger); + let mut challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); + domainsep.observe_domain_separator(&mut challenger); + let mut transcript = FiatShamirWriter::init(challenger.clone()); // Create polynomial commitment using Merkle tree over evaluation domain let committer = CommitmentWriter::new(¶ms); // DFT evaluator for polynomial let dft = Radix2DFTSmallBatch::::default(); - let mut proof = WhirProof::::from_protocol_parameters(&whir_params, num_variables); - // Commit to polynomial evaluations and generate cryptographic witness - let witness = committer - .commit(&dft, &mut proof, &mut prover_challenger, polynomial) - .unwrap(); + let witness = committer.commit(&dft, &mut transcript, polynomial).unwrap(); // Initialize WHIR prover with the configured parameters let prover = Prover(¶ms); // Generate WHIR proof prover - .prove( - &dft, - &mut proof, - &mut prover_challenger, - statement.clone(), - witness, - ) + .prove(&dft, &mut transcript, statement.clone(), witness) .unwrap(); + let checkpoint_prover: F = transcript.sample(); // Sample final challenge to ensure transcript consistency between prover/verifier - let checkpoint_prover: EF = prover_challenger.sample_algebra_element(); + let proof = transcript.finalize(); // Initialize commitment parser for verifier-side operations let commitment_reader = CommitmentReader::new(¶ms); @@ -148,31 +138,26 @@ pub fn make_whir_things( let verifier = Verifier::new(¶ms); // Reconstruct verifier's transcript from proof data and domain separator - let mut rng = SmallRng::seed_from_u64(1); - let mut verifier_challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); - domainsep.observe_domain_separator(&mut verifier_challenger); + let mut transcript = FiatShamirReader::init(proof, challenger); // Parse and validate the polynomial commitment from proof data - let parsed_commitment = - commitment_reader.parse_commitment::<8>(&proof, &mut verifier_challenger); + let parsed_commitment = commitment_reader + .parse_commitment::<_, 8>(&mut transcript) + .unwrap(); // Execute WHIR verification verifier - .verify( - &proof, - &mut verifier_challenger, - &parsed_commitment, - statement, - ) + .verify(&mut transcript, &parsed_commitment, statement) .unwrap(); - let checkpoint_verifier: EF = verifier_challenger.sample_algebra_element(); + let checkpoint_verifier: F = transcript.sample(); assert_eq!(checkpoint_prover, checkpoint_verifier); } #[cfg(test)] mod tests { use super::*; + use crate::constant::K_SKIP_SUMCHECK; #[test] fn test_whir_end_to_end_without_univariate_skip() { @@ -212,7 +197,7 @@ mod tests { soundness_type, pow_bits, rs_domain_initial_reduction_factor, - InitialPhaseConfig::WithStatementClassic, + InitialPhase::WithStatementClassic, ); } } @@ -225,13 +210,8 @@ mod tests { #[test] fn test_whir_end_to_end_with_univariate_skip() { let folding_factors = [ - FoldingFactor::Constant(1), - FoldingFactor::Constant(2), - FoldingFactor::Constant(3), - FoldingFactor::Constant(4), - FoldingFactor::ConstantFromSecondRound(2, 1), - FoldingFactor::ConstantFromSecondRound(3, 1), - FoldingFactor::ConstantFromSecondRound(3, 2), + FoldingFactor::Constant(5), + FoldingFactor::ConstantFromSecondRound(5, 1), FoldingFactor::ConstantFromSecondRound(5, 2), ]; let soundness_type = [ @@ -248,6 +228,7 @@ mod tests { if folding_factor.at_round(0) < rs_domain_initial_reduction_factor { continue; } + assert!(folding_factor.at_round(0) >= K_SKIP_SUMCHECK); let num_variables = folding_factor.at_round(0)..=3 * folding_factor.at_round(0); for num_variable in num_variables { for num_points in num_points { @@ -260,7 +241,7 @@ mod tests { soundness_type, pow_bits, rs_domain_initial_reduction_factor, - InitialPhaseConfig::WithStatementUnivariateSkip, + InitialPhase::WithStatementSkip, ); } } @@ -308,7 +289,7 @@ mod tests { soundness_type, pow_bits, rs_domain_initial_reduction_factor, - InitialPhaseConfig::WithStatementSvo, + InitialPhase::WithStatementSvo, ); } } @@ -347,7 +328,7 @@ mod tests { soundness_type, pow_bits, rs_reduction_factor, - InitialPhaseConfig::WithoutStatement, + InitialPhase::WithoutStatement, ); } } diff --git a/src/whir/parameters.rs b/src/whir/parameters.rs index 0208bd5d..f7391f12 100644 --- a/src/whir/parameters.rs +++ b/src/whir/parameters.rs @@ -1,14 +1,13 @@ use alloc::vec::Vec; use core::{f64::consts::LOG2_10, marker::PhantomData}; -use p3_challenger::{FieldChallenger, GrindingChallenger}; -use p3_field::{ExtensionField, Field, TwoAdicField}; +use p3_field::{ExtensionField, TwoAdicField}; use crate::parameters::{FoldingFactor, ProtocolParameters, errors::SecurityAssumption}; /// Configuration for the initial phase of the WHIR protocol. #[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] -pub enum InitialPhaseConfig { +pub enum InitialPhase { /// Protocol with statement using classic sumcheck (no optimization). /// /// This is the standard baseline implementation where the prover proves @@ -20,7 +19,7 @@ pub enum InitialPhaseConfig { /// /// Uses the univariate skip optimization from /// to skip the first k variables in the sumcheck by using a univariate representation. - WithStatementUnivariateSkip, + WithStatementSkip, /// Protocol with statement using Small Value Optimization (SVO). /// @@ -37,7 +36,7 @@ pub enum InitialPhaseConfig { WithoutStatement, } -impl InitialPhaseConfig { +impl InitialPhase { /// Returns `true` if this configuration includes an initial statement. #[must_use] pub const fn has_initial_statement(&self) -> bool { @@ -47,7 +46,7 @@ impl InitialPhaseConfig { /// Returns `true` if univariate skip optimization is enabled. #[must_use] pub const fn is_univariate_skip(&self) -> bool { - matches!(self, Self::WithStatementUnivariateSkip) + matches!(self, Self::WithStatementSkip) } /// Returns `true` if SVO optimization is enabled. @@ -70,12 +69,8 @@ pub struct RoundConfig { pub folded_domain_gen: F, } -#[derive(Debug, Clone)] -pub struct WhirConfig -where - F: Field, - EF: ExtensionField, -{ +#[derive(Debug)] +pub struct WhirConfig { pub num_variables: usize, pub soundness_type: SecurityAssumption, pub security_level: usize, @@ -88,7 +83,7 @@ where /// 1. The commitment is a valid low degree polynomial (WithoutStatement). /// 2. The commitment is a valid folded polynomial, and an additional polynomial evaluation /// statement (any of the WithStatement* variants). - pub initial_phase_config: InitialPhaseConfig, + pub initial_phase: InitialPhase, pub starting_log_inv_rate: usize, pub starting_folding_pow_bits: usize, @@ -104,21 +99,13 @@ where // Merkle tree parameters pub merkle_hash: Hash, - pub merkle_compress: C, - - pub _base_field: PhantomData, + pub merkle_compress: Compress, pub _extension_field: PhantomData, - pub _challenger: PhantomData, } -impl WhirConfig -where - F: TwoAdicField, - EF: ExtensionField + TwoAdicField, - Challenger: FieldChallenger + GrindingChallenger, -{ +impl, Hash, Compress> WhirConfig { #[allow(clippy::too_many_lines)] - pub fn new(num_variables: usize, whir_parameters: ProtocolParameters) -> Self { + pub fn new(num_variables: usize, whir_parameters: ProtocolParameters) -> Self { // We need to store the initial number of variables for the final composition. let initial_num_variables = num_variables; whir_parameters @@ -157,7 +144,7 @@ where .folding_factor .compute_number_of_rounds(num_variables); - let has_initial_statement = whir_parameters.initial_phase_config.has_initial_statement(); + let has_initial_statement = whir_parameters.initial_phase.has_initial_statement(); let commitment_ood_samples = if has_initial_statement { whir_parameters.soundness_type.determine_ood_samples( @@ -275,7 +262,7 @@ where Self { security_level: whir_parameters.security_level, max_pow_bits: whir_parameters.pow_bits, - initial_phase_config: whir_parameters.initial_phase_config, + initial_phase: whir_parameters.initial_phase, commitment_ood_samples, num_variables: initial_num_variables, soundness_type: whir_parameters.soundness_type, @@ -291,9 +278,7 @@ where final_log_inv_rate: log_inv_rate, merkle_hash: whir_parameters.merkle_hash, merkle_compress: whir_parameters.merkle_compress, - _base_field: PhantomData, _extension_field: PhantomData, - _challenger: PhantomData, } } @@ -513,7 +498,7 @@ mod tests { const fn default_whir_params() -> ProtocolParameters, Poseidon2Compression> { ProtocolParameters { - initial_phase_config: InitialPhaseConfig::WithStatementClassic, + initial_phase: InitialPhase::WithStatementClassic, security_level: 100, pow_bits: 20, rs_domain_initial_reduction_factor: 1, @@ -530,23 +515,19 @@ mod tests { let params = default_whir_params(); let config = - WhirConfig::, Poseidon2Compression, MyChallenger>::new( - 10, params, - ); + WhirConfig::, Poseidon2Compression>::new(10, params); assert_eq!(config.security_level, 100); assert_eq!(config.max_pow_bits, 20); assert_eq!(config.soundness_type, SecurityAssumption::CapacityBound); - assert!(config.initial_phase_config.has_initial_statement()); + assert!(config.initial_phase.has_initial_statement()); } #[test] fn test_n_rounds() { let params = default_whir_params(); let config = - WhirConfig::, Poseidon2Compression, MyChallenger>::new( - 10, params, - ); + WhirConfig::, Poseidon2Compression>::new(10, params); assert_eq!(config.n_rounds(), config.round_parameters.len()); } @@ -556,7 +537,7 @@ mod tests { let field_size_bits = 64; let soundness = SecurityAssumption::CapacityBound; - let pow_bits = WhirConfig::::folding_pow_bits( + let pow_bits = WhirConfig::::folding_pow_bits( 100, // Security level soundness, field_size_bits, @@ -572,9 +553,7 @@ mod tests { fn test_check_pow_bits_within_limits() { let params = default_whir_params(); let mut config = - WhirConfig::, Poseidon2Compression, MyChallenger>::new( - 10, params, - ); + WhirConfig::, Poseidon2Compression>::new(10, params); // Set all values within limits config.max_pow_bits = 20; @@ -618,9 +597,7 @@ mod tests { fn test_check_pow_bits_starting_folding_exceeds() { let params = default_whir_params(); let mut config = - WhirConfig::, Poseidon2Compression, MyChallenger>::new( - 10, params, - ); + WhirConfig::, Poseidon2Compression>::new(10, params); config.max_pow_bits = 20; config.starting_folding_pow_bits = 21; // Exceeds max_pow_bits @@ -637,9 +614,7 @@ mod tests { fn test_check_pow_bits_final_pow_exceeds() { let params = default_whir_params(); let mut config = - WhirConfig::, Poseidon2Compression, MyChallenger>::new( - 10, params, - ); + WhirConfig::, Poseidon2Compression>::new(10, params); config.max_pow_bits = 20; config.starting_folding_pow_bits = 15; @@ -656,9 +631,7 @@ mod tests { fn test_check_pow_bits_round_pow_exceeds() { let params = default_whir_params(); let mut config = - WhirConfig::, Poseidon2Compression, MyChallenger>::new( - 10, params, - ); + WhirConfig::, Poseidon2Compression>::new(10, params); config.max_pow_bits = 20; config.starting_folding_pow_bits = 15; @@ -688,9 +661,7 @@ mod tests { fn test_check_pow_bits_round_folding_pow_exceeds() { let params = default_whir_params(); let mut config = - WhirConfig::, Poseidon2Compression, MyChallenger>::new( - 10, params, - ); + WhirConfig::, Poseidon2Compression>::new(10, params); config.max_pow_bits = 20; config.starting_folding_pow_bits = 15; @@ -720,9 +691,7 @@ mod tests { fn test_check_pow_bits_exactly_at_limit() { let params = default_whir_params(); let mut config = - WhirConfig::, Poseidon2Compression, MyChallenger>::new( - 10, params, - ); + WhirConfig::, Poseidon2Compression>::new(10, params); config.max_pow_bits = 20; config.starting_folding_pow_bits = 20; @@ -751,9 +720,7 @@ mod tests { fn test_check_pow_bits_all_exceed() { let params = default_whir_params(); let mut config = - WhirConfig::, Poseidon2Compression, MyChallenger>::new( - 10, params, - ); + WhirConfig::, Poseidon2Compression>::new(10, params); config.max_pow_bits = 20; config.starting_folding_pow_bits = 22; diff --git a/src/whir/proof.rs b/src/whir/proof.rs deleted file mode 100644 index 59f10a62..00000000 --- a/src/whir/proof.rs +++ /dev/null @@ -1,1110 +0,0 @@ -use alloc::vec::Vec; -use core::array; - -use p3_challenger::{FieldChallenger, GrindingChallenger}; -use p3_field::{ExtensionField, Field}; -use serde::{Deserialize, Serialize}; - -use crate::{ - constant::K_SKIP_SUMCHECK, parameters::ProtocolParameters, poly::evals::EvaluationsList, - whir::parameters::InitialPhaseConfig, -}; - -/// Complete WHIR proof -#[derive(Serialize, Deserialize, Clone, Debug)] -#[serde(bound( - serialize = "F: Serialize, EF: Serialize, [F; DIGEST_ELEMS]: Serialize", - deserialize = "F: Deserialize<'de>, EF: Deserialize<'de>, [F; DIGEST_ELEMS]: Deserialize<'de>" -))] -pub struct WhirProof { - /// Initial polynomial commitment (Merkle root) - pub initial_commitment: [F; DIGEST_ELEMS], - - /// Initial OOD evaluations - pub initial_ood_answers: Vec, - - /// Initial phase data - captures the protocol variant - pub initial_phase: InitialPhase, - - /// One proof per WHIR round - pub rounds: Vec>, - - /// Final polynomial evaluations - pub final_poly: Option>, - - /// Final round PoW witness - pub final_pow_witness: F, - - /// Final round query openings - pub final_queries: Vec>, - - /// Final sumcheck (if final_sumcheck_rounds > 0) - pub final_sumcheck: Option>, -} - -impl Default - for WhirProof -{ - fn default() -> Self { - Self { - initial_commitment: array::from_fn(|_| F::default()), - initial_ood_answers: Vec::new(), - initial_phase: InitialPhase::default(), - rounds: Vec::new(), - final_poly: None, - final_pow_witness: F::default(), - final_queries: Vec::new(), - final_sumcheck: None, - } - } -} - -/// Initial phase of WHIR protocol -#[derive(Serialize, Deserialize, Clone, Debug)] -#[serde(tag = "type")] -pub enum InitialPhase { - /// Protocol with statement and univariate skip optimization - #[serde(rename = "with_statement_skip")] - WithStatementSkip(SumcheckSkipData), - - /// Protocol with statement and svo optimization. - /// First `l` rounds of svo optimization, the remaining rounds from algorithm 5 of the paper - /// (which have the same structure) stored in the subsequents `WhirRoundProof` elements. - #[serde(rename = "with_statement_svo")] - WithStatementSvo { - /// Svo sumcheck data - sumcheck: SumcheckData, - }, - - /// Protocol with statement (standard sumcheck, no skip) - #[serde(rename = "with_statement")] - WithStatement { - /// Standard sumcheck data - sumcheck: SumcheckData, - }, - - /// Protocol without statement (direct folding) - #[serde(rename = "without_statement")] - WithoutStatement { pow_witness: F }, -} - -impl Default for InitialPhase { - fn default() -> Self { - Self::WithStatement { - sumcheck: SumcheckData::default(), - } - } -} - -/// Data for a single WHIR round -#[derive(Serialize, Deserialize, Clone, Debug)] -#[serde(bound( - serialize = "F: Serialize, EF: Serialize, [F; DIGEST_ELEMS]: Serialize", - deserialize = "F: Deserialize<'de>, EF: Deserialize<'de>, [F; DIGEST_ELEMS]: Deserialize<'de>" -))] -pub struct WhirRoundProof { - /// Round commitment (Merkle root) - pub commitment: [F; DIGEST_ELEMS], - - /// OOD evaluations for this round - pub ood_answers: Vec, - - /// PoW witness after commitment - pub pow_witness: F, - - /// STIR query openings - pub queries: Vec>, - - /// Sumcheck data for this round - pub sumcheck: SumcheckData, -} - -impl Default - for WhirRoundProof -{ - fn default() -> Self { - Self { - commitment: array::from_fn(|_| F::default()), - ood_answers: Vec::new(), - pow_witness: F::default(), - queries: Vec::new(), - sumcheck: SumcheckData::default(), - } - } -} - -/// Query opening -#[derive(Serialize, Deserialize, Clone, Debug)] -#[serde( - bound( - serialize = "F: Serialize, EF: Serialize, [F; DIGEST_ELEMS]: Serialize", - deserialize = "F: Deserialize<'de>, EF: Deserialize<'de>, [F; DIGEST_ELEMS]: Deserialize<'de>" - ), - tag = "type" -)] -pub enum QueryOpening { - /// Base field query (round_index == 0) - #[serde(rename = "base")] - Base { - /// Merkle leaf values in F - values: Vec, - /// Merkle authentication path - proof: Vec<[F; DIGEST_ELEMS]>, - }, - /// Extension field query (round_index > 0) - #[serde(rename = "extension")] - Extension { - /// Merkle leaf values in EF - values: Vec, - /// Merkle authentication path - proof: Vec<[F; DIGEST_ELEMS]>, - }, -} - -/// Sumcheck polynomial data -/// -/// Stores the polynomial evaluations for sumcheck rounds in a compact format. -/// Each round stores `[h(0), h(2)]` where `h(1)` is derived as `claimed_sum - h(0)`. -#[derive(Default, Serialize, Deserialize, Clone, Debug)] -pub struct SumcheckData { - /// Polynomial evaluations for each sumcheck round - /// - /// Each entry is `[h(0), h(2)]` - the evaluations at 0 and 2 - /// - /// `h(1)` is derived as `claimed_sum - h(0)` by the verifier - /// - /// Length: folding_factor - pub polynomial_evaluations: Vec<[EF; 2]>, - - /// PoW witnesses for each sumcheck round - /// Length: folding_factor - pub pow_witnesses: Vec, -} - -impl SumcheckData { - /// Appends a proof-of-work witness. - pub fn push_pow_witness(&mut self, witness: F) { - self.pow_witnesses.push(witness); - } - - /// Commits polynomial coefficients to the transcript and returns a challenge. - /// - /// This helper function handles the Fiat-Shamir interaction for a sumcheck round. - /// - /// # Arguments - /// - /// * `challenger` - Fiat-Shamir transcript. - /// * `c0` - Constant coefficient `h(0)`. - /// * `c2` - Quadratic coefficient. - /// * `pow_bits` - PoW difficulty (0 to skip grinding). - /// - /// # Returns - /// - /// The sampled challenge `r \in EF`. - pub fn observe_and_sample( - &mut self, - challenger: &mut Challenger, - c0: EF, - c2: EF, - pow_bits: usize, - ) -> EF - where - BF: Field, - EF: ExtensionField, - F: Clone, - Challenger: FieldChallenger + GrindingChallenger, - { - // Record the polynomial coefficients in the proof. - self.polynomial_evaluations.push([c0, c2]); - - // Absorb coefficients into the transcript. - // - // Note: We only send (c_0, c_2). The verifier derives c_1 from the sum constraint. - challenger.observe_algebra_slice(&[c0, c2]); - - // Optional proof-of-work to increase prover cost. - // - // This makes it expensive for a malicious prover to "mine" favorable challenges. - if pow_bits > 0 { - self.push_pow_witness(challenger.grind(pow_bits)); - } - - // Sample the verifier's challenge for this round. - challenger.sample_algebra_element() - } -} - -#[derive(Default, Serialize, Deserialize, Clone, Debug)] -pub struct SumcheckSkipData { - pub evaluations: Vec, - pub pow: F, - pub sumcheck: SumcheckData, -} - -impl WhirProof { - /// Create a new WhirProof from protocol parameters and configuration - /// - /// This initializes an empty proof structure with appropriate capacity allocations - /// based on the protocol parameters. The actual proof data will be populated during - /// the proving process. - /// - /// # Parameters - /// - `params`: The protocol parameters containing security settings and folding configuration - /// - `num_variables`: The number of variables in the multilinear polynomial - /// - /// # Returns - /// A new `WhirProof` with pre-allocated vectors sized according to the protocol parameters - pub fn from_protocol_parameters( - params: &ProtocolParameters, - num_variables: usize, - ) -> Self { - // Determine which initial phase variant based on protocol configuration - let initial_phase = match params.initial_phase_config { - // No initial statement: direct folding path - InitialPhaseConfig::WithoutStatement => InitialPhase::without_statement(), - - // With statement + UnivariateSkip optimization - InitialPhaseConfig::WithStatementUnivariateSkip - if K_SKIP_SUMCHECK <= params.folding_factor.at_round(0) => - { - InitialPhase::with_statement_skip(SumcheckSkipData::default()) - } - - // With statement + SVO optimization - // TODO: SVO optimization is not yet fully implemented - // Fall back to classic sumcheck for now - InitialPhaseConfig::WithStatementSvo => { - InitialPhase::with_statement_svo(SumcheckData::default()) - } - - // With statement + Classic (or UnivariateSkip with insufficient folding factor) - _ => InitialPhase::with_statement(SumcheckData::default()), - }; - - // Use the actual FoldingFactor method to calculate rounds correctly - let (num_rounds, _final_sumcheck_rounds) = params - .folding_factor - .compute_number_of_rounds(num_variables); - - // Calculate protocol security level (after subtracting PoW bits) - let protocol_security_level = params.security_level.saturating_sub(params.pow_bits); - - // Compute the number of queries - let num_queries = params - .soundness_type - .queries(protocol_security_level, params.starting_log_inv_rate); - - Self { - initial_commitment: array::from_fn(|_| F::default()), - initial_ood_answers: Vec::new(), - initial_phase, - rounds: (0..num_rounds).map(|_| WhirRoundProof::default()).collect(), - final_poly: None, - final_pow_witness: F::default(), - final_queries: Vec::with_capacity(num_queries), - final_sumcheck: None, - } - } -} - -impl WhirProof { - /// Extract the PoW witness after the commitment at the given round index - /// - /// Returns the PoW witness from the round at the given index. - /// The PoW witness is stored in proof.rounds[round_index].pow_witness. - pub fn get_pow_after_commitment(&self, round_index: usize) -> Option { - self.rounds - .get(round_index) - .map(|round| round.pow_witness.clone()) - } - - /// Stores sumcheck data at a specific round index. - /// - /// # Parameters - /// - `data`: The sumcheck data to store - /// - `round_index`: The round index to store the data at - /// - /// # Panics - /// Panics if `round_index` is out of bounds. - pub fn set_sumcheck_data_at(&mut self, data: SumcheckData, round_index: usize) { - self.rounds[round_index].sumcheck = data; - } - - /// Stores sumcheck data in the final sumcheck field. - /// - /// # Parameters - /// - `data`: The sumcheck data to store - pub fn set_final_sumcheck_data(&mut self, data: SumcheckData) { - self.final_sumcheck = Some(data); - } -} - -impl InitialPhase { - /// Create initial phase with statement and skip optimization - pub const fn with_statement_skip(skip_data: SumcheckSkipData) -> Self { - Self::WithStatementSkip(skip_data) - } - - /// Create initial phase with statement and SVO optimization - #[must_use] - pub const fn with_statement_svo(sumcheck: SumcheckData) -> Self { - Self::WithStatementSvo { sumcheck } - } - - /// Create initial phase with statement (no skip) - #[must_use] - pub const fn with_statement(sumcheck: SumcheckData) -> Self { - Self::WithStatement { sumcheck } - } - - /// Create initial phase without statement - #[must_use] - pub fn without_statement() -> Self - where - F: Default, - { - Self::WithoutStatement { - pow_witness: F::default(), - } - } -} - -#[cfg(test)] -mod tests { - use alloc::vec; - - use p3_baby_bear::{BabyBear, Poseidon2BabyBear}; - use p3_challenger::DuplexChallenger; - use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; - use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; - use rand::SeedableRng; - - use super::*; - use crate::parameters::{FoldingFactor, errors::SecurityAssumption}; - - /// Type alias for the base field used in tests - type F = BabyBear; - - /// Type alias for the extension field used in tests - type EF = BinomialExtensionField; - - /// Type alias for the permutation used in Merkle tree - type Perm = Poseidon2BabyBear<16>; - - /// Type alias for the hash function - type MyHash = PaddingFreeSponge; - - /// Type alias for the compression function - type MyCompress = TruncatedPermutation; - - /// Type alias for the challenger used in observe_and_sample tests. - type TestChallenger = DuplexChallenger; - - /// Digest size for Merkle tree commitments - const DIGEST_ELEMS: usize = 8; - - /// Helper function to create minimal protocol parameters for testing - /// - /// This creates a `ProtocolParameters` instance with specified configuration - /// for testing different proof initialization scenarios. - /// - /// # Parameters - /// - `initial_phase_config`: Configuration for the initial phase - /// - `folding_factor`: The folding strategy for the protocol - /// - /// # Returns - /// A `ProtocolParameters` instance configured for testing - fn create_test_params( - initial_phase_config: InitialPhaseConfig, - folding_factor: FoldingFactor, - ) -> ProtocolParameters { - // Create the permutation for hash and compress - let perm = Perm::new_from_rng_128(&mut rand::rngs::SmallRng::seed_from_u64(42)); - - ProtocolParameters { - initial_phase_config, - starting_log_inv_rate: 2, - rs_domain_initial_reduction_factor: 1, - folding_factor, - soundness_type: SecurityAssumption::UniqueDecoding, - security_level: 100, - pow_bits: 10, - merkle_hash: PaddingFreeSponge::new(perm.clone()), - merkle_compress: TruncatedPermutation::new(perm), - } - } - - #[test] - fn test_whir_proof_from_params_with_univariate_skip() { - // Declare test parameters explicitly - - // Set folding factor to 6, which is >= K_SKIP_SUMCHECK (5) - // This ensures univariate skip optimization is enabled - let folding_factor_value = 6; - let folding_factor = FoldingFactor::Constant(folding_factor_value); - - // Enable univariate skip optimization - let initial_phase_config = InitialPhaseConfig::WithStatementUnivariateSkip; - - // Use 20 variables for testing - let num_variables = 20; - - // Create protocol parameters with univariate skip enabled - let params = create_test_params(initial_phase_config, folding_factor); - - // Create proof structure from parameters - let proof: WhirProof = - WhirProof::from_protocol_parameters(¶ms, num_variables); - - // Verify that initial_phase is WithStatementSkip variant - // This should be true because: - // - initial_phase_config = WithStatementUnivariateSkip - // - folding_factor (6) >= K_SKIP_SUMCHECK (5) - match proof.initial_phase { - InitialPhase::WithStatementSkip(skip_data) => { - // evaluations should be empty (not populated yet) - assert_eq!(skip_data.evaluations.len(), 0); - // pow should be default (not populated yet) - assert_eq!(skip_data.pow, F::default()); - // sumcheck should have empty polynomial_evaluations - assert_eq!(skip_data.sumcheck.polynomial_evaluations.len(), 0); - // sumcheck should have empty PoW witnesses - assert!(skip_data.sumcheck.pow_witnesses.is_empty()); - } - _ => panic!("Expected WithStatementSkip variant"), - } - - // Verify rounds length - // Formula: ((num_variables - MAX_NUM_VARIABLES_TO_SEND_COEFFS) / folding_factor) - 1 - // MAX_NUM_VARIABLES_TO_SEND_COEFFS = 6 (threshold for sending coefficients directly) - // For 20 variables with folding_factor 6: - // (20 - 6).div_ceil(6) - 1 = 14.div_ceil(6) - 1 = 3 - 1 = 2 - let expected_rounds = 2; - assert_eq!(proof.rounds.len(), expected_rounds); - } - - #[test] - fn test_whir_proof_from_params_without_univariate_skip() { - // Declare test parameters explicitly - - // Set folding factor to 4, which is < K_SKIP_SUMCHECK (5) - // This ensures univariate skip optimization is NOT enabled - let folding_factor_value = 4; - let folding_factor = FoldingFactor::Constant(folding_factor_value); - - // Even if we request UnivariateSkip, it won't be used - // because folding_factor < K_SKIP_SUMCHECK - let initial_phase_config = InitialPhaseConfig::WithStatementUnivariateSkip; - - // Use 16 variables for testing - let num_variables = 16; - - // Create protocol parameters (skip won't be enabled due to folding_factor < 5) - let params = create_test_params(initial_phase_config, folding_factor); - - // Create proof structure from parameters - let proof: WhirProof = - WhirProof::from_protocol_parameters(¶ms, num_variables); - - // Verify that initial_phase is WithStatement variant (NOT WithStatementSkip) - // This is because folding_factor (4) < K_SKIP_SUMCHECK (5) - match proof.initial_phase { - InitialPhase::WithStatement { sumcheck } => { - // sumcheck should have empty polynomial_evaluations - assert_eq!(sumcheck.polynomial_evaluations.len(), 0); - // sumcheck should have empty PoW witnesses - assert!(sumcheck.pow_witnesses.is_empty()); - } - _ => panic!("Expected WithStatement variant, not WithStatementSkip"), - } - - // Verify rounds length - // Formula: ((num_variables - MAX_NUM_VARIABLES_TO_SEND_COEFFS) / folding_factor) - 1 - // MAX_NUM_VARIABLES_TO_SEND_COEFFS = 6 - // For 16 variables with folding_factor 4: - // (16 - 6).div_ceil(4) - 1 = 10.div_ceil(4) - 1 = 3 - 1 = 2 - let expected_rounds = 2; - assert_eq!(proof.rounds.len(), expected_rounds); - } - - #[test] - fn test_whir_proof_from_params_without_initial_statement() { - // Declare test parameters explicitly - - // Folding factor doesn't matter for initial_phase when WithoutStatement - let folding_factor_value = 6; - let folding_factor = FoldingFactor::Constant(folding_factor_value); - - // Configure without initial statement - let initial_phase_config = InitialPhaseConfig::WithoutStatement; - - // Use 18 variables for testing - let num_variables = 18; - - // Create protocol parameters without initial statement - let params = create_test_params(initial_phase_config, folding_factor); - - // Create proof structure from parameters - let proof: WhirProof = - WhirProof::from_protocol_parameters(¶ms, num_variables); - - // Verify that initial_phase is WithoutStatement variant - // This is because initial_phase_config = WithoutStatement - match proof.initial_phase { - InitialPhase::WithoutStatement { pow_witness } => { - // pow_witness should be default (not populated yet) - assert_eq!(pow_witness, F::default()); - } - _ => panic!("Expected WithoutStatement variant"), - } - - // Verify rounds length - // Formula: ((num_variables - MAX_NUM_VARIABLES_TO_SEND_COEFFS) / folding_factor) - 1 - // MAX_NUM_VARIABLES_TO_SEND_COEFFS = 6 - // For 18 variables with folding_factor 6: - // (18 - 6).div_ceil(6) - 1 = 12.div_ceil(6) - 1 = 2 - 1 = 1 - let expected_rounds = 1; - assert_eq!(proof.rounds.len(), expected_rounds); - } - - #[test] - fn test_get_pow_after_commitment_with_witness() { - // Create an explicit PoW witness value for testing - let pow_witness_value = F::from_u64(42); - - // Create a proof with one round containing a PoW witness - let proof: WhirProof = WhirProof { - initial_commitment: array::from_fn(|_| F::default()), - initial_ood_answers: Vec::new(), - initial_phase: InitialPhase::WithoutStatement { - pow_witness: F::default(), - }, - rounds: vec![WhirRoundProof { - commitment: array::from_fn(|_| F::default()), - ood_answers: Vec::new(), - pow_witness: pow_witness_value, - queries: Vec::new(), - sumcheck: SumcheckData::default(), - }], - final_poly: None, - final_pow_witness: F::default(), - final_queries: Vec::new(), - final_sumcheck: None, - }; - - // Query round index 0, which exists and has a PoW witness - let round_index = 0; - - // Get the PoW witness after commitment at round 0 - let result = proof.get_pow_after_commitment(round_index); - - // Verify that we get Some(pow_witness_value) - assert_eq!(result, Some(pow_witness_value)); - } - - #[test] - fn test_get_pow_after_commitment_invalid_round() { - // Create a proof with one round - let proof: WhirProof = WhirProof { - initial_commitment: array::from_fn(|_| F::default()), - initial_ood_answers: Vec::new(), - initial_phase: InitialPhase::WithoutStatement { - pow_witness: F::default(), - }, - rounds: vec![WhirRoundProof { - commitment: array::from_fn(|_| F::default()), - ood_answers: Vec::new(), - pow_witness: F::from_u64(42), - queries: Vec::new(), - sumcheck: SumcheckData::default(), - }], - final_poly: None, - final_pow_witness: F::default(), - final_queries: Vec::new(), - final_sumcheck: None, - }; - - // Query round index 1, which doesn't exist (only round 0 exists) - let invalid_round_index = 1; - - // Get the PoW witness after commitment at invalid round - let result = proof.get_pow_after_commitment(invalid_round_index); - - // Verify that we get None because the round doesn't exist - assert_eq!(result, None); - } - - #[test] - fn test_initial_phase_constructors() { - // Test with_statement_skip constructor - - // Create skip PoW witness - let skip_pow_value = F::from_u64(123); - - // Create SumcheckSkipData - let skip_data: SumcheckSkipData = SumcheckSkipData { - evaluations: Vec::new(), - pow: skip_pow_value, - sumcheck: SumcheckData::default(), - }; - - // Construct WithStatementSkip variant - let phase_skip = InitialPhase::with_statement_skip(skip_data); - - // Verify it's the correct variant - match phase_skip { - InitialPhase::WithStatementSkip(skip_data) => { - assert_eq!(skip_data.evaluations.len(), 0); - assert_eq!(skip_data.pow, skip_pow_value); - assert_eq!(skip_data.sumcheck.polynomial_evaluations.len(), 0); - } - _ => panic!("Expected WithStatementSkip variant"), - } - - // Test with_statement constructor - - // Create empty sumcheck data - let sumcheck: SumcheckData = SumcheckData::default(); - - // Construct WithStatement variant - let phase_statement = InitialPhase::with_statement(sumcheck); - - // Verify it's the correct variant - match phase_statement { - InitialPhase::WithStatement { sumcheck } => { - assert_eq!(sumcheck.polynomial_evaluations.len(), 0); - } - _ => panic!("Expected WithStatement variant"), - } - - // Test without_statement constructor - - // Construct WithoutStatement variant - let phase_without = InitialPhase::::without_statement(); - - // Verify it's the correct variant - match phase_without { - InitialPhase::WithoutStatement { pow_witness } => { - // pow_witness should be default from constructor - assert_eq!(pow_witness, F::default()); - } - _ => panic!("Expected WithoutStatement variant"), - } - } - - #[test] - fn test_whir_round_proof_default() { - // Create a default WhirRoundProof - let round: WhirRoundProof = WhirRoundProof::default(); - - // Verify commitment is array of default F values - assert_eq!(round.commitment.len(), DIGEST_ELEMS); - for elem in round.commitment { - assert_eq!(elem, F::default()); - } - - // Verify ood_answers is empty - assert_eq!(round.ood_answers.len(), 0); - - // Verify pow_witness is default - assert_eq!(round.pow_witness, F::default()); - - // Verify queries is empty - assert_eq!(round.queries.len(), 0); - - // Verify sumcheck has default values - assert_eq!(round.sumcheck.polynomial_evaluations.len(), 0); - assert!(round.sumcheck.pow_witnesses.is_empty()); - } - - #[test] - fn test_sumcheck_data_default() { - // Create a default SumcheckData - let sumcheck: SumcheckData = SumcheckData::default(); - - // Verify polynomial_evaluations is empty - assert_eq!(sumcheck.polynomial_evaluations.len(), 0); - - // Verify pow_witnesses is empty - assert!(sumcheck.pow_witnesses.is_empty()); - } - - #[test] - fn test_query_opening_variants() { - // Test Base variant - - // Create base field values - let base_val_0 = F::from_u64(1); - let base_val_1 = F::from_u64(2); - let values = vec![base_val_0, base_val_1]; - - // Create Merkle proof (authentication path) - let proof_node = array::from_fn(|i| F::from_u64(i as u64)); - let proof = vec![proof_node]; - - // Construct Base variant - let base_opening: QueryOpening = QueryOpening::Base { - values, - proof: proof.clone(), - }; - - // Verify it's the correct variant - match base_opening { - QueryOpening::Base { - values: v, - proof: p, - } => { - assert_eq!(v.len(), 2); - assert_eq!(v[0], base_val_0); - assert_eq!(v[1], base_val_1); - assert_eq!(p.len(), 1); - } - QueryOpening::Extension { .. } => panic!("Expected Base variant"), - } - - // Test Extension variant - - // Create extension field values - // Extension field values are created from base field using From trait - let ext_val_0 = EF::from_u64(3); - let ext_val_1 = EF::from_u64(4); - let ext_values = vec![ext_val_0, ext_val_1]; - - // Construct Extension variant - let ext_opening: QueryOpening = QueryOpening::Extension { - values: ext_values, - proof, - }; - - // Verify it's the correct variant - match ext_opening { - QueryOpening::Extension { - values: v, - proof: p, - } => { - assert_eq!(v.len(), 2); - assert_eq!(v[0], ext_val_0); - assert_eq!(v[1], ext_val_1); - assert_eq!(p.len(), 1); - } - QueryOpening::Base { .. } => panic!("Expected Extension variant"), - } - } - - #[test] - fn test_push_pow_witness() { - let mut sumcheck: SumcheckData = SumcheckData::default(); - - // First push - let witness1 = F::from_u64(42); - sumcheck.push_pow_witness(witness1); - - assert_eq!(sumcheck.pow_witnesses.len(), 1); - assert_eq!(sumcheck.pow_witnesses[0], witness1); - - // Second push should append to existing vector - let witness2 = F::from_u64(123); - sumcheck.push_pow_witness(witness2); - - assert_eq!(sumcheck.pow_witnesses.len(), 2); - assert_eq!(sumcheck.pow_witnesses[1], witness2); - } - - #[test] - fn test_set_final_sumcheck_data() { - // Create a proof with no rounds - let mut proof: WhirProof = WhirProof { - initial_commitment: array::from_fn(|_| F::default()), - initial_ood_answers: Vec::new(), - initial_phase: InitialPhase::WithoutStatement { - pow_witness: F::default(), - }, - rounds: Vec::new(), - final_poly: None, - final_pow_witness: F::default(), - final_queries: Vec::new(), - final_sumcheck: None, - }; - - // Verify final_sumcheck is None initially - assert!(proof.final_sumcheck.is_none()); - - // Create sumcheck data with a distinguishable value - let mut data: SumcheckData = SumcheckData::default(); - data.push_pow_witness(F::from_u64(999)); - - // Set as final - proof.set_final_sumcheck_data(data); - - // Verify it was stored in final_sumcheck - assert!(proof.final_sumcheck.is_some()); - let stored = proof.final_sumcheck.as_ref().unwrap(); - assert_eq!(stored.pow_witnesses[0], F::from_u64(999)); - } - - #[test] - fn test_set_sumcheck_data_at_round() { - // Create a proof with two rounds - let mut proof: WhirProof = WhirProof { - initial_commitment: array::from_fn(|_| F::default()), - initial_ood_answers: Vec::new(), - initial_phase: InitialPhase::WithoutStatement { - pow_witness: F::default(), - }, - rounds: vec![WhirRoundProof::default(), WhirRoundProof::default()], - final_poly: None, - final_pow_witness: F::default(), - final_queries: Vec::new(), - final_sumcheck: None, - }; - - // Verify rounds' sumcheck is empty initially - assert!(proof.rounds[0].sumcheck.pow_witnesses.is_empty()); - assert!(proof.rounds[1].sumcheck.pow_witnesses.is_empty()); - - // Create sumcheck data with a distinguishable value for round 0 - let mut data0: SumcheckData = SumcheckData::default(); - data0.push_pow_witness(F::from_u64(777)); - proof.set_sumcheck_data_at(data0, 0); - - // Create sumcheck data with a distinguishable value for round 1 - let mut data1: SumcheckData = SumcheckData::default(); - data1.push_pow_witness(F::from_u64(888)); - proof.set_sumcheck_data_at(data1, 1); - - // Verify it was stored in the correct rounds - assert_eq!(proof.rounds[0].sumcheck.pow_witnesses[0], F::from_u64(777)); - assert_eq!(proof.rounds[1].sumcheck.pow_witnesses[0], F::from_u64(888)); - - // Verify final_sumcheck is still None - assert!(proof.final_sumcheck.is_none()); - } - - #[test] - #[should_panic(expected = "index out of bounds")] - fn test_set_sumcheck_data_at_no_rounds_panics() { - // Create a proof with no rounds - let mut proof: WhirProof = WhirProof { - initial_commitment: array::from_fn(|_| F::default()), - initial_ood_answers: Vec::new(), - initial_phase: InitialPhase::WithoutStatement { - pow_witness: F::default(), - }, - rounds: Vec::new(), - final_poly: None, - final_pow_witness: F::default(), - final_queries: Vec::new(), - final_sumcheck: None, - }; - - // Try to set sumcheck data at index 0 with no rounds - should panic - proof.set_sumcheck_data_at(SumcheckData::default(), 0); - } - - /// Creates a fresh challenger for testing. - /// - /// The challenger is seeded deterministically so tests are reproducible. - fn create_test_challenger() -> TestChallenger { - let perm = Perm::new_from_rng_128(&mut rand::rngs::SmallRng::seed_from_u64(42)); - DuplexChallenger::new(perm) - } - - #[test] - fn test_observe_and_sample_records_coefficients() { - // The method should push [c0, c2] to polynomial_evaluations. - // - // polynomial_evaluations stores the sumcheck polynomial coefficients - // for each round: [h(0), h(2)] where h(1) is derived by the verifier. - let c0 = EF::from_u64(5); - let c2 = EF::from_u64(7); - - let mut sumcheck: SumcheckData = SumcheckData::default(); - let mut challenger = create_test_challenger(); - - // polynomial_evaluations should be empty initially - assert!(sumcheck.polynomial_evaluations.is_empty()); - - // Call observe_and_sample with pow_bits = 0 (no grinding) - let pow_bits = 0; - let _r = sumcheck.observe_and_sample::<_, F>(&mut challenger, c0, c2, pow_bits); - - // polynomial_evaluations should now have one entry: [c0, c2] - assert_eq!(sumcheck.polynomial_evaluations.len(), 1); - assert_eq!(sumcheck.polynomial_evaluations[0][0], c0); - assert_eq!(sumcheck.polynomial_evaluations[0][1], c2); - } - - #[test] - fn test_observe_and_sample_multiple_rounds() { - // Multiple calls should accumulate coefficients in order. - // - // Round 0: push [c0_0, c2_0] - // Round 1: push [c0_1, c2_1] - // Round 2: push [c0_2, c2_2] - let c0_0 = EF::from_u64(1); - let c2_0 = EF::from_u64(2); - let c0_1 = EF::from_u64(3); - let c2_1 = EF::from_u64(4); - let c0_2 = EF::from_u64(5); - let c2_2 = EF::from_u64(6); - - let mut sumcheck: SumcheckData = SumcheckData::default(); - let mut challenger = create_test_challenger(); - let pow_bits = 0; - - // Round 0 - let _r0 = sumcheck.observe_and_sample::<_, F>(&mut challenger, c0_0, c2_0, pow_bits); - assert_eq!(sumcheck.polynomial_evaluations.len(), 1); - - // Round 1 - let _r1 = sumcheck.observe_and_sample::<_, F>(&mut challenger, c0_1, c2_1, pow_bits); - assert_eq!(sumcheck.polynomial_evaluations.len(), 2); - - // Round 2 - let _r2 = sumcheck.observe_and_sample::<_, F>(&mut challenger, c0_2, c2_2, pow_bits); - assert_eq!(sumcheck.polynomial_evaluations.len(), 3); - - // Verify all stored coefficients match input order - assert_eq!(sumcheck.polynomial_evaluations[0], [c0_0, c2_0]); - assert_eq!(sumcheck.polynomial_evaluations[1], [c0_1, c2_1]); - assert_eq!(sumcheck.polynomial_evaluations[2], [c0_2, c2_2]); - } - - #[test] - fn test_observe_and_sample_without_pow() { - // When pow_bits = 0, no PoW witness should be recorded. - // - // The method skips the grinding step when pow_bits is zero, - // so pow_witnesses should remain empty. - let c0 = EF::from_u64(10); - let c2 = EF::from_u64(20); - - let mut sumcheck: SumcheckData = SumcheckData::default(); - let mut challenger = create_test_challenger(); - - // pow_witnesses should be empty initially - assert!(sumcheck.pow_witnesses.is_empty()); - - // Call with pow_bits = 0 - let pow_bits = 0; - let _r = sumcheck.observe_and_sample::<_, F>(&mut challenger, c0, c2, pow_bits); - - // pow_witnesses should still be empty (no grinding performed) - assert!(sumcheck.pow_witnesses.is_empty()); - } - - #[test] - fn test_observe_and_sample_with_pow() { - // When pow_bits > 0, a PoW witness should be recorded. - // - // The method calls challenger.grind(pow_bits) and pushes - // the resulting witness to pow_witnesses. - let c0 = EF::from_u64(10); - let c2 = EF::from_u64(20); - - let mut sumcheck: SumcheckData = SumcheckData::default(); - let mut challenger = create_test_challenger(); - - // pow_witnesses should be empty initially - assert!(sumcheck.pow_witnesses.is_empty()); - - // Call with pow_bits = 1 (minimal PoW) - let pow_bits = 1; - let _r = sumcheck.observe_and_sample::<_, F>(&mut challenger, c0, c2, pow_bits); - - // pow_witnesses should now have one entry - assert_eq!(sumcheck.pow_witnesses.len(), 1); - } - - #[test] - fn test_observe_and_sample_pow_accumulates() { - // Multiple rounds with PoW should accumulate witnesses. - // - // Each call with pow_bits > 0 should add one witness. - let c0 = EF::from_u64(1); - let c2 = EF::from_u64(2); - - let mut sumcheck: SumcheckData = SumcheckData::default(); - let mut challenger = create_test_challenger(); - let pow_bits = 1; - - // Three rounds with PoW - let _r0 = sumcheck.observe_and_sample::<_, F>(&mut challenger, c0, c2, pow_bits); - let _r1 = sumcheck.observe_and_sample::<_, F>(&mut challenger, c0, c2, pow_bits); - let _r2 = sumcheck.observe_and_sample::<_, F>(&mut challenger, c0, c2, pow_bits); - - // Should have 3 witnesses - assert_eq!(sumcheck.pow_witnesses.len(), 3); - // And 3 polynomial evaluations - assert_eq!(sumcheck.polynomial_evaluations.len(), 3); - } - - #[test] - fn test_observe_and_sample_deterministic_challenge() { - // Fiat-Shamir property: same inputs produce same challenge. - // - // Two challengers with the same initial state, observing the same - // coefficients, should sample the same challenge. - let c0 = EF::from_u64(42); - let c2 = EF::from_u64(99); - let pow_bits = 0; - - // First run - let mut sumcheck1: SumcheckData = SumcheckData::default(); - let mut challenger1 = create_test_challenger(); - let r1 = sumcheck1.observe_and_sample::<_, F>(&mut challenger1, c0, c2, pow_bits); - - // Second run with fresh but identically-seeded challenger - let mut sumcheck2: SumcheckData = SumcheckData::default(); - let mut challenger2 = create_test_challenger(); - let r2 = sumcheck2.observe_and_sample::<_, F>(&mut challenger2, c0, c2, pow_bits); - - // Challenges should be identical - assert_eq!(r1, r2); - } - - #[test] - fn test_observe_and_sample_challenge_depends_on_history() { - // The challenge at round i depends on all previous observations. - // - // Two sequences with different history should produce different - // challenges even if the final round has the same coefficients. - let c0 = EF::from_u64(100); - let c2 = EF::from_u64(200); - let pow_bits = 0; - - // Sequence A: observe once, then observe (c0, c2) - let mut sumcheck_a: SumcheckData = SumcheckData::default(); - let mut challenger_a = create_test_challenger(); - let _r0_a = - sumcheck_a.observe_and_sample::<_, F>(&mut challenger_a, EF::ONE, EF::TWO, pow_bits); - let r1_a = sumcheck_a.observe_and_sample::<_, F>(&mut challenger_a, c0, c2, pow_bits); - - // Sequence B: directly observe (c0, c2) without prior round - let mut sumcheck_b: SumcheckData = SumcheckData::default(); - let mut challenger_b = create_test_challenger(); - let r_b = sumcheck_b.observe_and_sample::<_, F>(&mut challenger_b, c0, c2, pow_bits); - - // Challenges should differ due to different transcript history - assert_ne!(r1_a, r_b); - } - - #[test] - fn test_observe_and_sample_returns_extension_field_element() { - // The returned challenge should be a valid extension field element. - // - // This is verified implicitly by the type system, but we can also - // check that it's not trivially zero (with high probability). - let c0 = EF::from_u64(7); - let c2 = EF::from_u64(11); - let pow_bits = 0; - - let mut sumcheck: SumcheckData = SumcheckData::default(); - let mut challenger = create_test_challenger(); - - let r: EF = sumcheck.observe_and_sample::<_, F>(&mut challenger, c0, c2, pow_bits); - - // The challenge should (with overwhelming probability) be non-zero - assert_ne!(r, EF::ZERO); - } -} diff --git a/src/whir/prover/mod.rs b/src/whir/prover/mod.rs index 826465ad..4120a121 100644 --- a/src/whir/prover/mod.rs +++ b/src/whir/prover/mod.rs @@ -1,7 +1,6 @@ use alloc::vec::Vec; use core::ops::Deref; -use p3_challenger::{FieldChallenger, GrindingChallenger}; use p3_commit::{ExtensionMmcs, Mmcs}; use p3_dft::TwoAdicSubgroupDft; use p3_field::{ExtensionField, Field, TwoAdicField}; @@ -19,15 +18,17 @@ use tracing::{info_span, instrument}; use super::{ committer::Witness, constraints::statement::EqStatement, - parameters::{InitialPhaseConfig, WhirConfig}, + parameters::{InitialPhase, WhirConfig}, }; use crate::{ constant::K_SKIP_SUMCHECK, - fiat_shamir::errors::FiatShamirError, + fiat_shamir::{ + errors::FiatShamirError, + transcript::{Challenge, ProverTranscript}, + }, poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, whir::{ constraints::{Constraint, statement::SelectStatement}, - proof::{QueryOpening, SumcheckData, WhirProof}, utils::get_challenge_stir_queries, }, }; @@ -38,31 +39,30 @@ pub type Proof = Vec>; pub type Leafs = Vec>; #[derive(Debug)] -pub struct Prover<'a, EF, F, H, C, Challenger>( +pub struct Prover<'a, F, EF, Hash, Compress>( /// Reference to the protocol configuration shared across prover components. - pub &'a WhirConfig, + pub &'a WhirConfig, ) where F: Field, EF: ExtensionField; -impl Deref for Prover<'_, EF, F, H, C, Challenger> +impl Deref for Prover<'_, F, EF, Hash, Compress> where F: Field, EF: ExtensionField, { - type Target = WhirConfig; + type Target = WhirConfig; fn deref(&self) -> &Self::Target { self.0 } } -impl Prover<'_, EF, F, H, C, Challenger> +impl Prover<'_, F, EF, Hash, Compress> where F: TwoAdicField + Ord, EF: ExtensionField + TwoAdicField, - Challenger: FieldChallenger + GrindingChallenger, { /// Validates that the total number of variables expected by the prover configuration /// matches the number implied by the folding schedule and the final rounds. @@ -90,7 +90,7 @@ where /// `true` if the statement structure is valid for this protocol instance. const fn validate_statement(&self, statement: &EqStatement) -> bool { statement.num_variables() == self.0.num_variables - && (self.0.initial_phase_config.has_initial_statement() || statement.is_empty()) + && (self.0.initial_phase.has_initial_statement() || statement.is_empty()) } /// Validates that the witness satisfies the structural requirements of the WHIR prover. @@ -113,7 +113,7 @@ where &self, witness: &Witness, DIGEST_ELEMS>, ) -> bool { - if !self.0.initial_phase_config.has_initial_statement() { + if !self.0.initial_phase.has_initial_statement() { assert!(witness.ood_statement.is_empty()); } witness.polynomial.num_variables() == self.0.num_variables @@ -131,8 +131,7 @@ where /// /// # Parameters /// - `dft`: A DFT backend used for evaluations - /// - `proof`: Mutable proof structure to store the generated proof data - /// - `challenger`: Mutable Fiat-Shamir challenger for transcript management + /// - `transcript`: Mutable Fiat-Shamir challenger for transcript management /// - `statement`: The public input, consisting of linear or nonlinear constraints /// - `witness`: The private witness satisfying the constraints, including committed values /// @@ -140,21 +139,21 @@ where /// # Errors /// Returns an error if the witness or statement are invalid, or if a round fails. #[instrument(skip_all)] - pub fn prove, const DIGEST_ELEMS: usize>( + pub fn prove, Transcript, const DIGEST_ELEMS: usize>( &self, dft: &Dft, - proof: &mut WhirProof, - challenger: &mut Challenger, + transcript: &mut Transcript, statement: EqStatement, witness: Witness, DIGEST_ELEMS>, ) -> Result<(), FiatShamirError> where - H: CryptographicHasher + Hash: CryptographicHasher + CryptographicHasher + Sync, - C: PseudoCompressionFunction<[F; DIGEST_ELEMS], 2> + Compress: PseudoCompressionFunction<[F; DIGEST_ELEMS], 2> + PseudoCompressionFunction<[F::Packing; DIGEST_ELEMS], 2> + Sync, + Transcript: ProverTranscript, [F; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>, { // Validate parameters @@ -167,11 +166,11 @@ where // Initialize the round state with inputs and initial polynomial data let mut round_state = - RoundState::initialize_first_round_state(self, proof, challenger, statement, witness)?; + RoundState::initialize_first_round_state(self, transcript, statement, witness)?; // Run the WHIR protocol round-by-round for round in 0..=self.n_rounds() { - self.round(dft, round, proof, challenger, &mut round_state)?; + self.round(dft, transcript, round, &mut round_state)?; } Ok(()) @@ -179,21 +178,21 @@ where #[instrument(skip_all, fields(round_number = round_index, log_size = self.num_variables - self.folding_factor.total_number(round_index)))] #[allow(clippy::too_many_lines)] - fn round>( + fn round, Transcript>( &self, dft: &Dft, + transcript: &mut Transcript, round_index: usize, - proof: &mut WhirProof, - challenger: &mut Challenger, round_state: &mut RoundState, DIGEST_ELEMS>, ) -> Result<(), FiatShamirError> where - H: CryptographicHasher + Hash: CryptographicHasher + CryptographicHasher + Sync, - C: PseudoCompressionFunction<[F; DIGEST_ELEMS], 2> + Compress: PseudoCompressionFunction<[F; DIGEST_ELEMS], 2> + PseudoCompressionFunction<[F::Packing; DIGEST_ELEMS], 2> + Sync, + Transcript: ProverTranscript, [F; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>, { let folded_evaluations = &round_state.sumcheck_prover.evals(); @@ -202,7 +201,7 @@ where // Base case: final round reached if round_index == self.n_rounds() { - return self.final_round(round_index, proof, challenger, round_state); + return self.final_round(transcript, round_index, round_state); } let round_params = &self.round_parameters[round_index]; @@ -233,7 +232,7 @@ where let folded_matrix = info_span!("dft", height = padded.height(), width = padded.width()) .in_scope(|| dft.dft_algebra_batch(padded).to_row_major_matrix()); - let mmcs = MerkleTreeMmcs::::new( + let mmcs = MerkleTreeMmcs::::new( self.merkle_hash.clone(), self.merkle_compress.clone(), ); @@ -242,28 +241,19 @@ where info_span!("commit matrix").in_scope(|| extension_mmcs.commit_matrix(folded_matrix)); // Observe the round merkle tree commitment - challenger.observe_slice(root.as_ref()); - - // Store commitment in proof - proof.rounds[round_index].commitment = root.into(); + transcript.write(*root.as_ref())?; // Handle OOD (Out-Of-Domain) samples let mut ood_statement = EqStatement::initialize(num_variables); - let mut ood_answers = Vec::with_capacity(round_params.ood_samples); - (0..round_params.ood_samples).for_each(|_| { - let point = MultilinearPoint::expand_from_univariate( - challenger.sample_algebra_element(), - num_variables, - ); - let eval = round_state.sumcheck_prover.eval(&point); - challenger.observe_algebra_element(eval); - ood_answers.push(eval); + (0..round_params.ood_samples).try_for_each(|_| { + let point = >::sample(transcript); + let point = MultilinearPoint::expand_from_univariate(point, num_variables); + let eval = round_state.sumcheck_prover.eval(&point); + transcript.write(eval)?; ood_statement.add_evaluated_constraint(point, eval); - }); - - // Store OOD answers in proof - proof.rounds[round_index].ood_answers = ood_answers; + Ok(()) + })?; // CRITICAL: Perform proof-of-work grinding to finalize the transcript before querying. // @@ -278,59 +268,35 @@ where // By forcing the prover to perform this expensive proof-of-work *after* committing but // *before* receiving the queries, we make it computationally infeasible to "shop" for // favorable challenges. The grinding effectively "locks in" the prover's commitment. - if round_params.pow_bits > 0 { - proof.rounds[round_index].pow_witness = challenger.grind(round_params.pow_bits); - } - - challenger.sample(); + transcript.pow(round_params.pow_bits)?; + let _ = >::sample(transcript); // STIR Queries - let stir_challenges_indexes = get_challenge_stir_queries::( + let stir_challenges_indexes = get_challenge_stir_queries::( + transcript, round_state.domain_size, self.folding_factor.at_round(round_index), round_params.num_queries, - challenger, )?; - let stir_vars = stir_challenges_indexes - .iter() - .map(|&i| round_state.next_domain_gen.exp_u64(i as u64)) - .collect::>(); - let mut stir_statement = SelectStatement::initialize(num_variables); - // Initialize vector of queries - let mut queries = Vec::with_capacity(stir_challenges_indexes.len()); - // Collect Merkle proofs for stir queries match &round_state.merkle_prover_data { None => { - let mut answers = Vec::with_capacity(stir_challenges_indexes.len()); - for challenge in &stir_challenges_indexes { + for idx in stir_challenges_indexes { let commitment = - mmcs.open_batch(*challenge, &round_state.commitment_merkle_prover_data); + mmcs.open_batch(idx, &round_state.commitment_merkle_prover_data); let answer = commitment.opened_values[0].clone(); - answers.push(answer.clone()); + transcript.write_hint_many(&answer)?; + transcript.write_hint_many(&commitment.opening_proof)?; - queries.push(QueryOpening::Base { - values: answer.clone(), - proof: commitment.opening_proof, - }); - } + // Determine if this is the special first round where the univariate skip is applied. + let is_skip_round = round_index == 0 + && matches!(self.initial_phase, InitialPhase::WithStatementSkip) + && self.folding_factor.at_round(0) >= K_SKIP_SUMCHECK; - // Determine if this is the special first round where the univariate skip is applied. - let is_skip_round = round_index == 0 - && matches!( - self.initial_phase_config, - InitialPhaseConfig::WithStatementUnivariateSkip - ) - && self.folding_factor.at_round(0) >= K_SKIP_SUMCHECK; - - // Process each set of evaluations retrieved from the Merkle tree openings. - for (answer, var) in answers.iter().zip(stir_vars.into_iter()) { - let evals = EvaluationsList::new(answer.clone()); - // Fold the polynomial represented by the `answer` evaluations using the verifier's challenge. - // The evaluation method depends on whether this is a "skip round" or a "standard round". + let evals = EvaluationsList::new(answer); if is_skip_round { // Case 1: Univariate Skip Round Evaluation // @@ -363,7 +329,8 @@ where // Evaluate the resulting smaller polynomial at the remaining challenges `r_rest`. let eval = EvaluationsList::new(folded_row).evaluate_hypercube_ext::(&r_rest); - stir_statement.add_constraint(var, eval); + stir_statement + .add_constraint(round_state.next_domain_gen.exp_u64(idx as u64), eval); } else { // Case 2: Standard Sumcheck Round // @@ -371,51 +338,40 @@ where // Perform a standard multilinear evaluation at the full challenge point `r`. let eval = evals.evaluate_hypercube_base(&round_state.folding_randomness); - stir_statement.add_constraint(var, eval); + stir_statement + .add_constraint(round_state.next_domain_gen.exp_u64(idx as u64), eval); } } } Some(data) => { - let mut answers = Vec::with_capacity(stir_challenges_indexes.len()); - for challenge in &stir_challenges_indexes { - let commitment = extension_mmcs.open_batch(*challenge, data); + for idx in stir_challenges_indexes { + let commitment = extension_mmcs.open_batch(idx, data); let answer = commitment.opened_values[0].clone(); - answers.push(answer.clone()); - queries.push(QueryOpening::Extension { - values: answer.clone(), - proof: commitment.opening_proof, - }); - } + transcript.write_hint_many(&answer)?; + transcript.write_hint_many(&commitment.opening_proof)?; - // Process each set of evaluations retrieved from the Merkle tree openings. - for (answer, var) in answers.iter().zip(stir_vars.into_iter()) { // Wrap the evaluations to represent the polynomial. - let evals = EvaluationsList::new(answer.clone()); + let evals = EvaluationsList::new(answer); // Perform a standard multilinear evaluation at the full challenge point `r`. let eval = evals.evaluate_hypercube_ext::(&round_state.folding_randomness); - stir_statement.add_constraint(var, eval); + stir_statement + .add_constraint(round_state.next_domain_gen.exp_u64(idx as u64), eval); } } } - // Store queries in proof - proof.rounds[round_index].queries = queries; - let constraint = Constraint::new( - challenger.sample_algebra_element(), + >::sample(transcript), ood_statement, stir_statement, ); - let mut sumcheck_data: SumcheckData = SumcheckData::default(); let folding_randomness = round_state.sumcheck_prover.compute_sumcheck_polynomials( - &mut sumcheck_data, - challenger, + transcript, folding_factor_next, round_params.folding_pow_bits, Some(constraint), - ); - proof.set_sumcheck_data_at(sumcheck_data, round_index); + )?; // Update round state round_state.domain_size = new_domain_size; @@ -428,27 +384,24 @@ where } #[instrument(skip_all)] - fn final_round( + fn final_round( &self, + transcript: &mut Transcript, round_index: usize, - proof: &mut WhirProof, - challenger: &mut Challenger, round_state: &mut RoundState, DIGEST_ELEMS>, ) -> Result<(), FiatShamirError> where - H: CryptographicHasher + Hash: CryptographicHasher + CryptographicHasher + Sync, - C: PseudoCompressionFunction<[F; DIGEST_ELEMS], 2> + Compress: PseudoCompressionFunction<[F; DIGEST_ELEMS], 2> + PseudoCompressionFunction<[F::Packing; DIGEST_ELEMS], 2> + Sync, + Transcript: ProverTranscript, [F; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>, { // Directly send coefficients of the polynomial to the verifier. - challenger.observe_algebra_slice(round_state.sumcheck_prover.evals().as_slice()); - - // Store the final polynomial in the proof - proof.final_poly = Some(round_state.sumcheck_prover.evals()); + transcript.write_many(round_state.sumcheck_prover.evals().as_slice())?; // CRITICAL: Perform proof-of-work grinding to finalize the transcript before querying. // @@ -463,23 +416,21 @@ where // By forcing the prover to perform this expensive proof-of-work *after* committing but // *before* receiving the queries, we make it computationally infeasible to "shop" for // favorable challenges. The grinding effectively "locks in" the prover's commitment. - if self.final_pow_bits > 0 { - proof.final_pow_witness = challenger.grind(self.final_pow_bits); - } + transcript.pow(self.final_pow_bits)?; // Final verifier queries and answers. The indices are over the folded domain. - let final_challenge_indexes = get_challenge_stir_queries::( + let final_challenge_indexes = get_challenge_stir_queries::( + transcript, // The size of the original domain before folding round_state.domain_size, // The folding factor we used to fold the previous polynomial self.folding_factor.at_round(round_index), // Number of final verification queries self.final_queries, - challenger, )?; // Every query requires opening these many in the previous Merkle tree - let mmcs = MerkleTreeMmcs::::new( + let mmcs = MerkleTreeMmcs::::new( self.merkle_hash.clone(), self.merkle_compress.clone(), ); @@ -487,40 +438,30 @@ where match &round_state.merkle_prover_data { None => { - for challenge in final_challenge_indexes { + for idx in final_challenge_indexes { let commitment = - mmcs.open_batch(challenge, &round_state.commitment_merkle_prover_data); - - proof.final_queries.push(QueryOpening::Base { - values: commitment.opened_values[0].clone(), - proof: commitment.opening_proof, - }); + mmcs.open_batch(idx, &round_state.commitment_merkle_prover_data); + transcript.write_hint_many(&commitment.opened_values[0])?; + transcript.write_hint_many(&commitment.opening_proof)?; } } Some(data) => { for challenge in final_challenge_indexes { let commitment = extension_mmcs.open_batch(challenge, data); - proof.final_queries.push(QueryOpening::Extension { - values: commitment.opened_values[0].clone(), - proof: commitment.opening_proof, - }); + transcript.write_hint_many(&commitment.opened_values[0])?; + transcript.write_hint_many(&commitment.opening_proof)?; } } } // Run final sumcheck if required - if self.final_sumcheck_rounds > 0 { - let mut sumcheck_data: SumcheckData = SumcheckData::default(); - round_state.sumcheck_prover.compute_sumcheck_polynomials( - &mut sumcheck_data, - challenger, - self.final_sumcheck_rounds, - self.final_folding_pow_bits, - None, - ); - proof.set_final_sumcheck_data(sumcheck_data); - } + round_state.sumcheck_prover.compute_sumcheck_polynomials( + transcript, + self.final_sumcheck_rounds, + self.final_folding_pow_bits, + None, + )?; Ok(()) } diff --git a/src/whir/prover/round_state/state.rs b/src/whir/prover/round_state/state.rs index 6d86c82a..da3b5a76 100644 --- a/src/whir/prover/round_state/state.rs +++ b/src/whir/prover/round_state/state.rs @@ -2,9 +2,8 @@ //! //! This module implements the core round state management for the WHIR protocol. -use alloc::{sync::Arc, vec::Vec}; +use alloc::sync::Arc; -use p3_challenger::{FieldChallenger, GrindingChallenger}; use p3_field::{ExtensionField, TwoAdicField}; use p3_matrix::dense::DenseMatrix; use p3_merkle_tree::MerkleTree; @@ -12,13 +11,16 @@ use tracing::instrument; use crate::{ constant::K_SKIP_SUMCHECK, - fiat_shamir::errors::FiatShamirError, + fiat_shamir::{ + errors::FiatShamirError, + transcript::{Challenge, Pow, Writer}, + }, poly::multilinear::MultilinearPoint, sumcheck::sumcheck_single::SumcheckSingle, whir::{ committer::{RoundMerkleTree, Witness}, constraints::{Constraint, statement::EqStatement}, - proof::{InitialPhase, SumcheckSkipData, WhirProof}, + parameters::InitialPhase, prover::Prover, }, }; @@ -135,36 +137,31 @@ where /// /// Returns the complete `RoundState` ready for the first WHIR folding round. #[instrument(skip_all)] - pub fn initialize_first_round_state( - prover: &Prover<'_, EF, F, MyChallenger, C, Challenger>, - proof: &mut WhirProof, - challenger: &mut Challenger, + pub fn initialize_first_round_state( + prover: &Prover<'_, F, EF, Hash, Compress>, + transcript: &mut Transcript, mut statement: EqStatement, witness: Witness, DIGEST_ELEMS>, ) -> Result where - Challenger: FieldChallenger + GrindingChallenger, - MyChallenger: Clone, - C: Clone, + Transcript: Writer<[F; DIGEST_ELEMS]> + Writer + Challenge + Pow, { // Append OOD constraints to statement for Reed-Solomon proximity testing statement.concatenate(&witness.ood_statement); // Protocol branching based on initial phase variant in proof - let (sumcheck_prover, folding_randomness) = match &mut proof.initial_phase { + let (sumcheck_prover, folding_randomness) = match prover.initial_phase { // Branch: WithStatementSkip - use univariate skip optimization - InitialPhase::WithStatementSkip(skip_data) + InitialPhase::WithStatementSkip if K_SKIP_SUMCHECK <= prover.folding_factor.at_round(0) => { // Build constraint with random linear combination - let constraint = - Constraint::new_eq_only(challenger.sample_algebra_element(), statement.clone()); + let constraint = Constraint::new_eq_only(transcript.sample(), statement.clone()); // Use univariate skip by skipping k variables SumcheckSingle::with_skip( &witness.polynomial, - skip_data, - challenger, + transcript, prover.folding_factor.at_round(0), prover.starting_folding_pow_bits, K_SKIP_SUMCHECK, @@ -173,7 +170,7 @@ where } // Branch: WithStatementSvo - SVO optimization - InitialPhase::WithStatementSvo { sumcheck } => { + InitialPhase::WithStatementSvo => { // SVO optimization requirements (see Procedure 9 in https://eprint.iacr.org/2025/1117): // 1. At least 2 * NUM_SVO_ROUNDS variables - The SVO algorithm partitions // variables into Prefix (k), Inner (l/2), and Outer segments. For these @@ -186,8 +183,7 @@ where const MIN_SVO_FOLDING_FACTOR: usize = 6; // Build constraint with random linear combination - let constraint = - Constraint::new_eq_only(challenger.sample_algebra_element(), statement.clone()); + let constraint = Constraint::new_eq_only(transcript.sample(), statement.clone()); let folding_factor = prover.folding_factor.at_round(0); let has_single_constraint = constraint.eq_statement.len() == 1; @@ -196,9 +192,8 @@ where // Use SVO optimization: first 3 rounds use specialized algorithm, // remaining rounds use standard Algorithm 5 SumcheckSingle::from_base_evals_svo( + transcript, &witness.polynomial, - sumcheck, - challenger, folding_factor, prover.starting_folding_pow_bits, &constraint, @@ -208,9 +203,8 @@ where // - Input is too small (folding_factor < MIN_SVO_FOLDING_FACTOR) // - Multiple constraints exist (SVO only handles single constraint, see TODO above) SumcheckSingle::from_base_evals( + transcript, &witness.polynomial, - sumcheck, - challenger, folding_factor, prover.starting_folding_pow_bits, &constraint, @@ -219,17 +213,14 @@ where } // Branch: WithStatement or WithStatementSkip (fallback when folding_factor < K_SKIP) - InitialPhase::WithStatement { sumcheck } - | InitialPhase::WithStatementSkip(SumcheckSkipData { sumcheck, .. }) => { + InitialPhase::WithStatementClassic | InitialPhase::WithStatementSkip => { // Build constraint with random linear combination - let constraint = - Constraint::new_eq_only(challenger.sample_algebra_element(), statement.clone()); + let constraint = Constraint::new_eq_only(transcript.sample(), statement.clone()); // Standard sumcheck protocol without optimization SumcheckSingle::from_base_evals( + transcript, &witness.polynomial, - sumcheck, - challenger, prover.folding_factor.at_round(0), prover.starting_folding_pow_bits, &constraint, @@ -237,12 +228,10 @@ where } // Branch: WithoutStatement - direct polynomial folding path - InitialPhase::WithoutStatement { pow_witness } => { + InitialPhase::WithoutStatement => { // Sample folding challenges α_1, ..., α_k let folding_randomness = MultilinearPoint::new( - (0..prover.folding_factor.at_round(0)) - .map(|_| challenger.sample_algebra_element()) - .collect::>(), + transcript.sample_many(prover.folding_factor.at_round(0)), ); // Apply folding transformation: f(X_0, ..., X_{n-1}) → f'(X_k, ..., X_{n-1}) @@ -257,13 +246,11 @@ where ); // Apply proof-of-work grinding and store witness (only if pow_bits > 0) - if prover.starting_folding_pow_bits > 0 { - *pow_witness = challenger.grind(prover.starting_folding_pow_bits); - } + transcript.pow(prover.starting_folding_pow_bits)?; - (sumcheck, folding_randomness) + Ok((sumcheck, folding_randomness)) } - }; + }?; // Initialize complete round state for first WHIR protocol round Ok(Self { diff --git a/src/whir/prover/round_state/tests.rs b/src/whir/prover/round_state/tests.rs index a81f93cc..c9b434de 100644 --- a/src/whir/prover/round_state/tests.rs +++ b/src/whir/prover/round_state/tests.rs @@ -9,21 +9,20 @@ use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; use rand::{SeedableRng, rngs::SmallRng}; use crate::{ - fiat_shamir::domain_separator::DomainSeparator, + fiat_shamir::{domain_separator::DomainSeparator, transcript::FiatShamirWriter}, parameters::{FoldingFactor, ProtocolParameters, errors::SecurityAssumption}, poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, whir::{ WhirConfig, committer::{Witness, writer::CommitmentWriter}, constraints::statement::EqStatement, - parameters::InitialPhaseConfig, - proof::WhirProof, + parameters::InitialPhase, prover::{Prover, round_state::RoundState}, }, }; type F = BabyBear; -type EF4 = BinomialExtensionField; +type EF = BinomialExtensionField; type Perm = Poseidon2BabyBear<16>; type MyHash = PaddingFreeSponge; @@ -44,10 +43,10 @@ const DIGEST_ELEMS: usize = 8; /// for round state construction in WHIR tests. fn make_test_config( num_variables: usize, - initial_phase_config: InitialPhaseConfig, + initial_phase_config: InitialPhase, folding_factor: usize, pow_bits: usize, -) -> WhirConfig { +) -> WhirConfig { let mut rng = SmallRng::seed_from_u64(1); let perm = Perm::new_from_rng_128(&mut rng); @@ -57,7 +56,7 @@ fn make_test_config( // Define the core protocol parameters for WHIR, customizing behavior based // on whether to start with an initial sumcheck and how to fold the polynomial. let protocol_params = ProtocolParameters { - initial_phase_config, + initial_phase: initial_phase_config, security_level: 80, pow_bits, rs_domain_initial_reduction_factor: 1, @@ -83,61 +82,38 @@ fn make_test_config( /// This is used as a boilerplate step before running the first WHIR round. #[allow(clippy::type_complexity)] fn setup_domain_and_commitment( - params: &WhirConfig, + params: &WhirConfig, poly: EvaluationsList, ) -> ( - WhirProof, - MyChallenger, - Witness, DIGEST_ELEMS>, + FiatShamirWriter, + Witness, DIGEST_ELEMS>, ) { - // Build ProtocolParameters from WhirConfig fields - let protocol_params = ProtocolParameters { - initial_phase_config: params.initial_phase_config, - security_level: params.security_level, - pow_bits: params.starting_folding_pow_bits, - folding_factor: params.folding_factor, - merkle_hash: params.merkle_hash.clone(), - merkle_compress: params.merkle_compress.clone(), - soundness_type: params.soundness_type, - starting_log_inv_rate: params.starting_log_inv_rate, - rs_domain_initial_reduction_factor: 1, - }; - - // Create WhirProof structure from protocol parameters - let whir_proof = WhirProof::from_protocol_parameters(&protocol_params, poly.num_variables()); - // Create a new Fiat-Shamir domain separator. let mut domsep = DomainSeparator::new(vec![]); // Observe the public statement into the transcript for binding. - domsep.commit_statement::<_, _, _, 8>(params); + domsep.commit_statement::<_, _, 8>(params); // Reserve transcript space for WHIR proof messages. - domsep.add_whir_proof::<_, _, _, 8>(params); + domsep.add_whir_proof::<_, _, 8>(params); // Convert the domain separator into a mutable prover-side transcript. let mut rng = SmallRng::seed_from_u64(1); - let mut prover_challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); - domsep.observe_domain_separator(&mut prover_challenger); + let mut challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); + domsep.observe_domain_separator(&mut challenger); + let mut transcript = FiatShamirWriter::init(challenger); // Create a committer using the protocol configuration (Merkle parameters, hashers, etc.). let committer = CommitmentWriter::new(params); - let mut proof = WhirProof::from_protocol_parameters(&protocol_params, poly.num_variables()); - // Perform DFT-based commitment to the polynomial, producing a witness // which includes the Merkle tree and polynomial values. let witness = committer - .commit( - &Radix2DFTSmallBatch::::default(), - &mut proof, - &mut prover_challenger, - poly, - ) + .commit(&Radix2DFTSmallBatch::::default(), &mut transcript, poly) .unwrap(); // Return all initialized components needed for round state setup. - (whir_proof, prover_challenger, witness) + (transcript, witness) } #[test] @@ -149,7 +125,7 @@ fn test_no_initial_statement_no_sumcheck() { // - no initial sumcheck, // - folding factor 2, // - no PoW grinding. - let config = make_test_config(num_variables, InitialPhaseConfig::WithoutStatement, 2, 0); + let config = make_test_config(num_variables, InitialPhase::WithoutStatement, 2, 0); // Define a polynomial let poly = EvaluationsList::new(vec![F::from_u64(3); 1 << num_variables]); @@ -158,16 +134,15 @@ fn test_no_initial_statement_no_sumcheck() { // - domain separator for Fiat-Shamir transcript, // - prover state, // - witness containing Merkle tree for `poly`. - let (mut proof, mut challenger, witness) = setup_domain_and_commitment(&config, poly); + let (mut transcript, witness) = setup_domain_and_commitment(&config, poly); // Create an empty public statement (no constraints) - let statement = EqStatement::::initialize(num_variables); + let statement = EqStatement::::initialize(num_variables); // Initialize the round state using the setup configuration and witness let state = RoundState::initialize_first_round_state( &Prover(&config), - &mut proof, - &mut challenger, + &mut transcript, statement, witness, ) @@ -189,12 +164,7 @@ fn test_initial_statement_with_folding_factor_3() { // - initial statement enabled (sumcheck will run), // - folding factor = 3 (fold all variables in the first round), // - PoW disabled. - let config = make_test_config( - num_variables, - InitialPhaseConfig::WithStatementClassic, - 3, - 0, - ); + let config = make_test_config(num_variables, InitialPhase::WithStatementClassic, 3, 0); // Define the multilinear polynomial: // f(X0, X1, X2) = 1 + 2*X2 + 3*X1 + 4*X1*X2 @@ -220,7 +190,7 @@ fn test_initial_statement_with_folding_factor_3() { ]); // Manual redefinition of the same polynomial as a function for evaluation - let f = |x0: EF4, x1: EF4, x2: EF4| { + let f = |x0: EF, x1: EF, x2: EF| { x2 * e2 + x1 * e3 + x1 * x2 * e4 @@ -232,20 +202,19 @@ fn test_initial_statement_with_folding_factor_3() { }; // Add a single equality constraint to the statement: f(1,1,1) = expected value - let mut statement = EqStatement::::initialize(num_variables); + let mut statement = EqStatement::::initialize(num_variables); statement.add_evaluated_constraint( - MultilinearPoint::new(vec![EF4::ONE, EF4::ONE, EF4::ONE]), - f(EF4::ONE, EF4::ONE, EF4::ONE), + MultilinearPoint::new(vec![EF::ONE, EF::ONE, EF::ONE]), + f(EF::ONE, EF::ONE, EF::ONE), ); // Set up the domain separator, prover state, and witness for this configuration - let (mut proof, mut challenger_rf, witness) = setup_domain_and_commitment(&config, poly); + let (mut transcript, witness) = setup_domain_and_commitment(&config, poly); // Run the first round state initialization (this will trigger sumcheck) let state = RoundState::initialize_first_round_state( &Prover(&config), - &mut proof, - &mut challenger_rf, + &mut transcript, statement, witness, ) @@ -268,7 +237,7 @@ fn test_initial_statement_with_folding_factor_3() { assert_eq!(eval_at_point, expected); // Check that dot product of evaluations and weights matches the final sum - let dot_product: EF4 = sumcheck.poly.dot_product(); + let dot_product: EF = sumcheck.poly.dot_product(); assert_eq!(dot_product, sumcheck.sum); // The `folding_randomness` should store values in forward order (X0, X1, X2) @@ -291,35 +260,29 @@ fn test_zero_poly_multiple_constraints() { let num_variables = 3; // Build a WHIR config with an initial statement, folding factor 1, and no PoW - let config = make_test_config( - num_variables, - InitialPhaseConfig::WithStatementClassic, - 1, - 0, - ); + let config = make_test_config(num_variables, InitialPhase::WithStatementClassic, 1, 0); // Define a zero polynomial: f(X) = 0 for all X let poly = EvaluationsList::new(vec![F::ZERO; 1 << num_variables]); // Generate domain separator, prover state, and Merkle commitment witness for the poly - let (mut proof, mut challenger_rf, witness) = setup_domain_and_commitment(&config, poly); + let (mut transcript, witness) = setup_domain_and_commitment(&config, poly); // Create a new statement with multiple constraints - let mut statement = EqStatement::::initialize(num_variables); + let mut statement = EqStatement::::initialize(num_variables); // Add one equality constraint per Boolean input: f(x) = 0 for all x ∈ {0,1}³ for i in 0..1 << num_variables { let point = (0..num_variables) - .map(|b| EF4::from_u64(((i >> b) & 1) as u64)) + .map(|b| EF::from_u64(((i >> b) & 1) as u64)) .collect(); - statement.add_evaluated_constraint(MultilinearPoint::new(point), EF4::ZERO); + statement.add_evaluated_constraint(MultilinearPoint::new(point), EF::ZERO); } // Initialize the first round of the WHIR protocol with the zero polynomial and constraints let state = RoundState::initialize_first_round_state( &Prover(&config), - &mut proof, - &mut challenger_rf, + &mut transcript, statement, witness, ) @@ -331,12 +294,12 @@ fn test_zero_poly_multiple_constraints() { for (f, w) in sumcheck.evals().iter().zip(&sumcheck.weights()) { // Each evaluation should be 0 - assert_eq!(*f, EF4::ZERO); + assert_eq!(*f, EF::ZERO); // Their contribution to the weighted sum should also be 0 - assert_eq!(*f * *w, EF4::ZERO); + assert_eq!(*f * *w, EF::ZERO); } // Final claimed sum is 0 - assert_eq!(sumcheck.sum, EF4::ZERO); + assert_eq!(sumcheck.sum, EF::ZERO); // Folding randomness should have length equal to the folding factor (1) assert_eq!(sumcheck_randomness.num_variables(), 1); @@ -365,7 +328,7 @@ fn test_initialize_round_state_with_initial_statement() { // - PoW bits enabled. let config = make_test_config( num_variables, - InitialPhaseConfig::WithStatementClassic, + InitialPhase::WithStatementClassic, 1, pow_bits, ); @@ -393,7 +356,7 @@ fn test_initialize_round_state_with_initial_statement() { ]); // Equivalent function for evaluating the polynomial manually - let f = |x0: EF4, x1: EF4, x2: EF4| { + let f = |x0: EF, x1: EF, x2: EF| { x2 * e2 + x1 * e3 + x1 * x2 * e4 @@ -405,21 +368,20 @@ fn test_initialize_round_state_with_initial_statement() { }; // Construct a statement with one evaluation constraint at the point (1, 0, 1) - let mut statement = EqStatement::::initialize(num_variables); + let mut statement = EqStatement::::initialize(num_variables); statement.add_evaluated_constraint( - MultilinearPoint::new(vec![EF4::ONE, EF4::ZERO, EF4::ONE]), - f(EF4::ONE, EF4::ZERO, EF4::ONE), + MultilinearPoint::new(vec![EF::ONE, EF::ZERO, EF::ONE]), + f(EF::ONE, EF::ZERO, EF::ONE), ); // Set up Fiat-Shamir domain and produce commitment + witness // Generate domain separator, prover state, and Merkle commitment witness for the poly - let (mut proof, mut challenger_rf, witness) = setup_domain_and_commitment(&config, poly); + let (mut transcript, witness) = setup_domain_and_commitment(&config, poly); // Run the first round initialization let state = RoundState::initialize_first_round_state( &Prover(&config), - &mut proof, - &mut challenger_rf, + &mut transcript, statement, witness, ) @@ -433,13 +395,13 @@ fn test_initialize_round_state_with_initial_statement() { let evals_f = &sumcheck.evals(); assert_eq!( evals_f.evaluate_hypercube_ext::(&MultilinearPoint::new(vec![ - EF4::from_u64(32636), - EF4::from_u64(9876) + EF::from_u64(32636), + EF::from_u64(9876) ])), f( sumcheck_randomness[0], - EF4::from_u64(32636), - EF4::from_u64(9876), + EF::from_u64(32636), + EF::from_u64(9876), ) ); diff --git a/src/whir/utils.rs b/src/whir/utils.rs index 1dfc643b..9255100d 100644 --- a/src/whir/utils.rs +++ b/src/whir/utils.rs @@ -1,10 +1,9 @@ use alloc::vec::Vec; -use p3_challenger::FieldChallenger; -use p3_field::{ExtensionField, Field}; +use p3_field::Field; use p3_util::log2_strict_usize; -use crate::fiat_shamir::errors::FiatShamirError; +use crate::fiat_shamir::{errors::FiatShamirError, transcript::ChallengeBits}; /// Computes the optimal workload size for `T` to fit in L1 cache (32 KB). /// @@ -48,17 +47,12 @@ pub const fn workload_size() -> usize { /// /// ## Returns /// Sorted, deduplicated vector of query indices in [0, folded_domain_size) -pub fn get_challenge_stir_queries( +pub fn get_challenge_stir_queries( + transcript: &mut Transcript, domain_size: usize, folding_factor: usize, num_queries: usize, - challenger: &mut Challenger, -) -> Result, FiatShamirError> -where - Challenger: FieldChallenger, - F: Field, - EF: ExtensionField, -{ +) -> Result, FiatShamirError> { // COMPUTE DOMAIN AND BATCHING PARAMETERS // Apply folding to get the final, smaller domain size. @@ -91,7 +85,7 @@ where // safe call to the transcript, reducing N transcript operations to just 1. // Sample all the random bits needed for all queries in one go. - let mut all_bits = challenger.sample_bits(total_bits_needed); + let mut all_bits = transcript.sample(total_bits_needed); // Create a bitmask to extract `domain_size_bits` chunks from the sampled randomness. // // Example: 16 bits -> (1 << 16) - 1 -> 0b1111_1111_1111_1111 @@ -135,7 +129,7 @@ where // Sample just enough bits for the current batch. // // This is the expensive operation. - let mut all_bits = challenger.sample_bits(batch_bits); + let mut all_bits = transcript.sample(batch_bits); // Unpack the batch of bits into query indices, same as the single-batch path. for _ in 0..batch_size { @@ -154,7 +148,7 @@ where // 2 queries per call), we fall back to the naive approach of one call per query. for _ in 0..num_queries { - let value = challenger.sample_bits(domain_size_bits); + let value = transcript.sample(domain_size_bits); queries.push(value); } } diff --git a/src/whir/verifier/mod.rs b/src/whir/verifier/mod.rs index 470d7089..42411410 100644 --- a/src/whir/verifier/mod.rs +++ b/src/whir/verifier/mod.rs @@ -1,14 +1,14 @@ -use alloc::{format, vec, vec::Vec}; +use alloc::vec::Vec; use core::{fmt::Debug, ops::Deref, slice::from_ref}; use errors::VerifierError; -use p3_challenger::{FieldChallenger, GrindingChallenger}; use p3_commit::{BatchOpeningRef, ExtensionMmcs, Mmcs}; -use p3_field::{ExtensionField, Field, TwoAdicField}; +use p3_field::{ExtensionField, TwoAdicField}; use p3_interpolation::interpolate_subgroup; use p3_matrix::Dimensions; use p3_merkle_tree::MerkleTreeMmcs; use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction}; +use p3_util::log2_strict_usize; use serde::{Deserialize, Serialize}; use tracing::instrument; @@ -18,15 +18,13 @@ use super::{ use crate::{ alloc::string::ToString, constant::K_SKIP_SUMCHECK, + fiat_shamir::transcript::{Challenge, VerifierTranscript}, poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, whir::{ EqStatement, constraints::{Constraint, evaluator::ConstraintPolyEvaluator, statement::SelectStatement}, - parameters::{InitialPhaseConfig, WhirConfig}, - proof::{QueryOpening, WhirProof}, - verifier::sumcheck::{ - verify_final_sumcheck_rounds, verify_initial_sumcheck_rounds, verify_sumcheck_rounds, - }, + parameters::{InitialPhase, WhirConfig}, + verifier::sumcheck::{verify_initial_sumcheck_rounds, verify_standard_sumcheck_rounds}, }, }; @@ -38,37 +36,41 @@ pub mod sumcheck; /// This type provides a lightweight, ergonomic interface to verification methods /// by wrapping a reference to the `WhirConfig`. #[derive(Debug)] -pub struct Verifier<'a, EF, F, H, C, Challenger>( +pub struct Verifier<'a, F, EF, Hash, Compress>( /// Reference to the verifier’s configuration containing all round parameters. - pub(crate) &'a WhirConfig, -) -where - F: Field, - EF: ExtensionField; + pub(crate) &'a WhirConfig, +); + +impl Deref for Verifier<'_, F, EF, Hash, Compress> { + type Target = WhirConfig; + + fn deref(&self) -> &Self::Target { + self.0 + } +} -impl<'a, EF, F, H, C, Challenger> Verifier<'a, EF, F, H, C, Challenger> +impl<'a, F, EF, Hasher, Compress> Verifier<'a, F, EF, Hasher, Compress> where F: TwoAdicField, EF: ExtensionField + TwoAdicField, - Challenger: FieldChallenger + GrindingChallenger, { - pub const fn new(params: &'a WhirConfig) -> Self { + pub const fn new(params: &'a WhirConfig) -> Self { Self(params) } #[instrument(skip_all)] #[allow(clippy::too_many_lines)] - pub fn verify( + pub fn verify( &self, - proof: &WhirProof, - challenger: &mut Challenger, + transcript: &mut Transcript, parsed_commitment: &ParsedCommitment>, mut statement: EqStatement, ) -> Result, VerifierError> where - H: CryptographicHasher + Sync, - C: PseudoCompressionFunction<[F; DIGEST_ELEMS], 2> + Sync, + Hasher: CryptographicHasher + Sync, + Compress: PseudoCompressionFunction<[F; DIGEST_ELEMS], 2> + Sync, [F; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>, + Transcript: VerifierTranscript, { // During the rounds we collect constraints, combination randomness, folding randomness // and we update the claimed sum of constraint evaluation. @@ -78,11 +80,11 @@ where let mut prev_commitment = parsed_commitment.clone(); // Optional constraint building - only if we have a statement - if self.initial_phase_config.has_initial_statement() { + if self.initial_phase.has_initial_statement() { statement.concatenate(&prev_commitment.ood_statement); let constraint = Constraint::new( - challenger.sample_algebra_element(), + >::sample(transcript), statement, SelectStatement::initialize(self.num_variables), ); @@ -96,8 +98,8 @@ where // Verify initial sumcheck let folding_randomness = verify_initial_sumcheck_rounds( - &proof.initial_phase, - challenger, + transcript, + self.initial_phase, &mut claimed_eval, self.folding_factor.at_round(0), self.starting_folding_pow_bits, @@ -110,26 +112,23 @@ where let round_params = &self.round_parameters[round_index]; // Receive commitment to the folded polynomial (likely encoded at higher expansion) - let new_commitment = ParsedCommitment::<_, Hash>::parse_with_round( - proof, - challenger, + let new_commitment = ParsedCommitment::<_, [F; DIGEST_ELEMS]>::parse( + transcript, round_params.num_variables, round_params.ood_samples, - Some(round_index), - ); + )?; // Verify in-domain challenges on the previous commitment. let stir_statement = self.verify_stir_challenges( - proof, - challenger, + transcript, round_params, - &prev_commitment, + prev_commitment.root, round_folding_randomness.last().unwrap(), round_index, )?; let constraint = Constraint::new( - challenger.sample_algebra_element(), + >::sample(transcript), new_commitment.ood_statement.clone(), stir_statement, ); @@ -138,11 +137,10 @@ where // TODO: SVO optimization is not yet fully implemented // Falls back to classic sumcheck for all optimization modes - let folding_randomness = verify_sumcheck_rounds( - &proof.rounds[round_index], - challenger, - &mut claimed_eval, + let folding_randomness = verify_standard_sumcheck_rounds( + transcript, self.folding_factor.at_round(round_index + 1), + &mut claimed_eval, round_params.folding_pow_bits, )?; @@ -152,20 +150,15 @@ where prev_commitment = new_commitment; } - // In the final round we receive the full polynomial instead of a commitment. - let Some(final_evaluations) = proof.final_poly.clone() else { - panic!("Expected final polynomial"); - }; - // Observe the final polynomial to the challenger - challenger.observe_algebra_slice(final_evaluations.as_slice()); + let final_evaluations = + EvaluationsList::::new(transcript.read_many(1 << self.final_sumcheck_rounds)?); // Verify in-domain challenges on the previous commitment. let stir_statement = self.verify_stir_challenges( - proof, - challenger, + transcript, &self.final_round_config(), - &prev_commitment, + prev_commitment.root, round_folding_randomness.last().unwrap(), self.n_rounds(), )?; @@ -181,11 +174,10 @@ where // TODO: SVO optimization is not yet fully implemented // Falls back to classic sumcheck for all optimization modes - let final_sumcheck_randomness = verify_final_sumcheck_rounds( - proof.final_sumcheck.as_ref(), - challenger, - &mut claimed_eval, + let final_sumcheck_randomness = verify_standard_sumcheck_rounds( + transcript, self.final_sumcheck_rounds, + &mut claimed_eval, self.final_folding_pow_bits, )?; @@ -201,7 +193,7 @@ where // For skip case, don't reverse the randomness (prover stores it in forward order) // For non-skip case, reverse it to match the prover's storage - let is_skip_used = self.initial_phase_config.is_univariate_skip() + let is_skip_used = self.initial_phase.is_univariate_skip() && K_SKIP_SUMCHECK <= self.folding_factor.at_round(0); let point_for_eval = if is_skip_used { @@ -255,19 +247,19 @@ where /// # Errors /// Returns `VerifierError::MerkleProofInvalid` if Merkle proof verification fails /// or the prover’s data does not match the commitment. - pub fn verify_stir_challenges( + pub fn verify_stir_challenges( &self, - proof: &crate::whir::proof::WhirProof, - challenger: &mut Challenger, + transcript: &mut Transcript, params: &RoundConfig, - commitment: &ParsedCommitment>, + root: Hash, folding_randomness: &MultilinearPoint, round_index: usize, ) -> Result, VerifierError> where - H: CryptographicHasher + Sync, - C: PseudoCompressionFunction<[F; DIGEST_ELEMS], 2> + Sync, + Hasher: CryptographicHasher + Sync, + Compress: PseudoCompressionFunction<[F; DIGEST_ELEMS], 2> + Sync, [F; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>, + Transcript: VerifierTranscript, { // CRITICAL: Verify the prover's proof-of-work before generating challenges. // @@ -283,48 +275,46 @@ where // By verifying that proof-of-work *now*, we confirm that the prover "locked in" their // commitment at a significant computational cost. This gives us confidence that the // challenges we generate are unpredictable and unbiased by a cheating prover. - let pow_witness = if round_index < self.n_rounds() { - proof - .get_pow_after_commitment(round_index) - .ok_or(VerifierError::InvalidRoundIndex { index: round_index })? - } else { - // Final round uses final_pow_witness - proof.final_pow_witness - }; - if params.pow_bits > 0 && !challenger.check_witness(params.pow_bits, pow_witness) { - return Err(VerifierError::InvalidPowWitness); - } + transcript.pow(params.pow_bits)?; // Transcript checkpoint after PoW if round_index < self.n_rounds() { - challenger.sample(); + >::sample(transcript); } - let stir_challenges_indexes = get_challenge_stir_queries::( + let stir_challenges_indexes = get_challenge_stir_queries::( + transcript, params.domain_size, params.folding_factor, params.num_queries, - challenger, )?; - let dimensions = vec![Dimensions { - height: params.domain_size >> params.folding_factor, - width: 1 << params.folding_factor, - }]; - let answers = self.verify_merkle_proof( - proof, - &commitment.root, - &stir_challenges_indexes, - &dimensions, - round_index, - )?; + let answers = if round_index == 0 { + let answers = self.verify_merkle_proof( + transcript, + root, + &stir_challenges_indexes, + params.domain_size, + params.folding_factor, + )?; + + answers + .into_iter() + .map(|answer| answer.into_iter().map(EF::from).collect::>()) + .collect::>() + } else { + self.verify_merkle_proof_ext( + transcript, + root, + &stir_challenges_indexes, + params.domain_size, + params.folding_factor, + )? + }; // Determine if this is the special first round where the univariate skip is applied. let is_skip_round = round_index == 0 - && matches!( - self.initial_phase_config, - InitialPhaseConfig::WithStatementUnivariateSkip - ) + && matches!(self.initial_phase, InitialPhase::WithStatementSkip) && self.folding_factor.at_round(0) >= K_SKIP_SUMCHECK; // Compute STIR Constraints @@ -406,93 +396,88 @@ where /// /// # Errors /// Returns `VerifierError::MerkleProofInvalid` if any Merkle proof fails verification. - pub fn verify_merkle_proof( + pub fn verify_merkle_proof( &self, - proof: &WhirProof, - root: &Hash, + transcript: &mut Transcript, + root: Hash, indices: &[usize], - dimensions: &[Dimensions], - round_index: usize, - ) -> Result>, VerifierError> + domain_size: usize, + folding_factor: usize, + ) -> Result>, VerifierError> where - H: CryptographicHasher + Sync, - C: PseudoCompressionFunction<[F; DIGEST_ELEMS], 2> + Sync, + Hasher: CryptographicHasher + Sync, + Compress: PseudoCompressionFunction<[F; DIGEST_ELEMS], 2> + Sync, [F; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>, + Transcript: VerifierTranscript, { + let width = 1 << folding_factor; + let height = domain_size >> folding_factor; + let depth = log2_strict_usize(height); let mmcs = MerkleTreeMmcs::new(self.merkle_hash.clone(), self.merkle_compress.clone()); - let extension_mmcs = ExtensionMmcs::new(mmcs.clone()); - - // Determine which queries to use from the proof structure - let queries = if round_index == self.n_rounds() { - &proof.final_queries - } else { - &proof - .rounds - .get(round_index) - .ok_or_else(|| VerifierError::MerkleProofInvalid { - position: 0, - reason: format!("Round {round_index} not found in proof"), - })? - .queries - }; - - let mut results = Vec::with_capacity(indices.len()); + indices + .iter() + .map(|&index| { + let answer: Vec = transcript.read_hint_many(width)?; + let proof: Vec<[F; DIGEST_ELEMS]> = transcript.read_hint_many(depth)?; + mmcs.verify_batch( + &root, + &[Dimensions { height, width }], + index, + BatchOpeningRef { + opened_values: from_ref(&answer), + opening_proof: &proof, + }, + ) + .map_err(|_| VerifierError::MerkleProofInvalid { + position: index, + reason: "Base field Merkle proof verification failed".to_string(), + })?; + Ok(answer) + }) + .collect::, VerifierError>>() + } - for (&index, query) in indices.iter().zip(queries.iter()) { - let values_ef = match query { - QueryOpening::Base { values, proof } => { - mmcs.verify_batch( - root, - dimensions, + // TODO: implement a generic verifier with MMC trait + pub fn verify_merkle_proof_ext( + &self, + transcript: &mut Transcript, + root: Hash, + indices: &[usize], + domain_size: usize, + folding_factor: usize, + ) -> Result>, VerifierError> + where + Hasher: CryptographicHasher + Sync, + Compress: PseudoCompressionFunction<[F; DIGEST_ELEMS], 2> + Sync, + [F; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>, + Transcript: VerifierTranscript, + { + let width = 1 << folding_factor; + let height = domain_size >> folding_factor; + let depth = log2_strict_usize(height); + let mmcs = MerkleTreeMmcs::new(self.merkle_hash.clone(), self.merkle_compress.clone()); + let extension_mmcs = ExtensionMmcs::new(mmcs); + indices + .iter() + .map(|&index| { + let answer: Vec = transcript.read_hint_many(width)?; + let proof: Vec<[F; DIGEST_ELEMS]> = transcript.read_hint_many(depth)?; + extension_mmcs + .verify_batch( + &root, + &[Dimensions { height, width }], index, BatchOpeningRef { - opened_values: from_ref(values), - opening_proof: proof, + opened_values: from_ref(&answer), + opening_proof: &proof, }, ) .map_err(|_| VerifierError::MerkleProofInvalid { position: index, reason: "Base field Merkle proof verification failed".to_string(), })?; - - // Convert F -> EF - values.iter().map(|&f| f.into()).collect() - } - QueryOpening::Extension { values, proof } => { - extension_mmcs - .verify_batch( - root, - dimensions, - index, - BatchOpeningRef { - opened_values: from_ref(values), - opening_proof: proof, - }, - ) - .map_err(|_| VerifierError::MerkleProofInvalid { - position: index, - reason: "Extension field Merkle proof verification failed".to_string(), - })?; - - values.clone() - } - }; - - results.push(values_ef); - } - - Ok(results) - } -} - -impl Deref for Verifier<'_, EF, F, H, C, Challenger> -where - F: Field, - EF: ExtensionField, -{ - type Target = WhirConfig; - - fn deref(&self) -> &Self::Target { - self.0 + Ok(answer) + }) + .collect::, VerifierError>>() } } diff --git a/src/whir/verifier/sumcheck.rs b/src/whir/verifier/sumcheck.rs index c247dfd9..b4a566f4 100644 --- a/src/whir/verifier/sumcheck.rs +++ b/src/whir/verifier/sumcheck.rs @@ -1,17 +1,14 @@ -use alloc::{format, string::ToString, vec, vec::Vec}; +use alloc::{string::ToString, vec, vec::Vec}; -use p3_challenger::{FieldChallenger, GrindingChallenger}; use p3_field::{ExtensionField, TwoAdicField}; use p3_interpolation::interpolate_subgroup; use p3_matrix::dense::RowMajorMatrix; use crate::{ constant::K_SKIP_SUMCHECK, + fiat_shamir::transcript::{Challenge, Pow, Reader}, poly::multilinear::MultilinearPoint, - whir::{ - proof::{InitialPhase, SumcheckData, SumcheckSkipData, WhirRoundProof}, - verifier::VerifierError, - }, + whir::{parameters::InitialPhase, verifier::VerifierError}, }; /// Verifies standard sumcheck rounds and extracts folding randomness from the transcript. @@ -43,65 +40,50 @@ use crate::{ /// /// - A `MultilinearPoint` of folding randomness values in reverse order. /// Common helper function to verify standard sumcheck rounds -fn verify_standard_sumcheck_rounds( - polynomial_evaluations: &[[EF; 2]], - pow_witnesses: &[F], - challenger: &mut Challenger, +pub(crate) fn verify_standard_sumcheck_rounds( + transcript: &mut Transcript, + number_of_rounds: usize, claimed_sum: &mut EF, pow_bits: usize, -) -> Result, VerifierError> +) -> Result, VerifierError> where F: TwoAdicField, EF: ExtensionField + TwoAdicField, - Challenger: FieldChallenger + GrindingChallenger, + Transcript: Reader + Challenge + Pow, { - let mut randomness = Vec::with_capacity(polynomial_evaluations.len()); - - for (i, &[c0, c2]) in polynomial_evaluations.iter().enumerate() { - // Derive h(1) from the sumcheck equation: h(0) + h(1) = claimed_sum - let h_1 = *claimed_sum - c0; - - // Observe only the sent polynomial evaluations (c0 and c2) - challenger.observe_algebra_slice(&[c0, c2]); - - // Verify PoW (only if pow_bits > 0) - if pow_bits > 0 && !challenger.check_witness(pow_bits, pow_witnesses[i]) { - return Err(VerifierError::InvalidPowWitness); - } - - // Sample challenge - let r: EF = challenger.sample_algebra_element(); - - // Update claimed sum for next round using direct quadratic formula: - // h(X) = c0 + c1*X + c2*X^2 where c1 = h(1) - c0 - c2 - // h(r) = c2*r^2 + c1*r + c0 = c2*r^2 + (h(1) - c0 - c2)*r + c0 - *claimed_sum = c2 * r.square() + (h_1 - c0 - c2) * r + c0; - randomness.push(r); - } - - Ok(randomness) + let vars = (0..number_of_rounds) + .map(|_| { + let c0 = transcript.read()?; + let c2 = transcript.read()?; + let h1 = *claimed_sum - c0; + transcript.pow(pow_bits)?; + let r = transcript.sample(); + // Update claimed sum for next round using direct quadratic formula: + // h(X) = c0 + c1*X + c2*X^2 where c1 = h(1) - c0 - c2 + // h(r) = c2*r^2 + c1*r + c0 = c2*r^2 + (h(1) - c0 - c2)*r + c0 + *claimed_sum = c2 * r.square() + (h1 - c0 - c2) * r + c0; + Ok(r) + }) + .collect::, VerifierError>>()?; + Ok(MultilinearPoint::new(vars)) } -pub(crate) fn verify_initial_sumcheck_rounds( - initial_phase: &InitialPhase, - challenger: &mut Challenger, +pub(crate) fn verify_initial_sumcheck_rounds( + transcript: &mut Transcript, + initial_phase: InitialPhase, claimed_sum: &mut EF, - rounds: usize, + num_rounds: usize, pow_bits: usize, ) -> Result, VerifierError> where F: TwoAdicField, EF: ExtensionField + TwoAdicField, - Challenger: FieldChallenger + GrindingChallenger, + Transcript: Reader + Reader + Challenge + Pow, { match initial_phase { - InitialPhase::WithStatementSkip(SumcheckSkipData { - evaluations: skip_evaluations, - pow: skip_pow, - sumcheck, - }) => { + InitialPhase::WithStatementSkip => { // Handle univariate skip optimization - if rounds < K_SKIP_SUMCHECK { + if num_rounds < K_SKIP_SUMCHECK { return Err(VerifierError::SumcheckFailed { round: 0, expected: "univariate skip optimization enabled".to_string(), @@ -109,16 +91,7 @@ where }); } - // Verify skip round evaluations size - let skip_size = 1 << (K_SKIP_SUMCHECK + 1); - if skip_evaluations.len() != skip_size { - return Err(VerifierError::SumcheckFailed { - round: 0, - expected: format!("{skip_size} evaluations"), - actual: format!("{} evaluations", skip_evaluations.len()), - }); - } - + let skip_evaluations: Vec = transcript.read_many(1 << (K_SKIP_SUMCHECK + 1))?; // Verify sum over subgroup H (every other element starting from 0) let actual_sum: EF = skip_evaluations.iter().step_by(2).copied().sum(); if actual_sum != *claimed_sum { @@ -129,161 +102,41 @@ where }); } - // Observe the skip evaluations for Fiat-Shamir - challenger.observe_algebra_slice(skip_evaluations); - - if pow_bits > 0 && !challenger.check_witness(pow_bits, *skip_pow) { - return Err(VerifierError::InvalidPowWitness); - } - + // Verify pow + transcript.pow(pow_bits)?; // Sample challenge for the skip round - let r_skip: EF = challenger.sample_algebra_element(); + let r_skip = transcript.sample(); // Interpolate to get the new claimed sum after skip folding - let mat = RowMajorMatrix::new(skip_evaluations.clone(), 1); + let mat = RowMajorMatrix::new(skip_evaluations, 1); *claimed_sum = interpolate_subgroup(&mat, r_skip)[0]; // Now process the remaining standard sumcheck rounds after the skip - let remaining_rounds = rounds - K_SKIP_SUMCHECK; + let remaining_rounds = num_rounds - K_SKIP_SUMCHECK; let mut randomness = vec![r_skip]; - let standard_randomness = verify_standard_sumcheck_rounds( - &sumcheck.polynomial_evaluations[0..remaining_rounds], - &sumcheck.pow_witnesses, - challenger, + randomness.extend(verify_standard_sumcheck_rounds( + transcript, + remaining_rounds, claimed_sum, pow_bits, - )?; - - randomness.extend(standard_randomness); + )?); Ok(MultilinearPoint::new(randomness)) } - InitialPhase::WithStatement { sumcheck } => { - // Standard initial sumcheck without skip - let randomness = verify_standard_sumcheck_rounds( - &sumcheck.polynomial_evaluations, - &sumcheck.pow_witnesses, - challenger, - claimed_sum, - pow_bits, - )?; - - Ok(MultilinearPoint::new(randomness)) - } - - InitialPhase::WithoutStatement { pow_witness } => { + InitialPhase::WithoutStatement => { // No sumcheck - just sample folding randomness directly - let randomness: Vec = (0..rounds) - .map(|_| challenger.sample_algebra_element()) - .collect(); - - // Check PoW - if pow_bits > 0 && !challenger.check_witness(pow_bits, *pow_witness) { - return Err(VerifierError::InvalidPowWitness); - } - + let randomness = transcript.sample_many(num_rounds); + transcript.pow(pow_bits)?; Ok(MultilinearPoint::new(randomness)) } - InitialPhase::WithStatementSvo { sumcheck } => { - // Fallback to WithStatement behavior (WithStatementSvo not yet implemented) - let randomness = verify_standard_sumcheck_rounds( - &sumcheck.polynomial_evaluations, - &sumcheck.pow_witnesses, - challenger, - claimed_sum, - pow_bits, - )?; - - Ok(MultilinearPoint::new(randomness)) - } + InitialPhase::WithStatementClassic | InitialPhase::WithStatementSvo => Ok( + verify_standard_sumcheck_rounds(transcript, num_rounds, claimed_sum, pow_bits)?, + ), } } -/// Verify sumcheck rounds from a WhirRoundProof. -/// -/// # Returns -/// -/// - A `MultilinearPoint` of folding randomness values in reverse order. -pub(crate) fn verify_sumcheck_rounds( - round_proof: &WhirRoundProof, - challenger: &mut Challenger, - claimed_sum: &mut EF, - rounds: usize, - pow_bits: usize, -) -> Result, VerifierError> -where - F: TwoAdicField, - EF: ExtensionField + TwoAdicField, - Challenger: FieldChallenger + GrindingChallenger, -{ - let sumcheck = &round_proof.sumcheck; - - if sumcheck.polynomial_evaluations.len() != rounds { - return Err(VerifierError::SumcheckFailed { - round: 0, - expected: format!("{rounds} rounds"), - actual: format!("{} rounds in proof", sumcheck.polynomial_evaluations.len()), - }); - } - - let randomness = verify_standard_sumcheck_rounds( - &sumcheck.polynomial_evaluations, - &sumcheck.pow_witnesses, - challenger, - claimed_sum, - pow_bits, - )?; - - Ok(MultilinearPoint::new(randomness)) -} - -/// Verify the final sumcheck rounds. -/// -/// # Returns -/// -/// - A `MultilinearPoint` of folding randomness values in reverse order. -pub(crate) fn verify_final_sumcheck_rounds( - final_sumcheck: Option<&SumcheckData>, - challenger: &mut Challenger, - claimed_sum: &mut EF, - rounds: usize, - pow_bits: usize, -) -> Result, VerifierError> -where - F: TwoAdicField, - EF: ExtensionField + TwoAdicField, - Challenger: FieldChallenger + GrindingChallenger, -{ - if rounds == 0 { - // No final sumcheck expected - return Ok(MultilinearPoint::new(Vec::new())); - } - - let sumcheck = final_sumcheck.ok_or_else(|| VerifierError::SumcheckFailed { - round: 0, - expected: format!("{rounds} final sumcheck rounds"), - actual: "None".to_string(), - })?; - - if sumcheck.polynomial_evaluations.len() != rounds { - return Err(VerifierError::SumcheckFailed { - round: 0, - expected: format!("{rounds} rounds"), - actual: format!("{} rounds in proof", sumcheck.polynomial_evaluations.len()), - }); - } - - let randomness = verify_standard_sumcheck_rounds( - &sumcheck.polynomial_evaluations, - &sumcheck.pow_witnesses, - challenger, - claimed_sum, - pow_bits, - )?; - Ok(MultilinearPoint::new(randomness)) -} #[cfg(test)] mod tests { use alloc::vec; @@ -296,19 +149,20 @@ mod tests { use super::*; use crate::{ - fiat_shamir::domain_separator::{DomainSeparator, SumcheckParams}, - parameters::{FoldingFactor, ProtocolParameters, errors::SecurityAssumption}, + fiat_shamir::{ + domain_separator::{DomainSeparator, SumcheckParams}, + transcript::{FiatShamirReader, FiatShamirWriter}, + }, poly::evals::EvaluationsList, sumcheck::sumcheck_single::SumcheckSingle, whir::{ constraints::{Constraint, statement::EqStatement}, - parameters::InitialPhaseConfig, - proof::{InitialPhase, WhirProof}, + parameters::InitialPhase, }, }; type F = BabyBear; - type EF4 = BinomialExtensionField; + type EF = BinomialExtensionField; type Perm = Poseidon2BabyBear<16>; type MyHash = PaddingFreeSponge; @@ -318,36 +172,6 @@ mod tests { // Digest size matches MyCompress output size (the 3rd parameter of TruncatedPermutation) const DIGEST_ELEMS: usize = 8; - /// Constructs a default WHIR configuration for testing - fn create_proof_from_test_protocol_params( - num_variables: usize, - folding_factor: FoldingFactor, - initial_phase_config: InitialPhaseConfig, - ) -> WhirProof { - // Create hash and compression functions for the Merkle tree - let mut rng = SmallRng::seed_from_u64(1); - let perm = Perm::new_from_rng_128(&mut rng); - - let merkle_hash = MyHash::new(perm.clone()); - let merkle_compress = MyCompress::new(perm); - - // Construct WHIR protocol parameters - let whir_params = ProtocolParameters { - initial_phase_config, - security_level: 32, - pow_bits: 0, - rs_domain_initial_reduction_factor: 1, - folding_factor, - merkle_hash, - merkle_compress, - soundness_type: SecurityAssumption::UniqueDecoding, - starting_log_inv_rate: 1, - }; - - // Combine protocol and polynomial parameters into a single config - WhirProof::from_protocol_parameters(&whir_params, num_variables) - } - #[test] #[allow(clippy::too_many_lines)] fn test_read_sumcheck_rounds_variants() { @@ -375,7 +199,7 @@ mod tests { ]); // Define the actual polynomial function over EF4 - let f = |x0: EF4, x1: EF4, x2: EF4| { + let f = |x0: EF, x1: EF, x2: EF| { x2 * e2 + x1 * e3 + x1 * x2 * e4 @@ -392,17 +216,17 @@ mod tests { // Create a constraint system with evaluations of f at various points let mut statement = EqStatement::initialize(n_vars); - let x_000 = MultilinearPoint::new(vec![EF4::ZERO, EF4::ZERO, EF4::ZERO]); - let x_100 = MultilinearPoint::new(vec![EF4::ONE, EF4::ZERO, EF4::ZERO]); - let x_110 = MultilinearPoint::new(vec![EF4::ONE, EF4::ONE, EF4::ZERO]); - let x_111 = MultilinearPoint::new(vec![EF4::ONE, EF4::ONE, EF4::ONE]); - let x_011 = MultilinearPoint::new(vec![EF4::ZERO, EF4::ONE, EF4::ONE]); + let x_000 = MultilinearPoint::new(vec![EF::ZERO, EF::ZERO, EF::ZERO]); + let x_100 = MultilinearPoint::new(vec![EF::ONE, EF::ZERO, EF::ZERO]); + let x_110 = MultilinearPoint::new(vec![EF::ONE, EF::ONE, EF::ZERO]); + let x_111 = MultilinearPoint::new(vec![EF::ONE, EF::ONE, EF::ONE]); + let x_011 = MultilinearPoint::new(vec![EF::ZERO, EF::ONE, EF::ONE]); - let f_000 = f(EF4::ZERO, EF4::ZERO, EF4::ZERO); - let f_100 = f(EF4::ONE, EF4::ZERO, EF4::ZERO); - let f_110 = f(EF4::ONE, EF4::ONE, EF4::ZERO); - let f_111 = f(EF4::ONE, EF4::ONE, EF4::ONE); - let f_011 = f(EF4::ZERO, EF4::ONE, EF4::ONE); + let f_000 = f(EF::ZERO, EF::ZERO, EF::ZERO); + let f_100 = f(EF::ONE, EF::ZERO, EF::ZERO); + let f_110 = f(EF::ONE, EF::ONE, EF::ZERO); + let f_111 = f(EF::ONE, EF::ONE, EF::ONE); + let f_011 = f(EF::ZERO, EF::ONE, EF::ONE); statement.add_evaluated_constraint(x_000, f_000); statement.add_evaluated_constraint(x_100, f_100); @@ -415,7 +239,7 @@ mod tests { // Set up domain separator // - Add sumcheck - let mut domsep: DomainSeparator = DomainSeparator::new(vec![]); + let mut domsep: DomainSeparator = DomainSeparator::new(vec![]); domsep.add_sumcheck(&SumcheckParams { rounds: folding_factor, pow_bits, @@ -423,93 +247,63 @@ mod tests { }); let mut rng = SmallRng::seed_from_u64(1); - let challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); - let mut prover_challenger = challenger.clone(); - - let constraint = Constraint::new_eq_only(EF4::ONE, statement.clone()); - - // Initialize proof and challenger - let mut proof = create_proof_from_test_protocol_params( - n_vars, - FoldingFactor::Constant(folding_factor), - InitialPhaseConfig::WithStatementClassic, - ); - domsep.observe_domain_separator(&mut prover_challenger); - - // Extract sumcheck data from the initial phase - let InitialPhase::WithStatement { ref mut sumcheck } = proof.initial_phase else { - panic!("Expected WithStatement variant"); - }; + let mut challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); + domsep.observe_domain_separator(&mut challenger); + let mut transcript = FiatShamirWriter::init(challenger.clone()); + let constraint = Constraint::new_eq_only(EF::ONE, statement.clone()); // Instantiate the prover with base field coefficients - let (_, _) = SumcheckSingle::::from_base_evals( + let (_, _) = SumcheckSingle::::from_base_evals( + &mut transcript, &evals, - sumcheck, - &mut prover_challenger, folding_factor, pow_bits, &constraint, - ); - - // Reconstruct verifier state to simulate the rounds - let mut verifier_challenger = challenger; - domsep.observe_domain_separator(&mut verifier_challenger); + ) + .unwrap(); - // Save a fresh copy for verify_initial_sumcheck_rounds - let mut verifier_challenger_for_verify = verifier_challenger.clone(); + let proof = transcript.finalize(); + let mut transcript = FiatShamirReader::init(proof.clone(), challenger.clone()); let mut t = EvaluationsList::zero(statement.num_variables()); - let mut expected_initial_sum = EF4::ZERO; - statement.combine_hypercube::(&mut t, &mut expected_initial_sum, EF4::ONE); + let mut expected_initial_sum = EF::ZERO; + statement.combine_hypercube::(&mut t, &mut expected_initial_sum, EF::ONE); // Start with the claimed sum before folding let mut current_sum = expected_initial_sum; let mut expected = Vec::with_capacity(folding_factor); - // Extract and verify each sumcheck round - let InitialPhase::WithStatement { - sumcheck: initial_sumcheck_data, - } = &proof.initial_phase - else { - panic!("Expected WithStatement variant") - }; - // First round: read c_0 = h(0) and c_2 (quadratic coefficient) - let [c_0, c_2] = initial_sumcheck_data.polynomial_evaluations[0]; + let c_0: EF = transcript.read().unwrap(); + let c_2: EF = transcript.read().unwrap(); let h_1 = current_sum - c_0; - - // Observe polynomial evaluations (must match what verify_initial_sumcheck_rounds does) - verifier_challenger.observe_algebra_slice(&[c_0, c_2]); + transcript.pow(pow_bits).unwrap(); // Sample random challenge r_i ∈ EF4 and evaluate h_i(r_i) - let r: EF4 = verifier_challenger.sample_algebra_element(); + let r: EF = transcript.sample(); // h(r) = c_2 * r^2 + (h(1) - c_0 - c_2) * r + c_0 current_sum = c_2 * r.square() + (h_1 - c_0 - c_2) * r + c_0; expected.push(r); - for i in 0..folding_factor - 1 { + for _ in 0..folding_factor - 1 { // Read c_0 = h(0) and c_2 (quadratic coefficient), derive h(1) = claimed_sum - c_0 - let [c_0, c_2] = initial_sumcheck_data.polynomial_evaluations[i + 1]; + let c_0: EF = transcript.read().unwrap(); + let c_2: EF = transcript.read().unwrap(); let h_1 = current_sum - c_0; - - // Observe polynomial evaluations - verifier_challenger.observe_algebra_slice(&[c_0, c_2]); + transcript.pow(pow_bits).unwrap(); // Sample random challenge r - let r: EF4 = verifier_challenger.sample_algebra_element(); + let r: EF = transcript.sample(); // h(r) = c_2 * r^2 + (h(1) - c_0 - c_2) * r + c_0 current_sum = c_2 * r.square() + (h_1 - c_0 - c_2) * r + c_0; - if pow_bits > 0 { - // verifier_state.challenge_pow::(pow_bits).unwrap(); - } - expected.push(r); } + let mut transcript = FiatShamirReader::init(proof, challenger.clone()); let randomness = verify_initial_sumcheck_rounds( - &proof.initial_phase, - &mut verifier_challenger_for_verify, + &mut transcript, + InitialPhase::WithStatementClassic, &mut expected_initial_sum, folding_factor, pow_bits, @@ -550,13 +344,7 @@ mod tests { let mut statement = EqStatement::initialize(NUM_VARS); for i in 0..5 { let bool_point: Vec<_> = (0..NUM_VARS) - .map(|j| { - if (i >> j) & 1 == 1 { - EF4::ONE - } else { - EF4::ZERO - } - }) + .map(|j| if (i >> j) & 1 == 1 { EF::ONE } else { EF::ZERO }) .collect(); let ml_point = MultilinearPoint::new(bool_point.clone()); let expected_val = evals.evaluate_hypercube_base(&ml_point); @@ -567,7 +355,7 @@ mod tests { let pow_bits = 0; // Set up domain separator - let mut domsep: DomainSeparator = DomainSeparator::new(vec![]); + let mut domsep: DomainSeparator = DomainSeparator::new(vec![]); domsep.add_sumcheck(&SumcheckParams { rounds: folding_factor, pow_bits, @@ -575,67 +363,45 @@ mod tests { }); let mut rng = SmallRng::seed_from_u64(1); - let challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); - let mut prover_challenger = challenger.clone(); + let mut challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); + domsep.observe_domain_separator(&mut challenger); - let constraint = Constraint::new_eq_only(EF4::ONE, statement.clone()); - - // Initialize proof and challenger - let mut proof = create_proof_from_test_protocol_params( - NUM_VARS, - FoldingFactor::Constant(folding_factor), - InitialPhaseConfig::WithStatementUnivariateSkip, - ); - domsep.observe_domain_separator(&mut prover_challenger); - - // Extract skip data from the initial phase - let InitialPhase::WithStatementSkip(ref mut skip_data) = proof.initial_phase else { - panic!("Expected WithStatementSkip variant"); - }; + let mut transcript = FiatShamirWriter::init(challenger.clone()); + let constraint = Constraint::new_eq_only(EF::ONE, statement.clone()); // Instantiate the prover with base field coefficients and univariate skip - let (_, _) = SumcheckSingle::::with_skip( + let (_, _) = SumcheckSingle::::with_skip( &evals, - skip_data, - &mut prover_challenger, + &mut transcript, folding_factor, pow_bits, K_SKIP, &constraint, - ); + ) + .unwrap(); // Reconstruct verifier state to simulate the rounds - let mut verifier_challenger = challenger; - domsep.observe_domain_separator(&mut verifier_challenger); - - // Save a fresh copy for verify_initial_sumcheck_rounds - let mut verifier_challenger_for_verify = verifier_challenger.clone(); + let proof = transcript.finalize(); + let mut transcript = FiatShamirReader::init(proof.clone(), challenger.clone()); let mut t = EvaluationsList::zero(statement.num_variables()); - let mut expected_initial_sum = EF4::ZERO; - statement.combine_hypercube::(&mut t, &mut expected_initial_sum, EF4::ONE); + let mut expected_initial_sum = EF::ZERO; + statement.combine_hypercube::(&mut t, &mut expected_initial_sum, EF::ONE); // Start with the claimed sum before folding let mut current_sum = expected_initial_sum; let mut expected = Vec::new(); - // Extract skip data from the proof for verification replay - let InitialPhase::WithStatementSkip(skip_data) = &proof.initial_phase else { - panic!("Expected WithStatementSkip variant"); - }; - // First skipped round (wide DFT LDE) - let skip_evaluations = &skip_data.evaluations; + let skip_evaluations = &transcript.read_many(1 << (K_SKIP + 1)).unwrap(); // Verify sum over subgroup H (every other element starting from 0) - let actual_sum: EF4 = skip_evaluations.iter().step_by(2).copied().sum(); + let actual_sum: EF = skip_evaluations.iter().step_by(2).copied().sum(); assert_eq!(actual_sum, current_sum, "Skip round sum mismatch"); - // Observe the skip evaluations for Fiat-Shamir - verifier_challenger.observe_algebra_slice(skip_evaluations); - + transcript.pow(pow_bits).unwrap(); // Sample challenge for the skip round - let r_skip: EF4 = verifier_challenger.sample_algebra_element(); + let r_skip: EF = transcript.sample(); expected.push(r_skip); // Interpolate to get the new claimed sum after skip folding @@ -644,25 +410,25 @@ mod tests { // Remaining quadratic rounds after the skip let remaining_rounds = folding_factor - K_SKIP; - for i in 0..remaining_rounds { + for _ in 0..remaining_rounds { // Read c_0 = h(0) and c_2 (quadratic coefficient), derive h(1) = claimed_sum - c_0 - let [c_0, c_2] = skip_data.sumcheck.polynomial_evaluations[i]; + let c_0: EF = transcript.read().unwrap(); + let c_2: EF = transcript.read().unwrap(); let h_1 = current_sum - c_0; - - // Observe polynomial evaluations - verifier_challenger.observe_algebra_slice(&[c_0, c_2]); + transcript.pow(pow_bits).unwrap(); // Sample random challenge r - let r: EF4 = verifier_challenger.sample_algebra_element(); - // h(r) = c_2 * r^2 + (h(1) - c_0 - c_2) * r + c_0 + let r: EF = transcript.sample(); + // h(r) = c_2 * r^2 + (h_1 - c_0 - c_2) * r + c_0 current_sum = c_2 * r.square() + (h_1 - c_0 - c_2) * r + c_0; expected.push(r); } + let mut transcript = FiatShamirReader::init(proof, challenger.clone()); let randomness = verify_initial_sumcheck_rounds( - &proof.initial_phase, - &mut verifier_challenger_for_verify, + &mut transcript, + InitialPhase::WithStatementSkip, &mut expected_initial_sum, folding_factor, pow_bits, @@ -696,7 +462,7 @@ mod tests { // Create a constraint system with evaluations of f at a point let mut statement = EqStatement::initialize(NUM_VARS); let constraint_point: Vec<_> = (0..NUM_VARS) - .map(|j| if j % 2 == 0 { EF4::ONE } else { EF4::ZERO }) + .map(|j| if j % 2 == 0 { EF::ONE } else { EF::ZERO }) .collect(); let ml_point = MultilinearPoint::new(constraint_point); let expected_val = evals.evaluate_hypercube_base(&ml_point); @@ -706,7 +472,7 @@ mod tests { let pow_bits = 0; // Set up domain separator - let mut domsep: DomainSeparator = DomainSeparator::new(vec![]); + let mut domsep: DomainSeparator = DomainSeparator::new(vec![]); domsep.add_sumcheck(&SumcheckParams { rounds: folding_factor, pow_bits, @@ -714,40 +480,24 @@ mod tests { }); let mut rng = SmallRng::seed_from_u64(1); - let challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); - let mut prover_challenger = challenger.clone(); + let mut challenger = MyChallenger::new(Perm::new_from_rng_128(&mut rng)); + domsep.observe_domain_separator(&mut challenger); + let mut transcript = FiatShamirWriter::init(challenger.clone()); - let constraint = Constraint::new_eq_only(EF4::ONE, statement.clone()); - - // Initialize proof and challenger - let mut proof = create_proof_from_test_protocol_params( - NUM_VARS, - FoldingFactor::Constant(folding_factor), - InitialPhaseConfig::WithStatementSvo, - ); - domsep.observe_domain_separator(&mut prover_challenger); - - // Extract sumcheck data from the initial phase - let InitialPhase::WithStatementSvo { ref mut sumcheck } = proof.initial_phase else { - panic!("Expected WithStatementSvo variant"); - }; + let constraint = Constraint::new_eq_only(EF::ONE, statement.clone()); // Instantiate the prover with base field coefficients using SVO - let (_, _) = SumcheckSingle::::from_base_evals( + let (_, _) = SumcheckSingle::::from_base_evals( + &mut transcript, &evals, - sumcheck, - &mut prover_challenger, folding_factor, pow_bits, &constraint, - ); - - // Reconstruct verifier state to simulate the rounds - let mut verifier_challenger = challenger; - domsep.observe_domain_separator(&mut verifier_challenger); + ) + .unwrap(); - // Save a fresh copy for verify_initial_sumcheck_rounds - let mut verifier_challenger_for_verify = verifier_challenger.clone(); + let proof = transcript.finalize(); + let mut transcript = FiatShamirReader::init(proof.clone(), challenger.clone()); let (_, mut expected_initial_sum) = constraint.combine_new(); // Start with the claimed sum before folding @@ -755,33 +505,25 @@ mod tests { let mut expected = Vec::with_capacity(folding_factor); - // Extract sumcheck data from the proof for verification replay - let InitialPhase::WithStatementSvo { - sumcheck: svo_sumcheck, - } = &proof.initial_phase - else { - panic!("Expected WithStatementSvo variant") - }; - - for i in 0..folding_factor { + for _ in 0..folding_factor { // Read c_0 = h(0) and c_2 (quadratic coefficient), derive h(1) = claimed_sum - c_0 - let [c_0, c_2] = svo_sumcheck.polynomial_evaluations[i]; + let c_0: EF = transcript.read().unwrap(); + let c_2: EF = transcript.read().unwrap(); let h_1 = current_sum - c_0; - - // Observe polynomial evaluations - verifier_challenger.observe_algebra_slice(&[c_0, c_2]); + transcript.pow(pow_bits).unwrap(); // Sample random challenge r - let r: EF4 = verifier_challenger.sample_algebra_element(); + let r: EF = transcript.sample(); // h(r) = c_2 * r^2 + (h(1) - c_0 - c_2) * r + c_0 current_sum = c_2 * r.square() + (h_1 - c_0 - c_2) * r + c_0; expected.push(r); } + let mut transcript = FiatShamirReader::init(proof, challenger.clone()); let randomness = verify_initial_sumcheck_rounds( - &proof.initial_phase, - &mut verifier_challenger_for_verify, + &mut transcript, + InitialPhase::WithStatementSvo, &mut expected_initial_sum, folding_factor, pow_bits,