1
1
import csv
2
+
2
3
from collections import Counter
3
4
from dataclasses import dataclass
4
5
from datetime import timedelta
5
6
from enum import IntEnum
6
7
from functools import reduce
7
- from typing import TextIO
8
+ from typing import Any , TextIO
8
9
9
10
from executorch .backends .test .harness .error_statistics import ErrorStatistics
11
+ from torch .export import ExportedProgram
12
+
13
+
14
+ # Operators that are excluded from the counts returned by count_ops. These are used to
15
+ # exclude operatations that are not logically relevant or delegatable to backends.
16
+ OP_COUNT_IGNORED_OPS = {
17
+ "executorch_call_delegate" ,
18
+ "getitem" ,
19
+ }
10
20
11
21
12
22
class TestResult (IntEnum ):
@@ -115,6 +125,12 @@ class TestCaseSummary:
115
125
lower_time : timedelta | None = None
116
126
""" The total runtime of the to_edge_transform_and_lower stage, or none, if the test did not run the quantize stage. """
117
127
128
+ delegated_op_counts : Counter | None = None
129
+ """ The number of delegated occurances of each operator in the graph. """
130
+
131
+ undelegated_op_counts : Counter | None = None
132
+ """ The number of undelegated occurances of each operator in the graph. """
133
+
118
134
119
135
class TestSessionState :
120
136
test_case_summaries : list [TestCaseSummary ]
@@ -164,6 +180,40 @@ def from_session(cls, session: TestSessionState) -> "RunSummary":
164
180
_active_session : TestSessionState | None = None
165
181
166
182
183
+ def _get_target_name (target : Any ) -> str :
184
+ """Retrieve a string representation of a node target."""
185
+ if isinstance (target , str ):
186
+ return target
187
+ elif hasattr (target , "name" ):
188
+ return target .name () # Op overloads have this
189
+ elif hasattr (target , "__name__" ):
190
+ return target .__name__ # Some builtins have this
191
+ else :
192
+ return str (target )
193
+
194
+
195
+ def _count_ops (program : ExportedProgram ) -> Counter :
196
+ op_names = (
197
+ _get_target_name (n .target )
198
+ for n in program .graph .nodes
199
+ if n .op == "call_function"
200
+ )
201
+
202
+ return Counter (op for op in op_names if op not in OP_COUNT_IGNORED_OPS )
203
+
204
+
205
+ def count_ops (program : dict [str , ExportedProgram ] | ExportedProgram ) -> Counter :
206
+ if isinstance (program , ExportedProgram ):
207
+ return _count_ops (program )
208
+ else :
209
+ # Sum op counts for all methods in the program.
210
+ return reduce (
211
+ lambda a , b : a + b ,
212
+ (_count_ops (p ) for p in program .values ()),
213
+ Counter (),
214
+ )
215
+
216
+
167
217
def begin_test_session ():
168
218
global _active_session
169
219
@@ -188,6 +238,24 @@ def complete_test_session() -> RunSummary:
188
238
return summary
189
239
190
240
241
+ def _sum_op_counts (counter : Counter | None ) -> int | None :
242
+ """
243
+ A utility function to count the total number of nodes in an op count dict.
244
+ """
245
+ return sum (counter .values ()) if counter is not None else None
246
+
247
+
248
+ def _serialize_op_counts (counter : Counter | None ) -> str :
249
+ """
250
+ A utility function to serialize op counts to a string, for the purpose of including
251
+ in the test report.
252
+ """
253
+ if counter is not None :
254
+ return str (dict (sorted (counter .items ())))
255
+ else :
256
+ return ""
257
+
258
+
191
259
def generate_csv_report (summary : RunSummary , output : TextIO ):
192
260
"""Write a run summary report to a file in CSV format."""
193
261
@@ -228,6 +296,14 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
228
296
f"Output { i } SQNR" ,
229
297
]
230
298
)
299
+ field_names .extend (
300
+ [
301
+ "Delegated Nodes" ,
302
+ "Undelegated Nodes" ,
303
+ "Delegated Ops" ,
304
+ "Undelegated Ops" ,
305
+ ]
306
+ )
231
307
232
308
writer = csv .DictWriter (output , field_names )
233
309
writer .writeheader ()
@@ -256,4 +332,9 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
256
332
row [f"Output { output_idx } Error L2" ] = error_stats .error_l2_norm
257
333
row [f"Output { output_idx } SQNR" ] = error_stats .sqnr
258
334
335
+ row ["Delegated Nodes" ] = _sum_op_counts (record .delegated_op_counts )
336
+ row ["Undelegated Nodes" ] = _sum_op_counts (record .undelegated_op_counts )
337
+ row ["Delegated Ops" ] = _serialize_op_counts (record .delegated_op_counts )
338
+ row ["Undelegated Ops" ] = _serialize_op_counts (record .undelegated_op_counts )
339
+
259
340
writer .writerow (row )
0 commit comments