diff --git a/backends/apple/coreml/compiler/coreml_preprocess.py b/backends/apple/coreml/compiler/coreml_preprocess.py index e9afd819d94..bf390698705 100644 --- a/backends/apple/coreml/compiler/coreml_preprocess.py +++ b/backends/apple/coreml/compiler/coreml_preprocess.py @@ -28,6 +28,8 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) +from executorch.backends.apple.coreml.compiler.torch_ops import * # noqa: F401, F403 + class COMPILE_SPEC_KEYS(Enum): COMPUTE_UNITS = "compute_units" diff --git a/backends/apple/coreml/compiler/torch_ops.py b/backends/apple/coreml/compiler/torch_ops.py new file mode 100644 index 00000000000..32facd7cd61 --- /dev/null +++ b/backends/apple/coreml/compiler/torch_ops.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file registers torch ops that are not yet in coremltools, or are in a more recent version of +# coremltools than is used by ExecuTorch. Each op registered here should have a link to a PR in coremltools that adds +# the op to the coremltools library. + +from coremltools.converters.mil.frontend.torch.ops import transpose, unbind +from coremltools.converters.mil.frontend.torch.torch_op_registry import ( + register_torch_op, +) + + +# https://github.com/apple/coremltools/pull/2556 +@register_torch_op(override=False) +def transpose_copy(context, node): + transpose(context, node) + + +# https://github.com/apple/coremltools/pull/2557 +@register_torch_op(override=False) +def unbind_copy(context, node): + unbind(context, node)