diff --git a/torchrec/modules/tests/test_itep_embedding_modules.py b/torchrec/modules/tests/test_itep_embedding_modules.py index fc089f631..1f1a4ee26 100644 --- a/torchrec/modules/tests/test_itep_embedding_modules.py +++ b/torchrec/modules/tests/test_itep_embedding_modules.py @@ -190,6 +190,11 @@ def generate_expected_address_lookup_buffer( return torch.tensor(address_lookup, dtype=torch.int64) + # pyre-ignore[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) def test_init_itep_module(self) -> None: itep_module = GenericITEPModule( table_name_to_unpruned_hash_sizes=self._table_name_to_unpruned_hash_sizes, @@ -222,6 +227,11 @@ def test_init_itep_module(self) -> None: equal_nan=True, ) + # pyre-ignore[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) def test_init_itep_module_without_pruned_table(self) -> None: itep_module = GenericITEPModule( table_name_to_unpruned_hash_sizes={}, @@ -353,6 +363,11 @@ def test_eval_forward( # Check that reset_weight_momentum is not called self.assertEqual(mock_reset_weight_momentum.call_count, 0) + # pyre-ignore[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) def test_iter_increment_per_forward(self) -> None: """Test that the iteration counter increments correctly with each forward pass.""" itep_module = GenericITEPModule(