Skip to content

Commit 95ba754

Browse files
authored
Allow mixed tensor type math if one of them is a scalar (#9453)
1 parent f3c7907 commit 95ba754

File tree

3 files changed

+10
-1
lines changed

3 files changed

+10
-1
lines changed

torchax/test/test_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def setUp(self):
186186
self.env = torchax.default_env()
187187
torchax.enable_accuracy_mode()
188188
#self.env.config.debug_accuracy_for_each_op = True
189-
self.env.config.debug_print_each_op = True
189+
self.env.config.debug_print_each_op = False
190190
torch.manual_seed(0)
191191
self.old_var = self.env.config.use_torch_native_for_cpu_tensor
192192
self.env.config.use_torch_native_for_cpu_tensor = False

torchax/torchax/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ class Configuration:
1010

1111
use_int32_for_index: bool = False
1212

13+
# normally, math between CPU torch.Tensor with torchax.Tensor is not
14+
# allowed. However, if that torch.Tensor happens to be scalar, then we
15+
# can use scalar * tensor math to handle it
16+
allow_mixed_math_with_scalar_tensor: bool = True
17+
1318
# If true, we will convert Views into torchax.Tensors eagerly
1419
force_materialize_views: bool = False
1520

torchax/torchax/tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,10 @@ def t2j_iso(self, torchtensors):
639639
"""
640640

641641
def to_jax(x):
642+
if self.config.allow_mixed_math_with_scalar_tensor and not isinstance(
643+
x, Tensor):
644+
if x.squeeze().ndim == 0:
645+
return x.item()
642646
if isinstance(
643647
x, torch.distributed._functional_collectives.AsyncCollectiveTensor):
644648
x = x.wait()

0 commit comments

Comments
 (0)