From 689c5ddc56e56c3275c2ea1300cc94d2e6af2153 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Jul 2025 17:58:16 +0200 Subject: [PATCH 1/8] dynamic imports for readers --- chebai/preprocessing/reader.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index aa9960f9..c737df75 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -5,10 +5,7 @@ from itertools import islice from typing import Any, Dict, List, Optional -import deepsmiles -import selfies as sf from pysmiles.read_smiles import _tokenize -from transformers import RobertaTokenizerFast from chebai.preprocessing.collate import DefaultCollator, RaggedCollator @@ -205,6 +202,8 @@ class DeepChemDataReader(ChemDataReader): """ def __init__(self, *args, **kwargs): + import deepsmiles + super().__init__(*args, **kwargs) self.converter = deepsmiles.Converter(rings=True, branches=True) self.error_count = 0 @@ -279,6 +278,8 @@ def __init__( vsize: int = 4000, **kwargs, ): + from transformers import RobertaTokenizerFast + super().__init__(*args, **kwargs) self.tokenizer = RobertaTokenizerFast.from_pretrained( data_path, max_len=max_len @@ -312,6 +313,8 @@ def __init__( vsize: int = 4000, **kwargs, ): + import selfies as sf + super().__init__(*args, **kwargs) self.error_count = 0 sf.set_semantic_constraints("hypervalent") @@ -323,6 +326,8 @@ def name(cls) -> str: def _read_data(self, raw_data: str) -> Optional[List[int]]: """Read and tokenize raw data using SELFIES.""" + import selfies as sf + try: tokenized = sf.split_selfies(sf.encoder(raw_data.strip(), strict=True)) tokenized = [self._get_token_index(v) for v in tokenized] From bbcf6352c626bb74f1a89e63c3826969be24b4db Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Jul 2025 18:16:59 +0200 Subject: [PATCH 2/8] dyamic import for base dm --- chebai/preprocessing/datasets/base.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 4a1898bc..a229e7af 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -1,24 +1,21 @@ import os import random from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union import lightning as pl -import networkx as nx import pandas as pd import torch import tqdm -from iterstrat.ml_stratifiers import ( - MultilabelStratifiedKFold, - MultilabelStratifiedShuffleSplit, -) from lightning.pytorch.core.datamodule import LightningDataModule from lightning_utilities.core.rank_zero import rank_zero_info -from sklearn.model_selection import StratifiedShuffleSplit from torch.utils.data import DataLoader from chebai.preprocessing import reader as dr +if TYPE_CHECKING: + import networkx as nx + class XYBaseDataModule(LightningDataModule): """ @@ -818,7 +815,7 @@ def _download_required_data(self) -> str: pass @abstractmethod - def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: + def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph": """ Extracts the class hierarchy from the data. Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from @@ -833,7 +830,7 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: pass @abstractmethod - def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: + def _graph_to_raw_dataset(self, graph: "nx.DiGraph") -> pd.DataFrame: """ Converts the graph to a raw dataset. Uses the graph created by `_extract_class_hierarchy` method to extract the @@ -848,7 +845,7 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: pass @abstractmethod - def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List: """ Selects classes from the dataset based on a specified criteria. @@ -1023,6 +1020,9 @@ def get_test_split( Raises: ValueError: If the DataFrame does not contain a column named "labels". """ + from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit + from sklearn.model_selection import StratifiedShuffleSplit + print("Get test data split") labels_list = df["labels"].tolist() @@ -1060,6 +1060,12 @@ def get_train_val_splits_given_test( and validation DataFrames. The keys are the names of the train and validation sets, and the values are the corresponding DataFrames. """ + from iterstrat.ml_stratifiers import ( + MultilabelStratifiedKFold, + MultilabelStratifiedShuffleSplit, + ) + from sklearn.model_selection import StratifiedShuffleSplit + print("Split dataset into train / val with given test set") test_ids = test_df["ident"].tolist() From 0bfd79d246be0279d082c42415aea2e73672548f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Jul 2025 18:17:19 +0200 Subject: [PATCH 3/8] dynamic import for chebi dm --- chebai/preprocessing/datasets/chebi.py | 30 ++++++++++++++++++-------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 1df144d9..06aa1e70 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -13,17 +13,18 @@ import pickle from abc import ABC from collections import OrderedDict -from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union -import fastobo -import networkx as nx import pandas as pd -import requests import torch from chebai.preprocessing import reader as dr from chebai.preprocessing.datasets.base import XYBaseDataModule, _DynamicDataset +if TYPE_CHECKING: + import fastobo + import networkx as nx + # exclude some entities from the dataset because the violate disjointness axioms CHEBI_BLACKLIST = [ 194026, @@ -212,6 +213,8 @@ def _load_chebi(self, version: int) -> str: Returns: str: The file path of the loaded ChEBI ontology. """ + import requests + chebi_name = self.raw_file_names_dict["chebi"] chebi_path = os.path.join(self.raw_dir, chebi_name) if not os.path.isfile(chebi_path): @@ -223,7 +226,7 @@ def _load_chebi(self, version: int) -> str: open(chebi_path, "wb").write(r.content) return chebi_path - def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: + def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph": """ Extracts the class hierarchy from the ChEBI ontology. Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from @@ -235,6 +238,9 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: Returns: nx.DiGraph: The class hierarchy. """ + import fastobo + import networkx as nx + with open(data_path, encoding="utf-8") as chebi: chebi = "\n".join(line for line in chebi if not line.startswith("xref:")) @@ -262,7 +268,7 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: print("Compute transitive closure") return nx.transitive_closure_dag(g) - def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: + def _graph_to_raw_dataset(self, g: "nx.DiGraph") -> pd.DataFrame: """ Converts the graph to a raw dataset. Uses the graph created by `_extract_class_hierarchy` method to extract the @@ -274,6 +280,8 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: Returns: pd.DataFrame: The raw dataset created from the graph. """ + import networkx as nx + smiles = nx.get_node_attributes(g, "smiles") names = nx.get_node_attributes(g, "name") @@ -574,7 +582,7 @@ def _name(self) -> str: """ return f"ChEBI{self.THRESHOLD}" - def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List: """ Selects classes from the ChEBI dataset based on the number of successors meeting a specified threshold. @@ -599,6 +607,8 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: - The `THRESHOLD` attribute should be defined in the subclass of this class. - Nodes without a 'smiles' attribute are ignored in the successor count. """ + import networkx as nx + smiles = nx.get_node_attributes(g, "smiles") nodes = list( sorted( @@ -731,7 +741,7 @@ def processed_dir_main(self) -> str: "processed", ) - def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph: + def _extract_class_hierarchy(self, chebi_path: str) -> "nx.DiGraph": """ Extracts a subset of ChEBI based on subclasses of the top class ID. @@ -786,7 +796,7 @@ def chebi_to_int(s: str) -> int: return int(s[s.index(":") + 1 :]) -def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]: +def term_callback(doc: "fastobo.term.TermFrame") -> Union[Dict, bool]: """ Extracts information from a ChEBI term document. This function takes a ChEBI term document as input and extracts relevant information such as the term ID, parents, @@ -803,6 +813,8 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]: - "name": The name of the ChEBI term. - "smiles": The SMILES string associated with the ChEBI term, if available. """ + import fastobo + parts = set() parents = [] name = None From 109723cf1c3931f393fa64748f27c3c13a6ece60 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Jul 2025 18:28:12 +0200 Subject: [PATCH 4/8] dynamic imports for log and struc --- chebai/loggers/custom.py | 3 ++- chebai/preprocessing/structures.py | 8 ++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/chebai/loggers/custom.py b/chebai/loggers/custom.py index d1b4282d..04c48849 100644 --- a/chebai/loggers/custom.py +++ b/chebai/loggers/custom.py @@ -2,7 +2,6 @@ from datetime import datetime from typing import List, Literal, Optional, Union -import wandb from lightning.fabric.utilities.types import _PATH from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import WandbLogger @@ -105,6 +104,8 @@ def set_fold(self, fold: int) -> None: Args: fold (int): Cross-validation fold number. """ + import wandb + if fold != self._fold: self._fold = fold # Start new experiment diff --git a/chebai/preprocessing/structures.py b/chebai/preprocessing/structures.py index 5cfe7966..2ab5de5d 100644 --- a/chebai/preprocessing/structures.py +++ b/chebai/preprocessing/structures.py @@ -1,8 +1,10 @@ -from typing import Any, Tuple, Union +from typing import TYPE_CHECKING, Any, Tuple, Union -import networkx as nx import torch +if TYPE_CHECKING: + import networkx as nx + class XYData(torch.utils.data.Dataset): """ @@ -129,6 +131,8 @@ def to_x(self, device: torch.device) -> Tuple[nx.Graph, ...]: Returns: A tuple of molecular graphs with node attributes on the specified device. """ + import networkx as nx + l_ = [] for g in self.x: graph = g.copy() From b87129da8fdd5b2655dceb24f4295519af09baa4 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Jul 2025 18:28:41 +0200 Subject: [PATCH 5/8] to avoid access to pubchem file: dynamic import --- chebai/loss/bce_weighted.py | 3 ++- chebai/loss/semantic.py | 8 ++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index 1d5ea763..993d535e 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -5,7 +5,6 @@ from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor -from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed class BCEWeighted(torch.nn.BCEWithLogitsLoss): @@ -27,6 +26,8 @@ def __init__( data_extractor: Optional[XYBaseDataModule] = None, **kwargs, ): + from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed + self.beta = beta if isinstance(data_extractor, LabeledUnlabeledMixed): data_extractor = data_extractor.labeled diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 18485269..877e0060 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -2,14 +2,16 @@ import math import os import pickle -from typing import List, Literal, Union +from typing import TYPE_CHECKING, List, Literal, Union import torch from chebai.loss.bce_weighted import BCEWeighted from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.preprocessing.datasets.chebi import ChEBIOver100, _ChEBIDataExtractor -from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed + +if TYPE_CHECKING: + from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed class ImplicationLoss(torch.nn.Module): @@ -68,6 +70,8 @@ def __init__( multiply_with_base_loss: bool = True, no_grads: bool = False, ): + from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed + super().__init__() # automatically choose labeled subset for implication filter in case of mixed dataset if isinstance(data_extractor, LabeledUnlabeledMixed): From 610d9d4b2b50cfb362ceda62057121aa0367346c Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Jul 2025 18:36:56 +0200 Subject: [PATCH 6/8] fix action error: add string literals --- chebai/loss/semantic.py | 2 +- chebai/preprocessing/structures.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 877e0060..89abb175 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -342,7 +342,7 @@ class DisjointLoss(ImplicationLoss): def __init__( self, path_to_disjointness: str, - data_extractor: Union[_ChEBIDataExtractor, LabeledUnlabeledMixed], + data_extractor: Union[_ChEBIDataExtractor, "LabeledUnlabeledMixed"], base_loss: torch.nn.Module = None, disjoint_loss_weight: float = 100, **kwargs, diff --git a/chebai/preprocessing/structures.py b/chebai/preprocessing/structures.py index 2ab5de5d..4a69ea4f 100644 --- a/chebai/preprocessing/structures.py +++ b/chebai/preprocessing/structures.py @@ -121,7 +121,7 @@ class XYMolData(XYData): kwargs: Additional fields to store in the dataset. """ - def to_x(self, device: torch.device) -> Tuple[nx.Graph, ...]: + def to_x(self, device: torch.device) -> Tuple["nx.Graph", ...]: """ Moves the node attributes of the molecular graphs to the specified device. From 925eea556439e8f2f2f82040a34c4e03d7bfa067 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Jul 2025 18:42:26 +0200 Subject: [PATCH 7/8] add inference dependencies --- pyproject.toml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 5723d78a..99728a13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,16 @@ dev = ["black", "isort", "pre-commit"] plot = ["matplotlib", "seaborn"] wandb = ["wandb"] +inference = [ + "numpy", + "pandas", + "torch", + "transformers", + "pysmiles==1.1.2", + "rdkit", + "lightning>=2.5", +] + [tool.setuptools] include-package-data = true license-files = ["LICEN[CS]E*"] From 2a039b6425d6ea4170a8b5d9e39085276f819db2 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 10 Aug 2025 00:01:01 +0200 Subject: [PATCH 8/8] fix nx dynamic import --- chebai/preprocessing/datasets/chebi.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index e9afd854..cbd04895 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -783,8 +783,10 @@ def _extract_class_hierarchy(self, chebi_path: str) -> "nx.DiGraph": ) return g - def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List: """Only selects classes that meet the threshold AND are subclasses of the top class ID (including itself).""" + import networkx as nx + smiles = nx.get_node_attributes(g, "smiles") nodes = list( sorted(