Skip to content

Commit 75bc82b

Browse files
committed
Fixed situations with local_invocation_id and/or local_invocation_index being written incorrectly
1 parent 12b9d49 commit 75bc82b

19 files changed

+235
-35
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ depth_stencil: Some(wgpu::DepthStencilState::stencil(
175175
- Renamed `EXPERIMENTAL_PASSTHROUGH_SHADERS` to `PASSTHROUGH_SHADERS` and made this no longer an experimental feature. by @inner-daemons in [#9054](https://github.com/gfx-rs/wgpu/pull/9054).
176176
- BREAKING: End offsets in trace and `player` commands are now represented using `offset` + `size` instead. By @ErichDonGubler in [9073](https://github.com/gfx-rs/wgpu/pull/9073).
177177
- Validate some uncaught cases where buffer transfer operations could overflow when computing an end offset. By @ErichDonGubler in [9073](https://github.com/gfx-rs/wgpu/pull/9073).
178+
- Fix various issues relating to `local_invocation_index` and `local_invocation_id` in HLSL and MSL. By @inner-daemons in [#9099](https://github.com/gfx-rs/wgpu/pull/9099).
178179
- Added internal labels to validation GPU objects and timestamp normalization code to improve clarity in graphics debuggers. By @szostid in [9094](https://github.com/gfx-rs/wgpu/pull/9094)
179180
- Fix multi-planar texture copying. By @noituri [#9069](https://github.com/gfx-rs/wgpu/pull/9069)
180181

naga/src/back/hlsl/writer.rs

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ struct EntryPointBinding {
7474
ty_name: String,
7575
/// Members of generated structure
7676
members: Vec<EpStructMember>,
77+
local_invocation_index_name: Option<String>,
7778
}
7879

7980
pub(super) struct EntryPointInterface {
@@ -619,11 +620,23 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
619620

620621
write!(self.out, "struct {struct_name}")?;
621622
writeln!(self.out, " {{")?;
623+
let mut local_invocation_index_name = None;
624+
let mut subgroup_id_used = false;
622625
for m in members.iter() {
623626
// Sanity check that each IO member is a built-in or is assigned a
624627
// location. Also see note about nesting in `write_ep_input_struct`.
625628
debug_assert!(m.binding.is_some());
626629

630+
match m.binding {
631+
Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
632+
subgroup_id_used = true;
633+
}
634+
Some(crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex)) => {
635+
local_invocation_index_name = Some(m.name.clone());
636+
}
637+
_ => (),
638+
}
639+
627640
if is_subgroup_builtin_binding(&m.binding) {
628641
continue;
629642
}
@@ -636,17 +649,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
636649
self.write_semantic(&m.binding, Some(shader_stage))?;
637650
writeln!(self.out, ";")?;
638651
}
639-
if members.iter().any(|arg| {
640-
matches!(
641-
arg.binding,
642-
Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId))
643-
)
644-
}) {
645-
writeln!(
646-
self.out,
647-
"{}uint __local_invocation_index : SV_GroupIndex;",
648-
back::INDENT
649-
)?;
652+
if subgroup_id_used && local_invocation_index_name.is_none() {
653+
let name = self.namer.call("local_invocation_index");
654+
writeln!(self.out, "{}uint {name} : SV_GroupIndex;", back::INDENT)?;
655+
local_invocation_index_name = Some(name);
650656
}
651657
writeln!(self.out, "}};")?;
652658
writeln!(self.out)?;
@@ -666,6 +672,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
666672
arg_name: self.namer.call(struct_name.to_lowercase().as_str()),
667673
ty_name: struct_name,
668674
members,
675+
local_invocation_index_name,
669676
})
670677
}
671678

@@ -845,8 +852,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
845852
Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
846853
write!(
847854
self.out,
848-
"{}.__local_invocation_index / WaveGetLaneCount()",
849-
ep_input.arg_name
855+
"{}.{} / WaveGetLaneCount()",
856+
ep_input.arg_name,
857+
// When writing SubgroupId, we always guarantee that local_invocation_index_name is written
858+
ep_input.local_invocation_index_name.as_ref().unwrap()
850859
)?;
851860
}
852861
_ => {
@@ -1587,6 +1596,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
15871596
let need_workgroup_variables_initialization =
15881597
self.need_workgroup_variables_initialization(func_ctx, module);
15891598

1599+
let needs_local_invocation_id_name = need_workgroup_variables_initialization;
1600+
let mut local_invocation_id_name = None;
15901601
// Write function arguments for non entry point functions
15911602
match func_ctx.ty {
15921603
back::FunctionType::Function(handle) => {
@@ -1614,6 +1625,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
16141625
let argument_name =
16151626
&self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
16161627

1628+
if arg.binding
1629+
== Some(crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId))
1630+
{
1631+
local_invocation_id_name = Some(argument_name.clone());
1632+
}
1633+
16171634
write!(self.out, " {argument_name}")?;
16181635
if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
16191636
self.write_array_size(module, base, size)?;
@@ -1622,7 +1639,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
16221639
self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
16231640
}
16241641
}
1625-
if need_workgroup_variables_initialization {
1642+
if needs_local_invocation_id_name && local_invocation_id_name.is_none() {
16261643
if self
16271644
.entry_point_io
16281645
.get(&(ep_index as usize))
@@ -1633,7 +1650,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
16331650
{
16341651
write!(self.out, ", ")?;
16351652
}
1636-
write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
1653+
let var_name = self.namer.call("local_invocation_id");
1654+
write!(self.out, "uint3 {var_name} : SV_GroupThreadID")?;
1655+
local_invocation_id_name = Some(var_name);
16371656
}
16381657
}
16391658
}
@@ -1653,7 +1672,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
16531672
writeln!(self.out, "{{")?;
16541673

16551674
if need_workgroup_variables_initialization {
1656-
self.write_workgroup_variables_initialization(func_ctx, module)?;
1675+
self.write_workgroup_variables_initialization(
1676+
func_ctx,
1677+
module,
1678+
// need_workgroup_variables_initialization forces this to be written
1679+
// if the user doesn't specify it (so this must be Some())
1680+
local_invocation_id_name.unwrap(),
1681+
)?;
16571682
}
16581683

16591684
if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
@@ -1804,12 +1829,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
18041829
&mut self,
18051830
func_ctx: &back::FunctionCtx,
18061831
module: &Module,
1832+
local_invocation_id_name: String,
18071833
) -> BackendResult {
18081834
let level = back::Level(1);
18091835

18101836
writeln!(
18111837
self.out,
1812-
"{level}if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {{"
1838+
"{level}if (all({local_invocation_id_name} == uint3(0u, 0u, 0u))) {{"
18131839
)?;
18141840

18151841
let vars = module.global_variables.iter().filter(|&(handle, var)| {

naga/src/back/msl/keywords.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,8 @@ const RESERVED: &[&str] = &[
343343
"M_SQRT1_2",
344344
// Naga utilities
345345
"DefaultConstructible",
346+
// Naga builtin names
347+
"__local_invocation_id",
346348
super::writer::FREXP_FUNCTION,
347349
super::writer::MODF_FUNCTION,
348350
super::writer::ABS_FUNCTION,
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
var<workgroup> wg_var: u32;
2+
3+
struct Input {
4+
@builtin(local_invocation_id)
5+
local_invocation_id: vec3<u32>,
6+
@builtin(local_invocation_index)
7+
local_invocation_index: u32,
8+
}
9+
10+
@compute
11+
@workgroup_size(1)
12+
fn compute1(input: Input) {
13+
wg_var = input.local_invocation_index * 2;
14+
wg_var += input.local_invocation_id.x;
15+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#version 310 es
2+
3+
precision highp float;
4+
precision highp int;
5+
6+
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
7+
8+
struct Input {
9+
uvec3 local_invocation_id;
10+
uint local_invocation_index;
11+
};
12+
shared uint wg_var;
13+
14+
15+
void main() {
16+
if (gl_LocalInvocationID == uvec3(0u)) {
17+
wg_var = 0u;
18+
}
19+
memoryBarrierShared();
20+
barrier();
21+
Input input_ = Input(gl_LocalInvocationID, gl_LocalInvocationIndex);
22+
wg_var = (input_.local_invocation_index * 2u);
23+
uint _e8 = wg_var;
24+
wg_var = (_e8 + input_.local_invocation_id.x);
25+
return;
26+
}
27+

naga/tests/out/hlsl/spv-subgroup-operations-s.hlsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ static uint global_2 = (uint)0;
44
static uint global_3 = (uint)0;
55

66
struct ComputeInput_main {
7-
uint __local_invocation_index : SV_GroupIndex;
7+
uint local_invocation_index : SV_GroupIndex;
88
};
99

1010
void function()
@@ -39,7 +39,7 @@ void function()
3939
void main(ComputeInput_main computeinput_main)
4040
{
4141
uint param = (1u + WaveGetLaneCount() - 1u) / WaveGetLaneCount();
42-
uint param_1 = computeinput_main.__local_invocation_index / WaveGetLaneCount();
42+
uint param_1 = computeinput_main.local_invocation_index / WaveGetLaneCount();
4343
uint param_2 = WaveGetLaneCount();
4444
uint param_3 = WaveGetLaneIndex();
4545
global = param;
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
struct Input {
2+
uint3 local_invocation_id : SV_GroupThreadID;
3+
uint local_invocation_index : SV_GroupIndex;
4+
};
5+
6+
groupshared uint wg_var;
7+
8+
[numthreads(1, 1, 1)]
9+
void compute1_(Input input, uint3 local_invocation_id : SV_GroupThreadID)
10+
{
11+
if (all(local_invocation_id == uint3(0u, 0u, 0u))) {
12+
wg_var = (uint)0;
13+
}
14+
GroupMemoryBarrierWithGroupSync();
15+
wg_var = (input.local_invocation_index * 2u);
16+
uint _e8 = wg_var;
17+
wg_var = (_e8 + input.local_invocation_id.x);
18+
return;
19+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
(
2+
vertex:[
3+
],
4+
fragment:[
5+
],
6+
compute:[
7+
(
8+
entry_point:"compute1_",
9+
target_profile:"cs_5_1",
10+
),
11+
],
12+
)

naga/tests/out/hlsl/wgsl-atomicOps-int64.hlsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ groupshared int64_t workgroup_atomic_arr[2];
3030
groupshared Struct workgroup_struct;
3131

3232
[numthreads(2, 1, 1)]
33-
void cs_main(uint3 id : SV_GroupThreadID, uint3 __local_invocation_id : SV_GroupThreadID)
33+
void cs_main(uint3 id : SV_GroupThreadID)
3434
{
35-
if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {
35+
if (all(id == uint3(0u, 0u, 0u))) {
3636
workgroup_atomic_scalar = (uint64_t)0;
3737
workgroup_atomic_arr = (int64_t[2])0;
3838
workgroup_struct = (Struct)0;

naga/tests/out/hlsl/wgsl-atomicOps.hlsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ groupshared int workgroup_atomic_arr[2];
2121
groupshared Struct workgroup_struct;
2222

2323
[numthreads(2, 1, 1)]
24-
void cs_main(uint3 id : SV_GroupThreadID, uint3 __local_invocation_id : SV_GroupThreadID)
24+
void cs_main(uint3 id : SV_GroupThreadID)
2525
{
26-
if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {
26+
if (all(id == uint3(0u, 0u, 0u))) {
2727
workgroup_atomic_scalar = (uint)0;
2828
workgroup_atomic_arr = (int[2])0;
2929
workgroup_struct = (Struct)0;

0 commit comments

Comments
 (0)