diff --git a/chebai/loggers/custom.py b/chebai/loggers/custom.py index d1b4282d..04c48849 100644 --- a/chebai/loggers/custom.py +++ b/chebai/loggers/custom.py @@ -2,7 +2,6 @@ from datetime import datetime from typing import List, Literal, Optional, Union -import wandb from lightning.fabric.utilities.types import _PATH from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import WandbLogger @@ -105,6 +104,8 @@ def set_fold(self, fold: int) -> None: Args: fold (int): Cross-validation fold number. """ + import wandb + if fold != self._fold: self._fold = fold # Start new experiment diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index 1d5ea763..993d535e 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -5,7 +5,6 @@ from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor -from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed class BCEWeighted(torch.nn.BCEWithLogitsLoss): @@ -27,6 +26,8 @@ def __init__( data_extractor: Optional[XYBaseDataModule] = None, **kwargs, ): + from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed + self.beta = beta if isinstance(data_extractor, LabeledUnlabeledMixed): data_extractor = data_extractor.labeled diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 18485269..89abb175 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -2,14 +2,16 @@ import math import os import pickle -from typing import List, Literal, Union +from typing import TYPE_CHECKING, List, Literal, Union import torch from chebai.loss.bce_weighted import BCEWeighted from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.preprocessing.datasets.chebi import ChEBIOver100, _ChEBIDataExtractor -from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed + +if TYPE_CHECKING: + from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed class ImplicationLoss(torch.nn.Module): @@ -68,6 +70,8 @@ def __init__( multiply_with_base_loss: bool = True, no_grads: bool = False, ): + from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed + super().__init__() # automatically choose labeled subset for implication filter in case of mixed dataset if isinstance(data_extractor, LabeledUnlabeledMixed): @@ -338,7 +342,7 @@ class DisjointLoss(ImplicationLoss): def __init__( self, path_to_disjointness: str, - data_extractor: Union[_ChEBIDataExtractor, LabeledUnlabeledMixed], + data_extractor: Union[_ChEBIDataExtractor, "LabeledUnlabeledMixed"], base_loss: torch.nn.Module = None, disjoint_loss_weight: float = 100, **kwargs, diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 4a1898bc..a229e7af 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -1,24 +1,21 @@ import os import random from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union import lightning as pl -import networkx as nx import pandas as pd import torch import tqdm -from iterstrat.ml_stratifiers import ( - MultilabelStratifiedKFold, - MultilabelStratifiedShuffleSplit, -) from lightning.pytorch.core.datamodule import LightningDataModule from lightning_utilities.core.rank_zero import rank_zero_info -from sklearn.model_selection import StratifiedShuffleSplit from torch.utils.data import DataLoader from chebai.preprocessing import reader as dr +if TYPE_CHECKING: + import networkx as nx + class XYBaseDataModule(LightningDataModule): """ @@ -818,7 +815,7 @@ def _download_required_data(self) -> str: pass @abstractmethod - def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: + def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph": """ Extracts the class hierarchy from the data. Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from @@ -833,7 +830,7 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: pass @abstractmethod - def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: + def _graph_to_raw_dataset(self, graph: "nx.DiGraph") -> pd.DataFrame: """ Converts the graph to a raw dataset. Uses the graph created by `_extract_class_hierarchy` method to extract the @@ -848,7 +845,7 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: pass @abstractmethod - def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List: """ Selects classes from the dataset based on a specified criteria. @@ -1023,6 +1020,9 @@ def get_test_split( Raises: ValueError: If the DataFrame does not contain a column named "labels". """ + from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit + from sklearn.model_selection import StratifiedShuffleSplit + print("Get test data split") labels_list = df["labels"].tolist() @@ -1060,6 +1060,12 @@ def get_train_val_splits_given_test( and validation DataFrames. The keys are the names of the train and validation sets, and the values are the corresponding DataFrames. """ + from iterstrat.ml_stratifiers import ( + MultilabelStratifiedKFold, + MultilabelStratifiedShuffleSplit, + ) + from sklearn.model_selection import StratifiedShuffleSplit + print("Split dataset into train / val with given test set") test_ids = test_df["ident"].tolist() diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 9fa1c1c7..cbd04895 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -13,17 +13,18 @@ import pickle from abc import ABC from collections import OrderedDict -from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union -import fastobo -import networkx as nx import pandas as pd -import requests import torch from chebai.preprocessing import reader as dr from chebai.preprocessing.datasets.base import XYBaseDataModule, _DynamicDataset +if TYPE_CHECKING: + import fastobo + import networkx as nx + # exclude some entities from the dataset because the violate disjointness axioms CHEBI_BLACKLIST = [ 194026, @@ -212,6 +213,8 @@ def _load_chebi(self, version: int) -> str: Returns: str: The file path of the loaded ChEBI ontology. """ + import requests + chebi_name = self.raw_file_names_dict["chebi"] chebi_path = os.path.join(self.raw_dir, chebi_name) if not os.path.isfile(chebi_path): @@ -223,7 +226,7 @@ def _load_chebi(self, version: int) -> str: open(chebi_path, "wb").write(r.content) return chebi_path - def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: + def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph": """ Extracts the class hierarchy from the ChEBI ontology. Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from @@ -235,6 +238,9 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: Returns: nx.DiGraph: The class hierarchy. """ + import fastobo + import networkx as nx + with open(data_path, encoding="utf-8") as chebi: chebi = "\n".join(line for line in chebi if not line.startswith("xref:")) @@ -262,7 +268,7 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: print("Compute transitive closure") return nx.transitive_closure_dag(g) - def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: + def _graph_to_raw_dataset(self, g: "nx.DiGraph") -> pd.DataFrame: """ Converts the graph to a raw dataset. Uses the graph created by `_extract_class_hierarchy` method to extract the @@ -274,6 +280,8 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: Returns: pd.DataFrame: The raw dataset created from the graph. """ + import networkx as nx + smiles = nx.get_node_attributes(g, "smiles") names = nx.get_node_attributes(g, "name") @@ -572,7 +580,7 @@ def _name(self) -> str: """ return f"ChEBI{self.THRESHOLD}" - def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List: """ Selects classes from the ChEBI dataset based on the number of successors meeting a specified threshold. @@ -597,6 +605,8 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: - The `THRESHOLD` attribute should be defined in the subclass of this class. - Nodes without a 'smiles' attribute are ignored in the successor count. """ + import networkx as nx + smiles = nx.get_node_attributes(g, "smiles") nodes = list( sorted( @@ -735,7 +745,7 @@ def processed_dir_main(self) -> str: "processed", ) - def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph: + def _extract_class_hierarchy(self, chebi_path: str) -> "nx.DiGraph": """ Extracts a subset of ChEBI based on subclasses of the top class ID. @@ -773,8 +783,10 @@ def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph: ) return g - def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List: """Only selects classes that meet the threshold AND are subclasses of the top class ID (including itself).""" + import networkx as nx + smiles = nx.get_node_attributes(g, "smiles") nodes = list( sorted( @@ -834,7 +846,7 @@ def chebi_to_int(s: str) -> int: return int(s[s.index(":") + 1 :]) -def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]: +def term_callback(doc: "fastobo.term.TermFrame") -> Union[Dict, bool]: """ Extracts information from a ChEBI term document. This function takes a ChEBI term document as input and extracts relevant information such as the term ID, parents, @@ -851,6 +863,8 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]: - "name": The name of the ChEBI term. - "smiles": The SMILES string associated with the ChEBI term, if available. """ + import fastobo + parts = set() parents = [] name = None diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 4b1b0353..e308a2e1 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -5,11 +5,8 @@ from itertools import islice from typing import Any, Dict, List, Optional -import deepsmiles -import selfies as sf from pysmiles.read_smiles import _tokenize from rdkit import Chem -from transformers import RobertaTokenizerFast from chebai.preprocessing.collate import DefaultCollator, RaggedCollator @@ -220,6 +217,8 @@ class DeepChemDataReader(ChemDataReader): """ def __init__(self, *args, **kwargs): + import deepsmiles + super().__init__(*args, **kwargs) self.converter = deepsmiles.Converter(rings=True, branches=True) self.error_count = 0 @@ -294,6 +293,8 @@ def __init__( vsize: int = 4000, **kwargs, ): + from transformers import RobertaTokenizerFast + super().__init__(*args, **kwargs) self.tokenizer = RobertaTokenizerFast.from_pretrained( data_path, max_len=max_len @@ -327,6 +328,8 @@ def __init__( vsize: int = 4000, **kwargs, ): + import selfies as sf + super().__init__(*args, **kwargs) self.error_count = 0 sf.set_semantic_constraints("hypervalent") @@ -338,6 +341,8 @@ def name(cls) -> str: def _read_data(self, raw_data: str) -> Optional[List[int]]: """Read and tokenize raw data using SELFIES.""" + import selfies as sf + try: tokenized = sf.split_selfies(sf.encoder(raw_data.strip(), strict=True)) tokenized = [self._get_token_index(v) for v in tokenized] diff --git a/chebai/preprocessing/structures.py b/chebai/preprocessing/structures.py index 5cfe7966..4a69ea4f 100644 --- a/chebai/preprocessing/structures.py +++ b/chebai/preprocessing/structures.py @@ -1,8 +1,10 @@ -from typing import Any, Tuple, Union +from typing import TYPE_CHECKING, Any, Tuple, Union -import networkx as nx import torch +if TYPE_CHECKING: + import networkx as nx + class XYData(torch.utils.data.Dataset): """ @@ -119,7 +121,7 @@ class XYMolData(XYData): kwargs: Additional fields to store in the dataset. """ - def to_x(self, device: torch.device) -> Tuple[nx.Graph, ...]: + def to_x(self, device: torch.device) -> Tuple["nx.Graph", ...]: """ Moves the node attributes of the molecular graphs to the specified device. @@ -129,6 +131,8 @@ def to_x(self, device: torch.device) -> Tuple[nx.Graph, ...]: Returns: A tuple of molecular graphs with node attributes on the specified device. """ + import networkx as nx + l_ = [] for g in self.x: graph = g.copy() diff --git a/pyproject.toml b/pyproject.toml index 5723d78a..99728a13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,16 @@ dev = ["black", "isort", "pre-commit"] plot = ["matplotlib", "seaborn"] wandb = ["wandb"] +inference = [ + "numpy", + "pandas", + "torch", + "transformers", + "pysmiles==1.1.2", + "rdkit", + "lightning>=2.5", +] + [tool.setuptools] include-package-data = true license-files = ["LICEN[CS]E*"]