Skip to content

Commit a41cad5

Browse files
Fei Yufacebook-github-bot
authored andcommitted
add back skip CPU containter test when using cuda device (#3385)
Summary: as title, missed few tests without skip decorator causing CI CPU unit test to fail, somehow previous diff didn't reveal these ones that actually failed on CI, so adding these back separately. (https://www.internalfb.com/diff/D67302872?dst_version_fbid=815631520895451&transaction_fbid=1481902306176330) Reviewed By: spmex Differential Revision: D82771497
1 parent aeb3ffb commit a41cad5

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

torchrec/modules/tests/test_itep_embedding_modules.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@ def generate_expected_address_lookup_buffer(
190190

191191
return torch.tensor(address_lookup, dtype=torch.int64)
192192

193+
# pyre-ignore[56]: Pyre was not able to infer the type of argument
194+
@unittest.skipIf(
195+
torch.cuda.device_count() <= 1,
196+
"Not enough GPUs, this test requires at least two GPUs",
197+
)
193198
def test_init_itep_module(self) -> None:
194199
itep_module = GenericITEPModule(
195200
table_name_to_unpruned_hash_sizes=self._table_name_to_unpruned_hash_sizes,
@@ -222,6 +227,11 @@ def test_init_itep_module(self) -> None:
222227
equal_nan=True,
223228
)
224229

230+
# pyre-ignore[56]: Pyre was not able to infer the type of argument
231+
@unittest.skipIf(
232+
torch.cuda.device_count() <= 1,
233+
"Not enough GPUs, this test requires at least two GPUs",
234+
)
225235
def test_init_itep_module_without_pruned_table(self) -> None:
226236
itep_module = GenericITEPModule(
227237
table_name_to_unpruned_hash_sizes={},
@@ -353,6 +363,11 @@ def test_eval_forward(
353363
# Check that reset_weight_momentum is not called
354364
self.assertEqual(mock_reset_weight_momentum.call_count, 0)
355365

366+
# pyre-ignore[56]: Pyre was not able to infer the type of argument
367+
@unittest.skipIf(
368+
torch.cuda.device_count() <= 1,
369+
"Not enough GPUs, this test requires at least two GPUs",
370+
)
356371
def test_iter_increment_per_forward(self) -> None:
357372
"""Test that the iteration counter increments correctly with each forward pass."""
358373
itep_module = GenericITEPModule(

0 commit comments

Comments
 (0)