Skip to content

scale_cost = "std" doesn't work for PointCloud #606

@matthieuheitz

Description

@matthieuheitz

Describe the bug
Setting scale_cost = "std" fails when accessing the cost matrix, for PointCloud

To Reproduce

import jax
from ott.geometry import pointcloud

n = 5
d = 2
epsilon = None
relative_epsilon = None
# scale_cost = 1.0
# scale_cost = "mean"
scale_cost = "std"      # doesn't work
# scale_cost = "median"
# scale_cost = "max_cost"

rng = jax.random.key(0)
x = jax.random.normal(rng, (n, d))
pc = pointcloud.PointCloud(x, x, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost)
cost = pc.cost_matrix

Expected behavior
pc.cost_matrix should return the cost matrix.

Screenshots
Error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/dtypes.py:701, in dtype(x, canonicalize)
    700 try:
--> 701   dt = np.result_type(x)
    702 except TypeError as err:

TypeError: data type 'std' not understood

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:4054, in array(object, dtype, copy, order, ndmin, device)
   4053 try:
-> 4054   dtype = dtypes._lattice_result_type(*leaves)[0] if leaves else dtypes.float_
   4055 except TypeError:
   4056   # This happens if, e.g. one of the entries is a memoryview object.
   4057   # This is rare, so we only handle it if the normal path fails.

File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/dtypes.py:711, in _lattice_result_type(*args)
    710 def _lattice_result_type(*args: Any) -> tuple[DType, bool]:
--> 711   dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
    712   if len(dtypes) == 1:

File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/dtypes.py:711, in <genexpr>(.0)
    710 def _lattice_result_type(*args: Any) -> tuple[DType, bool]:
--> 711   dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
    712   if len(dtypes) == 1:

File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/dtypes.py:514, in _dtype_and_weaktype(value)
    513 """Return a (dtype, weak_type) tuple for the given input."""
--> 514 return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)

File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/dtypes.py:703, in dtype(x, canonicalize)
    702   except TypeError as err:
--> 703     raise TypeError(f"Cannot determine dtype of {x}") from err
    704 if dt not in _jax_dtype_set and not issubdtype(dt, extended):

TypeError: Cannot determine dtype of std

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Cell In[92], line 19
     17 x = jax.random.normal(rng, (n, d))
     18 pc = pointcloud.PointCloud(x, x, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost)
---> 19 cost = pc.cost_matrix

File /z1-mheitz/mheitz/Projects/trajectory_inference/ott/src/ott/geometry/pointcloud.py:318, in PointCloud.cost_matrix(self)
    316 @property
    317 def cost_matrix(self) -> Optional[jnp.ndarray]:  # noqa: D102
--> 318   return self.inv_scale_cost * self._unscaled_cost_matrix

File /z1-mheitz/mheitz/Projects/trajectory_inference/ott/src/ott/geometry/pointcloud.py:361, in PointCloud.inv_scale_cost(self)
    355     return 1.0 / max_bound
    356   raise NotImplementedError(
    357       "Using max_bound as scaling factor for "
    358       "the cost matrix when the cost is not squared euclidean "
    359       "is not implemented."
    360   )
--> 361 if utils.is_scalar(self._scale_cost):
    362   return 1.0 / self._scale_cost
    363 raise ValueError(f"Scaling {self._scale_cost} not implemented.")

File /z1-mheitz/mheitz/Projects/trajectory_inference/ott/src/ott/utils.py:434, in is_scalar(x)
    429 def is_scalar(x: Any) -> bool:  # noqa: D103
    430   if (
    431       isinstance(x, (np.ndarray, jax.Array)) or hasattr(x, "__jax_array__") or
    432       np.isscalar(x)
    433   ):
--> 434     return jnp.asarray(x).ndim == 0
    435   return False

File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:4252, in asarray(a, dtype, order, copy, device)
   4250 if dtype is not None:
   4251   dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)  # type: ignore[assignment]
-> 4252 return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)

File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:4059, in array(object, dtype, copy, order, ndmin, device)
   4055   except TypeError:
   4056     # This happens if, e.g. one of the entries is a memoryview object.
   4057     # This is rare, so we only handle it if the normal path fails.
   4058     leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves]
-> 4059     dtype = dtypes._lattice_result_type(*leaves)[0]
   4061 if not weak_type:
   4062   dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)  # type: ignore[assignment]

File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/dtypes.py:711, in _lattice_result_type(*args)
    710 def _lattice_result_type(*args: Any) -> tuple[DType, bool]:
--> 711   dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
    712   if len(dtypes) == 1:
    713     out_dtype = dtypes[0]

File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/dtypes.py:711, in <genexpr>(.0)
    710 def _lattice_result_type(*args: Any) -> tuple[DType, bool]:
--> 711   dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
    712   if len(dtypes) == 1:
    713     out_dtype = dtypes[0]

File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/dtypes.py:514, in _dtype_and_weaktype(value)
    512 def _dtype_and_weaktype(value: Any) -> tuple[DType, bool]:
    513   """Return a (dtype, weak_type) tuple for the given input."""
--> 514   return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)

File ~/Projects/trajectory_inference/venv/lib/python3.12/site-packages/jax/_src/dtypes.py:705, in dtype(x, canonicalize)
    703     raise TypeError(f"Cannot determine dtype of {x}") from err
    704 if dt not in _jax_dtype_set and not issubdtype(dt, extended):
--> 705   raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array "
    706                   "type. Only arrays of numeric types are supported by JAX.")
    707 # TODO(jakevdp): fix return type annotation and remove this ignore.
    708 return canonicalize_dtype(dt, allow_extended_dtype=True) if canonicalize else dt

TypeError: Value 'std' with dtype <U3 is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

Desktop:
Python 3.12
OTT-JAX master (e6fff13)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions