Skip to content

Commit 492a50e

Browse files
committed
[Backend Tester] Add quantized test flows for XNNPACK and Core ML
ghstack-source-id: 3bd0358 ghstack-comment-id: 3105090683 Pull-Request: #12733
1 parent 4b33596 commit 492a50e

31 files changed

+475
-366
lines changed

backends/apple/coreml/test/tester.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,65 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, List, Optional, Tuple
7+
import functools
8+
from typing import Any, List, Optional, Sequence, Tuple
89

10+
import coremltools as ct
911
import executorch
1012
import executorch.backends.test.harness.stages as BaseStages
11-
1213
import torch
14+
15+
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1316
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
17+
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
1418
from executorch.backends.test.harness import Tester as TesterBase
1519
from executorch.backends.test.harness.stages import StageType
1620
from executorch.exir import EdgeCompileConfig
1721
from executorch.exir.backend.partitioner import Partitioner
1822

1923

24+
def _get_static_int8_linear_qconfig():
25+
return ct.optimize.torch.quantization.LinearQuantizerConfig(
26+
global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig(
27+
quantization_scheme="symmetric",
28+
activation_dtype=torch.quint8,
29+
weight_dtype=torch.qint8,
30+
weight_per_channel=True,
31+
)
32+
)
33+
34+
35+
class Quantize(BaseStages.Quantize):
36+
def __init__(
37+
self,
38+
quantizer: Optional[CoreMLQuantizer] = None,
39+
quantization_config: Optional[Any] = None,
40+
calibrate: bool = True,
41+
calibration_samples: Optional[Sequence[Any]] = None,
42+
is_qat: Optional[bool] = False,
43+
):
44+
super().__init__(
45+
quantizer=quantizer
46+
or CoreMLQuantizer(quantization_config or _get_static_int8_linear_qconfig()),
47+
calibrate=calibrate,
48+
calibration_samples=calibration_samples,
49+
is_qat=is_qat,
50+
)
51+
52+
2053
class Partition(BaseStages.Partition):
21-
def __init__(self, partitioner: Optional[Partitioner] = None):
54+
def __init__(
55+
self,
56+
partitioner: Optional[Partitioner] = None,
57+
minimum_deployment_target: Optional[Any] = ct.target.iOS15,
58+
):
2259
super().__init__(
23-
partitioner=partitioner or CoreMLPartitioner,
60+
partitioner=partitioner
61+
or CoreMLPartitioner(
62+
compile_specs=CoreMLBackend.generate_compile_specs(
63+
minimum_deployment_target=minimum_deployment_target
64+
)
65+
),
2466
)
2567

2668

@@ -29,9 +71,14 @@ def __init__(
2971
self,
3072
partitioners: Optional[List[Partitioner]] = None,
3173
edge_compile_config: Optional[EdgeCompileConfig] = None,
74+
minimum_deployment_target: Optional[Any] = ct.target.iOS15,
3275
):
3376
super().__init__(
34-
default_partitioner_cls=CoreMLPartitioner,
77+
default_partitioner_cls=lambda: CoreMLPartitioner(
78+
compile_specs=CoreMLBackend.generate_compile_specs(
79+
minimum_deployment_target=minimum_deployment_target
80+
)
81+
),
3582
partitioners=partitioners,
3683
edge_compile_config=edge_compile_config,
3784
)
@@ -43,13 +90,20 @@ def __init__(
4390
module: torch.nn.Module,
4491
example_inputs: Tuple[torch.Tensor],
4592
dynamic_shapes: Optional[Tuple[Any]] = None,
93+
minimum_deployment_target: Optional[Any] = ct.target.iOS15,
4694
):
4795
# Specialize for XNNPACK
4896
stage_classes = (
4997
executorch.backends.test.harness.Tester.default_stage_classes()
5098
| {
51-
StageType.PARTITION: Partition,
52-
StageType.TO_EDGE_TRANSFORM_AND_LOWER: ToEdgeTransformAndLower,
99+
StageType.QUANTIZE: Quantize,
100+
StageType.PARTITION: functools.partial(
101+
Partition, minimum_deployment_target=minimum_deployment_target
102+
),
103+
StageType.TO_EDGE_TRANSFORM_AND_LOWER: functools.partial(
104+
ToEdgeTransformAndLower,
105+
minimum_deployment_target=minimum_deployment_target,
106+
),
53107
}
54108
)
55109

backends/test/harness/stages/quantize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@ def __init__(
2525
calibrate: bool = True,
2626
calibration_samples: Optional[Sequence[Any]] = None,
2727
is_qat: Optional[bool] = False,
28+
set_global: bool = True,
2829
):
2930
self.quantizer = quantizer
3031
self.quantization_config = quantization_config
3132
self.calibrate = calibrate
3233
self.calibration_samples = calibration_samples
3334

34-
self.quantizer.set_global(self.quantization_config)
35+
if self.quantization_config is not None and set_global:
36+
self.quantizer.set_global(self.quantization_config)
3537

3638
self.converted_graph = None
3739
self.is_qat = is_qat

backends/test/harness/tester.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import random
22
from collections import Counter, OrderedDict
3-
from typing import Any, Dict, List, Optional, Tuple, Type
3+
from typing import Any, Callable, Dict, List, Optional, Tuple
44

55
import torch
66

@@ -33,7 +33,7 @@ def __init__(
3333
self,
3434
module: torch.nn.Module,
3535
example_inputs: Tuple[torch.Tensor],
36-
stage_classes: Dict[StageType, Type],
36+
stage_classes: Dict[StageType, Callable],
3737
dynamic_shapes: Optional[Tuple[Any]] = None,
3838
):
3939
module.eval()
@@ -81,7 +81,7 @@ def __init__(
8181
self.stage_output = None
8282

8383
@staticmethod
84-
def default_stage_classes() -> Dict[StageType, Type]:
84+
def default_stage_classes() -> Dict[StageType, Callable]:
8585
"""
8686
Returns a map of StageType to default Stage implementation.
8787
"""

backends/test/suite/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _make_wrapped_test(
129129
def wrapped_test(self):
130130
with TestContext(test_name, flow.name, params):
131131
test_kwargs = params or {}
132-
test_kwargs["tester_factory"] = flow.tester_factory
132+
test_kwargs["flow"] = flow
133133

134134
test_func(self, **test_kwargs)
135135

@@ -175,7 +175,7 @@ def load_tests(loader, suite, pattern):
175175

176176

177177
class OperatorTest(unittest.TestCase):
178-
def _test_op(self, model, inputs, tester_factory):
178+
def _test_op(self, model, inputs, flow: TestFlow):
179179
context = get_active_test_context()
180180

181181
# This should be set in the wrapped test. See _make_wrapped_test above.
@@ -184,9 +184,8 @@ def _test_op(self, model, inputs, tester_factory):
184184
run_summary = run_test(
185185
model,
186186
inputs,
187-
tester_factory,
187+
flow,
188188
context.test_name,
189-
context.flow_name,
190189
context.params,
191190
)
192191

backends/test/suite/flow.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import logging
22

3-
from dataclasses import dataclass
3+
from dataclasses import dataclass, field
44
from typing import Callable
55

66
from executorch.backends.test.harness import Tester
7+
from executorch.backends.test.harness.stages import Quantize
78

89
logger = logging.getLogger(__name__)
910
logger.setLevel(logging.INFO)
@@ -22,41 +23,43 @@ class TestFlow:
2223
backend: str
2324
""" The name of the target backend. """
2425

25-
tester_factory: Callable[[], Tester]
26+
tester_factory: Callable[..., Tester]
2627
""" A factory function that returns a Tester instance for this lowering flow. """
2728

29+
quantize: bool = field(default=False)
30+
""" Whether to tester should run the quantize stage on the model. """
2831

29-
def create_xnnpack_flow() -> TestFlow | None:
30-
try:
31-
from executorch.backends.xnnpack.test.tester import Tester as XnnpackTester
32+
quantize_stage_factory: Callable[..., Quantize] | None = None
33+
""" A factory function which instantiates a Quantize stage. Can be None to use the tester's default. """
3234

33-
return TestFlow(
34-
name="xnnpack",
35-
backend="xnnpack",
36-
tester_factory=XnnpackTester,
37-
)
38-
except Exception:
39-
logger.info("Skipping XNNPACK flow registration due to import failure.")
40-
return None
4135

36+
def all_flows() -> dict[str, TestFlow]:
37+
flows = []
4238

43-
def create_coreml_flow() -> TestFlow | None:
4439
try:
45-
from executorch.backends.apple.coreml.test.tester import CoreMLTester
40+
from executorch.backends.test.suite.flows.xnnpack import (
41+
XNNPACK_STATIC_INT8_PER_CHANNEL_TEST_FLOW,
42+
XNNPACK_TEST_FLOW,
43+
)
4644

47-
return TestFlow(
48-
name="coreml",
49-
backend="coreml",
50-
tester_factory=CoreMLTester,
45+
flows += [
46+
XNNPACK_TEST_FLOW,
47+
XNNPACK_STATIC_INT8_PER_CHANNEL_TEST_FLOW,
48+
]
49+
except Exception as e:
50+
logger.info(f"Skipping XNNPACK flow registration: {e}")
51+
52+
try:
53+
from executorch.backends.test.suite.flows.coreml import (
54+
COREML_STATIC_INT8_TEST_FLOW,
55+
COREML_TEST_FLOW,
5156
)
52-
except Exception:
53-
logger.info("Skipping Core ML flow registration due to import failure.")
54-
return None
5557

58+
flows += [
59+
COREML_TEST_FLOW,
60+
COREML_STATIC_INT8_TEST_FLOW,
61+
]
62+
except Exception as e:
63+
logger.info(f"Skipping Core ML flow registration: {e}")
5664

57-
def all_flows() -> dict[str, TestFlow]:
58-
flows = [
59-
create_xnnpack_flow(),
60-
create_coreml_flow(),
61-
]
6265
return {f.name: f for f in flows if f is not None}

backends/test/suite/flows/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe

backends/test/suite/flows/coreml.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import functools
2+
from typing import Any
3+
4+
import coremltools
5+
6+
from executorch.backends.apple.coreml.test.tester import CoreMLTester
7+
from executorch.backends.test.suite.flow import TestFlow
8+
9+
10+
def _create_coreml_flow(
11+
name: str,
12+
quantize: bool = False,
13+
minimum_deployment_target: Any = coremltools.target.iOS15,
14+
) -> TestFlow:
15+
return TestFlow(
16+
name,
17+
backend="coreml",
18+
tester_factory=functools.partial(
19+
CoreMLTester, minimum_deployment_target=minimum_deployment_target
20+
),
21+
quantize=quantize,
22+
)
23+
24+
25+
COREML_TEST_FLOW = _create_coreml_flow("coreml")
26+
COREML_STATIC_INT8_TEST_FLOW = _create_coreml_flow(
27+
"coreml_static_int8",
28+
quantize=True,
29+
minimum_deployment_target=coremltools.target.iOS17,
30+
)

backends/test/suite/flows/xnnpack.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import logging
2+
from typing import Callable
3+
4+
from executorch.backends.test.harness.stages import Quantize
5+
from executorch.backends.test.suite.flow import TestFlow
6+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
7+
get_symmetric_quantization_config,
8+
)
9+
from executorch.backends.xnnpack.test.tester import (
10+
Quantize as XnnpackQuantize,
11+
Tester as XnnpackTester,
12+
)
13+
14+
logger = logging.getLogger(__name__)
15+
logger.setLevel(logging.INFO)
16+
17+
18+
def _create_xnnpack_flow_base(
19+
name: str, quantize_stage_factory: Callable[..., Quantize] | None = None
20+
) -> TestFlow:
21+
return TestFlow(
22+
name,
23+
backend="xnnpack",
24+
tester_factory=XnnpackTester,
25+
quantize=quantize_stage_factory is not None,
26+
quantize_stage_factory=quantize_stage_factory,
27+
)
28+
29+
30+
def _create_xnnpack_flow() -> TestFlow:
31+
return _create_xnnpack_flow_base("xnnpack")
32+
33+
34+
def _create_xnnpack_static_int8_per_channel_flow() -> TestFlow:
35+
def create_quantize_stage() -> Quantize:
36+
qparams = get_symmetric_quantization_config(is_per_channel=True)
37+
return XnnpackQuantize(
38+
quantization_config=qparams,
39+
)
40+
41+
return _create_xnnpack_flow_base("xnnpack_static_int8_per_channel", create_quantize_stage)
42+
43+
44+
XNNPACK_TEST_FLOW = _create_xnnpack_flow()
45+
XNNPACK_STATIC_INT8_PER_CHANNEL_TEST_FLOW = _create_xnnpack_static_int8_per_channel_flow()

backends/test/suite/models/__init__.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from typing import Any, Callable
1313

1414
import torch
15-
from executorch.backends.test.harness import Tester
1615
from executorch.backends.test.suite import get_test_flows
1716
from executorch.backends.test.suite.context import get_active_test_context, TestContext
1817
from executorch.backends.test.suite.flow import TestFlow
@@ -49,7 +48,7 @@ def wrapped_test(self):
4948
"use_dynamic_shapes": use_dynamic_shapes,
5049
}
5150
with TestContext(test_name, flow.name, params):
52-
test_func(self, dtype, use_dynamic_shapes, flow.tester_factory)
51+
test_func(self, flow, dtype, use_dynamic_shapes)
5352

5453
dtype_name = str(dtype)[6:] # strip "torch."
5554
test_name = f"{test_func.__name__}_{flow.name}_{dtype_name}"
@@ -104,9 +103,9 @@ def inner_decorator(func: Callable) -> Callable:
104103
def run_model_test(
105104
model: torch.nn.Module,
106105
inputs: tuple[Any],
106+
flow: TestFlow,
107107
dtype: torch.dtype,
108108
dynamic_shapes: Any | None,
109-
tester_factory: Callable[[], Tester],
110109
):
111110
model = model.to(dtype)
112111
context = get_active_test_context()
@@ -117,9 +116,8 @@ def run_model_test(
117116
run_summary = run_test(
118117
model,
119118
inputs,
120-
tester_factory,
119+
flow,
121120
context.test_name,
122-
context.flow_name,
123121
context.params,
124122
dynamic_shapes=dynamic_shapes,
125123
)

0 commit comments

Comments
 (0)