Skip to content

Commit a53ffff

Browse files
Mesh shader HLSL writer (#8752)
1 parent 1148bac commit a53ffff

File tree

121 files changed

+1686
-361
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

121 files changed

+1686
-361
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ Bottom level categories:
5252

5353
- Unconditionally enable `Features::CLIP_DISTANCES`. By @ErichDonGubler in [#9270](https://github.com/gfx-rs/wgpu/pull/9270).
5454

55+
#### DX12
56+
57+
- Added support for mesh shaders in naga's HLSL writer, completing DX12 support for mesh shaders. By @inner-daemons in #8752.
58+
5559
### Changes
5660

5761
#### General

docs/api-specs/mesh_shading.md

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,16 @@ An example of using mesh shaders to render a single triangle can be seen [here](
9090

9191
> **NOTE**: More limits will be added when support is added to `naga`.
9292
93-
* `Limits::max_task_workgroup_total_count` - the maximum total number of workgroups from a `draw_mesh_tasks` command or similar. The dimensions passed must be less than or equal to this limit when multiplied together.
94-
* `Limits::max_task_workgroups_per_dimension` - the maximum for each of the 3 workgroup dimensions in a `draw_mesh_tasks` command. Each dimension passed must be less than or equal to this limit.
95-
* `max_mesh_multiview_count` - The maximum number of views used when multiview rendering with a mesh shader pipeline.
93+
* `Limits::max_task_mesh_workgroup_total_count` - the maximum total number of workgroups from a `draw_mesh_tasks` command or similar. The dimensions passed must be less than or equal to this limit when multiplied together.
94+
* `Limits::max_task_mesh_workgroups_per_dimension` - the maximum for each of the 3 workgroup dimensions in a `draw_mesh_tasks` command. Each dimension passed must be less than or equal to this limit.
95+
* `max_task_invocations_per_workgroup` - The maximum total number of threads in a task shader workgroup, given by `workgroupSize.x * workgroupSize.y * workgroupSize.z`.
96+
* `max_task_invocations_per_dimension` the maximum value for each of `workgroupSize.x`, `workgroupSize.y` and `workgroupSize.z` in task shader workgroups.
97+
* `max_mesh_invocations_per_workgroup` - The maximum total number of threads in a mesh shader workgroup, given by `workgroupSize.x * workgroupSize.y * workgroupSize.z`.
98+
* `max_mesh_invocations_per_dimension` the maximum value for each of `workgroupSize.x`, `workgroupSize.y` and `workgroupSize.z` in mesh shader workgroups.
99+
* `max_task_payload_size` - the size of an `var<task_payload>` variable, in bytes.
100+
* `max_mesh_output_vertices` - the maximum number of vertices that a single mesh shader workgroup may output.
101+
* `max_mesh_output_primitives` - the maximum number of primitives that a single mesh shader workgroup may output.
102+
* `max_mesh_multiview_count` - the maximum number of views used when multiview rendering with a mesh shader pipeline.
96103
* `max_mesh_output_layers` - the maximum number of output layers for a mesh shader pipeline.
97104

98105
## Naga implementation
@@ -132,8 +139,10 @@ A task shader entry point must also have a `@payload(G)` property, where `G` is
132139

133140
A task shader entry point must return a `vec3<u32>` value decorated with `@builtin(mesh_task_size)`. The return value of each workgroup's first invocation (that is, the one whose `local_invocation_index` is `0`) is taken as the size of a **mesh shader grid** to dispatch, measured in workgroups. (If the task shader entry point returns `vec3(0, 0, 0)`, then no mesh shaders are dispatched.) Mesh shader grids are described in the next section.
134141

142+
The output of a task shader is set to zero if it violates either of the limits `max_task_mesh_workgroup_total_count` or `max_task_mesh_workgroups_per_dimension`.
143+
135144
Each task shader workgroup dispatches an independent mesh shader grid: in mesh shader invocations, `@builtin` values like `workgroup_id` and `global_invocation_id` describe the position of the workgroup and invocation within that grid;
136-
and `@builtin(num_workgroups)` matches the task shader workgroup's return value. Mesh shaders dispatched for other task shader workgroups are not included in the count. If it is necessary for a mesh shader to know which task shader workgroup dispatched it, the task shader can include its own workgroup id in the task payload.
145+
and `@builtin(num_workgroups)` matches the task shader workgroup's return value. If this output violates any limits, it may be zeroed or cause undefined behavior, depending on the compilation options. Mesh shaders dispatched for other task shader workgroups are not included in the count. If it is necessary for a mesh shader to know which task shader workgroup dispatched it, the task shader can include its own workgroup id in the task payload.
137146

138147
Task shaders can use compute and subgroup builtin inputs, in addition to `view_index` and `draw_id`.
139148

examples/features/src/mesh_shader/mod.rs

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,37 +11,6 @@ fn compile_wgsl(device: &wgpu::Device) -> wgpu::ShaderModule {
1111
)
1212
}
1313
}
14-
fn compile_hlsl(device: &wgpu::Device, entry: &str, stage_str: &str) -> wgpu::ShaderModule {
15-
let out_path = format!(
16-
"{}/src/mesh_shader/shader.{stage_str}.cso",
17-
env!("CARGO_MANIFEST_DIR")
18-
);
19-
let cmd = std::process::Command::new("dxc")
20-
.args([
21-
"-T",
22-
&format!("{stage_str}_6_5"),
23-
"-E",
24-
entry,
25-
&format!("{}/src/mesh_shader/shader.hlsl", env!("CARGO_MANIFEST_DIR")),
26-
"-Fo",
27-
&out_path,
28-
])
29-
.output()
30-
.unwrap();
31-
if !cmd.status.success() {
32-
panic!("DXC failed:\n{}", String::from_utf8(cmd.stderr).unwrap());
33-
}
34-
let file = std::fs::read(&out_path).unwrap();
35-
std::fs::remove_file(out_path).unwrap();
36-
unsafe {
37-
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough {
38-
label: None,
39-
num_workgroups: (1, 1, 1),
40-
dxil: Some(std::borrow::Cow::Owned(file)),
41-
..Default::default()
42-
})
43-
}
44-
}
4514

4615
fn compile_msl(device: &wgpu::Device) -> wgpu::ShaderModule {
4716
unsafe {
@@ -67,7 +36,7 @@ fn get_shaders(device: &wgpu::Device, backend: wgpu::Backend) -> Shaders {
6736
// In the case that the platform does support mesh shaders, the dummy
6837
// shader is used to avoid requiring PASSTHROUGH_SHADERS.
6938
match backend {
70-
wgpu::Backend::Vulkan => {
39+
wgpu::Backend::Vulkan | wgpu::Backend::Dx12 => {
7140
let compiled = compile_wgsl(device);
7241
Shaders {
7342
ts: compiled.clone(),
@@ -78,14 +47,6 @@ fn get_shaders(device: &wgpu::Device, backend: wgpu::Backend) -> Shaders {
7847
fs_name: "fs_main",
7948
}
8049
}
81-
wgpu::Backend::Dx12 => Shaders {
82-
ts: compile_hlsl(device, "Task", "as"),
83-
ms: compile_hlsl(device, "Mesh", "ms"),
84-
fs: compile_hlsl(device, "Frag", "ps"),
85-
ts_name: "main",
86-
ms_name: "main",
87-
fs_name: "main",
88-
},
8950
wgpu::Backend::Metal => {
9051
let compiled = compile_msl(device);
9152
Shaders {

examples/features/src/mesh_shader/shader.hlsl

Lines changed: 0 additions & 53 deletions
This file was deleted.

naga-cli/src/bin/naga.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,8 @@ fn run() -> anyhow::Result<()> {
575575

576576
params.spv_out.mesh_shader_primitive_indices_clamp = args.validate_mesh_output;
577577
params.spv_out.task_dispatch_limits = args.task_limits.0;
578+
params.hlsl.mesh_shader_primitive_indices_clamp = args.validate_mesh_output;
579+
params.hlsl.task_dispatch_limits = args.task_limits.0;
578580

579581
if args.bulk_validate {
580582
return bulk_validate(&args, &params);

naga/hlsl-snapshots/src/lib.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ pub struct Config {
5050
pub vertex: Vec<ConfigItem>,
5151
pub fragment: Vec<ConfigItem>,
5252
pub compute: Vec<ConfigItem>,
53+
pub task: Vec<ConfigItem>,
54+
pub mesh: Vec<ConfigItem>,
5355
}
5456

5557
impl Config {
@@ -59,6 +61,8 @@ impl Config {
5961
vertex: Default::default(),
6062
fragment: Default::default(),
6163
compute: Default::default(),
64+
task: Default::default(),
65+
mesh: Default::default(),
6266
}
6367
}
6468

@@ -85,8 +89,14 @@ impl Config {
8589
vertex,
8690
fragment,
8791
compute,
92+
task,
93+
mesh,
8894
} = self;
89-
vertex.is_empty() && fragment.is_empty() && compute.is_empty()
95+
vertex.is_empty()
96+
&& fragment.is_empty()
97+
&& compute.is_empty()
98+
&& task.is_empty()
99+
&& mesh.is_empty()
90100
}
91101
}
92102

naga/src/back/hlsl/conv.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,10 @@ impl crate::StorageFormat {
149149
}
150150

151151
impl crate::BuiltIn {
152-
pub(super) fn to_hlsl_str(self) -> Result<&'static str, Error> {
153-
Ok(match self {
152+
/// Returns `None` for "virtual" builtins, i.e. mesh shader builtins that are
153+
/// used by naga but not recognized by HLSL.
154+
pub(super) fn to_hlsl_str(self) -> Result<Option<&'static str>, Error> {
155+
Ok(Some(match self {
154156
Self::Position { .. } => "SV_Position",
155157
// vertex
156158
Self::ClipDistances => "SV_ClipDistance",
@@ -186,12 +188,14 @@ impl crate::BuiltIn {
186188
return Err(Error::Custom(format!("Unsupported builtin {self:?}")))
187189
}
188190
Self::CullPrimitive => "SV_CullPrimitive",
189-
Self::PointIndex | Self::LineIndices | Self::TriangleIndices => unimplemented!(),
190191
Self::MeshTaskSize
191192
| Self::VertexCount
192193
| Self::PrimitiveCount
193194
| Self::Vertices
194-
| Self::Primitives => unreachable!(),
195+
| Self::Primitives
196+
| Self::PointIndex
197+
| Self::LineIndices
198+
| Self::TriangleIndices => return Ok(None),
195199
Self::RayInvocationId
196200
| Self::NumRayInvocations
197201
| Self::InstanceCustomData
@@ -205,7 +209,7 @@ impl crate::BuiltIn {
205209
| Self::ObjectToWorld
206210
| Self::WorldToObject
207211
| Self::HitKind => unreachable!(),
208-
})
212+
}))
209213
}
210214
}
211215

0 commit comments

Comments
 (0)