Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,7 @@ harness = false
[[bench]]
name = "sumcheck"
harness = false

[[bench]]
name = "sumcheck_svo"
harness = false
104 changes: 104 additions & 0 deletions benches/sumcheck_svo.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
use p3_challenger::DuplexChallenger;
use p3_field::extension::BinomialExtensionField;
use p3_koala_bear::{KoalaBear, Poseidon2KoalaBear};
use rand::{Rng, SeedableRng, rngs::StdRng};
use std::hint::black_box;
use std::time::Duration;
use whir::{
fiat_shamir::{domain_separator::DomainSeparator, prover::ProverState},
poly::{evals::EvaluationsList, multilinear::MultilinearPoint},
sumcheck::sumcheck_single::SumcheckSingle,
whir::statement::{Statement, point::ConstraintPoint},
};
use whir_p3 as whir;
type F = KoalaBear;
type EF = BinomialExtensionField<F, 8>;
type Poseidon16 = Poseidon2KoalaBear<16>;
type MyChallenger = DuplexChallenger<F, Poseidon16, 16, 8>;

const NUM_CONSTRAINTS: usize = 1;
const FOLDING_FACTOR: usize = 5;
const POW_BITS: usize = 0;

fn setup_prover() -> ProverState<F, EF, MyChallenger> {
let mut rng = StdRng::seed_from_u64(0);
let poseidon = Poseidon16::new_from_rng_128(&mut rng);
let challenger = MyChallenger::new(poseidon);
DomainSeparator::new(vec![]).to_prover_state(challenger)
}

fn generate_poly(num_vars: usize) -> EvaluationsList<F> {
let mut rng = StdRng::seed_from_u64(1 + num_vars as u64);
EvaluationsList::new((0..1 << num_vars).map(|_| rng.random()).collect())
}

fn generate_statement(
num_vars: usize,
poly: &EvaluationsList<F>,
num_constraints: usize,
) -> Statement<EF> {
let mut rng = StdRng::seed_from_u64(42 + num_vars as u64);
let mut statement = Statement::new(num_vars);
for _ in 0..num_constraints {
let point = MultilinearPoint::rand(&mut rng, num_vars);
let eval = poly.evaluate(&point);
statement.add_constraint(ConstraintPoint::new(point), eval);
}
statement
}

fn bench_sumcheck_prover_svo(c: &mut Criterion) {
let mut group = c.benchmark_group("SumcheckProver");
group.sample_size(100);
group.warm_up_time(Duration::from_secs(10));
for &num_vars in &[16, 18, 20] {
let poly = generate_poly(num_vars);
let statement = generate_statement(num_vars, &poly, NUM_CONSTRAINTS);

group.bench_with_input(
BenchmarkId::new("Classic", num_vars),
&num_vars,
|b, &_num_vars| {
b.iter(|| {
let mut prover = setup_prover();
let combination_randomness: EF = prover.sample();
let result = SumcheckSingle::from_base_evals(
&poly,
&statement,
combination_randomness,
&mut prover,
FOLDING_FACTOR,
POW_BITS,
);
black_box(result);
});
},
);

group.bench_with_input(
BenchmarkId::new("SVO", num_vars),
&num_vars,
|b, &_num_vars| {
b.iter(|| {
let mut prover = setup_prover();
let combination_randomness: EF = prover.sample();
let result = SumcheckSingle::from_base_evals_svo(
&poly,
&statement,
combination_randomness,
&mut prover,
FOLDING_FACTOR,
POW_BITS,
);
black_box(result);
});
},
);
}

group.finish();
}

criterion_group!(benches, bench_sumcheck_prover_svo);
criterion_main!(benches);
21 changes: 17 additions & 4 deletions src/sumcheck/small_value_utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use p3_field::Field;
use std::ops::Add;

pub const NUM_OF_ROUNDS: usize = 3;

Expand Down Expand Up @@ -49,6 +50,21 @@ where
&self.accumulators[round]
}
}

impl<F: Field> Add for Accumulators<F> {
type Output = Self;

fn add(mut self, other: Self) -> Self {
for i in 0..NUM_OF_ROUNDS {
// NUM_OF_ROUNDS is 3
for j in 0..self.accumulators[i].len() {
self.accumulators[i][j] += other.accumulators[i][j];
}
}
self
}
}

// For round i, RoundAccumulators has all the accumulators of the form A_i(u, v).
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct RoundAccumlators<F: Field> {
Expand Down Expand Up @@ -129,8 +145,7 @@ pub fn idx4_v2(index_beta: usize) -> [Option<usize>; 3] {
// Implement Procedure 6 (Page 34).
// Fijado x'' en {0, 1}^{l-3}, dadas las evaluaciones del multilineal q(x1, x2, x3) = p(x1, x2, x3, x'') en el booleano devuelve las
// evaluaciones de q en beta para todo beta in {0, 1, inf}^3.
pub fn compute_p_beta<F: Field>(current_evals: Vec<F>) -> Vec<F> {
let mut next_evals = vec![F::ZERO; 27];
pub fn compute_p_beta<F: Field>(current_evals: &[F; 8], next_evals: &mut [F; 27]) {

next_evals[0] = current_evals[0]; // 000
next_evals[1] = current_evals[1]; // 001
Expand Down Expand Up @@ -165,6 +180,4 @@ pub fn compute_p_beta<F: Field>(current_evals: Vec<F>) -> Vec<F> {
next_evals[20] = next_evals[19] - next_evals[18]; // 202
next_evals[23] = next_evals[22] - next_evals[21]; // 212
next_evals[26] = next_evals[25] - next_evals[24]; // 222

next_evals
}
23 changes: 16 additions & 7 deletions src/sumcheck/sumcheck_small_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,31 @@ pub fn compute_accumulators<F: Field>(
let mut round_2_accumulator = RoundAccumlators::<F>::new_empty(2);
let mut round_3_accumulator = RoundAccumlators::<F>::new_empty(3);

let mut evals_1_buffer = [F::ZERO; 27];
let mut evals_2_buffer = [F::ZERO; 27];
// For x'' in {0 .. 2^{l - 3}}:
for x in 0..1 << (l - NUM_OF_ROUNDS) {
// We compute p_1(beta, x'') for all beta in {0, 1, inf}^3
let current_evals_1: Vec<F> = poly_1
let current_evals_1_array: [F; 8] = poly_1
.iter()
.skip(x)
.step_by(1 << (l - NUM_OF_ROUNDS))
.cloned()
.collect();
let evals_1 = compute_p_beta(current_evals_1);
.collect::<Vec<F>>()
.try_into()
.unwrap();
compute_p_beta(&current_evals_1_array, &mut evals_1_buffer);

// We compute p_2(beta, x'') for all beta in {0, 1, inf}^3
let current_evals_2: Vec<F> = poly_2
let current_evals_2_array: [F; 8] = poly_2
.iter()
.skip(x)
.step_by(1 << (l - NUM_OF_ROUNDS))
.cloned()
.collect();
let evals_2 = compute_p_beta(current_evals_2);
.collect::<Vec<F>>()
.try_into()
.unwrap();
compute_p_beta(&current_evals_2_array, &mut evals_2_buffer);

// For each beta in {0, 1, inf}^3:
// (We have 27 = 3 ^ NUM_OF_ROUNDS number of betas)
Expand All @@ -54,7 +60,10 @@ pub fn compute_accumulators<F: Field>(
(index_accumulator_3, &mut round_3_accumulator),
] {
if let Some(index) = index_opt {
acc.accumulate_eval(evals_1[beta_index] * evals_2[beta_index], index);
acc.accumulate_eval(
evals_1_buffer[beta_index] * evals_2_buffer[beta_index],
index,
);
}
}
}
Expand Down
Loading