Skip to content

Commit 338402c

Browse files
superbobrycopybara-github
authored andcommitted
Pass torch.Tensor to jax.dlpack.from_dlpack directly
The code path accepting a DLPack capsule is deprecated and will be removed soon. PiperOrigin-RevId: 797799893 Change-Id: I54a2bdad6eb9ee8efaa230a8cb416b7179f67fe2
1 parent 75cddee commit 338402c

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

brax/io/torch.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,7 @@ def torch_to_jax(value: Any) -> Any:
5252
@torch_to_jax.register(torch.Tensor)
5353
def _tensor_to_jax(value: torch.Tensor) -> jax.Array:
5454
"""Converts a PyTorch Tensor into a jax.Array."""
55-
tensor = torch_dlpack.to_dlpack(value)
56-
tensor = jax_dlpack.from_dlpack(tensor)
57-
return tensor
55+
return jax_dlpack.from_dlpack(value)
5856

5957

6058
@torch_to_jax.register(abc.Mapping)

0 commit comments

Comments
 (0)