Skip to content

Commit 8151d72

Browse files
committed
jal migrated
1 parent d1a6040 commit 8151d72

File tree

2 files changed

+121
-0
lines changed

2 files changed

+121
-0
lines changed

ceno_zkvm/src/instructions/riscv/jump.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1+
#[cfg(not(feature = "u16limb_circuit"))]
12
mod jal;
3+
#[cfg(feature = "u16limb_circuit")]
4+
mod jal_v2;
25
mod jalr;
36

7+
#[cfg(not(feature = "u16limb_circuit"))]
48
pub use jal::JalInstruction;
9+
#[cfg(feature = "u16limb_circuit")]
10+
pub use jal_v2::JalInstruction;
11+
512
pub use jalr::JalrInstruction;
613

714
#[cfg(test)]
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
use std::marker::PhantomData;
2+
3+
use ff_ext::ExtensionField;
4+
5+
use crate::{
6+
Value,
7+
circuit_builder::CircuitBuilder,
8+
error::ZKVMError,
9+
instructions::{
10+
Instruction,
11+
riscv::{
12+
constants::{BIT_WIDTH, PC_BITS, UINT_BYTE_LIMBS, UInt8},
13+
j_insn::JInstructionConfig,
14+
},
15+
},
16+
structs::ProgramParams,
17+
utils::split_to_u8,
18+
witness::LkMultiplicity,
19+
};
20+
use ceno_emul::{InsnKind, PC_STEP_SIZE};
21+
use gkr_iop::tables::{LookupTable, ops::XorTable};
22+
use multilinear_extensions::{Expression, ToExpr};
23+
use p3::field::FieldAlgebra;
24+
25+
pub struct JalConfig<E: ExtensionField> {
26+
pub j_insn: JInstructionConfig<E>,
27+
pub rd_written: UInt8<E>,
28+
}
29+
30+
pub struct JalInstruction<E>(PhantomData<E>);
31+
32+
/// JAL instruction circuit
33+
///
34+
/// Note: does not validate that next_pc is aligned by 4-byte increments, which
35+
/// should be verified by lookup argument of the next execution step against
36+
/// the program table
37+
///
38+
/// Assumption: values for valid initial program counter must lie between
39+
/// 2^20 and 2^32 - 2^20 + 2 inclusive, probably enforced by the static
40+
/// program lookup table. If this assumption does not hold, then resulting
41+
/// value for next_pc may not correctly wrap mod 2^32 because of the use
42+
/// of native WitIn values for address space arithmetic.
43+
impl<E: ExtensionField> Instruction<E> for JalInstruction<E> {
44+
type InstructionConfig = JalConfig<E>;
45+
46+
fn name() -> String {
47+
format!("{:?}", InsnKind::JAL)
48+
}
49+
50+
fn construct_circuit(
51+
circuit_builder: &mut CircuitBuilder<E>,
52+
_params: &ProgramParams,
53+
) -> Result<JalConfig<E>, ZKVMError> {
54+
let rd_written = UInt8::new(|| "rd_written", circuit_builder)?;
55+
let rd_exprs = rd_written.expr();
56+
57+
let j_insn = JInstructionConfig::construct_circuit(
58+
circuit_builder,
59+
InsnKind::JAL,
60+
rd_written.register_expr(),
61+
)?;
62+
63+
// constrain rd_exprs [PC_BITS .. u32::BITS] are all 0 via xor
64+
let last_limb_bits = PC_BITS - UInt8::<E>::LIMB_BITS * (UInt8::<E>::NUM_LIMBS - 1);
65+
let additional_bits =
66+
(last_limb_bits..UInt8::<E>::LIMB_BITS).fold(0, |acc, x| acc + (1 << x));
67+
let additional_bits = E::BaseField::from_canonical_u32(additional_bits);
68+
circuit_builder.logic_u8(
69+
LookupTable::Xor,
70+
rd_exprs[3].expr(),
71+
additional_bits.expr(),
72+
rd_exprs[3].expr() + additional_bits.expr(),
73+
)?;
74+
75+
circuit_builder.require_equal(
76+
|| "jal rd_written",
77+
rd_exprs
78+
.iter()
79+
.enumerate()
80+
.fold(Expression::ZERO, |acc, (i, val)| {
81+
acc + val.expr()
82+
* E::BaseField::from_canonical_u32(1 << (i * UInt8::<E>::LIMB_BITS)).expr()
83+
}),
84+
j_insn.vm_state.pc.expr() + PC_STEP_SIZE,
85+
)?;
86+
87+
Ok(JalConfig { j_insn, rd_written })
88+
}
89+
90+
fn assign_instance(
91+
config: &Self::InstructionConfig,
92+
instance: &mut [E::BaseField],
93+
lk_multiplicity: &mut LkMultiplicity,
94+
step: &ceno_emul::StepRecord,
95+
) -> Result<(), ZKVMError> {
96+
config
97+
.j_insn
98+
.assign_instance(instance, lk_multiplicity, step)?;
99+
100+
let rd_written = split_to_u8(step.rd().unwrap().value.after);
101+
config.rd_written.assign_limbs(instance, &rd_written);
102+
for val in &rd_written {
103+
lk_multiplicity.assert_ux::<8>(*val as u64);
104+
}
105+
106+
// constrain pc msb limb range via xor
107+
let last_limb_bits = PC_BITS - UInt8::<E>::LIMB_BITS * (UINT_BYTE_LIMBS - 1);
108+
let additional_bits =
109+
(last_limb_bits..UInt8::<E>::LIMB_BITS).fold(0, |acc, x| acc + (1 << x));
110+
lk_multiplicity.logic_u8::<XorTable>(rd_written[3] as u64, additional_bits as u64);
111+
112+
Ok(())
113+
}
114+
}

0 commit comments

Comments
 (0)