Describe the bug
Setting scale_cost = "std" fails when accessing the cost matrix, for PointCloud
To Reproduce
import jax
from ott.geometry import pointcloud
n = 5
d = 2
epsilon = None
relative_epsilon = None
# scale_cost = 1.0
# scale_cost = "mean"
scale_cost = "std" # doesn't work
# scale_cost = "median"
# scale_cost = "max_cost"
rng = jax.random.key(0)
x = jax.random.normal(rng, (n, d))
pc = pointcloud.PointCloud(x, x, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost)
cost = pc.cost_matrix
Expected behavior
pc.cost_matrix should return the cost matrix.
Screenshots
Error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/dtypes.py:701, in dtype(x, canonicalize)
700 try:
--> 701 dt = np.result_type(x)
702 except TypeError as err:
TypeError: data type 'std' not understood
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:4054, in array(object, dtype, copy, order, ndmin, device)
4053 try:
-> 4054 dtype = dtypes._lattice_result_type(*leaves)[0] if leaves else dtypes.float_
4055 except TypeError:
4056 # This happens if, e.g. one of the entries is a memoryview object.
4057 # This is rare, so we only handle it if the normal path fails.
File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/dtypes.py:711, in _lattice_result_type(*args)
710 def _lattice_result_type(*args: Any) -> tuple[DType, bool]:
--> 711 dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
712 if len(dtypes) == 1:
File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/dtypes.py:711, in <genexpr>(.0)
710 def _lattice_result_type(*args: Any) -> tuple[DType, bool]:
--> 711 dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
712 if len(dtypes) == 1:
File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/dtypes.py:514, in _dtype_and_weaktype(value)
513 """Return a (dtype, weak_type) tuple for the given input."""
--> 514 return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)
File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/dtypes.py:703, in dtype(x, canonicalize)
702 except TypeError as err:
--> 703 raise TypeError(f"Cannot determine dtype of {x}") from err
704 if dt not in _jax_dtype_set and not issubdtype(dt, extended):
TypeError: Cannot determine dtype of std
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
Cell In[92], line 19
17 x = jax.random.normal(rng, (n, d))
18 pc = pointcloud.PointCloud(x, x, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost)
---> 19 cost = pc.cost_matrix
File /z1-mheitz/mheitz/Projects/trajectory_inference/ott/src/ott/geometry/pointcloud.py:318, in PointCloud.cost_matrix(self)
316 @property
317 def cost_matrix(self) -> Optional[jnp.ndarray]: # noqa: D102
--> 318 return self.inv_scale_cost * self._unscaled_cost_matrix
File /z1-mheitz/mheitz/Projects/trajectory_inference/ott/src/ott/geometry/pointcloud.py:361, in PointCloud.inv_scale_cost(self)
355 return 1.0 / max_bound
356 raise NotImplementedError(
357 "Using max_bound as scaling factor for "
358 "the cost matrix when the cost is not squared euclidean "
359 "is not implemented."
360 )
--> 361 if utils.is_scalar(self._scale_cost):
362 return 1.0 / self._scale_cost
363 raise ValueError(f"Scaling {self._scale_cost} not implemented.")
File /z1-mheitz/mheitz/Projects/trajectory_inference/ott/src/ott/utils.py:434, in is_scalar(x)
429 def is_scalar(x: Any) -> bool: # noqa: D103
430 if (
431 isinstance(x, (np.ndarray, jax.Array)) or hasattr(x, "__jax_array__") or
432 np.isscalar(x)
433 ):
--> 434 return jnp.asarray(x).ndim == 0
435 return False
File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:4252, in asarray(a, dtype, order, copy, device)
4250 if dtype is not None:
4251 dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment]
-> 4252 return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:4059, in array(object, dtype, copy, order, ndmin, device)
4055 except TypeError:
4056 # This happens if, e.g. one of the entries is a memoryview object.
4057 # This is rare, so we only handle it if the normal path fails.
4058 leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves]
-> 4059 dtype = dtypes._lattice_result_type(*leaves)[0]
4061 if not weak_type:
4062 dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment]
File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/dtypes.py:711, in _lattice_result_type(*args)
710 def _lattice_result_type(*args: Any) -> tuple[DType, bool]:
--> 711 dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
712 if len(dtypes) == 1:
713 out_dtype = dtypes[0]
File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/dtypes.py:711, in <genexpr>(.0)
710 def _lattice_result_type(*args: Any) -> tuple[DType, bool]:
--> 711 dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
712 if len(dtypes) == 1:
713 out_dtype = dtypes[0]
File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/dtypes.py:514, in _dtype_and_weaktype(value)
512 def _dtype_and_weaktype(value: Any) -> tuple[DType, bool]:
513 """Return a (dtype, weak_type) tuple for the given input."""
--> 514 return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)
File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/dtypes.py:705, in dtype(x, canonicalize)
703 raise TypeError(f"Cannot determine dtype of {x}") from err
704 if dt not in _jax_dtype_set and not issubdtype(dt, extended):
--> 705 raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array "
706 "type. Only arrays of numeric types are supported by JAX.")
707 # TODO(jakevdp): fix return type annotation and remove this ignore.
708 return canonicalize_dtype(dt, allow_extended_dtype=True) if canonicalize else dt
TypeError: Value 'std' with dtype <U3 is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
Desktop:
Python 3.12
OTT-JAX master (e6fff13)
Describe the bug
Setting scale_cost = "std" fails when accessing the cost matrix, for PointCloud
To Reproduce
Expected behavior
pc.cost_matrixshould return the cost matrix.Screenshots
Error:
Desktop:
Python 3.12
OTT-JAX master (e6fff13)