diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml new file mode 100644 index 0000000..b04fb15 --- /dev/null +++ b/.github/workflows/black.yml @@ -0,0 +1,10 @@ +name: Lint + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: psf/black@stable diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3f0c0ab --- /dev/null +++ b/.gitignore @@ -0,0 +1,171 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# configs/ # commented as new configs can be added as a part of a feature + +/.idea +/data +/logs +/results_buffer +electra_pretrained.ckpt +.isort.cfg +/.vscode diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 866c153..108b91d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,25 @@ repos: -#- repo: https://github.com/PyCQA/isort -# rev: "5.12.0" -# hooks: -# - id: isort - repo: https://github.com/psf/black - rev: "22.10.0" + rev: "24.2.0" hooks: - - id: black \ No newline at end of file + - id: black + - id: black-jupyter # for formatting jupyter-notebook + +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort (python) + args: ["--profile=black"] + +- repo: https://github.com/asottile/seed-isort-config + rev: v2.2.0 + hooks: + - id: seed-isort-config + +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace diff --git a/README.md b/README.md index 6af4630..c8ce94b 100644 --- a/README.md +++ b/README.md @@ -2,13 +2,13 @@ ## Installation -Some requirements may not be installed successfully automatically. +Some requirements may not be installed successfully automatically. To install the `torch-` libraries, use `pip install torch-${lib} -f https://data.pyg.org/whl/torch-2.1.0+${CUDA}.html` where `${lib}` is either `scatter`, `geometric`, `sparse` or `cluster`, and -`${CUDA}` is either `cpu`, `cu118` or `cu121` (depending on your system, see e.g. +`${CUDA}` is either `cpu`, `cu118` or `cu121` (depending on your system, see e.g. [torch-geometric docs](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html)) @@ -31,7 +31,7 @@ We recommend the following setup: If you run the command from the `python-chebai` directory, you can use the same data for both chebai- and chebai-graph-models (e.g., Transformers and GNNs). Then you have to use `{path-to-chebai} -> .` and `{path-to-chebai-graph} -> ../python-chebai-graph`. - + Pretraining on a atom / bond masking task with PubChem data (feature-branch): ``` python3 -m chebai fit --model={path-to-chebai-graph}/configs/model/gnn_resgated_pretrain.yml --data={path-to-chebai-graph}/configs/data/pubchem_graph.yml --trainer={path-to-chebai}/configs/training/pretraining_trainer.yml diff --git a/chebai_graph/models/gin_net.py b/chebai_graph/models/gin_net.py index 75c2c45..6fed4c6 100644 --- a/chebai_graph/models/gin_net.py +++ b/chebai_graph/models/gin_net.py @@ -1,10 +1,11 @@ +import typing + +import torch +import torch.nn.functional as F +import torch_geometric from torch_scatter import scatter_add from chebai_graph.models.graph import GraphBaseNet -import torch_geometric -import torch.nn.functional as F -import torch -import typing class AggregateMLP(torch.nn.Module): diff --git a/chebai_graph/models/graph.py b/chebai_graph/models/graph.py index 5da9a62..7c4082a 100644 --- a/chebai_graph/models/graph.py +++ b/chebai_graph/models/graph.py @@ -1,5 +1,6 @@ import logging import typing +from abc import ABC import torch import torch.nn.functional as F @@ -15,7 +16,7 @@ logging.getLogger("pysmiles").setLevel(logging.CRITICAL) -class GraphBaseNet(ChebaiBaseNet): +class GraphBaseNet(ChebaiBaseNet, ABC): def _get_prediction_and_labels(self, data, labels, output): return torch.sigmoid(output), labels.int() @@ -84,8 +85,6 @@ class ResGatedGraphConvNetBase(GraphBaseNet): def __init__(self, config: typing.Dict, **kwargs): super().__init__(**kwargs) - - self.in_length = config["in_length"] self.hidden_length = config["hidden_length"] self.dropout_rate = config["dropout_rate"] self.n_conv_layers = config["n_conv_layers"] if "n_conv_layers" in config else 3 @@ -104,26 +103,26 @@ def __init__(self, config: typing.Dict, **kwargs): self.activation = F.elu self.dropout = nn.Dropout(self.dropout_rate) - self.convs = torch.nn.ModuleList([]) - for i in range(self.n_conv_layers): - if i == 0: - self.convs.append( - tgnn.ResGatedGraphConv( - self.n_atom_properties, - self.in_length, - # dropout=self.dropout_rate, - edge_dim=self.n_bond_properties, - ) - ) + + self.convs.append( + tgnn.ResGatedGraphConv( + self.n_atom_properties, + self.hidden_length, + # dropout=self.dropout_rate, + edge_dim=self.n_bond_properties, + ) + ) + + for _ in range(self.n_conv_layers - 1): self.convs.append( tgnn.ResGatedGraphConv( - self.in_length, self.in_length, edge_dim=self.n_bond_properties + self.hidden_length, + self.hidden_length, + # dropout=self.dropout_rate, + edge_dim=self.n_bond_properties, ) ) - self.final_conv = tgnn.ResGatedGraphConv( - self.in_length, self.hidden_length, edge_dim=self.n_bond_properties - ) def forward(self, batch): graph_data = batch["features"][0] @@ -136,11 +135,6 @@ def forward(self, batch): a = self.activation( conv(a, graph_data.edge_index.long(), edge_attr=graph_data.edge_attr) ) - a = self.activation( - self.final_conv( - a, graph_data.edge_index.long(), edge_attr=graph_data.edge_attr - ) - ) return a diff --git a/chebai_graph/preprocessing/collate.py b/chebai_graph/preprocessing/collate.py index 2c5f696..4be36cf 100644 --- a/chebai_graph/preprocessing/collate.py +++ b/chebai_graph/preprocessing/collate.py @@ -1,11 +1,11 @@ from typing import Dict import torch +from chebai.preprocessing.collate import RaggedCollator from torch_geometric.data import Data as GeomData from torch_geometric.data.collate import collate as graph_collate -from chebai_graph.preprocessing.structures import XYGraphData -from chebai.preprocessing.collate import RaggedCollator +from chebai_graph.preprocessing.structures import XYGraphData class GraphCollator(RaggedCollator): diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index f84b3a5..2721b1c 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -1,26 +1,27 @@ -from typing import Optional, List, Callable +import importlib +import os +from abc import ABC +from typing import Callable, List, Optional +import pandas as pd +import torch +import tqdm from chebai.preprocessing.datasets.chebi import ( ChEBIOver50, ChEBIOver100, + ChEBIOverX, ChEBIOverXPartial, ) -from chebai.preprocessing.datasets.base import XYBaseDataModule from lightning_utilities.core.rank_zero import rank_zero_info +from torch_geometric.data.data import Data as GeomData -from chebai_graph.preprocessing.reader import GraphReader, GraphPropertyReader +import chebai_graph.preprocessing.properties as graph_properties from chebai_graph.preprocessing.properties import ( AtomProperty, BondProperty, MolecularProperty, ) -import pandas as pd -from torch_geometric.data.data import Data as GeomData -import torch -import chebai_graph.preprocessing.properties as graph_properties -import importlib -import os -import tqdm +from chebai_graph.preprocessing.reader import GraphPropertyReader, GraphReader class ChEBI50GraphData(ChEBIOver50): @@ -48,7 +49,7 @@ def _resolve_property( return getattr(graph_properties, property)() -class GraphPropertiesMixIn(XYBaseDataModule): +class GraphPropertiesMixIn(ChEBIOverX, ABC): READER = GraphPropertyReader def __init__( @@ -84,9 +85,11 @@ def _setup_properties(self): for file in file_names: # processed_dir_main only exists for ChEBI datasets path = os.path.join( - self.processed_dir_main - if hasattr(self, "processed_dir_main") - else self.raw_dir, + ( + self.processed_dir_main + if hasattr(self, "processed_dir_main") + else self.raw_dir + ), file, ) raw_data += list(self._load_dict(path)) @@ -94,8 +97,8 @@ def _setup_properties(self): features = [row["features"] for row in raw_data] # use vectorized version of encode function, apply only if value is present - enc_if_not_none = ( - lambda encode, value: [encode(atom_v) for atom_v in value] + enc_if_not_none = lambda encode, value: ( + [encode(atom_v) for atom_v in value] if value is not None and len(value) > 0 else None ) @@ -104,10 +107,12 @@ def _setup_properties(self): if not os.path.isfile(self.get_property_path(property)): rank_zero_info(f"Processing property {property.name}") # read all property values first, then encode + rank_zero_info(f"\tReading property valeus...") property_values = [ self.reader.read_property(feat, property) for feat in tqdm.tqdm(features) ] + rank_zero_info(f"\tEncoding property values...") property.encoder.on_start(property_values=property_values) encoded_values = [ enc_if_not_none(property.encoder.encode, value) @@ -160,7 +165,11 @@ def _merge_props_into_base(self, row): if isinstance(property, AtomProperty): x = torch.cat([x, property_values], dim=1) elif isinstance(property, BondProperty): - edge_attr = torch.cat([edge_attr, property_values], dim=1) + # Concat/Duplicate properties values for undirected graph as `edge_index` has first src to tgt edges, then tgt to src edges + edge_attr = torch.cat( + [edge_attr, torch.cat([property_values, property_values], dim=0)], + dim=1, + ) else: molecule_attr = torch.cat([molecule_attr, property_values], dim=1) return GeomData( diff --git a/chebai_graph/preprocessing/datasets/pubchem.py b/chebai_graph/preprocessing/datasets/pubchem.py index 210b7ab..6f5d118 100644 --- a/chebai_graph/preprocessing/datasets/pubchem.py +++ b/chebai_graph/preprocessing/datasets/pubchem.py @@ -1,6 +1,7 @@ -from chebai_graph.preprocessing.datasets.chebi import GraphPropertiesMixIn from chebai.preprocessing.datasets.pubchem import PubchemChem +from chebai_graph.preprocessing.datasets.chebi import GraphPropertiesMixIn + class PubChemGraphProperties(GraphPropertiesMixIn, PubchemChem): pass diff --git a/chebai_graph/preprocessing/properties.py b/chebai_graph/preprocessing/properties.py index 21d8342..2b3acf8 100644 --- a/chebai_graph/preprocessing/properties.py +++ b/chebai_graph/preprocessing/properties.py @@ -6,11 +6,11 @@ from descriptastorus.descriptors import rdNormalizedDescriptors from chebai_graph.preprocessing.property_encoder import ( - PropertyEncoder, - IndexEncoder, - OneHotEncoder, AsIsEncoder, BoolEncoder, + IndexEncoder, + OneHotEncoder, + PropertyEncoder, ) @@ -155,5 +155,5 @@ def get_property_value(self, mol: Chem.rdchem.Mol): features_normalized = generator_normalized.processMol( mol, Chem.MolToSmiles(mol) ) - np.nan_to_num(features_normalized, copy=False) + features_normalized = np.nan_to_num(features_normalized) return [features_normalized[1:]] diff --git a/chebai_graph/preprocessing/property_encoder.py b/chebai_graph/preprocessing/property_encoder.py index 497025c..ebfbe0c 100644 --- a/chebai_graph/preprocessing/property_encoder.py +++ b/chebai_graph/preprocessing/property_encoder.py @@ -1,8 +1,9 @@ import abc import os -import torch from typing import Optional +import torch + class PropertyEncoder(abc.ABC): def __init__(self, property, **kwargs): diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index b814d53..687b199 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -1,19 +1,17 @@ -import importlib - -from torch_geometric.utils import from_networkx -from typing import Tuple, Mapping, Optional, List +import os +from typing import List, Optional -import importlib +import chebai.preprocessing.reader as dr import networkx as nx -import os -import torch -import rdkit.Chem as Chem import pysmiles as ps -import chebai.preprocessing.reader as dr -from chebai_graph.preprocessing.collate import GraphCollator -import chebai_graph.preprocessing.properties as properties +import rdkit.Chem as Chem +import torch +from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn from torch_geometric.data import Data as GeomData -from lightning_utilities.core.rank_zero import rank_zero_warn, rank_zero_info +from torch_geometric.utils import from_networkx + +from chebai_graph.preprocessing import properties +from chebai_graph.preprocessing.collate import GraphCollator class GraphPropertyReader(dr.ChemDataReader): @@ -45,7 +43,7 @@ def _smiles_to_mol(self, smiles: str) -> Optional[Chem.rdchem.Mol]: try: Chem.SanitizeMol(mol) except Exception as e: - rank_zero_warn(f"Rdkit failed at sanitizing {smiles}") + rank_zero_warn(f"Rdkit failed at sanitizing {smiles} \n Error: {e}") self.failed_counter += 1 self.mol_object_buffer[smiles] = mol return mol @@ -57,14 +55,14 @@ def _read_data(self, raw_data): x = torch.zeros((mol.GetNumAtoms(), 0)) - edge_attr = torch.zeros((mol.GetNumBonds(), 0)) + # First source to target edges, then target to source edges + src = [bond.GetBeginAtomIdx() for bond in mol.GetBonds()] + tgt = [bond.GetEndAtomIdx() for bond in mol.GetBonds()] + edge_index = torch.tensor([src + tgt, tgt + src], dtype=torch.long) + + # edge_index.shape == [2, num_edges]; edge_attr.shape == [num_edges, num_edge_features] + edge_attr = torch.zeros((edge_index.size(1), 0)) - edge_index = torch.tensor( - [ - [bond.GetBeginAtomIdx() for bond in mol.GetBonds()], - [bond.GetEndAtomIdx() for bond in mol.GetBonds()], - ] - ) return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) def on_finish(self): diff --git a/chebai_graph/preprocessing/transform_unlabeled.py b/chebai_graph/preprocessing/transform_unlabeled.py index 3920659..0cc4b35 100644 --- a/chebai_graph/preprocessing/transform_unlabeled.py +++ b/chebai_graph/preprocessing/transform_unlabeled.py @@ -1,4 +1,5 @@ import random + import torch diff --git a/configs/data/chebi50_graph.yml b/configs/data/chebi50_graph.yml index 14cc489..19c8753 100644 --- a/configs/data/chebi50_graph.yml +++ b/configs/data/chebi50_graph.yml @@ -1 +1 @@ -class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphData \ No newline at end of file +class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphData diff --git a/configs/data/pubchem_graph.yml b/configs/data/pubchem_graph.yml index af04491..c21f188 100644 --- a/configs/data/pubchem_graph.yml +++ b/configs/data/pubchem_graph.yml @@ -16,4 +16,4 @@ init_args: - chebai_graph.preprocessing.properties.BondInRing - chebai_graph.preprocessing.properties.BondAromaticity #- chebai_graph.preprocessing.properties.MoleculeNumRings - - chebai_graph.preprocessing.properties.RDKit2DNormalized \ No newline at end of file + - chebai_graph.preprocessing.properties.RDKit2DNormalized diff --git a/configs/loss/mask_pretraining.yml b/configs/loss/mask_pretraining.yml index c677559..6d2a560 100644 --- a/configs/loss/mask_pretraining.yml +++ b/configs/loss/mask_pretraining.yml @@ -1 +1 @@ -class_path: chebai_graph.loss.pretraining.MaskPretrainingLoss \ No newline at end of file +class_path: chebai_graph.loss.pretraining.MaskPretrainingLoss diff --git a/configs/model/gnn.yml b/configs/model/gnn.yml index b0b119d..f85fa76 100644 --- a/configs/model/gnn.yml +++ b/configs/model/gnn.yml @@ -7,4 +7,4 @@ init_args: hidden_length: 512 dropout_rate: 0.1 n_conv_layers: 3 - n_linear_layers: 3 \ No newline at end of file + n_linear_layers: 3 diff --git a/configs/model/gnn_attention.yml b/configs/model/gnn_attention.yml index b1c553b..0c11ced 100644 --- a/configs/model/gnn_attention.yml +++ b/configs/model/gnn_attention.yml @@ -8,4 +8,4 @@ init_args: dropout_rate: 0.1 n_conv_layers: 5 n_linear_layers: 3 - n_heads: 5 \ No newline at end of file + n_heads: 5 diff --git a/configs/model/gnn_gine.yml b/configs/model/gnn_gine.yml index 0d0ed20..c84ea61 100644 --- a/configs/model/gnn_gine.yml +++ b/configs/model/gnn_gine.yml @@ -8,4 +8,4 @@ init_args: n_conv_layers: 5 n_linear_layers: 3 n_atom_properties: 125 - n_bond_properties: 5 \ No newline at end of file + n_bond_properties: 5 diff --git a/configs/model/gnn_res_gated.yml b/configs/model/gnn_res_gated.yml index d9ddc05..62d990d 100644 --- a/configs/model/gnn_res_gated.yml +++ b/configs/model/gnn_res_gated.yml @@ -3,11 +3,10 @@ init_args: optimizer_kwargs: lr: 1e-3 config: - in_length: 256 - hidden_length: 512 + hidden_length: 256 dropout_rate: 0.1 n_conv_layers: 3 n_linear_layers: 3 n_atom_properties: 158 n_bond_properties: 7 - n_molecule_properties: 200 \ No newline at end of file + n_molecule_properties: 200 diff --git a/configs/model/gnn_resgated_pretrain.yml b/configs/model/gnn_resgated_pretrain.yml index c26db76..fad8c27 100644 --- a/configs/model/gnn_resgated_pretrain.yml +++ b/configs/model/gnn_resgated_pretrain.yml @@ -13,4 +13,4 @@ init_args: n_linear_layers: 3 n_atom_properties: 151 n_bond_properties: 7 - n_molecule_properties: 200 \ No newline at end of file + n_molecule_properties: 200 diff --git a/pyproject.toml b/pyproject.toml index 64c572c..4aea1ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,4 +25,4 @@ build-backend = "flit_core.buildapi" requires = ["flit_core >=3.2,<4"] [project.entry-points.'chebai.plugins'] -models = 'chebai_graph.models' \ No newline at end of file +models = 'chebai_graph.models' diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/readers/__init__.py b/tests/unit/readers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/readers/testGraphPropertyReader.py b/tests/unit/readers/testGraphPropertyReader.py new file mode 100644 index 0000000..0222fa4 --- /dev/null +++ b/tests/unit/readers/testGraphPropertyReader.py @@ -0,0 +1,86 @@ +import unittest + +import torch +from torch_geometric.data import Data as GeomData + +from chebai_graph.preprocessing.reader import GraphPropertyReader +from tests.unit.test_data import MoleculeGraph + + +class TestGraphPropertyReader(unittest.TestCase): + """Unit tests for the GraphPropertyReader class, which converts SMILES strings to torch_geometric Data objects.""" + + def setUp(self) -> None: + """Initialize the reader and the reference molecule graph.""" + self.reader: GraphPropertyReader = GraphPropertyReader() + self.molecule_graph: MoleculeGraph = MoleculeGraph() + + def test_read_data(self) -> None: + """Test that the reader correctly parses a SMILES string into a graph and matches expected aspirin structure.""" + smiles: str = "CC(=O)OC1=CC=CC=C1C(=O)O" # Aspirin + + data: GeomData = self.reader._read_data(smiles) # noqa + + self.assertIsInstance( + data, + GeomData, + msg="The output should be an instance of torch_geometric.data.Data.", + ) + + self.assertEqual( + data.edge_index.shape[0], + 2, + msg=f"Expected edge_index to have shape [2, num_edges], but got shape {data.edge_index.shape}", + ) + + self.assertEqual( + data.edge_index.shape[1], + data.edge_attr.shape[0], + msg=f"Mismatch between number of edges in edge_index ({data.edge_index.shape[1]}) and edge_attr ({data.edge_attr.shape[0]})", + ) + + self.assertEqual( + len(set(data.edge_index[0].tolist())), + data.x.shape[0], + msg=f"Number of unique source nodes in edge_index ({len(set(data.edge_index[0].tolist()))}) does not match number of nodes in x ({data.x.shape[0]})", + ) + + # Check for duplicates by checking if the rows are the same (direction matters) + _, counts = torch.unique(data.edge_index.t(), dim=0, return_counts=True) + self.assertFalse( + torch.any(counts > 1), + msg="There are duplicates of directed edge in edge_index", + ) + + expected_data: GeomData = self.molecule_graph.get_aspirin_graph() + self.assertTrue( + torch.equal(data.edge_index, expected_data.edge_index), + msg=( + "edge_index tensors do not match.\n" + f"Differences at indices: {(data.edge_index != expected_data.edge_index).nonzero()}.\n" + f"Parsed edge_index:\n{data.edge_index}\nExpected edge_index:\n{expected_data.edge_index}" + f"If fails in future, check if there is change in RDKIT version, the expected graph is generated with RDKIT 2024.9.6" + ), + ) + + self.assertEqual( + data.x.shape[0], + expected_data.x.shape[0], + msg=( + "The number of atoms (nodes) in the parsed graph does not match the reference graph.\n" + f"Parsed: {data.x.shape[0]}, Expected: {expected_data.x.shape[0]}" + ), + ) + + self.assertEqual( + data.edge_attr.shape[0], + expected_data.edge_attr.shape[0], + msg=( + "The number of edge attributes does not match the expected value.\n" + f"Parsed: {data.edge_attr.shape[0]}, Expected: {expected_data.edge_attr.shape[0]}" + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py new file mode 100644 index 0000000..4acf41d --- /dev/null +++ b/tests/unit/test_data.py @@ -0,0 +1,102 @@ +import torch +from torch_geometric.data import Data + + +class MoleculeGraph: + """Class representing molecular graph data.""" + + def get_aspirin_graph(self): + """ + Constructs and returns a PyTorch Geometric Data object representing the molecular graph of Aspirin. + + Aspirin -> CC(=O)OC1=CC=CC=C1C(=O)O ; CHEBI:15365 + + Node labels (atom indices): + O2 C5———C6 + \ / \ + C1———O3———C4 C7 + / \ / + C0 C9———C8 + / + C10 + / \ + O12 O11 + + + Returns: + torch_geometric.data.Data: A Data object with attributes: + - x (FloatTensor): Node feature matrix of shape (num_nodes, 1). + - edge_index (LongTensor): Graph connectivity in COO format of shape (2, num_edges). + - edge_attr (FloatTensor): Edge feature matrix of shape (num_edges, 1). + + Refer: + For graph construction: https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html + """ + + # --- Node features: atomic numbers (C=6, O=8) --- + # Shape of x : num_nodes x num_of_node_features + # fmt: off + x = torch.tensor( + [ + [6], # C0 - This feature belongs to atom/node with 0 value in edge_index + [6], # C1 - This feature belongs to atom/node with 1 value in edge_index + [8], # O2 - This feature belongs to atom/node with 2 value in edge_index + [8], # O3 - This feature belongs to atom/node with 3 value in edge_index + [6], # C4 - This feature belongs to atom/node with 4 value in edge_index + [6], # C5 - This feature belongs to atom/node with 5 value in edge_index + [6], # C6 - This feature belongs to atom/node with 6 value in edge_index + [6], # C7 - This feature belongs to atom/node with 7 value in edge_index + [6], # C8 - This feature belongs to atom/node with 8 value in edge_index + [6], # C9 - This feature belongs to atom/node with 9 value in edge_index + [6], # C10 - This feature belongs to atom/node with 10 value in edge_index + [8], # O11 - This feature belongs to atom/node with 11 value in edge_index + [8], # O12 - This feature belongs to atom/node with 12 value in edge_index + ], + dtype=torch.float, + ) + # fmt: on + + # --- Edge list (bidirectional) --- + # Shape of edge_index for undirected graph: 2 x num_of_edges; (2x26) + # Generated using RDKIT 2024.9.6 + # fmt: off + _edge_index = torch.tensor([ + [0, 1, 1, 3, 4, 5, 6, 7, 8, 9, 10, 10, 9], # Start atoms (u) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 4] # End atoms (v) + ], dtype=torch.long) + # fmt: on + + # Reverse the edges + reversed_edge_index = _edge_index[[1, 0], :] + + # First all directed edges from source to target are placed, + # then all directed edges from target to source are placed --- this is needed + undirected_edge_index = torch.cat([_edge_index, reversed_edge_index], dim=1) + + # --- Dummy edge features --- + # Shape of undirected_edge_attr: num_of_edges x num_of_edges_features (26 x 1) + # fmt: off + _edge_attr = torch.tensor([ + [1], # C0 - C1, This two features belong to elements at index 0 in `edge_index` + [2], # C1 - C2, This two features belong to elements at index 1 in `edge_index` + [2], # C1 - O3, This two features belong to elements at index 2 in `edge_index` + [2], # O3 - C4, This two features belong to elements at index 3 in `edge_index` + [1], # C4 - C5, This two features belong to elements at index 4 in `edge_index` + [1], # C5 - C6, This two features belong to elements at index 5 in `edge_index` + [1], # C6 - C7, This two features belong to elements at index 6 in `edge_index` + [1], # C7 - C8, This two features belong to elements at index 7 in `edge_index` + [1], # C8 - C9, This two features belong to elements at index 8 in `edge_index` + [1], # C9 - C10, This two features belong to elements at index 9 in `edge_index` + [1], # C10 - O11, This two features belong to elements at index 10 in `edge_index` + [1], # C10 - O12, This two features belong to elements at index 11 in `edge_index` + [1], # C9 - C4, This two features belong to elements at index 12 in `edge_index` + ], dtype=torch.float) + # fmt: on + + # Alignement of edge attributes should in same order as of edge_index + undirected_edge_attr = torch.cat([_edge_attr, _edge_attr], dim=0) + + # Create graph data object + return Data( + x=x, edge_index=undirected_edge_index, edge_attr=undirected_edge_attr + )