@@ -46,7 +46,12 @@ class aten_where_as_nonzero(torch.nn.Module):
4646 def forward (self , cond ):
4747 return torch .where (cond )
4848
49+ class aten_where_as_nonzero_getitem (torch .nn .Module ):
50+ def forward (self , cond : torch .Tensor ):
51+ return torch .where (cond )[0 ]
4952
53+ if as_non_zero == 'scripted' :
54+ return aten_where_as_nonzero_getitem (), "aten::where"
5055 if as_non_zero :
5156 return aten_where_as_nonzero (), "aten::where"
5257 return aten_where (torch_dtypes ), "aten::where"
@@ -91,6 +96,24 @@ def test_where_as_nonzero(self, mask_fill, mask_dtype, x_dtype, ie_device, preci
9196 },
9297 trace_model = True )
9398
99+ @pytest .mark .parametrize (
100+ "mask_fill" , ['zeros' , 'ones' , 'random' ])
101+ @pytest .mark .parametrize ("cond_dtype" , [np .float32 , np .int32 ])
102+ @pytest .mark .nightly
103+ @pytest .mark .precommit
104+ def test_where_as_nonzero_nonbool_cond (self , mask_fill , cond_dtype , ie_device , precision , ir_version ):
105+ # aten::where(cond) must accept non-boolean condition tensors (float32, int32).
106+ # NonZero treats zero as False and any nonzero value as True, matching PyTorch semantics.
107+ self ._test (* self .create_model (True ),
108+ ie_device , precision , ir_version ,
109+ kwargs_to_prepare_input = {
110+ 'mask_fill' : mask_fill ,
111+ 'mask_dtype' : cond_dtype ,
112+ 'return_x_y' : False ,
113+ 'x_dtype' : "float32" ,
114+ },
115+ trace_model = True )
116+
94117 @pytest .mark .parametrize (
95118 "mask_fill" , ['zeros' , 'ones' , 'random' ])
96119 @pytest .mark .parametrize ("mask_dtype" , [bool ])
@@ -108,3 +131,24 @@ def test_where_as_nonzero_export(self, mask_fill, mask_dtype, x_dtype, ie_device
108131 'return_x_y' : False ,
109132 "x_dtype" : x_dtype ,
110133 })
134+
135+ @pytest .mark .parametrize ("mask_fill" , ['zeros' , 'ones' , 'random' ])
136+ @pytest .mark .parametrize ("x_dtype" , ["float32" , "int32" ])
137+ @pytest .mark .nightly
138+ @pytest .mark .precommit
139+ def test_where_as_nonzero_scripted (self , mask_fill , x_dtype , ie_device , precision , ir_version ):
140+ # Tests translate_where's input_is_none(1) branch via torch.jit.script.
141+ # The model returns torch.where(cond)[0], producing aten::where → aten::__getitem__
142+ self ._test (* self .create_model ('scripted' ),
143+ ie_device , precision , ir_version ,
144+ kwargs_to_prepare_input = {
145+ 'mask_fill' : mask_fill ,
146+ 'mask_dtype' : bool ,
147+ 'return_x_y' : False ,
148+ 'x_dtype' : x_dtype ,
149+ },
150+ trace_model = False ,
151+ dynamic_shapes = False ,
152+ use_convert_model = True ,
153+ freeze_model = False )
154+
0 commit comments