Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions configs/datasets/random_dataset.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
data_domain: pointcloud
data_type: toy_dataset
data_name: random_dataset
data_dir: datasets/${data_domain}/${data_type}

# Dataset parameters
num_features: 1
num_classes: 2
task: classification
loss_type: cross_entropy
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
transform_type: 'lifting'
transform_name: "VietorisRipsLifting"
complex_dim: 2
feature_lifting: ProjectionSum
epsilon: 0.5
30 changes: 30 additions & 0 deletions modules/data/load/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
load_cell_complex_dataset,
load_hypergraph_pickle_dataset,
load_manual_graph,
load_random_points,
load_simplicial_dataset,
)

Expand Down Expand Up @@ -204,3 +205,32 @@ def load(
torch_geometric.data.Dataset object containing the loaded data.
"""
return load_hypergraph_pickle_dataset(self.parameters)


class PointCloudLoader(AbstractLoader):
r"""Loader for point-cloud dataset.

Parameters
----------
parameters: DictConfig
Configuration parameters
"""

def __init__(self, parameters: DictConfig):
super().__init__(parameters)
self.parameters = parameters

def load(self) -> torch_geometric.data.Dataset:
r"""Load point-cloud dataset.

Parameters
----------
None

Returns
-------
torch_geometric.data.Dataset
torch_geometric.data.Dataset object containing the loaded data.
"""
data = load_random_points(num_classes=self.cfg["num_classes"])
return CustomDataset([data], self.cfg["data_dir"])
27 changes: 19 additions & 8 deletions modules/data/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ def get_complex_connectivity(complex, max_rank, signed=False):
)
except ValueError: # noqa: PERF203
if connectivity_info == "incidence":
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
else:
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity["shape"] = practical_shape
return connectivity
Expand Down Expand Up @@ -283,6 +283,17 @@ def load_hypergraph_pickle_dataset(cfg):
return data


def load_random_points(num_classes: int = 2, num_points: int = 8, seed: int = 128):
"""Create a toy point cloud dataset"""
rng = np.random.default_rng(seed)

points = torch.tensor(rng.random((num_points, 2)), dtype=torch.float)
classes = torch.tensor(rng.integers(num_classes, size=num_points), dtype=torch.long)
features = torch.tensor(rng.integers(3, size=(num_points, 1)), dtype=torch.float)

return torch_geometric.data.Data(x=features, y=classes, pos=points)


def load_manual_graph():
"""Create a manual graph for testing purposes."""
# Define the vertices (just 8 vertices)
Expand Down
5 changes: 5 additions & 0 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
SimplicialCliqueLifting,
)
from modules.transforms.liftings.pointcloud2simplicial.vietoris_rips_lifting import (
VietorisRipsLifting,
)

TRANSFORMS = {
# Graph -> Hypergraph
Expand All @@ -23,6 +26,8 @@
"SimplicialCliqueLifting": SimplicialCliqueLifting,
# Graph -> Cell Complex
"CellCycleLifting": CellCycleLifting,
# Point Cloud -> Simplicial Complex
"VietorisRipsLifting": VietorisRipsLifting,
# Feature Liftings
"ProjectionSum": ProjectionSum,
# Data Manipulations
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from itertools import combinations

import torch
import torch_geometric
from toponetx.classes import Simplex, SimplicialComplex

from modules.data.utils.utils import get_complex_connectivity
from modules.transforms.liftings.pointcloud2simplicial.base import (
PointCloud2SimplicialLifting,
)


class VietorisRipsLifting(PointCloud2SimplicialLifting):
"""Lifts point cloud data to a Vietoris-Rips Complex. It works
by creating a 1-simplex between any two points if their distance
is less than or equal to epsilon. It then creates an n-simplex if
every pair of its n+1 vertices is connected by a 1-simplex.

"""

def __init__(self, epsilon: float, **kwargs):
assert epsilon > 0

self.epsilon = epsilon
super().__init__(**kwargs)

def _get_lifted_topology(self, simplicial_complex: SimplicialComplex) -> dict:
r"""Returns the lifted topology.

Parameters
----------
simplicial_complex : SimplicialComplex
The simplicial complex.

Returns
-------
dict
The lifted topology.
"""
lifted_topology = get_complex_connectivity(
simplicial_complex, simplicial_complex.maxdim
)

lifted_topology["x_0"] = torch.stack(
list(simplicial_complex.get_simplex_attributes("features", 0).values())
)

return lifted_topology

def lift_topology(self, data: torch_geometric.data.Data) -> dict:
"""
Applies Vietoris-Rips lifting strategy to point cloud.

Parameters
----------
data : torch_geometric.data.Data
The input data to be lifted.

Returns
-------
dict
The lifted topology.
"""
points = data.pos

# Calculate pairwise distance matrix between points.
distance_matrix = torch.cdist(points, points)

n = len(points)

# Add 0-simplices (vertices) with their associated features.
simplices = [Simplex([i], features=data.x[i]) for i in range(n)]

# Add 1-simplices (edges) where the pairwise distance between
# points are less than epsilon
edges = [[i, j] for i in range(n) for j in range(i + 1, n) if distance_matrix[i, j] <= self.epsilon]
simplices.extend(Simplex(edge) for edge in edges)

# Step 3: Construct higher-dimensional simplices
# Iteratively finds all k-dimensional simplices (starting from k = 2) that can be formed in the graph.
k = 2
while True:
higher_dim_simplices = [
Simplex(list(simplex))
for simplex in combinations(range(n), k + 1)
if all(
([simplex[i], simplex[j]] in edges or [simplex[j], simplex[i]] in edges)
for i in range(k)
for j in range(i + 1, k + 1)
)
]

if not higher_dim_simplices:
break

simplices.extend(higher_dim_simplices)
k += 1

SC = SimplicialComplex(simplices)

return self._get_lifted_topology(SC)
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import unittest

import torch
from torch_geometric.data import Data

from modules.transforms.liftings.pointcloud2simplicial.vietoris_rips_lifting import (
VietorisRipsLifting,
)


class TestVietorisRipsLifting(unittest.TestCase):
def setUp(self):
# Set up some basic point cloud data for testing
self.points = torch.tensor(
[[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=torch.float
)
self.features = torch.tensor([[1], [2], [3], [4]], dtype=torch.float)
self.data = Data(pos=self.points, x=self.features)
self.epsilon = 1.5 # Set epsilon distance

def test_initialization(self):
# Test initialization
lifting = VietorisRipsLifting(epsilon=self.epsilon)
self.assertEqual(lifting.epsilon, self.epsilon)

def test_lift_topology(self):
# Test lift_topology method
lifting = VietorisRipsLifting(epsilon=self.epsilon)
lifted_topology = lifting.lift_topology(self.data)

# Check if the lifted topology contains expected keys
self.assertIn("x_0", lifted_topology)

# Check if the number of vertices matches the input points
self.assertEqual(lifted_topology["shape"][0], len(self.points))

# Check if the features are correctly assigned
for i, feature in enumerate(self.features):
self.assertTrue(torch.equal(lifted_topology["x_0"][i], feature))

def test_lifted_topology_structure(self):
# Check the structure of the lifted topology
lifting_tiny_epsilon = VietorisRipsLifting(epsilon=0.5)
lifted_topology_tiny = lifting_tiny_epsilon.lift_topology(self.data)

# Ensure the output is a dictionary
self.assertIsInstance(lifted_topology_tiny, dict)

self.assertEqual(lifted_topology_tiny["shape"], [4])

lifting_small_epsilon = VietorisRipsLifting(epsilon=1)
lifted_topology_small = lifting_small_epsilon.lift_topology(self.data)
self.assertEqual(lifted_topology_small["shape"], [4, 4])

lifting_large_epsilon = VietorisRipsLifting(epsilon=1.5)
lifted_topology_large = lifting_large_epsilon.lift_topology(self.data)
self.assertEqual(lifted_topology_large["shape"], [4, 6, 4, 1])

def test_epsilon_effect(self):
# Test the effect of different epsilon values
lifting_small_epsilon = VietorisRipsLifting(epsilon=1)
lifted_topology_small = lifting_small_epsilon.lift_topology(self.data)
simplices_count_small = sum(lifted_topology_small["shape"])

lifting_large_epsilon = VietorisRipsLifting(epsilon=1.5)
lifted_topology_large = lifting_large_epsilon.lift_topology(self.data)
simplices_count_large = sum(lifted_topology_large["shape"])

# With a smaller epsilon, the total number of simplic
self.assertGreater(simplices_count_large, simplices_count_small)


if __name__ == "__main__":
unittest.main()
Loading