Skip to content

Commit 0dac69d

Browse files
committed
change fibonacci test to memcpy test
1 parent 1bac01c commit 0dac69d

File tree

7 files changed

+88
-39
lines changed

7 files changed

+88
-39
lines changed

Cargo.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
[app_vm_config.rv32i]
22
[app_vm_config.rv32m]
33
[app_vm_config.io]
4+
[app_vm_config.memcpy]
Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,44 @@
1-
use openvm::io::{read, reveal_u32};
1+
use core::ptr;
22

3-
pub fn main() {
4-
let n: u64 = read();
5-
let mut a: u64 = 0;
6-
let mut b: u64 = 1;
7-
for _ in 0..n {
8-
let c: u64 = a.wrapping_add(b);
9-
a = b;
10-
b = c;
3+
openvm::entry!(main);
4+
5+
/// Moves all the elements of `src` into `dst`, leaving `src` empty.
6+
#[no_mangle]
7+
pub fn append<T>(dst: &mut [T], src: &mut [T], shift: usize) {
8+
let src_len = src.len();
9+
let dst_len = dst.len();
10+
11+
unsafe {
12+
// The call to add is always safe because `Vec` will never
13+
// allocate more than `isize::MAX` bytes.
14+
let dst_ptr = dst.as_mut_ptr().wrapping_add(shift);
15+
let src_ptr = src.as_ptr();
16+
println!("dst_ptr: {:?}", dst_ptr);
17+
println!("src_ptr: {:?}", src_ptr);
18+
println!("src_len: {:?}", src_len);
19+
20+
// The two regions cannot overlap because mutable references do
21+
// not alias, and two different vectors cannot own the same
22+
// memory.
23+
ptr::copy_nonoverlapping(src_ptr, dst_ptr, src_len);
1124
}
12-
reveal_u32(a as u32, 0);
13-
reveal_u32((a >> 32) as u32, 1);
1425
}
26+
27+
pub fn main() {
28+
let mut a: [u8; 1000] = [1; 1000];
29+
let mut b: [u8; 500] = [2; 500];
30+
31+
let shift: usize = 0;
32+
append(&mut a, &mut b, shift);
33+
34+
for i in 0..1000 {
35+
if i < shift || i >= shift + b.len() {
36+
assert_eq!(a[i], 1);
37+
} else {
38+
assert_eq!(a[i], 2);
39+
}
40+
}
41+
42+
println!("a: {:?}", a);
43+
println!("b: {:?}", b);
44+
}

crates/toolchain/openvm/src/memcpy.s

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,6 @@ memcpy:
255255
sb a6, 2(a3)
256256
addi a2, a2, -3
257257
addi a3, a4, 16
258-
li a4, 16
259258
.LBBmemcpy0_9:
260259
memcpy_loop 1
261260
addi a4, a3, -13
@@ -268,7 +267,6 @@ memcpy:
268267
.LBBmemcpy0_12:
269268
li a1, 16
270269
bltu a2, a1, .LBBmemcpy0_15
271-
li a1, 15
272270
.LBBmemcpy0_14:
273271
memcpy_loop 0
274272
.LBBmemcpy0_15:
@@ -294,7 +292,6 @@ memcpy:
294292
sb a5, 0(a3)
295293
addi a2, a2, -1
296294
addi a3, a4, 16
297-
li a4, 18
298295
.LBBmemcpy0_20:
299296
memcpy_loop 3
300297
addi a4, a3, -15
@@ -307,7 +304,6 @@ memcpy:
307304
sb a6, 1(a3)
308305
addi a2, a2, -2
309306
addi a3, a4, 16
310-
li a4, 17
311307
.LBBmemcpy0_23:
312308
memcpy_loop 2
313309
addi a4, a3, -14

extensions/memcpy/README.md

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,20 @@ memcpy_loop shift
1111

1212
Where `shift` is an immediate value (0, 1, 2, or 3) representing the byte alignment shift.
1313

14-
### RISC-V Encoding
15-
- **Opcode**: `0x73` (custom opcode)
16-
- **Funct3**: `0x0` (custom funct3)
17-
- **Immediate**: 12-bit signed immediate for shift value
18-
- **Format**: I-type instruction
1914

2015
### Usage
2116
The `memcpy_loop` instruction is designed to replace repetitive shift-handling code in memcpy implementations. Instead of having separate code blocks for each shift value, you can use a single instruction:
2217

2318
```assembly
24-
# Instead of this repetitive code:
25-
.Lshift_1:
26-
lw a5, 0(a4)
27-
sb a5, 0(a3)
28-
srli a1, a5, 8
29-
sb a1, 1(a3)
30-
# ... more shift handling code
31-
32-
# You can use:
3319
memcpy_loop 1 # Handles shift=1 case
3420
```
3521

22+
Note that you must define `memcpy_loop` before using it. For example, in [memcpy.s](../../crates/toolchain/openvm/src/memcpy.s) it is defined at the beginning of the assembly code as follows:
23+
```assembly
24+
.macro memcpy_loop shift
25+
.word 0x00000072 | (\shift << 12) # opcode 0x72 + shift in immediate field (bits 12-31)
26+
```
27+
3628
### Benefits
3729
1. **Code Size Reduction**: Eliminates repetitive shift-handling code
3830
2. **Performance**: Optimized implementation in the circuit layer
@@ -84,3 +76,8 @@ RISC-V Assembly → Transpiler Extension → OpenVM Instruction → MemcpyIterat
8476
The extension provides:
8577
- **Transpiler**: `extensions/memcpy/transpiler/` - Translates RISC-V to OpenVM
8678
- **Circuit**: `extensions/memcpy/circuit/` - Implements the instruction logic
79+
80+
81+
# References
82+
83+
- Official Keccak [spec summary](https://keccak.team/keccak_specs_summary.html)

extensions/memcpy/circuit/src/iteration.rs

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use openvm_circuit::{
1818
MemoryWriteAuxCols, MemoryWriteBytesAuxRecord,
1919
},
2020
online::{GuestMemory, TracingMemory},
21-
MemoryAddress, MemoryAuxColsFactory,
21+
MemoryAddress, MemoryAuxColsFactory, POINTER_MAX_BITS,
2222
},
2323
};
2424
use openvm_circuit_primitives::{
@@ -441,6 +441,7 @@ where
441441
);
442442
let mut len = read_rv32_register(state.memory.data(), A2_REGISTER_PTR as u32);
443443

444+
// Create a record with var_size = ((len - shift) >> 4) + 1 which is the number of rows in iteration trace
444445
let record = state.ctx.alloc(MultiRowLayout::new(MemcpyIterMetadata {
445446
num_rows: ((len - shift as u32) >> 4) as usize + 1,
446447
}));
@@ -449,7 +450,11 @@ where
449450
record.inner.shift = shift;
450451
record.inner.from_pc = *state.pc;
451452
record.inner.from_timestamp = state.memory.timestamp;
453+
record.inner.dest = dest;
454+
record.inner.source = source;
455+
record.inner.len = len;
452456

457+
// Fill record.var for the first row of iteration trace
453458
if shift != 0 {
454459
source -= 12;
455460
record.var[0].data[3] = tracing_read(
@@ -460,6 +465,7 @@ where
460465
);
461466
};
462467

468+
// Fill record.var for the rest of the rows of iteration trace
463469
let mut idx = 1;
464470
while len - shift as u32 > 15 {
465471
let writes_data: [[u8; MEMCPY_LOOP_NUM_LIMBS]; 4] = array::from_fn(|i| {
@@ -540,9 +546,9 @@ where
540546
&mut len_data,
541547
);
542548

543-
record.inner.dest = u32::from_le_bytes(dest_data);
544-
record.inner.source = u32::from_le_bytes(source_data);
545-
record.inner.len = u32::from_le_bytes(len_data);
549+
debug_assert_eq!(record.inner.dest, u32::from_le_bytes(dest_data));
550+
debug_assert_eq!(record.inner.source, u32::from_le_bytes(source_data));
551+
debug_assert_eq!(record.inner.len, u32::from_le_bytes(len_data));
546552

547553
*state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
548554

@@ -580,7 +586,7 @@ impl<F: PrimeField32> TraceFiller<F> for MemcpyIterFiller {
580586
num_loops += 1;
581587
num_iters += num_rows;
582588
}
583-
// tracing::info!("num_loops: {:?}, num_iters: {:?}", num_loops, num_iters);
589+
tracing::info!("num_loops: {:?}, num_iters: {:?}, sizes: {:?}", num_loops, num_iters, sizes);
584590

585591
chunks
586592
.par_iter_mut()
@@ -594,6 +600,7 @@ impl<F: PrimeField32> TraceFiller<F> for MemcpyIterFiller {
594600
)
595601
};
596602

603+
tracing::info!("shift: {:?}", record.inner.shift);
597604
// Fill memcpy loop record
598605
self.memcpy_loop_chip.add_new_loop(
599606
mem_helper,
@@ -606,6 +613,7 @@ impl<F: PrimeField32> TraceFiller<F> for MemcpyIterFiller {
606613
record.inner.register_aux.clone(),
607614
);
608615

616+
// Calculate the timestamp for the last memory access
609617
// 4 reads + 4 writes per iteration + (shift != 0) read for the loop header
610618
let timestamp = record.inner.from_timestamp
611619
+ ((num_rows - 1) << 3) as u32
@@ -906,6 +914,7 @@ unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
906914
) -> u32 {
907915
let shift = pre_compute.c;
908916
let mut height = 1;
917+
// Read dest and source from registers
909918
let (dest, source) = if shift == 0 {
910919
(
911920
vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, A3_REGISTER_PTR as u32),
@@ -917,19 +926,31 @@ unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
917926
vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, A3_REGISTER_PTR as u32),
918927
)
919928
};
929+
// Read length from a2 register
920930
let len = vm_state.vm_read::<u8, 4>(RV32_REGISTER_AS, A2_REGISTER_PTR as u32);
921931

922932
let mut dest = u32::from_le_bytes(dest);
923-
let mut source = u32::from_le_bytes(source);
933+
let mut source = u32::from_le_bytes(source) - 12 * (shift != 0) as u32;
924934
let mut len = u32::from_le_bytes(len);
925935

936+
// Check address ranges are valid
937+
debug_assert!(dest < (1 << POINTER_MAX_BITS));
938+
debug_assert!((source - 4 * (shift != 0) as u32) < (1 << POINTER_MAX_BITS));
939+
let to_dest = dest + ((len - shift as u32) & !15);
940+
let to_source = source + ((len - shift as u32) & !15);
941+
debug_assert!(to_dest <= (1 << POINTER_MAX_BITS));
942+
debug_assert!(to_source <= (1 << POINTER_MAX_BITS));
943+
// Make sure the destination and source are not overlapping
944+
debug_assert!(to_dest <= source || to_source <= dest);
945+
946+
// Read the previous data from memory if shift != 0
926947
let mut prev_data = if shift == 0 {
927948
[0; 4]
928949
} else {
929-
source -= 12;
930950
vm_state.vm_read::<u8, 4>(RV32_MEMORY_AS, source - 4)
931951
};
932952

953+
// Run iterations
933954
while len - shift as u32 > 15 {
934955
for i in 0..4 {
935956
let data = vm_state.vm_read::<u8, 4>(RV32_MEMORY_AS, source + 4 * i);

extensions/memcpy/circuit/src/loops.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ impl<AB: InteractionBuilder> Air<AB> for MemcpyLoopAir {
181181
.eval(builder, local.is_valid);
182182

183183
// Generate 16-bit limbs for range checking
184+
// dest, to_dest, source - 12 * is_shift_non_zero, to_source - 12 * is_shift_non_zero
184185
let dest_u16_limbs = u8_word_to_u16(local.dest);
185186
let to_dest_u16_limbs = u8_word_to_u16(local.to_dest);
186187
let source_u16_limbs = [
@@ -213,12 +214,14 @@ impl<AB: InteractionBuilder> Air<AB> for MemcpyLoopAir {
213214
];
214215

215216
range_check_data.iter().for_each(|data| {
217+
// Check the low 16 bits of dest and source, make sure they are multiple of 4
216218
self.range_bus
217219
.range_check(
218220
data[0].clone() * AB::F::from_canonical_u32(4).inverse(),
219221
MEMCPY_LOOP_LIMB_BITS * 2 - 2,
220222
)
221223
.eval(builder, local.is_valid);
224+
// Check the high 16 bits of dest and source, make sure they are in the range [0, 2^pointer_max_bits - 2^MEMCPY_LOOP_LIMB_BITS)
222225
self.range_bus
223226
.range_check(
224227
data[1].clone(),
@@ -395,6 +398,7 @@ impl MemcpyLoopChip {
395398
let height = next_power_of_two_or_zero(self.records.lock().unwrap().len());
396399
let mut rows = F::zero_vec(height * NUM_MEMCPY_LOOP_COLS);
397400

401+
// TODO: run in parallel
398402
for (i, record) in self.records.lock().unwrap().iter().enumerate() {
399403
let row = &mut rows[i * NUM_MEMCPY_LOOP_COLS..(i + 1) * NUM_MEMCPY_LOOP_COLS];
400404
let cols: &mut MemcpyLoopCols<F> = row.borrow_mut();

0 commit comments

Comments
 (0)