From 710db292d78edf481e63e4f4b1ee5c9ac17fade8 Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Wed, 11 Dec 2024 15:41:49 +0800 Subject: [PATCH 01/18] Remove more manual `Expression` clones --- ceno_zkvm/src/chip_handler.rs | 14 ++- ceno_zkvm/src/chip_handler/general.rs | 81 +++++++------ ceno_zkvm/src/chip_handler/global_state.rs | 15 ++- ceno_zkvm/src/chip_handler/register.rs | 12 +- ceno_zkvm/src/circuit_builder.rs | 5 +- ceno_zkvm/src/expression.rs | 114 ++++++++++-------- ceno_zkvm/src/gadgets/is_lt.rs | 91 ++++++++------ ceno_zkvm/src/gadgets/is_zero.rs | 38 ++++-- ceno_zkvm/src/gadgets/signed_ext.rs | 22 +++- ceno_zkvm/src/instructions/riscv/arith_imm.rs | 2 +- ceno_zkvm/src/instructions/riscv/b_insn.rs | 9 +- .../instructions/riscv/branch/beq_circuit.rs | 6 +- .../src/instructions/riscv/branch/blt.rs | 6 +- .../src/instructions/riscv/branch/bltu.rs | 6 +- ceno_zkvm/src/instructions/riscv/div.rs | 14 +-- .../src/instructions/riscv/ecall/halt.rs | 8 +- .../src/instructions/riscv/ecall_insn.rs | 8 +- ceno_zkvm/src/instructions/riscv/i_insn.rs | 4 +- ceno_zkvm/src/instructions/riscv/im_insn.rs | 4 +- ceno_zkvm/src/instructions/riscv/insn_base.rs | 22 ++-- ceno_zkvm/src/instructions/riscv/jump/jalr.rs | 2 +- .../riscv/logic_imm/logic_imm_circuit.rs | 2 +- .../src/instructions/riscv/memory/gadget.rs | 18 +-- .../src/instructions/riscv/memory/load.rs | 6 +- .../src/instructions/riscv/memory/store.rs | 6 +- ceno_zkvm/src/instructions/riscv/mul.rs | 4 +- ceno_zkvm/src/instructions/riscv/s_insn.rs | 4 +- ceno_zkvm/src/instructions/riscv/shift.rs | 4 +- ceno_zkvm/src/instructions/riscv/shift_imm.rs | 4 +- ceno_zkvm/src/instructions/riscv/slt.rs | 4 +- ceno_zkvm/src/instructions/riscv/slti.rs | 12 +- ceno_zkvm/src/scheme/mock_prover.rs | 2 +- ceno_zkvm/src/scheme/tests.rs | 2 +- ceno_zkvm/src/scheme/utils.rs | 10 +- ceno_zkvm/src/uint.rs | 29 ++--- ceno_zkvm/src/uint/arithmetic.rs | 43 ++++--- ceno_zkvm/src/uint/logic.rs | 6 +- 37 files changed, 363 insertions(+), 276 deletions(-) diff --git a/ceno_zkvm/src/chip_handler.rs b/ceno_zkvm/src/chip_handler.rs index 8d16f342d..2978fb70c 100644 --- a/ceno_zkvm/src/chip_handler.rs +++ b/ceno_zkvm/src/chip_handler.rs @@ -17,7 +17,11 @@ pub mod utils; pub mod test; pub trait GlobalStateRegisterMachineChipOperations { - fn state_in(&mut self, pc: Expression, ts: Expression) -> Result<(), ZKVMError>; + fn state_in( + &mut self, + pc: impl ToExpr>, + ts: impl ToExpr>, + ) -> Result<(), ZKVMError>; fn state_out(&mut self, pc: Expression, ts: Expression) -> Result<(), ZKVMError>; } @@ -30,8 +34,8 @@ pub trait RegisterChipOperations, N: FnOnce( fn register_read( &mut self, name_fn: N, - register_id: impl ToExpr>, - prev_ts: Expression, + register_id: impl ToExpr> + std::marker::Copy, + prev_ts: impl ToExpr> + std::marker::Copy, ts: Expression, value: RegisterExpr, ) -> Result<(Expression, AssertLTConfig), ZKVMError>; @@ -40,8 +44,8 @@ pub trait RegisterChipOperations, N: FnOnce( fn register_write( &mut self, name_fn: N, - register_id: impl ToExpr>, - prev_ts: Expression, + register_id: impl ToExpr> + std::marker::Copy, + prev_ts: impl ToExpr> + std::marker::Copy, ts: Expression, prev_values: RegisterExpr, value: RegisterExpr, diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 780914e78..46715d85e 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -164,7 +164,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { N: FnOnce() -> NR + Clone, { let byte = self.cs.create_witin(name_fn.clone()); - self.assert_ux::<_, _, 8>(name_fn, byte.expr())?; + self.assert_ux::<_, _, 8>(name_fn, byte)?; Ok(byte) } @@ -175,7 +175,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { N: FnOnce() -> NR + Clone, { let limb = self.cs.create_witin(name_fn.clone()); - self.assert_ux::<_, _, 16>(name_fn, limb.expr())?; + self.assert_ux::<_, _, 16>(name_fn, limb)?; Ok(limb) } @@ -191,7 +191,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { N: FnOnce() -> NR + Clone, { let wit = self.cs.create_witin(name_fn.clone()); - self.require_equal(name_fn, wit.expr(), expr)?; + self.require_equal(name_fn, wit, expr)?; Ok(wit) } @@ -199,7 +199,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { pub fn require_zero( &mut self, name_fn: N, - assert_zero_expr: Expression, + assert_zero_expr: impl ToExpr>, ) -> Result<(), ZKVMError> where NR: Into, @@ -214,8 +214,8 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { pub fn require_equal( &mut self, name_fn: N, - a: Expression, - b: Expression, + a: impl ToExpr>, + b: impl ToExpr>, ) -> Result<(), ZKVMError> where NR: Into, @@ -224,8 +224,10 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.namespace( || "require_equal", |cb| { - cb.cs - .require_zero(name_fn, a.to_monomial_form() - b.to_monomial_form()) + cb.cs.require_zero( + name_fn, + a.expr().to_monomial_form() - b.expr().to_monomial_form(), + ) }, ) } @@ -241,21 +243,25 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { pub fn condition_require_equal( &mut self, name_fn: N, - cond: Expression, - target: Expression, - true_expr: Expression, - false_expr: Expression, + cond: impl ToExpr>, + target: impl ToExpr>, + true_expr: impl ToExpr>, + false_expr: impl ToExpr>, ) -> Result<(), ZKVMError> where NR: Into, N: FnOnce() -> NR, { + let cond = cond.expr(); + let target = target.expr(); + let true_expr = true_expr.expr(); + let false_expr = false_expr.expr(); // cond * (true_expr) + (1 - cond) * false_expr // => false_expr + cond * true_expr - cond * false_expr self.namespace( || "cond_require_equal", |cb| { - let cond_target = false_expr.clone() + cond.clone() * true_expr - cond * false_expr; + let cond_target = &false_expr + &cond * true_expr - cond * false_expr; cb.cs.require_zero(name_fn, target - cond_target) }, ) @@ -263,22 +269,24 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { pub fn select( &mut self, - cond: &Expression, - when_true: &Expression, - when_false: &Expression, + cond: impl ToExpr>, + when_true: impl ToExpr>, + when_false: impl ToExpr>, ) -> Expression { - cond * when_true + (1 - cond) * when_false + let cond = cond.expr(); + &cond * when_true.expr() + (1 - &cond) * when_false.expr() } pub(crate) fn assert_ux( &mut self, name_fn: N, - expr: Expression, + expr: impl ToExpr>, ) -> Result<(), ZKVMError> where NR: Into, N: FnOnce() -> NR, { + let expr = expr.expr(); match C { 16 => self.assert_u16(name_fn, expr), 14 => self.assert_u14(name_fn, expr), @@ -333,25 +341,26 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { pub(crate) fn assert_byte( &mut self, name_fn: N, - expr: Expression, + expr: impl ToExpr>, ) -> Result<(), ZKVMError> where NR: Into, N: FnOnce() -> NR, { - self.lk_record(name_fn, ROMType::U8, vec![expr])?; + self.lk_record(name_fn, ROMType::U8, vec![expr.expr()])?; Ok(()) } pub(crate) fn assert_bit( &mut self, name_fn: N, - expr: Expression, + expr: impl ToExpr>, ) -> Result<(), ZKVMError> where NR: Into, N: FnOnce() -> NR, { + let expr = expr.expr(); self.namespace( || "assert_bit", |cb| cb.cs.require_zero(name_fn, &expr * (1 - &expr)), @@ -362,10 +371,13 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { pub fn logic_u8( &mut self, rom_type: ROMType, - a: Expression, - b: Expression, - c: Expression, + a: impl ToExpr>, + b: impl ToExpr>, + c: impl ToExpr>, ) -> Result<(), ZKVMError> { + let a = a.expr(); + let b = b.expr(); + let c = c.expr(); self.lk_record(|| format!("lookup_{:?}", rom_type), rom_type, vec![a, b, c]) } @@ -402,31 +414,30 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { /// Assert that `(a < b) == c as bool`, that `a, b` are unsigned bytes, and that `c` is 0 or 1. pub fn lookup_ltu_byte( &mut self, - a: Expression, - b: Expression, - c: Expression, + a: impl ToExpr>, + b: impl ToExpr>, + c: impl ToExpr>, ) -> Result<(), ZKVMError> { self.logic_u8(ROMType::Ltu, a, b, c) } // Assert that `2^b = c` and that `b` is a 5-bit unsigned integer. pub fn lookup_pow2(&mut self, b: Expression, c: Expression) -> Result<(), ZKVMError> { - self.logic_u8(ROMType::Pow, 2.into(), b, c) + self.logic_u8(ROMType::Pow, 2, b, c) } pub(crate) fn is_equal( &mut self, - lhs: Expression, - rhs: Expression, + lhs: impl ToExpr>, + rhs: impl ToExpr>, ) -> Result<(WitIn, WitIn), ZKVMError> { + let lhs = lhs.expr(); + let rhs = rhs.expr(); let is_eq = self.create_witin(|| "is_eq"); let diff_inverse = self.create_witin(|| "diff_inverse"); - self.require_zero(|| "is equal", is_eq.expr() * &lhs - is_eq.expr() * &rhs)?; - self.require_zero( - || "is equal", - 1 - is_eq.expr() - diff_inverse.expr() * lhs + diff_inverse.expr() * rhs, - )?; + self.require_zero(|| "is equal", is_eq * &lhs - is_eq * &rhs)?; + self.require_zero(|| "is equal", 1 + diff_inverse * (rhs + lhs) - is_eq)?; Ok((is_eq, diff_inverse)) } diff --git a/ceno_zkvm/src/chip_handler/global_state.rs b/ceno_zkvm/src/chip_handler/global_state.rs index 27c28e166..ffeff9de7 100644 --- a/ceno_zkvm/src/chip_handler/global_state.rs +++ b/ceno_zkvm/src/chip_handler/global_state.rs @@ -1,17 +1,24 @@ use ff_ext::ExtensionField; use crate::{ - circuit_builder::CircuitBuilder, error::ZKVMError, expression::Expression, structs::RAMType, + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, ToExpr}, + structs::RAMType, }; use super::GlobalStateRegisterMachineChipOperations; impl GlobalStateRegisterMachineChipOperations for CircuitBuilder<'_, E> { - fn state_in(&mut self, pc: Expression, ts: Expression) -> Result<(), ZKVMError> { + fn state_in( + &mut self, + pc: impl ToExpr>, + ts: impl ToExpr>, + ) -> Result<(), ZKVMError> { let record: Vec> = vec![ Expression::Constant(E::BaseField::from(RAMType::GlobalState as u64)), - pc, - ts, + pc.expr(), + ts.expr(), ]; self.read_record(|| "state_in", RAMType::GlobalState, record) diff --git a/ceno_zkvm/src/chip_handler/register.rs b/ceno_zkvm/src/chip_handler/register.rs index d2f1ebf14..22bb7d37f 100644 --- a/ceno_zkvm/src/chip_handler/register.rs +++ b/ceno_zkvm/src/chip_handler/register.rs @@ -17,8 +17,8 @@ impl, N: FnOnce() -> NR> RegisterChipOperati fn register_read( &mut self, name_fn: N, - register_id: impl ToExpr>, - prev_ts: Expression, + register_id: impl ToExpr> + std::marker::Copy, + prev_ts: impl ToExpr> + std::marker::Copy, ts: Expression, value: RegisterExpr, ) -> Result<(Expression, AssertLTConfig), ZKVMError> { @@ -28,7 +28,7 @@ impl, N: FnOnce() -> NR> RegisterChipOperati vec![RAMType::Register.into()], vec![register_id.expr()], value.to_vec(), - vec![prev_ts.clone()], + vec![prev_ts.expr()], ] .concat(); // Write (a, v, t) @@ -60,8 +60,8 @@ impl, N: FnOnce() -> NR> RegisterChipOperati fn register_write( &mut self, name_fn: N, - register_id: impl ToExpr>, - prev_ts: Expression, + register_id: impl ToExpr> + std::marker::Copy, + prev_ts: impl ToExpr> + std::marker::Copy, ts: Expression, prev_values: RegisterExpr, value: RegisterExpr, @@ -73,7 +73,7 @@ impl, N: FnOnce() -> NR> RegisterChipOperati vec![RAMType::Register.into()], vec![register_id.expr()], prev_values.to_vec(), - vec![prev_ts.clone()], + vec![prev_ts.expr()], ] .concat(); // Write (a, v, t) diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index 709f46839..206b18c46 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -9,7 +9,7 @@ use crate::{ ROMType, chip_handler::utils::rlc_chip_record, error::ZKVMError, - expression::{Expression, Fixed, Instance, WitIn}, + expression::{Expression, Fixed, Instance, ToExpr, WitIn}, structs::{ProgramParams, ProvingKey, RAMType, VerifyingKey, WitnessId}, witness::RowMajorMatrix, }; @@ -440,8 +440,9 @@ impl ConstraintSystem { pub fn require_zero, N: FnOnce() -> NR>( &mut self, name_fn: N, - assert_zero_expr: Expression, + assert_zero_expr: impl ToExpr>, ) -> Result<(), ZKVMError> { + let assert_zero_expr = assert_zero_expr.expr(); assert!( assert_zero_expr.degree() > 0, "constant expression assert to zero ?" diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index 9d65177e0..ca22002ab 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -588,20 +588,29 @@ macro_rules! mixed_binop_instances { }; } -mixed_binop_instances!( - Add, - add, - (u8, u16, u32, u64, usize, i8, i16, i32, i64, i128, isize) -); -mixed_binop_instances!( - Sub, - sub, - (u8, u16, u32, u64, usize, i8, i16, i32, i64, i128, isize) -); -mixed_binop_instances!( - Mul, - mul, - (u8, u16, u32, u64, usize, i8, i16, i32, i64, i128, isize) +macro_rules! mixed_binop_instances_all { + ($($t:ty),*) => { + mixed_binop_instances!( + Add, + add, + ($($t),*) + ); + mixed_binop_instances!( + Sub, + sub, + ($($t),*) + ); + mixed_binop_instances!( + Mul, + mul, + ($($t),*) + ); + }; +} + +mixed_binop_instances_all!( + u8, u16, u32, u64, usize, i8, i16, i32, i64, isize, WitIn, Fixed, Instance, &WitIn, &Fixed, + &Instance ); impl Mul for Expression { @@ -767,48 +776,55 @@ impl WitIn { pub trait ToExpr { type Output; - fn expr(&self) -> Self::Output; + fn expr(self) -> Self::Output; +} + +impl ToExpr for Expression { + type Output = Expression; + fn expr(self) -> Expression { + self + } +} + +impl ToExpr for &Expression { + type Output = Expression; + fn expr(self) -> Expression { + self.clone() + } } impl ToExpr for WitIn { type Output = Expression; - fn expr(&self) -> Expression { + fn expr(self) -> Expression { Expression::WitIn(self.id) } } impl ToExpr for &WitIn { type Output = Expression; - fn expr(&self) -> Expression { + fn expr(self) -> Expression { Expression::WitIn(self.id) } } impl ToExpr for Fixed { type Output = Expression; - fn expr(&self) -> Expression { - Expression::Fixed(*self) + fn expr(self) -> Expression { + Expression::Fixed(self) } } impl ToExpr for &Fixed { type Output = Expression; - fn expr(&self) -> Expression { - Expression::Fixed(**self) + fn expr(self) -> Expression { + Expression::Fixed(*self) } } impl ToExpr for Instance { type Output = Expression; - fn expr(&self) -> Expression { - Expression::Instance(*self) - } -} - -impl> ToExpr for F { - type Output = Expression; - fn expr(&self) -> Expression { - Expression::Constant(*self) + fn expr(self) -> Expression { + Expression::Instance(self) } } @@ -823,45 +839,41 @@ macro_rules! impl_from_via_ToExpr { )* }; } -impl_from_via_ToExpr!(WitIn, Fixed, Instance); +impl_from_via_ToExpr!( + WitIn, Fixed, Instance, u8, u16, u32, u64, usize, RAMType, InsnKind, i8, i16, i32, i64, isize +); impl_from_via_ToExpr!(&WitIn, &Fixed, &Instance); -// Implement From trait for unsigned types of at most 64 bits -macro_rules! impl_from_unsigned { +// Implement ToExpr trait for unsigned types of at most 64 bits +macro_rules! impl_ToExpr_unsigned { ($($t:ty),*) => { $( - impl> From<$t> for Expression { - fn from(value: $t) -> Self { - Expression::Constant(F::from(value as u64)) + impl> ToExpr for $t { + type Output = Expression; + fn expr(self) -> Self::Output { + Expression::Constant(F::from(self as u64)) } } )* }; } -impl_from_unsigned!(u8, u16, u32, u64, usize, RAMType, InsnKind); +impl_ToExpr_unsigned!(u8, u16, u32, u64, usize, RAMType, InsnKind); -// Implement From trait for u128 separately since it requires explicit reduction -impl> From for Expression { - fn from(value: u128) -> Self { - let reduced = value.rem_euclid(F::MODULUS_U64 as u128) as u64; - Expression::Constant(F::from(reduced)) - } -} - -// Implement From trait for signed types -macro_rules! impl_from_signed { +// Implement ToExpr trait for signed types +macro_rules! impl_ToExpr_signed { ($($t:ty),*) => { $( - impl> From<$t> for Expression { - fn from(value: $t) -> Self { - let reduced = (value as i128).rem_euclid(F::MODULUS_U64 as i128) as u64; + impl> ToExpr for $t { + type Output = Expression; + fn expr(self) -> Self::Output { + let reduced = (self as i128).rem_euclid(F::MODULUS_U64 as i128) as u64; Expression::Constant(F::from(reduced)) } } )* }; } -impl_from_signed!(i8, i16, i32, i64, i128, isize); +impl_ToExpr_signed!(i8, i16, i32, i64, isize); impl Display for Expression { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/ceno_zkvm/src/gadgets/is_lt.rs b/ceno_zkvm/src/gadgets/is_lt.rs index 3885b7257..d479d0cce 100644 --- a/ceno_zkvm/src/gadgets/is_lt.rs +++ b/ceno_zkvm/src/gadgets/is_lt.rs @@ -29,22 +29,16 @@ impl AssertLTConfig { >( cb: &mut CircuitBuilder, name_fn: N, - lhs: Expression, - rhs: Expression, + lhs: impl ToExpr>, + rhs: impl ToExpr>, max_num_u16_limbs: usize, ) -> Result { cb.namespace( || "assert_lt", |cb| { let name = name_fn(); - let config = InnerLtConfig::construct_circuit( - cb, - name, - lhs, - rhs, - Expression::ONE, - max_num_u16_limbs, - )?; + let config = + InnerLtConfig::construct_circuit(cb, name, lhs, rhs, 1, max_num_u16_limbs)?; Ok(Self(config)) }, ) @@ -68,11 +62,23 @@ pub struct IsLtConfig { config: InnerLtConfig, } -impl IsLtConfig { - pub fn expr(&self) -> Expression { +impl ToExpr for IsLtConfig { + type Output = Expression; + + fn expr(self) -> Self::Output { self.is_lt.expr() } +} + +impl ToExpr for &IsLtConfig { + type Output = Expression; + fn expr(self) -> Self::Output { + (&self.is_lt).expr() + } +} + +impl IsLtConfig { pub fn construct_circuit< E: ExtensionField, NR: Into + Display + Clone, @@ -80,8 +86,8 @@ impl IsLtConfig { >( cb: &mut CircuitBuilder, name_fn: N, - lhs: Expression, - rhs: Expression, + lhs: impl ToExpr>, + rhs: impl ToExpr>, max_num_u16_limbs: usize, ) -> Result { cb.namespace( @@ -89,16 +95,10 @@ impl IsLtConfig { |cb| { let name = name_fn(); let is_lt = cb.create_witin(|| format!("{name} is_lt witin")); - cb.assert_bit(|| "is_lt_bit", is_lt.expr())?; - - let config = InnerLtConfig::construct_circuit( - cb, - name, - lhs, - rhs, - is_lt.expr(), - max_num_u16_limbs, - )?; + cb.assert_bit(|| "is_lt_bit", is_lt)?; + + let config = + InnerLtConfig::construct_circuit(cb, name, lhs, rhs, is_lt, max_num_u16_limbs)?; Ok(Self { is_lt, config }) }, ) @@ -144,11 +144,14 @@ impl InnerLtConfig { pub fn construct_circuit + Display + Clone>( cb: &mut CircuitBuilder, name: NR, - lhs: Expression, - rhs: Expression, - is_lt_expr: Expression, + lhs: impl ToExpr>, + rhs: impl ToExpr>, + is_lt_expr: impl ToExpr>, max_num_u16_limbs: usize, ) -> Result { + let lhs = lhs.expr(); + let rhs = rhs.expr(); + let is_lt_expr = is_lt_expr.expr(); assert!(max_num_u16_limbs >= 1); let mut witin_u16 = |var_name: String| -> Result { @@ -156,7 +159,7 @@ impl InnerLtConfig { || format!("var {var_name}"), |cb| { let witin = cb.create_witin(|| var_name.to_string()); - cb.assert_ux::<_, _, 16>(|| name.clone(), witin.expr())?; + cb.assert_ux::<_, _, 16>(|| name.clone(), witin)?; Ok(witin) }, ) @@ -169,7 +172,7 @@ impl InnerLtConfig { let pows = power_sequence((1 << u16::BITS).into()); let diff_expr = izip!(&diff, pows) - .map(|(record, beta)| beta * record.expr()) + .map(|(record, beta)| beta * record) .sum::>(); let range = Self::range(max_num_u16_limbs); @@ -247,8 +250,7 @@ impl AssertSignedLtConfig { || "assert_signed_lt", |cb| { let name = name_fn(); - let config = - InnerSignedLtConfig::construct_circuit(cb, name, lhs, rhs, Expression::ONE)?; + let config = InnerSignedLtConfig::construct_circuit(cb, name, lhs, rhs, 1)?; Ok(Self { config }) }, ) @@ -272,11 +274,23 @@ pub struct SignedLtConfig { config: InnerSignedLtConfig, } -impl SignedLtConfig { - pub fn expr(&self) -> Expression { +impl ToExpr for SignedLtConfig { + type Output = Expression; + + fn expr(self) -> Self::Output { self.is_lt.expr() } +} + +impl ToExpr for &SignedLtConfig { + type Output = Expression; + + fn expr(self) -> Self::Output { + (&self.is_lt).expr() + } +} +impl SignedLtConfig { pub fn construct_circuit + Display + Clone, N: FnOnce() -> NR>( cb: &mut CircuitBuilder, name_fn: N, @@ -288,9 +302,8 @@ impl SignedLtConfig { |cb| { let name = name_fn(); let is_lt = cb.create_witin(|| format!("{name} is_signed_lt witin")); - cb.assert_bit(|| "is_lt_bit", is_lt.expr())?; - let config = - InnerSignedLtConfig::construct_circuit(cb, name, lhs, rhs, is_lt.expr())?; + cb.assert_bit(|| "is_lt_bit", is_lt)?; + let config = InnerSignedLtConfig::construct_circuit(cb, name, lhs, rhs, is_lt)?; Ok(SignedLtConfig { is_lt, config }) }, @@ -324,15 +337,15 @@ impl InnerSignedLtConfig { name: NR, lhs: &UInt, rhs: &UInt, - is_lt_expr: Expression, + is_lt_expr: impl ToExpr>, ) -> Result { // Extract the sign bit. let is_lhs_neg = lhs.is_negative(cb)?; let is_rhs_neg = rhs.is_negative(cb)?; // Convert to field arithmetic. - let lhs_value = lhs.to_field_expr(is_lhs_neg.expr()); - let rhs_value = rhs.to_field_expr(is_rhs_neg.expr()); + let lhs_value = lhs.to_field_expr(&is_lhs_neg); + let rhs_value = rhs.to_field_expr(&is_rhs_neg); let config = InnerLtConfig::construct_circuit( cb, format!("{name} (lhs < rhs)"), diff --git a/ceno_zkvm/src/gadgets/is_zero.rs b/ceno_zkvm/src/gadgets/is_zero.rs index f7d749354..d8d6f2f98 100644 --- a/ceno_zkvm/src/gadgets/is_zero.rs +++ b/ceno_zkvm/src/gadgets/is_zero.rs @@ -15,11 +15,23 @@ pub struct IsZeroConfig { inverse: WitIn, } -impl IsZeroConfig { - pub fn expr(&self) -> Expression { - self.is_zero.map(|wit| wit.expr()).unwrap_or(0.into()) +impl ToExpr for IsZeroConfig { + type Output = Expression; + + fn expr(self) -> Self::Output { + self.is_zero.map(ToExpr::expr).unwrap_or(0.into()) + } +} + +impl ToExpr for &IsZeroConfig { + type Output = Expression; + + fn expr(self) -> Self::Output { + self.is_zero.map(ToExpr::expr).unwrap_or(0.into()) } +} +impl IsZeroConfig { pub fn construct_circuit, N: FnOnce() -> NR>( cb: &mut CircuitBuilder, name_fn: N, @@ -49,14 +61,14 @@ impl IsZeroConfig { let is_zero = cb.create_witin(|| "is_zero"); // x!=0 => is_zero=0 - cb.require_zero(|| "is_zero_0", is_zero.expr() * x.clone())?; + cb.require_zero(|| "is_zero_0", is_zero * &x)?; (Some(is_zero), is_zero.expr()) }; let inverse = cb.create_witin(|| "inv"); // x==0 => is_zero=1 - cb.require_one(|| "is_zero_1", is_zero_expr + x.clone() * inverse.expr())?; + cb.require_one(|| "is_zero_1", is_zero_expr + x * inverse)?; Ok(IsZeroConfig { is_zero, inverse }) }) @@ -84,11 +96,23 @@ impl IsZeroConfig { pub struct IsEqualConfig(IsZeroConfig); -impl IsEqualConfig { - pub fn expr(&self) -> Expression { +impl ToExpr for IsEqualConfig { + type Output = Expression; + + fn expr(self) -> Self::Output { self.0.expr() } +} + +impl ToExpr for &IsEqualConfig { + type Output = Expression; + + fn expr(self) -> Self::Output { + (&self.0).expr() + } +} +impl IsEqualConfig { pub fn construct_circuit, N: FnOnce() -> NR>( cb: &mut CircuitBuilder, name_fn: N, diff --git a/ceno_zkvm/src/gadgets/signed_ext.rs b/ceno_zkvm/src/gadgets/signed_ext.rs index d1dc8ed62..a446687b5 100644 --- a/ceno_zkvm/src/gadgets/signed_ext.rs +++ b/ceno_zkvm/src/gadgets/signed_ext.rs @@ -19,6 +19,22 @@ pub struct SignedExtendConfig { _marker: PhantomData, } +impl ToExpr for SignedExtendConfig { + type Output = Expression; + + fn expr(self) -> Self::Output { + self.msb.expr() + } +} + +impl ToExpr for &SignedExtendConfig { + type Output = Expression; + + fn expr(self) -> Self::Output { + self.msb.expr() + } +} + impl SignedExtendConfig { pub fn construct_limb( cb: &mut CircuitBuilder, @@ -34,10 +50,6 @@ impl SignedExtendConfig { Self::construct_circuit(cb, 8, val) } - pub fn expr(&self) -> Expression { - self.msb.expr() - } - fn construct_circuit( cb: &mut CircuitBuilder, n_bits: usize, @@ -47,7 +59,7 @@ impl SignedExtendConfig { let msb = cb.create_witin(|| "msb"); // require msb is boolean - cb.assert_bit(|| "msb is boolean", msb.expr())?; + cb.assert_bit(|| "msb is boolean", msb)?; // assert 2*val - msb*2^N_BITS is within range [0, 2^N_BITS) // - if val < 2^(N_BITS-1), then 2*val < 2^N_BITS, msb can only be zero. diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm.rs b/ceno_zkvm/src/instructions/riscv/arith_imm.rs index 2508000f2..10654abe4 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm.rs @@ -41,7 +41,7 @@ impl Instruction for AddiInstruction { let i_insn = IInstructionConfig::::construct_circuit( circuit_builder, Self::INST_KIND, - &imm.value(), + imm.value(), rs1_read.register_expr(), rd_written.register_expr(), false, diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index c638d314f..1c44098bf 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -69,14 +69,9 @@ impl BInstructionConfig { ))?; // Branch program counter - let pc_offset = - branch_taken_bit.clone() * imm.expr() - branch_taken_bit * PC_STEP_SIZE + PC_STEP_SIZE; + let pc_offset = &branch_taken_bit * imm - branch_taken_bit * PC_STEP_SIZE + PC_STEP_SIZE; let next_pc = vm_state.next_pc.unwrap(); - circuit_builder.require_equal( - || "pc_branch", - next_pc.expr(), - vm_state.pc.expr() + pc_offset, - )?; + circuit_builder.require_equal(|| "pc_branch", next_pc, vm_state.pc + pc_offset)?; Ok(BInstructionConfig { vm_state, diff --git a/ceno_zkvm/src/instructions/riscv/branch/beq_circuit.rs b/ceno_zkvm/src/instructions/riscv/branch/beq_circuit.rs index 4826c94bf..249231d5a 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/beq_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/beq_circuit.rs @@ -7,7 +7,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, error::ZKVMError, - expression::Expression, + expression::ToExpr, gadgets::IsEqualConfig, instructions::{ Instruction, @@ -49,8 +49,8 @@ impl Instruction for BeqCircuit { )?; let branch_taken_bit = match I::INST_KIND { - InsnKind::BEQ => equal.expr(), - InsnKind::BNE => Expression::ONE - equal.expr(), + InsnKind::BEQ => (&equal).expr(), + InsnKind::BNE => 1 - (&equal).expr(), _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), }; diff --git a/ceno_zkvm/src/instructions/riscv/branch/blt.rs b/ceno_zkvm/src/instructions/riscv/branch/blt.rs index c5e0798f2..f72ad04bd 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/blt.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/blt.rs @@ -6,7 +6,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, error::ZKVMError, - expression::Expression, + expression::ToExpr, gadgets::SignedLtConfig, instructions::{ Instruction, @@ -42,8 +42,8 @@ impl Instruction for BltCircuit { SignedLtConfig::construct_circuit(circuit_builder, || "rs1 is_lt.expr(), - InsnKind::BGE => Expression::ONE - is_lt.expr(), + InsnKind::BLT => (&is_lt).expr(), + InsnKind::BGE => 1 - (&is_lt).expr(), _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), }; diff --git a/ceno_zkvm/src/instructions/riscv/branch/bltu.rs b/ceno_zkvm/src/instructions/riscv/branch/bltu.rs index 896bf19da..a6ec83add 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/bltu.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/bltu.rs @@ -6,7 +6,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, error::ZKVMError, - expression::Expression, + expression::ToExpr, gadgets::IsLtConfig, instructions::{ Instruction, @@ -51,8 +51,8 @@ impl Instruction for BltuCircuit )?; let branch_taken_bit = match I::INST_KIND { - InsnKind::BLTU => is_lt.expr(), - InsnKind::BGEU => Expression::ONE - is_lt.expr(), + InsnKind::BLTU => (&is_lt).expr(), + InsnKind::BGEU => 1 - (&is_lt).expr(), _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), }; diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index e9cb2e4ca..266e5ae1f 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -10,7 +10,7 @@ use super::{ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, - expression::Expression, + expression::{Expression, ToExpr}, gadgets::{IsLtConfig, IsZeroConfig}, instructions::Instruction, uint::Value, @@ -79,10 +79,10 @@ impl Instruction for ArithInstruction::TOTAL_BITS) - 1).into(), - outcome_value, + &is_zero, + &outcome_value, + Expression::from((1u64 << UInt::::TOTAL_BITS) - 1), + &outcome_value, )?; // remainder should be less than divisor if divisor != 0. @@ -98,8 +98,8 @@ impl Instruction for ArithInstruction::construct_circuit( diff --git a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs index 47ed7b9b7..3d3e2ec75 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs @@ -2,7 +2,7 @@ use crate::{ chip_handler::RegisterChipOperations, circuit_builder::CircuitBuilder, error::ZKVMError, - expression::{ToExpr, WitIn}, + expression::{Expression, ToExpr, WitIn}, gadgets::AssertLTConfig, instructions::{ Instruction, @@ -47,11 +47,13 @@ impl Instruction for HaltInstruction { Some(EXIT_PC.into()), )?; + // let reg_: usize = ceno_emul::Platform::reg_arg0(); // read exit_code from arg0 (X10 register) let (_, lt_x10_cfg) = cb.register_read( || "read x10", - E::BaseField::from(ceno_emul::Platform::reg_arg0() as u64), - prev_x10_ts.expr(), + // TODO(Matthias): clean up. + &Expression::Constant(E::BaseField::from(ceno_emul::Platform::reg_arg0() as u64)), + prev_x10_ts, ecall_cfg.ts.expr() + Tracer::SUBCYCLE_RS2, exit_code, )?; diff --git a/ceno_zkvm/src/instructions/riscv/ecall_insn.rs b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs index 3bd2faa1e..1e8571906 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs @@ -31,10 +31,10 @@ impl EcallInstructionConfig { let pc = cb.create_witin(|| "pc"); let ts = cb.create_witin(|| "cur_ts"); - cb.state_in(pc.expr(), ts.expr())?; + cb.state_in(pc, ts)?; cb.state_out( next_pc.map_or(pc.expr() + PC_STEP_SIZE, |next_pc| next_pc), - ts.expr() + (Tracer::SUBCYCLES_PER_INSN as usize), + ts.expr() + Tracer::SUBCYCLES_PER_INSN, )?; cb.lk_fetch(&InsnRecord::new( @@ -51,8 +51,8 @@ impl EcallInstructionConfig { // read syscall_id from x5 and write return value to x5 let (_, lt_x5_cfg) = cb.register_write( || "write x5", - E::BaseField::from(Platform::reg_ecall() as u64), - prev_x5_ts.expr(), + Platform::reg_ecall(), + prev_x5_ts, ts.expr() + Tracer::SUBCYCLE_RS1, syscall_id.clone(), syscall_ret_value.map_or(syscall_id, |v| v), diff --git a/ceno_zkvm/src/instructions/riscv/i_insn.rs b/ceno_zkvm/src/instructions/riscv/i_insn.rs index 65beb8c5f..32f748da1 100644 --- a/ceno_zkvm/src/instructions/riscv/i_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/i_insn.rs @@ -28,7 +28,7 @@ impl IInstructionConfig { pub fn construct_circuit( circuit_builder: &mut CircuitBuilder, insn_kind: InsnKind, - imm: &Expression, + imm: impl ToExpr>, rs1_read: RegisterExpr, rd_written: RegisterExpr, branching: bool, @@ -49,7 +49,7 @@ impl IInstructionConfig { Some(rd.id.expr()), rs1.id.expr(), 0.into(), - imm.clone(), + imm.expr(), ))?; Ok(IInstructionConfig { vm_state, rs1, rd }) diff --git a/ceno_zkvm/src/instructions/riscv/im_insn.rs b/ceno_zkvm/src/instructions/riscv/im_insn.rs index 6727f6628..e51ffae99 100644 --- a/ceno_zkvm/src/instructions/riscv/im_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/im_insn.rs @@ -26,7 +26,7 @@ impl IMInstructionConfig { pub fn construct_circuit( circuit_builder: &mut CircuitBuilder, insn_kind: InsnKind, - imm: &Expression, + imm: impl ToExpr>, rs1_read: RegisterExpr, memory_read: MemoryExpr, memory_addr: AddressExpr, @@ -49,7 +49,7 @@ impl IMInstructionConfig { Some(rd.id.expr()), rs1.id.expr(), 0.into(), - imm.clone(), + imm.expr(), ))?; Ok(IMInstructionConfig { diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index bcdae575f..c17eedf42 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -92,7 +92,7 @@ impl ReadRS1 { let (_, lt_cfg) = circuit_builder.register_read( || "read_rs1", id, - prev_ts.expr(), + prev_ts, cur_ts.expr() + Tracer::SUBCYCLE_RS1, rs1_read, )?; @@ -146,7 +146,7 @@ impl ReadRS2 { let (_, lt_cfg) = circuit_builder.register_read( || "read_rs2", id, - prev_ts.expr(), + prev_ts, cur_ts.expr() + Tracer::SUBCYCLE_RS2, rs2_read, )?; @@ -201,7 +201,7 @@ impl WriteRD { let (_, lt_cfg) = circuit_builder.register_write( || "write_rd", id, - prev_ts.expr(), + prev_ts, cur_ts.expr() + Tracer::SUBCYCLE_RD, prev_value.register_expr(), rd_written, @@ -420,10 +420,11 @@ impl MemAddr { .sum(); // Range check the middle bits, that is the low limb excluding the low bits. - let shift_right = E::BaseField::from(1 << Self::N_LOW_BITS) - .invert() - .unwrap() - .expr(); + // TODO(Matthias): division here seems very, very suspicious from a soundness perspective. + // TODO(Matthias): clean up. + let shift_right = Expression::Constant( + E::BaseField::from(1 << Self::N_LOW_BITS).invert().unwrap(), /* TODO: do something about this. */ + ); let mid_u14 = (&limbs[0] - low_sum) * shift_right; cb.assert_ux::<_, _, 14>(|| "mid_u14", mid_u14)?; @@ -477,6 +478,7 @@ mod test { ROMType, circuit_builder::{CircuitBuilder, ConstraintSystem}, error::ZKVMError, + expression::Expression, scheme::mock_prover::MockProver, witness::{LkMultiplicity, RowMajorMatrix}, }; @@ -535,9 +537,9 @@ mod test { assert_eq!(lkm[ROMType::U16 as usize].len(), 1); if is_ok { - cb.require_equal(|| "", mem_addr.expr_unaligned(), addr.into())?; - cb.require_equal(|| "", mem_addr.expr_align2(), (addr & !1).into())?; - cb.require_equal(|| "", mem_addr.expr_align4(), (addr & !3).into())?; + cb.require_equal(|| "", mem_addr.expr_unaligned(), Expression::from(addr))?; + cb.require_equal(|| "", mem_addr.expr_align2(), Expression::from(addr & !1))?; + cb.require_equal(|| "", mem_addr.expr_align4(), Expression::from(addr & !3))?; } MockProver::assert_with_expected_errors( &cb, diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs index 0339a6b0a..b8778a7f7 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -51,7 +51,7 @@ impl Instruction for JalrInstruction { let i_insn = IInstructionConfig::construct_circuit( circuit_builder, InsnKind::JALR, - &imm.expr(), + imm, rs1_read.register_expr(), rd_written.register_expr(), true, diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs index b6a8bb690..bcf05fd0f 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs @@ -87,7 +87,7 @@ impl LogicConfig { let i_insn = IInstructionConfig::::construct_circuit( cb, insn_kind, - &imm.value(), + imm.value(), rs1_read.register_expr(), rd_written.register_expr(), false, diff --git a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs index 6539c1325..07aa5dac1 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs @@ -55,7 +55,7 @@ impl MemWordChange { .iter() .enumerate() .map(|(idx, byte)| byte.expr() << (idx * 8)) - .sum(), + .sum::>(), )?; Ok(bytes) @@ -80,27 +80,27 @@ impl MemWordChange { let u8_base_inv = E::BaseField::from(1 << 8).invert().unwrap(); cb.assert_ux::<_, _, 8>( || "rs2_limb[0].le_bytes[1]", - u8_base_inv.expr() * (&rs2_limbs[0] - rs2_limb_bytes[0].expr()), + Expression::Constant(u8_base_inv) * (&rs2_limbs[0] - rs2_limb_bytes[0].expr()), )?; // alloc a new witIn to cache degree 2 expression let expected_limb_change = cb.create_witin(|| "expected_limb_change"); cb.condition_require_equal( || "expected_limb_change = select(low_bits[0], rs2 - prev)", - low_bits[0].clone(), - expected_limb_change.expr(), - (rs2_limb_bytes[0].expr() - prev_limb_bytes[1].expr()) << 8, - rs2_limb_bytes[0].expr() - prev_limb_bytes[0].expr(), + &low_bits[0], + expected_limb_change, + (rs2_limb_bytes[0].expr() - prev_limb_bytes[1]) << 8, + rs2_limb_bytes[0].expr() - prev_limb_bytes[0], )?; // alloc a new witIn to cache degree 2 expression let expected_change = cb.create_witin(|| "expected_change"); cb.condition_require_equal( || "expected_change = select(low_bits[1], limb_change*2^16, limb_change)", - low_bits[1].clone(), - expected_change.expr(), + &low_bits[1], + expected_change, expected_limb_change.expr() << 16, - expected_limb_change.expr(), + expected_limb_change, )?; Ok(MemWordChange { diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index a06c17687..ccd68c9a5 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -97,7 +97,7 @@ impl Instruction for LoadInstruction Instruction for LoadInstruction Instruction for LoadInstruction::construct_circuit( circuit_builder, I::INST_KIND, - &imm.expr(), + imm, rs1_read.register_expr(), memory_read.memory_expr(), memory_addr.expr_align4(), diff --git a/ceno_zkvm/src/instructions/riscv/memory/store.rs b/ceno_zkvm/src/instructions/riscv/memory/store.rs index d1d941b97..2b2a39fb1 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store.rs @@ -2,7 +2,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, error::ZKVMError, - expression::{ToExpr, WitIn}, + expression::WitIn, instructions::{ Instruction, riscv::{ @@ -87,7 +87,7 @@ impl Instruction circuit_builder.require_equal( || "memory_addr = rs1_read + imm", memory_addr.expr_unaligned(), - rs1_read.value() + imm.expr(), + rs1_read.value() + imm, )?; let (new_memory_value, word_change) = match I::INST_KIND { @@ -107,7 +107,7 @@ impl Instruction let s_insn = SInstructionConfig::::construct_circuit( circuit_builder, I::INST_KIND, - &imm.expr(), + imm, rs1_read.register_expr(), rs2_read.register_expr(), memory_addr.expr_align4(), diff --git a/ceno_zkvm/src/instructions/riscv/mul.rs b/ceno_zkvm/src/instructions/riscv/mul.rs index 58b410960..458de49e3 100644 --- a/ceno_zkvm/src/instructions/riscv/mul.rs +++ b/ceno_zkvm/src/instructions/riscv/mul.rs @@ -87,7 +87,7 @@ use goldilocks::SmallField; use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, - expression::Expression, + expression::{Expression, ToExpr}, gadgets::{IsEqualConfig, SignedExtendConfig}, instructions::{ Instruction, @@ -415,7 +415,7 @@ impl Signed { ) -> Result { cb.namespace(name_fn, |cb| { let is_negative = unsigned_val.is_negative(cb)?; - let val = unsigned_val.value() - (1u64 << BIT_WIDTH) * is_negative.expr(); + let val = unsigned_val.value() - (1u64 << BIT_WIDTH) * (&is_negative).expr(); Ok(Self { is_negative, val }) }) diff --git a/ceno_zkvm/src/instructions/riscv/s_insn.rs b/ceno_zkvm/src/instructions/riscv/s_insn.rs index dc133e894..4158bcd36 100644 --- a/ceno_zkvm/src/instructions/riscv/s_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/s_insn.rs @@ -27,7 +27,7 @@ impl SInstructionConfig { pub fn construct_circuit( circuit_builder: &mut CircuitBuilder, insn_kind: InsnKind, - imm: &Expression, + imm: impl ToExpr>, rs1_read: RegisterExpr, rs2_read: RegisterExpr, memory_addr: AddressExpr, @@ -48,7 +48,7 @@ impl SInstructionConfig { None, rs1.id.expr(), rs2.id.expr(), - imm.clone(), + imm.expr(), ))?; // Memory diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index 5b8735311..32f1966f4 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -113,8 +113,8 @@ impl Instruction for ShiftLogicalInstru let (inflow, signed_extend_config) = match I::INST_KIND { InsnKind::SRA => { let signed_extend_config = rs1_read.is_negative(circuit_builder)?; - let msb_expr = signed_extend_config.expr(); - let ones = pow2_rs2_low5.expr() - Expression::ONE; + let msb_expr = (&signed_extend_config).expr(); + let ones = (&pow2_rs2_low5).expr() - 1; (msb_expr * ones, Some(signed_extend_config)) } InsnKind::SRL => (Expression::ZERO, None), diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index 4e2700914..9ae8ffe9d 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -107,7 +107,7 @@ impl Instruction for ShiftImmInstructio InsnKind::SRAI => { let is_rs1_neg = rs1_read.is_negative(circuit_builder)?; let ones = imm.expr() - 1; - (is_rs1_neg.expr() * ones, Some(is_rs1_neg)) + ((&is_rs1_neg).expr() * ones, Some(is_rs1_neg)) } InsnKind::SRLI => (Expression::ZERO, None), _ => unreachable!(), @@ -125,7 +125,7 @@ impl Instruction for ShiftImmInstructio let i_insn = IInstructionConfig::::construct_circuit( circuit_builder, I::INST_KIND, - &imm.expr(), + imm, rs1_read.register_expr(), rd_written.register_expr(), false, diff --git a/ceno_zkvm/src/instructions/riscv/slt.rs b/ceno_zkvm/src/instructions/riscv/slt.rs index 80fc69874..bc171bdb8 100644 --- a/ceno_zkvm/src/instructions/riscv/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/slt.rs @@ -65,7 +65,7 @@ impl Instruction for SetLessThanInstruc InsnKind::SLT => { let signed_lt = SignedLtConfig::construct_circuit(cb, || "rs1 < rs2", &rs1_read, &rs2_read)?; - let rd_written = UInt::from_exprs_unchecked(vec![signed_lt.expr()]); + let rd_written = UInt::from_exprs_unchecked(vec![&signed_lt]); (SetLessThanDependencies::Slt { signed_lt }, rd_written) } InsnKind::SLTU => { @@ -76,7 +76,7 @@ impl Instruction for SetLessThanInstruc rs2_read.value(), UINT_LIMBS, )?; - let rd_written = UInt::from_exprs_unchecked(vec![is_lt.expr()]); + let rd_written = UInt::from_exprs_unchecked(vec![&is_lt]); (SetLessThanDependencies::Sltu { is_lt }, rd_written) } _ => unreachable!(), diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index 76894f7a0..be2315705 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -66,19 +66,21 @@ impl Instruction for SetLessThanImmInst InsnKind::SLTIU => (rs1_read.value(), None), InsnKind::SLTI => { let is_rs1_neg = rs1_read.is_negative(cb)?; - (rs1_read.to_field_expr(is_rs1_neg.expr()), Some(is_rs1_neg)) + ( + rs1_read.to_field_expr((&is_rs1_neg).expr()), + Some(is_rs1_neg), + ) } _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), }; - let lt = - IsLtConfig::construct_circuit(cb, || "rs1 < imm", value_expr, imm.expr(), UINT_LIMBS)?; - let rd_written = UInt::from_exprs_unchecked(vec![lt.expr()]); + let lt = IsLtConfig::construct_circuit(cb, || "rs1 < imm", value_expr, imm, UINT_LIMBS)?; + let rd_written = UInt::from_exprs_unchecked(vec![(<).expr()]); let i_insn = IInstructionConfig::::construct_circuit( cb, I::INST_KIND, - &imm.expr(), + imm, rs1_read.register_expr(), rd_written.register_expr(), false, diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index ccd8e0a07..a5cf738b7 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -1369,7 +1369,7 @@ mod tests { fn construct_circuit(cb: &mut CircuitBuilder) -> Result { let a = cb.create_witin(|| "a"); let b = cb.create_witin(|| "b"); - let lt_wtns = AssertLTConfig::construct_circuit(cb, || "lt", a.expr(), b.expr(), 1)?; + let lt_wtns = AssertLTConfig::construct_circuit(cb, || "lt", a, b, 1)?; Ok(Self { a, b, lt_wtns }) } diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 13ef29a66..42f8630e3 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -65,7 +65,7 @@ impl Instruction for Test Result::<(), ZKVMError>::Ok(()) })?; (0..L).try_for_each(|_| { - cb.assert_ux::<_, _, 16>(|| "regid_in_range", reg_id.expr())?; + cb.assert_ux::<_, _, 16>(|| "regid_in_range", reg_id)?; Result::<(), ZKVMError>::Ok(()) })?; assert_eq!(cb.cs.lk_expressions.len(), L); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index c8ec6453a..1e8e31276 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -19,7 +19,9 @@ use rayon::{ }; use crate::{ - expression::Expression, scheme::constants::MIN_PAR_SIZE, utils::next_pow2_instance_padding, + expression::{Expression, ToExpr}, + scheme::constants::MIN_PAR_SIZE, + utils::next_pow2_instance_padding, }; /// interleaving multiple mles into mles, and num_limbs indicate number of final limbs vector @@ -350,7 +352,7 @@ pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( pub(crate) fn eval_by_expr( witnesses: &[E], challenges: &[E], - expr: &Expression, + expr: impl ToExpr>, ) -> E { eval_by_expr_with_fixed(&[], witnesses, challenges, expr) } @@ -359,9 +361,9 @@ pub(crate) fn eval_by_expr_with_fixed( fixed: &[E], witnesses: &[E], challenges: &[E], - expr: &Expression, + expr: impl ToExpr>, ) -> E { - expr.evaluate::( + expr.expr().evaluate::( &|f| fixed[f.0], &|witness_id| witnesses[witness_id as usize], &|scalar| scalar.into(), diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index 193d34f13..a3de505e4 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -113,7 +113,7 @@ impl UIntLimbs { .map(|i| { let w = cb.create_witin(|| format!("limb_{i}")); if is_check { - cb.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), w.expr())?; + cb.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), w)?; } // skip range check Ok(w) @@ -185,10 +185,10 @@ impl UIntLimbs { .map(|i| { let w = circuit_builder.create_witin(|| "wit for limb"); circuit_builder - .assert_ux::<_, _, C>(|| "range check", w.expr()) + .assert_ux::<_, _, C>(|| "range check", w) .unwrap(); circuit_builder - .require_zero(|| "create_witin_from_expr", w.expr() - &expr_limbs[i]) + .require_zero(|| "create_witin_from_expr", w - &expr_limbs[i]) .unwrap(); w }) @@ -315,9 +315,8 @@ impl UIntLimbs { chunk .iter() .zip(shift_pows.iter()) - .map(|(limb, shift)| shift * limb.expr()) - .reduce(|a, b| a + b) - .unwrap() + .map(|(&limb, shift)| shift * limb) + .sum::>() }) .collect_vec(); Ok(UIntLimbs::::from_exprs_unchecked(combined_limbs)) @@ -343,8 +342,8 @@ impl UIntLimbs { let limbs = (0..k) .map(|_| { let w = circuit_builder.create_witin(|| ""); - circuit_builder.assert_byte(|| "", w.expr()).unwrap(); - w.expr() + circuit_builder.assert_byte(|| "", w).unwrap(); + w }) .collect_vec(); let combined_limb = limbs @@ -355,19 +354,21 @@ impl UIntLimbs { .unwrap(); circuit_builder - .require_zero(|| "zero check", large_limb.expr() - combined_limb) + .require_zero(|| "zero check", large_limb - combined_limb) .unwrap(); limbs }) + .map(ToExpr::expr) .collect_vec(); UIntLimbs::::create_witin_from_exprs(circuit_builder, split_limbs) } - pub fn from_exprs_unchecked(expr_limbs: Vec>) -> Self { + pub fn from_exprs_unchecked(expr_limbs: Vec>>) -> Self { Self { limbs: UintLimb::Expression( expr_limbs .into_iter() + .map(ToExpr::expr) .chain(std::iter::repeat(Expression::ZERO)) .take(Self::NUM_LIMBS) .collect_vec(), @@ -479,10 +480,10 @@ impl UIntLimbs { )) } - pub fn to_field_expr(&self, is_neg: Expression) -> Expression { + pub fn to_field_expr(&self, is_neg: impl ToExpr>) -> Expression { // Convert two's complement representation into field arithmetic. // Example: 0xFFFF_FFFF = 2^32 - 1 --> shift --> -1 - self.value() - is_neg * (1_u64 << 32) + self.value() - (is_neg.expr() << 32) } } @@ -531,9 +532,9 @@ impl TryFrom<&[WitIn]> for UI } } -impl ToExpr for UIntLimbs { +impl ToExpr for &UIntLimbs { type Output = Vec>; - fn expr(&self) -> Vec> { + fn expr(self) -> Vec> { match &self.limbs { UintLimb::WitIn(limbs) => limbs .iter() diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index dfe33b076..6d0a8a9be 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -34,31 +34,30 @@ impl UIntLimbs { return Err(ZKVMError::CircuitError); }; carries.iter().enumerate().try_for_each(|(i, carry)| { - circuit_builder.assert_bit(|| format!("carry_{i}_in_as_bit"), carry.expr()) + circuit_builder.assert_bit(|| format!("carry_{i}_in_as_bit"), carry) })?; // perform add operation // c[i] = a[i] + b[i] + carry[i-1] - carry[i] * 2 ^ C c.limbs = UintLimb::Expression( - (self.expr()) + self.expr() .iter() .zip((*addend).iter()) .enumerate() .map(|(i, (a, b))| { let carries = c.carries.as_ref().unwrap(); - let carry = if i > 0 { carries.get(i - 1) } else { None }; + let carry = carries.get(i - 1); let next_carry = carries.get(i); - let mut limb_expr = a.clone() + b.clone(); - if carry.is_some() { - limb_expr = limb_expr.clone() + carry.unwrap().expr(); + let mut limb_expr = a + b; + if let Some(carry) = carry { + limb_expr += carry; } - if next_carry.is_some() { - limb_expr = limb_expr.clone() - next_carry.unwrap().expr() * Self::POW_OF_C; + if let Some(next_carry) = next_carry { + limb_expr -= next_carry.expr() * Self::POW_OF_C; } - circuit_builder - .assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), limb_expr.clone())?; + .assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), &limb_expr)?; Ok(limb_expr) }) .collect::>, ZKVMError>>()?, @@ -120,7 +119,7 @@ impl UIntLimbs { // with high limb, overall cell will be double let c_limbs: Vec = (0..num_limbs).try_fold(vec![], |mut c_limbs, i| { let limb = circuit_builder.create_witin(|| format!("limb_{i}")); - circuit_builder.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), limb.expr())?; + circuit_builder.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), limb)?; c_limbs.push(limb); Result::, ZKVMError>::Ok(c_limbs) })?; @@ -140,8 +139,8 @@ impl UIntLimbs { AssertLTConfig::construct_circuit( circuit_builder, || format!("carry_{i}_in_less_than"), - carry.expr(), - (Self::MAX_DEGREE_2_MUL_CARRY_VALUE as usize).into(), + carry, + Self::MAX_DEGREE_2_MUL_CARRY_VALUE, Self::MAX_DEGREE_2_MUL_CARRY_U16_LIMB, ) }) @@ -190,16 +189,16 @@ impl UIntLimbs { // constrain each limb with carry c_limbs.iter().enumerate().try_for_each(|(i, c_limb)| { - let carry = if i > 0 { c_carries.get(i - 1) } else { None }; + let carry = c_carries.get(i - 1); let next_carry = c_carries.get(i); - result_c[i] = result_c[i].clone() - c_limb.expr(); - if carry.is_some() { - result_c[i] = result_c[i].clone() + carry.unwrap().expr(); + result_c[i] -= c_limb; + if let Some(carry) = carry { + result_c[i] += carry; } - if next_carry.is_some() { - result_c[i] = result_c[i].clone() - next_carry.unwrap().expr() * Self::POW_OF_C; + if let Some(next_carry) = next_carry { + result_c[i] -= next_carry.expr() * Self::POW_OF_C; } - circuit_builder.require_zero(|| format!("mul_zero_{i}"), result_c[i].clone())?; + circuit_builder.require_zero(|| format!("mul_zero_{i}"), &result_c[i])?; Ok::<(), ZKVMError>(()) })?; @@ -667,8 +666,8 @@ mod tests { // overflow if overflow { - let overflow = uint_c.carries.unwrap().last().unwrap().expr(); - assert_eq!(eval_by_expr(&wit, &challenges, &overflow), E::ONE); + let &overflow = uint_c.carries.unwrap().last().unwrap(); + assert_eq!(eval_by_expr(&wit, &challenges, overflow), E::ONE); } else { // non-overflow case, the len of carries should be (NUM_CELLS - 1) assert_eq!(uint_c.carries.unwrap().len(), single_wit_size - 1) diff --git a/ceno_zkvm/src/uint/logic.rs b/ceno_zkvm/src/uint/logic.rs index 024d09d73..3867b48a1 100644 --- a/ceno_zkvm/src/uint/logic.rs +++ b/ceno_zkvm/src/uint/logic.rs @@ -3,8 +3,8 @@ use itertools::izip; use super::UIntLimbs; use crate::{ - ROMType, circuit_builder::CircuitBuilder, error::ZKVMError, expression::ToExpr, - tables::OpsTable, witness::LkMultiplicity, + ROMType, circuit_builder::CircuitBuilder, error::ZKVMError, tables::OpsTable, + witness::LkMultiplicity, }; // Only implemented for u8 limbs. @@ -19,7 +19,7 @@ impl UIntLimbs { c: &Self, ) -> Result<(), ZKVMError> { for (a_byte, b_byte, c_byte) in izip!(&a.limbs, &b.limbs, &c.limbs) { - cb.logic_u8(rom_type, a_byte.expr(), b_byte.expr(), c_byte.expr())?; + cb.logic_u8(rom_type, a_byte, b_byte, c_byte)?; } Ok(()) } From 4bd26c3a99d6ffffe2a44cdba98b71d65918ab93 Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Thu, 12 Dec 2024 09:55:59 +0800 Subject: [PATCH 02/18] Fix --- ceno_zkvm/src/gadgets/is_zero.rs | 6 ++-- ceno_zkvm/src/instructions/riscv/mul.rs | 39 +++++++++++++++---------- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/ceno_zkvm/src/gadgets/is_zero.rs b/ceno_zkvm/src/gadgets/is_zero.rs index d8d6f2f98..0d4900be1 100644 --- a/ceno_zkvm/src/gadgets/is_zero.rs +++ b/ceno_zkvm/src/gadgets/is_zero.rs @@ -129,9 +129,11 @@ impl IsEqualConfig { pub fn construct_non_equal, N: FnOnce() -> NR>( cb: &mut CircuitBuilder, name_fn: N, - a: Expression, - b: Expression, + a: impl ToExpr>, + b: impl ToExpr>, ) -> Result { + let a = a.expr(); + let b = b.expr(); Ok(IsEqualConfig(IsZeroConfig::construct_non_zero( cb, name_fn, diff --git a/ceno_zkvm/src/instructions/riscv/mul.rs b/ceno_zkvm/src/instructions/riscv/mul.rs index 458de49e3..8abba6072 100644 --- a/ceno_zkvm/src/instructions/riscv/mul.rs +++ b/ceno_zkvm/src/instructions/riscv/mul.rs @@ -198,9 +198,9 @@ impl Instruction for MulhInstructionBas let prod_low = UInt::new(|| "prod_low", circuit_builder)?; ( - rs1_signed.expr(), - rs2_signed.expr(), - rd_signed.expr(), + (&rs1_signed).expr(), + (&rs2_signed).expr(), + (&rd_signed).expr(), MulhSignDependencies::SS { rs1_signed, rs2_signed, @@ -213,12 +213,11 @@ impl Instruction for MulhInstructionBas InsnKind::MULHU => { let prod_low = UInt::new(|| "prod_low", circuit_builder)?; // constrain that rd does not represent 2^32 - 1 - let rd_avoid = Expression::::from(u32::MAX); let constrain_rd = IsEqualConfig::construct_non_equal( circuit_builder, || "constrain_rd", rd_written.value(), - rd_avoid, + u32::MAX, )?; ( @@ -231,13 +230,12 @@ impl Instruction for MulhInstructionBas } InsnKind::MUL => { // constrain that prod_hi does not represent 2^32 - 1 - let prod_hi_avoid = Expression::::from(u32::MAX); let prod_hi = UInt::new(|| "prod_hi", circuit_builder)?; let constrain_rd = IsEqualConfig::construct_non_equal( circuit_builder, || "constrain_prod_hi", prod_hi.value(), - prod_hi_avoid, + u32::MAX, )?; ( @@ -255,18 +253,17 @@ impl Instruction for MulhInstructionBas let prod_low = UInt::new(|| "prod_low", circuit_builder)?; // constrain that (signed) rd does not represent 2^31 - 1 - let rd_avoid = Expression::::from(i32::MAX); let constrain_rd = IsEqualConfig::construct_non_equal( circuit_builder, || "constrain_rd", - rd_signed.expr(), - rd_avoid, + &rd_signed, + i32::MAX, )?; ( - rs1_signed.expr(), + (&rs1_signed).expr(), rs2_read.value(), - rd_signed.expr(), + (&rd_signed).expr(), MulhSignDependencies::SU { rs1_signed, rd_signed, @@ -407,6 +404,20 @@ struct Signed { val: Expression, } +impl ToExpr for &Signed { + type Output = Expression; + fn expr(self) -> Expression { + self.val.clone() + } +} + +impl ToExpr for Signed { + type Output = Expression; + fn expr(self) -> Expression { + self.val + } +} + impl Signed { pub fn construct_circuit + Display + Clone, N: FnOnce() -> NR>( cb: &mut CircuitBuilder, @@ -436,10 +447,6 @@ impl Signed { Ok(signed_val) } - - pub fn expr(&self) -> Expression { - self.val.clone() - } } #[cfg(test)] From 5744546b6e32e3b47d443d618a4410ea949a43e0 Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Thu, 12 Dec 2024 09:58:14 +0800 Subject: [PATCH 03/18] Output --- ceno_zkvm/src/expression.rs | 14 +++++++------- ceno_zkvm/src/instructions/riscv/mul.rs | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index ca22002ab..30c4dcba5 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -781,49 +781,49 @@ pub trait ToExpr { impl ToExpr for Expression { type Output = Expression; - fn expr(self) -> Expression { + fn expr(self) -> Self::Output { self } } impl ToExpr for &Expression { type Output = Expression; - fn expr(self) -> Expression { + fn expr(self) -> Self::Output { self.clone() } } impl ToExpr for WitIn { type Output = Expression; - fn expr(self) -> Expression { + fn expr(self) -> Self::Output { Expression::WitIn(self.id) } } impl ToExpr for &WitIn { type Output = Expression; - fn expr(self) -> Expression { + fn expr(self) -> Self::Output { Expression::WitIn(self.id) } } impl ToExpr for Fixed { type Output = Expression; - fn expr(self) -> Expression { + fn expr(self) -> Self::Output { Expression::Fixed(self) } } impl ToExpr for &Fixed { type Output = Expression; - fn expr(self) -> Expression { + fn expr(self) -> Self::Output { Expression::Fixed(*self) } } impl ToExpr for Instance { type Output = Expression; - fn expr(self) -> Expression { + fn expr(self) -> Self::Output { Expression::Instance(self) } } diff --git a/ceno_zkvm/src/instructions/riscv/mul.rs b/ceno_zkvm/src/instructions/riscv/mul.rs index 8abba6072..185503741 100644 --- a/ceno_zkvm/src/instructions/riscv/mul.rs +++ b/ceno_zkvm/src/instructions/riscv/mul.rs @@ -406,14 +406,14 @@ struct Signed { impl ToExpr for &Signed { type Output = Expression; - fn expr(self) -> Expression { + fn expr(self) -> Self::Output { self.val.clone() } } impl ToExpr for Signed { type Output = Expression; - fn expr(self) -> Expression { + fn expr(self) -> Self::Output { self.val } } From fd54041c66d467b5b71f9fcafdb7577fe1082905 Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Thu, 12 Dec 2024 10:03:46 +0800 Subject: [PATCH 04/18] Use Output everywhere --- ceno_zkvm/src/uint.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index a3de505e4..0220cd0f3 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -534,7 +534,7 @@ impl TryFrom<&[WitIn]> for UI impl ToExpr for &UIntLimbs { type Output = Vec>; - fn expr(self) -> Vec> { + fn expr(self) -> Self::Output { match &self.limbs { UintLimb::WitIn(limbs) => limbs .iter() From a9a30be48806a6bffe1ad450300c6d9289a95190 Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Thu, 12 Dec 2024 10:27:44 +0800 Subject: [PATCH 05/18] Use sum --- ceno_zkvm/src/uint/arithmetic.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 6d0a8a9be..5ea0f8fa7 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -268,18 +268,15 @@ impl UIntLimbs { let n_limbs = Self::NUM_LIMBS; let (is_equal_per_limb, diff_inv_per_limb): (Vec, Vec) = izip!(&self.limbs, &rhs.limbs) - .map(|(a, b)| circuit_builder.is_equal(a.expr(), b.expr())) + .map(|(a, b)| circuit_builder.is_equal(a, b)) .collect::, ZKVMError>>()? .into_iter() .unzip(); - let sum_expr = is_equal_per_limb - .iter() - .fold(Expression::ZERO, |acc, flag| acc.clone() + flag.expr()); + let sum_expr = is_equal_per_limb.iter().map(ToExpr::expr).sum(); let sum_flag = WitIn::from_expr(|| "sum_flag", circuit_builder, sum_expr, false)?; - let (is_equal, diff_inv) = - circuit_builder.is_equal(sum_flag.expr(), Expression::from(n_limbs))?; + let (is_equal, diff_inv) = circuit_builder.is_equal(sum_flag, n_limbs)?; Ok(IsEqualConfig { is_equal_per_limb, diff_inv_per_limb, From 61ce9e4eb8c4495781b384c8da88dd65e80bbd26 Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Thu, 12 Dec 2024 10:31:28 +0800 Subject: [PATCH 06/18] Extra --- mpcs/src/sum_check/classic/coeff.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mpcs/src/sum_check/classic/coeff.rs b/mpcs/src/sum_check/classic/coeff.rs index 12f46880f..36d5390a6 100644 --- a/mpcs/src/sum_check/classic/coeff.rs +++ b/mpcs/src/sum_check/classic/coeff.rs @@ -49,9 +49,7 @@ impl ClassicSumCheckRoundMessage for Coefficients { } fn sum(&self) -> E { - self[1..] - .iter() - .fold(self[0].double(), |acc, coeff| acc + coeff) + self[..].iter().sum() } fn evaluate(&self, _: &Self::Auxiliary, challenge: &E) -> E { From 1e52d68832228a56eb6b57145da97f6f21a951d4 Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Thu, 12 Dec 2024 10:31:28 +0800 Subject: [PATCH 07/18] Use `sum` instead of writing our own --- ceno_zkvm/src/uint/arithmetic.rs | 4 +--- mpcs/src/sum_check/classic/coeff.rs | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index dfe33b076..729021be9 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -274,9 +274,7 @@ impl UIntLimbs { .into_iter() .unzip(); - let sum_expr = is_equal_per_limb - .iter() - .fold(Expression::ZERO, |acc, flag| acc.clone() + flag.expr()); + let sum_expr = is_equal_per_limb.iter().map(ToExpr::expr).sum(); let sum_flag = WitIn::from_expr(|| "sum_flag", circuit_builder, sum_expr, false)?; let (is_equal, diff_inv) = diff --git a/mpcs/src/sum_check/classic/coeff.rs b/mpcs/src/sum_check/classic/coeff.rs index 12f46880f..36d5390a6 100644 --- a/mpcs/src/sum_check/classic/coeff.rs +++ b/mpcs/src/sum_check/classic/coeff.rs @@ -49,9 +49,7 @@ impl ClassicSumCheckRoundMessage for Coefficients { } fn sum(&self) -> E { - self[1..] - .iter() - .fold(self[0].double(), |acc, coeff| acc + coeff) + self[..].iter().sum() } fn evaluate(&self, _: &Self::Auxiliary, challenge: &E) -> E { From 78c43b5f0e39c5053384adf73d55303018e64893 Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Thu, 12 Dec 2024 11:01:28 +0800 Subject: [PATCH 08/18] Cleaned up --- ceno_zkvm/src/instructions/riscv/ecall/halt.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs index 3d3e2ec75..eb5c0849d 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs @@ -2,7 +2,7 @@ use crate::{ chip_handler::RegisterChipOperations, circuit_builder::CircuitBuilder, error::ZKVMError, - expression::{Expression, ToExpr, WitIn}, + expression::{ToExpr, WitIn}, gadgets::AssertLTConfig, instructions::{ Instruction, @@ -51,8 +51,7 @@ impl Instruction for HaltInstruction { // read exit_code from arg0 (X10 register) let (_, lt_x10_cfg) = cb.register_read( || "read x10", - // TODO(Matthias): clean up. - &Expression::Constant(E::BaseField::from(ceno_emul::Platform::reg_arg0() as u64)), + ceno_emul::Platform::reg_arg0(), prev_x10_ts, ecall_cfg.ts.expr() + Tracer::SUBCYCLE_RS2, exit_code, From 4c597d98a84e7e9b005b6d22cae04847bc751bcb Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Thu, 12 Dec 2024 11:12:22 +0800 Subject: [PATCH 09/18] Fix --- mpcs/src/sum_check/classic/coeff.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mpcs/src/sum_check/classic/coeff.rs b/mpcs/src/sum_check/classic/coeff.rs index 36d5390a6..10d5c1c20 100644 --- a/mpcs/src/sum_check/classic/coeff.rs +++ b/mpcs/src/sum_check/classic/coeff.rs @@ -49,7 +49,7 @@ impl ClassicSumCheckRoundMessage for Coefficients { } fn sum(&self) -> E { - self[..].iter().sum() + self[0] + self[..].iter().sum::() } fn evaluate(&self, _: &Self::Auxiliary, challenge: &E) -> E { From 709c9f06b6e6532593fedcbe68ed3e5194ff918e Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Thu, 12 Dec 2024 17:38:06 +0800 Subject: [PATCH 10/18] Tone down --- ceno_zkvm/src/instructions/riscv/insn_base.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 8f1fd291a..43f1d774b 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -420,11 +420,9 @@ impl MemAddr { .sum(); // Range check the middle bits, that is the low limb excluding the low bits. - // TODO(Matthias): division here seems very, very suspicious from a soundness perspective. - // TODO(Matthias): clean up. - let shift_right = Expression::Constant( - E::BaseField::from(1 << Self::N_LOW_BITS).invert().unwrap(), /* TODO: do something about this. */ - ); + // TODO(Matthias): division here seems suspicious from a soundness perspective. + let shift_right = + Expression::Constant(E::BaseField::from(1 << Self::N_LOW_BITS).invert().unwrap()); let mid_u14 = (&limbs[0] - low_sum) * shift_right; cb.assert_ux::<_, _, 14>(|| "mid_u14", mid_u14)?; From 1059a0da118480127f7713be1752e1d8a57a7b0b Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Thu, 12 Dec 2024 17:51:25 +0800 Subject: [PATCH 11/18] Copy --- ceno_zkvm/src/gadgets/signed_ext.rs | 2 +- ceno_zkvm/src/instructions/riscv/mul.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ceno_zkvm/src/gadgets/signed_ext.rs b/ceno_zkvm/src/gadgets/signed_ext.rs index a446687b5..8d642955b 100644 --- a/ceno_zkvm/src/gadgets/signed_ext.rs +++ b/ceno_zkvm/src/gadgets/signed_ext.rs @@ -9,7 +9,7 @@ use crate::{ use ff_ext::ExtensionField; use std::{marker::PhantomData, mem::MaybeUninit}; -#[derive(Debug)] +#[derive(Copy, Clone, Debug)] pub struct SignedExtendConfig { /// most significant bit msb: WitIn, diff --git a/ceno_zkvm/src/instructions/riscv/mul.rs b/ceno_zkvm/src/instructions/riscv/mul.rs index 44d54b6a5..d3ae5e8b2 100644 --- a/ceno_zkvm/src/instructions/riscv/mul.rs +++ b/ceno_zkvm/src/instructions/riscv/mul.rs @@ -426,7 +426,7 @@ impl Signed { ) -> Result { cb.namespace(name_fn, |cb| { let is_negative = unsigned_val.is_negative(cb)?; - let val = unsigned_val.value() - (1u64 << BIT_WIDTH) * (&is_negative).expr(); + let val = unsigned_val.value() - (is_negative.expr() << BIT_WIDTH); Ok(Self { is_negative, val }) }) From 0869fc333df0a113aa41a4a61e859f4bf2f34fbf Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Thu, 12 Dec 2024 17:54:30 +0800 Subject: [PATCH 12/18] Simpler --- ceno_zkvm/src/gadgets/is_zero.rs | 2 ++ ceno_zkvm/src/instructions/riscv/branch/beq_circuit.rs | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ceno_zkvm/src/gadgets/is_zero.rs b/ceno_zkvm/src/gadgets/is_zero.rs index 0d4900be1..676693efc 100644 --- a/ceno_zkvm/src/gadgets/is_zero.rs +++ b/ceno_zkvm/src/gadgets/is_zero.rs @@ -10,6 +10,7 @@ use crate::{ set_val, }; +#[derive(Clone, Copy, Debug)] pub struct IsZeroConfig { is_zero: Option, inverse: WitIn, @@ -94,6 +95,7 @@ impl IsZeroConfig { } } +#[derive(Clone, Copy, Debug)] pub struct IsEqualConfig(IsZeroConfig); impl ToExpr for IsEqualConfig { diff --git a/ceno_zkvm/src/instructions/riscv/branch/beq_circuit.rs b/ceno_zkvm/src/instructions/riscv/branch/beq_circuit.rs index 249231d5a..bcf9975e1 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/beq_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/beq_circuit.rs @@ -49,8 +49,8 @@ impl Instruction for BeqCircuit { )?; let branch_taken_bit = match I::INST_KIND { - InsnKind::BEQ => (&equal).expr(), - InsnKind::BNE => 1 - (&equal).expr(), + InsnKind::BEQ => equal.expr(), + InsnKind::BNE => 1 - equal.expr(), _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), }; From b0ac76d6cd60b03410183ac9e16a5ea7babab77a Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Thu, 12 Dec 2024 17:57:07 +0800 Subject: [PATCH 13/18] Simpler --- ceno_zkvm/src/instructions/riscv/div.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index 266e5ae1f..59ab346b5 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -98,7 +98,7 @@ impl Instruction for ArithInstruction Date: Thu, 12 Dec 2024 18:16:20 +0800 Subject: [PATCH 14/18] Shifting --- ceno_zkvm/src/expression.rs | 58 ++++++++++++++++++- ceno_zkvm/src/instructions/riscv/insn_base.rs | 6 +- .../src/instructions/riscv/memory/gadget.rs | 4 +- 3 files changed, 59 insertions(+), 9 deletions(-) diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index 30c4dcba5..bcb22c47d 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -5,7 +5,7 @@ use std::{ fmt::Display, iter::{Product, Sum}, mem::MaybeUninit, - ops::{Add, AddAssign, Deref, Mul, MulAssign, Neg, Shl, ShlAssign, Sub, SubAssign}, + ops::{Add, AddAssign, Deref, Div, Mul, MulAssign, Neg, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign}, }; use ceno_emul::InsnKind; @@ -361,6 +361,35 @@ impl ShlAssign for Expression { } } +// + +impl Shr for Expression { + type Output = Expression; + fn shr(self, rhs: usize) -> Expression { + self / (1_usize << rhs) + } +} + +impl Shr for &Expression { + type Output = Expression; + fn shr(self, rhs: usize) -> Expression { + self.clone() >> rhs + } +} + +impl Shr for &mut Expression { + type Output = Expression; + fn shr(self, rhs: usize) -> Expression { + self.clone() >> rhs + } +} + +impl ShrAssign for Expression { + fn shr_assign(&mut self, rhs: usize) { + *self = self.clone() >> rhs; + } +} + impl Sum for Expression { fn sum>>(iter: I) -> Expression { iter.fold(Expression::ZERO, |acc, x| acc + x) @@ -730,6 +759,33 @@ impl Mul for Expression { } } + +macro_rules! div_instances { + (($($t:ty),*)) => { + $( + + impl Div<$t> for Expression { + type Output = Expression; + #[allow(clippy::suspicious_arithmetic_impl)] + fn div(self, rhs: $t) -> Expression { + let reduced = (rhs as i128).rem_euclid(E::BaseField::MODULUS_U64 as i128) as u64; + self * E::BaseField::from(reduced).invert().unwrap().to_canonical_u64() + } + } + + impl Div<$t> for &Expression { + type Output = Expression; + #[allow(clippy::suspicious_arithmetic_impl)] + fn div(self, rhs: $t) -> Expression { + let reduced = (rhs as i128).rem_euclid(E::BaseField::MODULUS_U64 as i128) as u64; + self * E::BaseField::from(reduced).invert().unwrap().to_canonical_u64() + } + } + )* + }; +} +div_instances!((u8, u16, u32, u64, usize, i8, i16, i32, i64, isize)); + #[derive(Clone, Debug, Copy)] pub struct WitIn { pub id: WitnessId, diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 43f1d774b..4ba3d7a9e 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -1,5 +1,4 @@ use ceno_emul::{StepRecord, Word}; -use ff::Field; use ff_ext::ExtensionField; use itertools::Itertools; @@ -420,10 +419,7 @@ impl MemAddr { .sum(); // Range check the middle bits, that is the low limb excluding the low bits. - // TODO(Matthias): division here seems suspicious from a soundness perspective. - let shift_right = - Expression::Constant(E::BaseField::from(1 << Self::N_LOW_BITS).invert().unwrap()); - let mid_u14 = (&limbs[0] - low_sum) * shift_right; + let mid_u14 = (&limbs[0] - low_sum) >> Self::N_LOW_BITS; cb.assert_ux::<_, _, 14>(|| "mid_u14", mid_u14)?; // Range check the high limb. diff --git a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs index 07aa5dac1..79efcd81c 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs @@ -8,7 +8,6 @@ use crate::{ witness::LkMultiplicity, }; use ceno_emul::StepRecord; -use ff::Field; use ff_ext::ExtensionField; use itertools::izip; use std::mem::MaybeUninit; @@ -77,10 +76,9 @@ impl MemWordChange { // extract the least significant byte from u16 limb let rs2_limb_bytes = alloc_bytes(cb, "rs2_limb[0]", 1)?; - let u8_base_inv = E::BaseField::from(1 << 8).invert().unwrap(); cb.assert_ux::<_, _, 8>( || "rs2_limb[0].le_bytes[1]", - Expression::Constant(u8_base_inv) * (&rs2_limbs[0] - rs2_limb_bytes[0].expr()), + (&rs2_limbs[0] - rs2_limb_bytes[0].expr()) >> 8, )?; // alloc a new witIn to cache degree 2 expression From 42e4edf0a8744ff67c356cf16c96971f7104d420 Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Thu, 12 Dec 2024 18:18:24 +0800 Subject: [PATCH 15/18] Div and Shr --- ceno_zkvm/src/expression.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index bcb22c47d..78ca0234b 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -5,7 +5,10 @@ use std::{ fmt::Display, iter::{Product, Sum}, mem::MaybeUninit, - ops::{Add, AddAssign, Deref, Div, Mul, MulAssign, Neg, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign}, + ops::{ + Add, AddAssign, Deref, Div, Mul, MulAssign, Neg, Shl, ShlAssign, Shr, ShrAssign, Sub, + SubAssign, + }, }; use ceno_emul::InsnKind; @@ -759,7 +762,6 @@ impl Mul for Expression { } } - macro_rules! div_instances { (($($t:ty),*)) => { $( From 91e1a4649c990aaaf1c519d8af78f22ca86ea615 Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Thu, 12 Dec 2024 18:21:28 +0800 Subject: [PATCH 16/18] Simpler --- ceno_zkvm/src/instructions/riscv/div.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index 59ab346b5..d7518b54b 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -79,9 +79,9 @@ impl Instruction for ArithInstruction::TOTAL_BITS) - 1), + (1u64 << UInt::::TOTAL_BITS) - 1, &outcome_value, )?; From 0a11f3aabadb913239c1effa914cefab8b1b4df2 Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Thu, 12 Dec 2024 18:22:21 +0800 Subject: [PATCH 17/18] Clippy --- ceno_zkvm/src/gadgets/is_lt.rs | 4 ++-- ceno_zkvm/src/instructions/riscv/div.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ceno_zkvm/src/gadgets/is_lt.rs b/ceno_zkvm/src/gadgets/is_lt.rs index b5de99d13..0cb05d5ef 100644 --- a/ceno_zkvm/src/gadgets/is_lt.rs +++ b/ceno_zkvm/src/gadgets/is_lt.rs @@ -344,8 +344,8 @@ impl InnerSignedLtConfig { let is_rhs_neg = rhs.is_negative(cb)?; // Convert to field arithmetic. - let lhs_value = lhs.to_field_expr(&is_lhs_neg); - let rhs_value = rhs.to_field_expr(&is_rhs_neg); + let lhs_value = lhs.to_field_expr(is_lhs_neg); + let rhs_value = rhs.to_field_expr(is_rhs_neg); let config = InnerLtConfig::construct_circuit( cb, format!("{name} (lhs < rhs)"), diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index d7518b54b..404250459 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -10,7 +10,7 @@ use super::{ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, - expression::{Expression, ToExpr}, + expression::ToExpr, gadgets::{IsLtConfig, IsZeroConfig}, instructions::Instruction, uint::Value, From 3c9e297ee15c6bfc4dacc91f6feef5c2d627ca81 Mon Sep 17 00:00:00 2001 From: Matthias Goergens Date: Thu, 12 Dec 2024 18:25:40 +0800 Subject: [PATCH 18/18] Fewer refereces --- ceno_zkvm/src/instructions/riscv/shift_imm.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index 166d18875..3f02b5603 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -106,8 +106,8 @@ impl Instruction for ShiftImmInstructio let (inflow, is_lt_config) = match I::INST_KIND { InsnKind::SRAI => { let is_rs1_neg = rs1_read.is_negative(circuit_builder)?; - let ones = imm.expr() - 1; - ((&is_rs1_neg).expr() * ones, Some(is_rs1_neg)) + let ones: Expression = imm.expr() - 1; + (is_rs1_neg.expr() * ones, Some(is_rs1_neg)) } InsnKind::SRLI => (Expression::ZERO, None), _ => unreachable!(),