diff --git a/configs/datasets/random_dataset.yaml b/configs/datasets/random_dataset.yaml new file mode 100644 index 00000000..acb5f8aa --- /dev/null +++ b/configs/datasets/random_dataset.yaml @@ -0,0 +1,10 @@ +data_domain: pointcloud +data_type: toy_dataset +data_name: random_dataset +data_dir: datasets/${data_domain}/${data_type} + +# Dataset parameters +num_features: 1 +num_classes: 2 +task: classification +loss_type: cross_entropy diff --git a/configs/transforms/liftings/pointcloud2simplicial/vietoris_rips_lifting.yaml b/configs/transforms/liftings/pointcloud2simplicial/vietoris_rips_lifting.yaml new file mode 100644 index 00000000..e40257f6 --- /dev/null +++ b/configs/transforms/liftings/pointcloud2simplicial/vietoris_rips_lifting.yaml @@ -0,0 +1,5 @@ +transform_type: 'lifting' +transform_name: "VietorisRipsLifting" +complex_dim: 2 +feature_lifting: ProjectionSum +epsilon: 0.5 diff --git a/modules/data/load/loaders.py b/modules/data/load/loaders.py index 8ccafb11..0387db0f 100755 --- a/modules/data/load/loaders.py +++ b/modules/data/load/loaders.py @@ -12,6 +12,7 @@ load_cell_complex_dataset, load_hypergraph_pickle_dataset, load_manual_graph, + load_random_points, load_simplicial_dataset, ) @@ -204,3 +205,32 @@ def load( torch_geometric.data.Dataset object containing the loaded data. """ return load_hypergraph_pickle_dataset(self.parameters) + + +class PointCloudLoader(AbstractLoader): + r"""Loader for point-cloud dataset. + + Parameters + ---------- + parameters: DictConfig + Configuration parameters + """ + + def __init__(self, parameters: DictConfig): + super().__init__(parameters) + self.parameters = parameters + + def load(self) -> torch_geometric.data.Dataset: + r"""Load point-cloud dataset. + + Parameters + ---------- + None + + Returns + ------- + torch_geometric.data.Dataset + torch_geometric.data.Dataset object containing the loaded data. + """ + data = load_random_points(num_classes=self.cfg["num_classes"]) + return CustomDataset([data], self.cfg["data_dir"]) diff --git a/modules/data/utils/utils.py b/modules/data/utils/utils.py index 93ab5021..40e48984 100755 --- a/modules/data/utils/utils.py +++ b/modules/data/utils/utils.py @@ -50,16 +50,16 @@ def get_complex_connectivity(complex, max_rank, signed=False): ) except ValueError: # noqa: PERF203 if connectivity_info == "incidence": - connectivity[f"{connectivity_info}_{rank_idx}"] = ( - generate_zero_sparse_connectivity( - m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx] - ) + connectivity[ + f"{connectivity_info}_{rank_idx}" + ] = generate_zero_sparse_connectivity( + m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx] ) else: - connectivity[f"{connectivity_info}_{rank_idx}"] = ( - generate_zero_sparse_connectivity( - m=practical_shape[rank_idx], n=practical_shape[rank_idx] - ) + connectivity[ + f"{connectivity_info}_{rank_idx}" + ] = generate_zero_sparse_connectivity( + m=practical_shape[rank_idx], n=practical_shape[rank_idx] ) connectivity["shape"] = practical_shape return connectivity @@ -283,6 +283,17 @@ def load_hypergraph_pickle_dataset(cfg): return data +def load_random_points(num_classes: int = 2, num_points: int = 8, seed: int = 128): + """Create a toy point cloud dataset""" + rng = np.random.default_rng(seed) + + points = torch.tensor(rng.random((num_points, 2)), dtype=torch.float) + classes = torch.tensor(rng.integers(num_classes, size=num_points), dtype=torch.long) + features = torch.tensor(rng.integers(3, size=(num_points, 1)), dtype=torch.float) + + return torch_geometric.data.Data(x=features, y=classes, pos=points) + + def load_manual_graph(): """Create a manual graph for testing purposes.""" # Define the vertices (just 8 vertices) diff --git a/modules/transforms/data_transform.py b/modules/transforms/data_transform.py index 59253ecf..636af19a 100755 --- a/modules/transforms/data_transform.py +++ b/modules/transforms/data_transform.py @@ -15,6 +15,9 @@ from modules.transforms.liftings.graph2simplicial.clique_lifting import ( SimplicialCliqueLifting, ) +from modules.transforms.liftings.pointcloud2simplicial.vietoris_rips_lifting import ( + VietorisRipsLifting, +) TRANSFORMS = { # Graph -> Hypergraph @@ -23,6 +26,8 @@ "SimplicialCliqueLifting": SimplicialCliqueLifting, # Graph -> Cell Complex "CellCycleLifting": CellCycleLifting, + # Point Cloud -> Simplicial Complex + "VietorisRipsLifting": VietorisRipsLifting, # Feature Liftings "ProjectionSum": ProjectionSum, # Data Manipulations diff --git a/modules/transforms/liftings/pointcloud2simplicial/vietoris_rips_lifting.py b/modules/transforms/liftings/pointcloud2simplicial/vietoris_rips_lifting.py new file mode 100644 index 00000000..90efae3f --- /dev/null +++ b/modules/transforms/liftings/pointcloud2simplicial/vietoris_rips_lifting.py @@ -0,0 +1,101 @@ +from itertools import combinations + +import torch +import torch_geometric +from toponetx.classes import Simplex, SimplicialComplex + +from modules.data.utils.utils import get_complex_connectivity +from modules.transforms.liftings.pointcloud2simplicial.base import ( + PointCloud2SimplicialLifting, +) + + +class VietorisRipsLifting(PointCloud2SimplicialLifting): + """Lifts point cloud data to a Vietoris-Rips Complex. It works + by creating a 1-simplex between any two points if their distance + is less than or equal to epsilon. It then creates an n-simplex if + every pair of its n+1 vertices is connected by a 1-simplex. + + """ + + def __init__(self, epsilon: float, **kwargs): + assert epsilon > 0 + + self.epsilon = epsilon + super().__init__(**kwargs) + + def _get_lifted_topology(self, simplicial_complex: SimplicialComplex) -> dict: + r"""Returns the lifted topology. + + Parameters + ---------- + simplicial_complex : SimplicialComplex + The simplicial complex. + + Returns + ------- + dict + The lifted topology. + """ + lifted_topology = get_complex_connectivity( + simplicial_complex, simplicial_complex.maxdim + ) + + lifted_topology["x_0"] = torch.stack( + list(simplicial_complex.get_simplex_attributes("features", 0).values()) + ) + + return lifted_topology + + def lift_topology(self, data: torch_geometric.data.Data) -> dict: + """ + Applies Vietoris-Rips lifting strategy to point cloud. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data to be lifted. + + Returns + ------- + dict + The lifted topology. + """ + points = data.pos + + # Calculate pairwise distance matrix between points. + distance_matrix = torch.cdist(points, points) + + n = len(points) + + # Add 0-simplices (vertices) with their associated features. + simplices = [Simplex([i], features=data.x[i]) for i in range(n)] + + # Add 1-simplices (edges) where the pairwise distance between + # points are less than epsilon + edges = [[i, j] for i in range(n) for j in range(i + 1, n) if distance_matrix[i, j] <= self.epsilon] + simplices.extend(Simplex(edge) for edge in edges) + + # Step 3: Construct higher-dimensional simplices + # Iteratively finds all k-dimensional simplices (starting from k = 2) that can be formed in the graph. + k = 2 + while True: + higher_dim_simplices = [ + Simplex(list(simplex)) + for simplex in combinations(range(n), k + 1) + if all( + ([simplex[i], simplex[j]] in edges or [simplex[j], simplex[i]] in edges) + for i in range(k) + for j in range(i + 1, k + 1) + ) + ] + + if not higher_dim_simplices: + break + + simplices.extend(higher_dim_simplices) + k += 1 + + SC = SimplicialComplex(simplices) + + return self._get_lifted_topology(SC) diff --git a/test/transforms/liftings/pointcloud2simplicial/test_viteoris_rips_lifting.py b/test/transforms/liftings/pointcloud2simplicial/test_viteoris_rips_lifting.py new file mode 100644 index 00000000..6fc46756 --- /dev/null +++ b/test/transforms/liftings/pointcloud2simplicial/test_viteoris_rips_lifting.py @@ -0,0 +1,74 @@ +import unittest + +import torch +from torch_geometric.data import Data + +from modules.transforms.liftings.pointcloud2simplicial.vietoris_rips_lifting import ( + VietorisRipsLifting, +) + + +class TestVietorisRipsLifting(unittest.TestCase): + def setUp(self): + # Set up some basic point cloud data for testing + self.points = torch.tensor( + [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=torch.float + ) + self.features = torch.tensor([[1], [2], [3], [4]], dtype=torch.float) + self.data = Data(pos=self.points, x=self.features) + self.epsilon = 1.5 # Set epsilon distance + + def test_initialization(self): + # Test initialization + lifting = VietorisRipsLifting(epsilon=self.epsilon) + self.assertEqual(lifting.epsilon, self.epsilon) + + def test_lift_topology(self): + # Test lift_topology method + lifting = VietorisRipsLifting(epsilon=self.epsilon) + lifted_topology = lifting.lift_topology(self.data) + + # Check if the lifted topology contains expected keys + self.assertIn("x_0", lifted_topology) + + # Check if the number of vertices matches the input points + self.assertEqual(lifted_topology["shape"][0], len(self.points)) + + # Check if the features are correctly assigned + for i, feature in enumerate(self.features): + self.assertTrue(torch.equal(lifted_topology["x_0"][i], feature)) + + def test_lifted_topology_structure(self): + # Check the structure of the lifted topology + lifting_tiny_epsilon = VietorisRipsLifting(epsilon=0.5) + lifted_topology_tiny = lifting_tiny_epsilon.lift_topology(self.data) + + # Ensure the output is a dictionary + self.assertIsInstance(lifted_topology_tiny, dict) + + self.assertEqual(lifted_topology_tiny["shape"], [4]) + + lifting_small_epsilon = VietorisRipsLifting(epsilon=1) + lifted_topology_small = lifting_small_epsilon.lift_topology(self.data) + self.assertEqual(lifted_topology_small["shape"], [4, 4]) + + lifting_large_epsilon = VietorisRipsLifting(epsilon=1.5) + lifted_topology_large = lifting_large_epsilon.lift_topology(self.data) + self.assertEqual(lifted_topology_large["shape"], [4, 6, 4, 1]) + + def test_epsilon_effect(self): + # Test the effect of different epsilon values + lifting_small_epsilon = VietorisRipsLifting(epsilon=1) + lifted_topology_small = lifting_small_epsilon.lift_topology(self.data) + simplices_count_small = sum(lifted_topology_small["shape"]) + + lifting_large_epsilon = VietorisRipsLifting(epsilon=1.5) + lifted_topology_large = lifting_large_epsilon.lift_topology(self.data) + simplices_count_large = sum(lifted_topology_large["shape"]) + + # With a smaller epsilon, the total number of simplic + self.assertGreater(simplices_count_large, simplices_count_small) + + +if __name__ == "__main__": + unittest.main() diff --git a/tutorials/pointcloud2simplicial/vietoris_rips_lifting.ipynb b/tutorials/pointcloud2simplicial/vietoris_rips_lifting.ipynb new file mode 100644 index 00000000..b66d2e8f --- /dev/null +++ b/tutorials/pointcloud2simplicial/vietoris_rips_lifting.ipynb @@ -0,0 +1,287 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": true, + "ExecuteTime": { + "start_time": "2024-06-23T16:44:17.471879Z", + "end_time": "2024-06-23T16:44:17.499155Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "from modules.data.load.loaders import PointCloudLoader\n", + "from modules.data.preprocess.preprocessor import PreProcessor\n", + "from modules.utils.utils import (\n", + " describe_data,\n", + " load_dataset_config,\n", + " load_model_config,\n", + " load_transform_config,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Dataset configuration for random_dataset:\n", + "\n", + "{'data_domain': 'pointcloud',\n", + " 'data_type': 'toy_dataset',\n", + " 'data_name': 'random_dataset',\n", + " 'data_dir': 'datasets/pointcloud/toy_dataset',\n", + " 'num_features': 1,\n", + " 'num_classes': 2,\n", + " 'task': 'classification',\n", + " 'loss_type': 'cross_entropy'}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing...\n", + "Done!\n" + ] + } + ], + "source": [ + "# loader = PointCloudLoader(\n", + "# {\n", + "# \"num_classes\": 3,\n", + "# \"data_dir\": \"modules/transforms/liftings/pointcloud2simplicial/\",\n", + "# }\n", + "# )\n", + "\n", + "dataset_name = \"random_dataset\"\n", + "dataset_config = load_dataset_config(dataset_name)\n", + "loader = PointCloudLoader(dataset_config)\n", + "\n", + "dataset = loader.load()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2024-06-23T16:44:22.905086Z", + "end_time": "2024-06-23T16:44:22.940047Z" + } + } + }, + { + "cell_type": "code", + "execution_count": 11, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Transform configuration for pointcloud2simplicial/vietoris_rips_lifting:\n", + "\n", + "{'transform_type': 'lifting',\n", + " 'transform_name': 'VietorisRipsLifting',\n", + " 'complex_dim': 2,\n", + " 'feature_lifting': 'ProjectionSum',\n", + " 'epsilon': 0.5}\n" + ] + } + ], + "source": [ + "transform_type = \"liftings\"\n", + "# If the transform is a topological lifting, it should include both the type of the lifting and the identifier\n", + "transform_id = \"pointcloud2simplicial/vietoris_rips_lifting\"\n", + "\n", + "# Read yaml file\n", + "transform_config = {\n", + " \"lifting\": load_transform_config(transform_type, transform_id)\n", + " # other transforms (e.g. data manipulations, feature liftings) can be added here\n", + "}" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2024-06-23T16:44:34.781062Z", + "end_time": "2024-06-23T16:44:34.811609Z" + } + } + }, + { + "cell_type": "code", + "execution_count": 12, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing...\n", + "/Users/elphicm/mambaforge/envs/topox/lib/python3.11/site-packages/scipy/sparse/_index.py:143: SparseEfficiencyWarning: Changing the sparsity structure of a csr_matrix is expensive. lil_matrix is more efficient.\n", + " self._set_arrayXarray(i, j, x)\n", + "Done!\n" + ] + } + ], + "source": [ + "lifted_dataset = PreProcessor(dataset, transform_config, loader.parameters[\"data_dir\"])" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2024-06-23T16:44:35.715747Z", + "end_time": "2024-06-23T16:44:35.762659Z" + } + } + }, + { + "cell_type": "code", + "execution_count": 13, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Dataset only contains 1 sample:\n" + ] + }, + { + "data": { + "text/plain": "
", + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " - The complex has 8 0-cells.\n", + " - The 0-cells have features dimension 1\n", + " - The complex has 18 1-cells.\n", + " - The 1-cells have features dimension 1\n", + " - The complex has 17 2-cells.\n", + " - The 2-cells have features dimension 1\n", + " - The complex has 7 3-cells.\n", + " - The 3-cells have features dimension 1\n", + " - The complex has 1 4-cells.\n", + " - The 4-cells have features dimension 1\n", + "\n" + ] + } + ], + "source": [ + "describe_data(lifted_dataset)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2024-06-23T16:44:36.383385Z", + "end_time": "2024-06-23T16:44:36.778717Z" + } + } + }, + { + "cell_type": "code", + "execution_count": 14, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Model configuration for simplicial SAN:\n", + "\n", + "{'in_channels': None,\n", + " 'hidden_channels': 32,\n", + " 'out_channels': None,\n", + " 'n_layers': 2,\n", + " 'n_filters': 2,\n", + " 'order_harmonic': 5,\n", + " 'epsilon_harmonic': 0.1}\n" + ] + } + ], + "source": [ + "from modules.models.simplicial.san import SANModel\n", + "\n", + "model_type = \"simplicial\"\n", + "model_id = \"san\"\n", + "model_config = load_model_config(model_type, model_id)\n", + "\n", + "model = SANModel(model_config, dataset_config)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2024-06-23T16:44:37.431847Z", + "end_time": "2024-06-23T16:44:37.464032Z" + } + } + }, + { + "cell_type": "code", + "execution_count": 17, + "outputs": [ + { + "data": { + "text/plain": "tensor([[0.5710, 0.5974],\n [0.5751, 0.5552],\n [0.5592, 0.6042],\n [0.5718, 0.6019],\n [0.5619, 0.6089],\n [0.5743, 0.5460],\n [0.5607, 0.6206],\n [0.5710, 0.5974],\n [0.5678, 0.5981],\n [0.5716, 0.5490],\n [0.5647, 0.5582],\n [0.5831, 0.5089],\n [0.5751, 0.5552],\n [0.5797, 0.5526],\n [0.5592, 0.6042],\n [0.5548, 0.6062],\n [0.5681, 0.5536],\n [0.5619, 0.6089]], grad_fn=)" + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y_hat = model(lifted_dataset.get(0))\n", + "\n", + "y_hat" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2024-06-23T16:45:16.116100Z", + "end_time": "2024-06-23T16:45:16.145987Z" + } + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}