Skip to content

Customize Function and BasicBlock to carry both a SPIR-V ID and an index, for O(1) access. #341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 46 additions & 61 deletions crates/rustc_codegen_spirv/src/builder/builder_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;

use super::Builder;
use crate::abi::ConvSpirvType;
use crate::builder_spirv::{BuilderCursor, SpirvConst, SpirvValue, SpirvValueExt, SpirvValueKind};
use crate::builder_spirv::{
SpirvBlockCursor, SpirvConst, SpirvValue, SpirvValueExt, SpirvValueKind,
};
use crate::codegen_cx::CodegenCx;
use crate::custom_insts::{CustomInst, CustomOp};
use crate::spirv_type::SpirvType;
Expand Down Expand Up @@ -127,7 +129,7 @@ fn memset_fill_u64(b: u8) -> u64 {
}

fn memset_dynamic_scalar(
builder: &Builder<'_, '_>,
builder: &mut Builder<'_, '_>,
fill_var: Word,
byte_width: usize,
is_float: bool,
Expand All @@ -154,7 +156,7 @@ fn memset_dynamic_scalar(

impl<'a, 'tcx> Builder<'a, 'tcx> {
#[instrument(level = "trace", skip(self))]
fn ordering_to_semantics_def(&self, ordering: AtomicOrdering) -> SpirvValue {
fn ordering_to_semantics_def(&mut self, ordering: AtomicOrdering) -> SpirvValue {
let mut invalid_seq_cst = false;
let semantics = match ordering {
AtomicOrdering::Relaxed => MemorySemantics::NONE,
Expand All @@ -166,8 +168,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
| MemorySemantics::ACQUIRE_RELEASE
}
AtomicOrdering::SeqCst => {
let emit = self.emit();
let memory_model = emit.module_ref().memory_model.as_ref().unwrap();
let builder = self.emit();
let memory_model = builder.module_ref().memory_model.as_ref().unwrap();
if memory_model.operands[1].unwrap_memory_model() == MemoryModel::Vulkan {
invalid_seq_cst = true;
}
Expand Down Expand Up @@ -263,7 +265,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
}

#[instrument(level = "trace", skip(self))]
fn memset_dynamic_pattern(&self, ty: &SpirvType<'tcx>, fill_var: Word) -> Word {
fn memset_dynamic_pattern(&mut self, ty: &SpirvType<'tcx>, fill_var: Word) -> Word {
match *ty {
SpirvType::Void => self.fatal("memset invalid on void pattern"),
SpirvType::Bool => self.fatal("memset invalid on bool pattern"),
Expand Down Expand Up @@ -750,9 +752,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
let ptr_id = ptr.def(self);

let maybe_original_access_chain = {
let emit = self.emit();
let module = emit.module_ref();
let current_func_blocks = emit
let builder = self.emit();
let module = builder.module_ref();
let current_func_blocks = builder
.selected_function()
.and_then(|func_idx| Some(&module.functions.get(func_idx)?.blocks[..]))
.unwrap_or_default();
Expand Down Expand Up @@ -823,28 +825,28 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
)
)]
fn emit_access_chain(
&self,
&mut self,
result_type: <Self as BackendTypes>::Type,
pointer: Word,
ptr_base_index: Option<SpirvValue>,
indices: Vec<Word>,
is_inbounds: bool,
) -> SpirvValue {
let mut emit = self.emit();
let mut builder = self.emit();

let non_zero_ptr_base_index =
ptr_base_index.filter(|&idx| self.builder.lookup_const_scalar(idx) != Some(0));
if let Some(ptr_base_index) = non_zero_ptr_base_index {
let result = if is_inbounds {
emit.in_bounds_ptr_access_chain(
builder.in_bounds_ptr_access_chain(
result_type,
None,
pointer,
ptr_base_index.def(self),
indices,
)
} else {
emit.ptr_access_chain(
builder.ptr_access_chain(
result_type,
None,
pointer,
Expand All @@ -857,9 +859,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
result
} else {
if is_inbounds {
emit.in_bounds_access_chain(result_type, None, pointer, indices)
builder.in_bounds_access_chain(result_type, None, pointer, indices)
} else {
emit.access_chain(result_type, None, pointer, indices)
builder.access_chain(result_type, None, pointer, indices)
}
.unwrap()
}
Expand Down Expand Up @@ -1091,21 +1093,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {

#[instrument(level = "trace", skip(cx))]
fn build(cx: &'a Self::CodegenCx, llbb: Self::BasicBlock) -> Self {
let cursor = cx.builder.select_block_by_id(llbb);
// FIXME(eddyb) change `Self::Function` to be more like a function index.
let current_fn = {
let emit = cx.emit_with_cursor(cursor);
let selected_function = emit.selected_function().unwrap();
let selected_function = &emit.module_ref().functions[selected_function];
let def_inst = selected_function.def.as_ref().unwrap();
let def = def_inst.result_id.unwrap();
let ty = def_inst.operands[1].unwrap_id_ref();
def.with_type(ty)
};
Self {
cx,
cursor,
current_fn,
current_block: llbb,
current_span: Default::default(),
}
}
Expand Down Expand Up @@ -1201,17 +1191,18 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
llfn: Self::Function,
_name: &str,
) -> Self::BasicBlock {
let cursor_fn = cx.builder.select_function_by_id(llfn.def_cx(cx));
cx.emit_with_cursor(cursor_fn).begin_block(None).unwrap()
let mut builder = cx.builder.builder_for_fn(llfn);
let id = builder.begin_block(None).unwrap();
let index_in_builder = builder.selected_block().unwrap();
SpirvBlockCursor {
parent_fn: llfn,
id,
index_in_builder,
}
}

fn append_sibling_block(&mut self, _name: &str) -> Self::BasicBlock {
self.emit_with_cursor(BuilderCursor {
function: self.cursor.function,
block: None,
})
.begin_block(None)
.unwrap()
fn append_sibling_block(&mut self, name: &str) -> Self::BasicBlock {
Self::append_block(self.cx, self.current_block.parent_fn, name)
}

fn switch_to_block(&mut self, llbb: Self::BasicBlock) {
Expand Down Expand Up @@ -1239,7 +1230,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}

fn br(&mut self, dest: Self::BasicBlock) {
self.emit().branch(dest).unwrap();
self.emit().branch(dest.id).unwrap();
}

fn cond_br(
Expand All @@ -1257,7 +1248,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
Some(SpirvConst::Scalar(0)) => self.br(else_llbb),
_ => {
self.emit()
.branch_conditional(cond, then_llbb, else_llbb, empty())
.branch_conditional(cond, then_llbb.id, else_llbb.id, empty())
.unwrap();
}
}
Expand Down Expand Up @@ -1330,9 +1321,11 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
)),
};
let cases = cases
.map(|(i, b)| (construct_case(self, signed, i), b))
.map(|(i, b)| (construct_case(self, signed, i), b.id))
.collect::<Vec<_>>();
self.emit().switch(v.def(self), else_llbb, cases).unwrap();
self.emit()
.switch(v.def(self), else_llbb.id, cases)
.unwrap();
}

fn invoke(
Expand All @@ -1349,7 +1342,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
) -> Self::Value {
// Exceptions don't exist, jump directly to then block
let result = self.call(llty, fn_attrs, fn_abi, llfn, args, funclet, instance);
self.emit().branch(then).unwrap();
self.emit().branch(then.id).unwrap();
result
}

Expand Down Expand Up @@ -2741,16 +2734,16 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
.with_type(agg_val.ty)
}

fn set_personality_fn(&mut self, _personality: Self::Value) {
fn set_personality_fn(&mut self, _personality: Self::Function) {
todo!()
}

// These are used by everyone except msvc
fn cleanup_landing_pad(&mut self, _pers_fn: Self::Value) -> (Self::Value, Self::Value) {
fn cleanup_landing_pad(&mut self, _pers_fn: Self::Function) -> (Self::Value, Self::Value) {
todo!()
}

fn filter_landing_pad(&mut self, _pers_fn: Self::Value) -> (Self::Value, Self::Value) {
fn filter_landing_pad(&mut self, _pers_fn: Self::Function) -> (Self::Value, Self::Value) {
todo!()
}

Expand Down Expand Up @@ -2979,16 +2972,6 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
// be fixed upstream, so we never see any "function pointer" values being
// created just to perform direct calls.
let (callee_val, result_type, argument_types) = match self.lookup_type(callee.ty) {
// HACK(eddyb) this seems to be needed, but it's not what `get_fn_addr`
// produces, are these coming from inside `rustc_codegen_spirv`?
SpirvType::Function {
return_type,
arguments,
} => {
assert_ty_eq!(self, callee_ty, callee.ty);
(callee.def(self), return_type, arguments)
}

Comment on lines 2974 to -2991
Copy link
Collaborator Author

@eddyb eddyb Jul 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This special-case could be removed because it could only be reached by the custom entry-point logic passing a function as the callee, while rustc_codegen_ssa always expects a function pointer (this matches up with the entry_func -> self.get_fn_addr(entry_instance) change in entry.rs).

While direct calls should IMO be supported as first-class, that would require call to be able to take Self::Function as the callee, not Self::Value, as the only rustc_codegen_ssa interface to get a Self::Value for a function is get_fn_addr (which always creates a pointer).

(This also makes sense if you consider that Self::Value is really meant to be an IR value, ideally something vaguely-register-like, that can partake in runtime dataflow, and "a function" doesn't fit that without being referred to indirectly, i.e. as a function pointer)

SpirvType::Pointer { pointee } => match self.lookup_type(pointee) {
SpirvType::Function {
return_type,
Expand All @@ -3014,7 +2997,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
},

_ => bug!(
"call expected function or `fn` pointer type, got `{}`",
"call expected `fn` pointer type, got `{}`",
self.debug_type(callee.ty)
),
};
Expand Down Expand Up @@ -3091,18 +3074,20 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
struct FormatArgsNotRecognized(String);

// HACK(eddyb) this is basically a `try` block.
let try_decode_and_remove_format_args = || {
let mut try_decode_and_remove_format_args = || {
let mut decoded_format_args = DecodedFormatArgs::default();

let const_u32_as_usize = |ct_id| match self.builder.lookup_const_by_id(ct_id)? {
// HACK(eddyb) work around mutable borrowing conflicts.
let cx = self.cx;

let const_u32_as_usize = |ct_id| match cx.builder.lookup_const_by_id(ct_id)? {
SpirvConst::Scalar(x) => Some(u32::try_from(x).ok()? as usize),
_ => None,
};
let const_slice_as_elem_ids = |ptr_id: Word, len: usize| {
if let SpirvConst::PtrTo { pointee } =
self.builder.lookup_const_by_id(ptr_id)?
if let SpirvConst::PtrTo { pointee } = cx.builder.lookup_const_by_id(ptr_id)?
&& let SpirvConst::Composite(elems) =
self.builder.lookup_const_by_id(pointee)?
cx.builder.lookup_const_by_id(pointee)?
&& elems.len() == len
{
return Some(elems);
Expand Down
9 changes: 7 additions & 2 deletions crates/rustc_codegen_spirv/src/builder/ext_inst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl ExtInst {

impl<'a, 'tcx> Builder<'a, 'tcx> {
pub fn custom_inst(
&self,
&mut self,
result_type: Word,
inst: custom_insts::CustomInst<Operand>,
) -> SpirvValue {
Expand All @@ -58,7 +58,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
.with_type(result_type)
}

pub fn gl_op(&self, op: GLOp, result_type: Word, args: impl AsRef<[SpirvValue]>) -> SpirvValue {
pub fn gl_op(
&mut self,
op: GLOp,
result_type: Word,
args: impl AsRef<[SpirvValue]>,
) -> SpirvValue {
let args = args.as_ref();
let glsl = self.ext_inst.borrow_mut().import_glsl(self);
self.emit()
Expand Down
34 changes: 17 additions & 17 deletions crates/rustc_codegen_spirv/src/builder/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> {
}

impl Builder<'_, '_> {
pub fn count_ones(&self, arg: SpirvValue) -> SpirvValue {
pub fn count_ones(&mut self, arg: SpirvValue) -> SpirvValue {
let ty = arg.ty;
match self.cx.lookup_type(ty) {
SpirvType::Integer(bits, false) => {
Expand Down Expand Up @@ -426,7 +426,7 @@ impl Builder<'_, '_> {
}
}

pub fn bit_reverse(&self, arg: SpirvValue) -> SpirvValue {
pub fn bit_reverse(&mut self, arg: SpirvValue) -> SpirvValue {
let ty = arg.ty;
match self.cx.lookup_type(ty) {
SpirvType::Integer(bits, false) => {
Expand Down Expand Up @@ -489,7 +489,7 @@ impl Builder<'_, '_> {
}

pub fn count_leading_trailing_zeros(
&self,
&mut self,
arg: SpirvValue,
trailing: bool,
non_zero: bool,
Expand All @@ -501,9 +501,9 @@ impl Builder<'_, '_> {
let u32 = SpirvType::Integer(32, false).def(self.span(), self);

let glsl = self.ext_inst.borrow_mut().import_glsl(self);
let find_xsb = |arg, offset: i32| {
let find_xsb = |this: &mut Self, arg, offset: i32| {
if trailing {
let lsb = self
let lsb = this
.emit()
.ext_inst(
u32,
Expand All @@ -516,12 +516,12 @@ impl Builder<'_, '_> {
if offset == 0 {
lsb
} else {
let const_offset = self.constant_i32(self.span(), offset).def(self);
self.emit().i_add(u32, None, const_offset, lsb).unwrap()
let const_offset = this.constant_i32(this.span(), offset).def(this);
this.emit().i_add(u32, None, const_offset, lsb).unwrap()
}
} else {
// rust is always unsigned, so FindUMsb
let msb_bit = self
let msb_bit = this
.emit()
.ext_inst(
u32,
Expand All @@ -533,21 +533,21 @@ impl Builder<'_, '_> {
.unwrap();
// the glsl op returns the Msb bit, not the amount of leading zeros of this u32
// leading zeros = 31 - Msb bit
let const_offset = self.constant_i32(self.span(), 31 - offset).def(self);
self.emit().i_sub(u32, None, const_offset, msb_bit).unwrap()
let const_offset = this.constant_i32(this.span(), 31 - offset).def(this);
this.emit().i_sub(u32, None, const_offset, msb_bit).unwrap()
}
};

let converted = match bits {
8 | 16 => {
let arg = self.emit().u_convert(u32, None, arg.def(self)).unwrap();
if trailing {
find_xsb(arg, 0)
find_xsb(self, arg, 0)
} else {
find_xsb(arg, bits as i32 - 32)
find_xsb(self, arg, bits as i32 - 32)
}
}
32 => find_xsb(arg.def(self), 0),
32 => find_xsb(self, arg.def(self), 0),
64 => {
let u32_0 = self.constant_int(u32, 0).def(self);
let u32_32 = self.constant_u32(self.span(), 32).def(self);
Expand All @@ -562,16 +562,16 @@ impl Builder<'_, '_> {

if trailing {
let use_lower = self.emit().i_equal(bool, None, lower, u32_0).unwrap();
let lower_bits = find_xsb(lower, 32);
let higher_bits = find_xsb(higher, 0);
let lower_bits = find_xsb(self, lower, 32);
let higher_bits = find_xsb(self, higher, 0);
self.emit()
.select(u32, None, use_lower, higher_bits, lower_bits)
.unwrap()
} else {
let use_higher =
self.emit().i_equal(bool, None, higher, u32_0).unwrap();
let lower_bits = find_xsb(lower, 0);
let higher_bits = find_xsb(higher, 32);
let lower_bits = find_xsb(self, lower, 0);
let higher_bits = find_xsb(self, higher, 32);
self.emit()
.select(u32, None, use_higher, lower_bits, higher_bits)
.unwrap()
Expand Down
Loading
Loading