@@ -141,6 +141,8 @@ def get_chebi_graph(data_module, label_names):
141
141
chebi_graph = data_module ._extract_class_hierarchy (
142
142
os .path .join (data_module .raw_dir , "chebi.obo" )
143
143
)
144
+ if label_names is None :
145
+ return chebi_graph
144
146
return chebi_graph .subgraph ([int (n ) for n in label_names ])
145
147
print (
146
148
f"Failed to retrieve ChEBI graph, { os .path .join (data_module .raw_dir , 'chebi.obo' )} not found"
@@ -196,39 +198,38 @@ class PredictionSmoother:
196
198
"""Removes implication and disjointness violations from predictions"""
197
199
198
200
def __init__ (self , dataset , label_names = None , disjoint_files = None ):
199
- if label_names :
200
- self .label_names = label_names
201
- else :
202
- self .label_names = get_label_names (dataset )
203
- self .chebi_graph = get_chebi_graph (dataset , self .label_names )
201
+ self .chebi_graph = get_chebi_graph (dataset , None )
202
+ self .set_label_names (label_names )
204
203
self .disjoint_groups = get_disjoint_groups (disjoint_files )
205
204
205
+ def set_label_names (self , label_names ):
206
+ if label_names is not None :
207
+ self .label_names = [int (label ) for label in label_names ]
208
+ chebi_subgraph = self .chebi_graph .subgraph (self .label_names )
209
+ self .label_successors = torch .zeros (
210
+ (len (self .label_names ), len (self .label_names )), dtype = torch .bool
211
+ )
212
+ for i , label in enumerate (self .label_names ):
213
+ self .label_successors [i , i ] = 1
214
+ for p in chebi_subgraph .successors (label ):
215
+ if p in self .label_names :
216
+ self .label_successors [i , self .label_names .index (p )] = 1
217
+ self .label_successors = self .label_successors .unsqueeze (0 )
218
+
206
219
def __call__ (self , preds ):
207
220
preds_sum_orig = torch .sum (preds )
208
- for i , label in enumerate (self .label_names ):
209
- succs = [
210
- self .label_names .index (str (p ))
211
- for p in self .chebi_graph .successors (int (label ))
212
- ] + [i ]
213
- if len (succs ) > 0 :
214
- if torch .max (preds [:, succs ], dim = 1 ).values > 0.5 and preds [:, i ] < 0.5 :
215
- print (
216
- f"Correcting prediction for { label } to max of subclasses { list (self .chebi_graph .successors (int (label )))} "
217
- )
218
- print (
219
- f"Original pred: { preds [:, i ]} , successors: { preds [:, succs ]} "
220
- )
221
- preds [:, i ] = torch .max (preds [:, succs ], dim = 1 ).values
221
+ # step 1: apply implications: for each class, set prediction to max of itself and all successors
222
+ preds = preds .unsqueeze (1 )
223
+ preds_masked_succ = torch .where (self .label_successors , preds , 0 )
224
+ preds = preds_masked_succ .max (dim = 2 ).values
222
225
if torch .sum (preds ) != preds_sum_orig :
223
226
print (f"Preds change (step 1): { torch .sum (preds ) - preds_sum_orig } " )
224
227
preds_sum_orig = torch .sum (preds )
225
228
# step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
226
229
preds_bounded = torch .min (preds , torch .ones_like (preds ) * 0.49 )
227
230
for disj_group in self .disjoint_groups :
228
231
disj_group = [
229
- self .label_names .index (str (g ))
230
- for g in disj_group
231
- if g in self .label_names
232
+ self .label_names .index (g ) for g in disj_group if g in self .label_names
232
233
]
233
234
if len (disj_group ) > 1 :
234
235
old_preds = preds [:, disj_group ]
@@ -245,26 +246,17 @@ def __call__(self, preds):
245
246
print (
246
247
f"disjointness group { [self .label_names [d ] for d in disj_group ]} changed { samples_changed } samples"
247
248
)
249
+ if torch .sum (preds ) != preds_sum_orig :
250
+ print (f"Preds change (step 2): { torch .sum (preds ) - preds_sum_orig } " )
248
251
preds_sum_orig = torch .sum (preds )
249
252
# step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors
250
- for i , label in enumerate (self .label_names ):
251
- predecessors = [i ] + [
252
- self .label_names .index (str (p ))
253
- for p in self .chebi_graph .predecessors (int (label ))
254
- ]
255
- lowest_predecessors = torch .min (preds [:, predecessors ], dim = 1 )
256
- preds [:, i ] = lowest_predecessors .values
257
- for idx_idx , idx in enumerate (lowest_predecessors .indices ):
258
- if idx > 0 :
259
- print (
260
- f"class { label } : changed prediction of sample { idx_idx } to value of class "
261
- f"{ self .label_names [predecessors [idx ]]} ({ preds [idx_idx , i ].item ():.2f} )"
262
- )
263
- if torch .sum (preds ) != preds_sum_orig :
264
- print (
265
- f"Preds change (step 3) for { label } : { torch .sum (preds ) - preds_sum_orig } "
266
- )
267
- preds_sum_orig = torch .sum (preds )
253
+ preds = preds .unsqueeze (1 )
254
+ preds_masked_predec = torch .where (
255
+ torch .transpose (self .label_successors , 1 , 2 ), preds , 1
256
+ )
257
+ preds = preds_masked_predec .min (dim = 2 ).values
258
+ if torch .sum (preds ) != preds_sum_orig :
259
+ print (f"Preds change (step 3): { torch .sum (preds ) - preds_sum_orig } " )
268
260
return preds
269
261
270
262
0 commit comments