Skip to content

Commit af471be

Browse files
committed
[Backend Tester] Add CSV report generation
ghstack-source-id: 058e3ca ghstack-comment-id: 3105325555 Pull-Request: #12741
1 parent 8f470f8 commit af471be

File tree

8 files changed

+185
-17
lines changed

8 files changed

+185
-17
lines changed

backends/test/suite/context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# Test run context management. This is used to determine the test context for reporting
22
# purposes.
33
class TestContext:
4-
def __init__(self, test_name: str, flow_name: str, params: dict | None):
4+
def __init__(self, test_name: str, test_base_name: str, flow_name: str, params: dict | None):
55
self.test_name = test_name
6+
self.test_base_name = test_base_name
67
self.flow_name = flow_name
78
self.params = params
89

backends/test/suite/models/__init__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,19 @@ def _create_test(
4343
dtype: torch.dtype,
4444
use_dynamic_shapes: bool,
4545
):
46+
dtype_name = str(dtype)[6:] # strip "torch."
47+
test_name = f"{test_func.__name__}_{flow.name}_{dtype_name}"
48+
if use_dynamic_shapes:
49+
test_name += "_dynamic_shape"
50+
4651
def wrapped_test(self):
4752
params = {
4853
"dtype": dtype,
4954
"use_dynamic_shapes": use_dynamic_shapes,
5055
}
51-
with TestContext(test_name, flow.name, params):
56+
with TestContext(test_name, test_func.__name__, flow.name, params):
5257
test_func(self, flow, dtype, use_dynamic_shapes)
5358

54-
dtype_name = str(dtype)[6:] # strip "torch."
55-
test_name = f"{test_func.__name__}_{flow.name}_{dtype_name}"
56-
if use_dynamic_shapes:
57-
test_name += "_dynamic_shape"
58-
5959
wrapped_test._name = test_func.__name__ # type: ignore
6060
wrapped_test._flow = flow # type: ignore
6161

@@ -119,6 +119,7 @@ def run_model_test(
119119
inputs,
120120
flow,
121121
context.test_name,
122+
context.test_base_name,
122123
context.params,
123124
dynamic_shapes=dynamic_shapes,
124125
)

backends/test/suite/operators/__init__.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-unsafe
88

9+
import copy
910
import os
1011
import unittest
1112

@@ -89,12 +90,13 @@ def _expand_test(cls, test_name: str):
8990
def _make_wrapped_test(
9091
test_func: Callable,
9192
test_name: str,
93+
test_base_name: str,
9294
flow: TestFlow,
9395
params: dict | None = None,
9496
):
9597
def wrapped_test(self):
96-
with TestContext(test_name, flow.name, params):
97-
test_kwargs = params or {}
98+
with TestContext(test_name, test_base_name, flow.name, params):
99+
test_kwargs = copy.copy(params) or {}
98100
test_kwargs["flow"] = flow
99101

100102
test_func(self, **test_kwargs)
@@ -113,19 +115,20 @@ def _create_test_for_backend(
113115
test_type = getattr(test_func, "test_type", TestType.STANDARD)
114116

115117
if test_type == TestType.STANDARD:
116-
wrapped_test = _make_wrapped_test(test_func, test_func.__name__, flow)
117118
test_name = f"{test_func.__name__}_{flow.name}"
119+
wrapped_test = _make_wrapped_test(test_func, test_name, test_func.__name__, flow)
118120
setattr(cls, test_name, wrapped_test)
119121
elif test_type == TestType.DTYPE:
120122
for dtype in DTYPES:
123+
dtype_name = str(dtype)[6:] # strip "torch."
124+
test_name = f"{test_func.__name__}_{dtype_name}_{flow.name}"
121125
wrapped_test = _make_wrapped_test(
122126
test_func,
127+
test_name,
123128
test_func.__name__,
124129
flow,
125130
{"dtype": dtype},
126131
)
127-
dtype_name = str(dtype)[6:] # strip "torch."
128-
test_name = f"{test_func.__name__}_{dtype_name}_{flow.name}"
129132
setattr(cls, test_name, wrapped_test)
130133
else:
131134
raise NotImplementedError(f"Unknown test type {test_type}.")
@@ -143,6 +146,7 @@ def _test_op(self, model, inputs, flow: TestFlow):
143146
inputs,
144147
flow,
145148
context.test_name,
149+
context.test_base_name,
146150
context.params,
147151
)
148152

backends/test/suite/reporting.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from collections import Counter
22
from dataclasses import dataclass
33
from enum import IntEnum
4+
from functools import reduce
5+
from re import A
6+
from typing import TextIO
47

8+
import csv
59

610
class TestResult(IntEnum):
711
"""Represents the result of a test case run, indicating success or a specific failure reason."""
@@ -75,13 +79,19 @@ class TestCaseSummary:
7579
"""
7680
Contains summary results for the execution of a single test case.
7781
"""
82+
83+
backend: str
84+
""" The name of the target backend. """
7885

79-
name: str
80-
""" The qualified name of the test, not including the flow suffix. """
81-
86+
base_name: str
87+
""" The base name of the test, not including flow or parameter suffixes. """
88+
8289
flow: str
8390
""" The backend-specific flow name. Corresponds to flows registered in backends/test/suite/__init__.py. """
8491

92+
name: str
93+
""" The full name of test, including flow and parameter suffixes. """
94+
8595
params: dict | None
8696
""" Test-specific parameters, such as dtype. """
8797

@@ -162,3 +172,40 @@ def complete_test_session() -> RunSummary:
162172
_active_session = None
163173

164174
return summary
175+
176+
def generate_csv_report(summary: RunSummary, output: TextIO):
177+
""" Write a run summary report to a file in CSV format. """
178+
179+
field_names = [
180+
"Test ID",
181+
"Test Case",
182+
"Backend",
183+
"Flow",
184+
"Result",
185+
]
186+
187+
# Tests can have custom parameters. We'll want to report them here, so we need
188+
# a list of all unique parameter names.
189+
param_names = reduce(
190+
lambda a, b: a.union(b),
191+
(set(s.params.keys()) for s in summary.test_case_summaries if s.params is not None),
192+
set()
193+
)
194+
field_names += (s.capitalize() for s in param_names)
195+
196+
writer = csv.DictWriter(output, field_names)
197+
writer.writeheader()
198+
199+
for record in summary.test_case_summaries:
200+
row = {
201+
"Test ID": record.name,
202+
"Test Case": record.base_name,
203+
"Backend": record.backend,
204+
"Flow": record.flow,
205+
"Result": record.result.display_name(),
206+
}
207+
if record.params is not None:
208+
row.update({
209+
k.capitalize(): v for k, v in record.params.items()
210+
})
211+
writer.writerow(row)

backends/test/suite/runner.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77

88
import torch
99

10-
from executorch.backends.test.harness import Tester
1110
from executorch.backends.test.harness.stages import StageType
1211
from executorch.backends.test.suite.discovery import discover_tests, TestFilter
1312
from executorch.backends.test.suite.flow import TestFlow
1413
from executorch.backends.test.suite.reporting import (
1514
begin_test_session,
1615
complete_test_session,
16+
generate_csv_report,
1717
RunSummary,
1818
TestCaseSummary,
1919
TestResult,
@@ -32,6 +32,7 @@ def run_test( # noqa: C901
3232
inputs: Any,
3333
flow: TestFlow,
3434
test_name: str,
35+
test_base_name: str,
3536
params: dict | None,
3637
dynamic_shapes: Any | None = None,
3738
) -> TestCaseSummary:
@@ -45,8 +46,10 @@ def build_result(
4546
result: TestResult, error: Exception | None = None
4647
) -> TestCaseSummary:
4748
return TestCaseSummary(
48-
name=test_name,
49+
backend=flow.backend,
50+
base_name=test_base_name,
4951
flow=flow.name,
52+
name=test_name,
5053
params=params,
5154
result=result,
5255
error=error,
@@ -167,6 +170,9 @@ def parse_args():
167170
parser.add_argument(
168171
"-f", "--filter", nargs="?", help="A regular expression filter for test names."
169172
)
173+
parser.add_argument(
174+
"-r", "--report", nargs="?", help="A file to write the test report to, in CSV format."
175+
)
170176
return parser.parse_args()
171177

172178

@@ -194,6 +200,11 @@ def runner_main():
194200

195201
summary = complete_test_session()
196202
print_summary(summary)
203+
204+
if args.report is not None:
205+
with open(args.report, "w") as f:
206+
print(f"Writing CSV report to {args.report}.")
207+
generate_csv_report(summary, f)
197208

198209

199210
if __name__ == "__main__":

backends/test/suite/tests/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Tests
2+
3+
This directory contains meta-tests for the backend test suite. As the test suite contains a non-neglible amount of logic, these tests are useful to ensure that the test suite itself is working correctly.

backends/test/suite/tests/__init__.py

Whitespace-only changes.
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import torch
2+
import unittest
3+
4+
from csv import DictReader
5+
from ..reporting import TestResult, TestCaseSummary, RunSummary, TestSessionState, generate_csv_report
6+
from io import StringIO
7+
8+
# Test data for simulated test results.
9+
TEST_CASE_SUMMARIES = [
10+
TestCaseSummary(
11+
backend="backend1",
12+
base_name="test1",
13+
flow="flow1",
14+
name="test1_backend1_flow1",
15+
params=None,
16+
result=TestResult.SUCCESS,
17+
error=None,
18+
),
19+
TestCaseSummary(
20+
backend="backend2",
21+
base_name="test1",
22+
flow="flow1",
23+
name="test1_backend2_flow1",
24+
params=None,
25+
result=TestResult.LOWER_FAIL,
26+
error=None,
27+
),
28+
TestCaseSummary(
29+
backend="backend1",
30+
base_name="test2",
31+
flow="flow1",
32+
name="test2_backend1_flow1",
33+
params={"dtype": torch.float32},
34+
result=TestResult.SUCCESS_UNDELEGATED,
35+
error=None,
36+
),
37+
TestCaseSummary(
38+
backend="backend2",
39+
base_name="test2",
40+
flow="flow1",
41+
name="test2_backend2_flow1",
42+
params={"use_dynamic_shapes": True},
43+
result=TestResult.EXPORT_FAIL,
44+
error=None,
45+
),
46+
]
47+
48+
class Reporting(unittest.TestCase):
49+
def test_csv_report_simple(self):
50+
# Verify the format of a simple CSV run report.
51+
session_state = TestSessionState()
52+
session_state.test_case_summaries.extend(TEST_CASE_SUMMARIES)
53+
run_summary = RunSummary.from_session(session_state)
54+
55+
strio = StringIO()
56+
generate_csv_report(run_summary, strio)
57+
58+
# Attempt to deserialize and validate the CSV report.
59+
report = DictReader(StringIO(strio.getvalue()))
60+
records = list(report)
61+
self.assertEqual(len(records), 4)
62+
63+
# Validate first record: test1, backend1, SUCCESS
64+
self.assertEqual(records[0]["Test ID"], "test1_backend1_flow1")
65+
self.assertEqual(records[0]["Test Case"], "test1")
66+
self.assertEqual(records[0]["Backend"], "backend1")
67+
self.assertEqual(records[0]["Flow"], "flow1")
68+
self.assertEqual(records[0]["Result"], "Success (Delegated)")
69+
self.assertEqual(records[0]["Dtype"], "")
70+
self.assertEqual(records[0]["Use_dynamic_shapes"], "")
71+
72+
# Validate second record: test1, backend2, LOWER_FAIL
73+
self.assertEqual(records[1]["Test ID"], "test1_backend2_flow1")
74+
self.assertEqual(records[1]["Test Case"], "test1")
75+
self.assertEqual(records[1]["Backend"], "backend2")
76+
self.assertEqual(records[1]["Flow"], "flow1")
77+
self.assertEqual(records[1]["Result"], "Fail (Lowering)")
78+
self.assertEqual(records[1]["Dtype"], "")
79+
self.assertEqual(records[1]["Use_dynamic_shapes"], "")
80+
81+
# Validate third record: test2, backend1, SUCCESS_UNDELEGATED with dtype param
82+
self.assertEqual(records[2]["Test ID"], "test2_backend1_flow1")
83+
self.assertEqual(records[2]["Test Case"], "test2")
84+
self.assertEqual(records[2]["Backend"], "backend1")
85+
self.assertEqual(records[2]["Flow"], "flow1")
86+
self.assertEqual(records[2]["Result"], "Success (Undelegated)")
87+
self.assertEqual(records[2]["Dtype"], str(torch.float32))
88+
self.assertEqual(records[2]["Use_dynamic_shapes"], "")
89+
90+
# Validate fourth record: test2, backend2, EXPORT_FAIL with use_dynamic_shapes param
91+
self.assertEqual(records[3]["Test ID"], "test2_backend2_flow1")
92+
self.assertEqual(records[3]["Test Case"], "test2")
93+
self.assertEqual(records[3]["Backend"], "backend2")
94+
self.assertEqual(records[3]["Flow"], "flow1")
95+
self.assertEqual(records[3]["Result"], "Fail (Export)")
96+
self.assertEqual(records[3]["Dtype"], "")
97+
self.assertEqual(records[3]["Use_dynamic_shapes"], "True")
98+
99+
100+
101+

0 commit comments

Comments
 (0)