From ab9d50c3dfec120d0798c3cfdfd20a93a1c17b8e Mon Sep 17 00:00:00 2001 From: jd7-tr <205998557+jd7-tr@users.noreply.github.com> Date: Thu, 28 Aug 2025 10:13:38 -0700 Subject: [PATCH] Implement __torch_function__ for KeyedTensor (#3329) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/3329 X-link: https://github.com/pytorch/pytorch/pull/161683 1. There are a bunch of `torch.ops.aten` operations that can't handle `KeyedTensor`: The error was occurring because these ops expects a regular `Tensor` but was receiving a `KeyedTensor` object. 2. Implement `__torch_function__` for `KeyedTensor`, so when these incompatible operations are called with a `KeyedTensor`, the `__torch_function__` method automatically delegates the op to the underlying values tensor from the `KeyedTensor` and returns a new `KeyedTensor` with updated values. Reviewed By: malaybag Differential Revision: D81047278 --- torchrec/sparse/jagged_tensor.py | 48 +++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 78cdd2e1f..042f17332 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -12,7 +12,7 @@ import operator -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import torch from torch.autograd.profiler import record_function @@ -3191,6 +3191,52 @@ class KeyedTensor(Pipelineable, metaclass=JaggedTensorMeta): # torch.Tensor([[2, 1, 2], [2, 1, 2], [2, 1, 2]]) """ + @classmethod + # pyre-ignore + def __torch_function__(cls: Type["KeyedTensor"], func, types, args=(), kwargs=None): + """ + Enable KeyedTensor compatibility with PyTorch operations by delegating + operations to the underlying values tensor and reconstructing KeyedTensor + when appropriate. + + This method allows KeyedTensor to work with various PyTorch operations + including compilation operations like torch.compile. + """ + if kwargs is None: + kwargs = {} + + # Handle operations that expect regular tensors but should return KeyedTensor + tensor_ops = [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten.to.dtype, + torch.ops.aten.to.device, + torch.ops.aten.detach.default, + torch.ops.aten.clone.default, + ] + + if func in tensor_ops: + if len(args) > 0 and isinstance(args[0], cls): + keyed_tensor = args[0] + values_tensor = keyed_tensor.values() + new_args = (values_tensor,) + args[1:] + result = func(*new_args, **kwargs) + + # For operations that return tensors, create new KeyedTensor with updated values + if isinstance(result, torch.Tensor): + return cls( + keys=keyed_tensor.keys(), + length_per_key=keyed_tensor.length_per_key(), + values=result, + key_dim=keyed_tensor.key_dim(), + offset_per_key=keyed_tensor._offset_per_key, + index_per_key=keyed_tensor._index_per_key, + ) + + return result + + # For all other operations, return NotImplemented to allow normal handling + return NotImplementedError(f"{func} cannot be applied to KeyedTensor.") + def __init__( self, keys: List[str],