Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions torchrec/pt2/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch

from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true

USE_TORCHDYNAMO_COMPILING_PATH: bool = False

Expand Down Expand Up @@ -91,8 +91,15 @@ def pt2_check_size_nonzero(x: torch.Tensor) -> torch.Tensor:
return x


def pt2_guard_size_oblivious(x: bool) -> bool:
def pt2_guard_or_false(x: bool) -> bool:
if torch.jit.is_scripting() or not is_pt2_compiling():
return x

return guard_size_oblivious(x)
return guard_or_false(x)


def pt2_guard_or_true(x: bool) -> bool:
if torch.jit.is_scripting() or not is_pt2_compiling():
return x

return guard_or_true(x)
9 changes: 5 additions & 4 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
pt2_check_size_nonzero,
pt2_checks_all_is_size,
pt2_checks_tensor_slice,
pt2_guard_size_oblivious,
pt2_guard_or_false,
pt2_guard_or_true,
)
from torchrec.streamable import Pipelineable

Expand Down Expand Up @@ -1071,7 +1072,7 @@ def _assert_tensor_has_no_elements_or_has_integers(
# TODO(ivankobzarev): Use guard_size_oblivious to pass tensor.numel() == 0 once it is torch scriptable.
return

assert pt2_guard_size_oblivious(tensor.numel() == 0) or tensor.dtype in [
assert pt2_guard_or_false(tensor.numel() == 0) or tensor.dtype in [
torch.long,
torch.int,
torch.short,
Expand Down Expand Up @@ -1206,7 +1207,7 @@ def _maybe_compute_length_per_key(
torch.sum(
pt2_check_size_nonzero(lengths.view(len(keys), stride)), dim=1
).tolist()
if pt2_guard_size_oblivious(lengths.numel() != 0)
if pt2_guard_or_true(lengths.numel() != 0)
else [0] * len(keys)
)
)
Expand Down Expand Up @@ -1425,7 +1426,7 @@ def _maybe_compute_kjt_to_jt_dict(
torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
for lengths in split_lengths
]
elif pt2_guard_size_oblivious(lengths.numel() > 0):
elif pt2_guard_or_true(lengths.numel() > 0):
strided_lengths = lengths.view(len(keys), stride)
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
torch._check(strided_lengths.size(0) > 0)
Expand Down