diff --git a/Cargo.toml b/Cargo.toml index 10a965b8..9a3b63b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -103,3 +103,7 @@ harness = false [[bench]] name = "sumcheck" harness = false + +[[bench]] +name = "sumcheck_svo" +harness = false diff --git a/benches/sumcheck_svo.rs b/benches/sumcheck_svo.rs new file mode 100644 index 00000000..695a4fee --- /dev/null +++ b/benches/sumcheck_svo.rs @@ -0,0 +1,104 @@ +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use p3_challenger::DuplexChallenger; +use p3_field::extension::BinomialExtensionField; +use p3_koala_bear::{KoalaBear, Poseidon2KoalaBear}; +use rand::{Rng, SeedableRng, rngs::StdRng}; +use std::hint::black_box; +use std::time::Duration; +use whir::{ + fiat_shamir::{domain_separator::DomainSeparator, prover::ProverState}, + poly::{evals::EvaluationsList, multilinear::MultilinearPoint}, + sumcheck::sumcheck_single::SumcheckSingle, + whir::statement::{Statement, point::ConstraintPoint}, +}; +use whir_p3 as whir; +type F = KoalaBear; +type EF = BinomialExtensionField; +type Poseidon16 = Poseidon2KoalaBear<16>; +type MyChallenger = DuplexChallenger; + +const NUM_CONSTRAINTS: usize = 1; +const FOLDING_FACTOR: usize = 5; +const POW_BITS: usize = 0; + +fn setup_prover() -> ProverState { + let mut rng = StdRng::seed_from_u64(0); + let poseidon = Poseidon16::new_from_rng_128(&mut rng); + let challenger = MyChallenger::new(poseidon); + DomainSeparator::new(vec![]).to_prover_state(challenger) +} + +fn generate_poly(num_vars: usize) -> EvaluationsList { + let mut rng = StdRng::seed_from_u64(1 + num_vars as u64); + EvaluationsList::new((0..1 << num_vars).map(|_| rng.random()).collect()) +} + +fn generate_statement( + num_vars: usize, + poly: &EvaluationsList, + num_constraints: usize, +) -> Statement { + let mut rng = StdRng::seed_from_u64(42 + num_vars as u64); + let mut statement = Statement::new(num_vars); + for _ in 0..num_constraints { + let point = MultilinearPoint::rand(&mut rng, num_vars); + let eval = poly.evaluate(&point); + statement.add_constraint(ConstraintPoint::new(point), eval); + } + statement +} + +fn bench_sumcheck_prover_svo(c: &mut Criterion) { + let mut group = c.benchmark_group("SumcheckProver"); + group.sample_size(100); + group.warm_up_time(Duration::from_secs(10)); + for &num_vars in &[16, 18, 20] { + let poly = generate_poly(num_vars); + let statement = generate_statement(num_vars, &poly, NUM_CONSTRAINTS); + + group.bench_with_input( + BenchmarkId::new("Classic", num_vars), + &num_vars, + |b, &_num_vars| { + b.iter(|| { + let mut prover = setup_prover(); + let combination_randomness: EF = prover.sample(); + let result = SumcheckSingle::from_base_evals( + &poly, + &statement, + combination_randomness, + &mut prover, + FOLDING_FACTOR, + POW_BITS, + ); + black_box(result); + }); + }, + ); + + group.bench_with_input( + BenchmarkId::new("SVO", num_vars), + &num_vars, + |b, &_num_vars| { + b.iter(|| { + let mut prover = setup_prover(); + let combination_randomness: EF = prover.sample(); + let result = SumcheckSingle::from_base_evals_svo( + &poly, + &statement, + combination_randomness, + &mut prover, + FOLDING_FACTOR, + POW_BITS, + ); + black_box(result); + }); + }, + ); + } + + group.finish(); +} + +criterion_group!(benches, bench_sumcheck_prover_svo); +criterion_main!(benches); diff --git a/src/sumcheck/small_value_utils.rs b/src/sumcheck/small_value_utils.rs index f2392a52..f5faa57b 100644 --- a/src/sumcheck/small_value_utils.rs +++ b/src/sumcheck/small_value_utils.rs @@ -1,4 +1,5 @@ use p3_field::Field; +use std::ops::Add; pub const NUM_OF_ROUNDS: usize = 3; @@ -49,6 +50,21 @@ where &self.accumulators[round] } } + +impl Add for Accumulators { + type Output = Self; + + fn add(mut self, other: Self) -> Self { + for i in 0..NUM_OF_ROUNDS { + // NUM_OF_ROUNDS is 3 + for j in 0..self.accumulators[i].len() { + self.accumulators[i][j] += other.accumulators[i][j]; + } + } + self + } +} + // For round i, RoundAccumulators has all the accumulators of the form A_i(u, v). #[derive(Debug, Clone, Eq, PartialEq)] pub struct RoundAccumlators { @@ -129,8 +145,7 @@ pub fn idx4_v2(index_beta: usize) -> [Option; 3] { // Implement Procedure 6 (Page 34). // Fijado x'' en {0, 1}^{l-3}, dadas las evaluaciones del multilineal q(x1, x2, x3) = p(x1, x2, x3, x'') en el booleano devuelve las // evaluaciones de q en beta para todo beta in {0, 1, inf}^3. -pub fn compute_p_beta(current_evals: Vec) -> Vec { - let mut next_evals = vec![F::ZERO; 27]; +pub fn compute_p_beta(current_evals: &[F; 8], next_evals: &mut [F; 27]) { next_evals[0] = current_evals[0]; // 000 next_evals[1] = current_evals[1]; // 001 @@ -165,6 +180,4 @@ pub fn compute_p_beta(current_evals: Vec) -> Vec { next_evals[20] = next_evals[19] - next_evals[18]; // 202 next_evals[23] = next_evals[22] - next_evals[21]; // 212 next_evals[26] = next_evals[25] - next_evals[24]; // 222 - - next_evals } diff --git a/src/sumcheck/sumcheck_small_value.rs b/src/sumcheck/sumcheck_small_value.rs index 2ef1f019..2052622e 100644 --- a/src/sumcheck/sumcheck_small_value.rs +++ b/src/sumcheck/sumcheck_small_value.rs @@ -19,25 +19,31 @@ pub fn compute_accumulators( let mut round_2_accumulator = RoundAccumlators::::new_empty(2); let mut round_3_accumulator = RoundAccumlators::::new_empty(3); + let mut evals_1_buffer = [F::ZERO; 27]; + let mut evals_2_buffer = [F::ZERO; 27]; // For x'' in {0 .. 2^{l - 3}}: for x in 0..1 << (l - NUM_OF_ROUNDS) { // We compute p_1(beta, x'') for all beta in {0, 1, inf}^3 - let current_evals_1: Vec = poly_1 + let current_evals_1_array: [F; 8] = poly_1 .iter() .skip(x) .step_by(1 << (l - NUM_OF_ROUNDS)) .cloned() - .collect(); - let evals_1 = compute_p_beta(current_evals_1); + .collect::>() + .try_into() + .unwrap(); + compute_p_beta(¤t_evals_1_array, &mut evals_1_buffer); // We compute p_2(beta, x'') for all beta in {0, 1, inf}^3 - let current_evals_2: Vec = poly_2 + let current_evals_2_array: [F; 8] = poly_2 .iter() .skip(x) .step_by(1 << (l - NUM_OF_ROUNDS)) .cloned() - .collect(); - let evals_2 = compute_p_beta(current_evals_2); + .collect::>() + .try_into() + .unwrap(); + compute_p_beta(¤t_evals_2_array, &mut evals_2_buffer); // For each beta in {0, 1, inf}^3: // (We have 27 = 3 ^ NUM_OF_ROUNDS number of betas) @@ -54,7 +60,10 @@ pub fn compute_accumulators( (index_accumulator_3, &mut round_3_accumulator), ] { if let Some(index) = index_opt { - acc.accumulate_eval(evals_1[beta_index] * evals_2[beta_index], index); + acc.accumulate_eval( + evals_1_buffer[beta_index] * evals_2_buffer[beta_index], + index, + ); } } } diff --git a/src/sumcheck/sumcheck_small_value_eq.rs b/src/sumcheck/sumcheck_small_value_eq.rs index 79082c8d..36c6b7c3 100644 --- a/src/sumcheck/sumcheck_small_value_eq.rs +++ b/src/sumcheck/sumcheck_small_value_eq.rs @@ -10,6 +10,7 @@ use p3_challenger::{FieldChallenger, GrindingChallenger}; use p3_field::{ExtensionField, Field}; use super::sumcheck_polynomial::SumcheckPolynomial; +use p3_maybe_rayon::prelude::*; use p3_multilinear_util::eq::eval_eq; // WE ASSUME THE NUMBER OF ROUNDS WE ARE DOING WITH SMALL VALUES IS 3 @@ -33,6 +34,46 @@ fn precompute_e_out(w: &MultilinearPoint) -> [Vec; NUM_OF_ROUNDS }) } +/// Reorders the polynomial evaluations to improve cache locality. +/// +/// Instead of the original layout, this function groups the 8 values +/// needed for each `compute_p_beta` call into contiguous blocks. +fn transpose_poly_for_svo( + poly: &EvaluationsList, + num_variables: usize, + x_out_num_vars: usize, + half_l: usize, +) -> Vec { + let num_x_in = 1 << half_l; + let num_x_out = 1 << x_out_num_vars; + let step_size = 1 << (num_variables - NUM_OF_ROUNDS); + let block_size = 8; + + // Pre-allocate the full memory for the transposed data. + let mut transposed_poly = vec![F::ZERO; 1 << num_variables]; + let x_out_block_size = num_x_in * block_size; + + // Parallelize the transposition work. + transposed_poly + .par_chunks_mut(x_out_block_size) + .enumerate() + .for_each(|(x_out, chunk)| { + // Each thread works on a separate `x_out` chunk. + for x_in in 0..num_x_in { + let start_index = (x_in << x_out_num_vars) | x_out; + + // The destination index is relative to the start of the current chunk. + let dest_base_index = x_in * block_size; + + let mut iter = poly.iter().skip(start_index).step_by(step_size); + for i in 0..block_size { + chunk[dest_base_index + i] = *iter.next().unwrap(); + } + } + }); + + transposed_poly +} // Procedure 9. Page 37. fn compute_accumulators_eq>( poly: &EvaluationsList, @@ -42,63 +83,180 @@ fn compute_accumulators_eq>( let l = poly.num_variables(); let half_l = l / 2; - let mut accumulators = Accumulators::::new_empty(); - let x_out_num_variables = half_l - NUM_OF_ROUNDS + (l % 2); debug_assert_eq!(half_l + x_out_num_variables, l - NUM_OF_ROUNDS); - for x_out in 0..1 << (x_out_num_variables) { - let mut temp_accumulators: Vec = vec![EF::ZERO; 27]; - - for x_in in 0..1 << half_l { - // We collect the evaluations of p(X_0, X_1, X_2, x_in, x_out) where - // x_in and x_out are fixed and X_0, X_1, X_2 are variables. - let start_index = (x_in << x_out_num_variables) | x_out; - let step_size = 1 << (l - NUM_OF_ROUNDS); + // Optimization number 3: Transpose the polynomial to improve cache locality. + // 1 . Transpose the polynomial befoere entering the parallel loop. + let transposed_poly = transpose_poly_for_svo(poly, l, x_out_num_variables, half_l); + + // Parallelize the outer loop over `x_out` + (0..1 << x_out_num_variables) + .into_par_iter() + .map(|x_out| { + // Each thread will compute its own set of local accumulators. + // This avoids mutable state sharing and the need for locks. + let mut local_accumulators = Accumulators::::new_empty(); + + // This inner part remains the same, but operates on local variables. + let mut temp_accumulators: Vec = vec![EF::ZERO; 27]; + let mut p_evals_buffer = [F::ZERO; 27]; + let num_x_in = 1 << half_l; + + for x_in in 0..num_x_in { + // 2. Read a contiguous block instead of jumping through memory. + let block_start = (x_out * num_x_in + x_in) * 8; + let current_evals_arr: [F; 8] = transposed_poly[block_start..block_start + 8] + .try_into() + .unwrap(); + + compute_p_beta(¤t_evals_arr, &mut p_evals_buffer); + let e_in_value = e_in[x_in]; + + for (accumulator, &p_eval) in + temp_accumulators.iter_mut().zip(p_evals_buffer.iter()) + { + *accumulator += e_in_value * p_eval; + } + } - let current_evals: Vec = poly - .iter() - .skip(start_index) - .step_by(step_size) - .copied() - .collect(); + // hardcoded accumulator distribution + // This now populates the `local_accumulators` for this specific `x_out`. + let temp_acc = &temp_accumulators; + let e_out_2 = e_out[2][x_out]; - // We compute p(beta, x_in, x_out) for all beta in {0, 1, inf}^3 - let p_evals = compute_p_beta(current_evals); - let e_in_value = e_in[x_in]; + // Pre-fetch e_out values to avoid repeated indexing and allocations. + let e0_0 = e_out[0][(0 << x_out_num_variables) | x_out]; + let e0_1 = e_out[0][(1 << x_out_num_variables) | x_out]; + let e0_2 = e_out[0][(2 << x_out_num_variables) | x_out]; + let e0_3 = e_out[0][(3 << x_out_num_variables) | x_out]; + let e1_0 = e_out[1][(0 << x_out_num_variables) | x_out]; + let e1_1 = e_out[1][(1 << x_out_num_variables) | x_out]; - for (accumulator, &p_eval) in temp_accumulators.iter_mut().zip(&p_evals) { - *accumulator += e_in_value * p_eval; - } - } + // Now we do not use the idx4 function since we are directly computing the indices. - // TODO: This can be hardcoded for better performance. - for beta_index in 0..27 { - let [index_1, index_2, index_3] = idx4_v2(beta_index); - let [_, beta_2, beta_3] = to_base_three_coeff(beta_index); - let temp_acc = temp_accumulators[beta_index]; - - // Accumulator 1: uses y = beta_2 || beta_3 - if let Some(index) = index_1 { - let y = beta_2 << 1 | beta_3; - let e_out_value = e_out[0][(y << x_out_num_variables) | x_out]; - accumulators.accumulate(0, index, e_out_value * temp_acc); - } + // beta_index = 0; b=(0,0,0); + local_accumulators.accumulate(0, 0, e0_0 * temp_acc[0]); // y=0<<1|0=0 + local_accumulators.accumulate(1, 0, e1_0 * temp_acc[0]); // y=0 + local_accumulators.accumulate(2, 0, e_out_2 * temp_acc[0]); - // Accumulator 2: uses y = beta_3 - if let Some(index) = index_2 { - let y = beta_3; - let e_out_value = e_out[1][(y << x_out_num_variables) | x_out]; - accumulators.accumulate(1, index, e_out_value * temp_acc); - } + // beta_index = 1; b=(0,0,1); + local_accumulators.accumulate(0, 0, e0_1 * temp_acc[1]); // y=0<<1|1=1 + local_accumulators.accumulate(1, 0, e1_1 * temp_acc[1]); // y=1 + local_accumulators.accumulate(2, 1, e_out_2 * temp_acc[1]); - // Accumulator 3: uses x_out directly - if let Some(index) = index_3 { - accumulators.accumulate(2, index, e_out[2][x_out] * temp_acc); - } - } - } - accumulators + // beta_index = 2; b=(0,0,2); + local_accumulators.accumulate(2, 2, e_out_2 * temp_acc[2]); + + // beta_index = 3; b=(0,1,0); + local_accumulators.accumulate(0, 0, e0_2 * temp_acc[3]); // y=1<<1|0=2 + local_accumulators.accumulate(1, 1, e1_0 * temp_acc[3]); // y=0 + local_accumulators.accumulate(2, 3, e_out_2 * temp_acc[3]); + + // beta_index = 4; b=(0,1,1); + local_accumulators.accumulate(0, 0, e0_3 * temp_acc[4]); // y=1<<1|1=3 + local_accumulators.accumulate(1, 1, e1_1 * temp_acc[4]); // y=1 + local_accumulators.accumulate(2, 4, e_out_2 * temp_acc[4]); + + // beta_index = 5; b=(0,1,2); + local_accumulators.accumulate(2, 5, e_out_2 * temp_acc[5]); + + // beta_index = 6; b=(0,2,0); + local_accumulators.accumulate(1, 2, e1_0 * temp_acc[6]); // y=0 + local_accumulators.accumulate(2, 6, e_out_2 * temp_acc[6]); + + // beta_index = 7; b=(0,2,1); + local_accumulators.accumulate(1, 2, e1_1 * temp_acc[7]); // y=1 + local_accumulators.accumulate(2, 7, e_out_2 * temp_acc[7]); + + // beta_index = 8; b=(0,2,2); + local_accumulators.accumulate(2, 8, e_out_2 * temp_acc[8]); + + // beta_index = 9; b=(1,0,0); + local_accumulators.accumulate(0, 1, e0_0 * temp_acc[9]); // y=0<<1|0=0 + local_accumulators.accumulate(1, 3, e1_0 * temp_acc[9]); // y=0 + local_accumulators.accumulate(2, 9, e_out_2 * temp_acc[9]); + + // beta_index = 10; b=(1,0,1); + local_accumulators.accumulate(0, 1, e0_1 * temp_acc[10]); // y=0<<1|1=1 + local_accumulators.accumulate(1, 3, e1_1 * temp_acc[10]); // y=1 + local_accumulators.accumulate(2, 10, e_out_2 * temp_acc[10]); + + // beta_index = 11; b=(1,0,2); + local_accumulators.accumulate(2, 11, e_out_2 * temp_acc[11]); + + // beta_index = 12; b=(1,1,0); + local_accumulators.accumulate(0, 1, e0_2 * temp_acc[12]); // y=1<<1|0=2 + local_accumulators.accumulate(1, 4, e1_0 * temp_acc[12]); // y=0 + local_accumulators.accumulate(2, 12, e_out_2 * temp_acc[12]); + + // beta_index = 13; b=(1,1,1); + local_accumulators.accumulate(0, 1, e0_3 * temp_acc[13]); // y=1<<1|1=3 + local_accumulators.accumulate(1, 4, e1_1 * temp_acc[13]); // y=1 + local_accumulators.accumulate(2, 13, e_out_2 * temp_acc[13]); + + // beta_index = 14; b=(1,1,2); + local_accumulators.accumulate(2, 14, e_out_2 * temp_acc[14]); + + // beta_index = 15; b=(1,2,0); + local_accumulators.accumulate(1, 5, e1_0 * temp_acc[15]); // y=0 + local_accumulators.accumulate(2, 15, e_out_2 * temp_acc[15]); + + // beta_index = 16; b=(1,2,1); + local_accumulators.accumulate(1, 5, e1_1 * temp_acc[16]); // y=1 + local_accumulators.accumulate(2, 16, e_out_2 * temp_acc[16]); + + // beta_index = 17; b=(1,2,2); + local_accumulators.accumulate(2, 17, e_out_2 * temp_acc[17]); + + // beta_index = 18; b=(2,0,0); + local_accumulators.accumulate(1, 6, e1_0 * temp_acc[18]); // y=0 + local_accumulators.accumulate(2, 18, e_out_2 * temp_acc[18]); + + // beta_index = 19; b=(2,0,1); + local_accumulators.accumulate(1, 6, e1_1 * temp_acc[19]); // y=1 + local_accumulators.accumulate(2, 19, e_out_2 * temp_acc[19]); + + // beta_index = 20; b=(2,0,2); + local_accumulators.accumulate(2, 20, e_out_2 * temp_acc[20]); + + // beta_index = 21; b=(2,1,0); + local_accumulators.accumulate(1, 7, e1_0 * temp_acc[21]); // y=0 + local_accumulators.accumulate(2, 21, e_out_2 * temp_acc[21]); + + // beta_index = 22; b=(2,1,1); + local_accumulators.accumulate(1, 7, e1_1 * temp_acc[22]); // y=1 + local_accumulators.accumulate(2, 22, e_out_2 * temp_acc[22]); + + // beta_index = 23; b=(2,1,2); + local_accumulators.accumulate(2, 23, e_out_2 * temp_acc[23]); + + // beta_index = 24; b=(2,2,0); + local_accumulators.accumulate(1, 8, e1_0 * temp_acc[24]); // y=0 + local_accumulators.accumulate(2, 24, e_out_2 * temp_acc[24]); + + // beta_index = 25; b=(2,2,1); + local_accumulators.accumulate(1, 8, e1_1 * temp_acc[25]); // y=1 + local_accumulators.accumulate(2, 25, e_out_2 * temp_acc[25]); + + // beta_index = 26; b=(2,2,2); + local_accumulators.accumulate(2, 26, e_out_2 * temp_acc[26]); + + // Return the computed local accumulators for this thread. + local_accumulators + }) + // Reduce the results from all threads into a single Accumulators struct. + .reduce( + || Accumulators::::new_empty(), + |mut a, b| { + for (round_a, round_b) in a.accumulators.iter_mut().zip(b.accumulators.iter()) { + for (acc_a, acc_b) in round_a.iter_mut().zip(round_b.iter()) { + *acc_a += *acc_b; + } + } + a + }, + ) } pub fn eval_eq_in_hypercube(point: &Vec) -> Vec { @@ -133,7 +291,7 @@ pub fn compute_linear_function(w: &[F], r: &[F]) -> [F; 2] { fn get_evals_from_l_and_t(l: &[F; 2], t: &[F]) -> [F; 2] { [ - t[0] * l[0], // s(0) + t[0] * l[0], // s(0) (t[1] - t[0]) * (l[1] - l[0]), //s(inf) -> l(inf) = l(1) - l(0) ] } @@ -176,8 +334,10 @@ where // 4. Receive the challenge r_1 from the verifier. let r_1: EF = prover_state.sample(); - let eval_1 = *sum - round_poly_evals[0] ; - *sum = round_poly_evals[1] * r_1.square() + (eval_1 - round_poly_evals[0] - round_poly_evals[1]) * r_1 + round_poly_evals[0]; + let eval_1 = *sum - round_poly_evals[0]; + *sum = round_poly_evals[1] * r_1.square() + + (eval_1 - round_poly_evals[0] - round_poly_evals[1]) * r_1 + + round_poly_evals[0]; // 5. Compte R_2 = [L_0(r_1), L_1(r_1), L_inf(r_1)] // L_0 (x) = 1 - x @@ -231,7 +391,9 @@ where ]; let eval_1 = *sum - round_poly_evals[0]; - *sum = round_poly_evals[1] * r_2.square() + (eval_1 - round_poly_evals[0] - round_poly_evals[1]) * r_2 + round_poly_evals[0]; + *sum = round_poly_evals[1] * r_2.square() + + (eval_1 - round_poly_evals[0] - round_poly_evals[1]) * r_2 + + round_poly_evals[0]; // Round 3 @@ -265,11 +427,12 @@ where // TODO: En realidad no hace falta mandar S_3(1) porque se dedecue usando S_3(0). prover_state.add_extension_scalars(&round_poly_evals); - let r_3: EF = prover_state.sample(); let eval_1 = *sum - round_poly_evals[0]; - *sum = round_poly_evals[1] * r_3.square() + (eval_1 - round_poly_evals[0] - round_poly_evals[1]) * r_3 + round_poly_evals[0]; + *sum = round_poly_evals[1] * r_3.square() + + (eval_1 - round_poly_evals[0] - round_poly_evals[1]) * r_3 + + round_poly_evals[0]; (r_1, r_2, r_3) } @@ -1055,11 +1218,9 @@ mod tests { // We compute l_2(0) and l_2(inf) let linear_2_evals = compute_linear_function(&w.0[..2], &[r_1]); - // We compute S_2(0) and S_2(inf) let round_poly_evals = get_evals_from_l_and_t(&linear_2_evals, &t_2_evals); - println!("ROUND 2 EQ: {:?}", round_poly_evals); // 5. Compute R_3 = [L_00(r_1, r_2), L_01(r_1, r_2), ..., L_{inf inf}(r_1, r_2)]