Skip to content

Commit ac542f6

Browse files
add egno
Co-authored-by: avisquid <[email protected]>
1 parent 87c5c6a commit ac542f6

File tree

11 files changed

+868
-39
lines changed

11 files changed

+868
-39
lines changed

docs/source/_rst/_code.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ Models
105105
GraphNeuralOperator <model/graph_neural_operator.rst>
106106
GraphNeuralKernel <model/graph_neural_operator_integral_kernel.rst>
107107
PirateNet <model/pirate_network.rst>
108+
EquivariantGraphNeuralOperator <model/equivariant_graph_neural_operator.rst>
108109

109110
Blocks
110111
-------------
@@ -134,6 +135,7 @@ Message Passing
134135
E(n) Equivariant Network Block <model/block/message_passing/en_equivariant_network_block.rst>
135136
Interaction Network Block <model/block/message_passing/interaction_network_block.rst>
136137
Radial Field Network Block <model/block/message_passing/radial_field_network_block.rst>
138+
EquivariantGraphNeuralOperatorBlock <model/block/message_passing/equivariant_graph_neural_operator_block.rst>
137139

138140

139141
Reduction and Embeddings
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
EquivariantGraphNeuralOperatorBlock
2+
=====================================
3+
.. currentmodule:: pina.model.block.message_passing.equivariant_graph_neural_operator_block
4+
5+
.. autoclass:: EquivariantGraphNeuralOperatorBlock
6+
:members:
7+
:show-inheritance:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
EquivariantGraphNeuralOperator
2+
=================================
3+
.. currentmodule:: pina.model.equivariant_graph_neural_operator
4+
5+
.. autoclass:: EquivariantGraphNeuralOperator
6+
:members:
7+
:show-inheritance:

pina/model/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"Spline",
1515
"GraphNeuralOperator",
1616
"PirateNet",
17+
"EquivariantGraphNeuralOperator",
1718
]
1819

1920
from .feed_forward import FeedForward, ResidualFeedForward
@@ -26,3 +27,4 @@
2627
from .spline import Spline
2728
from .graph_neural_operator import GraphNeuralOperator
2829
from .pirate_network import PirateNet
30+
from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator

pina/model/block/message_passing/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,13 @@
55
"DeepTensorNetworkBlock",
66
"EnEquivariantNetworkBlock",
77
"RadialFieldNetworkBlock",
8+
"EquivariantGraphNeuralOperatorBlock",
89
]
910

1011
from .interaction_network_block import InteractionNetworkBlock
1112
from .deep_tensor_network_block import DeepTensorNetworkBlock
1213
from .en_equivariant_network_block import EnEquivariantNetworkBlock
1314
from .radial_field_network_block import RadialFieldNetworkBlock
15+
from .equivariant_graph_neural_operator_block import (
16+
EquivariantGraphNeuralOperatorBlock,
17+
)

pina/model/block/message_passing/en_equivariant_network_block.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from torch_geometric.nn import MessagePassing
55
from torch_geometric.utils import degree
6-
from ....utils import check_positive_integer
6+
from ....utils import check_positive_integer, check_consistency
77
from ....model import FeedForward
88

99

@@ -27,6 +27,12 @@ class EnEquivariantNetworkBlock(MessagePassing):
2727
positions are updated by adding the incoming messages divided by the
2828
degree of the recipient node.
2929
30+
When velocity features are used, node velocities are passed through a small
31+
MLP to compute updates, which are then combined with the aggregated position
32+
messages. The node positions are updated both by the normalized position
33+
messages and by the updated velocities, ensuring equivariance while
34+
incorporating dynamic information.
35+
3036
.. seealso::
3137
3238
**Original reference** Satorras, V. G., Hoogeboom, E., Welling, M.
@@ -40,6 +46,7 @@ def __init__(
4046
node_feature_dim,
4147
edge_feature_dim,
4248
pos_dim,
49+
use_velocity=False,
4350
hidden_dim=64,
4451
n_message_layers=2,
4552
n_update_layers=2,
@@ -54,6 +61,8 @@ def __init__(
5461
:param int node_feature_dim: The dimension of the node features.
5562
:param int edge_feature_dim: The dimension of the edge features.
5663
:param int pos_dim: The dimension of the position features.
64+
:param bool use_velocity: Whether to use velocity features in the
65+
message passing. Default is False.
5766
:param int hidden_dim: The dimension of the hidden features.
5867
Default is 64.
5968
:param int n_message_layers: The number of layers in the message
@@ -80,6 +89,7 @@ def __init__(
8089
:raises AssertionError: If `hidden_dim` is not a positive integer.
8190
:raises AssertionError: If `n_message_layers` is not a positive integer.
8291
:raises AssertionError: If `n_update_layers` is not a positive integer.
92+
:raises AssertionError: If `use_velocity` is not a boolean.
8393
"""
8494
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
8595

@@ -90,6 +100,10 @@ def __init__(
90100
check_positive_integer(hidden_dim, strict=True)
91101
check_positive_integer(n_message_layers, strict=True)
92102
check_positive_integer(n_update_layers, strict=True)
103+
check_consistency(use_velocity, bool)
104+
105+
# Initialization
106+
self.use_velocity = use_velocity
93107

94108
# Layer for computing the message
95109
self.message_net = FeedForward(
@@ -119,7 +133,17 @@ def __init__(
119133
func=activation,
120134
)
121135

122-
def forward(self, x, pos, edge_index, edge_attr=None):
136+
# If velocity is used, instantiate layer for velocity updates
137+
if self.use_velocity:
138+
self.update_vel_net = FeedForward(
139+
input_dimensions=node_feature_dim,
140+
output_dimensions=1,
141+
inner_size=hidden_dim,
142+
n_layers=n_update_layers,
143+
func=activation,
144+
)
145+
146+
def forward(self, x, pos, edge_index, edge_attr=None, vel=None):
123147
"""
124148
Forward pass of the block, triggering the message-passing routine.
125149
@@ -130,11 +154,19 @@ def forward(self, x, pos, edge_index, edge_attr=None):
130154
:param torch.Tensor edge_index: The edge indices.
131155
:param edge_attr: The edge attributes. Default is None.
132156
:type edge_attr: torch.Tensor | LabelTensor
157+
:param vel: The velocity of the nodes. Default is None.
158+
:type vel: torch.Tensor | LabelTensor
133159
:return: The updated node features and node positions.
134160
:rtype: tuple(torch.Tensor, torch.Tensor)
161+
:raises: ValueError: If ``use_velocity`` is True and ``vel`` is None.
135162
"""
163+
if self.use_velocity and vel is None:
164+
raise ValueError(
165+
"Velocity features are enabled, but no velocity is passed."
166+
)
167+
136168
return self.propagate(
137-
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr
169+
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr, vel=vel
138170
)
139171

140172
def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
@@ -202,28 +234,36 @@ def aggregate(self, inputs, index, ptr=None, dim_size=None):
202234

203235
return agg_message, agg_m_ij
204236

205-
def update(self, aggregated_inputs, x, pos, edge_index):
237+
def update(self, aggregated_inputs, x, pos, edge_index, vel):
206238
"""
207-
Update the node features and the node coordinates with the received
208-
messages.
239+
Update node features, positions, and optionally velocities.
209240
210241
:param tuple(torch.Tensor) aggregated_inputs: The messages to be passed.
211242
:param x: The node features.
212243
:type x: torch.Tensor | LabelTensor
213244
:param pos: The euclidean coordinates of the nodes.
214245
:type pos: torch.Tensor | LabelTensor
215246
:param torch.Tensor edge_index: The edge indices.
247+
:param vel: The velocity of the nodes.
248+
:type vel: torch.Tensor | LabelTensor
216249
:return: The updated node features and node positions.
217-
:rtype: tuple(torch.Tensor, torch.Tensor)
250+
:rtype: tuple(torch.Tensor, torch.Tensor) |
251+
tuple(torch.Tensor, torch.Tensor, torch.Tensor)
218252
"""
219253
# aggregated_inputs is tuple (agg_message, agg_m_ij)
220254
agg_message, agg_m_ij = aggregated_inputs
221255

256+
# Degree for normalization of position updates
257+
c = degree(edge_index[1], pos.shape[0]).unsqueeze(-1).clamp(min=1)
258+
259+
# If velocity is used, update it and use it to update positions
260+
if self.use_velocity:
261+
vel = self.update_vel_net(x) * vel + agg_message / c
262+
222263
# Update node features with aggregated m_ij
223264
x = self.update_feat_net(torch.cat((x, agg_m_ij), dim=-1))
224265

225-
# Degree for normalization of position updates
226-
c = degree(edge_index[1], pos.shape[0]).unsqueeze(-1).clamp(min=1)
227-
pos = pos + agg_message / c
266+
# Update positions with aggregated messages m_ij and velocities
267+
pos = pos + agg_message / c + (vel if self.use_velocity else 0)
228268

229-
return x, pos
269+
return (x, pos, vel) if self.use_velocity else (x, pos)
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
"""Module for the Equivariant Graph Neural Operator block."""
2+
3+
import torch
4+
from ....utils import check_positive_integer
5+
from .en_equivariant_network_block import EnEquivariantNetworkBlock
6+
7+
8+
class EquivariantGraphNeuralOperatorBlock(torch.nn.Module):
9+
"""
10+
A single block of the Equivariant Graph Neural Operator (EGNO).
11+
12+
This block combines a temporal convolution with an equivariant graph neural
13+
network (EGNN) layer. It preserves equivariance while modeling complex
14+
interactions between nodes in a graph over time.
15+
16+
.. seealso::
17+
18+
**Original reference**
19+
Xu, M., Han, J., Lou, A., Kossaifi, J., Ramanathan, A., Azizzadenesheli,
20+
K., Leskovec, J., Ermon, S., Anandkumar, A. (2024).
21+
*Equivariant Graph Neural Operator for Modeling 3D Dynamics*
22+
DOI: `arXiv preprint arXiv:2401.11037.
23+
<https://arxiv.org/abs/2401.11037>`_
24+
"""
25+
26+
def __init__(
27+
self,
28+
node_feature_dim,
29+
edge_feature_dim,
30+
pos_dim,
31+
modes,
32+
hidden_dim=64,
33+
n_message_layers=2,
34+
n_update_layers=2,
35+
activation=torch.nn.SiLU,
36+
aggr="add",
37+
node_dim=-2,
38+
flow="source_to_target",
39+
):
40+
"""
41+
Initialization of the :class:`EquivariantGraphNeuralOperatorBlock`
42+
class.
43+
44+
:param int node_feature_dim: The dimension of the node features.
45+
:param int edge_feature_dim: The dimension of the edge features.
46+
:param int pos_dim: The dimension of the position features.
47+
:param int modes: The number of Fourier modes to use in the temporal
48+
convolution.
49+
:param int hidden_dim: The dimension of the hidden features.
50+
Default is 64.
51+
:param int n_message_layers: The number of layers in the message
52+
network. Default is 2.
53+
:param int n_update_layers: The number of layers in the update network.
54+
Default is 2.
55+
:param torch.nn.Module activation: The activation function.
56+
Default is :class:`torch.nn.SiLU`.
57+
:param str aggr: The aggregation scheme to use for message passing.
58+
Available options are "add", "mean", "min", "max", "mul".
59+
See :class:`torch_geometric.nn.MessagePassing` for more details.
60+
Default is "add".
61+
:param int node_dim: The axis along which to propagate. Default is -2.
62+
:param str flow: The direction of message passing. Available options
63+
are "source_to_target" and "target_to_source".
64+
The "source_to_target" flow means that messages are sent from
65+
the source node to the target node, while the "target_to_source"
66+
flow means that messages are sent from the target node to the
67+
source node. See :class:`torch_geometric.nn.MessagePassing` for more
68+
details. Default is "source_to_target".
69+
:raises AssertionError: If ``modes`` is not a positive integer.
70+
"""
71+
super().__init__()
72+
73+
# Check consistency
74+
check_positive_integer(modes, strict=True)
75+
76+
# Initialization
77+
self.modes = modes
78+
79+
# Temporal convolution weights - real and imaginary parts
80+
self.weight_scalar_r = torch.nn.Parameter(
81+
torch.rand(node_feature_dim, node_feature_dim, modes)
82+
)
83+
self.weight_scalar_i = torch.nn.Parameter(
84+
torch.rand(node_feature_dim, node_feature_dim, modes)
85+
)
86+
self.weight_vector_r = torch.nn.Parameter(torch.rand(2, 2, modes) * 0.1)
87+
self.weight_vector_i = torch.nn.Parameter(torch.rand(2, 2, modes) * 0.1)
88+
89+
# EGNN block
90+
self.egnn = EnEquivariantNetworkBlock(
91+
node_feature_dim=node_feature_dim,
92+
edge_feature_dim=edge_feature_dim,
93+
pos_dim=pos_dim,
94+
use_velocity=True,
95+
hidden_dim=hidden_dim,
96+
n_message_layers=n_message_layers,
97+
n_update_layers=n_update_layers,
98+
activation=activation,
99+
aggr=aggr,
100+
node_dim=node_dim,
101+
flow=flow,
102+
)
103+
104+
def forward(self, x, pos, vel, edge_index, edge_attr=None):
105+
"""
106+
Forward pass of the Equivariant Graph Neural Operator block.
107+
108+
:param x: The node features.
109+
:type x: torch.Tensor | LabelTensor
110+
:param pos: The euclidean coordinates of the nodes.
111+
:type pos: torch.Tensor | LabelTensor
112+
:param vel: The velocity of the nodes.
113+
:type vel: torch.Tensor | LabelTensor
114+
:param torch.Tensor edge_index: The edge indices.
115+
:param edge_attr: The edge attributes. Default is None.
116+
:type edge_attr: torch.Tensor | LabelTensor
117+
:return: The updated node features, positions, and velocities.
118+
:rtype: tuple(torch.Tensor, torch.Tensor, torch.Tensor)
119+
"""
120+
# Prepare features
121+
center = pos.mean(dim=1, keepdim=True)
122+
vector = torch.stack((pos - center, vel), dim=-1)
123+
124+
# Compute temporal convolution
125+
x = x + self._convolution(
126+
x, "mni, iom -> mno", self.weight_scalar_r, self.weight_scalar_i
127+
)
128+
vector = vector + self._convolution(
129+
vector,
130+
"mndi, iom -> mndo",
131+
self.weight_vector_r,
132+
self.weight_vector_i,
133+
)
134+
135+
# Split position and velocity
136+
pos, vel = vector.unbind(dim=-1)
137+
pos = pos + center
138+
139+
# Reshape to (time * nodes, feature) for egnn
140+
x = x.reshape(-1, x.shape[-1])
141+
pos = pos.reshape(-1, pos.shape[-1])
142+
vel = vel.reshape(-1, vel.shape[-1])
143+
if edge_attr is not None:
144+
edge_attr = edge_attr.reshape(-1, edge_attr.shape[-1])
145+
146+
x, pos, vel = self.egnn(
147+
x=x,
148+
pos=pos,
149+
edge_index=edge_index,
150+
edge_attr=edge_attr,
151+
vel=vel,
152+
)
153+
154+
# Reshape back to (time, nodes, feature)
155+
x = x.reshape(center.shape[0], -1, x.shape[-1])
156+
pos = pos.reshape(center.shape[0], -1, pos.shape[-1])
157+
vel = vel.reshape(center.shape[0], -1, vel.shape[-1])
158+
159+
return x, pos, vel
160+
161+
def _convolution(self, x, einsum_idx, real, img):
162+
"""
163+
Compute the temporal convolution.
164+
165+
:param torch.Tensor x: The input features.
166+
:param str einsum_idx: The indices for the einsum operation.
167+
:param torch.Tensor real: The real part of the convolution weights.
168+
:param torch.Tensor img: The imaginary part of the convolution weights.
169+
:return: The convolved features.
170+
:rtype: torch.Tensor
171+
"""
172+
# Number of modes to use
173+
modes = min(self.modes, (x.shape[0] // 2) + 1)
174+
175+
# Build complex weights
176+
weights = torch.complex(real[..., :modes], img[..., :modes])
177+
178+
# Convolution in Fourier space
179+
fourier = torch.fft.rfftn(x, dim=[0])[:modes]
180+
out = torch.einsum(einsum_idx, fourier, weights)
181+
182+
return torch.fft.irfftn(out, s=x.shape[0], dim=0)

0 commit comments

Comments
 (0)