Skip to content

Commit 9a2b5fd

Browse files
committed
abstract away usage of scratch space for buffer barriers and unique index iterators
1 parent 3705928 commit 9a2b5fd

File tree

3 files changed

+165
-103
lines changed

3 files changed

+165
-103
lines changed

wgpu-core/src/indirect_validation/draw.rs

Lines changed: 80 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use super::CreateIndirectValidationPipelineError;
1+
use super::{
2+
utils::{BufferBarrierScratch, BufferBarriers, UniqueIndexExt as _, UniqueIndexScratch},
3+
CreateIndirectValidationPipelineError,
4+
};
25
use crate::{
36
device::{queue::TempResource, Device, DeviceError},
47
lock::{rank, Mutex},
@@ -256,47 +259,37 @@ impl Draw {
256259
batch.metadata_buffer_offset = metadata_buffer_offset;
257260
}
258261

259-
let mut buffer_barriers = Vec::new();
260-
let mut buffer_index_set = bit_set::BitSet::new();
261-
262-
for index in batches
263-
.values()
264-
.map(|batch| batch.staging_buffer_index)
265-
.filter(|index| buffer_index_set.insert(*index))
266-
{
267-
let staging_buffer = &staging_buffers[index];
268-
269-
buffer_barriers.push(hal::BufferBarrier {
270-
buffer: staging_buffer.raw(),
271-
usage: hal::StateTransition {
272-
from: wgt::BufferUses::MAP_WRITE,
273-
to: wgt::BufferUses::COPY_SRC,
274-
},
275-
});
276-
}
277-
buffer_index_set.clear();
278-
279-
for index in batches
280-
.values()
281-
.map(|batch| batch.metadata_resource_index)
282-
.filter(|index| buffer_index_set.insert(*index))
283-
{
284-
let metadata_buffer = resources.get_metadata_buffer(index);
285-
286-
buffer_barriers.push(hal::BufferBarrier {
287-
buffer: metadata_buffer,
288-
usage: hal::StateTransition {
289-
from: wgt::BufferUses::STORAGE_READ_ONLY,
290-
to: wgt::BufferUses::COPY_DST,
291-
},
292-
});
293-
}
294-
buffer_index_set.clear();
295-
296-
unsafe {
297-
encoder.transition_buffers(&buffer_barriers);
298-
}
299-
buffer_barriers.clear();
262+
let buffer_barrier_scratch = &mut BufferBarrierScratch::new();
263+
let unique_index_scratch = &mut UniqueIndexScratch::new();
264+
265+
BufferBarriers::new(buffer_barrier_scratch)
266+
.extend(
267+
batches
268+
.values()
269+
.map(|batch| batch.staging_buffer_index)
270+
.unique(unique_index_scratch)
271+
.map(|index| hal::BufferBarrier {
272+
buffer: staging_buffers[index].raw(),
273+
usage: hal::StateTransition {
274+
from: wgt::BufferUses::MAP_WRITE,
275+
to: wgt::BufferUses::COPY_SRC,
276+
},
277+
}),
278+
)
279+
.extend(
280+
batches
281+
.values()
282+
.map(|batch| batch.metadata_resource_index)
283+
.unique(unique_index_scratch)
284+
.map(|index| hal::BufferBarrier {
285+
buffer: resources.get_metadata_buffer(index),
286+
usage: hal::StateTransition {
287+
from: wgt::BufferUses::STORAGE_READ_ONLY,
288+
to: wgt::BufferUses::COPY_DST,
289+
},
290+
}),
291+
)
292+
.encode(encoder);
300293

301294
for batch in batches.values() {
302295
let data = batch.metadata();
@@ -319,44 +312,38 @@ impl Draw {
319312
}
320313
}
321314

322-
for index in batches
323-
.values()
324-
.map(|batch| batch.metadata_resource_index)
325-
.filter(|index| buffer_index_set.insert(*index))
326-
{
327-
let metadata_buffer = resources.get_metadata_buffer(index);
328-
329-
buffer_barriers.push(hal::BufferBarrier {
330-
buffer: metadata_buffer,
331-
usage: hal::StateTransition {
332-
from: wgt::BufferUses::COPY_DST,
333-
to: wgt::BufferUses::STORAGE_READ_ONLY,
334-
},
335-
});
336-
}
337-
buffer_index_set.clear();
338-
339-
for index in batches
340-
.values()
341-
.map(|batch| batch.dst_resource_index)
342-
.filter(|index| buffer_index_set.insert(*index))
343-
{
344-
let dst_buffer = resources.get_dst_buffer(index);
345-
346-
buffer_barriers.push(hal::BufferBarrier {
347-
buffer: dst_buffer,
348-
usage: hal::StateTransition {
349-
from: wgt::BufferUses::INDIRECT,
350-
to: wgt::BufferUses::STORAGE_READ_WRITE,
351-
},
352-
});
315+
for staging_buffer in staging_buffers {
316+
temp_resources.push(TempResource::StagingBuffer(staging_buffer));
353317
}
354-
buffer_index_set.clear();
355318

356-
unsafe {
357-
encoder.transition_buffers(&buffer_barriers);
358-
}
359-
buffer_barriers.clear();
319+
BufferBarriers::new(buffer_barrier_scratch)
320+
.extend(
321+
batches
322+
.values()
323+
.map(|batch| batch.metadata_resource_index)
324+
.unique(unique_index_scratch)
325+
.map(|index| hal::BufferBarrier {
326+
buffer: resources.get_metadata_buffer(index),
327+
usage: hal::StateTransition {
328+
from: wgt::BufferUses::COPY_DST,
329+
to: wgt::BufferUses::STORAGE_READ_ONLY,
330+
},
331+
}),
332+
)
333+
.extend(
334+
batches
335+
.values()
336+
.map(|batch| batch.dst_resource_index)
337+
.unique(unique_index_scratch)
338+
.map(|index| hal::BufferBarrier {
339+
buffer: resources.get_dst_buffer(index),
340+
usage: hal::StateTransition {
341+
from: wgt::BufferUses::INDIRECT,
342+
to: wgt::BufferUses::STORAGE_READ_WRITE,
343+
},
344+
}),
345+
)
346+
.encode(encoder);
360347

361348
let desc = hal::ComputePassDescriptor {
362349
label: None,
@@ -420,31 +407,21 @@ impl Draw {
420407
encoder.end_compute_pass();
421408
}
422409

423-
for index in batches
424-
.values()
425-
.map(|batch| batch.dst_resource_index)
426-
.filter(|index| buffer_index_set.insert(*index))
427-
{
428-
let dst_buffer = resources.get_dst_buffer(index);
429-
430-
buffer_barriers.push(hal::BufferBarrier {
431-
buffer: dst_buffer,
432-
usage: hal::StateTransition {
433-
from: wgt::BufferUses::STORAGE_READ_WRITE,
434-
to: wgt::BufferUses::INDIRECT,
435-
},
436-
});
437-
}
438-
buffer_index_set.clear();
439-
440-
unsafe {
441-
encoder.transition_buffers(&buffer_barriers);
442-
}
443-
buffer_barriers.clear();
444-
445-
for staging_buffer in staging_buffers {
446-
temp_resources.push(TempResource::StagingBuffer(staging_buffer));
447-
}
410+
BufferBarriers::new(buffer_barrier_scratch)
411+
.extend(
412+
batches
413+
.values()
414+
.map(|batch| batch.dst_resource_index)
415+
.unique(unique_index_scratch)
416+
.map(|index| hal::BufferBarrier {
417+
buffer: resources.get_dst_buffer(index),
418+
usage: hal::StateTransition {
419+
from: wgt::BufferUses::STORAGE_READ_WRITE,
420+
to: wgt::BufferUses::INDIRECT,
421+
},
422+
}),
423+
)
424+
.encode(encoder);
448425

449426
Ok(())
450427
}

wgpu-core/src/indirect_validation/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use thiserror::Error;
77

88
mod dispatch;
99
mod draw;
10+
mod utils;
1011

1112
pub(crate) use dispatch::Dispatch;
1213
pub(crate) use draw::{Draw, DrawBatcher, DrawResources};
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
use alloc::vec::Vec;
2+
3+
pub(crate) struct UniqueIndexScratch(bit_set::BitSet);
4+
5+
impl UniqueIndexScratch {
6+
pub(crate) fn new() -> Self {
7+
Self(bit_set::BitSet::new())
8+
}
9+
}
10+
11+
pub(crate) struct UniqueIndex<'a, I: Iterator<Item = usize>> {
12+
inner: I,
13+
scratch: &'a mut UniqueIndexScratch,
14+
}
15+
16+
impl<'a, I: Iterator<Item = usize>> UniqueIndex<'a, I> {
17+
fn new(inner: I, scratch: &'a mut UniqueIndexScratch) -> Self {
18+
scratch.0.clear();
19+
Self { inner, scratch }
20+
}
21+
}
22+
23+
impl<'a, I: Iterator<Item = usize>> Iterator for UniqueIndex<'a, I> {
24+
type Item = usize;
25+
26+
fn next(&mut self) -> Option<Self::Item> {
27+
self.inner.find(|&i| self.scratch.0.insert(i))
28+
}
29+
}
30+
31+
pub(crate) trait UniqueIndexExt: Iterator<Item = usize> {
32+
fn unique<'a>(self, scratch: &'a mut UniqueIndexScratch) -> UniqueIndex<'a, Self>
33+
where
34+
Self: Sized,
35+
{
36+
UniqueIndex::new(self, scratch)
37+
}
38+
}
39+
40+
impl<T: Iterator<Item = usize>> UniqueIndexExt for T {}
41+
42+
type BufferBarrier<'b> = hal::BufferBarrier<'b, dyn hal::DynBuffer>;
43+
44+
pub(crate) struct BufferBarrierScratch<'b>(Vec<BufferBarrier<'b>>);
45+
46+
impl<'b> BufferBarrierScratch<'b> {
47+
pub(crate) fn new() -> Self {
48+
Self(Vec::new())
49+
}
50+
}
51+
52+
pub(crate) struct BufferBarriers<'a, 'b> {
53+
scratch: &'a mut BufferBarrierScratch<'b>,
54+
}
55+
56+
impl<'a, 'b> BufferBarriers<'a, 'b> {
57+
pub(crate) fn new(scratch: &'a mut BufferBarrierScratch<'_>) -> Self {
58+
// change lifetime of buffer reference, this is safe since `scratch` is empty,
59+
// it was either just created or it has been cleared on `BufferBarriers::drop`
60+
let scratch = unsafe {
61+
core::mem::transmute::<&'a mut BufferBarrierScratch<'_>, &'a mut BufferBarrierScratch<'b>>(
62+
scratch,
63+
)
64+
};
65+
Self { scratch }
66+
}
67+
68+
pub(crate) fn extend(self, iter: impl Iterator<Item = BufferBarrier<'b>>) -> Self {
69+
self.scratch.0.extend(iter);
70+
self
71+
}
72+
73+
pub(crate) fn encode(self, encoder: &mut dyn hal::DynCommandEncoder) {
74+
unsafe {
75+
encoder.transition_buffers(&self.scratch.0);
76+
}
77+
}
78+
}
79+
80+
impl<'a, 'b> Drop for BufferBarriers<'a, 'b> {
81+
fn drop(&mut self) {
82+
self.scratch.0.clear();
83+
}
84+
}

0 commit comments

Comments
 (0)