Skip to content

Commit d3dfa1d

Browse files
committed
feat(loss): add focal loss
1 parent 48a71b2 commit d3dfa1d

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

neuralnetlib/losses.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def from_config(config: dict) -> 'LossFunction':
3232
return CrossEntropyWithLabelSmoothing(config['label_smoothing'])
3333
elif config['name'] == 'Wasserstein':
3434
return Wasserstein()
35+
elif config['name'] == 'FocalLoss':
36+
return FocalLoss(config['gamma'], config['alpha'])
3537
else:
3638
raise ValueError(f'Unknown loss function: {config["name"]}')
3739

@@ -54,6 +56,8 @@ def from_name(name: str) -> "LossFunction":
5456
return CrossEntropyWithLabelSmoothing()
5557
elif name == "Wasserstein" or name == "wasserstein" or name == "wass":
5658
return Wasserstein()
59+
elif name == "focalloss" or name == "focal" or name == "fl":
60+
return FocalLoss()
5761
elif name.startswith("huber") and len(name.split("_")) == 2:
5862
delta = float(name.split("_")[-1])
5963
return HuberLoss(delta)
@@ -228,3 +232,53 @@ def derivative(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
228232

229233
def __str__(self):
230234
return "Wasserstein"
235+
236+
237+
class FocalLoss(LossFunction):
238+
def __init__(self, gamma: float = 2.0, alpha: float = 0.25):
239+
super().__init__()
240+
self.gamma = gamma
241+
self.alpha = alpha
242+
243+
def __call__(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
244+
y_pred = np.clip(y_pred, self.EPSILON, 1 - self.EPSILON)
245+
246+
ce_loss = -y_true * np.log(y_pred) - (1 - y_true) * np.log(1 - y_pred)
247+
248+
p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
249+
modulating_factor = np.power(1 - p_t, self.gamma)
250+
251+
alpha_factor = y_true * self.alpha + (1 - y_true) * (1 - self.alpha)
252+
253+
focal_loss = alpha_factor * modulating_factor * ce_loss
254+
255+
return np.mean(focal_loss)
256+
257+
def derivative(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
258+
y_pred = np.clip(y_pred, self.EPSILON, 1 - self.EPSILON)
259+
260+
p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
261+
262+
alpha_factor = y_true * self.alpha + (1 - y_true) * (1 - self.alpha)
263+
264+
modulating_factor = np.power(1 - p_t, self.gamma)
265+
d_modulating_factor = -self.gamma * np.power(1 - p_t, self.gamma - 1)
266+
267+
d_ce = y_true / y_pred - (1 - y_true) / (1 - y_pred)
268+
269+
derivative = alpha_factor * (
270+
modulating_factor * d_ce +
271+
d_modulating_factor * (-y_true * np.log(y_pred) - (1 - y_true) * np.log(1 - y_pred))
272+
)
273+
274+
return derivative / y_true.shape[0]
275+
276+
def __str__(self):
277+
return f"FocalLoss(gamma={self.gamma}, alpha={self.alpha})"
278+
279+
def get_config(self) -> dict:
280+
return {
281+
"name": self.__class__.__name__,
282+
"gamma": self.gamma,
283+
"alpha": self.alpha
284+
}

0 commit comments

Comments
 (0)