Skip to content

Commit 24d162c

Browse files
ssmmnn11pre-commit-ci[bot]HCookieicedoom888cathalobrien
authored
feat: compile transformer gnn (#181)
Compile GraphTransformerConv <!-- readthedocs-preview anemoi-training start --> ---- πŸ“š Documentation preview πŸ“š: https://anemoi-training--181.org.readthedocs.build/en/181/ <!-- readthedocs-preview anemoi-training end --> <!-- readthedocs-preview anemoi-graphs start --> ---- πŸ“š Documentation preview πŸ“š: https://anemoi-graphs--181.org.readthedocs.build/en/181/ <!-- readthedocs-preview anemoi-graphs end --> <!-- readthedocs-preview anemoi-models start --> ---- πŸ“š Documentation preview πŸ“š: https://anemoi-models--181.org.readthedocs.build/en/181/ <!-- readthedocs-preview anemoi-models end --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Harrison Cook <[email protected]> Co-authored-by: Icedoom <[email protected]> Co-authored-by: Cathal OBrien <[email protected]> Co-authored-by: Matthew Chantry <[email protected]> Co-authored-by: Ana Prieto Nemesio <[email protected]>
1 parent cd777fb commit 24d162c

21 files changed

+572
-1
lines changed

β€Žmodels/src/anemoi/models/schemas/models.pyβ€Ž

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
import logging
1313
from enum import Enum
1414
from typing import Annotated
15+
from typing import Any
1516
from typing import Literal
17+
from typing import Optional
1618
from typing import Union
1719

1820
from pydantic import BaseModel as PydanticBaseModel
@@ -226,6 +228,8 @@ class BaseModelSchema(PydanticBaseModel):
226228
discriminator="target_",
227229
)
228230
"GNN decoder schema."
231+
compile: Optional[list[dict[str, Any]]] = Field(None)
232+
"Modules to be compiled"
229233

230234

231235
class NoiseInjectorSchema(BaseModel):
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# (C) Copyright 2025 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# (C) Copyright 2025 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
import logging
11+
from functools import reduce
12+
from importlib.util import find_spec
13+
14+
import torch
15+
import torch_geometric
16+
from hydra.utils import get_class
17+
from numpy import unique
18+
from omegaconf import DictConfig
19+
20+
from anemoi.training.train.tasks.base import BaseGraphModule
21+
22+
LOGGER = logging.getLogger(__name__)
23+
24+
25+
def _get_compile_entry(module: str, compile_config: DictConfig) -> DictConfig | None:
26+
"""Search the compile config for an entry c module name.
27+
28+
module: str -> full module name e.g. 'anemoi.models.layers.conv.GraphTransformerConv'
29+
compile_config : DictConfig -> The 'compile' entry within the models config
30+
31+
returns: None, if 'module' is not listed within 'compile_config'. Otherwise returns the modules entry.
32+
33+
"""
34+
for entry in compile_config:
35+
if get_class(entry["module"]) is type(module):
36+
return entry
37+
38+
return None
39+
40+
41+
def _meets_library_versions_for_compile() -> bool:
42+
"""Returns True if minimum library versions for compilation in Anemoi is met."""
43+
has_triton = True
44+
if find_spec("triton") is None:
45+
msg = "Triton not installed! Consider installing Triton to "
46+
msg += "enable compilation and improve speed and memory usage."
47+
LOGGER.warning(msg)
48+
has_triton = False
49+
50+
version_req = torch.__version__ >= "2.6" and torch_geometric.__version__ >= "2.6"
51+
52+
if not version_req:
53+
msg = "Minimum library versions for compilation not met. "
54+
msg += f"torch: v{torch.__version__}<2.6 or torch_geometric: v{torch_geometric.__version__}<2.6. "
55+
msg += "Please upgrade these libraries to enable compilation."
56+
LOGGER.warning(msg)
57+
58+
return version_req and has_triton
59+
60+
61+
def mark_for_compilation(model: BaseGraphModule, compile_config: DictConfig | None) -> BaseGraphModule:
62+
"""Marks modules within 'model' for compilation, according to 'compile_config'.
63+
64+
Modules are not compiled here. The compilation will occur
65+
automatically before the first forward iteration.
66+
67+
returns an updated model, with modules marked for compilation
68+
"""
69+
if compile_config is None:
70+
return model
71+
72+
if not _meets_library_versions_for_compile():
73+
return model
74+
75+
default_compile_options = {}
76+
compiled_modules = []
77+
78+
# Loop through all modules
79+
for name, module in model.named_modules():
80+
entry = _get_compile_entry(module, compile_config)
81+
# entry is 'None' if compilation was not requested for this module
82+
if entry is not None:
83+
options = entry.get("options", default_compile_options)
84+
85+
LOGGER.debug("%s will be compiled with the following options: %s", str(module), str(options))
86+
compiled_module = torch.compile(module, **options) # Note: the module is not compiled yet
87+
# It is just marked for JIT-compilation later
88+
# It will be compiled before its first forward pass
89+
compiled_modules.append(entry.module)
90+
91+
# Update the model with the new 'compiled' module
92+
# go from "anemoi.models.layers.conv.GraphTransformerConv"
93+
# to obj(anemoi.models.layers.conv)
94+
parts = name.split(".")
95+
parent = reduce(getattr, parts[:-1], model)
96+
# then set obj(anemoi.models.layers.conv).GrapTransformerConv = compiled_module
97+
LOGGER.debug("Replacing %s with a compiled version", str(parts[-1]))
98+
setattr(parent, parts[-1], compiled_module)
99+
100+
LOGGER.info("The following modules will be compiled: %s", str(unique(compiled_modules)))
101+
102+
return model
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# (C) Copyright 2024 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
import logging
11+
12+
import torch
13+
from omegaconf import DictConfig
14+
from omegaconf import OmegaConf
15+
16+
from anemoi.models.layers.attention import MultiHeadSelfAttention
17+
from anemoi.models.layers.normalization import ConditionalLayerNorm
18+
from anemoi.models.layers.utils import load_layer_kernels
19+
from anemoi.models.utils.compile import _get_compile_entry
20+
from anemoi.models.utils.compile import _meets_library_versions_for_compile
21+
from anemoi.models.utils.compile import mark_for_compilation
22+
23+
LOGGER = logging.getLogger(__name__)
24+
25+
26+
def graphtransformer_compile_config() -> None:
27+
return OmegaConf.create(
28+
{
29+
"compile": [
30+
{
31+
"module": "anemoi.models.layers.conv.GraphTransformerConv",
32+
},
33+
],
34+
}
35+
)
36+
37+
38+
def layer_kernel_compile_config() -> None:
39+
return OmegaConf.create(
40+
{
41+
"compile": [
42+
{
43+
"module": "torch.nn.Linear",
44+
},
45+
],
46+
}
47+
)
48+
49+
50+
def graphtransformer_ens_compile_config() -> None:
51+
return OmegaConf.create(
52+
{
53+
"compile": [
54+
{
55+
"module": "anemoi.models.layers.conv.GraphTransformerConv",
56+
},
57+
{
58+
"module": "anemoi.models.layers.normalization.ConditionalLayerNorm",
59+
"options": {
60+
"dynamic": False,
61+
},
62+
},
63+
],
64+
}
65+
)
66+
67+
68+
def test_compile_config_no_match() -> None:
69+
"""Tests that _get_compile_entry() returns None when no match is found."""
70+
cfg = graphtransformer_compile_config()
71+
72+
num_channels = 64
73+
cond_shape = 16
74+
model = ConditionalLayerNorm(num_channels, condition_shape=cond_shape)
75+
result = _get_compile_entry(model, cfg.compile)
76+
77+
assert result is None
78+
79+
80+
def test_compile_config_match() -> None:
81+
"""Tests that _get_compile_entry() returns a dict when a match is found."""
82+
cfg = graphtransformer_ens_compile_config()
83+
84+
num_channels = 64
85+
cond_shape = 16
86+
model = ConditionalLayerNorm(num_channels, condition_shape=cond_shape)
87+
result = _get_compile_entry(model, cfg.compile)
88+
89+
assert type(result) is DictConfig
90+
91+
92+
def test_compile() -> None:
93+
94+
# Skip this test if library versions aren't met
95+
if not _meets_library_versions_for_compile():
96+
LOGGER.warning("triton not installed. skipping 'test_compile.py::test_compile'")
97+
return
98+
99+
num_channels = 64
100+
cond_shape = 16
101+
ln = ConditionalLayerNorm(num_channels, condition_shape=cond_shape)
102+
x_in = torch.randn(num_channels)
103+
cond = torch.randn(cond_shape)
104+
result = ln.forward(x_in, cond)
105+
106+
cfg = graphtransformer_ens_compile_config()
107+
ln_compiled = mark_for_compilation(ln, cfg.compile)
108+
109+
result_compiled = ln_compiled.forward(x_in, cond)
110+
111+
# check the function was compiled
112+
assert hasattr(ln_compiled, "_compile_kwargs")
113+
114+
# check the result of the compiled function matches the uncompiled result
115+
assert torch.allclose(result, result_compiled)
116+
117+
118+
def test_compile_layer_kernel() -> None:
119+
120+
# Skip this test if library versions aren't met
121+
if not _meets_library_versions_for_compile():
122+
LOGGER.warning("triton not installed. skipping 'test_compile.py::test_compile'")
123+
return
124+
125+
cfg = layer_kernel_compile_config()
126+
layer_kernels = load_layer_kernels(kernel_config={})
127+
128+
num_heads = 1
129+
embed_dim = 64
130+
dropout_p = 0.0
131+
batch_size = 1
132+
mhsa = MultiHeadSelfAttention(
133+
num_heads,
134+
embed_dim,
135+
layer_kernels,
136+
dropout_p=dropout_p,
137+
attention_implementation="scaled_dot_product_attention",
138+
)
139+
mhsa_compiled = mark_for_compilation(mhsa, cfg.compile)
140+
141+
x = torch.randn(batch_size * 2, embed_dim, requires_grad=True)
142+
shapes = [list(x.shape)]
143+
144+
result = mhsa.forward(x, shapes, batch_size)
145+
result_compiled = mhsa_compiled.forward(x, shapes, batch_size)
146+
147+
# check the function was compiled
148+
assert hasattr(mhsa_compiled.projection, "_compile_kwargs")
149+
150+
# check the result of the compiled function matches the uncompiled result
151+
assert torch.allclose(result, result_compiled)

β€Žtraining/docs/user-guide/models.rstβ€Ž

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,3 +191,47 @@ configuration:
191191
This determines how many ensemble members are generated per device
192192
during training. Effective ensemble size is then the number of ensemble
193193
members per device times the number of GPUs per ensemble.
194+
195+
*************
196+
Compilation
197+
*************
198+
199+
PyTorch supports JIT-compiliation of code. This can speed up execution
200+
and reduce peak memory usage. For more information, consult `the
201+
introduction to torch.compile
202+
<https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__
203+
and `the official documentation
204+
<https://docs.pytorch.org/docs/stable/generated/torch.compile.html>`__.
205+
206+
Compilation requires Triton. Normally Triton is pulled in as a
207+
dependancy when PyTorch is installed. Otherwise, Triton can be `built
208+
from source
209+
<https://github.com/triton-lang/triton?tab=readme-ov-file#install-from-source>`__
210+
. Compilation requires torch >= 2.6 and torch_geometric >= 2.6. If these
211+
versions are not met, or if Triton is not installed, then anemoi will
212+
run without compilation.
213+
214+
Anemoi exposes 'torch.compile' at the module level through the model
215+
config. Below is an example:
216+
217+
.. code:: yaml
218+
219+
#training/config/models/transformer_ens.yaml
220+
compile:
221+
- module: anemoi.models.layers.conv.GraphTransformerConv
222+
options:
223+
dynamic: false
224+
mode: max-autotune
225+
- module: anemoi.models.layers.normalization.ConditionalLayerNorm
226+
options:
227+
dynamic: false
228+
229+
Under the 'compile' keyword, you provide a list of modules. These
230+
modules will be marked for compilation when the model is built. During
231+
their first forward pass, these modules will be compiled. No code
232+
modifications are required.
233+
234+
You can optionally pass options to torch compile via the 'options'
235+
keyword. A full list of the possible options and their meanings can be
236+
found in the `torch.compile documentation
237+
<https://docs.pytorch.org/docs/stable/generated/torch.compile.html>`__.

β€Žtraining/src/anemoi/training/config/model/graphtransformer.yamlβ€Ž

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,29 @@ trainable_parameters:
6868
hidden2data: 8
6969
hidden2hidden: 8 # GNN and GraphTransformer Processor only
7070

71+
# Torch compile configuration
72+
# A list of modules present in the model, which will be compiled
73+
# You can optionally pass options to torch.compile with the 'options' key
74+
#
75+
# Below is an explanation of some common parameters to torch.compile
76+
# For a full list of possible parameters, consult the documenation for torch compile
77+
# https://docs.pytorch.org/docs/stable/generated/torch.compile.html
78+
#
79+
# dynamic (bool): When True, it will try to avoid recompilation by generating
80+
# as general a kernel as possible. But the performance of the general
81+
# kernel might be worse. When False, it will generate a specific
82+
# kernel for each input shape (until the configurable recompile
83+
# limit has been hit), leading to possibly better performance but
84+
# more recompilations
85+
# mode (str): Different compilation modes, allowing you to trade off
86+
# compilation time versus potential performance. See the
87+
# torch.compile documentation for list of possible modes
88+
# fullgraph (bool): When True, torch.compile will error when it hits a
89+
# section of code it can't compile. When False, it will fallback to
90+
# non-compiled ("eager") execution for those lines.
91+
# options (dict): a dict of further options which can be passed to torch.compile
92+
compile:
93+
- module: anemoi.models.layers.conv.GraphTransformerConv
7194

7295
attributes:
7396
edges:

β€Žtraining/src/anemoi/training/config/model/graphtransformer_diffusion.yamlβ€Ž

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,5 +105,35 @@ attributes:
105105
- edge_dirs
106106
nodes: []
107107

108+
# Torch compile configuration
109+
# A list of modules present in the model, which will be compiled
110+
# You can optionally pass options to torch.compile with the 'options' key
111+
#
112+
# Below is an explanation of some common parameters to torch.compile
113+
# For a full list of possible parameters, consult the documenation for torch compile
114+
# https://docs.pytorch.org/docs/stable/generated/torch.compile.html
115+
#
116+
# dynamic (bool): When True, it will try to avoid recompilation by generating
117+
# as general a kernel as possible. But the performance of the general
118+
# kernel might be worse. When False, it will generate a specific
119+
# kernel for each input shape (until the configurable recompile
120+
# limit has been hit), leading to possibly better performance but
121+
# more recompilations
122+
# mode (str): Different compilation modes, allowing you to trade off
123+
# compilation time versus potential performance. See the
124+
# torch.compile documentation for list of possible modes
125+
# fullgraph (bool): When True, torch.compile will error when it hits a
126+
# section of code it can't compile. When False, it will fallback to
127+
# non-compiled ("eager") execution for those lines.
128+
# options (dict): a dict of further options which can be passed to torch.compile
129+
compile:
130+
- module: anemoi.models.layers.conv.GraphTransformerConv
131+
#options: # An example of setting torch.compile options
132+
#dynamic: false
133+
#mode: max-autotune
134+
- module: anemoi.models.layers.normalization.ConditionalLayerNorm
135+
options:
136+
dynamic: false
137+
108138
# Bounding configuration
109139
bounding: []

0 commit comments

Comments
Β (0)