Skip to content

Commit b802132

Browse files
authored
Merge pull request #12 from ChEB-AI/fix/model-specific-package
Dynamic imports - for model-specific packages
2 parents 0c9e492 + 92aba63 commit b802132

File tree

8 files changed

+89
-43
lines changed

8 files changed

+89
-43
lines changed
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from .base_predictor import BasePredictor
2-
from .chemlog_predictor import ChemlogPeptidesPredictor, ChemlogExtraPredictor
2+
from .c3p_predictor import C3PPredictor
3+
from .chebi_lookup import ChEBILookupPredictor
4+
from .chemlog_predictor import ChemlogExtraPredictor, ChemlogPeptidesPredictor
35
from .electra_predictor import ElectraPredictor
46
from .gnn_predictor import ResGatedPredictor
5-
from .chebi_lookup import ChEBILookupPredictor
67

78
__all__ = [
89
"BasePredictor",
@@ -11,4 +12,5 @@
1112
"ResGatedPredictor",
1213
"ChEBILookupPredictor",
1314
"ChemlogExtraPredictor",
15+
"C3PPredictor",
1416
]

chebifier/prediction_models/c3p_predictor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from pathlib import Path
22
from typing import List, Optional
33

4-
from c3p import classifier as c3p_classifier
5-
64
from chebifier import modelwise_smiles_lru_cache
75
from chebifier.prediction_models import BasePredictor
86

@@ -26,6 +24,8 @@ def __init__(
2624

2725
@modelwise_smiles_lru_cache.batch_decorator
2826
def predict_smiles_list(self, smiles_list: list[str]) -> list:
27+
from c3p import classifier as c3p_classifier
28+
2929
result_list = c3p_classifier.classify(
3030
list(smiles_list),
3131
self.program_directory,
@@ -50,6 +50,8 @@ def explain_smiles(self, smiles):
5050
C3P provides natural language explanations for each prediction (positive or negative). Since there are more
5151
than 300 classes, only take the positive ones.
5252
"""
53+
from c3p import classifier as c3p_classifier
54+
5355
highlights = []
5456
result_list = c3p_classifier.classify(
5557
[smiles], self.program_directory, self.chemical_classes, strict=False

chebifier/prediction_models/chebi_lookup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import os
33
from typing import Optional
44

5-
import networkx as nx
65
from rdkit import Chem
76

87
from chebifier import modelwise_smiles_lru_cache
@@ -18,6 +17,7 @@ def __init__(
1817
chebi_version: int = 241,
1918
**kwargs,
2019
):
20+
2121
super().__init__(model_name, **kwargs)
2222
self._description = (
2323
description
@@ -42,6 +42,8 @@ def get_smiles_lookup(self):
4242
return smiles_lookup
4343

4444
def build_smiles_lookup(self):
45+
import networkx as nx
46+
4547
smiles_lookup = dict()
4648
for chebi_id, smiles in nx.get_node_attributes(
4749
self.chebi_graph, "smiles"

chebifier/prediction_models/chemlog_predictor.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,9 @@
11
from typing import Optional
22

33
import tqdm
4-
from chemlog.alg_classification.charge_classifier import get_charge_category
5-
from chemlog.alg_classification.peptide_size_classifier import get_n_amino_acid_residues
6-
from chemlog.alg_classification.proteinogenics_classifier import (
7-
get_proteinogenic_amino_acids,
8-
)
9-
from chemlog.alg_classification.substructure_classifier import (
10-
is_diketopiperazine,
11-
is_emericellamide,
12-
)
13-
from chemlog.cli import CLASSIFIERS, _smiles_to_mol, strategy_call
14-
from chemlog_extra.alg_classification.by_element_classification import (
15-
OrganoXCompoundClassifier,
16-
XMolecularEntityClassifier,
17-
)
18-
19-
from chebifier import modelwise_smiles_lru_cache
204

215
from .base_predictor import BasePredictor
6+
from .. import modelwise_smiles_lru_cache
227

238
AA_DICT = {
249
"A": "L-alanine",
@@ -48,15 +33,16 @@
4833

4934

5035
class ChemlogExtraPredictor(BasePredictor):
51-
CHEMLOG_CLASSIFIER = None
5236

5337
def __init__(self, model_name: str, **kwargs):
5438
super().__init__(model_name, **kwargs)
5539
self.chebi_graph = kwargs.get("chebi_graph", None)
56-
self.classifier = self.CHEMLOG_CLASSIFIER()
40+
self.classifier = None
5741

5842
@modelwise_smiles_lru_cache.batch_decorator
5943
def predict_smiles_list(self, smiles_list: list[str]) -> list:
44+
from chemlog.cli import _smiles_to_mol
45+
6046
mol_list = [_smiles_to_mol(smiles) for smiles in smiles_list]
6147
res = self.classifier.classify(mol_list)
6248
if self.chebi_graph is not None:
@@ -73,15 +59,29 @@ def predict_smiles_list(self, smiles_list: list[str]) -> list:
7359

7460

7561
class ChemlogXMolecularEntityPredictor(ChemlogExtraPredictor):
76-
CHEMLOG_CLASSIFIER = XMolecularEntityClassifier
62+
def __init__(self, model_name: str, **kwargs):
63+
from chemlog_extra.alg_classification.by_element_classification import (
64+
XMolecularEntityClassifier,
65+
)
66+
67+
super().__init__(model_name, **kwargs)
68+
self.classifier = XMolecularEntityClassifier()
7769

7870

7971
class ChemlogOrganoXCompoundPredictor(ChemlogExtraPredictor):
80-
CHEMLOG_CLASSIFIER = OrganoXCompoundClassifier
72+
def __init__(self, model_name: str, **kwargs):
73+
from chemlog_extra.alg_classification.by_element_classification import (
74+
OrganoXCompoundClassifier,
75+
)
76+
77+
super().__init__(model_name, **kwargs)
78+
self.classifier = OrganoXCompoundClassifier()
8179

8280

8381
class ChemlogPeptidesPredictor(BasePredictor):
8482
def __init__(self, model_name: str, **kwargs):
83+
from chemlog.cli import CLASSIFIERS
84+
8585
super().__init__(model_name, **kwargs)
8686
self.strategy = "algo"
8787
self.chebi_graph = kwargs.get("chebi_graph", None)
@@ -97,6 +97,8 @@ def __init__(self, model_name: str, **kwargs):
9797
print(f"Initialised ChemLog model {self.model_name}")
9898

9999
def predict_smiles(self, smiles: str) -> Optional[dict]:
100+
from chemlog.cli import _smiles_to_mol, strategy_call
101+
100102
mol = _smiles_to_mol(smiles)
101103
if mol is None:
102104
return None
@@ -133,6 +135,19 @@ def predict_smiles_list(self, smiles_list: list[str]) -> list:
133135

134136
def get_chemlog_result_info(self, smiles):
135137
"""Get classification for single molecule with additional information."""
138+
from chemlog.alg_classification.charge_classifier import get_charge_category
139+
from chemlog.alg_classification.peptide_size_classifier import (
140+
get_n_amino_acid_residues,
141+
)
142+
from chemlog.alg_classification.proteinogenics_classifier import (
143+
get_proteinogenic_amino_acids,
144+
)
145+
from chemlog.alg_classification.substructure_classifier import (
146+
is_diketopiperazine,
147+
is_emericellamide,
148+
)
149+
from chemlog.cli import _smiles_to_mol
150+
136151
mol = _smiles_to_mol(smiles)
137152
if mol is None or not smiles:
138153
return {"error": "Failed to parse SMILES"}

chebifier/prediction_models/electra_predictor.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
from typing import TYPE_CHECKING
2+
13
import numpy as np
2-
from chebai.models.electra import Electra
3-
from chebai.preprocessing.reader import EMBEDDING_OFFSET, ChemDataReader
44

55
from .nn_predictor import NNPredictor
66

7+
if TYPE_CHECKING:
8+
from chebai.models.electra import Electra
9+
710

811
def build_graph_from_attention(att, node_labels, token_labels, threshold=0.0):
912
n_nodes = len(node_labels)
@@ -37,10 +40,14 @@ def build_graph_from_attention(att, node_labels, token_labels, threshold=0.0):
3740

3841
class ElectraPredictor(NNPredictor):
3942
def __init__(self, model_name: str, ckpt_path: str, **kwargs):
43+
from chebai.preprocessing.reader import ChemDataReader
44+
4045
super().__init__(model_name, ckpt_path, reader_cls=ChemDataReader, **kwargs)
4146
print(f"Initialised Electra model {self.model_name} (device: {self.device})")
4247

43-
def init_model(self, ckpt_path: str, **kwargs) -> Electra:
48+
def init_model(self, ckpt_path: str, **kwargs) -> "Electra":
49+
from chebai.models.electra import Electra
50+
4451
model = Electra.load_from_checkpoint(
4552
ckpt_path,
4653
map_location=self.device,
@@ -53,6 +60,8 @@ def init_model(self, ckpt_path: str, **kwargs) -> Electra:
5360
return model
5461

5562
def explain_smiles(self, smiles) -> dict:
63+
from chebai.preprocessing.reader import EMBEDDING_OFFSET
64+
5665
reader = self.reader_cls()
5766
token_dict = reader.to_data(dict(features=smiles, labels=None))
5867
tokens = np.array(token_dict["features"]).astype(int).tolist()

chebifier/prediction_models/gnn_predictor.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1-
import chebai_graph.preprocessing.properties as p
2-
import torch
3-
from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred
4-
from chebai_graph.preprocessing.property_encoder import IndexEncoder, OneHotEncoder
5-
from chebai_graph.preprocessing.reader import GraphPropertyReader
6-
from torch_geometric.data.data import Data as GeomData
1+
from typing import TYPE_CHECKING
72

83
from .nn_predictor import NNPredictor
94

5+
if TYPE_CHECKING:
6+
from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred
7+
108

119
class ResGatedPredictor(NNPredictor):
1210
def __init__(self, model_name: str, ckpt_path: str, molecular_properties, **kwargs):
11+
from chebai_graph.preprocessing.properties import MolecularProperty
12+
from chebai_graph.preprocessing.reader import GraphPropertyReader
13+
1314
super().__init__(
1415
model_name, ckpt_path, reader_cls=GraphPropertyReader, **kwargs
1516
)
@@ -23,7 +24,7 @@ def __init__(self, model_name: str, ckpt_path: str, molecular_properties, **kwar
2324
properties = []
2425
self.molecular_properties = properties
2526
assert isinstance(self.molecular_properties, list) and all(
26-
isinstance(prop, p.MolecularProperty) for prop in self.molecular_properties
27+
isinstance(prop, MolecularProperty) for prop in self.molecular_properties
2728
)
2829
print(f"Initialised GNN model {self.model_name} (device: {self.device})")
2930

@@ -32,7 +33,10 @@ def load_class(self, class_path: str):
3233
module = __import__(module_path, fromlist=[class_name])
3334
return getattr(module, class_name)
3435

35-
def init_model(self, ckpt_path: str, **kwargs) -> ResGatedGraphConvNetGraphPred:
36+
def init_model(self, ckpt_path: str, **kwargs) -> "ResGatedGraphConvNetGraphPred":
37+
import torch
38+
from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred
39+
3640
model = ResGatedGraphConvNetGraphPred.load_from_checkpoint(
3741
ckpt_path,
3842
map_location=torch.device(self.device),
@@ -45,6 +49,14 @@ def init_model(self, ckpt_path: str, **kwargs) -> ResGatedGraphConvNetGraphPred:
4549
return model
4650

4751
def read_smiles(self, smiles):
52+
import torch
53+
from chebai_graph.preprocessing.properties import AtomProperty, BondProperty
54+
from chebai_graph.preprocessing.property_encoder import (
55+
IndexEncoder,
56+
OneHotEncoder,
57+
)
58+
from torch_geometric.data.data import Data as GeomData
59+
4860
reader = self.reader_cls()
4961
d = reader.to_data(dict(features=smiles, labels=None))
5062
geom_data = d["features"]
@@ -87,9 +99,9 @@ def read_smiles(self, smiles):
8799
encoded_values = encoded_values.unsqueeze(1)
88100
else:
89101
encoded_values = torch.zeros((0, prop.encoder.get_encoding_length()))
90-
if isinstance(prop, p.AtomProperty):
102+
if isinstance(prop, AtomProperty):
91103
x = torch.cat([x, encoded_values], dim=1)
92-
elif isinstance(prop, p.BondProperty):
104+
elif isinstance(prop, BondProperty):
93105
edge_attr = torch.cat([edge_attr, encoded_values], dim=1)
94106
else:
95107
molecule_attr = torch.cat([molecule_attr, encoded_values[0]], dim=1)

chebifier/prediction_models/nn_predictor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
import torch
32
import tqdm
43
from rdkit import Chem
54

@@ -17,6 +16,8 @@ def __init__(
1716
target_labels_path: str,
1817
**kwargs,
1918
):
19+
import torch
20+
2021
super().__init__(model_name, **kwargs)
2122
self.reader_cls = reader_cls
2223

@@ -56,6 +57,8 @@ def read_smiles(self, smiles):
5657
def predict_smiles_list(self, smiles_list: list[str]) -> list:
5758
"""Returns a list with the length of smiles_list, each element is either None (=failure) or a dictionary
5859
Of classes and predicted values."""
60+
import torch
61+
5962
token_dicts = []
6063
could_not_parse = []
6164
index_map = dict()

pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@ classifiers = [
2020
dependencies = [
2121
"click",
2222
"pyyaml",
23-
"torch",
2423
"tqdm",
2524
"rdkit",
26-
"chebai>=1.0.1",
27-
"chemlog>=1.0.4",
25+
# Package to install manually if required
26+
#"chebai>=1.0.1",
27+
#"chemlog>=1.0.4",
28+
2829
# pypi does not support git dependencies
2930
#"chemlog_extra @ git+https://github.com/ChEB-AI/chemlog-extra.git",
30-
"c3p"
31+
3132
# forked version of c3p is windows-compatible
3233
#"c3p @ git+https://github.com/sfluegel05/c3p.git"
3334
]

0 commit comments

Comments
 (0)