Skip to content
Open
69 changes: 65 additions & 4 deletions crates/compiler/codegen/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub struct CasmBuilder {
/// Labels that need to be resolved
pub(super) labels: Vec<Label>,
/// Current function layout for offset lookups
pub(crate) layout: FunctionLayout,
pub layout: FunctionLayout,
/// Counter for generating unique labels
pub(super) label_counter: usize,
/// Highest fp+ offset that has been written to (for optimization tracking)
Expand Down Expand Up @@ -96,7 +96,7 @@ impl CasmBuilder {
/// Generate type-aware assignment instruction
///
/// Handles assignments for all types including aggregates (structs, tuples)
pub(crate) fn assign(
pub fn assign(
&mut self,
dest: ValueId,
source: Value,
Expand Down Expand Up @@ -194,7 +194,7 @@ impl CasmBuilder {
}

/// Generate unary operation instruction
pub(crate) fn unary_op(
pub fn unary_op(
&mut self,
op: UnaryOp,
dest: ValueId,
Expand Down Expand Up @@ -228,11 +228,71 @@ impl CasmBuilder {
Ok(())
}

/// Load a value from memory through a pointer
///
/// Loads slots from [[ptr_base] + 0..size] into dest. Size is inferred from dest's layout.
pub fn load_from_memory(
&mut self,
dest: ValueId,
ptr_base: ValueId,
size: usize,
) -> CodegenResult<()> {
let base_off = self.layout.get_offset(ptr_base)?;
let dest_off = self.layout.allocate_local(dest, size)?;

for slot in 0..size {
self.store_from_double_deref_fp_imm(
base_off,
slot as i32,
dest_off + slot as i32,
format!(
"[fp + {}] = [[fp + {}] + {}] (load limb {}/{})",
dest_off + slot as i32,
base_off,
slot,
slot + 1,
size
),
);
}

Ok(())
}

/// Store a value to memory through a pointer
///
/// Stores slots to [[ptr_base] + 0..size] from value. Size is inferred from value's layout.
pub fn store_to_memory(&mut self, ptr_base: ValueId, value: ValueId) -> CodegenResult<()> {
let base_off = self.layout.get_offset(ptr_base)?;
let value_off = self.layout.get_offset(value)?;

// Get size from value's layout
let size = self.layout.get_value_size(value);

for slot in 0..size {
self.store_to_double_deref_fp_imm(
value_off + slot as i32,
base_off,
slot as i32,
format!(
"[[fp + {}] + {}] = [fp + {}] (store limb {}/{})",
base_off,
slot,
value_off + slot as i32,
slot + 1,
size
),
);
}

Ok(())
}

/// Generate a binary operation instruction
///
/// If target_offset is provided, writes directly to that location.
/// Otherwise, allocates a new local variable.
pub(crate) fn binary_op(
pub fn binary_op(
&mut self,
op: BinaryOp,
dest: ValueId,
Expand Down Expand Up @@ -332,6 +392,7 @@ impl CasmBuilder {
&mut self.instructions
}

#[cfg(test)]
/// Get the labels
pub(crate) fn labels(&self) -> &[Label] {
&self.labels
Expand Down
2 changes: 1 addition & 1 deletion crates/compiler/codegen/src/builder/calls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use stwo_prover::core::fields::m31::M31;

impl super::CasmBuilder {
/// Shared lowering for all call flavors (void, single, multiple).
pub(crate) fn lower_call(
pub fn lower_call(
&mut self,
callee_name: &str,
args: &[Value],
Expand Down
24 changes: 7 additions & 17 deletions crates/compiler/codegen/src/builder/ctrlflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use stwo_prover::core::fields::m31::M31;
impl super::CasmBuilder {
/// Generates an unconditional jump to a label.
/// Uses a relative jump (offset resolved in a later pass).
pub(crate) fn jump(&mut self, target_label: &str) {
pub fn jump(&mut self, target_label: &str) {
let instr = InstructionBuilder::new(
CasmInstr::JmpRelImm {
offset: M31::from(0),
Expand All @@ -22,7 +22,7 @@ impl super::CasmBuilder {

/// Generates a conditional jump instruction that triggers if the value at `cond_off` is non-zero.
/// The `target_label` is a placeholder that will be resolved to a relative offset later.
pub(crate) fn jnz(&mut self, condition: Value, target_label: &str) -> CodegenResult<()> {
pub fn jnz(&mut self, condition: Value, target_label: &str) -> CodegenResult<()> {
// Get the condition value offset
let cond_off = match condition {
Value::Operand(cond_id) => self.layout.get_offset(cond_id)?,
Expand All @@ -44,7 +44,7 @@ impl super::CasmBuilder {
}

/// Generates a conditional jump based on a direct fp-relative offset.
pub(crate) fn jnz_offset(&mut self, cond_off: i32, target_label: &str) {
pub fn jnz_offset(&mut self, cond_off: i32, target_label: &str) {
let instr = InstructionBuilder::new(
CasmInstr::JnzFpImm {
cond_off: M31::from(cond_off),
Expand All @@ -58,12 +58,7 @@ impl super::CasmBuilder {
}

/// Short-circuit OR: dest = (left != 0) || (right != 0)
pub(super) fn sc_or(
&mut self,
dest_off: i32,
left: &Value,
right: &Value,
) -> CodegenResult<()> {
pub fn sc_or(&mut self, dest_off: i32, left: &Value, right: &Value) -> CodegenResult<()> {
// Initialize result to 0
self.store_immediate(0, dest_off, "Initialize OR result to 0".to_string());
let set_true = self.emit_new_label_name("or_true");
Expand All @@ -82,12 +77,7 @@ impl super::CasmBuilder {
}

/// Short-circuit AND: dest = (left != 0) && (right != 0)
pub(super) fn sc_and(
&mut self,
dest_off: i32,
left: &Value,
right: &Value,
) -> CodegenResult<()> {
pub fn sc_and(&mut self, dest_off: i32, left: &Value, right: &Value) -> CodegenResult<()> {
// Default to 0
self.store_immediate(0, dest_off, format!("[fp + {dest_off}] = 0"));
let check_right = self.emit_new_label_name("and_check_right");
Expand Down Expand Up @@ -132,7 +122,7 @@ impl super::CasmBuilder {
}

/// NOT: dest = ([source] == 0)
pub(super) fn sc_not(&mut self, dest_off: i32, source: &Value) -> CodegenResult<()> {
pub fn sc_not(&mut self, dest_off: i32, source: &Value) -> CodegenResult<()> {
let set_zero = self.emit_new_label_name("not_zero");
let end = self.emit_new_label_name("not_end");
match source {
Expand Down Expand Up @@ -170,7 +160,7 @@ impl super::CasmBuilder {
/// - When `value` is an operand, emits a JNZ.
/// - When `value` is a literal and truthy, emits an unconditional JMP if `emit_jmp_if_const_true` is true.
/// - Returns Some(true/false) when `value` is a constant; None when dynamic.
fn branch_if_nonzero_to(
pub fn branch_if_nonzero_to(
&mut self,
value: &Value,
label: &str,
Expand Down
4 changes: 2 additions & 2 deletions crates/compiler/codegen/src/builder/emit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ impl super::CasmBuilder {
}

/// Generate a fresh label name using the builder's counter.
pub(crate) fn emit_new_label_name(&mut self, prefix: &str) -> String {
pub fn emit_new_label_name(&mut self, prefix: &str) -> String {
let label_id = self.label_counter;
self.label_counter += 1;
format!("{}_{}", prefix, label_id)
}

/// Add a label at the current instruction address.
pub(crate) fn emit_add_label(&mut self, mut label: Label) {
pub fn emit_add_label(&mut self, mut label: Label) {
label.address = Some(self.instructions.len());
self.labels.push(label);
}
Expand Down
80 changes: 50 additions & 30 deletions crates/compiler/codegen/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ impl CodeGenerator {
}

/// Compile the generated code into a CompiledProgram.
pub(crate) fn compile(self) -> CodegenResult<Program> {
pub fn compile(self) -> CodegenResult<Program> {
let instructions: Vec<cairo_m_common::Instruction> = self
.instructions
.iter()
Expand Down Expand Up @@ -256,7 +256,7 @@ impl CodeGenerator {
}

/// Calculate memory layout for variable-sized instructions
fn calculate_memory_layout(&mut self) -> CodegenResult<()> {
pub fn calculate_memory_layout(&mut self) -> CodegenResult<()> {
self.memory_layout.clear();
let mut current_mem_pc = 0u32;

Expand Down Expand Up @@ -344,6 +344,51 @@ impl CodeGenerator {
Ok(out)
}

/// Add a pre-built function from a CasmBuilder to the program
///
/// This allows external code to build functions using CasmBuilder and add them
/// to the program without exposing internal fields. Used by WASM lowering.
pub fn add_function_from_builder(
&mut self,
mut builder: CasmBuilder,
params: Vec<AbiSlot>,
returns: Vec<AbiSlot>,
) -> CodegenResult<()> {
let name = builder.layout.name.clone();
// Store layout
self.function_layouts
.insert(name.clone(), builder.layout.clone());

// Update entrypoint with current instruction offset
let info = EntrypointInfo {
pc: self.instructions.len(),
params,
returns,
};
self.function_entrypoints.insert(name, info);

// Update label counter to avoid collisions
self.label_counter += builder.label_counter();

// Run post-builder passes
passes::run_all(&mut builder)?;

// Fix label addresses to be relative to global instruction stream
let instruction_offset = self.instructions.len();
let mut corrected_labels = builder.labels;
for label in &mut corrected_labels {
if let Some(local_addr) = label.address {
label.address = Some(local_addr + instruction_offset);
}
}

// Append generated instructions and corrected labels
self.instructions.extend(builder.instructions);
self.labels.extend(corrected_labels);

Ok(())
}

/// Generate code for all functions
fn generate_all_functions(&mut self, module: &MirModule) -> CodegenResult<()> {
for (_, function) in module.functions() {
Expand Down Expand Up @@ -399,36 +444,11 @@ impl CodeGenerator {
})
.collect::<CodegenResult<_>>()?;

let entrypoint_info = EntrypointInfo {
pc: self.instructions.len(),
params,
returns,
};
self.function_entrypoints
.insert(function.name.clone(), entrypoint_info);

builder.emit_add_label(func_label);

self.generate_basic_blocks(function, module, &mut builder)?;

self.label_counter += builder.label_counter();

// Run post-builder passes (deduplication, peephole opts, etc.)
passes::run_all(&mut builder)?;

// Fix label addresses to be relative to the global instruction stream
let instruction_offset = self.instructions.len();
let mut corrected_labels = builder.labels().to_vec();
for label in &mut corrected_labels {
if let Some(local_addr) = label.address {
label.address = Some(local_addr + instruction_offset);
}
}

// Append generated instructions and corrected labels
self.instructions
.extend(builder.instructions().iter().cloned());
self.labels.extend(corrected_labels);
// Use common logic to append function
self.add_function_from_builder(builder, params, returns)?;

Ok(())
}
Expand Down Expand Up @@ -1557,7 +1577,7 @@ impl CodeGenerator {
}

/// Resolve all label references (second pass)
fn resolve_labels(&mut self) -> CodegenResult<()> {
pub fn resolve_labels(&mut self) -> CodegenResult<()> {
// Build a map of label names to their physical addresses
let mut label_map = HashMap::new();

Expand Down
10 changes: 5 additions & 5 deletions crates/compiler/codegen/src/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,17 @@ pub enum ValueLayout {
/// Maps every ValueId in a function to its fp-relative memory offset.
#[derive(Debug, Clone)]
pub struct FunctionLayout {
name: String,
pub name: String,
/// Maps ValueId to its memory layout.
pub value_layouts: FxHashMap<ValueId, ValueLayout>,
/// The total frame size needed for this function.
pub frame_size: usize,
/// Number of parameters this function takes.
pub num_parameters: usize,
/// Number of values this function returns.
num_return_values: usize,
pub num_return_values: usize,
/// Total number of slots required for return values (accounting for multi-slot types).
num_return_slots: usize,
pub num_return_slots: usize,
}

impl FunctionLayout {
Expand Down Expand Up @@ -235,7 +235,7 @@ impl FunctionLayout {
}

/// Allocates a new local variable at the next available positive offset from `fp`.
pub(crate) fn allocate_local(&mut self, value_id: ValueId, size: usize) -> CodegenResult<i32> {
pub fn allocate_local(&mut self, value_id: ValueId, size: usize) -> CodegenResult<i32> {
// If this value is already allocated, return its offset.
if let Some(layout) = self.value_layouts.get(&value_id) {
return match layout {
Expand Down Expand Up @@ -289,7 +289,7 @@ impl FunctionLayout {
}

/// Gets the fp-relative offset for a `ValueId`.
pub(crate) fn get_offset(&self, value_id: ValueId) -> CodegenResult<i32> {
pub fn get_offset(&self, value_id: ValueId) -> CodegenResult<i32> {
match self.value_layouts.get(&value_id) {
Some(ValueLayout::Slot { offset }) | Some(ValueLayout::MultiSlot { offset, .. }) => {
Ok(*offset)
Expand Down
2 changes: 1 addition & 1 deletion crates/wasm/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ mod tests {
// Test basic loading functionality
let wasm_bytes = parse_file("tests/test_cases/i32_arithmetic.wat").unwrap();
let result = BlocklessDagModule::from_bytes(&wasm_bytes);
assert!(result.is_ok(), "Should load add.wasm successfully");
assert!(result.is_ok(), "Should load wat file successfully");

let module = result.unwrap();
assert!(!module.0.functions.is_empty());
Expand Down
Loading
Loading