Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 42 additions & 5 deletions ceno_zkvm/src/instructions/riscv/insn_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use ff_ext::{ExtensionField, FieldInto, SmallField};
use itertools::Itertools;
use p3::field::{Field, FieldAlgebra};

use super::constants::{PC_STEP_SIZE, UINT_LIMBS, UInt};
use super::constants::{BIT_WIDTH, PC_STEP_SIZE, UINT_LIMBS, UInt};
use crate::{
chip_handler::{
AddressExpr, GlobalStateRegisterMachineChipOperations, MemoryChipOperations, MemoryExpr,
Expand Down Expand Up @@ -368,6 +368,7 @@ impl WriteMEM {
pub struct MemAddr<E: ExtensionField> {
addr: UInt<E>,
low_bits: Vec<WitIn>,
max_bits: usize,
}

impl<E: ExtensionField> MemAddr<E> {
Expand All @@ -393,6 +394,17 @@ impl<E: ExtensionField> MemAddr<E> {
self.addr.address_expr()
}

pub fn uint_unaligned(&self) -> UInt<E> {
UInt::from_exprs_unchecked(self.addr.expr())
}

pub fn uint_align2(&self) -> UInt<E> {
UInt::from_exprs_unchecked(vec![
self.addr.limbs[0].expr() - &self.low_bit_exprs()[0],
self.addr.limbs[1].expr(),
])
}

/// Represent the address aligned to 2 bytes.
pub fn expr_align2(&self) -> AddressExpr<E> {
self.addr.address_expr() - &self.low_bit_exprs()[0]
Expand All @@ -404,6 +416,14 @@ impl<E: ExtensionField> MemAddr<E> {
self.addr.address_expr() - &low_bits[1] * 2 - &low_bits[0]
}

pub fn uint_align4(&self) -> UInt<E> {
let low_bits = self.low_bit_exprs();
UInt::from_exprs_unchecked(vec![
self.addr.limbs[0].expr() - &low_bits[1] * 2 - &low_bits[0],
self.addr.limbs[1].expr(),
])
}

/// Expressions of the low bits of the address, LSB-first: [bit_0, bit_1].
pub fn low_bit_exprs(&self) -> Vec<Expression<E>> {
iter::repeat_n(Expression::ZERO, self.n_zeros())
Expand All @@ -412,6 +432,14 @@ impl<E: ExtensionField> MemAddr<E> {
}

fn construct(cb: &mut CircuitBuilder<E>, n_zeros: usize) -> Result<Self, ZKVMError> {
Self::construct_with_max_bits(cb, n_zeros, BIT_WIDTH)
}

pub fn construct_with_max_bits(
cb: &mut CircuitBuilder<E>,
n_zeros: usize,
max_bits: usize,
) -> Result<Self, ZKVMError> {
assert!(n_zeros <= Self::N_LOW_BITS);

// The address as two u16 limbs.
Expand Down Expand Up @@ -442,11 +470,19 @@ impl<E: ExtensionField> MemAddr<E> {
cb.assert_ux::<_, _, 14>(|| "mid_u14", mid_u14)?;

// Range check the high limb.
for high_u16 in limbs.iter().skip(1) {
cb.assert_ux::<_, _, 16>(|| "high_u16", high_u16.clone())?;
for (i, high_limb) in limbs.iter().enumerate().skip(1) {
cb.assert_ux_v2(
|| "high_limb",
high_limb.clone(),
(max_bits - i * 16).min(16),
)?;
}

Ok(MemAddr { addr, low_bits })
Ok(MemAddr {
addr,
low_bits,
max_bits,
})
}

pub fn assign_instance(
Expand All @@ -470,7 +506,8 @@ impl<E: ExtensionField> MemAddr<E> {
// Range check the high limb.
for i in 1..UINT_LIMBS {
let high_u16 = (addr >> (i * 16)) & 0xffff;
lkm.assert_ux::<16>(high_u16 as u64);
println!("assignment max bit {}", (self.max_bits - i * 16).min(16));
lkm.assert_ux_v2(high_u16 as u64, (self.max_bits - i * 16).min(16));
}

Ok(())
Expand Down
14 changes: 14 additions & 0 deletions ceno_zkvm/src/instructions/riscv/jump.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
#[cfg(not(feature = "u16limb_circuit"))]
mod jal;
#[cfg(feature = "u16limb_circuit")]
mod jal_v2;

#[cfg(not(feature = "u16limb_circuit"))]
mod jalr;
#[cfg(feature = "u16limb_circuit")]
mod jalr_v2;

#[cfg(not(feature = "u16limb_circuit"))]
pub use jal::JalInstruction;
#[cfg(feature = "u16limb_circuit")]
pub use jal_v2::JalInstruction;

#[cfg(not(feature = "u16limb_circuit"))]
pub use jalr::JalrInstruction;
#[cfg(feature = "u16limb_circuit")]
pub use jalr_v2::JalrInstruction;

#[cfg(test)]
mod test;
113 changes: 113 additions & 0 deletions ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use std::marker::PhantomData;

use ff_ext::ExtensionField;

use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
instructions::{
Instruction,
riscv::{
constants::{PC_BITS, UINT_BYTE_LIMBS, UInt8},
j_insn::JInstructionConfig,
},
},
structs::ProgramParams,
utils::split_to_u8,
witness::LkMultiplicity,
};
use ceno_emul::{InsnKind, PC_STEP_SIZE};
use gkr_iop::tables::{LookupTable, ops::XorTable};
use multilinear_extensions::{Expression, ToExpr};
use p3::field::FieldAlgebra;

pub struct JalConfig<E: ExtensionField> {
pub j_insn: JInstructionConfig<E>,
pub rd_written: UInt8<E>,
}

pub struct JalInstruction<E>(PhantomData<E>);

/// JAL instruction circuit
///
/// Note: does not validate that next_pc is aligned by 4-byte increments, which
/// should be verified by lookup argument of the next execution step against
/// the program table
///
/// Assumption: values for valid initial program counter must lie between
/// 2^20 and 2^32 - 2^20 + 2 inclusive, probably enforced by the static
/// program lookup table. If this assumption does not hold, then resulting
/// value for next_pc may not correctly wrap mod 2^32 because of the use
/// of native WitIn values for address space arithmetic.
impl<E: ExtensionField> Instruction<E> for JalInstruction<E> {
type InstructionConfig = JalConfig<E>;

fn name() -> String {
format!("{:?}", InsnKind::JAL)
}

fn construct_circuit(
circuit_builder: &mut CircuitBuilder<E>,
_params: &ProgramParams,
) -> Result<JalConfig<E>, ZKVMError> {
let rd_written = UInt8::new(|| "rd_written", circuit_builder)?;
let rd_exprs = rd_written.expr();

let j_insn = JInstructionConfig::construct_circuit(
circuit_builder,
InsnKind::JAL,
rd_written.register_expr(),
)?;

// constrain rd_exprs [PC_BITS .. u32::BITS] are all 0 via xor
let last_limb_bits = PC_BITS - UInt8::<E>::LIMB_BITS * (UInt8::<E>::NUM_LIMBS - 1);
let additional_bits =
(last_limb_bits..UInt8::<E>::LIMB_BITS).fold(0, |acc, x| acc + (1 << x));
let additional_bits = E::BaseField::from_canonical_u32(additional_bits);
circuit_builder.logic_u8(
LookupTable::Xor,
rd_exprs[3].expr(),
additional_bits.expr(),
rd_exprs[3].expr() + additional_bits.expr(),
)?;

circuit_builder.require_equal(
|| "jal rd_written",
rd_exprs
.iter()
.enumerate()
.fold(Expression::ZERO, |acc, (i, val)| {
acc + val.expr()
* E::BaseField::from_canonical_u32(1 << (i * UInt8::<E>::LIMB_BITS)).expr()
}),
j_insn.vm_state.pc.expr() + PC_STEP_SIZE,
)?;

Ok(JalConfig { j_insn, rd_written })
}

fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [E::BaseField],
lk_multiplicity: &mut LkMultiplicity,
step: &ceno_emul::StepRecord,
) -> Result<(), ZKVMError> {
config
.j_insn
.assign_instance(instance, lk_multiplicity, step)?;

let rd_written = split_to_u8(step.rd().unwrap().value.after);
config.rd_written.assign_limbs(instance, &rd_written);
for val in &rd_written {
lk_multiplicity.assert_ux::<8>(*val as u64);
}

// constrain pc msb limb range via xor
let last_limb_bits = PC_BITS - UInt8::<E>::LIMB_BITS * (UINT_BYTE_LIMBS - 1);
let additional_bits =
(last_limb_bits..UInt8::<E>::LIMB_BITS).fold(0, |acc, x| acc + (1 << x));
lk_multiplicity.logic_u8::<XorTable>(rd_written[3] as u64, additional_bits as u64);

Ok(())
}
}
2 changes: 0 additions & 2 deletions ceno_zkvm/src/instructions/riscv/jump/jalr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ impl<E: ExtensionField> Instruction<E> for JalrInstruction<E> {
circuit_builder,
InsnKind::JALR,
imm.expr(),
#[cfg(feature = "u16limb_circuit")]
0.into(),
rs1_read.register_expr(),
rd_written.register_expr(),
true,
Expand Down
Loading