Skip to content

Commit dfc6e2b

Browse files
committed
performance(sierra-calcs): Made topological order not require vec alloc.
SIERRA_UPDATE_PATCH_CHANGE_TAG=Just performance gain - no interface effect.
1 parent 07d0f2a commit dfc6e2b

File tree

4 files changed

+51
-46
lines changed

4 files changed

+51
-46
lines changed

crates/cairo-lang-sierra-ap-change/src/compute.rs

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -191,25 +191,18 @@ impl<'a, TokenUsages: Fn(StatementIdx, CostTokenType) -> usize>
191191
(0..self.program.statements.len()).map(StatementIdx),
192192
self.program.statements.len(),
193193
|idx| {
194-
Ok(self.branches[idx.0]
195-
.iter()
196-
.flat_map(|(ap_change, target)| match ap_change {
197-
ApChange::Unknown => None,
198-
ApChange::FunctionCall(id) => {
199-
if self.function_ap_change.contains_key(id) {
200-
Some(*target)
201-
} else {
202-
None
203-
}
204-
}
205-
ApChange::Known(_)
206-
| ApChange::DisableApTracking
207-
| ApChange::FromMetadata
208-
| ApChange::AtLocalsFinalization(_)
209-
| ApChange::FinalizeLocals
210-
| ApChange::EnableApTracking => Some(*target),
211-
})
212-
.collect())
194+
Ok(self.branches[idx.0].iter().flat_map(|(ap_change, target)| match ap_change {
195+
ApChange::Unknown => None,
196+
ApChange::FunctionCall(id) => {
197+
self.function_ap_change.contains_key(id).then_some(*target)
198+
}
199+
ApChange::Known(_)
200+
| ApChange::DisableApTracking
201+
| ApChange::FromMetadata
202+
| ApChange::AtLocalsFinalization(_)
203+
| ApChange::FinalizeLocals
204+
| ApChange::EnableApTracking => Some(*target),
205+
}))
213206
},
214207
|_| unreachable!("Cycle isn't an error."),
215208
)

crates/cairo-lang-sierra-gas/src/compute_costs.rs

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ use cairo_lang_sierra::program::{BranchInfo, Invocation, Program, Statement, Sta
77
use cairo_lang_utils::casts::IntoOrPanic;
88
use cairo_lang_utils::iterators::zip_eq3;
99
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
10-
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
1110
use cairo_lang_utils::unordered_hash_map::{Entry, UnorderedHashMap};
1211
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
1312
use itertools::zip_eq;
@@ -180,12 +179,19 @@ fn get_branch_requirements_dependencies(
180179
idx: &StatementIdx,
181180
invocation: &Invocation,
182181
libfunc_cost: &[BranchCost],
183-
) -> OrderedHashSet<StatementIdx> {
184-
let mut res: OrderedHashSet<StatementIdx> = Default::default();
182+
) -> Vec<StatementIdx> {
183+
let mut res = vec![];
184+
// Adds to the result if not already in it.
185+
// Since rather small - more efficient using a Vec than a Map.
186+
let mut add_to_res = |idx| {
187+
if !res.contains(&idx) {
188+
res.push(idx);
189+
}
190+
};
185191
for (branch_info, branch_cost) in zip_eq(&invocation.branches, libfunc_cost) {
186192
match branch_cost {
187193
BranchCost::FunctionCost { const_cost: _, function, sign: _ } => {
188-
res.insert(function.entry_point);
194+
add_to_res(function.entry_point);
189195
}
190196
BranchCost::WithdrawGas(WithdrawGasBranchInfo {
191197
success: true,
@@ -197,9 +203,8 @@ fn get_branch_requirements_dependencies(
197203
}
198204
_ => {}
199205
}
200-
res.insert(idx.next(&branch_info.target));
206+
add_to_res(idx.next(&branch_info.target));
201207
}
202-
203208
res
204209
}
205210

@@ -454,13 +459,11 @@ impl<CostType: CostTypeTrait> CostContext<'_, CostType> {
454459
// Return has no dependencies.
455460
vec![]
456461
}
457-
Statement::Invocation(invocation) => {
458-
let libfunc_cost = &self.branch_costs[current_idx.0];
459-
460-
get_branch_requirements_dependencies(current_idx, invocation, libfunc_cost)
461-
.into_iter()
462-
.collect()
463-
}
462+
Statement::Invocation(invocation) => get_branch_requirements_dependencies(
463+
current_idx,
464+
invocation,
465+
&self.branch_costs[current_idx.0],
466+
),
464467
}
465468
},
466469
)?;
@@ -674,10 +677,13 @@ impl<CostType: CostTypeTrait> CostContext<'_, CostType> {
674677
/// Generates a topological ordering of the statements according to the given dependencies_callback.
675678
///
676679
/// Each statement appears in the ordering after its dependencies.
677-
fn compute_reverse_topological_order(
680+
fn compute_reverse_topological_order<
681+
Dependencies: IntoIterator<Item = StatementIdx>,
682+
DependenciesCallback: Fn(&StatementIdx) -> Dependencies,
683+
>(
678684
n_statements: usize,
679685
detect_cycles: bool,
680-
dependencies_callback: impl Fn(&StatementIdx) -> Vec<StatementIdx>,
686+
dependencies_callback: DependenciesCallback,
681687
) -> Result<Vec<StatementIdx>, CostError> {
682688
reverse_topological_ordering(
683689
detect_cycles,

crates/cairo-lang-sierra-gas/src/generate_equations.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,12 @@ fn get_reverse_topological_ordering(program: &Program) -> Result<Vec<StatementId
110110
program.statements.len(),
111111
|idx| {
112112
Ok(match program.get_statement(&idx).unwrap() {
113-
Statement::Invocation(invocation) => invocation
114-
.branches
115-
.iter()
116-
.rev()
117-
.map(|branch| idx.next(&branch.target))
118-
.collect(),
119-
Statement::Return(_) => vec![],
120-
})
113+
Statement::Invocation(invocation) => invocation.branches.as_slice(),
114+
Statement::Return(_) => &[],
115+
}
116+
.iter()
117+
.rev()
118+
.map(move |branch| idx.next(&branch.target)))
121119
},
122120
|_| unreachable!("Cycles are not detected."),
123121
)

crates/cairo-lang-sierra/src/algorithm/topological_order.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,15 @@ enum TopologicalOrderStatus {
2121
/// `out_of_bounds_err` - a function that returns an error for a node is out of bounds.
2222
/// `cycle_err` - a function that returns an error for a node that is part of a cycle.
2323
/// Note: Will only work properly if the nodes are in the range [0, node_count).
24-
pub fn reverse_topological_ordering<E>(
24+
pub fn reverse_topological_ordering<
25+
E,
26+
Children: IntoIterator<Item = StatementIdx>,
27+
GetChildren: Fn(StatementIdx) -> Result<Children, E>,
28+
>(
2529
detect_cycles: bool,
2630
roots: impl Iterator<Item = StatementIdx>,
2731
node_count: usize,
28-
get_children: impl Fn(StatementIdx) -> Result<Vec<StatementIdx>, E>,
32+
get_children: GetChildren,
2933
cycle_err: impl Fn(StatementIdx) -> E,
3034
) -> Result<Vec<StatementIdx>, E> {
3135
let mut ordering = vec![];
@@ -45,12 +49,16 @@ pub fn reverse_topological_ordering<E>(
4549

4650
/// Calculates the reverse topological ordering starting from `root`. For more info see
4751
/// `reverse_topological_ordering`.
48-
fn calculate_reverse_topological_ordering<E>(
52+
fn calculate_reverse_topological_ordering<
53+
E,
54+
Children: IntoIterator<Item = StatementIdx>,
55+
GetChildren: Fn(StatementIdx) -> Result<Children, E>,
56+
>(
4957
detect_cycles: bool,
5058
ordering: &mut Vec<StatementIdx>,
5159
status: &mut [TopologicalOrderStatus],
5260
root: StatementIdx,
53-
get_children: &impl Fn(StatementIdx) -> Result<Vec<StatementIdx>, E>,
61+
get_children: &GetChildren,
5462
cycle_err: &impl Fn(StatementIdx) -> E,
5563
) -> Result<(), E> {
5664
// A stack of statements to visit.

0 commit comments

Comments
 (0)