Skip to content

Commit 1c83f61

Browse files
cccclaiDannyYuyang-quic
authored andcommitted
Add support for aten.min.dim (pytorch#12623)
Summary: Rollback Plan: Differential Revision: D78526966
1 parent b7eff08 commit 1c83f61

File tree

3 files changed

+15
-1
lines changed

3 files changed

+15
-1
lines changed

backends/qualcomm/builders/op_min.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@register_node_visitor
1818
class Min(NodeVisitor):
19-
target = ["aten.minimum.default"]
19+
target = ["aten.minimum.default", "aten.min.dim"]
2020

2121
def __init__(self, *args) -> None:
2222
super().__init__(*args)

backends/qualcomm/tests/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,6 +1162,15 @@ def forward(self, x, y):
11621162
return torch.minimum(x, y)
11631163

11641164

1165+
class MinDim(torch.nn.Module):
1166+
def __init__(self):
1167+
super().__init__()
1168+
1169+
def forward(self, logits):
1170+
min_logits, min_indices = torch.min(logits, dim=1)
1171+
return min_logits, min_indices
1172+
1173+
11651174
class Mul(torch.nn.Module):
11661175
def __init__(self):
11671176
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,11 @@ def test_qnn_backend_minimum(self):
884884
sample_input = (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4))
885885
self.lower_module_and_test_output(module, sample_input)
886886

887+
def test_qnn_backend_min_dim(self):
888+
module = MinDim() # noqa: F405
889+
sample_input = (torch.randn(4, 10), )
890+
self.lower_module_and_test_output(module, sample_input)
891+
887892
def test_qnn_backend_neg(self):
888893
module = Neg() # noqa: F405
889894
sample_input = (torch.randn(1, 4, 16, 16),)

0 commit comments

Comments
 (0)