Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import operator

from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import torch
from torch.autograd.profiler import record_function
Expand Down Expand Up @@ -3191,6 +3191,52 @@ class KeyedTensor(Pipelineable, metaclass=JaggedTensorMeta):
# torch.Tensor([[2, 1, 2], [2, 1, 2], [2, 1, 2]])
"""

@classmethod
# pyre-ignore
def __torch_function__(cls: Type["KeyedTensor"], func, types, args=(), kwargs=None):
"""
Enable KeyedTensor compatibility with PyTorch operations by delegating
operations to the underlying values tensor and reconstructing KeyedTensor
when appropriate.

This method allows KeyedTensor to work with various PyTorch operations
including compilation operations like torch.compile.
"""
if kwargs is None:
kwargs = {}

# Handle operations that expect regular tensors but should return KeyedTensor
tensor_ops = [
torch.ops.aten._assert_tensor_metadata.default,
torch.ops.aten.to.dtype,
torch.ops.aten.to.device,
torch.ops.aten.detach.default,
torch.ops.aten.clone.default,
]

if func in tensor_ops:
if len(args) > 0 and isinstance(args[0], cls):
keyed_tensor = args[0]
values_tensor = keyed_tensor.values()
new_args = (values_tensor,) + args[1:]
result = func(*new_args, **kwargs)

# For operations that return tensors, create new KeyedTensor with updated values
if isinstance(result, torch.Tensor):
return cls(
keys=keyed_tensor.keys(),
length_per_key=keyed_tensor.length_per_key(),
values=result,
key_dim=keyed_tensor.key_dim(),
offset_per_key=keyed_tensor._offset_per_key,
index_per_key=keyed_tensor._index_per_key,
)

return result

# For all other operations, return NotImplemented to allow normal handling
return NotImplementedError(f"{func} cannot be applied to KeyedTensor.")

def __init__(
self,
keys: List[str],
Expand Down
Loading