Skip to content

Commit 2ecc159

Browse files
committed
Merge remote-tracking branch 'origin/master'
2 parents d774567 + 8fcf54b commit 2ecc159

File tree

3 files changed

+33
-17
lines changed

3 files changed

+33
-17
lines changed

chebai/models/electra.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ def __call__(self, target, input):
111111
return gen_loss + disc_loss
112112

113113

114+
def filter_dict(d, filter_key):
115+
return {str(k)[len(filter_key):]: v for k, v in
116+
d.items() if
117+
str(k).startswith(filter_key)}
118+
119+
114120
class Electra(JCIBaseNet):
115121
NAME = "Electra"
116122

@@ -151,26 +157,26 @@ def __init__(self, **kwargs):
151157
self.config = ElectraConfig(**kwargs["config"], output_attentions=True)
152158
self.word_dropout = nn.Dropout(kwargs["config"].get("word_dropout", 0))
153159
model_prefix = kwargs.get("load_prefix", None)
154-
if pretrained_checkpoint:
155-
with open(pretrained_checkpoint, "rb") as fin:
156-
model_dict = torch.load(fin,map_location=self.device)
157-
if model_prefix:
158-
state_dict = {str(k)[len(model_prefix):]:v for k,v in model_dict["state_dict"].items() if str(k).startswith(model_prefix)}
159-
else:
160-
state_dict = model_dict["state_dict"]
161-
self.electra = ElectraModel.from_pretrained(None, state_dict=state_dict, config=self.config)
162-
else:
163-
self.electra = ElectraModel(config=self.config)
164160

165161
in_d = self.config.hidden_size
166-
167162
self.output = nn.Sequential(
168163
nn.Dropout(self.config.hidden_dropout_prob),
169164
nn.Linear(in_d, in_d),
170165
nn.GELU(),
171166
nn.Dropout(self.config.hidden_dropout_prob),
172167
nn.Linear(in_d, self.config.num_labels),
173168
)
169+
if pretrained_checkpoint:
170+
with open(pretrained_checkpoint, "rb") as fin:
171+
model_dict = torch.load(fin,map_location=self.device)
172+
if model_prefix:
173+
state_dict = filter_dict(model_dict["state_dict"], model_prefix)
174+
else:
175+
state_dict = model_dict["state_dict"]
176+
self.electra = ElectraModel.from_pretrained(None, state_dict={k:v for (k,v) in state_dict.items() if k.startswith("electra.")}, config=self.config)
177+
self.output.load_state_dict(filter_dict(state_dict,"output."))
178+
else:
179+
self.electra = ElectraModel(config=self.config)
174180

175181
def _get_data_for_loss(self, model_output, labels):
176182
mask = model_output.get("target_mask")

chebai/result/molplot.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
from chebai.result.base import ResultProcessor
2121

2222

23-
class AttentionMolPlot(abc.ABC):
24-
def plot_attentions(self, smiles, attention, threshold, labels):
23+
class AttentionMolPlot:
24+
25+
def draw_attention_molecule(self, smiles, attention):
2526
pmol = self.read_smiles_with_index(smiles)
2627
rdmol = Chem.MolFromSmiles(smiles)
2728
if not rdmol:
@@ -34,26 +35,33 @@ def plot_attentions(self, smiles, attention, threshold, labels):
3435
}
3536
d = rdMolDraw2D.MolDraw2DCairo(500, 500)
3637
cmap = cm.ScalarMappable(cmap=cm.Greens)
37-
attention_colors = cmap.to_rgba(attention, norm=False)
38+
3839
aggr_attention_colors = cmap.to_rgba(
3940
np.max(attention[2:, :], axis=0), norm=False
4041
)
4142
cols = {
4243
token_to_node_map[token_index]: tuple(
4344
aggr_attention_colors[token_index].tolist()
4445
)
45-
for node, token_index in nx.get_node_attributes(pmol, "token_index").items()
46+
for node, token_index in
47+
nx.get_node_attributes(pmol, "token_index").items()
4648
}
4749
highlight_atoms = [
4850
token_to_node_map[token_index]
49-
for node, token_index in nx.get_node_attributes(pmol, "token_index").items()
51+
for node, token_index in
52+
nx.get_node_attributes(pmol, "token_index").items()
5053
]
5154
rdMolDraw2D.PrepareAndDrawMolecule(
5255
d, rdmol, highlightAtoms=highlight_atoms, highlightAtomColors=cols
5356
)
5457

5558
d.FinishDrawing()
59+
return d
5660

61+
def plot_attentions(self, smiles, attention, threshold, labels):
62+
d = self.draw_attention_molecule(smiles, attention)
63+
cmap = cm.ScalarMappable(cmap=cm.Greens)
64+
attention_colors = cmap.to_rgba(attention, norm=False)
5765
num_tokens = sum(1 for _ in _tokenize(smiles))
5866

5967
fig = plt.figure(figsize=(15, 15), facecolor="w")

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from setuptools import setup
22

33
setup(
4-
name="ChEBI-learn",
4+
name="chebai",
55
version="0.0.0",
66
packages=["chebai", "chebai.models"],
77
url="",
@@ -39,6 +39,8 @@
3939
"scikit-network",
4040
"svgutils",
4141
"matplotlib",
42+
"rdkit",
43+
"selfies"
4244
],
4345
extras_require={"dev": ["black", "isort", "pre-commit"]},
4446
)

0 commit comments

Comments
 (0)