Skip to content

Commit a580e35

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Replace int(..) with torch.sym_int(...) for IR export compatibility (#3133)
Summary: Pull Request resolved: #3133 int(..) is not PT2 IR compatible Reviewed By: TroyGarden Differential Revision: D77195403 fbshipit-source-id: 7427561fc495815c0276220ffc85da1799c32e46
1 parent 4091d7d commit a580e35

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1102,7 +1102,12 @@ def _maybe_compute_stride_kjt(
11021102
elif (
11031103
stride_per_key_per_rank is not None and stride_per_key_per_rank.numel() > 0
11041104
):
1105-
stride = int(stride_per_key_per_rank.sum(dim=1).max().item())
1105+
s = stride_per_key_per_rank.sum(dim=1).max().item()
1106+
if not torch.jit.is_scripting() and is_non_strict_exporting():
1107+
stride = torch.sym_int(s)
1108+
else:
1109+
stride = int(s)
1110+
11061111
elif offsets is not None and offsets.numel() > 0:
11071112
stride = (offsets.numel() - 1) // len(keys)
11081113
elif lengths is not None:

0 commit comments

Comments
 (0)