Skip to content

Revert D77709565 #3173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 1 addition & 35 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,6 @@ def _maybe_compute_lengths(
return lengths


@torch.fx.wrap
def _maybe_compute_max_length(lengths: torch.Tensor, max_length: Optional[int]) -> int:
if max_length is None:
if lengths.numel() == 0:
return 0
max_length = int(lengths.max().item())
return max_length


def _maybe_compute_offsets(
lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor]
) -> torch.Tensor:
Expand Down Expand Up @@ -590,15 +581,14 @@ class JaggedTensor(Pipelineable, metaclass=JaggedTensorMeta):
offsets.
"""

_fields = ["_values", "_weights", "_lengths", "_offsets", "_max_length"]
_fields = ["_values", "_weights", "_lengths", "_offsets"]

def __init__(
self,
values: torch.Tensor,
weights: Optional[torch.Tensor] = None,
lengths: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
max_length: Optional[int] = None,
) -> None:

self._values: torch.Tensor = values
Expand All @@ -610,7 +600,6 @@ def __init__(
_assert_tensor_has_no_elements_or_has_integers(lengths, "lengths")
self._lengths: Optional[torch.Tensor] = lengths
self._offsets: Optional[torch.Tensor] = offsets
self._max_length: Optional[int] = max_length

@staticmethod
def empty(
Expand Down Expand Up @@ -641,7 +630,6 @@ def empty(
offsets=torch.empty(0, dtype=lengths_dtype, device=device),
lengths=torch.empty(0, dtype=lengths_dtype, device=device),
weights=weights,
max_length=0,
)

@staticmethod
Expand Down Expand Up @@ -924,26 +912,6 @@ def lengths_or_none(self) -> Optional[torch.Tensor]:
"""
return self._lengths

def max_length(self) -> int:
"""
Get the maximum length of the JaggedTensor.

Returns:
int: the maximum length of the JaggedTensor.
"""
_max_length = _maybe_compute_max_length(self.lengths(), self._max_length)
self._max_length = _max_length
return _max_length

def max_length_or_none(self) -> Optional[int]:
"""
Get the maximum length of the JaggedTensor. If not computed, return None.

Returns:
Optional[int]: the maximum length of the JaggedTensor.
"""
return self._max_length

def offsets(self) -> torch.Tensor:
"""
Get JaggedTensor offsets. If not computed, compute it from lengths.
Expand Down Expand Up @@ -1005,7 +973,6 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "JaggedTensor"
weights = self._weights
lengths = self._lengths
offsets = self._offsets
max_length = self._max_length
return JaggedTensor(
values=self._values.to(device, non_blocking=non_blocking),
weights=(
Expand All @@ -1023,7 +990,6 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "JaggedTensor"
if offsets is not None
else None
),
max_length=max_length,
)

@torch.jit.unused
Expand Down
8 changes: 0 additions & 8 deletions torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,14 +568,6 @@ def test_length_vs_offset(self) -> None:
self.assertTrue(torch.equal(j_offset.lengths(), j_lens.lengths()))
self.assertTrue(torch.equal(j_offset.offsets(), j_lens.offsets().int()))

def test_max_length(self) -> None:
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
jt = JaggedTensor(values=values, offsets=offsets)
self.assertIsNone(jt.max_length_or_none())
self.assertEqual(jt.max_length(), 3)
self.assertEqual(jt.max_length_or_none(), 3)

def test_empty(self) -> None:
jt = JaggedTensor.empty(values_dtype=torch.int64)

Expand Down
Loading