Skip to content

Commit 83a92c9

Browse files
committed
merge dev into parthood-loss
1 parent da2f9f7 commit 83a92c9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+3842
-56
lines changed

chebai/loss/bce_weighted.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,32 +50,29 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
5050
and self.data_extractor is not None
5151
and all(
5252
os.path.exists(
53-
os.path.join(self.data_extractor.processed_dir_main, file_name)
53+
os.path.join(self.data_extractor.processed_dir, file_name)
5454
)
55-
for file_name in self.data_extractor.processed_main_file_names
55+
for file_name in self.data_extractor.processed_file_names
5656
)
5757
and self.pos_weight is None
5858
):
5959
print(
6060
f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})"
6161
)
62-
complete_data = pd.concat(
62+
complete_labels = torch.concat(
6363
[
64-
pd.read_pickle(
65-
open(
66-
os.path.join(
67-
self.data_extractor.processed_dir_main,
68-
file_name,
69-
),
70-
"rb",
71-
)
64+
torch.stack(
65+
[
66+
torch.Tensor(row["labels"])
67+
for row in self.data_extractor.load_processed_data(
68+
filename=file_name
69+
)
70+
]
7271
)
73-
for file_name in self.data_extractor.processed_main_file_names
72+
for file_name in self.data_extractor.processed_file_names
7473
]
7574
)
76-
value_counts = []
77-
for c in complete_data.columns[3:]:
78-
value_counts.append(len([v for v in complete_data[c] if v]))
75+
value_counts = complete_labels.sum(dim=0)
7976
weights = [
8077
(1 - self.beta) / (1 - pow(self.beta, value)) for value in value_counts
8178
]

chebai/models/electra.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
332332
except RuntimeError as e:
333333
print(f"RuntimeError at forward: {e}")
334334
print(f'data[features]: {data["features"]}')
335-
raise Exception
335+
raise e
336336
inp = self.word_dropout(inp)
337337
electra = self.electra(inputs_embeds=inp, **kwargs)
338338
d = electra.last_hidden_state[:, 0, :]
@@ -344,14 +344,14 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
344344

345345
class ElectraAugmented(Electra):
346346
"""Electra model that takes global properties (i.e., which substructures are part of a molecule) into account.
347-
The smiles embedding of size [batch_size, smiles_len, config.embedding_size - """
347+
The smiles embedding of size [batch_size, smiles_len, config.embedding_size -"""
348348

349349
NAME = "ElectraAugmented"
350350

351351
def __init__(
352352
self,
353-
add_embedding_size = 16,
354-
n_global_properties = 2841,
353+
add_embedding_size=16,
354+
n_global_properties=2841,
355355
**kwargs: Any,
356356
):
357357

@@ -361,7 +361,9 @@ def __init__(
361361
smiles_config.embedding_size = self.config.embedding_size - add_embedding_size
362362
self.smiles_embedding = ElectraEmbeddings(config=smiles_config)
363363

364-
self.add_embedding = nn.Linear(n_global_properties, add_embedding_size, device=self.device)
364+
self.add_embedding = nn.Linear(
365+
n_global_properties, add_embedding_size, device=self.device
366+
)
365367
self.add_norm = nn.LayerNorm(add_embedding_size, device=self.device)
366368

367369
def _process_batch(self, batch: XYData, batch_idx: int) -> Dict[str, Any]:

chebai/models/ffn.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
from typing import Any, Dict, List, Optional, Tuple
2+
3+
import torch
4+
from torch import Tensor, nn
5+
6+
from chebai.models import ChebaiBaseNet
7+
8+
9+
class FFN(ChebaiBaseNet):
10+
# Reference: https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/models.py#L121-L139
11+
12+
NAME = "FFN"
13+
14+
def __init__(
15+
self,
16+
input_size: int,
17+
hidden_layers: List[int] = [
18+
1024,
19+
],
20+
**kwargs
21+
):
22+
super().__init__(**kwargs)
23+
24+
layers = []
25+
current_layer_input_size = input_size
26+
for hidden_dim in hidden_layers:
27+
layers.append(MLPBlock(current_layer_input_size, hidden_dim))
28+
layers.append(Residual(MLPBlock(hidden_dim, hidden_dim)))
29+
current_layer_input_size = hidden_dim
30+
31+
layers.append(torch.nn.Linear(current_layer_input_size, self.out_dim))
32+
layers.append(nn.Sigmoid())
33+
self.model = nn.Sequential(*layers)
34+
35+
def _get_prediction_and_labels(self, data, labels, model_output):
36+
d = model_output["logits"]
37+
loss_kwargs = data.get("loss_kwargs", dict())
38+
if "non_null_labels" in loss_kwargs:
39+
n = loss_kwargs["non_null_labels"]
40+
d = d[n]
41+
return torch.sigmoid(d), labels.int() if labels is not None else None
42+
43+
def _process_for_loss(
44+
self,
45+
model_output: Dict[str, Tensor],
46+
labels: Tensor,
47+
loss_kwargs: Dict[str, Any],
48+
) -> Tuple[Tensor, Tensor, Dict[str, Any]]:
49+
"""
50+
Process the model output for calculating the loss.
51+
52+
Args:
53+
model_output (Dict[str, Tensor]): The output of the model.
54+
labels (Tensor): The target labels.
55+
loss_kwargs (Dict[str, Any]): Additional loss arguments.
56+
57+
Returns:
58+
tuple: A tuple containing the processed model output, labels, and loss arguments.
59+
"""
60+
kwargs_copy = dict(loss_kwargs)
61+
if labels is not None:
62+
labels = labels.float()
63+
return model_output["logits"], labels, kwargs_copy
64+
65+
def forward(self, data, **kwargs):
66+
x = data["features"]
67+
return {"logits": self.model(x)}
68+
69+
70+
class Residual(nn.Module):
71+
"""
72+
A residual layer that adds the output of a function to its input.
73+
74+
Args:
75+
fn (nn.Module): The function to be applied to the input.
76+
77+
References:
78+
https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/base.py#L6-L35
79+
"""
80+
81+
def __init__(self, fn):
82+
"""
83+
Initialize the Residual layer with a given function.
84+
85+
Args:
86+
fn (nn.Module): The function to be applied to the input.
87+
"""
88+
super().__init__()
89+
self.fn = fn
90+
91+
def forward(self, x):
92+
"""
93+
Forward pass of the Residual layer.
94+
95+
Args:
96+
x: Input tensor.
97+
98+
Returns:
99+
torch.Tensor: The input tensor added to the result of applying the function `fn` to it.
100+
"""
101+
return x + self.fn(x)
102+
103+
104+
class MLPBlock(nn.Module):
105+
"""
106+
A basic Multi-Layer Perceptron (MLP) block with one fully connected layer.
107+
108+
Args:
109+
in_features (int): The number of input features.
110+
output_size (int): The number of output features.
111+
bias (boolean): Add bias to the linear layer
112+
layer_norm (boolean): Apply layer normalization
113+
dropout (float): The dropout value
114+
activation (nn.Module): The activation function to be applied after each fully connected layer.
115+
116+
References:
117+
https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/base.py#L38-L73
118+
119+
Example:
120+
```python
121+
# Create an MLP block with 2 hidden layers and ReLU activation
122+
mlp_block = MLPBlock(input_size=64, output_size=10, activation=nn.ReLU())
123+
124+
# Apply the MLP block to an input tensor
125+
input_tensor = torch.randn(32, 64)
126+
output = mlp_block(input_tensor)
127+
```
128+
"""
129+
130+
def __init__(
131+
self,
132+
in_features,
133+
out_features,
134+
bias=True,
135+
layer_norm=True,
136+
dropout=0.1,
137+
activation=nn.ReLU,
138+
):
139+
super().__init__()
140+
self.linear = nn.Linear(in_features, out_features, bias)
141+
self.activation = activation()
142+
self.layer_norm: Optional[nn.LayerNorm] = (
143+
nn.LayerNorm(out_features) if layer_norm else None
144+
)
145+
self.dropout: Optional[nn.Dropout] = nn.Dropout(dropout) if dropout else None
146+
147+
def forward(self, x):
148+
x = self.activation(self.linear(x))
149+
if self.layer_norm:
150+
x = self.layer_norm(x)
151+
if self.dropout:
152+
x = self.dropout(x)
153+
return x

chebai/preprocessing/bin/protein_token/tokens.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ W
1818
E
1919
V
2020
H
21+
X

0 commit comments

Comments
 (0)