Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions models/docs/introduction/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,10 @@ Processors
==========

Additionally, the layers implement `Processors` which are used to
process the data on the hidden grid. The `Processors` use a chunking
strategy with `Chunks` that pass a subset of layers to `Blocks` to allow
for more efficient processing of the data.
process the data on the hidden grid. The `Processors` use a series of
`Blocks` to process the data. These `Blocks` can be partitioned into
checkpointed chunks via `num_chunks` to reduce memory usage during
training.

**************
Data Indices
Expand Down
36 changes: 21 additions & 15 deletions models/src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def forward(
batch_size: int,
model_comm_group: Optional[ProcessGroup] = None,
**layer_kwargs,
) -> Tensor:
return self.mlp(x)
) -> tuple[Tensor]:
return (self.mlp(x),)


class TransformerProcessorBlock(BaseBlock):
Expand Down Expand Up @@ -146,7 +146,7 @@ def forward(
model_comm_group: Optional[ProcessGroup] = None,
cond: Optional[Tensor] = None,
**layer_kwargs,
) -> Tensor:
) -> tuple[Tensor]:

# In case we have conditionings we pass these to the layer norm
cond_kwargs = {"cond": cond} if cond is not None else {}
Expand All @@ -160,7 +160,7 @@ def forward(
**cond_kwargs,
)
)
return x
return (x,)


class TransformerMapperBlock(TransformerProcessorBlock):
Expand Down Expand Up @@ -222,7 +222,7 @@ def forward(
shapes: list,
batch_size: int,
model_comm_group: Optional[ProcessGroup] = None,
) -> Tensor:
) -> tuple[Tensor, Tensor]:
x_src = self.layer_norm_attention_src(x[0])
x_dst = self.layer_norm_attention_dst(x[1])
x_dst = x_dst + self.attention((x_src, x_dst), shapes, batch_size, model_comm_group=model_comm_group)
Expand All @@ -242,6 +242,7 @@ def __init__(
mlp_extra_layers: int = 0,
update_src_nodes: bool = True,
layer_kernels: DotDict,
edge_dim: Optional[int] = None,
**kwargs,
) -> None:
"""Initialize GNNBlock.
Expand All @@ -264,6 +265,17 @@ def __init__(
"""
super().__init__(**kwargs)

if edge_dim:
self.emb_edges = MLP(
in_features=edge_dim,
hidden_dim=out_channels,
out_features=out_channels,
layer_kernels=layer_kernels,
n_extra_layers=mlp_extra_layers,
)
else:
self.emb_edges = None

self.update_src_nodes = update_src_nodes
self.num_chunks = num_chunks

Expand Down Expand Up @@ -306,6 +318,8 @@ def forward(
size: Optional[Size] = None,
**layer_kwargs,
) -> tuple[Tensor, Tensor]:
if self.emb_edges is not None:
edge_attr = self.emb_edges(edge_attr)

x_in = sync_tensor(x, 0, shapes[1], model_comm_group)

Expand Down Expand Up @@ -424,7 +438,6 @@ def __init__(
hidden_dim: int,
out_channels: int,
num_heads: int,
num_chunks: int,
edge_dim: int,
bias: bool = True,
qk_norm: bool = False,
Expand All @@ -442,8 +455,6 @@ def __init__(
Number of output channels.
num_heads : int,
Number of heads
num_chunks : int,
Number of chunks
edge_dim : int,
Edge dimension
bias : bool, by default True,
Expand All @@ -463,7 +474,6 @@ def __init__(
self.out_channels_conv = out_channels // num_heads
self.num_heads = num_heads
self.qk_norm = qk_norm
self.num_chunks = num_chunks

Linear = layer_kernels.Linear
LayerNorm = layer_kernels.LayerNorm
Expand Down Expand Up @@ -662,7 +672,6 @@ def __init__(
layer_kernels=layer_kernels,
num_heads=num_heads,
bias=bias,
num_chunks=1,
qk_norm=qk_norm,
update_src_nodes=update_src_nodes,
**kwargs,
Expand Down Expand Up @@ -777,7 +786,6 @@ def __init__(
hidden_dim: int,
out_channels: int,
num_heads: int,
num_chunks: int,
edge_dim: int,
bias: bool = True,
qk_norm: bool = False,
Expand All @@ -795,8 +803,6 @@ def __init__(
Number of output channels.
num_heads : int,
Number of heads
num_chunks : int,
Number of chunks
edge_dim : int,
Edge dimension
bias : bool
Expand All @@ -819,7 +825,6 @@ def __init__(
num_heads=num_heads,
bias=bias,
qk_norm=qk_norm,
num_chunks=num_chunks,
update_src_nodes=update_src_nodes,
**kwargs,
)
Expand Down Expand Up @@ -851,7 +856,8 @@ def forward(
query = self.q_norm(query)
key = self.k_norm(key)

num_chunks = self.num_chunks if self.training else NUM_CHUNKS_INFERENCE_PROCESSOR
# "inner" chunking for memory reductions in inference, controlled via env variable:
num_chunks = 1 if self.training else NUM_CHUNKS_INFERENCE_PROCESSOR

out = self.attention_block(query, key, value, edges, edge_index, size, num_chunks)

Expand Down
Loading