@@ -190,6 +190,11 @@ def generate_expected_address_lookup_buffer(
190
190
191
191
return torch .tensor (address_lookup , dtype = torch .int64 )
192
192
193
+ # pyre-ignore[56]: Pyre was not able to infer the type of argument
194
+ @unittest .skipIf (
195
+ torch .cuda .device_count () <= 1 ,
196
+ "Not enough GPUs, this test requires at least two GPUs" ,
197
+ )
193
198
def test_init_itep_module (self ) -> None :
194
199
itep_module = GenericITEPModule (
195
200
table_name_to_unpruned_hash_sizes = self ._table_name_to_unpruned_hash_sizes ,
@@ -222,6 +227,11 @@ def test_init_itep_module(self) -> None:
222
227
equal_nan = True ,
223
228
)
224
229
230
+ # pyre-ignore[56]: Pyre was not able to infer the type of argument
231
+ @unittest .skipIf (
232
+ torch .cuda .device_count () <= 1 ,
233
+ "Not enough GPUs, this test requires at least two GPUs" ,
234
+ )
225
235
def test_init_itep_module_without_pruned_table (self ) -> None :
226
236
itep_module = GenericITEPModule (
227
237
table_name_to_unpruned_hash_sizes = {},
@@ -353,6 +363,11 @@ def test_eval_forward(
353
363
# Check that reset_weight_momentum is not called
354
364
self .assertEqual (mock_reset_weight_momentum .call_count , 0 )
355
365
366
+ # pyre-ignore[56]: Pyre was not able to infer the type of argument
367
+ @unittest .skipIf (
368
+ torch .cuda .device_count () <= 1 ,
369
+ "Not enough GPUs, this test requires at least two GPUs" ,
370
+ )
356
371
def test_iter_increment_per_forward (self ) -> None :
357
372
"""Test that the iteration counter increments correctly with each forward pass."""
358
373
itep_module = GenericITEPModule (
0 commit comments