diff --git a/tensorframes/lframes/classical_lframes.py b/tensorframes/lframes/classical_lframes.py index 6f198e0..0ca99c2 100644 --- a/tensorframes/lframes/classical_lframes.py +++ b/tensorframes/lframes/classical_lframes.py @@ -3,7 +3,8 @@ import torch from e3nn.o3 import rand_matrix from torch import Tensor -from torch_geometric.nn import knn +from torch_geometric.nn import knn, radius +from torch_geometric.utils import scatter from tensorframes.lframes.gram_schmidt import gram_schmidt from tensorframes.lframes.lframes import LFrames @@ -17,6 +18,86 @@ def forward(self, *args, **kwargs) -> LFrames: assert NotImplementedError, "Subclasses must implement this method." +class PCALFrames(LFramesPredictionModule): + """Computes local frames using PCA.""" + + def __init__( + self, r: float, max_num_neighbors: int = 64, exceptional_choice: str = "random" + ) -> None: + """Initializes an instance of the PCALFrames class. + + Args: + radius (float): The radius for the PCA computation. + max_neighbors (int, optional): The maximum number of neighbors to consider. Defaults to 10. + exceptional_choice (str, optional): The choice for exceptional case (with zero neighbors). Defaults to "random". + """ + super().__init__() + self.r = r + self.max_num_neighbors = max_num_neighbors + self.exceptional_choice = exceptional_choice + + def forward( + self, pos: Tensor, idx: Union[Tensor, None] = None, batch: Union[Tensor, None] = None + ) -> LFrames: + """Forward pass of the LFrames module. + + Args: + pos (Tensor): The input tensor of shape (N, D) representing the positions of N points in D-dimensional space. + idx (Tensor, optional): The indices of the points to consider. If None, all points are considered. Defaults to None. + batch (Tensor, optional): The batch indices of the points. If None, a batch of zeros is used. Defaults to None. + + Returns: + LFrames: The computed local frames as an instance of the LFrames class. + """ + if idx is None: + idx = torch.ones(pos.shape[0], dtype=torch.bool, device=pos.device) + + if batch is None: + batch = torch.zeros(pos.shape[0], dtype=torch.int64, device=pos.device) + + row, col = radius( + pos, pos[idx], self.r, batch, batch[idx], max_num_neighbors=self.max_num_neighbors + ) + # print("average number of neighbors: ", len(row) / len(idx), "max_num_neighbors", self.max_num_neighbors) + edge_index = torch.stack([col, row], dim=0) + edge_vec = pos[edge_index[0]] - pos[edge_index[1]] # (N_edges, dim) + + cov_matrices = scatter( + edge_vec.unsqueeze(-1) * edge_vec.unsqueeze(-2), + edge_index[1], + dim=0, + ) + + # compute the PCA: + _, eigenvectors = torch.linalg.eigh(cov_matrices) + + # choose the directions to be o3 equivariant: + eigenvectors = eigenvectors.transpose(-1, -2) # (N, n_vec, dim_vec) + + # for each eigenvector compute average dot product with edge vectors: + dots = torch.einsum("ijk,ik->ij", eigenvectors[edge_index[1]], edge_vec) + summed_dots = scatter(dots, edge_index[1], dim=0) # (N, n_vec) + sign_mask = (summed_dots > 0).float() * 2 - 1 + eigenvectors = eigenvectors * sign_mask.unsqueeze(-1) + + # check how many neighbors each point has: + num_neighbors = scatter( + torch.ones_like(edge_index[0]), edge_index[1], dim=0, reduce="sum" + ).float() + no_neighbors_mask = num_neighbors <= 1 + if self.exceptional_choice == "random": + random_lframes = RandomLFrames()(pos[no_neighbors_mask]).matrices + eigenvectors[no_neighbors_mask] = random_lframes + elif self.exceptional_choice == "zero": + eigenvectors[no_neighbors_mask] = 0.0 + else: + assert ( + NotImplementedError + ), f"exceptional_choice {self.exceptional_choice} not implemented" + + return LFrames(eigenvectors) + + class ThreeNNLFrames(LFramesPredictionModule): """Computes local frames using the 3-nearest neighbors. diff --git a/tensorframes/nn/tfmessage_passing.py b/tensorframes/nn/tfmessage_passing.py index be89f49..28df707 100644 --- a/tensorframes/nn/tfmessage_passing.py +++ b/tensorframes/nn/tfmessage_passing.py @@ -12,7 +12,9 @@ class TFMessagePassing(MessagePassing): https://arxiv.org/abs/2405.15389v1 """ - def __init__(self, params_dict: Dict[str, Dict[str, Any]], aggr="add") -> None: + def __init__( + self, params_dict: Dict[str, Dict[str, Any]], aggr="add", *args, **kwargs + ) -> None: """Initializes a new instance of the TFMessagePassing class. Args: @@ -28,7 +30,7 @@ def __init__(self, params_dict: Dict[str, Dict[str, Any]], aggr="add") -> None: }, } """ - super().__init__(aggr=aggr) + super().__init__(aggr=aggr, *args, **kwargs) self.params_dict = params_dict diff --git a/tests/lframes/test_pca_lframes.py b/tests/lframes/test_pca_lframes.py new file mode 100644 index 0000000..813d59d --- /dev/null +++ b/tests/lframes/test_pca_lframes.py @@ -0,0 +1,74 @@ +import numpy as np +import torch +from sklearn.decomposition import PCA + +from tensorframes.lframes.classical_lframes import PCALFrames, RandomGlobalLFrames +from tensorframes.reps.tensorreps import TensorReps + + +def test_pca_lframes(): + """Tests pca based lframes.""" + num_points = 100 + pos = torch.rand(num_points, 3) + batch = torch.tensor([1, 2]).repeat_interleave(num_points // 2) + + pca_lframes = PCALFrames(r=1, max_num_neighbors=16) + lframes = pca_lframes(pos=pos, idx=None, batch=batch) + + dets = torch.linalg.det(lframes.matrices) + print("dets: ", dets.mean(), dets.max(), dets.min()) + print("dets counts: ", torch.unique(torch.round(dets, decimals=2), return_counts=True)) + assert torch.allclose(torch.linalg.det(lframes.matrices).abs(), torch.ones(num_points)) + idents = torch.bmm(lframes.matrices, lframes.matrices.transpose(1, 2)) + print("diff from identity: ", (idents - torch.eye(3).expand(num_points, -1, -1)).abs().max()) + assert torch.allclose( + torch.bmm(lframes.matrices, lframes.matrices.transpose(1, 2)), + torch.eye(3).expand(num_points, -1, -1), + atol=1e-5, + ) + + # create a test case to check against PCA: + num_points = 5000 + A = np.random.randn(3, 3) + pos = torch.from_numpy( + np.random.multivariate_normal(mean=np.zeros(3), cov=A @ A.T, size=num_points) + ).float() + pos[0] = 0.0 # the mean + + pca_lframes = PCALFrames(r=100, max_num_neighbors=num_points) + lframes = pca_lframes(pos=pos, idx=None, batch=None) + + print("lframes: ", lframes.matrices[0]) + + # check that the first lframe is the PCA frame: + pca = PCA(n_components=3) + pca.fit(pos.numpy()) + print("pca components: ", pca.components_) + + # divide them and see if they are the same up to a sign: + ratio = torch.from_numpy(pca.components_).flip(dims=(0,)) / lframes.matrices[0] + print("ratio: ", ratio) + assert torch.allclose(ratio.abs(), torch.ones(3), atol=1e-3) + + # ckeck that lframes are equivariant: + num_points = 100 + pos = torch.rand(num_points, 3) + batch = torch.tensor([1, 2]).repeat_interleave(num_points // 2) + + lframes_learner = PCALFrames(r=1, max_num_neighbors=num_points) + lframes1 = lframes_learner(pos=pos, batch=batch) + + # check that x is invariant and lframes are equivariant: + random_trafo = RandomGlobalLFrames()(pos=pos) + pos_rot = TensorReps("1x1").get_transform_class()(coeffs=pos, basis_change=random_trafo) + lframes2 = lframes_learner(pos=pos_rot, batch=batch) + + # check that lframes are equivariant: + lframes_matrices1 = torch.bmm(lframes1.matrices, random_trafo.matrices.transpose(1, 2)) + diff = (lframes2.matrices - lframes_matrices1).abs() + print("frames max diff", diff.max(), "mean diff", diff.mean()) + assert torch.allclose(lframes2.matrices, lframes_matrices1, atol=1e-3) + + +if __name__ == "__main__": + test_pca_lframes()