File tree Expand file tree Collapse file tree 3 files changed +10
-1
lines changed Expand file tree Collapse file tree 3 files changed +10
-1
lines changed Original file line number Diff line number Diff line change @@ -186,7 +186,7 @@ def setUp(self):
186
186
self .env = torchax .default_env ()
187
187
torchax .enable_accuracy_mode ()
188
188
#self.env.config.debug_accuracy_for_each_op = True
189
- self .env .config .debug_print_each_op = True
189
+ self .env .config .debug_print_each_op = False
190
190
torch .manual_seed (0 )
191
191
self .old_var = self .env .config .use_torch_native_for_cpu_tensor
192
192
self .env .config .use_torch_native_for_cpu_tensor = False
Original file line number Diff line number Diff line change @@ -10,6 +10,11 @@ class Configuration:
10
10
11
11
use_int32_for_index : bool = False
12
12
13
+ # normally, math between CPU torch.Tensor with torchax.Tensor is not
14
+ # allowed. However, if that torch.Tensor happens to be scalar, then we
15
+ # can use scalar * tensor math to handle it
16
+ allow_mixed_math_with_scalar_tensor : bool = True
17
+
13
18
# If true, we will convert Views into torchax.Tensors eagerly
14
19
force_materialize_views : bool = False
15
20
Original file line number Diff line number Diff line change @@ -639,6 +639,10 @@ def t2j_iso(self, torchtensors):
639
639
"""
640
640
641
641
def to_jax (x ):
642
+ if self .config .allow_mixed_math_with_scalar_tensor and not isinstance (
643
+ x , Tensor ):
644
+ if x .squeeze ().ndim == 0 :
645
+ return x .item ()
642
646
if isinstance (
643
647
x , torch .distributed ._functional_collectives .AsyncCollectiveTensor ):
644
648
x = x .wait ()
You can’t perform that action at this time.
0 commit comments