Skip to content

Commit 9557f95

Browse files
authored
Merge pull request #104 from ChEB-AI/fix/ensemble-compatibility
Ensemble compatibility
2 parents f2d62fd + ff6b52a commit 9557f95

File tree

2 files changed

+33
-41
lines changed

2 files changed

+33
-41
lines changed

chebai/loss/semantic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88

99
from chebai.loss.bce_weighted import BCEWeighted
10-
from chebai.preprocessing.datasets import XYBaseDataModule
10+
from chebai.preprocessing.datasets.base import XYBaseDataModule
1111
from chebai.preprocessing.datasets.chebi import ChEBIOver100, _ChEBIDataExtractor
1212
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
1313

chebai/result/analyse_sem.py

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ def get_chebi_graph(data_module, label_names):
141141
chebi_graph = data_module._extract_class_hierarchy(
142142
os.path.join(data_module.raw_dir, "chebi.obo")
143143
)
144+
if label_names is None:
145+
return chebi_graph
144146
return chebi_graph.subgraph([int(n) for n in label_names])
145147
print(
146148
f"Failed to retrieve ChEBI graph, {os.path.join(data_module.raw_dir, 'chebi.obo')} not found"
@@ -196,39 +198,38 @@ class PredictionSmoother:
196198
"""Removes implication and disjointness violations from predictions"""
197199

198200
def __init__(self, dataset, label_names=None, disjoint_files=None):
199-
if label_names:
200-
self.label_names = label_names
201-
else:
202-
self.label_names = get_label_names(dataset)
203-
self.chebi_graph = get_chebi_graph(dataset, self.label_names)
201+
self.chebi_graph = get_chebi_graph(dataset, None)
202+
self.set_label_names(label_names)
204203
self.disjoint_groups = get_disjoint_groups(disjoint_files)
205204

205+
def set_label_names(self, label_names):
206+
if label_names is not None:
207+
self.label_names = [int(label) for label in label_names]
208+
chebi_subgraph = self.chebi_graph.subgraph(self.label_names)
209+
self.label_successors = torch.zeros(
210+
(len(self.label_names), len(self.label_names)), dtype=torch.bool
211+
)
212+
for i, label in enumerate(self.label_names):
213+
self.label_successors[i, i] = 1
214+
for p in chebi_subgraph.successors(label):
215+
if p in self.label_names:
216+
self.label_successors[i, self.label_names.index(p)] = 1
217+
self.label_successors = self.label_successors.unsqueeze(0)
218+
206219
def __call__(self, preds):
207220
preds_sum_orig = torch.sum(preds)
208-
for i, label in enumerate(self.label_names):
209-
succs = [
210-
self.label_names.index(str(p))
211-
for p in self.chebi_graph.successors(int(label))
212-
] + [i]
213-
if len(succs) > 0:
214-
if torch.max(preds[:, succs], dim=1).values > 0.5 and preds[:, i] < 0.5:
215-
print(
216-
f"Correcting prediction for {label} to max of subclasses {list(self.chebi_graph.successors(int(label)))}"
217-
)
218-
print(
219-
f"Original pred: {preds[:, i]}, successors: {preds[:, succs]}"
220-
)
221-
preds[:, i] = torch.max(preds[:, succs], dim=1).values
221+
# step 1: apply implications: for each class, set prediction to max of itself and all successors
222+
preds = preds.unsqueeze(1)
223+
preds_masked_succ = torch.where(self.label_successors, preds, 0)
224+
preds = preds_masked_succ.max(dim=2).values
222225
if torch.sum(preds) != preds_sum_orig:
223226
print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}")
224227
preds_sum_orig = torch.sum(preds)
225228
# step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
226229
preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49)
227230
for disj_group in self.disjoint_groups:
228231
disj_group = [
229-
self.label_names.index(str(g))
230-
for g in disj_group
231-
if g in self.label_names
232+
self.label_names.index(g) for g in disj_group if g in self.label_names
232233
]
233234
if len(disj_group) > 1:
234235
old_preds = preds[:, disj_group]
@@ -245,26 +246,17 @@ def __call__(self, preds):
245246
print(
246247
f"disjointness group {[self.label_names[d] for d in disj_group]} changed {samples_changed} samples"
247248
)
249+
if torch.sum(preds) != preds_sum_orig:
250+
print(f"Preds change (step 2): {torch.sum(preds) - preds_sum_orig}")
248251
preds_sum_orig = torch.sum(preds)
249252
# step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors
250-
for i, label in enumerate(self.label_names):
251-
predecessors = [i] + [
252-
self.label_names.index(str(p))
253-
for p in self.chebi_graph.predecessors(int(label))
254-
]
255-
lowest_predecessors = torch.min(preds[:, predecessors], dim=1)
256-
preds[:, i] = lowest_predecessors.values
257-
for idx_idx, idx in enumerate(lowest_predecessors.indices):
258-
if idx > 0:
259-
print(
260-
f"class {label}: changed prediction of sample {idx_idx} to value of class "
261-
f"{self.label_names[predecessors[idx]]} ({preds[idx_idx, i].item():.2f})"
262-
)
263-
if torch.sum(preds) != preds_sum_orig:
264-
print(
265-
f"Preds change (step 3) for {label}: {torch.sum(preds) - preds_sum_orig}"
266-
)
267-
preds_sum_orig = torch.sum(preds)
253+
preds = preds.unsqueeze(1)
254+
preds_masked_predec = torch.where(
255+
torch.transpose(self.label_successors, 1, 2), preds, 1
256+
)
257+
preds = preds_masked_predec.min(dim=2).values
258+
if torch.sum(preds) != preds_sum_orig:
259+
print(f"Preds change (step 3): {torch.sum(preds) - preds_sum_orig}")
268260
return preds
269261

270262

0 commit comments

Comments
 (0)