From 23f09daaf56727d0d9f37bc407fef84620a9eafe Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Thu, 17 Jul 2025 17:15:32 -0700 Subject: [PATCH] Add support for aten.min.dim (#12623) Summary: Rollback Plan: Differential Revision: D78526966 --- backends/qualcomm/builders/op_min.py | 2 +- backends/qualcomm/tests/models.py | 9 +++++++++ backends/qualcomm/tests/test_qnn_delegate.py | 5 +++++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/backends/qualcomm/builders/op_min.py b/backends/qualcomm/builders/op_min.py index 28c766cffb5..64e41a99464 100644 --- a/backends/qualcomm/builders/op_min.py +++ b/backends/qualcomm/builders/op_min.py @@ -16,7 +16,7 @@ @register_node_visitor class Min(NodeVisitor): - target = ["aten.minimum.default"] + target = ["aten.minimum.default", "aten.min.dim"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 6e396c69c54..4843e1abd37 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1162,6 +1162,15 @@ def forward(self, x, y): return torch.minimum(x, y) +class MinDim(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, logits): + min_logits, min_indices = torch.min(logits, dim=1) + return min_logits, min_indices + + class Mul(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 444bb10c74f..f8d8eaa5964 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -884,6 +884,11 @@ def test_qnn_backend_minimum(self): sample_input = (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4)) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_min_dim(self): + module = MinDim() # noqa: F405 + sample_input = (torch.randn(4, 10), ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_neg(self): module = Neg() # noqa: F405 sample_input = (torch.randn(1, 4, 16, 16),)