From cb6577a4c4e10ac549a58f431ec6bdefe7ae3373 Mon Sep 17 00:00:00 2001 From: Mattia Ragni Date: Thu, 21 May 2026 13:33:00 +0200 Subject: [PATCH] fix(mlflow): skip non-scalar metrics like atomic_numbers to prevent JAX conversion crash --- src/reax/loggers/mlflow.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/reax/loggers/mlflow.py b/src/reax/loggers/mlflow.py index 7fca6ce..f46a6f6 100644 --- a/src/reax/loggers/mlflow.py +++ b/src/reax/loggers/mlflow.py @@ -307,10 +307,15 @@ def log_metrics(self, metrics: Mapping[str, float], step: int | None = None) -> if isinstance(value, str): self._warning_cache.warn(f"Discarding metric with string value {name}={value}.") continue + if isinstance(value, list): self._warning_cache.warn(f"Discarding metric with list value {name}={value}.") continue + if hasattr(value, "ndim") and value.ndim > 0: + self._warning_cache.warn(f"Discarding non scalar metric {name}={value}.") + continue + new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", name) if name != new_k: self._warning_cache.warn(