Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ rollouts
profile
dist
.coverage
.claude

# Sphinx documentation
docs/_build/
docs/_build/
17 changes: 10 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ Although the current [`jax-metal==0.0.5` library](https://pypi.org/project/jax-m
A general tutorial is provided in the example notebook "Training GNS on the 2D Taylor Green Vortex" under `./notebooks/tutorial.ipynb` on the [LagrangeBench repository](https://github.com/tumaer/lagrangebench). The notebook covers the basics of LagrangeBench, such as loading a dataset, setting up a case, training a model from scratch and evaluating its performance.

### Running in a local clone (`main.py`)
Alternatively, experiments can also be set up with `main.py`, based on extensive YAML config files and cli arguments (check [`configs/`](configs/)). By default, the arguments have priority as 1) passed cli arguments, 2) YAML config and 3) [`defaults.py`](lagrangebench/defaults.py) (`lagrangebench` defaults).
Alternatively, experiments can also be set up with `main.py`, based on Hydra-composed YAML config files and cli arguments (check [`configs/`](configs/)). Each preset is a small YAML file whose `defaults:` list pulls in either the dataset-level `base.yaml` or the global defaults registered from [`defaults.py`](lagrangebench/defaults.py). CLI overrides use the standard Hydra `key=value` syntax and have priority over everything.

When loading a saved model with `load_ckp` the config from the checkpoint is automatically loaded and training is restarted. For more details check the [`runner.py`](lagrangebench/runner.py) file.

**Train**

For example, to start a _GNS_ run from scratch on the RPF 2D dataset use
```
python main.py config=configs/rpf_2d/gns.yaml
python main.py --config-path configs/rpf_2d --config-name gns
```
Some model presets can be found in `./configs/`.

Expand All @@ -103,16 +103,18 @@ If `mode=all` is provided, then training (`mode=train`) and subsequent inference

**Restart training**

To restart training from the last checkpoint in `load_ckp` use
To restart training from the last checkpoint in `load_ckp` point Hydra at the checkpoint's saved `config.yaml`:
```
python main.py load_ckp=ckp/gns_rpf2d_yyyymmdd-hhmmss
python main.py --config-path ckp/gns_rpf2d_yyyymmdd-hhmmss --config-name config load_ckp=ckp/gns_rpf2d_yyyymmdd-hhmmss
```

**Inference**

To evaluate a trained model from `load_ckp` on the test split (`test=True`) use
To evaluate a trained model from `load_ckp` on the test split (`eval.test=True`) use
```
python main.py load_ckp=ckp/gns_rpf2d_yyyymmdd-hhmmss/best rollout_dir=rollout/gns_rpf2d_yyyymmdd-hhmmss/best mode=infer test=True
python main.py --config-path ckp/gns_rpf2d_yyyymmdd-hhmmss/best --config-name config \
load_ckp=ckp/gns_rpf2d_yyyymmdd-hhmmss/best \
mode=infer eval.test=True eval.rollout_dir=rollout/gns_rpf2d_yyyymmdd-hhmmss/best
```

If the default `eval.infer.out_type=pkl` is active, then the generated trajectories and a `metricsYYYY_MM_DD_HH_MM_SS.pkl` file will be written to `eval.rollout_dir`. The metrics file contains all `eval.infer.metrics` properties for each generated rollout.
Expand Down Expand Up @@ -162,7 +164,8 @@ gdown 19TO4PaFGcryXOFFKs93IniuPZKEcaJ37
# unzip the downloaded file `gns_tgv2d.zip`
python -c "import shutil; shutil.unpack_archive('gns_tgv2d.zip', 'gns_tgv2d')"
# evaluate the model on the test split
python main.py gpu=$GPU_ID mode=infer eval.test=True load_ckp=gns_tgv2d/best
python main.py --config-path gns_tgv2d/best --config-name config \
gpu=$GPU_ID mode=infer eval.test=True load_ckp=gns_tgv2d/best
```

## Directory structure
Expand Down
7 changes: 5 additions & 2 deletions configs/WaterDrop_2d/gns.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
extends: LAGRANGEBENCH_DEFAULTS
# @package _global_
defaults:
- /lagrangebench_defaults
- _self_

dataset:
src: /tmp/datasets/WaterDrop

model:
model:
name: gns
num_mp_steps: 10
latent_dim: 128
Expand Down
5 changes: 4 additions & 1 deletion configs/dam_2d/base.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
extends: LAGRANGEBENCH_DEFAULTS
# @package _global_
defaults:
- /lagrangebench_defaults
- _self_

dataset:
src: datasets/2D_DAM_5740_20kevery100
Expand Down
7 changes: 5 additions & 2 deletions configs/dam_2d/gns.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
extends: configs/dam_2d/base.yaml
# @package _global_
defaults:
- base
- _self_

model:
model:
name: gns
num_mp_steps: 10
latent_dim: 128
Expand Down
7 changes: 5 additions & 2 deletions configs/dam_2d/segnn.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
extends: configs/dam_2d/base.yaml
# @package _global_
defaults:
- base
- _self_

model:
model:
name: segnn
num_mp_steps: 10
latent_dim: 64
Expand Down
7 changes: 5 additions & 2 deletions configs/ldc_2d/base.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
extends: LAGRANGEBENCH_DEFAULTS
# @package _global_
defaults:
- /lagrangebench_defaults
- _self_

dataset:
src: datasets/2D_LDC_2708_10kevery100
Expand All @@ -7,4 +10,4 @@ logging:
wandb_project: ldc_2d

neighbors:
multiplier: 2.0
multiplier: 2.0
7 changes: 5 additions & 2 deletions configs/ldc_2d/gns.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
extends: configs/ldc_2d/base.yaml
# @package _global_
defaults:
- base
- _self_

model:
model:
name: gns
num_mp_steps: 10
latent_dim: 128
Expand Down
7 changes: 5 additions & 2 deletions configs/ldc_2d/segnn.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
extends: configs/ldc_2d/base.yaml
# @package _global_
defaults:
- base
- _self_

model:
model:
name: segnn
num_mp_steps: 10
latent_dim: 64
Expand Down
7 changes: 5 additions & 2 deletions configs/ldc_3d/base.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
extends: LAGRANGEBENCH_DEFAULTS
# @package _global_
defaults:
- /lagrangebench_defaults
- _self_

dataset:
src: datasets/3D_LDC_8160_10kevery100
Expand All @@ -7,4 +10,4 @@ logging:
wandb_project: ldc_3d

neighbors:
multiplier: 2.0
multiplier: 2.0
7 changes: 5 additions & 2 deletions configs/ldc_3d/gns.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
extends: configs/ldc_3d/base.yaml
# @package _global_
defaults:
- base
- _self_

model:
model:
name: gns
num_mp_steps: 10
latent_dim: 128
Expand Down
7 changes: 5 additions & 2 deletions configs/ldc_3d/segnn.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
extends: configs/ldc_3d/base.yaml
# @package _global_
defaults:
- base
- _self_

model:
model:
name: segnn
num_mp_steps: 10
latent_dim: 64
Expand Down
7 changes: 5 additions & 2 deletions configs/rpf_2d/base.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
extends: LAGRANGEBENCH_DEFAULTS
# @package _global_
defaults:
- /lagrangebench_defaults
- _self_

dataset:
src: datasets/2D_RPF_3200_20kevery100

logging:
wandb_project: rpf_2d
wandb_project: rpf_2d
7 changes: 5 additions & 2 deletions configs/rpf_2d/egnn.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
extends: configs/rpf_2d/base.yaml
# @package _global_
defaults:
- base
- _self_

model:
model:
name: egnn
num_mp_steps: 5
latent_dim: 128
Expand Down
7 changes: 5 additions & 2 deletions configs/rpf_2d/gns.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
extends: configs/rpf_2d/base.yaml
# @package _global_
defaults:
- base
- _self_

model:
model:
name: gns
num_mp_steps: 10
latent_dim: 128
Expand Down
7 changes: 5 additions & 2 deletions configs/rpf_2d/painn.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
extends: configs/rpf_2d/base.yaml
# @package _global_
defaults:
- base
- _self_

model:
model:
name: painn
num_mp_steps: 5
latent_dim: 128
Expand Down
7 changes: 5 additions & 2 deletions configs/rpf_2d/segnn.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
extends: configs/rpf_2d/base.yaml
# @package _global_
defaults:
- base
- _self_

model:
model:
name: segnn
num_mp_steps: 10
latent_dim: 64
Expand Down
7 changes: 5 additions & 2 deletions configs/rpf_3d/base.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
extends: LAGRANGEBENCH_DEFAULTS
# @package _global_
defaults:
- /lagrangebench_defaults
- _self_

dataset:
src: datasets/3D_RPF_8000_10kevery100

logging:
wandb_project: rpf_3d
wandb_project: rpf_3d
7 changes: 5 additions & 2 deletions configs/rpf_3d/egnn.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
extends: configs/rpf_3d/base.yaml
# @package _global_
defaults:
- base
- _self_

model:
model:
name: egnn
num_mp_steps: 5
latent_dim: 128
Expand Down
7 changes: 5 additions & 2 deletions configs/rpf_3d/gns.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
extends: configs/rpf_3d/base.yaml
# @package _global_
defaults:
- base
- _self_

model:
model:
name: gns
num_mp_steps: 10
latent_dim: 128
Expand Down
7 changes: 5 additions & 2 deletions configs/rpf_3d/painn.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
extends: configs/rpf_3d/base.yaml
# @package _global_
defaults:
- base
- _self_

model:
model:
name: painn
num_mp_steps: 5
latent_dim: 128
Expand Down
7 changes: 5 additions & 2 deletions configs/rpf_3d/segnn.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
extends: configs/rpf_3d/base.yaml
# @package _global_
defaults:
- base
- _self_

model:
model:
name: segnn
num_mp_steps: 10
latent_dim: 64
Expand Down
5 changes: 4 additions & 1 deletion configs/tgv_2d/base.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
extends: LAGRANGEBENCH_DEFAULTS
# @package _global_
defaults:
- /lagrangebench_defaults
- _self_

dataset:
src: datasets/2D_TGV_2500_10kevery100
Expand Down
7 changes: 5 additions & 2 deletions configs/tgv_2d/gns.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
extends: configs/tgv_2d/base.yaml
# @package _global_
defaults:
- base
- _self_

model:
model:
name: gns
num_mp_steps: 10
latent_dim: 128
Expand Down
7 changes: 5 additions & 2 deletions configs/tgv_2d/segnn.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
extends: configs/tgv_2d/base.yaml
# @package _global_
defaults:
- base
- _self_

model:
model:
name: segnn
num_mp_steps: 10
latent_dim: 64
Expand Down
5 changes: 4 additions & 1 deletion configs/tgv_3d/base.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
extends: LAGRANGEBENCH_DEFAULTS
# @package _global_
defaults:
- /lagrangebench_defaults
- _self_

dataset:
src: datasets/3D_TGV_8000_10kevery100
Expand Down
7 changes: 5 additions & 2 deletions configs/tgv_3d/gns.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
extends: configs/tgv_3d/base.yaml
# @package _global_
defaults:
- base
- _self_

model:
model:
name: gns
num_mp_steps: 10
latent_dim: 128
Expand Down
5 changes: 4 additions & 1 deletion configs/tgv_3d/segnn.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
extends: configs/tgv_3d/base.yaml
# @package _global_
defaults:
- base
- _self_

model:
name: segnn
Expand Down
3 changes: 2 additions & 1 deletion lagrangebench/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .case_setup.case import case_builder
from .data import DAM2D, LDC2D, LDC3D, RPF2D, RPF3D, TGV2D, TGV3D, H5Dataset
from .evaluate import infer
from .models import EGNN, GNS, SEGNN, PaiNN
from .models import EGNN, GNS, SEGNN, Linear, PaiNN
from .train.trainer import Trainer

__all__ = [
Expand All @@ -13,6 +13,7 @@
"EGNN",
"SEGNN",
"PaiNN",
"Linear",
"data",
"H5Dataset",
"TGV2D",
Expand Down
Loading
Loading