Skip to content

Commit 42f383f

Browse files
committed
[Backend Tester] Add CSV report generation
ghstack-source-id: 17e6147 ghstack-comment-id: 3105325555 Pull-Request: #12741
1 parent 5c33ef9 commit 42f383f

File tree

8 files changed

+185
-16
lines changed

8 files changed

+185
-16
lines changed

backends/test/suite/context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# Test run context management. This is used to determine the test context for reporting
22
# purposes.
33
class TestContext:
4-
def __init__(self, test_name: str, flow_name: str, params: dict | None):
4+
def __init__(self, test_name: str, test_base_name: str, flow_name: str, params: dict | None):
55
self.test_name = test_name
6+
self.test_base_name = test_base_name
67
self.flow_name = flow_name
78
self.params = params
89

backends/test/suite/models/__init__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,19 @@ def _create_test(
4242
dtype: torch.dtype,
4343
use_dynamic_shapes: bool,
4444
):
45+
dtype_name = str(dtype)[6:] # strip "torch."
46+
test_name = f"{test_func.__name__}_{flow.name}_{dtype_name}"
47+
if use_dynamic_shapes:
48+
test_name += "_dynamic_shape"
49+
4550
def wrapped_test(self):
4651
params = {
4752
"dtype": dtype,
4853
"use_dynamic_shapes": use_dynamic_shapes,
4954
}
50-
with TestContext(test_name, flow.name, params):
55+
with TestContext(test_name, test_func.__name__, flow.name, params):
5156
test_func(self, flow, dtype, use_dynamic_shapes)
5257

53-
dtype_name = str(dtype)[6:] # strip "torch."
54-
test_name = f"{test_func.__name__}_{flow.name}_{dtype_name}"
55-
if use_dynamic_shapes:
56-
test_name += "_dynamic_shape"
57-
5858
wrapped_test._name = test_func.__name__ # type: ignore
5959
wrapped_test._flow = flow # type: ignore
6060

@@ -118,6 +118,7 @@ def run_model_test(
118118
inputs,
119119
flow,
120120
context.test_name,
121+
context.test_base_name,
121122
context.params,
122123
dynamic_shapes=dynamic_shapes,
123124
)

backends/test/suite/operators/__init__.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-unsafe
88

9+
import copy
910
import os
1011
import unittest
1112

@@ -90,12 +91,13 @@ def _expand_test(cls, test_name: str):
9091
def _make_wrapped_test(
9192
test_func: Callable,
9293
test_name: str,
94+
test_base_name: str,
9395
flow: TestFlow,
9496
params: dict | None = None,
9597
):
9698
def wrapped_test(self):
97-
with TestContext(test_name, flow.name, params):
98-
test_kwargs = params or {}
99+
with TestContext(test_name, test_base_name, flow.name, params):
100+
test_kwargs = copy.copy(params) or {}
99101
test_kwargs["flow"] = flow
100102

101103
test_func(self, **test_kwargs)
@@ -114,19 +116,20 @@ def _create_test_for_backend(
114116
test_type = getattr(test_func, "test_type", TestType.STANDARD)
115117

116118
if test_type == TestType.STANDARD:
117-
wrapped_test = _make_wrapped_test(test_func, test_func.__name__, flow)
118119
test_name = f"{test_func.__name__}_{flow.name}"
120+
wrapped_test = _make_wrapped_test(test_func, test_name, test_func.__name__, flow)
119121
setattr(cls, test_name, wrapped_test)
120122
elif test_type == TestType.DTYPE:
121123
for dtype in DTYPES:
124+
dtype_name = str(dtype)[6:] # strip "torch."
125+
test_name = f"{test_func.__name__}_{dtype_name}_{flow.name}"
122126
wrapped_test = _make_wrapped_test(
123127
test_func,
128+
test_name,
124129
test_func.__name__,
125130
flow,
126131
{"dtype": dtype},
127132
)
128-
dtype_name = str(dtype)[6:] # strip "torch."
129-
test_name = f"{test_func.__name__}_{dtype_name}_{flow.name}"
130133
setattr(cls, test_name, wrapped_test)
131134
else:
132135
raise NotImplementedError(f"Unknown test type {test_type}.")
@@ -144,6 +147,7 @@ def _test_op(self, model, inputs, flow: TestFlow):
144147
inputs,
145148
flow,
146149
context.test_name,
150+
context.test_base_name,
147151
context.params,
148152
)
149153

backends/test/suite/reporting.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from collections import Counter
22
from dataclasses import dataclass
33
from enum import IntEnum
4+
from functools import reduce
5+
from re import A
6+
from typing import TextIO
47

8+
import csv
59

610
class TestResult(IntEnum):
711
"""Represents the result of a test case run, indicating success or a specific failure reason."""
@@ -75,13 +79,19 @@ class TestCaseSummary:
7579
"""
7680
Contains summary results for the execution of a single test case.
7781
"""
82+
83+
backend: str
84+
""" The name of the target backend. """
7885

79-
name: str
80-
""" The qualified name of the test, not including the flow suffix. """
81-
86+
base_name: str
87+
""" The base name of the test, not including flow or parameter suffixes. """
88+
8289
flow: str
8390
""" The backend-specific flow name. Corresponds to flows registered in backends/test/suite/__init__.py. """
8491

92+
name: str
93+
""" The full name of test, including flow and parameter suffixes. """
94+
8595
params: dict | None
8696
""" Test-specific parameters, such as dtype. """
8797

@@ -162,3 +172,40 @@ def complete_test_session() -> RunSummary:
162172
_active_session = None
163173

164174
return summary
175+
176+
def generate_csv_report(summary: RunSummary, output: TextIO):
177+
""" Write a run summary report to a file in CSV format. """
178+
179+
field_names = [
180+
"Test ID",
181+
"Test Case",
182+
"Backend",
183+
"Flow",
184+
"Result",
185+
]
186+
187+
# Tests can have custom parameters. We'll want to report them here, so we need
188+
# a list of all unique parameter names.
189+
param_names = reduce(
190+
lambda a, b: a.union(b),
191+
(set(s.params.keys()) for s in summary.test_case_summaries if s.params is not None),
192+
set()
193+
)
194+
field_names += (s.capitalize() for s in param_names)
195+
196+
writer = csv.DictWriter(output, field_names)
197+
writer.writeheader()
198+
199+
for record in summary.test_case_summaries:
200+
row = {
201+
"Test ID": record.name,
202+
"Test Case": record.base_name,
203+
"Backend": record.backend,
204+
"Flow": record.flow,
205+
"Result": record.result.display_name(),
206+
}
207+
if record.params is not None:
208+
row.update({
209+
k.capitalize(): v for k, v in record.params.items()
210+
})
211+
writer.writerow(row)

backends/test/suite/runner.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from executorch.backends.test.suite.reporting import (
1414
begin_test_session,
1515
complete_test_session,
16+
generate_csv_report,
1617
RunSummary,
1718
TestCaseSummary,
1819
TestResult,
@@ -31,6 +32,7 @@ def run_test( # noqa: C901
3132
inputs: Any,
3233
flow: TestFlow,
3334
test_name: str,
35+
test_base_name: str,
3436
params: dict | None,
3537
dynamic_shapes: Any | None = None,
3638
) -> TestCaseSummary:
@@ -44,8 +46,10 @@ def build_result(
4446
result: TestResult, error: Exception | None = None
4547
) -> TestCaseSummary:
4648
return TestCaseSummary(
47-
name=test_name,
49+
backend=flow.backend,
50+
base_name=test_base_name,
4851
flow=flow.name,
52+
name=test_name,
4953
params=params,
5054
result=result,
5155
error=error,
@@ -168,6 +172,9 @@ def parse_args():
168172
parser.add_argument(
169173
"-f", "--filter", nargs="?", help="A regular expression filter for test names."
170174
)
175+
parser.add_argument(
176+
"-r", "--report", nargs="?", help="A file to write the test report to, in CSV format."
177+
)
171178
return parser.parse_args()
172179

173180

@@ -195,6 +202,11 @@ def runner_main():
195202

196203
summary = complete_test_session()
197204
print_summary(summary)
205+
206+
if args.report is not None:
207+
with open(args.report, "w") as f:
208+
print(f"Writing CSV report to {args.report}.")
209+
generate_csv_report(summary, f)
198210

199211

200212
if __name__ == "__main__":

backends/test/suite/tests/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Tests
2+
3+
This directory contains meta-tests for the backend test suite. As the test suite contains a non-neglible amount of logic, these tests are useful to ensure that the test suite itself is working correctly.

backends/test/suite/tests/__init__.py

Whitespace-only changes.
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import torch
2+
import unittest
3+
4+
from csv import DictReader
5+
from ..reporting import TestResult, TestCaseSummary, RunSummary, TestSessionState, generate_csv_report
6+
from io import StringIO
7+
8+
# Test data for simulated test results.
9+
TEST_CASE_SUMMARIES = [
10+
TestCaseSummary(
11+
backend="backend1",
12+
base_name="test1",
13+
flow="flow1",
14+
name="test1_backend1_flow1",
15+
params=None,
16+
result=TestResult.SUCCESS,
17+
error=None,
18+
),
19+
TestCaseSummary(
20+
backend="backend2",
21+
base_name="test1",
22+
flow="flow1",
23+
name="test1_backend2_flow1",
24+
params=None,
25+
result=TestResult.LOWER_FAIL,
26+
error=None,
27+
),
28+
TestCaseSummary(
29+
backend="backend1",
30+
base_name="test2",
31+
flow="flow1",
32+
name="test2_backend1_flow1",
33+
params={"dtype": torch.float32},
34+
result=TestResult.SUCCESS_UNDELEGATED,
35+
error=None,
36+
),
37+
TestCaseSummary(
38+
backend="backend2",
39+
base_name="test2",
40+
flow="flow1",
41+
name="test2_backend2_flow1",
42+
params={"use_dynamic_shapes": True},
43+
result=TestResult.EXPORT_FAIL,
44+
error=None,
45+
),
46+
]
47+
48+
class Reporting(unittest.TestCase):
49+
def test_csv_report_simple(self):
50+
# Verify the format of a simple CSV run report.
51+
session_state = TestSessionState()
52+
session_state.test_case_summaries.extend(TEST_CASE_SUMMARIES)
53+
run_summary = RunSummary.from_session(session_state)
54+
55+
strio = StringIO()
56+
generate_csv_report(run_summary, strio)
57+
58+
# Attempt to deserialize and validate the CSV report.
59+
report = DictReader(StringIO(strio.getvalue()))
60+
records = list(report)
61+
self.assertEqual(len(records), 4)
62+
63+
# Validate first record: test1, backend1, SUCCESS
64+
self.assertEqual(records[0]["Test ID"], "test1_backend1_flow1")
65+
self.assertEqual(records[0]["Test Case"], "test1")
66+
self.assertEqual(records[0]["Backend"], "backend1")
67+
self.assertEqual(records[0]["Flow"], "flow1")
68+
self.assertEqual(records[0]["Result"], "Success (Delegated)")
69+
self.assertEqual(records[0]["Dtype"], "")
70+
self.assertEqual(records[0]["Use_dynamic_shapes"], "")
71+
72+
# Validate second record: test1, backend2, LOWER_FAIL
73+
self.assertEqual(records[1]["Test ID"], "test1_backend2_flow1")
74+
self.assertEqual(records[1]["Test Case"], "test1")
75+
self.assertEqual(records[1]["Backend"], "backend2")
76+
self.assertEqual(records[1]["Flow"], "flow1")
77+
self.assertEqual(records[1]["Result"], "Fail (Lowering)")
78+
self.assertEqual(records[1]["Dtype"], "")
79+
self.assertEqual(records[1]["Use_dynamic_shapes"], "")
80+
81+
# Validate third record: test2, backend1, SUCCESS_UNDELEGATED with dtype param
82+
self.assertEqual(records[2]["Test ID"], "test2_backend1_flow1")
83+
self.assertEqual(records[2]["Test Case"], "test2")
84+
self.assertEqual(records[2]["Backend"], "backend1")
85+
self.assertEqual(records[2]["Flow"], "flow1")
86+
self.assertEqual(records[2]["Result"], "Success (Undelegated)")
87+
self.assertEqual(records[2]["Dtype"], str(torch.float32))
88+
self.assertEqual(records[2]["Use_dynamic_shapes"], "")
89+
90+
# Validate fourth record: test2, backend2, EXPORT_FAIL with use_dynamic_shapes param
91+
self.assertEqual(records[3]["Test ID"], "test2_backend2_flow1")
92+
self.assertEqual(records[3]["Test Case"], "test2")
93+
self.assertEqual(records[3]["Backend"], "backend2")
94+
self.assertEqual(records[3]["Flow"], "flow1")
95+
self.assertEqual(records[3]["Result"], "Fail (Export)")
96+
self.assertEqual(records[3]["Dtype"], "")
97+
self.assertEqual(records[3]["Use_dynamic_shapes"], "True")
98+
99+
100+
101+

0 commit comments

Comments
 (0)