@@ -16,57 +16,53 @@ def get_config(self) -> dict:
1616
1717 @staticmethod
1818 def from_config (config : dict ) -> 'LossFunction' :
19- if config ['name' ] == 'MeanSquaredError' :
20- return MeanSquaredError ()
21- elif config ['name' ] == 'BinaryCrossentropy' :
22- return BinaryCrossentropy ()
23- elif config ['name' ] == 'CategoricalCrossentropy' :
24- return CategoricalCrossentropy ()
25- elif config ['name' ] == 'MeanAbsoluteError' :
26- return MeanAbsoluteError ()
27- elif config ['name' ] == 'HuberLoss' :
28- return HuberLoss (config ['delta' ])
29- elif config ['name' ] == 'KullbackLeiblerDivergence' :
30- return KullbackLeiblerDivergence ()
31- elif config ['name' ] == 'CrossEntropyWithLabelSmoothing' :
32- return CrossEntropyWithLabelSmoothing (config ['label_smoothing' ])
33- elif config ['name' ] == 'Wasserstein' :
34- return Wasserstein ()
35- elif config ['name' ] == 'FocalLoss' :
36- return FocalLoss (config ['gamma' ], config ['alpha' ])
37- else :
38- raise ValueError (f'Unknown loss function: { config ["name" ]} ' )
19+ loss_name = config ['name' ]
20+
21+ for loss_class in LossFunction .__subclasses__ ():
22+ if loss_class .__name__ == loss_name :
23+ constructor_params = {k : v for k , v in config .items () if k != 'name' }
24+ return loss_class (** constructor_params )
3925
4026 @staticmethod
4127 def from_name (name : str ) -> "LossFunction" :
28+ aliases = {
29+ "mse" : "MeanSquaredError" ,
30+ "bce" : "BinaryCrossentropy" ,
31+ "cce" : "CategoricalCrossentropy" ,
32+ "scce" : "SparseCategoricalCrossentropy" ,
33+ "mae" : "MeanAbsoluteError" ,
34+ "kld" : "KullbackLeiblerDivergence" ,
35+ "cels" : "CrossEntropyWithLabelSmoothing" ,
36+ "wass" : "Wasserstein" ,
37+ "focal" : "FocalLoss" ,
38+ "fl" : "FocalLoss"
39+ }
40+
41+ original_name = name
4242 name = name .lower ().replace ("_" , "" )
43- if name == "mse" or name == "meansquarederror" :
44- return MeanSquaredError ()
45- elif name == "bce" or name == "binarycrossentropy" :
46- return BinaryCrossentropy ()
47- elif name == "cce" or name == "categorycrossentropy" :
48- return CategoricalCrossentropy ()
49- elif name == "scce" or name == "sparsecategoricalcrossentropy" :
50- return SparseCategoricalCrossentropy ()
51- elif name == "mae" or name == "meanabsoluteerror" :
52- return MeanAbsoluteError ()
53- elif name == "kld" or name == "kullbackleiblerdivergence" :
54- return KullbackLeiblerDivergence ()
55- elif name == "crossentropywithlabelsmoothing" or name == "cels" :
56- return CrossEntropyWithLabelSmoothing ()
57- elif name == "Wasserstein" or name == "wasserstein" or name == "wass" :
58- return Wasserstein ()
59- elif name == "focalloss" or name == "focal" or name == "fl" :
60- return FocalLoss ()
61- elif name .startswith ("huber" ) and len (name .split ("_" )) == 2 :
62- delta = float (name .split ("_" )[- 1 ])
63- return HuberLoss (delta )
64- else :
65- for subclass in LossFunction .__subclasses__ ():
66- if subclass .__name__ .lower () == name :
67- return subclass ()
68-
69- raise ValueError (f"No loss function found for the name: { name } " )
43+
44+ if name .startswith ("huber" ) and len (original_name .split ("_" )) == 2 :
45+ try :
46+ delta = float (original_name .split ("_" )[- 1 ])
47+ return Huber (delta = delta )
48+ except ValueError :
49+ pass
50+
51+ if name in aliases :
52+ name = aliases [name ]
53+
54+ for loss_class in LossFunction .__subclasses__ ():
55+ if loss_class .__name__ .lower () == name or loss_class .__name__ == name :
56+ if loss_class .__name__ == "Huber" :
57+ return loss_class (delta = 1.0 )
58+ elif loss_class .__name__ == "CrossEntropyWithLabelSmoothing" :
59+ return loss_class (label_smoothing = 0.1 )
60+ elif loss_class .__name__ == "FocalLoss" :
61+ return loss_class (gamma = 2.0 , alpha = 0.25 )
62+ else :
63+ return loss_class ()
64+
65+ raise ValueError (f"No loss function found for the name: { original_name } " )
7066
7167
7268class MeanSquaredError (LossFunction ):
@@ -142,7 +138,7 @@ def __str__(self):
142138 return "MeanAbsoluteError"
143139
144140
145- class HuberLoss (LossFunction ):
141+ class Huber (LossFunction ):
146142 def __init__ (self , delta : float = 1.0 ):
147143 super ().__init__ ()
148144 self .delta = delta
0 commit comments