Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions torchrec/modules/tests/test_itep_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,11 @@ def generate_expected_address_lookup_buffer(

return torch.tensor(address_lookup, dtype=torch.int64)

# pyre-ignore[56]: Pyre was not able to infer the type of argument
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
def test_init_itep_module(self) -> None:
itep_module = GenericITEPModule(
table_name_to_unpruned_hash_sizes=self._table_name_to_unpruned_hash_sizes,
Expand Down Expand Up @@ -222,6 +227,11 @@ def test_init_itep_module(self) -> None:
equal_nan=True,
)

# pyre-ignore[56]: Pyre was not able to infer the type of argument
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
def test_init_itep_module_without_pruned_table(self) -> None:
itep_module = GenericITEPModule(
table_name_to_unpruned_hash_sizes={},
Expand Down Expand Up @@ -353,6 +363,11 @@ def test_eval_forward(
# Check that reset_weight_momentum is not called
self.assertEqual(mock_reset_weight_momentum.call_count, 0)

# pyre-ignore[56]: Pyre was not able to infer the type of argument
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
def test_iter_increment_per_forward(self) -> None:
"""Test that the iteration counter increments correctly with each forward pass."""
itep_module = GenericITEPModule(
Expand Down
Loading