diff --git a/crates/wizer/src/component.rs b/crates/wizer/src/component.rs index f87a4245bfc6..ec0385538d70 100644 --- a/crates/wizer/src/component.rs +++ b/crates/wizer/src/component.rs @@ -58,7 +58,8 @@ pub trait ComponentInstanceState: Send { &mut self, instance: &str, func: &str, - ) -> impl Future> + Send; + contents: impl FnOnce(&[u8]) + Send, + ) -> impl Future + Send; /// Same as [`Self::call_func_ret_list_u8`], but for the `s32` WIT type. fn call_func_ret_s32(&mut self, instance: &str, func: &str) diff --git a/crates/wizer/src/component/snapshot.rs b/crates/wizer/src/component/snapshot.rs index 00f401eaf56f..a78e86dc7a6d 100644 --- a/crates/wizer/src/component/snapshot.rs +++ b/crates/wizer/src/component/snapshot.rs @@ -115,7 +115,7 @@ where } } - async fn memory_contents(&mut self, name: &str) -> Vec { + async fn memory_contents(&mut self, name: &str, contents: impl FnOnce(&[u8]) + Send) { let Accessor::Memory { accessor_export_name, .. @@ -124,7 +124,7 @@ where panic!("expected memory accessor for {name}"); }; self.ctx - .call_func_ret_list_u8(WIZER_INSTANCE, accessor_export_name) + .call_func_ret_list_u8(WIZER_INSTANCE, accessor_export_name, contents) .await } } diff --git a/crates/wizer/src/component/wasmtime.rs b/crates/wizer/src/component/wasmtime.rs index 83d8370d523e..bf9f7d87e6ec 100644 --- a/crates/wizer/src/component/wasmtime.rs +++ b/crates/wizer/src/component/wasmtime.rs @@ -1,7 +1,9 @@ use crate::Wizer; use crate::component::ComponentInstanceState; use anyhow::{Context, anyhow}; -use wasmtime::component::{Component, ComponentExportIndex, Instance, Lift, types::ComponentItem}; +use wasmtime::component::{ + Component, ComponentExportIndex, Instance, Lift, WasmList, types::ComponentItem, +}; use wasmtime::{Result, Store}; impl Wizer { @@ -84,7 +86,12 @@ pub struct WasmtimeWizerComponent<'a, T: 'static> { } impl WasmtimeWizerComponent<'_, T> { - async fn call_func(&mut self, instance: &str, func: &str) -> R + async fn call_func( + &mut self, + instance: &str, + func: &str, + use_ret: impl FnOnce(&mut Store, R) -> R2, + ) -> R2 where R: Lift + 'static, { @@ -102,29 +109,40 @@ impl WasmtimeWizerComponent<'_, T> { .get_typed_func::<(), (R,)>(&mut *self.store, func_export) .unwrap(); let ret = func.call_async(&mut *self.store, ()).await.unwrap().0; + let ret = use_ret(&mut *self.store, ret); func.post_return_async(&mut *self.store).await.unwrap(); ret } } impl ComponentInstanceState for WasmtimeWizerComponent<'_, T> { - async fn call_func_ret_list_u8(&mut self, instance: &str, func: &str) -> Vec { - self.call_func(instance, func).await + async fn call_func_ret_list_u8( + &mut self, + instance: &str, + func: &str, + contents: impl FnOnce(&[u8]) + Send, + ) { + self.call_func(instance, func, |store, list: WasmList| { + contents(list.as_le_slice(&store)); + }) + .await } async fn call_func_ret_s32(&mut self, instance: &str, func: &str) -> i32 { - self.call_func(instance, func).await + self.call_func(instance, func, |_, r| r).await } async fn call_func_ret_s64(&mut self, instance: &str, func: &str) -> i64 { - self.call_func(instance, func).await + self.call_func(instance, func, |_, r| r).await } async fn call_func_ret_f32(&mut self, instance: &str, func: &str) -> u32 { - self.call_func::(instance, func).await.to_bits() + self.call_func(instance, func, |_, r: f32| r.to_bits()) + .await } async fn call_func_ret_f64(&mut self, instance: &str, func: &str) -> u64 { - self.call_func::(instance, func).await.to_bits() + self.call_func(instance, func, |_, r: f64| r.to_bits()) + .await } } diff --git a/crates/wizer/src/lib.rs b/crates/wizer/src/lib.rs index 3975f8dcb381..f5f04189c6a9 100644 --- a/crates/wizer/src/lib.rs +++ b/crates/wizer/src/lib.rs @@ -386,5 +386,9 @@ pub trait InstanceState { /// # Panics /// /// This function panics if `name` isn't an exported memory. - fn memory_contents(&mut self, name: &str) -> impl Future> + Send; + fn memory_contents( + &mut self, + name: &str, + contents: impl FnOnce(&[u8]) + Send, + ) -> impl Future + Send; } diff --git a/crates/wizer/src/rewrite.rs b/crates/wizer/src/rewrite.rs index df37e6b6451c..5f854af48589 100644 --- a/crates/wizer/src/rewrite.rs +++ b/crates/wizer/src/rewrite.rs @@ -25,11 +25,12 @@ impl Wizer { // than the original, uninitialized data segments. let add_data_segments = |data_section: &mut wasm_encoder::DataSection| { for seg in &snapshot.data_segments { - data_section.active( - seg.memory_index, - &ConstExpr::i32_const(seg.offset as i32), - seg.data().iter().copied(), - ); + let offset = if seg.is64 { + ConstExpr::i64_const(seg.offset.cast_signed()) + } else { + ConstExpr::i32_const(u32::try_from(seg.offset).unwrap().cast_signed()) + }; + data_section.active(seg.memory_index, &offset, seg.data.iter().copied()); } }; diff --git a/crates/wizer/src/snapshot.rs b/crates/wizer/src/snapshot.rs index 387e856ab35e..cbd246770d54 100644 --- a/crates/wizer/src/snapshot.rs +++ b/crates/wizer/src/snapshot.rs @@ -2,7 +2,7 @@ use crate::InstanceState; use crate::info::ModuleContext; use rayon::iter::{IntoParallelIterator, ParallelExtend, ParallelIterator}; use std::convert::TryFrom; -use std::sync::Arc; +use std::ops::Range; /// The maximum number of data segments that we will emit. Most /// engines support more than this, but we want to leave some @@ -41,47 +41,13 @@ pub struct DataSegment { pub memory_index: u32, /// This data segment's initialized memory that it originated from. - pub memory: Arc>, + pub data: Vec, /// The offset within the memory that `data` should be copied to. - pub offset: u32, + pub offset: u64, - /// This segment's length. - pub len: u32, -} - -impl DataSegment { - pub fn data(&self) -> &[u8] { - let start = usize::try_from(self.offset).unwrap(); - let end = start + usize::try_from(self.len).unwrap(); - &self.memory[start..end] - } -} - -impl DataSegment { - /// What is the gap between two consecutive data segments? - /// - /// `self` must be in front of `other` and they must not overlap with each - /// other. - fn gap(&self, other: &Self) -> u32 { - debug_assert_eq!(self.memory_index, other.memory_index); - debug_assert!(self.offset + self.len <= other.offset); - other.offset - (self.offset + self.len) - } - - /// Merge two consecutive data segments. - /// - /// `self` must be in front of `other` and they must not overlap with each - /// other. - fn merge(&self, other: &Self) -> DataSegment { - let gap = self.gap(other); - - DataSegment { - offset: self.offset, - len: self.len + gap + other.len, - ..self.clone() - } - } + /// Whether or not `memory_index` is a 64-bit memory. + pub is64: bool, } /// Snapshot the given instance's globals, memories, and instances from the Wasm @@ -116,6 +82,34 @@ async fn snapshot_globals( ret } +#[derive(Clone)] +struct DataSegmentRange { + memory_index: u32, + range: Range, +} + +impl DataSegmentRange { + /// What is the gap between two consecutive data segments? + /// + /// `self` must be in front of `other` and they must not overlap with each + /// other. + fn gap(&self, other: &Self) -> usize { + debug_assert_eq!(self.memory_index, other.memory_index); + debug_assert!(self.range.end <= other.range.start); + other.range.start - self.range.end + } + + /// Merge two consecutive data segments. + /// + /// `self` must be in front of `other` and they must not overlap with each + /// other. + fn merge(&mut self, other: &Self) { + debug_assert_eq!(self.memory_index, other.memory_index); + debug_assert!(self.range.end <= other.range.start); + self.range.end = other.range.end; + } +} + /// Find the initialized minimum page size of each memory, as well as all /// regions of non-zero memory. async fn snapshot_memories( @@ -131,51 +125,52 @@ async fn snapshot_memories( .defined_memories() .zip(module.defined_memory_exports.as_ref().unwrap()); for ((memory_index, ty), name) in iter { - let memory = Arc::new(instance.memory_contents(&name).await); - let page_size = 1 << ty.page_size_log2.unwrap_or(16); - let num_wasm_pages = memory.len() / page_size; - memory_mins.push(num_wasm_pages as u64); - - let memory_data = &memory[..]; - - // Consider each Wasm page in parallel. Create data segments for each - // region of non-zero memory. - data_segments.par_extend((0..num_wasm_pages).into_par_iter().flat_map(|i| { - let page_end = (i + 1) * page_size; - let mut start = i * page_size; - let mut segments = vec![]; - while start < page_end { - let nonzero = match memory_data[start..page_end] - .iter() - .position(|byte| *byte != 0) - { - None => break, - Some(i) => i, - }; - start += nonzero; - let end = memory_data[start..page_end] - .iter() - .position(|byte| *byte == 0) - .map_or(page_end, |zero| start + zero); - segments.push(DataSegment { - memory_index, - memory: memory.clone(), - offset: u32::try_from(start).unwrap(), - len: u32::try_from(end - start).unwrap(), - }); - start = end; - } - segments - })); + instance + .memory_contents(&name, |memory| { + let page_size = 1 << ty.page_size_log2.unwrap_or(16); + let num_wasm_pages = memory.len() / page_size; + memory_mins.push(num_wasm_pages as u64); + + let memory_data = &memory[..]; + + // Consider each Wasm page in parallel. Create data segments for each + // region of non-zero memory. + data_segments.par_extend((0..num_wasm_pages).into_par_iter().flat_map(|i| { + let page_end = (i + 1) * page_size; + let mut start = i * page_size; + let mut segments = vec![]; + while start < page_end { + let nonzero = match memory_data[start..page_end] + .iter() + .position(|byte| *byte != 0) + { + None => break, + Some(i) => i, + }; + start += nonzero; + let end = memory_data[start..page_end] + .iter() + .position(|byte| *byte == 0) + .map_or(page_end, |zero| start + zero); + segments.push(DataSegmentRange { + memory_index, + range: start..end, + }); + start = end; + } + segments + })); + }) + .await; } if data_segments.is_empty() { - return (memory_mins, data_segments); + return (memory_mins, Vec::new()); } // Sort data segments to enforce determinism in the face of the // parallelism above. - data_segments.sort_by_key(|s| (s.memory_index, s.offset)); + data_segments.sort_by_key(|s| (s.memory_index, s.range.start)); // Merge any contiguous segments (caused by spanning a Wasm page boundary, // and therefore created in separate logical threads above) or pages that @@ -184,7 +179,7 @@ async fn snapshot_memories( // LEB, two for the memory offset init expression (one for the `i32.const` // opcode and another for the constant immediate LEB), and finally one for // the data length LEB). - const MIN_ACTIVE_SEGMENT_OVERHEAD: u32 = 4; + const MIN_ACTIVE_SEGMENT_OVERHEAD: usize = 4; let mut merged_data_segments = Vec::with_capacity(data_segments.len()); merged_data_segments.push(data_segments[0].clone()); for b in &data_segments[1..] { @@ -206,19 +201,47 @@ async fn snapshot_memories( // Okay, merge them together into `a` (so that the next iteration can // merge it with its predecessor) and then omit `b`! - let merged = a.merge(b); - *a = merged; + a.merge(b); } remove_excess_segments(&mut merged_data_segments); - (memory_mins, merged_data_segments) + // With the final set of data segments now extract the actual data of each + // memory, copying it into a `DataSegment`, to return the final list of + // segments. + // + // Here the memories are iterated over again and, in tandem, the + // `merged_data_segments` list is traversed to extract a `DataSegment` for + // each range that `merged_data_segments` indicates. This relies on + // `merged_data_segments` being a sorted list by `memory_index` at least. + let mut final_data_segments = Vec::with_capacity(merged_data_segments.len()); + let mut merged = merged_data_segments.iter().peekable(); + let iter = module + .defined_memories() + .zip(module.defined_memory_exports.as_ref().unwrap()); + for ((memory_index, ty), name) in iter { + instance + .memory_contents(&name, |memory| { + while let Some(segment) = merged.next_if(|s| s.memory_index == memory_index) { + final_data_segments.push(DataSegment { + memory_index, + data: memory[segment.range.clone()].to_vec(), + offset: segment.range.start.try_into().unwrap(), + is64: ty.memory64, + }); + } + }) + .await; + } + assert!(merged.next().is_none()); + + (memory_mins, final_data_segments) } /// Engines apply a limit on how many segments a module may contain, and Wizer /// can run afoul of it. When that happens, we need to merge data segments /// together until our number of data segments fits within the limit. -fn remove_excess_segments(merged_data_segments: &mut Vec) { +fn remove_excess_segments(merged_data_segments: &mut Vec) { if merged_data_segments.len() < MAX_DATA_SEGMENTS { return; } @@ -243,7 +266,14 @@ fn remove_excess_segments(merged_data_segments: &mut Vec) { if w[0].memory_index != w[1].memory_index { continue; } - let gap = w[0].gap(&w[1]); + let gap = match u32::try_from(w[0].gap(&w[1])) { + Ok(gap) => gap, + // If the gap is larger than 4G then don't consider these two data + // segments for merging and assume there's enough other data + // segments close enough together to still consider for merging to + // get under the limit. + Err(_) => continue, + }; let index = u32::try_from(index).unwrap(); smallest_gaps.push(GapIndex { gap, index }); } @@ -256,8 +286,10 @@ fn remove_excess_segments(merged_data_segments: &mut Vec) { smallest_gaps.sort_unstable_by(|a, b| a.index.cmp(&b.index).reverse()); for GapIndex { index, .. } in smallest_gaps { let index = usize::try_from(index).unwrap(); - let merged = merged_data_segments[index].merge(&merged_data_segments[index + 1]); - merged_data_segments[index] = merged; + let [a, b] = merged_data_segments + .get_disjoint_mut([index, index + 1]) + .unwrap(); + a.merge(b); // Okay to use `swap_remove` here because, even though it makes // `merged_data_segments` unsorted, the segments are still sorted within @@ -269,5 +301,5 @@ fn remove_excess_segments(merged_data_segments: &mut Vec) { // Finally, sort the data segments again so that our output is // deterministic. - merged_data_segments.sort_by_key(|s| (s.memory_index, s.offset)); + merged_data_segments.sort_by_key(|s| (s.memory_index, s.range.start)); } diff --git a/crates/wizer/src/wasmtime.rs b/crates/wizer/src/wasmtime.rs index 9d1f29fac282..289df6230469 100644 --- a/crates/wizer/src/wasmtime.rs +++ b/crates/wizer/src/wasmtime.rs @@ -105,8 +105,8 @@ impl InstanceState for WasmtimeWizer<'_, T> { } } - async fn memory_contents(&mut self, name: &str) -> Vec { + async fn memory_contents(&mut self, name: &str, contents: impl FnOnce(&[u8]) + Send) { let memory = self.instance.get_memory(&mut *self.store, name).unwrap(); - memory.data(&self.store).to_vec() + contents(memory.data(&self.store)) } } diff --git a/crates/wizer/tests/all/tests.rs b/crates/wizer/tests/all/tests.rs index 895b9e694516..46fed8a0f3d8 100644 --- a/crates/wizer/tests/all/tests.rs +++ b/crates/wizer/tests/all/tests.rs @@ -992,3 +992,28 @@ async fn memory_init_and_data_segments() -> Result<()> { let wizer = get_wizer(); wizen_and_run_wasm(&[], 0x02010403 + 0x06050201, &wasm, wizer).await } + +#[tokio::test] +async fn memory64() -> Result<()> { + let _ = env_logger::try_init(); + let wasm = wat_to_wasm( + r#" +(module + (memory i64 1) + + (func (export "wizer-initialize") + i64.const 0 + i32.const 10 + i32.store + ) + + (func (export "run") (result i32) + i64.const 0 + i32.load + ) +) + "#, + )?; + let wizer = get_wizer(); + wizen_and_run_wasm(&[], 10, &wasm, wizer).await +}