From 749b21fac7933fa5f881b6f318edca29ca240da5 Mon Sep 17 00:00:00 2001 From: galletti Date: Tue, 21 Apr 2026 14:55:52 +0200 Subject: [PATCH] to latest jax ; equinox models ; hydra fix --- .gitignore | 3 +- README.md | 17 +- configs/WaterDrop_2d/gns.yaml | 7 +- configs/dam_2d/base.yaml | 5 +- configs/dam_2d/gns.yaml | 7 +- configs/dam_2d/segnn.yaml | 7 +- configs/ldc_2d/base.yaml | 7 +- configs/ldc_2d/gns.yaml | 7 +- configs/ldc_2d/segnn.yaml | 7 +- configs/ldc_3d/base.yaml | 7 +- configs/ldc_3d/gns.yaml | 7 +- configs/ldc_3d/segnn.yaml | 7 +- configs/rpf_2d/base.yaml | 7 +- configs/rpf_2d/egnn.yaml | 7 +- configs/rpf_2d/gns.yaml | 7 +- configs/rpf_2d/painn.yaml | 7 +- configs/rpf_2d/segnn.yaml | 7 +- configs/rpf_3d/base.yaml | 7 +- configs/rpf_3d/egnn.yaml | 7 +- configs/rpf_3d/gns.yaml | 7 +- configs/rpf_3d/painn.yaml | 7 +- configs/rpf_3d/segnn.yaml | 7 +- configs/tgv_2d/base.yaml | 5 +- configs/tgv_2d/gns.yaml | 7 +- configs/tgv_2d/segnn.yaml | 7 +- configs/tgv_3d/base.yaml | 5 +- configs/tgv_3d/gns.yaml | 7 +- configs/tgv_3d/segnn.yaml | 5 +- lagrangebench/__init__.py | 3 +- lagrangebench/case_setup/features.py | 2 +- lagrangebench/evaluate/metrics.py | 22 +- lagrangebench/evaluate/rollout.py | 166 +++---- lagrangebench/models/base.py | 61 +-- lagrangebench/models/egnn.py | 534 ++++++++++++---------- lagrangebench/models/gns.py | 251 ++++++----- lagrangebench/models/linear.py | 26 +- lagrangebench/models/painn.py | 494 +++++++++++++-------- lagrangebench/models/segnn.py | 632 ++++++++++++++++----------- lagrangebench/models/utils.py | 246 ++++++++--- lagrangebench/runner.py | 181 ++++---- lagrangebench/train/strats.py | 13 +- lagrangebench/train/trainer.py | 205 ++++----- lagrangebench/utils.py | 112 ++--- main.py | 92 ++-- pyproject.toml | 26 +- tests/conftest.py | 9 + tests/models_test.py | 90 ++-- tests/rollout_test.py | 76 ++-- 48 files changed, 1939 insertions(+), 1496 deletions(-) create mode 100644 tests/conftest.py diff --git a/.gitignore b/.gitignore index 5580ed4..a91c6b7 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,7 @@ rollouts profile dist .coverage +.claude # Sphinx documentation -docs/_build/ +docs/_build/ \ No newline at end of file diff --git a/README.md b/README.md index cfecab0..ce00ddd 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ Although the current [`jax-metal==0.0.5` library](https://pypi.org/project/jax-m A general tutorial is provided in the example notebook "Training GNS on the 2D Taylor Green Vortex" under `./notebooks/tutorial.ipynb` on the [LagrangeBench repository](https://github.com/tumaer/lagrangebench). The notebook covers the basics of LagrangeBench, such as loading a dataset, setting up a case, training a model from scratch and evaluating its performance. ### Running in a local clone (`main.py`) -Alternatively, experiments can also be set up with `main.py`, based on extensive YAML config files and cli arguments (check [`configs/`](configs/)). By default, the arguments have priority as 1) passed cli arguments, 2) YAML config and 3) [`defaults.py`](lagrangebench/defaults.py) (`lagrangebench` defaults). +Alternatively, experiments can also be set up with `main.py`, based on Hydra-composed YAML config files and cli arguments (check [`configs/`](configs/)). Each preset is a small YAML file whose `defaults:` list pulls in either the dataset-level `base.yaml` or the global defaults registered from [`defaults.py`](lagrangebench/defaults.py). CLI overrides use the standard Hydra `key=value` syntax and have priority over everything. When loading a saved model with `load_ckp` the config from the checkpoint is automatically loaded and training is restarted. For more details check the [`runner.py`](lagrangebench/runner.py) file. @@ -94,7 +94,7 @@ When loading a saved model with `load_ckp` the config from the checkpoint is aut For example, to start a _GNS_ run from scratch on the RPF 2D dataset use ``` -python main.py config=configs/rpf_2d/gns.yaml +python main.py --config-path configs/rpf_2d --config-name gns ``` Some model presets can be found in `./configs/`. @@ -103,16 +103,18 @@ If `mode=all` is provided, then training (`mode=train`) and subsequent inference **Restart training** -To restart training from the last checkpoint in `load_ckp` use +To restart training from the last checkpoint in `load_ckp` point Hydra at the checkpoint's saved `config.yaml`: ``` -python main.py load_ckp=ckp/gns_rpf2d_yyyymmdd-hhmmss +python main.py --config-path ckp/gns_rpf2d_yyyymmdd-hhmmss --config-name config load_ckp=ckp/gns_rpf2d_yyyymmdd-hhmmss ``` **Inference** -To evaluate a trained model from `load_ckp` on the test split (`test=True`) use +To evaluate a trained model from `load_ckp` on the test split (`eval.test=True`) use ``` -python main.py load_ckp=ckp/gns_rpf2d_yyyymmdd-hhmmss/best rollout_dir=rollout/gns_rpf2d_yyyymmdd-hhmmss/best mode=infer test=True +python main.py --config-path ckp/gns_rpf2d_yyyymmdd-hhmmss/best --config-name config \ + load_ckp=ckp/gns_rpf2d_yyyymmdd-hhmmss/best \ + mode=infer eval.test=True eval.rollout_dir=rollout/gns_rpf2d_yyyymmdd-hhmmss/best ``` If the default `eval.infer.out_type=pkl` is active, then the generated trajectories and a `metricsYYYY_MM_DD_HH_MM_SS.pkl` file will be written to `eval.rollout_dir`. The metrics file contains all `eval.infer.metrics` properties for each generated rollout. @@ -162,7 +164,8 @@ gdown 19TO4PaFGcryXOFFKs93IniuPZKEcaJ37 # unzip the downloaded file `gns_tgv2d.zip` python -c "import shutil; shutil.unpack_archive('gns_tgv2d.zip', 'gns_tgv2d')" # evaluate the model on the test split -python main.py gpu=$GPU_ID mode=infer eval.test=True load_ckp=gns_tgv2d/best +python main.py --config-path gns_tgv2d/best --config-name config \ + gpu=$GPU_ID mode=infer eval.test=True load_ckp=gns_tgv2d/best ``` ## Directory structure diff --git a/configs/WaterDrop_2d/gns.yaml b/configs/WaterDrop_2d/gns.yaml index a1f841c..8b30239 100644 --- a/configs/WaterDrop_2d/gns.yaml +++ b/configs/WaterDrop_2d/gns.yaml @@ -1,9 +1,12 @@ -extends: LAGRANGEBENCH_DEFAULTS +# @package _global_ +defaults: + - /lagrangebench_defaults + - _self_ dataset: src: /tmp/datasets/WaterDrop -model: +model: name: gns num_mp_steps: 10 latent_dim: 128 diff --git a/configs/dam_2d/base.yaml b/configs/dam_2d/base.yaml index 23f2466..60ffca0 100644 --- a/configs/dam_2d/base.yaml +++ b/configs/dam_2d/base.yaml @@ -1,4 +1,7 @@ -extends: LAGRANGEBENCH_DEFAULTS +# @package _global_ +defaults: + - /lagrangebench_defaults + - _self_ dataset: src: datasets/2D_DAM_5740_20kevery100 diff --git a/configs/dam_2d/gns.yaml b/configs/dam_2d/gns.yaml index 4cabc0e..dd94461 100644 --- a/configs/dam_2d/gns.yaml +++ b/configs/dam_2d/gns.yaml @@ -1,6 +1,9 @@ -extends: configs/dam_2d/base.yaml +# @package _global_ +defaults: + - base + - _self_ -model: +model: name: gns num_mp_steps: 10 latent_dim: 128 diff --git a/configs/dam_2d/segnn.yaml b/configs/dam_2d/segnn.yaml index c50ce85..97c42d6 100644 --- a/configs/dam_2d/segnn.yaml +++ b/configs/dam_2d/segnn.yaml @@ -1,6 +1,9 @@ -extends: configs/dam_2d/base.yaml +# @package _global_ +defaults: + - base + - _self_ -model: +model: name: segnn num_mp_steps: 10 latent_dim: 64 diff --git a/configs/ldc_2d/base.yaml b/configs/ldc_2d/base.yaml index 7232c2a..86a3237 100644 --- a/configs/ldc_2d/base.yaml +++ b/configs/ldc_2d/base.yaml @@ -1,4 +1,7 @@ -extends: LAGRANGEBENCH_DEFAULTS +# @package _global_ +defaults: + - /lagrangebench_defaults + - _self_ dataset: src: datasets/2D_LDC_2708_10kevery100 @@ -7,4 +10,4 @@ logging: wandb_project: ldc_2d neighbors: - multiplier: 2.0 \ No newline at end of file + multiplier: 2.0 diff --git a/configs/ldc_2d/gns.yaml b/configs/ldc_2d/gns.yaml index 39eeb31..dd94461 100644 --- a/configs/ldc_2d/gns.yaml +++ b/configs/ldc_2d/gns.yaml @@ -1,6 +1,9 @@ -extends: configs/ldc_2d/base.yaml +# @package _global_ +defaults: + - base + - _self_ -model: +model: name: gns num_mp_steps: 10 latent_dim: 128 diff --git a/configs/ldc_2d/segnn.yaml b/configs/ldc_2d/segnn.yaml index 59230f0..97c42d6 100644 --- a/configs/ldc_2d/segnn.yaml +++ b/configs/ldc_2d/segnn.yaml @@ -1,6 +1,9 @@ -extends: configs/ldc_2d/base.yaml +# @package _global_ +defaults: + - base + - _self_ -model: +model: name: segnn num_mp_steps: 10 latent_dim: 64 diff --git a/configs/ldc_3d/base.yaml b/configs/ldc_3d/base.yaml index b927239..4831ac4 100644 --- a/configs/ldc_3d/base.yaml +++ b/configs/ldc_3d/base.yaml @@ -1,4 +1,7 @@ -extends: LAGRANGEBENCH_DEFAULTS +# @package _global_ +defaults: + - /lagrangebench_defaults + - _self_ dataset: src: datasets/3D_LDC_8160_10kevery100 @@ -7,4 +10,4 @@ logging: wandb_project: ldc_3d neighbors: - multiplier: 2.0 \ No newline at end of file + multiplier: 2.0 diff --git a/configs/ldc_3d/gns.yaml b/configs/ldc_3d/gns.yaml index cacf6bb..3e8e26c 100644 --- a/configs/ldc_3d/gns.yaml +++ b/configs/ldc_3d/gns.yaml @@ -1,6 +1,9 @@ -extends: configs/ldc_3d/base.yaml +# @package _global_ +defaults: + - base + - _self_ -model: +model: name: gns num_mp_steps: 10 latent_dim: 128 diff --git a/configs/ldc_3d/segnn.yaml b/configs/ldc_3d/segnn.yaml index 88adf11..d3d7682 100644 --- a/configs/ldc_3d/segnn.yaml +++ b/configs/ldc_3d/segnn.yaml @@ -1,6 +1,9 @@ -extends: configs/ldc_3d/base.yaml +# @package _global_ +defaults: + - base + - _self_ -model: +model: name: segnn num_mp_steps: 10 latent_dim: 64 diff --git a/configs/rpf_2d/base.yaml b/configs/rpf_2d/base.yaml index dc7a669..0d16f2d 100644 --- a/configs/rpf_2d/base.yaml +++ b/configs/rpf_2d/base.yaml @@ -1,7 +1,10 @@ -extends: LAGRANGEBENCH_DEFAULTS +# @package _global_ +defaults: + - /lagrangebench_defaults + - _self_ dataset: src: datasets/2D_RPF_3200_20kevery100 logging: - wandb_project: rpf_2d \ No newline at end of file + wandb_project: rpf_2d diff --git a/configs/rpf_2d/egnn.yaml b/configs/rpf_2d/egnn.yaml index 21e4ef9..57900bf 100644 --- a/configs/rpf_2d/egnn.yaml +++ b/configs/rpf_2d/egnn.yaml @@ -1,6 +1,9 @@ -extends: configs/rpf_2d/base.yaml +# @package _global_ +defaults: + - base + - _self_ -model: +model: name: egnn num_mp_steps: 5 latent_dim: 128 diff --git a/configs/rpf_2d/gns.yaml b/configs/rpf_2d/gns.yaml index 6313033..3e8e26c 100644 --- a/configs/rpf_2d/gns.yaml +++ b/configs/rpf_2d/gns.yaml @@ -1,6 +1,9 @@ -extends: configs/rpf_2d/base.yaml +# @package _global_ +defaults: + - base + - _self_ -model: +model: name: gns num_mp_steps: 10 latent_dim: 128 diff --git a/configs/rpf_2d/painn.yaml b/configs/rpf_2d/painn.yaml index 82907f9..93fa945 100644 --- a/configs/rpf_2d/painn.yaml +++ b/configs/rpf_2d/painn.yaml @@ -1,6 +1,9 @@ -extends: configs/rpf_2d/base.yaml +# @package _global_ +defaults: + - base + - _self_ -model: +model: name: painn num_mp_steps: 5 latent_dim: 128 diff --git a/configs/rpf_2d/segnn.yaml b/configs/rpf_2d/segnn.yaml index f447336..d91566b 100644 --- a/configs/rpf_2d/segnn.yaml +++ b/configs/rpf_2d/segnn.yaml @@ -1,6 +1,9 @@ -extends: configs/rpf_2d/base.yaml +# @package _global_ +defaults: + - base + - _self_ -model: +model: name: segnn num_mp_steps: 10 latent_dim: 64 diff --git a/configs/rpf_3d/base.yaml b/configs/rpf_3d/base.yaml index 80ebcac..bdafdf4 100644 --- a/configs/rpf_3d/base.yaml +++ b/configs/rpf_3d/base.yaml @@ -1,7 +1,10 @@ -extends: LAGRANGEBENCH_DEFAULTS +# @package _global_ +defaults: + - /lagrangebench_defaults + - _self_ dataset: src: datasets/3D_RPF_8000_10kevery100 logging: - wandb_project: rpf_3d \ No newline at end of file + wandb_project: rpf_3d diff --git a/configs/rpf_3d/egnn.yaml b/configs/rpf_3d/egnn.yaml index 8bdb928..e472a4b 100644 --- a/configs/rpf_3d/egnn.yaml +++ b/configs/rpf_3d/egnn.yaml @@ -1,6 +1,9 @@ -extends: configs/rpf_3d/base.yaml +# @package _global_ +defaults: + - base + - _self_ -model: +model: name: egnn num_mp_steps: 5 latent_dim: 128 diff --git a/configs/rpf_3d/gns.yaml b/configs/rpf_3d/gns.yaml index 4deb161..3e8e26c 100644 --- a/configs/rpf_3d/gns.yaml +++ b/configs/rpf_3d/gns.yaml @@ -1,6 +1,9 @@ -extends: configs/rpf_3d/base.yaml +# @package _global_ +defaults: + - base + - _self_ -model: +model: name: gns num_mp_steps: 10 latent_dim: 128 diff --git a/configs/rpf_3d/painn.yaml b/configs/rpf_3d/painn.yaml index e6e05d9..d8d3504 100644 --- a/configs/rpf_3d/painn.yaml +++ b/configs/rpf_3d/painn.yaml @@ -1,6 +1,9 @@ -extends: configs/rpf_3d/base.yaml +# @package _global_ +defaults: + - base + - _self_ -model: +model: name: painn num_mp_steps: 5 latent_dim: 128 diff --git a/configs/rpf_3d/segnn.yaml b/configs/rpf_3d/segnn.yaml index 813c931..d91566b 100644 --- a/configs/rpf_3d/segnn.yaml +++ b/configs/rpf_3d/segnn.yaml @@ -1,6 +1,9 @@ -extends: configs/rpf_3d/base.yaml +# @package _global_ +defaults: + - base + - _self_ -model: +model: name: segnn num_mp_steps: 10 latent_dim: 64 diff --git a/configs/tgv_2d/base.yaml b/configs/tgv_2d/base.yaml index 6c48d1a..abf794f 100644 --- a/configs/tgv_2d/base.yaml +++ b/configs/tgv_2d/base.yaml @@ -1,4 +1,7 @@ -extends: LAGRANGEBENCH_DEFAULTS +# @package _global_ +defaults: + - /lagrangebench_defaults + - _self_ dataset: src: datasets/2D_TGV_2500_10kevery100 diff --git a/configs/tgv_2d/gns.yaml b/configs/tgv_2d/gns.yaml index 17e7b64..3e8e26c 100644 --- a/configs/tgv_2d/gns.yaml +++ b/configs/tgv_2d/gns.yaml @@ -1,6 +1,9 @@ -extends: configs/tgv_2d/base.yaml +# @package _global_ +defaults: + - base + - _self_ -model: +model: name: gns num_mp_steps: 10 latent_dim: 128 diff --git a/configs/tgv_2d/segnn.yaml b/configs/tgv_2d/segnn.yaml index ba3742b..d3d7682 100644 --- a/configs/tgv_2d/segnn.yaml +++ b/configs/tgv_2d/segnn.yaml @@ -1,6 +1,9 @@ -extends: configs/tgv_2d/base.yaml +# @package _global_ +defaults: + - base + - _self_ -model: +model: name: segnn num_mp_steps: 10 latent_dim: 64 diff --git a/configs/tgv_3d/base.yaml b/configs/tgv_3d/base.yaml index 1441f78..3b246e3 100644 --- a/configs/tgv_3d/base.yaml +++ b/configs/tgv_3d/base.yaml @@ -1,4 +1,7 @@ -extends: LAGRANGEBENCH_DEFAULTS +# @package _global_ +defaults: + - /lagrangebench_defaults + - _self_ dataset: src: datasets/3D_TGV_8000_10kevery100 diff --git a/configs/tgv_3d/gns.yaml b/configs/tgv_3d/gns.yaml index dd6dd84..3e8e26c 100644 --- a/configs/tgv_3d/gns.yaml +++ b/configs/tgv_3d/gns.yaml @@ -1,6 +1,9 @@ -extends: configs/tgv_3d/base.yaml +# @package _global_ +defaults: + - base + - _self_ -model: +model: name: gns num_mp_steps: 10 latent_dim: 128 diff --git a/configs/tgv_3d/segnn.yaml b/configs/tgv_3d/segnn.yaml index fab105a..d3d7682 100644 --- a/configs/tgv_3d/segnn.yaml +++ b/configs/tgv_3d/segnn.yaml @@ -1,4 +1,7 @@ -extends: configs/tgv_3d/base.yaml +# @package _global_ +defaults: + - base + - _self_ model: name: segnn diff --git a/lagrangebench/__init__.py b/lagrangebench/__init__.py index b956594..58c4a67 100644 --- a/lagrangebench/__init__.py +++ b/lagrangebench/__init__.py @@ -1,7 +1,7 @@ from .case_setup.case import case_builder from .data import DAM2D, LDC2D, LDC3D, RPF2D, RPF3D, TGV2D, TGV3D, H5Dataset from .evaluate import infer -from .models import EGNN, GNS, SEGNN, PaiNN +from .models import EGNN, GNS, SEGNN, Linear, PaiNN from .train.trainer import Trainer __all__ = [ @@ -13,6 +13,7 @@ "EGNN", "SEGNN", "PaiNN", + "Linear", "data", "H5Dataset", "TGV2D", diff --git a/lagrangebench/case_setup/features.py b/lagrangebench/case_setup/features.py index a18320e..5b01378 100644 --- a/lagrangebench/case_setup/features.py +++ b/lagrangebench/case_setup/features.py @@ -123,6 +123,6 @@ def feature_transform( ) features["rel_dist"] = normalized_relative_distances[:, None] - return jax.tree_map(lambda f: f, features) + return jax.tree.map(lambda f: f, features) return feature_transform diff --git a/lagrangebench/evaluate/metrics.py b/lagrangebench/evaluate/metrics.py index 6d977b0..f2792aa 100644 --- a/lagrangebench/evaluate/metrics.py +++ b/lagrangebench/evaluate/metrics.py @@ -160,20 +160,30 @@ def e_kin(self, frame: jnp.ndarray) -> float: return jnp.sum(frame**2) # * dx ** 3 def _sinkhorn_ott(self, pred: jnp.ndarray, target: jnp.ndarray) -> float: - # pairwise distances as cost + # pairwise distances as cost. Newer ott-jax releases contain an internal + # ``jax.lax.cond`` whose branches disagree on dtype when the inputs are + # float64 (x64 mode); cast to float32 to dodge that and keep the metric + # usable across ott versions. + pred = pred.astype(jnp.float32) + target = target.astype(jnp.float32) loss_matrix_xy = self._distance_matrix(pred, target) loss_matrix_xx = self._distance_matrix(pred, pred) loss_matrix_yy = self._distance_matrix(target, target) - return sinkhorn_divergence( + # Newer ott-jax returns ``Tuple[divergence, SinkhornDivergenceOutput]``; + # unpack the divergence scalar directly. + div, _ = sinkhorn_divergence( Geometry, loss_matrix_xy, loss_matrix_xx, loss_matrix_yy, # uniform weights - a=jnp.ones((pred.shape[0],)) / pred.shape[0], - b=jnp.ones((target.shape[0],)) / target.shape[0], - sinkhorn_kwargs={"threshold": 1e-4}, - ).divergence + a=jnp.ones((pred.shape[0],), dtype=jnp.float32) / pred.shape[0], + b=jnp.ones((target.shape[0],), dtype=jnp.float32) / target.shape[0], + # ``sinkhorn_kwargs`` was dropped from ``Geometry.__init__`` in newer + # ott-jax releases; forward the tolerance via ``solve_kwargs`` instead. + solve_kwargs={"threshold": 1e-4}, + ) + return div def _sinkhorn_pot(self, pred: jnp.ndarray, target: jnp.ndarray): """Jax-compatible POT implementation of Sinkorn.""" diff --git a/lagrangebench/evaluate/rollout.py b/lagrangebench/evaluate/rollout.py index 0d5dd82..4b58c86 100644 --- a/lagrangebench/evaluate/rollout.py +++ b/lagrangebench/evaluate/rollout.py @@ -6,7 +6,7 @@ from functools import partial from typing import Callable, Dict, Iterable, Optional, Tuple, Union -import haiku as hk +import equinox as eqx import jax import jax.numpy as jnp import jax_sph.jax_md.partition as partition @@ -23,41 +23,44 @@ broadcast_from_batch, broadcast_to_batch, get_kinematic_mask, - load_haiku, + load_model, set_seed, ) -@partial(jit, static_argnames=["model_apply", "case_integrate"]) +@partial(eqx.filter_jit, donate="none") def _forward_eval( - params: hk.Params, - state: hk.State, + model: eqx.Module, sample: Tuple[jnp.ndarray, jnp.ndarray], current_positions: jnp.ndarray, target_positions: jnp.ndarray, - model_apply: Callable, case_integrate: Callable, -) -> jnp.ndarray: +) -> Tuple[jnp.ndarray, eqx.Module]: """Run one update of the 'current_state' using the trained model Args: - params: Haiku model parameters - state: Haiku model state + model: Trained equinox model. + sample: Tuple ``(features, particle_type)``. current_positions: Set of historic positions of shape (n_nodel, t_window, dim) target_positions: used to get the next state of kinematic particles, i.e. those who are not update using the ML model, e.g. boundary particles - model_apply: model function case_integrate: integration function from case.integrate Return: current_positions: after shifting the historic position sequence by one, i.e. by the newly computed most recent position + model: possibly updated model. Most models are stateless and are returned + unchanged, but some (e.g. models tracking a rollout step counter) + return a new version of themselves. """ _, particle_type = sample # predict acceleration and integrate - pred, state = model_apply(params, state, sample) - + pred = model(sample) + # if the model exposes an ``advance`` method it is used to update the + # model's internal state; otherwise the model is returned unchanged. + if hasattr(model, "advance"): + model = model.advance() next_position = case_integrate(pred, current_positions) # update only the positions of non-boundary particles @@ -70,17 +73,15 @@ def _forward_eval( current_positions = jnp.concatenate( [current_positions[:, 1:], next_position[:, None, :]], axis=1 - ) # as next model input - - return current_positions, state + ) + return current_positions, model def _eval_batched_rollout( forward_eval_vmap: Callable, preprocess_eval_vmap: Callable, case, - params: hk.Params, - state: hk.State, + model: eqx.Module, traj_batch_i: Tuple[jnp.ndarray, jnp.ndarray], neighbors: partition.NeighborList, metrics_computer_vmap: Callable, @@ -93,11 +94,10 @@ def _eval_batched_rollout( Args: forward_eval_vmap: Model function. case: CaseSetupFn class. - params: Haiku params. - state: Haiku state. + model: Trained equinox model. traj_batch_i: Trajectory to evaluate. neighbors: Neighbor list. - metrics_computer: Vectorized MetricsComputer with the desired metrics. + metrics_computer_vmap: Vectorized MetricsComputer with the desired metrics. n_rollout_steps: Number of rollout steps. t_window: Length of the input sequence. n_extrap_steps: Number of extrapolation steps (beyond the ground truth rollout). @@ -105,17 +105,13 @@ def _eval_batched_rollout( Returns: A tuple with (predicted rollout, metrics, neighbor list). """ - # particle type is treated as a static property defined by state at t=0 pos_input_batch, particle_type_batch = traj_batch_i - # current_batch_size might be < eval_batch_size if the last batch is not full current_batch_size, n_nodes_max, _, dim = pos_input_batch.shape - # if n_rollout_steps set to -1, use the whole trajectory if n_rollout_steps == -1: n_rollout_steps = pos_input_batch.shape[2] - t_window current_positions_batch = pos_input_batch[:, :, 0:t_window] - # (batch, n_nodes, t_window, dim) traj_len = n_rollout_steps + n_extrap_steps target_positions_batch = pos_input_batch[:, :, t_window : t_window + traj_len] @@ -123,19 +119,26 @@ def _eval_batched_rollout( neighbors_batch = broadcast_to_batch(neighbors, current_batch_size) step = 0 + retries = 0 + max_retries = 10 while step < n_rollout_steps + n_extrap_steps: sample_batch = (current_positions_batch, particle_type_batch) - # 1. preprocess features features_batch, neighbors_batch = preprocess_eval_vmap( sample_batch, neighbors_batch ) - # 2. check whether list overflowed and fix it if so if neighbors_batch.did_buffer_overflow.sum() > 0: - # check if the neighbor list is too small for any of the samples - # if so, reallocate the neighbor list - + if retries >= max_retries: + print( + f"(eval) Neighbor-list overflow persists after {retries} retries " + f"at step {step}. Aborting rollout; remaining steps will be " + f"reported with the last successful prediction." + ) + # Keep predictions_batch at its current value for the remaining + # steps (already zeros). This surfaces as a large rollout error + # while preventing an infinite loop. + break print(f"(eval) Reallocate neighbors list at step {step}") ind = jnp.argmax(neighbors_batch.did_buffer_overflow) sample = broadcast_from_batch(sample_batch, index=ind) @@ -145,60 +148,51 @@ def _eval_batched_rollout( f"(eval) From {neighbors_batch.idx[ind].shape} to {nbrs_temp.idx.shape}" ) neighbors_batch = broadcast_to_batch(nbrs_temp, current_batch_size) - - # To run the loop N times even if sometimes - # did_buffer_overflow > 0 we directly return to the beginning + retries += 1 continue + # reset retry counter once a step is successfully processed + retries = 0 - # 3. run forward model - current_positions_batch, state_batch = forward_eval_vmap( - params, - state, + current_positions_batch, model_batch = forward_eval_vmap( + model, (features_batch, particle_type_batch), current_positions_batch, target_positions_batch[:, :, step], ) - # the state is not passed out of this loop, so no not really relevant - state = broadcast_from_batch(state_batch, 0) - - # 4. write predicted next position to output array + # model_batch is a batched version of the model; take sample 0 to + # recover an un-batched model for the next iteration. All samples + # produce the same update (models have no per-sample state). + model = broadcast_from_batch(model_batch, 0) predictions_batch = predictions_batch.at[:, step].set( - current_positions_batch[:, :, -1] # most recently predicted positions + current_positions_batch[:, :, -1] ) - step += 1 - # (batch, n_nodes, time, dim) -> (batch, time, n_nodes, dim) target_positions_batch = target_positions_batch.transpose(0, 2, 1, 3) - # slice out extrapolation steps metrics_batch = metrics_computer_vmap( predictions_batch[:, :n_rollout_steps, :, :], target_positions_batch ) - return (predictions_batch, metrics_batch, broadcast_from_batch(neighbors_batch, 0)) + return predictions_batch, metrics_batch, broadcast_from_batch(neighbors_batch, 0) def eval_rollout( - model_apply: Callable, + model: eqx.Module, case, - params: hk.Params, - state: hk.State, loader_eval: Iterable, neighbors: partition.NeighborList, metrics_computer: MetricsComputer, n_rollout_steps: int, n_trajs: int, - rollout_dir: str, + rollout_dir: Optional[str], out_type: str = "none", n_extrap_steps: int = 0, ) -> MetricsDict: """Compute the rollout and evaluate the metrics. Args: - model_apply: Model function. + model: Trained equinox model. case: CaseSetupFn class. - params: Haiku params. - state: Haiku state. loader_eval: Evaluation data loader. neighbors: Neighbor list. metrics_computer: MetricsComputer with the desired metrics. @@ -218,33 +212,24 @@ def eval_rollout( if rollout_dir is not None: os.makedirs(rollout_dir, exist_ok=True) - forward_eval = partial( - _forward_eval, - model_apply=model_apply, - case_integrate=case.integrate, - ) - forward_eval_vmap = vmap(forward_eval, in_axes=(None, None, 0, 0, 0)) + forward_eval = partial(_forward_eval, case_integrate=case.integrate) + forward_eval_vmap = eqx.filter_vmap(forward_eval, in_axes=(None, 0, 0, 0)) preprocess_eval_vmap = vmap(case.preprocess_eval, in_axes=(0, 0)) metrics_computer_vmap = vmap(metrics_computer, in_axes=(0, 0)) for i, traj_batch_i in enumerate(loader_eval): - # if n_trajs is not a multiple of batch_size, we slice from the last batch n_traj_left = n_trajs - i * batch_size if n_traj_left < batch_size: - traj_batch_i = jax.tree_map(lambda x: x[:n_traj_left], traj_batch_i) + traj_batch_i = jax.tree.map(lambda x: x[:n_traj_left], traj_batch_i) - # numpy to jax - traj_batch_i = jax.tree_map(lambda x: jnp.array(x), traj_batch_i) - # (pos_input_batch, particle_type_batch) = traj_batch_i - # pos_input_batch.shape = (batch, num_particles, seq_length, dim) + traj_batch_i = jax.tree.map(lambda x: jnp.array(x), traj_batch_i) example_rollout_batch, metrics_batch, neighbors = _eval_batched_rollout( forward_eval_vmap=forward_eval_vmap, preprocess_eval_vmap=preprocess_eval_vmap, case=case, - params=params, - state=state, - traj_batch_i=traj_batch_i, # (batch, nodes, t, dim) + model=model, + traj_batch_i=traj_batch_i, neighbors=neighbors, metrics_computer_vmap=metrics_computer_vmap, n_rollout_steps=n_rollout_steps, @@ -254,37 +239,33 @@ def eval_rollout( current_batch_size = traj_batch_i[0].shape[0] for j in range(current_batch_size): - # write metrics to output dictionary ind = i * batch_size + j eval_metrics[f"rollout_{ind}"] = broadcast_from_batch(metrics_batch, j) if rollout_dir is not None: - # (batch, nodes, t, dim) -> (batch, t, nodes, dim) pos_input_batch = traj_batch_i[0].transpose(0, 2, 1, 3) - for j in range(current_batch_size): # write every trajectory to file + for j in range(current_batch_size): pos_input = pos_input_batch[j] example_rollout = example_rollout_batch[j] initial_positions = pos_input[:t_window] example_full = jnp.concatenate([initial_positions, example_rollout]) example_rollout = { - "predicted_rollout": example_full, # (t + extrap, nodes, dim) - "ground_truth_rollout": pos_input, # (t, nodes, dim), - "particle_type": traj_batch_i[1][j], # (nodes,) + "predicted_rollout": example_full, + "ground_truth_rollout": pos_input, + "particle_type": traj_batch_i[1][j], } - file_prefix = os.path.join(rollout_dir, f"rollout_{i*batch_size+j}") - if out_type == "vtk": # write vtk files for each time step + file_prefix = os.path.join(rollout_dir, f"rollout_{i * batch_size + j}") + if out_type == "vtk": for k in range(example_full.shape[0]): - # predictions state_vtk = { "r": example_rollout["predicted_rollout"][k], "tag": example_rollout["particle_type"], } write_vtk(state_vtk, f"{file_prefix}_{k}.vtk") for k in range(pos_input.shape[0]): - # ground truth reference ref_state_vtk = { "r": example_rollout["ground_truth_rollout"][k], "tag": example_rollout["particle_type"], @@ -292,7 +273,6 @@ def eval_rollout( write_vtk(ref_state_vtk, f"{file_prefix}_ref_{k}.vtk") elif out_type == "pkl": filename = f"{file_prefix}.pkl" - with open(filename, "wb") as f: pickle.dump(example_rollout, f) @@ -300,7 +280,6 @@ def eval_rollout( break if rollout_dir is not None: - # save metrics t = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) with open(f"{rollout_dir}/metrics{t}.pkl", "wb") as f: pickle.dump(eval_metrics, f) @@ -309,11 +288,9 @@ def eval_rollout( def infer( - model: hk.TransformedWithState, + model: eqx.Module, case, data_test: H5Dataset, - params: Optional[hk.Params] = None, - state: Optional[hk.State] = None, load_ckp: Optional[str] = None, cfg_eval_infer: Union[Dict, DictConfig] = defaults.eval.infer, rollout_dir: Optional[str] = defaults.eval.rollout_dir, @@ -324,11 +301,10 @@ def infer( Infer on a dataset, compute metrics and optionally save rollout in out_type format. Args: - model: (Transformed) Haiku model. + model: Equinox model. If ``load_ckp`` is provided, this serves as the + structural template into which checkpointed weights are loaded. case: Case setup class. data_test: Test dataset. - params: Haiku params. - state: Haiku state. load_ckp: Path to checkpoint directory. rollout_dir: Path to rollout directory. cfg_eval_infer: Evaluation configuration for inference mode. @@ -338,25 +314,16 @@ def infer( Returns: eval_metrics: Metrics per trajectory. """ - assert ( - params is not None or load_ckp is not None - ), "Either params or a load_ckp directory must be provided for inference." - if isinstance(cfg_eval_infer, Dict): cfg_eval_infer = OmegaConf.create(cfg_eval_infer) - - # if one of the cfg_* arguments has a subset of the default configs, merge them cfg_eval_infer = OmegaConf.merge(defaults.eval.infer, cfg_eval_infer) n_trajs = cfg_eval_infer.n_trajs if n_trajs == -1: n_trajs = data_test.num_samples - if params is not None: - if state is None: - state = {} - else: - params, state, _, _ = load_haiku(load_ckp) + if load_ckp is not None: + model, _, _ = load_model(load_ckp, model) key, seed_worker, generator = set_seed(seed) @@ -374,20 +341,15 @@ def infer( input_seq_length=data_test.input_seq_length, stride=cfg_eval_infer.metrics_stride, ) - # Precompile model - model_apply = jit(model.apply) - # init values pos_input_and_target, particle_type = next(iter(loader_test)) sample = (pos_input_and_target[0], particle_type[0]) key, _, _, neighbors = case.allocate(key, sample) eval_metrics = eval_rollout( - model_apply=model_apply, + model=model, case=case, metrics_computer=metrics_computer, - params=params, - state=state, neighbors=neighbors, loader_eval=loader_test, n_rollout_steps=n_rollout_steps, diff --git a/lagrangebench/models/base.py b/lagrangebench/models/base.py index 94642a6..026711a 100644 --- a/lagrangebench/models/base.py +++ b/lagrangebench/models/base.py @@ -1,41 +1,42 @@ -from abc import ABC, abstractmethod +"""Base class for LagrangeBench models (Equinox).""" + +from abc import abstractmethod from typing import Dict, Tuple -import haiku as hk +import equinox as eqx import jax.numpy as jnp -class BaseModel(hk.Module, ABC): - """Base model class. All models must inherit from this class.""" +class BaseModel(eqx.Module): + """Abstract base class for particle simulation models. + + Subclasses inherit from :class:`equinox.Module`. Parameters are stored as + fields of the module; calling ``model(sample)`` performs a forward pass. + + We specify the dimensions of the inputs and outputs using the number of nodes + ``N``, the number of edges ``E``, the number of historic velocities + ``K = input_seq_length - 1``, and the feature dimensionality ``dim``. + + Expected ``features`` dict entries (present ones depend on the dataset): + + - ``abs_pos`` ``(N, K+1, dim)`` — absolute positions + - ``vel_hist`` ``(N, K*dim)`` — historical velocity sequence + - ``vel_mag`` ``(N,)`` — velocity magnitudes + - ``bound`` ``(N, 2*dim)`` — distance to boundaries + - ``force`` ``(N, dim)`` — external force field + - ``rel_disp`` ``(E, dim)`` — relative displacement vectors + - ``rel_dist`` ``(E, 1)`` — relative distances + - ``senders``, ``receivers`` ``(E,)`` — edge indices + + Return dict with at least one of: + + - ``acc`` ``(N, dim)`` — (normalized) acceleration + - ``vel`` ``(N, dim)`` — (normalized) velocity + - ``pos`` ``(N, dim)`` — absolute next position + """ @abstractmethod def __call__( self, sample: Tuple[Dict[str, jnp.ndarray], jnp.ndarray] ) -> Dict[str, jnp.ndarray]: - """Forward pass. - - We specify the dimensions of the inputs and outputs using the number of nodes N, - the number of edges E, number of historic velocities K (=input_seq_length - 1), - and the dimensionality of the feature vectors dim. - - Args: - sample: Tuple with feature dictionary and particle type. Possible features - - - "abs_pos" (N, K+1, dim), absolute positions - - "vel_hist" (N, K*dim), historical velocity sequence - - "vel_mag" (N,), velocity magnitudes - - "bound" (N, 2*dim), distance to boundaries - - "force" (N, dim), external force field - - "rel_disp" (E, dim), relative displacement vectors - - "rel_dist" (E, 1), relative distances, i.e. magnitude of displacements - - "senders" (E), sender indices - - "receivers" (E), receiver indices - Returns: - Dict with model output. - The keys must be at least one of the following: - - - "acc" (N, dim), (normalized) acceleration - - "vel" (N, dim), (normalized) velocity - - "pos" (N, dim), (absolute) next position - """ raise NotImplementedError diff --git a/lagrangebench/models/egnn.py b/lagrangebench/models/egnn.py index 5c8b6f1..1a890f2 100644 --- a/lagrangebench/models/egnn.py +++ b/lagrangebench/models/egnn.py @@ -7,13 +7,12 @@ Standalone implementation + validation: https://github.com/gerkone/egnn-jax """ -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple -import haiku as hk +import equinox as eqx import jax import jax.numpy as jnp import jraph -from jax.tree_util import Partial from jax_sph.jax_md import space from lagrangebench.utils import NodeType @@ -22,188 +21,266 @@ from .utils import LinearXav, MLPXav -class EGNNLayer(hk.Module): +def _segment_sum(messages, indices, num_segments): + return jraph.segment_sum(messages, indices, num_segments) + + +class _PosCorrectionMLP(eqx.Module): + """Per-edge coord-correction MLP: blocks hidden Linear layers then a scalar head. + + The final layer's init variance is scaled by ``dt`` (UniformScaling(dt) in haiku). + """ + + hidden: List[LinearXav] + head: LinearXav + activation: Callable = eqx.field(static=True) + use_tanh: bool = eqx.field(static=True) + + def __init__( + self, + in_size: int, + hidden_size: int, + blocks: int, + activation: Callable, + dt: float, + use_tanh: bool, + *, + key, + ): + keys = jax.random.split(key, blocks + 1) + sizes = [in_size] + [hidden_size] * blocks + self.hidden = [ + LinearXav(sizes[i], sizes[i + 1], key=keys[i]) for i in range(blocks) + ] + self.head = LinearXav( + hidden_size, + 1, + with_bias=False, + gain=dt, + init="uniform_scaling", + key=keys[-1], + ) + self.activation = activation + self.use_tanh = use_tanh + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + for layer in self.hidden: + x = layer(x) + x = self.activation(x) + x = self.head(x) + if self.use_tanh: + x = jax.nn.tanh(x) + return x + + +class _VelCorrectionMLP(eqx.Module): + hidden: List[LinearXav] + head: LinearXav + activation: Callable = eqx.field(static=True) + + def __init__( + self, + in_size: int, + hidden_size: int, + blocks: int, + activation: Callable, + dt: float, + *, + key, + ): + keys = jax.random.split(key, blocks + 1) + sizes = [in_size] + [hidden_size] * blocks + self.hidden = [ + LinearXav(sizes[i], sizes[i + 1], key=keys[i]) for i in range(blocks) + ] + self.head = LinearXav( + hidden_size, + 1, + with_bias=False, + gain=dt, + init="uniform_scaling", + key=keys[-1], + ) + self.activation = activation + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + for layer in self.hidden: + x = layer(x) + x = self.activation(x) + return self.head(x) + + +class _AttentionMLP(eqx.Module): + linear: LinearXav + + def __init__(self, in_size: int, hidden_size: int, *, key): + self.linear = LinearXav(in_size, hidden_size, key=key) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + return jax.nn.sigmoid(self.linear(x)) + + +class EGNNLayer(eqx.Module): r"""E(n)-equivariant EGNN layer. Applies a message passing step where the positions are corrected with the velocities and a learnable correction term :math:`\psi_x(\mathbf{h}_i^{(t+1)})`: """ + edge_mlp: MLPXav + node_mlp: MLPXav + pos_mlp: _PosCorrectionMLP + vel_mlp: _VelCorrectionMLP + attention_mlp: Optional[_AttentionMLP] + + displacement_fn: Callable = eqx.field(static=True) + shift_fn: Callable = eqx.field(static=True) + residual: bool = eqx.field(static=True) + normalize: bool = eqx.field(static=True) + pos_aggregate_fn: Callable = eqx.field(static=True) + msg_aggregate_fn: Callable = eqx.field(static=True) + eps: float = eqx.field(static=True) + def __init__( self, - layer_num: int, hidden_size: int, output_size: int, - displacement_fn: space.DisplacementFn, - shift_fn: space.ShiftFn, + displacement_fn: Callable, + shift_fn: Callable, + edge_in_size: int, + node_in_size: int, + has_edge_attr: bool, + has_node_attr: bool, blocks: int = 1, act_fn: Callable = jax.nn.silu, - pos_aggregate_fn: Optional[Callable] = jraph.segment_sum, - msg_aggregate_fn: Optional[Callable] = jraph.segment_sum, residual: bool = True, attention: bool = False, normalize: bool = False, tanh: bool = False, dt: float = 0.001, eps: float = 1e-8, + pos_aggregate_fn: Callable = _segment_sum, + msg_aggregate_fn: Callable = _segment_sum, + *, + key, ): """Initialize the layer. Args: - layer_num: layer number hidden_size: hidden size output_size: output size displacement_fn: Displacement function for the acceleration computation. shift_fn: Shift function for updating positions + edge_in_size: Size of ``edge_attr`` features fed into the edge MLP. + node_in_size: Size of the scalar node features fed into the edge MLP. + has_edge_attr: Whether an ``edge_attr`` tensor is concatenated to the + edge input (``rel_dist`` in practice). + has_node_attr: Whether a ``node_attr`` tensor is concatenated to the + node input. blocks: number of blocks in the node and edge MLPs act_fn: activation function - pos_aggregate_fn: position aggregation function - msg_aggregate_fn: message aggregation function residual: whether to use residual connections attention: whether to use attention normalize: whether to normalize the coordinates tanh: whether to use tanh in the position update dt: position update step size eps: small number to avoid division by zero + pos_aggregate_fn: position aggregation function + msg_aggregate_fn: message aggregation function + key: PRNG key for initialization. """ - super().__init__(f"layer_{layer_num}") - - self._displacement_fn = displacement_fn - self._shift_fn = shift_fn + self.displacement_fn = displacement_fn + self.shift_fn = shift_fn + self.residual = residual + self.normalize = normalize self.pos_aggregate_fn = pos_aggregate_fn self.msg_aggregate_fn = msg_aggregate_fn - self._residual = residual - self._normalize = normalize - self._eps = eps + self.eps = eps - # message network - self._edge_mlp = MLPXav( - [hidden_size] * blocks + [hidden_size], + k_e, k_n, k_p, k_v, k_a = jax.random.split(key, 5) + + # edge mlp: in = 2 * nodes + radial + (edge_attr) + edge_mlp_in = 2 * node_in_size + 1 + (edge_in_size if has_edge_attr else 0) + self.edge_mlp = MLPXav( + input_size=edge_mlp_in, + output_sizes=[hidden_size] * blocks + [hidden_size], activation=act_fn, activate_final=True, + key=k_e, ) - - # update network - self._node_mlp = MLPXav( - [hidden_size] * blocks + [output_size], + # node mlp: in = node + msg + (node_attr) + node_mlp_in = node_in_size + hidden_size + (1 if has_node_attr else 0) + self.node_mlp = MLPXav( + input_size=node_mlp_in, + output_sizes=[hidden_size] * blocks + [output_size], activation=act_fn, activate_final=False, + key=k_n, + ) + self.pos_mlp = _PosCorrectionMLP( + in_size=hidden_size, + hidden_size=hidden_size, + blocks=blocks, + activation=act_fn, + dt=dt, + use_tanh=tanh, + key=k_p, + ) + self.vel_mlp = _VelCorrectionMLP( + in_size=output_size, + hidden_size=hidden_size, + blocks=blocks, + activation=act_fn, + dt=dt, + key=k_v, + ) + self.attention_mlp = ( + _AttentionMLP(hidden_size, hidden_size, key=k_a) if attention else None ) - - # position update network - net = [LinearXav(hidden_size)] * blocks - # NOTE: from https://github.com/vgsatorras/egnn/blob/main/models/gcl.py#L254 - net += [ - act_fn, - LinearXav(1, with_bias=False, w_init=hk.initializers.UniformScaling(dt)), - ] - if tanh: - net.append(jax.nn.tanh) - self._pos_correction_mlp = hk.Sequential(net) - - # velocity integrator network - net = [LinearXav(hidden_size)] * blocks - net += [ - act_fn, - LinearXav(1, with_bias=False, w_init=hk.initializers.UniformScaling(dt)), - ] - self._vel_correction_mlp = hk.Sequential(net) - - # attention - self._attention_mlp = None - if attention: - self._attention_mlp = hk.Sequential( - [LinearXav(hidden_size), jax.nn.sigmoid] - ) - - def _pos_update( - self, - pos: jnp.ndarray, - graph: jraph.GraphsTuple, - coord_diff: jnp.ndarray, - ) -> jnp.ndarray: - trans = coord_diff * self._pos_correction_mlp(graph.edges) - return self.pos_aggregate_fn(trans, graph.senders, num_segments=pos.shape[0]) - - def _message( - self, - radial: jnp.ndarray, - edge_attribute: jnp.ndarray, - edge_features: Any, - incoming: jnp.ndarray, - outgoing: jnp.ndarray, - globals_: Any, - ) -> jnp.ndarray: - _ = edge_features - _ = globals_ - msg = jnp.concatenate([incoming, outgoing, radial], axis=-1) - if edge_attribute is not None: - msg = jnp.concatenate([msg, edge_attribute], axis=-1) - msg = self._edge_mlp(msg) - if self._attention_mlp: - att = self._attention_mlp(msg) - msg = msg * att - return msg - - def _update( - self, - node_attribute: jnp.ndarray, - nodes: jnp.ndarray, - senders: Any, - msg: jnp.ndarray, - globals_: Any, - ) -> jnp.ndarray: - _ = senders - _ = globals_ - x = jnp.concatenate([nodes, msg], axis=-1) - if node_attribute is not None: - x = jnp.concatenate([x, node_attribute], axis=-1) - x = self._node_mlp(x) - if self._residual: - x = nodes + x - return x def _coord2radial( - self, graph: jraph.GraphsTuple, coord: jnp.array - ) -> Tuple[jnp.array, jnp.array]: - coord_diff = self._displacement_fn(coord[graph.senders], coord[graph.receivers]) + self, senders: jnp.ndarray, receivers: jnp.ndarray, coord: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + coord_diff = self.displacement_fn(coord[senders], coord[receivers]) radial = jnp.sum(coord_diff**2, 1)[:, jnp.newaxis] - if self._normalize: + if self.normalize: norm = jnp.sqrt(radial) - coord_diff = coord_diff / (norm + self._eps) + coord_diff = coord_diff / (norm + self.eps) return radial, coord_diff def __call__( self, - graph: jraph.GraphsTuple, + nodes: jnp.ndarray, + senders: jnp.ndarray, + receivers: jnp.ndarray, pos: jnp.ndarray, vel: jnp.ndarray, - edge_attribute: Optional[jnp.ndarray] = None, - node_attribute: Optional[jnp.ndarray] = None, - ) -> Tuple[jraph.GraphsTuple, jnp.ndarray]: - """ - Apply EGNN layer. - - Args: - graph: Graph from previous step - pos: Node position, updated separately - vel: Node velocity - edge_attribute: Edge attribute (optional) - node_attribute: Node attribute (optional) - Returns: - Updated graph, node position - """ - radial, coord_diff = self._coord2radial(graph, pos) - graph = jraph.GraphNetwork( - update_edge_fn=Partial(self._message, radial, edge_attribute), - update_node_fn=Partial(self._update, node_attribute), - aggregate_edges_for_nodes_fn=self.msg_aggregate_fn, - )(graph) - # update position - pos = self._shift_fn(pos, self._pos_update(pos, graph, coord_diff)) - # integrate velocity - pos = self._shift_fn(pos, self._vel_correction_mlp(graph.nodes) * vel) - return graph, pos + edge_attr: Optional[jnp.ndarray] = None, + node_attr: Optional[jnp.ndarray] = None, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + radial, coord_diff = self._coord2radial(senders, receivers, pos) + + msg_inputs = [nodes[senders], nodes[receivers], radial] + if edge_attr is not None: + msg_inputs.append(edge_attr) + msg = self.edge_mlp(jnp.concatenate(msg_inputs, axis=-1)) + if self.attention_mlp is not None: + msg = msg * self.attention_mlp(msg) + + aggr = self.msg_aggregate_fn(msg, receivers, nodes.shape[0]) + node_inputs = [nodes, aggr] + if node_attr is not None: + node_inputs.append(node_attr) + new_nodes = self.node_mlp(jnp.concatenate(node_inputs, axis=-1)) + if self.residual: + new_nodes = nodes + new_nodes + + trans = coord_diff * self.pos_mlp(msg) + pos_delta = self.pos_aggregate_fn(trans, senders, num_segments=pos.shape[0]) + pos = self.shift_fn(pos, pos_delta) + pos = self.shift_fn(pos, self.vel_mlp(new_nodes) * vel) + + return new_nodes, pos class EGNN(BaseModel): @@ -247,15 +324,27 @@ class EGNN(BaseModel): - we apply a simple integrator after the last layer to get the acceleration. """ + node_embed: LinearXav + layers: List[EGNNLayer] + + output_size: int = eqx.field(static=True) + num_mp_steps: int = eqx.field(static=True) + n_vels: int = eqx.field(static=True) + homogeneous_particles: bool = eqx.field(static=True) + vel_mean: jnp.ndarray + vel_std: jnp.ndarray + displacement_fn: Callable = eqx.field(static=True) + shift_fn: Callable = eqx.field(static=True) + def __init__( self, hidden_size: int, output_size: int, dt: float, n_vels: int, - displacement_fn: space.DisplacementFn, - shift_fn: space.ShiftFn, - normalization_stats: Optional[Dict[str, jnp.ndarray]] = None, + displacement_fn: Callable, + shift_fn: Callable, + normalization_stats: Optional[Dict[str, Dict[str, jnp.ndarray]]] = None, act_fn: Callable = jax.nn.silu, num_mp_steps: int = 4, homogeneous_particles: bool = True, @@ -263,6 +352,8 @@ def __init__( attention: bool = False, normalize: bool = False, tanh: bool = False, + *, + key, ): r""" Initialize the network. @@ -287,114 +378,119 @@ def __init__( tanh: Sets a tanh activation function at the output of ``\phi_x(m_{ij})``. It bounds the output of ``\phi_x(m_{ij})`` which definitely improves in stability but it may decrease in accuracy. Not used in the paper. + key: PRNG key for initialization. """ - super().__init__() - # network - self._hidden_size = hidden_size - self._output_size = output_size - self._act_fn = act_fn - self._num_mp_steps = num_mp_steps - self._residual = residual - self._attention = attention - self._normalize = normalize - self._tanh = tanh - - # integrator - self._dt = dt / self._num_mp_steps - self._displacement_fn = displacement_fn - self._shift_fn = shift_fn + self.output_size = output_size + self.num_mp_steps = num_mp_steps + self.n_vels = n_vels + self.homogeneous_particles = homogeneous_particles + self.displacement_fn = displacement_fn + self.shift_fn = shift_fn + if normalization_stats is None: normalization_stats = { - "velocity": {"mean": 0.0, "std": 1.0}, - "acceleration": {"mean": 0.0, "std": 1.0}, + "velocity": {"mean": jnp.array(0.0), "std": jnp.array(1.0)}, + "acceleration": {"mean": jnp.array(0.0), "std": jnp.array(1.0)}, } - self._vel_stats = normalization_stats["velocity"] - self._acc_stats = normalization_stats["acceleration"] - - # transform - self._n_vels = n_vels - self._homogeneous_particles = homogeneous_particles + self.vel_mean = jnp.asarray(normalization_stats["velocity"]["mean"]) + self.vel_std = jnp.asarray(normalization_stats["velocity"]["std"]) + + # node scalar features: one magnitude per historical velocity + node_scalar_size = n_vels + if not homogeneous_particles: + node_scalar_size += NodeType.SIZE + + keys = jax.random.split(key, num_mp_steps + 1) + self.node_embed = LinearXav(node_scalar_size, hidden_size, key=keys[0]) + + dt_per_step = dt / num_mp_steps + self.layers = [ + EGNNLayer( + hidden_size=hidden_size, + output_size=hidden_size, + displacement_fn=displacement_fn, + shift_fn=shift_fn, + edge_in_size=1, + node_in_size=hidden_size, + has_edge_attr=True, # rel_dist + has_node_attr=False, + act_fn=act_fn, + residual=residual, + attention=attention, + normalize=normalize, + tanh=tanh, + dt=dt_per_step, + key=keys[i + 1], + ) + for i in range(num_mp_steps) + ] def _transform( self, features: Dict[str, jnp.ndarray], particle_type: jnp.ndarray - ) -> Tuple[jraph.GraphsTuple, Dict[str, jnp.ndarray]]: - props = {} + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], Optional[jnp.ndarray]]: n_nodes = features["vel_hist"].shape[0] + vel = jnp.reshape(features["vel_hist"], (n_nodes, self.n_vels, -1)) + pos = features["abs_pos"][:, -1] + edge_attr = features["rel_dist"] - props["vel"] = jnp.reshape(features["vel_hist"], (n_nodes, self._n_vels, -1)) - - # most recent position - props["pos"] = features["abs_pos"][:, -1] - # relative distances between particles - props["edge_attr"] = features["rel_dist"] - # force magnitude as node attributes - props["node_attr"] = None + node_attr = None if "force" in features: - props["node_attr"] = jnp.sqrt( + node_attr = jnp.sqrt( jnp.sum(features["force"] ** 2, axis=-1, keepdims=True) ) - # velocity magnitudes as node features node_features = jnp.concatenate( [ - jnp.sqrt(jnp.sum(props["vel"][:, i, :] ** 2, axis=-1, keepdims=True)) - for i in range(self._n_vels) + jnp.sqrt(jnp.sum(vel[:, i, :] ** 2, axis=-1, keepdims=True)) + for i in range(self.n_vels) ], axis=-1, ) - if not self._homogeneous_particles: + if not self.homogeneous_particles: particles = jax.nn.one_hot(particle_type, NodeType.SIZE) node_features = jnp.concatenate([node_features, particles], axis=-1) - graph = jraph.GraphsTuple( - nodes=node_features, - edges=None, - senders=features["senders"], - receivers=features["receivers"], - n_node=jnp.array([n_nodes]), - n_edge=jnp.array([len(features["senders"])]), - globals=None, + return ( + node_features, + vel, + pos, + features["senders"], + features["receivers"], + edge_attr, + node_attr, ) - return graph, props - - def _postprocess( - self, next_pos: jnp.ndarray, props: Dict[str, jnp.ndarray] - ) -> Dict[str, jnp.ndarray]: - prev_vel = props["vel"][:, -1, :] - prev_pos = props["pos"] - # first order finite difference - next_vel = self._displacement_fn(next_pos, prev_pos) - acc = next_vel - prev_vel - return {"pos": next_pos, "vel": next_vel, "acc": acc} - def __call__( self, sample: Tuple[Dict[str, jnp.ndarray], jnp.ndarray] ) -> Dict[str, jnp.ndarray]: - graph, props = self._transform(*sample) - # input node embedding - h = LinearXav(self._hidden_size, name="scalar_emb")(graph.nodes) - graph = graph._replace(nodes=h) - prev_vel = props["vel"][:, -1, :] - # egnn works with unnormalized velocities - prev_vel = prev_vel * self._vel_stats["std"] + self._vel_stats["mean"] - # message passing - next_pos = props["pos"].copy() - for n in range(self._num_mp_steps): - graph, next_pos = EGNNLayer( - layer_num=n, - hidden_size=self._hidden_size, - output_size=self._hidden_size, - displacement_fn=self._displacement_fn, - shift_fn=self._shift_fn, - act_fn=self._act_fn, - residual=self._residual, - attention=self._attention, - normalize=self._normalize, - dt=self._dt, - tanh=self._tanh, - )(graph, next_pos, prev_vel, props["edge_attr"], props["node_attr"]) - - # position finite differencing to get acceleration - out = self._postprocess(next_pos, props) - return out + features, particle_type = sample + ( + nodes, + vel_hist, + pos, + senders, + receivers, + edge_attr, + node_attr, + ) = self._transform(features, particle_type) + + prev_vel_norm = vel_hist[:, -1, :] + # EGNN works with unnormalized velocities + prev_vel = prev_vel_norm * self.vel_std + self.vel_mean + + h = self.node_embed(nodes) + next_pos = pos + for layer in self.layers: + h, next_pos = layer( + nodes=h, + senders=senders, + receivers=receivers, + pos=next_pos, + vel=prev_vel, + edge_attr=edge_attr, + node_attr=node_attr, + ) + + next_vel = self.displacement_fn(next_pos, pos) + acc = next_vel - prev_vel + return {"pos": next_pos, "vel": next_vel, "acc": acc} diff --git a/lagrangebench/models/gns.py b/lagrangebench/models/gns.py index 9020231..ff54f7a 100644 --- a/lagrangebench/models/gns.py +++ b/lagrangebench/models/gns.py @@ -1,37 +1,77 @@ -""" -Graph Network-based Simulator. -GNS model and feature transform. +"""Graph Network-based Simulator (Equinox). + +Reference: Sanchez-Gonzalez et al. (https://arxiv.org/abs/2002.09405). """ -from typing import Dict, Tuple +from typing import Dict, List, Tuple -import haiku as hk +import equinox as eqx +import jax import jax.numpy as jnp import jraph from lagrangebench.utils import NodeType from .base import BaseModel -from .utils import build_mlp +from .utils import Embedding, MLPBlock, build_mlp + + +class _MPStep(eqx.Module): + """One Graph Network block: edge update, node update, residual.""" + + edge_mlp: MLPBlock + node_mlp: MLPBlock + + def __init__(self, latent_size: int, blocks_per_step: int, *, key): + k_e, k_n = jax.random.split(key) + self.edge_mlp = build_mlp( + input_size=3 * latent_size, + latent_size=latent_size, + output_size=latent_size, + num_hidden_layers=blocks_per_step, + is_layer_norm=True, + key=k_e, + ) + self.node_mlp = build_mlp( + input_size=2 * latent_size, + latent_size=latent_size, + output_size=latent_size, + num_hidden_layers=blocks_per_step, + is_layer_norm=True, + key=k_n, + ) + + def __call__( + self, + nodes: jnp.ndarray, + edges: jnp.ndarray, + senders: jnp.ndarray, + receivers: jnp.ndarray, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + msg_in = jnp.concatenate( + [nodes[senders], nodes[receivers], edges], axis=-1 + ) + new_edges = self.edge_mlp(msg_in) + aggr = jraph.segment_sum(new_edges, receivers, nodes.shape[0]) + node_in = jnp.concatenate([nodes, aggr], axis=-1) + new_nodes = self.node_mlp(node_in) + return nodes + new_nodes, edges + new_edges class GNS(BaseModel): r"""Graph Network-based Simulator by `Sanchez-Gonzalez et al. `_. - - GNS is the simples graph neural network applied to particle dynamics. It is built on - the usual Graph Network architecture, with an encoder, a processor, and a decoder. - - .. math:: - \begin{align} - \mathbf{m}_{ij}^{(t+1)} &= \phi \left( - \mathbf{m}_{ij}^{(t)}, \mathbf{h}_i^{(t)}, \mathbf{h}_j^{(t)} \right) \\ - \mathbf{h}_i^{(t+1)} &= \psi \left( - \mathbf{h}_i^{(t)}, \sum_{j \in \mathcal{N}(i)} \mathbf{m}_{ij}^{(t+1)} - \right) \\ - \end{align} """ + node_encoder: MLPBlock + edge_encoder: MLPBlock + mp_steps: List[_MPStep] + decoder: MLPBlock + embedding: Embedding + output_size: int = eqx.field(static=True) + num_mp_steps: int = eqx.field(static=True) + num_particle_types: int = eqx.field(static=True) + def __init__( self, particle_dimension: int, @@ -39,133 +79,92 @@ def __init__( blocks_per_step: int, num_mp_steps: int, particle_type_embedding_size: int, + node_in_size: int, + edge_in_size: int, num_particle_types: int = NodeType.SIZE, + *, + key, ): """Initialize the model. Args: - particle_dimension: Space dimensionality (e.g. 2 or 3). + particle_dimension: Space dimensionality (2 or 3). latent_size: Size of the latent representations. - blocks_per_step: Number of MLP layers per block. + blocks_per_step: Number of MLP hidden layers per GN block. num_mp_steps: Number of message passing steps. - particle_type_embedding_size: Size of the particle type embedding. - num_particle_types: Max number of particle types. + particle_type_embedding_size: Size of the particle-type embedding. + node_in_size: Size of scalar node features produced by + ``case.preprocess`` (sum of vel_hist, vel_mag, bound, force dims). + edge_in_size: Size of edge features produced by ``case.preprocess`` + (``rel_disp`` + ``rel_dist``). + num_particle_types: Max number of particle types; embedding is skipped + when ``num_particle_types <= 1``. + key: PRNG key for initialization. """ - super().__init__() - self._output_size = particle_dimension - self._latent_size = latent_size - self._blocks_per_step = blocks_per_step - self._mp_steps = num_mp_steps - self._num_particle_types = num_particle_types - - self._embedding = hk.Embed( - num_particle_types, particle_type_embedding_size - ) # (9, 16) - - def _encoder(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: - """MLP graph encoder.""" - node_latents = build_mlp( - self._latent_size, self._latent_size, self._blocks_per_step - )(graph.nodes) - edge_latents = build_mlp( - self._latent_size, self._latent_size, self._blocks_per_step - )(graph.edges) - return jraph.GraphsTuple( - nodes=node_latents, - edges=edge_latents, - globals=graph.globals, - receivers=graph.receivers, - senders=graph.senders, - n_node=jnp.asarray([node_latents.shape[0]]), - n_edge=jnp.asarray([edge_latents.shape[0]]), + self.output_size = particle_dimension + self.num_mp_steps = num_mp_steps + self.num_particle_types = num_particle_types + + k_emb, k_nenc, k_eenc, k_dec, *k_mp = jax.random.split(key, 4 + num_mp_steps) + self.embedding = Embedding( + num_particle_types, particle_type_embedding_size, key=k_emb ) - def _processor(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: - """Sequence of Graph Network blocks.""" - - def update_edge_features( - edge_features, - sender_node_features, - receiver_node_features, - _, # globals_ - ): - update_fn = build_mlp( - self._latent_size, self._latent_size, self._blocks_per_step - ) - # Calculate sender node features from edge features - return update_fn( - jnp.concatenate( - [sender_node_features, receiver_node_features, edge_features], - axis=-1, - ) - ) - - def update_node_features( - node_features, - _, # aggr_sender_edge_features, - aggr_receiver_edge_features, - __, # globals_, - ): - update_fn = build_mlp( - self._latent_size, self._latent_size, self._blocks_per_step - ) - features = [node_features, aggr_receiver_edge_features] - return update_fn(jnp.concatenate(features, axis=-1)) - - # Perform iterative message passing by stacking Graph Network blocks - for _ in range(self._mp_steps): - _graph = jraph.GraphNetwork( - update_edge_fn=update_edge_features, update_node_fn=update_node_features - )(graph) - graph = graph._replace( - nodes=_graph.nodes + graph.nodes, edges=_graph.edges + graph.edges - ) - - return graph - - def _decoder(self, graph: jraph.GraphsTuple): - """MLP graph node decoder.""" - return build_mlp( - self._latent_size, - self._output_size, - self._blocks_per_step, + embed_size = particle_type_embedding_size if num_particle_types > 1 else 0 + self.node_encoder = build_mlp( + input_size=node_in_size + embed_size, + latent_size=latent_size, + output_size=latent_size, + num_hidden_layers=blocks_per_step, + is_layer_norm=True, + key=k_nenc, + ) + self.edge_encoder = build_mlp( + input_size=edge_in_size, + latent_size=latent_size, + output_size=latent_size, + num_hidden_layers=blocks_per_step, + is_layer_norm=True, + key=k_eenc, + ) + self.mp_steps = [ + _MPStep(latent_size, blocks_per_step, key=k) for k in k_mp + ] + self.decoder = build_mlp( + input_size=latent_size, + latent_size=latent_size, + output_size=particle_dimension, + num_hidden_layers=blocks_per_step, is_layer_norm=False, - )(graph.nodes) + key=k_dec, + ) def _transform( - self, features: Dict[str, jnp.ndarray], particle_type: jnp.ndarray - ) -> jraph.GraphsTuple: - """Convert physical features to jraph.GraphsTuple for gns.""" - n_total_points = features["vel_hist"].shape[0] + self, features: Dict[str, jnp.ndarray] + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: node_features = [ features[k] for k in ["vel_hist", "vel_mag", "bound", "force"] if k in features ] - edge_features = [features[k] for k in ["rel_disp", "rel_dist"] if k in features] - - graph = jraph.GraphsTuple( - nodes=jnp.concatenate(node_features, axis=-1), - edges=jnp.concatenate(edge_features, axis=-1), - receivers=features["receivers"], - senders=features["senders"], - n_node=jnp.array([n_total_points]), - n_edge=jnp.array([len(features["senders"])]), - globals=None, - ) - - return graph, particle_type + edge_features = [ + features[k] for k in ["rel_disp", "rel_dist"] if k in features + ] + nodes = jnp.concatenate(node_features, axis=-1) + edges = jnp.concatenate(edge_features, axis=-1) + return nodes, edges, features["senders"], features["receivers"] def __call__( self, sample: Tuple[Dict[str, jnp.ndarray], jnp.ndarray] ) -> Dict[str, jnp.ndarray]: - graph, particle_type = self._transform(*sample) - - if self._num_particle_types > 1: - particle_type_embeddings = self._embedding(particle_type) - new_node_features = jnp.concatenate( - [graph.nodes, particle_type_embeddings], axis=-1 - ) - graph = graph._replace(nodes=new_node_features) - acc = self._decoder(self._processor(self._encoder(graph))) - return {"acc": acc} + features, particle_type = sample + nodes, edges, senders, receivers = self._transform(features) + + if self.num_particle_types > 1: + nodes = jnp.concatenate([nodes, self.embedding(particle_type)], axis=-1) + + nodes = self.node_encoder(nodes) + edges = self.edge_encoder(edges) + for step in self.mp_steps: + nodes, edges = step(nodes, edges, senders, receivers) + return {"acc": self.decoder(nodes)} diff --git a/lagrangebench/models/linear.py b/lagrangebench/models/linear.py index d804b4a..12f8233 100644 --- a/lagrangebench/models/linear.py +++ b/lagrangebench/models/linear.py @@ -1,42 +1,42 @@ -"""Simple baseline linear model.""" +"""Simple baseline linear model (Equinox).""" from typing import Dict, Tuple -import haiku as hk import jax.numpy as jnp import numpy as np -from jax import vmap from .base import BaseModel +from .utils import LinearXav class Linear(BaseModel): - r"""Model defining linear relation between input nodes and targets. + r"""Linear relation between input scalar features and targets. - :math:`\mathbf{a}_i = \mathbf{W} \mathbf{x}_i` where :math:`\mathbf{a}_i` are the - output accelerations, :math:`\mathbf{W}` is a learnable weight matrix and - :math:`\mathbf{x}_i` are input features. + :math:`\mathbf{a}_i = \mathbf{W} \mathbf{x}_i`. """ - def __init__(self, dim_out): + mlp: LinearXav + + def __init__(self, in_features: int, dim_out: int, *, key): """Initialize the model. Args: + in_features: Total number of scalar input features per node (sum of the + dimensionalities of ``vel_hist``, ``vel_mag``, ``bound``, ``force`` + plus 1 for the particle type). dim_out: Output dimensionality. + key: PRNG key for init. """ - super().__init__() - self.mlp = hk.Linear(dim_out) + self.mlp = LinearXav(in_features, dim_out, key=key) def __call__( self, sample: Tuple[Dict[str, jnp.ndarray], np.ndarray] ) -> Dict[str, jnp.ndarray]: - # transform features, particle_type = sample x = [ features[k] for k in ["vel_hist", "vel_mag", "bound", "force"] if k in features ] + [particle_type[:, None]] - # call - acc = vmap(self.mlp)(jnp.concatenate(x, axis=-1)) + acc = self.mlp(jnp.concatenate(x, axis=-1)) return {"acc": acc} diff --git a/lagrangebench/models/painn.py b/lagrangebench/models/painn.py index de83f98..2362143 100644 --- a/lagrangebench/models/painn.py +++ b/lagrangebench/models/painn.py @@ -8,16 +8,16 @@ Standalone implementation + validation: https://github.com/gerkone/painn-jax """ -from typing import Callable, Dict, NamedTuple, Tuple +from typing import Callable, Dict, List, NamedTuple, Optional, Tuple -import haiku as hk +import equinox as eqx import jax import jax.numpy as jnp -import jax.tree_util as tree import jraph from lagrangebench.utils import NodeType +from .base import BaseModel from .utils import LinearXav @@ -32,54 +32,63 @@ class NodeFeatures(NamedTuple): ReadoutBuilderFn = Callable[..., ReadoutFn] -class GatedEquivariantBlock(hk.Module): +class GatedEquivariantBlock(eqx.Module): """Gated equivariant block (restricted to vectorial features). .. image:: https://i.imgur.com/EMlg2Qi.png """ + vector_mix_net: LinearXav + gate_lin1: LinearXav + gate_lin2: LinearXav + activation: Callable = eqx.field(static=True) + scalar_activation: Optional[Callable] = eqx.field(static=True) + scalar_out_channels: int = eqx.field(static=True) + vector_out_channels: int = eqx.field(static=True) + eps: float = eqx.field(static=True) + def __init__( self, + s_in_size: int, + v_in_channels: int, hidden_size: int, scalar_out_channels: int, vector_out_channels: int, activation: Callable = jax.nn.silu, - scalar_activation: Callable = None, + scalar_activation: Optional[Callable] = None, eps: float = 1e-8, - name: str = "gated_equivariant_block", + *, + key, ): """Initialize the layer. Args: + s_in_size: Input scalar feature size. + v_in_channels: Input vector feature channels. hidden_size: Number of hidden channels. scalar_out_channels: Number of scalar output channels. vector_out_channels: Number of vector output channels. activation: Gate activation function. scalar_activation: Activation function for the scalar output. eps: Constant added in norm to prevent derivation instabilities. - name: Name of the module. - + key: PRNG key for initialization. """ - super().__init__(name) - assert scalar_out_channels > 0 and vector_out_channels > 0 - self._scalar_out_channels = scalar_out_channels - self._vector_out_channels = vector_out_channels - self._eps = eps + self.scalar_out_channels = scalar_out_channels + self.vector_out_channels = vector_out_channels + self.eps = eps + k_v, k_g1, k_g2 = jax.random.split(key, 3) self.vector_mix_net = LinearXav( - 2 * vector_out_channels, - with_bias=False, - name="vector_mix_net", + v_in_channels, 2 * vector_out_channels, with_bias=False, key=k_v ) - self.gate_block = hk.Sequential( - [ - LinearXav(hidden_size), - activation, - LinearXav(scalar_out_channels + vector_out_channels), - ], - name="scalar_gate_net", + self.gate_lin1 = LinearXav( + s_in_size + vector_out_channels, hidden_size, key=k_g1 ) + self.gate_lin2 = LinearXav( + hidden_size, scalar_out_channels + vector_out_channels, key=k_g2 + ) + self.activation = activation self.scalar_activation = scalar_activation def __call__( @@ -87,20 +96,62 @@ def __call__( ) -> Tuple[jnp.ndarray, jnp.ndarray]: v_l, v_r = jnp.split(self.vector_mix_net(v), 2, axis=-1) - v_r_norm = jnp.sqrt(jnp.sum(v_r**2, axis=-2) + self._eps) + v_r_norm = jnp.sqrt(jnp.sum(v_r**2, axis=-2) + self.eps) gating_scalars = jnp.concatenate([s, v_r_norm], axis=-1) - s, _, v_gate = jnp.split( - self.gate_block(gating_scalars), - [self._scalar_out_channels, self._vector_out_channels], + gate_out = self.gate_lin2(self.activation(self.gate_lin1(gating_scalars))) + s_new, _, v_gate = jnp.split( + gate_out, + [self.scalar_out_channels, self.vector_out_channels], axis=-1, ) # scale the vectors by the gating scalars - v = v_l * v_gate[:, jnp.newaxis] + v_new = v_l * v_gate[:, jnp.newaxis] - if self.scalar_activation: - s = self.scalar_activation(s) + if self.scalar_activation is not None: + s_new = self.scalar_activation(s_new) + return s_new, v_new - return s, v + +class GaussianRBF(eqx.Module): + r"""Gaussian radial basis functions. + + Args: + n_rbf: total number of Gaussian functions, :math:`N_g`. + cutoff: center of last Gaussian function, :math:`\mu_{N_g}` + start: center of first Gaussian function, :math:`\mu_0`. + trainable: If True, widths and offset of Gaussian functions learnable. + """ + + widths: jnp.ndarray + offsets: jnp.ndarray + trainable: bool = eqx.field(static=True) + + def __init__( + self, + n_rbf: int, + cutoff: float, + start: float = 0.0, + centered: bool = False, + trainable: bool = False, + ): + if centered: + widths = jnp.linspace(start, cutoff, n_rbf) + offsets = jnp.zeros_like(widths) + else: + offsets = jnp.linspace(start, cutoff, n_rbf) + widths = jnp.abs(cutoff - start) / n_rbf * jnp.ones_like(offsets) + self.widths = widths + self.offsets = offsets + self.trainable = trainable + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + widths = self.widths if self.trainable else jax.lax.stop_gradient(self.widths) + offsets = ( + self.offsets if self.trainable else jax.lax.stop_gradient(self.offsets) + ) + coeff = -0.5 / jnp.power(widths, 2) + diff = x[..., jnp.newaxis] - offsets + return jnp.exp(coeff * jnp.power(diff, 2)) def gaussian_rbf( @@ -109,7 +160,7 @@ def gaussian_rbf( start: float = 0.0, centered: bool = False, trainable: bool = False, -) -> Callable[[jnp.ndarray], Callable]: +) -> GaussianRBF: r"""Gaussian radial basis functions. Args: @@ -118,35 +169,10 @@ def gaussian_rbf( start: center of first Gaussian function, :math:`\mu_0`. trainable: If True, widths and offset of Gaussian functions learnable. """ - if centered: - widths = jnp.linspace(start, cutoff, n_rbf) - offset = jnp.zeros_like(widths) - else: - offset = jnp.linspace(start, cutoff, n_rbf) - width = jnp.abs(cutoff - start) / n_rbf * jnp.ones_like(offset) - - if trainable: - widths = hk.get_parameter( - "widths", width.shape, width.dtype, init=lambda *_: width - ) - offsets = hk.get_parameter( - "offset", offset.shape, offset.dtype, init=lambda *_: offset - ) - else: - hk.set_state("widths", jnp.array([width])) - hk.set_state("offsets", jnp.array([offset])) - widths = hk.get_state("widths") - offsets = hk.get_state("offsets") - - def _rbf(x: jnp.ndarray) -> jnp.ndarray: - coeff = -0.5 / jnp.power(widths, 2) - diff = x[..., jnp.newaxis] - offsets - return jnp.exp(coeff * jnp.power(diff, 2)) + return GaussianRBF(n_rbf, cutoff, start=start, centered=centered, trainable=trainable) - return _rbf - -def cosine_cutoff(cutoff: float) -> Callable[[jnp.ndarray], Callable]: +class CosineCutoff(eqx.Module): r"""Behler-style cosine cutoff. .. math:: @@ -159,78 +185,142 @@ def cosine_cutoff(cutoff: float) -> Callable[[jnp.ndarray], Callable]: Args: cutoff (float): cutoff radius. """ - hk.set_state("cutoff", cutoff) - cutoff = hk.get_state("cutoff") - - def _cutoff(x: jnp.ndarray) -> jnp.ndarray: - # Compute values of cutoff function - cuts = 0.5 * (jnp.cos(x * jnp.pi / cutoff) + 1.0) - # Remove contributions beyond the cutoff radius - mask = jnp.array(x < cutoff, dtype=jnp.float32) + + cutoff: float = eqx.field(static=True) + + def __init__(self, cutoff: float): + self.cutoff = cutoff + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + cuts = 0.5 * (jnp.cos(x * jnp.pi / self.cutoff) + 1.0) + mask = jnp.array(x < self.cutoff, dtype=jnp.float32) return cuts * mask - return _cutoff + +def cosine_cutoff(cutoff: float) -> CosineCutoff: + r"""Behler-style cosine cutoff. + + Args: + cutoff (float): cutoff radius. + """ + return CosineCutoff(cutoff) + + +class _PaiNNReadout(eqx.Module): + """PaiNN readout block.""" + + blocks: List[GatedEquivariantBlock] + + def __init__( + self, + hidden_size: int, + in_channels: int, + out_channels: int, + activation: Callable, + blocks: int, + eps: float, + *, + key, + ): + keys = jax.random.split(key, blocks) + modules = [] + s_in = in_channels + v_in = in_channels + ith_hidden_size = hidden_size + for i in range(blocks - 1): + ith_hidden_size = hidden_size // 2 ** (i + 1) + modules.append( + GatedEquivariantBlock( + s_in_size=s_in, + v_in_channels=v_in, + hidden_size=ith_hidden_size * 2, + scalar_out_channels=ith_hidden_size, + vector_out_channels=ith_hidden_size, + activation=activation, + eps=eps, + key=keys[i], + ) + ) + s_in = ith_hidden_size + v_in = ith_hidden_size + modules.append( + GatedEquivariantBlock( + s_in_size=s_in, + v_in_channels=v_in, + hidden_size=ith_hidden_size, + scalar_out_channels=out_channels, + vector_out_channels=out_channels, + activation=activation, + eps=eps, + key=keys[-1], + ) + ) + self.blocks = modules + + def __call__(self, graph: jraph.GraphsTuple) -> Tuple[jnp.ndarray, jnp.ndarray]: + s, v = graph.nodes + s = jnp.squeeze(s) + for block in self.blocks: + s, v = block(s, v) + return jnp.squeeze(s), jnp.squeeze(v) def PaiNNReadout( hidden_size: int, + in_channels: int, out_channels: int = 1, activation: Callable = jax.nn.silu, blocks: int = 2, eps: float = 1e-8, -) -> ReadoutFn: + *, + key, +) -> _PaiNNReadout: """ PaiNN readout block. Args: hidden_size: Number of hidden channels. - scalar_out_channels: Number of scalar/vector output channels. + in_channels: Input scalar/vector channels from the last interaction block. + out_channels: Number of scalar/vector output channels. activation: Activation function. blocks: Number of readout blocks. + eps: Small constant for norm stability. + key: PRNG key for initialization. Returns: - Configured readout function. + Configured readout module. """ - - def _readout(graph: jraph.GraphsTuple) -> Tuple[jnp.ndarray, jnp.ndarray]: - s, v = graph.nodes - s = jnp.squeeze(s) - for i in range(blocks - 1): - ith_hidden_size = hidden_size // 2 ** (i + 1) - s, v = GatedEquivariantBlock( - hidden_size=ith_hidden_size * 2, - scalar_out_channels=ith_hidden_size, - vector_out_channels=ith_hidden_size, - activation=activation, - eps=eps, - name=f"readout_block_{i}", - )(s, v) - - s, v = GatedEquivariantBlock( - hidden_size=ith_hidden_size, - scalar_out_channels=out_channels, - vector_out_channels=out_channels, - activation=activation, - eps=eps, - name="readout_block_out", - )(s, v) - - return jnp.squeeze(s), jnp.squeeze(v) - - return _readout - - -class PaiNNLayer(hk.Module): + return _PaiNNReadout( + hidden_size=hidden_size, + in_channels=in_channels, + out_channels=out_channels, + activation=activation, + blocks=blocks, + eps=eps, + key=key, + ) + + +class PaiNNLayer(eqx.Module): """PaiNN interaction block.""" + interaction_block_lin: List[LinearXav] + mixing_block_lin: List[LinearXav] + vector_mixing_block: LinearXav + activation: Callable = eqx.field(static=True) + aggregate_fn: Callable = eqx.field(static=True) + hidden_size: int = eqx.field(static=True) + eps: float = eqx.field(static=True) + def __init__( self, hidden_size: int, - layer_num: int, activation: Callable = jax.nn.silu, blocks: int = 2, aggregate_fn: Callable = jraph.segment_sum, eps: float = 1e-8, + *, + key, ): """ Initialize the PaiNN layer, made up of an interaction block and a mixing block. @@ -238,37 +328,57 @@ def __init__( Args: hidden_size: Number of node features. activation: Activation function. - layer_num: Numbering of the layer. blocks: Number of layers in the context networks. aggregate_fn: Function to aggregate the neighbors. eps: Constant added in norm to prevent derivation instabilities. + key: PRNG key for initialization. """ - super().__init__(f"layer_{layer_num}") - self._hidden_size = hidden_size - self._eps = eps - self._aggregate_fn = aggregate_fn - - # inter-particle context net - self.interaction_block = hk.Sequential( - [LinearXav(hidden_size), activation] * (blocks - 1) - + [LinearXav(3 * hidden_size)], - name="interaction_block", - ) + self.hidden_size = hidden_size + self.eps = eps + self.aggregate_fn = aggregate_fn + self.activation = activation - # intra-particle context net - self.mixing_block = hk.Sequential( - [LinearXav(hidden_size), activation] * (blocks - 1) - + [LinearXav(3 * hidden_size)], - name="mixing_block", - ) + k_i, k_m, k_v = jax.random.split(key, 3) + k_i_layers = jax.random.split(k_i, blocks) + k_m_layers = jax.random.split(k_m, blocks) + + # inter-particle context net (interaction_block) + ib_layers = [] + for i in range(blocks - 1): + ib_layers.append(LinearXav(hidden_size, hidden_size, key=k_i_layers[i])) + ib_layers.append(LinearXav(hidden_size, 3 * hidden_size, key=k_i_layers[-1])) + self.interaction_block_lin = ib_layers + + # intra-particle context net (mixing_block) + mb_layers = [] + # mixing input: [s (hidden), v_norm (hidden)] concat = 2 * hidden + mb_layers.append(LinearXav(2 * hidden_size, hidden_size, key=k_m_layers[0])) + for i in range(1, blocks - 1): + mb_layers.append(LinearXav(hidden_size, hidden_size, key=k_m_layers[i])) + mb_layers.append(LinearXav(hidden_size, 3 * hidden_size, key=k_m_layers[-1])) + self.mixing_block_lin = mb_layers # vector channel mix self.vector_mixing_block = LinearXav( - 2 * hidden_size, - with_bias=False, - name="vector_mixing_block", + hidden_size, 2 * hidden_size, with_bias=False, key=k_v ) + def _interaction(self, s: jnp.ndarray) -> jnp.ndarray: + x = s + for i, layer in enumerate(self.interaction_block_lin): + x = layer(x) + if i < len(self.interaction_block_lin) - 1: + x = self.activation(x) + return x + + def _mixing(self, ts: jnp.ndarray) -> jnp.ndarray: + x = ts + for i, layer in enumerate(self.mixing_block_lin): + x = layer(x) + if i < len(self.mixing_block_lin) - 1: + x = self.activation(x) + return x + def _message( self, s: jnp.ndarray, @@ -291,21 +401,18 @@ def _message( Returns: Aggregated messages after interaction. """ - x = self.interaction_block(s) - + x = self._interaction(s) xj = x[receivers] vj = v[receivers] ds, dv1, dv2 = jnp.split(Wij * xj, 3, axis=-1) - n_nodes = tree.tree_leaves(s)[0].shape[0] + n_nodes = s.shape[0] dv = dv1 * dir_ij[..., jnp.newaxis] + dv2 * vj - # aggregate scalars and vectors - ds = self._aggregate_fn(ds, senders, n_nodes) - dv = self._aggregate_fn(dv, senders, n_nodes) + ds = self.aggregate_fn(ds, senders, n_nodes) + dv = self.aggregate_fn(dv, senders, n_nodes) s = s + jnp.clip(ds, -1e2, 1e2) v = v + jnp.clip(dv, -1e2, 1e2) - return s, v def _update( @@ -321,10 +428,10 @@ def _update( Node features after update. """ v_l, v_r = jnp.split(self.vector_mixing_block(v), 2, axis=-1) - v_norm = jnp.sqrt(jnp.sum(v_r**2, axis=-2, keepdims=True) + self._eps) + v_norm = jnp.sqrt(jnp.sum(v_r**2, axis=-2, keepdims=True) + self.eps) ts = jnp.concatenate([s, v_norm], axis=-1) - ds, dv, dsv = jnp.split(self.mixing_block(ts), 3, axis=-1) + ds, dv, dsv = jnp.split(self._mixing(ts), 3, axis=-1) dv = v_l * dv dsv = dsv * jnp.sum(v_r * v_l, axis=1, keepdims=True) @@ -352,7 +459,7 @@ def __call__( return graph._replace(nodes=NodeFeatures(s=s, v=v)) -class PaiNN(hk.Module): +class PaiNN(BaseModel): r"""Polarizable interaction Neural Network by `Schütt et al. `_. @@ -364,19 +471,41 @@ class PaiNN(hk.Module): """ + scalar_emb: LinearXav + vector_emb: LinearXav + filter_net: LinearXav + layers: List[PaiNNLayer] + readout: _PaiNNReadout + radial_basis_fn: eqx.Module + cutoff_fn: Optional[eqx.Module] + + hidden_size: int = eqx.field(static=True) + output_size: int = eqx.field(static=True) + num_mp_steps: int = eqx.field(static=True) + n_vels: int = eqx.field(static=True) + homogeneous_particles: bool = eqx.field(static=True) + shared_filters: bool = eqx.field(static=True) + shared_interactions: bool = eqx.field(static=True) + eps: float = eqx.field(static=True) + def __init__( self, hidden_size: int, output_size: int, num_mp_steps: int, - radial_basis_fn: Callable, - cutoff_fn: Callable, + radial_basis_fn: eqx.Module, + cutoff_fn: Optional[eqx.Module], n_vels: int, + n_rbf: int, + s_in_size: int, + v_in_channels: int, homogeneous_particles: bool = True, activation: Callable = jax.nn.silu, shared_interactions: bool = False, shared_filters: bool = False, eps: float = 1e-8, + *, + key, ): """Initialize the model. @@ -387,66 +516,71 @@ def __init__( radial_basis_fn: Expands inter-particle distances in a basis set. cutoff_fn: Cutoff function. n_vels: Number of historical velocities. + n_rbf: Number of radial basis functions (input size of ``filter_net``). + s_in_size: Size of the input scalar features (vel_mag + optional + particle-type one-hot). + v_in_channels: Number of input vector channels (trajectory + force + + bound depending on the dataset). homogeneous_particles: If all particles are of homogeneous type. activation: Activation function. shared_interactions: If True, share the weights across interaction blocks. shared_filters: If True, share the weights across filter networks. eps: Constant added in norm to prevent derivation instabilities. + key: PRNG key for initialization. """ - super().__init__("painn") - assert radial_basis_fn is not None, "A radial_basis_fn must be provided" - self._n_vels = n_vels - self._homogeneous_particles = homogeneous_particles - self._hidden_size = hidden_size - self._num_mp_steps = num_mp_steps - self._eps = eps - self._shared_filters = shared_filters - self._shared_interactions = shared_interactions - + self.hidden_size = hidden_size + self.output_size = output_size + self.num_mp_steps = num_mp_steps + self.n_vels = n_vels + self.homogeneous_particles = homogeneous_particles + self.shared_filters = shared_filters + self.shared_interactions = shared_interactions + self.eps = eps self.radial_basis_fn = radial_basis_fn self.cutoff_fn = cutoff_fn - self.scalar_emb = LinearXav(self._hidden_size, name="scalar_embedding") - # mix vector channels (only used if vector features are present in input) + k_s, k_v, k_f, k_l, k_r = jax.random.split(key, 5) + self.scalar_emb = LinearXav(s_in_size, hidden_size, key=k_s) self.vector_emb = LinearXav( - self._hidden_size, with_bias=False, name="vector_embedding" + v_in_channels, hidden_size, with_bias=False, key=k_v ) + filter_out = ( + 3 * hidden_size if shared_filters else num_mp_steps * 3 * hidden_size + ) + self.filter_net = LinearXav(n_rbf, filter_out, key=k_f) - if shared_filters: - self.filter_net = LinearXav(3 * self._hidden_size, name="filter_net") - else: - self.filter_net = LinearXav( - self._num_mp_steps * 3 * self._hidden_size, name="filter_net" - ) - - if self._shared_interactions: - self.layers = [ - PaiNNLayer(self._hidden_size, 0, activation, eps=eps) - ] * self._num_mp_steps + keys_layers = jax.random.split(k_l, num_mp_steps) + if shared_interactions: + layer0 = PaiNNLayer(hidden_size, activation=activation, eps=eps, key=keys_layers[0]) + self.layers = [layer0] * num_mp_steps else: self.layers = [ - PaiNNLayer(self._hidden_size, i, activation, eps=eps) - for i in range(self._num_mp_steps) + PaiNNLayer(hidden_size, activation=activation, eps=eps, key=keys_layers[i]) + for i in range(num_mp_steps) ] - self._readout = PaiNNReadout(self._hidden_size, out_channels=output_size) + self.readout = _PaiNNReadout( + hidden_size=hidden_size, + in_channels=hidden_size, + out_channels=output_size, + activation=activation, + blocks=2, + eps=eps, + key=k_r, + ) - def _embed(self, graph: jraph.GraphsTuple) -> Tuple[jnp.ndarray, jnp.ndarray]: + def _embed(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: """Embed the input nodes.""" - # embeds scalar features s = jnp.asarray(graph.nodes.s, dtype=jnp.float32) if len(s.shape) == 1: s = s[:, jnp.newaxis] s = self.scalar_emb(s)[:, jnp.newaxis] - - # embeds vector features v = self.vector_emb(graph.nodes.v) - return graph._replace(nodes=NodeFeatures(s=s, v=v)) - def _get_filters(self, norm_ij: jnp.ndarray) -> jnp.ndarray: + def _get_filters(self, norm_ij: jnp.ndarray) -> List[jnp.ndarray]: r"""Compute the rotationally invariant filters :math:`W_s`. .. math:: @@ -455,13 +589,11 @@ def _get_filters(self, norm_ij: jnp.ndarray) -> jnp.ndarray: phi_ij = self.radial_basis_fn(norm_ij) if self.cutoff_fn is not None: norm_ij = self.cutoff_fn(norm_ij) - # compute filters filters = self.filter_net(phi_ij) * norm_ij[:, jnp.newaxis] - # split into layer-wise filters - if self._shared_filters: - filter_list = [filters] * self._num_mp_steps + if self.shared_filters: + filter_list = [filters] * self.num_mp_steps else: - filter_list = jnp.split(filters, self._num_mp_steps, axis=-1) + filter_list = jnp.split(filters, self.num_mp_steps, axis=-1) return filter_list def _transform( @@ -469,19 +601,17 @@ def _transform( ) -> jraph.GraphsTuple: n_nodes = particle_type.shape[0] - # node features node_scalars = [] node_vectors = [] - traj = jnp.reshape(features["vel_hist"], (n_nodes, self._n_vels, -1)) + traj = jnp.reshape(features["vel_hist"], (n_nodes, self.n_vels, -1)) node_vectors.append(traj.transpose(0, 2, 1)) if "force" in features: node_vectors.append(features["force"][..., jnp.newaxis]) if "bound" in features: bounds = jnp.reshape(features["bound"], (n_nodes, 2, -1)) node_vectors.append(bounds.transpose(0, 2, 1)) - # velocity magnitudes as node feature node_scalars.append(features["vel_mag"]) - if not self._homogeneous_particles: + if not self.homogeneous_particles: particles = jax.nn.one_hot(particle_type, NodeType.SIZE) node_scalars.append(particles) @@ -503,20 +633,16 @@ def __call__( ) -> Dict[str, jnp.ndarray]: graph = self._transform(*sample) # compute atom and pair features - norm_ij = jnp.sqrt(jnp.sum(graph.edges**2, axis=1, keepdims=True) + self._eps) + norm_ij = jnp.sqrt(jnp.sum(graph.edges**2, axis=1, keepdims=True) + self.eps) # edge directions - dir_ij = graph.edges / (norm_ij + self._eps) + dir_ij = graph.edges / (norm_ij + self.eps) graph = graph._replace(edges=dir_ij) - # compute filters (r_ij track in message block from the paper) filter_list = self._get_filters(norm_ij) - # embeds node scalar features (and vector, if present) graph = self._embed(graph) - - # message passing for n, layer in enumerate(self.layers): graph = layer(graph, filter_list[n]) - _, v = self._readout(graph) + _, v = self.readout(graph) return {"acc": v} diff --git a/lagrangebench/models/segnn.py b/lagrangebench/models/segnn.py index 5f3ee66..0f75d9f 100644 --- a/lagrangebench/models/segnn.py +++ b/lagrangebench/models/segnn.py @@ -11,15 +11,16 @@ import warnings from math import prod -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import e3nn_jax as e3nn -import haiku as hk +import equinox as eqx import jax import jax.numpy as jnp import jraph from e3nn_jax import Irreps, IrrepsArray -from jax.tree_util import Partial, tree_map +from e3nn_jax._src.linear import FunctionalLinear +from jax.tree_util import tree_map from lagrangebench.utils import NodeType @@ -27,21 +28,21 @@ from .utils import SteerableGraphsTuple, features_2d_to_3d -def uniform_init( - name: str, - path_shape: Tuple[int, ...], - weight_std: float, - dtype: jnp.dtype = jnp.float32, -) -> jnp.ndarray: - return hk.get_parameter( - name, - shape=path_shape, - dtype=dtype, - init=hk.initializers.RandomUniform(minval=-weight_std, maxval=weight_std), - ) +def _resolve_irreps_io( + irreps_in: Irreps, irreps_out: Irreps, simplify: bool = True +) -> Tuple[Irreps, Irreps]: + """Match ``e3nn.flax.Linear``'s irreps-in/out simplification and filtering.""" + irreps_in = Irreps(irreps_in) + irreps_out = Irreps(irreps_out) + if simplify: + irreps_in = irreps_in.simplify() + irreps_out = irreps_out.simplify() + # only keep outputs that are reachable from the input + irreps_out = irreps_out.filter(keep=irreps_in) + return irreps_in, irreps_out -class O3TensorProduct(hk.Module): +class O3TensorProduct(eqx.Module): r""" O(3) equivariant linear parametrized tensor product layer. @@ -59,60 +60,82 @@ class O3TensorProduct(hk.Module): as fast as FullyConnectedTensorProduct. """ + weights: List[jnp.ndarray] + _linear: FunctionalLinear = eqx.field(static=True) + output_irreps: Irreps = eqx.field(static=True) + def __init__( self, - output_irreps: e3nn.Irreps, + x_irreps: Irreps, + output_irreps: Irreps, + y_irreps: Optional[Irreps] = None, *, biases: bool = True, - name: Optional[str] = None, - init_fn: Callable = uniform_init, + init_scale: float = 1.0, gradient_normalization: Union[str, float] = "element", path_normalization: Union[str, float] = "element", + key, ): """Initialize the tensor product. Args: + x_irreps: Irreps of the left input ``x``. output_irreps: Output representation + y_irreps: Irreps of the right input ``y``. Defaults to the scalar + ``1x0e`` placeholder used when ``y`` is omitted at call time. biases: If set ot true will add biases - name: Name of the linear layer params - init_fn: Weight initialization function. Default is uniform. + init_scale: Multiplicative factor applied to the per-path weight std + when sampling the weights with a uniform distribution in + ``[-weight_std, weight_std]``. gradient_normalization: Gradient normalization method. Default is "element". NOTE: gradient_normalization="element" is the default in torch and haiku. path_normalization: Path normalization method. Default is "element" + key: PRNG key for initialization. """ - super().__init__(name=name) - - if not isinstance(output_irreps, e3nn.Irreps): - output_irreps = e3nn.Irreps(output_irreps) - self.output_irreps = output_irreps + x_irreps = Irreps(x_irreps) + if y_irreps is None: + y_irreps = Irreps("1x0e") + else: + y_irreps = Irreps(y_irreps) + output_irreps = Irreps(output_irreps) - self._linear = e3nn.haiku.Linear( - self.output_irreps, - get_parameter=init_fn, - biases=(biases and "0e" in self.output_irreps), - gradient_normalization=gradient_normalization, - path_normalization=path_normalization, - ) + # the actual input to the internal linear layer is tensor_product(x, y) + tp_irreps = e3nn.tensor_product(x_irreps, y_irreps) - def _check_input( - self, x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None - ) -> Tuple[e3nn.IrrepsArray, e3nn.IrrepsArray]: - if not y: - y = e3nn.IrrepsArray("1x0e", jnp.ones((1, 1), dtype=x.dtype)) - if x.irreps.lmax == 0 and y.irreps.lmax == 0 and self.output_irreps.lmax > 0: + # check for unreachable outputs (warn as original code did) + if x_irreps.lmax == 0 and y_irreps.lmax == 0 and output_irreps.lmax > 0: warnings.warn( - f"The specified output irreps ({self.output_irreps}) are not scalars " + f"The specified output irreps ({output_irreps}) are not scalars " "but both operands are. This can have undesired behaviour (NaN). Try " "redistributing them into scalars or choose higher orders." ) - miss = self.output_irreps.filter(drop=e3nn.tensor_product(x.irreps, y.irreps)) + miss = output_irreps.filter(drop=tp_irreps) if len(miss) > 0: warnings.warn(f"Output irreps: '{miss}' are unreachable and were ignored.") - return x, y + + lin_in_irreps, lin_out_irreps = _resolve_irreps_io(tp_irreps, output_irreps) + linear = FunctionalLinear( + lin_in_irreps, + lin_out_irreps, + biases=(biases and "0e" in lin_out_irreps), + gradient_normalization=gradient_normalization, + path_normalization=path_normalization, + ) + self._linear = linear + self.output_irreps = lin_out_irreps + + keys = jax.random.split(key, max(len(linear.instructions), 1)) + weights = [] + for k, ins in zip(keys, linear.instructions): + std = ins.weight_std * init_scale + weights.append( + jax.random.uniform(k, ins.path_shape, minval=-std, maxval=std) + ) + self.weights = weights def __call__( - self, x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None - ) -> e3nn.IrrepsArray: + self, x: IrrepsArray, y: Optional[IrrepsArray] = None + ) -> IrrepsArray: """Applies an O(3) equivariant linear parametrized tensor product layer. Args: @@ -122,21 +145,18 @@ def __call__( Returns: The output to the weighted tensor product (IrrepsArray). """ - x, y = self._check_input(x, y) - # tensor product + linear - tp = self._linear(e3nn.tensor_product(x, y)) - return tp - - -def O3TensorProductGate( - output_irreps: e3nn.Irreps, - *, - biases: bool = True, - scalar_activation: Optional[Callable] = None, - gate_activation: Optional[Callable] = None, - name: Optional[str] = None, - init_fn: Optional[Callable] = None, -) -> Callable: + if y is None: + y = e3nn.IrrepsArray("1x0e", jnp.ones((1, 1), dtype=x.dtype)) + tp = e3nn.tensor_product(x, y) + tp = tp.remove_zero_chunks().regroup() + + f = lambda arr: self._linear(self.weights, arr) + for _ in range(tp.ndim - 1): + f = e3nn.utils.vmap(f) + return f(tp) + + +class O3TensorProductGate(eqx.Module): r"""Non-linear (gated) O(3) equivariant linear tensor product layer. It applies a linear tensor product of representations to the input(s) and then @@ -145,43 +165,62 @@ def O3TensorProductGate( The input representation is lifted to have gating scalars. Args: + x_irreps: Irreps of the left input ``x``. output_irreps: Output representation + y_irreps: Irreps of the right input ``y``. biases: Add biases scalar_activation: Activation function for scalars gate_activation: Activation function for higher order - name: Name of the linear layer params - - Returns: - Function that applies the gated tensor product layer. + key: PRNG key for initialization. """ - if not isinstance(output_irreps, e3nn.Irreps): - output_irreps = e3nn.Irreps(output_irreps) - - # lift output with gating scalars - gate_irreps = e3nn.Irreps( - f"{output_irreps.num_irreps - output_irreps.count('0e')}x0e" - ) - tensor_product = O3TensorProduct( - (gate_irreps + output_irreps).regroup(), - biases=biases, - name=name, - init_fn=init_fn, - ) - if not scalar_activation: - scalar_activation = jax.nn.silu - if not gate_activation: - gate_activation = jax.nn.sigmoid - - def _gated_tensor_product( - x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None, **kwargs - ) -> e3nn.IrrepsArray: - tp = tensor_product(x, y, **kwargs) - return e3nn.gate(tp, even_act=scalar_activation, odd_gate_act=gate_activation) - - return _gated_tensor_product - - -def O3Embedding(embed_irreps: Irreps, embed_edges: bool = True) -> Callable: + + tp: O3TensorProduct + scalar_activation: Callable = eqx.field(static=True) + gate_activation: Callable = eqx.field(static=True) + + def __init__( + self, + x_irreps: Irreps, + output_irreps: Irreps, + y_irreps: Optional[Irreps] = None, + *, + biases: bool = True, + scalar_activation: Optional[Callable] = None, + gate_activation: Optional[Callable] = None, + init_scale: float = 1.0, + key, + ): + output_irreps = Irreps(output_irreps) + gate_irreps = Irreps( + f"{output_irreps.num_irreps - output_irreps.count('0e')}x0e" + ) + self.tp = O3TensorProduct( + x_irreps, + (gate_irreps + output_irreps).regroup(), + y_irreps=y_irreps, + biases=biases, + init_scale=init_scale, + key=key, + ) + if scalar_activation is None: + scalar_activation = jax.nn.silu + if gate_activation is None: + gate_activation = jax.nn.sigmoid + self.scalar_activation = scalar_activation + self.gate_activation = gate_activation + + def __call__( + self, x: IrrepsArray, y: Optional[IrrepsArray] = None + ) -> IrrepsArray: + tp = self.tp(x, y) + return e3nn.gate( + tp, + even_act=self.scalar_activation, + odd_gate_act=self.gate_activation, + ) + + +class O3Embedding(eqx.Module): """Linear steerable embedding. Embeds the graph nodes in the representation space :param embed_irreps:. @@ -189,154 +228,207 @@ def O3Embedding(embed_irreps: Irreps, embed_edges: bool = True) -> Callable: Args: embed_irreps: Output representation embed_edges: If true also embed edges/message passing features - - Returns: - Function to embed graph nodes (and optionally edges) """ - def _embedding( - st_graph: SteerableGraphsTuple, - ) -> SteerableGraphsTuple: - graph = st_graph.graph - nodes = O3TensorProduct( - embed_irreps, - name="embedding_nodes", - )(graph.nodes, st_graph.node_attributes) - st_graph = st_graph._replace(graph=graph._replace(nodes=nodes)) + node_embed: O3TensorProduct + msg_embed: Optional[O3TensorProduct] + embed_edges: bool = eqx.field(static=True) - # NOTE edge embedding is not in the original paper but can get good results + def __init__( + self, + node_in_irreps: Irreps, + node_attr_irreps: Irreps, + edge_in_irreps: Optional[Irreps], + edge_attr_irreps: Optional[Irreps], + embed_irreps: Irreps, + embed_edges: bool = True, + *, + key, + ): + self.embed_edges = embed_edges + k_nodes, k_msg = jax.random.split(key) + self.node_embed = O3TensorProduct( + node_in_irreps, + embed_irreps, + y_irreps=node_attr_irreps, + key=k_nodes, + ) if embed_edges: - additional_message_features = O3TensorProduct( - embed_irreps, name="embedding_msg_features" - ) - (st_graph.additional_message_features, st_graph.edge_attributes) - st_graph = st_graph._replace( - additional_message_features=additional_message_features + assert edge_in_irreps is not None and edge_attr_irreps is not None + self.msg_embed = O3TensorProduct( + edge_in_irreps, + embed_irreps, + y_irreps=edge_attr_irreps, + key=k_msg, ) + else: + self.msg_embed = None + def __call__(self, st_graph: SteerableGraphsTuple) -> SteerableGraphsTuple: + graph = st_graph.graph + nodes = self.node_embed(graph.nodes, st_graph.node_attributes) + st_graph = st_graph._replace(graph=graph._replace(nodes=nodes)) + if self.embed_edges: + amf = self.msg_embed( + st_graph.additional_message_features, st_graph.edge_attributes + ) + st_graph = st_graph._replace(additional_message_features=amf) return st_graph - return _embedding - -def O3Decoder( - latent_irreps: Irreps, - output_irreps: Irreps, - n_blocks: int = 1, -): +class O3Decoder(eqx.Module): """Steerable decoder. Args: latent_irreps: Representation from the previous block output_irreps: Output representation n_blocks: Number of tensor product blocks in the decoder - - Returns: - Decoded latent feature space to output space. """ - def _decoder(st_graph: SteerableGraphsTuple): - nodes = st_graph.graph.nodes - for i in range(n_blocks): - nodes = O3TensorProductGate(latent_irreps, name=f"readout_{i}")( - nodes, st_graph.node_attributes - ) + readout_blocks: List[O3TensorProductGate] + output_block: O3TensorProduct - return O3TensorProduct(output_irreps, name="output")( - nodes, st_graph.node_attributes + def __init__( + self, + latent_irreps: Irreps, + node_attr_irreps: Irreps, + output_irreps: Irreps, + n_blocks: int = 1, + *, + key, + ): + keys = jax.random.split(key, n_blocks + 1) + self.readout_blocks = [ + O3TensorProductGate( + latent_irreps, + latent_irreps, + y_irreps=node_attr_irreps, + key=keys[i], + ) + for i in range(n_blocks) + ] + self.output_block = O3TensorProduct( + latent_irreps, + output_irreps, + y_irreps=node_attr_irreps, + key=keys[-1], ) - return _decoder + def __call__(self, st_graph: SteerableGraphsTuple) -> IrrepsArray: + nodes = st_graph.graph.nodes + for block in self.readout_blocks: + nodes = block(nodes, st_graph.node_attributes) + return self.output_block(nodes, st_graph.node_attributes) -class SEGNNLayer(hk.Module): +class SEGNNLayer(eqx.Module): """ Steerable E(3) equivariant layer. Applies a message passing step (GN) with equivariant message and update functions. """ + msg_blocks: List[O3TensorProductGate] + update_gate_blocks: List[O3TensorProductGate] + update_final: O3TensorProduct + output_irreps: Irreps = eqx.field(static=True) + n_blocks: int = eqx.field(static=True) + aggregate_fn: Callable = eqx.field(static=True) + def __init__( self, + msg_in_irreps: Irreps, + node_in_irreps: Irreps, + edge_attr_irreps: Irreps, + node_attr_irreps: Irreps, output_irreps: Irreps, - layer_idx: int, n_blocks: int = 2, - norm: Optional[str] = None, - aggregate_fn: Optional[Callable] = jraph.segment_sum, + aggregate_fn: Callable = jraph.segment_sum, + *, + key, ): """ Initialize the layer. Args: + msg_in_irreps: Irreps of the pre-message concatenation + ``[h_i, h_j, additional_message_features]`` at the layer input. + node_in_irreps: Irreps of the node features at the layer input. + edge_attr_irreps: Irreps of the steerable edge attributes. + node_attr_irreps: Irreps of the steerable node attributes. output_irreps: Layer output representation - layer_idx: Numbering of the layer n_blocks: Number of tensor product n_blocks in the layer - norm: Normalization type. Either be None, 'instance' or 'batch' aggregate_fn: Message aggregation function. Defaults to sum. + key: PRNG key for initialization. """ - super().__init__(f"layer_{layer_idx}") - assert norm in ["batch", "instance", "none", None], f"Unknown norm '{norm}'" - self._output_irreps = output_irreps - self._n_blocks = n_blocks - self._norm = norm - self._aggregate_fn = aggregate_fn + self.output_irreps = Irreps(output_irreps) + self.n_blocks = n_blocks + self.aggregate_fn = aggregate_fn + + output_irreps_clean = Irreps(output_irreps) + + key_m, key_u = jax.random.split(key, 2) + keys_m = jax.random.split(key_m, n_blocks) + self.msg_blocks = [] + cur_irreps = Irreps(msg_in_irreps) + for i in range(n_blocks): + block = O3TensorProductGate( + cur_irreps, + output_irreps_clean, + y_irreps=edge_attr_irreps, + key=keys_m[i], + ) + self.msg_blocks.append(block) + # after gate the irreps collapse back to ``output_irreps`` + cur_irreps = output_irreps_clean + + keys_u = jax.random.split(key_u, n_blocks) + self.update_gate_blocks = [] + # update takes [nodes, msg] concatenated -> node_in_irreps + output_irreps + cur_irreps = (Irreps(node_in_irreps) + output_irreps_clean).regroup() + for i in range(n_blocks - 1): + block = O3TensorProductGate( + cur_irreps, + output_irreps_clean, + y_irreps=node_attr_irreps, + key=keys_u[i], + ) + self.update_gate_blocks.append(block) + cur_irreps = output_irreps_clean + self.update_final = O3TensorProduct( + cur_irreps, + output_irreps_clean, + y_irreps=node_attr_irreps, + key=keys_u[-1], + ) def _message( self, - edge_attribute: IrrepsArray, - additional_message_features: IrrepsArray, - edge_features: Any, incoming: IrrepsArray, outgoing: IrrepsArray, - globals_: Any, + additional_message_features: Optional[IrrepsArray], + edge_attribute: IrrepsArray, ) -> IrrepsArray: """Steerable equivariant message function.""" - _ = globals_ - _ = edge_features - # create messages msg = e3nn.concatenate([incoming, outgoing], axis=-1) if additional_message_features is not None: msg = e3nn.concatenate([msg, additional_message_features], axis=-1) - # message mlp (phi_m in the paper) steered by edge attributeibutes - for i in range(self._n_blocks): - msg = O3TensorProductGate(self._output_irreps, name=f"tp_{i}")( - msg, edge_attribute - ) - # NOTE: original implementation only applied batch norm to messages - if self._norm == "batch": - msg = e3nn.haiku.BatchNorm(irreps=self._output_irreps)(msg) + for block in self.msg_blocks: + msg = block(msg, edge_attribute) return msg def _update( self, - node_attribute: IrrepsArray, nodes: IrrepsArray, - senders: Any, - msg: IrrepsArray, - globals_: Any, + msg_sum: IrrepsArray, + node_attribute: IrrepsArray, ) -> IrrepsArray: """Steerable equivariant update function.""" - _ = senders - _ = globals_ - x = e3nn.concatenate([nodes, msg], axis=-1) - # update mlp (phi_f in the paper) steered by node attributeibutes - for i in range(self._n_blocks - 1): - x = O3TensorProductGate(self._output_irreps, name=f"tp_{i}")( - x, node_attribute - ) - # last update layer without activation - update = O3TensorProduct(self._output_irreps, name=f"tp_{self._n_blocks - 1}")( - x, node_attribute - ) - # residual connection - nodes += update - # message norm - if self._norm in ["batch", "instance"]: - nodes = e3nn.haiku.BatchNorm( - irreps=self._output_irreps, - instance=(self._norm == "instance"), - )(nodes) - return nodes + x = e3nn.concatenate([nodes, msg_sum], axis=-1) + for block in self.update_gate_blocks: + x = block(x, node_attribute) + update = self.update_final(x, node_attribute) + return nodes + update def __call__(self, st_graph: SteerableGraphsTuple) -> SteerableGraphsTuple: """Perform a message passing step. @@ -347,23 +439,27 @@ def __call__(self, st_graph: SteerableGraphsTuple) -> SteerableGraphsTuple: Returns: The updated graph """ - # NOTE node_attributes, edge_attributes and additional_message_features - # are never updated within the message passing layers - return st_graph._replace( - graph=jraph.GraphNetwork( - update_node_fn=Partial(self._update, st_graph.node_attributes), - update_edge_fn=Partial( - self._message, - st_graph.edge_attributes, - st_graph.additional_message_features, - ), - aggregate_edges_for_nodes_fn=self._aggregate_fn, - )(st_graph.graph) + graph = st_graph.graph + n_nodes = graph.nodes.shape[0] + + incoming = tree_map(lambda a: a[graph.senders], graph.nodes) + outgoing = tree_map(lambda a: a[graph.receivers], graph.nodes) + msg = self._message( + incoming, + outgoing, + st_graph.additional_message_features, + st_graph.edge_attributes, + ) + # aggregate messages to receivers; apply segment_sum per chunk + msg_sum = tree_map( + lambda a: self.aggregate_fn(a, graph.receivers, n_nodes), msg ) + new_nodes = self._update(graph.nodes, msg_sum, st_graph.node_attributes) + return st_graph._replace(graph=graph._replace(nodes=new_nodes)) def weight_balanced_irreps( - scalar_units: int, irreps_right: Irreps, lmax: int = None + scalar_units: int, irreps_right: Irreps, lmax: Optional[int] = None ) -> Irreps: """ Determine left irreps so that the tensor product with irreps_right has at least @@ -441,6 +537,22 @@ class SEGNN(BaseModel): """ + embedding: O3Embedding + layers: List[SEGNNLayer] + decoder: O3Decoder + + attribute_irreps: Irreps = eqx.field(static=True) + hidden_irreps: Irreps = eqx.field(static=True) + output_irreps: Irreps = eqx.field(static=True) + node_features_irreps: Irreps = eqx.field(static=True) + edge_features_irreps: Irreps = eqx.field(static=True) + num_mp_steps: int = eqx.field(static=True) + blocks_per_step: int = eqx.field(static=True) + embed_msg_features: bool = eqx.field(static=True) + velocity_aggregate: str = eqx.field(static=True) + n_vels: int = eqx.field(static=True) + homogeneous_particles: bool = eqx.field(static=True) + def __init__( self, node_features_irreps: Irreps, @@ -453,9 +565,10 @@ def __init__( n_vels: int, velocity_aggregate: str = "avg", homogeneous_particles: bool = True, - norm: Optional[str] = None, blocks_per_step: int = 2, embed_msg_features: bool = False, + *, + key, ): """ Initialize the network. @@ -471,95 +584,116 @@ def __init__( n_vels: Number of velocities in the history. velocity_aggregate: Velocity sequence aggregation method. homogeneous_particles: If all particles are of homogeneous type. - norm: Normalization type. Either None, 'instance' or 'batch' blocks_per_step: Number of tensor product blocks in each message passing embed_msg_features: Set to true to also embed edges/message passing features + key: PRNG key for initialization. """ - super().__init__() - - # network - self._attribute_irreps = Irreps.spherical_harmonics(lmax_attributes) - self._hidden_irreps = weight_balanced_irreps( - scalar_units, self._attribute_irreps, lmax_hidden - ) - self._output_irreps = output_irreps - self._num_mp_steps = num_mp_steps - self._embed_msg_features = embed_msg_features - self._norm = norm - self._blocks_per_step = blocks_per_step - - self._embedding = O3Embedding( - self._hidden_irreps, - embed_edges=self._embed_msg_features, + self.attribute_irreps = Irreps.spherical_harmonics(lmax_attributes) + self.hidden_irreps = weight_balanced_irreps( + scalar_units, self.attribute_irreps, lmax_hidden ) + self.output_irreps = Irreps(output_irreps) + self.node_features_irreps = Irreps(node_features_irreps) + self.edge_features_irreps = Irreps(edge_features_irreps) + self.num_mp_steps = num_mp_steps + self.blocks_per_step = blocks_per_step + self.embed_msg_features = embed_msg_features - self._decoder = O3Decoder( - latent_irreps=self._hidden_irreps, - output_irreps=output_irreps, - n_blocks=self._blocks_per_step, - ) - - # transform assert velocity_aggregate in [ "avg", "last", ], "Invalid velocity aggregate. Must be one of 'avg', 'sum' or 'last'." - self._node_features_irreps = node_features_irreps - self._edge_features_irreps = edge_features_irreps - self._velocity_aggregate = velocity_aggregate - self._n_vels = n_vels - self._homogeneous_particles = homogeneous_particles + self.velocity_aggregate = velocity_aggregate + self.n_vels = n_vels + self.homogeneous_particles = homogeneous_particles + + keys = jax.random.split(key, 2 + num_mp_steps) + self.embedding = O3Embedding( + node_in_irreps=self.node_features_irreps, + node_attr_irreps=self.attribute_irreps, + edge_in_irreps=self.edge_features_irreps if embed_msg_features else None, + edge_attr_irreps=self.attribute_irreps if embed_msg_features else None, + embed_irreps=self.hidden_irreps, + embed_edges=embed_msg_features, + key=keys[0], + ) + + # message passing layers + self.layers = [] + # input to message block: [h_i, h_j, msg_features] concat + msg_feat_irreps = ( + self.hidden_irreps if embed_msg_features else self.edge_features_irreps + ) + msg_in_irreps = ( + (self.hidden_irreps + self.hidden_irreps + msg_feat_irreps) + .regroup() + ) + for i in range(num_mp_steps): + self.layers.append( + SEGNNLayer( + msg_in_irreps=msg_in_irreps, + node_in_irreps=self.hidden_irreps, + edge_attr_irreps=self.attribute_irreps, + node_attr_irreps=self.attribute_irreps, + output_irreps=self.hidden_irreps, + n_blocks=blocks_per_step, + key=keys[1 + i], + ) + ) + + self.decoder = O3Decoder( + latent_irreps=self.hidden_irreps, + node_attr_irreps=self.attribute_irreps, + output_irreps=self.output_irreps, + n_blocks=blocks_per_step, + key=keys[-1], + ) def _transform( self, features: Dict[str, jnp.ndarray], particle_type: jnp.ndarray ) -> Tuple[SteerableGraphsTuple, int]: """Convert physical features to SteerableGraphsTuple for segnn.""" - dim = features["vel_hist"].shape[1] // self._n_vels + dim = features["vel_hist"].shape[1] // self.n_vels assert ( dim == 3 or dim == 2 ), "The velocity history should be of shape (n_nodes, n_vels * 3)." n_nodes = features["vel_hist"].shape[0] - - features["vel_hist"] = features["vel_hist"].reshape(n_nodes, self._n_vels, dim) - + features["vel_hist"] = features["vel_hist"].reshape(n_nodes, self.n_vels, dim) if dim == 2: - # add zeros for z component for E(3) equivariance features = features_2d_to_3d(features) - if self._n_vels == 1: + if self.n_vels == 1: vel = jnp.squeeze(features["vel_hist"]) else: - if self._velocity_aggregate == "avg": + if self.velocity_aggregate == "avg": vel = jnp.mean(features["vel_hist"], 1) - if self._velocity_aggregate == "last": + if self.velocity_aggregate == "last": vel = features["vel_hist"][:, -1, :] rel_pos = features["rel_disp"] edge_attributes = e3nn.spherical_harmonics( - self._attribute_irreps, rel_pos, normalize=True, normalization="integral" + self.attribute_irreps, rel_pos, normalize=True, normalization="integral" ) vel_embedding = e3nn.spherical_harmonics( - self._attribute_irreps, vel, normalize=True, normalization="integral" + self.attribute_irreps, vel, normalize=True, normalization="integral" ) - # scatter edge attributes to nodes (density) scattered_edges = tree_map( lambda e: jraph.segment_mean(e, features["receivers"], n_nodes), edge_attributes, ) - # node attributes as velocities + edge "density". Scalar default to 1.0 node_attributes = e3nn.IrrepsArray( vel_embedding.irreps, (vel_embedding + scattered_edges).array.at[:, 0].set(1.0), ) - node_features = [features["vel_hist"].reshape(n_nodes, self._n_vels * 3)] + node_features = [features["vel_hist"].reshape(n_nodes, self.n_vels * 3)] node_features += [ features[k] for k in ["vel_mag", "bound", "force"] if k in features ] node_features = jnp.concatenate(node_features, axis=-1) - if not self._homogeneous_particles: + if not self.homogeneous_particles: particles = jax.nn.one_hot(particle_type, NodeType.SIZE) node_features = jnp.concatenate([node_features, particles], axis=-1) @@ -567,7 +701,7 @@ def _transform( edge_features = jnp.concatenate(edge_features, axis=-1) feature_graph = jraph.GraphsTuple( - nodes=IrrepsArray(self._node_features_irreps, node_features), + nodes=IrrepsArray(self.node_features_irreps, node_features), edges=None, senders=features["senders"], receivers=features["receivers"], @@ -580,10 +714,9 @@ def _transform( node_attributes=node_attributes, edge_attributes=edge_attributes, additional_message_features=IrrepsArray( - self._edge_features_irreps, edge_features + self.edge_features_irreps, edge_features ), ) - return st_graph, dim def _postprocess(self, nodes: IrrepsArray, dim: int) -> Dict[str, jnp.ndarray]: @@ -595,16 +728,9 @@ def _postprocess(self, nodes: IrrepsArray, dim: int) -> Dict[str, jnp.ndarray]: def __call__( self, sample: Tuple[Dict[str, jnp.ndarray], jnp.ndarray] ) -> Dict[str, jnp.ndarray]: - # feature transformation st_graph, dim = self._transform(*sample) - # node (and edge) embedding - st_graph = self._embedding(st_graph) - # message passing - for n in range(self._num_mp_steps): - st_graph = SEGNNLayer( - self._hidden_irreps, n, n_blocks=self._blocks_per_step, norm=self._norm - )(st_graph) - # readout - nodes = self._decoder(st_graph) - out = self._postprocess(nodes, dim) - return out + st_graph = self.embedding(st_graph) + for layer in self.layers: + st_graph = layer(st_graph) + nodes = self.decoder(st_graph) + return self._postprocess(nodes, dim) diff --git a/lagrangebench/models/utils.py b/lagrangebench/models/utils.py index f40f302..579cd33 100644 --- a/lagrangebench/models/utils.py +++ b/lagrangebench/models/utils.py @@ -1,56 +1,197 @@ -from typing import Callable, Dict, Iterable, NamedTuple, Optional +"""Model utilities and building blocks (Equinox-based).""" + +from typing import Callable, Dict, List, NamedTuple, Optional, Tuple import e3nn_jax as e3nn -import haiku as hk +import equinox as eqx import jax import jax.numpy as jnp import jraph +from jax import Array from lagrangebench.utils import NodeType -class LinearXav(hk.Linear): - """Linear layer with Xavier init. Avoid distracting 'w_init' everywhere.""" +def _xavier_uniform(key, shape: Tuple[int, ...], gain: float = 1.0) -> Array: + """Xavier/Glorot uniform over a (in, out) weight shape.""" + fan_in, fan_out = shape[-2], shape[-1] + limit = gain * jnp.sqrt(6.0 / (fan_in + fan_out)) + return jax.random.uniform(key, shape, minval=-limit, maxval=limit) + + +def _uniform_scaling(key, shape: Tuple[int, ...], scale: float = 1.0) -> Array: + """Uniform-scaling init: ``U(-limit, +limit)`` with + ``limit = scale * sqrt(3 / fan_in)``. Matches haiku's ``UniformScaling``. + """ + fan_in = shape[-2] + limit = scale * jnp.sqrt(3.0 / fan_in) + return jax.random.uniform(key, shape, minval=-limit, maxval=limit) + + +class LinearXav(eqx.Module): + """Linear layer with Xavier/Glorot-uniform init, applied on the last axis.""" + + weight: Array + bias: Optional[Array] def __init__( self, - output_size: int, + in_size: int, + out_size: int, with_bias: bool = True, - w_init: Optional[hk.initializers.Initializer] = None, - b_init: Optional[hk.initializers.Initializer] = None, - name: Optional[str] = None, + gain: float = 1.0, + init: str = "xavier_avg", + *, + key, ): - if w_init is None: - w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform") - super().__init__(output_size, with_bias, w_init, b_init, name) + """Initialize. + + Args: + in_size: Input feature size (last-axis length). + out_size: Output feature size. + with_bias: Whether to include a bias term. + gain: Multiplicative factor on the initialization limit. + init: ``"xavier_avg"`` uses ``gain * sqrt(6/(fan_in+fan_out))``; + ``"uniform_scaling"`` uses ``gain * sqrt(3/fan_in)`` (equivalent + to haiku's ``hk.initializers.UniformScaling(gain)``). + key: PRNG key for initialization. + """ + if init == "xavier_avg": + self.weight = _xavier_uniform(key, (in_size, out_size), gain=gain) + elif init == "uniform_scaling": + self.weight = _uniform_scaling(key, (in_size, out_size), scale=gain) + else: + raise ValueError(f"Unknown init: {init}") + self.bias = jnp.zeros(out_size) if with_bias else None + + def __call__(self, x: Array) -> Array: + y = x @ self.weight + if self.bias is not None: + y = y + self.bias + return y + + +class LayerNorm(eqx.Module): + """LayerNorm on the last axis.""" + + scale: Array + offset: Array + eps: float = eqx.field(static=True) + + def __init__(self, size: int, eps: float = 1e-5): + self.scale = jnp.ones(size) + self.offset = jnp.zeros(size) + self.eps = eps + + def __call__(self, x: Array) -> Array: + mean = x.mean(-1, keepdims=True) + var = x.var(-1, keepdims=True) + return (x - mean) * jax.lax.rsqrt(var + self.eps) * self.scale + self.offset -class MLPXav(hk.nets.MLP): - """MLP layer with Xavier init. Avoid distracting 'w_init' everywhere.""" +class MLPXav(eqx.Module): + """MLP with Xavier-init dense layers. Analogue of haiku's hk.nets.MLP.""" + + layers: List[LinearXav] + activation: Callable = eqx.field(static=True) + activate_final: bool = eqx.field(static=True) def __init__( self, - output_sizes: Iterable[int], + input_size: int, + output_sizes: List[int], with_bias: bool = True, - w_init: Optional[hk.initializers.Initializer] = None, - b_init: Optional[hk.initializers.Initializer] = None, activation: Callable = jax.nn.silu, activate_final: bool = False, - name: Optional[str] = None, + *, + key, + ): + keys = jax.random.split(key, len(output_sizes)) + sizes = [input_size] + list(output_sizes) + self.layers = [ + LinearXav(sizes[i], sizes[i + 1], with_bias=with_bias, key=keys[i]) + for i in range(len(output_sizes)) + ] + self.activation = activation + self.activate_final = activate_final + + def __call__(self, x: Array) -> Array: + for i, layer in enumerate(self.layers): + x = layer(x) + if i < len(self.layers) - 1 or self.activate_final: + x = self.activation(x) + return x + + +class MLPBlock(eqx.Module): + """An MLP optionally followed by a LayerNorm. Replaces haiku's `build_mlp`.""" + + mlp: MLPXav + norm: Optional[LayerNorm] + + def __init__( + self, + input_size: int, + latent_size: int, + output_size: int, + num_hidden_layers: int, + is_layer_norm: bool = True, + activation: Callable = jax.nn.silu, + *, + key, ): - if w_init is None: - w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform") - if not with_bias: - b_init = None - super().__init__( - output_sizes, - w_init, - b_init, - with_bias, - activation, - activate_final, - name, + assert num_hidden_layers >= 1 + output_sizes = [latent_size] * (num_hidden_layers - 1) + [output_size] + self.mlp = MLPXav( + input_size=input_size, + output_sizes=output_sizes, + activation=activation, + activate_final=False, + key=key, ) + self.norm = LayerNorm(output_size) if is_layer_norm else None + + def __call__(self, x: Array) -> Array: + x = self.mlp(x) + if self.norm is not None: + x = self.norm(x) + return x + + +def build_mlp( + input_size: int, + latent_size: int, + output_size: int, + num_hidden_layers: int, + is_layer_norm: bool = True, + activation: Callable = jax.nn.silu, + *, + key, +) -> MLPBlock: + """MLP factory: stacks ``num_hidden_layers`` dense layers + optional LayerNorm.""" + return MLPBlock( + input_size=input_size, + latent_size=latent_size, + output_size=output_size, + num_hidden_layers=num_hidden_layers, + is_layer_norm=is_layer_norm, + activation=activation, + key=key, + ) + + +class Embedding(eqx.Module): + """Categorical embedding table (jax-friendly replacement for hk.Embed).""" + + table: Array + + def __init__(self, vocab_size: int, embed_size: int, *, key): + self.table = jax.random.truncated_normal( + key, lower=-2.0, upper=2.0, shape=(vocab_size, embed_size) + ) * 0.02 + + def __call__(self, ids: Array) -> Array: + return self.table[ids] class SteerableGraphsTuple(NamedTuple): @@ -84,37 +225,15 @@ def node_irreps( irreps.append(f"{input_seq_length - 1}x1o") if not any(metadata["periodic_boundary_conditions"]): irreps.append("2x1o") - if has_external_force: irreps.append("1x1o") - if has_magnitudes: irreps.append(f"{input_seq_length - 1}x0e") - if not has_homogeneous_particles: irreps.append(f"{NodeType.SIZE}x0e") - return e3nn.Irreps("+".join(irreps)) -def build_mlp( - latent_size, output_size, num_hidden_layers, is_layer_norm=True, **kwds: Dict -): - """MLP generation helper using Haiku.""" - assert num_hidden_layers >= 1 - network = hk.nets.MLP( - [latent_size] * (num_hidden_layers - 1) + [output_size], - **kwds, - activate_final=False, - name="MLP", - ) - if is_layer_norm: - l_norm = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True) - return hk.Sequential([network, l_norm]) - else: - return network - - def features_2d_to_3d(features): """Add zeros in the z component of 2D features.""" n_nodes = features["vel_hist"].shape[0] @@ -134,5 +253,28 @@ def features_2d_to_3d(features): features["force"] = jnp.concatenate( [features["force"], jnp.zeros((n_nodes, 1))], -1 ) - return features + + +def node_feature_size( + metadata: Dict, + input_seq_length: int, + has_external_force: bool, + has_magnitudes: bool, + has_homogeneous_particles: bool, +) -> int: + """Scalar-feature size for GNS (sum of vel_hist, vel_mag, bound, force, embed).""" + dim = metadata["dim"] + size = (input_seq_length - 1) * dim # vel_hist + if has_magnitudes: + size += input_seq_length - 1 # vel_mag + if not any(metadata["periodic_boundary_conditions"]): + size += 2 * dim # bound (distance to lower + upper) + if has_external_force: + size += dim # force + return size + + +def edge_feature_size(metadata: Dict) -> int: + """Edge feature size (rel_disp + rel_dist).""" + return metadata["dim"] + 1 diff --git a/lagrangebench/runner.py b/lagrangebench/runner.py index 26e1ecc..7b7759d 100644 --- a/lagrangebench/runner.py +++ b/lagrangebench/runner.py @@ -1,16 +1,15 @@ import os import os.path as osp from datetime import datetime -from typing import Callable, Dict, Optional, Tuple, Type, Union +from typing import Dict, Optional, Tuple, Union -import haiku as hk +import equinox as eqx import jax import jax.numpy as jnp -import jmp import numpy as np from e3nn_jax import Irreps from jax import config -from jax_sph.jax_md import space +from jax_md import space from omegaconf import DictConfig, OmegaConf from lagrangebench import Trainer, infer, models @@ -18,7 +17,11 @@ from lagrangebench.data import H5Dataset from lagrangebench.defaults import check_cfg from lagrangebench.evaluate import averaged_metrics -from lagrangebench.models.utils import node_irreps +from lagrangebench.models.utils import ( + edge_feature_size, + node_feature_size, + node_irreps, +) from lagrangebench.utils import NodeType @@ -57,19 +60,13 @@ def train_or_infer(cfg: Union[Dict, DictConfig]): _, particle_type = data_train[0] # setup model from configs - model, MODEL = setup_model( + model = setup_model( cfg, metadata=metadata, homogeneous_particles=particle_type.max() == particle_type.min(), has_external_force=data_train.external_force_fn is not None, normalization_stats=case.normalization_stats, ) - model = hk.without_apply_rng(hk.transform_with_state(model)) - - # mixed precision training based on this reference: - # https://github.com/deepmind/dm-haiku/blob/main/examples/imagenet/train.py - policy = jmp.get_policy("params=float32,compute=float32,output=float32") - hk.mixed_precision.set_policy(MODEL, policy) if mode == "train" or mode == "all": print("Start training...") @@ -102,7 +99,7 @@ def train_or_infer(cfg: Union[Dict, DictConfig]): seed=cfg.seed, ) - _, _, _ = trainer.train( + model, _ = trainer.train( step_max=cfg.train.step_max, load_ckp=load_ckp, store_ckp=store_ckp, @@ -116,7 +113,7 @@ def train_or_infer(cfg: Union[Dict, DictConfig]): model_dir = load_ckp if mode == "all": model_dir = os.path.join(store_ckp, "best") - assert osp.isfile(os.path.join(model_dir, "params_tree.pkl")) + assert osp.isfile(os.path.join(model_dir, "model.eqx")) cfg.eval.rollout_dir = model_dir.replace("ckp", "rollout") os.makedirs(cfg.eval.rollout_dir, exist_ok=True) @@ -195,27 +192,34 @@ def setup_model( homogeneous_particles: bool = False, has_external_force: bool = False, normalization_stats: Optional[Dict] = None, -) -> Tuple[Callable, Type]: +) -> eqx.Module: """Setup model based on cfg.""" model_name = cfg.model.name.lower() input_seq_length = cfg.model.input_seq_length magnitude_features = cfg.model.magnitude_features + key = jax.random.PRNGKey(cfg.seed) if model_name == "gns": - - def model_fn(x): - return models.GNS( - particle_dimension=metadata["dim"], - latent_size=cfg.model.latent_dim, - blocks_per_step=cfg.model.num_mlp_layers, - num_mp_steps=cfg.model.num_mp_steps, - num_particle_types=NodeType.SIZE, - particle_type_embedding_size=16, - )(x) - - MODEL = models.GNS + n_in = node_feature_size( + metadata, + input_seq_length, + has_external_force=has_external_force, + has_magnitudes=magnitude_features, + has_homogeneous_particles=homogeneous_particles, + ) + e_in = edge_feature_size(metadata) + model = models.GNS( + particle_dimension=metadata["dim"], + latent_size=cfg.model.latent_dim, + blocks_per_step=cfg.model.num_mlp_layers, + num_mp_steps=cfg.model.num_mp_steps, + num_particle_types=NodeType.SIZE, + particle_type_embedding_size=16, + node_in_size=n_in, + edge_in_size=e_in, + key=key, + ) elif model_name == "segnn": - # Hx1o vel, Hx0e vel, 2x1o boundary, 9x0e type node_feature_irreps = node_irreps( metadata, input_seq_length, @@ -223,70 +227,79 @@ def model_fn(x): magnitude_features, homogeneous_particles, ) - # 1o displacement, 0e distance edge_feature_irreps = Irreps("1x1o + 1x0e") - - def model_fn(x): - return models.SEGNN( - node_features_irreps=node_feature_irreps, - edge_features_irreps=edge_feature_irreps, - scalar_units=cfg.model.latent_dim, - lmax_hidden=cfg.model.lmax_hidden, - lmax_attributes=cfg.model.lmax_attributes, - output_irreps=Irreps("1x1o"), - num_mp_steps=cfg.model.num_mp_steps, - n_vels=cfg.model.input_seq_length - 1, - velocity_aggregate=cfg.model.velocity_aggregate, - homogeneous_particles=homogeneous_particles, - blocks_per_step=cfg.model.num_mlp_layers, - norm=cfg.model.segnn_norm, - )(x) - - MODEL = models.SEGNN + model = models.SEGNN( + node_features_irreps=node_feature_irreps, + edge_features_irreps=edge_feature_irreps, + scalar_units=cfg.model.latent_dim, + lmax_hidden=cfg.model.lmax_hidden, + lmax_attributes=cfg.model.lmax_attributes, + output_irreps=Irreps("1x1o"), + num_mp_steps=cfg.model.num_mp_steps, + n_vels=input_seq_length - 1, + velocity_aggregate=cfg.model.velocity_aggregate, + homogeneous_particles=homogeneous_particles, + blocks_per_step=cfg.model.num_mlp_layers, + key=key, + ) elif model_name == "egnn": box = cfg.box if jnp.array(metadata["periodic_boundary_conditions"]).any(): displacement_fn, shift_fn = space.periodic(jnp.array(box)) else: displacement_fn, shift_fn = space.free() - - displacement_fn = jax.vmap(displacement_fn, in_axes=(0, 0)) - shift_fn = jax.vmap(shift_fn, in_axes=(0, 0)) - - def model_fn(x): - return models.EGNN( - hidden_size=cfg.model.latent_dim, - output_size=1, - dt=metadata["dt"] * metadata["write_every"], - displacement_fn=displacement_fn, - shift_fn=shift_fn, - normalization_stats=normalization_stats, - num_mp_steps=cfg.model.num_mp_steps, - n_vels=input_seq_length - 1, - residual=True, - )(x) - - MODEL = models.EGNN + model = models.EGNN( + hidden_size=cfg.model.latent_dim, + output_size=1, + dt=metadata["dt"] * metadata["write_every"], + displacement_fn=displacement_fn, + shift_fn=shift_fn, + normalization_stats=normalization_stats, + num_mp_steps=cfg.model.num_mp_steps, + n_vels=input_seq_length - 1, + homogeneous_particles=homogeneous_particles, + residual=True, + key=key, + ) elif model_name == "painn": assert magnitude_features, "PaiNN requires magnitudes" radius = metadata["default_connectivity_radius"] * 1.5 - - def model_fn(x): - return models.PaiNN( - hidden_size=cfg.model.latent_dim, - output_size=1, - n_vels=input_seq_length - 1, - radial_basis_fn=models.painn.gaussian_rbf(20, radius, trainable=True), - cutoff_fn=models.painn.cosine_cutoff(radius), - num_mp_steps=cfg.model.num_mp_steps, - )(x) - - MODEL = models.PaiNN + n_rbf = 20 + # scalar input: vel_mag (input_seq_length - 1) + optional one-hot types + s_in_size = (input_seq_length - 1) + ( + 0 if homogeneous_particles else NodeType.SIZE + ) + # vector input: traj (n_vels) + optional force (1) + optional bound (2) + n_vels = input_seq_length - 1 + v_in_channels = n_vels + if has_external_force: + v_in_channels += 1 + if not any(metadata["periodic_boundary_conditions"]): + v_in_channels += 2 + model = models.PaiNN( + hidden_size=cfg.model.latent_dim, + output_size=1, + n_vels=n_vels, + radial_basis_fn=models.painn.gaussian_rbf(n_rbf, radius, trainable=True), + cutoff_fn=models.painn.cosine_cutoff(radius), + num_mp_steps=cfg.model.num_mp_steps, + n_rbf=n_rbf, + s_in_size=s_in_size, + v_in_channels=v_in_channels, + homogeneous_particles=homogeneous_particles, + key=key, + ) elif model_name == "linear": - - def model_fn(x): - return models.Linear(dim_out=metadata["dim"])(x) - - MODEL = models.Linear - - return model_fn, MODEL + # input = vel_hist + vel_mag + bound + force + particle_type[:, None] + n_in = node_feature_size( + metadata, + input_seq_length, + has_external_force=has_external_force, + has_magnitudes=magnitude_features, + has_homogeneous_particles=homogeneous_particles, + ) + 1 + model = models.Linear(in_features=n_in, dim_out=metadata["dim"], key=key) + else: + raise ValueError(f"Unknown model: {model_name}") + + return model diff --git a/lagrangebench/train/strats.py b/lagrangebench/train/strats.py index 24da4d2..0f5d3ec 100644 --- a/lagrangebench/train/strats.py +++ b/lagrangebench/train/strats.py @@ -109,7 +109,7 @@ def push_forward_sample_steps(key, step, pushforward): return key, unroll_steps -def push_forward_build(model_apply, case): +def push_forward_build(case): r"""Build the push forward function, introduced by `Brandstetter et al. `_. @@ -130,24 +130,23 @@ def push_forward_build(model_apply, case): running gradients through the last unroll step. Args: - model_apply: Model apply function case: Case setup function """ + import equinox as eqx # local import to avoid circular dep on package init - @jax.jit - def push_forward_fn(features, current_pos, particle_type, neighbors, params, state): + @eqx.filter_jit + def push_forward_fn(model, features, current_pos, particle_type, neighbors): """Push forward function. Args: + model: Equinox model. features: Input features current_pos: Current position particle_type: Particle type vector neighbors: Neighbor list - params: Model parameters - state: Model state """ # no buffer overflow check here, since push forward acts on later epochs - pred, _ = model_apply(params, state, (features, particle_type)) + pred = model((features, particle_type)) next_pos = case.integrate(pred, current_pos) current_pos = jnp.concatenate( [current_pos[:, 1:], next_pos[:, None, :]], axis=1 diff --git a/lagrangebench/train/trainer.py b/lagrangebench/train/trainer.py index de1fce5..b22c119 100644 --- a/lagrangebench/train/trainer.py +++ b/lagrangebench/train/trainer.py @@ -5,13 +5,11 @@ from functools import partial from typing import Callable, Dict, Optional, Tuple, Union -import haiku as hk +import equinox as eqx import jax import jax.numpy as jnp -import jraph import optax import wandb -from jax import vmap from omegaconf import DictConfig, OmegaConf from torch.utils.data import DataLoader @@ -24,69 +22,56 @@ broadcast_to_batch, get_kinematic_mask, get_num_params, - load_haiku, - save_haiku, + load_model, + save_model, set_seed, ) from .strats import push_forward_build, push_forward_sample_steps -@partial(jax.jit, static_argnames=["model_fn", "loss_weight"]) -def _mse( - params: hk.Params, - state: hk.State, +def _sample_loss( + model: eqx.Module, features: Dict[str, jnp.ndarray], particle_type: jnp.ndarray, - target: jnp.ndarray, - model_fn: Callable, - loss_weight: Dict[str, float], -): - pred, state = model_fn(params, state, (features, particle_type)) - # check active (non zero) output shapes + target: Dict[str, jnp.ndarray], + loss_weight, +) -> jnp.ndarray: + """Mean-squared error over active (non-kinematic) particles of a single sample.""" + pred = model((features, particle_type)) assert all(target[k].shape == pred[k].shape for k in pred) - # particle mask + non_kinematic_mask = jnp.logical_not(get_kinematic_mask(particle_type)) num_non_kinematic = non_kinematic_mask.sum() - # loss components + losses = [] for t in pred: w = getattr(loss_weight, t) losses.append((w * (pred[t] - target[t]) ** 2).sum(axis=-1)) total_loss = jnp.array(losses).sum(0) total_loss = jnp.where(non_kinematic_mask, total_loss, 0) - total_loss = total_loss.sum() / num_non_kinematic + return total_loss.sum() / num_non_kinematic - return total_loss, state +def _make_update_fn(opt_update, loss_weight): + """Build a jitted update step that does batched loss + grad + optim step.""" -@partial(jax.jit, static_argnames=["loss_fn", "opt_update"]) -def _update( - params: hk.Params, - state: hk.State, - features_batch: Tuple[jraph.GraphsTuple, ...], - target_batch: Tuple[jnp.ndarray, ...], - particle_type_batch: Tuple[jnp.ndarray, ...], - opt_state: optax.OptState, - loss_fn: Callable, - opt_update: Callable, -) -> Tuple[float, hk.Params, hk.State, optax.OptState]: - value_and_grad_vmap = vmap( - jax.value_and_grad(loss_fn, has_aux=True), in_axes=(None, None, 0, 0, 0) - ) - (loss, state), grads = value_and_grad_vmap( - params, state, features_batch, particle_type_batch, target_batch - ) - - # aggregate over the first (batch) dimension of each leave element - grads = jax.tree_map(lambda x: x.sum(axis=0), grads) - state = jax.tree_map(lambda x: x.sum(axis=0), state) - loss = jax.tree_map(lambda x: x.mean(axis=0), loss) + def batch_loss(model, features_batch, particle_type_batch, target_batch): + losses = jax.vmap(_sample_loss, in_axes=(None, 0, 0, 0, None))( + model, features_batch, particle_type_batch, target_batch, loss_weight + ) + return losses.mean() - updates, opt_state = opt_update(grads, opt_state, params) - new_params = optax.apply_updates(params, updates) + @eqx.filter_jit + def update_fn(model, features_batch, target_batch, particle_type_batch, opt_state): + loss, grads = eqx.filter_value_and_grad(batch_loss)( + model, features_batch, particle_type_batch, target_batch + ) + updates, opt_state = opt_update(grads, opt_state, model) + model = eqx.apply_updates(model, updates) + return loss, model, opt_state - return loss, new_params, state, opt_state + return update_fn class Trainer: @@ -103,7 +88,7 @@ class Trainer: def __init__( self, - model: hk.TransformedWithState, + model: eqx.Module, case, data_train: H5Dataset, data_valid: H5Dataset, @@ -116,7 +101,7 @@ def __init__( """Initializes the trainer. Args: - model: (Transformed) Haiku model. + model: Equinox model (already initialized with a PRNG key). case: Case setup class. data_train: Training dataset. data_valid: Validation dataset. @@ -126,7 +111,6 @@ def __init__( input_seq_length: Input sequence length, i.e. number of past positions. seed: Random seed for model init, training tricks and dataloading. """ - if isinstance(cfg_train, Dict): cfg_train = OmegaConf.create(cfg_train) if isinstance(cfg_eval, Dict): @@ -137,36 +121,32 @@ def __init__( self.model = model self.case = case self.input_seq_length = input_seq_length - # if one of the cfg_* arguments has a subset of the default configs, merge them self.cfg_train = OmegaConf.merge(defaults.train, cfg_train) self.cfg_eval = OmegaConf.merge(defaults.eval, cfg_eval) self.cfg_logging = OmegaConf.merge(defaults.logging, cfg_logging) - assert isinstance( - model, hk.TransformedWithState - ), "Model must be passed as an Haiku transformed function." + assert isinstance(model, eqx.Module), ( + "Model must be an equinox.Module instance." + ) available_rollout_length = data_valid.subseq_length - input_seq_length - assert cfg_eval.n_rollout_steps <= available_rollout_length, ( + assert self.cfg_eval.n_rollout_steps <= available_rollout_length, ( "The loss cannot be evaluated on longer than a ground truth trajectory " - f"({cfg_eval.n_rollout_steps} > {available_rollout_length})" + f"({self.cfg_eval.n_rollout_steps} > {available_rollout_length})" ) - assert cfg_eval.train.n_trajs <= data_valid.num_samples, ( + assert self.cfg_eval.train.n_trajs <= data_valid.num_samples, ( f"Number of requested validation trajectories exceeds the available ones " - f"({cfg_eval.train.n_trajs} > {data_valid.num_samples})" + f"({self.cfg_eval.train.n_trajs} > {data_valid.num_samples})" ) - # set the number of validation trajectories during training if self.cfg_eval.train.n_trajs == -1: self.cfg_eval.train.n_trajs = data_valid.num_samples - # make immutable for jitting loss_weight = self.cfg_train.loss_weight self.loss_weight = namedtuple("loss_weight", loss_weight)(**loss_weight) self.base_key, seed_worker, generator = set_seed(seed) - # dataloaders self.loader_train = DataLoader( dataset=data_train, batch_size=self.cfg_eval.train.batch_size, @@ -185,19 +165,16 @@ def __init__( generator=generator, ) - # exponential learning rate decays from lr_start to lr_final over lr_decay_steps lr_scheduler = optax.exponential_decay( init_value=self.cfg_train.optimizer.lr_start, transition_steps=self.cfg_train.optimizer.lr_decay_steps, decay_rate=self.cfg_train.optimizer.lr_decay_rate, end_value=self.cfg_train.optimizer.lr_final, ) - # optimizer - self.opt_init, self.opt_update = optax.adamw( + self.optimizer = optax.adamw( learning_rate=lr_scheduler, weight_decay=1e-8 ) - # metrics computer config self.metrics_computer = MetricsComputer( self.cfg_eval.train.metrics, dist_fn=self.case.displacement, @@ -209,13 +186,12 @@ def __init__( def train( self, step_max: int = defaults.train.step_max, - params: Optional[hk.Params] = None, - state: Optional[hk.State] = None, + model: Optional[eqx.Module] = None, opt_state: Optional[optax.OptState] = None, store_ckp: Optional[str] = None, load_ckp: Optional[str] = None, wandb_config: Optional[Dict] = None, - ) -> Tuple[hk.Params, hk.State, optax.OptState]: + ) -> Tuple[eqx.Module, optax.OptState]: """ Training loop. @@ -224,18 +200,15 @@ def train( Args: step_max: Maximum number of training steps. - params: Optional model parameters. If provided, training continues from it. - state: Optional model state. + model: Optional model. If provided, training continues from it. opt_state: Optional optimizer state. store_ckp: Checkpoints destination. Without it params aren't saved. load_ckp: Initial checkpoint directory. If provided resumes training. wandb_config: Optional configuration to be logged on wandb. Returns: - Tuple containing the final model parameters, state and optimizer state. + Tuple containing the final trained model and optimizer state. """ - - model = self.model case = self.case cfg_train = self.cfg_train cfg_eval = self.cfg_eval @@ -245,44 +218,29 @@ def train( noise_std = cfg_train.noise_std pushforward = cfg_train.pushforward - # Precompile model for evaluation - model_apply = jax.jit(model.apply) + update_fn = _make_update_fn(self.optimizer.update, self.loss_weight) - # loss and update functions - loss_fn = partial(_mse, model_fn=model_apply, loss_weight=self.loss_weight) - update_fn = partial(_update, loss_fn=loss_fn, opt_update=self.opt_update) + if model is None: + model = self.model - # init values raw_batch = next(iter(loader_train)) - raw_batch = jax.tree_map(lambda x: jnp.array(x), raw_batch) # numpy to jax + raw_batch = jax.tree.map(lambda x: jnp.array(x), raw_batch) pos_input_and_target, particle_type = raw_batch raw_sample = (pos_input_and_target[0], particle_type[0]) key, features, _, neighbors = case.allocate(self.base_key, raw_sample) step = 0 - if params is not None: - # continue training from params - if state is None: - state = {} - elif load_ckp: - # continue training from checkpoint - params, state, opt_state, step = load_haiku(load_ckp) - else: - # initialize new model - key, subkey = jax.random.split(key, 2) - params, state = model.init(subkey, (features, particle_type[0])) - - # start logging + if load_ckp is not None: + model, opt_state, step = load_model(load_ckp, model) + if cfg_logging.wandb: if wandb_config is None: - # minimal config reconstruction without model details wandb_config = { "train": OmegaConf.to_container(cfg_train), "eval": OmegaConf.to_container(cfg_eval), "logging": OmegaConf.to_container(cfg_logging), "dataset_path": loader_train.dataset.dataset_path, } - else: wandb_config["eval"]["train"]["n_trajs"] = cfg_eval.train.n_trajs @@ -290,7 +248,7 @@ def train( "dataset_name": loader_train.dataset.name, "len_train": len(loader_train.dataset), "len_eval": len(loader_valid.dataset), - "num_params": get_num_params(params).item(), + "num_params": int(get_num_params(model)), "step_start": step, } @@ -302,31 +260,27 @@ def train( save_code=True, ) - # initialize optimizer state if opt_state is None: - opt_state = self.opt_init(params) + opt_state = self.optimizer.init(eqx.filter(model, eqx.is_inexact_array)) - # create new checkpoint directory if store_ckp is not None: os.makedirs(store_ckp, exist_ok=True) os.makedirs(os.path.join(store_ckp, "best"), exist_ok=True) preprocess_vmap = jax.vmap(case.preprocess, in_axes=(0, 0, None, 0, None)) - push_forward = push_forward_build(model_apply, case) - push_forward_vmap = jax.vmap(push_forward, in_axes=(0, 0, 0, 0, None, None)) + push_forward = push_forward_build(case) + push_forward_vmap = eqx.filter_vmap( + push_forward, in_axes=(None, 0, 0, 0, 0) + ) - # prepare for batch training. keys = jax.random.split(key, loader_train.batch_size) neighbors_batch = broadcast_to_batch(neighbors, loader_train.batch_size) - # start training while step < step_max + 1: for raw_batch in loader_train: - # numpy to jax - raw_batch = jax.tree_map(lambda x: jnp.array(x), raw_batch) + raw_batch = jax.tree.map(lambda x: jnp.array(x), raw_batch) key, unroll_steps = push_forward_sample_steps(key, step, pushforward) - # target computation incorporates the sampled number pushforward steps _keys, features_batch, target_batch, neighbors_batch = preprocess_vmap( keys, raw_batch, @@ -334,44 +288,34 @@ def train( neighbors_batch, unroll_steps, ) - # unroll for push-forward steps _current_pos = raw_batch[0][:, :, : self.input_seq_length] for _ in range(unroll_steps): if neighbors_batch.did_buffer_overflow.sum() > 0: break _current_pos, neighbors_batch, features_batch = push_forward_vmap( + model, features_batch, _current_pos, raw_batch[1], neighbors_batch, - params, - state, ) if neighbors_batch.did_buffer_overflow.sum() > 0: - # check if the neighbor list is too small for any of the samples - # if so, reallocate the neighbor list - print(f"Reallocate neighbors list at step {step}") ind = jnp.argmax(neighbors_batch.did_buffer_overflow) sample = broadcast_from_batch(raw_batch, index=ind) - _, _, _, nbrs = case.allocate(keys[ind], sample, noise_std) print(f"From {neighbors_batch.idx[ind].shape} to {nbrs.idx.shape}") neighbors_batch = broadcast_to_batch(nbrs, loader_train.batch_size) - - # To run the loop N times even if sometimes - # did_buffer_overflow > 0 we directly return to the beginning continue keys = _keys - loss, params, state, opt_state = update_fn( - params=params, - state=state, - features_batch=features_batch, - target_batch=target_batch, - particle_type_batch=raw_batch[1], - opt_state=opt_state, + loss, model, opt_state = update_fn( + model, + features_batch, + target_batch, + raw_batch[1], + opt_state, ) if step % cfg_logging.log_steps == 0: @@ -385,11 +329,9 @@ def train( if step % cfg_logging.eval_steps == 0 and step > 0: nbrs = broadcast_from_batch(neighbors_batch, index=0) eval_metrics = eval_rollout( + model=model, case=case, metrics_computer=self.metrics_computer, - model_apply=model_apply, - params=params, - state=state, neighbors=nbrs, loader_eval=loader_valid, n_rollout_steps=cfg_eval.n_rollout_steps, @@ -397,14 +339,23 @@ def train( rollout_dir=cfg_eval.rollout_dir, out_type=cfg_eval.train.out_type, ) - metrics = averaged_metrics(eval_metrics) + # pick a rollout metric as the "loss" used for best-ckpt + # selection (val/loss isn't emitted by MetricsComputer) + best_key = next( + ( + k + for k in ("val/mse20", "val/mse10", "val/mse5", "val/mse") + if k in metrics + ), + None, + ) metadata_ckp = { "step": step, - "loss": metrics.get("val/loss", None), + "loss": float(metrics[best_key]) if best_key else None, } if store_ckp is not None: - save_haiku(store_ckp, params, state, opt_state, metadata_ckp) + save_model(store_ckp, model, opt_state, metadata_ckp) if cfg_logging.wandb: wandb_run.log(metrics, step) @@ -418,4 +369,4 @@ def train( if cfg_logging.wandb: wandb_run.finish() - return params, state, opt_state + return model, opt_state diff --git a/lagrangebench/utils.py b/lagrangebench/utils.py index 9255fd8..b8ca599 100644 --- a/lagrangebench/utils.py +++ b/lagrangebench/utils.py @@ -3,11 +3,11 @@ import enum import json import os -import pickle import random from typing import Callable, Tuple import cloudpickle +import equinox as eqx import jax import jax.numpy as jnp import numpy as np @@ -30,65 +30,55 @@ def get_kinematic_mask(particle_type): res = jnp.logical_or( particle_type == NodeType.SOLID_WALL, particle_type == NodeType.MOVING_WALL ) - # In datasets with variable number of particles we treat padding as kinematic nodes res = jnp.logical_or(res, particle_type == NodeType.PAD_VALUE) return res def broadcast_to_batch(sample, batch_size: int): - """Broadcast a pytree to a batched one with first dimension batch_size.""" + """Broadcast a pytree to a batched one with first dimension ``batch_size``.""" assert batch_size > 0 - return jax.tree_map(lambda x: jnp.repeat(x[None, ...], batch_size, axis=0), sample) + return jax.tree.map( + lambda x: jnp.repeat(x[None, ...], batch_size, axis=0), sample + ) def broadcast_from_batch(batch, index: int): - """Broadcast a batched pytree to the sample `index` out of the batch.""" + """Select ``index``th sample from a batched pytree.""" assert index >= 0 - return jax.tree_map(lambda x: x[index], batch) - - -def save_pytree(ckp_dir: str, pytree_obj, name) -> None: - """Save a pytree to a directory.""" - with open(os.path.join(ckp_dir, f"{name}_array.npy"), "wb") as f: - for x in jax.tree_leaves(pytree_obj): - np.save(f, x, allow_pickle=False) + return jax.tree.map(lambda x: x[index], batch) - tree_struct = jax.tree_map(lambda t: 0, pytree_obj) - with open(os.path.join(ckp_dir, f"{name}_tree.pkl"), "wb") as f: - pickle.dump(tree_struct, f) - -def save_haiku(ckp_dir: str, params, state, opt_state, metadata_ckp) -> None: - """Save params, state and optimizer state to ckp_dir. +def save_model(ckp_dir: str, model: eqx.Module, opt_state, metadata_ckp) -> None: + """Save model, optimizer state and metadata to ckp_dir. Additionally it tracks and saves the best model to ckp_dir/best. - See: https://github.com/deepmind/dm-haiku/issues/18 + Serializes the model's leaf arrays with :func:`equinox.tree_serialise_leaves`. + The *structure* of the model is not stored; loading requires an instance of + the same ``eqx.Module`` class with identical hyperparameters. """ - save_pytree(ckp_dir, params, "params") - save_pytree(ckp_dir, state, "state") - + os.makedirs(ckp_dir, exist_ok=True) + eqx.tree_serialise_leaves(os.path.join(ckp_dir, "model.eqx"), model) with open(os.path.join(ckp_dir, "opt_state.pkl"), "wb") as f: cloudpickle.dump(opt_state, f) with open(os.path.join(ckp_dir, "metadata_ckp.json"), "w") as f: json.dump(metadata_ckp, f) - # only run for the main checkpoint directory (not best) if "best" not in ckp_dir: ckp_dir_best = os.path.join(ckp_dir, "best") metadata_best_path = os.path.join(ckp_dir, "best", "metadata_ckp.json") tag = "" - - if os.path.exists(metadata_best_path): # all except first step + if os.path.exists(metadata_best_path): with open(metadata_best_path, "r") as fp: metadata_ckp_best = json.loads(fp.read()) - - # if loss is better than best previous loss, save to best model directory - if metadata_ckp["loss"] < metadata_ckp_best["loss"]: - save_haiku(ckp_dir_best, params, state, opt_state, metadata_ckp) + if metadata_ckp["loss"] is not None and ( + metadata_ckp_best["loss"] is None + or metadata_ckp["loss"] < metadata_ckp_best["loss"] + ): + save_model(ckp_dir_best, model, opt_state, metadata_ckp) tag = " (best so far)" - else: # first step - save_haiku(ckp_dir_best, params, state, opt_state, metadata_ckp) + else: + save_model(ckp_dir_best, model, opt_state, metadata_ckp) print( f"saved model to {ckp_dir} at step {metadata_ckp['step']}" @@ -96,60 +86,43 @@ def save_haiku(ckp_dir: str, params, state, opt_state, metadata_ckp) -> None: ) -def load_pytree(model_dir: str, name): - """Load a pytree from a directory.""" - with open(os.path.join(model_dir, f"{name}_tree.pkl"), "rb") as f: - tree_struct = pickle.load(f) - - leaves, treedef = jax.tree_flatten(tree_struct) - - with open(os.path.join(model_dir, f"{name}_array.npy"), "rb") as f: - flat_state = [np.load(f) for _ in leaves] - - return jax.tree_unflatten(treedef, flat_state) - +def load_model(model_dir: str, template: eqx.Module): + """Load model, optimizer state and last training step from model_dir. -def load_haiku(model_dir: str): - """Load params, state, optimizer state and last training step from model_dir. + The ``template`` argument must be an ``eqx.Module`` instance with the same + architecture as the saved one; its arrays will be replaced with the + deserialized leaves. - See: https://github.com/deepmind/dm-haiku/issues/18 + Returns ``(model, opt_state, step)``. """ - params = load_pytree(model_dir, "params") - state = load_pytree(model_dir, "state") - - with open(os.path.join(model_dir, "opt_state.pkl"), "rb") as f: - opt_state = cloudpickle.load(f) - + model = eqx.tree_deserialise_leaves(os.path.join(model_dir, "model.eqx"), template) + opt_state_path = os.path.join(model_dir, "opt_state.pkl") + if os.path.exists(opt_state_path): + with open(opt_state_path, "rb") as f: + opt_state = cloudpickle.load(f) + else: + opt_state = None with open(os.path.join(model_dir, "metadata_ckp.json"), "r") as fp: metadata_ckp = json.loads(fp.read()) - print(f"Loaded model from {model_dir} at step {metadata_ckp['step']}") - - return params, state, opt_state, metadata_ckp["step"] + return model, opt_state, metadata_ckp["step"] -def get_num_params(params): - """Get the number of parameters in a Haiku model.""" - return sum(np.prod(p.shape) for p in jax.tree_leaves(params)) - - -def print_params_shapes(params, prefix=""): - if not isinstance(params, dict): - print(f"{prefix: <40}, shape = {params.shape}") - else: - for k, v in params.items(): - print_params_shapes(v, prefix=prefix + k) +def get_num_params(model_or_tree): + """Get the number of parameters in an Equinox model (or any pytree).""" + return sum( + np.prod(p.shape) + for p in jax.tree.leaves(eqx.filter(model_or_tree, eqx.is_inexact_array)) + ) def set_seed(seed: int) -> Tuple[jax.Array, Callable, torch.Generator]: """Set seeds for jax, random and torch.""" - # first PRNG key key = jax.random.PRNGKey(seed) np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) - # dataloader-related seeds def seed_worker(_): worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) @@ -157,5 +130,4 @@ def seed_worker(_): generator = torch.Generator() generator.manual_seed(seed) - return key, seed_worker, generator diff --git a/main.py b/main.py index 644d995..dedaae0 100644 --- a/main.py +++ b/main.py @@ -1,72 +1,46 @@ -import os - -from omegaconf import DictConfig, OmegaConf +"""Hydra entry point for LagrangeBench training and inference. +Invocation:: -def check_subset(superset, subset, full_key=""): - """Check that the keys of 'subset' are a subset of 'superset'.""" - for k, v in subset.items(): - key = full_key + k - if isinstance(v, dict): - check_subset(superset[k], v, key + ".") - else: - msg = f"cli_args must be a subset of the defaults. Wrong cli key: '{key}'" - assert k in superset, msg + python main.py --config-path configs/tgv_2d --config-name gns gpu=6 +Any key defined in :mod:`lagrangebench.defaults` can be overridden from the +CLI with the standard ``key=value`` syntax. Example:: -def load_embedded_configs(config_path: str, cli_args: DictConfig) -> DictConfig: - """Loads all 'extends' embedded configs and merge them with the cli overwrites.""" + python main.py -cp configs/tgv_2d -cn gns gpu=6 train.step_max=20000 - cfgs = [OmegaConf.load(config_path)] - while "extends" in cfgs[0]: - extends_path = cfgs[0]["extends"] - del cfgs[0]["extends"] +Resuming from a checkpoint is a plain override: - # go to parents configs until the defaults are reached - if extends_path != "LAGRANGEBENCH_DEFAULTS": - cfgs = [OmegaConf.load(extends_path)] + cfgs - else: - from lagrangebench.defaults import defaults + python main.py -cp configs/tgv_2d -cn gns mode=infer load_ckp=runs/ckp/gns_tgv2d_xyz +""" - cfgs = [defaults] + cfgs +import os - # assert that the cli_args are a subset of the defaults if inheritance from - # defaults is used. - check_subset(cfgs[0], cli_args) +import hydra +from hydra.core.config_store import ConfigStore +from omegaconf import DictConfig, OmegaConf - break +from lagrangebench.defaults import defaults as _lb_defaults - # merge all embedded configs and give highest priority to cli_args - cfg = OmegaConf.merge(*cfgs, cli_args) - return cfg +# Register the Python defaults under ``lagrangebench_defaults`` so that config +# files can pull them in via ``defaults: [/lagrangebench_defaults, _self_]``. +_cs = ConfigStore.instance() +_cs.store(name="lagrangebench_defaults", node=_lb_defaults, package="_global_") -if __name__ == "__main__": - cli_args = OmegaConf.from_cli() - assert ("config" in cli_args) != ( - "load_ckp" in cli_args - ), "You must specify one of 'config' or 'load_ckp'." - - if "config" in cli_args: # start from config.yaml - config_path = cli_args.config - elif "load_ckp" in cli_args: # start from a checkpoint - config_path = os.path.join(cli_args.load_ckp, "config.yaml") - - # values that need to be specified before importing jax - cli_args.gpu = cli_args.get("gpu", -1) - cli_args.xla_mem_fraction = cli_args.get("xla_mem_fraction", 0.75) - - # specify cuda device - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 from TensorFlow - os.environ["CUDA_VISIBLE_DEVICES"] = str(cli_args.gpu) - if cli_args.gpu == -1: +@hydra.main(version_base=None, config_path=None, config_name=None) +def main(cfg: DictConfig) -> int: + # ``gpu`` and ``xla_mem_fraction`` must be set before importing jax. + gpu = cfg.get("gpu", -1) + gpu = -1 if gpu is None else int(gpu) + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see TensorFlow issue #152 + os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu) + if gpu == -1: os.environ["JAX_PLATFORMS"] = "cpu" - os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(cli_args.xla_mem_fraction) - - # The following line makes the code deterministic on GPUs, but also extremely slow. - # os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true" - - cfg = load_embedded_configs(config_path, cli_args) + mem_fraction = cfg.get("xla_mem_fraction", 0.75) + if mem_fraction is None: + mem_fraction = 0.75 + os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(mem_fraction) print("#" * 79, "\nStarting a LagrangeBench run with the following configs:") print(OmegaConf.to_yaml(cfg)) @@ -74,4 +48,8 @@ def load_embedded_configs(config_path: str, cli_args: DictConfig) -> DictConfig: from lagrangebench.runner import train_or_infer - train_or_infer(cfg) + return train_or_infer(cfg) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 10cac0e..73872f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,9 +24,9 @@ classifiers = [ "Operating System :: MacOS", "Operating System :: POSIX :: Linux", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Physics", "Topic :: Scientific/Engineering :: Hydrology", @@ -35,26 +35,25 @@ classifiers = [ ] [tool.poetry.dependencies] -python = ">=3.9,<=3.11" +python = ">=3.10" cloudpickle = ">=2.2.1" h5py = ">=3.9.0" PyYAML = ">=6.0" numpy = ">=1.24.4" wandb = ">=0.15.11" pyvista = ">=0.42.2" -jax = {version = "0.4.29", extras = ["cpu"]} -jaxlib = "0.4.29" -dm-haiku = ">=0.0.10" -e3nn-jax = "0.20.3" -jmp = ">=0.0.4" -jraph = "0.0.6.dev0" -optax = "0.1.7" +jax = ">=0.6.0" +jaxlib = ">=0.6.0" +equinox = ">=0.11.0" +e3nn-jax = ">=0.21.0" +jraph = ">=0.0.6.dev0" +optax = ">=0.2.0" ott-jax = ">=0.4.2" matscipy = ">=0.8.0" -torch = {version = "2.3.1+cpu", source = "torchcpu"} +torch = ">=2.3.0" wget = ">=3.2" omegaconf = ">=2.3.0" -jax-sph = ">=0.0.3" +jax-sph = {git = "https://github.com/tumaer/jax-sph.git"} [tool.poetry.group.dev.dependencies] # mypy = ">=1.8.0" - consider in the future @@ -69,11 +68,6 @@ sphinx = "7.2.6" sphinx-rtd-theme = "1.3.0" toml = ">=0.10.2" -[[tool.poetry.source]] -name = "torchcpu" -url = "https://download.pytorch.org/whl/cpu" -priority = "explicit" - [tool.ruff] exclude = [ ".git", diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..d717f15 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,9 @@ +"""Pytest fixtures and global test configuration.""" + +import jax + +# Enable 64-bit precision for equivariance tests to pass within the +# strict tolerance used by ``ModelTest.assert_equivariant`` (1e-5). With +# float32 the accumulated numerical error from many message-passing +# aggregations can exceed this tolerance. +jax.config.update("jax_enable_x64", True) diff --git a/tests/models_test.py b/tests/models_test.py index 702280c..5ab3b65 100644 --- a/tests/models_test.py +++ b/tests/models_test.py @@ -1,7 +1,6 @@ import unittest import e3nn_jax as e3nn -import haiku as hk import jax import jax.numpy as jnp import numpy as np @@ -37,7 +36,7 @@ def dummy_sample(self, vel=None, pos=None): def key(self): return jax.random.PRNGKey(0) - def assert_equivariant(self, f, params, state): + def assert_equivariant(self, model): key = self.key() vel = e3nn.normal("5x1o", key, (100,)) @@ -51,7 +50,7 @@ def wrapper(v, p): "abs_pos": p.array.reshape((100, 1, 3)), } ) - y, _ = f.apply(params, state, (sample, particle_type)) + y = model((sample, particle_type)) return e3nn.IrrepsArray("1x1o", y["acc"]) # random rotation matrix @@ -68,58 +67,47 @@ def assert_(x, y): jax.tree_util.tree_map(assert_, out1, out2) def test_segnn(self): - def segnn(x): - return models.SEGNN( - node_features_irreps="5x1o + 5x0e", - edge_features_irreps="1x1o + 1x0e", - scalar_units=8, - lmax_hidden=1, - lmax_attributes=1, - n_vels=5, - num_mp_steps=1, - output_irreps="1x1o", - )(x) - - segnn = hk.without_apply_rng(hk.transform_with_state(segnn)) - x, particle_type = self.dummy_sample() - params, segnn_state = segnn.init(self.key(), (x, particle_type)) - - self.assert_equivariant(segnn, params, segnn_state) + segnn = models.SEGNN( + node_features_irreps="5x1o + 5x0e", + edge_features_irreps="1x1o + 1x0e", + scalar_units=8, + lmax_hidden=1, + lmax_attributes=1, + n_vels=5, + num_mp_steps=1, + output_irreps="1x1o", + key=self.key(), + ) + self.assert_equivariant(segnn) def test_egnn(self): - def egnn(x): - return models.EGNN( - hidden_size=8, - output_size=1, - num_mp_steps=1, - dt=0.01, - n_vels=5, - displacement_fn=lambda x, y: x - y, - shift_fn=lambda x, y: x + y, - )(x) - - egnn = hk.without_apply_rng(hk.transform_with_state(egnn)) - x, particle_type = self.dummy_sample() - params, egnn_state = egnn.init(self.key(), (x, particle_type)) - - self.assert_equivariant(egnn, params, egnn_state) + egnn = models.EGNN( + hidden_size=8, + output_size=1, + num_mp_steps=1, + dt=0.01, + n_vels=5, + displacement_fn=lambda x, y: x - y, + shift_fn=lambda x, y: x + y, + key=self.key(), + ) + self.assert_equivariant(egnn) def test_painn(self): - def painn(x): - return models.PaiNN( - hidden_size=8, - output_size=1, - num_mp_steps=1, - radial_basis_fn=models.painn.gaussian_rbf(20, 10, trainable=True), - cutoff_fn=models.painn.cosine_cutoff(10), - n_vels=5, - )(x) - - painn = hk.without_apply_rng(hk.transform_with_state(painn)) - x, particle_type = self.dummy_sample() - params, painn_state = painn.init(self.key(), (x, particle_type)) - - self.assert_equivariant(painn, params, painn_state) + # With homogeneous particles: s_in = n_vels (5); no force/bound → v_in = 5 + painn = models.PaiNN( + hidden_size=8, + output_size=1, + num_mp_steps=1, + radial_basis_fn=models.painn.gaussian_rbf(20, 10, trainable=True), + cutoff_fn=models.painn.cosine_cutoff(10), + n_vels=5, + n_rbf=20, + s_in_size=5, + v_in_channels=5, + key=self.key(), + ) + self.assert_equivariant(painn) if __name__ == "__main__": diff --git a/tests/rollout_test.py b/tests/rollout_test.py index 1a9a05e..2d54464 100644 --- a/tests/rollout_test.py +++ b/tests/rollout_test.py @@ -1,18 +1,16 @@ import unittest from functools import partial +from typing import Dict, Tuple -import haiku as hk +import equinox as eqx import jax import jax.numpy as jnp import numpy as np -from jax import config as jax_config -from jax import jit, vmap +from jax import vmap from jax_sph.jax_md import space from omegaconf import OmegaConf from torch.utils.data import DataLoader -jax_config.update("jax_enable_x64", True) - from lagrangebench.case_setup import case_builder from lagrangebench.data import H5Dataset from lagrangebench.data.utils import get_dataset_stats, numpy_collate @@ -21,6 +19,32 @@ from lagrangebench.utils import broadcast_from_batch +class _CheatingModel(eqx.Module): + """Test-only model that replays pre-computed acceleration targets. + + Tracks the rollout step with a ``jnp.ndarray`` counter. The rollout loop + calls ``.advance()`` between steps to increment it. + """ + + target: jnp.ndarray + counter: jnp.ndarray + + def __init__(self, target: jnp.ndarray, start: int): + self.target = target + self.counter = jnp.asarray(start, dtype=jnp.int32) + + def __call__( + self, sample: Tuple[Dict[str, jnp.ndarray], jnp.ndarray] + ) -> Dict[str, jnp.ndarray]: + acc = jax.lax.dynamic_index_in_dim( + self.target, self.counter, axis=1, keepdims=False + ) + return {"acc": acc} + + def advance(self) -> "_CheatingModel": + return eqx.tree_at(lambda m: m.counter, self, self.counter + 1) + + class TestInferBuilder(unittest.TestCase): """Class for unit testing the evaluate_single_rollout function.""" @@ -76,7 +100,7 @@ def test_rollout(self): # get one validation trajectory from the debug dataset traj_batch_i = next(iter(self.loader_valid)) - traj_batch_i = jax.tree_map(lambda x: jnp.array(x), traj_batch_i) + traj_batch_i = jax.tree.map(lambda x: jnp.array(x), traj_batch_i) # remove batch dimension self.assertTrue(traj_batch_i[0].shape[0] == 1, "We test only batch size 1") traj_i = broadcast_from_batch(traj_batch_i, index=0) @@ -89,44 +113,15 @@ def test_rollout(self): stats = self.normalization_stats["acceleration"] accs = (accs - stats["mean"]) / stats["std"] - class CheatingModel(hk.Module): - def __init__(self, target, start): - super().__init__() - self.target = target - self.start = start - - def __call__(self, x): - i = hk.get_state( - "counter", - shape=[], - dtype=jnp.int32, - init=hk.initializers.Constant(self.start), - ) - hk.set_state("counter", i + 1) - return {"acc": self.target[:, i]} - - def setup_model(target, start): - def model(x): - return CheatingModel(target, start)(x) - - model = hk.without_apply_rng(hk.transform_with_state(model)) - params, state = model.init(None, None) - model_apply = model.apply - model_apply = jit(model_apply) - return params, state, model_apply - - params, state, model_apply = setup_model(accs, 0) - # proof that the above "model" works - out, state = model_apply(params, state, None) + model0 = _CheatingModel(accs, start=0) + out = model0((None, None)) pred_acc = stats["mean"] + out["acc"] * stats["std"] pred_pos = self.shift_fn(positions[:, isl - 1], vels[:, isl - 2] + pred_acc) pred_pos = jnp.asarray(pred_pos, dtype=jnp.float32) target_pos = positions[:, isl] - assert jnp.isclose(pred_pos, target_pos, atol=1e-7).all(), "Wrong setup" - params, state, model_apply = setup_model(accs, isl - 2) _, neighbors = self.case.allocate_eval((positions[:, :isl], traj_i[1])) metrics_computer = MetricsComputer( @@ -138,21 +133,20 @@ def model(x): forward_eval = partial( _forward_eval, - model_apply=model_apply, case_integrate=self.case.integrate, ) - forward_eval_vmap = vmap(forward_eval, in_axes=(None, None, 0, 0, 0)) + forward_eval_vmap = eqx.filter_vmap(forward_eval, in_axes=(None, 0, 0, 0)) preprocess_eval_vmap = vmap(self.case.preprocess_eval, in_axes=(0, 0)) metrics_computer_vmap = vmap(metrics_computer, in_axes=(0, 0)) for n_extrap_steps in [0, 5, 10]: with self.subTest(n_extrap_steps): + model = _CheatingModel(accs, start=isl - 2) example_rollout_batch, metrics_batch, neighbors = _eval_batched_rollout( forward_eval_vmap=forward_eval_vmap, preprocess_eval_vmap=preprocess_eval_vmap, case=self.case, - params=params, - state=state, + model=model, traj_batch_i=traj_batch_i, neighbors=neighbors, metrics_computer_vmap=metrics_computer_vmap,