Skip to content

Commit b0b3113

Browse files
committed
Merge branch 'dev' of https://github.com/ChEB-AI/python-chebai into dev
2 parents 426f1b0 + 97839d9 commit b0b3113

File tree

20 files changed

+4156
-103
lines changed

20 files changed

+4156
-103
lines changed

chebai/callbacks/epoch_metrics.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
6262
labels (torch.Tensor): Ground truth labels.
6363
"""
6464
tps = torch.sum(
65-
torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0
65+
torch.logical_and(preds > self.threshold, labels.to(torch.bool)),
66+
dim=0,
6667
)
6768
self.true_positives += tps
6869
self.positive_predictions += torch.sum(preds > self.threshold, dim=0)

chebai/models/base.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66
from lightning.pytorch.core.module import LightningModule
7+
from lightning.pytorch.utilities.rank_zero import rank_zero_info
78

89
from chebai.preprocessing.structures import XYData
910

@@ -107,7 +108,8 @@ def _get_prediction_and_labels(
107108
Returns:
108109
Tuple[torch.Tensor, torch.Tensor]: Predictions and labels.
109110
"""
110-
return output, labels
111+
# cast labels to int
112+
return output, labels.to(torch.int) if labels is not None else labels
111113

112114
def _process_labels_in_batch(self, batch: XYData) -> torch.Tensor:
113115
"""
@@ -159,6 +161,13 @@ def _process_for_loss(
159161
"""
160162
return model_output, labels, loss_kwargs
161163

164+
def on_train_epoch_start(self) -> None:
165+
# pass current epoch to datamodule if it has the attribute curr_epoch (for PubChemBatched dataset)
166+
rank_zero_info(f"Starting epoch {self.current_epoch}")
167+
if hasattr(self.trainer.datamodule, "curr_epoch"):
168+
rank_zero_info(f"Setting datamodule.curr_epoch to {self.current_epoch}")
169+
self.trainer.datamodule.curr_epoch = self.current_epoch
170+
162171
def training_step(
163172
self, batch: XYData, batch_idx: int
164173
) -> Dict[str, Union[torch.Tensor, Any]]:
@@ -310,6 +319,8 @@ def _execute(
310319
for metric_name, metric in metrics.items():
311320
metric.update(pr, tar)
312321
self._log_metrics(prefix, metrics, len(batch))
322+
if isinstance(d, dict) and "loss" not in d:
323+
print(f"d has keys {d.keys()}, log={log}, criterion={self.criterion}")
313324
return d
314325

315326
def _log_metrics(self, prefix: str, metrics: torch.nn.Module, batch_size: int):

chebai/models/classic_ml.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import os
2+
import pickle as pkl
3+
from typing import Any, Dict, List, Optional
4+
5+
import numpy as np
6+
import torch
7+
import tqdm
8+
from sklearn.exceptions import NotFittedError
9+
from sklearn.linear_model import LogisticRegression as SklearnLogisticRegression
10+
11+
from chebai.models.base import ChebaiBaseNet
12+
13+
LR_MODEL_PATH = os.path.join("models", "LR")
14+
15+
16+
class LogisticRegression(ChebaiBaseNet):
17+
"""
18+
Logistic Regression model using scikit-learn, wrapped to fit the ChebaiBaseNet interface.
19+
"""
20+
21+
def __init__(
22+
self,
23+
out_dim: int,
24+
input_dim: int,
25+
only_predict_classes: Optional[List] = None,
26+
n_classes=1528,
27+
**kwargs,
28+
):
29+
super().__init__(out_dim=out_dim, input_dim=input_dim, **kwargs)
30+
self.models = [
31+
SklearnLogisticRegression(solver="liblinear") for _ in range(n_classes)
32+
]
33+
# indices of classes (in the dataset used for training) where a model should be trained
34+
self.only_predict_classes = only_predict_classes
35+
36+
def forward(self, x: Dict[str, Any], **kwargs) -> torch.Tensor:
37+
print(
38+
f"forward called with x[features].shape {x['features'].shape}, self.training {self.training}"
39+
)
40+
if self.training:
41+
self.fit_sklearn(x["features"], x["labels"])
42+
preds = []
43+
for model in self.models:
44+
try:
45+
p = torch.from_numpy(model.predict(x["features"])).float()
46+
p = p.to(x["features"].device)
47+
preds.append(p)
48+
except NotFittedError:
49+
preds.append(
50+
torch.zeros((x["features"].shape[0]), device=(x["features"].device))
51+
)
52+
except AttributeError:
53+
preds.append(
54+
torch.zeros((x["features"].shape[0]), device=(x["features"].device))
55+
)
56+
preds = torch.stack(preds, dim=1)
57+
print(f"preds shape {preds.shape}")
58+
return preds.squeeze(-1)
59+
60+
def fit_sklearn(self, X, y):
61+
"""
62+
Fit the underlying sklearn model. X and y should be numpy arrays.
63+
"""
64+
for i, model in tqdm.tqdm(enumerate(self.models), desc="Fitting models"):
65+
import os
66+
67+
if os.path.exists(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl")):
68+
print(f"Loading model {i} from file")
69+
self.models[i] = pkl.load(
70+
open(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl"), "rb")
71+
)
72+
else:
73+
if (
74+
self.only_predict_classes and i not in self.only_predict_classes
75+
): # only try these classes
76+
continue
77+
try:
78+
model.fit(X, y[:, i])
79+
except ValueError:
80+
self.models[i] = PlaceholderModel()
81+
# dump
82+
pkl.dump(
83+
model, open(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl"), "wb")
84+
)
85+
86+
def configure_optimizers(self, **kwargs):
87+
pass
88+
89+
90+
class PlaceholderModel:
91+
"""Acts like a trained model, but isn't. Use this if training fails and you need a placeholder."""
92+
93+
def __init__(self, default_prediction=1):
94+
self.default_prediction = default_prediction
95+
96+
def predict(self, preds):
97+
return np.ones(preds.shape[0]) * self.default_prediction

chebai/models/electra.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def __init__(
227227
pretrained_checkpoint: Optional[str] = None,
228228
load_prefix: Optional[str] = None,
229229
model_type="classification",
230+
freeze_electra: bool = False,
230231
**kwargs: Any,
231232
):
232233
# Remove this property in order to prevent it from being stored as a
@@ -267,9 +268,9 @@ def __init__(
267268
else:
268269
self.electra = ElectraModel(config=self.config)
269270

270-
# freeze parameters
271-
# for param in self.electra.parameters():
272-
# param.requires_grad = False
271+
if freeze_electra:
272+
for param in self.electra.parameters():
273+
param.requires_grad = False
273274

274275
def _process_for_loss(
275276
self,

chebai/models/ffn.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@ def __init__(
1414
hidden_layers: List[int] = [
1515
1024,
1616
],
17+
use_adam_optimizer: bool = False,
1718
**kwargs,
1819
):
1920
super().__init__(**kwargs)
2021

22+
self.use_adam_optimizer: bool = bool(use_adam_optimizer)
23+
print(f"Using Adam optimizer: {self.use_adam_optimizer}")
24+
2125
layers = []
2226
current_layer_input_size = self.input_dim
2327
for hidden_dim in hidden_layers:
@@ -26,7 +30,6 @@ def __init__(
2630
current_layer_input_size = hidden_dim
2731

2832
layers.append(torch.nn.Linear(current_layer_input_size, self.out_dim))
29-
layers.append(nn.Sigmoid())
3033
self.model = nn.Sequential(*layers)
3134

3235
def _get_prediction_and_labels(self, data, labels, model_output):
@@ -63,6 +66,21 @@ def forward(self, data, **kwargs):
6366
x = data["features"]
6467
return {"logits": self.model(x)}
6568

69+
def configure_optimizers(self, **kwargs) -> torch.optim.Optimizer:
70+
"""
71+
Configures the optimizers.
72+
73+
Args:
74+
**kwargs: Additional keyword arguments.
75+
76+
Returns:
77+
torch.optim.Optimizer: The optimizer.
78+
"""
79+
if self.use_adam_optimizer:
80+
return torch.optim.Adam(self.parameters(), **self.optimizer_kwargs)
81+
82+
return torch.optim.Adamax(self.parameters(), **self.optimizer_kwargs)
83+
6684

6785
class Residual(nn.Module):
6886
"""

chebai/models/lstm.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,55 @@
11
import logging
22

33
from torch import nn
4-
from torch.nn.utils.rnn import pack_padded_sequence
4+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
55

66
from chebai.models.base import ChebaiBaseNet
77

88
logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
99

1010

1111
class ChemLSTM(ChebaiBaseNet):
12-
def __init__(self, in_d, out_d, num_classes, **kwargs):
13-
super().__init__(num_classes, **kwargs)
14-
self.lstm = nn.LSTM(in_d, out_d, batch_first=True)
15-
self.embedding = nn.Embedding(800, 100)
12+
def __init__(
13+
self,
14+
out_d,
15+
in_d,
16+
num_classes,
17+
criterion: nn.Module = None,
18+
num_layers=6,
19+
dropout=0.2,
20+
**kwargs,
21+
):
22+
super().__init__(
23+
out_dim=out_d,
24+
input_dim=in_d,
25+
criterion=criterion,
26+
num_classes=num_classes,
27+
**kwargs,
28+
)
29+
self.lstm = nn.LSTM(
30+
in_d,
31+
out_d,
32+
batch_first=True,
33+
dropout=dropout,
34+
bidirectional=True,
35+
num_layers=num_layers,
36+
)
37+
self.embedding = nn.Embedding(1400, in_d)
1638
self.output = nn.Sequential(
17-
nn.Linear(out_d, in_d),
39+
nn.Linear(out_d * 2, out_d),
1840
nn.ReLU(),
1941
nn.Dropout(0.2),
20-
nn.Linear(in_d, num_classes),
42+
nn.Linear(out_d, num_classes),
2143
)
2244

23-
def forward(self, data):
24-
x = data.x
25-
x_lens = data.lens
45+
def forward(self, data, *args, **kwargs):
46+
x = data["features"]
47+
x_lens = data["model_kwargs"]["lens"]
2648
x = self.embedding(x)
2749
x = pack_padded_sequence(x, x_lens, batch_first=True, enforce_sorted=False)
28-
x = self.lstm(x)[1][0]
29-
# = pad_packed_sequence(x, batch_first=True)[0]
50+
x = self.lstm(x)[0]
51+
x = pad_packed_sequence(x, batch_first=True)[0][
52+
:, 0
53+
] # reduce sequence dimension to first element
3054
x = self.output(x)
31-
return x.squeeze(0)
55+
return x

0 commit comments

Comments
 (0)