Skip to content

Commit face4f1

Browse files
authored
Change flattened dependency graph from DiGraph<NodeId> to DiGraph<SystemKey> (#20172)
# Objective - Part of #20115 Semantically, the fully flattened dependency graph never contains `SystemSetKey`s, so lets encode that into its type. ## Solution - Added `GraphNodeId` trait. - Generalized `DiGraph` and `UnGraph` with a new `GraphNodeId` `N` type parameter. - Generalized most functions involving `DiGraph`/`UnGraph` to take a `GraphNodeId` type parameter. - Added `Graph::try_into` function to help us convert from `DiGraph<NodeId>` to `DiGraph<SystemKey>`. Does it look a bit gnarly? Yea. ## Testing Re-using current tests.
1 parent 877d278 commit face4f1

File tree

10 files changed

+454
-337
lines changed

10 files changed

+454
-337
lines changed

crates/bevy_ecs/src/schedule/auto_insert_apply_deferred.rs

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use super::{
2626
pub struct AutoInsertApplyDeferredPass {
2727
/// Dependency edges that will **not** automatically insert an instance of `ApplyDeferred` on the edge.
2828
no_sync_edges: BTreeSet<(NodeId, NodeId)>,
29-
auto_sync_node_ids: HashMap<u32, NodeId>,
29+
auto_sync_node_ids: HashMap<u32, SystemKey>,
3030
}
3131

3232
/// If added to a dependency edge, the edge will not be considered for auto sync point insertions.
@@ -35,14 +35,14 @@ pub struct IgnoreDeferred;
3535
impl AutoInsertApplyDeferredPass {
3636
/// Returns the `NodeId` of the cached auto sync point. Will create
3737
/// a new one if needed.
38-
fn get_sync_point(&mut self, graph: &mut ScheduleGraph, distance: u32) -> NodeId {
38+
fn get_sync_point(&mut self, graph: &mut ScheduleGraph, distance: u32) -> SystemKey {
3939
self.auto_sync_node_ids
4040
.get(&distance)
4141
.copied()
4242
.unwrap_or_else(|| {
43-
let node_id = NodeId::System(self.add_auto_sync(graph));
44-
self.auto_sync_node_ids.insert(distance, node_id);
45-
node_id
43+
let key = self.add_auto_sync(graph);
44+
self.auto_sync_node_ids.insert(distance, key);
45+
key
4646
})
4747
}
4848
/// add an [`ApplyDeferred`] system with no config
@@ -72,7 +72,7 @@ impl ScheduleBuildPass for AutoInsertApplyDeferredPass {
7272
&mut self,
7373
_world: &mut World,
7474
graph: &mut ScheduleGraph,
75-
dependency_flattened: &mut DiGraph,
75+
dependency_flattened: &mut DiGraph<SystemKey>,
7676
) -> Result<(), ScheduleBuildError> {
7777
let mut sync_point_graph = dependency_flattened.clone();
7878
let topo = graph.topsort_graph(dependency_flattened, ReportCycles::Dependency)?;
@@ -119,14 +119,10 @@ impl ScheduleBuildPass for AutoInsertApplyDeferredPass {
119119
HashMap::with_capacity_and_hasher(topo.len(), Default::default());
120120

121121
// Keep track of any explicit sync nodes for a specific distance.
122-
let mut distance_to_explicit_sync_node: HashMap<u32, NodeId> = HashMap::default();
122+
let mut distance_to_explicit_sync_node: HashMap<u32, SystemKey> = HashMap::default();
123123

124124
// Determine the distance for every node and collect the explicit sync points.
125-
for node in &topo {
126-
let &NodeId::System(key) = node else {
127-
panic!("Encountered a non-system node in the flattened dependency graph: {node:?}");
128-
};
129-
125+
for &key in &topo {
130126
let (node_distance, mut node_needs_sync) = distances_and_pending_sync
131127
.get(&key)
132128
.copied()
@@ -137,7 +133,7 @@ impl ScheduleBuildPass for AutoInsertApplyDeferredPass {
137133
// makes sure that this node is no unvisited target of another node.
138134
// Because of this, the sync point can be stored for this distance to be reused as
139135
// automatically added sync points later.
140-
distance_to_explicit_sync_node.insert(node_distance, NodeId::System(key));
136+
distance_to_explicit_sync_node.insert(node_distance, key);
141137

142138
// This node just did a sync, so the only reason to do another sync is if one was
143139
// explicitly scheduled afterwards.
@@ -148,10 +144,7 @@ impl ScheduleBuildPass for AutoInsertApplyDeferredPass {
148144
node_needs_sync = graph.systems[key].has_deferred();
149145
}
150146

151-
for target in dependency_flattened.neighbors_directed(*node, Direction::Outgoing) {
152-
let NodeId::System(target) = target else {
153-
panic!("Encountered a non-system node in the flattened dependency graph: {target:?}");
154-
};
147+
for target in dependency_flattened.neighbors_directed(key, Direction::Outgoing) {
155148
let (target_distance, target_pending_sync) =
156149
distances_and_pending_sync.entry(target).or_default();
157150

@@ -160,7 +153,7 @@ impl ScheduleBuildPass for AutoInsertApplyDeferredPass {
160153
&& !graph.systems[target].is_exclusive()
161154
&& self
162155
.no_sync_edges
163-
.contains(&(*node, NodeId::System(target)))
156+
.contains(&(NodeId::System(key), NodeId::System(target)))
164157
{
165158
// The node has deferred params to apply, but this edge is ignoring sync points.
166159
// Mark the target as 'delaying' those commands to a future edge and the current
@@ -184,19 +177,13 @@ impl ScheduleBuildPass for AutoInsertApplyDeferredPass {
184177

185178
// Find any edges which have a different number of sync points between them and make sure
186179
// there is a sync point between them.
187-
for node in &topo {
188-
let &NodeId::System(key) = node else {
189-
panic!("Encountered a non-system node in the flattened dependency graph: {node:?}");
190-
};
180+
for &key in &topo {
191181
let (node_distance, _) = distances_and_pending_sync
192182
.get(&key)
193183
.copied()
194184
.unwrap_or_default();
195185

196-
for target in dependency_flattened.neighbors_directed(*node, Direction::Outgoing) {
197-
let NodeId::System(target) = target else {
198-
panic!("Encountered a non-system node in the flattened dependency graph: {target:?}");
199-
};
186+
for target in dependency_flattened.neighbors_directed(key, Direction::Outgoing) {
200187
let (target_distance, _) = distances_and_pending_sync
201188
.get(&target)
202189
.copied()
@@ -218,11 +205,11 @@ impl ScheduleBuildPass for AutoInsertApplyDeferredPass {
218205
.copied()
219206
.unwrap_or_else(|| self.get_sync_point(graph, target_distance));
220207

221-
sync_point_graph.add_edge(*node, sync_point);
222-
sync_point_graph.add_edge(sync_point, NodeId::System(target));
208+
sync_point_graph.add_edge(key, sync_point);
209+
sync_point_graph.add_edge(sync_point, target);
223210

224211
// The edge without the sync point is now redundant.
225-
sync_point_graph.remove_edge(*node, NodeId::System(target));
212+
sync_point_graph.remove_edge(key, target);
226213
}
227214
}
228215

@@ -234,14 +221,14 @@ impl ScheduleBuildPass for AutoInsertApplyDeferredPass {
234221
&mut self,
235222
set: SystemSetKey,
236223
systems: &[SystemKey],
237-
dependency_flattened: &DiGraph,
224+
dependency_flattening: &DiGraph<NodeId>,
238225
) -> impl Iterator<Item = (NodeId, NodeId)> {
239226
if systems.is_empty() {
240227
// collapse dependencies for empty sets
241-
for a in dependency_flattened.neighbors_directed(NodeId::Set(set), Direction::Incoming)
228+
for a in dependency_flattening.neighbors_directed(NodeId::Set(set), Direction::Incoming)
242229
{
243230
for b in
244-
dependency_flattened.neighbors_directed(NodeId::Set(set), Direction::Outgoing)
231+
dependency_flattening.neighbors_directed(NodeId::Set(set), Direction::Outgoing)
245232
{
246233
if self.no_sync_edges.contains(&(a, NodeId::Set(set)))
247234
&& self.no_sync_edges.contains(&(NodeId::Set(set), b))
@@ -251,7 +238,7 @@ impl ScheduleBuildPass for AutoInsertApplyDeferredPass {
251238
}
252239
}
253240
} else {
254-
for a in dependency_flattened.neighbors_directed(NodeId::Set(set), Direction::Incoming)
241+
for a in dependency_flattening.neighbors_directed(NodeId::Set(set), Direction::Incoming)
255242
{
256243
for &sys in systems {
257244
if self.no_sync_edges.contains(&(a, NodeId::Set(set))) {
@@ -260,7 +247,7 @@ impl ScheduleBuildPass for AutoInsertApplyDeferredPass {
260247
}
261248
}
262249

263-
for b in dependency_flattened.neighbors_directed(NodeId::Set(set), Direction::Outgoing)
250+
for b in dependency_flattening.neighbors_directed(NodeId::Set(set), Direction::Outgoing)
264251
{
265252
for &sys in systems {
266253
if self.no_sync_edges.contains(&(NodeId::Set(set), b)) {

0 commit comments

Comments
 (0)