@@ -32,6 +32,8 @@ def from_config(config: dict) -> 'LossFunction':
3232 return CrossEntropyWithLabelSmoothing (config ['label_smoothing' ])
3333 elif config ['name' ] == 'Wasserstein' :
3434 return Wasserstein ()
35+ elif config ['name' ] == 'FocalLoss' :
36+ return FocalLoss (config ['gamma' ], config ['alpha' ])
3537 else :
3638 raise ValueError (f'Unknown loss function: { config ["name" ]} ' )
3739
@@ -54,6 +56,8 @@ def from_name(name: str) -> "LossFunction":
5456 return CrossEntropyWithLabelSmoothing ()
5557 elif name == "Wasserstein" or name == "wasserstein" or name == "wass" :
5658 return Wasserstein ()
59+ elif name == "focalloss" or name == "focal" or name == "fl" :
60+ return FocalLoss ()
5761 elif name .startswith ("huber" ) and len (name .split ("_" )) == 2 :
5862 delta = float (name .split ("_" )[- 1 ])
5963 return HuberLoss (delta )
@@ -228,3 +232,53 @@ def derivative(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
228232
229233 def __str__ (self ):
230234 return "Wasserstein"
235+
236+
237+ class FocalLoss (LossFunction ):
238+ def __init__ (self , gamma : float = 2.0 , alpha : float = 0.25 ):
239+ super ().__init__ ()
240+ self .gamma = gamma
241+ self .alpha = alpha
242+
243+ def __call__ (self , y_true : np .ndarray , y_pred : np .ndarray ) -> float :
244+ y_pred = np .clip (y_pred , self .EPSILON , 1 - self .EPSILON )
245+
246+ ce_loss = - y_true * np .log (y_pred ) - (1 - y_true ) * np .log (1 - y_pred )
247+
248+ p_t = y_true * y_pred + (1 - y_true ) * (1 - y_pred )
249+ modulating_factor = np .power (1 - p_t , self .gamma )
250+
251+ alpha_factor = y_true * self .alpha + (1 - y_true ) * (1 - self .alpha )
252+
253+ focal_loss = alpha_factor * modulating_factor * ce_loss
254+
255+ return np .mean (focal_loss )
256+
257+ def derivative (self , y_true : np .ndarray , y_pred : np .ndarray ) -> np .ndarray :
258+ y_pred = np .clip (y_pred , self .EPSILON , 1 - self .EPSILON )
259+
260+ p_t = y_true * y_pred + (1 - y_true ) * (1 - y_pred )
261+
262+ alpha_factor = y_true * self .alpha + (1 - y_true ) * (1 - self .alpha )
263+
264+ modulating_factor = np .power (1 - p_t , self .gamma )
265+ d_modulating_factor = - self .gamma * np .power (1 - p_t , self .gamma - 1 )
266+
267+ d_ce = y_true / y_pred - (1 - y_true ) / (1 - y_pred )
268+
269+ derivative = alpha_factor * (
270+ modulating_factor * d_ce +
271+ d_modulating_factor * (- y_true * np .log (y_pred ) - (1 - y_true ) * np .log (1 - y_pred ))
272+ )
273+
274+ return derivative / y_true .shape [0 ]
275+
276+ def __str__ (self ):
277+ return f"FocalLoss(gamma={ self .gamma } , alpha={ self .alpha } )"
278+
279+ def get_config (self ) -> dict :
280+ return {
281+ "name" : self .__class__ .__name__ ,
282+ "gamma" : self .gamma ,
283+ "alpha" : self .alpha
284+ }
0 commit comments