Skip to content

Commit 06e5533

Browse files
japolsssmmnn11gabrielokspre-commit-ci[bot]anaprietonem
authored
fix(models): processor chunking (#629)
## Description Remove the separate ProcessorChunk class and flatten all layers directly into the BaseProcessor. Chunking is now handled dynamically at runtime by grouping layers into checkpointed segments. ## What problem does this change solve? Previously, the Processor class held a list of ProcessorChunks which held its own ModuleList of layers, meaning that checkpointed layer groupings were tied to the chunking configuration saved in the model checkpoint. When resuming training with a different num_chunks, the restored module structure no longer matched the saved one, causing checkpoint mismatches. Now we only have one flat list of all layers (Blocks) in the Processor Class and chunking is handled dynamically. ## What issue or task does this change relate to? <!-- link to Issue Number --> ## Additional notes ## Tested with all models, i.e. GT, Transformer, GNN, PointWiseMLP ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md) <!-- readthedocs-preview anemoi-training start --> ---- 📚 Documentation preview 📚: https://anemoi-training--629.org.readthedocs.build/en/629/ <!-- readthedocs-preview anemoi-training end --> <!-- readthedocs-preview anemoi-graphs start --> ---- 📚 Documentation preview 📚: https://anemoi-graphs--629.org.readthedocs.build/en/629/ <!-- readthedocs-preview anemoi-graphs end --> <!-- readthedocs-preview anemoi-models start --> ---- 📚 Documentation preview 📚: https://anemoi-models--629.org.readthedocs.build/en/629/ <!-- readthedocs-preview anemoi-models end --> --------- Co-authored-by: Simon Lang <[email protected]> Co-authored-by: gabrieloks <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ana Prieto Nemesio <[email protected]> Co-authored-by: Jakob Schloer <[email protected]>
1 parent 6819be1 commit 06e5533

File tree

18 files changed

+209
-608
lines changed

18 files changed

+209
-608
lines changed

models/docs/introduction/overview.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,10 @@ Processors
9999
==========
100100

101101
Additionally, the layers implement `Processors` which are used to
102-
process the data on the hidden grid. The `Processors` use a chunking
103-
strategy with `Chunks` that pass a subset of layers to `Blocks` to allow
104-
for more efficient processing of the data.
102+
process the data on the hidden grid. The `Processors` use a series of
103+
`Blocks` to process the data. These `Blocks` can be partitioned into
104+
checkpointed chunks via `num_chunks` to reduce memory usage during
105+
training.
105106

106107
**************
107108
Data Indices

models/src/anemoi/models/layers/block.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def forward(
9090
batch_size: int,
9191
model_comm_group: Optional[ProcessGroup] = None,
9292
**layer_kwargs,
93-
) -> Tensor:
94-
return self.mlp(x)
93+
) -> tuple[Tensor]:
94+
return (self.mlp(x),)
9595

9696

9797
class TransformerProcessorBlock(BaseBlock):
@@ -146,7 +146,7 @@ def forward(
146146
model_comm_group: Optional[ProcessGroup] = None,
147147
cond: Optional[Tensor] = None,
148148
**layer_kwargs,
149-
) -> Tensor:
149+
) -> tuple[Tensor]:
150150

151151
# In case we have conditionings we pass these to the layer norm
152152
cond_kwargs = {"cond": cond} if cond is not None else {}
@@ -160,7 +160,7 @@ def forward(
160160
**cond_kwargs,
161161
)
162162
)
163-
return x
163+
return (x,)
164164

165165

166166
class TransformerMapperBlock(TransformerProcessorBlock):
@@ -222,7 +222,7 @@ def forward(
222222
shapes: list,
223223
batch_size: int,
224224
model_comm_group: Optional[ProcessGroup] = None,
225-
) -> Tensor:
225+
) -> tuple[Tensor, Tensor]:
226226
x_src = self.layer_norm_attention_src(x[0])
227227
x_dst = self.layer_norm_attention_dst(x[1])
228228
x_dst = x_dst + self.attention((x_src, x_dst), shapes, batch_size, model_comm_group=model_comm_group)
@@ -242,6 +242,7 @@ def __init__(
242242
mlp_extra_layers: int = 0,
243243
update_src_nodes: bool = True,
244244
layer_kernels: DotDict,
245+
edge_dim: Optional[int] = None,
245246
**kwargs,
246247
) -> None:
247248
"""Initialize GNNBlock.
@@ -264,6 +265,17 @@ def __init__(
264265
"""
265266
super().__init__(**kwargs)
266267

268+
if edge_dim:
269+
self.emb_edges = MLP(
270+
in_features=edge_dim,
271+
hidden_dim=out_channels,
272+
out_features=out_channels,
273+
layer_kernels=layer_kernels,
274+
n_extra_layers=mlp_extra_layers,
275+
)
276+
else:
277+
self.emb_edges = None
278+
267279
self.update_src_nodes = update_src_nodes
268280
self.num_chunks = num_chunks
269281

@@ -306,6 +318,8 @@ def forward(
306318
size: Optional[Size] = None,
307319
**layer_kwargs,
308320
) -> tuple[Tensor, Tensor]:
321+
if self.emb_edges is not None:
322+
edge_attr = self.emb_edges(edge_attr)
309323

310324
x_in = sync_tensor(x, 0, shapes[1], model_comm_group)
311325

@@ -424,7 +438,6 @@ def __init__(
424438
hidden_dim: int,
425439
out_channels: int,
426440
num_heads: int,
427-
num_chunks: int,
428441
edge_dim: int,
429442
bias: bool = True,
430443
qk_norm: bool = False,
@@ -442,8 +455,6 @@ def __init__(
442455
Number of output channels.
443456
num_heads : int,
444457
Number of heads
445-
num_chunks : int,
446-
Number of chunks
447458
edge_dim : int,
448459
Edge dimension
449460
bias : bool, by default True,
@@ -463,7 +474,6 @@ def __init__(
463474
self.out_channels_conv = out_channels // num_heads
464475
self.num_heads = num_heads
465476
self.qk_norm = qk_norm
466-
self.num_chunks = num_chunks
467477

468478
Linear = layer_kernels.Linear
469479
LayerNorm = layer_kernels.LayerNorm
@@ -662,7 +672,6 @@ def __init__(
662672
layer_kernels=layer_kernels,
663673
num_heads=num_heads,
664674
bias=bias,
665-
num_chunks=1,
666675
qk_norm=qk_norm,
667676
update_src_nodes=update_src_nodes,
668677
**kwargs,
@@ -777,7 +786,6 @@ def __init__(
777786
hidden_dim: int,
778787
out_channels: int,
779788
num_heads: int,
780-
num_chunks: int,
781789
edge_dim: int,
782790
bias: bool = True,
783791
qk_norm: bool = False,
@@ -795,8 +803,6 @@ def __init__(
795803
Number of output channels.
796804
num_heads : int,
797805
Number of heads
798-
num_chunks : int,
799-
Number of chunks
800806
edge_dim : int,
801807
Edge dimension
802808
bias : bool
@@ -819,7 +825,6 @@ def __init__(
819825
num_heads=num_heads,
820826
bias=bias,
821827
qk_norm=qk_norm,
822-
num_chunks=num_chunks,
823828
update_src_nodes=update_src_nodes,
824829
**kwargs,
825830
)
@@ -851,7 +856,8 @@ def forward(
851856
query = self.q_norm(query)
852857
key = self.k_norm(key)
853858

854-
num_chunks = self.num_chunks if self.training else NUM_CHUNKS_INFERENCE_PROCESSOR
859+
# "inner" chunking for memory reductions in inference, controlled via env variable:
860+
num_chunks = 1 if self.training else NUM_CHUNKS_INFERENCE_PROCESSOR
855861

856862
out = self.attention_block(query, key, value, edges, edge_index, size, num_chunks)
857863

0 commit comments

Comments
 (0)