Skip to content

Commit 7d38c01

Browse files
committed
add
1 parent 3d6fd01 commit 7d38c01

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -979,7 +979,7 @@ def gt(context, node):
979979
context.add(greater)
980980

981981

982-
@register_torch_op(torch_alias=["t", "numpy_t"])
982+
@register_torch_op(torch_alias=["t", "numpy_t", "transpose_copy"])
983983
def transpose(context, node):
984984
assert len(node.outputs) == 1
985985
inputs = _get_inputs(context, node)

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7166,6 +7166,20 @@ def test(self, compute_unit, backend, frontend, shape, dims):
71667166
)
71677167

71687168

7169+
class TestTransposeCopy(TorchBaseTest):
7170+
@pytest.mark.parametrize(
7171+
"compute_unit, backend, frontend, shape, dims",
7172+
itertools.product(
7173+
compute_units, backends, frontends, COMMON_SHAPES, [(0, 1), (-2, -1), (1, 0), (-1, -2)]
7174+
),
7175+
)
7176+
def test(self, compute_unit, backend, frontend, shape, dims):
7177+
model = ModuleWrapper(function=torch.transpose_copy, kwargs={"dim0": dims[0], "dim1": dims[1]})
7178+
self.run_compare_torch(
7179+
shape, model, compute_unit=compute_unit, backend=backend, frontend=frontend
7180+
)
7181+
7182+
71697183
class TestTo(TorchBaseTest):
71707184
@pytest.mark.parametrize(
71717185
"compute_unit, backend, frontend",

0 commit comments

Comments
 (0)