diff --git a/backends/qualcomm/builders/op_min.py b/backends/qualcomm/builders/op_min.py index 28c766cffb5..64e41a99464 100644 --- a/backends/qualcomm/builders/op_min.py +++ b/backends/qualcomm/builders/op_min.py @@ -16,7 +16,7 @@ @register_node_visitor class Min(NodeVisitor): - target = ["aten.minimum.default"] + target = ["aten.minimum.default", "aten.min.dim"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 6e396c69c54..4843e1abd37 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1162,6 +1162,15 @@ def forward(self, x, y): return torch.minimum(x, y) +class MinDim(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, logits): + min_logits, min_indices = torch.min(logits, dim=1) + return min_logits, min_indices + + class Mul(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 444bb10c74f..f8d8eaa5964 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -884,6 +884,11 @@ def test_qnn_backend_minimum(self): sample_input = (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4)) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_min_dim(self): + module = MinDim() # noqa: F405 + sample_input = (torch.randn(4, 10), ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_neg(self): module = Neg() # noqa: F405 sample_input = (torch.randn(1, 4, 16, 16),)