From 6e8e6c001fde8aa1a477ca7e73843f484bbcba71 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Tue, 7 Oct 2025 20:56:15 -0700 Subject: [PATCH] remove guard_size_oblivious from torchrec jagged tensors. (#3431) Summary: keep intended semantics but use guard_or invariants. guard_size_oblivious will be deprecated soon. Reviewed By: TroyGarden Differential Revision: D83885644 --- torchrec/pt2/checks.py | 13 ++++++++++--- torchrec/sparse/jagged_tensor.py | 9 +++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/torchrec/pt2/checks.py b/torchrec/pt2/checks.py index 76626a9f8..e997d39d2 100644 --- a/torchrec/pt2/checks.py +++ b/torchrec/pt2/checks.py @@ -11,7 +11,7 @@ import torch -from torch.fx.experimental.symbolic_shapes import guard_size_oblivious +from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true USE_TORCHDYNAMO_COMPILING_PATH: bool = False @@ -91,8 +91,15 @@ def pt2_check_size_nonzero(x: torch.Tensor) -> torch.Tensor: return x -def pt2_guard_size_oblivious(x: bool) -> bool: +def pt2_guard_or_false(x: bool) -> bool: if torch.jit.is_scripting() or not is_pt2_compiling(): return x - return guard_size_oblivious(x) + return guard_or_false(x) + + +def pt2_guard_or_true(x: bool) -> bool: + if torch.jit.is_scripting() or not is_pt2_compiling(): + return x + + return guard_or_true(x) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index db1a26aba..0c1905387 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -25,7 +25,8 @@ pt2_check_size_nonzero, pt2_checks_all_is_size, pt2_checks_tensor_slice, - pt2_guard_size_oblivious, + pt2_guard_or_false, + pt2_guard_or_true, ) from torchrec.streamable import Pipelineable @@ -1071,7 +1072,7 @@ def _assert_tensor_has_no_elements_or_has_integers( # TODO(ivankobzarev): Use guard_size_oblivious to pass tensor.numel() == 0 once it is torch scriptable. return - assert pt2_guard_size_oblivious(tensor.numel() == 0) or tensor.dtype in [ + assert pt2_guard_or_false(tensor.numel() == 0) or tensor.dtype in [ torch.long, torch.int, torch.short, @@ -1206,7 +1207,7 @@ def _maybe_compute_length_per_key( torch.sum( pt2_check_size_nonzero(lengths.view(len(keys), stride)), dim=1 ).tolist() - if pt2_guard_size_oblivious(lengths.numel() != 0) + if pt2_guard_or_true(lengths.numel() != 0) else [0] * len(keys) ) ) @@ -1425,7 +1426,7 @@ def _maybe_compute_kjt_to_jt_dict( torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) for lengths in split_lengths ] - elif pt2_guard_size_oblivious(lengths.numel() > 0): + elif pt2_guard_or_true(lengths.numel() > 0): strided_lengths = lengths.view(len(keys), stride) if not torch.jit.is_scripting() and is_torchdynamo_compiling(): torch._check(strided_lengths.size(0) > 0)