Skip to content

Commit 1bac01c

Browse files
committed
fix: reduce memcpyIterAir degree from 4 to 3
1 parent 1d801a5 commit 1bac01c

File tree

11 files changed

+163
-27
lines changed

11 files changed

+163
-27
lines changed

Cargo.lock

Lines changed: 3 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

extensions/bigint/circuit/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ openvm-rv32im-circuit = { workspace = true }
2121
openvm-rv32-adapters = { workspace = true }
2222
openvm-bigint-transpiler = { workspace = true }
2323
openvm-rv32im-transpiler = { workspace = true }
24+
openvm-memcpy-circuit = { workspace = true }
2425

2526
derive-new.workspace = true
2627
derive_more = { workspace = true, features = ["from"] }

extensions/bigint/circuit/src/extension/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use openvm_circuit_primitives::{
2525
},
2626
};
2727
use openvm_instructions::{program::DEFAULT_PC_STEP, LocalOpcode};
28+
use openvm_memcpy_circuit::MemcpyCpuProverExt;
2829
use openvm_rv32im_circuit::Rv32ImCpuProverExt;
2930
use openvm_stark_backend::{
3031
config::{StarkGenericConfig, Val},
@@ -373,6 +374,11 @@ where
373374
&config.bigint,
374375
inventory,
375376
)?;
377+
VmProverExtension::<E, _, _>::extend_prover(
378+
&MemcpyCpuProverExt,
379+
&config.memcpy,
380+
inventory,
381+
)?;
376382
Ok(chip_complex)
377383
}
378384
}

extensions/bigint/circuit/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use openvm_circuit::{
66
system::SystemExecutor,
77
};
88
use openvm_circuit_derive::{PreflightExecutor, VmConfig};
9+
use openvm_memcpy_circuit::{Memcpy, MemcpyExecutor};
910
use openvm_rv32_adapters::{
1011
Rv32HeapAdapterAir, Rv32HeapAdapterExecutor, Rv32HeapAdapterFiller, Rv32HeapBranchAdapterAir,
1112
Rv32HeapBranchAdapterExecutor, Rv32HeapBranchAdapterFiller,
@@ -175,6 +176,8 @@ pub struct Int256Rv32Config {
175176
pub io: Rv32Io,
176177
#[extension]
177178
pub bigint: Int256,
179+
#[extension]
180+
pub memcpy: Memcpy,
178181
}
179182

180183
// Default implementation uses no init file
@@ -188,6 +191,7 @@ impl Default for Int256Rv32Config {
188191
rv32m: Rv32M::default(),
189192
io: Rv32Io,
190193
bigint: Int256::default(),
194+
memcpy: Memcpy,
191195
}
192196
}
193197
}

extensions/keccak256/circuit/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ openvm-circuit = { workspace = true }
1818
openvm-circuit-derive = { workspace = true }
1919
openvm-instructions = { workspace = true }
2020
openvm-rv32im-circuit = { workspace = true }
21+
openvm-memcpy-circuit = { workspace = true }
2122
openvm-keccak256-transpiler = { workspace = true }
2223

2324
p3-keccak-air = { workspace = true }

extensions/keccak256/circuit/src/extension/mod.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use openvm_circuit_primitives::bitwise_op_lookup::{
2020
};
2121
use openvm_instructions::*;
2222
use openvm_keccak256_transpiler::Rv32KeccakOpcode;
23+
use openvm_memcpy_circuit::{Memcpy, MemcpyCpuProverExt, MemcpyExecutor};
2324
use openvm_rv32im_circuit::{
2425
Rv32I, Rv32IExecutor, Rv32ImCpuProverExt, Rv32Io, Rv32IoExecutor, Rv32M, Rv32MExecutor,
2526
};
@@ -62,6 +63,8 @@ pub struct Keccak256Rv32Config {
6263
pub io: Rv32Io,
6364
#[extension]
6465
pub keccak: Keccak256,
66+
#[extension]
67+
pub memcpy: Memcpy,
6568
}
6669

6770
impl Default for Keccak256Rv32Config {
@@ -72,6 +75,7 @@ impl Default for Keccak256Rv32Config {
7275
rv32m: Rv32M::default(),
7376
io: Rv32Io,
7477
keccak: Keccak256,
78+
memcpy: Memcpy,
7579
}
7680
}
7781
}
@@ -111,6 +115,11 @@ where
111115
&config.keccak,
112116
inventory,
113117
)?;
118+
VmProverExtension::<E, _, _>::extend_prover(
119+
&MemcpyCpuProverExt,
120+
&config.memcpy,
121+
inventory,
122+
)?;
114123
Ok(chip_complex)
115124
}
116125
}

extensions/memcpy/circuit/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ openvm-instructions = { workspace = true }
1616
openvm-stark-backend = { workspace = true }
1717
openvm-memcpy-transpiler = { path = "../transpiler" }
1818
openvm-rv32im-transpiler = { workspace = true }
19-
openvm-rv32im-circuit = { workspace = true }
2019

2120
derive-new.workspace = true
2221
derive_more = { workspace = true, features = ["from"] }

extensions/memcpy/circuit/src/iteration.rs

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ use openvm_instructions::{
3434
LocalOpcode,
3535
};
3636
use openvm_memcpy_transpiler::Rv32MemcpyOpcode;
37-
use openvm_rv32im_circuit::adapters::{read_rv32_register, tracing_read, tracing_write};
3837
use openvm_stark_backend::{
3938
interaction::InteractionBuilder,
4039
p3_air::{Air, AirBuilder, BaseAir},
@@ -45,8 +44,7 @@ use openvm_stark_backend::{
4544
};
4645

4746
use crate::{
48-
bus::MemcpyBus, MemcpyLoopChip, A1_REGISTER_PTR, A2_REGISTER_PTR, A3_REGISTER_PTR,
49-
A4_REGISTER_PTR,
47+
bus::MemcpyBus, read_rv32_register, tracing_read, tracing_write, MemcpyLoopChip, A1_REGISTER_PTR, A2_REGISTER_PTR, A3_REGISTER_PTR, A4_REGISTER_PTR
5048
};
5149
// Import constants from lib.rs
5250
use crate::{MEMCPY_LOOP_LIMB_BITS, MEMCPY_LOOP_NUM_LIMBS};
@@ -58,10 +56,12 @@ pub struct MemcpyIterCols<T> {
5856
pub dest: T,
5957
pub source: T,
6058
pub len: [T; 2],
61-
pub shift: [T; 2],
59+
// 0: [0, 0, 0], 1: [1, 0, 0], 2: [0, 1, 0], 3: [0, 0, 1]
60+
pub shift: [T; 3],
6261
pub is_valid: T,
6362
pub is_valid_not_start: T,
64-
pub is_shift_non_zero: T,
63+
// This should be 0 if is_valid = 0. We use this to determine whether we need ro read data_4.
64+
pub is_shift_non_zero_or_not_start: T,
6565
// -1 for the first iteration, 1 for the last iteration, 0 for the middle iterations
6666
pub is_boundary: T,
6767
pub data_1: [T; MEMCPY_LOOP_NUM_LIMBS],
@@ -100,16 +100,21 @@ impl<AB: InteractionBuilder> Air<AB> for MemcpyIterAir {
100100

101101
let timestamp: AB::Var = local.timestamp;
102102
let mut timestamp_delta: AB::Expr = AB::Expr::ZERO;
103-
let mut timestamp_pp = |timestamp_increase_value: AB::Expr| {
104-
timestamp_delta += timestamp_increase_value.clone();
103+
let mut timestamp_pp = |timestamp_increase_value: AB::Var| {
104+
timestamp_delta += timestamp_increase_value.into();
105105
timestamp + timestamp_delta.clone() - timestamp_increase_value.clone()
106106
};
107107

108-
let shift = local.shift[1] * AB::Expr::TWO + local.shift[0];
109-
let is_shift_zero = not::<AB::Expr>(local.is_shift_non_zero);
110-
let is_shift_one = and::<AB::Expr>(local.shift[0], not::<AB::Expr>(local.shift[1]));
111-
let is_shift_two = and::<AB::Expr>(not::<AB::Expr>(local.shift[0]), local.shift[1]);
112-
let is_shift_three = and::<AB::Expr>(local.shift[0], local.shift[1]);
108+
let shift = local.shift.iter().enumerate().fold(AB::Expr::ZERO, |acc, (i, x)| {
109+
acc + (*x) * AB::Expr::from_canonical_u32(i as u32 + 1)
110+
});
111+
let is_shift_non_zero = local.shift.iter().fold(AB::Expr::ZERO, |acc, x| {
112+
acc + (*x)
113+
});
114+
let is_shift_zero = not::<AB::Expr>(is_shift_non_zero.clone());
115+
let is_shift_one = local.shift[0];
116+
let is_shift_two = local.shift[1];
117+
let is_shift_three = local.shift[2];
113118

114119
let is_end =
115120
(local.is_boundary + AB::Expr::ONE) * local.is_boundary * (AB::F::TWO).inverse();
@@ -166,8 +171,9 @@ impl<AB: InteractionBuilder> Air<AB> for MemcpyIterAir {
166171

167172
builder.assert_bool(local.is_valid);
168173
local.shift.iter().for_each(|x| builder.assert_bool(*x));
174+
builder.assert_bool(is_shift_non_zero.clone());
169175
builder.assert_bool(local.is_valid_not_start);
170-
builder.assert_bool(local.is_shift_non_zero);
176+
builder.assert_bool(local.is_shift_non_zero_or_not_start);
171177
// is_boundary is either -1, 0 or 1
172178
builder.assert_tern(local.is_boundary + AB::Expr::ONE);
173179

@@ -177,8 +183,8 @@ impl<AB: InteractionBuilder> Air<AB> for MemcpyIterAir {
177183
and::<AB::Expr>(local.is_valid, is_not_start),
178184
);
179185

180-
// is_shift_non_zero is correct
181-
builder.assert_eq(local.is_shift_non_zero, or::<AB::Expr>(local.shift[0], local.shift[1]));
186+
// is_shift_non_zero_or_not_start is correct
187+
builder.assert_eq(local.is_shift_non_zero_or_not_start, or::<AB::Expr>(is_shift_non_zero.clone(), local.is_valid_not_start));
182188

183189
// if !is_valid, then is_boundary = 0, shift = 0 (we will use this assumption later)
184190
let mut is_not_valid_when = builder.when(not::<AB::Expr>(local.is_valid));
@@ -193,8 +199,9 @@ impl<AB: InteractionBuilder> Air<AB> for MemcpyIterAir {
193199
is_valid_not_start_when
194200
.assert_eq(local.source, prev.source + AB::Expr::from_canonical_u32(16));
195201
is_valid_not_start_when.assert_eq(local.dest, prev.dest + AB::Expr::from_canonical_u32(16));
196-
is_valid_not_start_when.assert_eq(local.shift[0], prev.shift[0]);
197-
is_valid_not_start_when.assert_eq(local.shift[1], prev.shift[1]);
202+
local.shift.iter().zip(prev.shift.iter()).for_each(|(local_shift, prev_shift)| {
203+
is_valid_not_start_when.assert_eq(*local_shift, *prev_shift);
204+
});
198205

199206
// make sure if previous row is valid and not end, then local.is_valid = 1
200207
builder
@@ -205,7 +212,7 @@ impl<AB: InteractionBuilder> Air<AB> for MemcpyIterAir {
205212
// since is_shift_non_zero degree is 2, we need to keep the degree of the condition to 1
206213
builder
207214
.when(not::<AB::Expr>(prev.is_valid_not_start) - not::<AB::Expr>(prev.is_valid))
208-
.assert_eq(local.timestamp, prev.timestamp + local.is_shift_non_zero);
215+
.assert_eq(local.timestamp, prev.timestamp + is_shift_non_zero);
209216

210217
// if prev.is_valid_not_start and local.is_valid_not_start, then timestamp=prev_timestamp+8
211218
// prev.is_valid_not_start is the opposite of previous condition
@@ -247,9 +254,9 @@ impl<AB: InteractionBuilder> Air<AB> for MemcpyIterAir {
247254
.enumerate()
248255
.for_each(|(idx, data)| {
249256
let is_valid_read = if idx == 3 {
250-
or::<AB::Expr>(local.is_shift_non_zero, local.is_valid_not_start)
257+
local.is_shift_non_zero_or_not_start
251258
} else {
252-
local.is_valid_not_start.into()
259+
local.is_valid_not_start
253260
};
254261

255262
self.memory_bridge
@@ -274,7 +281,7 @@ impl<AB: InteractionBuilder> Air<AB> for MemcpyIterAir {
274281
local.dest - AB::Expr::from_canonical_usize(16 - idx * 4),
275282
),
276283
data.clone(),
277-
timestamp_pp(local.is_valid_not_start.into()),
284+
timestamp_pp(local.is_valid_not_start),
278285
&local.write_aux[idx],
279286
)
280287
.eval(builder, local.is_valid_not_start);
@@ -701,11 +708,15 @@ impl<F: PrimeField32> TraceFiller<F> for MemcpyIterFiller {
701708
} else {
702709
F::ZERO
703710
};
704-
cols.is_shift_non_zero = F::from_canonical_u8((record.inner.shift != 0) as u8);
705-
cols.is_valid_not_start = F::from_canonical_u8(1 - is_start as u8);
711+
cols.is_shift_non_zero_or_not_start = F::from_bool(record.inner.shift != 0 || !is_start);
712+
cols.is_valid_not_start = F::from_bool(!is_start);
706713
cols.is_valid = F::ONE;
707-
cols.shift = [record.inner.shift & 1, record.inner.shift >> 1]
708-
.map(F::from_canonical_u8);
714+
cols.shift = [
715+
record.inner.shift == 1,
716+
record.inner.shift == 2,
717+
record.inner.shift == 3,
718+
]
719+
.map(F::from_bool);
709720
cols.len = [len & 0xffff, len >> 16].map(F::from_canonical_u32);
710721
cols.source = F::from_canonical_u32(source);
711722
cols.dest = F::from_canonical_u32(dest);

extensions/memcpy/circuit/src/lib.rs

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ pub use bus::*;
77
pub use extension::*;
88
pub use iteration::*;
99
pub use loops::*;
10+
use openvm_circuit::system::memory::{merkle::public_values::PUBLIC_VALUES_AS, online::{GuestMemory, TracingMemory}};
11+
use openvm_instructions::riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS};
1012

1113
// ==== Do not change these constants! ====
1214
pub const MEMCPY_LOOP_NUM_LIMBS: usize = 4;
@@ -16,3 +18,96 @@ pub const A1_REGISTER_PTR: usize = 11 * 4;
1618
pub const A2_REGISTER_PTR: usize = 12 * 4;
1719
pub const A3_REGISTER_PTR: usize = 13 * 4;
1820
pub const A4_REGISTER_PTR: usize = 14 * 4;
21+
22+
23+
// TODO: These are duplicated from extensions/rv32im/circuit/src/adapters/mod.rs
24+
// to prevent cyclic dependencies. Fix this.
25+
26+
#[inline(always)]
27+
pub fn memory_read<const N: usize>(memory: &GuestMemory, address_space: u32, ptr: u32) -> [u8; N] {
28+
debug_assert!(
29+
address_space == RV32_REGISTER_AS
30+
|| address_space == RV32_MEMORY_AS
31+
|| address_space == PUBLIC_VALUES_AS,
32+
);
33+
34+
// SAFETY:
35+
// - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and
36+
// minimum alignment of `RV32_REGISTER_NUM_LIMBS`
37+
unsafe { memory.read::<u8, N>(address_space, ptr) }
38+
}
39+
40+
/// Atomic read operation which increments the timestamp by 1.
41+
/// Returns `(t_prev, [ptr:4]_{address_space})` where `t_prev` is the timestamp of the last memory
42+
/// access.
43+
#[inline(always)]
44+
pub fn timed_read<const N: usize>(
45+
memory: &mut TracingMemory,
46+
address_space: u32,
47+
ptr: u32,
48+
) -> (u32, [u8; N]) {
49+
debug_assert!(
50+
address_space == RV32_REGISTER_AS
51+
|| address_space == RV32_MEMORY_AS
52+
|| address_space == PUBLIC_VALUES_AS
53+
);
54+
55+
// SAFETY:
56+
// - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and
57+
// minimum alignment of `MEMCPY_LOOP_NUM_LIMBS`
58+
unsafe { memory.read::<u8, N, MEMCPY_LOOP_NUM_LIMBS>(address_space, ptr) }
59+
}
60+
61+
#[inline(always)]
62+
pub fn timed_write<const N: usize>(
63+
memory: &mut TracingMemory,
64+
address_space: u32,
65+
ptr: u32,
66+
data: [u8; N],
67+
) -> (u32, [u8; N]) {
68+
debug_assert!(
69+
address_space == RV32_REGISTER_AS
70+
|| address_space == RV32_MEMORY_AS
71+
|| address_space == PUBLIC_VALUES_AS
72+
);
73+
74+
// SAFETY:
75+
// - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and
76+
// minimum alignment of `MEMCPY_LOOP_NUM_LIMBS`
77+
unsafe { memory.write::<u8, N, MEMCPY_LOOP_NUM_LIMBS>(address_space, ptr, data) }
78+
}
79+
80+
/// Reads register value at `reg_ptr` from memory and records the memory access in mutable buffer.
81+
/// Trace generation relevant to this memory access can be done fully from the recorded buffer.
82+
#[inline(always)]
83+
pub fn tracing_read<const N: usize>(
84+
memory: &mut TracingMemory,
85+
address_space: u32,
86+
ptr: u32,
87+
prev_timestamp: &mut u32,
88+
) -> [u8; N] {
89+
let (t_prev, data) = timed_read(memory, address_space, ptr);
90+
*prev_timestamp = t_prev;
91+
data
92+
}
93+
94+
/// Writes `reg_ptr, reg_val` into memory and records the memory access in mutable buffer.
95+
/// Trace generation relevant to this memory access can be done fully from the recorded buffer.
96+
#[inline(always)]
97+
pub fn tracing_write<const N: usize>(
98+
memory: &mut TracingMemory,
99+
address_space: u32,
100+
ptr: u32,
101+
data: [u8; N],
102+
prev_timestamp: &mut u32,
103+
prev_data: &mut [u8; N],
104+
) {
105+
let (t_prev, data_prev) = timed_write(memory, address_space, ptr, data);
106+
*prev_timestamp = t_prev;
107+
*prev_data = data_prev;
108+
}
109+
110+
#[inline(always)]
111+
pub fn read_rv32_register(memory: &GuestMemory, ptr: u32) -> u32 {
112+
u32::from_le_bytes(memory_read(memory, RV32_REGISTER_AS, ptr))
113+
}

extensions/rv32im/circuit/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ openvm-circuit = { workspace = true }
1818
openvm-circuit-derive = { workspace = true }
1919
openvm-instructions = { workspace = true }
2020
openvm-rv32im-transpiler = { workspace = true }
21+
openvm-memcpy-circuit = { workspace = true }
2122
strum.workspace = true
2223
derive-new.workspace = true
2324
derive_more = { workspace = true, features = ["from"] }

0 commit comments

Comments
 (0)