From 9150bb6a91c636b76b8749f4b8c6f8f6beb5a42c Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 7 May 2025 16:18:55 +0200 Subject: [PATCH 01/23] Create .gitignore --- .gitignore | 170 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9676c5b --- /dev/null +++ b/.gitignore @@ -0,0 +1,170 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# configs/ # commented as new configs can be added as a part of a feature + +/.idea +/data +/logs +/results_buffer +electra_pretrained.ckpt +.isort.cfg From 06a71a66017708bd6ed9b21b9f13a64b2785e8c0 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 7 May 2025 16:23:11 +0200 Subject: [PATCH 02/23] update precommit + github action --- .github/workflows/black.yml | 10 ++++++++++ .pre-commit-config.yaml | 28 ++++++++++++++++++++++------ 2 files changed, 32 insertions(+), 6 deletions(-) create mode 100644 .github/workflows/black.yml diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml new file mode 100644 index 0000000..b04fb15 --- /dev/null +++ b/.github/workflows/black.yml @@ -0,0 +1,10 @@ +name: Lint + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: psf/black@stable diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 866c153..108b91d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,25 @@ repos: -#- repo: https://github.com/PyCQA/isort -# rev: "5.12.0" -# hooks: -# - id: isort - repo: https://github.com/psf/black - rev: "22.10.0" + rev: "24.2.0" hooks: - - id: black \ No newline at end of file + - id: black + - id: black-jupyter # for formatting jupyter-notebook + +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort (python) + args: ["--profile=black"] + +- repo: https://github.com/asottile/seed-isort-config + rev: v2.2.0 + hooks: + - id: seed-isort-config + +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace From ff1adc9acb877c56371f2a5c3b2d74269ec0da59 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 7 May 2025 16:27:13 +0200 Subject: [PATCH 03/23] pre-commit format files --- README.md | 6 ++-- chebai_graph/models/gin_net.py | 9 +++--- chebai_graph/preprocessing/collate.py | 4 +-- chebai_graph/preprocessing/datasets/chebi.py | 32 ++++++++++--------- .../preprocessing/datasets/pubchem.py | 3 +- chebai_graph/preprocessing/properties.py | 6 ++-- .../preprocessing/property_encoder.py | 3 +- chebai_graph/preprocessing/reader.py | 21 ++++++------ .../preprocessing/transform_unlabeled.py | 1 + configs/data/chebi50_graph.yml | 2 +- configs/data/pubchem_graph.yml | 2 +- configs/loss/mask_pretraining.yml | 2 +- configs/model/gnn.yml | 2 +- configs/model/gnn_attention.yml | 2 +- configs/model/gnn_gine.yml | 2 +- configs/model/gnn_res_gated.yml | 2 +- configs/model/gnn_resgated_pretrain.yml | 2 +- pyproject.toml | 2 +- 18 files changed, 54 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index 6af4630..c8ce94b 100644 --- a/README.md +++ b/README.md @@ -2,13 +2,13 @@ ## Installation -Some requirements may not be installed successfully automatically. +Some requirements may not be installed successfully automatically. To install the `torch-` libraries, use `pip install torch-${lib} -f https://data.pyg.org/whl/torch-2.1.0+${CUDA}.html` where `${lib}` is either `scatter`, `geometric`, `sparse` or `cluster`, and -`${CUDA}` is either `cpu`, `cu118` or `cu121` (depending on your system, see e.g. +`${CUDA}` is either `cpu`, `cu118` or `cu121` (depending on your system, see e.g. [torch-geometric docs](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html)) @@ -31,7 +31,7 @@ We recommend the following setup: If you run the command from the `python-chebai` directory, you can use the same data for both chebai- and chebai-graph-models (e.g., Transformers and GNNs). Then you have to use `{path-to-chebai} -> .` and `{path-to-chebai-graph} -> ../python-chebai-graph`. - + Pretraining on a atom / bond masking task with PubChem data (feature-branch): ``` python3 -m chebai fit --model={path-to-chebai-graph}/configs/model/gnn_resgated_pretrain.yml --data={path-to-chebai-graph}/configs/data/pubchem_graph.yml --trainer={path-to-chebai}/configs/training/pretraining_trainer.yml diff --git a/chebai_graph/models/gin_net.py b/chebai_graph/models/gin_net.py index 75c2c45..6fed4c6 100644 --- a/chebai_graph/models/gin_net.py +++ b/chebai_graph/models/gin_net.py @@ -1,10 +1,11 @@ +import typing + +import torch +import torch.nn.functional as F +import torch_geometric from torch_scatter import scatter_add from chebai_graph.models.graph import GraphBaseNet -import torch_geometric -import torch.nn.functional as F -import torch -import typing class AggregateMLP(torch.nn.Module): diff --git a/chebai_graph/preprocessing/collate.py b/chebai_graph/preprocessing/collate.py index 2c5f696..4be36cf 100644 --- a/chebai_graph/preprocessing/collate.py +++ b/chebai_graph/preprocessing/collate.py @@ -1,11 +1,11 @@ from typing import Dict import torch +from chebai.preprocessing.collate import RaggedCollator from torch_geometric.data import Data as GeomData from torch_geometric.data.collate import collate as graph_collate -from chebai_graph.preprocessing.structures import XYGraphData -from chebai.preprocessing.collate import RaggedCollator +from chebai_graph.preprocessing.structures import XYGraphData class GraphCollator(RaggedCollator): diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 6ee8bc5..843ba35 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -1,26 +1,26 @@ -from typing import Optional, List, Callable +import importlib +import os +from typing import Callable, List, Optional +import pandas as pd +import torch +import tqdm +from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.preprocessing.datasets.chebi import ( ChEBIOver50, ChEBIOver100, ChEBIOverXPartial, ) -from chebai.preprocessing.datasets.base import XYBaseDataModule from lightning_utilities.core.rank_zero import rank_zero_info +from torch_geometric.data.data import Data as GeomData -from chebai_graph.preprocessing.reader import GraphReader, GraphPropertyReader +import chebai_graph.preprocessing.properties as graph_properties from chebai_graph.preprocessing.properties import ( AtomProperty, BondProperty, MolecularProperty, ) -import pandas as pd -from torch_geometric.data.data import Data as GeomData -import torch -import chebai_graph.preprocessing.properties as graph_properties -import importlib -import os -import tqdm +from chebai_graph.preprocessing.reader import GraphPropertyReader, GraphReader class ChEBI50GraphData(ChEBIOver50): @@ -84,9 +84,11 @@ def _setup_properties(self): for file in file_names: # processed_dir_main only exists for ChEBI datasets path = os.path.join( - self.processed_dir_main - if hasattr(self, "processed_dir_main") - else self.raw_dir, + ( + self.processed_dir_main + if hasattr(self, "processed_dir_main") + else self.raw_dir + ), file, ) raw_data += list(self._load_dict(path)) @@ -94,8 +96,8 @@ def _setup_properties(self): features = [row["features"] for row in raw_data] # use vectorized version of encode function, apply only if value is present - enc_if_not_none = ( - lambda encode, value: [encode(atom_v) for atom_v in value] + enc_if_not_none = lambda encode, value: ( + [encode(atom_v) for atom_v in value] if value is not None and len(value) > 0 else None ) diff --git a/chebai_graph/preprocessing/datasets/pubchem.py b/chebai_graph/preprocessing/datasets/pubchem.py index 210b7ab..6f5d118 100644 --- a/chebai_graph/preprocessing/datasets/pubchem.py +++ b/chebai_graph/preprocessing/datasets/pubchem.py @@ -1,6 +1,7 @@ -from chebai_graph.preprocessing.datasets.chebi import GraphPropertiesMixIn from chebai.preprocessing.datasets.pubchem import PubchemChem +from chebai_graph.preprocessing.datasets.chebi import GraphPropertiesMixIn + class PubChemGraphProperties(GraphPropertiesMixIn, PubchemChem): pass diff --git a/chebai_graph/preprocessing/properties.py b/chebai_graph/preprocessing/properties.py index 95f85ab..9b927ed 100644 --- a/chebai_graph/preprocessing/properties.py +++ b/chebai_graph/preprocessing/properties.py @@ -6,11 +6,11 @@ from descriptastorus.descriptors import rdNormalizedDescriptors from chebai_graph.preprocessing.property_encoder import ( - PropertyEncoder, - IndexEncoder, - OneHotEncoder, AsIsEncoder, BoolEncoder, + IndexEncoder, + OneHotEncoder, + PropertyEncoder, ) diff --git a/chebai_graph/preprocessing/property_encoder.py b/chebai_graph/preprocessing/property_encoder.py index 497025c..ebfbe0c 100644 --- a/chebai_graph/preprocessing/property_encoder.py +++ b/chebai_graph/preprocessing/property_encoder.py @@ -1,8 +1,9 @@ import abc import os -import torch from typing import Optional +import torch + class PropertyEncoder(abc.ABC): def __init__(self, property, **kwargs): diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index b814d53..448f402 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -1,19 +1,18 @@ import importlib +import os +from typing import List, Mapping, Optional, Tuple -from torch_geometric.utils import from_networkx -from typing import Tuple, Mapping, Optional, List - -import importlib +import chebai.preprocessing.reader as dr import networkx as nx -import os -import torch -import rdkit.Chem as Chem import pysmiles as ps -import chebai.preprocessing.reader as dr -from chebai_graph.preprocessing.collate import GraphCollator -import chebai_graph.preprocessing.properties as properties +import rdkit.Chem as Chem +import torch +from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn from torch_geometric.data import Data as GeomData -from lightning_utilities.core.rank_zero import rank_zero_warn, rank_zero_info +from torch_geometric.utils import from_networkx + +import chebai_graph.preprocessing.properties as properties +from chebai_graph.preprocessing.collate import GraphCollator class GraphPropertyReader(dr.ChemDataReader): diff --git a/chebai_graph/preprocessing/transform_unlabeled.py b/chebai_graph/preprocessing/transform_unlabeled.py index 3920659..0cc4b35 100644 --- a/chebai_graph/preprocessing/transform_unlabeled.py +++ b/chebai_graph/preprocessing/transform_unlabeled.py @@ -1,4 +1,5 @@ import random + import torch diff --git a/configs/data/chebi50_graph.yml b/configs/data/chebi50_graph.yml index 14cc489..19c8753 100644 --- a/configs/data/chebi50_graph.yml +++ b/configs/data/chebi50_graph.yml @@ -1 +1 @@ -class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphData \ No newline at end of file +class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphData diff --git a/configs/data/pubchem_graph.yml b/configs/data/pubchem_graph.yml index af04491..c21f188 100644 --- a/configs/data/pubchem_graph.yml +++ b/configs/data/pubchem_graph.yml @@ -16,4 +16,4 @@ init_args: - chebai_graph.preprocessing.properties.BondInRing - chebai_graph.preprocessing.properties.BondAromaticity #- chebai_graph.preprocessing.properties.MoleculeNumRings - - chebai_graph.preprocessing.properties.RDKit2DNormalized \ No newline at end of file + - chebai_graph.preprocessing.properties.RDKit2DNormalized diff --git a/configs/loss/mask_pretraining.yml b/configs/loss/mask_pretraining.yml index c677559..6d2a560 100644 --- a/configs/loss/mask_pretraining.yml +++ b/configs/loss/mask_pretraining.yml @@ -1 +1 @@ -class_path: chebai_graph.loss.pretraining.MaskPretrainingLoss \ No newline at end of file +class_path: chebai_graph.loss.pretraining.MaskPretrainingLoss diff --git a/configs/model/gnn.yml b/configs/model/gnn.yml index b0b119d..f85fa76 100644 --- a/configs/model/gnn.yml +++ b/configs/model/gnn.yml @@ -7,4 +7,4 @@ init_args: hidden_length: 512 dropout_rate: 0.1 n_conv_layers: 3 - n_linear_layers: 3 \ No newline at end of file + n_linear_layers: 3 diff --git a/configs/model/gnn_attention.yml b/configs/model/gnn_attention.yml index b1c553b..0c11ced 100644 --- a/configs/model/gnn_attention.yml +++ b/configs/model/gnn_attention.yml @@ -8,4 +8,4 @@ init_args: dropout_rate: 0.1 n_conv_layers: 5 n_linear_layers: 3 - n_heads: 5 \ No newline at end of file + n_heads: 5 diff --git a/configs/model/gnn_gine.yml b/configs/model/gnn_gine.yml index 0d0ed20..c84ea61 100644 --- a/configs/model/gnn_gine.yml +++ b/configs/model/gnn_gine.yml @@ -8,4 +8,4 @@ init_args: n_conv_layers: 5 n_linear_layers: 3 n_atom_properties: 125 - n_bond_properties: 5 \ No newline at end of file + n_bond_properties: 5 diff --git a/configs/model/gnn_res_gated.yml b/configs/model/gnn_res_gated.yml index d9ddc05..27d1e78 100644 --- a/configs/model/gnn_res_gated.yml +++ b/configs/model/gnn_res_gated.yml @@ -10,4 +10,4 @@ init_args: n_linear_layers: 3 n_atom_properties: 158 n_bond_properties: 7 - n_molecule_properties: 200 \ No newline at end of file + n_molecule_properties: 200 diff --git a/configs/model/gnn_resgated_pretrain.yml b/configs/model/gnn_resgated_pretrain.yml index c26db76..fad8c27 100644 --- a/configs/model/gnn_resgated_pretrain.yml +++ b/configs/model/gnn_resgated_pretrain.yml @@ -13,4 +13,4 @@ init_args: n_linear_layers: 3 n_atom_properties: 151 n_bond_properties: 7 - n_molecule_properties: 200 \ No newline at end of file + n_molecule_properties: 200 diff --git a/pyproject.toml b/pyproject.toml index 64c572c..4aea1ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,4 +25,4 @@ build-backend = "flit_core.buildapi" requires = ["flit_core >=3.2,<4"] [project.entry-points.'chebai.plugins'] -models = 'chebai_graph.models' \ No newline at end of file +models = 'chebai_graph.models' From d7f30d38c51235f4d21731360d11ff9ff0724344 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 7 May 2025 16:31:50 +0200 Subject: [PATCH 04/23] change graph from directed to UNDIRECTED --- chebai_graph/preprocessing/reader.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index 448f402..ced9f31 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -1,6 +1,5 @@ -import importlib import os -from typing import List, Mapping, Optional, Tuple +from typing import List, Optional import chebai.preprocessing.reader as dr import networkx as nx @@ -9,7 +8,7 @@ import torch from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn from torch_geometric.data import Data as GeomData -from torch_geometric.utils import from_networkx +from torch_geometric.utils import from_networkx, to_undirected import chebai_graph.preprocessing.properties as properties from chebai_graph.preprocessing.collate import GraphCollator @@ -44,7 +43,7 @@ def _smiles_to_mol(self, smiles: str) -> Optional[Chem.rdchem.Mol]: try: Chem.SanitizeMol(mol) except Exception as e: - rank_zero_warn(f"Rdkit failed at sanitizing {smiles}") + rank_zero_warn(f"Rdkit failed at sanitizing {smiles} \n Error: {e}") self.failed_counter += 1 self.mol_object_buffer[smiles] = mol return mol @@ -64,7 +63,7 @@ def _read_data(self, raw_data): [bond.GetEndAtomIdx() for bond in mol.GetBonds()], ] ) - return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) + return GeomData(x=x, edge_index=to_undirected(edge_index), edge_attr=edge_attr) def on_finish(self): rank_zero_info(f"Failed to read {self.failed_counter} SMILES in total") From b8189d183190e559f6a1288e4bae17af8a4971d0 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 13:37:46 +0200 Subject: [PATCH 05/23] add test data --- tests/__init__.py | 0 tests/unit/__init__.py | 0 tests/unit/test_data.py | 119 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 119 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/test_data.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py new file mode 100644 index 0000000..aea97a0 --- /dev/null +++ b/tests/unit/test_data.py @@ -0,0 +1,119 @@ +import torch +from torch_geometric.data import Data + + +class MoleculeGraph: + """Dummy graph of Aspirin with node and edge features""" + + def get_aspirin_graph(self): + """ + Aspirin -> CC(=O)OC1=CC=CC=C1C(=O)O ; CHEBI:15365 + + Node labels (atom indices): + O2 C5———C6 + \ / \ + C1———O3———C4 C7 + / \ / + C0 C9———C8 + / + C10 + / \ + O12 O11 + """ + + # --- Node features: atomic numbers (C=6, O=8) --- + # Shape of x : num_nodes x num_of_node_features + x = torch.tensor( + [ + [6], # C0 - This feature belongs to atom with atom `0` in edge_index + [6], # C1 - This feature belongs to atom with atom `1` in edge_index + [8], # O2 - This feature belongs to atom with atom `2` in edge_index + [8], # O3 - This feature belongs to atom with atom `3` in edge_index + [6], # C4 - This feature belongs to atom with atom `4` in edge_index + [6], # C5 - This feature belongs to atom with atom `5` in edge_index + [6], # C6 - This feature belongs to atom with atom `6` in edge_index + [6], # C7 - This feature belongs to atom with atom `7` in edge_index + [6], # C8 - This feature belongs to atom with atom `8` in edge_index + [6], # C9 - This feature belongs to atom with atom `9` in edge_index + [6], # C10 - This feature belongs to atom with atom `10` in edge_index + [8], # O11 - This feature belongs to atom with atom `11` in edge_index + [8], # O12 - This feature belongs to atom with atom `12` in edge_index + ], + dtype=torch.float, + ) + + # --- Edge list (bidirectional) --- + # Shape of edge_index for undirected graph: 2 x num_of_edges + edge_index = ( + torch.tensor( + [ + [0, 1], + [1, 0], + [1, 2], + [2, 1], + [1, 3], + [3, 1], + [3, 4], + [4, 3], + [4, 5], + [5, 4], + [5, 6], + [6, 5], + [6, 7], + [7, 6], + [7, 8], + [8, 7], + [8, 9], + [9, 8], + [4, 9], + [9, 4], + [9, 10], + [10, 9], + [10, 11], + [11, 10], + [10, 12], + [12, 10], + ], + dtype=torch.long, + ) + .t() + .contiguous() + ) + + # --- Dummy edge features: bond type (single=1, double=2, ester=3) --- + # Using all single bonds for simplicity (except C=O as double bonds) + # Shape of edge_attr: num_of_edges x num_of_edges_features + edge_attr = torch.tensor( + [ + [1], + [1], # C0 - C1 # This two features to two first bond in + [2], + [2], # C1 = O2 (double bond) + [1], + [1], # C1 - O3 + [1], + [1], # O3 - C4 + [1], + [1], # C4 - C5 + [1], + [1], # C5 - C6 + [1], + [1], # C6 - C7 + [1], + [1], # C7 - C8 + [1], + [1], # C8 - C9 + [1], + [1], # C4 - C9 (ring closure) + [1], + [1], # C9 - C10 + [2], + [2], # C10 = O11 (carboxylic acid) + [1], + [1], # C10 - O12 (hydroxyl) + ], + dtype=torch.float, + ) + + # Create graph data object + return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) From ad301e6dda67ed7b067937301931844ccc4ff2d6 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 13:45:34 +0200 Subject: [PATCH 06/23] edge_features should be calculated after undirected graph --- chebai_graph/preprocessing/reader.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index ced9f31..9862fe9 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -55,15 +55,19 @@ def _read_data(self, raw_data): x = torch.zeros((mol.GetNumAtoms(), 0)) - edge_attr = torch.zeros((mol.GetNumBonds(), 0)) - - edge_index = torch.tensor( - [ - [bond.GetBeginAtomIdx() for bond in mol.GetBonds()], - [bond.GetEndAtomIdx() for bond in mol.GetBonds()], - ] + edge_index = to_undirected( + torch.tensor( + [ + [bond.GetBeginAtomIdx() for bond in mol.GetBonds()], + [bond.GetEndAtomIdx() for bond in mol.GetBonds()], + ] + ) ) - return GeomData(x=x, edge_index=to_undirected(edge_index), edge_attr=edge_attr) + + # edge_index.shape == [2, num_edges]; edge_attr.shape == [num_edges, num_edge_features] + edge_attr = torch.zeros((edge_index.size(1), 0)) + + return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) def on_finish(self): rank_zero_info(f"Failed to read {self.failed_counter} SMILES in total") From 344d828adb73b5cb28c5ff04dca122287dd7650e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 16:51:26 +0200 Subject: [PATCH 07/23] directed edge which form an un-dir edge should be adjancent --- chebai_graph/preprocessing/reader.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index 9862fe9..ac6f212 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -55,14 +55,12 @@ def _read_data(self, raw_data): x = torch.zeros((mol.GetNumAtoms(), 0)) - edge_index = to_undirected( - torch.tensor( - [ - [bond.GetBeginAtomIdx() for bond in mol.GetBonds()], - [bond.GetEndAtomIdx() for bond in mol.GetBonds()], - ] - ) - ) + # We need to ensure that directed edges which form a undirected edge are adjacent to each other + edge_index_list = [[], []] + for bond in mol.GetBonds(): + edge_index_list[0].extend([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) + edge_index_list[1].extend([bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]) + edge_index = torch.tensor(edge_index_list, dtype=torch.long) # edge_index.shape == [2, num_edges]; edge_attr.shape == [num_edges, num_edge_features] edge_attr = torch.zeros((edge_index.size(1), 0)) From 8a69828818dc09046806f06c7c8a15e18676f3fa Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 16:59:51 +0200 Subject: [PATCH 08/23] add test for GraphPropertyReader --- tests/unit/readers/__init__.py | 0 tests/unit/readers/testGraphPropertyReader.py | 61 +++++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 tests/unit/readers/__init__.py create mode 100644 tests/unit/readers/testGraphPropertyReader.py diff --git a/tests/unit/readers/__init__.py b/tests/unit/readers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/readers/testGraphPropertyReader.py b/tests/unit/readers/testGraphPropertyReader.py new file mode 100644 index 0000000..c3771c3 --- /dev/null +++ b/tests/unit/readers/testGraphPropertyReader.py @@ -0,0 +1,61 @@ +import unittest + +import torch +from torch_geometric.data import Data as GeomData + +from chebai_graph.preprocessing.reader import GraphPropertyReader +from tests.unit.test_data import MoleculeGraph + + +class TestGraphPropertyReader(unittest.TestCase): + """Unit tests for the GraphPropertyReader class, which converts SMILES strings to torch_geometric Data objects.""" + + def setUp(self) -> None: + """Initialize the reader and the reference molecule graph.""" + self.reader: GraphPropertyReader = GraphPropertyReader() + self.molecule_graph: MoleculeGraph = MoleculeGraph() + + def test_read_data(self) -> None: + """Test that the reader correctly parses a SMILES string into a graph and matches expected aspirin structure.""" + smiles: str = "CC(=O)OC1=CC=CC=C1C(=O)O" # Aspirin + + data: GeomData = self.reader._read_data(smiles) + expected_data: GeomData = self.molecule_graph.get_aspirin_graph() + + self.assertIsInstance( + data, + GeomData, + msg="The output should be an instance of torch_geometric.data.Data.", + ) + + self.assertTrue( + torch.equal(data.edge_index, expected_data.edge_index), + msg=( + "edge_index tensors do not match.\n" + f"Differences at indices: {(data.edge_index != expected_data.edge_index).nonzero()}.\n" + f"Parsed edge_index:\n{data.edge_index}\nExpected edge_index:\n{expected_data.edge_index}" + f"If fails in future, check if there is change in RDKIT version, the expected graph is generated with RDKIT 2024.9.6" + ), + ) + + self.assertEqual( + data.x.shape[0], + expected_data.x.shape[0], + msg=( + "The number of atoms (nodes) in the parsed graph does not match the reference graph.\n" + f"Parsed: {data.x.shape[0]}, Expected: {expected_data.x.shape[0]}" + ), + ) + + self.assertEqual( + data.edge_attr.shape[0], + expected_data.edge_attr.shape[0], + msg=( + "The number of edge attributes does not match the expected value.\n" + f"Parsed: {data.edge_attr.shape[0]}, Expected: {expected_data.edge_attr.shape[0]}" + ), + ) + + +if __name__ == "__main__": + unittest.main() From e0064b837394b0341b34bb3969ce738c15d074fb Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 17:18:14 +0200 Subject: [PATCH 09/23] add gt test data for aspirin --- tests/unit/test_data.py | 163 ++++++++++++++++++++-------------------- 1 file changed, 80 insertions(+), 83 deletions(-) diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py index aea97a0..8aff95b 100644 --- a/tests/unit/test_data.py +++ b/tests/unit/test_data.py @@ -3,10 +3,12 @@ class MoleculeGraph: - """Dummy graph of Aspirin with node and edge features""" + """Class representing molecular graph data.""" def get_aspirin_graph(self): """ + Constructs and returns a PyTorch Geometric Data object representing the molecular graph of Aspirin. + Aspirin -> CC(=O)OC1=CC=CC=C1C(=O)O ; CHEBI:15365 Node labels (atom indices): @@ -19,101 +21,96 @@ def get_aspirin_graph(self): C10 / \ O12 O11 + + + Returns: + torch_geometric.data.Data: A Data object with attributes: + - x (FloatTensor): Node feature matrix of shape (num_nodes, 1). + - edge_index (LongTensor): Graph connectivity in COO format of shape (2, num_edges). + - edge_attr (FloatTensor): Edge feature matrix of shape (num_edges, 1). + + Refer: + For graph construction: https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html """ # --- Node features: atomic numbers (C=6, O=8) --- # Shape of x : num_nodes x num_of_node_features x = torch.tensor( [ - [6], # C0 - This feature belongs to atom with atom `0` in edge_index - [6], # C1 - This feature belongs to atom with atom `1` in edge_index - [8], # O2 - This feature belongs to atom with atom `2` in edge_index - [8], # O3 - This feature belongs to atom with atom `3` in edge_index - [6], # C4 - This feature belongs to atom with atom `4` in edge_index - [6], # C5 - This feature belongs to atom with atom `5` in edge_index - [6], # C6 - This feature belongs to atom with atom `6` in edge_index - [6], # C7 - This feature belongs to atom with atom `7` in edge_index - [6], # C8 - This feature belongs to atom with atom `8` in edge_index - [6], # C9 - This feature belongs to atom with atom `9` in edge_index - [6], # C10 - This feature belongs to atom with atom `10` in edge_index - [8], # O11 - This feature belongs to atom with atom `11` in edge_index - [8], # O12 - This feature belongs to atom with atom `12` in edge_index + [ + 6 + ], # C0 - This feature belongs to atom/node with `0` value in edge_index + [ + 6 + ], # C1 - This feature belongs to atom/node with `1` value in edge_index + [ + 8 + ], # O2 - This feature belongs to atom/node with `2` value in edge_index + [ + 8 + ], # O3 - This feature belongs to atom/node with `3` value in edge_index + [ + 6 + ], # C4 - This feature belongs to atom/node with `4` value in edge_index + [ + 6 + ], # C5 - This feature belongs to atom/node with `5` value in edge_index + [ + 6 + ], # C6 - This feature belongs to atom/node with `6` value in edge_index + [ + 6 + ], # C7 - This feature belongs to atom/node with `7` value in edge_index + [ + 6 + ], # C8 - This feature belongs to atom/node with `8` value in edge_index + [ + 6 + ], # C9 - This feature belongs to atom/node with `9` value in edge_index + [ + 6 + ], # C10 - This feature belongs to atom/node with `10` value in edge_index + [ + 8 + ], # O11 - This feature belongs to atom/node with `11` value in edge_index + [ + 8 + ], # O12 - This feature belongs to atom/node with `12` value in edge_index ], dtype=torch.float, ) # --- Edge list (bidirectional) --- - # Shape of edge_index for undirected graph: 2 x num_of_edges - edge_index = ( - torch.tensor( - [ - [0, 1], - [1, 0], - [1, 2], - [2, 1], - [1, 3], - [3, 1], - [3, 4], - [4, 3], - [4, 5], - [5, 4], - [5, 6], - [6, 5], - [6, 7], - [7, 6], - [7, 8], - [8, 7], - [8, 9], - [9, 8], - [4, 9], - [9, 4], - [9, 10], - [10, 9], - [10, 11], - [11, 10], - [10, 12], - [12, 10], - ], - dtype=torch.long, - ) - .t() - .contiguous() - ) + # Shape of edge_index for undirected graph: 2 x num_of_edges; (2x26) + # 2 directed edges of one undirected edge are adjacent to each other --- this is needed + + # fmt: off + # Generated using RDKIT 2024.9.6 + edge_index = torch.tensor([ + [0, 1, 1, 2, 1, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 10, 12, 9, 4], # Start atoms (u) + [1, 0, 2, 1, 3, 1, 4, 3, 5, 4, 6, 5, 7, 6, 8, 7, 9, 8, 10, 9, 11, 10, 12, 10, 4, 9] # End atoms (v) + ], dtype=torch.long) + # fmt: on - # --- Dummy edge features: bond type (single=1, double=2, ester=3) --- - # Using all single bonds for simplicity (except C=O as double bonds) + # --- Dummy edge features --- # Shape of edge_attr: num_of_edges x num_of_edges_features - edge_attr = torch.tensor( - [ - [1], - [1], # C0 - C1 # This two features to two first bond in - [2], - [2], # C1 = O2 (double bond) - [1], - [1], # C1 - O3 - [1], - [1], # O3 - C4 - [1], - [1], # C4 - C5 - [1], - [1], # C5 - C6 - [1], - [1], # C6 - C7 - [1], - [1], # C7 - C8 - [1], - [1], # C8 - C9 - [1], - [1], # C4 - C9 (ring closure) - [1], - [1], # C9 - C10 - [2], - [2], # C10 = O11 (carboxylic acid) - [1], - [1], # C10 - O12 (hydroxyl) - ], - dtype=torch.float, - ) + # fmt: off + edge_attr = torch.tensor([ + [1], [1], # C0 - C1, This two features belong to elements at index 0 and 1 in `edge_index` + [2], [2], # C1 - C2, This two features belong to elements at index 2 and 3 in `edge_index` + [2], [2], # C1 - O3, This two features belong to elements at index 4 and 5 in `edge_index` + [2], [2], # O3 - C4, This two features belong to elements at index 6 and 7 in `edge_index` + [1], [1], # C4 - C5, This two features belong to elements at index 8 and 9 in `edge_index` + [1], [1], # C5 - C6, This two features belong to elements at index 10 and 11 in `edge_index` + [1], [1], # C6 - C7, This two features belong to elements at index 12 and 13 in `edge_index` + [1], [1], # C7 - C8, This two features belong to elements at index 14 and 15 in `edge_index` + [1], [1], # C8 - C9, This two features belong to elements at index 16 and 17 in `edge_index` + [1], [1], # C9 - C10, This two features belong to elements at index 18 and 19 in `edge_index` + [1], [1], # C10 - O11, This two features belong to elements at index 20 and 21 in `edge_index` + [1], [1], # C10 - O12, This two features belong to elements at index 22 and 23 in `edge_index` + [1], [1], # C9 - C4, This two features belong to elements at index 24 and 25 in `edge_index` + ], dtype=torch.float) + # fmt: on # Create graph data object return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) From a9c722888e846c10d4ec82953ce99ef2a601555a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 17:33:18 +0200 Subject: [PATCH 10/23] Update test_data.py --- tests/unit/test_data.py | 54 ++++++++++++----------------------------- 1 file changed, 15 insertions(+), 39 deletions(-) diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py index 8aff95b..cd3b16f 100644 --- a/tests/unit/test_data.py +++ b/tests/unit/test_data.py @@ -35,50 +35,26 @@ def get_aspirin_graph(self): # --- Node features: atomic numbers (C=6, O=8) --- # Shape of x : num_nodes x num_of_node_features + # fmt: off x = torch.tensor( [ - [ - 6 - ], # C0 - This feature belongs to atom/node with `0` value in edge_index - [ - 6 - ], # C1 - This feature belongs to atom/node with `1` value in edge_index - [ - 8 - ], # O2 - This feature belongs to atom/node with `2` value in edge_index - [ - 8 - ], # O3 - This feature belongs to atom/node with `3` value in edge_index - [ - 6 - ], # C4 - This feature belongs to atom/node with `4` value in edge_index - [ - 6 - ], # C5 - This feature belongs to atom/node with `5` value in edge_index - [ - 6 - ], # C6 - This feature belongs to atom/node with `6` value in edge_index - [ - 6 - ], # C7 - This feature belongs to atom/node with `7` value in edge_index - [ - 6 - ], # C8 - This feature belongs to atom/node with `8` value in edge_index - [ - 6 - ], # C9 - This feature belongs to atom/node with `9` value in edge_index - [ - 6 - ], # C10 - This feature belongs to atom/node with `10` value in edge_index - [ - 8 - ], # O11 - This feature belongs to atom/node with `11` value in edge_index - [ - 8 - ], # O12 - This feature belongs to atom/node with `12` value in edge_index + [6], # C0 - This feature belongs to atom/node with 0 value in edge_index + [6], # C1 - This feature belongs to atom/node with 1 value in edge_index + [8], # O2 - This feature belongs to atom/node with 2 value in edge_index + [8], # O3 - This feature belongs to atom/node with 3 value in edge_index + [6], # C4 - This feature belongs to atom/node with 4 value in edge_index + [6], # C5 - This feature belongs to atom/node with 5 value in edge_index + [6], # C6 - This feature belongs to atom/node with 6 value in edge_index + [6], # C7 - This feature belongs to atom/node with 7 value in edge_index + [6], # C8 - This feature belongs to atom/node with 8 value in edge_index + [6], # C9 - This feature belongs to atom/node with 9 value in edge_index + [6], # C10 - This feature belongs to atom/node with 10 value in edge_index + [8], # O11 - This feature belongs to atom/node with 11 value in edge_index + [8], # O12 - This feature belongs to atom/node with 12 value in edge_index ], dtype=torch.float, ) + # fmt: on # --- Edge list (bidirectional) --- # Shape of edge_index for undirected graph: 2 x num_of_edges; (2x26) From 0a9760ddd3e4b4c7e3b4bd4227a91f4ff8b00d8a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 20:30:06 +0200 Subject: [PATCH 11/23] add more graph test --- tests/unit/readers/testGraphPropertyReader.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/unit/readers/testGraphPropertyReader.py b/tests/unit/readers/testGraphPropertyReader.py index c3771c3..a4f18eb 100644 --- a/tests/unit/readers/testGraphPropertyReader.py +++ b/tests/unit/readers/testGraphPropertyReader.py @@ -19,8 +19,7 @@ def test_read_data(self) -> None: """Test that the reader correctly parses a SMILES string into a graph and matches expected aspirin structure.""" smiles: str = "CC(=O)OC1=CC=CC=C1C(=O)O" # Aspirin - data: GeomData = self.reader._read_data(smiles) - expected_data: GeomData = self.molecule_graph.get_aspirin_graph() + data: GeomData = self.reader._read_data(smiles) # noqa self.assertIsInstance( data, @@ -28,6 +27,19 @@ def test_read_data(self) -> None: msg="The output should be an instance of torch_geometric.data.Data.", ) + assert ( + data.edge_index.shape[0] == 2 + ), f"Expected edge_index to have shape [2, num_edges], but got shape {data.edge_index.shape}" + + assert ( + data.edge_index.shape[1] == data.edge_attr.shape[0] + ), f"Mismatch between number of edges in edge_index ({data.edge_index.shape[1]}) and edge_attr ({data.edge_attr.shape[0]})" + + assert ( + len(set(data.edge_index[0].tolist())) == data.x.shape[0] + ), f"Number of unique source nodes in edge_index ({len(set(data.edge_index[0].tolist()))}) does not match number of nodes in x ({data.x.shape[0]})" + + expected_data: GeomData = self.molecule_graph.get_aspirin_graph() self.assertTrue( torch.equal(data.edge_index, expected_data.edge_index), msg=( From 1a8dcb60897bfb5027d731647cc1955cc561b849 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 23:16:21 +0200 Subject: [PATCH 12/23] first src to tgt edges then tgt to src - instead of using adjacent directed edge, this one is better approach since we can stack edge attributes generated later without any further logic to rearrange edge_attr --- chebai_graph/preprocessing/reader.py | 13 +++---- tests/unit/test_data.py | 54 ++++++++++++++++------------ 2 files changed, 37 insertions(+), 30 deletions(-) diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index ac6f212..6cd4ecd 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -8,9 +8,8 @@ import torch from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn from torch_geometric.data import Data as GeomData -from torch_geometric.utils import from_networkx, to_undirected -import chebai_graph.preprocessing.properties as properties +from chebai_graph.preprocessing import properties from chebai_graph.preprocessing.collate import GraphCollator @@ -55,12 +54,10 @@ def _read_data(self, raw_data): x = torch.zeros((mol.GetNumAtoms(), 0)) - # We need to ensure that directed edges which form a undirected edge are adjacent to each other - edge_index_list = [[], []] - for bond in mol.GetBonds(): - edge_index_list[0].extend([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) - edge_index_list[1].extend([bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]) - edge_index = torch.tensor(edge_index_list, dtype=torch.long) + # First source to target edges, then target to source edges + src = [bond.GetBeginAtomIdx() for bond in mol.GetBonds()] + tgt = [bond.GetEndAtomIdx() for bond in mol.GetBonds()] + edge_index = torch.tensor([src + tgt, tgt + src], dtype=torch.long) # edge_index.shape == [2, num_edges]; edge_attr.shape == [num_edges, num_edge_features] edge_attr = torch.zeros((edge_index.size(1), 0)) diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py index cd3b16f..4acf41d 100644 --- a/tests/unit/test_data.py +++ b/tests/unit/test_data.py @@ -58,35 +58,45 @@ def get_aspirin_graph(self): # --- Edge list (bidirectional) --- # Shape of edge_index for undirected graph: 2 x num_of_edges; (2x26) - # 2 directed edges of one undirected edge are adjacent to each other --- this is needed - - # fmt: off # Generated using RDKIT 2024.9.6 - edge_index = torch.tensor([ - [0, 1, 1, 2, 1, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 10, 12, 9, 4], # Start atoms (u) - [1, 0, 2, 1, 3, 1, 4, 3, 5, 4, 6, 5, 7, 6, 8, 7, 9, 8, 10, 9, 11, 10, 12, 10, 4, 9] # End atoms (v) + # fmt: off + _edge_index = torch.tensor([ + [0, 1, 1, 3, 4, 5, 6, 7, 8, 9, 10, 10, 9], # Start atoms (u) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 4] # End atoms (v) ], dtype=torch.long) # fmt: on + # Reverse the edges + reversed_edge_index = _edge_index[[1, 0], :] + + # First all directed edges from source to target are placed, + # then all directed edges from target to source are placed --- this is needed + undirected_edge_index = torch.cat([_edge_index, reversed_edge_index], dim=1) + # --- Dummy edge features --- - # Shape of edge_attr: num_of_edges x num_of_edges_features + # Shape of undirected_edge_attr: num_of_edges x num_of_edges_features (26 x 1) # fmt: off - edge_attr = torch.tensor([ - [1], [1], # C0 - C1, This two features belong to elements at index 0 and 1 in `edge_index` - [2], [2], # C1 - C2, This two features belong to elements at index 2 and 3 in `edge_index` - [2], [2], # C1 - O3, This two features belong to elements at index 4 and 5 in `edge_index` - [2], [2], # O3 - C4, This two features belong to elements at index 6 and 7 in `edge_index` - [1], [1], # C4 - C5, This two features belong to elements at index 8 and 9 in `edge_index` - [1], [1], # C5 - C6, This two features belong to elements at index 10 and 11 in `edge_index` - [1], [1], # C6 - C7, This two features belong to elements at index 12 and 13 in `edge_index` - [1], [1], # C7 - C8, This two features belong to elements at index 14 and 15 in `edge_index` - [1], [1], # C8 - C9, This two features belong to elements at index 16 and 17 in `edge_index` - [1], [1], # C9 - C10, This two features belong to elements at index 18 and 19 in `edge_index` - [1], [1], # C10 - O11, This two features belong to elements at index 20 and 21 in `edge_index` - [1], [1], # C10 - O12, This two features belong to elements at index 22 and 23 in `edge_index` - [1], [1], # C9 - C4, This two features belong to elements at index 24 and 25 in `edge_index` + _edge_attr = torch.tensor([ + [1], # C0 - C1, This two features belong to elements at index 0 in `edge_index` + [2], # C1 - C2, This two features belong to elements at index 1 in `edge_index` + [2], # C1 - O3, This two features belong to elements at index 2 in `edge_index` + [2], # O3 - C4, This two features belong to elements at index 3 in `edge_index` + [1], # C4 - C5, This two features belong to elements at index 4 in `edge_index` + [1], # C5 - C6, This two features belong to elements at index 5 in `edge_index` + [1], # C6 - C7, This two features belong to elements at index 6 in `edge_index` + [1], # C7 - C8, This two features belong to elements at index 7 in `edge_index` + [1], # C8 - C9, This two features belong to elements at index 8 in `edge_index` + [1], # C9 - C10, This two features belong to elements at index 9 in `edge_index` + [1], # C10 - O11, This two features belong to elements at index 10 in `edge_index` + [1], # C10 - O12, This two features belong to elements at index 11 in `edge_index` + [1], # C9 - C4, This two features belong to elements at index 12 in `edge_index` ], dtype=torch.float) # fmt: on + # Alignement of edge attributes should in same order as of edge_index + undirected_edge_attr = torch.cat([_edge_attr, _edge_attr], dim=0) + # Create graph data object - return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + return Data( + x=x, edge_index=undirected_edge_index, edge_attr=undirected_edge_attr + ) From 5d4c174e314ebede5283ded03fd88a794dd49aa9 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 23:16:50 +0200 Subject: [PATCH 13/23] add test for duplicate directed edges --- tests/unit/readers/testGraphPropertyReader.py | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/tests/unit/readers/testGraphPropertyReader.py b/tests/unit/readers/testGraphPropertyReader.py index a4f18eb..0222fa4 100644 --- a/tests/unit/readers/testGraphPropertyReader.py +++ b/tests/unit/readers/testGraphPropertyReader.py @@ -27,17 +27,30 @@ def test_read_data(self) -> None: msg="The output should be an instance of torch_geometric.data.Data.", ) - assert ( - data.edge_index.shape[0] == 2 - ), f"Expected edge_index to have shape [2, num_edges], but got shape {data.edge_index.shape}" + self.assertEqual( + data.edge_index.shape[0], + 2, + msg=f"Expected edge_index to have shape [2, num_edges], but got shape {data.edge_index.shape}", + ) + + self.assertEqual( + data.edge_index.shape[1], + data.edge_attr.shape[0], + msg=f"Mismatch between number of edges in edge_index ({data.edge_index.shape[1]}) and edge_attr ({data.edge_attr.shape[0]})", + ) - assert ( - data.edge_index.shape[1] == data.edge_attr.shape[0] - ), f"Mismatch between number of edges in edge_index ({data.edge_index.shape[1]}) and edge_attr ({data.edge_attr.shape[0]})" + self.assertEqual( + len(set(data.edge_index[0].tolist())), + data.x.shape[0], + msg=f"Number of unique source nodes in edge_index ({len(set(data.edge_index[0].tolist()))}) does not match number of nodes in x ({data.x.shape[0]})", + ) - assert ( - len(set(data.edge_index[0].tolist())) == data.x.shape[0] - ), f"Number of unique source nodes in edge_index ({len(set(data.edge_index[0].tolist()))}) does not match number of nodes in x ({data.x.shape[0]})" + # Check for duplicates by checking if the rows are the same (direction matters) + _, counts = torch.unique(data.edge_index.t(), dim=0, return_counts=True) + self.assertFalse( + torch.any(counts > 1), + msg="There are duplicates of directed edge in edge_index", + ) expected_data: GeomData = self.molecule_graph.get_aspirin_graph() self.assertTrue( From 945ef7c34e77bd1ebddf23c76c4fb0e6997b9ce0 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 23:20:53 +0200 Subject: [PATCH 14/23] restore import --- chebai_graph/preprocessing/reader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index 6cd4ecd..687b199 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -8,6 +8,7 @@ import torch from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn from torch_geometric.data import Data as GeomData +from torch_geometric.utils import from_networkx from chebai_graph.preprocessing import properties from chebai_graph.preprocessing.collate import GraphCollator From 53a240ad1b478bea92efa3cecce7ee2d382d74ea Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 23:24:14 +0200 Subject: [PATCH 15/23] concat edge attr for undirected graph --- chebai_graph/preprocessing/datasets/chebi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 843ba35..f36fb64 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -168,7 +168,7 @@ def _merge_props_into_base(self, row): return GeomData( x=x, edge_index=geom_data.edge_index, - edge_attr=edge_attr, + edge_attr=torch.cat([edge_attr, edge_attr], dim=0), molecule_attr=molecule_attr, ) From b1f2da373de060702ee67b26b1e97c6049ff70cf Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 15 May 2025 11:56:35 +0200 Subject: [PATCH 16/23] concat prop values instead of edge_attr --- chebai_graph/preprocessing/datasets/chebi.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index f36fb64..da5445e 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -162,13 +162,17 @@ def _merge_props_into_base(self, row): if isinstance(property, AtomProperty): x = torch.cat([x, property_values], dim=1) elif isinstance(property, BondProperty): - edge_attr = torch.cat([edge_attr, property_values], dim=1) + # Concat/Duplicate properties values for undirected graph as `edge_index` has first src to tgt edges, then tgt to src edges + edge_attr = torch.cat( + [edge_attr, torch.cat([property_values, property_values], dim=0)], + dim=1, + ) else: molecule_attr = torch.cat([molecule_attr, property_values], dim=1) return GeomData( x=x, edge_index=geom_data.edge_index, - edge_attr=torch.cat([edge_attr, edge_attr], dim=0), + edge_attr=edge_attr, molecule_attr=molecule_attr, ) From 53ca4387a6292b651142040ffddb46d85b0cb25b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 23 May 2025 19:21:48 +0200 Subject: [PATCH 17/23] inherit from ChebiOverX instead of Base data module, as `load_processed_data_from_file` method used in this class is available in Dynamic dataset class --- chebai_graph/preprocessing/datasets/chebi.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index b132856..8e6599f 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -1,14 +1,15 @@ import importlib import os +from abc import ABC from typing import Callable, List, Optional import pandas as pd import torch import tqdm -from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.preprocessing.datasets.chebi import ( ChEBIOver50, ChEBIOver100, + ChEBIOverX, ChEBIOverXPartial, ) from lightning_utilities.core.rank_zero import rank_zero_info @@ -48,7 +49,7 @@ def _resolve_property( return getattr(graph_properties, property)() -class GraphPropertiesMixIn(XYBaseDataModule): +class GraphPropertiesMixIn(ChEBIOverX, ABC): READER = GraphPropertyReader def __init__( From 4319e4731a76bc800c009354601825f0dbedd340 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 25 May 2025 11:24:40 +0200 Subject: [PATCH 18/23] `nan_to_num` numpy2.x compatibility fix for https://github.com/ChEB-AI/python-chebai-graph/issues/10 --- chebai_graph/preprocessing/properties.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/properties.py b/chebai_graph/preprocessing/properties.py index 29808cd..2b3acf8 100644 --- a/chebai_graph/preprocessing/properties.py +++ b/chebai_graph/preprocessing/properties.py @@ -155,5 +155,5 @@ def get_property_value(self, mol: Chem.rdchem.Mol): features_normalized = generator_normalized.processMol( mol, Chem.MolToSmiles(mol) ) - np.nan_to_num(features_normalized, copy=False) + features_normalized = np.nan_to_num(features_normalized) return [features_normalized[1:]] From 7c0a484edb86a8d9297360350e913afe5d3c0ca9 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 25 May 2025 18:59:03 +0200 Subject: [PATCH 19/23] add print statements --- chebai_graph/preprocessing/datasets/chebi.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 8e6599f..9bddad7 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -107,10 +107,12 @@ def _setup_properties(self): if not os.path.isfile(self.get_property_path(property)): rank_zero_info(f"Processing property {property.name}") # read all property values first, then encode + rank_zero_info(f"\tReading property valeus...") property_values = [ self.reader.read_property(feat, property) for feat in tqdm.tqdm(features) ] + rank_zero_info(f"\tEncoding property values...") property.encoder.on_start(property_values=property_values) encoded_values = [ enc_if_not_none(property.encoder.encode, value) @@ -151,6 +153,7 @@ def _merge_props_into_base(self, row): assert isinstance(geom_data, GeomData) for property in self.properties: property_values = row[f"{property.name}"] + rank_zero_info(f"Merging {property.name} into base dataframe...") if isinstance(property_values, torch.Tensor): if len(property_values.size()) == 0: property_values = property_values.unsqueeze(0) From ffc0b7529671561944635bbc9b95bccffcc3cb22 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 25 May 2025 19:18:55 +0200 Subject: [PATCH 20/23] remove print for dataloader phase --- chebai_graph/preprocessing/datasets/chebi.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 9bddad7..2721b1c 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -153,7 +153,6 @@ def _merge_props_into_base(self, row): assert isinstance(geom_data, GeomData) for property in self.properties: property_values = row[f"{property.name}"] - rank_zero_info(f"Merging {property.name} into base dataframe...") if isinstance(property_values, torch.Tensor): if len(property_values.size()) == 0: property_values = property_values.unsqueeze(0) From 0a39749176106d782aaa4a4339f900806382452b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 28 May 2025 11:18:31 +0200 Subject: [PATCH 21/23] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 9676c5b..3f0c0ab 100644 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,4 @@ cython_debug/ /results_buffer electra_pretrained.ckpt .isort.cfg +/.vscode From 51e609ec74c51e23607ebb2045db6f68a4c04624 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 28 May 2025 11:35:11 +0200 Subject: [PATCH 22/23] why input channels needed? when n_atom_properties is there --- chebai_graph/models/graph.py | 38 ++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/chebai_graph/models/graph.py b/chebai_graph/models/graph.py index 5da9a62..8dec420 100644 --- a/chebai_graph/models/graph.py +++ b/chebai_graph/models/graph.py @@ -1,5 +1,6 @@ import logging import typing +from abc import ABC import torch import torch.nn.functional as F @@ -15,7 +16,7 @@ logging.getLogger("pysmiles").setLevel(logging.CRITICAL) -class GraphBaseNet(ChebaiBaseNet): +class GraphBaseNet(ChebaiBaseNet, ABC): def _get_prediction_and_labels(self, data, labels, output): return torch.sigmoid(output), labels.int() @@ -104,26 +105,26 @@ def __init__(self, config: typing.Dict, **kwargs): self.activation = F.elu self.dropout = nn.Dropout(self.dropout_rate) - self.convs = torch.nn.ModuleList([]) - for i in range(self.n_conv_layers): - if i == 0: - self.convs.append( - tgnn.ResGatedGraphConv( - self.n_atom_properties, - self.in_length, - # dropout=self.dropout_rate, - edge_dim=self.n_bond_properties, - ) - ) + + self.convs.append( + tgnn.ResGatedGraphConv( + self.n_atom_properties, + self.hidden_length, + # dropout=self.dropout_rate, + edge_dim=self.n_bond_properties, + ) + ) + + for _ in range(self.n_conv_layers - 1): self.convs.append( tgnn.ResGatedGraphConv( - self.in_length, self.in_length, edge_dim=self.n_bond_properties + self.hidden_length, + self.hidden_length, + # dropout=self.dropout_rate, + edge_dim=self.n_bond_properties, ) ) - self.final_conv = tgnn.ResGatedGraphConv( - self.in_length, self.hidden_length, edge_dim=self.n_bond_properties - ) def forward(self, batch): graph_data = batch["features"][0] @@ -136,11 +137,6 @@ def forward(self, batch): a = self.activation( conv(a, graph_data.edge_index.long(), edge_attr=graph_data.edge_attr) ) - a = self.activation( - self.final_conv( - a, graph_data.edge_index.long(), edge_attr=graph_data.edge_attr - ) - ) return a From dcf38a4932da32aadc786026f89be76ae6c00d06 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 28 May 2025 23:22:17 +0200 Subject: [PATCH 23/23] remove in_length from config --- chebai_graph/models/graph.py | 2 -- configs/model/gnn_res_gated.yml | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/chebai_graph/models/graph.py b/chebai_graph/models/graph.py index 8dec420..7c4082a 100644 --- a/chebai_graph/models/graph.py +++ b/chebai_graph/models/graph.py @@ -85,8 +85,6 @@ class ResGatedGraphConvNetBase(GraphBaseNet): def __init__(self, config: typing.Dict, **kwargs): super().__init__(**kwargs) - - self.in_length = config["in_length"] self.hidden_length = config["hidden_length"] self.dropout_rate = config["dropout_rate"] self.n_conv_layers = config["n_conv_layers"] if "n_conv_layers" in config else 3 diff --git a/configs/model/gnn_res_gated.yml b/configs/model/gnn_res_gated.yml index 27d1e78..62d990d 100644 --- a/configs/model/gnn_res_gated.yml +++ b/configs/model/gnn_res_gated.yml @@ -3,8 +3,7 @@ init_args: optimizer_kwargs: lr: 1e-3 config: - in_length: 256 - hidden_length: 512 + hidden_length: 256 dropout_rate: 0.1 n_conv_layers: 3 n_linear_layers: 3