diff --git a/crates/interpreter/Cargo.toml b/crates/interpreter/Cargo.toml index 7305c3d2d..07e4874ad 100644 --- a/crates/interpreter/Cargo.toml +++ b/crates/interpreter/Cargo.toml @@ -33,6 +33,8 @@ float-types = ["dep:libm"] vector-types = [] # Enable caching for execution. cache = ["dep:lru"] +# Enable interrupting execution. +interrupt = [] [lints] clippy.unit-arg = "allow" diff --git a/crates/interpreter/src/exec.rs b/crates/interpreter/src/exec.rs index e0c0a4fec..00e6177ac 100644 --- a/crates/interpreter/src/exec.rs +++ b/crates/interpreter/src/exec.rs @@ -15,6 +15,11 @@ // TODO: Some toctou could be used instead of panic. use alloc::vec; use alloc::vec::Vec; +#[cfg(feature = "interrupt")] +use core::sync::atomic::Ordering::Relaxed; + +#[cfg(feature = "interrupt")] +use portable_atomic::AtomicBool; use crate::error::*; use crate::module::*; @@ -58,6 +63,8 @@ pub struct Store<'m> { // functions in `funcs` is stored to limit normal linking to that part. func_default: Option<(&'m str, usize)>, threads: Vec>, + #[cfg(feature = "interrupt")] + interrupt: Option<&'m AtomicBool>, } #[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] @@ -99,6 +106,8 @@ impl Default for Store<'_> { funcs: vec![], func_default: None, threads: vec![], + #[cfg(feature = "interrupt")] + interrupt: None, } } } @@ -195,13 +204,18 @@ impl<'m> Store<'m> { let mut parser = self.insts[inst_id].module.func(ptr.index()); let mut locals = Vec::new(); append_locals(&mut parser, &mut locals); - let thread = Thread::new(parser, Frame::new(inst_id, 0, &[], locals)); + let thread = Thread::new( + parser, + Frame::new(inst_id, 0, &[], locals), + #[cfg(feature = "interrupt")] + None, + ); + let result = thread.run(self)?; - assert!(matches!(result, RunResult::Done(x) if x.is_empty())); + assert!(matches!(result, RunResult::Done(x) if x.is_empty())) } Ok(InstId { store_id: self.id, inst_id }) } - /// Invokes a function in an instance provided its name. /// /// If a function was already running, it will resume once the function being called terminates. @@ -225,7 +239,13 @@ impl<'m> Store<'m> { let mut locals = args; append_locals(&mut parser, &mut locals); let frame = Frame::new(inst_id, t.results.len(), &[], locals); - Thread::new(parser, frame).run(self) + Thread::new( + parser, + frame, + #[cfg(feature = "interrupt")] + self.interrupt, + ) + .run(self) } /// Returns the value of a global of an instance. @@ -303,6 +323,11 @@ impl<'m> Store<'m> { Some(Call { store: self }) } } + + #[cfg(feature = "interrupt")] + pub fn set_interrupt(&mut self, interrupt: Option<&'m AtomicBool>) { + self.interrupt = interrupt; + } } impl<'a, 'm> Call<'a, 'm> { @@ -339,6 +364,12 @@ impl<'a, 'm> Call<'a, 'm> { thread.run(self.store) } + // Returns if this call is due to an interrupt. + #[cfg(feature = "interrupt")] + pub fn is_interrupt(&self) -> bool { + self.cont().interrupted + } + fn cont(&self) -> &Continuation { self.store.threads.last().unwrap() } @@ -460,6 +491,8 @@ struct Instance<'m> { struct Thread<'m> { parser: Parser<'m>, frames: Vec>, + #[cfg(feature = "interrupt")] + interrupt: Option<&'m AtomicBool>, } /// Runtime result. @@ -470,6 +503,10 @@ pub enum RunResult<'a, 'm> { /// Execution is calling into the host. Host(Call<'a, 'm>), + + /// Execution was interrupted by the host. + #[cfg(feature = "interrupt")] + Interrupt(Call<'a, 'm>), } /// Runtime result without host call information. @@ -484,6 +521,8 @@ impl RunResult<'_, '_> { match self { RunResult::Done(result) => RunAnswer::Done(result), RunResult::Host(_) => RunAnswer::Host, + #[cfg(feature = "interrupt")] + RunResult::Interrupt(_) => RunAnswer::Host, } } } @@ -494,6 +533,8 @@ struct Continuation<'m> { index: usize, args: Vec, arity: usize, + #[cfg(feature = "interrupt")] + interrupted: bool, } impl<'m> Store<'m> { @@ -724,21 +765,40 @@ enum ThreadResult<'m> { Continue(Thread<'m>), Done(Vec), Host, + #[cfg(feature = "interrupt")] + Interrupt, } impl<'m> Thread<'m> { - fn new(parser: Parser<'m>, frame: Frame<'m>) -> Thread<'m> { - Thread { parser, frames: vec![frame] } + fn new( + parser: Parser<'m>, frame: Frame<'m>, + #[cfg(feature = "interrupt")] interrupt: Option<&'m AtomicBool>, + ) -> Thread<'m> { + Thread { + parser, + frames: vec![frame], + + #[cfg(feature = "interrupt")] + interrupt, + } } fn const_expr(store: &mut Store<'m>, inst_id: usize, mut_parser: &mut Parser<'m>) -> Val { let parser = mut_parser.clone(); - let mut thread = Thread::new(parser, Frame::new(inst_id, 1, &[], Vec::new())); + let mut thread = Thread::new( + parser, + Frame::new(inst_id, 1, &[], Vec::new()), + #[cfg(feature = "interrupt")] + None, + ); + let (parser, results) = loop { let p = thread.parser.save(); match thread.step(store).unwrap() { ThreadResult::Continue(x) => thread = x, ThreadResult::Done(x) => break (p, x), + #[cfg(feature = "interrupt")] + ThreadResult::Interrupt => unreachable!(), ThreadResult::Host => unreachable!(), } }; @@ -757,6 +817,8 @@ impl<'m> Thread<'m> { ThreadResult::Continue(x) => self = x, ThreadResult::Done(x) => return Ok(RunResult::Done(x)), ThreadResult::Host => return Ok(RunResult::Host(Call { store })), + #[cfg(feature = "interrupt")] + ThreadResult::Interrupt => return Ok(RunResult::Interrupt(Call { store })), } } } @@ -765,7 +827,7 @@ impl<'m> Thread<'m> { use Instr::*; let saved = self.parser.save(); let inst_id = self.frame().inst_id; - let inst = &mut store.insts[inst_id]; + let inst: &mut Instance<'m> = &mut store.insts[inst_id]; match self.parser.parse_instr().into_ok() { Unreachable => return Err(trap()), Nop => (), @@ -783,15 +845,15 @@ impl<'m> Thread<'m> { return Ok(self.exit_label()); } End => return Ok(self.exit_label()), - Br(l) => return Ok(self.pop_label(inst, l)), + Br(l) => return self.pop_label(inst, l, &mut store.threads), BrIf(l) => { if self.pop_value().unwrap_i32() != 0 { - return Ok(self.pop_label(inst, l)); + return self.pop_label(inst, l, &mut store.threads); } } BrTable(ls, ln) => { let i = self.pop_value().unwrap_i32() as usize; - return Ok(self.pop_label(inst, ls.get(i).cloned().unwrap_or(ln))); + return self.pop_label(inst, ls.get(i).cloned().unwrap_or(ln), &mut store.threads); } Return => return Ok(self.exit_frame()), Call(x) => return self.invoke(store, store.func_ptr(inst_id, x)), @@ -1035,20 +1097,48 @@ impl<'m> Thread<'m> { self.labels().push(label); } - fn pop_label(mut self, inst: &mut Instance<'m>, l: LabelIdx) -> ThreadResult<'m> { + #[allow(clippy::ptr_arg)] + fn unbounded_continue(self, _threads: &mut Vec>) -> ThreadResult<'m> { + #[cfg(feature = "interrupt")] + if self.interrupt.is_some_and(|interrupt| interrupt.swap(false, Relaxed)) { + _threads.push(Continuation { + thread: self, + index: 0, + args: vec![], + arity: 0, + #[cfg(feature = "interrupt")] + interrupted: true, + }); + return ThreadResult::Interrupt; + } + + ThreadResult::Continue(self) + } + + fn pop_label( + mut self, inst: &mut Instance<'m>, l: LabelIdx, threads: &mut Vec>, + ) -> Result, Error> { let i = self.labels().len() - l as usize - 1; if i == 0 { - return self.exit_frame(); + return Ok(self.exit_frame()); } let values = core::mem::take(self.values()); let frame = self.frame(); let Label { arity, kind, .. } = frame.labels.drain(i ..).next().unwrap(); self.values().extend_from_slice(&values[values.len() - arity ..]); + match kind { - LabelKind::Loop(pos) => unsafe { self.parser.restore(pos) }, - LabelKind::Block | LabelKind::If => self.skip_to_end(inst, l), + LabelKind::Loop(pos) => { + unsafe { + self.parser.restore(pos); + } + Ok(self.unbounded_continue(threads)) + } + LabelKind::Block | LabelKind::If => { + self.skip_to_end(inst, l); + Ok(ThreadResult::Continue(self)) + } } - ThreadResult::Continue(self) } fn exit_label(mut self) -> ThreadResult<'m> { @@ -1355,7 +1445,14 @@ impl<'m> Thread<'m> { let t = store.funcs[index].1; let arity = t.results.len(); let args = self.pop_values(t.params.len()); - store.threads.push(Continuation { thread: self, arity, index, args }); + store.threads.push(Continuation { + thread: self, + arity, + index, + args, + #[cfg(feature = "interrupt")] + interrupted: false, + }); return Ok(ThreadResult::Host); } Side::Wasm(x) => x, @@ -1366,7 +1463,7 @@ impl<'m> Thread<'m> { let ret = self.parser.save(); self.parser = parser; self.frames.push(Frame::new(inst_id, t.results.len(), ret, locals)); - Ok(ThreadResult::Continue(self)) + Ok(self.unbounded_continue(&mut store.threads)) } } diff --git a/crates/interpreter/test.sh b/crates/interpreter/test.sh index de8007957..17da2a0ac 100755 --- a/crates/interpreter/test.sh +++ b/crates/interpreter/test.sh @@ -40,3 +40,5 @@ RUSTFLAGS=--cfg=portable_atomic_unsafe_assume_single_core \ cargo check --example=hello # Run with `-- --test-threads=1 --nocapture` to see unsupported tests. cargo test --test=spec --features=debug,toctou,float-types,vector-types +cargo test --test=spec --features=debug,toctou,float-types,vector-types,interrupt +cargo test --test=interrupt --all-features \ No newline at end of file diff --git a/crates/interpreter/tests/interrupt.rs b/crates/interpreter/tests/interrupt.rs new file mode 100644 index 000000000..89e33c7c0 --- /dev/null +++ b/crates/interpreter/tests/interrupt.rs @@ -0,0 +1,87 @@ +#![allow(unused_crate_dependencies)] +use core::time; +use std::sync::atomic::Ordering::Relaxed; +use std::thread; + +use portable_atomic::AtomicBool; +use wasefire_interpreter::*; + +#[test] +fn test_interrupt() { + let mut n_interrupts = 0; + let mut n_loops = 0; + let interrupt = AtomicBool::new(false); + + std::thread::scope(|s: &std::thread::Scope<'_, '_>| { + // Create an empty store. + let mut store = Store::default(); + + store.link_func("env", "count", 0, 1).unwrap(); + + // ;; Use `wat2wasm infinite_loop.wat` to regenerate `.wasm`. + // (module + // (import "env" "count" (func $count (result i32))) + + // (memory (export "memory") 1) + // (func (export "loopforever") + // (local i32 i32) + // (loop + // (local.set 0 (call $count)) + // (local.set 1 (i32.const 1)) + // (block + // (loop + // (br_if 1 (i32.gt_u (local.get 1) (local.get 0))) + // (local.set 1 (i32.add (local.get 1) (i32.const 1))) + // (br 0) + // ) + // ) + // (br 0) + // ) + // ) + // ) + + const WASM: &[u8] = &[ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, 0x60, 0x00, 0x01, + 0x7f, 0x60, 0x00, 0x00, 0x02, 0x0d, 0x01, 0x03, 0x65, 0x6e, 0x76, 0x05, 0x63, 0x6f, + 0x75, 0x6e, 0x74, 0x00, 0x00, 0x03, 0x02, 0x01, 0x01, 0x05, 0x03, 0x01, 0x00, 0x01, + 0x07, 0x18, 0x02, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x02, 0x00, 0x0b, 0x6c, + 0x6f, 0x6f, 0x70, 0x66, 0x6f, 0x72, 0x65, 0x76, 0x65, 0x72, 0x00, 0x01, 0x0a, 0x29, + 0x01, 0x27, 0x01, 0x02, 0x7f, 0x03, 0x40, 0x10, 0x00, 0x21, 0x00, 0x41, 0x01, 0x21, + 0x01, 0x02, 0x40, 0x03, 0x40, 0x20, 0x01, 0x20, 0x00, 0x4b, 0x0d, 0x01, 0x20, 0x01, + 0x41, 0x01, 0x6a, 0x21, 0x01, 0x0c, 0x00, 0x0b, 0x0b, 0x0c, 0x00, 0x0b, 0x0b, + ]; + let module = Module::new(WASM).unwrap(); + let mut memory = [0; 16]; + + // Instantiate the module in the store. + let inst = store.instantiate(module, &mut memory).unwrap(); + + store.set_interrupt(Some(&interrupt)); + let mut result = store.invoke(inst, "loopforever", vec![]).unwrap(); + + // Let the outer infinite loop do 10 iterations. + while n_loops <= 10 { + let call = match result { + RunResult::Host(call) => call, + RunResult::Interrupt(call) => call, + RunResult::Done(_) => unreachable!(), + }; + + if call.is_interrupt() { + n_interrupts += 1; + result = call.resume(&[]).unwrap(); + } else { + // This is the count() function called in the loop header. + assert!(call.index() == 0); + n_loops += 1; + // Interrupt. + s.spawn(|| { + thread::sleep(time::Duration::from_millis(1)); + interrupt.store(true, Relaxed); + }); + result = call.resume(&[Val::I32(1000)]).unwrap(); + } + } + }); + assert!(n_interrupts > 9); +} diff --git a/crates/interpreter/tests/spec.rs b/crates/interpreter/tests/spec.rs index b5ba42510..2dd26fde6 100644 --- a/crates/interpreter/tests/spec.rs +++ b/crates/interpreter/tests/spec.rs @@ -183,6 +183,8 @@ impl<'m> Env<'m> { Ok(match self.store.invoke(inst_id, name, args)? { RunResult::Done(x) => x, RunResult::Host { .. } => unreachable!(), + #[cfg(feature = "interrupt")] + RunResult::Interrupt { .. } => unreachable!(), }) } diff --git a/scripts/log.sh b/scripts/log.sh index 25c5f7ab4..d2376669f 100644 --- a/scripts/log.sh +++ b/scripts/log.sh @@ -20,6 +20,8 @@ t() { _log '1;33' Todo "$*"; } d() { _log '1;32' Done "$*"; exit 0; } e() { _log '1;31' Error "$*"; exit 1; } +export LC_COLLATE=C + # We put the escape character in a variable because bash doesn't interpret escaped characters and # some scripts use bash instead of sh. _LOG=$(printf '\e')