|
| 1 | +"""Module for the Equivariant Graph Neural Operator block.""" |
| 2 | + |
| 3 | +import torch |
| 4 | +from ....utils import check_positive_integer |
| 5 | +from .en_equivariant_network_block import EnEquivariantNetworkBlock |
| 6 | + |
| 7 | + |
| 8 | +class EquivariantGraphNeuralOperatorBlock(torch.nn.Module): |
| 9 | + """ |
| 10 | + A single block of the Equivariant Graph Neural Operator (EGNO). |
| 11 | +
|
| 12 | + This block combines a temporal convolution with an equivariant graph neural |
| 13 | + network (EGNN) layer. It preserves equivariance while modeling complex |
| 14 | + interactions between nodes in a graph over time. |
| 15 | +
|
| 16 | + .. seealso:: |
| 17 | +
|
| 18 | + **Original reference** |
| 19 | + Xu, M., Han, J., Lou, A., Kossaifi, J., Ramanathan, A., Azizzadenesheli, |
| 20 | + K., Leskovec, J., Ermon, S., Anandkumar, A. (2024). |
| 21 | + *Equivariant Graph Neural Operator for Modeling 3D Dynamics* |
| 22 | + DOI: `arXiv preprint arXiv:2401.11037. |
| 23 | + <https://arxiv.org/abs/2401.11037>`_ |
| 24 | + """ |
| 25 | + |
| 26 | + def __init__( |
| 27 | + self, |
| 28 | + node_feature_dim, |
| 29 | + edge_feature_dim, |
| 30 | + pos_dim, |
| 31 | + modes, |
| 32 | + hidden_dim=64, |
| 33 | + n_message_layers=2, |
| 34 | + n_update_layers=2, |
| 35 | + activation=torch.nn.SiLU, |
| 36 | + aggr="add", |
| 37 | + node_dim=-2, |
| 38 | + flow="source_to_target", |
| 39 | + ): |
| 40 | + """ |
| 41 | + Initialization of the :class:`EquivariantGraphNeuralOperatorBlock` |
| 42 | + class. |
| 43 | +
|
| 44 | + :param int node_feature_dim: The dimension of the node features. |
| 45 | + :param int edge_feature_dim: The dimension of the edge features. |
| 46 | + :param int pos_dim: The dimension of the position features. |
| 47 | + :param int modes: The number of Fourier modes to use in the temporal |
| 48 | + convolution. |
| 49 | + :param int hidden_dim: The dimension of the hidden features. |
| 50 | + Default is 64. |
| 51 | + :param int n_message_layers: The number of layers in the message |
| 52 | + network. Default is 2. |
| 53 | + :param int n_update_layers: The number of layers in the update network. |
| 54 | + Default is 2. |
| 55 | + :param torch.nn.Module activation: The activation function. |
| 56 | + Default is :class:`torch.nn.SiLU`. |
| 57 | + :param str aggr: The aggregation scheme to use for message passing. |
| 58 | + Available options are "add", "mean", "min", "max", "mul". |
| 59 | + See :class:`torch_geometric.nn.MessagePassing` for more details. |
| 60 | + Default is "add". |
| 61 | + :param int node_dim: The axis along which to propagate. Default is -2. |
| 62 | + :param str flow: The direction of message passing. Available options |
| 63 | + are "source_to_target" and "target_to_source". |
| 64 | + The "source_to_target" flow means that messages are sent from |
| 65 | + the source node to the target node, while the "target_to_source" |
| 66 | + flow means that messages are sent from the target node to the |
| 67 | + source node. See :class:`torch_geometric.nn.MessagePassing` for more |
| 68 | + details. Default is "source_to_target". |
| 69 | + :raises AssertionError: If ``modes`` is not a positive integer. |
| 70 | + """ |
| 71 | + super().__init__() |
| 72 | + |
| 73 | + # Check consistency |
| 74 | + check_positive_integer(modes, strict=True) |
| 75 | + |
| 76 | + # Initialization |
| 77 | + self.modes = modes |
| 78 | + |
| 79 | + # Temporal convolution weights - real and imaginary parts |
| 80 | + self.weight_scalar_r = torch.nn.Parameter( |
| 81 | + torch.rand(node_feature_dim, node_feature_dim, modes) |
| 82 | + ) |
| 83 | + self.weight_scalar_i = torch.nn.Parameter( |
| 84 | + torch.rand(node_feature_dim, node_feature_dim, modes) |
| 85 | + ) |
| 86 | + self.weight_vector_r = torch.nn.Parameter(torch.rand(2, 2, modes) * 0.1) |
| 87 | + self.weight_vector_i = torch.nn.Parameter(torch.rand(2, 2, modes) * 0.1) |
| 88 | + |
| 89 | + # EGNN block |
| 90 | + self.egnn = EnEquivariantNetworkBlock( |
| 91 | + node_feature_dim=node_feature_dim, |
| 92 | + edge_feature_dim=edge_feature_dim, |
| 93 | + pos_dim=pos_dim, |
| 94 | + use_velocity=True, |
| 95 | + hidden_dim=hidden_dim, |
| 96 | + n_message_layers=n_message_layers, |
| 97 | + n_update_layers=n_update_layers, |
| 98 | + activation=activation, |
| 99 | + aggr=aggr, |
| 100 | + node_dim=node_dim, |
| 101 | + flow=flow, |
| 102 | + ) |
| 103 | + |
| 104 | + def forward(self, x, pos, vel, edge_index, edge_attr=None): |
| 105 | + """ |
| 106 | + Forward pass of the Equivariant Graph Neural Operator block. |
| 107 | +
|
| 108 | + :param x: The node features. |
| 109 | + :type x: torch.Tensor | LabelTensor |
| 110 | + :param pos: The euclidean coordinates of the nodes. |
| 111 | + :type pos: torch.Tensor | LabelTensor |
| 112 | + :param vel: The velocity of the nodes. |
| 113 | + :type vel: torch.Tensor | LabelTensor |
| 114 | + :param torch.Tensor edge_index: The edge indices. |
| 115 | + :param edge_attr: The edge attributes. Default is None. |
| 116 | + :type edge_attr: torch.Tensor | LabelTensor |
| 117 | + :return: The updated node features, positions, and velocities. |
| 118 | + :rtype: tuple(torch.Tensor, torch.Tensor, torch.Tensor) |
| 119 | + """ |
| 120 | + # Prepare features |
| 121 | + center = pos.mean(dim=1, keepdim=True) |
| 122 | + vector = torch.stack((pos - center, vel), dim=-1) |
| 123 | + |
| 124 | + # Compute temporal convolution |
| 125 | + x = x + self._convolution( |
| 126 | + x, "mni, iom -> mno", self.weight_scalar_r, self.weight_scalar_i |
| 127 | + ) |
| 128 | + vector = vector + self._convolution( |
| 129 | + vector, |
| 130 | + "mndi, iom -> mndo", |
| 131 | + self.weight_vector_r, |
| 132 | + self.weight_vector_i, |
| 133 | + ) |
| 134 | + |
| 135 | + # Split position and velocity |
| 136 | + pos, vel = vector.unbind(dim=-1) |
| 137 | + pos = pos + center |
| 138 | + |
| 139 | + # Reshape to (time * nodes, feature) for egnn |
| 140 | + x = x.reshape(-1, x.shape[-1]) |
| 141 | + pos = pos.reshape(-1, pos.shape[-1]) |
| 142 | + vel = vel.reshape(-1, vel.shape[-1]) |
| 143 | + if edge_attr is not None: |
| 144 | + edge_attr = edge_attr.reshape(-1, edge_attr.shape[-1]) |
| 145 | + |
| 146 | + x, pos, vel = self.egnn( |
| 147 | + x=x, |
| 148 | + pos=pos, |
| 149 | + edge_index=edge_index, |
| 150 | + edge_attr=edge_attr, |
| 151 | + vel=vel, |
| 152 | + ) |
| 153 | + |
| 154 | + # Reshape back to (time, nodes, feature) |
| 155 | + x = x.reshape(center.shape[0], -1, x.shape[-1]) |
| 156 | + pos = pos.reshape(center.shape[0], -1, pos.shape[-1]) |
| 157 | + vel = vel.reshape(center.shape[0], -1, vel.shape[-1]) |
| 158 | + |
| 159 | + return x, pos, vel |
| 160 | + |
| 161 | + def _convolution(self, x, einsum_idx, real, img): |
| 162 | + """ |
| 163 | + Compute the temporal convolution. |
| 164 | +
|
| 165 | + :param torch.Tensor x: The input features. |
| 166 | + :param str einsum_idx: The indices for the einsum operation. |
| 167 | + :param torch.Tensor real: The real part of the convolution weights. |
| 168 | + :param torch.Tensor img: The imaginary part of the convolution weights. |
| 169 | + :return: The convolved features. |
| 170 | + :rtype: torch.Tensor |
| 171 | + """ |
| 172 | + # Number of modes to use |
| 173 | + modes = min(self.modes, (x.shape[0] // 2) + 1) |
| 174 | + |
| 175 | + # Build complex weights |
| 176 | + weights = torch.complex(real[..., :modes], img[..., :modes]) |
| 177 | + |
| 178 | + # Convolution in Fourier space |
| 179 | + fourier = torch.fft.rfftn(x, dim=[0])[:modes] |
| 180 | + out = torch.einsum(einsum_idx, fourier, weights) |
| 181 | + |
| 182 | + return torch.fft.irfftn(out, s=x.shape[0], dim=0) |
0 commit comments