-
Notifications
You must be signed in to change notification settings - Fork 84
EGNO #602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
EGNO #602
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
7 changes: 7 additions & 0 deletions
7
...ce/_rst/model/block/message_passing/equivariant_graph_neural_operator_block.rst
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| EquivariantGraphNeuralOperatorBlock | ||
| ===================================== | ||
| .. currentmodule:: pina.model.block.message_passing.equivariant_graph_neural_operator_block | ||
|
|
||
| .. autoclass:: EquivariantGraphNeuralOperatorBlock | ||
| :members: | ||
| :show-inheritance: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| EquivariantGraphNeuralOperator | ||
| ================================= | ||
| .. currentmodule:: pina.model.equivariant_graph_neural_operator | ||
|
|
||
| .. autoclass:: EquivariantGraphNeuralOperator | ||
| :members: | ||
| :show-inheritance: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
188 changes: 188 additions & 0 deletions
188
pina/model/block/message_passing/equivariant_graph_neural_operator_block.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,188 @@ | ||
| """Module for the Equivariant Graph Neural Operator block.""" | ||
|
|
||
| import torch | ||
| from ....utils import check_positive_integer | ||
| from .en_equivariant_network_block import EnEquivariantNetworkBlock | ||
|
|
||
|
|
||
| class EquivariantGraphNeuralOperatorBlock(torch.nn.Module): | ||
| """ | ||
| A single block of the Equivariant Graph Neural Operator (EGNO). | ||
|
|
||
| This block combines a temporal convolution with an equivariant graph neural | ||
| network (EGNN) layer. It preserves equivariance while modeling complex | ||
| interactions between nodes in a graph over time. | ||
|
|
||
| .. seealso:: | ||
|
|
||
| **Original reference** | ||
| Xu, M., Han, J., Lou, A., Kossaifi, J., Ramanathan, A., Azizzadenesheli, | ||
| K., Leskovec, J., Ermon, S., Anandkumar, A. (2024). | ||
| *Equivariant Graph Neural Operator for Modeling 3D Dynamics* | ||
| DOI: `arXiv preprint arXiv:2401.11037. | ||
| <https://arxiv.org/abs/2401.11037>`_ | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| node_feature_dim, | ||
| edge_feature_dim, | ||
| pos_dim, | ||
| modes, | ||
| hidden_dim=64, | ||
| n_message_layers=2, | ||
| n_update_layers=2, | ||
| activation=torch.nn.SiLU, | ||
| aggr="add", | ||
| node_dim=-2, | ||
| flow="source_to_target", | ||
| ): | ||
| """ | ||
| Initialization of the :class:`EquivariantGraphNeuralOperatorBlock` | ||
| class. | ||
|
|
||
| :param int node_feature_dim: The dimension of the node features. | ||
| :param int edge_feature_dim: The dimension of the edge features. | ||
| :param int pos_dim: The dimension of the position features. | ||
| :param int modes: The number of Fourier modes to use in the temporal | ||
| convolution. | ||
| :param int hidden_dim: The dimension of the hidden features. | ||
| Default is 64. | ||
| :param int n_message_layers: The number of layers in the message | ||
| network. Default is 2. | ||
| :param int n_update_layers: The number of layers in the update network. | ||
| Default is 2. | ||
| :param torch.nn.Module activation: The activation function. | ||
| Default is :class:`torch.nn.SiLU`. | ||
| :param str aggr: The aggregation scheme to use for message passing. | ||
| Available options are "add", "mean", "min", "max", "mul". | ||
| See :class:`torch_geometric.nn.MessagePassing` for more details. | ||
| Default is "add". | ||
| :param int node_dim: The axis along which to propagate. Default is -2. | ||
| :param str flow: The direction of message passing. Available options | ||
| are "source_to_target" and "target_to_source". | ||
| The "source_to_target" flow means that messages are sent from | ||
| the source node to the target node, while the "target_to_source" | ||
| flow means that messages are sent from the target node to the | ||
| source node. See :class:`torch_geometric.nn.MessagePassing` for more | ||
| details. Default is "source_to_target". | ||
| :raises AssertionError: If ``modes`` is not a positive integer. | ||
| """ | ||
| super().__init__() | ||
|
|
||
| # Check consistency | ||
| check_positive_integer(modes, strict=True) | ||
|
|
||
| # Initialization | ||
| self.modes = modes | ||
|
|
||
| # Temporal convolution weights - real and imaginary parts | ||
| self.weight_scalar_r = torch.nn.Parameter( | ||
| torch.rand(node_feature_dim, node_feature_dim, modes) | ||
| ) | ||
| self.weight_scalar_i = torch.nn.Parameter( | ||
| torch.rand(node_feature_dim, node_feature_dim, modes) | ||
| ) | ||
| self.weight_vector_r = torch.nn.Parameter(torch.rand(2, 2, modes) * 0.1) | ||
| self.weight_vector_i = torch.nn.Parameter(torch.rand(2, 2, modes) * 0.1) | ||
|
|
||
| # EGNN block | ||
| self.egnn = EnEquivariantNetworkBlock( | ||
| node_feature_dim=node_feature_dim, | ||
| edge_feature_dim=edge_feature_dim, | ||
| pos_dim=pos_dim, | ||
| use_velocity=True, | ||
| hidden_dim=hidden_dim, | ||
| n_message_layers=n_message_layers, | ||
| n_update_layers=n_update_layers, | ||
| activation=activation, | ||
| aggr=aggr, | ||
| node_dim=node_dim, | ||
| flow=flow, | ||
| ) | ||
|
|
||
| def forward(self, x, pos, vel, edge_index, edge_attr=None): | ||
| """ | ||
| Forward pass of the Equivariant Graph Neural Operator block. | ||
|
|
||
| :param x: The node feature tensor of shape | ||
| ``[time_steps, num_nodes, node_feature_dim]``. | ||
| :type x: torch.Tensor | LabelTensor | ||
| :param pos: The node position tensor (Euclidean coordinates) of shape | ||
| ``[time_steps, num_nodes, pos_dim]``. | ||
| :type pos: torch.Tensor | LabelTensor | ||
| :param vel: The node velocity tensor of shape | ||
| ``[time_steps, num_nodes, pos_dim]``. | ||
| :type vel: torch.Tensor | LabelTensor | ||
| :param edge_index: The edge connectivity of shape ``[2, num_edges]``. | ||
| :type edge_index: torch.Tensor | ||
| :param edge_attr: The edge feature tensor of shape | ||
| ``[time_steps, num_edges, edge_feature_dim]``. Default is None. | ||
| :type edge_attr: torch.Tensor | LabelTensor, optional | ||
| :return: The updated node features, positions, and velocities, each with | ||
| the same shape as the inputs. | ||
| :rtype: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | ||
| """ | ||
| # Prepare features | ||
| center = pos.mean(dim=1, keepdim=True) | ||
| vector = torch.stack((pos - center, vel), dim=-1) | ||
|
|
||
| # Compute temporal convolution | ||
| x = x + self._convolution( | ||
| x, "mni, iom -> mno", self.weight_scalar_r, self.weight_scalar_i | ||
| ) | ||
| vector = vector + self._convolution( | ||
| vector, | ||
| "mndi, iom -> mndo", | ||
| self.weight_vector_r, | ||
| self.weight_vector_i, | ||
| ) | ||
|
|
||
| # Split position and velocity | ||
| pos, vel = vector.unbind(dim=-1) | ||
| pos = pos + center | ||
|
|
||
| # Reshape to (time * nodes, feature) for egnn | ||
| x = x.reshape(-1, x.shape[-1]) | ||
| pos = pos.reshape(-1, pos.shape[-1]) | ||
| vel = vel.reshape(-1, vel.shape[-1]) | ||
| if edge_attr is not None: | ||
| edge_attr = edge_attr.reshape(-1, edge_attr.shape[-1]) | ||
|
|
||
| x, pos, vel = self.egnn( | ||
| x=x, | ||
| pos=pos, | ||
| edge_index=edge_index, | ||
| edge_attr=edge_attr, | ||
| vel=vel, | ||
| ) | ||
|
|
||
| # Reshape back to (time, nodes, feature) | ||
| x = x.reshape(center.shape[0], -1, x.shape[-1]) | ||
| pos = pos.reshape(center.shape[0], -1, pos.shape[-1]) | ||
| vel = vel.reshape(center.shape[0], -1, vel.shape[-1]) | ||
|
|
||
| return x, pos, vel | ||
|
|
||
| def _convolution(self, x, einsum_idx, real, img): | ||
GiovanniCanali marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Compute the temporal convolution. | ||
|
|
||
| :param torch.Tensor x: The input features. | ||
| :param str einsum_idx: The indices for the einsum operation. | ||
| :param torch.Tensor real: The real part of the convolution weights. | ||
| :param torch.Tensor img: The imaginary part of the convolution weights. | ||
| :return: The convolved features. | ||
| :rtype: torch.Tensor | ||
| """ | ||
| # Number of modes to use | ||
| modes = min(self.modes, (x.shape[0] // 2) + 1) | ||
|
|
||
| # Build complex weights | ||
| weights = torch.complex(real[..., :modes], img[..., :modes]) | ||
|
|
||
| # Convolution in Fourier space | ||
| fourier = torch.fft.rfftn(x, dim=[0])[:modes] | ||
| out = torch.einsum(einsum_idx, fourier, weights) | ||
|
|
||
| return torch.fft.irfftn(out, s=x.shape[0], dim=0) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.