diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index 1fb35929c5..da763b2cf7 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -3,7 +3,9 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa; use super::Builder; use crate::abi::ConvSpirvType; -use crate::builder_spirv::{BuilderCursor, SpirvConst, SpirvValue, SpirvValueExt, SpirvValueKind}; +use crate::builder_spirv::{ + SpirvBlockCursor, SpirvConst, SpirvValue, SpirvValueExt, SpirvValueKind, +}; use crate::codegen_cx::CodegenCx; use crate::custom_insts::{CustomInst, CustomOp}; use crate::spirv_type::SpirvType; @@ -127,7 +129,7 @@ fn memset_fill_u64(b: u8) -> u64 { } fn memset_dynamic_scalar( - builder: &Builder<'_, '_>, + builder: &mut Builder<'_, '_>, fill_var: Word, byte_width: usize, is_float: bool, @@ -154,7 +156,7 @@ fn memset_dynamic_scalar( impl<'a, 'tcx> Builder<'a, 'tcx> { #[instrument(level = "trace", skip(self))] - fn ordering_to_semantics_def(&self, ordering: AtomicOrdering) -> SpirvValue { + fn ordering_to_semantics_def(&mut self, ordering: AtomicOrdering) -> SpirvValue { let mut invalid_seq_cst = false; let semantics = match ordering { AtomicOrdering::Relaxed => MemorySemantics::NONE, @@ -166,8 +168,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { | MemorySemantics::ACQUIRE_RELEASE } AtomicOrdering::SeqCst => { - let emit = self.emit(); - let memory_model = emit.module_ref().memory_model.as_ref().unwrap(); + let builder = self.emit(); + let memory_model = builder.module_ref().memory_model.as_ref().unwrap(); if memory_model.operands[1].unwrap_memory_model() == MemoryModel::Vulkan { invalid_seq_cst = true; } @@ -263,7 +265,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } #[instrument(level = "trace", skip(self))] - fn memset_dynamic_pattern(&self, ty: &SpirvType<'tcx>, fill_var: Word) -> Word { + fn memset_dynamic_pattern(&mut self, ty: &SpirvType<'tcx>, fill_var: Word) -> Word { match *ty { SpirvType::Void => self.fatal("memset invalid on void pattern"), SpirvType::Bool => self.fatal("memset invalid on bool pattern"), @@ -750,9 +752,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { let ptr_id = ptr.def(self); let maybe_original_access_chain = { - let emit = self.emit(); - let module = emit.module_ref(); - let current_func_blocks = emit + let builder = self.emit(); + let module = builder.module_ref(); + let current_func_blocks = builder .selected_function() .and_then(|func_idx| Some(&module.functions.get(func_idx)?.blocks[..])) .unwrap_or_default(); @@ -823,20 +825,20 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ) )] fn emit_access_chain( - &self, + &mut self, result_type: ::Type, pointer: Word, ptr_base_index: Option, indices: Vec, is_inbounds: bool, ) -> SpirvValue { - let mut emit = self.emit(); + let mut builder = self.emit(); let non_zero_ptr_base_index = ptr_base_index.filter(|&idx| self.builder.lookup_const_scalar(idx) != Some(0)); if let Some(ptr_base_index) = non_zero_ptr_base_index { let result = if is_inbounds { - emit.in_bounds_ptr_access_chain( + builder.in_bounds_ptr_access_chain( result_type, None, pointer, @@ -844,7 +846,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { indices, ) } else { - emit.ptr_access_chain( + builder.ptr_access_chain( result_type, None, pointer, @@ -857,9 +859,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { result } else { if is_inbounds { - emit.in_bounds_access_chain(result_type, None, pointer, indices) + builder.in_bounds_access_chain(result_type, None, pointer, indices) } else { - emit.access_chain(result_type, None, pointer, indices) + builder.access_chain(result_type, None, pointer, indices) } .unwrap() } @@ -1091,21 +1093,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { #[instrument(level = "trace", skip(cx))] fn build(cx: &'a Self::CodegenCx, llbb: Self::BasicBlock) -> Self { - let cursor = cx.builder.select_block_by_id(llbb); - // FIXME(eddyb) change `Self::Function` to be more like a function index. - let current_fn = { - let emit = cx.emit_with_cursor(cursor); - let selected_function = emit.selected_function().unwrap(); - let selected_function = &emit.module_ref().functions[selected_function]; - let def_inst = selected_function.def.as_ref().unwrap(); - let def = def_inst.result_id.unwrap(); - let ty = def_inst.operands[1].unwrap_id_ref(); - def.with_type(ty) - }; Self { cx, - cursor, - current_fn, + current_block: llbb, current_span: Default::default(), } } @@ -1201,17 +1191,18 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { llfn: Self::Function, _name: &str, ) -> Self::BasicBlock { - let cursor_fn = cx.builder.select_function_by_id(llfn.def_cx(cx)); - cx.emit_with_cursor(cursor_fn).begin_block(None).unwrap() + let mut builder = cx.builder.builder_for_fn(llfn); + let id = builder.begin_block(None).unwrap(); + let index_in_builder = builder.selected_block().unwrap(); + SpirvBlockCursor { + parent_fn: llfn, + id, + index_in_builder, + } } - fn append_sibling_block(&mut self, _name: &str) -> Self::BasicBlock { - self.emit_with_cursor(BuilderCursor { - function: self.cursor.function, - block: None, - }) - .begin_block(None) - .unwrap() + fn append_sibling_block(&mut self, name: &str) -> Self::BasicBlock { + Self::append_block(self.cx, self.current_block.parent_fn, name) } fn switch_to_block(&mut self, llbb: Self::BasicBlock) { @@ -1239,7 +1230,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } fn br(&mut self, dest: Self::BasicBlock) { - self.emit().branch(dest).unwrap(); + self.emit().branch(dest.id).unwrap(); } fn cond_br( @@ -1257,7 +1248,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { Some(SpirvConst::Scalar(0)) => self.br(else_llbb), _ => { self.emit() - .branch_conditional(cond, then_llbb, else_llbb, empty()) + .branch_conditional(cond, then_llbb.id, else_llbb.id, empty()) .unwrap(); } } @@ -1330,9 +1321,11 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { )), }; let cases = cases - .map(|(i, b)| (construct_case(self, signed, i), b)) + .map(|(i, b)| (construct_case(self, signed, i), b.id)) .collect::>(); - self.emit().switch(v.def(self), else_llbb, cases).unwrap(); + self.emit() + .switch(v.def(self), else_llbb.id, cases) + .unwrap(); } fn invoke( @@ -1349,7 +1342,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { ) -> Self::Value { // Exceptions don't exist, jump directly to then block let result = self.call(llty, fn_attrs, fn_abi, llfn, args, funclet, instance); - self.emit().branch(then).unwrap(); + self.emit().branch(then.id).unwrap(); result } @@ -2741,16 +2734,16 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { .with_type(agg_val.ty) } - fn set_personality_fn(&mut self, _personality: Self::Value) { + fn set_personality_fn(&mut self, _personality: Self::Function) { todo!() } // These are used by everyone except msvc - fn cleanup_landing_pad(&mut self, _pers_fn: Self::Value) -> (Self::Value, Self::Value) { + fn cleanup_landing_pad(&mut self, _pers_fn: Self::Function) -> (Self::Value, Self::Value) { todo!() } - fn filter_landing_pad(&mut self, _pers_fn: Self::Value) -> (Self::Value, Self::Value) { + fn filter_landing_pad(&mut self, _pers_fn: Self::Function) -> (Self::Value, Self::Value) { todo!() } @@ -2979,16 +2972,6 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { // be fixed upstream, so we never see any "function pointer" values being // created just to perform direct calls. let (callee_val, result_type, argument_types) = match self.lookup_type(callee.ty) { - // HACK(eddyb) this seems to be needed, but it's not what `get_fn_addr` - // produces, are these coming from inside `rustc_codegen_spirv`? - SpirvType::Function { - return_type, - arguments, - } => { - assert_ty_eq!(self, callee_ty, callee.ty); - (callee.def(self), return_type, arguments) - } - SpirvType::Pointer { pointee } => match self.lookup_type(pointee) { SpirvType::Function { return_type, @@ -3014,7 +2997,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { }, _ => bug!( - "call expected function or `fn` pointer type, got `{}`", + "call expected `fn` pointer type, got `{}`", self.debug_type(callee.ty) ), }; @@ -3091,18 +3074,20 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { struct FormatArgsNotRecognized(String); // HACK(eddyb) this is basically a `try` block. - let try_decode_and_remove_format_args = || { + let mut try_decode_and_remove_format_args = || { let mut decoded_format_args = DecodedFormatArgs::default(); - let const_u32_as_usize = |ct_id| match self.builder.lookup_const_by_id(ct_id)? { + // HACK(eddyb) work around mutable borrowing conflicts. + let cx = self.cx; + + let const_u32_as_usize = |ct_id| match cx.builder.lookup_const_by_id(ct_id)? { SpirvConst::Scalar(x) => Some(u32::try_from(x).ok()? as usize), _ => None, }; let const_slice_as_elem_ids = |ptr_id: Word, len: usize| { - if let SpirvConst::PtrTo { pointee } = - self.builder.lookup_const_by_id(ptr_id)? + if let SpirvConst::PtrTo { pointee } = cx.builder.lookup_const_by_id(ptr_id)? && let SpirvConst::Composite(elems) = - self.builder.lookup_const_by_id(pointee)? + cx.builder.lookup_const_by_id(pointee)? && elems.len() == len { return Some(elems); diff --git a/crates/rustc_codegen_spirv/src/builder/ext_inst.rs b/crates/rustc_codegen_spirv/src/builder/ext_inst.rs index 7c4997de76..3b3c129759 100644 --- a/crates/rustc_codegen_spirv/src/builder/ext_inst.rs +++ b/crates/rustc_codegen_spirv/src/builder/ext_inst.rs @@ -41,7 +41,7 @@ impl ExtInst { impl<'a, 'tcx> Builder<'a, 'tcx> { pub fn custom_inst( - &self, + &mut self, result_type: Word, inst: custom_insts::CustomInst, ) -> SpirvValue { @@ -58,7 +58,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { .with_type(result_type) } - pub fn gl_op(&self, op: GLOp, result_type: Word, args: impl AsRef<[SpirvValue]>) -> SpirvValue { + pub fn gl_op( + &mut self, + op: GLOp, + result_type: Word, + args: impl AsRef<[SpirvValue]>, + ) -> SpirvValue { let args = args.as_ref(); let glsl = self.ext_inst.borrow_mut().import_glsl(self); self.emit() diff --git a/crates/rustc_codegen_spirv/src/builder/intrinsics.rs b/crates/rustc_codegen_spirv/src/builder/intrinsics.rs index 084f14b649..deab2b087c 100644 --- a/crates/rustc_codegen_spirv/src/builder/intrinsics.rs +++ b/crates/rustc_codegen_spirv/src/builder/intrinsics.rs @@ -381,7 +381,7 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> { } impl Builder<'_, '_> { - pub fn count_ones(&self, arg: SpirvValue) -> SpirvValue { + pub fn count_ones(&mut self, arg: SpirvValue) -> SpirvValue { let ty = arg.ty; match self.cx.lookup_type(ty) { SpirvType::Integer(bits, false) => { @@ -426,7 +426,7 @@ impl Builder<'_, '_> { } } - pub fn bit_reverse(&self, arg: SpirvValue) -> SpirvValue { + pub fn bit_reverse(&mut self, arg: SpirvValue) -> SpirvValue { let ty = arg.ty; match self.cx.lookup_type(ty) { SpirvType::Integer(bits, false) => { @@ -489,7 +489,7 @@ impl Builder<'_, '_> { } pub fn count_leading_trailing_zeros( - &self, + &mut self, arg: SpirvValue, trailing: bool, non_zero: bool, @@ -501,9 +501,9 @@ impl Builder<'_, '_> { let u32 = SpirvType::Integer(32, false).def(self.span(), self); let glsl = self.ext_inst.borrow_mut().import_glsl(self); - let find_xsb = |arg, offset: i32| { + let find_xsb = |this: &mut Self, arg, offset: i32| { if trailing { - let lsb = self + let lsb = this .emit() .ext_inst( u32, @@ -516,12 +516,12 @@ impl Builder<'_, '_> { if offset == 0 { lsb } else { - let const_offset = self.constant_i32(self.span(), offset).def(self); - self.emit().i_add(u32, None, const_offset, lsb).unwrap() + let const_offset = this.constant_i32(this.span(), offset).def(this); + this.emit().i_add(u32, None, const_offset, lsb).unwrap() } } else { // rust is always unsigned, so FindUMsb - let msb_bit = self + let msb_bit = this .emit() .ext_inst( u32, @@ -533,8 +533,8 @@ impl Builder<'_, '_> { .unwrap(); // the glsl op returns the Msb bit, not the amount of leading zeros of this u32 // leading zeros = 31 - Msb bit - let const_offset = self.constant_i32(self.span(), 31 - offset).def(self); - self.emit().i_sub(u32, None, const_offset, msb_bit).unwrap() + let const_offset = this.constant_i32(this.span(), 31 - offset).def(this); + this.emit().i_sub(u32, None, const_offset, msb_bit).unwrap() } }; @@ -542,12 +542,12 @@ impl Builder<'_, '_> { 8 | 16 => { let arg = self.emit().u_convert(u32, None, arg.def(self)).unwrap(); if trailing { - find_xsb(arg, 0) + find_xsb(self, arg, 0) } else { - find_xsb(arg, bits as i32 - 32) + find_xsb(self, arg, bits as i32 - 32) } } - 32 => find_xsb(arg.def(self), 0), + 32 => find_xsb(self, arg.def(self), 0), 64 => { let u32_0 = self.constant_int(u32, 0).def(self); let u32_32 = self.constant_u32(self.span(), 32).def(self); @@ -562,16 +562,16 @@ impl Builder<'_, '_> { if trailing { let use_lower = self.emit().i_equal(bool, None, lower, u32_0).unwrap(); - let lower_bits = find_xsb(lower, 32); - let higher_bits = find_xsb(higher, 0); + let lower_bits = find_xsb(self, lower, 32); + let higher_bits = find_xsb(self, higher, 0); self.emit() .select(u32, None, use_lower, higher_bits, lower_bits) .unwrap() } else { let use_higher = self.emit().i_equal(bool, None, higher, u32_0).unwrap(); - let lower_bits = find_xsb(lower, 0); - let higher_bits = find_xsb(higher, 32); + let lower_bits = find_xsb(self, lower, 0); + let higher_bits = find_xsb(self, higher, 32); self.emit() .select(u32, None, use_higher, lower_bits, higher_bits) .unwrap() diff --git a/crates/rustc_codegen_spirv/src/builder/mod.rs b/crates/rustc_codegen_spirv/src/builder/mod.rs index e1dcd1577c..fb418f1018 100644 --- a/crates/rustc_codegen_spirv/src/builder/mod.rs +++ b/crates/rustc_codegen_spirv/src/builder/mod.rs @@ -12,7 +12,7 @@ pub use spirv_asm::InstructionTable; // HACK(eddyb) avoids rewriting all of the imports (see `lib.rs` and `build.rs`). use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa; -use crate::builder_spirv::{BuilderCursor, SpirvValue, SpirvValueExt}; +use crate::builder_spirv::{SpirvValue, SpirvValueExt}; use crate::codegen_cx::CodegenCx; use crate::spirv_type::SpirvType; use rspirv::spirv::Word; @@ -40,15 +40,18 @@ use std::ops::{Deref, Range}; pub struct Builder<'a, 'tcx> { cx: &'a CodegenCx<'tcx>, - cursor: BuilderCursor, - current_fn: ::Function, + current_block: ::BasicBlock, current_span: Option, } impl<'a, 'tcx> Builder<'a, 'tcx> { /// See comment on `BuilderCursor` - pub fn emit(&self) -> std::cell::RefMut<'_, rspirv::dr::Builder> { - self.emit_with_cursor(self.cursor) + // + // FIXME(eddyb) take advantage of `&mut self` to avoid `RefCell` entirely + // (sadly it requires making `&CodegeCx`'s types/consts more like SPIR-T, + // and completely disjoint from mutably building functions). + pub fn emit(&mut self) -> std::cell::RefMut<'a, rspirv::dr::Builder> { + self.cx.builder.builder_for_block(self.current_block) } pub fn zombie(&self, word: Word, reason: &str) { @@ -208,15 +211,16 @@ impl<'a, 'tcx> ArgAbiBuilderMethods<'tcx> for Builder<'a, 'tcx> { idx: &mut usize, dst: PlaceRef<'tcx, Self::Value>, ) { - fn next(bx: &Builder<'_, '_>, idx: &mut usize) -> SpirvValue { - let val = bx.function_parameter_values.borrow()[&bx.current_fn.def(bx)][*idx]; + fn next(bx: &mut Builder<'_, '_>, idx: &mut usize) -> SpirvValue { + let val = bx.get_param(*idx); *idx += 1; val } match arg_abi.mode { PassMode::Ignore => {} PassMode::Direct(_) => { - self.store_arg(arg_abi, next(self, idx), dst); + let arg = next(self, idx); + self.store_arg(arg_abi, arg, dst); } PassMode::Pair(..) => { OperandValue::Pair(next(self, idx), next(self, idx)).store(self, dst); @@ -253,7 +257,13 @@ impl<'a, 'tcx> ArgAbiBuilderMethods<'tcx> for Builder<'a, 'tcx> { impl AbiBuilderMethods for Builder<'_, '_> { fn get_param(&mut self, index: usize) -> Self::Value { - self.function_parameter_values.borrow()[&self.current_fn.def(self)][index] + let builder = self.emit(); + let param = + &builder.module_ref().functions[builder.selected_function().unwrap()].parameters[index]; + param + .result_id + .unwrap() + .with_type(param.result_type.unwrap()) } } diff --git a/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs b/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs index 663ebbeaf4..96a92fb0a6 100644 --- a/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs +++ b/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs @@ -3,7 +3,7 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa; use super::Builder; use crate::abi::ConvSpirvType; -use crate::builder_spirv::{BuilderCursor, SpirvValue}; +use crate::builder_spirv::SpirvValue; use crate::codegen_cx::CodegenCx; use crate::spirv_type::SpirvType; use rspirv::dr; @@ -418,12 +418,11 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { // OpVariable with Function storage class should be emitted inside the function, // however, all other OpVariables should appear in the global scope instead. if inst.operands[0].unwrap_storage_class() == StorageClass::Function { - self.emit_with_cursor(BuilderCursor { - block: Some(0), - ..self.cursor - }) - .insert_into_block(dr::InsertPoint::Begin, inst) - .unwrap(); + let mut builder = self.emit(); + builder.select_block(Some(0)).unwrap(); + builder + .insert_into_block(dr::InsertPoint::Begin, inst) + .unwrap(); } else { self.emit_global() .insert_types_global_values(dr::InsertPoint::End, inst); @@ -538,7 +537,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { }; let inst_class = inst_name .strip_prefix("Op") - .and_then(|n| self.instruction_table.table.get(n)); + .and_then(|n| self.cx.instruction_table.table.get(n)); let inst_class = if let Some(inst) = inst_class { inst } else { diff --git a/crates/rustc_codegen_spirv/src/builder_spirv.rs b/crates/rustc_codegen_spirv/src/builder_spirv.rs index 936edf883a..e05b433057 100644 --- a/crates/rustc_codegen_spirv/src/builder_spirv.rs +++ b/crates/rustc_codegen_spirv/src/builder_spirv.rs @@ -7,7 +7,7 @@ use crate::spirv_type::SpirvType; use crate::symbols::Symbols; use crate::target::SpirvTarget; use crate::target_feature::TargetFeature; -use rspirv::dr::{Block, Builder, Instruction, Module, Operand}; +use rspirv::dr::{Builder, Instruction, Module, Operand}; use rspirv::spirv::{ AddressingModel, Capability, MemoryModel, Op, SourceLanguage, StorageClass, Word, }; @@ -380,6 +380,22 @@ pub struct DebugFileSpirv<'tcx> { pub file_name_op_string_id: Word, } +// HACK(eddyb) unlike a raw SPIR-V ID (or `SpirvValue`), this allows random-access. +#[derive(Copy, Clone, Debug)] +pub struct SpirvFunctionCursor { + pub ty: Word, + pub id: Word, + pub index_in_builder: usize, +} + +// HACK(eddyb) unlike a raw SPIR-V ID, this allows random-access. +#[derive(Copy, Clone, Debug)] +pub struct SpirvBlockCursor { + pub parent_fn: SpirvFunctionCursor, + pub id: Word, + pub index_in_builder: usize, +} + /// Cursor system: /// /// The LLVM module builder model (and therefore `codegen_ssa`) assumes that there is a central @@ -402,11 +418,15 @@ pub struct DebugFileSpirv<'tcx> { /// then `self.emit_global()` will use the generic "global cursor" and return a mutable reference /// to the rspirv builder with no basic block nor function selected, i.e. any instructions emitted /// will be in the global section. +// +// FIXME(eddyb) try updating documentation like the above. +// FIXME(eddyb) figure out how to replace `BuilderCursor` with something like +// `Option`, but that can't handle "in function outside BB". #[derive(Debug, Default, Copy, Clone)] #[must_use = "BuilderCursor should usually be assigned to the Builder.cursor field"] -pub struct BuilderCursor { - pub function: Option, - pub block: Option, +struct BuilderCursor { + fn_id_and_idx: Option<(Word, usize)>, + block_id_and_idx: Option<(Word, usize)>, } pub struct BuilderSpirv<'tcx> { @@ -515,37 +535,59 @@ impl<'tcx> BuilderSpirv<'tcx> { .unwrap(); } + pub fn has_capability(&self, capability: Capability) -> bool { + self.enabled_capabilities.contains(&capability) + } + /// See comment on `BuilderCursor` - pub fn builder(&self, cursor: BuilderCursor) -> RefMut<'_, Builder> { + fn builder(&self, cursor: BuilderCursor) -> RefMut<'_, Builder> { let mut builder = self.builder.borrow_mut(); - // select_function does bounds checks and other relatively expensive things, so don't just call it - // unconditionally. - if builder.selected_function() != cursor.function { - builder.select_function(cursor.function).unwrap(); + + let [maybe_fn_idx, maybe_block_idx] = [cursor.fn_id_and_idx, cursor.block_id_and_idx] + .map(|id_and_idx| id_and_idx.map(|(_, idx)| idx)); + + let fn_changed = builder.selected_function() != maybe_fn_idx; + if fn_changed { + builder.select_function(maybe_fn_idx).unwrap(); } - if cursor.function.is_some() && builder.selected_block() != cursor.block { - builder.select_block(cursor.block).unwrap(); + + // Only check the function/block IDs if either of their indices changed. + if let Some((fn_id, fn_idx)) = cursor.fn_id_and_idx + && (fn_changed || builder.selected_block() != maybe_block_idx) + { + builder.select_block(maybe_block_idx).unwrap(); + + let function = &builder.module_ref().functions[fn_idx]; + if fn_changed { + assert_eq!(function.def_id(), Some(fn_id)); + } + if let Some((block_id, block_idx)) = cursor.block_id_and_idx { + assert_eq!(function.blocks[block_idx].label_id(), Some(block_id)); + } } + builder } - pub fn has_capability(&self, capability: Capability) -> bool { - self.enabled_capabilities.contains(&capability) + /// See comment on `BuilderCursor` + pub fn global_builder(&self) -> RefMut<'_, Builder> { + self.builder(BuilderCursor::default()) } - pub fn select_function_by_id(&self, id: Word) -> BuilderCursor { - let mut builder = self.builder.borrow_mut(); - for (index, func) in builder.module_ref().functions.iter().enumerate() { - if func.def.as_ref().and_then(|i| i.result_id) == Some(id) { - builder.select_function(Some(index)).unwrap(); - return BuilderCursor { - function: Some(index), - block: None, - }; - } - } + /// See comment on `BuilderCursor` + pub fn builder_for_fn(&self, func: SpirvFunctionCursor) -> RefMut<'_, Builder> { + self.builder(BuilderCursor { + fn_id_and_idx: Some((func.id, func.index_in_builder)), + block_id_and_idx: None, + }) + } - bug!("Function not found: {}", id); + /// See comment on `BuilderCursor` + pub fn builder_for_block(&self, block: SpirvBlockCursor) -> RefMut<'_, Builder> { + self.builder(BuilderCursor { + fn_id_and_idx: Some((block.parent_fn.id, block.parent_fn.index_in_builder)), + block_id_and_idx: Some((block.id, block.index_in_builder)), + }) } pub(crate) fn def_constant_cx( @@ -619,7 +661,7 @@ impl<'tcx> BuilderSpirv<'tcx> { id }; - let mut builder = self.builder(BuilderCursor::default()); + let mut builder = self.global_builder(); let id = match val { SpirvConst::Scalar(v) => match scalar_ty.unwrap() { SpirvType::Integer(..=32, _) | SpirvType::Float(..=32) => { @@ -636,7 +678,7 @@ impl<'tcx> BuilderSpirv<'tcx> { let [lo_id, hi_id] = [v as u64, (v >> 64) as u64].map(|half| cx.const_u64(half).def_cx(cx)); - builder = self.builder(BuilderCursor::default()); + builder = self.global_builder(); let mut const_op = |op, lhs, maybe_rhs| const_op(&mut builder, op, lhs, maybe_rhs); let [lo_u128_id, hi_shifted_u128_id] = @@ -661,7 +703,7 @@ impl<'tcx> BuilderSpirv<'tcx> { let v_u128_id = cx.const_u128(v).def_cx(cx); - builder = self.builder(BuilderCursor::default()); + builder = self.global_builder(); const_op(&mut builder, Op::Bitcast, v_u128_id, None) } SpirvType::Bool => match v { @@ -835,7 +877,7 @@ impl<'tcx> BuilderSpirv<'tcx> { .borrow_mut() .entry(DebugFileKey(sf)) .or_insert_with_key(|DebugFileKey(sf)| { - let mut builder = self.builder(Default::default()); + let mut builder = self.global_builder(); // FIXME(eddyb) it would be nicer if we could just rely on // `RealFileName::to_string_lossy` returning `Cow<'_, str>`, @@ -951,57 +993,4 @@ impl<'tcx> BuilderSpirv<'tcx> { inst.operands.push(Operand::IdRef(initializer)); module.types_global_values.push(inst); } - - pub fn select_block_by_id(&self, id: Word) -> BuilderCursor { - fn block_matches(block: &Block, id: Word) -> bool { - block.label.as_ref().and_then(|b| b.result_id) == Some(id) - } - - let mut builder = self.builder.borrow_mut(); - let module = builder.module_ref(); - - // The user is probably selecting a block in the current function, so search that first. - if let Some(selected_function) = builder.selected_function() { - // make no-ops really fast - if let Some(selected_block) = builder.selected_block() { - let block = &module.functions[selected_function].blocks[selected_block]; - if block_matches(block, id) { - return BuilderCursor { - function: Some(selected_function), - block: Some(selected_block), - }; - } - } - - for (index, block) in module.functions[selected_function] - .blocks - .iter() - .enumerate() - { - if block_matches(block, id) { - builder.select_block(Some(index)).unwrap(); - return BuilderCursor { - function: Some(selected_function), - block: Some(index), - }; - } - } - } - - // Search the whole module. - for (function_index, function) in module.functions.iter().enumerate() { - for (block_index, block) in function.blocks.iter().enumerate() { - if block_matches(block, id) { - builder.select_function(Some(function_index)).unwrap(); - builder.select_block(Some(block_index)).unwrap(); - return BuilderCursor { - function: Some(function_index), - block: Some(block_index), - }; - } - } - } - - bug!("Block not found: {}", id); - } } diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs index 73b434a8bf..5ca2edbd0c 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs @@ -4,7 +4,7 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa; use super::CodegenCx; use crate::abi::ConvSpirvType; use crate::attr::AggregatedSpirvAttributes; -use crate::builder_spirv::{SpirvConst, SpirvValue, SpirvValueExt}; +use crate::builder_spirv::{SpirvConst, SpirvFunctionCursor, SpirvValue, SpirvValueExt}; use crate::custom_decorations::{CustomDecoration, SrcLocDecoration}; use crate::spirv_type::SpirvType; use itertools::Itertools; @@ -40,11 +40,11 @@ fn attrs_to_spirv(attrs: &CodegenFnAttrs) -> FunctionControl { impl<'tcx> CodegenCx<'tcx> { /// Returns a function if it already exists, or declares a header if it doesn't. - pub fn get_fn_ext(&self, instance: Instance<'tcx>) -> SpirvValue { + pub fn get_fn_ext(&self, instance: Instance<'tcx>) -> SpirvFunctionCursor { assert!(!instance.args.has_infer()); assert!(!instance.args.has_escaping_bound_vars()); - if let Some(&func) = self.instances.borrow().get(&instance) { + if let Some(&func) = self.fn_instances.borrow().get(&instance) { return func; } @@ -53,7 +53,7 @@ impl<'tcx> CodegenCx<'tcx> { let linkage = Some(LinkageType::Import); let llfn = self.declare_fn_ext(instance, linkage); - self.instances.borrow_mut().insert(instance, llfn); + self.fn_instances.borrow_mut().insert(instance, llfn); llfn } @@ -62,7 +62,11 @@ impl<'tcx> CodegenCx<'tcx> { // MiscCodegenMethods::get_fn -> get_fn_ext -> declare_fn_ext // MiscCodegenMethods::get_fn_addr -> get_fn_ext -> declare_fn_ext // PreDefineCodegenMethods::predefine_fn -> declare_fn_ext - fn declare_fn_ext(&self, instance: Instance<'tcx>, linkage: Option) -> SpirvValue { + fn declare_fn_ext( + &self, + instance: Instance<'tcx>, + linkage: Option, + ) -> SpirvFunctionCursor { let def_id = instance.def_id(); let control = attrs_to_spirv(self.tcx.codegen_fn_attrs(def_id)); @@ -77,23 +81,28 @@ impl<'tcx> CodegenCx<'tcx> { other => bug!("fn_abi type {}", other.debug(function_type, self)), }; - let fn_id = { + let declared = { let mut emit = self.emit_global(); - let fn_id = emit + let id = emit .begin_function(return_type, None, control, function_type) .unwrap(); + let index_in_builder = emit.selected_function().unwrap(); + + // FIXME(eddyb) omitting `OpFunctionParameter` on imports might be + // illegal, this probably shouldn't be conditional at all. if linkage != Some(LinkageType::Import) { - let parameter_values = argument_types - .iter() - .map(|&ty| emit.function_parameter(ty).unwrap().with_type(ty)) - .collect::>(); - self.function_parameter_values - .borrow_mut() - .insert(fn_id, parameter_values); + for &ty in argument_types { + emit.function_parameter(ty).unwrap(); + } } emit.end_function().unwrap(); - fn_id + SpirvFunctionCursor { + ty: function_type, + id, + index_in_builder, + } }; + let fn_id = declared.id; // HACK(eddyb) this is a temporary workaround due to our use of `rspirv`, // which prevents us from attaching `OpLine`s to `OpFunction` definitions, @@ -124,15 +133,17 @@ impl<'tcx> CodegenCx<'tcx> { self.set_linkage(fn_id, symbol_name.to_owned(), linkage); } - let declared = fn_id.with_type(function_type); - let attrs = AggregatedSpirvAttributes::parse(self, self.tcx.get_attrs_unchecked(def_id)); if let Some(entry) = attrs.entry.map(|attr| attr.value) { + // HACK(eddyb) early insert to let `shader_entry_stub` call this + // very function via `get_fn_addr`. + self.fn_instances.borrow_mut().insert(instance, declared); + let entry_name = entry .name .as_ref() .map_or_else(|| instance.to_string(), ToString::to_string); - self.entry_stub(&instance, fn_abi, declared, entry_name, entry); + self.entry_stub(instance, fn_abi, entry_name, entry); } // FIXME(eddyb) should the maps exist at all, now that the `DefId` is known @@ -264,8 +275,7 @@ impl<'tcx> CodegenCx<'tcx> { } pub fn get_static(&self, def_id: DefId) -> SpirvValue { - let instance = Instance::mono(self.tcx, def_id); - if let Some(&g) = self.instances.borrow().get(&instance) { + if let Some(&g) = self.statics.borrow().get(&def_id) { return g; } @@ -278,11 +288,12 @@ impl<'tcx> CodegenCx<'tcx> { "get_static() should always hit the cache for statics defined in the same CGU, but did not for `{def_id:?}`" ); + let instance = Instance::mono(self.tcx, def_id); let ty = instance.ty(self.tcx, TypingEnv::fully_monomorphized()); let sym = self.tcx.symbol_name(instance).name; let span = self.tcx.def_span(def_id); let g = self.declare_global(span, self.layout_of(ty).spirv_type(span, self)); - self.instances.borrow_mut().insert(instance, g); + self.statics.borrow_mut().insert(def_id, g); self.set_linkage(g.def_cx(self), sym.to_string(), LinkageType::Import); g } @@ -326,7 +337,7 @@ impl<'tcx> PreDefineCodegenMethods<'tcx> for CodegenCx<'tcx> { let g = self.declare_global(span, spvty); - self.instances.borrow_mut().insert(instance, g); + self.statics.borrow_mut().insert(def_id, g); if let Some(linkage) = linkage { self.set_linkage(g.def_cx(self), symbol_name.to_string(), linkage); } @@ -354,7 +365,7 @@ impl<'tcx> PreDefineCodegenMethods<'tcx> for CodegenCx<'tcx> { }; let declared = self.declare_fn_ext(instance, linkage2); - self.instances.borrow_mut().insert(instance, declared); + self.fn_instances.borrow_mut().insert(instance, declared); } } diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs index 7435781173..5de823d261 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs @@ -5,13 +5,13 @@ use super::CodegenCx; use crate::abi::ConvSpirvType; use crate::attr::{AggregatedSpirvAttributes, Entry, Spanned, SpecConstant}; use crate::builder::Builder; -use crate::builder_spirv::{SpirvValue, SpirvValueExt}; +use crate::builder_spirv::{SpirvFunctionCursor, SpirvValue, SpirvValueExt}; use crate::spirv_type::SpirvType; use rspirv::dr::Operand; use rspirv::spirv::{ Capability, Decoration, Dim, ExecutionModel, FunctionControl, StorageClass, Word, }; -use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; +use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods, MiscCodegenMethods as _}; use rustc_data_structures::fx::FxHashMap; use rustc_errors::MultiSpan; use rustc_hir as hir; @@ -63,18 +63,18 @@ impl<'tcx> CodegenCx<'tcx> { // function. pub fn entry_stub( &self, - instance: &Instance<'_>, + entry_instance: Instance<'tcx>, fn_abi: &FnAbi<'tcx, Ty<'tcx>>, - entry_func: SpirvValue, name: String, entry: Entry, ) { + let entry_def_id = entry_instance.def_id(); let span = self .tcx - .def_ident_span(instance.def_id()) - .unwrap_or_else(|| self.tcx.def_span(instance.def_id())); + .def_ident_span(entry_def_id) + .unwrap_or_else(|| self.tcx.def_span(entry_def_id)); let hir_params = { - let fn_local_def_id = if let Some(id) = instance.def_id().as_local() { + let fn_local_def_id = if let Some(id) = entry_def_id.as_local() { id } else { self.tcx @@ -132,9 +132,9 @@ impl<'tcx> CodegenCx<'tcx> { } // let execution_model = entry.execution_model; - let fn_id = self.shader_entry_stub( + let stub = self.shader_entry_stub( span, - entry_func, + entry_instance, fn_abi, hir_params, name, @@ -145,19 +145,19 @@ impl<'tcx> CodegenCx<'tcx> { .execution_modes .iter() .for_each(|(execution_mode, execution_mode_extra)| { - emit.execution_mode(fn_id, *execution_mode, execution_mode_extra); + emit.execution_mode(stub.id, *execution_mode, execution_mode_extra); }); } fn shader_entry_stub( &self, span: Span, - entry_func: SpirvValue, + entry_instance: Instance<'tcx>, entry_fn_abi: &FnAbi<'tcx, Ty<'tcx>>, hir_params: &[hir::Param<'tcx>], name: String, execution_model: ExecutionModel, - ) -> Word { + ) -> SpirvFunctionCursor { let stub_fn = { let void = SpirvType::Void.def(span, self); let fn_void_void = SpirvType::Function { @@ -169,8 +169,13 @@ impl<'tcx> CodegenCx<'tcx> { let id = emit .begin_function(void, None, FunctionControl::NONE, fn_void_void) .unwrap(); + let index_in_builder = emit.selected_function().unwrap(); emit.end_function().unwrap(); - id.with_type(fn_void_void) + SpirvFunctionCursor { + ty: fn_void_void, + id, + index_in_builder, + } }; let mut op_entry_point_interface_operands = vec![]; @@ -192,24 +197,23 @@ impl<'tcx> CodegenCx<'tcx> { } bx.set_span(span); bx.call( - entry_func.ty, + self.get_fn(entry_instance).ty, None, Some(entry_fn_abi), - entry_func, + self.get_fn_addr(entry_instance), &call_args, None, None, ); bx.ret_void(); - let stub_fn_id = stub_fn.def_cx(self); self.emit_global().entry_point( execution_model, - stub_fn_id, + stub_fn.id, name, op_entry_point_interface_operands, ); - stub_fn_id + stub_fn } /// Attempt to compute `EntryParamDeducedFromRustRefOrValue` (see its docs) diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs index 072042d8f3..3b8f7651de 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs @@ -4,7 +4,9 @@ mod entry; mod type_; use crate::builder::{ExtInst, InstructionTable}; -use crate::builder_spirv::{BuilderCursor, BuilderSpirv, SpirvConst, SpirvValue, SpirvValueKind}; +use crate::builder_spirv::{ + BuilderSpirv, SpirvBlockCursor, SpirvConst, SpirvFunctionCursor, SpirvValue, SpirvValueKind, +}; use crate::custom_decorations::{CustomDecoration, SrcLocDecoration, ZombieDecoration}; use crate::spirv_type::{SpirvType, SpirvTypePrinter, TypeCache}; use crate::symbols::Symbols; @@ -44,12 +46,10 @@ use std::str::FromStr; pub struct CodegenCx<'tcx> { pub tcx: TyCtxt<'tcx>, pub codegen_unit: &'tcx CodegenUnit<'tcx>, - /// Spir-v module builder + /// SPIR-V module builder pub builder: BuilderSpirv<'tcx>, - /// Map from MIR function to spir-v function ID - pub instances: RefCell, SpirvValue>>, - /// Map from function ID to parameter list - pub function_parameter_values: RefCell>>, + pub fn_instances: RefCell, SpirvFunctionCursor>>, + pub statics: RefCell>, pub type_cache: TypeCache<'tcx>, /// Cache generated vtables pub vtables: RefCell, Option>), SpirvValue>>, @@ -193,8 +193,8 @@ impl<'tcx> CodegenCx<'tcx> { tcx, codegen_unit, builder: BuilderSpirv::new(tcx, &sym, &target, &features), - instances: Default::default(), - function_parameter_values: Default::default(), + fn_instances: Default::default(), + statics: Default::default(), type_cache: Default::default(), vtables: Default::default(), ext_inst: Default::default(), @@ -215,18 +215,7 @@ impl<'tcx> CodegenCx<'tcx> { /// See comment on `BuilderCursor` pub fn emit_global(&self) -> std::cell::RefMut<'_, rspirv::dr::Builder> { - self.builder.builder(BuilderCursor { - function: None, - block: None, - }) - } - - /// See comment on `BuilderCursor` - pub fn emit_with_cursor( - &self, - cursor: BuilderCursor, - ) -> std::cell::RefMut<'_, rspirv::dr::Builder> { - self.builder.builder(cursor) + self.builder.global_builder() } #[track_caller] @@ -794,12 +783,14 @@ impl FromStr for ModuleOutputType { impl<'tcx> BackendTypes for CodegenCx<'tcx> { type Value = SpirvValue; type Metadata = (); - type Function = SpirvValue; + type Function = SpirvFunctionCursor; - type BasicBlock = Word; + type BasicBlock = SpirvBlockCursor; type Type = Word; // Funclet: A structure representing an active landing pad for the duration of a basic block. (??) // https://doc.rust-lang.org/nightly/nightly-rustc/rustc_codegen_llvm/common/struct.Funclet.html + // + // FIXME(eddyb) replace with `!` or similar. type Funclet = (); type DIScope = (); @@ -861,13 +852,13 @@ impl<'tcx> MiscCodegenMethods<'tcx> for CodegenCx<'tcx> { SpirvValue { kind: SpirvValueKind::FnAddr { - function: function.def_cx(self), + function: function.id, }, ty, } } - fn eh_personality(&self) -> Self::Value { + fn eh_personality(&self) -> Self::Function { todo!() }