Skip to content

Commit 2ccf4ed

Browse files
committed
migrate jal/jalr
1 parent 8151d72 commit 2ccf4ed

File tree

9 files changed

+337
-34
lines changed

9 files changed

+337
-34
lines changed

ceno_zkvm/src/instructions/riscv/insn_base.rs

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use ff_ext::{ExtensionField, FieldInto, SmallField};
33
use itertools::Itertools;
44
use p3::field::{Field, FieldAlgebra};
55

6-
use super::constants::{PC_STEP_SIZE, UINT_LIMBS, UInt};
6+
use super::constants::{BIT_WIDTH, PC_STEP_SIZE, UINT_LIMBS, UInt};
77
use crate::{
88
chip_handler::{
99
AddressExpr, GlobalStateRegisterMachineChipOperations, MemoryChipOperations, MemoryExpr,
@@ -368,6 +368,7 @@ impl WriteMEM {
368368
pub struct MemAddr<E: ExtensionField> {
369369
addr: UInt<E>,
370370
low_bits: Vec<WitIn>,
371+
max_bits: usize,
371372
}
372373

373374
impl<E: ExtensionField> MemAddr<E> {
@@ -393,6 +394,17 @@ impl<E: ExtensionField> MemAddr<E> {
393394
self.addr.address_expr()
394395
}
395396

397+
pub fn uint_unaligned(&self) -> UInt<E> {
398+
UInt::from_exprs_unchecked(self.addr.expr())
399+
}
400+
401+
pub fn uint_align2(&self) -> UInt<E> {
402+
UInt::from_exprs_unchecked(vec![
403+
self.addr.limbs[0].expr() - &self.low_bit_exprs()[0],
404+
self.addr.limbs[1].expr(),
405+
])
406+
}
407+
396408
/// Represent the address aligned to 2 bytes.
397409
pub fn expr_align2(&self) -> AddressExpr<E> {
398410
self.addr.address_expr() - &self.low_bit_exprs()[0]
@@ -404,6 +416,14 @@ impl<E: ExtensionField> MemAddr<E> {
404416
self.addr.address_expr() - &low_bits[1] * 2 - &low_bits[0]
405417
}
406418

419+
pub fn uint_align4(&self) -> UInt<E> {
420+
let low_bits = self.low_bit_exprs();
421+
UInt::from_exprs_unchecked(vec![
422+
self.addr.limbs[0].expr() - &low_bits[1] * 2 - &low_bits[0],
423+
self.addr.limbs[1].expr(),
424+
])
425+
}
426+
407427
/// Expressions of the low bits of the address, LSB-first: [bit_0, bit_1].
408428
pub fn low_bit_exprs(&self) -> Vec<Expression<E>> {
409429
iter::repeat_n(Expression::ZERO, self.n_zeros())
@@ -412,6 +432,14 @@ impl<E: ExtensionField> MemAddr<E> {
412432
}
413433

414434
fn construct(cb: &mut CircuitBuilder<E>, n_zeros: usize) -> Result<Self, ZKVMError> {
435+
Self::construct_with_max_bits(cb, n_zeros, BIT_WIDTH)
436+
}
437+
438+
pub fn construct_with_max_bits(
439+
cb: &mut CircuitBuilder<E>,
440+
n_zeros: usize,
441+
max_bits: usize,
442+
) -> Result<Self, ZKVMError> {
415443
assert!(n_zeros <= Self::N_LOW_BITS);
416444

417445
// The address as two u16 limbs.
@@ -442,11 +470,19 @@ impl<E: ExtensionField> MemAddr<E> {
442470
cb.assert_ux::<_, _, 14>(|| "mid_u14", mid_u14)?;
443471

444472
// Range check the high limb.
445-
for high_u16 in limbs.iter().skip(1) {
446-
cb.assert_ux::<_, _, 16>(|| "high_u16", high_u16.clone())?;
473+
for (i, high_limb) in limbs.iter().enumerate().skip(1) {
474+
cb.assert_ux_v2(
475+
|| "high_limb",
476+
high_limb.clone(),
477+
(max_bits - i * 16).min(16),
478+
)?;
447479
}
448480

449-
Ok(MemAddr { addr, low_bits })
481+
Ok(MemAddr {
482+
addr,
483+
low_bits,
484+
max_bits,
485+
})
450486
}
451487

452488
pub fn assign_instance(
@@ -470,7 +506,8 @@ impl<E: ExtensionField> MemAddr<E> {
470506
// Range check the high limb.
471507
for i in 1..UINT_LIMBS {
472508
let high_u16 = (addr >> (i * 16)) & 0xffff;
473-
lkm.assert_ux::<16>(high_u16 as u64);
509+
println!("assignment max bit {}", (self.max_bits - i * 16).min(16));
510+
lkm.assert_ux_v2(high_u16 as u64, (self.max_bits - i * 16).min(16));
474511
}
475512

476513
Ok(())

ceno_zkvm/src/instructions/riscv/jump.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,21 @@
22
mod jal;
33
#[cfg(feature = "u16limb_circuit")]
44
mod jal_v2;
5+
6+
#[cfg(not(feature = "u16limb_circuit"))]
57
mod jalr;
8+
#[cfg(feature = "u16limb_circuit")]
9+
mod jalr_v2;
610

711
#[cfg(not(feature = "u16limb_circuit"))]
812
pub use jal::JalInstruction;
913
#[cfg(feature = "u16limb_circuit")]
1014
pub use jal_v2::JalInstruction;
1115

16+
#[cfg(not(feature = "u16limb_circuit"))]
1217
pub use jalr::JalrInstruction;
18+
#[cfg(feature = "u16limb_circuit")]
19+
pub use jalr_v2::JalrInstruction;
1320

1421
#[cfg(test)]
1522
mod test;

ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@ use std::marker::PhantomData;
33
use ff_ext::ExtensionField;
44

55
use crate::{
6-
Value,
76
circuit_builder::CircuitBuilder,
87
error::ZKVMError,
98
instructions::{
109
Instruction,
1110
riscv::{
12-
constants::{BIT_WIDTH, PC_BITS, UINT_BYTE_LIMBS, UInt8},
11+
constants::{PC_BITS, UINT_BYTE_LIMBS, UInt8},
1312
j_insn::JInstructionConfig,
1413
},
1514
},

ceno_zkvm/src/instructions/riscv/jump/jalr.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@ impl<E: ExtensionField> Instruction<E> for JalrInstruction<E> {
5353
circuit_builder,
5454
InsnKind::JALR,
5555
imm.expr(),
56-
#[cfg(feature = "u16limb_circuit")]
57-
0.into(),
5856
rs1_read.register_expr(),
5957
rd_written.register_expr(),
6058
true,
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
use ff_ext::ExtensionField;
2+
use std::marker::PhantomData;
3+
4+
use crate::{
5+
Value,
6+
chip_handler::general::InstFetch,
7+
circuit_builder::CircuitBuilder,
8+
error::ZKVMError,
9+
instructions::{
10+
Instruction,
11+
riscv::{
12+
constants::{PC_BITS, UINT_LIMBS, UInt},
13+
i_insn::IInstructionConfig,
14+
insn_base::{MemAddr, ReadRS1, StateInOut, WriteRD},
15+
},
16+
},
17+
structs::ProgramParams,
18+
tables::InsnRecord,
19+
utils::imm_sign_extend,
20+
witness::{LkMultiplicity, set_val},
21+
};
22+
use ceno_emul::{InsnKind, PC_STEP_SIZE};
23+
use ff_ext::FieldInto;
24+
use multilinear_extensions::{Expression, ToExpr, WitIn};
25+
use p3::field::{Field, FieldAlgebra};
26+
27+
pub struct JalrConfig<E: ExtensionField> {
28+
pub i_insn: IInstructionConfig<E>,
29+
pub rs1_read: UInt<E>,
30+
pub imm: WitIn,
31+
pub imm_sign: WitIn,
32+
pub jump_pc_addr: MemAddr<E>,
33+
pub rd_high: WitIn,
34+
}
35+
36+
pub struct JalrInstruction<E>(PhantomData<E>);
37+
38+
/// JALR instruction circuit
39+
/// NOTE: does not validate that next_pc is aligned by 4-byte increments, which
40+
/// should be verified by lookup argument of the next execution step against
41+
/// the program table
42+
impl<E: ExtensionField> Instruction<E> for JalrInstruction<E> {
43+
type InstructionConfig = JalrConfig<E>;
44+
45+
fn name() -> String {
46+
format!("{:?}", InsnKind::JALR)
47+
}
48+
49+
fn construct_circuit(
50+
circuit_builder: &mut CircuitBuilder<E>,
51+
_params: &ProgramParams,
52+
) -> Result<JalrConfig<E>, ZKVMError> {
53+
assert_eq!(UINT_LIMBS, 2);
54+
let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; // unsigned 32-bit value
55+
let imm = circuit_builder.create_witin(|| "imm"); // signed 12-bit value
56+
let imm_sign = circuit_builder.create_witin(|| "imm_sign");
57+
// State in and out
58+
let vm_state = StateInOut::construct_circuit(circuit_builder, true)?;
59+
let rd_high = circuit_builder.create_witin(|| "rd_high");
60+
let rd_low: Expression<_> = vm_state.pc.expr()
61+
+ E::BaseField::from_canonical_usize(PC_STEP_SIZE).expr()
62+
- rd_high.expr() * E::BaseField::from_canonical_u32(1 << UInt::<E>::LIMB_BITS).expr();
63+
// rd range check
64+
// rd_low
65+
circuit_builder.assert_ux_v2(|| "rd_low_u16", rd_low.expr(), UInt::<E>::LIMB_BITS)?;
66+
// rd_high
67+
circuit_builder.assert_ux_v2(
68+
|| "rd_high_range",
69+
rd_high.expr(),
70+
PC_BITS - UInt::<E>::LIMB_BITS,
71+
)?;
72+
let rd_uint = UInt::from_exprs_unchecked(vec![rd_low.expr(), rd_high.expr()]);
73+
74+
let jump_pc_addr = MemAddr::construct_with_max_bits(circuit_builder, 0, PC_BITS)?;
75+
76+
// Registers
77+
let rs1 =
78+
ReadRS1::construct_circuit(circuit_builder, rs1_read.register_expr(), vm_state.ts)?;
79+
let rd = WriteRD::construct_circuit(circuit_builder, rd_uint.register_expr(), vm_state.ts)?;
80+
81+
// Fetch the instruction.
82+
circuit_builder.lk_fetch(&InsnRecord::new(
83+
vm_state.pc.expr(),
84+
InsnKind::JALR.into(),
85+
Some(rd.id.expr()),
86+
rs1.id.expr(),
87+
0.into(),
88+
imm.expr(),
89+
imm_sign.expr(),
90+
))?;
91+
92+
let i_insn = IInstructionConfig { vm_state, rs1, rd };
93+
94+
// Next pc is obtained by rounding rs1+imm down to an even value.
95+
// To implement this, check three conditions:
96+
// 1. rs1 + imm = jump_pc_addr + overflow*2^32
97+
// 3. next_pc = jump_pc_addr aligned to even value (round down)
98+
99+
let inv = E::BaseField::from_canonical_u32(1 << UInt::<E>::LIMB_BITS).inverse();
100+
101+
let carry = (rs1_read.expr()[0].expr() + imm.expr()
102+
- jump_pc_addr.uint_unaligned().expr()[0].expr())
103+
* inv.expr();
104+
circuit_builder.assert_bit(|| "carry_lo_bit", carry.expr())?;
105+
106+
let imm_extend_limb = imm_sign.expr()
107+
* E::BaseField::from_canonical_u32((1 << UInt::<E>::LIMB_BITS) - 1).expr();
108+
let carry = (rs1_read.expr()[1].expr() + imm_extend_limb.expr() + carry
109+
- jump_pc_addr.uint_unaligned().expr()[1].expr())
110+
* inv.expr();
111+
circuit_builder.assert_bit(|| "overflow_bit", carry)?;
112+
113+
circuit_builder.require_equal(
114+
|| "jump_pc_addr = next_pc",
115+
jump_pc_addr.expr_align2(),
116+
i_insn.vm_state.next_pc.unwrap().expr(),
117+
)?;
118+
119+
// write pc+4 to rd
120+
circuit_builder.require_equal(
121+
|| "rd_written = pc+4",
122+
rd_uint.value(), // this operation is safe
123+
i_insn.vm_state.pc.expr() + PC_STEP_SIZE,
124+
)?;
125+
126+
Ok(JalrConfig {
127+
i_insn,
128+
rs1_read,
129+
imm,
130+
imm_sign,
131+
jump_pc_addr,
132+
rd_high,
133+
})
134+
}
135+
136+
fn assign_instance(
137+
config: &Self::InstructionConfig,
138+
instance: &mut [E::BaseField],
139+
lk_multiplicity: &mut LkMultiplicity,
140+
step: &ceno_emul::StepRecord,
141+
) -> Result<(), ZKVMError> {
142+
let insn = step.insn();
143+
144+
let rs1 = step.rs1().unwrap().value;
145+
let imm = InsnRecord::<E::BaseField>::imm_internal(&insn);
146+
set_val!(instance, config.imm, imm.1);
147+
// according to riscvim32 spec, imm always do signed extension
148+
let imm_sign_extend = imm_sign_extend(true, step.insn().imm as i16);
149+
set_val!(
150+
instance,
151+
config.imm_sign,
152+
E::BaseField::from_bool(imm_sign_extend[1] > 0)
153+
);
154+
let rd = Value::new_unchecked(step.rd().unwrap().value.after);
155+
let rd_limb = rd.as_u16_limbs();
156+
lk_multiplicity.assert_ux_v2(rd_limb[0] as u64, 16);
157+
lk_multiplicity.assert_ux_v2(rd_limb[1] as u64, PC_BITS - 16);
158+
159+
config
160+
.rs1_read
161+
.assign_value(instance, Value::new_unchecked(rs1));
162+
set_val!(
163+
instance,
164+
config.rd_high,
165+
E::BaseField::from_canonical_u16(rd_limb[1])
166+
);
167+
168+
let (sum, _) = rs1.overflowing_add_signed(i32::from_ne_bytes([
169+
imm_sign_extend[0] as u8,
170+
(imm_sign_extend[0] >> 8) as u8,
171+
imm_sign_extend[1] as u8,
172+
(imm_sign_extend[1] >> 8) as u8,
173+
]));
174+
config
175+
.jump_pc_addr
176+
.assign_instance(instance, lk_multiplicity, sum)?;
177+
178+
config
179+
.i_insn
180+
.assign_instance(instance, lk_multiplicity, step)?;
181+
182+
Ok(())
183+
}
184+
}

0 commit comments

Comments
 (0)