@@ -229,7 +229,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple:
229
229
return total_loss .mean (), loss_components
230
230
231
231
def _calculate_implication_loss (
232
- self , l : torch .Tensor , r : torch .Tensor
232
+ self , l_ : torch .Tensor , r : torch .Tensor
233
233
) -> torch .Tensor :
234
234
"""
235
235
Calculate implication loss based on T-norm and other parameters.
@@ -241,17 +241,17 @@ def _calculate_implication_loss(
241
241
Returns:
242
242
torch.Tensor: Calculated implication loss.
243
243
"""
244
- assert not l .isnan ().any (), (
245
- f"l contains NaN values - l.shape: { l .shape } , l.isnan().sum(): { l .isnan ().sum ()} , "
246
- f"l: { l } "
244
+ assert not l_ .isnan ().any (), (
245
+ f"l contains NaN values - l.shape: { l_ .shape } , l.isnan().sum(): { l_ .isnan ().sum ()} , "
246
+ f"l: { l_ } "
247
247
)
248
248
assert not r .isnan ().any (), (
249
249
f"r contains NaN values - r.shape: { r .shape } , r.isnan().sum(): { r .isnan ().sum ()} , "
250
250
f"r: { r } "
251
251
)
252
252
if self .pos_scalar != 1 :
253
- l = (
254
- torch .pow (l + self .eps , 1 / self .pos_scalar )
253
+ l_ = (
254
+ torch .pow (l_ + self .eps , 1 / self .pos_scalar )
255
255
- math .pow (self .eps , 1 / self .pos_scalar )
256
256
) / (
257
257
math .pow (1 + self .eps , 1 / self .pos_scalar )
@@ -269,21 +269,21 @@ def _calculate_implication_loss(
269
269
# for each implication I, calculate 1 - I(l, 1-one_min_r)
270
270
# for S-implications, this is equivalent to the t-norm
271
271
if self .fuzzy_implication in ["reichenbach" , "rc" ]:
272
- individual_loss = l * one_min_r
272
+ individual_loss = l_ * one_min_r
273
273
# xu19 (from Xu et al., 2019: Semantic loss) is not a fuzzy implication, but behaves similar to the Reichenbach
274
274
# implication
275
275
elif self .fuzzy_implication == "xu19" :
276
- individual_loss = - torch .log (1 - l * one_min_r )
276
+ individual_loss = - torch .log (1 - l_ * one_min_r )
277
277
elif self .fuzzy_implication in ["lukasiewicz" , "lk" ]:
278
- individual_loss = torch .relu (l + one_min_r - 1 )
278
+ individual_loss = torch .relu (l_ + one_min_r - 1 )
279
279
elif self .fuzzy_implication in ["kleene_dienes" , "kd" ]:
280
- individual_loss = torch .min (l , 1 - r )
280
+ individual_loss = torch .min (l_ , 1 - r )
281
281
elif self .fuzzy_implication in ["goedel" , "g" ]:
282
- individual_loss = torch .where (l <= r , 0 , one_min_r )
282
+ individual_loss = torch .where (l_ <= r , 0 , one_min_r )
283
283
elif self .fuzzy_implication in ["reverse-goedel" , "rg" ]:
284
- individual_loss = torch .where (l <= r , 0 , l )
284
+ individual_loss = torch .where (l_ <= r , 0 , l_ )
285
285
elif self .fuzzy_implication in ["binary" , "b" ]:
286
- individual_loss = torch .where (l <= r , 0 , 1 ).to (dtype = l .dtype )
286
+ individual_loss = torch .where (l_ <= r , 0 , 1 ).to (dtype = l_ .dtype )
287
287
else :
288
288
raise NotImplementedError (
289
289
f"Unknown fuzzy implication { self .fuzzy_implication } "
@@ -453,8 +453,8 @@ def _build_implication_filter(label_names: List, hierarchy: dict) -> torch.Tenso
453
453
454
454
def _build_dense_filter (sparse_filter : torch .Tensor , n_labels : int ) -> torch .Tensor :
455
455
res = torch .zeros ((n_labels , n_labels ), dtype = torch .bool )
456
- for l , r in sparse_filter :
457
- res [l , r ] = True
456
+ for l_ , r in sparse_filter :
457
+ res [l_ , r ] = True
458
458
return res
459
459
460
460
@@ -511,8 +511,8 @@ def _build_disjointness_filter(
511
511
random_labels = torch .randint (0 , 2 , (10 , 997 ))
512
512
for agg in ["sum" , "max" , "mean" , "log-mean" ]:
513
513
loss .violations_per_cls_aggregator = agg
514
- l = loss (random_preds , random_labels )
515
- print (f"Loss with { agg } aggregation for random input:" , l )
514
+ l_ = loss (random_preds , random_labels )
515
+ print (f"Loss with { agg } aggregation for random input:" , l_ )
516
516
517
517
# simplified example for ontology with 4 classes, A -> B, B -> C, D -> C, B and D disjoint
518
518
loss .implication_filter_l = torch .tensor (
@@ -528,5 +528,5 @@ def _build_disjointness_filter(
528
528
labels = [[0 , 1 , 1 , 0 ], [0 , 0 , 1 , 1 ]]
529
529
for agg in ["sum" , "max" , "mean" , "log-mean" ]:
530
530
loss .violations_per_cls_aggregator = agg
531
- l = loss (preds , torch .tensor (labels ))
532
- print (f"Loss with { agg } aggregation for simple input:" , l )
531
+ l_ = loss (preds , torch .tensor (labels ))
532
+ print (f"Loss with { agg } aggregation for simple input:" , l_ )
0 commit comments