Skip to content

Commit 9f99e94

Browse files
committed
Add u8 tests
1 parent 1d682f3 commit 9f99e94

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

src/frontends/pytorch/src/op/where.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ OutputVector translate_where(const NodeContext& context) {
3333
OutputVector result;
3434
if (ndim > 0) {
3535
auto split = context.mark_node(std::make_shared<v1::Split>(non_zero, axis, ndim));
36-
for (size_t i = 0; i < ndim; ++i) {
36+
for (int64_t i = 0; i < ndim; ++i) {
3737
result.push_back(context.mark_node(std::make_shared<v0::Squeeze>(split->output(i), axis)));
3838
}
3939
}

tests/layer_tests/pytorch_tests/test_where.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,12 @@ class aten_where_as_nonzero(torch.nn.Module):
4646
def forward(self, cond):
4747
return torch.where(cond)
4848

49+
class aten_where_as_nonzero_getitem(torch.nn.Module):
50+
def forward(self, cond: torch.Tensor):
51+
return torch.where(cond)[0]
4952

53+
if as_non_zero == 'scripted':
54+
return aten_where_as_nonzero_getitem(), "aten::where"
5055
if as_non_zero:
5156
return aten_where_as_nonzero(), "aten::where"
5257
return aten_where(torch_dtypes), "aten::where"
@@ -91,6 +96,24 @@ def test_where_as_nonzero(self, mask_fill, mask_dtype, x_dtype, ie_device, preci
9196
},
9297
trace_model=True)
9398

99+
@pytest.mark.parametrize(
100+
"mask_fill", ['zeros', 'ones', 'random'])
101+
@pytest.mark.parametrize("cond_dtype", [np.float32, np.int32])
102+
@pytest.mark.nightly
103+
@pytest.mark.precommit
104+
def test_where_as_nonzero_nonbool_cond(self, mask_fill, cond_dtype, ie_device, precision, ir_version):
105+
# aten::where(cond) must accept non-boolean condition tensors (float32, int32).
106+
# NonZero treats zero as False and any nonzero value as True, matching PyTorch semantics.
107+
self._test(*self.create_model(True),
108+
ie_device, precision, ir_version,
109+
kwargs_to_prepare_input={
110+
'mask_fill': mask_fill,
111+
'mask_dtype': cond_dtype,
112+
'return_x_y': False,
113+
'x_dtype': "float32",
114+
},
115+
trace_model=True)
116+
94117
@pytest.mark.parametrize(
95118
"mask_fill", ['zeros', 'ones', 'random'])
96119
@pytest.mark.parametrize("mask_dtype", [bool])
@@ -108,3 +131,24 @@ def test_where_as_nonzero_export(self, mask_fill, mask_dtype, x_dtype, ie_device
108131
'return_x_y': False,
109132
"x_dtype": x_dtype,
110133
})
134+
135+
@pytest.mark.parametrize("mask_fill", ['zeros', 'ones', 'random'])
136+
@pytest.mark.parametrize("x_dtype", ["float32", "int32"])
137+
@pytest.mark.nightly
138+
@pytest.mark.precommit
139+
def test_where_as_nonzero_scripted(self, mask_fill, x_dtype, ie_device, precision, ir_version):
140+
# Tests translate_where's input_is_none(1) branch via torch.jit.script.
141+
# The model returns torch.where(cond)[0], producing aten::where → aten::__getitem__
142+
self._test(*self.create_model('scripted'),
143+
ie_device, precision, ir_version,
144+
kwargs_to_prepare_input={
145+
'mask_fill': mask_fill,
146+
'mask_dtype': bool,
147+
'return_x_y': False,
148+
'x_dtype': x_dtype,
149+
},
150+
trace_model=False,
151+
dynamic_shapes=False,
152+
use_convert_model=True,
153+
freeze_model=False)
154+

0 commit comments

Comments
 (0)