Skip to content

Commit a5f04ca

Browse files
committed
Merge branch 'dev' into feature/augment_smiles
2 parents 388603f + 8e51a61 commit a5f04ca

File tree

9 files changed

+95
-46
lines changed

9 files changed

+95
-46
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ jobs:
2424
python -m pip install --upgrade pip
2525
python -m pip install --upgrade pip setuptools wheel
2626
python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
27-
# wandb included to use custom trainer for cli test which needs wandb logger
28-
python -m pip install -e .[wandb]
27+
python -m pip install -e .[dev]
2928
3029
- name: Display Python & Installed Packages
3130
run: |

chebai/loggers/custom.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from datetime import datetime
33
from typing import List, Literal, Optional, Union
44

5-
import wandb
65
from lightning.fabric.utilities.types import _PATH
76
from lightning.pytorch.callbacks import ModelCheckpoint
87
from lightning.pytorch.loggers import WandbLogger
@@ -105,6 +104,8 @@ def set_fold(self, fold: int) -> None:
105104
Args:
106105
fold (int): Cross-validation fold number.
107106
"""
107+
import wandb
108+
108109
if fold != self._fold:
109110
self._fold = fold
110111
# Start new experiment

chebai/loss/bce_weighted.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from chebai.preprocessing.datasets.base import XYBaseDataModule
77
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor
8-
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
98

109

1110
class BCEWeighted(torch.nn.BCEWithLogitsLoss):
@@ -27,6 +26,8 @@ def __init__(
2726
data_extractor: Optional[XYBaseDataModule] = None,
2827
**kwargs,
2928
):
29+
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
30+
3031
self.beta = beta
3132
if isinstance(data_extractor, LabeledUnlabeledMixed):
3233
data_extractor = data_extractor.labeled

chebai/loss/semantic.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
import math
33
import os
44
import pickle
5-
from typing import List, Literal, Union
5+
from typing import TYPE_CHECKING, List, Literal, Union
66

77
import torch
88

99
from chebai.loss.bce_weighted import BCEWeighted
1010
from chebai.preprocessing.datasets.base import XYBaseDataModule
1111
from chebai.preprocessing.datasets.chebi import ChEBIOver100, _ChEBIDataExtractor
12-
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
12+
13+
if TYPE_CHECKING:
14+
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
1315

1416

1517
class ImplicationLoss(torch.nn.Module):
@@ -68,6 +70,8 @@ def __init__(
6870
multiply_with_base_loss: bool = True,
6971
no_grads: bool = False,
7072
):
73+
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
74+
7175
super().__init__()
7276
# automatically choose labeled subset for implication filter in case of mixed dataset
7377
if isinstance(data_extractor, LabeledUnlabeledMixed):
@@ -338,7 +342,7 @@ class DisjointLoss(ImplicationLoss):
338342
def __init__(
339343
self,
340344
path_to_disjointness: str,
341-
data_extractor: Union[_ChEBIDataExtractor, LabeledUnlabeledMixed],
345+
data_extractor: Union[_ChEBIDataExtractor, "LabeledUnlabeledMixed"],
342346
base_loss: torch.nn.Module = None,
343347
disjoint_loss_weight: float = 100,
344348
**kwargs,

chebai/preprocessing/datasets/base.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,21 @@
22
import random
33
from abc import ABC, abstractmethod
44
from pathlib import Path
5-
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
5+
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union
66

77
import lightning as pl
8-
import networkx as nx
98
import pandas as pd
109
import torch
1110
import tqdm
12-
from iterstrat.ml_stratifiers import (
13-
MultilabelStratifiedKFold,
14-
MultilabelStratifiedShuffleSplit,
15-
)
1611
from lightning.pytorch.core.datamodule import LightningDataModule
1712
from lightning_utilities.core.rank_zero import rank_zero_info
18-
from sklearn.model_selection import StratifiedShuffleSplit
1913
from torch.utils.data import DataLoader
2014

2115
from chebai.preprocessing import reader as dr
2216

17+
if TYPE_CHECKING:
18+
import networkx as nx
19+
2320

2421
class XYBaseDataModule(LightningDataModule):
2522
"""
@@ -830,7 +827,7 @@ def _download_required_data(self) -> str:
830827
pass
831828

832829
@abstractmethod
833-
def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
830+
def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph":
834831
"""
835832
Extracts the class hierarchy from the data.
836833
Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from
@@ -845,7 +842,7 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
845842
pass
846843

847844
@abstractmethod
848-
def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
845+
def _graph_to_raw_dataset(self, graph: "nx.DiGraph") -> pd.DataFrame:
849846
"""
850847
Converts the graph to a raw dataset.
851848
Uses the graph created by `_extract_class_hierarchy` method to extract the
@@ -860,7 +857,7 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
860857
pass
861858

862859
@abstractmethod
863-
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> list:
860+
def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List:
864861
"""
865862
Selects classes from the dataset based on a specified criteria.
866863
@@ -1052,6 +1049,9 @@ def get_test_split(
10521049
Raises:
10531050
ValueError: If the DataFrame does not contain a column named "labels".
10541051
"""
1052+
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
1053+
from sklearn.model_selection import StratifiedShuffleSplit
1054+
10551055
print("Get test data split")
10561056

10571057
labels_list = df["labels"].tolist()
@@ -1089,6 +1089,12 @@ def get_train_val_splits_given_test(
10891089
and validation DataFrames. The keys are the names of the train and validation sets, and the values
10901090
are the corresponding DataFrames.
10911091
"""
1092+
from iterstrat.ml_stratifiers import (
1093+
MultilabelStratifiedKFold,
1094+
MultilabelStratifiedShuffleSplit,
1095+
)
1096+
from sklearn.model_selection import StratifiedShuffleSplit
1097+
10921098
print("Split dataset into train / val with given test set")
10931099

10941100
test_ids = test_df["ident"].tolist()

chebai/preprocessing/datasets/chebi.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,19 @@
1515
from abc import ABC
1616
from collections import OrderedDict
1717
from itertools import cycle, permutations, product
18-
from typing import Any, Generator, Optional, Union
18+
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union
1919

20-
import fastobo
21-
import networkx as nx
2220
import pandas as pd
23-
import requests
2421
import torch
2522
from rdkit import Chem
2623

2724
from chebai.preprocessing import reader as dr
2825
from chebai.preprocessing.datasets.base import XYBaseDataModule, _DynamicDataset
2926

27+
if TYPE_CHECKING:
28+
import fastobo
29+
import networkx as nx
30+
3031
# exclude some entities from the dataset because the violate disjointness axioms
3132
CHEBI_BLACKLIST = [
3233
194026,
@@ -236,6 +237,8 @@ def _load_chebi(self, version: int) -> str:
236237
Returns:
237238
str: The file path of the loaded ChEBI ontology.
238239
"""
240+
import requests
241+
239242
chebi_name = self.raw_file_names_dict["chebi"]
240243
chebi_path = os.path.join(self.raw_dir, chebi_name)
241244
if not os.path.isfile(chebi_path):
@@ -247,7 +250,7 @@ def _load_chebi(self, version: int) -> str:
247250
open(chebi_path, "wb").write(r.content)
248251
return chebi_path
249252

250-
def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
253+
def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph":
251254
"""
252255
Extracts the class hierarchy from the ChEBI ontology.
253256
Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from
@@ -259,6 +262,9 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
259262
Returns:
260263
nx.DiGraph: The class hierarchy.
261264
"""
265+
import fastobo
266+
import networkx as nx
267+
262268
with open(data_path, encoding="utf-8") as chebi:
263269
chebi = "\n".join(line for line in chebi if not line.startswith("xref:"))
264270

@@ -286,7 +292,7 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
286292
print("Compute transitive closure")
287293
return nx.transitive_closure_dag(g)
288294

289-
def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
295+
def _graph_to_raw_dataset(self, g: "nx.DiGraph") -> pd.DataFrame:
290296
"""
291297
Converts the graph to a raw dataset.
292298
Uses the graph created by `_extract_class_hierarchy` method to extract the
@@ -298,6 +304,8 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
298304
Returns:
299305
pd.DataFrame: The raw dataset created from the graph.
300306
"""
307+
import networkx as nx
308+
301309
smiles = nx.get_node_attributes(g, "smiles")
302310
names = nx.get_node_attributes(g, "name")
303311

@@ -696,7 +704,7 @@ def _name(self) -> str:
696704
"""
697705
return f"ChEBI{self.THRESHOLD}"
698706

699-
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> list:
707+
def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List:
700708
"""
701709
Selects classes from the ChEBI dataset based on the number of successors meeting a specified threshold.
702710
@@ -721,6 +729,8 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> list:
721729
- The `THRESHOLD` attribute should be defined in the subclass of this class.
722730
- Nodes without a 'smiles' attribute are ignored in the successor count.
723731
"""
732+
import networkx as nx
733+
724734
smiles = nx.get_node_attributes(g, "smiles")
725735
nodes = list(
726736
sorted(
@@ -859,7 +869,7 @@ def processed_dir_main(self) -> str:
859869
"processed",
860870
)
861871

862-
def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph:
872+
def _extract_class_hierarchy(self, chebi_path: str) -> "nx.DiGraph":
863873
"""
864874
Extracts a subset of ChEBI based on subclasses of the top class ID.
865875
@@ -897,8 +907,10 @@ def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph:
897907
)
898908
return g
899909

900-
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> list:
910+
def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List:
901911
"""Only selects classes that meet the threshold AND are subclasses of the top class ID (including itself)."""
912+
import networkx as nx
913+
902914
smiles = nx.get_node_attributes(g, "smiles")
903915
nodes = list(
904916
sorted(
@@ -958,7 +970,7 @@ def chebi_to_int(s: str) -> int:
958970
return int(s[s.index(":") + 1 :])
959971

960972

961-
def term_callback(doc: fastobo.term.TermFrame) -> Union[dict, bool]:
973+
def term_callback(doc: "fastobo.term.TermFrame") -> Union[Dict, bool]:
962974
"""
963975
Extracts information from a ChEBI term document.
964976
This function takes a ChEBI term document as input and extracts relevant information such as the term ID, parents,
@@ -975,6 +987,8 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[dict, bool]:
975987
- "name": The name of the ChEBI term.
976988
- "smiles": The SMILES string associated with the ChEBI term, if available.
977989
"""
990+
import fastobo
991+
978992
parts = set()
979993
parents = []
980994
name = None

chebai/preprocessing/reader.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,8 @@
55
from itertools import islice
66
from typing import Any, Dict, List, Optional
77

8-
import deepsmiles
9-
import selfies as sf
108
from pysmiles.read_smiles import _tokenize
119
from rdkit import Chem
12-
from transformers import RobertaTokenizerFast
1310

1411
from chebai.preprocessing.collate import DefaultCollator, RaggedCollator
1512

@@ -220,6 +217,8 @@ class DeepChemDataReader(ChemDataReader):
220217
"""
221218

222219
def __init__(self, *args, **kwargs):
220+
import deepsmiles
221+
223222
super().__init__(*args, **kwargs)
224223
self.converter = deepsmiles.Converter(rings=True, branches=True)
225224
self.error_count = 0
@@ -294,6 +293,8 @@ def __init__(
294293
vsize: int = 4000,
295294
**kwargs,
296295
):
296+
from transformers import RobertaTokenizerFast
297+
297298
super().__init__(*args, **kwargs)
298299
self.tokenizer = RobertaTokenizerFast.from_pretrained(
299300
data_path, max_len=max_len
@@ -327,6 +328,8 @@ def __init__(
327328
vsize: int = 4000,
328329
**kwargs,
329330
):
331+
import selfies as sf
332+
330333
super().__init__(*args, **kwargs)
331334
self.error_count = 0
332335
sf.set_semantic_constraints("hypervalent")
@@ -338,6 +341,8 @@ def name(cls) -> str:
338341

339342
def _read_data(self, raw_data: str) -> Optional[List[int]]:
340343
"""Read and tokenize raw data using SELFIES."""
344+
import selfies as sf
345+
341346
try:
342347
tokenized = sf.split_selfies(sf.encoder(raw_data.strip(), strict=True))
343348
tokenized = [self._get_token_index(v) for v in tokenized]

chebai/preprocessing/structures.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from typing import Any, Tuple, Union
1+
from typing import TYPE_CHECKING, Any, Tuple, Union
22

3-
import networkx as nx
43
import torch
54

5+
if TYPE_CHECKING:
6+
import networkx as nx
7+
68

79
class XYData(torch.utils.data.Dataset):
810
"""
@@ -119,7 +121,7 @@ class XYMolData(XYData):
119121
kwargs: Additional fields to store in the dataset.
120122
"""
121123

122-
def to_x(self, device: torch.device) -> Tuple[nx.Graph, ...]:
124+
def to_x(self, device: torch.device) -> Tuple["nx.Graph", ...]:
123125
"""
124126
Moves the node attributes of the molecular graphs to the specified device.
125127
@@ -129,6 +131,8 @@ def to_x(self, device: torch.device) -> Tuple[nx.Graph, ...]:
129131
Returns:
130132
A tuple of molecular graphs with node attributes on the specified device.
131133
"""
134+
import networkx as nx
135+
132136
l_ = []
133137
for g in self.x:
134138
graph = g.copy()

0 commit comments

Comments
 (0)