Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 82 additions & 1 deletion tensorframes/lframes/classical_lframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import torch
from e3nn.o3 import rand_matrix
from torch import Tensor
from torch_geometric.nn import knn
from torch_geometric.nn import knn, radius
from torch_geometric.utils import scatter

from tensorframes.lframes.gram_schmidt import gram_schmidt
from tensorframes.lframes.lframes import LFrames
Expand All @@ -17,6 +18,86 @@ def forward(self, *args, **kwargs) -> LFrames:
assert NotImplementedError, "Subclasses must implement this method."


class PCALFrames(LFramesPredictionModule):
"""Computes local frames using PCA."""

def __init__(
self, r: float, max_num_neighbors: int = 64, exceptional_choice: str = "random"
) -> None:
"""Initializes an instance of the PCALFrames class.

Args:
radius (float): The radius for the PCA computation.
max_neighbors (int, optional): The maximum number of neighbors to consider. Defaults to 10.
exceptional_choice (str, optional): The choice for exceptional case (with zero neighbors). Defaults to "random".
"""
super().__init__()
self.r = r
self.max_num_neighbors = max_num_neighbors
self.exceptional_choice = exceptional_choice

def forward(
self, pos: Tensor, idx: Union[Tensor, None] = None, batch: Union[Tensor, None] = None
) -> LFrames:
"""Forward pass of the LFrames module.

Args:
pos (Tensor): The input tensor of shape (N, D) representing the positions of N points in D-dimensional space.
idx (Tensor, optional): The indices of the points to consider. If None, all points are considered. Defaults to None.
batch (Tensor, optional): The batch indices of the points. If None, a batch of zeros is used. Defaults to None.

Returns:
LFrames: The computed local frames as an instance of the LFrames class.
"""
if idx is None:
idx = torch.ones(pos.shape[0], dtype=torch.bool, device=pos.device)

if batch is None:
batch = torch.zeros(pos.shape[0], dtype=torch.int64, device=pos.device)

row, col = radius(
pos, pos[idx], self.r, batch, batch[idx], max_num_neighbors=self.max_num_neighbors
)
# print("average number of neighbors: ", len(row) / len(idx), "max_num_neighbors", self.max_num_neighbors)
edge_index = torch.stack([col, row], dim=0)
edge_vec = pos[edge_index[0]] - pos[edge_index[1]] # (N_edges, dim)

cov_matrices = scatter(
edge_vec.unsqueeze(-1) * edge_vec.unsqueeze(-2),
edge_index[1],
dim=0,
)

# compute the PCA:
_, eigenvectors = torch.linalg.eigh(cov_matrices)

# choose the directions to be o3 equivariant:
eigenvectors = eigenvectors.transpose(-1, -2) # (N, n_vec, dim_vec)

# for each eigenvector compute average dot product with edge vectors:
dots = torch.einsum("ijk,ik->ij", eigenvectors[edge_index[1]], edge_vec)
summed_dots = scatter(dots, edge_index[1], dim=0) # (N, n_vec)
sign_mask = (summed_dots > 0).float() * 2 - 1
eigenvectors = eigenvectors * sign_mask.unsqueeze(-1)

# check how many neighbors each point has:
num_neighbors = scatter(
torch.ones_like(edge_index[0]), edge_index[1], dim=0, reduce="sum"
).float()
no_neighbors_mask = num_neighbors <= 1
if self.exceptional_choice == "random":
random_lframes = RandomLFrames()(pos[no_neighbors_mask]).matrices
eigenvectors[no_neighbors_mask] = random_lframes
elif self.exceptional_choice == "zero":
eigenvectors[no_neighbors_mask] = 0.0
else:
assert (
NotImplementedError
), f"exceptional_choice {self.exceptional_choice} not implemented"

return LFrames(eigenvectors)


class ThreeNNLFrames(LFramesPredictionModule):
"""Computes local frames using the 3-nearest neighbors.

Expand Down
6 changes: 4 additions & 2 deletions tensorframes/nn/tfmessage_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ class TFMessagePassing(MessagePassing):
https://arxiv.org/abs/2405.15389v1
"""

def __init__(self, params_dict: Dict[str, Dict[str, Any]], aggr="add") -> None:
def __init__(
self, params_dict: Dict[str, Dict[str, Any]], aggr="add", *args, **kwargs
) -> None:
"""Initializes a new instance of the TFMessagePassing class.

Args:
Expand All @@ -28,7 +30,7 @@ def __init__(self, params_dict: Dict[str, Dict[str, Any]], aggr="add") -> None:
},
}
"""
super().__init__(aggr=aggr)
super().__init__(aggr=aggr, *args, **kwargs)

self.params_dict = params_dict

Expand Down
74 changes: 74 additions & 0 deletions tests/lframes/test_pca_lframes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import numpy as np
import torch
from sklearn.decomposition import PCA

from tensorframes.lframes.classical_lframes import PCALFrames, RandomGlobalLFrames
from tensorframes.reps.tensorreps import TensorReps


def test_pca_lframes():
"""Tests pca based lframes."""
num_points = 100
pos = torch.rand(num_points, 3)
batch = torch.tensor([1, 2]).repeat_interleave(num_points // 2)

pca_lframes = PCALFrames(r=1, max_num_neighbors=16)
lframes = pca_lframes(pos=pos, idx=None, batch=batch)

dets = torch.linalg.det(lframes.matrices)
print("dets: ", dets.mean(), dets.max(), dets.min())
print("dets counts: ", torch.unique(torch.round(dets, decimals=2), return_counts=True))
assert torch.allclose(torch.linalg.det(lframes.matrices).abs(), torch.ones(num_points))
idents = torch.bmm(lframes.matrices, lframes.matrices.transpose(1, 2))
print("diff from identity: ", (idents - torch.eye(3).expand(num_points, -1, -1)).abs().max())
assert torch.allclose(
torch.bmm(lframes.matrices, lframes.matrices.transpose(1, 2)),
torch.eye(3).expand(num_points, -1, -1),
atol=1e-5,
)

# create a test case to check against PCA:
num_points = 5000
A = np.random.randn(3, 3)
pos = torch.from_numpy(
np.random.multivariate_normal(mean=np.zeros(3), cov=A @ A.T, size=num_points)
).float()
pos[0] = 0.0 # the mean

pca_lframes = PCALFrames(r=100, max_num_neighbors=num_points)
lframes = pca_lframes(pos=pos, idx=None, batch=None)

print("lframes: ", lframes.matrices[0])

# check that the first lframe is the PCA frame:
pca = PCA(n_components=3)
pca.fit(pos.numpy())
print("pca components: ", pca.components_)

# divide them and see if they are the same up to a sign:
ratio = torch.from_numpy(pca.components_).flip(dims=(0,)) / lframes.matrices[0]
print("ratio: ", ratio)
assert torch.allclose(ratio.abs(), torch.ones(3), atol=1e-3)

# ckeck that lframes are equivariant:
num_points = 100
pos = torch.rand(num_points, 3)
batch = torch.tensor([1, 2]).repeat_interleave(num_points // 2)

lframes_learner = PCALFrames(r=1, max_num_neighbors=num_points)
lframes1 = lframes_learner(pos=pos, batch=batch)

# check that x is invariant and lframes are equivariant:
random_trafo = RandomGlobalLFrames()(pos=pos)
pos_rot = TensorReps("1x1").get_transform_class()(coeffs=pos, basis_change=random_trafo)
lframes2 = lframes_learner(pos=pos_rot, batch=batch)

# check that lframes are equivariant:
lframes_matrices1 = torch.bmm(lframes1.matrices, random_trafo.matrices.transpose(1, 2))
diff = (lframes2.matrices - lframes_matrices1).abs()
print("frames max diff", diff.max(), "mean diff", diff.mean())
assert torch.allclose(lframes2.matrices, lframes_matrices1, atol=1e-3)


if __name__ == "__main__":
test_pca_lframes()