Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ tracing-subscriber = { version = "0.3.17", default-features = false, features =
"alloc",
"env-filter",
], optional = true }
bincode = { version = "1.3", optional = true }

[dev-dependencies]
criterion = "0.8"
Expand All @@ -86,7 +85,7 @@ proptest = { version = "1.0", default-features = true }
default = ["parallel"]
parallel = ["dep:rayon", "p3-maybe-rayon/parallel", "p3-util/parallel"]
rayon = ["dep:rayon"]
cli = ["dep:clap", "dep:tracing-subscriber", "dep:tracing-forest", "dep:bincode", "rand/default"]
cli = ["dep:clap", "dep:tracing-subscriber", "dep:tracing-forest", "rand/default"]

[[bin]]
name = "main"
Expand Down
52 changes: 30 additions & 22 deletions benches/stir_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use p3_challenger::DuplexChallenger;
use p3_field::extension::BinomialExtensionField;
use rand::{SeedableRng, rngs::SmallRng};
use whir_p3::{
fiat_shamir::domain_separator::DomainSeparator, whir::utils::get_challenge_stir_queries,
fiat_shamir::{domain_separator::DomainSeparator, transcript::FiatShamirWriter},
whir::utils::get_challenge_stir_queries,
};

type F = BabyBear;
Expand All @@ -29,51 +30,55 @@ fn bench_stir_queries(c: &mut Criterion) {
// Benchmarks from the main file use case
group.bench_function("benchmark main round 1", |b| {
b.iter(|| {
let mut challenger = setup_challenger();
get_challenge_stir_queries::<_, F, EF>(
let challenger = setup_challenger();
let mut transcript = FiatShamirWriter::<F, _>::init(challenger);
get_challenge_stir_queries::<F, _>(
&mut transcript,
black_box(67_108_864),
black_box(5),
black_box(80),
black_box(&mut challenger),
)
.unwrap()
});
});

group.bench_function("benchmark main round 2", |b| {
b.iter(|| {
let mut challenger = setup_challenger();
get_challenge_stir_queries::<_, F, EF>(
let challenger = setup_challenger();
let mut transcript = FiatShamirWriter::<F, _>::init(challenger);
get_challenge_stir_queries::<F, _>(
&mut transcript,
black_box(8_388_608),
black_box(5),
black_box(26),
black_box(&mut challenger),
)
.unwrap()
});
});

group.bench_function("benchmark main round 3", |b| {
b.iter(|| {
let mut challenger = setup_challenger();
get_challenge_stir_queries::<_, F, EF>(
let challenger = setup_challenger();
let mut transcript = FiatShamirWriter::<F, _>::init(challenger);
get_challenge_stir_queries::<F, _>(
&mut transcript,
black_box(4_194_304),
black_box(5),
black_box(11),
black_box(&mut challenger),
)
.unwrap()
});
});

group.bench_function("benchmark main round 4", |b| {
b.iter(|| {
let mut challenger = setup_challenger();
get_challenge_stir_queries::<_, F, EF>(
let challenger = setup_challenger();
let mut transcript = FiatShamirWriter::<F, _>::init(challenger);
get_challenge_stir_queries::<F, _>(
&mut transcript,
black_box(2_097_152),
black_box(5),
black_box(7),
black_box(&mut challenger),
)
.unwrap()
});
Expand All @@ -82,12 +87,13 @@ fn bench_stir_queries(c: &mut Criterion) {
// Large case: Many queries, large domain
group.bench_function("large_64_queries_64k_domain", |b| {
b.iter(|| {
let mut challenger = setup_challenger();
get_challenge_stir_queries::<_, F, EF>(
let challenger = setup_challenger();
let mut transcript = FiatShamirWriter::<F, _>::init(challenger);
get_challenge_stir_queries::<F, _>(
&mut transcript,
black_box(65536),
black_box(6),
black_box(64),
black_box(&mut challenger),
)
.unwrap()
});
Expand All @@ -96,12 +102,13 @@ fn bench_stir_queries(c: &mut Criterion) {
// Very large case: Extreme scenario
group.bench_function("very_large_256_queries_1m_domain", |b| {
b.iter(|| {
let mut challenger = setup_challenger();
get_challenge_stir_queries::<_, F, EF>(
let challenger = setup_challenger();
let mut transcript = FiatShamirWriter::<F, _>::init(challenger);
get_challenge_stir_queries::<F, _>(
&mut transcript,
black_box(1_048_576),
black_box(10),
black_box(256),
black_box(&mut challenger),
)
.unwrap()
});
Expand All @@ -110,12 +117,13 @@ fn bench_stir_queries(c: &mut Criterion) {
// Edge case: Many queries, tiny bits per query
group.bench_function("edge_100_queries_tiny_bits", |b| {
b.iter(|| {
let mut challenger = setup_challenger();
get_challenge_stir_queries::<_, F, EF>(
let challenger = setup_challenger();
let mut transcript = FiatShamirWriter::<F, _>::init(challenger);
get_challenge_stir_queries::<F, _>(
&mut transcript,
black_box(64), // domain_size
black_box(2), // folding_factor (domain becomes 16, needs 4 bits)
black_box(100), // num_queries (lots of queries, few bits each)
black_box(&mut challenger),
)
.unwrap()
});
Expand Down
42 changes: 15 additions & 27 deletions benches/sumcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@ use p3_field::extension::BinomialExtensionField;
use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation};
use rand::{Rng, SeedableRng, rngs::SmallRng};
use whir_p3::{
fiat_shamir::domain_separator::DomainSeparator,
fiat_shamir::{domain_separator::DomainSeparator, transcript::FiatShamirWriter},
parameters::{FoldingFactor, ProtocolParameters, errors::SecurityAssumption},
poly::{evals::EvaluationsList, multilinear::MultilinearPoint},
sumcheck::sumcheck_single::SumcheckSingle,
whir::{
constraints::{Constraint, statement::EqStatement},
parameters::InitialPhaseConfig,
proof::{InitialPhase, SumcheckData, WhirProof},
parameters::InitialPhase,
},
};

Expand All @@ -31,7 +30,7 @@ fn create_test_protocol_params(
let perm = Perm::new_from_rng_128(&mut rng);

ProtocolParameters {
initial_phase_config: InitialPhaseConfig::WithStatementClassic,
initial_phase: InitialPhase::WithStatementClassic,
security_level: 32,
pow_bits: 0,
rs_domain_initial_reduction_factor: 1,
Expand Down Expand Up @@ -90,9 +89,6 @@ fn bench_sumcheck_prover(c: &mut Criterion) {
// Benchmark for the classic, round-by-round sumcheck
let classic_folding_schedule = [*num_vars / 2, num_vars - (*num_vars / 2)];

// Create parameters with a dummy folding factor (we'll use manual schedule)
let params = create_test_protocol_params(FoldingFactor::Constant(2));

// Setup domain separator
let domsep: DomainSeparator<EF, F> = DomainSeparator::new(vec![]);

Expand All @@ -101,41 +97,33 @@ fn bench_sumcheck_prover(c: &mut Criterion) {
// Setup fresh challenger for each iteration
let mut challenger = setup_challenger();
domsep.observe_domain_separator(&mut challenger);

// Initialize proof
let mut proof = WhirProof::<F, EF, 8>::from_protocol_parameters(&params, *num_vars);
let mut transcript = FiatShamirWriter::init(challenger.clone());

// Create constraint using challenger directly
let statement = generate_statement(&mut challenger, *num_vars, poly, 3);
let alpha: EF = challenger.sample_algebra_element();
let constraint = Constraint::new_eq_only(alpha, statement);

// Extract sumcheck data from the initial phase
let InitialPhase::WithStatement { ref mut sumcheck } = proof.initial_phase else {
panic!("Expected WithStatement variant");
};

// First round - fold first half of variables
let (mut sumcheck_prover, _) = SumcheckSingle::from_base_evals(
&mut transcript,
poly,
sumcheck,
&mut challenger,
classic_folding_schedule[0],
0,
&constraint,
);
)
.unwrap();

// Second round - fold remaining variables
if classic_folding_schedule.len() > 1 && classic_folding_schedule[1] > 0 {
let mut sumcheck_data: SumcheckData<EF, F> = SumcheckData::default();
sumcheck_prover.compute_sumcheck_polynomials(
&mut sumcheck_data,
&mut challenger,
classic_folding_schedule[1],
0,
None,
);
proof.set_final_sumcheck_data(sumcheck_data);
sumcheck_prover
.compute_sumcheck_polynomials(
&mut transcript,
classic_folding_schedule[1],
0,
None,
)
.unwrap();
}
});
});
Expand Down
42 changes: 10 additions & 32 deletions benches/sumcheck_svo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,9 @@ use whir::{
};
use whir_p3::{
self as whir,
fiat_shamir::domain_separator::DomainSeparator,
fiat_shamir::{domain_separator::DomainSeparator, transcript::FiatShamirWriter},
parameters::{FoldingFactor, ProtocolParameters, errors::SecurityAssumption},
whir::{
constraints::Constraint,
parameters::InitialPhaseConfig,
proof::{InitialPhase, SumcheckData, WhirProof},
},
whir::{constraints::Constraint, parameters::InitialPhase},
};

type F = KoalaBear;
Expand All @@ -39,7 +35,7 @@ fn create_test_protocol_params_classic(
let perm = Poseidon16::new_from_rng_128(&mut rng);

ProtocolParameters {
initial_phase_config: InitialPhaseConfig::WithStatementClassic,
initial_phase: InitialPhase::WithStatementClassic,
security_level: 32,
pow_bits: 0,
rs_domain_initial_reduction_factor: 1,
Expand Down Expand Up @@ -95,37 +91,21 @@ fn bench_sumcheck_prover_svo(c: &mut Criterion) {
let poly = generate_poly(*num_vars);

// Classic benchmark - folding all variables in one round
let params_classic =
create_test_protocol_params_classic(FoldingFactor::Constant(*num_vars));
group.bench_with_input(BenchmarkId::new("Classic", *num_vars), &poly, |b, poly| {
b.iter(|| {
// Setup fresh challenger for each iteration
let mut challenger = setup_challenger();
domsep.observe_domain_separator(&mut challenger);

// Initialize proof
let mut proof =
WhirProof::<F, EF, 8>::from_protocol_parameters(&params_classic, *num_vars);
let mut transcript = FiatShamirWriter::init(challenger.clone());

// Create constraint using challenger directly
let statement = generate_statement(&mut challenger, *num_vars, poly, 3);
let alpha: EF = challenger.sample_algebra_element();
let constraint = Constraint::new_eq_only(alpha, statement);

// Extract sumcheck data from the initial phase
let InitialPhase::WithStatement { ref mut sumcheck } = proof.initial_phase else {
panic!("Expected WithStatement variant");
};

// Fold all variables in one round
SumcheckSingle::from_base_evals(
poly,
sumcheck,
&mut challenger,
*num_vars,
0,
&constraint,
);
SumcheckSingle::from_base_evals(&mut transcript, poly, *num_vars, 0, &constraint)
.unwrap();
});
});

Expand All @@ -135,24 +115,22 @@ fn bench_sumcheck_prover_svo(c: &mut Criterion) {
// Setup fresh challenger for each iteration
let mut challenger = setup_challenger();
domsep.observe_domain_separator(&mut challenger);
let mut transcript = FiatShamirWriter::init(challenger.clone());

// Create constraint using challenger directly
let statement = generate_statement(&mut challenger, *num_vars, poly, 1);
let alpha: EF = challenger.sample_algebra_element();
let constraint = Constraint::new_eq_only(alpha, statement);

// Create sumcheck data
let mut sumcheck_data: SumcheckData<EF, F> = SumcheckData::default();

// Fold all variables using SVO optimization
SumcheckSingle::from_base_evals_svo(
&mut transcript,
poly,
&mut sumcheck_data,
&mut challenger,
*num_vars,
0,
&constraint,
);
)
.unwrap();
});
});
}
Expand Down
Loading