diff --git a/models/docs/introduction/overview.rst b/models/docs/introduction/overview.rst index 63cf1251c..d0c98cc0f 100644 --- a/models/docs/introduction/overview.rst +++ b/models/docs/introduction/overview.rst @@ -99,9 +99,10 @@ Processors ========== Additionally, the layers implement `Processors` which are used to -process the data on the hidden grid. The `Processors` use a chunking -strategy with `Chunks` that pass a subset of layers to `Blocks` to allow -for more efficient processing of the data. +process the data on the hidden grid. The `Processors` use a series of +`Blocks` to process the data. These `Blocks` can be partitioned into +checkpointed chunks via `num_chunks` to reduce memory usage during +training. ************** Data Indices diff --git a/models/src/anemoi/models/layers/block.py b/models/src/anemoi/models/layers/block.py index c2a9ad15f..505021303 100644 --- a/models/src/anemoi/models/layers/block.py +++ b/models/src/anemoi/models/layers/block.py @@ -90,8 +90,8 @@ def forward( batch_size: int, model_comm_group: Optional[ProcessGroup] = None, **layer_kwargs, - ) -> Tensor: - return self.mlp(x) + ) -> tuple[Tensor]: + return (self.mlp(x),) class TransformerProcessorBlock(BaseBlock): @@ -146,7 +146,7 @@ def forward( model_comm_group: Optional[ProcessGroup] = None, cond: Optional[Tensor] = None, **layer_kwargs, - ) -> Tensor: + ) -> tuple[Tensor]: # In case we have conditionings we pass these to the layer norm cond_kwargs = {"cond": cond} if cond is not None else {} @@ -160,7 +160,7 @@ def forward( **cond_kwargs, ) ) - return x + return (x,) class TransformerMapperBlock(TransformerProcessorBlock): @@ -222,7 +222,7 @@ def forward( shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None, - ) -> Tensor: + ) -> tuple[Tensor, Tensor]: x_src = self.layer_norm_attention_src(x[0]) x_dst = self.layer_norm_attention_dst(x[1]) x_dst = x_dst + self.attention((x_src, x_dst), shapes, batch_size, model_comm_group=model_comm_group) @@ -242,6 +242,7 @@ def __init__( mlp_extra_layers: int = 0, update_src_nodes: bool = True, layer_kernels: DotDict, + edge_dim: Optional[int] = None, **kwargs, ) -> None: """Initialize GNNBlock. @@ -264,6 +265,17 @@ def __init__( """ super().__init__(**kwargs) + if edge_dim: + self.emb_edges = MLP( + in_features=edge_dim, + hidden_dim=out_channels, + out_features=out_channels, + layer_kernels=layer_kernels, + n_extra_layers=mlp_extra_layers, + ) + else: + self.emb_edges = None + self.update_src_nodes = update_src_nodes self.num_chunks = num_chunks @@ -306,6 +318,8 @@ def forward( size: Optional[Size] = None, **layer_kwargs, ) -> tuple[Tensor, Tensor]: + if self.emb_edges is not None: + edge_attr = self.emb_edges(edge_attr) x_in = sync_tensor(x, 0, shapes[1], model_comm_group) @@ -424,7 +438,6 @@ def __init__( hidden_dim: int, out_channels: int, num_heads: int, - num_chunks: int, edge_dim: int, bias: bool = True, qk_norm: bool = False, @@ -442,8 +455,6 @@ def __init__( Number of output channels. num_heads : int, Number of heads - num_chunks : int, - Number of chunks edge_dim : int, Edge dimension bias : bool, by default True, @@ -463,7 +474,6 @@ def __init__( self.out_channels_conv = out_channels // num_heads self.num_heads = num_heads self.qk_norm = qk_norm - self.num_chunks = num_chunks Linear = layer_kernels.Linear LayerNorm = layer_kernels.LayerNorm @@ -662,7 +672,6 @@ def __init__( layer_kernels=layer_kernels, num_heads=num_heads, bias=bias, - num_chunks=1, qk_norm=qk_norm, update_src_nodes=update_src_nodes, **kwargs, @@ -777,7 +786,6 @@ def __init__( hidden_dim: int, out_channels: int, num_heads: int, - num_chunks: int, edge_dim: int, bias: bool = True, qk_norm: bool = False, @@ -795,8 +803,6 @@ def __init__( Number of output channels. num_heads : int, Number of heads - num_chunks : int, - Number of chunks edge_dim : int, Edge dimension bias : bool @@ -819,7 +825,6 @@ def __init__( num_heads=num_heads, bias=bias, qk_norm=qk_norm, - num_chunks=num_chunks, update_src_nodes=update_src_nodes, **kwargs, ) @@ -851,7 +856,8 @@ def forward( query = self.q_norm(query) key = self.k_norm(key) - num_chunks = self.num_chunks if self.training else NUM_CHUNKS_INFERENCE_PROCESSOR + # "inner" chunking for memory reductions in inference, controlled via env variable: + num_chunks = 1 if self.training else NUM_CHUNKS_INFERENCE_PROCESSOR out = self.attention_block(query, key, value, edges, edge_index, size, num_chunks) diff --git a/models/src/anemoi/models/layers/chunk.py b/models/src/anemoi/models/layers/chunk.py deleted file mode 100644 index 0f5a59ff7..000000000 --- a/models/src/anemoi/models/layers/chunk.py +++ /dev/null @@ -1,318 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import logging -from abc import ABC -from abc import abstractmethod -from typing import Optional - -from torch import Tensor -from torch import nn -from torch.distributed.distributed_c10d import ProcessGroup -from torch_geometric.typing import Adj -from torch_geometric.typing import OptPairTensor -from torch_geometric.typing import Size - -from anemoi.models.layers.block import GraphConvProcessorBlock -from anemoi.models.layers.block import GraphTransformerProcessorBlock -from anemoi.models.layers.block import PointWiseMLPProcessorBlock -from anemoi.models.layers.block import TransformerProcessorBlock -from anemoi.models.layers.mlp import MLP -from anemoi.utils.config import DotDict - -LOGGER = logging.getLogger(__name__) - - -class BaseProcessorChunk(nn.Module, ABC): - """Base Processor Chunk.""" - - def __init__( - self, - num_channels: int, - num_layers: int, - *args, - **kwargs, - ) -> None: - """Initialize BaseProcessorChunk.""" - super().__init__() - - self.num_channels = num_channels - self.num_layers = num_layers - - def build_blocks(self, block: nn.Module, *args, **kwargs) -> None: - """Build Layers.""" - self.blocks = nn.ModuleList( - [ - block( - *args, - **kwargs, - ) - for _ in range(self.num_layers) - ], - ) - - @abstractmethod - def forward( - self, - x: Tensor, - shapes: list, - batch_size: int, - model_comm_group: Optional[ProcessGroup] = None, - **kwargs, - ) -> Tensor: ... - - -class PointWiseMLPProcessorChunk(BaseProcessorChunk): - """Wraps point-wise MLP blocks for checkpointing in Processor.""" - - def __init__( - self, - num_channels: int, - num_layers: int, - layer_kernels: DotDict, - mlp_hidden_ratio: int = 4, - dropout_p: float = 0.0, - ) -> None: - super().__init__(num_channels=num_channels, num_layers=num_layers) - - self.build_blocks( - PointWiseMLPProcessorBlock, - num_channels=num_channels, - hidden_dim=(mlp_hidden_ratio * num_channels), - layer_kernels=layer_kernels, - dropout_p=dropout_p, - ) - - def forward( - self, - x: Tensor, - shapes: list, - batch_size: int, - model_comm_group: Optional[ProcessGroup] = None, - **kwargs, - ) -> Tensor: - for i in range(self.num_layers): - x = self.blocks[i](x, shapes, batch_size, model_comm_group=model_comm_group, **kwargs) - - return (x,) - - -class TransformerProcessorChunk(BaseProcessorChunk): - """Wraps transformer blocks for checkpointing in Processor.""" - - def __init__( - self, - num_channels: int, - num_layers: int, - layer_kernels: DotDict, - window_size: int, - num_heads: int = 16, - mlp_hidden_ratio: int = 4, - qk_norm: bool = False, - dropout_p: float = 0.0, - attention_implementation: str = "flash_attention", - softcap: float = None, - use_alibi_slopes: bool = None, - ) -> None: - """Initialize TransformerProcessor. - - Parameters - ---------- - num_channels : int - Number of channels - num_layers : int - Number of layers - layer_kernels : DotDict - A dict of layer implementations e.g. layer_kernels.Linear = "torch.nn.Linear" - Defined in config/models/.yaml - window_size: int, - 1/2 size of shifted window for attention computation - num_heads: int - Number of heads to use, default 16 - mlp_hidden_ratio: int - ratio of mlp hidden dimension to embedding dimension, default 4 - qk_norm: bool, optional - Normalize query and key, by default False - dropout_p: float - Dropout probability used for multi-head self attention, default 0.0 - attention_implementation: str - A predefined string which selects which underlying attention - implementation, by default "flash_attention" - softcap : float, optional - Anything > 0 activates softcapping flash attention, by default None - use_alibi_slopes : bool, optional - Use aLiBI option, only used for flash attention, by default None - """ - super().__init__(num_channels=num_channels, num_layers=num_layers) - - self.build_blocks( - TransformerProcessorBlock, - num_channels=num_channels, - hidden_dim=(mlp_hidden_ratio * num_channels), - num_heads=num_heads, - qk_norm=qk_norm, - window_size=window_size, - layer_kernels=layer_kernels, - dropout_p=dropout_p, - attention_implementation=attention_implementation, - softcap=softcap, - use_alibi_slopes=use_alibi_slopes, - ) - - def forward( - self, - x: Tensor, - shapes: list, - batch_size: int, - model_comm_group: Optional[ProcessGroup] = None, - **kwargs, - ) -> Tensor: - for i in range(self.num_layers): - x = self.blocks[i](x, shapes, batch_size, model_comm_group=model_comm_group, **kwargs) - - return (x,) # return tuple for consistency with other processors - - -class GNNProcessorChunk(BaseProcessorChunk): - """Wraps edge embedding message passing blocks for checkpointing in Processor.""" - - def __init__( - self, - num_channels: int, - num_layers: int, - layer_kernels: DotDict, - mlp_extra_layers: int = 0, - edge_dim: Optional[int] = None, - ) -> None: - """Initialize GNNProcessorChunk. - - Parameters - ---------- - num_channels : int - Channels of the message passing blocks. - num_layers : int - Number of message passing blocks. - layer_kernels : DotDict - A dict of layer implementations e.g. layer_kernels.Linear = "torch.nn.Linear" - Defined in config/models/.yaml - mlp_extra_layers : int, optional - Extra num_layers in MLP, by default 0 - edge_dim: int, by default None - Embed edges with input dimension edge_dim, - if None: assume embedding is not required - """ - super().__init__(num_channels=num_channels, num_layers=num_layers) - - if edge_dim: - self.emb_edges = MLP( - in_features=edge_dim, - hidden_dim=num_channels, - out_features=num_channels, - layer_kernels=layer_kernels, - n_extra_layers=mlp_extra_layers, - ) - else: - self.emb_edges = None - - self.build_blocks( - GraphConvProcessorBlock, - in_channels=num_channels, - out_channels=num_channels, - num_chunks=1, - layer_kernels=layer_kernels, - mlp_extra_layers=mlp_extra_layers, - ) - - def forward( - self, - x: OptPairTensor, - edge_attr: Tensor, - edge_index: Adj, - shapes: tuple, - model_comm_group: Optional[ProcessGroup] = None, - size: Optional[Size] = None, - **kwargs, - ) -> OptPairTensor: - x_out = x * 1.0 # required for pytorch >= 2.1 - if self.emb_edges: - edge_attr = self.emb_edges(edge_attr) - - for i in range(self.num_layers): - x_out, edge_attr = self.blocks[i]( - x_out, edge_attr, edge_index, shapes, model_comm_group=model_comm_group, size=size, **kwargs - ) - - return x_out, edge_attr - - -class GraphTransformerProcessorChunk(BaseProcessorChunk): - """Wraps graph transformer blocks for checkpointing in Processor.""" - - def __init__( - self, - num_channels: int, - num_layers: int, - layer_kernels: DotDict, - num_heads: int = 16, - mlp_hidden_ratio: int = 4, - qk_norm: bool = False, - edge_dim: Optional[int] = None, - ) -> None: - """Initialize GraphTransformerProcessorChunk. - - Parameters - ---------- - num_channels : int - Number of channels. - num_layers : int - Number of layers. - layer_kernels : DotDict - A dict of layer implementations e.g. layer_kernels.Linear = "torch.nn.Linear" - Defined in config/models/.yaml - num_heads: int - Number of heads to use, default 16 - mlp_hidden_ratio: int - ratio of mlp hidden dimension to embedding dimension, default 4 - qk_norm: bool, optional - Normalize query and key, by default False - edge_dim: int, by default None - Embed edges with input dimension edge_dim - """ - super().__init__(num_channels=num_channels, num_layers=num_layers) - - self.build_blocks( - GraphTransformerProcessorBlock, - in_channels=num_channels, - hidden_dim=mlp_hidden_ratio * num_channels, - out_channels=num_channels, - num_heads=num_heads, - num_chunks=1, - edge_dim=edge_dim, - layer_kernels=layer_kernels, - qk_norm=qk_norm, - ) - - def forward( - self, - x: OptPairTensor, - edge_attr: Tensor, - edge_index: Adj, - shapes: tuple, - batch_size: int, - model_comm_group: Optional[ProcessGroup] = None, - size: Optional[Size] = None, - **kwargs, - ) -> OptPairTensor: - for i in range(self.num_layers): - x, edge_attr = self.blocks[i]( - x, edge_attr, edge_index, shapes, batch_size, size, model_comm_group=model_comm_group, **kwargs - ) - - return x, edge_attr diff --git a/models/src/anemoi/models/layers/processor.py b/models/src/anemoi/models/layers/processor.py index 5511eddf2..8cff193be 100644 --- a/models/src/anemoi/models/layers/processor.py +++ b/models/src/anemoi/models/layers/processor.py @@ -22,10 +22,10 @@ from anemoi.models.distributed.khop_edges import sort_edges_1hop_sharding from anemoi.models.distributed.shapes import change_channels_in_shape from anemoi.models.distributed.shapes import get_shard_shapes -from anemoi.models.layers.chunk import GNNProcessorChunk -from anemoi.models.layers.chunk import GraphTransformerProcessorChunk -from anemoi.models.layers.chunk import PointWiseMLPProcessorChunk -from anemoi.models.layers.chunk import TransformerProcessorChunk +from anemoi.models.layers.block import GraphConvProcessorBlock +from anemoi.models.layers.block import GraphTransformerProcessorBlock +from anemoi.models.layers.block import PointWiseMLPProcessorBlock +from anemoi.models.layers.block import TransformerProcessorBlock from anemoi.models.layers.graph import TrainableTensor from anemoi.models.layers.mapper import GraphEdgeMixin from anemoi.models.layers.utils import load_layer_kernels @@ -45,13 +45,30 @@ def __init__( layer_kernels: DotDict, **kwargs, ) -> None: - """Initialize BaseProcessor.""" + """Initialize BaseProcessor. + + Parameters + ---------- + num_layers : int + Number of processor layers. + num_channels : int + Number of channels, i.e. feature dimension of the processor state. + num_chunks: int + Number of chunks of the processor. The num_chunks and num_layers, defines how many layers are grouped together for checkpointing, i.e. chunk_size = num_layers/ num_chunks. + cpu_offload : bool + Whether to offload processing to CPU, by default False + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels.Linear = "torch.nn.Linear" + Defined in config/models/.yaml + **kwargs : dict + Additional keyword arguments + """ super().__init__() - # Each Processor divides the layers into chunks that get assigned to each ProcessorChunk + self.num_layers = num_layers self.num_chunks = num_chunks - self.num_channels = num_channels self.chunk_size = num_layers // num_chunks + self.num_channels = num_channels self.layer_factory = load_layer_kernels(layer_kernels) @@ -65,22 +82,29 @@ def offload_layers(self, cpu_offload): if cpu_offload: self.proc = nn.ModuleList([offload_wrapper(x) for x in self.proc]) - def build_layers(self, processor_chunk_class, *args, **kwargs) -> None: + def build_layers(self, layer_class, *layer_args, **layer_kwargs) -> None: """Build Layers.""" self.proc = nn.ModuleList( [ - processor_chunk_class( - *args, - **kwargs, + layer_class( + *layer_args, + **layer_kwargs, ) - for _ in range(self.num_chunks) + for _ in range(self.num_layers) ], ) - def run_layers(self, data: tuple, *args, **kwargs) -> Tensor: - """Run Layers with checkpoint.""" - for layer in self.proc: - data = checkpoint(layer, *data, *args, **kwargs, use_reentrant=False) + def run_layer_chunk(self, chunk_start: int, data: tuple, *args, **kwargs) -> tuple: + for layer_id in range(chunk_start, chunk_start + self.chunk_size): + data = self.proc[layer_id](*data, *args, **kwargs) + + return data + + def run_layers(self, data: tuple, *args, **kwargs) -> tuple: + """Run Layers with checkpoints around chunks.""" + for chunk_start in range(0, self.num_layers, self.chunk_size): + data = checkpoint(self.run_layer_chunk, chunk_start, data, *args, **kwargs, use_reentrant=False) + return data def forward(self, x: Tensor, *args, **kwargs) -> Tensor: @@ -120,11 +144,10 @@ def __init__( ) self.build_layers( - PointWiseMLPProcessorChunk, + PointWiseMLPProcessorBlock, num_channels=num_channels, - num_layers=self.chunk_size, + hidden_dim=(mlp_hidden_ratio * num_channels), layer_kernels=self.layer_factory, - mlp_hidden_ratio=mlp_hidden_ratio, dropout_p=dropout_p, ) @@ -217,14 +240,13 @@ def __init__( ) self.build_layers( - TransformerProcessorChunk, + TransformerProcessorBlock, num_channels=num_channels, - num_layers=self.chunk_size, - layer_kernels=self.layer_factory, - mlp_hidden_ratio=mlp_hidden_ratio, + hidden_dim=(mlp_hidden_ratio * num_channels), num_heads=num_heads, - window_size=window_size, qk_norm=qk_norm, + window_size=window_size, + layer_kernels=self.layer_factory, dropout_p=dropout_p, attention_implementation=attention_implementation, softcap=softcap, @@ -248,7 +270,7 @@ def forward( model_comm_group.size() == 1 or batch_size == 1 ), "Only batch size of 1 is supported when model is sharded accross GPUs" - (x,) = self.run_layers((x,), shape_nodes, batch_size, model_comm_group, **kwargs) + (x,) = self.run_layers((x,), shape_nodes, batch_size, model_comm_group=model_comm_group, **kwargs) return x @@ -320,10 +342,21 @@ def __init__( "edge_dim": None, } - self.build_layers(GNNProcessorChunk, num_channels, self.chunk_size, **kwargs) + self.build_layers( + GraphConvProcessorBlock, + in_channels=num_channels, + out_channels=num_channels, + num_chunks=1, + **kwargs, + ) kwargs["edge_dim"] = self.edge_dim # Edge dim for first layer - self.proc[0] = GNNProcessorChunk(num_channels, self.chunk_size, **kwargs) + self.proc[0] = GraphConvProcessorBlock( + in_channels=num_channels, + out_channels=num_channels, + num_chunks=1, + **kwargs, + ) self.offload_layers(cpu_offload) @@ -424,14 +457,14 @@ def __init__( self.trainable = TrainableTensor(trainable_size=trainable_size, tensor_size=self.edge_attr.shape[0]) self.build_layers( - GraphTransformerProcessorChunk, - num_channels=num_channels, - num_layers=self.chunk_size, - layer_kernels=self.layer_factory, + GraphTransformerProcessorBlock, + in_channels=num_channels, + hidden_dim=(mlp_hidden_ratio * num_channels), + out_channels=num_channels, num_heads=num_heads, - mlp_hidden_ratio=mlp_hidden_ratio, - qk_norm=qk_norm, edge_dim=self.edge_dim, + layer_kernels=self.layer_factory, + qk_norm=qk_norm, ) self.offload_layers(cpu_offload) diff --git a/models/src/anemoi/models/migrations/scripts/1762857428_chunking_fix.py b/models/src/anemoi/models/migrations/scripts/1762857428_chunking_fix.py new file mode 100644 index 000000000..770bc5e8d --- /dev/null +++ b/models/src/anemoi/models/migrations/scripts/1762857428_chunking_fix.py @@ -0,0 +1,76 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from anemoi.models.migrations import CkptType +from anemoi.models.migrations import MigrationMetadata + +# DO NOT CHANGE --> +metadata = MigrationMetadata( + versions={ + "migration": "1.0.0", + "anemoi-models": "%NEXT_ANEMOI_MODELS_VERSION%", + }, +) +# <-- END DO NOT CHANGE + + +def migrate(ckpt: CkptType) -> CkptType: + """Migrate the checkpoint. + + Parameters + ---------- + ckpt : CkptType + The checkpoint dict. + + Returns + ------- + CkptType + The migrated checkpoint dict. + """ + num_layers = ckpt["hyper_parameters"]["config"].model.processor.num_layers + num_chunks = ckpt["hyper_parameters"]["config"].model.processor.num_chunks + state_dict = ckpt["state_dict"] + + blocks_per_chunk = num_layers // num_chunks + updates = {} + + for key in [k for k in list(state_dict.keys()) if "processor.proc" in k]: + parts = key.split(".") + if not parts[5] == "blocks": # expecting format model.model.processor.proc.i.blocks.j.... + continue + + chunk_idx = int(parts[4]) + block_idx = int(parts[6]) + + flat_idx = chunk_idx * blocks_per_chunk + block_idx + rest = [""] + parts[7:] + # reconstruct new key: model.model.processor.proc.. + new_key = "model.model.processor.proc." + str(flat_idx) + ".".join(rest) + + updates[new_key] = state_dict[key] + del state_dict[key] + + ckpt["state_dict"].update(updates) + return ckpt + + +def rollback(ckpt: CkptType) -> CkptType: + """Rollback the checkpoint. + + Parameters + ---------- + ckpt : CkptType + The checkpoint dict. + + Returns + ------- + CkptType + The rollbacked checkpoint dict. + """ + return ckpt diff --git a/models/tests/layers/block/test_block_graphtransformer.py b/models/tests/layers/block/test_block_graphtransformer.py index d6b4f2d45..6e5d7167c 100644 --- a/models/tests/layers/block/test_block_graphtransformer.py +++ b/models/tests/layers/block/test_block_graphtransformer.py @@ -29,7 +29,6 @@ def init_proc(): edge_dim = 11 bias = True num_heads = 8 - num_chunks = 2 layer_kernels = load_layer_kernels() qk_norm = True return ( @@ -40,7 +39,6 @@ def init_proc(): layer_kernels, bias, num_heads, - num_chunks, qk_norm, ) @@ -55,7 +53,6 @@ def block(init_proc): layer_kernels, bias, num_heads, - num_chunks, qk_norm, ) = init_proc return GraphTransformerProcessorBlock( @@ -67,7 +64,6 @@ def block(init_proc): num_heads=num_heads, bias=bias, update_src_nodes=False, - num_chunks=num_chunks, qk_norm=qk_norm, ) @@ -81,7 +77,6 @@ def test_GraphTransformerProcessorBlock_init(init_proc, block): _layer_kernels, _bias, num_heads, - num_chunks, _qk_norm, ) = init_proc assert isinstance( @@ -91,7 +86,6 @@ def test_GraphTransformerProcessorBlock_init(init_proc, block): block.out_channels_conv == out_channels // num_heads ), f"block.out_channels_conv ({block.out_channels_conv}) != out_channels // num_heads ({out_channels // num_heads})" assert block.num_heads == num_heads, f"block.num_heads ({block.num_heads}) != num_heads ({num_heads})" - assert block.num_chunks == num_chunks, f"block.num_chunks ({block.num_chunks}) != num_chunks ({num_chunks})" assert isinstance(block.lin_key, torch.nn.Linear), "block.lin_key is not an instance of torch.nn.Linear" assert isinstance(block.lin_query, torch.nn.Linear), "block.lin_query is not an instance of torch.nn.Linear" assert isinstance(block.lin_value, torch.nn.Linear), "block.lin_value is not an instance of torch.nn.Linear" @@ -115,7 +109,6 @@ def test_GraphTransformerProcessorBlock_shard_qkve_heads(init_proc, block): _layer_kernels, _bias, num_heads, - _num_chunks, _qk_norm, ) = init_proc query = torch.randn(in_channels, num_heads * block.out_channels_conv) @@ -140,7 +133,6 @@ def test_GraphTransformerProcessorBlock_shard_output_seq(init_proc, block): _layer_kernels, _bias, num_heads, - _num_chunks, _qk_norm, ) = init_proc out = torch.randn(in_channels, num_heads, block.out_channels_conv) @@ -160,7 +152,6 @@ def test_GraphTransformerProcessorBlock_forward_backward(init_proc, block): _layer_kernels, _bias, _num_heads, - _num_chunks, _qk_norm, ) = init_proc @@ -206,7 +197,6 @@ def test_GraphTransformerProcessorBlock_chunking(init_proc, block, monkeypatch): _bias, _activation, _num_heads, - _num_chunks, ) = init_proc # Initialize GraphTransformerProcessorBlock block = block diff --git a/models/tests/layers/block/test_block_pointwise.py b/models/tests/layers/block/test_block_pointwise.py index 4b9c049e9..8bdb94ae8 100644 --- a/models/tests/layers/block/test_block_pointwise.py +++ b/models/tests/layers/block/test_block_pointwise.py @@ -89,5 +89,5 @@ def test_forward_output( x = torch.randn((batch_size, num_channels)) # .to(torch.float16, non_blocking=True) output = block.forward(x, shapes, batch_size) - assert isinstance(output, torch.Tensor) - assert output.shape == (batch_size, num_channels) + assert isinstance(output[0], torch.Tensor) + assert output[0].shape == (batch_size, num_channels) diff --git a/models/tests/layers/block/test_block_transformer.py b/models/tests/layers/block/test_block_transformer.py index 376c2eb45..af15a2e3c 100644 --- a/models/tests/layers/block/test_block_transformer.py +++ b/models/tests/layers/block/test_block_transformer.py @@ -123,8 +123,8 @@ def test_forward_output( x = torch.randn((batch_size, num_channels)) # .to(torch.float16, non_blocking=True) output = block.forward(x, shapes, batch_size) - assert isinstance(output, torch.Tensor) - assert output.shape == (batch_size, num_channels) + assert isinstance(output[0], torch.Tensor) + assert output[0].shape == (batch_size, num_channels) class TestGraphConvProcessorBlock: diff --git a/models/tests/layers/chunk/test_chunk_gnn.py b/models/tests/layers/chunk/test_chunk_gnn.py deleted file mode 100644 index 5e1cdb41e..000000000 --- a/models/tests/layers/chunk/test_chunk_gnn.py +++ /dev/null @@ -1,48 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import pytest - -from anemoi.models.layers.block import GraphConvProcessorBlock -from anemoi.models.layers.chunk import GNNProcessorChunk -from anemoi.models.layers.mlp import MLP -from anemoi.models.layers.utils import load_layer_kernels - - -class TestGNNProcessorChunk: - @pytest.fixture - def init(self): - num_channels = 10 - num_layers = 3 - mlp_extra_layers = 3 - edge_dim = None - layer_kernels = load_layer_kernels() - return num_channels, num_layers, layer_kernels, mlp_extra_layers, edge_dim - - @pytest.fixture - def processor_chunk(self, init): - num_channels, num_layers, layer_kernels, mlp_extra_layers, edge_dim = init - return GNNProcessorChunk( - num_channels=num_channels, - num_layers=num_layers, - layer_kernels=layer_kernels, - mlp_extra_layers=mlp_extra_layers, - edge_dim=edge_dim, - ) - - def test_embed_edges(self, init, processor_chunk): - _num_channels, _num_layers, _layer_kernels, _mlp_extra_layers, edge_dim = init - if edge_dim: - assert isinstance(processor_chunk.emb_edges, MLP) - else: - assert processor_chunk.emb_edges is None - - def test_all_blocks(self, processor_chunk): - assert all(isinstance(block, GraphConvProcessorBlock) for block in processor_chunk.blocks) diff --git a/models/tests/layers/chunk/test_chunk_graphtransformer.py b/models/tests/layers/chunk/test_chunk_graphtransformer.py deleted file mode 100644 index 58d6ec435..000000000 --- a/models/tests/layers/chunk/test_chunk_graphtransformer.py +++ /dev/null @@ -1,52 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import pytest - -from anemoi.models.layers.block import GraphTransformerProcessorBlock -from anemoi.models.layers.chunk import GraphTransformerProcessorChunk -from anemoi.models.layers.utils import load_layer_kernels - - -class TestGraphTransformerProcessorChunk: - @pytest.fixture - def init(self): - num_channels: int = 10 - num_layers: int = 3 - num_heads: int = 16 - mlp_hidden_ratio: int = 4 - qk_norm = True - edge_dim: int = 32 - layer_kernels = load_layer_kernels() - return ( - num_channels, - num_layers, - layer_kernels, - num_heads, - mlp_hidden_ratio, - qk_norm, - edge_dim, - ) - - @pytest.fixture - def processor_chunk(self, init): - num_channels, num_layers, layer_kernels, num_heads, mlp_hidden_ratio, qk_norm, edge_dim = init - return GraphTransformerProcessorChunk( - num_channels=num_channels, - num_layers=num_layers, - layer_kernels=layer_kernels, - num_heads=num_heads, - mlp_hidden_ratio=mlp_hidden_ratio, - qk_norm=qk_norm, - edge_dim=edge_dim, - ) - - def test_all_blocks(self, processor_chunk): - assert all(isinstance(block, GraphTransformerProcessorBlock) for block in processor_chunk.blocks) diff --git a/models/tests/layers/chunk/test_chunk_pointwise.py b/models/tests/layers/chunk/test_chunk_pointwise.py deleted file mode 100644 index fb7dea3b0..000000000 --- a/models/tests/layers/chunk/test_chunk_pointwise.py +++ /dev/null @@ -1,53 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import pytest - -from anemoi.models.layers.block import PointWiseMLPProcessorBlock -from anemoi.models.layers.chunk import PointWiseMLPProcessorChunk -from anemoi.models.layers.utils import load_layer_kernels - - -class TestPointWiseMLPProcessorChunk: - @pytest.fixture - def init(self): - num_channels = 512 - num_layers = 3 - mlp_hidden_ratio: int = 4 - dropout_p: float = 0.1 - layer_kernels = load_layer_kernels() - - return ( - num_channels, - num_layers, - layer_kernels, - mlp_hidden_ratio, - dropout_p, - ) - - @pytest.fixture - def processor_chunk(self, init): - ( - num_channels, - num_layers, - layer_kernels, - mlp_hidden_ratio, - dropout_p, - ) = init - return PointWiseMLPProcessorChunk( - num_channels=num_channels, - num_layers=num_layers, - layer_kernels=layer_kernels, - mlp_hidden_ratio=mlp_hidden_ratio, - dropout_p=dropout_p, - ) - - def test_all_blocks(self, processor_chunk): - assert all(isinstance(block, PointWiseMLPProcessorBlock) for block in processor_chunk.blocks) diff --git a/models/tests/layers/chunk/test_chunk_transformer.py b/models/tests/layers/chunk/test_chunk_transformer.py deleted file mode 100644 index e43272a54..000000000 --- a/models/tests/layers/chunk/test_chunk_transformer.py +++ /dev/null @@ -1,70 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import pytest - -from anemoi.models.layers.block import TransformerProcessorBlock -from anemoi.models.layers.chunk import TransformerProcessorChunk -from anemoi.models.layers.utils import load_layer_kernels - - -class TestTransformerProcessorChunk: - @pytest.fixture - def init(self): - num_channels = 512 - num_layers = 3 - num_heads: int = 16 - mlp_hidden_ratio: int = 4 - window_size: int = 13 - dropout_p: float = 0.1 - layer_kernels = load_layer_kernels() - attention_implementation = "scaled_dot_product_attention" - qk_norm = True - - # num_heads must be evenly divisible by num_channels for MHSA - return ( - num_channels, - num_layers, - layer_kernels, - num_heads, - mlp_hidden_ratio, - window_size, - dropout_p, - attention_implementation, - qk_norm, - ) - - @pytest.fixture - def processor_chunk(self, init): - ( - num_channels, - num_layers, - layer_kernels, - num_heads, - mlp_hidden_ratio, - window_size, - dropout_p, - attention_implementation, - qk_norm, - ) = init - return TransformerProcessorChunk( - num_channels=num_channels, - num_layers=num_layers, - layer_kernels=layer_kernels, - num_heads=num_heads, - mlp_hidden_ratio=mlp_hidden_ratio, - qk_norm=qk_norm, - window_size=window_size, - dropout_p=dropout_p, - attention_implementation=attention_implementation, - ) - - def test_all_blocks(self, processor_chunk): - assert all(isinstance(block, TransformerProcessorBlock) for block in processor_chunk.blocks) diff --git a/models/tests/layers/processor/test_graphconv_processor.py b/models/tests/layers/processor/test_graphconv_processor.py index 39e51c7ba..df879fdb1 100644 --- a/models/tests/layers/processor/test_graphconv_processor.py +++ b/models/tests/layers/processor/test_graphconv_processor.py @@ -15,6 +15,7 @@ import torch from torch_geometric.data import HeteroData +from anemoi.models.layers.block import GraphConvProcessorBlock from anemoi.models.layers.graph import TrainableTensor from anemoi.models.layers.processor import GNNProcessor from anemoi.models.layers.utils import load_layer_kernels @@ -70,6 +71,9 @@ def test_graphconv_processor_init(self, graphconv_processor, graphconv_init): assert graphconv_processor.chunk_size == graphconv_init.num_layers // graphconv_init.num_chunks assert isinstance(graphconv_processor.trainable, TrainableTensor) + def test_all_blocks(self, graphconv_processor): + assert all(isinstance(block, GraphConvProcessorBlock) for block in graphconv_processor.proc) + def test_forward(self, graphconv_processor, graphconv_init): batch_size = 1 x = torch.rand((self.NUM_EDGES, graphconv_init.num_channels)) diff --git a/models/tests/layers/processor/test_graphtransformer_processor.py b/models/tests/layers/processor/test_graphtransformer_processor.py index eb2adab33..49bc27948 100644 --- a/models/tests/layers/processor/test_graphtransformer_processor.py +++ b/models/tests/layers/processor/test_graphtransformer_processor.py @@ -15,6 +15,7 @@ import torch from torch_geometric.data import HeteroData +from anemoi.models.layers.block import GraphTransformerProcessorBlock from anemoi.models.layers.graph import TrainableTensor from anemoi.models.layers.processor import GraphTransformerProcessor from anemoi.models.layers.utils import load_layer_kernels @@ -31,8 +32,8 @@ class GraphTransformerProcessorConfig: trainable_size: int = 6 src_grid_size: int = 0 dst_grid_size: int = 0 - qk_norm: bool = (True,) - cpu_offload: bool = (False,) + qk_norm: bool = True + cpu_offload: bool = False layer_kernels: field(default_factory=DotDict) = None def __post_init__(self): @@ -75,6 +76,9 @@ def test_graphtransformer_processor_init(self, graphtransformer_processor, graph ) assert isinstance(graphtransformer_processor.trainable, TrainableTensor) + def test_all_blocks(self, graphtransformer_processor): + assert all(isinstance(block, GraphTransformerProcessorBlock) for block in graphtransformer_processor.proc) + def test_forward(self, graphtransformer_processor, graphtransformer_init): batch_size = 1 diff --git a/models/tests/layers/processor/test_pointwise_processor.py b/models/tests/layers/processor/test_pointwise_processor.py index 0c105c999..bf4938344 100644 --- a/models/tests/layers/processor/test_pointwise_processor.py +++ b/models/tests/layers/processor/test_pointwise_processor.py @@ -15,6 +15,7 @@ import pytest import torch +from anemoi.models.layers.block import PointWiseMLPProcessorBlock from anemoi.models.layers.processor import PointWiseMLPProcessor from anemoi.models.layers.utils import load_layer_kernels from anemoi.utils.config import DotDict @@ -53,6 +54,9 @@ def test_pointwisemlp_processor_init(pointwisemlp_processor, pointwisemlp_proces == pointwisemlp_processor_init.num_layers // pointwisemlp_processor_init.num_chunks ) + def test_all_blocks(self, pointwisemlp_processor): + assert all(isinstance(block, PointWiseMLPProcessorBlock) for block in pointwisemlp_processor.proc) + @pytest.fixture(params=[0.1, None]) def test_pointwisemlp_processor_with_sharding_dropout_forward(pointwisemlp_processor, pointwisemlp_processor_init): diff --git a/models/tests/layers/processor/test_transformer_processor.py b/models/tests/layers/processor/test_transformer_processor.py index 1294f402d..a6c4af2b4 100644 --- a/models/tests/layers/processor/test_transformer_processor.py +++ b/models/tests/layers/processor/test_transformer_processor.py @@ -14,6 +14,7 @@ import pytest import torch +from anemoi.models.layers.block import TransformerProcessorBlock from anemoi.models.layers.processor import TransformerProcessor from anemoi.models.layers.utils import load_layer_kernels from anemoi.utils.config import DotDict @@ -61,6 +62,10 @@ def test_transformer_processor_init(transformer_processor, transformer_processor ) +def test_all_blocks(transformer_processor): + assert all(isinstance(block, TransformerProcessorBlock) for block in transformer_processor.proc) + + def test_transformer_processor_forward(transformer_processor, transformer_processor_init): gridsize = 100 batch_size = 1 diff --git a/training/src/anemoi/training/config/model/point_wise.yaml b/training/src/anemoi/training/config/model/point_wise.yaml index fc1c66f5b..4dc846b70 100644 --- a/training/src/anemoi/training/config/model/point_wise.yaml +++ b/training/src/anemoi/training/config/model/point_wise.yaml @@ -33,16 +33,28 @@ processor: encoder: _target_: anemoi.models.layers.mapper.GraphTransformerForwardMapper + trainable_size: ${model.trainable_parameters.data2hidden} sub_graph_edge_attributes: ${model.attributes.edges} + num_chunks: 4 + mlp_hidden_ratio: 4 # GraphTransformer or Transformer only + num_heads: 16 # GraphTransformer or Transformer only + qk_norm: False cpu_offload: ${model.cpu_offload} layer_kernels: ${model.layer_kernels} + shard_strategy: "edges" decoder: _target_: anemoi.models.layers.mapper.GraphTransformerBackwardMapper + trainable_size: ${model.trainable_parameters.hidden2data} sub_graph_edge_attributes: ${model.attributes.edges} + num_chunks: 4 + mlp_hidden_ratio: 4 # GraphTransformer or Transformer only + num_heads: 16 # GraphTransformer or Transformer only initialise_data_extractor_zero: False + qk_norm: False cpu_offload: ${model.cpu_offload} layer_kernels: ${model.layer_kernels} + shard_strategy: "edges" output_mask: _target_: anemoi.training.utils.masks.NoOutputMask diff --git a/training/src/anemoi/training/utils/checkpoint.py b/training/src/anemoi/training/utils/checkpoint.py index e2daad886..9c92ba4cd 100644 --- a/training/src/anemoi/training/utils/checkpoint.py +++ b/training/src/anemoi/training/utils/checkpoint.py @@ -8,6 +8,7 @@ # nor does it submit to any jurisdiction. +import importlib import io import logging import pickle @@ -24,6 +25,8 @@ from anemoi.training.train.tasks.base import BaseGraphModule from anemoi.utils.checkpoints import save_metadata +chunking_fix_migration = importlib.import_module("anemoi.models.migrations.scripts.1762857428_chunking_fix").migrate + LOGGER = logging.getLogger(__name__) @@ -80,6 +83,10 @@ def transfer_learning_loading(model: torch.nn.Module, ckpt_path: Path | str) -> # Load the checkpoint checkpoint = torch.load(ckpt_path, weights_only=False, map_location=model.device) + # apply chunking migration (fails silently otherwise leading to hard to debug issues) + # this is due to loading with strict=False, planning to make this more robust in the future + checkpoint = chunking_fix_migration(checkpoint) + # Filter out layers with size mismatch state_dict = checkpoint["state_dict"]