From 9f956c73dec2a8662cfc10bee1c66c8d11187cac Mon Sep 17 00:00:00 2001 From: Peyman Jabbarzade Date: Tue, 19 Aug 2025 11:41:59 -0400 Subject: [PATCH 01/14] feat: initial code for memcpy chip --- Cargo.lock | 31 + Cargo.toml | 2 + .../primitives/src/assert_less_than/mod.rs | 10 +- crates/toolchain/openvm/src/memcpy.s | 88 +-- .../system/memory/offline_checker/columns.rs | 19 + .../src/system/memory/offline_checker/mod.rs | 34 +- extensions/memcpy/README.md | 86 +++ extensions/memcpy/circuit/Cargo.toml | 24 + extensions/memcpy/circuit/src/bus.rs | 100 +++ extensions/memcpy/circuit/src/core.rs | 712 ++++++++++++++++++ extensions/memcpy/circuit/src/extension.rs | 140 ++++ extensions/memcpy/circuit/src/iteration.rs | 499 ++++++++++++ extensions/memcpy/circuit/src/lib.rs | 17 + extensions/memcpy/tests.rs | 444 +++++++++++ extensions/memcpy/transpiler/Cargo.toml | 17 + extensions/memcpy/transpiler/src/lib.rs | 60 ++ 16 files changed, 2197 insertions(+), 86 deletions(-) create mode 100644 extensions/memcpy/README.md create mode 100644 extensions/memcpy/circuit/Cargo.toml create mode 100644 extensions/memcpy/circuit/src/bus.rs create mode 100644 extensions/memcpy/circuit/src/core.rs create mode 100644 extensions/memcpy/circuit/src/extension.rs create mode 100644 extensions/memcpy/circuit/src/iteration.rs create mode 100644 extensions/memcpy/circuit/src/lib.rs create mode 100644 extensions/memcpy/tests.rs create mode 100644 extensions/memcpy/transpiler/Cargo.toml create mode 100644 extensions/memcpy/transpiler/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index dcbbace47c..b0b6e8e389 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5740,6 +5740,37 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "openvm-memcpy-circuit" +version = "1.4.0-rc.2" +dependencies = [ + "derive-new 0.6.0", + "derive_more 1.0.0", + "openvm-circuit", + "openvm-circuit-derive", + "openvm-circuit-primitives", + "openvm-circuit-primitives-derive", + "openvm-instructions", + "openvm-memcpy-transpiler", + "openvm-rv32im-circuit", + "openvm-rv32im-transpiler", + "openvm-stark-backend", + "serde", + "strum", +] + +[[package]] +name = "openvm-memcpy-transpiler" +version = "1.4.0-rc.2" +dependencies = [ + "openvm-instructions", + "openvm-instructions-derive", + "openvm-stark-backend", + "openvm-transpiler", + "rrs-lib", + "strum", +] + [[package]] name = "openvm-mod-circuit-builder" version = "1.4.1-rc.0" diff --git a/Cargo.toml b/Cargo.toml index 2c5d1bbade..b83495e41e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,8 @@ members = [ "extensions/ecc/tests", "extensions/pairing/circuit", "extensions/pairing/guest", + "extensions/memcpy/circuit", + "extensions/memcpy/transpiler", "guest-libs/ff_derive/", "guest-libs/k256/", "guest-libs/p256/", diff --git a/crates/circuits/primitives/src/assert_less_than/mod.rs b/crates/circuits/primitives/src/assert_less_than/mod.rs index 53054c713a..1ec12a8fd5 100644 --- a/crates/circuits/primitives/src/assert_less_than/mod.rs +++ b/crates/circuits/primitives/src/assert_less_than/mod.rs @@ -3,7 +3,7 @@ use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::AirBuilder, - p3_field::{Field, FieldAlgebra}, + p3_field::{Field, FieldAlgebra, PrimeField32}, }; use crate::{ @@ -58,6 +58,14 @@ pub struct LessThanAuxCols { pub lower_decomp: [T; AUX_LEN], } +impl Default for LessThanAuxCols { + fn default() -> Self { + Self { + lower_decomp: [F::ZERO; AUX_LEN], + } + } +} + /// This is intended for use as a **SubAir**, not as a standalone Air. /// /// This SubAir constrains that `x < y` when `count != 0`, assuming diff --git a/crates/toolchain/openvm/src/memcpy.s b/crates/toolchain/openvm/src/memcpy.s index e0043ec220..3f787e5d52 100644 --- a/crates/toolchain/openvm/src/memcpy.s +++ b/crates/toolchain/openvm/src/memcpy.s @@ -252,30 +252,7 @@ memcpy: addi a3, a4, 16 li a4, 16 .LBBmemcpy0_9: - lw a6, -12(a3) - srli a5, a5, 24 - slli a7, a6, 8 - lw t0, -8(a3) - or a5, a7, a5 - sw a5, 0(a1) - srli a5, a6, 24 - slli a6, t0, 8 - lw a7, -4(a3) - or a5, a6, a5 - sw a5, 4(a1) - srli a6, t0, 24 - slli t0, a7, 8 - lw a5, 0(a3) - or a6, t0, a6 - sw a6, 8(a1) - srli a6, a7, 24 - slli a7, a5, 8 - or a6, a7, a6 - sw a6, 12(a1) - addi a1, a1, 16 - addi a2, a2, -16 - addi a3, a3, 16 - bltu a4, a2, .LBBmemcpy0_9 + memcpy_loop 1 addi a4, a3, -13 j .LBBmemcpy0_25 .LBBmemcpy0_11: @@ -288,18 +265,7 @@ memcpy: bltu a2, a1, .LBBmemcpy0_15 li a1, 15 .LBBmemcpy0_14: - lw a5, 0(a4) - lw a6, 4(a4) - lw a7, 8(a4) - lw t0, 12(a4) - sw a5, 0(a3) - sw a6, 4(a3) - sw a7, 8(a3) - sw t0, 12(a3) - addi a4, a4, 16 - addi a2, a2, -16 - addi a3, a3, 16 - bltu a1, a2, .LBBmemcpy0_14 + memcpy_loop 0 .LBBmemcpy0_15: andi a1, a2, 8 beqz a1, .LBBmemcpy0_17 @@ -325,30 +291,7 @@ memcpy: addi a3, a4, 16 li a4, 18 .LBBmemcpy0_20: - lw a6, -12(a3) - srli a5, a5, 8 - slli a7, a6, 24 - lw t0, -8(a3) - or a5, a7, a5 - sw a5, 0(a1) - srli a5, a6, 8 - slli a6, t0, 24 - lw a7, -4(a3) - or a5, a6, a5 - sw a5, 4(a1) - srli a6, t0, 8 - slli t0, a7, 24 - lw a5, 0(a3) - or a6, t0, a6 - sw a6, 8(a1) - srli a6, a7, 8 - slli a7, a5, 24 - or a6, a7, a6 - sw a6, 12(a1) - addi a1, a1, 16 - addi a2, a2, -16 - addi a3, a3, 16 - bltu a4, a2, .LBBmemcpy0_20 + memcpy_loop 3 addi a4, a3, -15 j .LBBmemcpy0_25 .LBBmemcpy0_22: @@ -361,30 +304,7 @@ memcpy: addi a3, a4, 16 li a4, 17 .LBBmemcpy0_23: - lw a6, -12(a3) - srli a5, a5, 16 - slli a7, a6, 16 - lw t0, -8(a3) - or a5, a7, a5 - sw a5, 0(a1) - srli a5, a6, 16 - slli a6, t0, 16 - lw a7, -4(a3) - or a5, a6, a5 - sw a5, 4(a1) - srli a6, t0, 16 - slli t0, a7, 16 - lw a5, 0(a3) - or a6, t0, a6 - sw a6, 8(a1) - srli a6, a7, 16 - slli a7, a5, 16 - or a6, a7, a6 - sw a6, 12(a1) - addi a1, a1, 16 - addi a2, a2, -16 - addi a3, a3, 16 - bltu a4, a2, .LBBmemcpy0_23 + memcpy_loop 2 addi a4, a3, -14 .LBBmemcpy0_25: mv a3, a1 diff --git a/crates/vm/src/system/memory/offline_checker/columns.rs b/crates/vm/src/system/memory/offline_checker/columns.rs index ef9821f859..9225b813b2 100644 --- a/crates/vm/src/system/memory/offline_checker/columns.rs +++ b/crates/vm/src/system/memory/offline_checker/columns.rs @@ -26,6 +26,15 @@ impl MemoryBaseAuxCols { } } +impl Default for MemoryBaseAuxCols { + fn default() -> Self { + Self { + prev_timestamp: F::ZERO, + timestamp_lt_aux: LessThanAuxCols::default(), + } + } +} + #[repr(C)] #[derive(Clone, Copy, Debug, AlignedBorrow)] pub struct MemoryWriteAuxCols { @@ -43,6 +52,11 @@ impl MemoryWriteAuxCols { self.base } + #[inline(always)] + pub fn set_base(&mut self, base: MemoryBaseAuxCols) { + self.base = base; + } + #[inline(always)] pub fn prev_data(&self) -> &[T; N] { &self.prev_data @@ -80,6 +94,11 @@ impl MemoryReadAuxCols { self.base } + #[inline(always)] + pub fn set_base(&mut self, base: MemoryBaseAuxCols) { + self.base = base; + } + /// Sets the previous timestamp **without** updating the less than auxiliary columns. #[inline(always)] pub fn set_prev(&mut self, timestamp: F) { diff --git a/crates/vm/src/system/memory/offline_checker/mod.rs b/crates/vm/src/system/memory/offline_checker/mod.rs index 8b15328185..c903319a23 100644 --- a/crates/vm/src/system/memory/offline_checker/mod.rs +++ b/crates/vm/src/system/memory/offline_checker/mod.rs @@ -5,13 +5,45 @@ mod columns; pub use bridge::*; pub use bus::*; pub use columns::*; +use openvm_circuit_primitives::is_less_than::LessThanAuxCols; +use openvm_stark_backend::p3_field::PrimeField32; #[repr(C)] #[derive(Debug, Clone)] -pub struct MemoryReadAuxRecord { +pub struct MemoryBaseAuxRecord { pub prev_timestamp: u32, } +#[repr(C)] +#[derive(Debug, Clone)] +pub struct MemoryExtendedAuxRecord { + pub prev_timestamp: u32, + pub timestamp_lt_aux: [u32; AUX_LEN], +} + +impl MemoryExtendedAuxRecord { + pub fn from_aux_cols(aux_cols: MemoryBaseAuxCols) -> Self { + Self { + prev_timestamp: aux_cols.prev_timestamp.as_canonical_u32(), + timestamp_lt_aux: aux_cols + .timestamp_lt_aux + .lower_decomp + .map(|x| x.as_canonical_u32()), + } + } + + pub fn to_aux_cols(&self) -> MemoryBaseAuxCols { + MemoryBaseAuxCols { + prev_timestamp: F::from_canonical_u32(self.prev_timestamp), + timestamp_lt_aux: LessThanAuxCols { + lower_decomp: self.timestamp_lt_aux.map(|x| F::from_canonical_u32(x)), + }, + } + } +} + +pub type MemoryReadAuxRecord = MemoryBaseAuxRecord; + #[repr(C)] #[derive(Debug, Clone)] pub struct MemoryWriteAuxRecord { diff --git a/extensions/memcpy/README.md b/extensions/memcpy/README.md new file mode 100644 index 0000000000..da381e2e3c --- /dev/null +++ b/extensions/memcpy/README.md @@ -0,0 +1,86 @@ +# OpenVM Memcpy Extension + +This extension provides a custom RISC-V instruction `memcpy_loop` that optimizes memory copy operations by handling different alignment shifts efficiently. + +## Custom Instruction: `memcpy_loop shift` + +### Format +``` +memcpy_loop shift +``` + +Where `shift` is an immediate value (0, 1, 2, or 3) representing the byte alignment shift. + +### RISC-V Encoding +- **Opcode**: `0x73` (custom opcode) +- **Funct3**: `0x0` (custom funct3) +- **Immediate**: 12-bit signed immediate for shift value +- **Format**: I-type instruction + +### Usage +The `memcpy_loop` instruction is designed to replace repetitive shift-handling code in memcpy implementations. Instead of having separate code blocks for each shift value, you can use a single instruction: + +```assembly +# Instead of this repetitive code: +.Lshift_1: + lw a5, 0(a4) + sb a5, 0(a3) + srli a1, a5, 8 + sb a1, 1(a3) + # ... more shift handling code + +# You can use: +memcpy_loop 1 # Handles shift=1 case +``` + +### Benefits +1. **Code Size Reduction**: Eliminates repetitive shift-handling code +2. **Performance**: Optimized implementation in the circuit layer +3. **Maintainability**: Single instruction handles all shift cases +4. **Verification**: Zero-knowledge proof ensures correct execution + +## Implementation Details + +### Circuit Layer +The instruction is implemented in the `MemcpyIterationAir` circuit which: +- Reads 4 words (16 bytes) from memory +- Applies the specified shift to combine words +- Writes the result to the destination +- Handles all shift values (0, 1, 2, 3) efficiently + +### Transpiler Extension +The `MemcpyTranspilerExtension` translates the RISC-V instruction into OpenVM's internal format: +- Parses I-type instruction format +- Validates shift value (0-3) +- Converts to OpenVM instruction with shift as operand + +### Example Usage +See `example_memcpy_optimized.s` for a complete example showing how to use the custom instruction to optimize a memcpy implementation. + +## Building and Testing + +### Compilation +```bash +# Build the extension +cargo build --package openvm-memcpy-circuit --package openvm-memcpy-transpiler + +# Check for compilation errors +cargo check --package openvm-memcpy-circuit --package openvm-memcpy-transpiler +``` + +### Integration +To use this extension in your OpenVM project: + +1. Add the transpiler extension to your OpenVM configuration +2. Use the `memcpy_loop` instruction in your RISC-V assembly +3. The circuit will handle the execution and verification + +## Architecture + +``` +RISC-V Assembly → Transpiler Extension → OpenVM Instruction → MemcpyIterationAir → Execution +``` + +The extension provides: +- **Transpiler**: `extensions/memcpy/transpiler/` - Translates RISC-V to OpenVM +- **Circuit**: `extensions/memcpy/circuit/` - Implements the instruction logic diff --git a/extensions/memcpy/circuit/Cargo.toml b/extensions/memcpy/circuit/Cargo.toml new file mode 100644 index 0000000000..5bb1a9d2c0 --- /dev/null +++ b/extensions/memcpy/circuit/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "openvm-memcpy-circuit" +description = "OpenVM circuit extension for memcpy" +version.workspace = true +authors.workspace = true +edition.workspace = true +homepage.workspace = true +repository.workspace = true + +[dependencies] +openvm-circuit = { workspace = true } +openvm-circuit-primitives = { workspace = true } +openvm-circuit-primitives-derive = { workspace = true } +openvm-circuit-derive = { workspace = true } +openvm-instructions = { workspace = true } +openvm-stark-backend = { workspace = true } +openvm-memcpy-transpiler = { path = "../transpiler" } +openvm-rv32im-transpiler = { workspace = true } +openvm-rv32im-circuit = { workspace = true } + +derive-new.workspace = true +derive_more = { workspace = true, features = ["from"] } +serde.workspace = true +strum = { workspace = true } diff --git a/extensions/memcpy/circuit/src/bus.rs b/extensions/memcpy/circuit/src/bus.rs new file mode 100644 index 0000000000..fcb77932c2 --- /dev/null +++ b/extensions/memcpy/circuit/src/bus.rs @@ -0,0 +1,100 @@ +use std::iter; + +use openvm_stark_backend::{ + interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, + p3_field::FieldAlgebra, +}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct MemcpyBus { + pub inner: PermutationCheckBus, +} + +impl MemcpyBus { + pub const fn new(index: BusIndex) -> Self { + Self { + inner: PermutationCheckBus::new(index), + } + } +} + +impl MemcpyBus { + #[inline(always)] + pub fn index(&self) -> BusIndex { + self.inner.index + } + + pub fn send( + &self, + timestamp: impl Into, + dest: impl Into, + source: impl Into, + n: impl Into, + shift: impl Into, + ) -> MemcpyBusInteraction { + self.push(true, timestamp, dest, source, n, shift) + } + + pub fn receive( + &self, + timestamp: impl Into, + dest: impl Into, + source: impl Into, + n: impl Into, + shift: impl Into, + ) -> MemcpyBusInteraction { + self.push(false, timestamp, dest, source, n, shift) + } + + fn push( + &self, + is_send: bool, + timestamp: impl Into, + dest: impl Into, + source: impl Into, + n: impl Into, + shift: impl Into, + ) -> MemcpyBusInteraction { + MemcpyBusInteraction { + bus: self.inner, + is_send, + timestamp: timestamp.into(), + dest: dest.into(), + source: source.into(), + n: n.into(), + shift: shift.into(), + } + } +} + +#[derive(Clone, Debug)] +pub struct MemcpyBusInteraction { + pub bus: PermutationCheckBus, + pub is_send: bool, + pub timestamp: T, + pub dest: T, + pub source: T, + pub n: T, + pub shift: T, +} + +impl MemcpyBusInteraction { + pub fn eval(self, builder: &mut AB, direction: impl Into) + where + AB: InteractionBuilder, + { + let fields = iter::empty() + .chain(iter::once(self.timestamp)) + .chain(iter::once(self.dest)) + .chain(iter::once(self.source)) + .chain(iter::once(self.n)) + .chain(iter::once(self.shift)); + + if self.is_send { + self.bus.interact(builder, fields, direction); + } else { + self.bus + .interact(builder, fields, AB::Expr::NEG_ONE * direction.into()); + } + } +} diff --git a/extensions/memcpy/circuit/src/core.rs b/extensions/memcpy/circuit/src/core.rs new file mode 100644 index 0000000000..43b465bfb1 --- /dev/null +++ b/extensions/memcpy/circuit/src/core.rs @@ -0,0 +1,712 @@ +use std::{ + array, + borrow::{Borrow, BorrowMut}, + mem::size_of, + sync::Arc, +}; + +use openvm_circuit::{ + arch::*, + system::memory::{ + offline_checker::{ + MemoryBaseAuxCols, MemoryBaseAuxRecord, MemoryBridge, MemoryReadAuxRecord, + MemoryWriteAuxCols, MemoryWriteBytesAuxRecord, + }, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, + }, +}; +use openvm_circuit_primitives::{ + utils::{not, or, select}, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + AlignedBytesBorrow, +}; +use openvm_circuit_primitives_derive::AlignedBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; +use openvm_rv32im_circuit::adapters::{read_rv32_register, tracing_read, tracing_write}; +use openvm_stark_backend::{ + interaction::InteractionBuilder, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_matrix::Matrix, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, +}; + +use crate::{bus::MemcpyBus, MemcpyIterChip}; +use openvm_circuit::arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, ExecuteFunc, ExecutionError, Executor, + MeteredExecutor, RecordArena, StaticProgramError, TraceFiller, VmExecState, +}; +use openvm_memcpy_transpiler::Rv32MemcpyOpcode; + +// Import constants from lib.rs +use crate::{ + A1_REGISTER_PTR, A2_REGISTER_PTR, A3_REGISTER_PTR, A4_REGISTER_PTR, MEMCPY_LOOP_LIMB_BITS, + MEMCPY_LOOP_NUM_LIMBS, +}; + +#[repr(C)] +#[derive(AlignedBorrow, Debug)] +pub struct MemcpyLoopCols { + pub from_state: ExecutionState, + pub dest: [T; MEMCPY_LOOP_NUM_LIMBS], + pub source: [T; MEMCPY_LOOP_NUM_LIMBS], + pub len: [T; MEMCPY_LOOP_NUM_LIMBS], + pub shift: [T; 2], + pub is_valid: T, + pub to_timestamp: T, + pub to_dest: [T; MEMCPY_LOOP_NUM_LIMBS], + pub to_source: [T; MEMCPY_LOOP_NUM_LIMBS], + pub to_len: T, + pub write_aux: [MemoryBaseAuxCols; 3], + pub source_minus_twelve_carry: T, + pub to_source_minus_twelve_carry: T, +} + +#[derive(Copy, Clone, Debug, derive_new::new)] +pub struct MemcpyLoopAir { + pub memory_bridge: MemoryBridge, + pub execution_bridge: ExecutionBridge, + pub range_bus: VariableRangeCheckerBus, + pub memcpy_bus: MemcpyBus, + pub pointer_max_bits: usize, +} + +impl BaseAir for MemcpyLoopAir { + fn width(&self) -> usize { + MemcpyLoopCols::::width() + } +} + +impl BaseAirWithPublicValues for MemcpyLoopAir {} +impl PartitionedBaseAir for MemcpyLoopAir {} + +impl Air for MemcpyLoopAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let local: &MemcpyLoopCols = (*local).borrow(); + + let timestamp: AB::Var = local.from_state.timestamp; + let mut timestamp_delta: usize = 0; + let mut timestamp_pp = || { + timestamp_delta += 1; + timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1) + }; + + let from_le_bytes = |data: [AB::Var; 4]| { + data.iter().fold(AB::Expr::ZERO, |acc, x| { + acc * AB::Expr::from_canonical_u32(256) + *x + }) + }; + + let u8_word_to_u16 = |data: [AB::Var; 4]| { + [ + data[0] + data[1] * AB::F::from_canonical_u32(1 << MEMCPY_LOOP_LIMB_BITS), + data[2] + data[3] * AB::F::from_canonical_u32(1 << MEMCPY_LOOP_LIMB_BITS), + ] + }; + + let shift = local.shift[1] * AB::Expr::from_canonical_u32(2) + local.shift[0]; + let is_shift_non_zero = or::(local.shift[0], local.shift[1]); + let dest = from_le_bytes(local.dest); + let source = from_le_bytes(local.source); + let len = from_le_bytes(local.len); + let to_dest = from_le_bytes(local.to_dest); + let to_source = from_le_bytes(local.to_source); + let to_len = local.to_len; + + builder.assert_bool(local.is_valid); + for i in 0..2 { + builder.assert_bool(local.shift[i]); + } + builder.assert_bool(local.source_minus_twelve_carry); + builder.assert_bool(local.to_source_minus_twelve_carry); + + let mut shift_zero_when = builder.when(not::(is_shift_non_zero.clone())); + shift_zero_when.assert_zero(local.source_minus_twelve_carry); + shift_zero_when.assert_zero(local.to_source_minus_twelve_carry); + + // Write source and destination to registers + let write_data = [ + (local.dest, local.to_dest, A1_REGISTER_PTR, A3_REGISTER_PTR), + ( + local.source, + local.to_source, + A2_REGISTER_PTR, + A4_REGISTER_PTR, + ), + ]; + + write_data + .iter() + .enumerate() + .for_each(|(idx, (dest, to_dest, ptr, zero_shift_ptr))| { + let write_ptr = select::( + is_shift_non_zero.clone(), + AB::Expr::from_canonical_usize(*ptr), + AB::Expr::from_canonical_usize(*zero_shift_ptr), + ); + + self.memory_bridge + .write( + MemoryAddress::new(AB::Expr::from_canonical_u32(RV32_MEMORY_AS), write_ptr), + *to_dest, + timestamp_pp(), + &MemoryWriteAuxCols::from_base(local.write_aux[idx], *dest), + ) + .eval(builder, local.is_valid); + }); + + // Write length to a2 register + self.memory_bridge + .write( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + AB::Expr::from_canonical_usize(A2_REGISTER_PTR), + ), + [ + to_len.into(), + AB::Expr::ZERO, + AB::Expr::ZERO, + AB::Expr::ZERO, + ], + timestamp_pp(), + &MemoryWriteAuxCols::from_base(local.write_aux[2], local.len), + ) + .eval(builder, local.is_valid); + + // Generate 16-bit limbs for range checking + let len_u16_limbs = u8_word_to_u16(local.len); + let dest_u16_limbs = u8_word_to_u16(local.dest); + let to_dest_u16_limbs = u8_word_to_u16(local.to_dest); + let source_u16_limbs = [ + local.source[0] + + local.source[1] * AB::F::from_canonical_u32(1 << MEMCPY_LOOP_LIMB_BITS) + - AB::Expr::from_canonical_u32(12) * is_shift_non_zero.clone() + + local.source_minus_twelve_carry + * AB::F::from_canonical_u32(1 << (2 * MEMCPY_LOOP_LIMB_BITS)), + local.source[2] + + local.source[3] * AB::F::from_canonical_u32(1 << MEMCPY_LOOP_LIMB_BITS) + - local.source_minus_twelve_carry, + ]; + let to_source_u16_limbs = [ + local.to_source[0] + + local.to_source[1] * AB::F::from_canonical_u32(1 << MEMCPY_LOOP_LIMB_BITS) + - AB::Expr::from_canonical_u32(12) * is_shift_non_zero.clone() + + local.to_source_minus_twelve_carry + * AB::F::from_canonical_u32(1 << (2 * MEMCPY_LOOP_LIMB_BITS)), + local.to_source[2] + + local.to_source[3] * AB::F::from_canonical_u32(1 << MEMCPY_LOOP_LIMB_BITS) + - local.to_source_minus_twelve_carry, + ]; + + // Range check addresses and n + let range_check_data = [ + (len_u16_limbs, false), + (dest_u16_limbs, true), + (source_u16_limbs, true), + (to_dest_u16_limbs, true), + (to_source_u16_limbs, true), + ]; + + range_check_data.iter().for_each(|(data, is_address)| { + let (data_0, num_bits) = if *is_address { + ( + data[0].clone() * AB::F::from_canonical_u32(4).inverse(), + MEMCPY_LOOP_LIMB_BITS * 2 - 2, + ) + } else { + (data[0].clone(), MEMCPY_LOOP_LIMB_BITS * 2) + }; + self.range_bus + .range_check(data_0, num_bits) + .eval(builder, local.is_valid); + self.range_bus + .range_check( + data[1].clone(), + self.pointer_max_bits - MEMCPY_LOOP_LIMB_BITS * 2, + ) + .eval(builder, local.is_valid); + }); + + // Send message to memcpy call bus + self.memcpy_bus + .send( + timestamp + AB::Expr::from_canonical_usize(timestamp_delta), + dest - AB::Expr::from_canonical_u32(16), + source + - select::( + is_shift_non_zero.clone(), + AB::Expr::from_canonical_u32(28), + AB::Expr::from_canonical_u32(16), + ), + len.clone() - shift.clone(), + shift.clone(), + ) + .eval(builder, local.is_valid); + + // Receive message from memcpy return bus + self.memcpy_bus + .receive( + local.to_timestamp, + to_dest, + to_source, + to_len - shift.clone(), + AB::Expr::from_canonical_u32(4), + ) + .eval(builder, local.is_valid); + + // Make sure the request and response match + builder.assert_eq( + local.to_timestamp - (timestamp + AB::Expr::from_canonical_usize(timestamp_delta)), + AB::Expr::TWO * (len.clone() - to_len) + is_shift_non_zero.clone(), + ); + + // Execution bus + program bus + self.execution_bridge + .execute_and_increment_pc( + AB::Expr::from_canonical_usize(Rv32MemcpyOpcode::MEMCPY_LOOP as usize), + [shift.clone()], + local.from_state, + local.to_timestamp, + ) + .eval(builder, local.is_valid); + } +} + +#[derive(derive_new::new, Clone, Copy)] +pub struct MemcpyLoopExecutor {} + +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct MemcpyLoopRecord { + pub shift: [u8; 2], + pub dest: [u8; MEMCPY_LOOP_NUM_LIMBS], + pub source: [u8; MEMCPY_LOOP_NUM_LIMBS], + pub len: [u8; MEMCPY_LOOP_NUM_LIMBS], + pub from_pc: u32, + pub from_timestamp: u32, + pub register_aux: [MemoryBaseAuxRecord; 3], + pub memory_read_data: Vec<[u8; MEMCPY_LOOP_NUM_LIMBS]>, + pub read_aux: Vec, + pub write_aux: Vec>, +} + +#[derive(derive_new::new)] +pub struct MemcpyLoopFiller { + pub pointer_max_bits: usize, + pub range_checker_chip: SharedVariableRangeCheckerChip, + pub memcpy_iter_chip: Arc, +} + +pub type MemcpyLoopChip = VmChipWrapper; + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct MemcpyLoopPreCompute { + a: u8, +} + +impl PreflightExecutor for MemcpyLoopExecutor +where + F: PrimeField32, + for<'buf> RA: RecordArena<'buf, EmptyMultiRowLayout, &'buf mut MemcpyLoopRecord>, +{ + fn get_opcode_name(&self, _: usize) -> String { + format!("{:?}", Rv32MemcpyOpcode::MEMCPY_LOOP) + } + + fn execute( + &mut self, + state: VmStateMut, + instruction: &Instruction, + ) -> Result<(), ExecutionError> { + let Instruction { opcode, a, .. } = instruction; + debug_assert_eq!(*opcode, Rv32MemcpyOpcode::MEMCPY_LOOP.global_opcode()); + let shift = a.as_canonical_u32() as u8; + debug_assert!([0, 1, 2, 3].contains(&shift)); + let mut record = state.ctx.alloc(EmptyMultiRowLayout::default()); + + let mut dest = read_rv32_register( + state.memory.data(), + if shift == 0 { + A3_REGISTER_PTR + } else { + A1_REGISTER_PTR + } as u32, + ); + let mut source = read_rv32_register( + state.memory.data(), + if shift == 0 { + A4_REGISTER_PTR + } else { + A3_REGISTER_PTR + } as u32, + ); + let mut len = read_rv32_register(state.memory.data(), A2_REGISTER_PTR as u32); + + // Store the original values in the record + record.shift = [shift % 2, shift / 2]; + record.from_pc = *state.pc; + record.from_timestamp = state.memory.timestamp; + + let num_iterations = (len - shift as u32) & !15; + let to_dest = dest + num_iterations; + let to_source = source + num_iterations; + let to_len = len - num_iterations; + + tracing_write( + state.memory, + RV32_REGISTER_AS, + if shift == 0 { + A3_REGISTER_PTR + } else { + A1_REGISTER_PTR + } as u32, + to_dest.to_le_bytes(), + &mut record.register_aux[0].prev_timestamp, + &mut record.dest, + ); + + tracing_write( + state.memory, + RV32_REGISTER_AS, + if shift == 0 { + A4_REGISTER_PTR + } else { + A3_REGISTER_PTR + } as u32, + to_source.to_le_bytes(), + &mut record.register_aux[1].prev_timestamp, + &mut record.source, + ); + + tracing_write( + state.memory, + RV32_REGISTER_AS, + A2_REGISTER_PTR as u32, + to_len.to_le_bytes(), + &mut record.register_aux[2].prev_timestamp, + &mut record.len, + ); + + let mut prev_data = if shift == 0 { + [0; 4] + } else { + source -= 12; + record + .read_aux + .push(MemoryReadAuxRecord { prev_timestamp: 0 }); + let data = tracing_read( + state.memory, + RV32_MEMORY_AS, + source - 4, + &mut record.read_aux.last_mut().unwrap().prev_timestamp, + ); + record.memory_read_data.push(data); + data + }; + + while len - shift as u32 > 15 { + let writes_data: [[u8; MEMCPY_LOOP_NUM_LIMBS]; 4] = array::from_fn(|i| { + record + .read_aux + .push(MemoryReadAuxRecord { prev_timestamp: 0 }); + let data = tracing_read( + state.memory, + RV32_MEMORY_AS, + source + 4 * i as u32, + &mut record.read_aux.last_mut().unwrap().prev_timestamp, + ); + record.memory_read_data.push(data); + let write_data: [u8; MEMCPY_LOOP_NUM_LIMBS] = array::from_fn(|i| { + if i < 4 - shift as usize { + data[i + shift as usize] + } else { + prev_data[i - (4 - shift as usize)] + } + }); + prev_data = data; + write_data + }); + writes_data.iter().enumerate().for_each(|(i, write_data)| { + record.write_aux.push(MemoryWriteBytesAuxRecord { + prev_timestamp: 0, + prev_data: [0; MEMCPY_LOOP_NUM_LIMBS], + }); + tracing_write( + state.memory, + RV32_MEMORY_AS, + dest + 4 * i as u32, + *write_data, + &mut record.write_aux.clone().last_mut().unwrap().prev_timestamp, + &mut record.write_aux.clone().last_mut().unwrap().prev_data, + ); + }); + len -= 16; + source += 16; + dest += 16; + } + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller for MemcpyLoopFiller { + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut row: &mut [F]) { + let record: &MemcpyLoopRecord = unsafe { get_record_from_slice(&mut row, ()) }; + let row: &mut MemcpyLoopCols = row.borrow_mut(); + + const NUM_WRITES: u32 = 3; + + let shift = record.shift[0] + record.shift[1] * 2; + let dest = u32::from_le_bytes(record.dest); + let source = u32::from_le_bytes(record.source); + let len = u32::from_le_bytes(record.len); + let num_copies = (len - shift as u32) & !15; + let to_dest = dest + num_copies; + let to_source = source + num_copies; + let to_len = len - num_copies; + let timestamp = record.from_timestamp; + + let source_minus_twelve_carry = if shift == 0 { + F::ZERO + } else { + F::from_canonical_u8((source % (1 << 8) < 12) as u8) + }; + let to_source_minus_twelve_carry = if shift == 0 { + F::ZERO + } else { + F::from_canonical_u8((to_source % (1 << 8) < 12) as u8) + }; + + for ((i, cols), register_aux_record) in row + .write_aux + .iter_mut() + .enumerate() + .zip(record.register_aux.iter()) + { + mem_helper.fill( + register_aux_record.prev_timestamp, + timestamp + i as u32, + cols, + ); + } + + row.source_minus_twelve_carry = source_minus_twelve_carry; + row.to_source_minus_twelve_carry = to_source_minus_twelve_carry; + row.to_dest = to_dest.to_le_bytes().map(F::from_canonical_u8); + row.to_source = to_source.to_le_bytes().map(F::from_canonical_u8); + row.to_len = F::from_canonical_u32(to_len); + row.to_timestamp = + F::from_canonical_u32(timestamp + NUM_WRITES + 2 * num_copies + (shift != 0) as u32); + row.is_valid = F::ONE; + row.dest = record.dest.map(F::from_canonical_u8); + row.source = record.source.map(F::from_canonical_u8); + row.len = record.len.map(F::from_canonical_u8); + row.shift = record.shift.map(F::from_canonical_u8); + row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + row.from_state.pc = F::from_canonical_u32(record.from_pc); + + let word_to_u16 = |data: u32| [data & 0xffff, data >> 16]; + let range_check_data = [ + (word_to_u16(len), false), + (word_to_u16(dest), true), + (word_to_u16(source - 12 * (shift != 0) as u32), true), + (word_to_u16(to_dest), true), + (word_to_u16(to_source - 12 * (shift != 0) as u32), true), + ]; + + range_check_data.iter().for_each(|(data, is_address)| { + if *is_address { + self.range_checker_chip + .add_count(data[0] >> 2, 2 * MEMCPY_LOOP_LIMB_BITS - 2) + } else { + self.range_checker_chip + .add_count(data[0], 2 * MEMCPY_LOOP_LIMB_BITS) + }; + self.range_checker_chip + .add_count(data[1], self.pointer_max_bits - 2 * MEMCPY_LOOP_LIMB_BITS); + }); + + // Handle MemcpyIter + self.memcpy_iter_chip.add_new_loop( + mem_helper, + timestamp + NUM_WRITES, + dest - 16, + source - 16 - 12 * (shift != 0) as u32, + len - shift as u32, + shift, + record.memory_read_data.clone(), + record.read_aux.clone(), + record.write_aux.clone(), + ); + } +} + +impl Executor for MemcpyLoopExecutor { + fn pre_compute_size(&self) -> usize { + size_of::() + } + + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E1ExecutionCtx, + { + let data: &mut MemcpyLoopPreCompute = data.borrow_mut(); + self.pre_compute_impl(pc, inst, data)?; + Ok(execute_e1_impl::<_, _>) + } +} + +impl MeteredExecutor for MemcpyLoopExecutor { + fn metered_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + self.pre_compute_impl(pc, inst, &mut data.data)?; + Ok(execute_e2_impl::<_, _>) + } +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &MemcpyLoopPreCompute, + vm_state: &mut VmExecState, +) { + let shift = pre_compute.a; + let (dest, source) = if shift == 0 { + ( + vm_state.vm_read::(RV32_REGISTER_AS, A3_REGISTER_PTR as u32), + vm_state.vm_read::(RV32_REGISTER_AS, A4_REGISTER_PTR as u32), + ) + } else { + ( + vm_state.vm_read::(RV32_REGISTER_AS, A1_REGISTER_PTR as u32), + vm_state.vm_read::(RV32_REGISTER_AS, A3_REGISTER_PTR as u32), + ) + }; + let len = vm_state.vm_read::(RV32_REGISTER_AS, A2_REGISTER_PTR as u32); + + let mut dest = u32::from_le_bytes(dest); + let mut source = u32::from_le_bytes(source); + let mut len = u32::from_le_bytes(len); + + let mut prev_data = if shift == 0 { + [0; 4] + } else { + source -= 12; + vm_state.vm_read::(RV32_MEMORY_AS, source - 4) + }; + + while len - shift as u32 > 15 { + for i in 0..4 { + let data = vm_state.vm_read::(RV32_MEMORY_AS, source + 4 * i); + let write_data: [u8; 4] = array::from_fn(|i| { + if i < 4 - shift as usize { + data[i + shift as usize] + } else { + prev_data[i - (4 - shift as usize)] + } + }); + vm_state.vm_write(RV32_MEMORY_AS, dest + 4 * i, &write_data); + prev_data = data; + } + len -= 16; + source += 16; + dest += 16; + } + + // Write the result back to memory + if shift == 0 { + vm_state.vm_write( + RV32_REGISTER_AS, + A3_REGISTER_PTR as u32, + &dest.to_le_bytes(), + ); + vm_state.vm_write( + RV32_REGISTER_AS, + A4_REGISTER_PTR as u32, + &source.to_le_bytes(), + ); + } else { + source += 12; + vm_state.vm_write( + RV32_REGISTER_AS, + A1_REGISTER_PTR as u32, + &dest.to_le_bytes(), + ); + vm_state.vm_write( + RV32_REGISTER_AS, + A3_REGISTER_PTR as u32, + &source.to_le_bytes(), + ); + }; + vm_state.vm_write(RV32_REGISTER_AS, A2_REGISTER_PTR as u32, &len.to_le_bytes()); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &MemcpyLoopPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl MemcpyLoopExecutor { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut MemcpyLoopPreCompute, + ) -> Result<(), StaticProgramError> { + let Instruction { opcode, a, .. } = inst; + let a_u32 = a.as_canonical_u32(); + if ![0, 1, 2, 3].contains(&a_u32) { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + *data = MemcpyLoopPreCompute { a: a_u32 as u8 }; + assert_eq!(*opcode, Rv32MemcpyOpcode::MEMCPY_LOOP.global_opcode()); + Ok(()) + } +} diff --git a/extensions/memcpy/circuit/src/extension.rs b/extensions/memcpy/circuit/src/extension.rs new file mode 100644 index 0000000000..d485618a50 --- /dev/null +++ b/extensions/memcpy/circuit/src/extension.rs @@ -0,0 +1,140 @@ +use std::{result::Result, sync::Arc}; + +use bus::MemcpyBus; +use derive_more::derive::From; +use openvm_circuit::{ + arch::{ + AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, ExecutionBridge, + ExecutorInventoryBuilder, ExecutorInventoryError, RowMajorMatrixArena, VmCircuitExtension, + VmExecutionExtension, VmProverExtension, + }, + system::{memory::SharedMemoryHelper, SystemPort}, +}; +use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor}; +use openvm_instructions::*; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + engine::StarkEngine, + p3_field::PrimeField32, + prover::cpu::{CpuBackend, CpuDevice}, +}; +use serde::{Deserialize, Serialize}; +use strum::IntoEnumIterator; + +use crate::*; + +use openvm_memcpy_transpiler::Rv32MemcpyOpcode; + +// =================================== VM Extension Implementation ================================= +#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] +pub struct Memcpy; + +#[derive(Clone, From, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)] +pub enum MemcpyExecutor { + MemcpyLoop(MemcpyLoopExecutor), +} + +impl VmExecutionExtension for Memcpy { + type Executor = MemcpyExecutor; + + fn extend_execution( + &self, + inventory: &mut ExecutorInventoryBuilder, + ) -> Result<(), ExecutorInventoryError> { + let memcpy_loop = MemcpyLoopExecutor::new(); + + inventory.add_executor( + memcpy_loop, + Rv32MemcpyOpcode::iter().map(|x| x.global_opcode()), + )?; + + Ok(()) + } +} + +impl VmCircuitExtension for Memcpy { + fn extend_circuit(&self, inventory: &mut AirInventory) -> Result<(), AirInventoryError> { + let SystemPort { + execution_bus, + program_bus, + memory_bridge, + } = inventory.system().port(); + + let execution_bridge = ExecutionBridge::new(execution_bus, program_bus); + let range_bus = inventory.range_checker().bus; + let pointer_max_bits = inventory.pointer_max_bits(); + + let memcpy_bus = MemcpyBus::new(inventory.new_bus_idx()); + + let memcpy_loop = MemcpyLoopAir::new( + memory_bridge, + execution_bridge, + range_bus, + memcpy_bus, + pointer_max_bits, + ); + inventory.add_air(memcpy_loop); + + let memcpy_iter = + MemcpyIterAir::new(memory_bridge, range_bus, memcpy_bus, pointer_max_bits); + inventory.add_air(memcpy_iter); + + Ok(()) + } +} + +pub struct MemcpyCpuProverExt; +// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker) +// are specific to CpuBackend. +impl VmProverExtension for MemcpyCpuProverExt +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + RA: RowMajorMatrixArena>, + Val: PrimeField32, +{ + fn extend_prover( + &self, + _: &Memcpy, + inventory: &mut ChipInventory>, + ) -> Result<(), ChipInventoryError> { + let range_checker = inventory.range_checker()?.clone(); + let timestamp_max_bits = inventory.timestamp_max_bits(); + let pointer_max_bits = inventory.airs().pointer_max_bits(); + let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits); + let range_bus = inventory.airs().range_checker().bus; + let memcpy_bus = inventory + .airs() + .find_air::() + .next() + .unwrap() + .memcpy_bus; + + let memcpy_iter_chip = Arc::new(MemcpyIterChip::new( + inventory.airs().system().port().memory_bridge, + range_bus, + memcpy_bus, + pointer_max_bits, + range_checker.clone(), + )); + + let memcpy_loop_chip = MemcpyLoopChip::new( + MemcpyLoopFiller::new( + pointer_max_bits, + range_checker.clone(), + memcpy_iter_chip.clone(), + ), + mem_helper.clone(), + ); + + // Add MemcpyLoop chip + inventory.next_air::()?; + inventory.add_executor_chip(memcpy_loop_chip); + + // Add MemcpyIter chip + inventory.next_air::()?; + inventory.add_periphery_chip(memcpy_iter_chip); + + Ok(()) + } +} diff --git a/extensions/memcpy/circuit/src/iteration.rs b/extensions/memcpy/circuit/src/iteration.rs new file mode 100644 index 0000000000..6191b061bd --- /dev/null +++ b/extensions/memcpy/circuit/src/iteration.rs @@ -0,0 +1,499 @@ +use std::{ + array, + borrow::{Borrow, BorrowMut}, + mem::size_of, + sync::{atomic::AtomicU32, Arc, Mutex}, +}; + +use openvm_circuit::system::memory::{ + offline_checker::{ + MemoryBaseAuxCols, MemoryBridge, MemoryExtendedAuxRecord, MemoryReadAuxCols, + MemoryReadAuxRecord, MemoryWriteAuxCols, MemoryWriteBytesAuxRecord, + }, + MemoryAddress, MemoryAuxColsFactory, +}; +use openvm_circuit_primitives::{ + utils::{and, not, or, select}, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + AlignedBytesBorrow, +}; +use openvm_circuit_primitives_derive::AlignedBorrow; +use openvm_instructions::riscv::RV32_MEMORY_AS; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + interaction::InteractionBuilder, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_matrix::{dense::RowMajorMatrix, Matrix}, + prover::{cpu::CpuBackend, types::AirProvingContext}, + rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, + Chip, ChipUsageGetter, +}; + +use crate::bus::MemcpyBus; + +// Import constants from lib.rs +use crate::{MEMCPY_LOOP_LIMB_BITS, MEMCPY_LOOP_NUM_LIMBS}; + +#[repr(C)] +#[derive(AlignedBorrow, Debug)] +pub struct MemcpyIterCols { + pub timestamp: T, + pub dest: T, + pub source: T, + pub len: [T; 2], + pub shift: [T; 2], + pub is_valid: T, + pub is_valid_not_start: T, + // -1 for the first iteration, 1 for the last iteration, 0 for the middle iterations + pub is_boundary: T, + pub data_1: [T; MEMCPY_LOOP_NUM_LIMBS], + pub data_2: [T; MEMCPY_LOOP_NUM_LIMBS], + pub data_3: [T; MEMCPY_LOOP_NUM_LIMBS], + pub data_4: [T; MEMCPY_LOOP_NUM_LIMBS], + pub read_aux: [MemoryReadAuxCols; 4], + pub write_aux: [MemoryWriteAuxCols; 4], +} + +pub const NUM_MEMCPY_ITER_COLS: usize = size_of::>(); + +#[derive(Copy, Clone, Debug, derive_new::new)] +pub struct MemcpyIterAir { + pub memory_bridge: MemoryBridge, + pub range_bus: VariableRangeCheckerBus, + pub memcpy_bus: MemcpyBus, + pub pointer_max_bits: usize, +} + +impl BaseAir for MemcpyIterAir { + fn width(&self) -> usize { + MemcpyIterCols::::width() + } +} + +impl BaseAirWithPublicValues for MemcpyIterAir {} +impl PartitionedBaseAir for MemcpyIterAir {} + +impl Air for MemcpyIterAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (prev, local) = (main.row_slice(0), main.row_slice(1)); + let prev: &MemcpyIterCols = (*prev).borrow(); + let local: &MemcpyIterCols = (*local).borrow(); + + let timestamp: AB::Var = local.timestamp; + let mut timestamp_delta: usize = 0; + let mut timestamp_pp = || { + timestamp_delta += 1; + timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1) + }; + + let shift = local.shift[0] * AB::Expr::TWO + local.shift[1]; + let is_shift_non_zero = or::(local.shift[0], local.shift[1]); + let is_shift_zero = not::(is_shift_non_zero.clone()); + let is_shift_one = and::(local.shift[0], not::(local.shift[1])); + let is_shift_two = and::(not::(local.shift[0]), local.shift[1]); + let is_shift_three = and::(local.shift[0], local.shift[1]); + + // TODO:since if is_valid = 0, then is_boundary = 0, we can reduce the degree of the following expressions by removing the is_valid term + let is_end = + (local.is_boundary + AB::Expr::ONE) * local.is_boundary * (AB::F::TWO).inverse(); + let is_not_start = (local.is_boundary + AB::Expr::ONE) + * (AB::Expr::TWO - local.is_boundary) + * (AB::F::TWO).inverse(); + + let len = local.len[0] + + local.len[1] * AB::F::from_canonical_u32(1 << (2 * MEMCPY_LOOP_LIMB_BITS)); + + // write_data = + // (local.data_1[shift..4], prev.data_4[0..shift]), + // (local.data_2[shift..4], local.data_1[0..shift]), + // (local.data_3[shift..4], local.data_2[0..shift]), + // (local.data_4[shift..4], local.data_3[0..shift]) + let write_data_pairs = [ + (prev.data_4, local.data_1), + (local.data_1, local.data_2), + (local.data_2, local.data_3), + (local.data_3, local.data_4), + ]; + + let write_data = write_data_pairs + .iter() + .map(|(prev_data, next_data)| { + array::from_fn(|i| { + is_shift_zero.clone() * (next_data[i]) + + is_shift_one.clone() + * (if i < 3 { + next_data[i + 1] + } else { + prev_data[i - 3] + }) + + is_shift_two.clone() + * (if i < 2 { + next_data[i + 2] + } else { + prev_data[i - 2] + }) + + is_shift_three.clone() + * (if i < 1 { + next_data[i + 3] + } else { + prev_data[i - 1] + }) + }) + }) + .collect::>(); + + builder.assert_bool(local.is_valid); + for i in 0..2 { + builder.assert_bool(local.shift[i]); + } + builder.assert_bool(local.is_valid_not_start); + // is_boundary is either -1, 0 or 1 + builder.assert_tern(local.is_boundary + AB::Expr::ONE); + + // is_valid_not_start = is_valid and is_not_start: + builder.assert_eq(local.is_valid_not_start, local.is_valid * is_not_start); + + // if is_valid = 0, then is_boundary = 0, shift = 0 + let mut is_not_valid_when = builder.when(not::(local.is_valid)); + is_not_valid_when.assert_zero(local.is_boundary); + is_not_valid_when.assert_zero(shift.clone()); + + // if is_valid_not_start = 1, then len = prev_len - 16, source = prev_source + 16, dest = prev_dest + 16 + let mut is_valid_not_start_when = builder.when(local.is_valid_not_start); + is_valid_not_start_when + .assert_eq(local.len[0], prev.len[0] - AB::Expr::from_canonical_u32(16)); + is_valid_not_start_when + .assert_eq(local.source, prev.source + AB::Expr::from_canonical_u32(16)); + is_valid_not_start_when.assert_eq(local.dest, prev.dest + AB::Expr::from_canonical_u32(16)); + + // if prev.is_valid_start, then timestamp = prev_timestamp + is_shift_non_zero + // since is_shift_non_zero degree is 2, we need to keep the degree of the condition to 1 + builder + .when(not::(prev.is_valid_not_start) - not::(prev.is_valid)) + .assert_eq(local.timestamp, prev.timestamp + is_shift_non_zero.clone()); + // if prev.is_valid_not_start and local.is_valid_not_start, then timestamp = prev_timestamp + 8 + // prev.is_valid_not_start is the opposite of previous condition + builder + .when( + local.is_valid_not_start + - (not::(prev.is_valid_not_start) - not::(prev.is_valid)), + ) + .assert_eq( + local.timestamp, + prev.timestamp + AB::Expr::from_canonical_usize(8), + ); + + // Receive message from memcpy bus or send message to it + // The last data is shift if is_boundary = -1, and 4 if is_boundary = 1 + // This actually receives when is_boundary = -1 + self.memcpy_bus + .send( + local.timestamp, + local.dest, + local.source, + len, + (AB::Expr::ONE - local.is_boundary) * shift.clone() * (AB::F::TWO).inverse() + + (local.is_boundary + AB::Expr::ONE) * AB::Expr::TWO, + ) + .eval(builder, local.is_boundary); + + // Read data from memory + let read_data = [ + (local.data_1, local.read_aux[0]), + (local.data_2, local.read_aux[1]), + (local.data_3, local.read_aux[2]), + (local.data_4, local.read_aux[3]), + ]; + + read_data + .iter() + .enumerate() + .for_each(|(idx, (data, read_aux))| { + let is_valid_read = if idx == 3 { + or::(is_shift_non_zero.clone(), local.is_valid_not_start) + } else { + local.is_valid_not_start.into() + }; + + self.memory_bridge + .read( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + local.source + AB::Expr::from_canonical_usize(idx * 4), + ), + *data, + timestamp_pp(), + read_aux, + ) + .eval(builder, is_valid_read); + }); + + // Write final data to registers + write_data.iter().enumerate().for_each(|(idx, data)| { + self.memory_bridge + .write( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + local.dest + AB::Expr::from_canonical_usize(idx * 4), + ), + data.clone(), + timestamp_pp(), + &local.write_aux[idx], + ) + .eval(builder, local.is_valid_not_start); + }); + + // Range check len + let len_bits_limit = [ + select::( + is_end.clone(), + AB::Expr::from_canonical_usize(4), + AB::Expr::from_canonical_usize(MEMCPY_LOOP_LIMB_BITS * 2), + ), + select::( + is_end.clone(), + AB::Expr::ZERO, + AB::Expr::from_canonical_usize(self.pointer_max_bits - MEMCPY_LOOP_LIMB_BITS * 2), + ), + ]; + self.range_bus + .push(local.len[0], len_bits_limit[0].clone(), true) + .eval(builder, local.is_valid); + self.range_bus + .push(local.len[1], len_bits_limit[1].clone(), true) + .eval(builder, local.is_valid); + } +} + +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct MemcpyIterRecord { + pub timestamp: u32, + pub dest: u32, + pub source: u32, + pub len: u32, + pub shift: u8, + pub memory_read_data: Vec<[u8; MEMCPY_LOOP_NUM_LIMBS]>, + pub read_aux: Vec, + pub write_aux: Vec, +} + +pub struct MemcpyIterChip { + pub air: MemcpyIterAir, + pub records: Arc>>, + pub num_rows: AtomicU32, + pub pointer_max_bits: usize, + pub range_checker_chip: SharedVariableRangeCheckerChip, +} + +impl MemcpyIterChip { + pub fn new( + memory_bridge: MemoryBridge, + range_bus: VariableRangeCheckerBus, + memcpy_bus: MemcpyBus, + pointer_max_bits: usize, + range_checker_chip: SharedVariableRangeCheckerChip, + ) -> Self { + Self { + air: MemcpyIterAir::new(memory_bridge, range_bus, memcpy_bus, pointer_max_bits), + records: Arc::new(Mutex::new(Vec::new())), + num_rows: AtomicU32::new(0), + pointer_max_bits, + range_checker_chip, + } + } + + pub fn bus(&self) -> MemcpyBus { + self.air.memcpy_bus + } + + pub fn clear(&self) { + self.records.lock().unwrap().clear(); + self.num_rows.store(0, std::sync::atomic::Ordering::Relaxed); + } + + pub fn add_new_loop( + &self, + mem_helper: &MemoryAuxColsFactory, + timestamp: u32, + dest: u32, + source: u32, + len: u32, + shift: u8, + memory_read_data: Vec<[u8; MEMCPY_LOOP_NUM_LIMBS]>, + read_aux: Vec, + write_aux: Vec>, + ) { + let mut len = len; + // Update number of rows + self.num_rows + .fetch_add(len / 16 + 1, std::sync::atomic::Ordering::Relaxed); + + let word_to_u16 = |data: u32| [data & 0xffff, data >> 16]; + let has_shift = (shift != 0) as usize; + + // Range check len + loop { + let len_u16_limbs = word_to_u16(len); + if len > 15 { + self.range_checker_chip + .add_count(len_u16_limbs[0], 2 * MEMCPY_LOOP_LIMB_BITS); + self.range_checker_chip.add_count( + len_u16_limbs[1], + self.pointer_max_bits - 2 * MEMCPY_LOOP_LIMB_BITS, + ); + } else { + self.range_checker_chip.add_count(len_u16_limbs[0], 4); + self.range_checker_chip.add_count(len_u16_limbs[1], 0); + } + if len < 16 { + break; + } + len -= 16; + } + + // Read data from memory + let mut row_read_aux = Vec::new(); + read_aux.iter().enumerate().for_each(|(i, aux)| { + let mut aux_cols = MemoryBaseAuxCols::::default(); + let read_timestamp = timestamp + + if i == 0 { + 0 + } else { + (i + (i - has_shift) / 4 * 4) as u32 + }; + mem_helper.fill(aux.prev_timestamp, read_timestamp, &mut aux_cols); + row_read_aux.push(MemoryExtendedAuxRecord::from_aux_cols(aux_cols)); + }); + + // Write data to memory + let mut row_write_aux = Vec::new(); + write_aux.iter().enumerate().for_each(|(i, aux)| { + let mut aux_cols = MemoryBaseAuxCols::::default(); + mem_helper.fill( + aux.prev_timestamp, + (timestamp as usize + i + has_shift + (i / 4 + 1) * 4) as u32, + &mut aux_cols, + ); + row_write_aux.push(MemoryExtendedAuxRecord::from_aux_cols(aux_cols)); + }); + + // Create record + let row = MemcpyIterRecord { + timestamp, + dest, + source, + len, + shift, + memory_read_data, + read_aux: row_read_aux, + write_aux: row_write_aux, + }; + + // Thread-safe push to rows vector + if let Ok(mut rows_guard) = self.records.lock() { + rows_guard.push(row); + } + } + + /// Generates trace + pub fn generate_trace(&self) -> RowMajorMatrix { + let mut rows = F::zero_vec( + (self.num_rows.load(std::sync::atomic::Ordering::Relaxed) as usize) + * NUM_MEMCPY_ITER_COLS, + ); + let mut current_row = 0; + let word_to_u16 = |data: u32| [data & 0xffff, data >> 16].map(F::from_canonical_u32); + + for record in self.records.lock().unwrap().iter() { + let mut timestamp = record.timestamp; + let shift = [record.shift % 2, record.shift / 2].map(F::from_canonical_u8); + let has_shift = (record.shift != 0) as usize; + let mut prev_data = [F::ZERO; MEMCPY_LOOP_NUM_LIMBS]; + + for n in 0..(record.len / 16 + 1) as usize { + let row_start = current_row + n * NUM_MEMCPY_ITER_COLS; + let row = &mut rows[row_start..row_start + NUM_MEMCPY_ITER_COLS]; + let cols: &mut MemcpyIterCols = row.borrow_mut(); + cols.timestamp = F::from_canonical_u32(timestamp); + cols.dest = F::from_canonical_u32(record.dest + (n << 2) as u32); + cols.source = F::from_canonical_u32(record.source + (n << 2) as u32); + cols.len = word_to_u16(record.len - (n << 2) as u32); + cols.shift = shift; + cols.is_valid = F::ONE; + cols.is_valid_not_start = F::ONE; + if n == 0 { + cols.is_boundary = F::NEG_ONE; + if has_shift != 0 { + cols.data_4 = record.memory_read_data[0].map(F::from_canonical_u8); + prev_data = cols.data_4; + cols.read_aux[3].set_base(record.read_aux[0].to_aux_cols()); + } + } else { + cols.is_boundary = if n as u32 == record.len / 16 { + F::ONE + } else { + F::ZERO + }; + let mut data = [[F::ZERO; MEMCPY_LOOP_NUM_LIMBS]; 4]; + for i in 0..4 { + data[i] = record.memory_read_data[(n - 1) * 4 + i + has_shift] + .map(F::from_canonical_u8); + cols.read_aux[i] + .set_base(record.read_aux[(n - 1) * 4 + i + has_shift].to_aux_cols()); + cols.write_aux[i].set_base(record.write_aux[(n - 1) * 4 + i].to_aux_cols()); + let write_data: [F; MEMCPY_LOOP_NUM_LIMBS] = std::array::from_fn(|j| { + if j < 4 - record.shift as usize { + data[i][record.shift as usize + j] + } else { + prev_data[j - (4 - record.shift as usize)] + } + }); + cols.write_aux[i].set_prev_data(write_data); + prev_data = data[i]; + } + cols.data_1 = data[0]; + cols.data_2 = data[1]; + cols.data_3 = data[2]; + cols.data_4 = data[3]; + } + if n == 0 { + timestamp += (record.shift != 0) as u32; + } else { + timestamp += 8; + } + } + current_row += (record.len / 16 + 1) as usize * NUM_MEMCPY_ITER_COLS; + } + RowMajorMatrix::new(rows, NUM_MEMCPY_ITER_COLS) + } +} + +// We allow any `R` type so this can work with arbitrary record arenas. +impl Chip> for MemcpyIterChip +where + Val: PrimeField32, +{ + /// Generates trace and resets the internal counters all to 0. + fn generate_proving_ctx(&self, _: R) -> AirProvingContext> { + let trace = self.generate_trace::>(); + AirProvingContext::simple_no_pis(Arc::new(trace)) + } +} + +impl ChipUsageGetter for MemcpyIterChip { + fn air_name(&self) -> String { + get_air_name(&self.air) + } + fn constant_trace_height(&self) -> Option { + Some(self.num_rows.load(std::sync::atomic::Ordering::Relaxed) as usize) + } + fn current_trace_height(&self) -> usize { + self.num_rows.load(std::sync::atomic::Ordering::Relaxed) as usize + } + fn trace_width(&self) -> usize { + NUM_MEMCPY_ITER_COLS + } +} diff --git a/extensions/memcpy/circuit/src/lib.rs b/extensions/memcpy/circuit/src/lib.rs new file mode 100644 index 0000000000..28f660af34 --- /dev/null +++ b/extensions/memcpy/circuit/src/lib.rs @@ -0,0 +1,17 @@ +mod bus; +mod core; +mod extension; +mod iteration; + +pub use core::*; +pub use extension::*; +pub use iteration::*; + +// ==== Do not change these constants! ==== +pub const MEMCPY_LOOP_NUM_LIMBS: usize = 4; +pub const MEMCPY_LOOP_LIMB_BITS: usize = 8; + +pub const A1_REGISTER_PTR: usize = 11 * 4; +pub const A2_REGISTER_PTR: usize = 12 * 4; +pub const A3_REGISTER_PTR: usize = 13 * 4; +pub const A4_REGISTER_PTR: usize = 14 * 4; diff --git a/extensions/memcpy/tests.rs b/extensions/memcpy/tests.rs new file mode 100644 index 0000000000..bc1880953c --- /dev/null +++ b/extensions/memcpy/tests.rs @@ -0,0 +1,444 @@ +use std::{array, borrow::BorrowMut, sync::Arc}; + +use openvm_circuit::arch::testing::{TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; +use openvm_circuit_primitives::bitwise_op_lookup::{ + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, +}; +use openvm_instructions::LocalOpcode; +use openvm_rv32im_transpiler::BaseAluOpcode::{self, *}; +use openvm_stark_backend::{ + p3_air::BaseAir, + p3_field::{FieldAlgebra, PrimeField32}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, + utils::disable_debug_builder, +}; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; + +use super::{core::run_alu, BaseAluCoreAir, Rv32BaseAluChip, Rv32BaseAluExecutor}; +use crate::{ + adapters::{ + Rv32BaseAluAdapterAir, Rv32BaseAluAdapterExecutor, Rv32BaseAluAdapterFiller, + RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + }, + base_alu::BaseAluCoreCols, + test_utils::{ + generate_rv32_is_type_immediate, get_verification_error, rv32_rand_write_register_or_imm, + }, + BaseAluFiller, Rv32BaseAluAir, +}; + +const MAX_INS_CAPACITY: usize = 128; +type F = BabyBear; +type Harness = TestChipHarness>; + +fn create_test_chip( + tester: &VmChipTestBuilder, +) -> ( + Harness, + ( + BitwiseOperationLookupAir, + SharedBitwiseOperationLookupChip, + ), +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + + let air = Rv32BaseAluAir::new( + Rv32BaseAluAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + ), + BaseAluCoreAir::new(bitwise_bus, BaseAluOpcode::CLASS_OFFSET), + ); + let executor = Rv32BaseAluExecutor::new( + Rv32BaseAluAdapterExecutor::new(), + BaseAluOpcode::CLASS_OFFSET, + ); + let chip = Rv32BaseAluChip::new( + BaseAluFiller::new( + Rv32BaseAluAdapterFiller::new(bitwise_chip.clone()), + bitwise_chip.clone(), + BaseAluOpcode::CLASS_OFFSET, + ), + tester.memory_helper(), + ); + let harness = Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + + (harness, (bitwise_chip.air, bitwise_chip)) +} + +fn set_and_execute( + tester: &mut VmChipTestBuilder, + harness: &mut Harness, + rng: &mut StdRng, + opcode: BaseAluOpcode, + b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + is_imm: Option, + c: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, +) { + let b = b.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + let (c_imm, c) = if is_imm.unwrap_or(rng.gen_bool(0.5)) { + let (imm, c) = if let Some(c) = c { + ((u32::from_le_bytes(c) & 0xFFFFFF) as usize, c) + } else { + generate_rv32_is_type_immediate(rng) + }; + (Some(imm), c) + } else { + ( + None, + c.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))), + ) + }; + + let (instruction, rd) = rv32_rand_write_register_or_imm( + tester, + b, + c, + c_imm, + opcode.global_opcode().as_usize(), + rng, + ); + tester.execute(harness, &instruction); + + let a = run_alu::(opcode, &b, &c) + .map(F::from_canonical_u8); + assert_eq!(a, tester.read::(1, rd)) +} + +////////////////////////////////////////////////////////////////////////////////////// +// POSITIVE TESTS +// +// Randomly generate computations and execute, ensuring that the generated trace +// passes all constraints. +////////////////////////////////////////////////////////////////////////////////////// + +#[test_case(ADD, 100)] +#[test_case(SUB, 100)] +#[test_case(XOR, 100)] +#[test_case(OR, 100)] +#[test_case(AND, 100)] +fn rand_rv32_alu_test(opcode: BaseAluOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + + let mut tester = VmChipTestBuilder::default(); + let (mut harness, bitwise) = create_test_chip(&tester); + + // TODO(AG): make a more meaningful test for memory accesses + tester.write(2, 1024, [F::ONE; 4]); + tester.write(2, 1028, [F::ONE; 4]); + let sm = tester.read(2, 1024); + assert_eq!(sm, [F::ONE; 8]); + + for _ in 0..num_ops { + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + opcode, + None, + None, + None, + ); + } + + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise) + .finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test_case(ADD, 100)] +#[test_case(SUB, 100)] +#[test_case(XOR, 100)] +#[test_case(OR, 100)] +#[test_case(AND, 100)] +fn rand_rv32_alu_test_persistent(opcode: BaseAluOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + + let mut tester = VmChipTestBuilder::default_persistent(); + let (mut harness, bitwise) = create_test_chip(&tester); + + // TODO(AG): make a more meaningful test for memory accesses + tester.write(2, 1024, [F::ONE; 4]); + tester.write(2, 1028, [F::ONE; 4]); + let sm = tester.read(2, 1024); + assert_eq!(sm, [F::ONE; 8]); + + for _ in 0..num_ops { + set_and_execute( + &mut tester, + &mut harness, + &mut rng, + opcode, + None, + None, + None, + ); + } + + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise) + .finalize(); + tester.simple_test().expect("Verification failed"); +} + +////////////////////////////////////////////////////////////////////////////////////// +// NEGATIVE TESTS +// +// Given a fake trace of a single operation, setup a chip and run the test. We replace +// part of the trace and check that the chip throws the expected error. +////////////////////////////////////////////////////////////////////////////////////// + +#[allow(clippy::too_many_arguments)] +fn run_negative_alu_test( + opcode: BaseAluOpcode, + prank_a: [u32; RV32_REGISTER_NUM_LIMBS], + b: [u8; RV32_REGISTER_NUM_LIMBS], + c: [u8; RV32_REGISTER_NUM_LIMBS], + prank_c: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + prank_opcode_flags: Option<[bool; 5]>, + is_imm: Option, + interaction_error: bool, +) { + let mut rng = create_seeded_rng(); + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); + let (mut chip, bitwise) = create_test_chip(&tester); + + set_and_execute( + &mut tester, + &mut chip, + &mut rng, + opcode, + Some(b), + is_imm, + Some(c), + ); + + let adapter_width = BaseAir::::width(&chip.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { + let mut values = trace.row_slice(0).to_vec(); + let cols: &mut BaseAluCoreCols = + values.split_at_mut(adapter_width).1.borrow_mut(); + cols.a = prank_a.map(F::from_canonical_u32); + if let Some(prank_c) = prank_c { + cols.c = prank_c.map(F::from_canonical_u32); + } + if let Some(prank_opcode_flags) = prank_opcode_flags { + cols.opcode_add_flag = F::from_bool(prank_opcode_flags[0]); + cols.opcode_and_flag = F::from_bool(prank_opcode_flags[1]); + cols.opcode_or_flag = F::from_bool(prank_opcode_flags[2]); + cols.opcode_sub_flag = F::from_bool(prank_opcode_flags[3]); + cols.opcode_xor_flag = F::from_bool(prank_opcode_flags[4]); + } + *trace = RowMajorMatrix::new(values, trace.width()); + }; + + disable_debug_builder(); + let tester = tester + .build() + .load_and_prank_trace(chip, modify_trace) + .load_periphery(bitwise) + .finalize(); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); +} + +#[test] +fn rv32_alu_add_wrong_negative_test() { + run_negative_alu_test( + ADD, + [246, 0, 0, 0], + [250, 0, 0, 0], + [250, 0, 0, 0], + None, + None, + None, + false, + ); +} + +#[test] +fn rv32_alu_add_out_of_range_negative_test() { + run_negative_alu_test( + ADD, + [500, 0, 0, 0], + [250, 0, 0, 0], + [250, 0, 0, 0], + None, + None, + None, + true, + ); +} + +#[test] +fn rv32_alu_sub_wrong_negative_test() { + run_negative_alu_test( + SUB, + [255, 0, 0, 0], + [1, 0, 0, 0], + [2, 0, 0, 0], + None, + None, + None, + false, + ); +} + +#[test] +fn rv32_alu_sub_out_of_range_negative_test() { + run_negative_alu_test( + SUB, + [F::NEG_ONE.as_canonical_u32(), 0, 0, 0], + [1, 0, 0, 0], + [2, 0, 0, 0], + None, + None, + None, + true, + ); +} + +#[test] +fn rv32_alu_xor_wrong_negative_test() { + run_negative_alu_test( + XOR, + [255, 255, 255, 255], + [0, 0, 1, 0], + [255, 255, 255, 255], + None, + None, + None, + true, + ); +} + +#[test] +fn rv32_alu_or_wrong_negative_test() { + run_negative_alu_test( + OR, + [255, 255, 255, 255], + [255, 255, 255, 254], + [0, 0, 0, 0], + None, + None, + None, + true, + ); +} + +#[test] +fn rv32_alu_and_wrong_negative_test() { + run_negative_alu_test( + AND, + [255, 255, 255, 255], + [0, 0, 1, 0], + [0, 0, 0, 0], + None, + None, + None, + true, + ); +} + +#[test] +fn rv32_alu_adapter_unconstrained_imm_limb_test() { + run_negative_alu_test( + ADD, + [255, 7, 0, 0], + [0, 0, 0, 0], + [255, 7, 0, 0], + Some([511, 6, 0, 0]), + None, + Some(true), + true, + ); +} + +#[test] +fn rv32_alu_adapter_unconstrained_rs2_read_test() { + run_negative_alu_test( + ADD, + [2, 2, 2, 2], + [1, 1, 1, 1], + [1, 1, 1, 1], + None, + Some([false, false, false, false, false]), + Some(false), + false, + ); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// SANITY TESTS +/// +/// Ensure that solve functions produce the correct results. +/////////////////////////////////////////////////////////////////////////////////////// + +#[test] +fn run_add_sanity_test() { + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [23, 205, 73, 49]; + let result = run_alu::(ADD, &x, &y); + for i in 0..RV32_REGISTER_NUM_LIMBS { + assert_eq!(z[i], result[i]) + } +} + +#[test] +fn run_sub_sanity_test() { + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [179, 118, 240, 172]; + let result = run_alu::(SUB, &x, &y); + for i in 0..RV32_REGISTER_NUM_LIMBS { + assert_eq!(z[i], result[i]) + } +} + +#[test] +fn run_xor_sanity_test() { + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [215, 138, 49, 173]; + let result = run_alu::(XOR, &x, &y); + for i in 0..RV32_REGISTER_NUM_LIMBS { + assert_eq!(z[i], result[i]) + } +} + +#[test] +fn run_or_sanity_test() { + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [247, 171, 61, 239]; + let result = run_alu::(OR, &x, &y); + for i in 0..RV32_REGISTER_NUM_LIMBS { + assert_eq!(z[i], result[i]) + } +} + +#[test] +fn run_and_sanity_test() { + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [32, 33, 12, 66]; + let result = run_alu::(AND, &x, &y); + for i in 0..RV32_REGISTER_NUM_LIMBS { + assert_eq!(z[i], result[i]) + } +} diff --git a/extensions/memcpy/transpiler/Cargo.toml b/extensions/memcpy/transpiler/Cargo.toml new file mode 100644 index 0000000000..03e39a78d7 --- /dev/null +++ b/extensions/memcpy/transpiler/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "openvm-memcpy-transpiler" +description = "OpenVM transpiler extension for memcpy" +version.workspace = true +authors.workspace = true +edition.workspace = true +homepage.workspace = true +repository.workspace = true + +[dependencies] +openvm-stark-backend = { workspace = true } +openvm-instructions = { workspace = true } +openvm-transpiler = { workspace = true } +rrs-lib = { workspace = true } + +openvm-instructions-derive = { workspace = true } +strum = { workspace = true } diff --git a/extensions/memcpy/transpiler/src/lib.rs b/extensions/memcpy/transpiler/src/lib.rs new file mode 100644 index 0000000000..76968ade99 --- /dev/null +++ b/extensions/memcpy/transpiler/src/lib.rs @@ -0,0 +1,60 @@ +use openvm_instructions::LocalOpcode; +use openvm_instructions_derive::LocalOpcode; +use openvm_stark_backend::p3_field::PrimeField32; +use openvm_transpiler::{util::from_u_type, TranspilerExtension, TranspilerOutput}; +use rrs_lib::instruction_formats::UType; +use strum::{EnumCount, EnumIter, FromRepr}; + +#[derive( + Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, +)] +#[opcode_offset = 0x330] +#[repr(usize)] +#[allow(non_camel_case_types)] +pub enum Rv32MemcpyOpcode { + MEMCPY_LOOP, +} + +// pub const OPCODE: u8 = 0x0b; +// pub const KECCAK256_FUNCT3: u8 = 0b100; +// pub const KECCAK256_FUNCT7: u8 = 0; +// Custom opcode for memcpy_loop instruction +pub const MEMCPY_LOOP_OPCODE: u8 = 0x73; // Custom opcode +pub const MEMCPY_LOOP_FUNCT3: u8 = 0x0; // Custom funct3 + +#[derive(Default)] +pub struct MemcpyTranspilerExtension; + +impl TranspilerExtension for MemcpyTranspilerExtension { + fn process_custom(&self, instruction_stream: &[u32]) -> Option> { + if instruction_stream.is_empty() { + return None; + } + + let instruction_u32 = instruction_stream[0]; + let opcode = (instruction_u32 & 0x7f) as u8; + let funct3 = ((instruction_u32 >> 12) & 0b111) as u8; + + // Check if this is our custom memcpy_loop instruction + if (opcode, funct3) != (MEMCPY_LOOP_OPCODE, MEMCPY_LOOP_FUNCT3) { + return None; + } + + // Parse I-type instruction format + let dec_insn = UType::new(instruction_u32); + let shift = dec_insn.imm as u8; + + // Validate shift value (0, 1, 2, or 3) + if shift > 3u8 { + return None; + } + + // Convert to OpenVM instruction format + let instruction = from_u_type( + Rv32MemcpyOpcode::MEMCPY_LOOP.global_opcode().as_usize(), + &dec_insn, + ); + + Some(TranspilerOutput::one_to_one(instruction)) + } +} From 83777cc3cb6529826fc26764603c1e412c567c4b Mon Sep 17 00:00:00 2001 From: Peyman Jabbarzade Date: Tue, 19 Aug 2025 13:27:16 -0400 Subject: [PATCH 02/14] fix: add memcpy to toolchain --- Cargo.lock | 8 +- Cargo.toml | 2 + benchmarks/execute/Cargo.toml | 2 + benchmarks/execute/benches/execute.rs | 7 + crates/sdk/Cargo.toml | 2 + crates/sdk/src/config/global.rs | 22 + crates/toolchain/openvm/src/memcpy.s | 5 + extensions/memcpy/circuit/src/core.rs | 16 +- extensions/memcpy/tests.rs | 888 ++++++++++++------------ extensions/memcpy/tests/Cargo.toml | 27 + extensions/memcpy/tests/src/lib.rs | 567 +++++++++++++++ extensions/memcpy/transpiler/src/lib.rs | 15 +- 12 files changed, 1097 insertions(+), 464 deletions(-) create mode 100644 extensions/memcpy/tests/Cargo.toml create mode 100644 extensions/memcpy/tests/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index b0b6e8e389..687d5ceae2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5235,6 +5235,8 @@ dependencies = [ "openvm-ecc-transpiler", "openvm-keccak256-circuit", "openvm-keccak256-transpiler", + "openvm-memcpy-circuit", + "openvm-memcpy-transpiler", "openvm-native-circuit", "openvm-native-recursion", "openvm-pairing-circuit", @@ -5742,7 +5744,7 @@ dependencies = [ [[package]] name = "openvm-memcpy-circuit" -version = "1.4.0-rc.2" +version = "1.4.0-rc.6" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", @@ -5761,7 +5763,7 @@ dependencies = [ [[package]] name = "openvm-memcpy-transpiler" -version = "1.4.0-rc.2" +version = "1.4.0-rc.6" dependencies = [ "openvm-instructions", "openvm-instructions-derive", @@ -6171,6 +6173,8 @@ dependencies = [ "openvm-ecc-transpiler", "openvm-keccak256-circuit", "openvm-keccak256-transpiler", + "openvm-memcpy-circuit", + "openvm-memcpy-transpiler", "openvm-native-circuit", "openvm-native-compiler", "openvm-native-recursion", diff --git a/Cargo.toml b/Cargo.toml index b83495e41e..b4419de60d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -173,6 +173,8 @@ openvm-ecc-sw-macros = { path = "extensions/ecc/sw-macros", default-features = f openvm-pairing-circuit = { path = "extensions/pairing/circuit", default-features = false } openvm-pairing-transpiler = { path = "extensions/pairing/transpiler", default-features = false } openvm-pairing-guest = { path = "extensions/pairing/guest", default-features = false } +openvm-memcpy-circuit = { path = "extensions/memcpy/circuit", default-features = false } +openvm-memcpy-transpiler = { path = "extensions/memcpy/transpiler", default-features = false } openvm-verify-stark = { path = "guest-libs/verify_stark", default-features = false } # Benchmarking diff --git a/benchmarks/execute/Cargo.toml b/benchmarks/execute/Cargo.toml index 5fcf58b1de..0abecbbf53 100644 --- a/benchmarks/execute/Cargo.toml +++ b/benchmarks/execute/Cargo.toml @@ -24,6 +24,8 @@ openvm-pairing-guest.workspace = true openvm-pairing-transpiler.workspace = true openvm-keccak256-circuit.workspace = true openvm-keccak256-transpiler.workspace = true +openvm-memcpy-circuit.workspace = true +openvm-memcpy-transpiler.workspace = true openvm-rv32im-circuit.workspace = true openvm-rv32im-transpiler.workspace = true openvm-sha256-circuit.workspace = true diff --git a/benchmarks/execute/benches/execute.rs b/benchmarks/execute/benches/execute.rs index 291140ceea..a3ba0c857c 100644 --- a/benchmarks/execute/benches/execute.rs +++ b/benchmarks/execute/benches/execute.rs @@ -32,6 +32,8 @@ use openvm_ecc_circuit::{EccCpuProverExt, WeierstrassExtension, WeierstrassExten use openvm_ecc_transpiler::EccTranspilerExtension; use openvm_keccak256_circuit::{Keccak256, Keccak256CpuProverExt, Keccak256Executor}; use openvm_keccak256_transpiler::Keccak256TranspilerExtension; +use openvm_memcpy_circuit::{Memcpy, MemcpyCpuProverExt, MemcpyExecutor}; +use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_native_circuit::{NativeCpuBuilder, NATIVE_MAX_TRACE_HEIGHTS}; use openvm_native_recursion::hints::Hintable; use openvm_pairing_circuit::{ @@ -107,6 +109,8 @@ pub struct ExecuteConfig { #[extension] pub keccak: Keccak256, #[extension] + pub memcpy: Memcpy, + #[extension] pub sha256: Sha256, #[extension] pub modular: ModularExtension, @@ -128,6 +132,7 @@ impl Default for ExecuteConfig { io: Rv32Io, bigint: Int256::default(), keccak: Keccak256, + memcpy: Memcpy, sha256: Sha256, modular: ModularExtension::new(vec![ bn_config.modulus.clone(), @@ -188,6 +193,7 @@ where &config.keccak, inventory, )?; + VmProverExtension::::extend_prover(&MemcpyCpuProverExt, &config.memcpy, inventory)?; VmProverExtension::::extend_prover(&Sha2CpuProverExt, &config.sha256, inventory)?; VmProverExtension::::extend_prover( &AlgebraCpuProverExt, @@ -216,6 +222,7 @@ fn create_default_transpiler() -> Transpiler { .with_extension(Rv32MTranspilerExtension) .with_extension(Int256TranspilerExtension) .with_extension(Keccak256TranspilerExtension) + .with_extension(MemcpyTranspilerExtension) .with_extension(Sha256TranspilerExtension) .with_extension(ModularTranspilerExtension) .with_extension(Fp2TranspilerExtension) diff --git a/crates/sdk/Cargo.toml b/crates/sdk/Cargo.toml index 44197f1eab..bb30819107 100644 --- a/crates/sdk/Cargo.toml +++ b/crates/sdk/Cargo.toml @@ -18,6 +18,8 @@ openvm-ecc-circuit = { workspace = true } openvm-ecc-transpiler = { workspace = true } openvm-keccak256-circuit = { workspace = true } openvm-keccak256-transpiler = { workspace = true } +openvm-memcpy-circuit = { workspace = true } +openvm-memcpy-transpiler = { workspace = true } openvm-sha256-circuit = { workspace = true } openvm-sha256-transpiler = { workspace = true } openvm-pairing-circuit = { workspace = true } diff --git a/crates/sdk/src/config/global.rs b/crates/sdk/src/config/global.rs index 9699b1ed34..81987285b9 100644 --- a/crates/sdk/src/config/global.rs +++ b/crates/sdk/src/config/global.rs @@ -33,6 +33,8 @@ use openvm_rv32im_circuit::{ use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; +use openvm_memcpy_circuit::{Memcpy, MemcpyCpuProverExt, MemcpyExecutor}; +use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha2CpuProverExt}; use openvm_sha256_transpiler::Sha256TranspilerExtension; use openvm_stark_backend::{ @@ -81,6 +83,7 @@ pub struct SdkVmConfig { pub rv32i: Option, pub io: Option, pub keccak: Option, + pub memcpy: Option, pub sha256: Option, pub native: Option, pub castf: Option, @@ -118,6 +121,7 @@ impl SdkVmConfig { .rv32m(Default::default()) .io(Default::default()) .keccak(Default::default()) + .memcpy(Default::default()) .sha256(Default::default()) .bigint(Default::default()) .modular(ModularExtension::new(vec![ @@ -199,6 +203,9 @@ impl TranspilerConfig for SdkVmConfig { if self.keccak.is_some() { transpiler = transpiler.with_extension(Keccak256TranspilerExtension); } + if self.memcpy.is_some() { + transpiler = transpiler.with_extension(MemcpyTranspilerExtension); + } if self.sha256.is_some() { transpiler = transpiler.with_extension(Sha256TranspilerExtension); } @@ -269,6 +276,7 @@ impl SdkVmConfig { let rv32i = config.rv32i.map(|_| Rv32I); let io = config.io.map(|_| Rv32Io); let keccak = config.keccak.map(|_| Keccak256); + let memcpy = config.memcpy.map(|_| Memcpy); let sha256 = config.sha256.map(|_| Sha256); let native = config.native.map(|_| Native); let castf = config.castf.map(|_| CastFExtension); @@ -284,6 +292,7 @@ impl SdkVmConfig { rv32i, io, keccak, + memcpy, sha256, native, castf, @@ -315,6 +324,8 @@ pub struct SdkVmConfigInner { pub io: Option, #[extension(executor = "Keccak256Executor")] pub keccak: Option, + #[extension(executor = "MemcpyExecutor")] + pub memcpy: Option, #[extension(executor = "Sha256Executor")] pub sha256: Option, #[extension(executor = "NativeExecutor")] @@ -392,6 +403,9 @@ where if let Some(keccak) = &config.keccak { VmProverExtension::::extend_prover(&Keccak256CpuProverExt, keccak, inventory)?; } + if let Some(memcpy) = &config.memcpy { + VmProverExtension::::extend_prover(&MemcpyCpuProverExt, memcpy, inventory)?; + } if let Some(sha256) = &config.sha256 { VmProverExtension::::extend_prover(&Sha2CpuProverExt, sha256, inventory)?; } @@ -566,6 +580,12 @@ impl From for UnitStruct { } } +impl From for UnitStruct { + fn from(_: Memcpy) -> Self { + UnitStruct {} + } +} + impl From for UnitStruct { fn from(_: Sha256) -> Self { UnitStruct {} @@ -592,6 +612,7 @@ struct SdkVmConfigWithDefaultDeser { pub rv32i: Option, pub io: Option, pub keccak: Option, + pub memcpy: Option, pub sha256: Option, pub native: Option, pub castf: Option, @@ -611,6 +632,7 @@ impl From for SdkVmConfig { rv32i: config.rv32i, io: config.io, keccak: config.keccak, + memcpy: config.memcpy, sha256: config.sha256, native: config.native, castf: config.castf, diff --git a/crates/toolchain/openvm/src/memcpy.s b/crates/toolchain/openvm/src/memcpy.s index 3f787e5d52..0e63bfdcab 100644 --- a/crates/toolchain/openvm/src/memcpy.s +++ b/crates/toolchain/openvm/src/memcpy.s @@ -205,6 +205,11 @@ .attribute 4, 16 .attribute 5, "rv32im" .file "musl_memcpy.c" + + # Define memcpy_loop macro for custom instruction (U-type) + .macro memcpy_loop shift + .word 0x72000000 | (\shift << 12) # opcode 0x72 + shift in immediate field (bits 12-31) + .endm .globl memcpy .p2align 2 .type memcpy,@function diff --git a/extensions/memcpy/circuit/src/core.rs b/extensions/memcpy/circuit/src/core.rs index 43b465bfb1..80fdf1a7d1 100644 --- a/extensions/memcpy/circuit/src/core.rs +++ b/extensions/memcpy/circuit/src/core.rs @@ -311,7 +311,7 @@ pub type MemcpyLoopChip = VmChipWrapper; #[derive(AlignedBytesBorrow, Clone)] #[repr(C)] struct MemcpyLoopPreCompute { - a: u8, + c: u8, } impl PreflightExecutor for MemcpyLoopExecutor @@ -328,9 +328,9 @@ where state: VmStateMut, instruction: &Instruction, ) -> Result<(), ExecutionError> { - let Instruction { opcode, a, .. } = instruction; + let Instruction { opcode, c, .. } = instruction; debug_assert_eq!(*opcode, Rv32MemcpyOpcode::MEMCPY_LOOP.global_opcode()); - let shift = a.as_canonical_u32() as u8; + let shift = c.as_canonical_u32() as u8; debug_assert!([0, 1, 2, 3].contains(&shift)); let mut record = state.ctx.alloc(EmptyMultiRowLayout::default()); @@ -600,7 +600,7 @@ unsafe fn execute_e12_impl( pre_compute: &MemcpyLoopPreCompute, vm_state: &mut VmExecState, ) { - let shift = pre_compute.a; + let shift = pre_compute.c; let (dest, source) = if shift == 0 { ( vm_state.vm_read::(RV32_REGISTER_AS, A3_REGISTER_PTR as u32), @@ -700,12 +700,12 @@ impl MemcpyLoopExecutor { inst: &Instruction, data: &mut MemcpyLoopPreCompute, ) -> Result<(), StaticProgramError> { - let Instruction { opcode, a, .. } = inst; - let a_u32 = a.as_canonical_u32(); - if ![0, 1, 2, 3].contains(&a_u32) { + let Instruction { opcode, c, .. } = inst; + let c_u32 = c.as_canonical_u32(); + if ![0, 1, 2, 3].contains(&c_u32) { return Err(StaticProgramError::InvalidInstruction(pc)); } - *data = MemcpyLoopPreCompute { a: a_u32 as u8 }; + *data = MemcpyLoopPreCompute { c: c_u32 as u8 }; assert_eq!(*opcode, Rv32MemcpyOpcode::MEMCPY_LOOP.global_opcode()); Ok(()) } diff --git a/extensions/memcpy/tests.rs b/extensions/memcpy/tests.rs index bc1880953c..743199194b 100644 --- a/extensions/memcpy/tests.rs +++ b/extensions/memcpy/tests.rs @@ -1,444 +1,444 @@ -use std::{array, borrow::BorrowMut, sync::Arc}; - -use openvm_circuit::arch::testing::{TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, - SharedBitwiseOperationLookupChip, -}; -use openvm_instructions::LocalOpcode; -use openvm_rv32im_transpiler::BaseAluOpcode::{self, *}; -use openvm_stark_backend::{ - p3_air::BaseAir, - p3_field::{FieldAlgebra, PrimeField32}, - p3_matrix::{ - dense::{DenseMatrix, RowMajorMatrix}, - Matrix, - }, - utils::disable_debug_builder, -}; -use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::{rngs::StdRng, Rng}; -use test_case::test_case; - -use super::{core::run_alu, BaseAluCoreAir, Rv32BaseAluChip, Rv32BaseAluExecutor}; -use crate::{ - adapters::{ - Rv32BaseAluAdapterAir, Rv32BaseAluAdapterExecutor, Rv32BaseAluAdapterFiller, - RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, - }, - base_alu::BaseAluCoreCols, - test_utils::{ - generate_rv32_is_type_immediate, get_verification_error, rv32_rand_write_register_or_imm, - }, - BaseAluFiller, Rv32BaseAluAir, -}; - -const MAX_INS_CAPACITY: usize = 128; -type F = BabyBear; -type Harness = TestChipHarness>; - -fn create_test_chip( - tester: &VmChipTestBuilder, -) -> ( - Harness, - ( - BitwiseOperationLookupAir, - SharedBitwiseOperationLookupChip, - ), -) { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); - - let air = Rv32BaseAluAir::new( - Rv32BaseAluAdapterAir::new( - tester.execution_bridge(), - tester.memory_bridge(), - bitwise_bus, - ), - BaseAluCoreAir::new(bitwise_bus, BaseAluOpcode::CLASS_OFFSET), - ); - let executor = Rv32BaseAluExecutor::new( - Rv32BaseAluAdapterExecutor::new(), - BaseAluOpcode::CLASS_OFFSET, - ); - let chip = Rv32BaseAluChip::new( - BaseAluFiller::new( - Rv32BaseAluAdapterFiller::new(bitwise_chip.clone()), - bitwise_chip.clone(), - BaseAluOpcode::CLASS_OFFSET, - ), - tester.memory_helper(), - ); - let harness = Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); - - (harness, (bitwise_chip.air, bitwise_chip)) -} - -fn set_and_execute( - tester: &mut VmChipTestBuilder, - harness: &mut Harness, - rng: &mut StdRng, - opcode: BaseAluOpcode, - b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, - is_imm: Option, - c: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, -) { - let b = b.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); - let (c_imm, c) = if is_imm.unwrap_or(rng.gen_bool(0.5)) { - let (imm, c) = if let Some(c) = c { - ((u32::from_le_bytes(c) & 0xFFFFFF) as usize, c) - } else { - generate_rv32_is_type_immediate(rng) - }; - (Some(imm), c) - } else { - ( - None, - c.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))), - ) - }; - - let (instruction, rd) = rv32_rand_write_register_or_imm( - tester, - b, - c, - c_imm, - opcode.global_opcode().as_usize(), - rng, - ); - tester.execute(harness, &instruction); - - let a = run_alu::(opcode, &b, &c) - .map(F::from_canonical_u8); - assert_eq!(a, tester.read::(1, rd)) -} - -////////////////////////////////////////////////////////////////////////////////////// -// POSITIVE TESTS -// -// Randomly generate computations and execute, ensuring that the generated trace -// passes all constraints. -////////////////////////////////////////////////////////////////////////////////////// - -#[test_case(ADD, 100)] -#[test_case(SUB, 100)] -#[test_case(XOR, 100)] -#[test_case(OR, 100)] -#[test_case(AND, 100)] -fn rand_rv32_alu_test(opcode: BaseAluOpcode, num_ops: usize) { - let mut rng = create_seeded_rng(); - - let mut tester = VmChipTestBuilder::default(); - let (mut harness, bitwise) = create_test_chip(&tester); - - // TODO(AG): make a more meaningful test for memory accesses - tester.write(2, 1024, [F::ONE; 4]); - tester.write(2, 1028, [F::ONE; 4]); - let sm = tester.read(2, 1024); - assert_eq!(sm, [F::ONE; 8]); - - for _ in 0..num_ops { - set_and_execute( - &mut tester, - &mut harness, - &mut rng, - opcode, - None, - None, - None, - ); - } - - let tester = tester - .build() - .load(harness) - .load_periphery(bitwise) - .finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test_case(ADD, 100)] -#[test_case(SUB, 100)] -#[test_case(XOR, 100)] -#[test_case(OR, 100)] -#[test_case(AND, 100)] -fn rand_rv32_alu_test_persistent(opcode: BaseAluOpcode, num_ops: usize) { - let mut rng = create_seeded_rng(); - - let mut tester = VmChipTestBuilder::default_persistent(); - let (mut harness, bitwise) = create_test_chip(&tester); - - // TODO(AG): make a more meaningful test for memory accesses - tester.write(2, 1024, [F::ONE; 4]); - tester.write(2, 1028, [F::ONE; 4]); - let sm = tester.read(2, 1024); - assert_eq!(sm, [F::ONE; 8]); - - for _ in 0..num_ops { - set_and_execute( - &mut tester, - &mut harness, - &mut rng, - opcode, - None, - None, - None, - ); - } - - let tester = tester - .build() - .load(harness) - .load_periphery(bitwise) - .finalize(); - tester.simple_test().expect("Verification failed"); -} - -////////////////////////////////////////////////////////////////////////////////////// -// NEGATIVE TESTS -// -// Given a fake trace of a single operation, setup a chip and run the test. We replace -// part of the trace and check that the chip throws the expected error. -////////////////////////////////////////////////////////////////////////////////////// - -#[allow(clippy::too_many_arguments)] -fn run_negative_alu_test( - opcode: BaseAluOpcode, - prank_a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u8; RV32_REGISTER_NUM_LIMBS], - c: [u8; RV32_REGISTER_NUM_LIMBS], - prank_c: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, - prank_opcode_flags: Option<[bool; 5]>, - is_imm: Option, - interaction_error: bool, -) { - let mut rng = create_seeded_rng(); - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let (mut chip, bitwise) = create_test_chip(&tester); - - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - opcode, - Some(b), - is_imm, - Some(c), - ); - - let adapter_width = BaseAir::::width(&chip.air.adapter); - let modify_trace = |trace: &mut DenseMatrix| { - let mut values = trace.row_slice(0).to_vec(); - let cols: &mut BaseAluCoreCols = - values.split_at_mut(adapter_width).1.borrow_mut(); - cols.a = prank_a.map(F::from_canonical_u32); - if let Some(prank_c) = prank_c { - cols.c = prank_c.map(F::from_canonical_u32); - } - if let Some(prank_opcode_flags) = prank_opcode_flags { - cols.opcode_add_flag = F::from_bool(prank_opcode_flags[0]); - cols.opcode_and_flag = F::from_bool(prank_opcode_flags[1]); - cols.opcode_or_flag = F::from_bool(prank_opcode_flags[2]); - cols.opcode_sub_flag = F::from_bool(prank_opcode_flags[3]); - cols.opcode_xor_flag = F::from_bool(prank_opcode_flags[4]); - } - *trace = RowMajorMatrix::new(values, trace.width()); - }; - - disable_debug_builder(); - let tester = tester - .build() - .load_and_prank_trace(chip, modify_trace) - .load_periphery(bitwise) - .finalize(); - tester.simple_test_with_expected_error(get_verification_error(interaction_error)); -} - -#[test] -fn rv32_alu_add_wrong_negative_test() { - run_negative_alu_test( - ADD, - [246, 0, 0, 0], - [250, 0, 0, 0], - [250, 0, 0, 0], - None, - None, - None, - false, - ); -} - -#[test] -fn rv32_alu_add_out_of_range_negative_test() { - run_negative_alu_test( - ADD, - [500, 0, 0, 0], - [250, 0, 0, 0], - [250, 0, 0, 0], - None, - None, - None, - true, - ); -} - -#[test] -fn rv32_alu_sub_wrong_negative_test() { - run_negative_alu_test( - SUB, - [255, 0, 0, 0], - [1, 0, 0, 0], - [2, 0, 0, 0], - None, - None, - None, - false, - ); -} - -#[test] -fn rv32_alu_sub_out_of_range_negative_test() { - run_negative_alu_test( - SUB, - [F::NEG_ONE.as_canonical_u32(), 0, 0, 0], - [1, 0, 0, 0], - [2, 0, 0, 0], - None, - None, - None, - true, - ); -} - -#[test] -fn rv32_alu_xor_wrong_negative_test() { - run_negative_alu_test( - XOR, - [255, 255, 255, 255], - [0, 0, 1, 0], - [255, 255, 255, 255], - None, - None, - None, - true, - ); -} - -#[test] -fn rv32_alu_or_wrong_negative_test() { - run_negative_alu_test( - OR, - [255, 255, 255, 255], - [255, 255, 255, 254], - [0, 0, 0, 0], - None, - None, - None, - true, - ); -} - -#[test] -fn rv32_alu_and_wrong_negative_test() { - run_negative_alu_test( - AND, - [255, 255, 255, 255], - [0, 0, 1, 0], - [0, 0, 0, 0], - None, - None, - None, - true, - ); -} - -#[test] -fn rv32_alu_adapter_unconstrained_imm_limb_test() { - run_negative_alu_test( - ADD, - [255, 7, 0, 0], - [0, 0, 0, 0], - [255, 7, 0, 0], - Some([511, 6, 0, 0]), - None, - Some(true), - true, - ); -} - -#[test] -fn rv32_alu_adapter_unconstrained_rs2_read_test() { - run_negative_alu_test( - ADD, - [2, 2, 2, 2], - [1, 1, 1, 1], - [1, 1, 1, 1], - None, - Some([false, false, false, false, false]), - Some(false), - false, - ); -} - -/////////////////////////////////////////////////////////////////////////////////////// -/// SANITY TESTS -/// -/// Ensure that solve functions produce the correct results. -/////////////////////////////////////////////////////////////////////////////////////// - -#[test] -fn run_add_sanity_test() { - let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; - let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; - let z: [u8; RV32_REGISTER_NUM_LIMBS] = [23, 205, 73, 49]; - let result = run_alu::(ADD, &x, &y); - for i in 0..RV32_REGISTER_NUM_LIMBS { - assert_eq!(z[i], result[i]) - } -} - -#[test] -fn run_sub_sanity_test() { - let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; - let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; - let z: [u8; RV32_REGISTER_NUM_LIMBS] = [179, 118, 240, 172]; - let result = run_alu::(SUB, &x, &y); - for i in 0..RV32_REGISTER_NUM_LIMBS { - assert_eq!(z[i], result[i]) - } -} - -#[test] -fn run_xor_sanity_test() { - let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; - let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; - let z: [u8; RV32_REGISTER_NUM_LIMBS] = [215, 138, 49, 173]; - let result = run_alu::(XOR, &x, &y); - for i in 0..RV32_REGISTER_NUM_LIMBS { - assert_eq!(z[i], result[i]) - } -} - -#[test] -fn run_or_sanity_test() { - let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; - let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; - let z: [u8; RV32_REGISTER_NUM_LIMBS] = [247, 171, 61, 239]; - let result = run_alu::(OR, &x, &y); - for i in 0..RV32_REGISTER_NUM_LIMBS { - assert_eq!(z[i], result[i]) - } -} - -#[test] -fn run_and_sanity_test() { - let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; - let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; - let z: [u8; RV32_REGISTER_NUM_LIMBS] = [32, 33, 12, 66]; - let result = run_alu::(AND, &x, &y); - for i in 0..RV32_REGISTER_NUM_LIMBS { - assert_eq!(z[i], result[i]) - } -} +// use std::{array, borrow::BorrowMut, sync::Arc}; + +// use openvm_circuit::arch::testing::{TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; +// use openvm_circuit_primitives::bitwise_op_lookup::{ +// BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, +// SharedBitwiseOperationLookupChip, +// }; +// use openvm_instructions::LocalOpcode; +// use openvm_rv32im_transpiler::BaseAluOpcode::{self, *}; +// use openvm_stark_backend::{ +// p3_air::BaseAir, +// p3_field::{FieldAlgebra, PrimeField32}, +// p3_matrix::{ +// dense::{DenseMatrix, RowMajorMatrix}, +// Matrix, +// }, +// utils::disable_debug_builder, +// }; +// use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; +// use rand::{rngs::StdRng, Rng}; +// use test_case::test_case; + +// use super::{core::run_alu, BaseAluCoreAir, Rv32BaseAluChip, Rv32BaseAluExecutor}; +// use crate::{ +// adapters::{ +// Rv32BaseAluAdapterAir, Rv32BaseAluAdapterExecutor, Rv32BaseAluAdapterFiller, +// RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, +// }, +// base_alu::BaseAluCoreCols, +// test_utils::{ +// generate_rv32_is_type_immediate, get_verification_error, rv32_rand_write_register_or_imm, +// }, +// BaseAluFiller, Rv32BaseAluAir, +// }; + +// const MAX_INS_CAPACITY: usize = 128; +// type F = BabyBear; +// type Harness = TestChipHarness>; + +// fn create_test_chip( +// tester: &VmChipTestBuilder, +// ) -> ( +// Harness, +// ( +// BitwiseOperationLookupAir, +// SharedBitwiseOperationLookupChip, +// ), +// ) { +// let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); +// let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( +// bitwise_bus, +// )); + +// let air = Rv32BaseAluAir::new( +// Rv32BaseAluAdapterAir::new( +// tester.execution_bridge(), +// tester.memory_bridge(), +// bitwise_bus, +// ), +// BaseAluCoreAir::new(bitwise_bus, BaseAluOpcode::CLASS_OFFSET), +// ); +// let executor = Rv32BaseAluExecutor::new( +// Rv32BaseAluAdapterExecutor::new(), +// BaseAluOpcode::CLASS_OFFSET, +// ); +// let chip = Rv32BaseAluChip::new( +// BaseAluFiller::new( +// Rv32BaseAluAdapterFiller::new(bitwise_chip.clone()), +// bitwise_chip.clone(), +// BaseAluOpcode::CLASS_OFFSET, +// ), +// tester.memory_helper(), +// ); +// let harness = Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + +// (harness, (bitwise_chip.air, bitwise_chip)) +// } + +// fn set_and_execute( +// tester: &mut VmChipTestBuilder, +// harness: &mut Harness, +// rng: &mut StdRng, +// opcode: BaseAluOpcode, +// b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, +// is_imm: Option, +// c: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, +// ) { +// let b = b.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); +// let (c_imm, c) = if is_imm.unwrap_or(rng.gen_bool(0.5)) { +// let (imm, c) = if let Some(c) = c { +// ((u32::from_le_bytes(c) & 0xFFFFFF) as usize, c) +// } else { +// generate_rv32_is_type_immediate(rng) +// }; +// (Some(imm), c) +// } else { +// ( +// None, +// c.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))), +// ) +// }; + +// let (instruction, rd) = rv32_rand_write_register_or_imm( +// tester, +// b, +// c, +// c_imm, +// opcode.global_opcode().as_usize(), +// rng, +// ); +// tester.execute(harness, &instruction); + +// let a = run_alu::(opcode, &b, &c) +// .map(F::from_canonical_u8); +// assert_eq!(a, tester.read::(1, rd)) +// } + +// ////////////////////////////////////////////////////////////////////////////////////// +// // POSITIVE TESTS +// // +// // Randomly generate computations and execute, ensuring that the generated trace +// // passes all constraints. +// ////////////////////////////////////////////////////////////////////////////////////// + +// #[test_case(ADD, 100)] +// #[test_case(SUB, 100)] +// #[test_case(XOR, 100)] +// #[test_case(OR, 100)] +// #[test_case(AND, 100)] +// fn rand_rv32_alu_test(opcode: BaseAluOpcode, num_ops: usize) { +// let mut rng = create_seeded_rng(); + +// let mut tester = VmChipTestBuilder::default(); +// let (mut harness, bitwise) = create_test_chip(&tester); + +// // TODO(AG): make a more meaningful test for memory accesses +// tester.write(2, 1024, [F::ONE; 4]); +// tester.write(2, 1028, [F::ONE; 4]); +// let sm = tester.read(2, 1024); +// assert_eq!(sm, [F::ONE; 8]); + +// for _ in 0..num_ops { +// set_and_execute( +// &mut tester, +// &mut harness, +// &mut rng, +// opcode, +// None, +// None, +// None, +// ); +// } + +// let tester = tester +// .build() +// .load(harness) +// .load_periphery(bitwise) +// .finalize(); +// tester.simple_test().expect("Verification failed"); +// } + +// #[test_case(ADD, 100)] +// #[test_case(SUB, 100)] +// #[test_case(XOR, 100)] +// #[test_case(OR, 100)] +// #[test_case(AND, 100)] +// fn rand_rv32_alu_test_persistent(opcode: BaseAluOpcode, num_ops: usize) { +// let mut rng = create_seeded_rng(); + +// let mut tester = VmChipTestBuilder::default_persistent(); +// let (mut harness, bitwise) = create_test_chip(&tester); + +// // TODO(AG): make a more meaningful test for memory accesses +// tester.write(2, 1024, [F::ONE; 4]); +// tester.write(2, 1028, [F::ONE; 4]); +// let sm = tester.read(2, 1024); +// assert_eq!(sm, [F::ONE; 8]); + +// for _ in 0..num_ops { +// set_and_execute( +// &mut tester, +// &mut harness, +// &mut rng, +// opcode, +// None, +// None, +// None, +// ); +// } + +// let tester = tester +// .build() +// .load(harness) +// .load_periphery(bitwise) +// .finalize(); +// tester.simple_test().expect("Verification failed"); +// } + +// ////////////////////////////////////////////////////////////////////////////////////// +// // NEGATIVE TESTS +// // +// // Given a fake trace of a single operation, setup a chip and run the test. We replace +// // part of the trace and check that the chip throws the expected error. +// ////////////////////////////////////////////////////////////////////////////////////// + +// #[allow(clippy::too_many_arguments)] +// fn run_negative_alu_test( +// opcode: BaseAluOpcode, +// prank_a: [u32; RV32_REGISTER_NUM_LIMBS], +// b: [u8; RV32_REGISTER_NUM_LIMBS], +// c: [u8; RV32_REGISTER_NUM_LIMBS], +// prank_c: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, +// prank_opcode_flags: Option<[bool; 5]>, +// is_imm: Option, +// interaction_error: bool, +// ) { +// let mut rng = create_seeded_rng(); +// let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); +// let (mut chip, bitwise) = create_test_chip(&tester); + +// set_and_execute( +// &mut tester, +// &mut chip, +// &mut rng, +// opcode, +// Some(b), +// is_imm, +// Some(c), +// ); + +// let adapter_width = BaseAir::::width(&chip.air.adapter); +// let modify_trace = |trace: &mut DenseMatrix| { +// let mut values = trace.row_slice(0).to_vec(); +// let cols: &mut BaseAluCoreCols = +// values.split_at_mut(adapter_width).1.borrow_mut(); +// cols.a = prank_a.map(F::from_canonical_u32); +// if let Some(prank_c) = prank_c { +// cols.c = prank_c.map(F::from_canonical_u32); +// } +// if let Some(prank_opcode_flags) = prank_opcode_flags { +// cols.opcode_add_flag = F::from_bool(prank_opcode_flags[0]); +// cols.opcode_and_flag = F::from_bool(prank_opcode_flags[1]); +// cols.opcode_or_flag = F::from_bool(prank_opcode_flags[2]); +// cols.opcode_sub_flag = F::from_bool(prank_opcode_flags[3]); +// cols.opcode_xor_flag = F::from_bool(prank_opcode_flags[4]); +// } +// *trace = RowMajorMatrix::new(values, trace.width()); +// }; + +// disable_debug_builder(); +// let tester = tester +// .build() +// .load_and_prank_trace(chip, modify_trace) +// .load_periphery(bitwise) +// .finalize(); +// tester.simple_test_with_expected_error(get_verification_error(interaction_error)); +// } + +// #[test] +// fn rv32_alu_add_wrong_negative_test() { +// run_negative_alu_test( +// ADD, +// [246, 0, 0, 0], +// [250, 0, 0, 0], +// [250, 0, 0, 0], +// None, +// None, +// None, +// false, +// ); +// } + +// #[test] +// fn rv32_alu_add_out_of_range_negative_test() { +// run_negative_alu_test( +// ADD, +// [500, 0, 0, 0], +// [250, 0, 0, 0], +// [250, 0, 0, 0], +// None, +// None, +// None, +// true, +// ); +// } + +// #[test] +// fn rv32_alu_sub_wrong_negative_test() { +// run_negative_alu_test( +// SUB, +// [255, 0, 0, 0], +// [1, 0, 0, 0], +// [2, 0, 0, 0], +// None, +// None, +// None, +// false, +// ); +// } + +// #[test] +// fn rv32_alu_sub_out_of_range_negative_test() { +// run_negative_alu_test( +// SUB, +// [F::NEG_ONE.as_canonical_u32(), 0, 0, 0], +// [1, 0, 0, 0], +// [2, 0, 0, 0], +// None, +// None, +// None, +// true, +// ); +// } + +// #[test] +// fn rv32_alu_xor_wrong_negative_test() { +// run_negative_alu_test( +// XOR, +// [255, 255, 255, 255], +// [0, 0, 1, 0], +// [255, 255, 255, 255], +// None, +// None, +// None, +// true, +// ); +// } + +// #[test] +// fn rv32_alu_or_wrong_negative_test() { +// run_negative_alu_test( +// OR, +// [255, 255, 255, 255], +// [255, 255, 255, 254], +// [0, 0, 0, 0], +// None, +// None, +// None, +// true, +// ); +// } + +// #[test] +// fn rv32_alu_and_wrong_negative_test() { +// run_negative_alu_test( +// AND, +// [255, 255, 255, 255], +// [0, 0, 1, 0], +// [0, 0, 0, 0], +// None, +// None, +// None, +// true, +// ); +// } + +// #[test] +// fn rv32_alu_adapter_unconstrained_imm_limb_test() { +// run_negative_alu_test( +// ADD, +// [255, 7, 0, 0], +// [0, 0, 0, 0], +// [255, 7, 0, 0], +// Some([511, 6, 0, 0]), +// None, +// Some(true), +// true, +// ); +// } + +// #[test] +// fn rv32_alu_adapter_unconstrained_rs2_read_test() { +// run_negative_alu_test( +// ADD, +// [2, 2, 2, 2], +// [1, 1, 1, 1], +// [1, 1, 1, 1], +// None, +// Some([false, false, false, false, false]), +// Some(false), +// false, +// ); +// } + +// /////////////////////////////////////////////////////////////////////////////////////// +// /// SANITY TESTS +// /// +// /// Ensure that solve functions produce the correct results. +// /////////////////////////////////////////////////////////////////////////////////////// + +// #[test] +// fn run_add_sanity_test() { +// let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; +// let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; +// let z: [u8; RV32_REGISTER_NUM_LIMBS] = [23, 205, 73, 49]; +// let result = run_alu::(ADD, &x, &y); +// for i in 0..RV32_REGISTER_NUM_LIMBS { +// assert_eq!(z[i], result[i]) +// } +// } + +// #[test] +// fn run_sub_sanity_test() { +// let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; +// let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; +// let z: [u8; RV32_REGISTER_NUM_LIMBS] = [179, 118, 240, 172]; +// let result = run_alu::(SUB, &x, &y); +// for i in 0..RV32_REGISTER_NUM_LIMBS { +// assert_eq!(z[i], result[i]) +// } +// } + +// #[test] +// fn run_xor_sanity_test() { +// let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; +// let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; +// let z: [u8; RV32_REGISTER_NUM_LIMBS] = [215, 138, 49, 173]; +// let result = run_alu::(XOR, &x, &y); +// for i in 0..RV32_REGISTER_NUM_LIMBS { +// assert_eq!(z[i], result[i]) +// } +// } + +// #[test] +// fn run_or_sanity_test() { +// let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; +// let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; +// let z: [u8; RV32_REGISTER_NUM_LIMBS] = [247, 171, 61, 239]; +// let result = run_alu::(OR, &x, &y); +// for i in 0..RV32_REGISTER_NUM_LIMBS { +// assert_eq!(z[i], result[i]) +// } +// } + +// #[test] +// fn run_and_sanity_test() { +// let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; +// let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; +// let z: [u8; RV32_REGISTER_NUM_LIMBS] = [32, 33, 12, 66]; +// let result = run_alu::(AND, &x, &y); +// for i in 0..RV32_REGISTER_NUM_LIMBS { +// assert_eq!(z[i], result[i]) +// } +// } diff --git a/extensions/memcpy/tests/Cargo.toml b/extensions/memcpy/tests/Cargo.toml new file mode 100644 index 0000000000..2dbf144891 --- /dev/null +++ b/extensions/memcpy/tests/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "openvm-memcpy-integration-tests" +description = "Integration tests for the OpenVM memcpy extension" +version.workspace = true +authors.workspace = true +edition.workspace = true +homepage.workspace = true +repository.workspace = true + +[dependencies] +openvm-instructions = { workspace = true } +openvm-stark-sdk.workspace = true +openvm-transpiler.workspace = true +openvm-memcpy-circuit.workspace = true +openvm-memcpy-transpiler.workspace = true +openvm = { workspace = true } +openvm-toolchain-tests = { path = "../../../crates/toolchain/tests" } +eyre.workspace = true +serde = { workspace = true, features = ["alloc"] } +strum.workspace = true +rand.workspace = true +openvm-circuit = { workspace = true, features = ["test-utils"] } +test-case.workspace = true + +[features] +default = ["parallel"] +parallel = ["openvm-circuit/parallel"] diff --git a/extensions/memcpy/tests/src/lib.rs b/extensions/memcpy/tests/src/lib.rs new file mode 100644 index 0000000000..964aa39cb3 --- /dev/null +++ b/extensions/memcpy/tests/src/lib.rs @@ -0,0 +1,567 @@ +#[cfg(test)] +mod tests { + use std::{array, borrow::BorrowMut, sync::Arc}; + + use openvm_circuit::arch::testing::{TestChipHarness, VmChipTestBuilder}; + use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerAir, VariableRangeCheckerBus, + VariableRangeCheckerChip, + }; + use openvm_instructions::LocalOpcode; + use openvm_memcpy_circuit::{ + bus::MemcpyBus, + extension::{Memcpy, MemcpyCpuProverExt}, + MemcpyIterAir, MemcpyIterCols, MemcpyIterExecutor, MemcpyIterFiller, MemcpyLoopAir, + MemcpyLoopChip, MemcpyLoopExecutor, MEMCPY_LOOP_LIMB_BITS, MEMCPY_LOOP_NUM_LIMBS, + }; + use openvm_memcpy_transpiler::Rv32MemcpyOpcode; + use openvm_stark_backend::{ + p3_air::BaseAir, + p3_field::{FieldAlgebra, PrimeField32}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, + utils::disable_debug_builder, + }; + use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; + use rand::{rngs::StdRng, Rng}; + use test_case::test_case; + + const MAX_INS_CAPACITY: usize = 128; + type F = BabyBear; + type Harness = TestChipHarness>; + + fn create_test_chip( + tester: &VmChipTestBuilder, + ) -> ( + Harness, + ( + VariableRangeCheckerAir, + SharedVariableRangeCheckerChip, + ), + ) { + let range_bus = VariableRangeCheckerBus::new(tester.new_bus_idx()); + let range_chip = Arc::new(VariableRangeCheckerChip::::new( + range_bus, + )); + + let memcpy_bus = MemcpyBus::new(tester.new_bus_idx()); + + let air = MemcpyIterAir::new( + tester.memory_bridge(), + range_bus, + memcpy_bus, + tester.pointer_max_bits(), + ); + let executor = MemcpyIterExecutor::new(Rv32MemcpyOpcode::CLASS_OFFSET); + let chip = MemcpyLoopChip::new( + tester.system_port(), + range_bus, + memcpy_bus, + Rv32MemcpyOpcode::CLASS_OFFSET, + tester.pointer_max_bits(), + range_chip.clone(), + ); + let harness = Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + + (harness, (range_chip.air, range_chip)) + } + + fn set_and_execute_memcpy( + tester: &mut VmChipTestBuilder, + harness: &mut Harness, + rng: &mut StdRng, + shift: u32, + source_data: &[u8], + dest_offset: u32, + source_offset: u32, + len: u32, + ) { + // Write source data to memory + for (i, &byte) in source_data.iter().enumerate() { + tester.write(2, source_offset + i as u32, [F::from_canonical_u8(byte)]); + } + + // Create instruction for memcpy_loop + let instruction = openvm_instructions::instruction::Instruction { + opcode: openvm_instructions::VmOpcode::from_usize( + Rv32MemcpyOpcode::MEMCPY_LOOP.global_opcode().as_usize(), + ), + a: F::ZERO, + b: F::ZERO, + c: F::from_canonical_u32(shift), + d: F::ZERO, + e: F::ZERO, + f: F::ZERO, + g: F::ZERO, + }; + + tester.execute(harness, &instruction); + + // Verify the copy operation + for i in 0..len.min(source_data.len() as u32) { + let expected = source_data[i as usize]; + let actual = tester.read(2, dest_offset + i)[0].as_canonical_u8(); + assert_eq!(expected, actual, "Mismatch at offset {}", i); + } + } + + ////////////////////////////////////////////////////////////////////////////////////// + // POSITIVE TESTS + // + // Randomly generate memcpy operations and execute, ensuring that the generated trace + // passes all constraints. + ////////////////////////////////////////////////////////////////////////////////////// + + #[test_case(0, 100)] + #[test_case(1, 100)] + #[test_case(2, 100)] + #[test_case(3, 100)] + fn rand_memcpy_loop_test(shift: u32, num_ops: usize) { + let mut rng = create_seeded_rng(); + + let mut tester = VmChipTestBuilder::default(); + let (mut harness, range_checker) = create_test_chip(&tester); + + for _ in 0..num_ops { + let source_data: Vec = (0..16).map(|_| rng.gen_range(0..=u8::MAX)).collect(); + let source_offset = rng.gen_range(0..1000); + let dest_offset = rng.gen_range(2000..3000); + let len = rng.gen_range(1..=16); + + set_and_execute_memcpy( + &mut tester, + &mut harness, + &mut rng, + shift, + &source_data, + dest_offset, + source_offset, + len, + ); + } + + let tester = tester + .build() + .load(harness) + .load_periphery(range_checker) + .finalize(); + tester.simple_test().expect("Verification failed"); + } + + #[test_case(0, 100)] + #[test_case(1, 100)] + #[test_case(2, 100)] + #[test_case(3, 100)] + fn rand_memcpy_loop_test_persistent(shift: u32, num_ops: usize) { + let mut rng = create_seeded_rng(); + + let mut tester = VmChipTestBuilder::default_persistent(); + let (mut harness, range_checker) = create_test_chip(&tester); + + for _ in 0..num_ops { + let source_data: Vec = (0..16).map(|_| rng.gen_range(0..=u8::MAX)).collect(); + let dest_offset = rng.gen_range(0..1000); + let source_offset = rng.gen_range(0..1000); + let len = rng.gen_range(1..=16); + + set_and_execute_memcpy( + &mut tester, + &mut harness, + &mut rng, + shift, + &source_data, + dest_offset, + source_offset, + len, + ); + } + + let tester = tester + .build() + .load(harness) + .load_periphery(range_checker) + .finalize(); + tester.simple_test().expect("Verification failed"); + } + + ////////////////////////////////////////////////////////////////////////////////////// + // NEGATIVE TESTS + // + // Given a fake trace of a single operation, setup a chip and run the test. We replace + // part of the trace and check that the chip throws the expected error. + ////////////////////////////////////////////////////////////////////////////////////// + + // #[allow(clippy::too_many_arguments)] + // fn run_negative_memcpy_test( + // shift: u32, + // prank_shift: u32, + // source_data: &[u8], + // dest_offset: u32, + // source_offset: u32, + // len: u32, + // prank_dest: Option, + // prank_source: Option, + // prank_len: Option, + // interaction_error: bool, + // ) { + // let mut rng = create_seeded_rng(); + // let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); + // let (mut chip, range_checker) = create_test_chip(&tester); + + // set_and_execute_memcpy( + // &mut tester, + // &mut chip, + // &mut rng, + // shift, + // source_data, + // dest_offset, + // source_offset, + // len, + // ); + + // let adapter_width = BaseAir::::width(&chip.air); + // let modify_trace = |trace: &mut DenseMatrix| { + // let mut values = trace.row_slice(0).to_vec(); + // let cols: &mut MemcpyIterCols = values.split_at_mut(adapter_width).1.borrow_mut(); + // cols.shift = [F::from_canonical_u32(prank_shift), F::ZERO]; + // if let Some(prank_dest) = prank_dest { + // cols.dest = F::from_canonical_u32(prank_dest); + // } + // if let Some(prank_source) = prank_source { + // cols.source = F::from_canonical_u32(prank_source); + // } + // if let Some(prank_len) = prank_len { + // cols.len = [F::from_canonical_u32(prank_len), F::ZERO]; + // } + // *trace = RowMajorMatrix::new(values, trace.width()); + // }; + + // disable_debug_builder(); + // let tester = tester + // .build() + // .load_and_prank_trace(chip, modify_trace) + // .load_periphery(range_checker) + // .finalize(); + + // if interaction_error { + // tester + // .simple_test() + // .expect_err("Expected verification to fail"); + // } else { + // tester + // .simple_test() + // .expect_err("Expected verification to fail"); + // } + // } + + // #[test] + // fn memcpy_wrong_shift_negative_test() { + // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + // run_negative_memcpy_test( + // 0, // original shift + // 1, // prank shift + // &source_data, + // 100, // dest_offset + // 200, // source_offset + // 8, // len + // None, + // None, + // None, + // true, + // ); + // } + + // #[test] + // fn memcpy_wrong_dest_negative_test() { + // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + // run_negative_memcpy_test( + // 0, // shift + // 0, // prank shift (same) + // &source_data, + // 100, // dest_offset + // 200, // source_offset + // 8, // len + // Some(150), // prank dest + // None, + // None, + // true, + // ); + // } + + // #[test] + // fn memcpy_wrong_source_negative_test() { + // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + // run_negative_memcpy_test( + // 0, // shift + // 0, // prank shift (same) + // &source_data, + // 100, // dest_offset + // 200, // source_offset + // 8, // len + // None, + // Some(250), // prank source + // None, + // true, + // ); + // } + + // #[test] + // fn memcpy_wrong_len_negative_test() { + // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + // run_negative_memcpy_test( + // 0, // shift + // 0, // prank shift (same) + // &source_data, + // 100, // dest_offset + // 200, // source_offset + // 8, // len + // None, + // None, + // Some(12), // prank len + // true, + // ); + // } + + // ////////////////////////////////////////////////////////////////////////////////////// + // // SANITY TESTS + // // + // // Ensure that memcpy operations produce the correct results. + // ////////////////////////////////////////////////////////////////////////////////////// + + // #[test] + // fn memcpy_shift_0_sanity_test() { + // let mut rng = create_seeded_rng(); + // let mut tester = VmChipTestBuilder::default(); + // let (mut harness, range_checker) = create_test_chip(&tester); + + // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + + // set_and_execute_memcpy( + // &mut tester, + // &mut harness, + // &mut rng, + // 0, // shift + // &source_data, + // 100, // dest_offset + // 200, // source_offset + // 8, // len + // ); + + // // Verify the copy operation + // for i in 0..8 { + // let expected = source_data[i]; + // let actual = tester.read(2, 100 + i)[0].as_canonical_u8(); + // assert_eq!(expected, actual, "Mismatch at offset {}", i); + // } + + // let tester = tester + // .build() + // .load(harness) + // .load_periphery(range_checker) + // .finalize(); + // tester.simple_test().expect("Verification failed"); + // } + + // #[test] + // fn memcpy_shift_1_sanity_test() { + // let mut rng = create_seeded_rng(); + // let mut tester = VmChipTestBuilder::default(); + // let (mut harness, range_checker) = create_test_chip(&tester); + + // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + + // set_and_execute_memcpy( + // &mut tester, + // &mut harness, + // &mut rng, + // 1, // shift + // &source_data, + // 100, // dest_offset + // 200, // source_offset + // 8, // len + // ); + + // // Verify the copy operation with shift=1 + // for i in 0..8 { + // let expected = source_data[i]; + // let actual = tester.read(2, 100 + i)[0].as_canonical_u8(); + // assert_eq!(expected, actual, "Mismatch at offset {}", i); + // } + + // let tester = tester + // .build() + // .load(harness) + // .load_periphery(range_checker) + // .finalize(); + // tester.simple_test().expect("Verification failed"); + // } + + // #[test] + // fn memcpy_shift_2_sanity_test() { + // let mut rng = create_seeded_rng(); + // let mut tester = VmChipTestBuilder::default(); + // let (mut harness, range_checker) = create_test_chip(&tester); + + // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + + // set_and_execute_memcpy( + // &mut tester, + // &mut harness, + // &mut rng, + // 2, // shift + // &source_data, + // 100, // dest_offset + // 200, // source_offset + // 8, // len + // ); + + // // Verify the copy operation with shift=2 + // for i in 0..8 { + // let expected = source_data[i]; + // let actual = tester.read(2, 100 + i)[0].as_canonical_u8(); + // assert_eq!(expected, actual, "Mismatch at offset {}", i); + // } + + // let tester = tester + // .build() + // .load(harness) + // .load_periphery(range_checker) + // .finalize(); + // tester.simple_test().expect("Verification failed"); + // } + + // #[test] + // fn memcpy_shift_3_sanity_test() { + // let mut rng = create_seeded_rng(); + // let mut tester = VmChipTestBuilder::default(); + // let (mut harness, range_checker) = create_test_chip(&tester); + + // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + + // set_and_execute_memcpy( + // &mut tester, + // &mut harness, + // &mut rng, + // 3, // shift + // &source_data, + // 100, // dest_offset + // 200, // source_offset + // 8, // len + // ); + + // // Verify the copy operation with shift=3 + // for i in 0..8 { + // let expected = source_data[i]; + // let actual = tester.read(2, 100 + i)[0].as_canonical_u8(); + // assert_eq!(expected, actual, "Mismatch at offset {}", i); + // } + + // let tester = tester + // .build() + // .load(harness) + // .load_periphery(range_checker) + // .finalize(); + // tester.simple_test().expect("Verification failed"); + // } + + // ////////////////////////////////////////////////////////////////////////////////////// + // // EDGE CASE TESTS + // // + // // Test edge cases and boundary conditions. + // ////////////////////////////////////////////////////////////////////////////////////// + + // #[test] + // fn memcpy_zero_length_test() { + // let mut rng = create_seeded_rng(); + // let mut tester = VmChipTestBuilder::default(); + // let (mut harness, range_checker) = create_test_chip(&tester); + + // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + + // set_and_execute_memcpy( + // &mut tester, + // &mut harness, + // &mut rng, + // 0, // shift + // &source_data, + // 100, // dest_offset + // 200, // source_offset + // 0, // zero length + // ); + + // let tester = tester + // .build() + // .load(harness) + // .load_periphery(range_checker) + // .finalize(); + // tester.simple_test().expect("Verification failed"); + // } + + // #[test] + // fn memcpy_max_length_test() { + // let mut rng = create_seeded_rng(); + // let mut tester = VmChipTestBuilder::default(); + // let (mut harness, range_checker) = create_test_chip(&tester); + + // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + + // set_and_execute_memcpy( + // &mut tester, + // &mut harness, + // &mut rng, + // 0, // shift + // &source_data, + // 100, // dest_offset + // 200, // source_offset + // 16, // max length + // ); + + // // Verify the copy operation + // for i in 0..16 { + // let expected = source_data[i]; + // let actual = tester.read(2, 100 + i)[0].as_canonical_u8(); + // assert_eq!(expected, actual, "Mismatch at offset {}", i); + // } + + // let tester = tester + // .build() + // .load(harness) + // .load_periphery(range_checker) + // .finalize(); + // tester.simple_test().expect("Verification failed"); + // } + + // #[test] + // fn memcpy_overlapping_regions_test() { + // let mut rng = create_seeded_rng(); + // let mut tester = VmChipTestBuilder::default(); + // let (mut harness, range_checker) = create_test_chip(&tester); + + // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + + // // Write initial data to destination + // for (i, &byte) in source_data.iter().enumerate() { + // tester.write(2, 100 + i as u32, [F::from_canonical_u8(byte)]); + // } + + // set_and_execute_memcpy( + // &mut tester, + // &mut harness, + // &mut rng, + // 0, // shift + // &source_data, + // 102, // dest_offset (overlapping with source) + // 100, // source_offset + // 8, // len + // ); + + // let tester = tester + // .build() + // .load(harness) + // .load_periphery(range_checker) + // .finalize(); + // tester.simple_test().expect("Verification failed"); + // } +} diff --git a/extensions/memcpy/transpiler/src/lib.rs b/extensions/memcpy/transpiler/src/lib.rs index 76968ade99..dd60126222 100644 --- a/extensions/memcpy/transpiler/src/lib.rs +++ b/extensions/memcpy/transpiler/src/lib.rs @@ -15,12 +15,8 @@ pub enum Rv32MemcpyOpcode { MEMCPY_LOOP, } -// pub const OPCODE: u8 = 0x0b; -// pub const KECCAK256_FUNCT3: u8 = 0b100; -// pub const KECCAK256_FUNCT7: u8 = 0; // Custom opcode for memcpy_loop instruction -pub const MEMCPY_LOOP_OPCODE: u8 = 0x73; // Custom opcode -pub const MEMCPY_LOOP_FUNCT3: u8 = 0x0; // Custom funct3 +pub const MEMCPY_LOOP_OPCODE: u8 = 0x72; // Custom opcode #[derive(Default)] pub struct MemcpyTranspilerExtension; @@ -33,19 +29,18 @@ impl TranspilerExtension for MemcpyTranspilerExtension { let instruction_u32 = instruction_stream[0]; let opcode = (instruction_u32 & 0x7f) as u8; - let funct3 = ((instruction_u32 >> 12) & 0b111) as u8; // Check if this is our custom memcpy_loop instruction - if (opcode, funct3) != (MEMCPY_LOOP_OPCODE, MEMCPY_LOOP_FUNCT3) { + if opcode != MEMCPY_LOOP_OPCODE { return None; } - // Parse I-type instruction format + // Parse U-type instruction format let dec_insn = UType::new(instruction_u32); - let shift = dec_insn.imm as u8; + let shift = dec_insn.imm >> 12; // Validate shift value (0, 1, 2, or 3) - if shift > 3u8 { + if ![0, 1, 2, 3].contains(&shift) { return None; } From 42435bf83eec6b76fd1e4fa2512f2a31726516e6 Mon Sep 17 00:00:00 2001 From: Peyman Jabbarzade Date: Tue, 19 Aug 2025 14:16:58 -0400 Subject: [PATCH 03/14] fix memcpy_loop opcode in memcpy.s --- crates/toolchain/openvm/src/memcpy.s | 2 +- extensions/memcpy/circuit/src/core.rs | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/crates/toolchain/openvm/src/memcpy.s b/crates/toolchain/openvm/src/memcpy.s index 0e63bfdcab..1606d576dc 100644 --- a/crates/toolchain/openvm/src/memcpy.s +++ b/crates/toolchain/openvm/src/memcpy.s @@ -208,7 +208,7 @@ # Define memcpy_loop macro for custom instruction (U-type) .macro memcpy_loop shift - .word 0x72000000 | (\shift << 12) # opcode 0x72 + shift in immediate field (bits 12-31) + .word 0x00000072 | (\shift << 12) # opcode 0x72 + shift in immediate field (bits 12-31) .endm .globl memcpy .p2align 2 diff --git a/extensions/memcpy/circuit/src/core.rs b/extensions/memcpy/circuit/src/core.rs index 80fdf1a7d1..042996e95b 100644 --- a/extensions/memcpy/circuit/src/core.rs +++ b/extensions/memcpy/circuit/src/core.rs @@ -39,7 +39,7 @@ use openvm_stark_backend::{ use crate::{bus::MemcpyBus, MemcpyIterChip}; use openvm_circuit::arch::{ - execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + execution_mode::{ExecutionCtxTrait, MeteredExecutionCtxTrait}, get_record_from_slice, ExecuteFunc, ExecutionError, Executor, MeteredExecutor, RecordArena, StaticProgramError, TraceFiller, VmExecState, }; @@ -324,7 +324,7 @@ where } fn execute( - &mut self, + &self, state: VmStateMut, instruction: &Instruction, ) -> Result<(), ExecutionError> { @@ -565,7 +565,7 @@ impl Executor for MemcpyLoopExecutor { data: &mut [u8], ) -> Result, StaticProgramError> where - Ctx: E1ExecutionCtx, + Ctx: ExecutionCtxTrait, { let data: &mut MemcpyLoopPreCompute = data.borrow_mut(); self.pre_compute_impl(pc, inst, data)?; @@ -586,7 +586,7 @@ impl MeteredExecutor for MemcpyLoopExecutor { data: &mut [u8], ) -> Result, StaticProgramError> where - Ctx: E2ExecutionCtx, + Ctx: MeteredExecutionCtxTrait, { let data: &mut E2PreCompute = data.borrow_mut(); data.chip_idx = chip_idx as u32; @@ -596,7 +596,7 @@ impl MeteredExecutor for MemcpyLoopExecutor { } #[inline(always)] -unsafe fn execute_e12_impl( +unsafe fn execute_e12_impl( pre_compute: &MemcpyLoopPreCompute, vm_state: &mut VmExecState, ) { @@ -674,7 +674,7 @@ unsafe fn execute_e12_impl( vm_state.instret += 1; } -unsafe fn execute_e1_impl( +unsafe fn execute_e1_impl( pre_compute: &[u8], vm_state: &mut VmExecState, ) { @@ -682,7 +682,7 @@ unsafe fn execute_e1_impl( execute_e12_impl::(pre_compute, vm_state); } -unsafe fn execute_e2_impl( +unsafe fn execute_e2_impl( pre_compute: &[u8], vm_state: &mut VmExecState, ) { From 371b57786a53ecba236b3849f6fc2149fee29ada Mon Sep 17 00:00:00 2001 From: Peyman Jabbarzade Date: Tue, 19 Aug 2025 15:49:44 -0400 Subject: [PATCH 04/14] fix: add memcpy transpiler to tests --- Cargo.lock | 10 ++++ crates/sdk/src/config/openvm_standard.toml | 1 + extensions/algebra/tests/Cargo.toml | 1 + extensions/algebra/tests/src/lib.rs | 22 +++++--- extensions/ecc/tests/Cargo.toml | 1 + extensions/ecc/tests/src/lib.rs | 16 ++++-- extensions/memcpy/transpiler/src/lib.rs | 3 +- extensions/rv32im/tests/Cargo.toml | 1 + extensions/rv32im/tests/src/lib.rs | 40 ++++++++++----- guest-libs/ff_derive/Cargo.toml | 1 + guest-libs/ff_derive/tests/lib.rs | 22 +++++--- guest-libs/k256/Cargo.toml | 1 + guest-libs/k256/tests/lib.rs | 16 ++++-- guest-libs/keccak256/Cargo.toml | 1 + guest-libs/keccak256/tests/lib.rs | 4 +- guest-libs/p256/Cargo.toml | 1 + guest-libs/p256/tests/lib.rs | 16 ++++-- guest-libs/pairing/Cargo.toml | 1 + guest-libs/pairing/tests/lib.rs | 50 +++++++++++++------ guest-libs/ruint/Cargo.toml | 1 + guest-libs/ruint/tests/lib.rs | 4 +- guest-libs/sha2/Cargo.toml | 1 + guest-libs/sha2/tests/lib.rs | 4 +- .../verify_stark/tests/integration_test.rs | 1 + 24 files changed, 157 insertions(+), 62 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 687d5ceae2..f6cacedf82 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4385,6 +4385,7 @@ dependencies = [ "openvm-ecc-guest", "openvm-ecc-sw-macros", "openvm-ecc-transpiler", + "openvm-memcpy-transpiler", "openvm-rv32im-transpiler", "openvm-sha256-circuit", "openvm-sha256-transpiler", @@ -5196,6 +5197,7 @@ dependencies = [ "openvm-circuit", "openvm-ecc-circuit", "openvm-instructions", + "openvm-memcpy-transpiler", "openvm-rv32im-transpiler", "openvm-stark-sdk", "openvm-toolchain-tests", @@ -5584,6 +5586,7 @@ dependencies = [ "openvm-circuit", "openvm-ecc-circuit", "openvm-ecc-transpiler", + "openvm-memcpy-transpiler", "openvm-rv32im-transpiler", "openvm-sdk", "openvm-stark-sdk", @@ -5630,6 +5633,7 @@ dependencies = [ "openvm-algebra-transpiler", "openvm-circuit", "openvm-instructions", + "openvm-memcpy-transpiler", "openvm-rv32im-transpiler", "openvm-stark-sdk", "openvm-toolchain-tests", @@ -5680,6 +5684,7 @@ dependencies = [ "openvm-keccak256-circuit", "openvm-keccak256-guest", "openvm-keccak256-transpiler", + "openvm-memcpy-transpiler", "openvm-rv32im-transpiler", "openvm-stark-sdk", "openvm-toolchain-tests", @@ -5924,6 +5929,7 @@ dependencies = [ "openvm-ecc-sw-macros", "openvm-ecc-transpiler", "openvm-instructions", + "openvm-memcpy-transpiler", "openvm-pairing", "openvm-pairing-circuit", "openvm-pairing-guest", @@ -6104,6 +6110,7 @@ dependencies = [ "openvm", "openvm-circuit", "openvm-instructions", + "openvm-memcpy-transpiler", "openvm-rv32im-circuit", "openvm-rv32im-guest", "openvm-rv32im-transpiler", @@ -6209,6 +6216,7 @@ dependencies = [ "eyre", "openvm-circuit", "openvm-instructions", + "openvm-memcpy-transpiler", "openvm-rv32im-transpiler", "openvm-sha256-circuit", "openvm-sha256-guest", @@ -6450,6 +6458,7 @@ dependencies = [ "openvm-ecc-guest", "openvm-ecc-sw-macros", "openvm-ecc-transpiler", + "openvm-memcpy-transpiler", "openvm-rv32im-transpiler", "openvm-sha256-circuit", "openvm-sha256-transpiler", @@ -7963,6 +7972,7 @@ dependencies = [ "openvm-bigint-transpiler", "openvm-circuit", "openvm-instructions", + "openvm-memcpy-transpiler", "openvm-rv32im-transpiler", "openvm-stark-sdk", "openvm-toolchain-tests", diff --git a/crates/sdk/src/config/openvm_standard.toml b/crates/sdk/src/config/openvm_standard.toml index f1f9267191..23c6a3d68a 100644 --- a/crates/sdk/src/config/openvm_standard.toml +++ b/crates/sdk/src/config/openvm_standard.toml @@ -3,6 +3,7 @@ [app_vm_config.io] [app_vm_config.keccak] +[app_vm_config.memcpy] [app_vm_config.sha256] [app_vm_config.bigint] diff --git a/extensions/algebra/tests/Cargo.toml b/extensions/algebra/tests/Cargo.toml index 0d748f3d88..e73112735d 100644 --- a/extensions/algebra/tests/Cargo.toml +++ b/extensions/algebra/tests/Cargo.toml @@ -15,6 +15,7 @@ openvm-transpiler.workspace = true openvm-algebra-transpiler.workspace = true openvm-algebra-circuit.workspace = true openvm-rv32im-transpiler.workspace = true +openvm-memcpy-transpiler.workspace = true openvm-toolchain-tests = { path = "../../../crates/toolchain/tests" } openvm-ecc-circuit.workspace = true eyre.workspace = true diff --git a/extensions/algebra/tests/src/lib.rs b/extensions/algebra/tests/src/lib.rs index c931dce496..1a56370953 100644 --- a/extensions/algebra/tests/src/lib.rs +++ b/extensions/algebra/tests/src/lib.rs @@ -12,6 +12,7 @@ mod tests { use openvm_circuit::utils::{air_test, test_system_config}; use openvm_ecc_circuit::SECP256K1_CONFIG; use openvm_instructions::exe::VmExe; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; @@ -49,7 +50,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularBuilder, config, openvm_exe); @@ -66,7 +68,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularBuilder, config, openvm_exe); Ok(()) @@ -93,7 +96,8 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(Fp2TranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularWithFp2Builder, config, openvm_exe); Ok(()) @@ -125,7 +129,8 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(Fp2TranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularWithFp2Builder, config, openvm_exe); Ok(()) @@ -145,7 +150,8 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(Fp2TranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularWithFp2Builder, config, openvm_exe); Ok(()) @@ -173,7 +179,8 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(Fp2TranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), ) .unwrap(); air_test(Rv32ModularBuilder, config, openvm_exe); @@ -189,7 +196,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularBuilder, config, openvm_exe); Ok(()) diff --git a/extensions/ecc/tests/Cargo.toml b/extensions/ecc/tests/Cargo.toml index 42a0bee912..73adc0cb24 100644 --- a/extensions/ecc/tests/Cargo.toml +++ b/extensions/ecc/tests/Cargo.toml @@ -15,6 +15,7 @@ openvm-algebra-transpiler.workspace = true openvm-ecc-transpiler.workspace = true openvm-ecc-circuit.workspace = true openvm-rv32im-transpiler.workspace = true +openvm-memcpy-transpiler.workspace = true openvm-toolchain-tests = { path = "../../../crates/toolchain/tests" } openvm-sdk.workspace = true serde.workspace = true diff --git a/extensions/ecc/tests/src/lib.rs b/extensions/ecc/tests/src/lib.rs index 2a10837dd5..a2593aa53d 100644 --- a/extensions/ecc/tests/src/lib.rs +++ b/extensions/ecc/tests/src/lib.rs @@ -17,6 +17,7 @@ mod tests { CurveConfig, Rv32WeierstrassBuilder, Rv32WeierstrassConfig, P256_CONFIG, SECP256K1_CONFIG, }; use openvm_ecc_transpiler::EccTranspilerExtension; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; @@ -60,7 +61,8 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -82,7 +84,8 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -105,7 +108,8 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -152,7 +156,8 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let p = Secp256k1Affine::generator(); @@ -265,7 +270,8 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), ) .unwrap(); let config = diff --git a/extensions/memcpy/transpiler/src/lib.rs b/extensions/memcpy/transpiler/src/lib.rs index dd60126222..bd39fd5d86 100644 --- a/extensions/memcpy/transpiler/src/lib.rs +++ b/extensions/memcpy/transpiler/src/lib.rs @@ -36,8 +36,9 @@ impl TranspilerExtension for MemcpyTranspilerExtension { } // Parse U-type instruction format - let dec_insn = UType::new(instruction_u32); + let mut dec_insn = UType::new(instruction_u32); let shift = dec_insn.imm >> 12; + dec_insn.rd = 1; // avoid using x0, otherwise we get nop() // Validate shift value (0, 1, 2, or 3) if ![0, 1, 2, 3].contains(&shift) { diff --git a/extensions/rv32im/tests/Cargo.toml b/extensions/rv32im/tests/Cargo.toml index 2de37c69c8..ee69a25904 100644 --- a/extensions/rv32im/tests/Cargo.toml +++ b/extensions/rv32im/tests/Cargo.toml @@ -15,6 +15,7 @@ openvm-transpiler.workspace = true openvm-rv32im-circuit.workspace = true openvm-rv32im-guest.workspace = true openvm-rv32im-transpiler.workspace = true +openvm-memcpy-transpiler.workspace = true openvm = { workspace = true } openvm-toolchain-tests = { path = "../../../crates/toolchain/tests" } eyre.workspace = true diff --git a/extensions/rv32im/tests/src/lib.rs b/extensions/rv32im/tests/src/lib.rs index a28a2da3f8..79461e910b 100644 --- a/extensions/rv32im/tests/src/lib.rs +++ b/extensions/rv32im/tests/src/lib.rs @@ -9,6 +9,7 @@ mod tests { utils::{air_test, air_test_with_min_segments, test_system_config}, }; use openvm_instructions::{exe::VmExe, instruction::Instruction, LocalOpcode, SystemOpcode}; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_rv32im_circuit::{Rv32IBuilder, Rv32IConfig, Rv32ImBuilder, Rv32ImConfig}; use openvm_rv32im_guest::hint_load_by_key_encode; use openvm_rv32im_transpiler::{ @@ -46,7 +47,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; change_rv32m_insn_to_nop(&mut exe); air_test_with_min_segments(Rv32IBuilder, config, exe, vec![], min_segments); @@ -63,7 +65,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(Rv32MTranspilerExtension), + .with_extension(Rv32MTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test_with_min_segments(Rv32ImBuilder, config, exe, vec![], min_segments); Ok(()) @@ -84,7 +87,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(Rv32MTranspilerExtension), + .with_extension(Rv32MTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test_with_min_segments(Rv32ImBuilder, config, exe, vec![], min_segments); Ok(()) @@ -99,7 +103,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let input = vec![[0, 1, 2, 3].map(F::from_canonical_u8).to_vec()]; air_test_with_min_segments(Rv32ImBuilder, config, exe, input, 1); @@ -115,7 +120,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; // stdin will be read after reading kv_store let stdin = vec![[0, 1, 2].map(F::from_canonical_u8).to_vec()]; @@ -138,7 +144,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; #[derive(serde::Serialize)] @@ -169,7 +176,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let executor = VmExecutor::new(config.clone())?; @@ -211,7 +219,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ImBuilder, config, exe); Ok(()) @@ -226,7 +235,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let executor = VmExecutor::new(config)?; @@ -253,7 +263,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ImBuilder, config, exe); Ok(()) @@ -273,7 +284,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ImBuilder, config, exe); Ok(()) @@ -289,7 +301,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), ) .unwrap(); let executor = VmExecutor::new(config).unwrap(); @@ -315,7 +328,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), ) .unwrap(); air_test(Rv32ImBuilder, config, exe); diff --git a/guest-libs/ff_derive/Cargo.toml b/guest-libs/ff_derive/Cargo.toml index acb017123f..f2a39da8be 100644 --- a/guest-libs/ff_derive/Cargo.toml +++ b/guest-libs/ff_derive/Cargo.toml @@ -35,6 +35,7 @@ openvm-transpiler = { workspace = true } openvm-algebra-transpiler = { workspace = true } openvm-algebra-circuit = { workspace = true } openvm-rv32im-transpiler = { workspace = true } +openvm-memcpy-transpiler = { workspace = true } openvm-toolchain-tests = { workspace = true } eyre = { workspace = true } diff --git a/guest-libs/ff_derive/tests/lib.rs b/guest-libs/ff_derive/tests/lib.rs index db7a1a09de..48de004639 100644 --- a/guest-libs/ff_derive/tests/lib.rs +++ b/guest-libs/ff_derive/tests/lib.rs @@ -8,6 +8,7 @@ mod tests { use openvm_algebra_transpiler::ModularTranspilerExtension; use openvm_circuit::utils::{air_test, test_system_config}; use openvm_instructions::exe::VmExe; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; @@ -43,7 +44,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularBuilder, config, openvm_exe); @@ -62,7 +64,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularBuilder, config, openvm_exe); @@ -81,7 +84,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularBuilder, config, openvm_exe); @@ -105,7 +109,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularBuilder, config, openvm_exe); @@ -129,7 +134,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularBuilder, config, openvm_exe); @@ -154,7 +160,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularBuilder, config, openvm_exe); @@ -178,7 +185,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularBuilder, config, openvm_exe); diff --git a/guest-libs/k256/Cargo.toml b/guest-libs/k256/Cargo.toml index 6e42d7e52c..c0931ec8fa 100644 --- a/guest-libs/k256/Cargo.toml +++ b/guest-libs/k256/Cargo.toml @@ -38,6 +38,7 @@ openvm-ecc-circuit.workspace = true openvm-sha256-circuit.workspace = true openvm-sha256-transpiler.workspace = true openvm-rv32im-transpiler.workspace = true +openvm-memcpy-transpiler = { workspace = true } openvm-toolchain-tests.workspace = true openvm-stark-backend.workspace = true diff --git a/guest-libs/k256/tests/lib.rs b/guest-libs/k256/tests/lib.rs index 59eb42dde3..b71ef9dad0 100644 --- a/guest-libs/k256/tests/lib.rs +++ b/guest-libs/k256/tests/lib.rs @@ -10,6 +10,7 @@ mod guest_tests { CurveConfig, Rv32WeierstrassBuilder, Rv32WeierstrassConfig, SECP256K1_CONFIG, }; use openvm_ecc_transpiler::EccTranspilerExtension; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; @@ -41,7 +42,8 @@ mod guest_tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -59,7 +61,8 @@ mod guest_tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -80,7 +83,8 @@ mod guest_tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -237,7 +241,8 @@ mod guest_tests { .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Sha256TranspilerExtension), + .with_extension(Sha256TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(EcdsaBuilder, config, openvm_exe); Ok(()) @@ -258,7 +263,8 @@ mod guest_tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) diff --git a/guest-libs/keccak256/Cargo.toml b/guest-libs/keccak256/Cargo.toml index 276ec2679f..4712554154 100644 --- a/guest-libs/keccak256/Cargo.toml +++ b/guest-libs/keccak256/Cargo.toml @@ -20,6 +20,7 @@ openvm-transpiler = { workspace = true } openvm-keccak256-transpiler = { workspace = true } openvm-keccak256-circuit = { workspace = true } openvm-rv32im-transpiler = { workspace = true } +openvm-memcpy-transpiler = { workspace = true } openvm-toolchain-tests = { workspace = true } eyre = { workspace = true } diff --git a/guest-libs/keccak256/tests/lib.rs b/guest-libs/keccak256/tests/lib.rs index fbab4ef8c8..ae2c8c953d 100644 --- a/guest-libs/keccak256/tests/lib.rs +++ b/guest-libs/keccak256/tests/lib.rs @@ -5,6 +5,7 @@ mod tests { use openvm_instructions::exe::VmExe; use openvm_keccak256_circuit::{Keccak256Rv32Builder, Keccak256Rv32Config}; use openvm_keccak256_transpiler::Keccak256TranspilerExtension; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; @@ -25,7 +26,8 @@ mod tests { .with_extension(Keccak256TranspilerExtension) .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Keccak256Rv32Builder, config, openvm_exe); Ok(()) diff --git a/guest-libs/p256/Cargo.toml b/guest-libs/p256/Cargo.toml index 852fa7af95..d9ca0c405a 100644 --- a/guest-libs/p256/Cargo.toml +++ b/guest-libs/p256/Cargo.toml @@ -35,6 +35,7 @@ openvm-ecc-circuit.workspace = true openvm-sha256-circuit.workspace = true openvm-sha256-transpiler.workspace = true openvm-rv32im-transpiler.workspace = true +openvm-memcpy-transpiler.workspace = true openvm-toolchain-tests.workspace = true openvm-stark-backend.workspace = true diff --git a/guest-libs/p256/tests/lib.rs b/guest-libs/p256/tests/lib.rs index 9eaf2b2c74..5e81e44e10 100644 --- a/guest-libs/p256/tests/lib.rs +++ b/guest-libs/p256/tests/lib.rs @@ -10,6 +10,7 @@ mod guest_tests { CurveConfig, Rv32WeierstrassBuilder, Rv32WeierstrassConfig, P256_CONFIG, }; use openvm_ecc_transpiler::EccTranspilerExtension; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; @@ -41,7 +42,8 @@ mod guest_tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -59,7 +61,8 @@ mod guest_tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -80,7 +83,8 @@ mod guest_tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -237,7 +241,8 @@ mod guest_tests { .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Sha256TranspilerExtension), + .with_extension(Sha256TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(EcdsaBuilder, config, openvm_exe); Ok(()) @@ -258,7 +263,8 @@ mod guest_tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) diff --git a/guest-libs/pairing/Cargo.toml b/guest-libs/pairing/Cargo.toml index 649f73f864..8ac04ebef4 100644 --- a/guest-libs/pairing/Cargo.toml +++ b/guest-libs/pairing/Cargo.toml @@ -46,6 +46,7 @@ openvm-ecc-circuit.workspace = true openvm-ecc-guest.workspace = true openvm-ecc-transpiler.workspace = true openvm-rv32im-transpiler.workspace = true +openvm-memcpy-transpiler.workspace = true openvm = { workspace = true } openvm-toolchain-tests = { workspace = true } eyre.workspace = true diff --git a/guest-libs/pairing/tests/lib.rs b/guest-libs/pairing/tests/lib.rs index 68150d536a..cd79b97545 100644 --- a/guest-libs/pairing/tests/lib.rs +++ b/guest-libs/pairing/tests/lib.rs @@ -36,6 +36,7 @@ mod bn254 { use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_stark_sdk::{ config::FriParameters, openvm_stark_backend::p3_field::FieldAlgebra, p3_baby_bear::BabyBear, }; @@ -85,7 +86,8 @@ mod bn254 { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -108,7 +110,8 @@ mod bn254 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let mut rng = rand::rngs::StdRng::seed_from_u64(2); @@ -144,7 +147,8 @@ mod bn254 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let mut rng = rand::rngs::StdRng::seed_from_u64(2); @@ -202,7 +206,8 @@ mod bn254 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let mut rng = rand::rngs::StdRng::seed_from_u64(20); @@ -251,7 +256,8 @@ mod bn254 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let S = G1Affine::generator(); @@ -304,7 +310,8 @@ mod bn254 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let S = G1Affine::generator(); @@ -361,7 +368,8 @@ mod bn254 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let S = G1Affine::generator(); @@ -426,7 +434,8 @@ mod bn254 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let P = G1Affine::generator(); @@ -490,6 +499,7 @@ mod bls12_381 { AffinePoint, }; use openvm_ecc_transpiler::EccTranspilerExtension; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_pairing_circuit::{ PairingCurve, PairingExtension, Rv32PairingBuilder, Rv32PairingConfig, }; @@ -560,7 +570,8 @@ mod bls12_381 { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -583,7 +594,8 @@ mod bls12_381 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let mut rng = rand::rngs::StdRng::seed_from_u64(50); @@ -619,7 +631,8 @@ mod bls12_381 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let mut rng = rand::rngs::StdRng::seed_from_u64(5); @@ -678,7 +691,8 @@ mod bls12_381 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let mut rng = rand::rngs::StdRng::seed_from_u64(88); @@ -727,7 +741,8 @@ mod bls12_381 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let S = G1Affine::generator(); @@ -786,7 +801,8 @@ mod bls12_381 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let S = G1Affine::generator(); @@ -843,7 +859,8 @@ mod bls12_381 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let S = G1Affine::generator(); @@ -907,7 +924,8 @@ mod bls12_381 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let P = G1Affine::generator(); diff --git a/guest-libs/ruint/Cargo.toml b/guest-libs/ruint/Cargo.toml index 214f315e4f..d4ffe6137b 100644 --- a/guest-libs/ruint/Cargo.toml +++ b/guest-libs/ruint/Cargo.toml @@ -114,6 +114,7 @@ openvm-transpiler.workspace = true openvm-bigint-transpiler.workspace = true openvm-bigint-circuit.workspace = true openvm-rv32im-transpiler.workspace = true +openvm-memcpy-transpiler.workspace = true openvm-toolchain-tests = { workspace = true } eyre.workspace = true diff --git a/guest-libs/ruint/tests/lib.rs b/guest-libs/ruint/tests/lib.rs index 1ebef69bad..a6f42327c3 100644 --- a/guest-libs/ruint/tests/lib.rs +++ b/guest-libs/ruint/tests/lib.rs @@ -5,6 +5,7 @@ mod tests { use openvm_bigint_transpiler::Int256TranspilerExtension; use openvm_circuit::utils::air_test; use openvm_instructions::exe::VmExe; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; @@ -28,7 +29,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(Int256TranspilerExtension), + .with_extension(Int256TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Int256Rv32Builder, config, openvm_exe); Ok(()) diff --git a/guest-libs/sha2/Cargo.toml b/guest-libs/sha2/Cargo.toml index 573930affb..e56b9eac1f 100644 --- a/guest-libs/sha2/Cargo.toml +++ b/guest-libs/sha2/Cargo.toml @@ -20,6 +20,7 @@ openvm-transpiler = { workspace = true } openvm-sha256-transpiler = { workspace = true } openvm-sha256-circuit = { workspace = true } openvm-rv32im-transpiler = { workspace = true } +openvm-memcpy-transpiler = { workspace = true } openvm-toolchain-tests = { workspace = true } eyre = { workspace = true } diff --git a/guest-libs/sha2/tests/lib.rs b/guest-libs/sha2/tests/lib.rs index adfae8e764..9d449627dd 100644 --- a/guest-libs/sha2/tests/lib.rs +++ b/guest-libs/sha2/tests/lib.rs @@ -9,6 +9,7 @@ mod tests { use openvm_sha256_circuit::{Sha256Rv32Builder, Sha256Rv32Config}; use openvm_sha256_transpiler::Sha256TranspilerExtension; use openvm_stark_sdk::p3_baby_bear::BabyBear; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_toolchain_tests::{build_example_program_at_path, get_programs_dir}; use openvm_transpiler::{transpiler::Transpiler, FromElf}; @@ -25,7 +26,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(Sha256TranspilerExtension), + .with_extension(Sha256TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Sha256Rv32Builder, config, openvm_exe); Ok(()) diff --git a/guest-libs/verify_stark/tests/integration_test.rs b/guest-libs/verify_stark/tests/integration_test.rs index ea1150976a..1e1caf03c7 100644 --- a/guest-libs/verify_stark/tests/integration_test.rs +++ b/guest-libs/verify_stark/tests/integration_test.rs @@ -33,6 +33,7 @@ fn test_verify_openvm_stark_e2e() -> Result<()> { .rv32m(Default::default()) .io(Default::default()) .native(Default::default()) + .memcpy(Default::default()) .build(); let fri_params = FriParameters::new_for_testing(LEAF_LOG_BLOWUP); let app_config = AppConfig::new_with_leaf_fri_params(fri_params, vm_config.clone(), fri_params); From 24e3d154df9db60a9de0098c031b249f77481234 Mon Sep 17 00:00:00 2001 From: Peyman Jabbarzade Date: Thu, 21 Aug 2025 12:16:44 -0400 Subject: [PATCH 05/14] fix: make memcpy_iter as executer and memcpy_loop as periphery (reverse the chips) --- .../src/system/memory/offline_checker/mod.rs | 2 +- extensions/memcpy/circuit/src/core.rs | 533 +++--------- extensions/memcpy/circuit/src/extension.rs | 17 +- extensions/memcpy/circuit/src/iteration.rs | 757 +++++++++++++----- 4 files changed, 693 insertions(+), 616 deletions(-) diff --git a/crates/vm/src/system/memory/offline_checker/mod.rs b/crates/vm/src/system/memory/offline_checker/mod.rs index c903319a23..010a74bb31 100644 --- a/crates/vm/src/system/memory/offline_checker/mod.rs +++ b/crates/vm/src/system/memory/offline_checker/mod.rs @@ -15,7 +15,7 @@ pub struct MemoryBaseAuxRecord { } #[repr(C)] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct MemoryExtendedAuxRecord { pub prev_timestamp: u32, pub timestamp_lt_aux: [u32; AUX_LEN], diff --git a/extensions/memcpy/circuit/src/core.rs b/extensions/memcpy/circuit/src/core.rs index 042996e95b..bf4180b7c7 100644 --- a/extensions/memcpy/circuit/src/core.rs +++ b/extensions/memcpy/circuit/src/core.rs @@ -1,20 +1,19 @@ use std::{ array, borrow::{Borrow, BorrowMut}, - mem::size_of, - sync::Arc, + mem::{align_of, size_of}, + sync::{Arc, Mutex}, }; use openvm_circuit::{ arch::*, - system::memory::{ + system::{memory::{ offline_checker::{ - MemoryBaseAuxCols, MemoryBaseAuxRecord, MemoryBridge, MemoryReadAuxRecord, - MemoryWriteAuxCols, MemoryWriteBytesAuxRecord, + MemoryBaseAuxCols, MemoryBaseAuxRecord, MemoryBridge, MemoryExtendedAuxRecord, MemoryReadAuxRecord, MemoryWriteAuxCols, MemoryWriteBytesAuxRecord }, online::{GuestMemory, TracingMemory}, MemoryAddress, MemoryAuxColsFactory, - }, + }, SystemPort}, }; use openvm_circuit_primitives::{ utils::{not, or, select}, @@ -30,18 +29,14 @@ use openvm_instructions::{ }; use openvm_rv32im_circuit::adapters::{read_rv32_register, tracing_read, tracing_write}; use openvm_stark_backend::{ - interaction::InteractionBuilder, - p3_air::{Air, AirBuilder, BaseAir}, - p3_field::{Field, FieldAlgebra, PrimeField32}, - p3_matrix::Matrix, - rap::{BaseAirWithPublicValues, PartitionedBaseAir}, + config::{StarkGenericConfig, Val}, interaction::InteractionBuilder, p3_air::{Air, AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, p3_matrix::{dense::RowMajorMatrix, Matrix}, prover::{cpu::CpuBackend, types::AirProvingContext}, rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, Chip, ChipUsageGetter }; use crate::{bus::MemcpyBus, MemcpyIterChip}; use openvm_circuit::arch::{ execution_mode::{ExecutionCtxTrait, MeteredExecutionCtxTrait}, - get_record_from_slice, ExecuteFunc, ExecutionError, Executor, - MeteredExecutor, RecordArena, StaticProgramError, TraceFiller, VmExecState, + get_record_from_slice, ExecuteFunc, ExecutionError, Executor, MeteredExecutor, RecordArena, + StaticProgramError, TraceFiller, VmExecState, }; use openvm_memcpy_transpiler::Rv32MemcpyOpcode; @@ -69,6 +64,9 @@ pub struct MemcpyLoopCols { pub to_source_minus_twelve_carry: T, } +pub const NUM_MEMCPY_LOOP_COLS: usize = size_of::>(); +pub const MEMCPY_LOOP_NUM_WRITES: u32 = 3; + #[derive(Copy, Clone, Debug, derive_new::new)] pub struct MemcpyLoopAir { pub memory_bridge: MemoryBridge, @@ -93,11 +91,10 @@ impl Air for MemcpyLoopAir { let local = main.row_slice(0); let local: &MemcpyLoopCols = (*local).borrow(); - let timestamp: AB::Var = local.from_state.timestamp; - let mut timestamp_delta: usize = 0; + let mut timestamp_delta: u32 = 0; let mut timestamp_pp = || { timestamp_delta += 1; - timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1) + local.to_timestamp - AB::Expr::from_canonical_u32(MEMCPY_LOOP_NUM_WRITES + timestamp_delta - 1) }; let from_le_bytes = |data: [AB::Var; 4]| { @@ -239,14 +236,9 @@ impl Air for MemcpyLoopAir { // Send message to memcpy call bus self.memcpy_bus .send( - timestamp + AB::Expr::from_canonical_usize(timestamp_delta), - dest - AB::Expr::from_canonical_u32(16), - source - - select::( - is_shift_non_zero.clone(), - AB::Expr::from_canonical_u32(28), - AB::Expr::from_canonical_u32(16), - ), + local.from_state.timestamp, + dest, + source - AB::Expr::from_canonical_u32(12) * is_shift_non_zero.clone(), len.clone() - shift.clone(), shift.clone(), ) @@ -255,9 +247,9 @@ impl Air for MemcpyLoopAir { // Receive message from memcpy return bus self.memcpy_bus .receive( - local.to_timestamp, + local.to_timestamp - AB::Expr::from_canonical_u32(timestamp_delta), to_dest, - to_source, + to_source - AB::Expr::from_canonical_u32(12) * is_shift_non_zero.clone(), to_len - shift.clone(), AB::Expr::from_canonical_u32(4), ) @@ -265,7 +257,7 @@ impl Air for MemcpyLoopAir { // Make sure the request and response match builder.assert_eq( - local.to_timestamp - (timestamp + AB::Expr::from_canonical_usize(timestamp_delta)), + local.to_timestamp - (local.from_state.timestamp + AB::Expr::from_canonical_u32(timestamp_delta)), AB::Expr::TWO * (len.clone() - to_len) + is_shift_non_zero.clone(), ); @@ -281,242 +273,72 @@ impl Air for MemcpyLoopAir { } } -#[derive(derive_new::new, Clone, Copy)] -pub struct MemcpyLoopExecutor {} - #[repr(C)] #[derive(AlignedBytesBorrow, Debug)] pub struct MemcpyLoopRecord { - pub shift: [u8; 2], - pub dest: [u8; MEMCPY_LOOP_NUM_LIMBS], - pub source: [u8; MEMCPY_LOOP_NUM_LIMBS], - pub len: [u8; MEMCPY_LOOP_NUM_LIMBS], pub from_pc: u32, pub from_timestamp: u32, - pub register_aux: [MemoryBaseAuxRecord; 3], - pub memory_read_data: Vec<[u8; MEMCPY_LOOP_NUM_LIMBS]>, - pub read_aux: Vec, - pub write_aux: Vec>, + pub dest: u32, + pub source: u32, + pub len: u32, + pub shift: u8, + pub write_aux: [MemoryExtendedAuxRecord; 3], } -#[derive(derive_new::new)] -pub struct MemcpyLoopFiller { +pub struct MemcpyLoopChip { + pub air: MemcpyLoopAir, + pub records: Arc>>, pub pointer_max_bits: usize, pub range_checker_chip: SharedVariableRangeCheckerChip, - pub memcpy_iter_chip: Arc, -} - -pub type MemcpyLoopChip = VmChipWrapper; - -#[derive(AlignedBytesBorrow, Clone)] -#[repr(C)] -struct MemcpyLoopPreCompute { - c: u8, } -impl PreflightExecutor for MemcpyLoopExecutor -where - F: PrimeField32, - for<'buf> RA: RecordArena<'buf, EmptyMultiRowLayout, &'buf mut MemcpyLoopRecord>, -{ - fn get_opcode_name(&self, _: usize) -> String { - format!("{:?}", Rv32MemcpyOpcode::MEMCPY_LOOP) - } - - fn execute( - &self, - state: VmStateMut, - instruction: &Instruction, - ) -> Result<(), ExecutionError> { - let Instruction { opcode, c, .. } = instruction; - debug_assert_eq!(*opcode, Rv32MemcpyOpcode::MEMCPY_LOOP.global_opcode()); - let shift = c.as_canonical_u32() as u8; - debug_assert!([0, 1, 2, 3].contains(&shift)); - let mut record = state.ctx.alloc(EmptyMultiRowLayout::default()); - - let mut dest = read_rv32_register( - state.memory.data(), - if shift == 0 { - A3_REGISTER_PTR - } else { - A1_REGISTER_PTR - } as u32, - ); - let mut source = read_rv32_register( - state.memory.data(), - if shift == 0 { - A4_REGISTER_PTR - } else { - A3_REGISTER_PTR - } as u32, - ); - let mut len = read_rv32_register(state.memory.data(), A2_REGISTER_PTR as u32); - - // Store the original values in the record - record.shift = [shift % 2, shift / 2]; - record.from_pc = *state.pc; - record.from_timestamp = state.memory.timestamp; - - let num_iterations = (len - shift as u32) & !15; - let to_dest = dest + num_iterations; - let to_source = source + num_iterations; - let to_len = len - num_iterations; - - tracing_write( - state.memory, - RV32_REGISTER_AS, - if shift == 0 { - A3_REGISTER_PTR - } else { - A1_REGISTER_PTR - } as u32, - to_dest.to_le_bytes(), - &mut record.register_aux[0].prev_timestamp, - &mut record.dest, - ); - - tracing_write( - state.memory, - RV32_REGISTER_AS, - if shift == 0 { - A4_REGISTER_PTR - } else { - A3_REGISTER_PTR - } as u32, - to_source.to_le_bytes(), - &mut record.register_aux[1].prev_timestamp, - &mut record.source, - ); - - tracing_write( - state.memory, - RV32_REGISTER_AS, - A2_REGISTER_PTR as u32, - to_len.to_le_bytes(), - &mut record.register_aux[2].prev_timestamp, - &mut record.len, - ); - - let mut prev_data = if shift == 0 { - [0; 4] - } else { - source -= 12; - record - .read_aux - .push(MemoryReadAuxRecord { prev_timestamp: 0 }); - let data = tracing_read( - state.memory, - RV32_MEMORY_AS, - source - 4, - &mut record.read_aux.last_mut().unwrap().prev_timestamp, - ); - record.memory_read_data.push(data); - data - }; - - while len - shift as u32 > 15 { - let writes_data: [[u8; MEMCPY_LOOP_NUM_LIMBS]; 4] = array::from_fn(|i| { - record - .read_aux - .push(MemoryReadAuxRecord { prev_timestamp: 0 }); - let data = tracing_read( - state.memory, - RV32_MEMORY_AS, - source + 4 * i as u32, - &mut record.read_aux.last_mut().unwrap().prev_timestamp, - ); - record.memory_read_data.push(data); - let write_data: [u8; MEMCPY_LOOP_NUM_LIMBS] = array::from_fn(|i| { - if i < 4 - shift as usize { - data[i + shift as usize] - } else { - prev_data[i - (4 - shift as usize)] - } - }); - prev_data = data; - write_data - }); - writes_data.iter().enumerate().for_each(|(i, write_data)| { - record.write_aux.push(MemoryWriteBytesAuxRecord { - prev_timestamp: 0, - prev_data: [0; MEMCPY_LOOP_NUM_LIMBS], - }); - tracing_write( - state.memory, - RV32_MEMORY_AS, - dest + 4 * i as u32, - *write_data, - &mut record.write_aux.clone().last_mut().unwrap().prev_timestamp, - &mut record.write_aux.clone().last_mut().unwrap().prev_data, - ); - }); - len -= 16; - source += 16; - dest += 16; +impl MemcpyLoopChip { + pub fn new( + system_port: SystemPort, + range_bus: VariableRangeCheckerBus, + memcpy_bus: MemcpyBus, + pointer_max_bits: usize, + range_checker_chip: SharedVariableRangeCheckerChip, + ) -> Self { + Self { + air: MemcpyLoopAir::new(system_port.memory_bridge, ExecutionBridge::new(system_port.execution_bus, system_port.program_bus), range_bus, memcpy_bus, pointer_max_bits), + records: Arc::new(Mutex::new(Vec::new())), + pointer_max_bits, + range_checker_chip, } - - *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - - Ok(()) } -} -impl TraceFiller for MemcpyLoopFiller { - fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut row: &mut [F]) { - let record: &MemcpyLoopRecord = unsafe { get_record_from_slice(&mut row, ()) }; - let row: &mut MemcpyLoopCols = row.borrow_mut(); + pub fn bus(&self) -> MemcpyBus { + self.air.memcpy_bus + } - const NUM_WRITES: u32 = 3; + pub fn clear(&self) { + self.records.lock().unwrap().clear(); + } - let shift = record.shift[0] + record.shift[1] * 2; - let dest = u32::from_le_bytes(record.dest); - let source = u32::from_le_bytes(record.source); - let len = u32::from_le_bytes(record.len); - let num_copies = (len - shift as u32) & !15; + pub fn add_new_loop<'a, F: PrimeField32>( + &self, + mem_helper: &MemoryAuxColsFactory, + from_pc: u32, + from_timestamp: u32, + dest: u32, + source: u32, + len: u32, + shift: u8, + register_aux: [MemoryBaseAuxRecord; 3], + ) { + let mut timestamp = from_timestamp + (((len - shift as u32) & !0x0f) >> 1) + (shift != 0) as u32; + let write_aux = register_aux.iter().map(|aux_record| { + let mut aux_col = MemoryBaseAuxCols::default(); + mem_helper.fill(aux_record.prev_timestamp, timestamp, &mut aux_col); + timestamp += 1; + MemoryExtendedAuxRecord::from_aux_cols(aux_col) + }).collect::>().try_into().unwrap(); + + let num_copies = (len - shift as u32) & !0x0f; let to_dest = dest + num_copies; let to_source = source + num_copies; - let to_len = len - num_copies; - let timestamp = record.from_timestamp; - - let source_minus_twelve_carry = if shift == 0 { - F::ZERO - } else { - F::from_canonical_u8((source % (1 << 8) < 12) as u8) - }; - let to_source_minus_twelve_carry = if shift == 0 { - F::ZERO - } else { - F::from_canonical_u8((to_source % (1 << 8) < 12) as u8) - }; - - for ((i, cols), register_aux_record) in row - .write_aux - .iter_mut() - .enumerate() - .zip(record.register_aux.iter()) - { - mem_helper.fill( - register_aux_record.prev_timestamp, - timestamp + i as u32, - cols, - ); - } - - row.source_minus_twelve_carry = source_minus_twelve_carry; - row.to_source_minus_twelve_carry = to_source_minus_twelve_carry; - row.to_dest = to_dest.to_le_bytes().map(F::from_canonical_u8); - row.to_source = to_source.to_le_bytes().map(F::from_canonical_u8); - row.to_len = F::from_canonical_u32(to_len); - row.to_timestamp = - F::from_canonical_u32(timestamp + NUM_WRITES + 2 * num_copies + (shift != 0) as u32); - row.is_valid = F::ONE; - row.dest = record.dest.map(F::from_canonical_u8); - row.source = record.source.map(F::from_canonical_u8); - row.len = record.len.map(F::from_canonical_u8); - row.shift = record.shift.map(F::from_canonical_u8); - row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); - row.from_state.pc = F::from_canonical_u32(record.from_pc); - + let word_to_u16 = |data: u32| [data & 0xffff, data >> 16]; let range_check_data = [ (word_to_u16(len), false), @@ -538,175 +360,78 @@ impl TraceFiller for MemcpyLoopFiller { .add_count(data[1], self.pointer_max_bits - 2 * MEMCPY_LOOP_LIMB_BITS); }); - // Handle MemcpyIter - self.memcpy_iter_chip.add_new_loop( - mem_helper, - timestamp + NUM_WRITES, - dest - 16, - source - 16 - 12 * (shift != 0) as u32, - len - shift as u32, + // Create record + let row = MemcpyLoopRecord { + from_pc, + from_timestamp, + dest, + source, + len, shift, - record.memory_read_data.clone(), - record.read_aux.clone(), - record.write_aux.clone(), - ); - } -} + write_aux, + }; -impl Executor for MemcpyLoopExecutor { - fn pre_compute_size(&self) -> usize { - size_of::() + // Thread-safe push to rows vector + if let Ok(mut rows_guard) = self.records.lock() { + rows_guard.push(row); + } } - fn pre_compute( - &self, - pc: u32, - inst: &Instruction, - data: &mut [u8], - ) -> Result, StaticProgramError> - where - Ctx: ExecutionCtxTrait, - { - let data: &mut MemcpyLoopPreCompute = data.borrow_mut(); - self.pre_compute_impl(pc, inst, data)?; - Ok(execute_e1_impl::<_, _>) + /// Generates trace + pub fn generate_trace(&self) -> RowMajorMatrix { + let mut rows = F::zero_vec((self.records.lock().unwrap().len() as usize) * NUM_MEMCPY_LOOP_COLS); + + for (i, record) in self.records.lock().unwrap().iter().enumerate() { + let row = &mut rows[i * NUM_MEMCPY_LOOP_COLS..(i + 1) * NUM_MEMCPY_LOOP_COLS]; + let cols: &mut MemcpyLoopCols = row.borrow_mut(); + + let shift = record.shift; + let num_copies = (record.len - shift as u32) & !0x0f; + let to_source = record.source + num_copies; + + cols.from_state.pc = F::from_canonical_u32(record.from_pc); + cols.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + cols.dest = record.dest.to_le_bytes().map(F::from_canonical_u8); + cols.source = record.source.to_le_bytes().map(F::from_canonical_u8); + cols.len = record.len.to_le_bytes().map(F::from_canonical_u8); + cols.shift = [F::from_canonical_u8(shift % 2), F::from_canonical_u8(shift / 2)]; + cols.is_valid = F::ONE; + // We have MEMCPY_LOOP_NUM_WRITES writes in the loop, (num_copies / 4 + shift != 0) reads and (num_copies / 4) writes in iterations + cols.to_timestamp = F::from_canonical_u32(record.from_timestamp + MEMCPY_LOOP_NUM_WRITES + (num_copies >> 1) + (shift != 0) as u32); + cols.to_dest = (record.dest + num_copies).to_le_bytes().map(F::from_canonical_u8); + cols.to_source = to_source.to_le_bytes().map(F::from_canonical_u8); + cols.to_len = F::from_canonical_u32(record.len - num_copies); + cols.write_aux = record.write_aux.clone().map(|aux| aux.to_aux_cols()); + cols.source_minus_twelve_carry = F::from_bool((record.source & 0x0ff) < 12); + cols.to_source_minus_twelve_carry = F::from_bool((to_source & 0x0ff) < 12); + } + RowMajorMatrix::new(rows, NUM_MEMCPY_LOOP_COLS) } } -impl MeteredExecutor for MemcpyLoopExecutor { - fn metered_pre_compute_size(&self) -> usize { - size_of::>() - } - - fn metered_pre_compute( - &self, - chip_idx: usize, - pc: u32, - inst: &Instruction, - data: &mut [u8], - ) -> Result, StaticProgramError> - where - Ctx: MeteredExecutionCtxTrait, - { - let data: &mut E2PreCompute = data.borrow_mut(); - data.chip_idx = chip_idx as u32; - self.pre_compute_impl(pc, inst, &mut data.data)?; - Ok(execute_e2_impl::<_, _>) +// We allow any `R` type so this can work with arbitrary record arenas. +impl Chip> for MemcpyLoopChip +where + Val: PrimeField32, +{ + /// Generates trace and resets the internal counters all to 0. + fn generate_proving_ctx(&self, _: R) -> AirProvingContext> { + let trace = self.generate_trace::>(); + AirProvingContext::simple_no_pis(Arc::new(trace)) } } -#[inline(always)] -unsafe fn execute_e12_impl( - pre_compute: &MemcpyLoopPreCompute, - vm_state: &mut VmExecState, -) { - let shift = pre_compute.c; - let (dest, source) = if shift == 0 { - ( - vm_state.vm_read::(RV32_REGISTER_AS, A3_REGISTER_PTR as u32), - vm_state.vm_read::(RV32_REGISTER_AS, A4_REGISTER_PTR as u32), - ) - } else { - ( - vm_state.vm_read::(RV32_REGISTER_AS, A1_REGISTER_PTR as u32), - vm_state.vm_read::(RV32_REGISTER_AS, A3_REGISTER_PTR as u32), - ) - }; - let len = vm_state.vm_read::(RV32_REGISTER_AS, A2_REGISTER_PTR as u32); - - let mut dest = u32::from_le_bytes(dest); - let mut source = u32::from_le_bytes(source); - let mut len = u32::from_le_bytes(len); - - let mut prev_data = if shift == 0 { - [0; 4] - } else { - source -= 12; - vm_state.vm_read::(RV32_MEMORY_AS, source - 4) - }; - - while len - shift as u32 > 15 { - for i in 0..4 { - let data = vm_state.vm_read::(RV32_MEMORY_AS, source + 4 * i); - let write_data: [u8; 4] = array::from_fn(|i| { - if i < 4 - shift as usize { - data[i + shift as usize] - } else { - prev_data[i - (4 - shift as usize)] - } - }); - vm_state.vm_write(RV32_MEMORY_AS, dest + 4 * i, &write_data); - prev_data = data; - } - len -= 16; - source += 16; - dest += 16; +impl ChipUsageGetter for MemcpyLoopChip { + fn air_name(&self) -> String { + get_air_name(&self.air) } - - // Write the result back to memory - if shift == 0 { - vm_state.vm_write( - RV32_REGISTER_AS, - A3_REGISTER_PTR as u32, - &dest.to_le_bytes(), - ); - vm_state.vm_write( - RV32_REGISTER_AS, - A4_REGISTER_PTR as u32, - &source.to_le_bytes(), - ); - } else { - source += 12; - vm_state.vm_write( - RV32_REGISTER_AS, - A1_REGISTER_PTR as u32, - &dest.to_le_bytes(), - ); - vm_state.vm_write( - RV32_REGISTER_AS, - A3_REGISTER_PTR as u32, - &source.to_le_bytes(), - ); - }; - vm_state.vm_write(RV32_REGISTER_AS, A2_REGISTER_PTR as u32, &len.to_le_bytes()); - - vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); - vm_state.instret += 1; -} - -unsafe fn execute_e1_impl( - pre_compute: &[u8], - vm_state: &mut VmExecState, -) { - let pre_compute: &MemcpyLoopPreCompute = pre_compute.borrow(); - execute_e12_impl::(pre_compute, vm_state); -} - -unsafe fn execute_e2_impl( - pre_compute: &[u8], - vm_state: &mut VmExecState, -) { - let pre_compute: &E2PreCompute = pre_compute.borrow(); - vm_state - .ctx - .on_height_change(pre_compute.chip_idx as usize, 1); - execute_e12_impl::(&pre_compute.data, vm_state); -} - -impl MemcpyLoopExecutor { - fn pre_compute_impl( - &self, - pc: u32, - inst: &Instruction, - data: &mut MemcpyLoopPreCompute, - ) -> Result<(), StaticProgramError> { - let Instruction { opcode, c, .. } = inst; - let c_u32 = c.as_canonical_u32(); - if ![0, 1, 2, 3].contains(&c_u32) { - return Err(StaticProgramError::InvalidInstruction(pc)); - } - *data = MemcpyLoopPreCompute { c: c_u32 as u8 }; - assert_eq!(*opcode, Rv32MemcpyOpcode::MEMCPY_LOOP.global_opcode()); - Ok(()) + fn constant_trace_height(&self) -> Option { + Some(self.records.lock().unwrap().len() as usize) + } + fn current_trace_height(&self) -> usize { + self.records.lock().unwrap().len() as usize + } + fn trace_width(&self) -> usize { + NUM_MEMCPY_LOOP_COLS } } diff --git a/extensions/memcpy/circuit/src/extension.rs b/extensions/memcpy/circuit/src/extension.rs index d485618a50..184bdfc4bf 100644 --- a/extensions/memcpy/circuit/src/extension.rs +++ b/extensions/memcpy/circuit/src/extension.rs @@ -31,7 +31,7 @@ pub struct Memcpy; #[derive(Clone, From, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)] pub enum MemcpyExecutor { - MemcpyLoop(MemcpyLoopExecutor), + MemcpyLoop(MemcpyIterExecutor), } impl VmExecutionExtension for Memcpy { @@ -41,10 +41,10 @@ impl VmExecutionExtension for Memcpy { &self, inventory: &mut ExecutorInventoryBuilder, ) -> Result<(), ExecutorInventoryError> { - let memcpy_loop = MemcpyLoopExecutor::new(); + let memcpy_iter = MemcpyIterExecutor::new(); inventory.add_executor( - memcpy_loop, + memcpy_iter, Rv32MemcpyOpcode::iter().map(|x| x.global_opcode()), )?; @@ -110,23 +110,22 @@ where .unwrap() .memcpy_bus; - let memcpy_iter_chip = Arc::new(MemcpyIterChip::new( - inventory.airs().system().port().memory_bridge, + let memcpy_loop_chip = Arc::new(MemcpyLoopChip::new( + inventory.airs().system().port(), range_bus, memcpy_bus, pointer_max_bits, range_checker.clone(), )); - let memcpy_loop_chip = MemcpyLoopChip::new( - MemcpyLoopFiller::new( + let memcpy_iter_chip = MemcpyIterChip::new( + MemcpyIterFiller::new( pointer_max_bits, range_checker.clone(), - memcpy_iter_chip.clone(), + memcpy_loop_chip.clone(), ), mem_helper.clone(), ); - // Add MemcpyLoop chip inventory.next_air::()?; inventory.add_executor_chip(memcpy_loop_chip); diff --git a/extensions/memcpy/circuit/src/iteration.rs b/extensions/memcpy/circuit/src/iteration.rs index 6191b061bd..1c9aca0458 100644 --- a/extensions/memcpy/circuit/src/iteration.rs +++ b/extensions/memcpy/circuit/src/iteration.rs @@ -2,15 +2,24 @@ use std::{ array, borrow::{Borrow, BorrowMut}, mem::size_of, - sync::{atomic::AtomicU32, Arc, Mutex}, + sync::Arc, }; -use openvm_circuit::system::memory::{ - offline_checker::{ - MemoryBaseAuxCols, MemoryBridge, MemoryExtendedAuxRecord, MemoryReadAuxCols, - MemoryReadAuxRecord, MemoryWriteAuxCols, MemoryWriteBytesAuxRecord, +use openvm_circuit::{ + arch::{ + get_record_from_slice, CustomBorrow, E2PreCompute, ExecuteFunc, ExecutionCtxTrait, + ExecutionError, Executor, MeteredExecutionCtxTrait, MeteredExecutor, MultiRowLayout, + MultiRowMetadata, PreflightExecutor, RecordArena, SizedRecord, StaticProgramError, + TraceFiller, VmChipWrapper, VmExecState, VmStateMut, + }, + system::memory::{ + offline_checker::{ + MemoryBaseAuxRecord, MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, + MemoryWriteAuxCols, MemoryWriteBytesAuxRecord, + }, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, }, - MemoryAddress, MemoryAuxColsFactory, }; use openvm_circuit_primitives::{ utils::{and, not, or, select}, @@ -18,19 +27,27 @@ use openvm_circuit_primitives::{ AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::riscv::RV32_MEMORY_AS; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; +use openvm_memcpy_transpiler::Rv32MemcpyOpcode; +use openvm_rv32im_circuit::adapters::{read_rv32_register, tracing_read, tracing_write}; use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, interaction::InteractionBuilder, p3_air::{Air, AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, p3_matrix::{dense::RowMajorMatrix, Matrix}, - prover::{cpu::CpuBackend, types::AirProvingContext}, - rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, - Chip, ChipUsageGetter, + p3_maybe_rayon::prelude::*, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, }; -use crate::bus::MemcpyBus; +use crate::{ + bus::MemcpyBus, MemcpyLoopChip, A1_REGISTER_PTR, A2_REGISTER_PTR, A3_REGISTER_PTR, + A4_REGISTER_PTR, +}; // Import constants from lib.rs use crate::{MEMCPY_LOOP_LIMB_BITS, MEMCPY_LOOP_NUM_LIMBS}; @@ -190,7 +207,8 @@ impl Air for MemcpyIterAir { // This actually receives when is_boundary = -1 self.memcpy_bus .send( - local.timestamp, + local.timestamp + + (local.is_boundary + AB::Expr::ONE) * AB::Expr::from_canonical_usize(4), local.dest, local.source, len, @@ -267,233 +285,568 @@ impl Air for MemcpyIterAir { } } +#[derive(derive_new::new, Clone, Copy)] +pub struct MemcpyIterExecutor {} + +#[derive(Copy, Clone, Debug)] +pub struct MemcpyIterMetadata { + num_rows: usize, +} + +impl MultiRowMetadata for MemcpyIterMetadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + self.num_rows + } +} + +pub type MemcpyIterLayout = MultiRowLayout; + #[repr(C)] #[derive(AlignedBytesBorrow, Debug)] -pub struct MemcpyIterRecord { - pub timestamp: u32, +pub struct MemcpyIterRecordHeader { + pub shift: u8, pub dest: u32, pub source: u32, pub len: u32, - pub shift: u8, - pub memory_read_data: Vec<[u8; MEMCPY_LOOP_NUM_LIMBS]>, - pub read_aux: Vec, - pub write_aux: Vec, + pub from_pc: u32, + pub from_timestamp: u32, + pub register_aux: [MemoryBaseAuxRecord; 3], } -pub struct MemcpyIterChip { - pub air: MemcpyIterAir, - pub records: Arc>>, - pub num_rows: AtomicU32, - pub pointer_max_bits: usize, - pub range_checker_chip: SharedVariableRangeCheckerChip, +// This is the part of the record that we keep `(len & !15) + (shift != 0)` times per instruction +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct MemcpyIterRecordVar { + pub data: [[u8; MEMCPY_LOOP_NUM_LIMBS]; 4], + pub read_aux: [MemoryReadAuxRecord; 4], + pub write_aux: [MemoryWriteBytesAuxRecord<4>; 4], +} + +/// **SAFETY**: the order of the fields in `MemcpyLoopRecordMut` and `MemcpyLoopRecordVar` +/// is important. +#[derive(Debug)] +pub struct MemcpyIterRecordMut<'a> { + pub inner: &'a mut MemcpyIterRecordHeader, + pub var: &'a mut [MemcpyIterRecordVar], } -impl MemcpyIterChip { - pub fn new( - memory_bridge: MemoryBridge, - range_bus: VariableRangeCheckerBus, - memcpy_bus: MemcpyBus, - pointer_max_bits: usize, - range_checker_chip: SharedVariableRangeCheckerChip, - ) -> Self { - Self { - air: MemcpyIterAir::new(memory_bridge, range_bus, memcpy_bus, pointer_max_bits), - records: Arc::new(Mutex::new(Vec::new())), - num_rows: AtomicU32::new(0), - pointer_max_bits, - range_checker_chip, +/// Custom borrowing that splits the buffer into a fixed `MemcpyLoopRecordHeader` header +/// followed by a slice of `MemcpyLoopRecordVar`'s of length `num_words` provided at runtime. +/// Uses `align_to_mut()` to make sure the slice is properly aligned to `MemcpyLoopRecordVar`. +/// Has debug assertions to make sure the above works as expected. +impl<'a> CustomBorrow<'a, MemcpyIterRecordMut<'a>, MemcpyIterLayout> for [u8] { + fn custom_borrow(&'a mut self, layout: MemcpyIterLayout) -> MemcpyIterRecordMut<'a> { + let (header_buf, rest) = + unsafe { self.split_at_mut_unchecked(size_of::()) }; + + let (_, vars, _) = unsafe { rest.align_to_mut::() }; + MemcpyIterRecordMut { + inner: header_buf.borrow_mut(), + var: &mut vars[..layout.metadata.num_rows], } } - pub fn bus(&self) -> MemcpyBus { - self.air.memcpy_bus + unsafe fn extract_layout(&self) -> MemcpyIterLayout { + let header: &MemcpyIterRecordHeader = self.borrow(); + MultiRowLayout::new(MemcpyIterMetadata { + num_rows: ((header.len - header.shift as u32) >> 4) as usize + 1, + }) } +} - pub fn clear(&self) { - self.records.lock().unwrap().clear(); - self.num_rows.store(0, std::sync::atomic::Ordering::Relaxed); +impl SizedRecord for MemcpyIterRecordMut<'_> { + fn size(layout: &MemcpyIterLayout) -> usize { + let mut total_len = size_of::(); + // Align the pointer to the alignment of `Rv32HintStoreVar` + total_len = total_len.next_multiple_of(align_of::()); + total_len += size_of::() * layout.metadata.num_rows; + total_len } - pub fn add_new_loop( - &self, - mem_helper: &MemoryAuxColsFactory, - timestamp: u32, - dest: u32, - source: u32, - len: u32, - shift: u8, - memory_read_data: Vec<[u8; MEMCPY_LOOP_NUM_LIMBS]>, - read_aux: Vec, - write_aux: Vec>, - ) { - let mut len = len; - // Update number of rows - self.num_rows - .fetch_add(len / 16 + 1, std::sync::atomic::Ordering::Relaxed); + fn alignment(_layout: &MemcpyIterLayout) -> usize { + align_of::() + } +} - let word_to_u16 = |data: u32| [data & 0xffff, data >> 16]; - let has_shift = (shift != 0) as usize; +#[derive(derive_new::new)] +pub struct MemcpyIterFiller { + pub pointer_max_bits: usize, + pub range_checker_chip: SharedVariableRangeCheckerChip, + pub memcpy_loop_chip: Arc, +} - // Range check len - loop { - let len_u16_limbs = word_to_u16(len); - if len > 15 { - self.range_checker_chip - .add_count(len_u16_limbs[0], 2 * MEMCPY_LOOP_LIMB_BITS); - self.range_checker_chip.add_count( - len_u16_limbs[1], - self.pointer_max_bits - 2 * MEMCPY_LOOP_LIMB_BITS, - ); +pub type MemcpyIterChip = VmChipWrapper; + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct MemcpyIterPreCompute { + c: u8, +} + +impl PreflightExecutor for MemcpyIterExecutor +where + F: PrimeField32, + for<'buf> RA: RecordArena<'buf, MultiRowLayout, MemcpyIterRecordMut<'buf>>, +{ + fn get_opcode_name(&self, _: usize) -> String { + format!("{:?}", Rv32MemcpyOpcode::MEMCPY_LOOP) + } + + fn execute( + &self, + state: VmStateMut, + instruction: &Instruction, + ) -> Result<(), ExecutionError> { + let Instruction { opcode, c, .. } = instruction; + debug_assert_eq!(*opcode, Rv32MemcpyOpcode::MEMCPY_LOOP.global_opcode()); + let shift = c.as_canonical_u32() as u8; + debug_assert!([0, 1, 2, 3].contains(&shift)); + + let mut dest = read_rv32_register( + state.memory.data(), + if shift == 0 { + A3_REGISTER_PTR + } else { + A1_REGISTER_PTR + } as u32, + ); + let mut source = read_rv32_register( + state.memory.data(), + if shift == 0 { + A4_REGISTER_PTR } else { - self.range_checker_chip.add_count(len_u16_limbs[0], 4); - self.range_checker_chip.add_count(len_u16_limbs[1], 0); - } - if len < 16 { - break; - } + A3_REGISTER_PTR + } as u32, + ); + let mut len = read_rv32_register(state.memory.data(), A2_REGISTER_PTR as u32); + + let record = state.ctx.alloc(MultiRowLayout::new(MemcpyIterMetadata { + num_rows: ((len - shift as u32) >> 4) as usize + 1, + })); + + // Store the original values in the record + record.inner.shift = shift; + record.inner.from_pc = *state.pc; + record.inner.from_timestamp = state.memory.timestamp; + + if shift != 0 { + source -= 12; + record.var[0].data[3] = tracing_read( + state.memory, + RV32_MEMORY_AS, + source - 4, + &mut record.var[0].read_aux[3].prev_timestamp, + ); + }; + + let mut idx = 1; + while len - shift as u32 > 15 { + let writes_data: [[u8; MEMCPY_LOOP_NUM_LIMBS]; 4] = array::from_fn(|i| { + record.var[idx].data[i] = tracing_read( + state.memory, + RV32_MEMORY_AS, + source + 4 * i as u32, + &mut record.var[idx].read_aux[i].prev_timestamp, + ); + let write_data: [u8; MEMCPY_LOOP_NUM_LIMBS] = array::from_fn(|j| { + if j < 4 - shift as usize { + record.var[idx].data[i][j + shift as usize] + } else { + record.var[idx - 1].data[i][j - (4 - shift as usize)] + } + }); + write_data + }); + writes_data.iter().enumerate().for_each(|(i, write_data)| { + tracing_write( + state.memory, + RV32_MEMORY_AS, + dest + 4 * i as u32, + *write_data, + &mut record.var[idx].write_aux[i].prev_timestamp, + &mut record.var[idx].write_aux[i].prev_data, + ); + }); len -= 16; + source += 16; + dest += 16; + idx += 1; } - // Read data from memory - let mut row_read_aux = Vec::new(); - read_aux.iter().enumerate().for_each(|(i, aux)| { - let mut aux_cols = MemoryBaseAuxCols::::default(); - let read_timestamp = timestamp - + if i == 0 { - 0 - } else { - (i + (i - has_shift) / 4 * 4) as u32 - }; - mem_helper.fill(aux.prev_timestamp, read_timestamp, &mut aux_cols); - row_read_aux.push(MemoryExtendedAuxRecord::from_aux_cols(aux_cols)); - }); + // Handle the core loop + if shift != 0 { + source += 12; + } - // Write data to memory - let mut row_write_aux = Vec::new(); - write_aux.iter().enumerate().for_each(|(i, aux)| { - let mut aux_cols = MemoryBaseAuxCols::::default(); - mem_helper.fill( - aux.prev_timestamp, - (timestamp as usize + i + has_shift + (i / 4 + 1) * 4) as u32, - &mut aux_cols, - ); - row_write_aux.push(MemoryExtendedAuxRecord::from_aux_cols(aux_cols)); - }); + let mut dest_data = [0; 4]; + let mut source_data = [0; 4]; + let mut len_data = [0; 4]; - // Create record - let row = MemcpyIterRecord { - timestamp, - dest, - source, - len, - shift, - memory_read_data, - read_aux: row_read_aux, - write_aux: row_write_aux, - }; + tracing_write( + state.memory, + RV32_REGISTER_AS, + if shift == 0 { + A3_REGISTER_PTR + } else { + A1_REGISTER_PTR + } as u32, + dest.to_le_bytes(), + &mut record.inner.register_aux[0].prev_timestamp, + &mut dest_data, + ); - // Thread-safe push to rows vector - if let Ok(mut rows_guard) = self.records.lock() { - rows_guard.push(row); - } - } + tracing_write( + state.memory, + RV32_REGISTER_AS, + if shift == 0 { + A4_REGISTER_PTR + } else { + A3_REGISTER_PTR + } as u32, + source.to_le_bytes(), + &mut record.inner.register_aux[1].prev_timestamp, + &mut source_data, + ); - /// Generates trace - pub fn generate_trace(&self) -> RowMajorMatrix { - let mut rows = F::zero_vec( - (self.num_rows.load(std::sync::atomic::Ordering::Relaxed) as usize) - * NUM_MEMCPY_ITER_COLS, + tracing_write( + state.memory, + RV32_REGISTER_AS, + A2_REGISTER_PTR as u32, + len.to_le_bytes(), + &mut record.inner.register_aux[2].prev_timestamp, + &mut len_data, ); - let mut current_row = 0; - let word_to_u16 = |data: u32| [data & 0xffff, data >> 16].map(F::from_canonical_u32); - - for record in self.records.lock().unwrap().iter() { - let mut timestamp = record.timestamp; - let shift = [record.shift % 2, record.shift / 2].map(F::from_canonical_u8); - let has_shift = (record.shift != 0) as usize; - let mut prev_data = [F::ZERO; MEMCPY_LOOP_NUM_LIMBS]; - - for n in 0..(record.len / 16 + 1) as usize { - let row_start = current_row + n * NUM_MEMCPY_ITER_COLS; - let row = &mut rows[row_start..row_start + NUM_MEMCPY_ITER_COLS]; - let cols: &mut MemcpyIterCols = row.borrow_mut(); - cols.timestamp = F::from_canonical_u32(timestamp); - cols.dest = F::from_canonical_u32(record.dest + (n << 2) as u32); - cols.source = F::from_canonical_u32(record.source + (n << 2) as u32); - cols.len = word_to_u16(record.len - (n << 2) as u32); - cols.shift = shift; - cols.is_valid = F::ONE; - cols.is_valid_not_start = F::ONE; - if n == 0 { - cols.is_boundary = F::NEG_ONE; - if has_shift != 0 { - cols.data_4 = record.memory_read_data[0].map(F::from_canonical_u8); - prev_data = cols.data_4; - cols.read_aux[3].set_base(record.read_aux[0].to_aux_cols()); + + record.inner.dest = u32::from_le_bytes(dest_data); + record.inner.source = u32::from_le_bytes(source_data); + record.inner.len = u32::from_le_bytes(len_data); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller for MemcpyIterFiller { + fn fill_trace( + &self, + mem_helper: &MemoryAuxColsFactory, + trace: &mut RowMajorMatrix, + rows_used: usize, + ) { + if rows_used == 0 { + return; + } + + let width = trace.width; + debug_assert_eq!(width, NUM_MEMCPY_ITER_COLS); + let mut trace = &mut trace.values[..width * rows_used]; + let mut sizes = Vec::with_capacity(rows_used >> 1); + let mut chunks = Vec::with_capacity(rows_used >> 1); + + while !trace.is_empty() { + let record: &MemcpyIterRecordHeader = unsafe { get_record_from_slice(&mut trace, ()) }; + let num_rows = ((record.len - record.shift as u32) >> 4) as usize + 1; + let (chunk, rest) = trace.split_at_mut(width * num_rows as usize); + sizes.push(num_rows); + chunks.push(chunk); + trace = rest; + } + + chunks + .par_iter_mut() + .zip(sizes.par_iter()) + .for_each(|(chunk, &num_rows)| { + let record: MemcpyIterRecordMut = unsafe { + get_record_from_slice( + chunk, + MultiRowLayout::new(MemcpyIterMetadata { num_rows }), + ) + }; + + // 4 reads + 4 writes per iteration + (shift != 0) read for the loop header + let timestamp = record.inner.from_timestamp + + ((num_rows - 1) << 3) as u32 + + (record.inner.shift != 0) as u32; + let mut timestamp_delta: u32 = 0; + let mut get_timestamp = |is_access: bool| { + if is_access { + timestamp_delta += 1; } - } else { - cols.is_boundary = if n as u32 == record.len / 16 { - F::ONE - } else { - F::ZERO - }; - let mut data = [[F::ZERO; MEMCPY_LOOP_NUM_LIMBS]; 4]; - for i in 0..4 { - data[i] = record.memory_read_data[(n - 1) * 4 + i + has_shift] - .map(F::from_canonical_u8); - cols.read_aux[i] - .set_base(record.read_aux[(n - 1) * 4 + i + has_shift].to_aux_cols()); - cols.write_aux[i].set_base(record.write_aux[(n - 1) * 4 + i].to_aux_cols()); - let write_data: [F; MEMCPY_LOOP_NUM_LIMBS] = std::array::from_fn(|j| { - if j < 4 - record.shift as usize { - data[i][record.shift as usize + j] + timestamp - timestamp_delta + }; + + let mut dest = record.inner.dest + ((num_rows - 1) << 4) as u32; + let mut source = record.inner.source + ((num_rows - 1) << 4) as u32 + - 12 * (record.inner.shift != 0) as u32; + let mut len = + record.inner.len - ((num_rows - 1) << 4) as u32 - record.inner.shift as u32; + + // Fill memcpy loop record + self.memcpy_loop_chip.add_new_loop( + mem_helper, + record.inner.from_pc, + record.inner.from_timestamp, + record.inner.dest, + record.inner.source, + record.inner.len, + record.inner.shift, + record.inner.register_aux.clone(), + ); + + // We are going to fill row in the reverse order + chunk + .rchunks_exact_mut(width) + .zip(record.var.iter().enumerate().rev()) + .for_each(|(row, (idx, var))| { + let cols: &mut MemcpyIterCols = row.borrow_mut(); + + let is_end = (idx == 0); + let is_start = (idx == num_rows - 1); + + // Range check len + let len_u16_limbs = [len & 0xffff, len >> 16]; + if is_end { + self.range_checker_chip.add_count(len_u16_limbs[0], 4); + self.range_checker_chip.add_count(len_u16_limbs[1], 0); + } else { + self.range_checker_chip + .add_count(len_u16_limbs[0], 2 * MEMCPY_LOOP_LIMB_BITS); + self.range_checker_chip.add_count( + len_u16_limbs[1], + self.pointer_max_bits - 2 * MEMCPY_LOOP_LIMB_BITS, + ); + } + + // Fill memory read/write auxiliary columns + if is_start { + debug_assert_eq!(get_timestamp(false), record.inner.from_timestamp); + + cols.write_aux.iter_mut().rev().for_each(|aux_col| { + mem_helper.fill_zero(aux_col.as_mut()); + }); + + if record.inner.shift == 0 { + mem_helper.fill_zero(cols.read_aux[3].as_mut()); } else { - prev_data[j - (4 - record.shift as usize)] + mem_helper.fill( + var.read_aux[3].prev_timestamp, + timestamp, + cols.read_aux[3].as_mut(), + ); } - }); - cols.write_aux[i].set_prev_data(write_data); - prev_data = data[i]; - } - cols.data_1 = data[0]; - cols.data_2 = data[1]; - cols.data_3 = data[2]; - cols.data_4 = data[3]; - } - if n == 0 { - timestamp += (record.shift != 0) as u32; - } else { - timestamp += 8; - } - } - current_row += (record.len / 16 + 1) as usize * NUM_MEMCPY_ITER_COLS; - } - RowMajorMatrix::new(rows, NUM_MEMCPY_ITER_COLS) + cols.read_aux[..2].iter_mut().rev().for_each(|aux_col| { + mem_helper.fill_zero(aux_col.as_mut()); + }); + } else { + var.write_aux + .iter() + .rev() + .zip(cols.write_aux.iter_mut().rev()) + .for_each(|(aux_record, aux_col)| { + mem_helper.fill( + aux_record.prev_timestamp, + get_timestamp(true), + aux_col.as_mut(), + ); + aux_col.set_prev_data( + aux_record.prev_data.map(F::from_canonical_u8), + ); + }); + + var.read_aux + .iter() + .rev() + .zip(cols.read_aux.iter_mut().rev()) + .for_each(|(aux_record, aux_col)| { + mem_helper.fill( + aux_record.prev_timestamp, + get_timestamp(true), + aux_col.as_mut(), + ); + }); + } + + cols.data_4 = var.data[3].map(F::from_canonical_u8); + cols.data_3 = var.data[2].map(F::from_canonical_u8); + cols.data_2 = var.data[1].map(F::from_canonical_u8); + cols.data_1 = var.data[0].map(F::from_canonical_u8); + cols.is_boundary = F::from_canonical_u8(is_end as u8 - is_start as u8); + cols.is_valid_not_start = F::from_canonical_u8(1 - is_start as u8); + cols.is_valid = F::ONE; + cols.shift = [record.inner.shift % 2, record.inner.shift / 2] + .map(F::from_canonical_u8); + cols.len = [len & 0xffff, len >> 16].map(F::from_canonical_u32); + cols.source = F::from_canonical_u32(source); + cols.dest = F::from_canonical_u32(dest); + cols.timestamp = F::from_canonical_u32(get_timestamp(false)); + + dest -= 16; + source -= 16; + len += 16; + }); + }); } } -// We allow any `R` type so this can work with arbitrary record arenas. -impl Chip> for MemcpyIterChip -where - Val: PrimeField32, -{ - /// Generates trace and resets the internal counters all to 0. - fn generate_proving_ctx(&self, _: R) -> AirProvingContext> { - let trace = self.generate_trace::>(); - AirProvingContext::simple_no_pis(Arc::new(trace)) +impl Executor for MemcpyIterExecutor { + fn pre_compute_size(&self) -> usize { + size_of::() + } + + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut MemcpyIterPreCompute = data.borrow_mut(); + self.pre_compute_impl(pc, inst, data)?; + Ok(execute_e1_impl::<_, _>) } } -impl ChipUsageGetter for MemcpyIterChip { - fn air_name(&self) -> String { - get_air_name(&self.air) +impl MeteredExecutor for MemcpyIterExecutor { + fn metered_pre_compute_size(&self) -> usize { + size_of::>() } - fn constant_trace_height(&self) -> Option { - Some(self.num_rows.load(std::sync::atomic::Ordering::Relaxed) as usize) + + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + self.pre_compute_impl(pc, inst, &mut data.data)?; + Ok(execute_e2_impl::<_, _>) } - fn current_trace_height(&self) -> usize { - self.num_rows.load(std::sync::atomic::Ordering::Relaxed) as usize +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &MemcpyIterPreCompute, + vm_state: &mut VmExecState, +) -> u32 { + let shift = pre_compute.c; + let mut height = 1; + let (dest, source) = if shift == 0 { + ( + vm_state.vm_read::(RV32_REGISTER_AS, A3_REGISTER_PTR as u32), + vm_state.vm_read::(RV32_REGISTER_AS, A4_REGISTER_PTR as u32), + ) + } else { + ( + vm_state.vm_read::(RV32_REGISTER_AS, A1_REGISTER_PTR as u32), + vm_state.vm_read::(RV32_REGISTER_AS, A3_REGISTER_PTR as u32), + ) + }; + let len = vm_state.vm_read::(RV32_REGISTER_AS, A2_REGISTER_PTR as u32); + + let mut dest = u32::from_le_bytes(dest); + let mut source = u32::from_le_bytes(source); + let mut len = u32::from_le_bytes(len); + + let mut prev_data = if shift == 0 { + [0; 4] + } else { + source -= 12; + vm_state.vm_read::(RV32_MEMORY_AS, source - 4) + }; + + while len - shift as u32 > 15 { + for i in 0..4 { + let data = vm_state.vm_read::(RV32_MEMORY_AS, source + 4 * i); + let write_data: [u8; 4] = array::from_fn(|i| { + if i < 4 - shift as usize { + data[i + shift as usize] + } else { + prev_data[i - (4 - shift as usize)] + } + }); + vm_state.vm_write(RV32_MEMORY_AS, dest + 4 * i, &write_data); + prev_data = data; + } + len -= 16; + source += 16; + dest += 16; + height += 1; } - fn trace_width(&self) -> usize { - NUM_MEMCPY_ITER_COLS + + // Write the result back to memory + if shift == 0 { + vm_state.vm_write( + RV32_REGISTER_AS, + A3_REGISTER_PTR as u32, + &dest.to_le_bytes(), + ); + vm_state.vm_write( + RV32_REGISTER_AS, + A4_REGISTER_PTR as u32, + &source.to_le_bytes(), + ); + } else { + source += 12; + vm_state.vm_write( + RV32_REGISTER_AS, + A1_REGISTER_PTR as u32, + &dest.to_le_bytes(), + ); + vm_state.vm_write( + RV32_REGISTER_AS, + A3_REGISTER_PTR as u32, + &source.to_le_bytes(), + ); + }; + vm_state.vm_write(RV32_REGISTER_AS, A2_REGISTER_PTR as u32, &len.to_le_bytes()); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; + height +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &MemcpyIterPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + let height = execute_e12_impl::(&pre_compute.data, vm_state); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height); +} + +impl MemcpyIterExecutor { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut MemcpyIterPreCompute, + ) -> Result<(), StaticProgramError> { + let Instruction { opcode, c, .. } = inst; + let c_u32 = c.as_canonical_u32(); + if ![0, 1, 2, 3].contains(&c_u32) { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + *data = MemcpyIterPreCompute { c: c_u32 as u8 }; + assert_eq!(*opcode, Rv32MemcpyOpcode::MEMCPY_LOOP.global_opcode()); + Ok(()) } } From e3ba4f0851ad36fcca8c977c813b1c8c6bba0c5e Mon Sep 17 00:00:00 2001 From: Peyman Jabbarzade Date: Thu, 21 Aug 2025 12:56:06 -0400 Subject: [PATCH 06/14] fix: add memcpy executer to tests + lint --- Cargo.lock | 3 + benchmarks/execute/benches/execute.rs | 6 +- crates/sdk/src/config/global.rs | 4 +- extensions/algebra/circuit/Cargo.toml | 1 + .../algebra/circuit/src/extension/mod.rs | 9 ++ extensions/memcpy/circuit/src/extension.rs | 3 +- extensions/memcpy/circuit/src/iteration.rs | 10 +- extensions/memcpy/circuit/src/lib.rs | 4 +- .../memcpy/circuit/src/{core.rs => loops.rs} | 108 +++++++++++------- extensions/pairing/circuit/Cargo.toml | 1 + extensions/pairing/circuit/src/config.rs | 9 ++ guest-libs/pairing/Cargo.toml | 1 + guest-libs/pairing/tests/lib.rs | 6 +- guest-libs/sha2/tests/lib.rs | 2 +- 14 files changed, 110 insertions(+), 57 deletions(-) rename extensions/memcpy/circuit/src/{core.rs => loops.rs} (83%) diff --git a/Cargo.lock b/Cargo.lock index f6cacedf82..f4ffdb359a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5138,6 +5138,7 @@ dependencies = [ "openvm-cuda-builder", "openvm-cuda-common", "openvm-instructions", + "openvm-memcpy-circuit", "openvm-mod-circuit-builder", "openvm-pairing-guest", "openvm-rv32-adapters", @@ -5929,6 +5930,7 @@ dependencies = [ "openvm-ecc-sw-macros", "openvm-ecc-transpiler", "openvm-instructions", + "openvm-memcpy-circuit", "openvm-memcpy-transpiler", "openvm-pairing", "openvm-pairing-circuit", @@ -5963,6 +5965,7 @@ dependencies = [ "openvm-ecc-circuit", "openvm-ecc-guest", "openvm-instructions", + "openvm-memcpy-circuit", "openvm-mod-circuit-builder", "openvm-pairing-guest", "openvm-pairing-transpiler", diff --git a/benchmarks/execute/benches/execute.rs b/benchmarks/execute/benches/execute.rs index a3ba0c857c..dbbccf9861 100644 --- a/benchmarks/execute/benches/execute.rs +++ b/benchmarks/execute/benches/execute.rs @@ -193,7 +193,11 @@ where &config.keccak, inventory, )?; - VmProverExtension::::extend_prover(&MemcpyCpuProverExt, &config.memcpy, inventory)?; + VmProverExtension::::extend_prover( + &MemcpyCpuProverExt, + &config.memcpy, + inventory, + )?; VmProverExtension::::extend_prover(&Sha2CpuProverExt, &config.sha256, inventory)?; VmProverExtension::::extend_prover( &AlgebraCpuProverExt, diff --git a/crates/sdk/src/config/global.rs b/crates/sdk/src/config/global.rs index 81987285b9..17ca24302b 100644 --- a/crates/sdk/src/config/global.rs +++ b/crates/sdk/src/config/global.rs @@ -18,6 +18,8 @@ use openvm_ecc_circuit::{ use openvm_ecc_transpiler::EccTranspilerExtension; use openvm_keccak256_circuit::{Keccak256, Keccak256CpuProverExt, Keccak256Executor}; use openvm_keccak256_transpiler::Keccak256TranspilerExtension; +use openvm_memcpy_circuit::{Memcpy, MemcpyCpuProverExt, MemcpyExecutor}; +use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_native_circuit::{ CastFExtension, CastFExtensionExecutor, Native, NativeCpuProverExt, NativeExecutor, }; @@ -33,8 +35,6 @@ use openvm_rv32im_circuit::{ use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; -use openvm_memcpy_circuit::{Memcpy, MemcpyCpuProverExt, MemcpyExecutor}; -use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha2CpuProverExt}; use openvm_sha256_transpiler::Sha256TranspilerExtension; use openvm_stark_backend::{ diff --git a/extensions/algebra/circuit/Cargo.toml b/extensions/algebra/circuit/Cargo.toml index 5f00cb2ea7..309e76b3bb 100644 --- a/extensions/algebra/circuit/Cargo.toml +++ b/extensions/algebra/circuit/Cargo.toml @@ -17,6 +17,7 @@ openvm-stark-backend = { workspace = true } openvm-mod-circuit-builder = { workspace = true } openvm-stark-sdk = { workspace = true } openvm-rv32im-circuit = { workspace = true } +openvm-memcpy-circuit = { workspace = true } openvm-rv32-adapters = { workspace = true } openvm-algebra-transpiler = { workspace = true } openvm-cuda-backend = { workspace = true, optional = true } diff --git a/extensions/algebra/circuit/src/extension/mod.rs b/extensions/algebra/circuit/src/extension/mod.rs index 35e0aebfc4..2e0eb4ba8f 100644 --- a/extensions/algebra/circuit/src/extension/mod.rs +++ b/extensions/algebra/circuit/src/extension/mod.rs @@ -9,6 +9,7 @@ use openvm_circuit::{ system::{SystemChipInventory, SystemCpuBuilder, SystemExecutor}, }; use openvm_circuit_derive::VmConfig; +use openvm_memcpy_circuit::{Memcpy, MemcpyCpuProverExt, MemcpyExecutor}; use openvm_rv32im_circuit::{ Rv32I, Rv32IExecutor, Rv32ImCpuProverExt, Rv32Io, Rv32IoExecutor, Rv32M, Rv32MExecutor, }; @@ -59,6 +60,8 @@ pub struct Rv32ModularConfig { pub io: Rv32Io, #[extension] pub modular: ModularExtension, + #[extension] + pub memcpy: Memcpy, } impl InitFileGenerator for Rv32ModularConfig { @@ -78,6 +81,7 @@ impl Rv32ModularConfig { mul: Default::default(), io: Default::default(), modular: ModularExtension::new(moduli), + memcpy: Memcpy, } } } @@ -145,6 +149,11 @@ where &config.modular, inventory, )?; + VmProverExtension::::extend_prover( + &MemcpyCpuProverExt, + &config.memcpy, + inventory, + )?; Ok(chip_complex) } } diff --git a/extensions/memcpy/circuit/src/extension.rs b/extensions/memcpy/circuit/src/extension.rs index 184bdfc4bf..e4ae5b3444 100644 --- a/extensions/memcpy/circuit/src/extension.rs +++ b/extensions/memcpy/circuit/src/extension.rs @@ -12,6 +12,7 @@ use openvm_circuit::{ }; use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor}; use openvm_instructions::*; +use openvm_memcpy_transpiler::Rv32MemcpyOpcode; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, engine::StarkEngine, @@ -23,8 +24,6 @@ use strum::IntoEnumIterator; use crate::*; -use openvm_memcpy_transpiler::Rv32MemcpyOpcode; - // =================================== VM Extension Implementation ================================= #[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] pub struct Memcpy; diff --git a/extensions/memcpy/circuit/src/iteration.rs b/extensions/memcpy/circuit/src/iteration.rs index 1c9aca0458..bb0952b783 100644 --- a/extensions/memcpy/circuit/src/iteration.rs +++ b/extensions/memcpy/circuit/src/iteration.rs @@ -48,7 +48,6 @@ use crate::{ bus::MemcpyBus, MemcpyLoopChip, A1_REGISTER_PTR, A2_REGISTER_PTR, A3_REGISTER_PTR, A4_REGISTER_PTR, }; - // Import constants from lib.rs use crate::{MEMCPY_LOOP_LIMB_BITS, MEMCPY_LOOP_NUM_LIMBS}; @@ -177,7 +176,8 @@ impl Air for MemcpyIterAir { is_not_valid_when.assert_zero(local.is_boundary); is_not_valid_when.assert_zero(shift.clone()); - // if is_valid_not_start = 1, then len = prev_len - 16, source = prev_source + 16, dest = prev_dest + 16 + // if is_valid_not_start = 1, then len = prev_len - 16, source = prev_source + 16, + // and dest = prev_dest + 16 let mut is_valid_not_start_when = builder.when(local.is_valid_not_start); is_valid_not_start_when .assert_eq(local.len[0], prev.len[0] - AB::Expr::from_canonical_u32(16)); @@ -190,7 +190,7 @@ impl Air for MemcpyIterAir { builder .when(not::(prev.is_valid_not_start) - not::(prev.is_valid)) .assert_eq(local.timestamp, prev.timestamp + is_shift_non_zero.clone()); - // if prev.is_valid_not_start and local.is_valid_not_start, then timestamp = prev_timestamp + 8 + // if prev.is_valid_not_start and local.is_valid_not_start, then timestamp=prev_timestamp+8 // prev.is_valid_not_start is the opposite of previous condition builder .when( @@ -602,8 +602,8 @@ impl TraceFiller for MemcpyIterFiller { .for_each(|(row, (idx, var))| { let cols: &mut MemcpyIterCols = row.borrow_mut(); - let is_end = (idx == 0); - let is_start = (idx == num_rows - 1); + let is_end = idx == 0; + let is_start = idx == num_rows - 1; // Range check len let len_u16_limbs = [len & 0xffff, len >> 16]; diff --git a/extensions/memcpy/circuit/src/lib.rs b/extensions/memcpy/circuit/src/lib.rs index 28f660af34..205e10a1c1 100644 --- a/extensions/memcpy/circuit/src/lib.rs +++ b/extensions/memcpy/circuit/src/lib.rs @@ -1,11 +1,11 @@ mod bus; -mod core; mod extension; mod iteration; +mod loops; -pub use core::*; pub use extension::*; pub use iteration::*; +pub use loops::*; // ==== Do not change these constants! ==== pub const MEMCPY_LOOP_NUM_LIMBS: usize = 4; diff --git a/extensions/memcpy/circuit/src/core.rs b/extensions/memcpy/circuit/src/loops.rs similarity index 83% rename from extensions/memcpy/circuit/src/core.rs rename to extensions/memcpy/circuit/src/loops.rs index bf4180b7c7..5a6566955b 100644 --- a/extensions/memcpy/circuit/src/core.rs +++ b/extensions/memcpy/circuit/src/loops.rs @@ -1,19 +1,21 @@ use std::{ - array, borrow::{Borrow, BorrowMut}, - mem::{align_of, size_of}, + mem::size_of, sync::{Arc, Mutex}, }; use openvm_circuit::{ - arch::*, - system::{memory::{ - offline_checker::{ - MemoryBaseAuxCols, MemoryBaseAuxRecord, MemoryBridge, MemoryExtendedAuxRecord, MemoryReadAuxRecord, MemoryWriteAuxCols, MemoryWriteBytesAuxRecord + arch::{ExecutionBridge, ExecutionState}, + system::{ + memory::{ + offline_checker::{ + MemoryBaseAuxCols, MemoryBaseAuxRecord, MemoryBridge, MemoryExtendedAuxRecord, + MemoryWriteAuxCols, + }, + MemoryAddress, MemoryAuxColsFactory, }, - online::{GuestMemory, TracingMemory}, - MemoryAddress, MemoryAuxColsFactory, - }, SystemPort}, + SystemPort, + }, }; use openvm_circuit_primitives::{ utils::{not, or, select}, @@ -21,25 +23,20 @@ use openvm_circuit_primitives::{ AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{ - instruction::Instruction, - program::DEFAULT_PC_STEP, - riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, - LocalOpcode, -}; -use openvm_rv32im_circuit::adapters::{read_rv32_register, tracing_read, tracing_write}; +use openvm_instructions::riscv::RV32_MEMORY_AS; +use openvm_memcpy_transpiler::Rv32MemcpyOpcode; use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, interaction::InteractionBuilder, p3_air::{Air, AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, p3_matrix::{dense::RowMajorMatrix, Matrix}, prover::{cpu::CpuBackend, types::AirProvingContext}, rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, Chip, ChipUsageGetter + config::{StarkGenericConfig, Val}, + interaction::InteractionBuilder, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_matrix::{dense::RowMajorMatrix, Matrix}, + prover::{cpu::CpuBackend, types::AirProvingContext}, + rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, + Chip, ChipUsageGetter, }; -use crate::{bus::MemcpyBus, MemcpyIterChip}; -use openvm_circuit::arch::{ - execution_mode::{ExecutionCtxTrait, MeteredExecutionCtxTrait}, - get_record_from_slice, ExecuteFunc, ExecutionError, Executor, MeteredExecutor, RecordArena, - StaticProgramError, TraceFiller, VmExecState, -}; -use openvm_memcpy_transpiler::Rv32MemcpyOpcode; - +use crate::bus::MemcpyBus; // Import constants from lib.rs use crate::{ A1_REGISTER_PTR, A2_REGISTER_PTR, A3_REGISTER_PTR, A4_REGISTER_PTR, MEMCPY_LOOP_LIMB_BITS, @@ -94,7 +91,8 @@ impl Air for MemcpyLoopAir { let mut timestamp_delta: u32 = 0; let mut timestamp_pp = || { timestamp_delta += 1; - local.to_timestamp - AB::Expr::from_canonical_u32(MEMCPY_LOOP_NUM_WRITES + timestamp_delta - 1) + local.to_timestamp + - AB::Expr::from_canonical_u32(MEMCPY_LOOP_NUM_WRITES + timestamp_delta - 1) }; let from_le_bytes = |data: [AB::Var; 4]| { @@ -257,7 +255,8 @@ impl Air for MemcpyLoopAir { // Make sure the request and response match builder.assert_eq( - local.to_timestamp - (local.from_state.timestamp + AB::Expr::from_canonical_u32(timestamp_delta)), + local.to_timestamp + - (local.from_state.timestamp + AB::Expr::from_canonical_u32(timestamp_delta)), AB::Expr::TWO * (len.clone() - to_len) + is_shift_non_zero.clone(), ); @@ -301,7 +300,13 @@ impl MemcpyLoopChip { range_checker_chip: SharedVariableRangeCheckerChip, ) -> Self { Self { - air: MemcpyLoopAir::new(system_port.memory_bridge, ExecutionBridge::new(system_port.execution_bus, system_port.program_bus), range_bus, memcpy_bus, pointer_max_bits), + air: MemcpyLoopAir::new( + system_port.memory_bridge, + ExecutionBridge::new(system_port.execution_bus, system_port.program_bus), + range_bus, + memcpy_bus, + pointer_max_bits, + ), records: Arc::new(Mutex::new(Vec::new())), pointer_max_bits, range_checker_chip, @@ -327,18 +332,24 @@ impl MemcpyLoopChip { shift: u8, register_aux: [MemoryBaseAuxRecord; 3], ) { - let mut timestamp = from_timestamp + (((len - shift as u32) & !0x0f) >> 1) + (shift != 0) as u32; - let write_aux = register_aux.iter().map(|aux_record| { - let mut aux_col = MemoryBaseAuxCols::default(); - mem_helper.fill(aux_record.prev_timestamp, timestamp, &mut aux_col); - timestamp += 1; - MemoryExtendedAuxRecord::from_aux_cols(aux_col) - }).collect::>().try_into().unwrap(); + let mut timestamp = + from_timestamp + (((len - shift as u32) & !0x0f) >> 1) + (shift != 0) as u32; + let write_aux = register_aux + .iter() + .map(|aux_record| { + let mut aux_col = MemoryBaseAuxCols::default(); + mem_helper.fill(aux_record.prev_timestamp, timestamp, &mut aux_col); + timestamp += 1; + MemoryExtendedAuxRecord::from_aux_cols(aux_col) + }) + .collect::>() + .try_into() + .unwrap(); let num_copies = (len - shift as u32) & !0x0f; let to_dest = dest + num_copies; let to_source = source + num_copies; - + let word_to_u16 = |data: u32| [data & 0xffff, data >> 16]; let range_check_data = [ (word_to_u16(len), false), @@ -379,7 +390,7 @@ impl MemcpyLoopChip { /// Generates trace pub fn generate_trace(&self) -> RowMajorMatrix { - let mut rows = F::zero_vec((self.records.lock().unwrap().len() as usize) * NUM_MEMCPY_LOOP_COLS); + let mut rows = F::zero_vec(self.records.lock().unwrap().len() * NUM_MEMCPY_LOOP_COLS); for (i, record) in self.records.lock().unwrap().iter().enumerate() { let row = &mut rows[i * NUM_MEMCPY_LOOP_COLS..(i + 1) * NUM_MEMCPY_LOOP_COLS]; @@ -394,11 +405,22 @@ impl MemcpyLoopChip { cols.dest = record.dest.to_le_bytes().map(F::from_canonical_u8); cols.source = record.source.to_le_bytes().map(F::from_canonical_u8); cols.len = record.len.to_le_bytes().map(F::from_canonical_u8); - cols.shift = [F::from_canonical_u8(shift % 2), F::from_canonical_u8(shift / 2)]; + cols.shift = [ + F::from_canonical_u8(shift % 2), + F::from_canonical_u8(shift / 2), + ]; cols.is_valid = F::ONE; - // We have MEMCPY_LOOP_NUM_WRITES writes in the loop, (num_copies / 4 + shift != 0) reads and (num_copies / 4) writes in iterations - cols.to_timestamp = F::from_canonical_u32(record.from_timestamp + MEMCPY_LOOP_NUM_WRITES + (num_copies >> 1) + (shift != 0) as u32); - cols.to_dest = (record.dest + num_copies).to_le_bytes().map(F::from_canonical_u8); + // We have MEMCPY_LOOP_NUM_WRITES writes in the loop, (num_copies / 4) writes + // and (num_copies / 4 + shift != 0) reads in iterations + cols.to_timestamp = F::from_canonical_u32( + record.from_timestamp + + MEMCPY_LOOP_NUM_WRITES + + (num_copies >> 1) + + (shift != 0) as u32, + ); + cols.to_dest = (record.dest + num_copies) + .to_le_bytes() + .map(F::from_canonical_u8); cols.to_source = to_source.to_le_bytes().map(F::from_canonical_u8); cols.to_len = F::from_canonical_u32(record.len - num_copies); cols.write_aux = record.write_aux.clone().map(|aux| aux.to_aux_cols()); @@ -426,10 +448,10 @@ impl ChipUsageGetter for MemcpyLoopChip { get_air_name(&self.air) } fn constant_trace_height(&self) -> Option { - Some(self.records.lock().unwrap().len() as usize) + Some(self.records.lock().unwrap().len()) } fn current_trace_height(&self) -> usize { - self.records.lock().unwrap().len() as usize + self.records.lock().unwrap().len() } fn trace_width(&self) -> usize { NUM_MEMCPY_LOOP_COLS diff --git a/extensions/pairing/circuit/Cargo.toml b/extensions/pairing/circuit/Cargo.toml index 4f4710fd2d..ef24579bf0 100644 --- a/extensions/pairing/circuit/Cargo.toml +++ b/extensions/pairing/circuit/Cargo.toml @@ -23,6 +23,7 @@ openvm-stark-backend = { workspace = true } openvm-rv32im-circuit = { workspace = true } openvm-algebra-circuit = { workspace = true } openvm-ecc-circuit = { workspace = true } +openvm-memcpy-circuit = { workspace = true } openvm-pairing-transpiler = { workspace = true } openvm-cuda-backend = { workspace = true, optional = true } openvm-stark-sdk = { workspace = true, optional = true } diff --git a/extensions/pairing/circuit/src/config.rs b/extensions/pairing/circuit/src/config.rs index 20ea07186a..ce6756651c 100644 --- a/extensions/pairing/circuit/src/config.rs +++ b/extensions/pairing/circuit/src/config.rs @@ -13,6 +13,7 @@ use openvm_circuit::{ }; use openvm_circuit_derive::VmConfig; use openvm_ecc_circuit::{EccCpuProverExt, WeierstrassExtension, WeierstrassExtensionExecutor}; +use openvm_memcpy_circuit::{Memcpy, MemcpyCpuProverExt, MemcpyExecutor}; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, engine::StarkEngine, @@ -33,6 +34,8 @@ pub struct Rv32PairingConfig { pub weierstrass: WeierstrassExtension, #[extension(generics = true)] pub pairing: PairingExtension, + #[extension] + pub memcpy: Memcpy, } impl Rv32PairingConfig { @@ -55,6 +58,7 @@ impl Rv32PairingConfig { curves.iter().map(|c| c.curve_config()).collect(), ), pairing: PairingExtension::new(curves), + memcpy: Memcpy, } } } @@ -101,6 +105,11 @@ where inventory, )?; VmProverExtension::::extend_prover(&PairingProverExt, &config.pairing, inventory)?; + VmProverExtension::::extend_prover( + &MemcpyCpuProverExt, + &config.memcpy, + inventory, + )?; Ok(chip_complex) } } diff --git a/guest-libs/pairing/Cargo.toml b/guest-libs/pairing/Cargo.toml index 8ac04ebef4..36ebeaa953 100644 --- a/guest-libs/pairing/Cargo.toml +++ b/guest-libs/pairing/Cargo.toml @@ -47,6 +47,7 @@ openvm-ecc-guest.workspace = true openvm-ecc-transpiler.workspace = true openvm-rv32im-transpiler.workspace = true openvm-memcpy-transpiler.workspace = true +openvm-memcpy-circuit.workspace = true openvm = { workspace = true } openvm-toolchain-tests = { workspace = true } eyre.workspace = true diff --git a/guest-libs/pairing/tests/lib.rs b/guest-libs/pairing/tests/lib.rs index cd79b97545..f8e49c4ead 100644 --- a/guest-libs/pairing/tests/lib.rs +++ b/guest-libs/pairing/tests/lib.rs @@ -24,6 +24,8 @@ mod bn254 { }; use openvm_ecc_transpiler::EccTranspilerExtension; use openvm_instructions::exe::VmExe; + use openvm_memcpy_circuit::Memcpy; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_pairing_circuit::{ PairingCurve, PairingExtension, Rv32PairingBuilder, Rv32PairingConfig, }; @@ -36,7 +38,6 @@ mod bn254 { use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; - use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_stark_sdk::{ config::FriParameters, openvm_stark_backend::p3_field::FieldAlgebra, p3_baby_bear::BabyBear, }; @@ -59,6 +60,7 @@ mod bn254 { fp2: Fp2Extension::new(primes_with_names), weierstrass: WeierstrassExtension::new(vec![]), pairing: PairingExtension::new(vec![PairingCurve::Bn254]), + memcpy: Memcpy, } } @@ -499,6 +501,7 @@ mod bls12_381 { AffinePoint, }; use openvm_ecc_transpiler::EccTranspilerExtension; + use openvm_memcpy_circuit::Memcpy; use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_pairing_circuit::{ PairingCurve, PairingExtension, Rv32PairingBuilder, Rv32PairingConfig, @@ -537,6 +540,7 @@ mod bls12_381 { fp2: Fp2Extension::new(primes_with_names), weierstrass: WeierstrassExtension::new(vec![]), pairing: PairingExtension::new(vec![PairingCurve::Bls12_381]), + memcpy: Memcpy, } } diff --git a/guest-libs/sha2/tests/lib.rs b/guest-libs/sha2/tests/lib.rs index 9d449627dd..0708e84af6 100644 --- a/guest-libs/sha2/tests/lib.rs +++ b/guest-libs/sha2/tests/lib.rs @@ -3,13 +3,13 @@ mod tests { use eyre::Result; use openvm_circuit::utils::air_test; use openvm_instructions::exe::VmExe; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; use openvm_sha256_circuit::{Sha256Rv32Builder, Sha256Rv32Config}; use openvm_sha256_transpiler::Sha256TranspilerExtension; use openvm_stark_sdk::p3_baby_bear::BabyBear; - use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_toolchain_tests::{build_example_program_at_path, get_programs_dir}; use openvm_transpiler::{transpiler::Transpiler, FromElf}; From f2a29595b0d930a20967060cc0a904605a8ed756 Mon Sep 17 00:00:00 2001 From: Peyman Jabbarzade Date: Fri, 22 Aug 2025 13:51:48 -0400 Subject: [PATCH 07/14] fix some bugs --- .../src/system/memory/offline_checker/mod.rs | 11 +- extensions/memcpy/circuit/src/bus.rs | 16 +- extensions/memcpy/circuit/src/extension.rs | 8 +- extensions/memcpy/circuit/src/iteration.rs | 101 ++++++++----- extensions/memcpy/circuit/src/loops.rs | 140 +++++++++--------- extensions/memcpy/transpiler/src/lib.rs | 4 +- 6 files changed, 154 insertions(+), 126 deletions(-) diff --git a/crates/vm/src/system/memory/offline_checker/mod.rs b/crates/vm/src/system/memory/offline_checker/mod.rs index 010a74bb31..5ef5745d11 100644 --- a/crates/vm/src/system/memory/offline_checker/mod.rs +++ b/crates/vm/src/system/memory/offline_checker/mod.rs @@ -32,13 +32,10 @@ impl MemoryExtendedAuxRecord { } } - pub fn to_aux_cols(&self) -> MemoryBaseAuxCols { - MemoryBaseAuxCols { - prev_timestamp: F::from_canonical_u32(self.prev_timestamp), - timestamp_lt_aux: LessThanAuxCols { - lower_decomp: self.timestamp_lt_aux.map(|x| F::from_canonical_u32(x)), - }, - } + pub fn to_aux_cols(&self, aux_cols: &mut MemoryBaseAuxCols) { + aux_cols.prev_timestamp = F::from_canonical_u32(self.prev_timestamp); + aux_cols.timestamp_lt_aux.lower_decomp = + self.timestamp_lt_aux.map(|x| F::from_canonical_u32(x)); } } diff --git a/extensions/memcpy/circuit/src/bus.rs b/extensions/memcpy/circuit/src/bus.rs index fcb77932c2..83393f1aa4 100644 --- a/extensions/memcpy/circuit/src/bus.rs +++ b/extensions/memcpy/circuit/src/bus.rs @@ -29,10 +29,10 @@ impl MemcpyBus { timestamp: impl Into, dest: impl Into, source: impl Into, - n: impl Into, + len: impl Into, shift: impl Into, ) -> MemcpyBusInteraction { - self.push(true, timestamp, dest, source, n, shift) + self.push(true, timestamp, dest, source, len, shift) } pub fn receive( @@ -40,10 +40,10 @@ impl MemcpyBus { timestamp: impl Into, dest: impl Into, source: impl Into, - n: impl Into, + len: impl Into, shift: impl Into, ) -> MemcpyBusInteraction { - self.push(false, timestamp, dest, source, n, shift) + self.push(false, timestamp, dest, source, len, shift) } fn push( @@ -52,7 +52,7 @@ impl MemcpyBus { timestamp: impl Into, dest: impl Into, source: impl Into, - n: impl Into, + len: impl Into, shift: impl Into, ) -> MemcpyBusInteraction { MemcpyBusInteraction { @@ -61,7 +61,7 @@ impl MemcpyBus { timestamp: timestamp.into(), dest: dest.into(), source: source.into(), - n: n.into(), + len: len.into(), shift: shift.into(), } } @@ -74,7 +74,7 @@ pub struct MemcpyBusInteraction { pub timestamp: T, pub dest: T, pub source: T, - pub n: T, + pub len: T, pub shift: T, } @@ -87,7 +87,7 @@ impl MemcpyBusInteraction { .chain(iter::once(self.timestamp)) .chain(iter::once(self.dest)) .chain(iter::once(self.source)) - .chain(iter::once(self.n)) + .chain(iter::once(self.len)) .chain(iter::once(self.shift)); if self.is_send { diff --git a/extensions/memcpy/circuit/src/extension.rs b/extensions/memcpy/circuit/src/extension.rs index e4ae5b3444..5bd3db9b8b 100644 --- a/extensions/memcpy/circuit/src/extension.rs +++ b/extensions/memcpy/circuit/src/extension.rs @@ -40,7 +40,7 @@ impl VmExecutionExtension for Memcpy { &self, inventory: &mut ExecutorInventoryBuilder, ) -> Result<(), ExecutorInventoryError> { - let memcpy_iter = MemcpyIterExecutor::new(); + let memcpy_iter = MemcpyIterExecutor::new(Rv32MemcpyOpcode::CLASS_OFFSET); inventory.add_executor( memcpy_iter, @@ -71,6 +71,7 @@ impl VmCircuitExtension for Memcpy { range_bus, memcpy_bus, pointer_max_bits, + Rv32MemcpyOpcode::CLASS_OFFSET, ); inventory.add_air(memcpy_loop); @@ -113,6 +114,7 @@ where inventory.airs().system().port(), range_bus, memcpy_bus, + Rv32MemcpyOpcode::CLASS_OFFSET, pointer_max_bits, range_checker.clone(), )); @@ -127,11 +129,11 @@ where ); // Add MemcpyLoop chip inventory.next_air::()?; - inventory.add_executor_chip(memcpy_loop_chip); + inventory.add_periphery_chip(memcpy_loop_chip); // Add MemcpyIter chip inventory.next_air::()?; - inventory.add_periphery_chip(memcpy_iter_chip); + inventory.add_executor_chip(memcpy_iter_chip); Ok(()) } diff --git a/extensions/memcpy/circuit/src/iteration.rs b/extensions/memcpy/circuit/src/iteration.rs index bb0952b783..44630b8ff4 100644 --- a/extensions/memcpy/circuit/src/iteration.rs +++ b/extensions/memcpy/circuit/src/iteration.rs @@ -111,15 +111,17 @@ impl Air for MemcpyIterAir { let is_shift_two = and::(not::(local.shift[0]), local.shift[1]); let is_shift_three = and::(local.shift[0], local.shift[1]); - // TODO:since if is_valid = 0, then is_boundary = 0, we can reduce the degree of the following expressions by removing the is_valid term let is_end = (local.is_boundary + AB::Expr::ONE) * local.is_boundary * (AB::F::TWO).inverse(); let is_not_start = (local.is_boundary + AB::Expr::ONE) * (AB::Expr::TWO - local.is_boundary) * (AB::F::TWO).inverse(); + let prev_is_not_end = not::( + (prev.is_boundary + AB::Expr::ONE) * prev.is_boundary * (AB::F::TWO).inverse(), + ); let len = local.len[0] - + local.len[1] * AB::F::from_canonical_u32(1 << (2 * MEMCPY_LOOP_LIMB_BITS)); + + local.len[1] * AB::Expr::from_canonical_u32(1 << (2 * MEMCPY_LOOP_LIMB_BITS)); // write_data = // (local.data_1[shift..4], prev.data_4[0..shift]), @@ -136,7 +138,7 @@ impl Air for MemcpyIterAir { let write_data = write_data_pairs .iter() .map(|(prev_data, next_data)| { - array::from_fn(|i| { + array::from_fn::<_, MEMCPY_LOOP_NUM_LIMBS, _>(|i| { is_shift_zero.clone() * (next_data[i]) + is_shift_one.clone() * (if i < 3 { @@ -161,35 +163,44 @@ impl Air for MemcpyIterAir { .collect::>(); builder.assert_bool(local.is_valid); - for i in 0..2 { - builder.assert_bool(local.shift[i]); - } + local.shift.iter().for_each(|x| builder.assert_bool(*x)); builder.assert_bool(local.is_valid_not_start); // is_boundary is either -1, 0 or 1 builder.assert_tern(local.is_boundary + AB::Expr::ONE); // is_valid_not_start = is_valid and is_not_start: - builder.assert_eq(local.is_valid_not_start, local.is_valid * is_not_start); + builder.assert_eq( + local.is_valid_not_start, + and::(local.is_valid, is_not_start), + ); - // if is_valid = 0, then is_boundary = 0, shift = 0 + // if !is_valid, then is_boundary = 0, shift = 0 (we will use this assumption later) let mut is_not_valid_when = builder.when(not::(local.is_valid)); is_not_valid_when.assert_zero(local.is_boundary); is_not_valid_when.assert_zero(shift.clone()); - // if is_valid_not_start = 1, then len = prev_len - 16, source = prev_source + 16, - // and dest = prev_dest + 16 + // if is_valid_not_start, then len = prev_len - 16, source = prev_source + 16, + // and dest = prev_dest + 16, shift = prev_shift let mut is_valid_not_start_when = builder.when(local.is_valid_not_start); is_valid_not_start_when .assert_eq(local.len[0], prev.len[0] - AB::Expr::from_canonical_u32(16)); is_valid_not_start_when .assert_eq(local.source, prev.source + AB::Expr::from_canonical_u32(16)); is_valid_not_start_when.assert_eq(local.dest, prev.dest + AB::Expr::from_canonical_u32(16)); + is_valid_not_start_when.assert_eq(local.shift[0], prev.shift[0]); + is_valid_not_start_when.assert_eq(local.shift[1], prev.shift[1]); + + // make sure if previous row is valid and not end, then local.is_valid = 1 + builder + .when(prev_is_not_end - not::(prev.is_valid)) + .assert_one(local.is_valid); // if prev.is_valid_start, then timestamp = prev_timestamp + is_shift_non_zero // since is_shift_non_zero degree is 2, we need to keep the degree of the condition to 1 builder .when(not::(prev.is_valid_not_start) - not::(prev.is_valid)) .assert_eq(local.timestamp, prev.timestamp + is_shift_non_zero.clone()); + // if prev.is_valid_not_start and local.is_valid_not_start, then timestamp=prev_timestamp+8 // prev.is_valid_not_start is the opposite of previous condition builder @@ -239,7 +250,7 @@ impl Air for MemcpyIterAir { .read( MemoryAddress::new( AB::Expr::from_canonical_u32(RV32_MEMORY_AS), - local.source + AB::Expr::from_canonical_usize(idx * 4), + local.source - AB::Expr::from_canonical_usize(16 - idx * 4), ), *data, timestamp_pp(), @@ -254,7 +265,7 @@ impl Air for MemcpyIterAir { .write( MemoryAddress::new( AB::Expr::from_canonical_u32(RV32_MEMORY_AS), - local.dest + AB::Expr::from_canonical_usize(idx * 4), + local.dest - AB::Expr::from_canonical_usize(16 - idx * 4), ), data.clone(), timestamp_pp(), @@ -286,7 +297,9 @@ impl Air for MemcpyIterAir { } #[derive(derive_new::new, Clone, Copy)] -pub struct MemcpyIterExecutor {} +pub struct MemcpyIterExecutor { + pub offset: usize, +} #[derive(Copy, Clone, Debug)] pub struct MemcpyIterMetadata { @@ -378,12 +391,6 @@ pub struct MemcpyIterFiller { pub type MemcpyIterChip = VmChipWrapper; -#[derive(AlignedBytesBorrow, Clone)] -#[repr(C)] -struct MemcpyIterPreCompute { - c: u8, -} - impl PreflightExecutor for MemcpyIterExecutor where F: PrimeField32, @@ -452,6 +459,8 @@ where let write_data: [u8; MEMCPY_LOOP_NUM_LIMBS] = array::from_fn(|j| { if j < 4 - shift as usize { record.var[idx].data[i][j + shift as usize] + } else if i > 0 { + record.var[idx].data[i - 1][j - (4 - shift as usize)] } else { record.var[idx - 1].data[i][j - (4 - shift as usize)] } @@ -565,6 +574,18 @@ impl TraceFiller for MemcpyIterFiller { ) }; + // Fill memcpy loop record + self.memcpy_loop_chip.add_new_loop( + mem_helper, + record.inner.from_pc, + record.inner.from_timestamp, + record.inner.dest, + record.inner.source, + record.inner.len, + record.inner.shift, + record.inner.register_aux.clone(), + ); + // 4 reads + 4 writes per iteration + (shift != 0) read for the loop header let timestamp = record.inner.from_timestamp + ((num_rows - 1) << 3) as u32 @@ -583,18 +604,6 @@ impl TraceFiller for MemcpyIterFiller { let mut len = record.inner.len - ((num_rows - 1) << 4) as u32 - record.inner.shift as u32; - // Fill memcpy loop record - self.memcpy_loop_chip.add_new_loop( - mem_helper, - record.inner.from_pc, - record.inner.from_timestamp, - record.inner.dest, - record.inner.source, - record.inner.len, - record.inner.shift, - record.inner.register_aux.clone(), - ); - // We are going to fill row in the reverse order chunk .rchunks_exact_mut(width) @@ -602,8 +611,8 @@ impl TraceFiller for MemcpyIterFiller { .for_each(|(row, (idx, var))| { let cols: &mut MemcpyIterCols = row.borrow_mut(); - let is_end = idx == 0; - let is_start = idx == num_rows - 1; + let is_start = idx == 0; + let is_end = idx == num_rows - 1; // Range check len let len_u16_limbs = [len & 0xffff, len >> 16]; @@ -621,8 +630,6 @@ impl TraceFiller for MemcpyIterFiller { // Fill memory read/write auxiliary columns if is_start { - debug_assert_eq!(get_timestamp(false), record.inner.from_timestamp); - cols.write_aux.iter_mut().rev().for_each(|aux_col| { mem_helper.fill_zero(aux_col.as_mut()); }); @@ -632,18 +639,20 @@ impl TraceFiller for MemcpyIterFiller { } else { mem_helper.fill( var.read_aux[3].prev_timestamp, - timestamp, + get_timestamp(true), cols.read_aux[3].as_mut(), ); } cols.read_aux[..2].iter_mut().rev().for_each(|aux_col| { mem_helper.fill_zero(aux_col.as_mut()); }); + + debug_assert_eq!(get_timestamp(false), record.inner.from_timestamp); } else { var.write_aux .iter() + .zip(cols.write_aux.iter_mut()) .rev() - .zip(cols.write_aux.iter_mut().rev()) .for_each(|(aux_record, aux_col)| { mem_helper.fill( aux_record.prev_timestamp, @@ -657,8 +666,8 @@ impl TraceFiller for MemcpyIterFiller { var.read_aux .iter() + .zip(cols.read_aux.iter_mut()) .rev() - .zip(cols.read_aux.iter_mut().rev()) .for_each(|(aux_record, aux_col)| { mem_helper.fill( aux_record.prev_timestamp, @@ -672,10 +681,16 @@ impl TraceFiller for MemcpyIterFiller { cols.data_3 = var.data[2].map(F::from_canonical_u8); cols.data_2 = var.data[1].map(F::from_canonical_u8); cols.data_1 = var.data[0].map(F::from_canonical_u8); - cols.is_boundary = F::from_canonical_u8(is_end as u8 - is_start as u8); + cols.is_boundary = if is_end { + F::ONE + } else if is_start { + F::NEG_ONE + } else { + F::ZERO + }; cols.is_valid_not_start = F::from_canonical_u8(1 - is_start as u8); cols.is_valid = F::ONE; - cols.shift = [record.inner.shift % 2, record.inner.shift / 2] + cols.shift = [record.inner.shift & 1, record.inner.shift >> 1] .map(F::from_canonical_u8); cols.len = [len & 0xffff, len >> 16].map(F::from_canonical_u32); cols.source = F::from_canonical_u32(source); @@ -690,6 +705,12 @@ impl TraceFiller for MemcpyIterFiller { } } +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct MemcpyIterPreCompute { + c: u8, +} + impl Executor for MemcpyIterExecutor { fn pre_compute_size(&self) -> usize { size_of::() diff --git a/extensions/memcpy/circuit/src/loops.rs b/extensions/memcpy/circuit/src/loops.rs index 5a6566955b..368abc237b 100644 --- a/extensions/memcpy/circuit/src/loops.rs +++ b/extensions/memcpy/circuit/src/loops.rs @@ -16,6 +16,7 @@ use openvm_circuit::{ }, SystemPort, }, + utils::next_power_of_two_or_zero, }; use openvm_circuit_primitives::{ utils::{not, or, select}, @@ -23,7 +24,7 @@ use openvm_circuit_primitives::{ AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::riscv::RV32_MEMORY_AS; +use openvm_instructions::riscv::RV32_REGISTER_AS; use openvm_memcpy_transpiler::Rv32MemcpyOpcode; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, @@ -71,6 +72,7 @@ pub struct MemcpyLoopAir { pub range_bus: VariableRangeCheckerBus, pub memcpy_bus: MemcpyBus, pub pointer_max_bits: usize, + pub offset: usize, } impl BaseAir for MemcpyLoopAir { @@ -92,12 +94,12 @@ impl Air for MemcpyLoopAir { let mut timestamp_pp = || { timestamp_delta += 1; local.to_timestamp - - AB::Expr::from_canonical_u32(MEMCPY_LOOP_NUM_WRITES + timestamp_delta - 1) + - AB::Expr::from_canonical_u32(MEMCPY_LOOP_NUM_WRITES - (timestamp_delta - 1)) }; let from_le_bytes = |data: [AB::Var; 4]| { - data.iter().fold(AB::Expr::ZERO, |acc, x| { - acc * AB::Expr::from_canonical_u32(256) + *x + data.iter().rev().fold(AB::Expr::ZERO, |acc, x| { + acc * AB::Expr::from_canonical_u32(1 << MEMCPY_LOOP_LIMB_BITS) + *x }) }; @@ -108,8 +110,9 @@ impl Air for MemcpyLoopAir { ] }; - let shift = local.shift[1] * AB::Expr::from_canonical_u32(2) + local.shift[0]; + let shift = local.shift[1] * AB::Expr::TWO + local.shift[0]; let is_shift_non_zero = or::(local.shift[0], local.shift[1]); + let is_shift_zero = not::(is_shift_non_zero.clone()); let dest = from_le_bytes(local.dest); let source = from_le_bytes(local.source); let len = from_le_bytes(local.len); @@ -118,52 +121,52 @@ impl Air for MemcpyLoopAir { let to_len = local.to_len; builder.assert_bool(local.is_valid); - for i in 0..2 { - builder.assert_bool(local.shift[i]); - } + local.shift.iter().for_each(|x| builder.assert_bool(*x)); builder.assert_bool(local.source_minus_twelve_carry); builder.assert_bool(local.to_source_minus_twelve_carry); - let mut shift_zero_when = builder.when(not::(is_shift_non_zero.clone())); + let mut shift_zero_when = builder.when(is_shift_zero.clone()); shift_zero_when.assert_zero(local.source_minus_twelve_carry); shift_zero_when.assert_zero(local.to_source_minus_twelve_carry); // Write source and destination to registers let write_data = [ - (local.dest, local.to_dest, A1_REGISTER_PTR, A3_REGISTER_PTR), + (local.dest, local.to_dest, A3_REGISTER_PTR, A1_REGISTER_PTR), ( local.source, local.to_source, - A2_REGISTER_PTR, A4_REGISTER_PTR, + A3_REGISTER_PTR, ), ]; - write_data - .iter() - .enumerate() - .for_each(|(idx, (dest, to_dest, ptr, zero_shift_ptr))| { + write_data.iter().enumerate().for_each( + |(idx, (prev_data, new_data, zero_shift_ptr, non_zero_shift_ptr))| { let write_ptr = select::( - is_shift_non_zero.clone(), - AB::Expr::from_canonical_usize(*ptr), + is_shift_zero.clone(), AB::Expr::from_canonical_usize(*zero_shift_ptr), + AB::Expr::from_canonical_usize(*non_zero_shift_ptr), ); self.memory_bridge .write( - MemoryAddress::new(AB::Expr::from_canonical_u32(RV32_MEMORY_AS), write_ptr), - *to_dest, + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_REGISTER_AS), + write_ptr, + ), + *new_data, timestamp_pp(), - &MemoryWriteAuxCols::from_base(local.write_aux[idx], *dest), + &MemoryWriteAuxCols::from_base(local.write_aux[idx], *prev_data), ) .eval(builder, local.is_valid); - }); + }, + ); // Write length to a2 register self.memory_bridge .write( MemoryAddress::new( - AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + AB::Expr::from_canonical_u32(RV32_REGISTER_AS), AB::Expr::from_canonical_usize(A2_REGISTER_PTR), ), [ @@ -178,7 +181,6 @@ impl Air for MemcpyLoopAir { .eval(builder, local.is_valid); // Generate 16-bit limbs for range checking - let len_u16_limbs = u8_word_to_u16(local.len); let dest_u16_limbs = u8_word_to_u16(local.dest); let to_dest_u16_limbs = u8_word_to_u16(local.to_dest); let source_u16_limbs = [ @@ -202,26 +204,20 @@ impl Air for MemcpyLoopAir { - local.to_source_minus_twelve_carry, ]; - // Range check addresses and n + // Range check addresses let range_check_data = [ - (len_u16_limbs, false), - (dest_u16_limbs, true), - (source_u16_limbs, true), - (to_dest_u16_limbs, true), - (to_source_u16_limbs, true), + dest_u16_limbs, + source_u16_limbs, + to_dest_u16_limbs, + to_source_u16_limbs, ]; - range_check_data.iter().for_each(|(data, is_address)| { - let (data_0, num_bits) = if *is_address { - ( + range_check_data.iter().for_each(|data| { + self.range_bus + .range_check( data[0].clone() * AB::F::from_canonical_u32(4).inverse(), MEMCPY_LOOP_LIMB_BITS * 2 - 2, ) - } else { - (data[0].clone(), MEMCPY_LOOP_LIMB_BITS * 2) - }; - self.range_bus - .range_check(data_0, num_bits) .eval(builder, local.is_valid); self.range_bus .range_check( @@ -253,20 +249,24 @@ impl Air for MemcpyLoopAir { ) .eval(builder, local.is_valid); - // Make sure the request and response match - builder.assert_eq( - local.to_timestamp - - (local.from_state.timestamp + AB::Expr::from_canonical_u32(timestamp_delta)), - AB::Expr::TWO * (len.clone() - to_len) + is_shift_non_zero.clone(), + // Make sure the request and response match, this should work because the + // from_timestamp and len are valid and to_len is in [0, 16 + shift) + builder.when(local.is_valid).assert_eq( + AB::Expr::TWO * (local.to_timestamp - local.from_state.timestamp), + (len.clone() - to_len) + + AB::Expr::TWO + * (is_shift_non_zero.clone() + AB::Expr::from_canonical_u32(timestamp_delta)), ); // Execution bus + program bus self.execution_bridge .execute_and_increment_pc( - AB::Expr::from_canonical_usize(Rv32MemcpyOpcode::MEMCPY_LOOP as usize), - [shift.clone()], + AB::Expr::from_canonical_usize( + Rv32MemcpyOpcode::MEMCPY_LOOP as usize + self.offset, + ), + [AB::Expr::ZERO, AB::Expr::ZERO, shift.clone()], local.from_state, - local.to_timestamp, + local.to_timestamp - local.from_state.timestamp, ) .eval(builder, local.is_valid); } @@ -296,6 +296,7 @@ impl MemcpyLoopChip { system_port: SystemPort, range_bus: VariableRangeCheckerBus, memcpy_bus: MemcpyBus, + offset: usize, pointer_max_bits: usize, range_checker_chip: SharedVariableRangeCheckerChip, ) -> Self { @@ -306,6 +307,7 @@ impl MemcpyLoopChip { range_bus, memcpy_bus, pointer_max_bits, + offset, ), records: Arc::new(Mutex::new(Vec::new())), pointer_max_bits, @@ -350,23 +352,23 @@ impl MemcpyLoopChip { let to_dest = dest + num_copies; let to_source = source + num_copies; - let word_to_u16 = |data: u32| [data & 0xffff, data >> 16]; + let word_to_u16 = |data: u32| [data & 0x0ffff, data >> 16]; + debug_assert!(source >= 12 * (shift != 0) as u32); + debug_assert!(to_source >= 12 * (shift != 0) as u32); + debug_assert!(dest % 4 == 0); + debug_assert!(to_dest % 4 == 0); + debug_assert!(source % 4 == 0); + debug_assert!(to_source % 4 == 0); let range_check_data = [ - (word_to_u16(len), false), - (word_to_u16(dest), true), - (word_to_u16(source - 12 * (shift != 0) as u32), true), - (word_to_u16(to_dest), true), - (word_to_u16(to_source - 12 * (shift != 0) as u32), true), + word_to_u16(dest), + word_to_u16(source - 12 * (shift != 0) as u32), + word_to_u16(to_dest), + word_to_u16(to_source - 12 * (shift != 0) as u32), ]; - range_check_data.iter().for_each(|(data, is_address)| { - if *is_address { - self.range_checker_chip - .add_count(data[0] >> 2, 2 * MEMCPY_LOOP_LIMB_BITS - 2) - } else { - self.range_checker_chip - .add_count(data[0], 2 * MEMCPY_LOOP_LIMB_BITS) - }; + range_check_data.iter().for_each(|data| { + self.range_checker_chip + .add_count(data[0] >> 2, 2 * MEMCPY_LOOP_LIMB_BITS - 2); self.range_checker_chip .add_count(data[1], self.pointer_max_bits - 2 * MEMCPY_LOOP_LIMB_BITS); }); @@ -390,7 +392,8 @@ impl MemcpyLoopChip { /// Generates trace pub fn generate_trace(&self) -> RowMajorMatrix { - let mut rows = F::zero_vec(self.records.lock().unwrap().len() * NUM_MEMCPY_LOOP_COLS); + let height = next_power_of_two_or_zero(self.records.lock().unwrap().len()); + let mut rows = F::zero_vec(height * NUM_MEMCPY_LOOP_COLS); for (i, record) in self.records.lock().unwrap().iter().enumerate() { let row = &mut rows[i * NUM_MEMCPY_LOOP_COLS..(i + 1) * NUM_MEMCPY_LOOP_COLS]; @@ -405,10 +408,7 @@ impl MemcpyLoopChip { cols.dest = record.dest.to_le_bytes().map(F::from_canonical_u8); cols.source = record.source.to_le_bytes().map(F::from_canonical_u8); cols.len = record.len.to_le_bytes().map(F::from_canonical_u8); - cols.shift = [ - F::from_canonical_u8(shift % 2), - F::from_canonical_u8(shift / 2), - ]; + cols.shift = [shift & 1, shift >> 1].map(F::from_canonical_u8); cols.is_valid = F::ONE; // We have MEMCPY_LOOP_NUM_WRITES writes in the loop, (num_copies / 4) writes // and (num_copies / 4 + shift != 0) reads in iterations @@ -423,9 +423,15 @@ impl MemcpyLoopChip { .map(F::from_canonical_u8); cols.to_source = to_source.to_le_bytes().map(F::from_canonical_u8); cols.to_len = F::from_canonical_u32(record.len - num_copies); - cols.write_aux = record.write_aux.clone().map(|aux| aux.to_aux_cols()); - cols.source_minus_twelve_carry = F::from_bool((record.source & 0x0ff) < 12); - cols.to_source_minus_twelve_carry = F::from_bool((to_source & 0x0ff) < 12); + record + .write_aux + .iter() + .zip(cols.write_aux.iter_mut()) + .for_each(|(record_aux, col_aux)| { + record_aux.to_aux_cols(col_aux); + }); + cols.source_minus_twelve_carry = F::from_bool((record.source & 0x0ffff) < 12); + cols.to_source_minus_twelve_carry = F::from_bool((to_source & 0x0ffff) < 12); } RowMajorMatrix::new(rows, NUM_MEMCPY_LOOP_COLS) } diff --git a/extensions/memcpy/transpiler/src/lib.rs b/extensions/memcpy/transpiler/src/lib.rs index bd39fd5d86..4dc13d3ebd 100644 --- a/extensions/memcpy/transpiler/src/lib.rs +++ b/extensions/memcpy/transpiler/src/lib.rs @@ -46,10 +46,12 @@ impl TranspilerExtension for MemcpyTranspilerExtension { } // Convert to OpenVM instruction format - let instruction = from_u_type( + let mut instruction = from_u_type( Rv32MemcpyOpcode::MEMCPY_LOOP.global_opcode().as_usize(), &dec_insn, ); + instruction.a = F::ZERO; + instruction.d = F::ZERO; Some(TranspilerOutput::one_to_one(instruction)) } From 90dee892d1addf883530c489e65024e5d965be52 Mon Sep 17 00:00:00 2001 From: Peyman Jabbarzade Date: Mon, 25 Aug 2025 10:38:20 -0400 Subject: [PATCH 08/14] fix: reduce memory_bridge enabled degree from 3 to 2 --- extensions/memcpy/circuit/src/iteration.rs | 23 ++++++++++++++-------- extensions/memcpy/circuit/src/loops.rs | 3 --- extensions/sha256/circuit/Cargo.toml | 1 + extensions/sha256/circuit/src/lib.rs | 5 +++++ 4 files changed, 21 insertions(+), 11 deletions(-) diff --git a/extensions/memcpy/circuit/src/iteration.rs b/extensions/memcpy/circuit/src/iteration.rs index 44630b8ff4..2b40563c45 100644 --- a/extensions/memcpy/circuit/src/iteration.rs +++ b/extensions/memcpy/circuit/src/iteration.rs @@ -61,6 +61,7 @@ pub struct MemcpyIterCols { pub shift: [T; 2], pub is_valid: T, pub is_valid_not_start: T, + pub is_shift_non_zero: T, // -1 for the first iteration, 1 for the last iteration, 0 for the middle iterations pub is_boundary: T, pub data_1: [T; MEMCPY_LOOP_NUM_LIMBS], @@ -105,8 +106,7 @@ impl Air for MemcpyIterAir { }; let shift = local.shift[0] * AB::Expr::TWO + local.shift[1]; - let is_shift_non_zero = or::(local.shift[0], local.shift[1]); - let is_shift_zero = not::(is_shift_non_zero.clone()); + let is_shift_zero = not::(local.is_shift_non_zero.clone()); let is_shift_one = and::(local.shift[0], not::(local.shift[1])); let is_shift_two = and::(not::(local.shift[0]), local.shift[1]); let is_shift_three = and::(local.shift[0], local.shift[1]); @@ -122,6 +122,8 @@ impl Air for MemcpyIterAir { let len = local.len[0] + local.len[1] * AB::Expr::from_canonical_u32(1 << (2 * MEMCPY_LOOP_LIMB_BITS)); + let prev_len = prev.len[0] + + prev.len[1] * AB::Expr::from_canonical_u32(1 << (2 * MEMCPY_LOOP_LIMB_BITS)); // write_data = // (local.data_1[shift..4], prev.data_4[0..shift]), @@ -165,6 +167,7 @@ impl Air for MemcpyIterAir { builder.assert_bool(local.is_valid); local.shift.iter().for_each(|x| builder.assert_bool(*x)); builder.assert_bool(local.is_valid_not_start); + builder.assert_bool(local.is_shift_non_zero); // is_boundary is either -1, 0 or 1 builder.assert_tern(local.is_boundary + AB::Expr::ONE); @@ -174,6 +177,9 @@ impl Air for MemcpyIterAir { and::(local.is_valid, is_not_start), ); + // is_shift_non_zero is correct + builder.assert_eq(local.is_shift_non_zero, or::(local.shift[0], local.shift[1])); + // if !is_valid, then is_boundary = 0, shift = 0 (we will use this assumption later) let mut is_not_valid_when = builder.when(not::(local.is_valid)); is_not_valid_when.assert_zero(local.is_boundary); @@ -183,7 +189,7 @@ impl Air for MemcpyIterAir { // and dest = prev_dest + 16, shift = prev_shift let mut is_valid_not_start_when = builder.when(local.is_valid_not_start); is_valid_not_start_when - .assert_eq(local.len[0], prev.len[0] - AB::Expr::from_canonical_u32(16)); + .assert_eq(len.clone(), prev_len - AB::Expr::from_canonical_u32(16)); is_valid_not_start_when .assert_eq(local.source, prev.source + AB::Expr::from_canonical_u32(16)); is_valid_not_start_when.assert_eq(local.dest, prev.dest + AB::Expr::from_canonical_u32(16)); @@ -199,7 +205,7 @@ impl Air for MemcpyIterAir { // since is_shift_non_zero degree is 2, we need to keep the degree of the condition to 1 builder .when(not::(prev.is_valid_not_start) - not::(prev.is_valid)) - .assert_eq(local.timestamp, prev.timestamp + is_shift_non_zero.clone()); + .assert_eq(local.timestamp, prev.timestamp + local.is_shift_non_zero.clone()); // if prev.is_valid_not_start and local.is_valid_not_start, then timestamp=prev_timestamp+8 // prev.is_valid_not_start is the opposite of previous condition @@ -222,7 +228,7 @@ impl Air for MemcpyIterAir { + (local.is_boundary + AB::Expr::ONE) * AB::Expr::from_canonical_usize(4), local.dest, local.source, - len, + len.clone(), (AB::Expr::ONE - local.is_boundary) * shift.clone() * (AB::F::TWO).inverse() + (local.is_boundary + AB::Expr::ONE) * AB::Expr::TWO, ) @@ -241,7 +247,7 @@ impl Air for MemcpyIterAir { .enumerate() .for_each(|(idx, (data, read_aux))| { let is_valid_read = if idx == 3 { - or::(is_shift_non_zero.clone(), local.is_valid_not_start) + or::(local.is_shift_non_zero.clone(), local.is_valid_not_start) } else { local.is_valid_not_start.into() }; @@ -288,10 +294,10 @@ impl Air for MemcpyIterAir { ), ]; self.range_bus - .push(local.len[0], len_bits_limit[0].clone(), true) + .push(local.len[0].clone(), len_bits_limit[0].clone(), true) .eval(builder, local.is_valid); self.range_bus - .push(local.len[1], len_bits_limit[1].clone(), true) + .push(local.len[1].clone(), len_bits_limit[1].clone(), true) .eval(builder, local.is_valid); } } @@ -688,6 +694,7 @@ impl TraceFiller for MemcpyIterFiller { } else { F::ZERO }; + cols.is_shift_non_zero = F::from_canonical_u8((record.inner.shift != 0) as u8); cols.is_valid_not_start = F::from_canonical_u8(1 - is_start as u8); cols.is_valid = F::ONE; cols.shift = [record.inner.shift & 1, record.inner.shift >> 1] diff --git a/extensions/memcpy/circuit/src/loops.rs b/extensions/memcpy/circuit/src/loops.rs index 368abc237b..6751899153 100644 --- a/extensions/memcpy/circuit/src/loops.rs +++ b/extensions/memcpy/circuit/src/loops.rs @@ -453,9 +453,6 @@ impl ChipUsageGetter for MemcpyLoopChip { fn air_name(&self) -> String { get_air_name(&self.air) } - fn constant_trace_height(&self) -> Option { - Some(self.records.lock().unwrap().len()) - } fn current_trace_height(&self) -> usize { self.records.lock().unwrap().len() } diff --git a/extensions/sha256/circuit/Cargo.toml b/extensions/sha256/circuit/Cargo.toml index 740a4302f5..f581346734 100644 --- a/extensions/sha256/circuit/Cargo.toml +++ b/extensions/sha256/circuit/Cargo.toml @@ -16,6 +16,7 @@ openvm-circuit = { workspace = true } openvm-instructions = { workspace = true } openvm-sha256-transpiler = { workspace = true } openvm-rv32im-circuit = { workspace = true } +openvm-memcpy-circuit = { workspace = true } openvm-sha256-air = { workspace = true } derive-new.workspace = true diff --git a/extensions/sha256/circuit/src/lib.rs b/extensions/sha256/circuit/src/lib.rs index 2847e51636..8afde9b5cd 100644 --- a/extensions/sha256/circuit/src/lib.rs +++ b/extensions/sha256/circuit/src/lib.rs @@ -12,6 +12,7 @@ use openvm_circuit::{ system::{SystemChipInventory, SystemCpuBuilder, SystemExecutor}, }; use openvm_circuit_derive::VmConfig; +use openvm_memcpy_circuit::{Memcpy, MemcpyCpuProverExt, MemcpyExecutor}; use openvm_rv32im_circuit::{ Rv32I, Rv32IExecutor, Rv32ImCpuProverExt, Rv32Io, Rv32IoExecutor, Rv32M, Rv32MExecutor, }; @@ -55,6 +56,8 @@ pub struct Sha256Rv32Config { pub io: Rv32Io, #[extension] pub sha256: Sha256, + #[extension] + pub memcpy: Memcpy, } impl Default for Sha256Rv32Config { @@ -65,6 +68,7 @@ impl Default for Sha256Rv32Config { rv32m: Rv32M::default(), io: Rv32Io, sha256: Sha256, + memcpy: Memcpy, } } } @@ -100,6 +104,7 @@ where VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.rv32m, inventory)?; VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.io, inventory)?; VmProverExtension::::extend_prover(&Sha2CpuProverExt, &config.sha256, inventory)?; + VmProverExtension::::extend_prover(&MemcpyCpuProverExt, &config.memcpy, inventory)?; Ok(chip_complex) } } From baecf3607ad5e59c47c89ec3ad63053c151e2ccc Mon Sep 17 00:00:00 2001 From: Peyman Jabbarzade Date: Mon, 25 Aug 2025 18:50:00 -0400 Subject: [PATCH 09/14] feat: add memcpy tests --- Cargo.lock | 21 + Cargo.toml | 1 + crates/vm/src/arch/testing/mod.rs | 1 + extensions/memcpy/circuit/src/iteration.rs | 16 +- extensions/memcpy/circuit/src/lib.rs | 1 + extensions/memcpy/tests.rs | 444 --------------- extensions/memcpy/tests/Cargo.toml | 5 +- extensions/memcpy/tests/src/lib.rs | 607 +++++---------------- 8 files changed, 186 insertions(+), 910 deletions(-) delete mode 100644 extensions/memcpy/tests.rs diff --git a/Cargo.lock b/Cargo.lock index f4ffdb359a..ad4ec19151 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5767,6 +5767,26 @@ dependencies = [ "strum", ] +[[package]] +name = "openvm-memcpy-integration-tests" +version = "1.4.0-rc.6" +dependencies = [ + "eyre", + "openvm", + "openvm-circuit", + "openvm-circuit-primitives", + "openvm-instructions", + "openvm-memcpy-circuit", + "openvm-memcpy-transpiler", + "openvm-stark-backend", + "openvm-stark-sdk", + "rand 0.8.5", + "serde", + "strum", + "test-case", + "tracing", +] + [[package]] name = "openvm-memcpy-transpiler" version = "1.4.0-rc.6" @@ -6257,6 +6277,7 @@ dependencies = [ "openvm-cuda-builder", "openvm-cuda-common", "openvm-instructions", + "openvm-memcpy-circuit", "openvm-rv32im-circuit", "openvm-sha256-air", "openvm-sha256-transpiler", diff --git a/Cargo.toml b/Cargo.toml index b4419de60d..561f338548 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,6 +63,7 @@ members = [ "extensions/pairing/guest", "extensions/memcpy/circuit", "extensions/memcpy/transpiler", + "extensions/memcpy/tests", "guest-libs/ff_derive/", "guest-libs/k256/", "guest-libs/p256/", diff --git a/crates/vm/src/arch/testing/mod.rs b/crates/vm/src/arch/testing/mod.rs index 5293a0275a..91595c404e 100644 --- a/crates/vm/src/arch/testing/mod.rs +++ b/crates/vm/src/arch/testing/mod.rs @@ -29,6 +29,7 @@ pub const BITWISE_OP_LOOKUP_BUS: BusIndex = 9; pub const BYTE_XOR_BUS: BusIndex = 10; pub const RANGE_TUPLE_CHECKER_BUS: BusIndex = 11; pub const MEMORY_MERKLE_BUS: BusIndex = 12; +pub const MEMCPY_BUS: BusIndex = 13; pub const RANGE_CHECKER_BUS: BusIndex = 4; diff --git a/extensions/memcpy/circuit/src/iteration.rs b/extensions/memcpy/circuit/src/iteration.rs index 2b40563c45..7b2b71373e 100644 --- a/extensions/memcpy/circuit/src/iteration.rs +++ b/extensions/memcpy/circuit/src/iteration.rs @@ -61,7 +61,7 @@ pub struct MemcpyIterCols { pub shift: [T; 2], pub is_valid: T, pub is_valid_not_start: T, - pub is_shift_non_zero: T, + pub is_shift_zero: T, // -1 for the first iteration, 1 for the last iteration, 0 for the middle iterations pub is_boundary: T, pub data_1: [T; MEMCPY_LOOP_NUM_LIMBS], @@ -106,7 +106,7 @@ impl Air for MemcpyIterAir { }; let shift = local.shift[0] * AB::Expr::TWO + local.shift[1]; - let is_shift_zero = not::(local.is_shift_non_zero.clone()); + let is_shift_non_zero = not::(local.is_shift_zero); let is_shift_one = and::(local.shift[0], not::(local.shift[1])); let is_shift_two = and::(not::(local.shift[0]), local.shift[1]); let is_shift_three = and::(local.shift[0], local.shift[1]); @@ -141,7 +141,7 @@ impl Air for MemcpyIterAir { .iter() .map(|(prev_data, next_data)| { array::from_fn::<_, MEMCPY_LOOP_NUM_LIMBS, _>(|i| { - is_shift_zero.clone() * (next_data[i]) + local.is_shift_zero.clone() * (next_data[i]) + is_shift_one.clone() * (if i < 3 { next_data[i + 1] @@ -167,7 +167,7 @@ impl Air for MemcpyIterAir { builder.assert_bool(local.is_valid); local.shift.iter().for_each(|x| builder.assert_bool(*x)); builder.assert_bool(local.is_valid_not_start); - builder.assert_bool(local.is_shift_non_zero); + builder.assert_bool(local.is_shift_zero); // is_boundary is either -1, 0 or 1 builder.assert_tern(local.is_boundary + AB::Expr::ONE); @@ -178,7 +178,7 @@ impl Air for MemcpyIterAir { ); // is_shift_non_zero is correct - builder.assert_eq(local.is_shift_non_zero, or::(local.shift[0], local.shift[1])); + builder.assert_eq(local.is_shift_zero, not::(or::(local.shift[0], local.shift[1]))); // if !is_valid, then is_boundary = 0, shift = 0 (we will use this assumption later) let mut is_not_valid_when = builder.when(not::(local.is_valid)); @@ -205,7 +205,7 @@ impl Air for MemcpyIterAir { // since is_shift_non_zero degree is 2, we need to keep the degree of the condition to 1 builder .when(not::(prev.is_valid_not_start) - not::(prev.is_valid)) - .assert_eq(local.timestamp, prev.timestamp + local.is_shift_non_zero.clone()); + .assert_eq(local.timestamp, prev.timestamp + is_shift_non_zero.clone()); // if prev.is_valid_not_start and local.is_valid_not_start, then timestamp=prev_timestamp+8 // prev.is_valid_not_start is the opposite of previous condition @@ -247,7 +247,7 @@ impl Air for MemcpyIterAir { .enumerate() .for_each(|(idx, (data, read_aux))| { let is_valid_read = if idx == 3 { - or::(local.is_shift_non_zero.clone(), local.is_valid_not_start) + or::(is_shift_non_zero.clone(), local.is_valid_not_start) } else { local.is_valid_not_start.into() }; @@ -694,7 +694,7 @@ impl TraceFiller for MemcpyIterFiller { } else { F::ZERO }; - cols.is_shift_non_zero = F::from_canonical_u8((record.inner.shift != 0) as u8); + cols.is_shift_zero = F::from_canonical_u8((record.inner.shift == 0) as u8); cols.is_valid_not_start = F::from_canonical_u8(1 - is_start as u8); cols.is_valid = F::ONE; cols.shift = [record.inner.shift & 1, record.inner.shift >> 1] diff --git a/extensions/memcpy/circuit/src/lib.rs b/extensions/memcpy/circuit/src/lib.rs index 205e10a1c1..e81deb9e37 100644 --- a/extensions/memcpy/circuit/src/lib.rs +++ b/extensions/memcpy/circuit/src/lib.rs @@ -3,6 +3,7 @@ mod extension; mod iteration; mod loops; +pub use bus::*; pub use extension::*; pub use iteration::*; pub use loops::*; diff --git a/extensions/memcpy/tests.rs b/extensions/memcpy/tests.rs deleted file mode 100644 index 743199194b..0000000000 --- a/extensions/memcpy/tests.rs +++ /dev/null @@ -1,444 +0,0 @@ -// use std::{array, borrow::BorrowMut, sync::Arc}; - -// use openvm_circuit::arch::testing::{TestChipHarness, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; -// use openvm_circuit_primitives::bitwise_op_lookup::{ -// BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, -// SharedBitwiseOperationLookupChip, -// }; -// use openvm_instructions::LocalOpcode; -// use openvm_rv32im_transpiler::BaseAluOpcode::{self, *}; -// use openvm_stark_backend::{ -// p3_air::BaseAir, -// p3_field::{FieldAlgebra, PrimeField32}, -// p3_matrix::{ -// dense::{DenseMatrix, RowMajorMatrix}, -// Matrix, -// }, -// utils::disable_debug_builder, -// }; -// use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -// use rand::{rngs::StdRng, Rng}; -// use test_case::test_case; - -// use super::{core::run_alu, BaseAluCoreAir, Rv32BaseAluChip, Rv32BaseAluExecutor}; -// use crate::{ -// adapters::{ -// Rv32BaseAluAdapterAir, Rv32BaseAluAdapterExecutor, Rv32BaseAluAdapterFiller, -// RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, -// }, -// base_alu::BaseAluCoreCols, -// test_utils::{ -// generate_rv32_is_type_immediate, get_verification_error, rv32_rand_write_register_or_imm, -// }, -// BaseAluFiller, Rv32BaseAluAir, -// }; - -// const MAX_INS_CAPACITY: usize = 128; -// type F = BabyBear; -// type Harness = TestChipHarness>; - -// fn create_test_chip( -// tester: &VmChipTestBuilder, -// ) -> ( -// Harness, -// ( -// BitwiseOperationLookupAir, -// SharedBitwiseOperationLookupChip, -// ), -// ) { -// let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); -// let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( -// bitwise_bus, -// )); - -// let air = Rv32BaseAluAir::new( -// Rv32BaseAluAdapterAir::new( -// tester.execution_bridge(), -// tester.memory_bridge(), -// bitwise_bus, -// ), -// BaseAluCoreAir::new(bitwise_bus, BaseAluOpcode::CLASS_OFFSET), -// ); -// let executor = Rv32BaseAluExecutor::new( -// Rv32BaseAluAdapterExecutor::new(), -// BaseAluOpcode::CLASS_OFFSET, -// ); -// let chip = Rv32BaseAluChip::new( -// BaseAluFiller::new( -// Rv32BaseAluAdapterFiller::new(bitwise_chip.clone()), -// bitwise_chip.clone(), -// BaseAluOpcode::CLASS_OFFSET, -// ), -// tester.memory_helper(), -// ); -// let harness = Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); - -// (harness, (bitwise_chip.air, bitwise_chip)) -// } - -// fn set_and_execute( -// tester: &mut VmChipTestBuilder, -// harness: &mut Harness, -// rng: &mut StdRng, -// opcode: BaseAluOpcode, -// b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, -// is_imm: Option, -// c: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, -// ) { -// let b = b.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); -// let (c_imm, c) = if is_imm.unwrap_or(rng.gen_bool(0.5)) { -// let (imm, c) = if let Some(c) = c { -// ((u32::from_le_bytes(c) & 0xFFFFFF) as usize, c) -// } else { -// generate_rv32_is_type_immediate(rng) -// }; -// (Some(imm), c) -// } else { -// ( -// None, -// c.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))), -// ) -// }; - -// let (instruction, rd) = rv32_rand_write_register_or_imm( -// tester, -// b, -// c, -// c_imm, -// opcode.global_opcode().as_usize(), -// rng, -// ); -// tester.execute(harness, &instruction); - -// let a = run_alu::(opcode, &b, &c) -// .map(F::from_canonical_u8); -// assert_eq!(a, tester.read::(1, rd)) -// } - -// ////////////////////////////////////////////////////////////////////////////////////// -// // POSITIVE TESTS -// // -// // Randomly generate computations and execute, ensuring that the generated trace -// // passes all constraints. -// ////////////////////////////////////////////////////////////////////////////////////// - -// #[test_case(ADD, 100)] -// #[test_case(SUB, 100)] -// #[test_case(XOR, 100)] -// #[test_case(OR, 100)] -// #[test_case(AND, 100)] -// fn rand_rv32_alu_test(opcode: BaseAluOpcode, num_ops: usize) { -// let mut rng = create_seeded_rng(); - -// let mut tester = VmChipTestBuilder::default(); -// let (mut harness, bitwise) = create_test_chip(&tester); - -// // TODO(AG): make a more meaningful test for memory accesses -// tester.write(2, 1024, [F::ONE; 4]); -// tester.write(2, 1028, [F::ONE; 4]); -// let sm = tester.read(2, 1024); -// assert_eq!(sm, [F::ONE; 8]); - -// for _ in 0..num_ops { -// set_and_execute( -// &mut tester, -// &mut harness, -// &mut rng, -// opcode, -// None, -// None, -// None, -// ); -// } - -// let tester = tester -// .build() -// .load(harness) -// .load_periphery(bitwise) -// .finalize(); -// tester.simple_test().expect("Verification failed"); -// } - -// #[test_case(ADD, 100)] -// #[test_case(SUB, 100)] -// #[test_case(XOR, 100)] -// #[test_case(OR, 100)] -// #[test_case(AND, 100)] -// fn rand_rv32_alu_test_persistent(opcode: BaseAluOpcode, num_ops: usize) { -// let mut rng = create_seeded_rng(); - -// let mut tester = VmChipTestBuilder::default_persistent(); -// let (mut harness, bitwise) = create_test_chip(&tester); - -// // TODO(AG): make a more meaningful test for memory accesses -// tester.write(2, 1024, [F::ONE; 4]); -// tester.write(2, 1028, [F::ONE; 4]); -// let sm = tester.read(2, 1024); -// assert_eq!(sm, [F::ONE; 8]); - -// for _ in 0..num_ops { -// set_and_execute( -// &mut tester, -// &mut harness, -// &mut rng, -// opcode, -// None, -// None, -// None, -// ); -// } - -// let tester = tester -// .build() -// .load(harness) -// .load_periphery(bitwise) -// .finalize(); -// tester.simple_test().expect("Verification failed"); -// } - -// ////////////////////////////////////////////////////////////////////////////////////// -// // NEGATIVE TESTS -// // -// // Given a fake trace of a single operation, setup a chip and run the test. We replace -// // part of the trace and check that the chip throws the expected error. -// ////////////////////////////////////////////////////////////////////////////////////// - -// #[allow(clippy::too_many_arguments)] -// fn run_negative_alu_test( -// opcode: BaseAluOpcode, -// prank_a: [u32; RV32_REGISTER_NUM_LIMBS], -// b: [u8; RV32_REGISTER_NUM_LIMBS], -// c: [u8; RV32_REGISTER_NUM_LIMBS], -// prank_c: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, -// prank_opcode_flags: Option<[bool; 5]>, -// is_imm: Option, -// interaction_error: bool, -// ) { -// let mut rng = create_seeded_rng(); -// let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); -// let (mut chip, bitwise) = create_test_chip(&tester); - -// set_and_execute( -// &mut tester, -// &mut chip, -// &mut rng, -// opcode, -// Some(b), -// is_imm, -// Some(c), -// ); - -// let adapter_width = BaseAir::::width(&chip.air.adapter); -// let modify_trace = |trace: &mut DenseMatrix| { -// let mut values = trace.row_slice(0).to_vec(); -// let cols: &mut BaseAluCoreCols = -// values.split_at_mut(adapter_width).1.borrow_mut(); -// cols.a = prank_a.map(F::from_canonical_u32); -// if let Some(prank_c) = prank_c { -// cols.c = prank_c.map(F::from_canonical_u32); -// } -// if let Some(prank_opcode_flags) = prank_opcode_flags { -// cols.opcode_add_flag = F::from_bool(prank_opcode_flags[0]); -// cols.opcode_and_flag = F::from_bool(prank_opcode_flags[1]); -// cols.opcode_or_flag = F::from_bool(prank_opcode_flags[2]); -// cols.opcode_sub_flag = F::from_bool(prank_opcode_flags[3]); -// cols.opcode_xor_flag = F::from_bool(prank_opcode_flags[4]); -// } -// *trace = RowMajorMatrix::new(values, trace.width()); -// }; - -// disable_debug_builder(); -// let tester = tester -// .build() -// .load_and_prank_trace(chip, modify_trace) -// .load_periphery(bitwise) -// .finalize(); -// tester.simple_test_with_expected_error(get_verification_error(interaction_error)); -// } - -// #[test] -// fn rv32_alu_add_wrong_negative_test() { -// run_negative_alu_test( -// ADD, -// [246, 0, 0, 0], -// [250, 0, 0, 0], -// [250, 0, 0, 0], -// None, -// None, -// None, -// false, -// ); -// } - -// #[test] -// fn rv32_alu_add_out_of_range_negative_test() { -// run_negative_alu_test( -// ADD, -// [500, 0, 0, 0], -// [250, 0, 0, 0], -// [250, 0, 0, 0], -// None, -// None, -// None, -// true, -// ); -// } - -// #[test] -// fn rv32_alu_sub_wrong_negative_test() { -// run_negative_alu_test( -// SUB, -// [255, 0, 0, 0], -// [1, 0, 0, 0], -// [2, 0, 0, 0], -// None, -// None, -// None, -// false, -// ); -// } - -// #[test] -// fn rv32_alu_sub_out_of_range_negative_test() { -// run_negative_alu_test( -// SUB, -// [F::NEG_ONE.as_canonical_u32(), 0, 0, 0], -// [1, 0, 0, 0], -// [2, 0, 0, 0], -// None, -// None, -// None, -// true, -// ); -// } - -// #[test] -// fn rv32_alu_xor_wrong_negative_test() { -// run_negative_alu_test( -// XOR, -// [255, 255, 255, 255], -// [0, 0, 1, 0], -// [255, 255, 255, 255], -// None, -// None, -// None, -// true, -// ); -// } - -// #[test] -// fn rv32_alu_or_wrong_negative_test() { -// run_negative_alu_test( -// OR, -// [255, 255, 255, 255], -// [255, 255, 255, 254], -// [0, 0, 0, 0], -// None, -// None, -// None, -// true, -// ); -// } - -// #[test] -// fn rv32_alu_and_wrong_negative_test() { -// run_negative_alu_test( -// AND, -// [255, 255, 255, 255], -// [0, 0, 1, 0], -// [0, 0, 0, 0], -// None, -// None, -// None, -// true, -// ); -// } - -// #[test] -// fn rv32_alu_adapter_unconstrained_imm_limb_test() { -// run_negative_alu_test( -// ADD, -// [255, 7, 0, 0], -// [0, 0, 0, 0], -// [255, 7, 0, 0], -// Some([511, 6, 0, 0]), -// None, -// Some(true), -// true, -// ); -// } - -// #[test] -// fn rv32_alu_adapter_unconstrained_rs2_read_test() { -// run_negative_alu_test( -// ADD, -// [2, 2, 2, 2], -// [1, 1, 1, 1], -// [1, 1, 1, 1], -// None, -// Some([false, false, false, false, false]), -// Some(false), -// false, -// ); -// } - -// /////////////////////////////////////////////////////////////////////////////////////// -// /// SANITY TESTS -// /// -// /// Ensure that solve functions produce the correct results. -// /////////////////////////////////////////////////////////////////////////////////////// - -// #[test] -// fn run_add_sanity_test() { -// let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; -// let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; -// let z: [u8; RV32_REGISTER_NUM_LIMBS] = [23, 205, 73, 49]; -// let result = run_alu::(ADD, &x, &y); -// for i in 0..RV32_REGISTER_NUM_LIMBS { -// assert_eq!(z[i], result[i]) -// } -// } - -// #[test] -// fn run_sub_sanity_test() { -// let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; -// let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; -// let z: [u8; RV32_REGISTER_NUM_LIMBS] = [179, 118, 240, 172]; -// let result = run_alu::(SUB, &x, &y); -// for i in 0..RV32_REGISTER_NUM_LIMBS { -// assert_eq!(z[i], result[i]) -// } -// } - -// #[test] -// fn run_xor_sanity_test() { -// let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; -// let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; -// let z: [u8; RV32_REGISTER_NUM_LIMBS] = [215, 138, 49, 173]; -// let result = run_alu::(XOR, &x, &y); -// for i in 0..RV32_REGISTER_NUM_LIMBS { -// assert_eq!(z[i], result[i]) -// } -// } - -// #[test] -// fn run_or_sanity_test() { -// let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; -// let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; -// let z: [u8; RV32_REGISTER_NUM_LIMBS] = [247, 171, 61, 239]; -// let result = run_alu::(OR, &x, &y); -// for i in 0..RV32_REGISTER_NUM_LIMBS { -// assert_eq!(z[i], result[i]) -// } -// } - -// #[test] -// fn run_and_sanity_test() { -// let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; -// let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; -// let z: [u8; RV32_REGISTER_NUM_LIMBS] = [32, 33, 12, 66]; -// let result = run_alu::(AND, &x, &y); -// for i in 0..RV32_REGISTER_NUM_LIMBS { -// assert_eq!(z[i], result[i]) -// } -// } diff --git a/extensions/memcpy/tests/Cargo.toml b/extensions/memcpy/tests/Cargo.toml index 2dbf144891..bedaf7d8dd 100644 --- a/extensions/memcpy/tests/Cargo.toml +++ b/extensions/memcpy/tests/Cargo.toml @@ -10,17 +10,18 @@ repository.workspace = true [dependencies] openvm-instructions = { workspace = true } openvm-stark-sdk.workspace = true -openvm-transpiler.workspace = true openvm-memcpy-circuit.workspace = true openvm-memcpy-transpiler.workspace = true openvm = { workspace = true } -openvm-toolchain-tests = { path = "../../../crates/toolchain/tests" } eyre.workspace = true serde = { workspace = true, features = ["alloc"] } strum.workspace = true rand.workspace = true openvm-circuit = { workspace = true, features = ["test-utils"] } +openvm-circuit-primitives.workspace = true +openvm-stark-backend.workspace = true test-case.workspace = true +tracing.workspace = true [features] default = ["parallel"] diff --git a/extensions/memcpy/tests/src/lib.rs b/extensions/memcpy/tests/src/lib.rs index 964aa39cb3..b8434f4831 100644 --- a/extensions/memcpy/tests/src/lib.rs +++ b/extensions/memcpy/tests/src/lib.rs @@ -1,93 +1,147 @@ #[cfg(test)] mod tests { - use std::{array, borrow::BorrowMut, sync::Arc}; + use std::sync::Arc; - use openvm_circuit::arch::testing::{TestChipHarness, VmChipTestBuilder}; + use openvm_circuit::{ + arch::{ + testing::{TestBuilder, TestChipHarness, VmChipTestBuilder, MEMCPY_BUS, RANGE_CHECKER_BUS}, + Arena, PreflightExecutor, + }, + system::{memory::SharedMemoryHelper, SystemPort}, + }; use openvm_circuit_primitives::var_range::{ SharedVariableRangeCheckerChip, VariableRangeCheckerAir, VariableRangeCheckerBus, VariableRangeCheckerChip, }; - use openvm_instructions::LocalOpcode; + use openvm_instructions::{instruction::Instruction, LocalOpcode, VmOpcode}; use openvm_memcpy_circuit::{ - bus::MemcpyBus, - extension::{Memcpy, MemcpyCpuProverExt}, - MemcpyIterAir, MemcpyIterCols, MemcpyIterExecutor, MemcpyIterFiller, MemcpyLoopAir, - MemcpyLoopChip, MemcpyLoopExecutor, MEMCPY_LOOP_LIMB_BITS, MEMCPY_LOOP_NUM_LIMBS, + MemcpyBus, MemcpyIterAir, MemcpyIterChip, MemcpyIterExecutor, MemcpyIterFiller, + MemcpyLoopAir, MemcpyLoopChip, A1_REGISTER_PTR, A2_REGISTER_PTR, A3_REGISTER_PTR, + A4_REGISTER_PTR, }; use openvm_memcpy_transpiler::Rv32MemcpyOpcode; - use openvm_stark_backend::{ - p3_air::BaseAir, - p3_field::{FieldAlgebra, PrimeField32}, - p3_matrix::{ - dense::{DenseMatrix, RowMajorMatrix}, - Matrix, - }, - utils::disable_debug_builder, - }; + use openvm_stark_backend::p3_field::FieldAlgebra; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; - use rand::{rngs::StdRng, Rng}; + use rand::Rng; use test_case::test_case; const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; - type Harness = TestChipHarness>; + type Harness = TestChipHarness>; - fn create_test_chip( - tester: &VmChipTestBuilder, - ) -> ( - Harness, - ( - VariableRangeCheckerAir, - SharedVariableRangeCheckerChip, - ), - ) { - let range_bus = VariableRangeCheckerBus::new(tester.new_bus_idx()); - let range_chip = Arc::new(VariableRangeCheckerChip::::new( - range_bus, - )); - - let memcpy_bus = MemcpyBus::new(tester.new_bus_idx()); + fn create_harness_fields( + address_bits: usize, + system_port: SystemPort, + range_chip: Arc, + memory_helper: SharedMemoryHelper, + ) -> (MemcpyIterAir, MemcpyIterExecutor, MemcpyIterChip, Arc) { + let range_bus = range_chip.bus(); + let memcpy_bus = MemcpyBus::new(MEMCPY_BUS); let air = MemcpyIterAir::new( - tester.memory_bridge(), + system_port.memory_bridge, range_bus, memcpy_bus, - tester.pointer_max_bits(), + address_bits, ); let executor = MemcpyIterExecutor::new(Rv32MemcpyOpcode::CLASS_OFFSET); - let chip = MemcpyLoopChip::new( - tester.system_port(), + let loop_chip = Arc::new(MemcpyLoopChip::new( + system_port, range_bus, memcpy_bus, Rv32MemcpyOpcode::CLASS_OFFSET, - tester.pointer_max_bits(), + address_bits, + range_chip.clone(), + )); + let chip = MemcpyIterChip::new( + MemcpyIterFiller::new(address_bits, range_chip, loop_chip.clone()), + memory_helper, + ); + (air, executor, chip, loop_chip) + } + + fn create_harness( + tester: &VmChipTestBuilder, + ) -> ( + Harness, + (VariableRangeCheckerAir, SharedVariableRangeCheckerChip), + (MemcpyLoopAir, Arc), + ) { + let range_bus = VariableRangeCheckerBus::new(RANGE_CHECKER_BUS, tester.address_bits()); + let range_chip = Arc::new(VariableRangeCheckerChip::new(range_bus)); + + let (air, executor, chip, loop_chip) = create_harness_fields( + tester.address_bits(), + tester.system_port(), range_chip.clone(), + tester.memory_helper(), ); let harness = Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); - (harness, (range_chip.air, range_chip)) + ( + harness, + (range_chip.air, range_chip), + (loop_chip.air, loop_chip), + ) } - fn set_and_execute_memcpy( - tester: &mut VmChipTestBuilder, - harness: &mut Harness, - rng: &mut StdRng, + fn set_and_execute_memcpy>( + tester: &mut impl TestBuilder, + executor: &mut E, + arena: &mut RA, shift: u32, source_data: &[u8], dest_offset: u32, source_offset: u32, len: u32, ) { - // Write source data to memory - for (i, &byte) in source_data.iter().enumerate() { - tester.write(2, source_offset + i as u32, [F::from_canonical_u8(byte)]); + // Write source data to memory by words (4 bytes) + let source_words = source_data.len().div_ceil(4); + for word_idx in 0..source_words { + let word_start = word_idx * 4; + let word_end = (word_idx + 1) * 4; + let mut word_data = [F::ZERO; 4]; + + for i in word_start..word_end { + if i < source_data.len() { + word_data[i - word_start] = F::from_canonical_u8(source_data[i]); + } + } + + tester.write(2, (source_offset + word_idx as u32 * 4) as usize, word_data); } - // Create instruction for memcpy_loop - let instruction = openvm_instructions::instruction::Instruction { - opcode: openvm_instructions::VmOpcode::from_usize( - Rv32MemcpyOpcode::MEMCPY_LOOP.global_opcode().as_usize(), - ), + // Set up registers that the memcpy instruction will read from + // destination address + tester.write::<4>( + 1, + if shift == 0 { + A3_REGISTER_PTR + } else { + A1_REGISTER_PTR + }, + dest_offset.to_le_bytes().map(F::from_canonical_u8), + ); + // length + tester.write::<4>( + 1, + A2_REGISTER_PTR, + len.to_le_bytes().map(F::from_canonical_u8), + ); + // source address + tester.write::<4>( + 1, + if shift == 0 { + A4_REGISTER_PTR + } else { + A3_REGISTER_PTR + }, + source_offset.to_le_bytes().map(F::from_canonical_u8), + ); + + // Create instruction for memcpy_iter (uses same opcode as memcpy_loop) + let instruction = Instruction { + opcode: VmOpcode::from_usize(Rv32MemcpyOpcode::MEMCPY_LOOP.global_opcode().as_usize()), a: F::ZERO, b: F::ZERO, c: F::from_canonical_u32(shift), @@ -97,14 +151,23 @@ mod tests { g: F::ZERO, }; - tester.execute(harness, &instruction); - - // Verify the copy operation - for i in 0..len.min(source_data.len() as u32) { - let expected = source_data[i as usize]; - let actual = tester.read(2, dest_offset + i)[0].as_canonical_u8(); - assert_eq!(expected, actual, "Mismatch at offset {}", i); - } + tester.execute(executor, arena, &instruction); + + // Verify the copy operation by reading words + // let dest_words = (len as usize + 3) / 4; // Round up to nearest word + // for word_idx in 0..dest_words { + // let word_data = tester.read::<4>(2, (dest_offset + word_idx as u32 * 4) as usize); + // let word_start = word_idx * 4; + + // for i in 0..4 { + // let byte_idx = word_start + i; + // if byte_idx < len as usize && byte_idx < source_data.len() { + // let expected = source_data[byte_idx]; + // let actual = word_data[i].as_canonical_u32() as u8; + // assert_eq!(expected, actual, "Mismatch at offset {}", byte_idx); + // } + // } + // } } ////////////////////////////////////////////////////////////////////////////////////// @@ -114,38 +177,47 @@ mod tests { // passes all constraints. ////////////////////////////////////////////////////////////////////////////////////// - #[test_case(0, 100)] + #[test_case(0, 1)] #[test_case(1, 100)] #[test_case(2, 100)] #[test_case(3, 100)] - fn rand_memcpy_loop_test(shift: u32, num_ops: usize) { + fn rand_memcpy_iter_test(shift: u32, num_ops: usize) { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let (mut harness, range_checker) = create_test_chip(&tester); + let (mut harness, range_checker, memcpy_loop) = create_harness(&tester); for _ in 0..num_ops { - let source_data: Vec = (0..16).map(|_| rng.gen_range(0..=u8::MAX)).collect(); - let source_offset = rng.gen_range(0..1000); - let dest_offset = rng.gen_range(2000..3000); - let len = rng.gen_range(1..=16); + let source_offset = rng.gen_range(0..250) * 4; // Ensure word alignment + let dest_offset = rng.gen_range(500..750) * 4; // Ensure word alignment + let len: u32 = rng.gen_range(100..=200); + let source_data: Vec = (0..len.div_ceil(4) * 4) + .map(|_| rng.gen_range(0..=u8::MAX)) + .collect(); set_and_execute_memcpy( &mut tester, - &mut harness, - &mut rng, + &mut harness.executor, + &mut harness.arena, shift, &source_data, dest_offset, source_offset, len, ); + tracing::info!( + "source_offset: {}, dest_offset: {}, len: {}", + source_offset, + dest_offset, + len + ); } let tester = tester .build() .load(harness) .load_periphery(range_checker) + .load_periphery(memcpy_loop) .finalize(); tester.simple_test().expect("Verification failed"); } @@ -154,22 +226,24 @@ mod tests { #[test_case(1, 100)] #[test_case(2, 100)] #[test_case(3, 100)] - fn rand_memcpy_loop_test_persistent(shift: u32, num_ops: usize) { + fn rand_memcpy_iter_test_persistent(shift: u32, num_ops: usize) { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default_persistent(); - let (mut harness, range_checker) = create_test_chip(&tester); + let (mut harness, range_checker, _iter_air) = create_harness(&tester); for _ in 0..num_ops { - let source_data: Vec = (0..16).map(|_| rng.gen_range(0..=u8::MAX)).collect(); - let dest_offset = rng.gen_range(0..1000); - let source_offset = rng.gen_range(0..1000); - let len = rng.gen_range(1..=16); + let source_offset = rng.gen_range(0..250) * 4; // Ensure word alignment + let dest_offset = rng.gen_range(500..750) * 4; // Ensure word alignment + let len: u32 = rng.gen_range(100..=200); + let source_data: Vec = (0..len.div_ceil(4) * 4) + .map(|_| rng.gen_range(0..=u8::MAX)) + .collect(); set_and_execute_memcpy( &mut tester, - &mut harness, - &mut rng, + &mut harness.executor, + &mut harness.arena, shift, &source_data, dest_offset, @@ -185,383 +259,4 @@ mod tests { .finalize(); tester.simple_test().expect("Verification failed"); } - - ////////////////////////////////////////////////////////////////////////////////////// - // NEGATIVE TESTS - // - // Given a fake trace of a single operation, setup a chip and run the test. We replace - // part of the trace and check that the chip throws the expected error. - ////////////////////////////////////////////////////////////////////////////////////// - - // #[allow(clippy::too_many_arguments)] - // fn run_negative_memcpy_test( - // shift: u32, - // prank_shift: u32, - // source_data: &[u8], - // dest_offset: u32, - // source_offset: u32, - // len: u32, - // prank_dest: Option, - // prank_source: Option, - // prank_len: Option, - // interaction_error: bool, - // ) { - // let mut rng = create_seeded_rng(); - // let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - // let (mut chip, range_checker) = create_test_chip(&tester); - - // set_and_execute_memcpy( - // &mut tester, - // &mut chip, - // &mut rng, - // shift, - // source_data, - // dest_offset, - // source_offset, - // len, - // ); - - // let adapter_width = BaseAir::::width(&chip.air); - // let modify_trace = |trace: &mut DenseMatrix| { - // let mut values = trace.row_slice(0).to_vec(); - // let cols: &mut MemcpyIterCols = values.split_at_mut(adapter_width).1.borrow_mut(); - // cols.shift = [F::from_canonical_u32(prank_shift), F::ZERO]; - // if let Some(prank_dest) = prank_dest { - // cols.dest = F::from_canonical_u32(prank_dest); - // } - // if let Some(prank_source) = prank_source { - // cols.source = F::from_canonical_u32(prank_source); - // } - // if let Some(prank_len) = prank_len { - // cols.len = [F::from_canonical_u32(prank_len), F::ZERO]; - // } - // *trace = RowMajorMatrix::new(values, trace.width()); - // }; - - // disable_debug_builder(); - // let tester = tester - // .build() - // .load_and_prank_trace(chip, modify_trace) - // .load_periphery(range_checker) - // .finalize(); - - // if interaction_error { - // tester - // .simple_test() - // .expect_err("Expected verification to fail"); - // } else { - // tester - // .simple_test() - // .expect_err("Expected verification to fail"); - // } - // } - - // #[test] - // fn memcpy_wrong_shift_negative_test() { - // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; - // run_negative_memcpy_test( - // 0, // original shift - // 1, // prank shift - // &source_data, - // 100, // dest_offset - // 200, // source_offset - // 8, // len - // None, - // None, - // None, - // true, - // ); - // } - - // #[test] - // fn memcpy_wrong_dest_negative_test() { - // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; - // run_negative_memcpy_test( - // 0, // shift - // 0, // prank shift (same) - // &source_data, - // 100, // dest_offset - // 200, // source_offset - // 8, // len - // Some(150), // prank dest - // None, - // None, - // true, - // ); - // } - - // #[test] - // fn memcpy_wrong_source_negative_test() { - // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; - // run_negative_memcpy_test( - // 0, // shift - // 0, // prank shift (same) - // &source_data, - // 100, // dest_offset - // 200, // source_offset - // 8, // len - // None, - // Some(250), // prank source - // None, - // true, - // ); - // } - - // #[test] - // fn memcpy_wrong_len_negative_test() { - // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; - // run_negative_memcpy_test( - // 0, // shift - // 0, // prank shift (same) - // &source_data, - // 100, // dest_offset - // 200, // source_offset - // 8, // len - // None, - // None, - // Some(12), // prank len - // true, - // ); - // } - - // ////////////////////////////////////////////////////////////////////////////////////// - // // SANITY TESTS - // // - // // Ensure that memcpy operations produce the correct results. - // ////////////////////////////////////////////////////////////////////////////////////// - - // #[test] - // fn memcpy_shift_0_sanity_test() { - // let mut rng = create_seeded_rng(); - // let mut tester = VmChipTestBuilder::default(); - // let (mut harness, range_checker) = create_test_chip(&tester); - - // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; - - // set_and_execute_memcpy( - // &mut tester, - // &mut harness, - // &mut rng, - // 0, // shift - // &source_data, - // 100, // dest_offset - // 200, // source_offset - // 8, // len - // ); - - // // Verify the copy operation - // for i in 0..8 { - // let expected = source_data[i]; - // let actual = tester.read(2, 100 + i)[0].as_canonical_u8(); - // assert_eq!(expected, actual, "Mismatch at offset {}", i); - // } - - // let tester = tester - // .build() - // .load(harness) - // .load_periphery(range_checker) - // .finalize(); - // tester.simple_test().expect("Verification failed"); - // } - - // #[test] - // fn memcpy_shift_1_sanity_test() { - // let mut rng = create_seeded_rng(); - // let mut tester = VmChipTestBuilder::default(); - // let (mut harness, range_checker) = create_test_chip(&tester); - - // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; - - // set_and_execute_memcpy( - // &mut tester, - // &mut harness, - // &mut rng, - // 1, // shift - // &source_data, - // 100, // dest_offset - // 200, // source_offset - // 8, // len - // ); - - // // Verify the copy operation with shift=1 - // for i in 0..8 { - // let expected = source_data[i]; - // let actual = tester.read(2, 100 + i)[0].as_canonical_u8(); - // assert_eq!(expected, actual, "Mismatch at offset {}", i); - // } - - // let tester = tester - // .build() - // .load(harness) - // .load_periphery(range_checker) - // .finalize(); - // tester.simple_test().expect("Verification failed"); - // } - - // #[test] - // fn memcpy_shift_2_sanity_test() { - // let mut rng = create_seeded_rng(); - // let mut tester = VmChipTestBuilder::default(); - // let (mut harness, range_checker) = create_test_chip(&tester); - - // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; - - // set_and_execute_memcpy( - // &mut tester, - // &mut harness, - // &mut rng, - // 2, // shift - // &source_data, - // 100, // dest_offset - // 200, // source_offset - // 8, // len - // ); - - // // Verify the copy operation with shift=2 - // for i in 0..8 { - // let expected = source_data[i]; - // let actual = tester.read(2, 100 + i)[0].as_canonical_u8(); - // assert_eq!(expected, actual, "Mismatch at offset {}", i); - // } - - // let tester = tester - // .build() - // .load(harness) - // .load_periphery(range_checker) - // .finalize(); - // tester.simple_test().expect("Verification failed"); - // } - - // #[test] - // fn memcpy_shift_3_sanity_test() { - // let mut rng = create_seeded_rng(); - // let mut tester = VmChipTestBuilder::default(); - // let (mut harness, range_checker) = create_test_chip(&tester); - - // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; - - // set_and_execute_memcpy( - // &mut tester, - // &mut harness, - // &mut rng, - // 3, // shift - // &source_data, - // 100, // dest_offset - // 200, // source_offset - // 8, // len - // ); - - // // Verify the copy operation with shift=3 - // for i in 0..8 { - // let expected = source_data[i]; - // let actual = tester.read(2, 100 + i)[0].as_canonical_u8(); - // assert_eq!(expected, actual, "Mismatch at offset {}", i); - // } - - // let tester = tester - // .build() - // .load(harness) - // .load_periphery(range_checker) - // .finalize(); - // tester.simple_test().expect("Verification failed"); - // } - - // ////////////////////////////////////////////////////////////////////////////////////// - // // EDGE CASE TESTS - // // - // // Test edge cases and boundary conditions. - // ////////////////////////////////////////////////////////////////////////////////////// - - // #[test] - // fn memcpy_zero_length_test() { - // let mut rng = create_seeded_rng(); - // let mut tester = VmChipTestBuilder::default(); - // let (mut harness, range_checker) = create_test_chip(&tester); - - // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; - - // set_and_execute_memcpy( - // &mut tester, - // &mut harness, - // &mut rng, - // 0, // shift - // &source_data, - // 100, // dest_offset - // 200, // source_offset - // 0, // zero length - // ); - - // let tester = tester - // .build() - // .load(harness) - // .load_periphery(range_checker) - // .finalize(); - // tester.simple_test().expect("Verification failed"); - // } - - // #[test] - // fn memcpy_max_length_test() { - // let mut rng = create_seeded_rng(); - // let mut tester = VmChipTestBuilder::default(); - // let (mut harness, range_checker) = create_test_chip(&tester); - - // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; - - // set_and_execute_memcpy( - // &mut tester, - // &mut harness, - // &mut rng, - // 0, // shift - // &source_data, - // 100, // dest_offset - // 200, // source_offset - // 16, // max length - // ); - - // // Verify the copy operation - // for i in 0..16 { - // let expected = source_data[i]; - // let actual = tester.read(2, 100 + i)[0].as_canonical_u8(); - // assert_eq!(expected, actual, "Mismatch at offset {}", i); - // } - - // let tester = tester - // .build() - // .load(harness) - // .load_periphery(range_checker) - // .finalize(); - // tester.simple_test().expect("Verification failed"); - // } - - // #[test] - // fn memcpy_overlapping_regions_test() { - // let mut rng = create_seeded_rng(); - // let mut tester = VmChipTestBuilder::default(); - // let (mut harness, range_checker) = create_test_chip(&tester); - - // let source_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; - - // // Write initial data to destination - // for (i, &byte) in source_data.iter().enumerate() { - // tester.write(2, 100 + i as u32, [F::from_canonical_u8(byte)]); - // } - - // set_and_execute_memcpy( - // &mut tester, - // &mut harness, - // &mut rng, - // 0, // shift - // &source_data, - // 102, // dest_offset (overlapping with source) - // 100, // source_offset - // 8, // len - // ); - - // let tester = tester - // .build() - // .load(harness) - // .load_periphery(range_checker) - // .finalize(); - // tester.simple_test().expect("Verification failed"); - // } -} +} \ No newline at end of file From b37f81faea5ca6b107761b84a54947190046b8a7 Mon Sep 17 00:00:00 2001 From: Peyman Jabbarzade Date: Tue, 26 Aug 2025 12:55:10 -0400 Subject: [PATCH 10/14] only increase timestamp on actual read --- Cargo.lock | 1 + extensions/memcpy/circuit/Cargo.toml | 1 + extensions/memcpy/circuit/src/iteration.rs | 114 +++++++++++++++++---- extensions/memcpy/circuit/src/loops.rs | 14 +++ extensions/memcpy/tests/src/lib.rs | 32 +++--- 5 files changed, 126 insertions(+), 36 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ad4ec19151..200387d1c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5765,6 +5765,7 @@ dependencies = [ "openvm-stark-backend", "serde", "strum", + "tracing", ] [[package]] diff --git a/extensions/memcpy/circuit/Cargo.toml b/extensions/memcpy/circuit/Cargo.toml index 5bb1a9d2c0..f9ddef80f5 100644 --- a/extensions/memcpy/circuit/Cargo.toml +++ b/extensions/memcpy/circuit/Cargo.toml @@ -22,3 +22,4 @@ derive-new.workspace = true derive_more = { workspace = true, features = ["from"] } serde.workspace = true strum = { workspace = true } +tracing.workspace = true diff --git a/extensions/memcpy/circuit/src/iteration.rs b/extensions/memcpy/circuit/src/iteration.rs index 7b2b71373e..64f9ec28a2 100644 --- a/extensions/memcpy/circuit/src/iteration.rs +++ b/extensions/memcpy/circuit/src/iteration.rs @@ -61,7 +61,7 @@ pub struct MemcpyIterCols { pub shift: [T; 2], pub is_valid: T, pub is_valid_not_start: T, - pub is_shift_zero: T, + pub is_shift_non_zero: T, // -1 for the first iteration, 1 for the last iteration, 0 for the middle iterations pub is_boundary: T, pub data_1: [T; MEMCPY_LOOP_NUM_LIMBS], @@ -99,14 +99,14 @@ impl Air for MemcpyIterAir { let local: &MemcpyIterCols = (*local).borrow(); let timestamp: AB::Var = local.timestamp; - let mut timestamp_delta: usize = 0; - let mut timestamp_pp = || { - timestamp_delta += 1; - timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1) + let mut timestamp_delta: AB::Expr = AB::Expr::ZERO; + let mut timestamp_pp = |timestamp_increase_value: AB::Expr| { + timestamp_delta += timestamp_increase_value.clone(); + timestamp + timestamp_delta.clone() - timestamp_increase_value.clone() }; let shift = local.shift[0] * AB::Expr::TWO + local.shift[1]; - let is_shift_non_zero = not::(local.is_shift_zero); + let is_shift_zero = not::(local.is_shift_non_zero); let is_shift_one = and::(local.shift[0], not::(local.shift[1])); let is_shift_two = and::(not::(local.shift[0]), local.shift[1]); let is_shift_three = and::(local.shift[0], local.shift[1]); @@ -141,7 +141,7 @@ impl Air for MemcpyIterAir { .iter() .map(|(prev_data, next_data)| { array::from_fn::<_, MEMCPY_LOOP_NUM_LIMBS, _>(|i| { - local.is_shift_zero.clone() * (next_data[i]) + is_shift_zero.clone() * (next_data[i]) + is_shift_one.clone() * (if i < 3 { next_data[i + 1] @@ -167,7 +167,7 @@ impl Air for MemcpyIterAir { builder.assert_bool(local.is_valid); local.shift.iter().for_each(|x| builder.assert_bool(*x)); builder.assert_bool(local.is_valid_not_start); - builder.assert_bool(local.is_shift_zero); + builder.assert_bool(local.is_shift_non_zero); // is_boundary is either -1, 0 or 1 builder.assert_tern(local.is_boundary + AB::Expr::ONE); @@ -178,7 +178,7 @@ impl Air for MemcpyIterAir { ); // is_shift_non_zero is correct - builder.assert_eq(local.is_shift_zero, not::(or::(local.shift[0], local.shift[1]))); + builder.assert_eq(local.is_shift_non_zero, or::(local.shift[0], local.shift[1])); // if !is_valid, then is_boundary = 0, shift = 0 (we will use this assumption later) let mut is_not_valid_when = builder.when(not::(local.is_valid)); @@ -205,7 +205,7 @@ impl Air for MemcpyIterAir { // since is_shift_non_zero degree is 2, we need to keep the degree of the condition to 1 builder .when(not::(prev.is_valid_not_start) - not::(prev.is_valid)) - .assert_eq(local.timestamp, prev.timestamp + is_shift_non_zero.clone()); + .assert_eq(local.timestamp, prev.timestamp + local.is_shift_non_zero); // if prev.is_valid_not_start and local.is_valid_not_start, then timestamp=prev_timestamp+8 // prev.is_valid_not_start is the opposite of previous condition @@ -247,7 +247,7 @@ impl Air for MemcpyIterAir { .enumerate() .for_each(|(idx, (data, read_aux))| { let is_valid_read = if idx == 3 { - or::(is_shift_non_zero.clone(), local.is_valid_not_start) + or::(local.is_shift_non_zero, local.is_valid_not_start) } else { local.is_valid_not_start.into() }; @@ -259,10 +259,10 @@ impl Air for MemcpyIterAir { local.source - AB::Expr::from_canonical_usize(16 - idx * 4), ), *data, - timestamp_pp(), + timestamp_pp(is_valid_read.clone()), read_aux, ) - .eval(builder, is_valid_read); + .eval(builder, is_valid_read.clone()); }); // Write final data to registers @@ -274,7 +274,7 @@ impl Air for MemcpyIterAir { local.dest - AB::Expr::from_canonical_usize(16 - idx * 4), ), data.clone(), - timestamp_pp(), + timestamp_pp(local.is_valid_not_start.into()), &local.write_aux[idx], ) .eval(builder, local.is_valid_not_start); @@ -294,10 +294,10 @@ impl Air for MemcpyIterAir { ), ]; self.range_bus - .push(local.len[0].clone(), len_bits_limit[0].clone(), true) + .push(local.len[0], len_bits_limit[0].clone(), true) .eval(builder, local.is_valid); self.range_bus - .push(local.len[1].clone(), len_bits_limit[1].clone(), true) + .push(local.len[1], len_bits_limit[1].clone(), true) .eval(builder, local.is_valid); } } @@ -560,6 +560,9 @@ impl TraceFiller for MemcpyIterFiller { let mut sizes = Vec::with_capacity(rows_used >> 1); let mut chunks = Vec::with_capacity(rows_used >> 1); + let mut num_loops: usize = 0; + let mut num_iters: usize = 0; + while !trace.is_empty() { let record: &MemcpyIterRecordHeader = unsafe { get_record_from_slice(&mut trace, ()) }; let num_rows = ((record.len - record.shift as u32) >> 4) as usize + 1; @@ -567,12 +570,16 @@ impl TraceFiller for MemcpyIterFiller { sizes.push(num_rows); chunks.push(chunk); trace = rest; + num_loops += 1; + num_iters += num_rows; } - + tracing::info!("num_loops: {:?}, num_iters: {:?}", num_loops, num_iters); + chunks .par_iter_mut() .zip(sizes.par_iter()) - .for_each(|(chunk, &num_rows)| { + .enumerate() + .for_each(|(row_idx, (chunk, &num_rows))| { let record: MemcpyIterRecordMut = unsafe { get_record_from_slice( chunk, @@ -694,7 +701,7 @@ impl TraceFiller for MemcpyIterFiller { } else { F::ZERO }; - cols.is_shift_zero = F::from_canonical_u8((record.inner.shift == 0) as u8); + cols.is_shift_non_zero = F::from_canonical_u8((record.inner.shift != 0) as u8); cols.is_valid_not_start = F::from_canonical_u8(1 - is_start as u8); cols.is_valid = F::ONE; cols.shift = [record.inner.shift & 1, record.inner.shift >> 1] @@ -707,8 +714,77 @@ impl TraceFiller for MemcpyIterFiller { dest -= 16; source -= 16; len += 16; + + if row_idx == 0 && is_start { + tracing::info!("first_roooooow, timestamp: {:?}, dest: {:?}, source: {:?}, len_0: {:?}, len_1: {:?}, shift: {:?}, is_valid: {:?}, is_valid_not_start: {:?}, is_shift_non_zero: {:?}, is_boundary: {:?}, data_1: {:?}, data_2: {:?}, data_3: {:?}, data_4: {:?}, read_aux: {:?}, read_aux_lt: {:?}", + cols.timestamp.as_canonical_u32(), + cols.dest.as_canonical_u32(), + cols.source.as_canonical_u32(), + cols.len[0].as_canonical_u32(), + cols.len[1].as_canonical_u32(), + cols.shift[1].as_canonical_u32() * 2 + cols.shift[0].as_canonical_u32(), + cols.is_valid.as_canonical_u32(), + cols.is_valid_not_start.as_canonical_u32(), + cols.is_shift_non_zero.as_canonical_u32(), + cols.is_boundary.as_canonical_u32(), + cols.data_1.map(|x| x.as_canonical_u32()).to_vec(), + cols.data_2.map(|x| x.as_canonical_u32()).to_vec(), + cols.data_3.map(|x| x.as_canonical_u32()).to_vec(), + cols.data_4.map(|x| x.as_canonical_u32()).to_vec(), + cols.read_aux.map(|x| x.get_base().prev_timestamp.as_canonical_u32()).to_vec(), + cols.read_aux.map(|x| x.get_base().timestamp_lt_aux.lower_decomp.iter().map(|x| x.as_canonical_u32()).collect::>()).to_vec()); + } }); }); + chunks.iter().enumerate().map(|(row_idx, chunk)| { + chunk.chunks_exact(width) + .enumerate() + .for_each(|(idx, row)| { + let cols: &MemcpyIterCols = row.borrow(); + let is_valid_not_start = cols.is_valid_not_start.as_canonical_u32() != 0; + let is_shift_non_zero = cols.is_shift_non_zero.as_canonical_u32() != 0; + let mut bad_col = false; + cols.read_aux.iter().enumerate().for_each(|(idx, aux)| { + if is_valid_not_start || (is_shift_non_zero && idx == 3) { + let prev_t = aux.get_base().prev_timestamp.as_canonical_u32(); + let curr_t = cols.timestamp.as_canonical_u32(); + let ts_lt = aux.get_base().timestamp_lt_aux.lower_decomp.iter() + .enumerate() + .fold(F::ZERO, |acc, (i, &val)| { + acc + val * F::from_canonical_usize(1 << (i * 17)) + }).as_canonical_u32(); + if curr_t + idx as u32 != ts_lt + prev_t + 1 { + bad_col = true; + } + } + }); + if bad_col { + tracing::info!("row_idx: {:?}, idx: {:?}, timestamp: {:?}, dest: {:?}, source: {:?}, len_0: {:?}, len_1: {:?}, shift_0: {:?}, shift_1: {:?}, is_valid: {:?}, is_valid_not_start: {:?}, is_shift_non_zero: {:?}, is_boundary: {:?}, data_1: {:?}, data_2: {:?}, data_3: {:?}, data_4: {:?}, read_aux: {:?}, read_aux_lt: {:?}", + row_idx, + idx, + cols.timestamp.as_canonical_u32(), + cols.dest.as_canonical_u32(), + cols.source.as_canonical_u32(), + cols.len[0].as_canonical_u32(), + cols.len[1].as_canonical_u32(), + cols.shift[0].as_canonical_u32(), + cols.shift[1].as_canonical_u32(), + cols.is_valid.as_canonical_u32(), + cols.is_valid_not_start.as_canonical_u32(), + cols.is_shift_non_zero.as_canonical_u32(), + cols.is_boundary.as_canonical_u32(), + cols.data_1.map(|x| x.as_canonical_u32()).to_vec(), + cols.data_2.map(|x| x.as_canonical_u32()).to_vec(), + cols.data_3.map(|x| x.as_canonical_u32()).to_vec(), + cols.data_4.map(|x| x.as_canonical_u32()).to_vec(), + cols.read_aux.map(|x| x.get_base().prev_timestamp.as_canonical_u32()).to_vec(), + cols.read_aux.map(|x| x.get_base().timestamp_lt_aux.lower_decomp.iter().map(|x| x.as_canonical_u32()).collect::>()).to_vec()); + // cols.write_aux.map(|x| x.get_base().prev_timestamp.as_canonical_u32()).to_vec(), + // cols.write_aux.map(|x| x.prev_data.map(|x| x.as_canonical_u32()).to_vec()).to_vec()); + } + }); + }); + // assert!(false); } } diff --git a/extensions/memcpy/circuit/src/loops.rs b/extensions/memcpy/circuit/src/loops.rs index 6751899153..d08fa3aa22 100644 --- a/extensions/memcpy/circuit/src/loops.rs +++ b/extensions/memcpy/circuit/src/loops.rs @@ -432,6 +432,20 @@ impl MemcpyLoopChip { }); cols.source_minus_twelve_carry = F::from_bool((record.source & 0x0ffff) < 12); cols.to_source_minus_twelve_carry = F::from_bool((to_source & 0x0ffff) < 12); + + // tracing::info!("timestamp: {:?}, pc: {:?}, dest: {:?}, source: {:?}, len: {:?}, shift: {:?}, is_valid: {:?}, to_timestamp: {:?}, to_dest: {:?}, to_source: {:?}, to_len: {:?}, write_aux: {:?}", + // cols.from_state.timestamp.as_canonical_u32(), + // cols.from_state.pc.as_canonical_u32(), + // u32::from_le_bytes(cols.dest.map(|x| x.as_canonical_u32() as u8)), + // u32::from_le_bytes(cols.source.map(|x| x.as_canonical_u32() as u8)), + // u32::from_le_bytes(cols.len.map(|x| x.as_canonical_u32() as u8)), + // cols.shift[1].as_canonical_u32() * 2 + cols.shift[0].as_canonical_u32(), + // cols.is_valid.as_canonical_u32(), + // cols.to_timestamp.as_canonical_u32(), + // u32::from_le_bytes(cols.to_dest.map(|x| x.as_canonical_u32() as u8)), + // u32::from_le_bytes(cols.to_source.map(|x| x.as_canonical_u32() as u8)), + // cols.to_len.as_canonical_u32(), + // cols.write_aux.map(|x| x.prev_timestamp.as_canonical_u32()).to_vec()); } RowMajorMatrix::new(rows, NUM_MEMCPY_LOOP_COLS) } diff --git a/extensions/memcpy/tests/src/lib.rs b/extensions/memcpy/tests/src/lib.rs index b8434f4831..49d5ae7d3d 100644 --- a/extensions/memcpy/tests/src/lib.rs +++ b/extensions/memcpy/tests/src/lib.rs @@ -13,7 +13,7 @@ mod tests { SharedVariableRangeCheckerChip, VariableRangeCheckerAir, VariableRangeCheckerBus, VariableRangeCheckerChip, }; - use openvm_instructions::{instruction::Instruction, LocalOpcode, VmOpcode}; + use openvm_instructions::{instruction::Instruction, riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, LocalOpcode, VmOpcode}; use openvm_memcpy_circuit::{ MemcpyBus, MemcpyIterAir, MemcpyIterChip, MemcpyIterExecutor, MemcpyIterFiller, MemcpyLoopAir, MemcpyLoopChip, A1_REGISTER_PTR, A2_REGISTER_PTR, A3_REGISTER_PTR, @@ -108,13 +108,13 @@ mod tests { } } - tester.write(2, (source_offset + word_idx as u32 * 4) as usize, word_data); + tester.write(RV32_MEMORY_AS as usize, (source_offset + word_idx as u32 * 4) as usize, word_data); } // Set up registers that the memcpy instruction will read from // destination address tester.write::<4>( - 1, + RV32_REGISTER_AS as usize, if shift == 0 { A3_REGISTER_PTR } else { @@ -124,13 +124,13 @@ mod tests { ); // length tester.write::<4>( - 1, + RV32_REGISTER_AS as usize, A2_REGISTER_PTR, len.to_le_bytes().map(F::from_canonical_u8), ); // source address tester.write::<4>( - 1, + RV32_REGISTER_AS as usize, if shift == 0 { A4_REGISTER_PTR } else { @@ -177,11 +177,11 @@ mod tests { // passes all constraints. ////////////////////////////////////////////////////////////////////////////////////// - #[test_case(0, 1)] - #[test_case(1, 100)] - #[test_case(2, 100)] - #[test_case(3, 100)] - fn rand_memcpy_iter_test(shift: u32, num_ops: usize) { + #[test_case(0, 1, 20)] + #[test_case(1, 100, 20)] + #[test_case(2, 100, 20)] + #[test_case(3, 100, 20)] + fn rand_memcpy_iter_test(shift: u32, num_ops: usize, len: u32) { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); @@ -190,7 +190,6 @@ mod tests { for _ in 0..num_ops { let source_offset = rng.gen_range(0..250) * 4; // Ensure word alignment let dest_offset = rng.gen_range(500..750) * 4; // Ensure word alignment - let len: u32 = rng.gen_range(100..=200); let source_data: Vec = (0..len.div_ceil(4) * 4) .map(|_| rng.gen_range(0..=u8::MAX)) .collect(); @@ -222,11 +221,11 @@ mod tests { tester.simple_test().expect("Verification failed"); } - #[test_case(0, 100)] - #[test_case(1, 100)] - #[test_case(2, 100)] - #[test_case(3, 100)] - fn rand_memcpy_iter_test_persistent(shift: u32, num_ops: usize) { + #[test_case(0, 100, 20)] + #[test_case(1, 100, 20)] + #[test_case(2, 100, 20)] + #[test_case(3, 100, 20)] + fn rand_memcpy_iter_test_persistent(shift: u32, num_ops: usize, len: u32) { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default_persistent(); @@ -235,7 +234,6 @@ mod tests { for _ in 0..num_ops { let source_offset = rng.gen_range(0..250) * 4; // Ensure word alignment let dest_offset = rng.gen_range(500..750) * 4; // Ensure word alignment - let len: u32 = rng.gen_range(100..=200); let source_data: Vec = (0..len.div_ceil(4) * 4) .map(|_| rng.gen_range(0..=u8::MAX)) .collect(); From 15b7057e5e652b9846e86c66e6eff8d38247181a Mon Sep 17 00:00:00 2001 From: Peyman Jabbarzade Date: Tue, 26 Aug 2025 15:11:42 -0400 Subject: [PATCH 11/14] fix small bugs --- extensions/memcpy/circuit/src/iteration.rs | 204 +++++++++++++-------- 1 file changed, 128 insertions(+), 76 deletions(-) diff --git a/extensions/memcpy/circuit/src/iteration.rs b/extensions/memcpy/circuit/src/iteration.rs index 64f9ec28a2..5f3d62c484 100644 --- a/extensions/memcpy/circuit/src/iteration.rs +++ b/extensions/memcpy/circuit/src/iteration.rs @@ -105,7 +105,7 @@ impl Air for MemcpyIterAir { timestamp + timestamp_delta.clone() - timestamp_increase_value.clone() }; - let shift = local.shift[0] * AB::Expr::TWO + local.shift[1]; + let shift = local.shift[1] * AB::Expr::TWO + local.shift[0]; let is_shift_zero = not::(local.is_shift_non_zero); let is_shift_one = and::(local.shift[0], not::(local.shift[1])); let is_shift_two = and::(not::(local.shift[0]), local.shift[1]); @@ -236,16 +236,16 @@ impl Air for MemcpyIterAir { // Read data from memory let read_data = [ - (local.data_1, local.read_aux[0]), - (local.data_2, local.read_aux[1]), - (local.data_3, local.read_aux[2]), - (local.data_4, local.read_aux[3]), + local.data_1, + local.data_2, + local.data_3, + local.data_4, ]; read_data .iter() .enumerate() - .for_each(|(idx, (data, read_aux))| { + .for_each(|(idx, data)| { let is_valid_read = if idx == 3 { or::(local.is_shift_non_zero, local.is_valid_not_start) } else { @@ -260,7 +260,7 @@ impl Air for MemcpyIterAir { ), *data, timestamp_pp(is_valid_read.clone()), - read_aux, + &local.read_aux[idx], ) .eval(builder, is_valid_read.clone()); }); @@ -468,7 +468,7 @@ where } else if i > 0 { record.var[idx].data[i - 1][j - (4 - shift as usize)] } else { - record.var[idx - 1].data[i][j - (4 - shift as usize)] + record.var[idx - 1].data[3][j - (4 - shift as usize)] } }); write_data @@ -573,7 +573,7 @@ impl TraceFiller for MemcpyIterFiller { num_loops += 1; num_iters += num_rows; } - tracing::info!("num_loops: {:?}, num_iters: {:?}", num_loops, num_iters); + // tracing::info!("num_loops: {:?}, num_iters: {:?}", num_loops, num_iters); chunks .par_iter_mut() @@ -715,75 +715,127 @@ impl TraceFiller for MemcpyIterFiller { source -= 16; len += 16; - if row_idx == 0 && is_start { - tracing::info!("first_roooooow, timestamp: {:?}, dest: {:?}, source: {:?}, len_0: {:?}, len_1: {:?}, shift: {:?}, is_valid: {:?}, is_valid_not_start: {:?}, is_shift_non_zero: {:?}, is_boundary: {:?}, data_1: {:?}, data_2: {:?}, data_3: {:?}, data_4: {:?}, read_aux: {:?}, read_aux_lt: {:?}", - cols.timestamp.as_canonical_u32(), - cols.dest.as_canonical_u32(), - cols.source.as_canonical_u32(), - cols.len[0].as_canonical_u32(), - cols.len[1].as_canonical_u32(), - cols.shift[1].as_canonical_u32() * 2 + cols.shift[0].as_canonical_u32(), - cols.is_valid.as_canonical_u32(), - cols.is_valid_not_start.as_canonical_u32(), - cols.is_shift_non_zero.as_canonical_u32(), - cols.is_boundary.as_canonical_u32(), - cols.data_1.map(|x| x.as_canonical_u32()).to_vec(), - cols.data_2.map(|x| x.as_canonical_u32()).to_vec(), - cols.data_3.map(|x| x.as_canonical_u32()).to_vec(), - cols.data_4.map(|x| x.as_canonical_u32()).to_vec(), - cols.read_aux.map(|x| x.get_base().prev_timestamp.as_canonical_u32()).to_vec(), - cols.read_aux.map(|x| x.get_base().timestamp_lt_aux.lower_decomp.iter().map(|x| x.as_canonical_u32()).collect::>()).to_vec()); - } + // if row_idx == 0 && is_start { + // tracing::info!("first_roooooow, timestamp: {:?}, dest: {:?}, source: {:?}, len_0: {:?}, len_1: {:?}, shift: {:?}, is_valid: {:?}, is_valid_not_start: {:?}, is_shift_non_zero: {:?}, is_boundary: {:?}, data_1: {:?}, data_2: {:?}, data_3: {:?}, data_4: {:?}, read_aux: {:?}, read_aux_lt: {:?}", + // cols.timestamp.as_canonical_u32(), + // cols.dest.as_canonical_u32(), + // cols.source.as_canonical_u32(), + // cols.len[0].as_canonical_u32(), + // cols.len[1].as_canonical_u32(), + // cols.shift[1].as_canonical_u32() * 2 + cols.shift[0].as_canonical_u32(), + // cols.is_valid.as_canonical_u32(), + // cols.is_valid_not_start.as_canonical_u32(), + // cols.is_shift_non_zero.as_canonical_u32(), + // cols.is_boundary.as_canonical_u32(), + // cols.data_1.map(|x| x.as_canonical_u32()).to_vec(), + // cols.data_2.map(|x| x.as_canonical_u32()).to_vec(), + // cols.data_3.map(|x| x.as_canonical_u32()).to_vec(), + // cols.data_4.map(|x| x.as_canonical_u32()).to_vec(), + // cols.read_aux.map(|x| x.get_base().prev_timestamp.as_canonical_u32()).to_vec(), + // cols.read_aux.map(|x| x.get_base().timestamp_lt_aux.lower_decomp.iter().map(|x| x.as_canonical_u32()).collect::>()).to_vec()); + // } }); }); - chunks.iter().enumerate().map(|(row_idx, chunk)| { - chunk.chunks_exact(width) - .enumerate() - .for_each(|(idx, row)| { - let cols: &MemcpyIterCols = row.borrow(); - let is_valid_not_start = cols.is_valid_not_start.as_canonical_u32() != 0; - let is_shift_non_zero = cols.is_shift_non_zero.as_canonical_u32() != 0; - let mut bad_col = false; - cols.read_aux.iter().enumerate().for_each(|(idx, aux)| { - if is_valid_not_start || (is_shift_non_zero && idx == 3) { - let prev_t = aux.get_base().prev_timestamp.as_canonical_u32(); - let curr_t = cols.timestamp.as_canonical_u32(); - let ts_lt = aux.get_base().timestamp_lt_aux.lower_decomp.iter() - .enumerate() - .fold(F::ZERO, |acc, (i, &val)| { - acc + val * F::from_canonical_usize(1 << (i * 17)) - }).as_canonical_u32(); - if curr_t + idx as u32 != ts_lt + prev_t + 1 { - bad_col = true; - } - } - }); - if bad_col { - tracing::info!("row_idx: {:?}, idx: {:?}, timestamp: {:?}, dest: {:?}, source: {:?}, len_0: {:?}, len_1: {:?}, shift_0: {:?}, shift_1: {:?}, is_valid: {:?}, is_valid_not_start: {:?}, is_shift_non_zero: {:?}, is_boundary: {:?}, data_1: {:?}, data_2: {:?}, data_3: {:?}, data_4: {:?}, read_aux: {:?}, read_aux_lt: {:?}", - row_idx, - idx, - cols.timestamp.as_canonical_u32(), - cols.dest.as_canonical_u32(), - cols.source.as_canonical_u32(), - cols.len[0].as_canonical_u32(), - cols.len[1].as_canonical_u32(), - cols.shift[0].as_canonical_u32(), - cols.shift[1].as_canonical_u32(), - cols.is_valid.as_canonical_u32(), - cols.is_valid_not_start.as_canonical_u32(), - cols.is_shift_non_zero.as_canonical_u32(), - cols.is_boundary.as_canonical_u32(), - cols.data_1.map(|x| x.as_canonical_u32()).to_vec(), - cols.data_2.map(|x| x.as_canonical_u32()).to_vec(), - cols.data_3.map(|x| x.as_canonical_u32()).to_vec(), - cols.data_4.map(|x| x.as_canonical_u32()).to_vec(), - cols.read_aux.map(|x| x.get_base().prev_timestamp.as_canonical_u32()).to_vec(), - cols.read_aux.map(|x| x.get_base().timestamp_lt_aux.lower_decomp.iter().map(|x| x.as_canonical_u32()).collect::>()).to_vec()); - // cols.write_aux.map(|x| x.get_base().prev_timestamp.as_canonical_u32()).to_vec(), - // cols.write_aux.map(|x| x.prev_data.map(|x| x.as_canonical_u32()).to_vec()).to_vec()); - } - }); - }); + + // chunks.iter().enumerate().for_each(|(row_idx, chunk)| { + // let mut prv_data = [0; 4]; + // tracing::info!("row_idx: {:?}", row_idx); + + // chunk.chunks_exact(width) + // .enumerate() + // .for_each(|(idx, row)| { + // let cols: &MemcpyIterCols = row.borrow(); + // let is_valid_not_start = cols.is_valid_not_start.as_canonical_u32() != 0; + // let is_shift_non_zero = cols.is_shift_non_zero.as_canonical_u32() != 0; + // let source = cols.source.as_canonical_u32(); + // let dest = cols.dest.as_canonical_u32(); + // let mut bad_col = false; + // tracing::info!("source: {:?}, dest: {:?}", source, dest); + // cols.read_aux.iter().enumerate().for_each(|(idx, aux)| { + // if is_valid_not_start || (is_shift_non_zero && idx == 3) { + // let prev_t = aux.get_base().prev_timestamp.as_canonical_u32(); + // let curr_t = cols.timestamp.as_canonical_u32(); + // let ts_lt = aux.get_base().timestamp_lt_aux.lower_decomp.iter() + // .enumerate() + // .fold(F::ZERO, |acc, (i, &val)| { + // acc + val * F::from_canonical_usize(1 << (i * 17)) + // }).as_canonical_u32(); + // if curr_t + idx as u32 != ts_lt + prev_t + 1 { + // bad_col = true; + // } + // } + // if dest + 4 * idx as u32 == 2097216 || dest - 4 * (idx + 1) as u32 == 2097216 || dest + 4 * idx as u32 == 2097280 || dest - 4 * (idx + 1) as u32 == 2097280 { + // bad_col = true; + // } + // }); + // if bad_col { + // let write_data_pairs = [ + // (prv_data, cols.data_1.map(|x| x.as_canonical_u32())), + // (cols.data_1.map(|x| x.as_canonical_u32()), cols.data_2.map(|x| x.as_canonical_u32())), + // (cols.data_2.map(|x| x.as_canonical_u32()), cols.data_3.map(|x| x.as_canonical_u32())), + // (cols.data_3.map(|x| x.as_canonical_u32()), cols.data_4.map(|x| x.as_canonical_u32())), + // ]; + + // let shift = cols.shift[1].as_canonical_u32() * 2 + cols.shift[0].as_canonical_u32(); + // let write_data = write_data_pairs + // .iter() + // .map(|(prev_data, next_data)| { + // array::from_fn::<_, MEMCPY_LOOP_NUM_LIMBS, _>(|i| { + // (shift == 0) as u32 * (next_data[i]) + // + (shift == 1) as u32 + // * (if i < 3 { + // next_data[i + 1] + // } else { + // prev_data[i - 3] + // }) + // + (shift == 2) as u32 + // * (if i < 2 { + // next_data[i + 2] + // } else { + // prev_data[i - 2] + // }) + // + (shift == 3) as u32 + // * (if i < 1 { + // next_data[i + 3] + // } else { + // prev_data[i - 1] + // }) + // }) + // }) + // .collect::>(); + + + + + // tracing::info!("row_idx: {:?}, idx: {:?}, timestamp: {:?}, dest: {:?}, source: {:?}, len_0: {:?}, len_1: {:?}, shift_0: {:?}, shift_1: {:?}, is_valid: {:?}, is_valid_not_start: {:?}, is_shift_non_zero: {:?}, is_boundary: {:?}, write_data: {:?}, prv_data: {:?}, data_1: {:?}, data_2: {:?}, data_3: {:?}, data_4: {:?}, read_aux: {:?}, read_aux_lt: {:?}, write_aux: {:?}, write_aux_lt: {:?}, write_aux_prev_data: {:?}", + // row_idx, + // idx, + // cols.timestamp.as_canonical_u32(), + // cols.dest.as_canonical_u32(), + // cols.source.as_canonical_u32(), + // cols.len[0].as_canonical_u32(), + // cols.len[1].as_canonical_u32(), + // cols.shift[0].as_canonical_u32(), + // cols.shift[1].as_canonical_u32(), + // cols.is_valid.as_canonical_u32(), + // cols.is_valid_not_start.as_canonical_u32(), + // cols.is_shift_non_zero.as_canonical_u32(), + // cols.is_boundary.as_canonical_u32(), + // write_data, + // prv_data, + // cols.data_1.map(|x| x.as_canonical_u32()).to_vec(), + // cols.data_2.map(|x| x.as_canonical_u32()).to_vec(), + // cols.data_3.map(|x| x.as_canonical_u32()).to_vec(), + // cols.data_4.map(|x| x.as_canonical_u32()).to_vec(), + // cols.read_aux.map(|x| x.get_base().prev_timestamp.as_canonical_u32()).to_vec(), + // cols.read_aux.map(|x| x.get_base().timestamp_lt_aux.lower_decomp.iter().map(|x| x.as_canonical_u32()).collect::>()).to_vec(), + // cols.write_aux.map(|x| x.get_base().prev_timestamp.as_canonical_u32()).to_vec(), + // cols.write_aux.map(|x| x.get_base().timestamp_lt_aux.lower_decomp.iter().map(|x| x.as_canonical_u32()).collect::>()).to_vec(), + // cols.write_aux.map(|x| x.prev_data.map(|x| x.as_canonical_u32()).to_vec()).to_vec()); + // } + // prv_data = cols.data_4.map(|x| x.as_canonical_u32()); + // }); + // }); // assert!(false); } } From 7d256b301266a41579ea704a235845b293070d93 Mon Sep 17 00:00:00 2001 From: Peyman Jabbarzade Date: Tue, 26 Aug 2025 18:15:55 -0400 Subject: [PATCH 12/14] fix: reduce memcpyIterAir degree from 4 to 3 --- Cargo.lock | 4 +- extensions/bigint/circuit/Cargo.toml | 1 + .../bigint/circuit/src/extension/mod.rs | 6 ++ extensions/bigint/circuit/src/lib.rs | 4 + extensions/keccak256/circuit/Cargo.toml | 1 + .../keccak256/circuit/src/extension/mod.rs | 9 ++ extensions/memcpy/circuit/Cargo.toml | 1 - extensions/memcpy/circuit/src/iteration.rs | 61 +++++++----- extensions/memcpy/circuit/src/lib.rs | 95 +++++++++++++++++++ extensions/rv32im/circuit/Cargo.toml | 1 + extensions/rv32im/circuit/src/lib.rs | 7 ++ 11 files changed, 163 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 200387d1c3..e539f1dfae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5323,6 +5323,7 @@ dependencies = [ "openvm-cuda-builder", "openvm-cuda-common", "openvm-instructions", + "openvm-memcpy-circuit", "openvm-rv32-adapters", "openvm-rv32im-circuit", "openvm-rv32im-transpiler", @@ -5711,6 +5712,7 @@ dependencies = [ "openvm-cuda-common", "openvm-instructions", "openvm-keccak256-transpiler", + "openvm-memcpy-circuit", "openvm-rv32im-circuit", "openvm-stark-backend", "openvm-stark-sdk", @@ -5760,7 +5762,6 @@ dependencies = [ "openvm-circuit-primitives-derive", "openvm-instructions", "openvm-memcpy-transpiler", - "openvm-rv32im-circuit", "openvm-rv32im-transpiler", "openvm-stark-backend", "serde", @@ -6108,6 +6109,7 @@ dependencies = [ "openvm-cuda-builder", "openvm-cuda-common", "openvm-instructions", + "openvm-memcpy-circuit", "openvm-rv32im-transpiler", "openvm-stark-backend", "openvm-stark-sdk", diff --git a/extensions/bigint/circuit/Cargo.toml b/extensions/bigint/circuit/Cargo.toml index c431fae6e5..1398a1fa10 100644 --- a/extensions/bigint/circuit/Cargo.toml +++ b/extensions/bigint/circuit/Cargo.toml @@ -21,6 +21,7 @@ openvm-rv32im-circuit = { workspace = true } openvm-rv32-adapters = { workspace = true } openvm-bigint-transpiler = { workspace = true } openvm-rv32im-transpiler = { workspace = true } +openvm-memcpy-circuit = { workspace = true } derive-new.workspace = true derive_more = { workspace = true, features = ["from"] } diff --git a/extensions/bigint/circuit/src/extension/mod.rs b/extensions/bigint/circuit/src/extension/mod.rs index 1725a4860d..cff327da98 100644 --- a/extensions/bigint/circuit/src/extension/mod.rs +++ b/extensions/bigint/circuit/src/extension/mod.rs @@ -25,6 +25,7 @@ use openvm_circuit_primitives::{ }, }; use openvm_instructions::{program::DEFAULT_PC_STEP, LocalOpcode}; +use openvm_memcpy_circuit::MemcpyCpuProverExt; use openvm_rv32im_circuit::Rv32ImCpuProverExt; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, @@ -373,6 +374,11 @@ where &config.bigint, inventory, )?; + VmProverExtension::::extend_prover( + &MemcpyCpuProverExt, + &config.memcpy, + inventory, + )?; Ok(chip_complex) } } diff --git a/extensions/bigint/circuit/src/lib.rs b/extensions/bigint/circuit/src/lib.rs index 2c2b24b655..44f1ac66bf 100644 --- a/extensions/bigint/circuit/src/lib.rs +++ b/extensions/bigint/circuit/src/lib.rs @@ -7,6 +7,7 @@ use openvm_circuit::{ system::SystemExecutor, }; use openvm_circuit_derive::{PreflightExecutor, VmConfig}; +use openvm_memcpy_circuit::{Memcpy, MemcpyExecutor}; use openvm_rv32_adapters::{ Rv32HeapAdapterAir, Rv32HeapAdapterExecutor, Rv32HeapAdapterFiller, Rv32HeapBranchAdapterAir, Rv32HeapBranchAdapterExecutor, Rv32HeapBranchAdapterFiller, @@ -176,6 +177,8 @@ pub struct Int256Rv32Config { pub io: Rv32Io, #[extension] pub bigint: Int256, + #[extension] + pub memcpy: Memcpy, } // Default implementation uses no init file @@ -189,6 +192,7 @@ impl Default for Int256Rv32Config { rv32m: Rv32M::default(), io: Rv32Io, bigint: Int256::default(), + memcpy: Memcpy, } } } diff --git a/extensions/keccak256/circuit/Cargo.toml b/extensions/keccak256/circuit/Cargo.toml index 24e10deb10..28998f2f59 100644 --- a/extensions/keccak256/circuit/Cargo.toml +++ b/extensions/keccak256/circuit/Cargo.toml @@ -18,6 +18,7 @@ openvm-circuit = { workspace = true } openvm-circuit-derive = { workspace = true } openvm-instructions = { workspace = true } openvm-rv32im-circuit = { workspace = true } +openvm-memcpy-circuit = { workspace = true } openvm-keccak256-transpiler = { workspace = true } p3-keccak-air = { workspace = true } diff --git a/extensions/keccak256/circuit/src/extension/mod.rs b/extensions/keccak256/circuit/src/extension/mod.rs index 9f6e55a540..7a07080f0f 100644 --- a/extensions/keccak256/circuit/src/extension/mod.rs +++ b/extensions/keccak256/circuit/src/extension/mod.rs @@ -20,6 +20,7 @@ use openvm_circuit_primitives::bitwise_op_lookup::{ }; use openvm_instructions::*; use openvm_keccak256_transpiler::Rv32KeccakOpcode; +use openvm_memcpy_circuit::{Memcpy, MemcpyCpuProverExt, MemcpyExecutor}; use openvm_rv32im_circuit::{ Rv32I, Rv32IExecutor, Rv32ImCpuProverExt, Rv32Io, Rv32IoExecutor, Rv32M, Rv32MExecutor, }; @@ -62,6 +63,8 @@ pub struct Keccak256Rv32Config { pub io: Rv32Io, #[extension] pub keccak: Keccak256, + #[extension] + pub memcpy: Memcpy, } impl Default for Keccak256Rv32Config { @@ -72,6 +75,7 @@ impl Default for Keccak256Rv32Config { rv32m: Rv32M::default(), io: Rv32Io, keccak: Keccak256, + memcpy: Memcpy, } } } @@ -111,6 +115,11 @@ where &config.keccak, inventory, )?; + VmProverExtension::::extend_prover( + &MemcpyCpuProverExt, + &config.memcpy, + inventory, + )?; Ok(chip_complex) } } diff --git a/extensions/memcpy/circuit/Cargo.toml b/extensions/memcpy/circuit/Cargo.toml index f9ddef80f5..c5c4034ff8 100644 --- a/extensions/memcpy/circuit/Cargo.toml +++ b/extensions/memcpy/circuit/Cargo.toml @@ -16,7 +16,6 @@ openvm-instructions = { workspace = true } openvm-stark-backend = { workspace = true } openvm-memcpy-transpiler = { path = "../transpiler" } openvm-rv32im-transpiler = { workspace = true } -openvm-rv32im-circuit = { workspace = true } derive-new.workspace = true derive_more = { workspace = true, features = ["from"] } diff --git a/extensions/memcpy/circuit/src/iteration.rs b/extensions/memcpy/circuit/src/iteration.rs index 5f3d62c484..c8872bbe45 100644 --- a/extensions/memcpy/circuit/src/iteration.rs +++ b/extensions/memcpy/circuit/src/iteration.rs @@ -34,7 +34,6 @@ use openvm_instructions::{ LocalOpcode, }; use openvm_memcpy_transpiler::Rv32MemcpyOpcode; -use openvm_rv32im_circuit::adapters::{read_rv32_register, tracing_read, tracing_write}; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::{Air, AirBuilder, BaseAir}, @@ -45,8 +44,7 @@ use openvm_stark_backend::{ }; use crate::{ - bus::MemcpyBus, MemcpyLoopChip, A1_REGISTER_PTR, A2_REGISTER_PTR, A3_REGISTER_PTR, - A4_REGISTER_PTR, + bus::MemcpyBus, read_rv32_register, tracing_read, tracing_write, MemcpyLoopChip, A1_REGISTER_PTR, A2_REGISTER_PTR, A3_REGISTER_PTR, A4_REGISTER_PTR }; // Import constants from lib.rs use crate::{MEMCPY_LOOP_LIMB_BITS, MEMCPY_LOOP_NUM_LIMBS}; @@ -58,10 +56,12 @@ pub struct MemcpyIterCols { pub dest: T, pub source: T, pub len: [T; 2], - pub shift: [T; 2], + // 0: [0, 0, 0], 1: [1, 0, 0], 2: [0, 1, 0], 3: [0, 0, 1] + pub shift: [T; 3], pub is_valid: T, pub is_valid_not_start: T, - pub is_shift_non_zero: T, + // This should be 0 if is_valid = 0. We use this to determine whether we need ro read data_4. + pub is_shift_non_zero_or_not_start: T, // -1 for the first iteration, 1 for the last iteration, 0 for the middle iterations pub is_boundary: T, pub data_1: [T; MEMCPY_LOOP_NUM_LIMBS], @@ -100,16 +100,21 @@ impl Air for MemcpyIterAir { let timestamp: AB::Var = local.timestamp; let mut timestamp_delta: AB::Expr = AB::Expr::ZERO; - let mut timestamp_pp = |timestamp_increase_value: AB::Expr| { - timestamp_delta += timestamp_increase_value.clone(); + let mut timestamp_pp = |timestamp_increase_value: AB::Var| { + timestamp_delta += timestamp_increase_value.into(); timestamp + timestamp_delta.clone() - timestamp_increase_value.clone() }; - let shift = local.shift[1] * AB::Expr::TWO + local.shift[0]; - let is_shift_zero = not::(local.is_shift_non_zero); - let is_shift_one = and::(local.shift[0], not::(local.shift[1])); - let is_shift_two = and::(not::(local.shift[0]), local.shift[1]); - let is_shift_three = and::(local.shift[0], local.shift[1]); + let shift = local.shift.iter().enumerate().fold(AB::Expr::ZERO, |acc, (i, x)| { + acc + (*x) * AB::Expr::from_canonical_u32(i as u32 + 1) + }); + let is_shift_non_zero = local.shift.iter().fold(AB::Expr::ZERO, |acc, x| { + acc + (*x) + }); + let is_shift_zero = not::(is_shift_non_zero.clone()); + let is_shift_one = local.shift[0]; + let is_shift_two = local.shift[1]; + let is_shift_three = local.shift[2]; let is_end = (local.is_boundary + AB::Expr::ONE) * local.is_boundary * (AB::F::TWO).inverse(); @@ -166,8 +171,9 @@ impl Air for MemcpyIterAir { builder.assert_bool(local.is_valid); local.shift.iter().for_each(|x| builder.assert_bool(*x)); + builder.assert_bool(is_shift_non_zero.clone()); builder.assert_bool(local.is_valid_not_start); - builder.assert_bool(local.is_shift_non_zero); + builder.assert_bool(local.is_shift_non_zero_or_not_start); // is_boundary is either -1, 0 or 1 builder.assert_tern(local.is_boundary + AB::Expr::ONE); @@ -177,8 +183,8 @@ impl Air for MemcpyIterAir { and::(local.is_valid, is_not_start), ); - // is_shift_non_zero is correct - builder.assert_eq(local.is_shift_non_zero, or::(local.shift[0], local.shift[1])); + // is_shift_non_zero_or_not_start is correct + builder.assert_eq(local.is_shift_non_zero_or_not_start, or::(is_shift_non_zero.clone(), local.is_valid_not_start)); // if !is_valid, then is_boundary = 0, shift = 0 (we will use this assumption later) let mut is_not_valid_when = builder.when(not::(local.is_valid)); @@ -193,8 +199,9 @@ impl Air for MemcpyIterAir { is_valid_not_start_when .assert_eq(local.source, prev.source + AB::Expr::from_canonical_u32(16)); is_valid_not_start_when.assert_eq(local.dest, prev.dest + AB::Expr::from_canonical_u32(16)); - is_valid_not_start_when.assert_eq(local.shift[0], prev.shift[0]); - is_valid_not_start_when.assert_eq(local.shift[1], prev.shift[1]); + local.shift.iter().zip(prev.shift.iter()).for_each(|(local_shift, prev_shift)| { + is_valid_not_start_when.assert_eq(*local_shift, *prev_shift); + }); // make sure if previous row is valid and not end, then local.is_valid = 1 builder @@ -205,7 +212,7 @@ impl Air for MemcpyIterAir { // since is_shift_non_zero degree is 2, we need to keep the degree of the condition to 1 builder .when(not::(prev.is_valid_not_start) - not::(prev.is_valid)) - .assert_eq(local.timestamp, prev.timestamp + local.is_shift_non_zero); + .assert_eq(local.timestamp, prev.timestamp + is_shift_non_zero); // if prev.is_valid_not_start and local.is_valid_not_start, then timestamp=prev_timestamp+8 // prev.is_valid_not_start is the opposite of previous condition @@ -247,9 +254,9 @@ impl Air for MemcpyIterAir { .enumerate() .for_each(|(idx, data)| { let is_valid_read = if idx == 3 { - or::(local.is_shift_non_zero, local.is_valid_not_start) + local.is_shift_non_zero_or_not_start } else { - local.is_valid_not_start.into() + local.is_valid_not_start }; self.memory_bridge @@ -274,7 +281,7 @@ impl Air for MemcpyIterAir { local.dest - AB::Expr::from_canonical_usize(16 - idx * 4), ), data.clone(), - timestamp_pp(local.is_valid_not_start.into()), + timestamp_pp(local.is_valid_not_start), &local.write_aux[idx], ) .eval(builder, local.is_valid_not_start); @@ -701,11 +708,15 @@ impl TraceFiller for MemcpyIterFiller { } else { F::ZERO }; - cols.is_shift_non_zero = F::from_canonical_u8((record.inner.shift != 0) as u8); - cols.is_valid_not_start = F::from_canonical_u8(1 - is_start as u8); + cols.is_shift_non_zero_or_not_start = F::from_bool(record.inner.shift != 0 || !is_start); + cols.is_valid_not_start = F::from_bool(!is_start); cols.is_valid = F::ONE; - cols.shift = [record.inner.shift & 1, record.inner.shift >> 1] - .map(F::from_canonical_u8); + cols.shift = [ + record.inner.shift == 1, + record.inner.shift == 2, + record.inner.shift == 3, + ] + .map(F::from_bool); cols.len = [len & 0xffff, len >> 16].map(F::from_canonical_u32); cols.source = F::from_canonical_u32(source); cols.dest = F::from_canonical_u32(dest); diff --git a/extensions/memcpy/circuit/src/lib.rs b/extensions/memcpy/circuit/src/lib.rs index e81deb9e37..28b63d6a65 100644 --- a/extensions/memcpy/circuit/src/lib.rs +++ b/extensions/memcpy/circuit/src/lib.rs @@ -7,6 +7,8 @@ pub use bus::*; pub use extension::*; pub use iteration::*; pub use loops::*; +use openvm_circuit::system::memory::{merkle::public_values::PUBLIC_VALUES_AS, online::{GuestMemory, TracingMemory}}; +use openvm_instructions::riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}; // ==== Do not change these constants! ==== pub const MEMCPY_LOOP_NUM_LIMBS: usize = 4; @@ -16,3 +18,96 @@ pub const A1_REGISTER_PTR: usize = 11 * 4; pub const A2_REGISTER_PTR: usize = 12 * 4; pub const A3_REGISTER_PTR: usize = 13 * 4; pub const A4_REGISTER_PTR: usize = 14 * 4; + + +// TODO: These are duplicated from extensions/rv32im/circuit/src/adapters/mod.rs +// to prevent cyclic dependencies. Fix this. + +#[inline(always)] +pub fn memory_read(memory: &GuestMemory, address_space: u32, ptr: u32) -> [u8; N] { + debug_assert!( + address_space == RV32_REGISTER_AS + || address_space == RV32_MEMORY_AS + || address_space == PUBLIC_VALUES_AS, + ); + + // SAFETY: + // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and + // minimum alignment of `RV32_REGISTER_NUM_LIMBS` + unsafe { memory.read::(address_space, ptr) } +} + +/// Atomic read operation which increments the timestamp by 1. +/// Returns `(t_prev, [ptr:4]_{address_space})` where `t_prev` is the timestamp of the last memory +/// access. +#[inline(always)] +pub fn timed_read( + memory: &mut TracingMemory, + address_space: u32, + ptr: u32, +) -> (u32, [u8; N]) { + debug_assert!( + address_space == RV32_REGISTER_AS + || address_space == RV32_MEMORY_AS + || address_space == PUBLIC_VALUES_AS + ); + + // SAFETY: + // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and + // minimum alignment of `MEMCPY_LOOP_NUM_LIMBS` + unsafe { memory.read::(address_space, ptr) } +} + +#[inline(always)] +pub fn timed_write( + memory: &mut TracingMemory, + address_space: u32, + ptr: u32, + data: [u8; N], +) -> (u32, [u8; N]) { + debug_assert!( + address_space == RV32_REGISTER_AS + || address_space == RV32_MEMORY_AS + || address_space == PUBLIC_VALUES_AS + ); + + // SAFETY: + // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and + // minimum alignment of `MEMCPY_LOOP_NUM_LIMBS` + unsafe { memory.write::(address_space, ptr, data) } +} + +/// Reads register value at `reg_ptr` from memory and records the memory access in mutable buffer. +/// Trace generation relevant to this memory access can be done fully from the recorded buffer. +#[inline(always)] +pub fn tracing_read( + memory: &mut TracingMemory, + address_space: u32, + ptr: u32, + prev_timestamp: &mut u32, +) -> [u8; N] { + let (t_prev, data) = timed_read(memory, address_space, ptr); + *prev_timestamp = t_prev; + data +} + +/// Writes `reg_ptr, reg_val` into memory and records the memory access in mutable buffer. +/// Trace generation relevant to this memory access can be done fully from the recorded buffer. +#[inline(always)] +pub fn tracing_write( + memory: &mut TracingMemory, + address_space: u32, + ptr: u32, + data: [u8; N], + prev_timestamp: &mut u32, + prev_data: &mut [u8; N], +) { + let (t_prev, data_prev) = timed_write(memory, address_space, ptr, data); + *prev_timestamp = t_prev; + *prev_data = data_prev; +} + +#[inline(always)] +pub fn read_rv32_register(memory: &GuestMemory, ptr: u32) -> u32 { + u32::from_le_bytes(memory_read(memory, RV32_REGISTER_AS, ptr)) +} diff --git a/extensions/rv32im/circuit/Cargo.toml b/extensions/rv32im/circuit/Cargo.toml index 20659b693a..acb352cc45 100644 --- a/extensions/rv32im/circuit/Cargo.toml +++ b/extensions/rv32im/circuit/Cargo.toml @@ -18,6 +18,7 @@ openvm-circuit = { workspace = true } openvm-circuit-derive = { workspace = true } openvm-instructions = { workspace = true } openvm-rv32im-transpiler = { workspace = true } +openvm-memcpy-circuit = { workspace = true } strum.workspace = true derive-new.workspace = true derive_more = { workspace = true, features = ["from"] } diff --git a/extensions/rv32im/circuit/src/lib.rs b/extensions/rv32im/circuit/src/lib.rs index 02f26c0306..64598927d7 100644 --- a/extensions/rv32im/circuit/src/lib.rs +++ b/extensions/rv32im/circuit/src/lib.rs @@ -9,6 +9,7 @@ use openvm_circuit::{ system::{SystemChipInventory, SystemCpuBuilder, SystemExecutor}, }; use openvm_circuit_derive::{Executor, PreflightExecutor, VmConfig}; +use openvm_memcpy_circuit::{Memcpy, MemcpyCpuProverExt, MemcpyExecutor}; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, engine::StarkEngine, @@ -82,6 +83,8 @@ pub struct Rv32IConfig { pub base: Rv32I, #[extension] pub io: Rv32Io, + #[extension] + pub memcpy: Memcpy, } // Default implementation uses no init file @@ -106,6 +109,7 @@ impl Default for Rv32IConfig { system, base: Default::default(), io: Default::default(), + memcpy: Memcpy, } } } @@ -117,6 +121,7 @@ impl Rv32IConfig { system, base: Default::default(), io: Default::default(), + memcpy: Memcpy, } } @@ -128,6 +133,7 @@ impl Rv32IConfig { system, base: Default::default(), io: Default::default(), + memcpy: Memcpy, } } } @@ -174,6 +180,7 @@ where let inventory = &mut chip_complex.inventory; VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.base, inventory)?; VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.io, inventory)?; + VmProverExtension::::extend_prover(&MemcpyCpuProverExt, &config.memcpy, inventory)?; Ok(chip_complex) } } From f1fb3bc1c292a403b20f3d1f21bbe841e6c64654 Mon Sep 17 00:00:00 2001 From: Peyman Jabbarzade Date: Wed, 27 Aug 2025 18:45:01 -0400 Subject: [PATCH 13/14] change fibonacci test to memcpy test --- Cargo.lock | 6 +-- benchmarks/guest/fibonacci/openvm.toml | 1 + benchmarks/guest/fibonacci/src/main.rs | 52 +++++++++++++++++----- crates/toolchain/openvm/src/memcpy.s | 4 -- extensions/memcpy/README.md | 25 +++++------ extensions/memcpy/circuit/src/iteration.rs | 35 ++++++++++++--- extensions/memcpy/circuit/src/loops.rs | 4 ++ 7 files changed, 88 insertions(+), 39 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e539f1dfae..dfb52f8b06 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5752,7 +5752,7 @@ dependencies = [ [[package]] name = "openvm-memcpy-circuit" -version = "1.4.0-rc.6" +version = "1.4.0-rc.8" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", @@ -5771,7 +5771,7 @@ dependencies = [ [[package]] name = "openvm-memcpy-integration-tests" -version = "1.4.0-rc.6" +version = "1.4.0-rc.8" dependencies = [ "eyre", "openvm", @@ -5791,7 +5791,7 @@ dependencies = [ [[package]] name = "openvm-memcpy-transpiler" -version = "1.4.0-rc.6" +version = "1.4.0-rc.8" dependencies = [ "openvm-instructions", "openvm-instructions-derive", diff --git a/benchmarks/guest/fibonacci/openvm.toml b/benchmarks/guest/fibonacci/openvm.toml index 19a1e670e5..64b1f884d4 100644 --- a/benchmarks/guest/fibonacci/openvm.toml +++ b/benchmarks/guest/fibonacci/openvm.toml @@ -1,3 +1,4 @@ [app_vm_config.rv32i] [app_vm_config.rv32m] [app_vm_config.io] +[app_vm_config.memcpy] \ No newline at end of file diff --git a/benchmarks/guest/fibonacci/src/main.rs b/benchmarks/guest/fibonacci/src/main.rs index 158a3fa0ec..6b798ac752 100644 --- a/benchmarks/guest/fibonacci/src/main.rs +++ b/benchmarks/guest/fibonacci/src/main.rs @@ -1,14 +1,44 @@ -use openvm::io::{read, reveal_u32}; +use core::ptr; -pub fn main() { - let n: u64 = read(); - let mut a: u64 = 0; - let mut b: u64 = 1; - for _ in 0..n { - let c: u64 = a.wrapping_add(b); - a = b; - b = c; +openvm::entry!(main); + +/// Moves all the elements of `src` into `dst`, leaving `src` empty. +#[no_mangle] +pub fn append(dst: &mut [T], src: &mut [T], shift: usize) { + let src_len = src.len(); + let dst_len = dst.len(); + + unsafe { + // The call to add is always safe because `Vec` will never + // allocate more than `isize::MAX` bytes. + let dst_ptr = dst.as_mut_ptr().wrapping_add(shift); + let src_ptr = src.as_ptr(); + println!("dst_ptr: {:?}", dst_ptr); + println!("src_ptr: {:?}", src_ptr); + println!("src_len: {:?}", src_len); + + // The two regions cannot overlap because mutable references do + // not alias, and two different vectors cannot own the same + // memory. + ptr::copy_nonoverlapping(src_ptr, dst_ptr, src_len); } - reveal_u32(a as u32, 0); - reveal_u32((a >> 32) as u32, 1); } + +pub fn main() { + let mut a: [u8; 1000] = [1; 1000]; + let mut b: [u8; 500] = [2; 500]; + + let shift: usize = 0; + append(&mut a, &mut b, shift); + + for i in 0..1000 { + if i < shift || i >= shift + b.len() { + assert_eq!(a[i], 1); + } else { + assert_eq!(a[i], 2); + } + } + + println!("a: {:?}", a); + println!("b: {:?}", b); +} \ No newline at end of file diff --git a/crates/toolchain/openvm/src/memcpy.s b/crates/toolchain/openvm/src/memcpy.s index 1606d576dc..a45a82d16a 100644 --- a/crates/toolchain/openvm/src/memcpy.s +++ b/crates/toolchain/openvm/src/memcpy.s @@ -255,7 +255,6 @@ memcpy: sb a6, 2(a3) addi a2, a2, -3 addi a3, a4, 16 - li a4, 16 .LBBmemcpy0_9: memcpy_loop 1 addi a4, a3, -13 @@ -268,7 +267,6 @@ memcpy: .LBBmemcpy0_12: li a1, 16 bltu a2, a1, .LBBmemcpy0_15 - li a1, 15 .LBBmemcpy0_14: memcpy_loop 0 .LBBmemcpy0_15: @@ -294,7 +292,6 @@ memcpy: sb a5, 0(a3) addi a2, a2, -1 addi a3, a4, 16 - li a4, 18 .LBBmemcpy0_20: memcpy_loop 3 addi a4, a3, -15 @@ -307,7 +304,6 @@ memcpy: sb a6, 1(a3) addi a2, a2, -2 addi a3, a4, 16 - li a4, 17 .LBBmemcpy0_23: memcpy_loop 2 addi a4, a3, -14 diff --git a/extensions/memcpy/README.md b/extensions/memcpy/README.md index da381e2e3c..b85fe256b5 100644 --- a/extensions/memcpy/README.md +++ b/extensions/memcpy/README.md @@ -11,28 +11,20 @@ memcpy_loop shift Where `shift` is an immediate value (0, 1, 2, or 3) representing the byte alignment shift. -### RISC-V Encoding -- **Opcode**: `0x73` (custom opcode) -- **Funct3**: `0x0` (custom funct3) -- **Immediate**: 12-bit signed immediate for shift value -- **Format**: I-type instruction ### Usage The `memcpy_loop` instruction is designed to replace repetitive shift-handling code in memcpy implementations. Instead of having separate code blocks for each shift value, you can use a single instruction: ```assembly -# Instead of this repetitive code: -.Lshift_1: - lw a5, 0(a4) - sb a5, 0(a3) - srli a1, a5, 8 - sb a1, 1(a3) - # ... more shift handling code - -# You can use: memcpy_loop 1 # Handles shift=1 case ``` +Note that you must define `memcpy_loop` before using it. For example, in [memcpy.s](../../crates/toolchain/openvm/src/memcpy.s) it is defined at the beginning of the assembly code as follows: +```assembly +.macro memcpy_loop shift + .word 0x00000072 | (\shift << 12) # opcode 0x72 + shift in immediate field (bits 12-31) +``` + ### Benefits 1. **Code Size Reduction**: Eliminates repetitive shift-handling code 2. **Performance**: Optimized implementation in the circuit layer @@ -84,3 +76,8 @@ RISC-V Assembly → Transpiler Extension → OpenVM Instruction → MemcpyIterat The extension provides: - **Transpiler**: `extensions/memcpy/transpiler/` - Translates RISC-V to OpenVM - **Circuit**: `extensions/memcpy/circuit/` - Implements the instruction logic + + +# References + +- Official Keccak [spec summary](https://keccak.team/keccak_specs_summary.html) diff --git a/extensions/memcpy/circuit/src/iteration.rs b/extensions/memcpy/circuit/src/iteration.rs index c8872bbe45..38a31e362d 100644 --- a/extensions/memcpy/circuit/src/iteration.rs +++ b/extensions/memcpy/circuit/src/iteration.rs @@ -18,7 +18,7 @@ use openvm_circuit::{ MemoryWriteAuxCols, MemoryWriteBytesAuxRecord, }, online::{GuestMemory, TracingMemory}, - MemoryAddress, MemoryAuxColsFactory, + MemoryAddress, MemoryAuxColsFactory, POINTER_MAX_BITS, }, }; use openvm_circuit_primitives::{ @@ -441,6 +441,7 @@ where ); let mut len = read_rv32_register(state.memory.data(), A2_REGISTER_PTR as u32); + // Create a record with var_size = ((len - shift) >> 4) + 1 which is the number of rows in iteration trace let record = state.ctx.alloc(MultiRowLayout::new(MemcpyIterMetadata { num_rows: ((len - shift as u32) >> 4) as usize + 1, })); @@ -449,7 +450,11 @@ where record.inner.shift = shift; record.inner.from_pc = *state.pc; record.inner.from_timestamp = state.memory.timestamp; + record.inner.dest = dest; + record.inner.source = source; + record.inner.len = len; + // Fill record.var for the first row of iteration trace if shift != 0 { source -= 12; record.var[0].data[3] = tracing_read( @@ -460,6 +465,7 @@ where ); }; + // Fill record.var for the rest of the rows of iteration trace let mut idx = 1; while len - shift as u32 > 15 { let writes_data: [[u8; MEMCPY_LOOP_NUM_LIMBS]; 4] = array::from_fn(|i| { @@ -540,9 +546,9 @@ where &mut len_data, ); - record.inner.dest = u32::from_le_bytes(dest_data); - record.inner.source = u32::from_le_bytes(source_data); - record.inner.len = u32::from_le_bytes(len_data); + debug_assert_eq!(record.inner.dest, u32::from_le_bytes(dest_data)); + debug_assert_eq!(record.inner.source, u32::from_le_bytes(source_data)); + debug_assert_eq!(record.inner.len, u32::from_le_bytes(len_data)); *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); @@ -580,7 +586,7 @@ impl TraceFiller for MemcpyIterFiller { num_loops += 1; num_iters += num_rows; } - // tracing::info!("num_loops: {:?}, num_iters: {:?}", num_loops, num_iters); + tracing::info!("num_loops: {:?}, num_iters: {:?}, sizes: {:?}", num_loops, num_iters, sizes); chunks .par_iter_mut() @@ -594,6 +600,7 @@ impl TraceFiller for MemcpyIterFiller { ) }; + tracing::info!("shift: {:?}", record.inner.shift); // Fill memcpy loop record self.memcpy_loop_chip.add_new_loop( mem_helper, @@ -606,6 +613,7 @@ impl TraceFiller for MemcpyIterFiller { record.inner.register_aux.clone(), ); + // Calculate the timestamp for the last memory access // 4 reads + 4 writes per iteration + (shift != 0) read for the loop header let timestamp = record.inner.from_timestamp + ((num_rows - 1) << 3) as u32 @@ -906,6 +914,7 @@ unsafe fn execute_e12_impl( ) -> u32 { let shift = pre_compute.c; let mut height = 1; + // Read dest and source from registers let (dest, source) = if shift == 0 { ( vm_state.vm_read::(RV32_REGISTER_AS, A3_REGISTER_PTR as u32), @@ -917,19 +926,31 @@ unsafe fn execute_e12_impl( vm_state.vm_read::(RV32_REGISTER_AS, A3_REGISTER_PTR as u32), ) }; + // Read length from a2 register let len = vm_state.vm_read::(RV32_REGISTER_AS, A2_REGISTER_PTR as u32); let mut dest = u32::from_le_bytes(dest); - let mut source = u32::from_le_bytes(source); + let mut source = u32::from_le_bytes(source) - 12 * (shift != 0) as u32; let mut len = u32::from_le_bytes(len); + // Check address ranges are valid + debug_assert!(dest < (1 << POINTER_MAX_BITS)); + debug_assert!((source - 4 * (shift != 0) as u32) < (1 << POINTER_MAX_BITS)); + let to_dest = dest + ((len - shift as u32) & !15); + let to_source = source + ((len - shift as u32) & !15); + debug_assert!(to_dest <= (1 << POINTER_MAX_BITS)); + debug_assert!(to_source <= (1 << POINTER_MAX_BITS)); + // Make sure the destination and source are not overlapping + debug_assert!(to_dest <= source || to_source <= dest); + + // Read the previous data from memory if shift != 0 let mut prev_data = if shift == 0 { [0; 4] } else { - source -= 12; vm_state.vm_read::(RV32_MEMORY_AS, source - 4) }; + // Run iterations while len - shift as u32 > 15 { for i in 0..4 { let data = vm_state.vm_read::(RV32_MEMORY_AS, source + 4 * i); diff --git a/extensions/memcpy/circuit/src/loops.rs b/extensions/memcpy/circuit/src/loops.rs index d08fa3aa22..16204967b6 100644 --- a/extensions/memcpy/circuit/src/loops.rs +++ b/extensions/memcpy/circuit/src/loops.rs @@ -181,6 +181,7 @@ impl Air for MemcpyLoopAir { .eval(builder, local.is_valid); // Generate 16-bit limbs for range checking + // dest, to_dest, source - 12 * is_shift_non_zero, to_source - 12 * is_shift_non_zero let dest_u16_limbs = u8_word_to_u16(local.dest); let to_dest_u16_limbs = u8_word_to_u16(local.to_dest); let source_u16_limbs = [ @@ -213,12 +214,14 @@ impl Air for MemcpyLoopAir { ]; range_check_data.iter().for_each(|data| { + // Check the low 16 bits of dest and source, make sure they are multiple of 4 self.range_bus .range_check( data[0].clone() * AB::F::from_canonical_u32(4).inverse(), MEMCPY_LOOP_LIMB_BITS * 2 - 2, ) .eval(builder, local.is_valid); + // Check the high 16 bits of dest and source, make sure they are in the range [0, 2^pointer_max_bits - 2^MEMCPY_LOOP_LIMB_BITS) self.range_bus .range_check( data[1].clone(), @@ -395,6 +398,7 @@ impl MemcpyLoopChip { let height = next_power_of_two_or_zero(self.records.lock().unwrap().len()); let mut rows = F::zero_vec(height * NUM_MEMCPY_LOOP_COLS); + // TODO: run in parallel for (i, record) in self.records.lock().unwrap().iter().enumerate() { let row = &mut rows[i * NUM_MEMCPY_LOOP_COLS..(i + 1) * NUM_MEMCPY_LOOP_COLS]; let cols: &mut MemcpyLoopCols = row.borrow_mut(); From f1b14c7ddd0a23336f4d9e5346763d0b8e49849b Mon Sep 17 00:00:00 2001 From: Maillew Date: Fri, 12 Sep 2025 11:41:08 -0400 Subject: [PATCH 14/14] fix: rebase changes --- Cargo.lock | 6 +- extensions/memcpy/circuit/src/iteration.rs | 366 +++++++++++---------- 2 files changed, 192 insertions(+), 180 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dfb52f8b06..5da13dd71f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5752,7 +5752,7 @@ dependencies = [ [[package]] name = "openvm-memcpy-circuit" -version = "1.4.0-rc.8" +version = "1.4.1-rc.0" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", @@ -5771,7 +5771,7 @@ dependencies = [ [[package]] name = "openvm-memcpy-integration-tests" -version = "1.4.0-rc.8" +version = "1.4.1-rc.0" dependencies = [ "eyre", "openvm", @@ -5791,7 +5791,7 @@ dependencies = [ [[package]] name = "openvm-memcpy-transpiler" -version = "1.4.0-rc.8" +version = "1.4.1-rc.0" dependencies = [ "openvm-instructions", "openvm-instructions-derive", diff --git a/extensions/memcpy/circuit/src/iteration.rs b/extensions/memcpy/circuit/src/iteration.rs index 38a31e362d..1d7ebd33df 100644 --- a/extensions/memcpy/circuit/src/iteration.rs +++ b/extensions/memcpy/circuit/src/iteration.rs @@ -44,7 +44,8 @@ use openvm_stark_backend::{ }; use crate::{ - bus::MemcpyBus, read_rv32_register, tracing_read, tracing_write, MemcpyLoopChip, A1_REGISTER_PTR, A2_REGISTER_PTR, A3_REGISTER_PTR, A4_REGISTER_PTR + bus::MemcpyBus, read_rv32_register, tracing_read, tracing_write, MemcpyLoopChip, + A1_REGISTER_PTR, A2_REGISTER_PTR, A3_REGISTER_PTR, A4_REGISTER_PTR, }; // Import constants from lib.rs use crate::{MEMCPY_LOOP_LIMB_BITS, MEMCPY_LOOP_NUM_LIMBS}; @@ -105,12 +106,14 @@ impl Air for MemcpyIterAir { timestamp + timestamp_delta.clone() - timestamp_increase_value.clone() }; - let shift = local.shift.iter().enumerate().fold(AB::Expr::ZERO, |acc, (i, x)| { - acc + (*x) * AB::Expr::from_canonical_u32(i as u32 + 1) - }); - let is_shift_non_zero = local.shift.iter().fold(AB::Expr::ZERO, |acc, x| { - acc + (*x) - }); + let shift = local + .shift + .iter() + .enumerate() + .fold(AB::Expr::ZERO, |acc, (i, x)| { + acc + (*x) * AB::Expr::from_canonical_u32(i as u32 + 1) + }); + let is_shift_non_zero = local.shift.iter().fold(AB::Expr::ZERO, |acc, x| acc + (*x)); let is_shift_zero = not::(is_shift_non_zero.clone()); let is_shift_one = local.shift[0]; let is_shift_two = local.shift[1]; @@ -184,7 +187,10 @@ impl Air for MemcpyIterAir { ); // is_shift_non_zero_or_not_start is correct - builder.assert_eq(local.is_shift_non_zero_or_not_start, or::(is_shift_non_zero.clone(), local.is_valid_not_start)); + builder.assert_eq( + local.is_shift_non_zero_or_not_start, + or::(is_shift_non_zero.clone(), local.is_valid_not_start), + ); // if !is_valid, then is_boundary = 0, shift = 0 (we will use this assumption later) let mut is_not_valid_when = builder.when(not::(local.is_valid)); @@ -194,14 +200,17 @@ impl Air for MemcpyIterAir { // if is_valid_not_start, then len = prev_len - 16, source = prev_source + 16, // and dest = prev_dest + 16, shift = prev_shift let mut is_valid_not_start_when = builder.when(local.is_valid_not_start); - is_valid_not_start_when - .assert_eq(len.clone(), prev_len - AB::Expr::from_canonical_u32(16)); + is_valid_not_start_when.assert_eq(len.clone(), prev_len - AB::Expr::from_canonical_u32(16)); is_valid_not_start_when .assert_eq(local.source, prev.source + AB::Expr::from_canonical_u32(16)); is_valid_not_start_when.assert_eq(local.dest, prev.dest + AB::Expr::from_canonical_u32(16)); - local.shift.iter().zip(prev.shift.iter()).for_each(|(local_shift, prev_shift)| { - is_valid_not_start_when.assert_eq(*local_shift, *prev_shift); - }); + local + .shift + .iter() + .zip(prev.shift.iter()) + .for_each(|(local_shift, prev_shift)| { + is_valid_not_start_when.assert_eq(*local_shift, *prev_shift); + }); // make sure if previous row is valid and not end, then local.is_valid = 1 builder @@ -242,35 +251,27 @@ impl Air for MemcpyIterAir { .eval(builder, local.is_boundary); // Read data from memory - let read_data = [ - local.data_1, - local.data_2, - local.data_3, - local.data_4, - ]; + let read_data = [local.data_1, local.data_2, local.data_3, local.data_4]; - read_data - .iter() - .enumerate() - .for_each(|(idx, data)| { - let is_valid_read = if idx == 3 { - local.is_shift_non_zero_or_not_start - } else { - local.is_valid_not_start - }; + read_data.iter().enumerate().for_each(|(idx, data)| { + let is_valid_read = if idx == 3 { + local.is_shift_non_zero_or_not_start + } else { + local.is_valid_not_start + }; - self.memory_bridge - .read( - MemoryAddress::new( - AB::Expr::from_canonical_u32(RV32_MEMORY_AS), - local.source - AB::Expr::from_canonical_usize(16 - idx * 4), - ), - *data, - timestamp_pp(is_valid_read.clone()), - &local.read_aux[idx], - ) - .eval(builder, is_valid_read.clone()); - }); + self.memory_bridge + .read( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + local.source - AB::Expr::from_canonical_usize(16 - idx * 4), + ), + *data, + timestamp_pp(is_valid_read.clone()), + &local.read_aux[idx], + ) + .eval(builder, is_valid_read.clone()); + }); // Write final data to registers write_data.iter().enumerate().for_each(|(idx, data)| { @@ -586,8 +587,13 @@ impl TraceFiller for MemcpyIterFiller { num_loops += 1; num_iters += num_rows; } - tracing::info!("num_loops: {:?}, num_iters: {:?}, sizes: {:?}", num_loops, num_iters, sizes); - + tracing::info!( + "num_loops: {:?}, num_iters: {:?}, sizes: {:?}", + num_loops, + num_iters, + sizes + ); + chunks .par_iter_mut() .zip(sizes.par_iter()) @@ -716,7 +722,8 @@ impl TraceFiller for MemcpyIterFiller { } else { F::ZERO }; - cols.is_shift_non_zero_or_not_start = F::from_bool(record.inner.shift != 0 || !is_start); + cols.is_shift_non_zero_or_not_start = + F::from_bool(record.inner.shift != 0 || !is_start); cols.is_valid_not_start = F::from_bool(!is_start); cols.is_valid = F::ONE; cols.shift = [ @@ -735,125 +742,122 @@ impl TraceFiller for MemcpyIterFiller { len += 16; // if row_idx == 0 && is_start { - // tracing::info!("first_roooooow, timestamp: {:?}, dest: {:?}, source: {:?}, len_0: {:?}, len_1: {:?}, shift: {:?}, is_valid: {:?}, is_valid_not_start: {:?}, is_shift_non_zero: {:?}, is_boundary: {:?}, data_1: {:?}, data_2: {:?}, data_3: {:?}, data_4: {:?}, read_aux: {:?}, read_aux_lt: {:?}", - // cols.timestamp.as_canonical_u32(), - // cols.dest.as_canonical_u32(), - // cols.source.as_canonical_u32(), - // cols.len[0].as_canonical_u32(), - // cols.len[1].as_canonical_u32(), - // cols.shift[1].as_canonical_u32() * 2 + cols.shift[0].as_canonical_u32(), - // cols.is_valid.as_canonical_u32(), - // cols.is_valid_not_start.as_canonical_u32(), - // cols.is_shift_non_zero.as_canonical_u32(), - // cols.is_boundary.as_canonical_u32(), - // cols.data_1.map(|x| x.as_canonical_u32()).to_vec(), - // cols.data_2.map(|x| x.as_canonical_u32()).to_vec(), - // cols.data_3.map(|x| x.as_canonical_u32()).to_vec(), - // cols.data_4.map(|x| x.as_canonical_u32()).to_vec(), - // cols.read_aux.map(|x| x.get_base().prev_timestamp.as_canonical_u32()).to_vec(), + // tracing::info!("first_roooooow, timestamp: {:?}, dest: {:?}, source: {:?}, len_0: {:?}, len_1: {:?}, shift: {:?}, is_valid: {:?}, is_valid_not_start: {:?}, is_shift_non_zero: {:?}, is_boundary: {:?}, data_1: {:?}, data_2: {:?}, data_3: {:?}, data_4: {:?}, read_aux: {:?}, read_aux_lt: {:?}", + // cols.timestamp.as_canonical_u32(), + // cols.dest.as_canonical_u32(), + // cols.source.as_canonical_u32(), + // cols.len[0].as_canonical_u32(), + // cols.len[1].as_canonical_u32(), + // cols.shift[1].as_canonical_u32() * 2 + cols.shift[0].as_canonical_u32(), + // cols.is_valid.as_canonical_u32(), + // cols.is_valid_not_start.as_canonical_u32(), + // cols.is_shift_non_zero.as_canonical_u32(), + // cols.is_boundary.as_canonical_u32(), + // cols.data_1.map(|x| x.as_canonical_u32()).to_vec(), + // cols.data_2.map(|x| x.as_canonical_u32()).to_vec(), + // cols.data_3.map(|x| x.as_canonical_u32()).to_vec(), + // cols.data_4.map(|x| x.as_canonical_u32()).to_vec(), + // cols.read_aux.map(|x| x.get_base().prev_timestamp.as_canonical_u32()).to_vec(), // cols.read_aux.map(|x| x.get_base().timestamp_lt_aux.lower_decomp.iter().map(|x| x.as_canonical_u32()).collect::>()).to_vec()); // } }); }); // chunks.iter().enumerate().for_each(|(row_idx, chunk)| { - // let mut prv_data = [0; 4]; - // tracing::info!("row_idx: {:?}", row_idx); - - // chunk.chunks_exact(width) - // .enumerate() - // .for_each(|(idx, row)| { - // let cols: &MemcpyIterCols = row.borrow(); - // let is_valid_not_start = cols.is_valid_not_start.as_canonical_u32() != 0; - // let is_shift_non_zero = cols.is_shift_non_zero.as_canonical_u32() != 0; - // let source = cols.source.as_canonical_u32(); - // let dest = cols.dest.as_canonical_u32(); - // let mut bad_col = false; - // tracing::info!("source: {:?}, dest: {:?}", source, dest); - // cols.read_aux.iter().enumerate().for_each(|(idx, aux)| { - // if is_valid_not_start || (is_shift_non_zero && idx == 3) { - // let prev_t = aux.get_base().prev_timestamp.as_canonical_u32(); - // let curr_t = cols.timestamp.as_canonical_u32(); - // let ts_lt = aux.get_base().timestamp_lt_aux.lower_decomp.iter() - // .enumerate() - // .fold(F::ZERO, |acc, (i, &val)| { - // acc + val * F::from_canonical_usize(1 << (i * 17)) - // }).as_canonical_u32(); - // if curr_t + idx as u32 != ts_lt + prev_t + 1 { - // bad_col = true; - // } - // } - // if dest + 4 * idx as u32 == 2097216 || dest - 4 * (idx + 1) as u32 == 2097216 || dest + 4 * idx as u32 == 2097280 || dest - 4 * (idx + 1) as u32 == 2097280 { - // bad_col = true; - // } - // }); - // if bad_col { - // let write_data_pairs = [ - // (prv_data, cols.data_1.map(|x| x.as_canonical_u32())), - // (cols.data_1.map(|x| x.as_canonical_u32()), cols.data_2.map(|x| x.as_canonical_u32())), - // (cols.data_2.map(|x| x.as_canonical_u32()), cols.data_3.map(|x| x.as_canonical_u32())), - // (cols.data_3.map(|x| x.as_canonical_u32()), cols.data_4.map(|x| x.as_canonical_u32())), - // ]; - - // let shift = cols.shift[1].as_canonical_u32() * 2 + cols.shift[0].as_canonical_u32(); - // let write_data = write_data_pairs - // .iter() - // .map(|(prev_data, next_data)| { - // array::from_fn::<_, MEMCPY_LOOP_NUM_LIMBS, _>(|i| { - // (shift == 0) as u32 * (next_data[i]) - // + (shift == 1) as u32 - // * (if i < 3 { - // next_data[i + 1] - // } else { - // prev_data[i - 3] - // }) - // + (shift == 2) as u32 - // * (if i < 2 { - // next_data[i + 2] - // } else { - // prev_data[i - 2] - // }) - // + (shift == 3) as u32 - // * (if i < 1 { - // next_data[i + 3] - // } else { - // prev_data[i - 1] - // }) - // }) - // }) - // .collect::>(); - - - - - // tracing::info!("row_idx: {:?}, idx: {:?}, timestamp: {:?}, dest: {:?}, source: {:?}, len_0: {:?}, len_1: {:?}, shift_0: {:?}, shift_1: {:?}, is_valid: {:?}, is_valid_not_start: {:?}, is_shift_non_zero: {:?}, is_boundary: {:?}, write_data: {:?}, prv_data: {:?}, data_1: {:?}, data_2: {:?}, data_3: {:?}, data_4: {:?}, read_aux: {:?}, read_aux_lt: {:?}, write_aux: {:?}, write_aux_lt: {:?}, write_aux_prev_data: {:?}", - // row_idx, - // idx, - // cols.timestamp.as_canonical_u32(), - // cols.dest.as_canonical_u32(), - // cols.source.as_canonical_u32(), - // cols.len[0].as_canonical_u32(), - // cols.len[1].as_canonical_u32(), - // cols.shift[0].as_canonical_u32(), - // cols.shift[1].as_canonical_u32(), - // cols.is_valid.as_canonical_u32(), - // cols.is_valid_not_start.as_canonical_u32(), - // cols.is_shift_non_zero.as_canonical_u32(), - // cols.is_boundary.as_canonical_u32(), - // write_data, - // prv_data, - // cols.data_1.map(|x| x.as_canonical_u32()).to_vec(), - // cols.data_2.map(|x| x.as_canonical_u32()).to_vec(), - // cols.data_3.map(|x| x.as_canonical_u32()).to_vec(), - // cols.data_4.map(|x| x.as_canonical_u32()).to_vec(), - // cols.read_aux.map(|x| x.get_base().prev_timestamp.as_canonical_u32()).to_vec(), - // cols.read_aux.map(|x| x.get_base().timestamp_lt_aux.lower_decomp.iter().map(|x| x.as_canonical_u32()).collect::>()).to_vec(), - // cols.write_aux.map(|x| x.get_base().prev_timestamp.as_canonical_u32()).to_vec(), - // cols.write_aux.map(|x| x.get_base().timestamp_lt_aux.lower_decomp.iter().map(|x| x.as_canonical_u32()).collect::>()).to_vec(), - // cols.write_aux.map(|x| x.prev_data.map(|x| x.as_canonical_u32()).to_vec()).to_vec()); - // } - // prv_data = cols.data_4.map(|x| x.as_canonical_u32()); - // }); + // let mut prv_data = [0; 4]; + // tracing::info!("row_idx: {:?}", row_idx); + + // chunk.chunks_exact(width) + // .enumerate() + // .for_each(|(idx, row)| { + // let cols: &MemcpyIterCols = row.borrow(); + // let is_valid_not_start = cols.is_valid_not_start.as_canonical_u32() != 0; + // let is_shift_non_zero = cols.is_shift_non_zero.as_canonical_u32() != 0; + // let source = cols.source.as_canonical_u32(); + // let dest = cols.dest.as_canonical_u32(); + // let mut bad_col = false; + // tracing::info!("source: {:?}, dest: {:?}", source, dest); + // cols.read_aux.iter().enumerate().for_each(|(idx, aux)| { + // if is_valid_not_start || (is_shift_non_zero && idx == 3) { + // let prev_t = aux.get_base().prev_timestamp.as_canonical_u32(); + // let curr_t = cols.timestamp.as_canonical_u32(); + // let ts_lt = aux.get_base().timestamp_lt_aux.lower_decomp.iter() + // .enumerate() + // .fold(F::ZERO, |acc, (i, &val)| { + // acc + val * F::from_canonical_usize(1 << (i * 17)) + // }).as_canonical_u32(); + // if curr_t + idx as u32 != ts_lt + prev_t + 1 { + // bad_col = true; + // } + // } + // if dest + 4 * idx as u32 == 2097216 || dest - 4 * (idx + 1) as u32 == 2097216 || dest + 4 * idx as u32 == 2097280 || dest - 4 * (idx + 1) as u32 == 2097280 { + // bad_col = true; + // } + // }); + // if bad_col { + // let write_data_pairs = [ + // (prv_data, cols.data_1.map(|x| x.as_canonical_u32())), + // (cols.data_1.map(|x| x.as_canonical_u32()), cols.data_2.map(|x| x.as_canonical_u32())), + // (cols.data_2.map(|x| x.as_canonical_u32()), cols.data_3.map(|x| x.as_canonical_u32())), + // (cols.data_3.map(|x| x.as_canonical_u32()), cols.data_4.map(|x| x.as_canonical_u32())), + // ]; + + // let shift = cols.shift[1].as_canonical_u32() * 2 + cols.shift[0].as_canonical_u32(); + // let write_data = write_data_pairs + // .iter() + // .map(|(prev_data, next_data)| { + // array::from_fn::<_, MEMCPY_LOOP_NUM_LIMBS, _>(|i| { + // (shift == 0) as u32 * (next_data[i]) + // + (shift == 1) as u32 + // * (if i < 3 { + // next_data[i + 1] + // } else { + // prev_data[i - 3] + // }) + // + (shift == 2) as u32 + // * (if i < 2 { + // next_data[i + 2] + // } else { + // prev_data[i - 2] + // }) + // + (shift == 3) as u32 + // * (if i < 1 { + // next_data[i + 3] + // } else { + // prev_data[i - 1] + // }) + // }) + // }) + // .collect::>(); + + // tracing::info!("row_idx: {:?}, idx: {:?}, timestamp: {:?}, dest: {:?}, source: {:?}, len_0: {:?}, len_1: {:?}, shift_0: {:?}, shift_1: {:?}, is_valid: {:?}, is_valid_not_start: {:?}, is_shift_non_zero: {:?}, is_boundary: {:?}, write_data: {:?}, prv_data: {:?}, data_1: {:?}, data_2: {:?}, data_3: {:?}, data_4: {:?}, read_aux: {:?}, read_aux_lt: {:?}, write_aux: {:?}, write_aux_lt: {:?}, write_aux_prev_data: {:?}", + // row_idx, + // idx, + // cols.timestamp.as_canonical_u32(), + // cols.dest.as_canonical_u32(), + // cols.source.as_canonical_u32(), + // cols.len[0].as_canonical_u32(), + // cols.len[1].as_canonical_u32(), + // cols.shift[0].as_canonical_u32(), + // cols.shift[1].as_canonical_u32(), + // cols.is_valid.as_canonical_u32(), + // cols.is_valid_not_start.as_canonical_u32(), + // cols.is_shift_non_zero.as_canonical_u32(), + // cols.is_boundary.as_canonical_u32(), + // write_data, + // prv_data, + // cols.data_1.map(|x| x.as_canonical_u32()).to_vec(), + // cols.data_2.map(|x| x.as_canonical_u32()).to_vec(), + // cols.data_3.map(|x| x.as_canonical_u32()).to_vec(), + // cols.data_4.map(|x| x.as_canonical_u32()).to_vec(), + // cols.read_aux.map(|x| x.get_base().prev_timestamp.as_canonical_u32()).to_vec(), + // cols.read_aux.map(|x| x.get_base().timestamp_lt_aux.lower_decomp.iter().map(|x| x.as_canonical_u32()).collect::>()).to_vec(), + // cols.write_aux.map(|x| x.get_base().prev_timestamp.as_canonical_u32()).to_vec(), + // cols.write_aux.map(|x| x.get_base().timestamp_lt_aux.lower_decomp.iter().map(|x| x.as_canonical_u32()).collect::>()).to_vec(), + // cols.write_aux.map(|x| x.prev_data.map(|x| x.as_canonical_u32()).to_vec()).to_vec()); + // } + // prv_data = cols.data_4.map(|x| x.as_canonical_u32()); + // }); // }); // assert!(false); } @@ -910,24 +914,26 @@ impl MeteredExecutor for MemcpyIterExecutor { #[inline(always)] unsafe fn execute_e12_impl( pre_compute: &MemcpyIterPreCompute, - vm_state: &mut VmExecState, + instret: &mut u64, + pc: &mut u32, + exec_state: &mut VmExecState, ) -> u32 { let shift = pre_compute.c; let mut height = 1; // Read dest and source from registers let (dest, source) = if shift == 0 { ( - vm_state.vm_read::(RV32_REGISTER_AS, A3_REGISTER_PTR as u32), - vm_state.vm_read::(RV32_REGISTER_AS, A4_REGISTER_PTR as u32), + exec_state.vm_read::(RV32_REGISTER_AS, A3_REGISTER_PTR as u32), + exec_state.vm_read::(RV32_REGISTER_AS, A4_REGISTER_PTR as u32), ) } else { ( - vm_state.vm_read::(RV32_REGISTER_AS, A1_REGISTER_PTR as u32), - vm_state.vm_read::(RV32_REGISTER_AS, A3_REGISTER_PTR as u32), + exec_state.vm_read::(RV32_REGISTER_AS, A1_REGISTER_PTR as u32), + exec_state.vm_read::(RV32_REGISTER_AS, A3_REGISTER_PTR as u32), ) }; // Read length from a2 register - let len = vm_state.vm_read::(RV32_REGISTER_AS, A2_REGISTER_PTR as u32); + let len = exec_state.vm_read::(RV32_REGISTER_AS, A2_REGISTER_PTR as u32); let mut dest = u32::from_le_bytes(dest); let mut source = u32::from_le_bytes(source) - 12 * (shift != 0) as u32; @@ -947,13 +953,13 @@ unsafe fn execute_e12_impl( let mut prev_data = if shift == 0 { [0; 4] } else { - vm_state.vm_read::(RV32_MEMORY_AS, source - 4) + exec_state.vm_read::(RV32_MEMORY_AS, source - 4) }; // Run iterations while len - shift as u32 > 15 { for i in 0..4 { - let data = vm_state.vm_read::(RV32_MEMORY_AS, source + 4 * i); + let data = exec_state.vm_read::(RV32_MEMORY_AS, source + 4 * i); let write_data: [u8; 4] = array::from_fn(|i| { if i < 4 - shift as usize { data[i + shift as usize] @@ -961,7 +967,7 @@ unsafe fn execute_e12_impl( prev_data[i - (4 - shift as usize)] } }); - vm_state.vm_write(RV32_MEMORY_AS, dest + 4 * i, &write_data); + exec_state.vm_write(RV32_MEMORY_AS, dest + 4 * i, &write_data); prev_data = data; } len -= 16; @@ -972,51 +978,57 @@ unsafe fn execute_e12_impl( // Write the result back to memory if shift == 0 { - vm_state.vm_write( + exec_state.vm_write( RV32_REGISTER_AS, A3_REGISTER_PTR as u32, &dest.to_le_bytes(), ); - vm_state.vm_write( + exec_state.vm_write( RV32_REGISTER_AS, A4_REGISTER_PTR as u32, &source.to_le_bytes(), ); } else { source += 12; - vm_state.vm_write( + exec_state.vm_write( RV32_REGISTER_AS, A1_REGISTER_PTR as u32, &dest.to_le_bytes(), ); - vm_state.vm_write( + exec_state.vm_write( RV32_REGISTER_AS, A3_REGISTER_PTR as u32, &source.to_le_bytes(), ); }; - vm_state.vm_write(RV32_REGISTER_AS, A2_REGISTER_PTR as u32, &len.to_le_bytes()); + exec_state.vm_write(RV32_REGISTER_AS, A2_REGISTER_PTR as u32, &len.to_le_bytes()); - vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); - vm_state.instret += 1; + *pc = pc.wrapping_add(DEFAULT_PC_STEP); + *instret += 1; height } unsafe fn execute_e1_impl( pre_compute: &[u8], - vm_state: &mut VmExecState, + instret: &mut u64, + pc: &mut u32, + _instret_end: u64, + exec_state: &mut VmExecState, ) { let pre_compute: &MemcpyIterPreCompute = pre_compute.borrow(); - execute_e12_impl::(pre_compute, vm_state); + execute_e12_impl::(pre_compute, instret, pc, exec_state); } unsafe fn execute_e2_impl( pre_compute: &[u8], - vm_state: &mut VmExecState, + instret: &mut u64, + pc: &mut u32, + _arg: u64, + exec_state: &mut VmExecState, ) { let pre_compute: &E2PreCompute = pre_compute.borrow(); - let height = execute_e12_impl::(&pre_compute.data, vm_state); - vm_state + let height = execute_e12_impl::(&pre_compute.data, instret, pc, exec_state); + exec_state .ctx .on_height_change(pre_compute.chip_idx as usize, height); }