- 
                Notifications
    You must be signed in to change notification settings 
- Fork 706
[Backend Tester] Add quantized test flows for XNNPACK and Core ML #12733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 24 commits
f120e70
              0fb85e6
              4d8d844
              dc12b40
              ead0616
              0f13676
              b0b01f2
              8b9c9ef
              06bf03a
              2f8f49b
              8ca7766
              bffb95f
              d21492b
              e2c4ea5
              8230848
              2a1f564
              b35e7b1
              5c4c6ce
              9397803
              9dfeb5a
              42a5de5
              402d8f5
              34d3ab3
              7ef236b
              81dfb07
              4d50265
              5f66043
              89757ce
              7a2fab5
              27cd171
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -4,23 +4,64 @@ | |
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|  | ||
| from typing import Any, List, Optional, Tuple | ||
| from typing import Any, List, Optional, Sequence, Tuple | ||
|  | ||
| import coremltools as ct | ||
| import executorch | ||
| import executorch.backends.test.harness.stages as BaseStages | ||
|  | ||
| import functools | ||
| import torch | ||
|  | ||
| from executorch.backends.apple.coreml.compiler import CoreMLBackend | ||
| from executorch.backends.apple.coreml.partition import CoreMLPartitioner | ||
| from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer | ||
| from executorch.backends.test.harness import Tester as TesterBase | ||
| from executorch.backends.test.harness.stages import StageType | ||
| from executorch.exir import EdgeCompileConfig | ||
| from executorch.exir.backend.partitioner import Partitioner | ||
|  | ||
|  | ||
| def _get_static_int8_qconfig(): | ||
|         
                  GregoryComer marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| return ct.optimize.torch.quantization.LinearQuantizerConfig( | ||
| global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig( | ||
| quantization_scheme="symmetric", | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this the main int8 schema we should be testing for Linear @metascroy There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI This is pulled directly from our docs at https://docs.pytorch.org/executorch/main/backends-coreml.html#bit-quantization-using-the-pt2e-flow. Would be good to sanity check with Scott, though. | ||
| activation_dtype=torch.quint8, | ||
| weight_dtype=torch.qint8, | ||
| weight_per_channel=True, | ||
| ) | ||
| ) | ||
|  | ||
|  | ||
| class Quantize(BaseStages.Quantize): | ||
| def __init__( | ||
| self, | ||
| quantizer: Optional[CoreMLQuantizer] = None, | ||
| quantization_config: Optional[Any] = None, | ||
| calibrate: bool = True, | ||
| calibration_samples: Optional[Sequence[Any]] = None, | ||
| is_qat: Optional[bool] = False, | ||
| ): | ||
| super().__init__( | ||
| quantizer=quantizer or CoreMLQuantizer(quantization_config or _get_static_int8_qconfig()), | ||
| calibrate=calibrate, | ||
| calibration_samples=calibration_samples, | ||
| is_qat=is_qat, | ||
| ) | ||
|  | ||
|  | ||
|  | ||
| class Partition(BaseStages.Partition): | ||
| def __init__(self, partitioner: Optional[Partitioner] = None): | ||
| def __init__( | ||
| self, | ||
| partitioner: Optional[Partitioner] = None, | ||
| minimum_deployment_target: Optional[Any] = ct.target.iOS15, | ||
| ): | ||
| super().__init__( | ||
| partitioner=partitioner or CoreMLPartitioner, | ||
| partitioner=partitioner or CoreMLPartitioner( | ||
| compile_specs=CoreMLBackend.generate_compile_specs( | ||
| minimum_deployment_target=minimum_deployment_target | ||
| ) | ||
| ), | ||
| ) | ||
|  | ||
|  | ||
|  | @@ -29,9 +70,14 @@ def __init__( | |
| self, | ||
| partitioners: Optional[List[Partitioner]] = None, | ||
| edge_compile_config: Optional[EdgeCompileConfig] = None, | ||
| minimum_deployment_target: Optional[Any] = ct.target.iOS15, | ||
| ): | ||
| super().__init__( | ||
| default_partitioner_cls=CoreMLPartitioner, | ||
| default_partitioner_cls=lambda: CoreMLPartitioner( | ||
|         
                  GregoryComer marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| compile_specs=CoreMLBackend.generate_compile_specs( | ||
| minimum_deployment_target=minimum_deployment_target | ||
| ) | ||
| ), | ||
| partitioners=partitioners, | ||
| edge_compile_config=edge_compile_config, | ||
| ) | ||
|  | @@ -43,13 +89,15 @@ def __init__( | |
| module: torch.nn.Module, | ||
| example_inputs: Tuple[torch.Tensor], | ||
| dynamic_shapes: Optional[Tuple[Any]] = None, | ||
| minimum_deployment_target: Optional[Any] = ct.target.iOS15, | ||
| ): | ||
| # Specialize for XNNPACK | ||
| stage_classes = ( | ||
| executorch.backends.test.harness.Tester.default_stage_classes() | ||
| | { | ||
| StageType.PARTITION: Partition, | ||
| StageType.TO_EDGE_TRANSFORM_AND_LOWER: ToEdgeTransformAndLower, | ||
| StageType.QUANTIZE: Quantize, | ||
| StageType.PARTITION: functools.partial(Partition, minimum_deployment_target=minimum_deployment_target), | ||
| StageType.TO_EDGE_TRANSFORM_AND_LOWER: functools.partial(ToEdgeTransformAndLower, minimum_deployment_target=minimum_deployment_target), | ||
| } | ||
| ) | ||
|  | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -1,9 +1,10 @@ | ||
| import logging | ||
|  | ||
| from dataclasses import dataclass | ||
| from dataclasses import dataclass, field | ||
| from typing import Callable | ||
|  | ||
| from executorch.backends.test.harness import Tester | ||
| from executorch.backends.test.harness.stages import Quantize | ||
|  | ||
| logger = logging.getLogger(__name__) | ||
| logger.setLevel(logging.INFO) | ||
|  | @@ -21,42 +22,35 @@ class TestFlow: | |
|  | ||
| backend: str | ||
| """ The name of the target backend. """ | ||
|  | ||
| tester_factory: Callable[[], Tester] | ||
| tester_factory: Callable[..., Tester] | ||
| """ A factory function that returns a Tester instance for this lowering flow. """ | ||
|  | ||
| quantize: bool = field(default=False) | ||
| """ Whether to tester should run the quantize stage on the model. """ | ||
|  | ||
| quantize_stage_factory: Callable[..., Quantize] | None = None | ||
| """ A factory function which instantiates a Quantize stage. Can be None to use the tester's default. """ | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why an extra flag? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The specific reason is that if quantize_stage_factory isn't provided, it will use the default Quantize stage from the tester. I could maybe just always require the caller to provide quantize_stage_factory. | ||
|  | ||
| def create_xnnpack_flow() -> TestFlow | None: | ||
| def all_flows() -> dict[str, TestFlow]: | ||
| flows = [] | ||
|  | ||
| try: | ||
| from executorch.backends.xnnpack.test.tester import Tester as XnnpackTester | ||
|  | ||
| return TestFlow( | ||
| name="xnnpack", | ||
| backend="xnnpack", | ||
| tester_factory=XnnpackTester, | ||
| ) | ||
| except Exception: | ||
| logger.info("Skipping XNNPACK flow registration due to import failure.") | ||
| return None | ||
|  | ||
| from executorch.backends.test.suite.flows.xnnpack import XNNPACK_TEST_FLOW, XNNPACK_STATIC_INT8_TEST_FLOW | ||
|          | ||
| flows += [ | ||
| XNNPACK_TEST_FLOW, | ||
| XNNPACK_STATIC_INT8_TEST_FLOW, | ||
| ] | ||
| except Exception as e: | ||
| logger.info(f"Skipping XNNPACK flow registration: {e}") | ||
|  | ||
| def create_coreml_flow() -> TestFlow | None: | ||
| try: | ||
| from executorch.backends.apple.coreml.test.tester import CoreMLTester | ||
| from executorch.backends.test.suite.flows.coreml import COREML_TEST_FLOW, COREML_STATIC_INT8_TEST_FLOW | ||
| flows += [ | ||
| COREML_TEST_FLOW, | ||
| COREML_STATIC_INT8_TEST_FLOW, | ||
| ] | ||
| except Exception as e: | ||
| logger.info(f"Skipping Core ML flow registration: {e}") | ||
|  | ||
| return TestFlow( | ||
| name="coreml", | ||
| backend="coreml", | ||
| tester_factory=CoreMLTester, | ||
| ) | ||
| except Exception: | ||
| logger.info("Skipping Core ML flow registration due to import failure.") | ||
| return None | ||
|  | ||
|  | ||
| def all_flows() -> dict[str, TestFlow]: | ||
| flows = [ | ||
| create_xnnpack_flow(), | ||
| create_coreml_flow(), | ||
| ] | ||
| return {f.name: f for f in flows if f is not None} | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|  | ||
| # pyre-unsafe | 
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| import coremltools | ||
| import functools | ||
|  | ||
| from executorch.backends.apple.coreml.test.tester import CoreMLTester | ||
| from executorch.backends.test.suite.flow import TestFlow | ||
| from typing import Any | ||
|  | ||
| def _create_coreml_flow( | ||
| name: str, | ||
| quantize: bool = False, | ||
| minimum_deployment_target: Any = coremltools.target.iOS15 | ||
| ) -> TestFlow: | ||
| return TestFlow( | ||
| name, | ||
| backend="coreml", | ||
| tester_factory=functools.partial(CoreMLTester, minimum_deployment_target=minimum_deployment_target), | ||
| quantize=quantize, | ||
| ) | ||
|  | ||
| COREML_TEST_FLOW = _create_coreml_flow("coreml") | ||
| COREML_STATIC_INT8_TEST_FLOW = _create_coreml_flow( | ||
| "coreml_static_int8", | ||
| quantize=True, | ||
| minimum_deployment_target=coremltools.target.iOS17) | 
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| from executorch.backends.test.harness.stages import Quantize | ||
| from executorch.backends.test.suite.flow import TestFlow | ||
| from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import get_symmetric_quantization_config | ||
| from executorch.backends.xnnpack.test.tester import ( | ||
| Quantize as XnnpackQuantize, | ||
| Tester as XnnpackTester | ||
| ) | ||
| from typing import Callable | ||
|  | ||
| import logging | ||
|  | ||
| logger = logging.getLogger(__name__) | ||
| logger.setLevel(logging.INFO) | ||
|  | ||
| def _create_xnnpack_flow_base(name: str, quantize_stage_factory: Callable[..., Quantize] | None = None) -> TestFlow: | ||
| return TestFlow( | ||
| name, | ||
| backend="xnnpack", | ||
| tester_factory=XnnpackTester, | ||
| quantize=quantize_stage_factory is not None, | ||
| quantize_stage_factory=quantize_stage_factory, | ||
| ) | ||
|  | ||
| def _create_xnnpack_flow() -> TestFlow: | ||
| return _create_xnnpack_flow_base("xnnpack") | ||
|  | ||
| def _create_xnnpack_static_int8_flow() -> TestFlow: | ||
| def create_quantize_stage() -> Quantize: | ||
| qparams = get_symmetric_quantization_config(is_per_channel=True) | ||
| return XnnpackQuantize( | ||
| quantization_config=qparams, | ||
| ) | ||
| return _create_xnnpack_flow_base("xnnpack_static_int8", create_quantize_stage) | ||
|         
                  GregoryComer marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
|  | ||
| XNNPACK_TEST_FLOW = _create_xnnpack_flow() | ||
| XNNPACK_STATIC_INT8_TEST_FLOW = _create_xnnpack_static_int8_flow() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
come on :p