Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions configs/datasets/sphere_point_cloud.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
data_domain: pointcloud
data_type: toy_dataset
data_name: sphere_point_cloud
num_features: 1
num_classes: 2
data_dir: datasets/${data_domain}/${data_type}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
transform_type: 'lifting'
transform_name: "BallPivotingLifting"
# complex_dim: 3
# preserve_edge_attr: False
radii: [0.1, 0.5]
feature_lifting: ProjectionSum
68 changes: 68 additions & 0 deletions modules/data/load/loaders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

import matplotlib.pyplot as plt
import numpy as np
import rootutils
import torch_geometric
Expand All @@ -13,6 +14,7 @@
load_hypergraph_pickle_dataset,
load_manual_graph,
load_simplicial_dataset,
load_sphere_point_cloud,
)


Expand Down Expand Up @@ -204,3 +206,69 @@ def load(
torch_geometric.data.Dataset object containing the loaded data.
"""
return load_hypergraph_pickle_dataset(self.parameters)


class SpherePointCloudLoader(AbstractLoader):
r"""Loader for the sphere pointcloud dataset.

Parameters
----------
parameters : DictConfig
Configuration parameters.
"""

def __init__(self, parameters: DictConfig):
super().__init__(parameters)
self.parameters = parameters

def load(
self,
) -> torch_geometric.data.Dataset:
r"""Load simplicial dataset.

Parameters
----------
None

Returns
-------
torch_geometric.data.Dataset
torch_geometric.data.Dataset object containing the loaded data.
"""
# Define the path to the data directory
root_folder = rootutils.find_root()
root_data_dir = os.path.join(root_folder, self.parameters["data_dir"])
self.data_dir = os.path.join(root_data_dir, self.parameters["data_name"])

function_params = {
"num_classes": self.parameters.get("num_classes", 2),
"num_points": self.parameters.get("num_points", 1000),
"num_features": self.parameters.get("num_features", 1),
"seed": self.parameters.get("seed", 0)
}
return load_sphere_point_cloud(**function_params)

def plot_point(
self,
data: torch_geometric.data.Data
) -> None:
r"""Plot 3d point cloud dataset.

Parameters
----------
data: torch_geometric.data.Data | dict
The input data to be plotted.

Returns
-------
None
"""
x = np.asarray(data.pos)[:, 0]
y = np.asarray(data.pos)[:, 1]
z = np.asarray(data.pos)[:, 2]

fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.scatter(x, y, z)
plt.show()
return
18 changes: 18 additions & 0 deletions modules/data/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,3 +421,21 @@ def make_hash(o):
hash_as_hex = sha1.hexdigest()
# Convert the hex back to int and restrict it to the relevant int range
return int(hash_as_hex, 16) % 4294967295


def load_sphere_point_cloud(num_classes: int = 2, num_points: int = 1000, num_features: int = 1, seed: int = 0):
"""Create a point cloud dataset in the shape of a sphere"""

# Generate random points from a normal distribution
points = torch.randn(num_points, 3)
# Normalize each point to lie on the surface of a unit sphere
points = points / points.norm(dim=1, keepdim=True)
# Generate the normals
normals = points.clone()

rng = np.random.default_rng(seed)
# points = torch.tensor(rng.random((5, 3)), dtype=torch.float)
classes = torch.tensor(rng.integers(num_classes, size=num_points), dtype=torch.long)
features = torch.tensor(rng.integers(num_features, size=(num_points, 1)), dtype=torch.float)

return torch_geometric.data.Data(x=features, y=classes, pos=points, normals=normals)
4 changes: 4 additions & 0 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
SimplicialCliqueLifting,
)
from modules.transforms.liftings.pointcloud2simplicial.ball_pivoting_lifting import (
BallPivotingLifting,
)

TRANSFORMS = {
"BallPivotingLifting": BallPivotingLifting,
# Graph -> Hypergraph
"HypergraphKNNLifting": HypergraphKNNLifting,
# Graph -> Simplicial Complex
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import warnings

import numpy as np
import open3d
import torch_geometric
from toponetx.classes import SimplicialComplex

from modules.data.utils.utils import get_complex_connectivity
from modules.transforms.liftings.pointcloud2simplicial.base import (
PointCloud2SimplicialLifting,
)


class BallPivotingLifting(PointCloud2SimplicialLifting):
"""Uses the Ball Pivoting Algorithm to lift an input point cloud to a simplical complex.
Parameters
----------
data : torch_geometric.data.Data | dict
The input data to be lifted.

radii : list[float]
The radii of the balls used in the algorithm.
Returns
-------
torch_geometric.data.Data | dict
The lifted data."""

def __init__(self, radii: list[float], **kwargs):
super().__init__(**kwargs)
self.radii = radii

def lift_topology(self, data: torch_geometric.data.Data) -> dict:
# Convert input data into an open3d point cloud
open3d_point_cloud = open3d.geometry.PointCloud(open3d.cpu.pybind.utility.Vector3dVector(data.pos.numpy()))
# Check that the input point cloud includes normals. The Ball Pivoting Algorithm requires normals.
if "normals" not in data:
warnings.warn("Normals not found in data set. The Ball Pivoting algorithm requires oriented 3D points, thus, normals will be estimated using the 'estimate_normals' method. Note, the normals are often not estimated with great success, so the performance of the algorithm might suffer heavily from this.",
stacklevel=1)

open3d_point_cloud.estimate_normals()
else:
open3d_point_cloud.normals = open3d.cpu.pybind.utility.Vector3dVector(data.normals.numpy())

rec_mesh = open3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
open3d_point_cloud, open3d.utility.DoubleVector(self.radii))

# Convert output to proper format
simplices = np.asarray(rec_mesh.triangles)

simplicial_complex = SimplicialComplex(simplices)

lifted_topology = self._get_lifted_topology(simplicial_complex)
lifted_topology["x_0"] = data.x

return lifted_topology

def _get_lifted_topology(self, simplicial_complex: SimplicialComplex) -> dict:
r"""Returns the lifted topology.
Parameters
----------
simplicial_complex : SimplicialComplex
The simplicial complex.
Returns
---------
dict
The lifted topology.
"""
return get_complex_connectivity(simplicial_complex, self.complex_dim)

def plot_lifted_topology(self, data: torch_geometric.data.Data):
r"""Plots the lifted topology.
Parameters
----------
data : torch_geometric.data.Data | dict
The input data to be lifted.
Returns
---------
None
"""
# Convert input data into an open3d point cloud
open3d_point_cloud = open3d.geometry.PointCloud(open3d.cpu.pybind.utility.Vector3dVector(data.pos.numpy()))
# Check that the input point cloud includes normals. The Ball Pivoting Algorithm requires normals.
if "normals" not in data:
warnings.warn("Normals not found in data set. The Ball Pivoting algorithm requires oriented 3D points, thus, normals will be estimated using the 'estimate_normals' method. Note, the normals are often not estimated with great success, so the performance of the algorithm might suffer heavily from this.",
stacklevel=1)

open3d_point_cloud.estimate_normals()
else:
open3d_point_cloud.normals = open3d.cpu.pybind.utility.Vector3dVector(data.normals.numpy())

rec_mesh = open3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
open3d_point_cloud, open3d.utility.DoubleVector(self.radii))
open3d.visualization.draw_geometries([open3d_point_cloud, rec_mesh])
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ dependencies=[
"spharapy",
"rich",
"rootutils",
"open3d",
"pytest",
"toponetx @ git+https://github.com/pyt-team/TopoNetX.git",
"topomodelx @ git+https://github.com/pyt-team/TopoModelX.git",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Test the message passing module."""

import torch

from modules.data.utils.utils import load_sphere_point_cloud
from modules.transforms.liftings.pointcloud2simplicial.ball_pivoting_lifting import (
BallPivotingLifting,
)


class TestBallPivotingLifting:
"""Test the SimplicialCliqueLifting class."""

def setup_method(self):
# Create the point cloud
self.sphere_pc = load_sphere_point_cloud(num_classes=2, num_points=3, num_features=1, seed=0)
radii = [0.1, 2.0]

# Initialize the BallPivotingLifting class
self.lifting = BallPivotingLifting(radii=radii)

def test_lift_topology(self):
"""Test the lift_topology method."""

# Test the lift_topology method
lifted_data = self.lifting.forward(self.sphere_pc)

expected_incidence_0_indices = torch.tensor(
[
[0, 0, 0],
[0, 1, 2]
])

assert(
torch.all(abs(expected_incidence_0_indices) == lifted_data["incidence_0"].indices())
)

expected_incidence_1_indices = torch.tensor(
[
[0, 0, 1, 1, 2, 2],
[0, 1, 0, 2, 1, 2]
])

assert(
torch.all(abs(expected_incidence_1_indices) == lifted_data["incidence_1"].indices())
), "Something is wrong with the incidence_1 matrix (nodes to edges)."

expected_incidence_2_indices = torch.tensor(
[
[0, 1, 2],
[0, 0, 0]
])

assert(
torch.all(abs(expected_incidence_2_indices) == lifted_data["incidence_2"].indices())
), "Something is wrong with the incidence_2 matrix (edges to triangles)."
259 changes: 259 additions & 0 deletions tutorials/pointcloud2simplicial/ball_pivoting.ipynb

Large diffs are not rendered by default.