diff --git a/Cargo.lock b/Cargo.lock index 0f2d0226e0..f26220c089 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4354,6 +4354,7 @@ dependencies = [ "openvm-ecc-guest", "openvm-ecc-sw-macros", "openvm-ecc-transpiler", + "openvm-memcpy-transpiler", "openvm-rv32im-transpiler", "openvm-sha256-circuit", "openvm-sha256-transpiler", @@ -5107,6 +5108,7 @@ dependencies = [ "openvm-cuda-builder", "openvm-cuda-common", "openvm-instructions", + "openvm-memcpy-circuit", "openvm-mod-circuit-builder", "openvm-pairing-guest", "openvm-rv32-adapters", @@ -5166,6 +5168,7 @@ dependencies = [ "openvm-circuit", "openvm-ecc-circuit", "openvm-instructions", + "openvm-memcpy-transpiler", "openvm-rv32im-transpiler", "openvm-stark-sdk", "openvm-toolchain-tests", @@ -5205,6 +5208,8 @@ dependencies = [ "openvm-ecc-transpiler", "openvm-keccak256-circuit", "openvm-keccak256-transpiler", + "openvm-memcpy-circuit", + "openvm-memcpy-transpiler", "openvm-native-circuit", "openvm-native-recursion", "openvm-pairing-circuit", @@ -5288,6 +5293,7 @@ dependencies = [ "openvm-cuda-builder", "openvm-cuda-common", "openvm-instructions", + "openvm-memcpy-circuit", "openvm-rv32-adapters", "openvm-rv32im-circuit", "openvm-rv32im-transpiler", @@ -5551,6 +5557,7 @@ dependencies = [ "openvm-circuit", "openvm-ecc-circuit", "openvm-ecc-transpiler", + "openvm-memcpy-transpiler", "openvm-rv32im-transpiler", "openvm-sdk", "openvm-stark-sdk", @@ -5597,6 +5604,7 @@ dependencies = [ "openvm-algebra-transpiler", "openvm-circuit", "openvm-instructions", + "openvm-memcpy-transpiler", "openvm-rv32im-transpiler", "openvm-stark-sdk", "openvm-toolchain-tests", @@ -5647,6 +5655,7 @@ dependencies = [ "openvm-keccak256-circuit", "openvm-keccak256-guest", "openvm-keccak256-transpiler", + "openvm-memcpy-transpiler", "openvm-rv32im-transpiler", "openvm-stark-sdk", "openvm-toolchain-tests", @@ -5672,6 +5681,7 @@ dependencies = [ "openvm-cuda-common", "openvm-instructions", "openvm-keccak256-transpiler", + "openvm-memcpy-circuit", "openvm-rv32im-circuit", "openvm-stark-backend", "openvm-stark-sdk", @@ -5709,6 +5719,57 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "openvm-memcpy-circuit" +version = "1.4.0-rc.8" +dependencies = [ + "derive-new 0.6.0", + "derive_more 1.0.0", + "openvm-circuit", + "openvm-circuit-derive", + "openvm-circuit-primitives", + "openvm-circuit-primitives-derive", + "openvm-instructions", + "openvm-memcpy-transpiler", + "openvm-rv32im-transpiler", + "openvm-stark-backend", + "serde", + "strum", + "tracing", +] + +[[package]] +name = "openvm-memcpy-integration-tests" +version = "1.4.0-rc.8" +dependencies = [ + "eyre", + "openvm", + "openvm-circuit", + "openvm-circuit-primitives", + "openvm-instructions", + "openvm-memcpy-circuit", + "openvm-memcpy-transpiler", + "openvm-stark-backend", + "openvm-stark-sdk", + "rand 0.8.5", + "serde", + "strum", + "test-case", + "tracing", +] + +[[package]] +name = "openvm-memcpy-transpiler" +version = "1.4.0-rc.8" +dependencies = [ + "openvm-instructions", + "openvm-instructions-derive", + "openvm-stark-backend", + "openvm-transpiler", + "rrs-lib", + "strum", +] + [[package]] name = "openvm-mod-circuit-builder" version = "1.4.0-rc.8" @@ -5860,6 +5921,8 @@ dependencies = [ "openvm-ecc-sw-macros", "openvm-ecc-transpiler", "openvm-instructions", + "openvm-memcpy-circuit", + "openvm-memcpy-transpiler", "openvm-pairing", "openvm-pairing-circuit", "openvm-pairing-guest", @@ -5893,6 +5956,7 @@ dependencies = [ "openvm-ecc-circuit", "openvm-ecc-guest", "openvm-instructions", + "openvm-memcpy-circuit", "openvm-mod-circuit-builder", "openvm-pairing-guest", "openvm-pairing-transpiler", @@ -6014,6 +6078,7 @@ dependencies = [ "openvm-cuda-builder", "openvm-cuda-common", "openvm-instructions", + "openvm-memcpy-circuit", "openvm-rv32im-transpiler", "openvm-stark-backend", "openvm-stark-sdk", @@ -6040,6 +6105,7 @@ dependencies = [ "openvm", "openvm-circuit", "openvm-instructions", + "openvm-memcpy-transpiler", "openvm-rv32im-circuit", "openvm-rv32im-guest", "openvm-rv32im-transpiler", @@ -6107,6 +6173,8 @@ dependencies = [ "openvm-ecc-transpiler", "openvm-keccak256-circuit", "openvm-keccak256-transpiler", + "openvm-memcpy-circuit", + "openvm-memcpy-transpiler", "openvm-native-circuit", "openvm-native-compiler", "openvm-native-recursion", @@ -6141,6 +6209,7 @@ dependencies = [ "eyre", "openvm-circuit", "openvm-instructions", + "openvm-memcpy-transpiler", "openvm-rv32im-transpiler", "openvm-sha256-circuit", "openvm-sha256-guest", @@ -6178,6 +6247,7 @@ dependencies = [ "openvm-cuda-builder", "openvm-cuda-common", "openvm-instructions", + "openvm-memcpy-circuit", "openvm-rv32im-circuit", "openvm-sha256-air", "openvm-sha256-transpiler", @@ -6386,6 +6456,7 @@ dependencies = [ "openvm-ecc-guest", "openvm-ecc-sw-macros", "openvm-ecc-transpiler", + "openvm-memcpy-transpiler", "openvm-rv32im-transpiler", "openvm-sha256-circuit", "openvm-sha256-transpiler", @@ -7914,6 +7985,7 @@ dependencies = [ "openvm-bigint-transpiler", "openvm-circuit", "openvm-instructions", + "openvm-memcpy-transpiler", "openvm-rv32im-transpiler", "openvm-stark-sdk", "openvm-toolchain-tests", diff --git a/Cargo.toml b/Cargo.toml index 9226a5b0f4..4fd7d73c04 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,9 @@ members = [ "extensions/ecc/tests", "extensions/pairing/circuit", "extensions/pairing/guest", + "extensions/memcpy/circuit", + "extensions/memcpy/transpiler", + "extensions/memcpy/tests", "guest-libs/ff_derive/", "guest-libs/k256/", "guest-libs/p256/", @@ -171,6 +174,8 @@ openvm-ecc-sw-macros = { path = "extensions/ecc/sw-macros", default-features = f openvm-pairing-circuit = { path = "extensions/pairing/circuit", default-features = false } openvm-pairing-transpiler = { path = "extensions/pairing/transpiler", default-features = false } openvm-pairing-guest = { path = "extensions/pairing/guest", default-features = false } +openvm-memcpy-circuit = { path = "extensions/memcpy/circuit", default-features = false } +openvm-memcpy-transpiler = { path = "extensions/memcpy/transpiler", default-features = false } openvm-verify-stark = { path = "guest-libs/verify_stark", default-features = false } # Benchmarking diff --git a/benchmarks/execute/Cargo.toml b/benchmarks/execute/Cargo.toml index 5fcf58b1de..0abecbbf53 100644 --- a/benchmarks/execute/Cargo.toml +++ b/benchmarks/execute/Cargo.toml @@ -24,6 +24,8 @@ openvm-pairing-guest.workspace = true openvm-pairing-transpiler.workspace = true openvm-keccak256-circuit.workspace = true openvm-keccak256-transpiler.workspace = true +openvm-memcpy-circuit.workspace = true +openvm-memcpy-transpiler.workspace = true openvm-rv32im-circuit.workspace = true openvm-rv32im-transpiler.workspace = true openvm-sha256-circuit.workspace = true diff --git a/benchmarks/execute/benches/execute.rs b/benchmarks/execute/benches/execute.rs index 291140ceea..dbbccf9861 100644 --- a/benchmarks/execute/benches/execute.rs +++ b/benchmarks/execute/benches/execute.rs @@ -32,6 +32,8 @@ use openvm_ecc_circuit::{EccCpuProverExt, WeierstrassExtension, WeierstrassExten use openvm_ecc_transpiler::EccTranspilerExtension; use openvm_keccak256_circuit::{Keccak256, Keccak256CpuProverExt, Keccak256Executor}; use openvm_keccak256_transpiler::Keccak256TranspilerExtension; +use openvm_memcpy_circuit::{Memcpy, MemcpyCpuProverExt, MemcpyExecutor}; +use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_native_circuit::{NativeCpuBuilder, NATIVE_MAX_TRACE_HEIGHTS}; use openvm_native_recursion::hints::Hintable; use openvm_pairing_circuit::{ @@ -107,6 +109,8 @@ pub struct ExecuteConfig { #[extension] pub keccak: Keccak256, #[extension] + pub memcpy: Memcpy, + #[extension] pub sha256: Sha256, #[extension] pub modular: ModularExtension, @@ -128,6 +132,7 @@ impl Default for ExecuteConfig { io: Rv32Io, bigint: Int256::default(), keccak: Keccak256, + memcpy: Memcpy, sha256: Sha256, modular: ModularExtension::new(vec![ bn_config.modulus.clone(), @@ -188,6 +193,11 @@ where &config.keccak, inventory, )?; + VmProverExtension::::extend_prover( + &MemcpyCpuProverExt, + &config.memcpy, + inventory, + )?; VmProverExtension::::extend_prover(&Sha2CpuProverExt, &config.sha256, inventory)?; VmProverExtension::::extend_prover( &AlgebraCpuProverExt, @@ -216,6 +226,7 @@ fn create_default_transpiler() -> Transpiler { .with_extension(Rv32MTranspilerExtension) .with_extension(Int256TranspilerExtension) .with_extension(Keccak256TranspilerExtension) + .with_extension(MemcpyTranspilerExtension) .with_extension(Sha256TranspilerExtension) .with_extension(ModularTranspilerExtension) .with_extension(Fp2TranspilerExtension) diff --git a/benchmarks/guest/fibonacci/openvm.toml b/benchmarks/guest/fibonacci/openvm.toml index 19a1e670e5..64b1f884d4 100644 --- a/benchmarks/guest/fibonacci/openvm.toml +++ b/benchmarks/guest/fibonacci/openvm.toml @@ -1,3 +1,4 @@ [app_vm_config.rv32i] [app_vm_config.rv32m] [app_vm_config.io] +[app_vm_config.memcpy] \ No newline at end of file diff --git a/benchmarks/guest/fibonacci/src/main.rs b/benchmarks/guest/fibonacci/src/main.rs index 158a3fa0ec..6b798ac752 100644 --- a/benchmarks/guest/fibonacci/src/main.rs +++ b/benchmarks/guest/fibonacci/src/main.rs @@ -1,14 +1,44 @@ -use openvm::io::{read, reveal_u32}; +use core::ptr; -pub fn main() { - let n: u64 = read(); - let mut a: u64 = 0; - let mut b: u64 = 1; - for _ in 0..n { - let c: u64 = a.wrapping_add(b); - a = b; - b = c; +openvm::entry!(main); + +/// Moves all the elements of `src` into `dst`, leaving `src` empty. +#[no_mangle] +pub fn append(dst: &mut [T], src: &mut [T], shift: usize) { + let src_len = src.len(); + let dst_len = dst.len(); + + unsafe { + // The call to add is always safe because `Vec` will never + // allocate more than `isize::MAX` bytes. + let dst_ptr = dst.as_mut_ptr().wrapping_add(shift); + let src_ptr = src.as_ptr(); + println!("dst_ptr: {:?}", dst_ptr); + println!("src_ptr: {:?}", src_ptr); + println!("src_len: {:?}", src_len); + + // The two regions cannot overlap because mutable references do + // not alias, and two different vectors cannot own the same + // memory. + ptr::copy_nonoverlapping(src_ptr, dst_ptr, src_len); } - reveal_u32(a as u32, 0); - reveal_u32((a >> 32) as u32, 1); } + +pub fn main() { + let mut a: [u8; 1000] = [1; 1000]; + let mut b: [u8; 500] = [2; 500]; + + let shift: usize = 0; + append(&mut a, &mut b, shift); + + for i in 0..1000 { + if i < shift || i >= shift + b.len() { + assert_eq!(a[i], 1); + } else { + assert_eq!(a[i], 2); + } + } + + println!("a: {:?}", a); + println!("b: {:?}", b); +} \ No newline at end of file diff --git a/crates/circuits/primitives/src/assert_less_than/mod.rs b/crates/circuits/primitives/src/assert_less_than/mod.rs index 53054c713a..1ec12a8fd5 100644 --- a/crates/circuits/primitives/src/assert_less_than/mod.rs +++ b/crates/circuits/primitives/src/assert_less_than/mod.rs @@ -3,7 +3,7 @@ use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::AirBuilder, - p3_field::{Field, FieldAlgebra}, + p3_field::{Field, FieldAlgebra, PrimeField32}, }; use crate::{ @@ -58,6 +58,14 @@ pub struct LessThanAuxCols { pub lower_decomp: [T; AUX_LEN], } +impl Default for LessThanAuxCols { + fn default() -> Self { + Self { + lower_decomp: [F::ZERO; AUX_LEN], + } + } +} + /// This is intended for use as a **SubAir**, not as a standalone Air. /// /// This SubAir constrains that `x < y` when `count != 0`, assuming diff --git a/crates/sdk/Cargo.toml b/crates/sdk/Cargo.toml index a9bbc4108e..f5239aca0e 100644 --- a/crates/sdk/Cargo.toml +++ b/crates/sdk/Cargo.toml @@ -18,6 +18,8 @@ openvm-ecc-circuit = { workspace = true } openvm-ecc-transpiler = { workspace = true } openvm-keccak256-circuit = { workspace = true } openvm-keccak256-transpiler = { workspace = true } +openvm-memcpy-circuit = { workspace = true } +openvm-memcpy-transpiler = { workspace = true } openvm-sha256-circuit = { workspace = true } openvm-sha256-transpiler = { workspace = true } openvm-pairing-circuit = { workspace = true } diff --git a/crates/sdk/src/config/global.rs b/crates/sdk/src/config/global.rs index a54dc85be4..6c673849e3 100644 --- a/crates/sdk/src/config/global.rs +++ b/crates/sdk/src/config/global.rs @@ -18,6 +18,8 @@ use openvm_ecc_circuit::{ use openvm_ecc_transpiler::EccTranspilerExtension; use openvm_keccak256_circuit::{Keccak256, Keccak256CpuProverExt, Keccak256Executor}; use openvm_keccak256_transpiler::Keccak256TranspilerExtension; +use openvm_memcpy_circuit::{Memcpy, MemcpyCpuProverExt, MemcpyExecutor}; +use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_native_circuit::{ CastFExtension, CastFExtensionExecutor, Native, NativeCpuProverExt, NativeExecutor, }; @@ -81,6 +83,7 @@ pub struct SdkVmConfig { pub rv32i: Option, pub io: Option, pub keccak: Option, + pub memcpy: Option, pub sha256: Option, pub native: Option, pub castf: Option, @@ -118,6 +121,7 @@ impl SdkVmConfig { .rv32m(Default::default()) .io(Default::default()) .keccak(Default::default()) + .memcpy(Default::default()) .sha256(Default::default()) .bigint(Default::default()) .modular(ModularExtension::new(vec![ @@ -199,6 +203,9 @@ impl TranspilerConfig for SdkVmConfig { if self.keccak.is_some() { transpiler = transpiler.with_extension(Keccak256TranspilerExtension); } + if self.memcpy.is_some() { + transpiler = transpiler.with_extension(MemcpyTranspilerExtension); + } if self.sha256.is_some() { transpiler = transpiler.with_extension(Sha256TranspilerExtension); } @@ -269,6 +276,7 @@ impl SdkVmConfig { let rv32i = config.rv32i.map(|_| Rv32I); let io = config.io.map(|_| Rv32Io); let keccak = config.keccak.map(|_| Keccak256); + let memcpy = config.memcpy.map(|_| Memcpy); let sha256 = config.sha256.map(|_| Sha256); let native = config.native.map(|_| Native); let castf = config.castf.map(|_| CastFExtension); @@ -284,6 +292,7 @@ impl SdkVmConfig { rv32i, io, keccak, + memcpy, sha256, native, castf, @@ -315,6 +324,8 @@ pub struct SdkVmConfigInner { pub io: Option, #[extension(executor = "Keccak256Executor")] pub keccak: Option, + #[extension(executor = "MemcpyExecutor")] + pub memcpy: Option, #[extension(executor = "Sha256Executor")] pub sha256: Option, #[extension(executor = "NativeExecutor")] @@ -392,6 +403,9 @@ where if let Some(keccak) = &config.keccak { VmProverExtension::::extend_prover(&Keccak256CpuProverExt, keccak, inventory)?; } + if let Some(memcpy) = &config.memcpy { + VmProverExtension::::extend_prover(&MemcpyCpuProverExt, memcpy, inventory)?; + } if let Some(sha256) = &config.sha256 { VmProverExtension::::extend_prover(&Sha2CpuProverExt, sha256, inventory)?; } @@ -566,6 +580,12 @@ impl From for UnitStruct { } } +impl From for UnitStruct { + fn from(_: Memcpy) -> Self { + UnitStruct {} + } +} + impl From for UnitStruct { fn from(_: Sha256) -> Self { UnitStruct {} @@ -592,6 +612,7 @@ struct SdkVmConfigWithDefaultDeser { pub rv32i: Option, pub io: Option, pub keccak: Option, + pub memcpy: Option, pub sha256: Option, pub native: Option, pub castf: Option, @@ -611,6 +632,7 @@ impl From for SdkVmConfig { rv32i: config.rv32i, io: config.io, keccak: config.keccak, + memcpy: config.memcpy, sha256: config.sha256, native: config.native, castf: config.castf, diff --git a/crates/sdk/src/config/openvm_standard.toml b/crates/sdk/src/config/openvm_standard.toml index f1f9267191..23c6a3d68a 100644 --- a/crates/sdk/src/config/openvm_standard.toml +++ b/crates/sdk/src/config/openvm_standard.toml @@ -3,6 +3,7 @@ [app_vm_config.io] [app_vm_config.keccak] +[app_vm_config.memcpy] [app_vm_config.sha256] [app_vm_config.bigint] diff --git a/crates/toolchain/openvm/src/memcpy.s b/crates/toolchain/openvm/src/memcpy.s index e0043ec220..a45a82d16a 100644 --- a/crates/toolchain/openvm/src/memcpy.s +++ b/crates/toolchain/openvm/src/memcpy.s @@ -205,6 +205,11 @@ .attribute 4, 16 .attribute 5, "rv32im" .file "musl_memcpy.c" + + # Define memcpy_loop macro for custom instruction (U-type) + .macro memcpy_loop shift + .word 0x00000072 | (\shift << 12) # opcode 0x72 + shift in immediate field (bits 12-31) + .endm .globl memcpy .p2align 2 .type memcpy,@function @@ -250,32 +255,8 @@ memcpy: sb a6, 2(a3) addi a2, a2, -3 addi a3, a4, 16 - li a4, 16 .LBBmemcpy0_9: - lw a6, -12(a3) - srli a5, a5, 24 - slli a7, a6, 8 - lw t0, -8(a3) - or a5, a7, a5 - sw a5, 0(a1) - srli a5, a6, 24 - slli a6, t0, 8 - lw a7, -4(a3) - or a5, a6, a5 - sw a5, 4(a1) - srli a6, t0, 24 - slli t0, a7, 8 - lw a5, 0(a3) - or a6, t0, a6 - sw a6, 8(a1) - srli a6, a7, 24 - slli a7, a5, 8 - or a6, a7, a6 - sw a6, 12(a1) - addi a1, a1, 16 - addi a2, a2, -16 - addi a3, a3, 16 - bltu a4, a2, .LBBmemcpy0_9 + memcpy_loop 1 addi a4, a3, -13 j .LBBmemcpy0_25 .LBBmemcpy0_11: @@ -286,20 +267,8 @@ memcpy: .LBBmemcpy0_12: li a1, 16 bltu a2, a1, .LBBmemcpy0_15 - li a1, 15 .LBBmemcpy0_14: - lw a5, 0(a4) - lw a6, 4(a4) - lw a7, 8(a4) - lw t0, 12(a4) - sw a5, 0(a3) - sw a6, 4(a3) - sw a7, 8(a3) - sw t0, 12(a3) - addi a4, a4, 16 - addi a2, a2, -16 - addi a3, a3, 16 - bltu a1, a2, .LBBmemcpy0_14 + memcpy_loop 0 .LBBmemcpy0_15: andi a1, a2, 8 beqz a1, .LBBmemcpy0_17 @@ -323,32 +292,8 @@ memcpy: sb a5, 0(a3) addi a2, a2, -1 addi a3, a4, 16 - li a4, 18 .LBBmemcpy0_20: - lw a6, -12(a3) - srli a5, a5, 8 - slli a7, a6, 24 - lw t0, -8(a3) - or a5, a7, a5 - sw a5, 0(a1) - srli a5, a6, 8 - slli a6, t0, 24 - lw a7, -4(a3) - or a5, a6, a5 - sw a5, 4(a1) - srli a6, t0, 8 - slli t0, a7, 24 - lw a5, 0(a3) - or a6, t0, a6 - sw a6, 8(a1) - srli a6, a7, 8 - slli a7, a5, 24 - or a6, a7, a6 - sw a6, 12(a1) - addi a1, a1, 16 - addi a2, a2, -16 - addi a3, a3, 16 - bltu a4, a2, .LBBmemcpy0_20 + memcpy_loop 3 addi a4, a3, -15 j .LBBmemcpy0_25 .LBBmemcpy0_22: @@ -359,32 +304,8 @@ memcpy: sb a6, 1(a3) addi a2, a2, -2 addi a3, a4, 16 - li a4, 17 .LBBmemcpy0_23: - lw a6, -12(a3) - srli a5, a5, 16 - slli a7, a6, 16 - lw t0, -8(a3) - or a5, a7, a5 - sw a5, 0(a1) - srli a5, a6, 16 - slli a6, t0, 16 - lw a7, -4(a3) - or a5, a6, a5 - sw a5, 4(a1) - srli a6, t0, 16 - slli t0, a7, 16 - lw a5, 0(a3) - or a6, t0, a6 - sw a6, 8(a1) - srli a6, a7, 16 - slli a7, a5, 16 - or a6, a7, a6 - sw a6, 12(a1) - addi a1, a1, 16 - addi a2, a2, -16 - addi a3, a3, 16 - bltu a4, a2, .LBBmemcpy0_23 + memcpy_loop 2 addi a4, a3, -14 .LBBmemcpy0_25: mv a3, a1 diff --git a/crates/vm/src/arch/testing/mod.rs b/crates/vm/src/arch/testing/mod.rs index 5293a0275a..91595c404e 100644 --- a/crates/vm/src/arch/testing/mod.rs +++ b/crates/vm/src/arch/testing/mod.rs @@ -29,6 +29,7 @@ pub const BITWISE_OP_LOOKUP_BUS: BusIndex = 9; pub const BYTE_XOR_BUS: BusIndex = 10; pub const RANGE_TUPLE_CHECKER_BUS: BusIndex = 11; pub const MEMORY_MERKLE_BUS: BusIndex = 12; +pub const MEMCPY_BUS: BusIndex = 13; pub const RANGE_CHECKER_BUS: BusIndex = 4; diff --git a/crates/vm/src/system/memory/offline_checker/columns.rs b/crates/vm/src/system/memory/offline_checker/columns.rs index ef9821f859..9225b813b2 100644 --- a/crates/vm/src/system/memory/offline_checker/columns.rs +++ b/crates/vm/src/system/memory/offline_checker/columns.rs @@ -26,6 +26,15 @@ impl MemoryBaseAuxCols { } } +impl Default for MemoryBaseAuxCols { + fn default() -> Self { + Self { + prev_timestamp: F::ZERO, + timestamp_lt_aux: LessThanAuxCols::default(), + } + } +} + #[repr(C)] #[derive(Clone, Copy, Debug, AlignedBorrow)] pub struct MemoryWriteAuxCols { @@ -43,6 +52,11 @@ impl MemoryWriteAuxCols { self.base } + #[inline(always)] + pub fn set_base(&mut self, base: MemoryBaseAuxCols) { + self.base = base; + } + #[inline(always)] pub fn prev_data(&self) -> &[T; N] { &self.prev_data @@ -80,6 +94,11 @@ impl MemoryReadAuxCols { self.base } + #[inline(always)] + pub fn set_base(&mut self, base: MemoryBaseAuxCols) { + self.base = base; + } + /// Sets the previous timestamp **without** updating the less than auxiliary columns. #[inline(always)] pub fn set_prev(&mut self, timestamp: F) { diff --git a/crates/vm/src/system/memory/offline_checker/mod.rs b/crates/vm/src/system/memory/offline_checker/mod.rs index 8b15328185..5ef5745d11 100644 --- a/crates/vm/src/system/memory/offline_checker/mod.rs +++ b/crates/vm/src/system/memory/offline_checker/mod.rs @@ -5,13 +5,42 @@ mod columns; pub use bridge::*; pub use bus::*; pub use columns::*; +use openvm_circuit_primitives::is_less_than::LessThanAuxCols; +use openvm_stark_backend::p3_field::PrimeField32; #[repr(C)] #[derive(Debug, Clone)] -pub struct MemoryReadAuxRecord { +pub struct MemoryBaseAuxRecord { pub prev_timestamp: u32, } +#[repr(C)] +#[derive(Debug, Clone, Default)] +pub struct MemoryExtendedAuxRecord { + pub prev_timestamp: u32, + pub timestamp_lt_aux: [u32; AUX_LEN], +} + +impl MemoryExtendedAuxRecord { + pub fn from_aux_cols(aux_cols: MemoryBaseAuxCols) -> Self { + Self { + prev_timestamp: aux_cols.prev_timestamp.as_canonical_u32(), + timestamp_lt_aux: aux_cols + .timestamp_lt_aux + .lower_decomp + .map(|x| x.as_canonical_u32()), + } + } + + pub fn to_aux_cols(&self, aux_cols: &mut MemoryBaseAuxCols) { + aux_cols.prev_timestamp = F::from_canonical_u32(self.prev_timestamp); + aux_cols.timestamp_lt_aux.lower_decomp = + self.timestamp_lt_aux.map(|x| F::from_canonical_u32(x)); + } +} + +pub type MemoryReadAuxRecord = MemoryBaseAuxRecord; + #[repr(C)] #[derive(Debug, Clone)] pub struct MemoryWriteAuxRecord { diff --git a/extensions/algebra/circuit/Cargo.toml b/extensions/algebra/circuit/Cargo.toml index 5f00cb2ea7..309e76b3bb 100644 --- a/extensions/algebra/circuit/Cargo.toml +++ b/extensions/algebra/circuit/Cargo.toml @@ -17,6 +17,7 @@ openvm-stark-backend = { workspace = true } openvm-mod-circuit-builder = { workspace = true } openvm-stark-sdk = { workspace = true } openvm-rv32im-circuit = { workspace = true } +openvm-memcpy-circuit = { workspace = true } openvm-rv32-adapters = { workspace = true } openvm-algebra-transpiler = { workspace = true } openvm-cuda-backend = { workspace = true, optional = true } diff --git a/extensions/algebra/circuit/src/extension/mod.rs b/extensions/algebra/circuit/src/extension/mod.rs index 07384e56d2..fe16df067b 100644 --- a/extensions/algebra/circuit/src/extension/mod.rs +++ b/extensions/algebra/circuit/src/extension/mod.rs @@ -9,6 +9,7 @@ use openvm_circuit::{ system::{SystemChipInventory, SystemCpuBuilder, SystemExecutor}, }; use openvm_circuit_derive::VmConfig; +use openvm_memcpy_circuit::{Memcpy, MemcpyCpuProverExt, MemcpyExecutor}; use openvm_rv32im_circuit::{ Rv32I, Rv32IExecutor, Rv32ImCpuProverExt, Rv32Io, Rv32IoExecutor, Rv32M, Rv32MExecutor, }; @@ -57,6 +58,8 @@ pub struct Rv32ModularConfig { pub io: Rv32Io, #[extension] pub modular: ModularExtension, + #[extension] + pub memcpy: Memcpy, } impl InitFileGenerator for Rv32ModularConfig { @@ -76,6 +79,7 @@ impl Rv32ModularConfig { mul: Default::default(), io: Default::default(), modular: ModularExtension::new(moduli), + memcpy: Memcpy, } } } @@ -143,6 +147,11 @@ where &config.modular, inventory, )?; + VmProverExtension::::extend_prover( + &MemcpyCpuProverExt, + &config.memcpy, + inventory, + )?; Ok(chip_complex) } } diff --git a/extensions/algebra/tests/Cargo.toml b/extensions/algebra/tests/Cargo.toml index 0d748f3d88..e73112735d 100644 --- a/extensions/algebra/tests/Cargo.toml +++ b/extensions/algebra/tests/Cargo.toml @@ -15,6 +15,7 @@ openvm-transpiler.workspace = true openvm-algebra-transpiler.workspace = true openvm-algebra-circuit.workspace = true openvm-rv32im-transpiler.workspace = true +openvm-memcpy-transpiler.workspace = true openvm-toolchain-tests = { path = "../../../crates/toolchain/tests" } openvm-ecc-circuit.workspace = true eyre.workspace = true diff --git a/extensions/algebra/tests/src/lib.rs b/extensions/algebra/tests/src/lib.rs index c931dce496..1a56370953 100644 --- a/extensions/algebra/tests/src/lib.rs +++ b/extensions/algebra/tests/src/lib.rs @@ -12,6 +12,7 @@ mod tests { use openvm_circuit::utils::{air_test, test_system_config}; use openvm_ecc_circuit::SECP256K1_CONFIG; use openvm_instructions::exe::VmExe; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; @@ -49,7 +50,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularBuilder, config, openvm_exe); @@ -66,7 +68,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularBuilder, config, openvm_exe); Ok(()) @@ -93,7 +96,8 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(Fp2TranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularWithFp2Builder, config, openvm_exe); Ok(()) @@ -125,7 +129,8 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(Fp2TranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularWithFp2Builder, config, openvm_exe); Ok(()) @@ -145,7 +150,8 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(Fp2TranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularWithFp2Builder, config, openvm_exe); Ok(()) @@ -173,7 +179,8 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(Fp2TranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), ) .unwrap(); air_test(Rv32ModularBuilder, config, openvm_exe); @@ -189,7 +196,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularBuilder, config, openvm_exe); Ok(()) diff --git a/extensions/bigint/circuit/Cargo.toml b/extensions/bigint/circuit/Cargo.toml index c431fae6e5..1398a1fa10 100644 --- a/extensions/bigint/circuit/Cargo.toml +++ b/extensions/bigint/circuit/Cargo.toml @@ -21,6 +21,7 @@ openvm-rv32im-circuit = { workspace = true } openvm-rv32-adapters = { workspace = true } openvm-bigint-transpiler = { workspace = true } openvm-rv32im-transpiler = { workspace = true } +openvm-memcpy-circuit = { workspace = true } derive-new.workspace = true derive_more = { workspace = true, features = ["from"] } diff --git a/extensions/bigint/circuit/src/extension/mod.rs b/extensions/bigint/circuit/src/extension/mod.rs index 1725a4860d..cff327da98 100644 --- a/extensions/bigint/circuit/src/extension/mod.rs +++ b/extensions/bigint/circuit/src/extension/mod.rs @@ -25,6 +25,7 @@ use openvm_circuit_primitives::{ }, }; use openvm_instructions::{program::DEFAULT_PC_STEP, LocalOpcode}; +use openvm_memcpy_circuit::MemcpyCpuProverExt; use openvm_rv32im_circuit::Rv32ImCpuProverExt; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, @@ -373,6 +374,11 @@ where &config.bigint, inventory, )?; + VmProverExtension::::extend_prover( + &MemcpyCpuProverExt, + &config.memcpy, + inventory, + )?; Ok(chip_complex) } } diff --git a/extensions/bigint/circuit/src/lib.rs b/extensions/bigint/circuit/src/lib.rs index 69e8ae9b58..c9f6eb38a6 100644 --- a/extensions/bigint/circuit/src/lib.rs +++ b/extensions/bigint/circuit/src/lib.rs @@ -6,6 +6,7 @@ use openvm_circuit::{ system::SystemExecutor, }; use openvm_circuit_derive::{PreflightExecutor, VmConfig}; +use openvm_memcpy_circuit::{Memcpy, MemcpyExecutor}; use openvm_rv32_adapters::{ Rv32HeapAdapterAir, Rv32HeapAdapterExecutor, Rv32HeapAdapterFiller, Rv32HeapBranchAdapterAir, Rv32HeapBranchAdapterExecutor, Rv32HeapBranchAdapterFiller, @@ -175,6 +176,8 @@ pub struct Int256Rv32Config { pub io: Rv32Io, #[extension] pub bigint: Int256, + #[extension] + pub memcpy: Memcpy, } // Default implementation uses no init file @@ -188,6 +191,7 @@ impl Default for Int256Rv32Config { rv32m: Rv32M::default(), io: Rv32Io, bigint: Int256::default(), + memcpy: Memcpy, } } } diff --git a/extensions/ecc/tests/Cargo.toml b/extensions/ecc/tests/Cargo.toml index 42a0bee912..73adc0cb24 100644 --- a/extensions/ecc/tests/Cargo.toml +++ b/extensions/ecc/tests/Cargo.toml @@ -15,6 +15,7 @@ openvm-algebra-transpiler.workspace = true openvm-ecc-transpiler.workspace = true openvm-ecc-circuit.workspace = true openvm-rv32im-transpiler.workspace = true +openvm-memcpy-transpiler.workspace = true openvm-toolchain-tests = { path = "../../../crates/toolchain/tests" } openvm-sdk.workspace = true serde.workspace = true diff --git a/extensions/ecc/tests/src/lib.rs b/extensions/ecc/tests/src/lib.rs index 2a10837dd5..a2593aa53d 100644 --- a/extensions/ecc/tests/src/lib.rs +++ b/extensions/ecc/tests/src/lib.rs @@ -17,6 +17,7 @@ mod tests { CurveConfig, Rv32WeierstrassBuilder, Rv32WeierstrassConfig, P256_CONFIG, SECP256K1_CONFIG, }; use openvm_ecc_transpiler::EccTranspilerExtension; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; @@ -60,7 +61,8 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -82,7 +84,8 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -105,7 +108,8 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -152,7 +156,8 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let p = Secp256k1Affine::generator(); @@ -265,7 +270,8 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), ) .unwrap(); let config = diff --git a/extensions/keccak256/circuit/Cargo.toml b/extensions/keccak256/circuit/Cargo.toml index 24e10deb10..28998f2f59 100644 --- a/extensions/keccak256/circuit/Cargo.toml +++ b/extensions/keccak256/circuit/Cargo.toml @@ -18,6 +18,7 @@ openvm-circuit = { workspace = true } openvm-circuit-derive = { workspace = true } openvm-instructions = { workspace = true } openvm-rv32im-circuit = { workspace = true } +openvm-memcpy-circuit = { workspace = true } openvm-keccak256-transpiler = { workspace = true } p3-keccak-air = { workspace = true } diff --git a/extensions/keccak256/circuit/src/extension/mod.rs b/extensions/keccak256/circuit/src/extension/mod.rs index 9f6e55a540..7a07080f0f 100644 --- a/extensions/keccak256/circuit/src/extension/mod.rs +++ b/extensions/keccak256/circuit/src/extension/mod.rs @@ -20,6 +20,7 @@ use openvm_circuit_primitives::bitwise_op_lookup::{ }; use openvm_instructions::*; use openvm_keccak256_transpiler::Rv32KeccakOpcode; +use openvm_memcpy_circuit::{Memcpy, MemcpyCpuProverExt, MemcpyExecutor}; use openvm_rv32im_circuit::{ Rv32I, Rv32IExecutor, Rv32ImCpuProverExt, Rv32Io, Rv32IoExecutor, Rv32M, Rv32MExecutor, }; @@ -62,6 +63,8 @@ pub struct Keccak256Rv32Config { pub io: Rv32Io, #[extension] pub keccak: Keccak256, + #[extension] + pub memcpy: Memcpy, } impl Default for Keccak256Rv32Config { @@ -72,6 +75,7 @@ impl Default for Keccak256Rv32Config { rv32m: Rv32M::default(), io: Rv32Io, keccak: Keccak256, + memcpy: Memcpy, } } } @@ -111,6 +115,11 @@ where &config.keccak, inventory, )?; + VmProverExtension::::extend_prover( + &MemcpyCpuProverExt, + &config.memcpy, + inventory, + )?; Ok(chip_complex) } } diff --git a/extensions/memcpy/README.md b/extensions/memcpy/README.md new file mode 100644 index 0000000000..b85fe256b5 --- /dev/null +++ b/extensions/memcpy/README.md @@ -0,0 +1,83 @@ +# OpenVM Memcpy Extension + +This extension provides a custom RISC-V instruction `memcpy_loop` that optimizes memory copy operations by handling different alignment shifts efficiently. + +## Custom Instruction: `memcpy_loop shift` + +### Format +``` +memcpy_loop shift +``` + +Where `shift` is an immediate value (0, 1, 2, or 3) representing the byte alignment shift. + + +### Usage +The `memcpy_loop` instruction is designed to replace repetitive shift-handling code in memcpy implementations. Instead of having separate code blocks for each shift value, you can use a single instruction: + +```assembly +memcpy_loop 1 # Handles shift=1 case +``` + +Note that you must define `memcpy_loop` before using it. For example, in [memcpy.s](../../crates/toolchain/openvm/src/memcpy.s) it is defined at the beginning of the assembly code as follows: +```assembly +.macro memcpy_loop shift + .word 0x00000072 | (\shift << 12) # opcode 0x72 + shift in immediate field (bits 12-31) +``` + +### Benefits +1. **Code Size Reduction**: Eliminates repetitive shift-handling code +2. **Performance**: Optimized implementation in the circuit layer +3. **Maintainability**: Single instruction handles all shift cases +4. **Verification**: Zero-knowledge proof ensures correct execution + +## Implementation Details + +### Circuit Layer +The instruction is implemented in the `MemcpyIterationAir` circuit which: +- Reads 4 words (16 bytes) from memory +- Applies the specified shift to combine words +- Writes the result to the destination +- Handles all shift values (0, 1, 2, 3) efficiently + +### Transpiler Extension +The `MemcpyTranspilerExtension` translates the RISC-V instruction into OpenVM's internal format: +- Parses I-type instruction format +- Validates shift value (0-3) +- Converts to OpenVM instruction with shift as operand + +### Example Usage +See `example_memcpy_optimized.s` for a complete example showing how to use the custom instruction to optimize a memcpy implementation. + +## Building and Testing + +### Compilation +```bash +# Build the extension +cargo build --package openvm-memcpy-circuit --package openvm-memcpy-transpiler + +# Check for compilation errors +cargo check --package openvm-memcpy-circuit --package openvm-memcpy-transpiler +``` + +### Integration +To use this extension in your OpenVM project: + +1. Add the transpiler extension to your OpenVM configuration +2. Use the `memcpy_loop` instruction in your RISC-V assembly +3. The circuit will handle the execution and verification + +## Architecture + +``` +RISC-V Assembly → Transpiler Extension → OpenVM Instruction → MemcpyIterationAir → Execution +``` + +The extension provides: +- **Transpiler**: `extensions/memcpy/transpiler/` - Translates RISC-V to OpenVM +- **Circuit**: `extensions/memcpy/circuit/` - Implements the instruction logic + + +# References + +- Official Keccak [spec summary](https://keccak.team/keccak_specs_summary.html) diff --git a/extensions/memcpy/circuit/Cargo.toml b/extensions/memcpy/circuit/Cargo.toml new file mode 100644 index 0000000000..c5c4034ff8 --- /dev/null +++ b/extensions/memcpy/circuit/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "openvm-memcpy-circuit" +description = "OpenVM circuit extension for memcpy" +version.workspace = true +authors.workspace = true +edition.workspace = true +homepage.workspace = true +repository.workspace = true + +[dependencies] +openvm-circuit = { workspace = true } +openvm-circuit-primitives = { workspace = true } +openvm-circuit-primitives-derive = { workspace = true } +openvm-circuit-derive = { workspace = true } +openvm-instructions = { workspace = true } +openvm-stark-backend = { workspace = true } +openvm-memcpy-transpiler = { path = "../transpiler" } +openvm-rv32im-transpiler = { workspace = true } + +derive-new.workspace = true +derive_more = { workspace = true, features = ["from"] } +serde.workspace = true +strum = { workspace = true } +tracing.workspace = true diff --git a/extensions/memcpy/circuit/src/bus.rs b/extensions/memcpy/circuit/src/bus.rs new file mode 100644 index 0000000000..83393f1aa4 --- /dev/null +++ b/extensions/memcpy/circuit/src/bus.rs @@ -0,0 +1,100 @@ +use std::iter; + +use openvm_stark_backend::{ + interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, + p3_field::FieldAlgebra, +}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct MemcpyBus { + pub inner: PermutationCheckBus, +} + +impl MemcpyBus { + pub const fn new(index: BusIndex) -> Self { + Self { + inner: PermutationCheckBus::new(index), + } + } +} + +impl MemcpyBus { + #[inline(always)] + pub fn index(&self) -> BusIndex { + self.inner.index + } + + pub fn send( + &self, + timestamp: impl Into, + dest: impl Into, + source: impl Into, + len: impl Into, + shift: impl Into, + ) -> MemcpyBusInteraction { + self.push(true, timestamp, dest, source, len, shift) + } + + pub fn receive( + &self, + timestamp: impl Into, + dest: impl Into, + source: impl Into, + len: impl Into, + shift: impl Into, + ) -> MemcpyBusInteraction { + self.push(false, timestamp, dest, source, len, shift) + } + + fn push( + &self, + is_send: bool, + timestamp: impl Into, + dest: impl Into, + source: impl Into, + len: impl Into, + shift: impl Into, + ) -> MemcpyBusInteraction { + MemcpyBusInteraction { + bus: self.inner, + is_send, + timestamp: timestamp.into(), + dest: dest.into(), + source: source.into(), + len: len.into(), + shift: shift.into(), + } + } +} + +#[derive(Clone, Debug)] +pub struct MemcpyBusInteraction { + pub bus: PermutationCheckBus, + pub is_send: bool, + pub timestamp: T, + pub dest: T, + pub source: T, + pub len: T, + pub shift: T, +} + +impl MemcpyBusInteraction { + pub fn eval(self, builder: &mut AB, direction: impl Into) + where + AB: InteractionBuilder, + { + let fields = iter::empty() + .chain(iter::once(self.timestamp)) + .chain(iter::once(self.dest)) + .chain(iter::once(self.source)) + .chain(iter::once(self.len)) + .chain(iter::once(self.shift)); + + if self.is_send { + self.bus.interact(builder, fields, direction); + } else { + self.bus + .interact(builder, fields, AB::Expr::NEG_ONE * direction.into()); + } + } +} diff --git a/extensions/memcpy/circuit/src/extension.rs b/extensions/memcpy/circuit/src/extension.rs new file mode 100644 index 0000000000..5bd3db9b8b --- /dev/null +++ b/extensions/memcpy/circuit/src/extension.rs @@ -0,0 +1,140 @@ +use std::{result::Result, sync::Arc}; + +use bus::MemcpyBus; +use derive_more::derive::From; +use openvm_circuit::{ + arch::{ + AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, ExecutionBridge, + ExecutorInventoryBuilder, ExecutorInventoryError, RowMajorMatrixArena, VmCircuitExtension, + VmExecutionExtension, VmProverExtension, + }, + system::{memory::SharedMemoryHelper, SystemPort}, +}; +use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor}; +use openvm_instructions::*; +use openvm_memcpy_transpiler::Rv32MemcpyOpcode; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + engine::StarkEngine, + p3_field::PrimeField32, + prover::cpu::{CpuBackend, CpuDevice}, +}; +use serde::{Deserialize, Serialize}; +use strum::IntoEnumIterator; + +use crate::*; + +// =================================== VM Extension Implementation ================================= +#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] +pub struct Memcpy; + +#[derive(Clone, From, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)] +pub enum MemcpyExecutor { + MemcpyLoop(MemcpyIterExecutor), +} + +impl VmExecutionExtension for Memcpy { + type Executor = MemcpyExecutor; + + fn extend_execution( + &self, + inventory: &mut ExecutorInventoryBuilder, + ) -> Result<(), ExecutorInventoryError> { + let memcpy_iter = MemcpyIterExecutor::new(Rv32MemcpyOpcode::CLASS_OFFSET); + + inventory.add_executor( + memcpy_iter, + Rv32MemcpyOpcode::iter().map(|x| x.global_opcode()), + )?; + + Ok(()) + } +} + +impl VmCircuitExtension for Memcpy { + fn extend_circuit(&self, inventory: &mut AirInventory) -> Result<(), AirInventoryError> { + let SystemPort { + execution_bus, + program_bus, + memory_bridge, + } = inventory.system().port(); + + let execution_bridge = ExecutionBridge::new(execution_bus, program_bus); + let range_bus = inventory.range_checker().bus; + let pointer_max_bits = inventory.pointer_max_bits(); + + let memcpy_bus = MemcpyBus::new(inventory.new_bus_idx()); + + let memcpy_loop = MemcpyLoopAir::new( + memory_bridge, + execution_bridge, + range_bus, + memcpy_bus, + pointer_max_bits, + Rv32MemcpyOpcode::CLASS_OFFSET, + ); + inventory.add_air(memcpy_loop); + + let memcpy_iter = + MemcpyIterAir::new(memory_bridge, range_bus, memcpy_bus, pointer_max_bits); + inventory.add_air(memcpy_iter); + + Ok(()) + } +} + +pub struct MemcpyCpuProverExt; +// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker) +// are specific to CpuBackend. +impl VmProverExtension for MemcpyCpuProverExt +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + RA: RowMajorMatrixArena>, + Val: PrimeField32, +{ + fn extend_prover( + &self, + _: &Memcpy, + inventory: &mut ChipInventory>, + ) -> Result<(), ChipInventoryError> { + let range_checker = inventory.range_checker()?.clone(); + let timestamp_max_bits = inventory.timestamp_max_bits(); + let pointer_max_bits = inventory.airs().pointer_max_bits(); + let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits); + let range_bus = inventory.airs().range_checker().bus; + let memcpy_bus = inventory + .airs() + .find_air::() + .next() + .unwrap() + .memcpy_bus; + + let memcpy_loop_chip = Arc::new(MemcpyLoopChip::new( + inventory.airs().system().port(), + range_bus, + memcpy_bus, + Rv32MemcpyOpcode::CLASS_OFFSET, + pointer_max_bits, + range_checker.clone(), + )); + + let memcpy_iter_chip = MemcpyIterChip::new( + MemcpyIterFiller::new( + pointer_max_bits, + range_checker.clone(), + memcpy_loop_chip.clone(), + ), + mem_helper.clone(), + ); + // Add MemcpyLoop chip + inventory.next_air::()?; + inventory.add_periphery_chip(memcpy_loop_chip); + + // Add MemcpyIter chip + inventory.next_air::()?; + inventory.add_executor_chip(memcpy_iter_chip); + + Ok(()) + } +} diff --git a/extensions/memcpy/circuit/src/iteration.rs b/extensions/memcpy/circuit/src/iteration.rs new file mode 100644 index 0000000000..38a31e362d --- /dev/null +++ b/extensions/memcpy/circuit/src/iteration.rs @@ -0,0 +1,1040 @@ +use std::{ + array, + borrow::{Borrow, BorrowMut}, + mem::size_of, + sync::Arc, +}; + +use openvm_circuit::{ + arch::{ + get_record_from_slice, CustomBorrow, E2PreCompute, ExecuteFunc, ExecutionCtxTrait, + ExecutionError, Executor, MeteredExecutionCtxTrait, MeteredExecutor, MultiRowLayout, + MultiRowMetadata, PreflightExecutor, RecordArena, SizedRecord, StaticProgramError, + TraceFiller, VmChipWrapper, VmExecState, VmStateMut, + }, + system::memory::{ + offline_checker::{ + MemoryBaseAuxRecord, MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, + MemoryWriteAuxCols, MemoryWriteBytesAuxRecord, + }, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, POINTER_MAX_BITS, + }, +}; +use openvm_circuit_primitives::{ + utils::{and, not, or, select}, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + AlignedBytesBorrow, +}; +use openvm_circuit_primitives_derive::AlignedBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; +use openvm_memcpy_transpiler::Rv32MemcpyOpcode; +use openvm_stark_backend::{ + interaction::InteractionBuilder, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_matrix::{dense::RowMajorMatrix, Matrix}, + p3_maybe_rayon::prelude::*, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, +}; + +use crate::{ + bus::MemcpyBus, read_rv32_register, tracing_read, tracing_write, MemcpyLoopChip, A1_REGISTER_PTR, A2_REGISTER_PTR, A3_REGISTER_PTR, A4_REGISTER_PTR +}; +// Import constants from lib.rs +use crate::{MEMCPY_LOOP_LIMB_BITS, MEMCPY_LOOP_NUM_LIMBS}; + +#[repr(C)] +#[derive(AlignedBorrow, Debug)] +pub struct MemcpyIterCols { + pub timestamp: T, + pub dest: T, + pub source: T, + pub len: [T; 2], + // 0: [0, 0, 0], 1: [1, 0, 0], 2: [0, 1, 0], 3: [0, 0, 1] + pub shift: [T; 3], + pub is_valid: T, + pub is_valid_not_start: T, + // This should be 0 if is_valid = 0. We use this to determine whether we need ro read data_4. + pub is_shift_non_zero_or_not_start: T, + // -1 for the first iteration, 1 for the last iteration, 0 for the middle iterations + pub is_boundary: T, + pub data_1: [T; MEMCPY_LOOP_NUM_LIMBS], + pub data_2: [T; MEMCPY_LOOP_NUM_LIMBS], + pub data_3: [T; MEMCPY_LOOP_NUM_LIMBS], + pub data_4: [T; MEMCPY_LOOP_NUM_LIMBS], + pub read_aux: [MemoryReadAuxCols; 4], + pub write_aux: [MemoryWriteAuxCols; 4], +} + +pub const NUM_MEMCPY_ITER_COLS: usize = size_of::>(); + +#[derive(Copy, Clone, Debug, derive_new::new)] +pub struct MemcpyIterAir { + pub memory_bridge: MemoryBridge, + pub range_bus: VariableRangeCheckerBus, + pub memcpy_bus: MemcpyBus, + pub pointer_max_bits: usize, +} + +impl BaseAir for MemcpyIterAir { + fn width(&self) -> usize { + MemcpyIterCols::::width() + } +} + +impl BaseAirWithPublicValues for MemcpyIterAir {} +impl PartitionedBaseAir for MemcpyIterAir {} + +impl Air for MemcpyIterAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (prev, local) = (main.row_slice(0), main.row_slice(1)); + let prev: &MemcpyIterCols = (*prev).borrow(); + let local: &MemcpyIterCols = (*local).borrow(); + + let timestamp: AB::Var = local.timestamp; + let mut timestamp_delta: AB::Expr = AB::Expr::ZERO; + let mut timestamp_pp = |timestamp_increase_value: AB::Var| { + timestamp_delta += timestamp_increase_value.into(); + timestamp + timestamp_delta.clone() - timestamp_increase_value.clone() + }; + + let shift = local.shift.iter().enumerate().fold(AB::Expr::ZERO, |acc, (i, x)| { + acc + (*x) * AB::Expr::from_canonical_u32(i as u32 + 1) + }); + let is_shift_non_zero = local.shift.iter().fold(AB::Expr::ZERO, |acc, x| { + acc + (*x) + }); + let is_shift_zero = not::(is_shift_non_zero.clone()); + let is_shift_one = local.shift[0]; + let is_shift_two = local.shift[1]; + let is_shift_three = local.shift[2]; + + let is_end = + (local.is_boundary + AB::Expr::ONE) * local.is_boundary * (AB::F::TWO).inverse(); + let is_not_start = (local.is_boundary + AB::Expr::ONE) + * (AB::Expr::TWO - local.is_boundary) + * (AB::F::TWO).inverse(); + let prev_is_not_end = not::( + (prev.is_boundary + AB::Expr::ONE) * prev.is_boundary * (AB::F::TWO).inverse(), + ); + + let len = local.len[0] + + local.len[1] * AB::Expr::from_canonical_u32(1 << (2 * MEMCPY_LOOP_LIMB_BITS)); + let prev_len = prev.len[0] + + prev.len[1] * AB::Expr::from_canonical_u32(1 << (2 * MEMCPY_LOOP_LIMB_BITS)); + + // write_data = + // (local.data_1[shift..4], prev.data_4[0..shift]), + // (local.data_2[shift..4], local.data_1[0..shift]), + // (local.data_3[shift..4], local.data_2[0..shift]), + // (local.data_4[shift..4], local.data_3[0..shift]) + let write_data_pairs = [ + (prev.data_4, local.data_1), + (local.data_1, local.data_2), + (local.data_2, local.data_3), + (local.data_3, local.data_4), + ]; + + let write_data = write_data_pairs + .iter() + .map(|(prev_data, next_data)| { + array::from_fn::<_, MEMCPY_LOOP_NUM_LIMBS, _>(|i| { + is_shift_zero.clone() * (next_data[i]) + + is_shift_one.clone() + * (if i < 3 { + next_data[i + 1] + } else { + prev_data[i - 3] + }) + + is_shift_two.clone() + * (if i < 2 { + next_data[i + 2] + } else { + prev_data[i - 2] + }) + + is_shift_three.clone() + * (if i < 1 { + next_data[i + 3] + } else { + prev_data[i - 1] + }) + }) + }) + .collect::>(); + + builder.assert_bool(local.is_valid); + local.shift.iter().for_each(|x| builder.assert_bool(*x)); + builder.assert_bool(is_shift_non_zero.clone()); + builder.assert_bool(local.is_valid_not_start); + builder.assert_bool(local.is_shift_non_zero_or_not_start); + // is_boundary is either -1, 0 or 1 + builder.assert_tern(local.is_boundary + AB::Expr::ONE); + + // is_valid_not_start = is_valid and is_not_start: + builder.assert_eq( + local.is_valid_not_start, + and::(local.is_valid, is_not_start), + ); + + // is_shift_non_zero_or_not_start is correct + builder.assert_eq(local.is_shift_non_zero_or_not_start, or::(is_shift_non_zero.clone(), local.is_valid_not_start)); + + // if !is_valid, then is_boundary = 0, shift = 0 (we will use this assumption later) + let mut is_not_valid_when = builder.when(not::(local.is_valid)); + is_not_valid_when.assert_zero(local.is_boundary); + is_not_valid_when.assert_zero(shift.clone()); + + // if is_valid_not_start, then len = prev_len - 16, source = prev_source + 16, + // and dest = prev_dest + 16, shift = prev_shift + let mut is_valid_not_start_when = builder.when(local.is_valid_not_start); + is_valid_not_start_when + .assert_eq(len.clone(), prev_len - AB::Expr::from_canonical_u32(16)); + is_valid_not_start_when + .assert_eq(local.source, prev.source + AB::Expr::from_canonical_u32(16)); + is_valid_not_start_when.assert_eq(local.dest, prev.dest + AB::Expr::from_canonical_u32(16)); + local.shift.iter().zip(prev.shift.iter()).for_each(|(local_shift, prev_shift)| { + is_valid_not_start_when.assert_eq(*local_shift, *prev_shift); + }); + + // make sure if previous row is valid and not end, then local.is_valid = 1 + builder + .when(prev_is_not_end - not::(prev.is_valid)) + .assert_one(local.is_valid); + + // if prev.is_valid_start, then timestamp = prev_timestamp + is_shift_non_zero + // since is_shift_non_zero degree is 2, we need to keep the degree of the condition to 1 + builder + .when(not::(prev.is_valid_not_start) - not::(prev.is_valid)) + .assert_eq(local.timestamp, prev.timestamp + is_shift_non_zero); + + // if prev.is_valid_not_start and local.is_valid_not_start, then timestamp=prev_timestamp+8 + // prev.is_valid_not_start is the opposite of previous condition + builder + .when( + local.is_valid_not_start + - (not::(prev.is_valid_not_start) - not::(prev.is_valid)), + ) + .assert_eq( + local.timestamp, + prev.timestamp + AB::Expr::from_canonical_usize(8), + ); + + // Receive message from memcpy bus or send message to it + // The last data is shift if is_boundary = -1, and 4 if is_boundary = 1 + // This actually receives when is_boundary = -1 + self.memcpy_bus + .send( + local.timestamp + + (local.is_boundary + AB::Expr::ONE) * AB::Expr::from_canonical_usize(4), + local.dest, + local.source, + len.clone(), + (AB::Expr::ONE - local.is_boundary) * shift.clone() * (AB::F::TWO).inverse() + + (local.is_boundary + AB::Expr::ONE) * AB::Expr::TWO, + ) + .eval(builder, local.is_boundary); + + // Read data from memory + let read_data = [ + local.data_1, + local.data_2, + local.data_3, + local.data_4, + ]; + + read_data + .iter() + .enumerate() + .for_each(|(idx, data)| { + let is_valid_read = if idx == 3 { + local.is_shift_non_zero_or_not_start + } else { + local.is_valid_not_start + }; + + self.memory_bridge + .read( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + local.source - AB::Expr::from_canonical_usize(16 - idx * 4), + ), + *data, + timestamp_pp(is_valid_read.clone()), + &local.read_aux[idx], + ) + .eval(builder, is_valid_read.clone()); + }); + + // Write final data to registers + write_data.iter().enumerate().for_each(|(idx, data)| { + self.memory_bridge + .write( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + local.dest - AB::Expr::from_canonical_usize(16 - idx * 4), + ), + data.clone(), + timestamp_pp(local.is_valid_not_start), + &local.write_aux[idx], + ) + .eval(builder, local.is_valid_not_start); + }); + + // Range check len + let len_bits_limit = [ + select::( + is_end.clone(), + AB::Expr::from_canonical_usize(4), + AB::Expr::from_canonical_usize(MEMCPY_LOOP_LIMB_BITS * 2), + ), + select::( + is_end.clone(), + AB::Expr::ZERO, + AB::Expr::from_canonical_usize(self.pointer_max_bits - MEMCPY_LOOP_LIMB_BITS * 2), + ), + ]; + self.range_bus + .push(local.len[0], len_bits_limit[0].clone(), true) + .eval(builder, local.is_valid); + self.range_bus + .push(local.len[1], len_bits_limit[1].clone(), true) + .eval(builder, local.is_valid); + } +} + +#[derive(derive_new::new, Clone, Copy)] +pub struct MemcpyIterExecutor { + pub offset: usize, +} + +#[derive(Copy, Clone, Debug)] +pub struct MemcpyIterMetadata { + num_rows: usize, +} + +impl MultiRowMetadata for MemcpyIterMetadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + self.num_rows + } +} + +pub type MemcpyIterLayout = MultiRowLayout; + +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct MemcpyIterRecordHeader { + pub shift: u8, + pub dest: u32, + pub source: u32, + pub len: u32, + pub from_pc: u32, + pub from_timestamp: u32, + pub register_aux: [MemoryBaseAuxRecord; 3], +} + +// This is the part of the record that we keep `(len & !15) + (shift != 0)` times per instruction +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct MemcpyIterRecordVar { + pub data: [[u8; MEMCPY_LOOP_NUM_LIMBS]; 4], + pub read_aux: [MemoryReadAuxRecord; 4], + pub write_aux: [MemoryWriteBytesAuxRecord<4>; 4], +} + +/// **SAFETY**: the order of the fields in `MemcpyLoopRecordMut` and `MemcpyLoopRecordVar` +/// is important. +#[derive(Debug)] +pub struct MemcpyIterRecordMut<'a> { + pub inner: &'a mut MemcpyIterRecordHeader, + pub var: &'a mut [MemcpyIterRecordVar], +} + +/// Custom borrowing that splits the buffer into a fixed `MemcpyLoopRecordHeader` header +/// followed by a slice of `MemcpyLoopRecordVar`'s of length `num_words` provided at runtime. +/// Uses `align_to_mut()` to make sure the slice is properly aligned to `MemcpyLoopRecordVar`. +/// Has debug assertions to make sure the above works as expected. +impl<'a> CustomBorrow<'a, MemcpyIterRecordMut<'a>, MemcpyIterLayout> for [u8] { + fn custom_borrow(&'a mut self, layout: MemcpyIterLayout) -> MemcpyIterRecordMut<'a> { + let (header_buf, rest) = + unsafe { self.split_at_mut_unchecked(size_of::()) }; + + let (_, vars, _) = unsafe { rest.align_to_mut::() }; + MemcpyIterRecordMut { + inner: header_buf.borrow_mut(), + var: &mut vars[..layout.metadata.num_rows], + } + } + + unsafe fn extract_layout(&self) -> MemcpyIterLayout { + let header: &MemcpyIterRecordHeader = self.borrow(); + MultiRowLayout::new(MemcpyIterMetadata { + num_rows: ((header.len - header.shift as u32) >> 4) as usize + 1, + }) + } +} + +impl SizedRecord for MemcpyIterRecordMut<'_> { + fn size(layout: &MemcpyIterLayout) -> usize { + let mut total_len = size_of::(); + // Align the pointer to the alignment of `Rv32HintStoreVar` + total_len = total_len.next_multiple_of(align_of::()); + total_len += size_of::() * layout.metadata.num_rows; + total_len + } + + fn alignment(_layout: &MemcpyIterLayout) -> usize { + align_of::() + } +} + +#[derive(derive_new::new)] +pub struct MemcpyIterFiller { + pub pointer_max_bits: usize, + pub range_checker_chip: SharedVariableRangeCheckerChip, + pub memcpy_loop_chip: Arc, +} + +pub type MemcpyIterChip = VmChipWrapper; + +impl PreflightExecutor for MemcpyIterExecutor +where + F: PrimeField32, + for<'buf> RA: RecordArena<'buf, MultiRowLayout, MemcpyIterRecordMut<'buf>>, +{ + fn get_opcode_name(&self, _: usize) -> String { + format!("{:?}", Rv32MemcpyOpcode::MEMCPY_LOOP) + } + + fn execute( + &self, + state: VmStateMut, + instruction: &Instruction, + ) -> Result<(), ExecutionError> { + let Instruction { opcode, c, .. } = instruction; + debug_assert_eq!(*opcode, Rv32MemcpyOpcode::MEMCPY_LOOP.global_opcode()); + let shift = c.as_canonical_u32() as u8; + debug_assert!([0, 1, 2, 3].contains(&shift)); + + let mut dest = read_rv32_register( + state.memory.data(), + if shift == 0 { + A3_REGISTER_PTR + } else { + A1_REGISTER_PTR + } as u32, + ); + let mut source = read_rv32_register( + state.memory.data(), + if shift == 0 { + A4_REGISTER_PTR + } else { + A3_REGISTER_PTR + } as u32, + ); + let mut len = read_rv32_register(state.memory.data(), A2_REGISTER_PTR as u32); + + // Create a record with var_size = ((len - shift) >> 4) + 1 which is the number of rows in iteration trace + let record = state.ctx.alloc(MultiRowLayout::new(MemcpyIterMetadata { + num_rows: ((len - shift as u32) >> 4) as usize + 1, + })); + + // Store the original values in the record + record.inner.shift = shift; + record.inner.from_pc = *state.pc; + record.inner.from_timestamp = state.memory.timestamp; + record.inner.dest = dest; + record.inner.source = source; + record.inner.len = len; + + // Fill record.var for the first row of iteration trace + if shift != 0 { + source -= 12; + record.var[0].data[3] = tracing_read( + state.memory, + RV32_MEMORY_AS, + source - 4, + &mut record.var[0].read_aux[3].prev_timestamp, + ); + }; + + // Fill record.var for the rest of the rows of iteration trace + let mut idx = 1; + while len - shift as u32 > 15 { + let writes_data: [[u8; MEMCPY_LOOP_NUM_LIMBS]; 4] = array::from_fn(|i| { + record.var[idx].data[i] = tracing_read( + state.memory, + RV32_MEMORY_AS, + source + 4 * i as u32, + &mut record.var[idx].read_aux[i].prev_timestamp, + ); + let write_data: [u8; MEMCPY_LOOP_NUM_LIMBS] = array::from_fn(|j| { + if j < 4 - shift as usize { + record.var[idx].data[i][j + shift as usize] + } else if i > 0 { + record.var[idx].data[i - 1][j - (4 - shift as usize)] + } else { + record.var[idx - 1].data[3][j - (4 - shift as usize)] + } + }); + write_data + }); + writes_data.iter().enumerate().for_each(|(i, write_data)| { + tracing_write( + state.memory, + RV32_MEMORY_AS, + dest + 4 * i as u32, + *write_data, + &mut record.var[idx].write_aux[i].prev_timestamp, + &mut record.var[idx].write_aux[i].prev_data, + ); + }); + len -= 16; + source += 16; + dest += 16; + idx += 1; + } + + // Handle the core loop + if shift != 0 { + source += 12; + } + + let mut dest_data = [0; 4]; + let mut source_data = [0; 4]; + let mut len_data = [0; 4]; + + tracing_write( + state.memory, + RV32_REGISTER_AS, + if shift == 0 { + A3_REGISTER_PTR + } else { + A1_REGISTER_PTR + } as u32, + dest.to_le_bytes(), + &mut record.inner.register_aux[0].prev_timestamp, + &mut dest_data, + ); + + tracing_write( + state.memory, + RV32_REGISTER_AS, + if shift == 0 { + A4_REGISTER_PTR + } else { + A3_REGISTER_PTR + } as u32, + source.to_le_bytes(), + &mut record.inner.register_aux[1].prev_timestamp, + &mut source_data, + ); + + tracing_write( + state.memory, + RV32_REGISTER_AS, + A2_REGISTER_PTR as u32, + len.to_le_bytes(), + &mut record.inner.register_aux[2].prev_timestamp, + &mut len_data, + ); + + debug_assert_eq!(record.inner.dest, u32::from_le_bytes(dest_data)); + debug_assert_eq!(record.inner.source, u32::from_le_bytes(source_data)); + debug_assert_eq!(record.inner.len, u32::from_le_bytes(len_data)); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller for MemcpyIterFiller { + fn fill_trace( + &self, + mem_helper: &MemoryAuxColsFactory, + trace: &mut RowMajorMatrix, + rows_used: usize, + ) { + if rows_used == 0 { + return; + } + + let width = trace.width; + debug_assert_eq!(width, NUM_MEMCPY_ITER_COLS); + let mut trace = &mut trace.values[..width * rows_used]; + let mut sizes = Vec::with_capacity(rows_used >> 1); + let mut chunks = Vec::with_capacity(rows_used >> 1); + + let mut num_loops: usize = 0; + let mut num_iters: usize = 0; + + while !trace.is_empty() { + let record: &MemcpyIterRecordHeader = unsafe { get_record_from_slice(&mut trace, ()) }; + let num_rows = ((record.len - record.shift as u32) >> 4) as usize + 1; + let (chunk, rest) = trace.split_at_mut(width * num_rows as usize); + sizes.push(num_rows); + chunks.push(chunk); + trace = rest; + num_loops += 1; + num_iters += num_rows; + } + tracing::info!("num_loops: {:?}, num_iters: {:?}, sizes: {:?}", num_loops, num_iters, sizes); + + chunks + .par_iter_mut() + .zip(sizes.par_iter()) + .enumerate() + .for_each(|(row_idx, (chunk, &num_rows))| { + let record: MemcpyIterRecordMut = unsafe { + get_record_from_slice( + chunk, + MultiRowLayout::new(MemcpyIterMetadata { num_rows }), + ) + }; + + tracing::info!("shift: {:?}", record.inner.shift); + // Fill memcpy loop record + self.memcpy_loop_chip.add_new_loop( + mem_helper, + record.inner.from_pc, + record.inner.from_timestamp, + record.inner.dest, + record.inner.source, + record.inner.len, + record.inner.shift, + record.inner.register_aux.clone(), + ); + + // Calculate the timestamp for the last memory access + // 4 reads + 4 writes per iteration + (shift != 0) read for the loop header + let timestamp = record.inner.from_timestamp + + ((num_rows - 1) << 3) as u32 + + (record.inner.shift != 0) as u32; + let mut timestamp_delta: u32 = 0; + let mut get_timestamp = |is_access: bool| { + if is_access { + timestamp_delta += 1; + } + timestamp - timestamp_delta + }; + + let mut dest = record.inner.dest + ((num_rows - 1) << 4) as u32; + let mut source = record.inner.source + ((num_rows - 1) << 4) as u32 + - 12 * (record.inner.shift != 0) as u32; + let mut len = + record.inner.len - ((num_rows - 1) << 4) as u32 - record.inner.shift as u32; + + // We are going to fill row in the reverse order + chunk + .rchunks_exact_mut(width) + .zip(record.var.iter().enumerate().rev()) + .for_each(|(row, (idx, var))| { + let cols: &mut MemcpyIterCols = row.borrow_mut(); + + let is_start = idx == 0; + let is_end = idx == num_rows - 1; + + // Range check len + let len_u16_limbs = [len & 0xffff, len >> 16]; + if is_end { + self.range_checker_chip.add_count(len_u16_limbs[0], 4); + self.range_checker_chip.add_count(len_u16_limbs[1], 0); + } else { + self.range_checker_chip + .add_count(len_u16_limbs[0], 2 * MEMCPY_LOOP_LIMB_BITS); + self.range_checker_chip.add_count( + len_u16_limbs[1], + self.pointer_max_bits - 2 * MEMCPY_LOOP_LIMB_BITS, + ); + } + + // Fill memory read/write auxiliary columns + if is_start { + cols.write_aux.iter_mut().rev().for_each(|aux_col| { + mem_helper.fill_zero(aux_col.as_mut()); + }); + + if record.inner.shift == 0 { + mem_helper.fill_zero(cols.read_aux[3].as_mut()); + } else { + mem_helper.fill( + var.read_aux[3].prev_timestamp, + get_timestamp(true), + cols.read_aux[3].as_mut(), + ); + } + cols.read_aux[..2].iter_mut().rev().for_each(|aux_col| { + mem_helper.fill_zero(aux_col.as_mut()); + }); + + debug_assert_eq!(get_timestamp(false), record.inner.from_timestamp); + } else { + var.write_aux + .iter() + .zip(cols.write_aux.iter_mut()) + .rev() + .for_each(|(aux_record, aux_col)| { + mem_helper.fill( + aux_record.prev_timestamp, + get_timestamp(true), + aux_col.as_mut(), + ); + aux_col.set_prev_data( + aux_record.prev_data.map(F::from_canonical_u8), + ); + }); + + var.read_aux + .iter() + .zip(cols.read_aux.iter_mut()) + .rev() + .for_each(|(aux_record, aux_col)| { + mem_helper.fill( + aux_record.prev_timestamp, + get_timestamp(true), + aux_col.as_mut(), + ); + }); + } + + cols.data_4 = var.data[3].map(F::from_canonical_u8); + cols.data_3 = var.data[2].map(F::from_canonical_u8); + cols.data_2 = var.data[1].map(F::from_canonical_u8); + cols.data_1 = var.data[0].map(F::from_canonical_u8); + cols.is_boundary = if is_end { + F::ONE + } else if is_start { + F::NEG_ONE + } else { + F::ZERO + }; + cols.is_shift_non_zero_or_not_start = F::from_bool(record.inner.shift != 0 || !is_start); + cols.is_valid_not_start = F::from_bool(!is_start); + cols.is_valid = F::ONE; + cols.shift = [ + record.inner.shift == 1, + record.inner.shift == 2, + record.inner.shift == 3, + ] + .map(F::from_bool); + cols.len = [len & 0xffff, len >> 16].map(F::from_canonical_u32); + cols.source = F::from_canonical_u32(source); + cols.dest = F::from_canonical_u32(dest); + cols.timestamp = F::from_canonical_u32(get_timestamp(false)); + + dest -= 16; + source -= 16; + len += 16; + + // if row_idx == 0 && is_start { + // tracing::info!("first_roooooow, timestamp: {:?}, dest: {:?}, source: {:?}, len_0: {:?}, len_1: {:?}, shift: {:?}, is_valid: {:?}, is_valid_not_start: {:?}, is_shift_non_zero: {:?}, is_boundary: {:?}, data_1: {:?}, data_2: {:?}, data_3: {:?}, data_4: {:?}, read_aux: {:?}, read_aux_lt: {:?}", + // cols.timestamp.as_canonical_u32(), + // cols.dest.as_canonical_u32(), + // cols.source.as_canonical_u32(), + // cols.len[0].as_canonical_u32(), + // cols.len[1].as_canonical_u32(), + // cols.shift[1].as_canonical_u32() * 2 + cols.shift[0].as_canonical_u32(), + // cols.is_valid.as_canonical_u32(), + // cols.is_valid_not_start.as_canonical_u32(), + // cols.is_shift_non_zero.as_canonical_u32(), + // cols.is_boundary.as_canonical_u32(), + // cols.data_1.map(|x| x.as_canonical_u32()).to_vec(), + // cols.data_2.map(|x| x.as_canonical_u32()).to_vec(), + // cols.data_3.map(|x| x.as_canonical_u32()).to_vec(), + // cols.data_4.map(|x| x.as_canonical_u32()).to_vec(), + // cols.read_aux.map(|x| x.get_base().prev_timestamp.as_canonical_u32()).to_vec(), + // cols.read_aux.map(|x| x.get_base().timestamp_lt_aux.lower_decomp.iter().map(|x| x.as_canonical_u32()).collect::>()).to_vec()); + // } + }); + }); + + // chunks.iter().enumerate().for_each(|(row_idx, chunk)| { + // let mut prv_data = [0; 4]; + // tracing::info!("row_idx: {:?}", row_idx); + + // chunk.chunks_exact(width) + // .enumerate() + // .for_each(|(idx, row)| { + // let cols: &MemcpyIterCols = row.borrow(); + // let is_valid_not_start = cols.is_valid_not_start.as_canonical_u32() != 0; + // let is_shift_non_zero = cols.is_shift_non_zero.as_canonical_u32() != 0; + // let source = cols.source.as_canonical_u32(); + // let dest = cols.dest.as_canonical_u32(); + // let mut bad_col = false; + // tracing::info!("source: {:?}, dest: {:?}", source, dest); + // cols.read_aux.iter().enumerate().for_each(|(idx, aux)| { + // if is_valid_not_start || (is_shift_non_zero && idx == 3) { + // let prev_t = aux.get_base().prev_timestamp.as_canonical_u32(); + // let curr_t = cols.timestamp.as_canonical_u32(); + // let ts_lt = aux.get_base().timestamp_lt_aux.lower_decomp.iter() + // .enumerate() + // .fold(F::ZERO, |acc, (i, &val)| { + // acc + val * F::from_canonical_usize(1 << (i * 17)) + // }).as_canonical_u32(); + // if curr_t + idx as u32 != ts_lt + prev_t + 1 { + // bad_col = true; + // } + // } + // if dest + 4 * idx as u32 == 2097216 || dest - 4 * (idx + 1) as u32 == 2097216 || dest + 4 * idx as u32 == 2097280 || dest - 4 * (idx + 1) as u32 == 2097280 { + // bad_col = true; + // } + // }); + // if bad_col { + // let write_data_pairs = [ + // (prv_data, cols.data_1.map(|x| x.as_canonical_u32())), + // (cols.data_1.map(|x| x.as_canonical_u32()), cols.data_2.map(|x| x.as_canonical_u32())), + // (cols.data_2.map(|x| x.as_canonical_u32()), cols.data_3.map(|x| x.as_canonical_u32())), + // (cols.data_3.map(|x| x.as_canonical_u32()), cols.data_4.map(|x| x.as_canonical_u32())), + // ]; + + // let shift = cols.shift[1].as_canonical_u32() * 2 + cols.shift[0].as_canonical_u32(); + // let write_data = write_data_pairs + // .iter() + // .map(|(prev_data, next_data)| { + // array::from_fn::<_, MEMCPY_LOOP_NUM_LIMBS, _>(|i| { + // (shift == 0) as u32 * (next_data[i]) + // + (shift == 1) as u32 + // * (if i < 3 { + // next_data[i + 1] + // } else { + // prev_data[i - 3] + // }) + // + (shift == 2) as u32 + // * (if i < 2 { + // next_data[i + 2] + // } else { + // prev_data[i - 2] + // }) + // + (shift == 3) as u32 + // * (if i < 1 { + // next_data[i + 3] + // } else { + // prev_data[i - 1] + // }) + // }) + // }) + // .collect::>(); + + + + + // tracing::info!("row_idx: {:?}, idx: {:?}, timestamp: {:?}, dest: {:?}, source: {:?}, len_0: {:?}, len_1: {:?}, shift_0: {:?}, shift_1: {:?}, is_valid: {:?}, is_valid_not_start: {:?}, is_shift_non_zero: {:?}, is_boundary: {:?}, write_data: {:?}, prv_data: {:?}, data_1: {:?}, data_2: {:?}, data_3: {:?}, data_4: {:?}, read_aux: {:?}, read_aux_lt: {:?}, write_aux: {:?}, write_aux_lt: {:?}, write_aux_prev_data: {:?}", + // row_idx, + // idx, + // cols.timestamp.as_canonical_u32(), + // cols.dest.as_canonical_u32(), + // cols.source.as_canonical_u32(), + // cols.len[0].as_canonical_u32(), + // cols.len[1].as_canonical_u32(), + // cols.shift[0].as_canonical_u32(), + // cols.shift[1].as_canonical_u32(), + // cols.is_valid.as_canonical_u32(), + // cols.is_valid_not_start.as_canonical_u32(), + // cols.is_shift_non_zero.as_canonical_u32(), + // cols.is_boundary.as_canonical_u32(), + // write_data, + // prv_data, + // cols.data_1.map(|x| x.as_canonical_u32()).to_vec(), + // cols.data_2.map(|x| x.as_canonical_u32()).to_vec(), + // cols.data_3.map(|x| x.as_canonical_u32()).to_vec(), + // cols.data_4.map(|x| x.as_canonical_u32()).to_vec(), + // cols.read_aux.map(|x| x.get_base().prev_timestamp.as_canonical_u32()).to_vec(), + // cols.read_aux.map(|x| x.get_base().timestamp_lt_aux.lower_decomp.iter().map(|x| x.as_canonical_u32()).collect::>()).to_vec(), + // cols.write_aux.map(|x| x.get_base().prev_timestamp.as_canonical_u32()).to_vec(), + // cols.write_aux.map(|x| x.get_base().timestamp_lt_aux.lower_decomp.iter().map(|x| x.as_canonical_u32()).collect::>()).to_vec(), + // cols.write_aux.map(|x| x.prev_data.map(|x| x.as_canonical_u32()).to_vec()).to_vec()); + // } + // prv_data = cols.data_4.map(|x| x.as_canonical_u32()); + // }); + // }); + // assert!(false); + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct MemcpyIterPreCompute { + c: u8, +} + +impl Executor for MemcpyIterExecutor { + fn pre_compute_size(&self) -> usize { + size_of::() + } + + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut MemcpyIterPreCompute = data.borrow_mut(); + self.pre_compute_impl(pc, inst, data)?; + Ok(execute_e1_impl::<_, _>) + } +} + +impl MeteredExecutor for MemcpyIterExecutor { + fn metered_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + self.pre_compute_impl(pc, inst, &mut data.data)?; + Ok(execute_e2_impl::<_, _>) + } +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &MemcpyIterPreCompute, + vm_state: &mut VmExecState, +) -> u32 { + let shift = pre_compute.c; + let mut height = 1; + // Read dest and source from registers + let (dest, source) = if shift == 0 { + ( + vm_state.vm_read::(RV32_REGISTER_AS, A3_REGISTER_PTR as u32), + vm_state.vm_read::(RV32_REGISTER_AS, A4_REGISTER_PTR as u32), + ) + } else { + ( + vm_state.vm_read::(RV32_REGISTER_AS, A1_REGISTER_PTR as u32), + vm_state.vm_read::(RV32_REGISTER_AS, A3_REGISTER_PTR as u32), + ) + }; + // Read length from a2 register + let len = vm_state.vm_read::(RV32_REGISTER_AS, A2_REGISTER_PTR as u32); + + let mut dest = u32::from_le_bytes(dest); + let mut source = u32::from_le_bytes(source) - 12 * (shift != 0) as u32; + let mut len = u32::from_le_bytes(len); + + // Check address ranges are valid + debug_assert!(dest < (1 << POINTER_MAX_BITS)); + debug_assert!((source - 4 * (shift != 0) as u32) < (1 << POINTER_MAX_BITS)); + let to_dest = dest + ((len - shift as u32) & !15); + let to_source = source + ((len - shift as u32) & !15); + debug_assert!(to_dest <= (1 << POINTER_MAX_BITS)); + debug_assert!(to_source <= (1 << POINTER_MAX_BITS)); + // Make sure the destination and source are not overlapping + debug_assert!(to_dest <= source || to_source <= dest); + + // Read the previous data from memory if shift != 0 + let mut prev_data = if shift == 0 { + [0; 4] + } else { + vm_state.vm_read::(RV32_MEMORY_AS, source - 4) + }; + + // Run iterations + while len - shift as u32 > 15 { + for i in 0..4 { + let data = vm_state.vm_read::(RV32_MEMORY_AS, source + 4 * i); + let write_data: [u8; 4] = array::from_fn(|i| { + if i < 4 - shift as usize { + data[i + shift as usize] + } else { + prev_data[i - (4 - shift as usize)] + } + }); + vm_state.vm_write(RV32_MEMORY_AS, dest + 4 * i, &write_data); + prev_data = data; + } + len -= 16; + source += 16; + dest += 16; + height += 1; + } + + // Write the result back to memory + if shift == 0 { + vm_state.vm_write( + RV32_REGISTER_AS, + A3_REGISTER_PTR as u32, + &dest.to_le_bytes(), + ); + vm_state.vm_write( + RV32_REGISTER_AS, + A4_REGISTER_PTR as u32, + &source.to_le_bytes(), + ); + } else { + source += 12; + vm_state.vm_write( + RV32_REGISTER_AS, + A1_REGISTER_PTR as u32, + &dest.to_le_bytes(), + ); + vm_state.vm_write( + RV32_REGISTER_AS, + A3_REGISTER_PTR as u32, + &source.to_le_bytes(), + ); + }; + vm_state.vm_write(RV32_REGISTER_AS, A2_REGISTER_PTR as u32, &len.to_le_bytes()); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; + height +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &MemcpyIterPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + let height = execute_e12_impl::(&pre_compute.data, vm_state); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height); +} + +impl MemcpyIterExecutor { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut MemcpyIterPreCompute, + ) -> Result<(), StaticProgramError> { + let Instruction { opcode, c, .. } = inst; + let c_u32 = c.as_canonical_u32(); + if ![0, 1, 2, 3].contains(&c_u32) { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + *data = MemcpyIterPreCompute { c: c_u32 as u8 }; + assert_eq!(*opcode, Rv32MemcpyOpcode::MEMCPY_LOOP.global_opcode()); + Ok(()) + } +} diff --git a/extensions/memcpy/circuit/src/lib.rs b/extensions/memcpy/circuit/src/lib.rs new file mode 100644 index 0000000000..28b63d6a65 --- /dev/null +++ b/extensions/memcpy/circuit/src/lib.rs @@ -0,0 +1,113 @@ +mod bus; +mod extension; +mod iteration; +mod loops; + +pub use bus::*; +pub use extension::*; +pub use iteration::*; +pub use loops::*; +use openvm_circuit::system::memory::{merkle::public_values::PUBLIC_VALUES_AS, online::{GuestMemory, TracingMemory}}; +use openvm_instructions::riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}; + +// ==== Do not change these constants! ==== +pub const MEMCPY_LOOP_NUM_LIMBS: usize = 4; +pub const MEMCPY_LOOP_LIMB_BITS: usize = 8; + +pub const A1_REGISTER_PTR: usize = 11 * 4; +pub const A2_REGISTER_PTR: usize = 12 * 4; +pub const A3_REGISTER_PTR: usize = 13 * 4; +pub const A4_REGISTER_PTR: usize = 14 * 4; + + +// TODO: These are duplicated from extensions/rv32im/circuit/src/adapters/mod.rs +// to prevent cyclic dependencies. Fix this. + +#[inline(always)] +pub fn memory_read(memory: &GuestMemory, address_space: u32, ptr: u32) -> [u8; N] { + debug_assert!( + address_space == RV32_REGISTER_AS + || address_space == RV32_MEMORY_AS + || address_space == PUBLIC_VALUES_AS, + ); + + // SAFETY: + // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and + // minimum alignment of `RV32_REGISTER_NUM_LIMBS` + unsafe { memory.read::(address_space, ptr) } +} + +/// Atomic read operation which increments the timestamp by 1. +/// Returns `(t_prev, [ptr:4]_{address_space})` where `t_prev` is the timestamp of the last memory +/// access. +#[inline(always)] +pub fn timed_read( + memory: &mut TracingMemory, + address_space: u32, + ptr: u32, +) -> (u32, [u8; N]) { + debug_assert!( + address_space == RV32_REGISTER_AS + || address_space == RV32_MEMORY_AS + || address_space == PUBLIC_VALUES_AS + ); + + // SAFETY: + // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and + // minimum alignment of `MEMCPY_LOOP_NUM_LIMBS` + unsafe { memory.read::(address_space, ptr) } +} + +#[inline(always)] +pub fn timed_write( + memory: &mut TracingMemory, + address_space: u32, + ptr: u32, + data: [u8; N], +) -> (u32, [u8; N]) { + debug_assert!( + address_space == RV32_REGISTER_AS + || address_space == RV32_MEMORY_AS + || address_space == PUBLIC_VALUES_AS + ); + + // SAFETY: + // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and + // minimum alignment of `MEMCPY_LOOP_NUM_LIMBS` + unsafe { memory.write::(address_space, ptr, data) } +} + +/// Reads register value at `reg_ptr` from memory and records the memory access in mutable buffer. +/// Trace generation relevant to this memory access can be done fully from the recorded buffer. +#[inline(always)] +pub fn tracing_read( + memory: &mut TracingMemory, + address_space: u32, + ptr: u32, + prev_timestamp: &mut u32, +) -> [u8; N] { + let (t_prev, data) = timed_read(memory, address_space, ptr); + *prev_timestamp = t_prev; + data +} + +/// Writes `reg_ptr, reg_val` into memory and records the memory access in mutable buffer. +/// Trace generation relevant to this memory access can be done fully from the recorded buffer. +#[inline(always)] +pub fn tracing_write( + memory: &mut TracingMemory, + address_space: u32, + ptr: u32, + data: [u8; N], + prev_timestamp: &mut u32, + prev_data: &mut [u8; N], +) { + let (t_prev, data_prev) = timed_write(memory, address_space, ptr, data); + *prev_timestamp = t_prev; + *prev_data = data_prev; +} + +#[inline(always)] +pub fn read_rv32_register(memory: &GuestMemory, ptr: u32) -> u32 { + u32::from_le_bytes(memory_read(memory, RV32_REGISTER_AS, ptr)) +} diff --git a/extensions/memcpy/circuit/src/loops.rs b/extensions/memcpy/circuit/src/loops.rs new file mode 100644 index 0000000000..16204967b6 --- /dev/null +++ b/extensions/memcpy/circuit/src/loops.rs @@ -0,0 +1,480 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + mem::size_of, + sync::{Arc, Mutex}, +}; + +use openvm_circuit::{ + arch::{ExecutionBridge, ExecutionState}, + system::{ + memory::{ + offline_checker::{ + MemoryBaseAuxCols, MemoryBaseAuxRecord, MemoryBridge, MemoryExtendedAuxRecord, + MemoryWriteAuxCols, + }, + MemoryAddress, MemoryAuxColsFactory, + }, + SystemPort, + }, + utils::next_power_of_two_or_zero, +}; +use openvm_circuit_primitives::{ + utils::{not, or, select}, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + AlignedBytesBorrow, +}; +use openvm_circuit_primitives_derive::AlignedBorrow; +use openvm_instructions::riscv::RV32_REGISTER_AS; +use openvm_memcpy_transpiler::Rv32MemcpyOpcode; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + interaction::InteractionBuilder, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_matrix::{dense::RowMajorMatrix, Matrix}, + prover::{cpu::CpuBackend, types::AirProvingContext}, + rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, + Chip, ChipUsageGetter, +}; + +use crate::bus::MemcpyBus; +// Import constants from lib.rs +use crate::{ + A1_REGISTER_PTR, A2_REGISTER_PTR, A3_REGISTER_PTR, A4_REGISTER_PTR, MEMCPY_LOOP_LIMB_BITS, + MEMCPY_LOOP_NUM_LIMBS, +}; + +#[repr(C)] +#[derive(AlignedBorrow, Debug)] +pub struct MemcpyLoopCols { + pub from_state: ExecutionState, + pub dest: [T; MEMCPY_LOOP_NUM_LIMBS], + pub source: [T; MEMCPY_LOOP_NUM_LIMBS], + pub len: [T; MEMCPY_LOOP_NUM_LIMBS], + pub shift: [T; 2], + pub is_valid: T, + pub to_timestamp: T, + pub to_dest: [T; MEMCPY_LOOP_NUM_LIMBS], + pub to_source: [T; MEMCPY_LOOP_NUM_LIMBS], + pub to_len: T, + pub write_aux: [MemoryBaseAuxCols; 3], + pub source_minus_twelve_carry: T, + pub to_source_minus_twelve_carry: T, +} + +pub const NUM_MEMCPY_LOOP_COLS: usize = size_of::>(); +pub const MEMCPY_LOOP_NUM_WRITES: u32 = 3; + +#[derive(Copy, Clone, Debug, derive_new::new)] +pub struct MemcpyLoopAir { + pub memory_bridge: MemoryBridge, + pub execution_bridge: ExecutionBridge, + pub range_bus: VariableRangeCheckerBus, + pub memcpy_bus: MemcpyBus, + pub pointer_max_bits: usize, + pub offset: usize, +} + +impl BaseAir for MemcpyLoopAir { + fn width(&self) -> usize { + MemcpyLoopCols::::width() + } +} + +impl BaseAirWithPublicValues for MemcpyLoopAir {} +impl PartitionedBaseAir for MemcpyLoopAir {} + +impl Air for MemcpyLoopAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let local: &MemcpyLoopCols = (*local).borrow(); + + let mut timestamp_delta: u32 = 0; + let mut timestamp_pp = || { + timestamp_delta += 1; + local.to_timestamp + - AB::Expr::from_canonical_u32(MEMCPY_LOOP_NUM_WRITES - (timestamp_delta - 1)) + }; + + let from_le_bytes = |data: [AB::Var; 4]| { + data.iter().rev().fold(AB::Expr::ZERO, |acc, x| { + acc * AB::Expr::from_canonical_u32(1 << MEMCPY_LOOP_LIMB_BITS) + *x + }) + }; + + let u8_word_to_u16 = |data: [AB::Var; 4]| { + [ + data[0] + data[1] * AB::F::from_canonical_u32(1 << MEMCPY_LOOP_LIMB_BITS), + data[2] + data[3] * AB::F::from_canonical_u32(1 << MEMCPY_LOOP_LIMB_BITS), + ] + }; + + let shift = local.shift[1] * AB::Expr::TWO + local.shift[0]; + let is_shift_non_zero = or::(local.shift[0], local.shift[1]); + let is_shift_zero = not::(is_shift_non_zero.clone()); + let dest = from_le_bytes(local.dest); + let source = from_le_bytes(local.source); + let len = from_le_bytes(local.len); + let to_dest = from_le_bytes(local.to_dest); + let to_source = from_le_bytes(local.to_source); + let to_len = local.to_len; + + builder.assert_bool(local.is_valid); + local.shift.iter().for_each(|x| builder.assert_bool(*x)); + builder.assert_bool(local.source_minus_twelve_carry); + builder.assert_bool(local.to_source_minus_twelve_carry); + + let mut shift_zero_when = builder.when(is_shift_zero.clone()); + shift_zero_when.assert_zero(local.source_minus_twelve_carry); + shift_zero_when.assert_zero(local.to_source_minus_twelve_carry); + + // Write source and destination to registers + let write_data = [ + (local.dest, local.to_dest, A3_REGISTER_PTR, A1_REGISTER_PTR), + ( + local.source, + local.to_source, + A4_REGISTER_PTR, + A3_REGISTER_PTR, + ), + ]; + + write_data.iter().enumerate().for_each( + |(idx, (prev_data, new_data, zero_shift_ptr, non_zero_shift_ptr))| { + let write_ptr = select::( + is_shift_zero.clone(), + AB::Expr::from_canonical_usize(*zero_shift_ptr), + AB::Expr::from_canonical_usize(*non_zero_shift_ptr), + ); + + self.memory_bridge + .write( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_REGISTER_AS), + write_ptr, + ), + *new_data, + timestamp_pp(), + &MemoryWriteAuxCols::from_base(local.write_aux[idx], *prev_data), + ) + .eval(builder, local.is_valid); + }, + ); + + // Write length to a2 register + self.memory_bridge + .write( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_REGISTER_AS), + AB::Expr::from_canonical_usize(A2_REGISTER_PTR), + ), + [ + to_len.into(), + AB::Expr::ZERO, + AB::Expr::ZERO, + AB::Expr::ZERO, + ], + timestamp_pp(), + &MemoryWriteAuxCols::from_base(local.write_aux[2], local.len), + ) + .eval(builder, local.is_valid); + + // Generate 16-bit limbs for range checking + // dest, to_dest, source - 12 * is_shift_non_zero, to_source - 12 * is_shift_non_zero + let dest_u16_limbs = u8_word_to_u16(local.dest); + let to_dest_u16_limbs = u8_word_to_u16(local.to_dest); + let source_u16_limbs = [ + local.source[0] + + local.source[1] * AB::F::from_canonical_u32(1 << MEMCPY_LOOP_LIMB_BITS) + - AB::Expr::from_canonical_u32(12) * is_shift_non_zero.clone() + + local.source_minus_twelve_carry + * AB::F::from_canonical_u32(1 << (2 * MEMCPY_LOOP_LIMB_BITS)), + local.source[2] + + local.source[3] * AB::F::from_canonical_u32(1 << MEMCPY_LOOP_LIMB_BITS) + - local.source_minus_twelve_carry, + ]; + let to_source_u16_limbs = [ + local.to_source[0] + + local.to_source[1] * AB::F::from_canonical_u32(1 << MEMCPY_LOOP_LIMB_BITS) + - AB::Expr::from_canonical_u32(12) * is_shift_non_zero.clone() + + local.to_source_minus_twelve_carry + * AB::F::from_canonical_u32(1 << (2 * MEMCPY_LOOP_LIMB_BITS)), + local.to_source[2] + + local.to_source[3] * AB::F::from_canonical_u32(1 << MEMCPY_LOOP_LIMB_BITS) + - local.to_source_minus_twelve_carry, + ]; + + // Range check addresses + let range_check_data = [ + dest_u16_limbs, + source_u16_limbs, + to_dest_u16_limbs, + to_source_u16_limbs, + ]; + + range_check_data.iter().for_each(|data| { + // Check the low 16 bits of dest and source, make sure they are multiple of 4 + self.range_bus + .range_check( + data[0].clone() * AB::F::from_canonical_u32(4).inverse(), + MEMCPY_LOOP_LIMB_BITS * 2 - 2, + ) + .eval(builder, local.is_valid); + // Check the high 16 bits of dest and source, make sure they are in the range [0, 2^pointer_max_bits - 2^MEMCPY_LOOP_LIMB_BITS) + self.range_bus + .range_check( + data[1].clone(), + self.pointer_max_bits - MEMCPY_LOOP_LIMB_BITS * 2, + ) + .eval(builder, local.is_valid); + }); + + // Send message to memcpy call bus + self.memcpy_bus + .send( + local.from_state.timestamp, + dest, + source - AB::Expr::from_canonical_u32(12) * is_shift_non_zero.clone(), + len.clone() - shift.clone(), + shift.clone(), + ) + .eval(builder, local.is_valid); + + // Receive message from memcpy return bus + self.memcpy_bus + .receive( + local.to_timestamp - AB::Expr::from_canonical_u32(timestamp_delta), + to_dest, + to_source - AB::Expr::from_canonical_u32(12) * is_shift_non_zero.clone(), + to_len - shift.clone(), + AB::Expr::from_canonical_u32(4), + ) + .eval(builder, local.is_valid); + + // Make sure the request and response match, this should work because the + // from_timestamp and len are valid and to_len is in [0, 16 + shift) + builder.when(local.is_valid).assert_eq( + AB::Expr::TWO * (local.to_timestamp - local.from_state.timestamp), + (len.clone() - to_len) + + AB::Expr::TWO + * (is_shift_non_zero.clone() + AB::Expr::from_canonical_u32(timestamp_delta)), + ); + + // Execution bus + program bus + self.execution_bridge + .execute_and_increment_pc( + AB::Expr::from_canonical_usize( + Rv32MemcpyOpcode::MEMCPY_LOOP as usize + self.offset, + ), + [AB::Expr::ZERO, AB::Expr::ZERO, shift.clone()], + local.from_state, + local.to_timestamp - local.from_state.timestamp, + ) + .eval(builder, local.is_valid); + } +} + +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct MemcpyLoopRecord { + pub from_pc: u32, + pub from_timestamp: u32, + pub dest: u32, + pub source: u32, + pub len: u32, + pub shift: u8, + pub write_aux: [MemoryExtendedAuxRecord; 3], +} + +pub struct MemcpyLoopChip { + pub air: MemcpyLoopAir, + pub records: Arc>>, + pub pointer_max_bits: usize, + pub range_checker_chip: SharedVariableRangeCheckerChip, +} + +impl MemcpyLoopChip { + pub fn new( + system_port: SystemPort, + range_bus: VariableRangeCheckerBus, + memcpy_bus: MemcpyBus, + offset: usize, + pointer_max_bits: usize, + range_checker_chip: SharedVariableRangeCheckerChip, + ) -> Self { + Self { + air: MemcpyLoopAir::new( + system_port.memory_bridge, + ExecutionBridge::new(system_port.execution_bus, system_port.program_bus), + range_bus, + memcpy_bus, + pointer_max_bits, + offset, + ), + records: Arc::new(Mutex::new(Vec::new())), + pointer_max_bits, + range_checker_chip, + } + } + + pub fn bus(&self) -> MemcpyBus { + self.air.memcpy_bus + } + + pub fn clear(&self) { + self.records.lock().unwrap().clear(); + } + + pub fn add_new_loop<'a, F: PrimeField32>( + &self, + mem_helper: &MemoryAuxColsFactory, + from_pc: u32, + from_timestamp: u32, + dest: u32, + source: u32, + len: u32, + shift: u8, + register_aux: [MemoryBaseAuxRecord; 3], + ) { + let mut timestamp = + from_timestamp + (((len - shift as u32) & !0x0f) >> 1) + (shift != 0) as u32; + let write_aux = register_aux + .iter() + .map(|aux_record| { + let mut aux_col = MemoryBaseAuxCols::default(); + mem_helper.fill(aux_record.prev_timestamp, timestamp, &mut aux_col); + timestamp += 1; + MemoryExtendedAuxRecord::from_aux_cols(aux_col) + }) + .collect::>() + .try_into() + .unwrap(); + + let num_copies = (len - shift as u32) & !0x0f; + let to_dest = dest + num_copies; + let to_source = source + num_copies; + + let word_to_u16 = |data: u32| [data & 0x0ffff, data >> 16]; + debug_assert!(source >= 12 * (shift != 0) as u32); + debug_assert!(to_source >= 12 * (shift != 0) as u32); + debug_assert!(dest % 4 == 0); + debug_assert!(to_dest % 4 == 0); + debug_assert!(source % 4 == 0); + debug_assert!(to_source % 4 == 0); + let range_check_data = [ + word_to_u16(dest), + word_to_u16(source - 12 * (shift != 0) as u32), + word_to_u16(to_dest), + word_to_u16(to_source - 12 * (shift != 0) as u32), + ]; + + range_check_data.iter().for_each(|data| { + self.range_checker_chip + .add_count(data[0] >> 2, 2 * MEMCPY_LOOP_LIMB_BITS - 2); + self.range_checker_chip + .add_count(data[1], self.pointer_max_bits - 2 * MEMCPY_LOOP_LIMB_BITS); + }); + + // Create record + let row = MemcpyLoopRecord { + from_pc, + from_timestamp, + dest, + source, + len, + shift, + write_aux, + }; + + // Thread-safe push to rows vector + if let Ok(mut rows_guard) = self.records.lock() { + rows_guard.push(row); + } + } + + /// Generates trace + pub fn generate_trace(&self) -> RowMajorMatrix { + let height = next_power_of_two_or_zero(self.records.lock().unwrap().len()); + let mut rows = F::zero_vec(height * NUM_MEMCPY_LOOP_COLS); + + // TODO: run in parallel + for (i, record) in self.records.lock().unwrap().iter().enumerate() { + let row = &mut rows[i * NUM_MEMCPY_LOOP_COLS..(i + 1) * NUM_MEMCPY_LOOP_COLS]; + let cols: &mut MemcpyLoopCols = row.borrow_mut(); + + let shift = record.shift; + let num_copies = (record.len - shift as u32) & !0x0f; + let to_source = record.source + num_copies; + + cols.from_state.pc = F::from_canonical_u32(record.from_pc); + cols.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + cols.dest = record.dest.to_le_bytes().map(F::from_canonical_u8); + cols.source = record.source.to_le_bytes().map(F::from_canonical_u8); + cols.len = record.len.to_le_bytes().map(F::from_canonical_u8); + cols.shift = [shift & 1, shift >> 1].map(F::from_canonical_u8); + cols.is_valid = F::ONE; + // We have MEMCPY_LOOP_NUM_WRITES writes in the loop, (num_copies / 4) writes + // and (num_copies / 4 + shift != 0) reads in iterations + cols.to_timestamp = F::from_canonical_u32( + record.from_timestamp + + MEMCPY_LOOP_NUM_WRITES + + (num_copies >> 1) + + (shift != 0) as u32, + ); + cols.to_dest = (record.dest + num_copies) + .to_le_bytes() + .map(F::from_canonical_u8); + cols.to_source = to_source.to_le_bytes().map(F::from_canonical_u8); + cols.to_len = F::from_canonical_u32(record.len - num_copies); + record + .write_aux + .iter() + .zip(cols.write_aux.iter_mut()) + .for_each(|(record_aux, col_aux)| { + record_aux.to_aux_cols(col_aux); + }); + cols.source_minus_twelve_carry = F::from_bool((record.source & 0x0ffff) < 12); + cols.to_source_minus_twelve_carry = F::from_bool((to_source & 0x0ffff) < 12); + + // tracing::info!("timestamp: {:?}, pc: {:?}, dest: {:?}, source: {:?}, len: {:?}, shift: {:?}, is_valid: {:?}, to_timestamp: {:?}, to_dest: {:?}, to_source: {:?}, to_len: {:?}, write_aux: {:?}", + // cols.from_state.timestamp.as_canonical_u32(), + // cols.from_state.pc.as_canonical_u32(), + // u32::from_le_bytes(cols.dest.map(|x| x.as_canonical_u32() as u8)), + // u32::from_le_bytes(cols.source.map(|x| x.as_canonical_u32() as u8)), + // u32::from_le_bytes(cols.len.map(|x| x.as_canonical_u32() as u8)), + // cols.shift[1].as_canonical_u32() * 2 + cols.shift[0].as_canonical_u32(), + // cols.is_valid.as_canonical_u32(), + // cols.to_timestamp.as_canonical_u32(), + // u32::from_le_bytes(cols.to_dest.map(|x| x.as_canonical_u32() as u8)), + // u32::from_le_bytes(cols.to_source.map(|x| x.as_canonical_u32() as u8)), + // cols.to_len.as_canonical_u32(), + // cols.write_aux.map(|x| x.prev_timestamp.as_canonical_u32()).to_vec()); + } + RowMajorMatrix::new(rows, NUM_MEMCPY_LOOP_COLS) + } +} + +// We allow any `R` type so this can work with arbitrary record arenas. +impl Chip> for MemcpyLoopChip +where + Val: PrimeField32, +{ + /// Generates trace and resets the internal counters all to 0. + fn generate_proving_ctx(&self, _: R) -> AirProvingContext> { + let trace = self.generate_trace::>(); + AirProvingContext::simple_no_pis(Arc::new(trace)) + } +} + +impl ChipUsageGetter for MemcpyLoopChip { + fn air_name(&self) -> String { + get_air_name(&self.air) + } + fn current_trace_height(&self) -> usize { + self.records.lock().unwrap().len() + } + fn trace_width(&self) -> usize { + NUM_MEMCPY_LOOP_COLS + } +} diff --git a/extensions/memcpy/tests/Cargo.toml b/extensions/memcpy/tests/Cargo.toml new file mode 100644 index 0000000000..bedaf7d8dd --- /dev/null +++ b/extensions/memcpy/tests/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "openvm-memcpy-integration-tests" +description = "Integration tests for the OpenVM memcpy extension" +version.workspace = true +authors.workspace = true +edition.workspace = true +homepage.workspace = true +repository.workspace = true + +[dependencies] +openvm-instructions = { workspace = true } +openvm-stark-sdk.workspace = true +openvm-memcpy-circuit.workspace = true +openvm-memcpy-transpiler.workspace = true +openvm = { workspace = true } +eyre.workspace = true +serde = { workspace = true, features = ["alloc"] } +strum.workspace = true +rand.workspace = true +openvm-circuit = { workspace = true, features = ["test-utils"] } +openvm-circuit-primitives.workspace = true +openvm-stark-backend.workspace = true +test-case.workspace = true +tracing.workspace = true + +[features] +default = ["parallel"] +parallel = ["openvm-circuit/parallel"] diff --git a/extensions/memcpy/tests/src/lib.rs b/extensions/memcpy/tests/src/lib.rs new file mode 100644 index 0000000000..49d5ae7d3d --- /dev/null +++ b/extensions/memcpy/tests/src/lib.rs @@ -0,0 +1,260 @@ +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use openvm_circuit::{ + arch::{ + testing::{TestBuilder, TestChipHarness, VmChipTestBuilder, MEMCPY_BUS, RANGE_CHECKER_BUS}, + Arena, PreflightExecutor, + }, + system::{memory::SharedMemoryHelper, SystemPort}, + }; + use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerAir, VariableRangeCheckerBus, + VariableRangeCheckerChip, + }; + use openvm_instructions::{instruction::Instruction, riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, LocalOpcode, VmOpcode}; + use openvm_memcpy_circuit::{ + MemcpyBus, MemcpyIterAir, MemcpyIterChip, MemcpyIterExecutor, MemcpyIterFiller, + MemcpyLoopAir, MemcpyLoopChip, A1_REGISTER_PTR, A2_REGISTER_PTR, A3_REGISTER_PTR, + A4_REGISTER_PTR, + }; + use openvm_memcpy_transpiler::Rv32MemcpyOpcode; + use openvm_stark_backend::p3_field::FieldAlgebra; + use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; + use rand::Rng; + use test_case::test_case; + + const MAX_INS_CAPACITY: usize = 128; + type F = BabyBear; + type Harness = TestChipHarness>; + + fn create_harness_fields( + address_bits: usize, + system_port: SystemPort, + range_chip: Arc, + memory_helper: SharedMemoryHelper, + ) -> (MemcpyIterAir, MemcpyIterExecutor, MemcpyIterChip, Arc) { + let range_bus = range_chip.bus(); + let memcpy_bus = MemcpyBus::new(MEMCPY_BUS); + + let air = MemcpyIterAir::new( + system_port.memory_bridge, + range_bus, + memcpy_bus, + address_bits, + ); + let executor = MemcpyIterExecutor::new(Rv32MemcpyOpcode::CLASS_OFFSET); + let loop_chip = Arc::new(MemcpyLoopChip::new( + system_port, + range_bus, + memcpy_bus, + Rv32MemcpyOpcode::CLASS_OFFSET, + address_bits, + range_chip.clone(), + )); + let chip = MemcpyIterChip::new( + MemcpyIterFiller::new(address_bits, range_chip, loop_chip.clone()), + memory_helper, + ); + (air, executor, chip, loop_chip) + } + + fn create_harness( + tester: &VmChipTestBuilder, + ) -> ( + Harness, + (VariableRangeCheckerAir, SharedVariableRangeCheckerChip), + (MemcpyLoopAir, Arc), + ) { + let range_bus = VariableRangeCheckerBus::new(RANGE_CHECKER_BUS, tester.address_bits()); + let range_chip = Arc::new(VariableRangeCheckerChip::new(range_bus)); + + let (air, executor, chip, loop_chip) = create_harness_fields( + tester.address_bits(), + tester.system_port(), + range_chip.clone(), + tester.memory_helper(), + ); + let harness = Harness::with_capacity(executor, air, chip, MAX_INS_CAPACITY); + + ( + harness, + (range_chip.air, range_chip), + (loop_chip.air, loop_chip), + ) + } + + fn set_and_execute_memcpy>( + tester: &mut impl TestBuilder, + executor: &mut E, + arena: &mut RA, + shift: u32, + source_data: &[u8], + dest_offset: u32, + source_offset: u32, + len: u32, + ) { + // Write source data to memory by words (4 bytes) + let source_words = source_data.len().div_ceil(4); + for word_idx in 0..source_words { + let word_start = word_idx * 4; + let word_end = (word_idx + 1) * 4; + let mut word_data = [F::ZERO; 4]; + + for i in word_start..word_end { + if i < source_data.len() { + word_data[i - word_start] = F::from_canonical_u8(source_data[i]); + } + } + + tester.write(RV32_MEMORY_AS as usize, (source_offset + word_idx as u32 * 4) as usize, word_data); + } + + // Set up registers that the memcpy instruction will read from + // destination address + tester.write::<4>( + RV32_REGISTER_AS as usize, + if shift == 0 { + A3_REGISTER_PTR + } else { + A1_REGISTER_PTR + }, + dest_offset.to_le_bytes().map(F::from_canonical_u8), + ); + // length + tester.write::<4>( + RV32_REGISTER_AS as usize, + A2_REGISTER_PTR, + len.to_le_bytes().map(F::from_canonical_u8), + ); + // source address + tester.write::<4>( + RV32_REGISTER_AS as usize, + if shift == 0 { + A4_REGISTER_PTR + } else { + A3_REGISTER_PTR + }, + source_offset.to_le_bytes().map(F::from_canonical_u8), + ); + + // Create instruction for memcpy_iter (uses same opcode as memcpy_loop) + let instruction = Instruction { + opcode: VmOpcode::from_usize(Rv32MemcpyOpcode::MEMCPY_LOOP.global_opcode().as_usize()), + a: F::ZERO, + b: F::ZERO, + c: F::from_canonical_u32(shift), + d: F::ZERO, + e: F::ZERO, + f: F::ZERO, + g: F::ZERO, + }; + + tester.execute(executor, arena, &instruction); + + // Verify the copy operation by reading words + // let dest_words = (len as usize + 3) / 4; // Round up to nearest word + // for word_idx in 0..dest_words { + // let word_data = tester.read::<4>(2, (dest_offset + word_idx as u32 * 4) as usize); + // let word_start = word_idx * 4; + + // for i in 0..4 { + // let byte_idx = word_start + i; + // if byte_idx < len as usize && byte_idx < source_data.len() { + // let expected = source_data[byte_idx]; + // let actual = word_data[i].as_canonical_u32() as u8; + // assert_eq!(expected, actual, "Mismatch at offset {}", byte_idx); + // } + // } + // } + } + + ////////////////////////////////////////////////////////////////////////////////////// + // POSITIVE TESTS + // + // Randomly generate memcpy operations and execute, ensuring that the generated trace + // passes all constraints. + ////////////////////////////////////////////////////////////////////////////////////// + + #[test_case(0, 1, 20)] + #[test_case(1, 100, 20)] + #[test_case(2, 100, 20)] + #[test_case(3, 100, 20)] + fn rand_memcpy_iter_test(shift: u32, num_ops: usize, len: u32) { + let mut rng = create_seeded_rng(); + + let mut tester = VmChipTestBuilder::default(); + let (mut harness, range_checker, memcpy_loop) = create_harness(&tester); + + for _ in 0..num_ops { + let source_offset = rng.gen_range(0..250) * 4; // Ensure word alignment + let dest_offset = rng.gen_range(500..750) * 4; // Ensure word alignment + let source_data: Vec = (0..len.div_ceil(4) * 4) + .map(|_| rng.gen_range(0..=u8::MAX)) + .collect(); + + set_and_execute_memcpy( + &mut tester, + &mut harness.executor, + &mut harness.arena, + shift, + &source_data, + dest_offset, + source_offset, + len, + ); + tracing::info!( + "source_offset: {}, dest_offset: {}, len: {}", + source_offset, + dest_offset, + len + ); + } + + let tester = tester + .build() + .load(harness) + .load_periphery(range_checker) + .load_periphery(memcpy_loop) + .finalize(); + tester.simple_test().expect("Verification failed"); + } + + #[test_case(0, 100, 20)] + #[test_case(1, 100, 20)] + #[test_case(2, 100, 20)] + #[test_case(3, 100, 20)] + fn rand_memcpy_iter_test_persistent(shift: u32, num_ops: usize, len: u32) { + let mut rng = create_seeded_rng(); + + let mut tester = VmChipTestBuilder::default_persistent(); + let (mut harness, range_checker, _iter_air) = create_harness(&tester); + + for _ in 0..num_ops { + let source_offset = rng.gen_range(0..250) * 4; // Ensure word alignment + let dest_offset = rng.gen_range(500..750) * 4; // Ensure word alignment + let source_data: Vec = (0..len.div_ceil(4) * 4) + .map(|_| rng.gen_range(0..=u8::MAX)) + .collect(); + + set_and_execute_memcpy( + &mut tester, + &mut harness.executor, + &mut harness.arena, + shift, + &source_data, + dest_offset, + source_offset, + len, + ); + } + + let tester = tester + .build() + .load(harness) + .load_periphery(range_checker) + .finalize(); + tester.simple_test().expect("Verification failed"); + } +} \ No newline at end of file diff --git a/extensions/memcpy/transpiler/Cargo.toml b/extensions/memcpy/transpiler/Cargo.toml new file mode 100644 index 0000000000..03e39a78d7 --- /dev/null +++ b/extensions/memcpy/transpiler/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "openvm-memcpy-transpiler" +description = "OpenVM transpiler extension for memcpy" +version.workspace = true +authors.workspace = true +edition.workspace = true +homepage.workspace = true +repository.workspace = true + +[dependencies] +openvm-stark-backend = { workspace = true } +openvm-instructions = { workspace = true } +openvm-transpiler = { workspace = true } +rrs-lib = { workspace = true } + +openvm-instructions-derive = { workspace = true } +strum = { workspace = true } diff --git a/extensions/memcpy/transpiler/src/lib.rs b/extensions/memcpy/transpiler/src/lib.rs new file mode 100644 index 0000000000..4dc13d3ebd --- /dev/null +++ b/extensions/memcpy/transpiler/src/lib.rs @@ -0,0 +1,58 @@ +use openvm_instructions::LocalOpcode; +use openvm_instructions_derive::LocalOpcode; +use openvm_stark_backend::p3_field::PrimeField32; +use openvm_transpiler::{util::from_u_type, TranspilerExtension, TranspilerOutput}; +use rrs_lib::instruction_formats::UType; +use strum::{EnumCount, EnumIter, FromRepr}; + +#[derive( + Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, +)] +#[opcode_offset = 0x330] +#[repr(usize)] +#[allow(non_camel_case_types)] +pub enum Rv32MemcpyOpcode { + MEMCPY_LOOP, +} + +// Custom opcode for memcpy_loop instruction +pub const MEMCPY_LOOP_OPCODE: u8 = 0x72; // Custom opcode + +#[derive(Default)] +pub struct MemcpyTranspilerExtension; + +impl TranspilerExtension for MemcpyTranspilerExtension { + fn process_custom(&self, instruction_stream: &[u32]) -> Option> { + if instruction_stream.is_empty() { + return None; + } + + let instruction_u32 = instruction_stream[0]; + let opcode = (instruction_u32 & 0x7f) as u8; + + // Check if this is our custom memcpy_loop instruction + if opcode != MEMCPY_LOOP_OPCODE { + return None; + } + + // Parse U-type instruction format + let mut dec_insn = UType::new(instruction_u32); + let shift = dec_insn.imm >> 12; + dec_insn.rd = 1; // avoid using x0, otherwise we get nop() + + // Validate shift value (0, 1, 2, or 3) + if ![0, 1, 2, 3].contains(&shift) { + return None; + } + + // Convert to OpenVM instruction format + let mut instruction = from_u_type( + Rv32MemcpyOpcode::MEMCPY_LOOP.global_opcode().as_usize(), + &dec_insn, + ); + instruction.a = F::ZERO; + instruction.d = F::ZERO; + + Some(TranspilerOutput::one_to_one(instruction)) + } +} diff --git a/extensions/pairing/circuit/Cargo.toml b/extensions/pairing/circuit/Cargo.toml index 4f4710fd2d..ef24579bf0 100644 --- a/extensions/pairing/circuit/Cargo.toml +++ b/extensions/pairing/circuit/Cargo.toml @@ -23,6 +23,7 @@ openvm-stark-backend = { workspace = true } openvm-rv32im-circuit = { workspace = true } openvm-algebra-circuit = { workspace = true } openvm-ecc-circuit = { workspace = true } +openvm-memcpy-circuit = { workspace = true } openvm-pairing-transpiler = { workspace = true } openvm-cuda-backend = { workspace = true, optional = true } openvm-stark-sdk = { workspace = true, optional = true } diff --git a/extensions/pairing/circuit/src/config.rs b/extensions/pairing/circuit/src/config.rs index 20ea07186a..ce6756651c 100644 --- a/extensions/pairing/circuit/src/config.rs +++ b/extensions/pairing/circuit/src/config.rs @@ -13,6 +13,7 @@ use openvm_circuit::{ }; use openvm_circuit_derive::VmConfig; use openvm_ecc_circuit::{EccCpuProverExt, WeierstrassExtension, WeierstrassExtensionExecutor}; +use openvm_memcpy_circuit::{Memcpy, MemcpyCpuProverExt, MemcpyExecutor}; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, engine::StarkEngine, @@ -33,6 +34,8 @@ pub struct Rv32PairingConfig { pub weierstrass: WeierstrassExtension, #[extension(generics = true)] pub pairing: PairingExtension, + #[extension] + pub memcpy: Memcpy, } impl Rv32PairingConfig { @@ -55,6 +58,7 @@ impl Rv32PairingConfig { curves.iter().map(|c| c.curve_config()).collect(), ), pairing: PairingExtension::new(curves), + memcpy: Memcpy, } } } @@ -101,6 +105,11 @@ where inventory, )?; VmProverExtension::::extend_prover(&PairingProverExt, &config.pairing, inventory)?; + VmProverExtension::::extend_prover( + &MemcpyCpuProverExt, + &config.memcpy, + inventory, + )?; Ok(chip_complex) } } diff --git a/extensions/rv32im/circuit/Cargo.toml b/extensions/rv32im/circuit/Cargo.toml index 78ed781c31..b7d01c6f3a 100644 --- a/extensions/rv32im/circuit/Cargo.toml +++ b/extensions/rv32im/circuit/Cargo.toml @@ -18,6 +18,7 @@ openvm-circuit = { workspace = true } openvm-circuit-derive = { workspace = true } openvm-instructions = { workspace = true } openvm-rv32im-transpiler = { workspace = true } +openvm-memcpy-circuit = { workspace = true } strum.workspace = true derive-new.workspace = true derive_more = { workspace = true, features = ["from"] } diff --git a/extensions/rv32im/circuit/src/lib.rs b/extensions/rv32im/circuit/src/lib.rs index 14f2783c47..6eeae4daa0 100644 --- a/extensions/rv32im/circuit/src/lib.rs +++ b/extensions/rv32im/circuit/src/lib.rs @@ -8,6 +8,7 @@ use openvm_circuit::{ system::{SystemChipInventory, SystemCpuBuilder, SystemExecutor}, }; use openvm_circuit_derive::{Executor, PreflightExecutor, VmConfig}; +use openvm_memcpy_circuit::{Memcpy, MemcpyCpuProverExt, MemcpyExecutor}; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, engine::StarkEngine, @@ -81,6 +82,8 @@ pub struct Rv32IConfig { pub base: Rv32I, #[extension] pub io: Rv32Io, + #[extension] + pub memcpy: Memcpy, } // Default implementation uses no init file @@ -105,6 +108,7 @@ impl Default for Rv32IConfig { system, base: Default::default(), io: Default::default(), + memcpy: Memcpy, } } } @@ -116,6 +120,7 @@ impl Rv32IConfig { system, base: Default::default(), io: Default::default(), + memcpy: Memcpy, } } @@ -127,6 +132,7 @@ impl Rv32IConfig { system, base: Default::default(), io: Default::default(), + memcpy: Memcpy, } } } @@ -173,6 +179,7 @@ where let inventory = &mut chip_complex.inventory; VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.base, inventory)?; VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.io, inventory)?; + VmProverExtension::::extend_prover(&MemcpyCpuProverExt, &config.memcpy, inventory)?; Ok(chip_complex) } } diff --git a/extensions/rv32im/tests/Cargo.toml b/extensions/rv32im/tests/Cargo.toml index 2de37c69c8..ee69a25904 100644 --- a/extensions/rv32im/tests/Cargo.toml +++ b/extensions/rv32im/tests/Cargo.toml @@ -15,6 +15,7 @@ openvm-transpiler.workspace = true openvm-rv32im-circuit.workspace = true openvm-rv32im-guest.workspace = true openvm-rv32im-transpiler.workspace = true +openvm-memcpy-transpiler.workspace = true openvm = { workspace = true } openvm-toolchain-tests = { path = "../../../crates/toolchain/tests" } eyre.workspace = true diff --git a/extensions/rv32im/tests/src/lib.rs b/extensions/rv32im/tests/src/lib.rs index a28a2da3f8..79461e910b 100644 --- a/extensions/rv32im/tests/src/lib.rs +++ b/extensions/rv32im/tests/src/lib.rs @@ -9,6 +9,7 @@ mod tests { utils::{air_test, air_test_with_min_segments, test_system_config}, }; use openvm_instructions::{exe::VmExe, instruction::Instruction, LocalOpcode, SystemOpcode}; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_rv32im_circuit::{Rv32IBuilder, Rv32IConfig, Rv32ImBuilder, Rv32ImConfig}; use openvm_rv32im_guest::hint_load_by_key_encode; use openvm_rv32im_transpiler::{ @@ -46,7 +47,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; change_rv32m_insn_to_nop(&mut exe); air_test_with_min_segments(Rv32IBuilder, config, exe, vec![], min_segments); @@ -63,7 +65,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(Rv32MTranspilerExtension), + .with_extension(Rv32MTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test_with_min_segments(Rv32ImBuilder, config, exe, vec![], min_segments); Ok(()) @@ -84,7 +87,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(Rv32MTranspilerExtension), + .with_extension(Rv32MTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test_with_min_segments(Rv32ImBuilder, config, exe, vec![], min_segments); Ok(()) @@ -99,7 +103,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let input = vec![[0, 1, 2, 3].map(F::from_canonical_u8).to_vec()]; air_test_with_min_segments(Rv32ImBuilder, config, exe, input, 1); @@ -115,7 +120,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; // stdin will be read after reading kv_store let stdin = vec![[0, 1, 2].map(F::from_canonical_u8).to_vec()]; @@ -138,7 +144,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; #[derive(serde::Serialize)] @@ -169,7 +176,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let executor = VmExecutor::new(config.clone())?; @@ -211,7 +219,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ImBuilder, config, exe); Ok(()) @@ -226,7 +235,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let executor = VmExecutor::new(config)?; @@ -253,7 +263,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ImBuilder, config, exe); Ok(()) @@ -273,7 +284,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ImBuilder, config, exe); Ok(()) @@ -289,7 +301,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), ) .unwrap(); let executor = VmExecutor::new(config).unwrap(); @@ -315,7 +328,8 @@ mod tests { Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), ) .unwrap(); air_test(Rv32ImBuilder, config, exe); diff --git a/extensions/sha256/circuit/Cargo.toml b/extensions/sha256/circuit/Cargo.toml index 740a4302f5..f581346734 100644 --- a/extensions/sha256/circuit/Cargo.toml +++ b/extensions/sha256/circuit/Cargo.toml @@ -16,6 +16,7 @@ openvm-circuit = { workspace = true } openvm-instructions = { workspace = true } openvm-sha256-transpiler = { workspace = true } openvm-rv32im-circuit = { workspace = true } +openvm-memcpy-circuit = { workspace = true } openvm-sha256-air = { workspace = true } derive-new.workspace = true diff --git a/extensions/sha256/circuit/src/lib.rs b/extensions/sha256/circuit/src/lib.rs index 2fcc8c8856..5d14ad299f 100644 --- a/extensions/sha256/circuit/src/lib.rs +++ b/extensions/sha256/circuit/src/lib.rs @@ -11,6 +11,7 @@ use openvm_circuit::{ system::{SystemChipInventory, SystemCpuBuilder, SystemExecutor}, }; use openvm_circuit_derive::VmConfig; +use openvm_memcpy_circuit::{Memcpy, MemcpyCpuProverExt, MemcpyExecutor}; use openvm_rv32im_circuit::{ Rv32I, Rv32IExecutor, Rv32ImCpuProverExt, Rv32Io, Rv32IoExecutor, Rv32M, Rv32MExecutor, }; @@ -54,6 +55,8 @@ pub struct Sha256Rv32Config { pub io: Rv32Io, #[extension] pub sha256: Sha256, + #[extension] + pub memcpy: Memcpy, } impl Default for Sha256Rv32Config { @@ -64,6 +67,7 @@ impl Default for Sha256Rv32Config { rv32m: Rv32M::default(), io: Rv32Io, sha256: Sha256, + memcpy: Memcpy, } } } @@ -99,6 +103,7 @@ where VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.rv32m, inventory)?; VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.io, inventory)?; VmProverExtension::::extend_prover(&Sha2CpuProverExt, &config.sha256, inventory)?; + VmProverExtension::::extend_prover(&MemcpyCpuProverExt, &config.memcpy, inventory)?; Ok(chip_complex) } } diff --git a/guest-libs/ff_derive/Cargo.toml b/guest-libs/ff_derive/Cargo.toml index acb017123f..f2a39da8be 100644 --- a/guest-libs/ff_derive/Cargo.toml +++ b/guest-libs/ff_derive/Cargo.toml @@ -35,6 +35,7 @@ openvm-transpiler = { workspace = true } openvm-algebra-transpiler = { workspace = true } openvm-algebra-circuit = { workspace = true } openvm-rv32im-transpiler = { workspace = true } +openvm-memcpy-transpiler = { workspace = true } openvm-toolchain-tests = { workspace = true } eyre = { workspace = true } diff --git a/guest-libs/ff_derive/tests/lib.rs b/guest-libs/ff_derive/tests/lib.rs index db7a1a09de..48de004639 100644 --- a/guest-libs/ff_derive/tests/lib.rs +++ b/guest-libs/ff_derive/tests/lib.rs @@ -8,6 +8,7 @@ mod tests { use openvm_algebra_transpiler::ModularTranspilerExtension; use openvm_circuit::utils::{air_test, test_system_config}; use openvm_instructions::exe::VmExe; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; @@ -43,7 +44,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularBuilder, config, openvm_exe); @@ -62,7 +64,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularBuilder, config, openvm_exe); @@ -81,7 +84,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularBuilder, config, openvm_exe); @@ -105,7 +109,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularBuilder, config, openvm_exe); @@ -129,7 +134,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularBuilder, config, openvm_exe); @@ -154,7 +160,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularBuilder, config, openvm_exe); @@ -178,7 +185,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32ModularBuilder, config, openvm_exe); diff --git a/guest-libs/k256/Cargo.toml b/guest-libs/k256/Cargo.toml index 40e09214df..f5d4c90986 100644 --- a/guest-libs/k256/Cargo.toml +++ b/guest-libs/k256/Cargo.toml @@ -38,6 +38,7 @@ openvm-ecc-circuit.workspace = true openvm-sha256-circuit.workspace = true openvm-sha256-transpiler.workspace = true openvm-rv32im-transpiler.workspace = true +openvm-memcpy-transpiler = { workspace = true } openvm-toolchain-tests.workspace = true openvm-stark-backend.workspace = true diff --git a/guest-libs/k256/tests/lib.rs b/guest-libs/k256/tests/lib.rs index 59eb42dde3..b71ef9dad0 100644 --- a/guest-libs/k256/tests/lib.rs +++ b/guest-libs/k256/tests/lib.rs @@ -10,6 +10,7 @@ mod guest_tests { CurveConfig, Rv32WeierstrassBuilder, Rv32WeierstrassConfig, SECP256K1_CONFIG, }; use openvm_ecc_transpiler::EccTranspilerExtension; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; @@ -41,7 +42,8 @@ mod guest_tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -59,7 +61,8 @@ mod guest_tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -80,7 +83,8 @@ mod guest_tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -237,7 +241,8 @@ mod guest_tests { .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Sha256TranspilerExtension), + .with_extension(Sha256TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(EcdsaBuilder, config, openvm_exe); Ok(()) @@ -258,7 +263,8 @@ mod guest_tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) diff --git a/guest-libs/keccak256/Cargo.toml b/guest-libs/keccak256/Cargo.toml index 276ec2679f..4712554154 100644 --- a/guest-libs/keccak256/Cargo.toml +++ b/guest-libs/keccak256/Cargo.toml @@ -20,6 +20,7 @@ openvm-transpiler = { workspace = true } openvm-keccak256-transpiler = { workspace = true } openvm-keccak256-circuit = { workspace = true } openvm-rv32im-transpiler = { workspace = true } +openvm-memcpy-transpiler = { workspace = true } openvm-toolchain-tests = { workspace = true } eyre = { workspace = true } diff --git a/guest-libs/keccak256/tests/lib.rs b/guest-libs/keccak256/tests/lib.rs index fbab4ef8c8..ae2c8c953d 100644 --- a/guest-libs/keccak256/tests/lib.rs +++ b/guest-libs/keccak256/tests/lib.rs @@ -5,6 +5,7 @@ mod tests { use openvm_instructions::exe::VmExe; use openvm_keccak256_circuit::{Keccak256Rv32Builder, Keccak256Rv32Config}; use openvm_keccak256_transpiler::Keccak256TranspilerExtension; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; @@ -25,7 +26,8 @@ mod tests { .with_extension(Keccak256TranspilerExtension) .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), + .with_extension(Rv32IoTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Keccak256Rv32Builder, config, openvm_exe); Ok(()) diff --git a/guest-libs/p256/Cargo.toml b/guest-libs/p256/Cargo.toml index 896e609e85..0616a3059a 100644 --- a/guest-libs/p256/Cargo.toml +++ b/guest-libs/p256/Cargo.toml @@ -35,6 +35,7 @@ openvm-ecc-circuit.workspace = true openvm-sha256-circuit.workspace = true openvm-sha256-transpiler.workspace = true openvm-rv32im-transpiler.workspace = true +openvm-memcpy-transpiler.workspace = true openvm-toolchain-tests.workspace = true openvm-stark-backend.workspace = true diff --git a/guest-libs/p256/tests/lib.rs b/guest-libs/p256/tests/lib.rs index 9eaf2b2c74..5e81e44e10 100644 --- a/guest-libs/p256/tests/lib.rs +++ b/guest-libs/p256/tests/lib.rs @@ -10,6 +10,7 @@ mod guest_tests { CurveConfig, Rv32WeierstrassBuilder, Rv32WeierstrassConfig, P256_CONFIG, }; use openvm_ecc_transpiler::EccTranspilerExtension; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; @@ -41,7 +42,8 @@ mod guest_tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -59,7 +61,8 @@ mod guest_tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -80,7 +83,8 @@ mod guest_tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -237,7 +241,8 @@ mod guest_tests { .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Sha256TranspilerExtension), + .with_extension(Sha256TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(EcdsaBuilder, config, openvm_exe); Ok(()) @@ -258,7 +263,8 @@ mod guest_tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) diff --git a/guest-libs/pairing/Cargo.toml b/guest-libs/pairing/Cargo.toml index 649f73f864..36ebeaa953 100644 --- a/guest-libs/pairing/Cargo.toml +++ b/guest-libs/pairing/Cargo.toml @@ -46,6 +46,8 @@ openvm-ecc-circuit.workspace = true openvm-ecc-guest.workspace = true openvm-ecc-transpiler.workspace = true openvm-rv32im-transpiler.workspace = true +openvm-memcpy-transpiler.workspace = true +openvm-memcpy-circuit.workspace = true openvm = { workspace = true } openvm-toolchain-tests = { workspace = true } eyre.workspace = true diff --git a/guest-libs/pairing/tests/lib.rs b/guest-libs/pairing/tests/lib.rs index 68150d536a..f8e49c4ead 100644 --- a/guest-libs/pairing/tests/lib.rs +++ b/guest-libs/pairing/tests/lib.rs @@ -24,6 +24,8 @@ mod bn254 { }; use openvm_ecc_transpiler::EccTranspilerExtension; use openvm_instructions::exe::VmExe; + use openvm_memcpy_circuit::Memcpy; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_pairing_circuit::{ PairingCurve, PairingExtension, Rv32PairingBuilder, Rv32PairingConfig, }; @@ -58,6 +60,7 @@ mod bn254 { fp2: Fp2Extension::new(primes_with_names), weierstrass: WeierstrassExtension::new(vec![]), pairing: PairingExtension::new(vec![PairingCurve::Bn254]), + memcpy: Memcpy, } } @@ -85,7 +88,8 @@ mod bn254 { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -108,7 +112,8 @@ mod bn254 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let mut rng = rand::rngs::StdRng::seed_from_u64(2); @@ -144,7 +149,8 @@ mod bn254 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let mut rng = rand::rngs::StdRng::seed_from_u64(2); @@ -202,7 +208,8 @@ mod bn254 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let mut rng = rand::rngs::StdRng::seed_from_u64(20); @@ -251,7 +258,8 @@ mod bn254 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let S = G1Affine::generator(); @@ -304,7 +312,8 @@ mod bn254 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let S = G1Affine::generator(); @@ -361,7 +370,8 @@ mod bn254 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let S = G1Affine::generator(); @@ -426,7 +436,8 @@ mod bn254 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let P = G1Affine::generator(); @@ -490,6 +501,8 @@ mod bls12_381 { AffinePoint, }; use openvm_ecc_transpiler::EccTranspilerExtension; + use openvm_memcpy_circuit::Memcpy; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_pairing_circuit::{ PairingCurve, PairingExtension, Rv32PairingBuilder, Rv32PairingConfig, }; @@ -527,6 +540,7 @@ mod bls12_381 { fp2: Fp2Extension::new(primes_with_names), weierstrass: WeierstrassExtension::new(vec![]), pairing: PairingExtension::new(vec![PairingCurve::Bls12_381]), + memcpy: Memcpy, } } @@ -560,7 +574,8 @@ mod bls12_381 { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) - .with_extension(ModularTranspilerExtension), + .with_extension(ModularTranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Rv32WeierstrassBuilder, config, openvm_exe); Ok(()) @@ -583,7 +598,8 @@ mod bls12_381 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let mut rng = rand::rngs::StdRng::seed_from_u64(50); @@ -619,7 +635,8 @@ mod bls12_381 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let mut rng = rand::rngs::StdRng::seed_from_u64(5); @@ -678,7 +695,8 @@ mod bls12_381 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let mut rng = rand::rngs::StdRng::seed_from_u64(88); @@ -727,7 +745,8 @@ mod bls12_381 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let S = G1Affine::generator(); @@ -786,7 +805,8 @@ mod bls12_381 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let S = G1Affine::generator(); @@ -843,7 +863,8 @@ mod bls12_381 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let S = G1Affine::generator(); @@ -907,7 +928,8 @@ mod bls12_381 { .with_extension(Rv32IoTranspilerExtension) .with_extension(PairingTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Fp2TranspilerExtension), + .with_extension(Fp2TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; let P = G1Affine::generator(); diff --git a/guest-libs/ruint/Cargo.toml b/guest-libs/ruint/Cargo.toml index 214f315e4f..d4ffe6137b 100644 --- a/guest-libs/ruint/Cargo.toml +++ b/guest-libs/ruint/Cargo.toml @@ -114,6 +114,7 @@ openvm-transpiler.workspace = true openvm-bigint-transpiler.workspace = true openvm-bigint-circuit.workspace = true openvm-rv32im-transpiler.workspace = true +openvm-memcpy-transpiler.workspace = true openvm-toolchain-tests = { workspace = true } eyre.workspace = true diff --git a/guest-libs/ruint/tests/lib.rs b/guest-libs/ruint/tests/lib.rs index 7339336674..df9044290d 100644 --- a/guest-libs/ruint/tests/lib.rs +++ b/guest-libs/ruint/tests/lib.rs @@ -5,6 +5,7 @@ mod tests { use openvm_bigint_transpiler::Int256TranspilerExtension; use openvm_circuit::utils::air_test; use openvm_instructions::exe::VmExe; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; @@ -28,7 +29,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(Int256TranspilerExtension), + .with_extension(Int256TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Int256Rv32Builder, config, openvm_exe); Ok(()) diff --git a/guest-libs/sha2/Cargo.toml b/guest-libs/sha2/Cargo.toml index 573930affb..e56b9eac1f 100644 --- a/guest-libs/sha2/Cargo.toml +++ b/guest-libs/sha2/Cargo.toml @@ -20,6 +20,7 @@ openvm-transpiler = { workspace = true } openvm-sha256-transpiler = { workspace = true } openvm-sha256-circuit = { workspace = true } openvm-rv32im-transpiler = { workspace = true } +openvm-memcpy-transpiler = { workspace = true } openvm-toolchain-tests = { workspace = true } eyre = { workspace = true } diff --git a/guest-libs/sha2/tests/lib.rs b/guest-libs/sha2/tests/lib.rs index adfae8e764..0708e84af6 100644 --- a/guest-libs/sha2/tests/lib.rs +++ b/guest-libs/sha2/tests/lib.rs @@ -3,6 +3,7 @@ mod tests { use eyre::Result; use openvm_circuit::utils::air_test; use openvm_instructions::exe::VmExe; + use openvm_memcpy_transpiler::MemcpyTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; @@ -25,7 +26,8 @@ mod tests { .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(Sha256TranspilerExtension), + .with_extension(Sha256TranspilerExtension) + .with_extension(MemcpyTranspilerExtension), )?; air_test(Sha256Rv32Builder, config, openvm_exe); Ok(()) diff --git a/guest-libs/verify_stark/tests/integration_test.rs b/guest-libs/verify_stark/tests/integration_test.rs index ea1150976a..1e1caf03c7 100644 --- a/guest-libs/verify_stark/tests/integration_test.rs +++ b/guest-libs/verify_stark/tests/integration_test.rs @@ -33,6 +33,7 @@ fn test_verify_openvm_stark_e2e() -> Result<()> { .rv32m(Default::default()) .io(Default::default()) .native(Default::default()) + .memcpy(Default::default()) .build(); let fri_params = FriParameters::new_for_testing(LEAF_LOG_BLOWUP); let app_config = AppConfig::new_with_leaf_fri_params(fri_params, vm_config.clone(), fri_params);