Skip to content
Merged

EGNO #602

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_rst/_code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ Models
GraphNeuralOperator <model/graph_neural_operator.rst>
GraphNeuralKernel <model/graph_neural_operator_integral_kernel.rst>
PirateNet <model/pirate_network.rst>
EquivariantGraphNeuralOperator <model/equivariant_graph_neural_operator.rst>

Blocks
-------------
Expand Down Expand Up @@ -134,6 +135,7 @@ Message Passing
E(n) Equivariant Network Block <model/block/message_passing/en_equivariant_network_block.rst>
Interaction Network Block <model/block/message_passing/interaction_network_block.rst>
Radial Field Network Block <model/block/message_passing/radial_field_network_block.rst>
EquivariantGraphNeuralOperatorBlock <model/block/message_passing/equivariant_graph_neural_operator_block.rst>


Reduction and Embeddings
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
EquivariantGraphNeuralOperatorBlock
=====================================
.. currentmodule:: pina.model.block.message_passing.equivariant_graph_neural_operator_block

.. autoclass:: EquivariantGraphNeuralOperatorBlock
:members:
:show-inheritance:
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
EquivariantGraphNeuralOperator
=================================
.. currentmodule:: pina.model.equivariant_graph_neural_operator

.. autoclass:: EquivariantGraphNeuralOperator
:members:
:show-inheritance:
2 changes: 2 additions & 0 deletions pina/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"Spline",
"GraphNeuralOperator",
"PirateNet",
"EquivariantGraphNeuralOperator",
]

from .feed_forward import FeedForward, ResidualFeedForward
Expand All @@ -26,3 +27,4 @@
from .spline import Spline
from .graph_neural_operator import GraphNeuralOperator
from .pirate_network import PirateNet
from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator
4 changes: 4 additions & 0 deletions pina/model/block/message_passing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
"DeepTensorNetworkBlock",
"EnEquivariantNetworkBlock",
"RadialFieldNetworkBlock",
"EquivariantGraphNeuralOperatorBlock",
]

from .interaction_network_block import InteractionNetworkBlock
from .deep_tensor_network_block import DeepTensorNetworkBlock
from .en_equivariant_network_block import EnEquivariantNetworkBlock
from .radial_field_network_block import RadialFieldNetworkBlock
from .equivariant_graph_neural_operator_block import (
EquivariantGraphNeuralOperatorBlock,
)
62 changes: 51 additions & 11 deletions pina/model/block/message_passing/en_equivariant_network_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import degree
from ....utils import check_positive_integer
from ....utils import check_positive_integer, check_consistency
from ....model import FeedForward


Expand All @@ -27,6 +27,12 @@ class EnEquivariantNetworkBlock(MessagePassing):
positions are updated by adding the incoming messages divided by the
degree of the recipient node.

When velocity features are used, node velocities are passed through a small
MLP to compute updates, which are then combined with the aggregated position
messages. The node positions are updated both by the normalized position
messages and by the updated velocities, ensuring equivariance while
incorporating dynamic information.

.. seealso::

**Original reference** Satorras, V. G., Hoogeboom, E., Welling, M.
Expand All @@ -40,6 +46,7 @@ def __init__(
node_feature_dim,
edge_feature_dim,
pos_dim,
use_velocity=False,
hidden_dim=64,
n_message_layers=2,
n_update_layers=2,
Expand All @@ -54,6 +61,8 @@ def __init__(
:param int node_feature_dim: The dimension of the node features.
:param int edge_feature_dim: The dimension of the edge features.
:param int pos_dim: The dimension of the position features.
:param bool use_velocity: Whether to use velocity features in the
message passing. Default is False.
:param int hidden_dim: The dimension of the hidden features.
Default is 64.
:param int n_message_layers: The number of layers in the message
Expand All @@ -80,6 +89,7 @@ def __init__(
:raises AssertionError: If `hidden_dim` is not a positive integer.
:raises AssertionError: If `n_message_layers` is not a positive integer.
:raises AssertionError: If `n_update_layers` is not a positive integer.
:raises AssertionError: If `use_velocity` is not a boolean.
"""
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)

Expand All @@ -90,6 +100,10 @@ def __init__(
check_positive_integer(hidden_dim, strict=True)
check_positive_integer(n_message_layers, strict=True)
check_positive_integer(n_update_layers, strict=True)
check_consistency(use_velocity, bool)

# Initialization
self.use_velocity = use_velocity

# Layer for computing the message
self.message_net = FeedForward(
Expand Down Expand Up @@ -119,7 +133,17 @@ def __init__(
func=activation,
)

def forward(self, x, pos, edge_index, edge_attr=None):
# If velocity is used, instantiate layer for velocity updates
if self.use_velocity:
self.update_vel_net = FeedForward(
input_dimensions=node_feature_dim,
output_dimensions=1,
inner_size=hidden_dim,
n_layers=n_update_layers,
func=activation,
)

def forward(self, x, pos, edge_index, edge_attr=None, vel=None):
"""
Forward pass of the block, triggering the message-passing routine.

Expand All @@ -130,11 +154,19 @@ def forward(self, x, pos, edge_index, edge_attr=None):
:param torch.Tensor edge_index: The edge indices.
:param edge_attr: The edge attributes. Default is None.
:type edge_attr: torch.Tensor | LabelTensor
:param vel: The velocity of the nodes. Default is None.
:type vel: torch.Tensor | LabelTensor
:return: The updated node features and node positions.
:rtype: tuple(torch.Tensor, torch.Tensor)
:raises: ValueError: If ``use_velocity`` is True and ``vel`` is None.
"""
if self.use_velocity and vel is None:
raise ValueError(
"Velocity features are enabled, but no velocity is passed."
)

return self.propagate(
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr, vel=vel
)

def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
Expand Down Expand Up @@ -202,28 +234,36 @@ def aggregate(self, inputs, index, ptr=None, dim_size=None):

return agg_message, agg_m_ij

def update(self, aggregated_inputs, x, pos, edge_index):
def update(self, aggregated_inputs, x, pos, edge_index, vel):
"""
Update the node features and the node coordinates with the received
messages.
Update node features, positions, and optionally velocities.

:param tuple(torch.Tensor) aggregated_inputs: The messages to be passed.
:param x: The node features.
:type x: torch.Tensor | LabelTensor
:param pos: The euclidean coordinates of the nodes.
:type pos: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: The edge indices.
:param vel: The velocity of the nodes.
:type vel: torch.Tensor | LabelTensor
:return: The updated node features and node positions.
:rtype: tuple(torch.Tensor, torch.Tensor)
:rtype: tuple(torch.Tensor, torch.Tensor) |
tuple(torch.Tensor, torch.Tensor, torch.Tensor)
"""
# aggregated_inputs is tuple (agg_message, agg_m_ij)
agg_message, agg_m_ij = aggregated_inputs

# Degree for normalization of position updates
c = degree(edge_index[1], pos.shape[0]).unsqueeze(-1).clamp(min=1)

# If velocity is used, update it and use it to update positions
if self.use_velocity:
vel = self.update_vel_net(x) * vel

# Update node features with aggregated m_ij
x = self.update_feat_net(torch.cat((x, agg_m_ij), dim=-1))

# Degree for normalization of position updates
c = degree(edge_index[1], pos.shape[0]).unsqueeze(-1).clamp(min=1)
pos = pos + agg_message / c
# Update positions with aggregated messages m_ij and velocities
pos = pos + agg_message / c + (vel if self.use_velocity else 0)

return x, pos
return (x, pos, vel) if self.use_velocity else (x, pos)
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""Module for the Equivariant Graph Neural Operator block."""

import torch
from ....utils import check_positive_integer
from .en_equivariant_network_block import EnEquivariantNetworkBlock


class EquivariantGraphNeuralOperatorBlock(torch.nn.Module):
"""
A single block of the Equivariant Graph Neural Operator (EGNO).

This block combines a temporal convolution with an equivariant graph neural
network (EGNN) layer. It preserves equivariance while modeling complex
interactions between nodes in a graph over time.

.. seealso::

**Original reference**
Xu, M., Han, J., Lou, A., Kossaifi, J., Ramanathan, A., Azizzadenesheli,
K., Leskovec, J., Ermon, S., Anandkumar, A. (2024).
*Equivariant Graph Neural Operator for Modeling 3D Dynamics*
DOI: `arXiv preprint arXiv:2401.11037.
<https://arxiv.org/abs/2401.11037>`_
"""

def __init__(
self,
node_feature_dim,
edge_feature_dim,
pos_dim,
modes,
hidden_dim=64,
n_message_layers=2,
n_update_layers=2,
activation=torch.nn.SiLU,
aggr="add",
node_dim=-2,
flow="source_to_target",
):
"""
Initialization of the :class:`EquivariantGraphNeuralOperatorBlock`
class.

:param int node_feature_dim: The dimension of the node features.
:param int edge_feature_dim: The dimension of the edge features.
:param int pos_dim: The dimension of the position features.
:param int modes: The number of Fourier modes to use in the temporal
convolution.
:param int hidden_dim: The dimension of the hidden features.
Default is 64.
:param int n_message_layers: The number of layers in the message
network. Default is 2.
:param int n_update_layers: The number of layers in the update network.
Default is 2.
:param torch.nn.Module activation: The activation function.
Default is :class:`torch.nn.SiLU`.
:param str aggr: The aggregation scheme to use for message passing.
Available options are "add", "mean", "min", "max", "mul".
See :class:`torch_geometric.nn.MessagePassing` for more details.
Default is "add".
:param int node_dim: The axis along which to propagate. Default is -2.
:param str flow: The direction of message passing. Available options
are "source_to_target" and "target_to_source".
The "source_to_target" flow means that messages are sent from
the source node to the target node, while the "target_to_source"
flow means that messages are sent from the target node to the
source node. See :class:`torch_geometric.nn.MessagePassing` for more
details. Default is "source_to_target".
:raises AssertionError: If ``modes`` is not a positive integer.
"""
super().__init__()

# Check consistency
check_positive_integer(modes, strict=True)

# Initialization
self.modes = modes

# Temporal convolution weights - real and imaginary parts
self.weight_scalar_r = torch.nn.Parameter(
torch.rand(node_feature_dim, node_feature_dim, modes)
)
self.weight_scalar_i = torch.nn.Parameter(
torch.rand(node_feature_dim, node_feature_dim, modes)
)
self.weight_vector_r = torch.nn.Parameter(torch.rand(2, 2, modes) * 0.1)
self.weight_vector_i = torch.nn.Parameter(torch.rand(2, 2, modes) * 0.1)

# EGNN block
self.egnn = EnEquivariantNetworkBlock(
node_feature_dim=node_feature_dim,
edge_feature_dim=edge_feature_dim,
pos_dim=pos_dim,
use_velocity=True,
hidden_dim=hidden_dim,
n_message_layers=n_message_layers,
n_update_layers=n_update_layers,
activation=activation,
aggr=aggr,
node_dim=node_dim,
flow=flow,
)

def forward(self, x, pos, vel, edge_index, edge_attr=None):
"""
Forward pass of the Equivariant Graph Neural Operator block.

:param x: The node feature tensor of shape
``[time_steps, num_nodes, node_feature_dim]``.
:type x: torch.Tensor | LabelTensor
:param pos: The node position tensor (Euclidean coordinates) of shape
``[time_steps, num_nodes, pos_dim]``.
:type pos: torch.Tensor | LabelTensor
:param vel: The node velocity tensor of shape
``[time_steps, num_nodes, pos_dim]``.
:type vel: torch.Tensor | LabelTensor
:param edge_index: The edge connectivity of shape ``[2, num_edges]``.
:type edge_index: torch.Tensor
:param edge_attr: The edge feature tensor of shape
``[time_steps, num_edges, edge_feature_dim]``. Default is None.
:type edge_attr: torch.Tensor | LabelTensor, optional
:return: The updated node features, positions, and velocities, each with
the same shape as the inputs.
:rtype: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
"""
# Prepare features
center = pos.mean(dim=1, keepdim=True)
vector = torch.stack((pos - center, vel), dim=-1)

# Compute temporal convolution
x = x + self._convolution(
x, "mni, iom -> mno", self.weight_scalar_r, self.weight_scalar_i
)
vector = vector + self._convolution(
vector,
"mndi, iom -> mndo",
self.weight_vector_r,
self.weight_vector_i,
)

# Split position and velocity
pos, vel = vector.unbind(dim=-1)
pos = pos + center

# Reshape to (time * nodes, feature) for egnn
x = x.reshape(-1, x.shape[-1])
pos = pos.reshape(-1, pos.shape[-1])
vel = vel.reshape(-1, vel.shape[-1])
if edge_attr is not None:
edge_attr = edge_attr.reshape(-1, edge_attr.shape[-1])

x, pos, vel = self.egnn(
x=x,
pos=pos,
edge_index=edge_index,
edge_attr=edge_attr,
vel=vel,
)

# Reshape back to (time, nodes, feature)
x = x.reshape(center.shape[0], -1, x.shape[-1])
pos = pos.reshape(center.shape[0], -1, pos.shape[-1])
vel = vel.reshape(center.shape[0], -1, vel.shape[-1])

return x, pos, vel

def _convolution(self, x, einsum_idx, real, img):
"""
Compute the temporal convolution.

:param torch.Tensor x: The input features.
:param str einsum_idx: The indices for the einsum operation.
:param torch.Tensor real: The real part of the convolution weights.
:param torch.Tensor img: The imaginary part of the convolution weights.
:return: The convolved features.
:rtype: torch.Tensor
"""
# Number of modes to use
modes = min(self.modes, (x.shape[0] // 2) + 1)

# Build complex weights
weights = torch.complex(real[..., :modes], img[..., :modes])

# Convolution in Fourier space
fourier = torch.fft.rfftn(x, dim=[0])[:modes]
out = torch.einsum(einsum_idx, fourier, weights)

return torch.fft.irfftn(out, s=x.shape[0], dim=0)
Loading