Skip to content

Commit 90a9194

Browse files
committed
fix(optimizers-losses-activations): from_config
1 parent d3dfa1d commit 90a9194

File tree

4 files changed

+61
-80
lines changed

4 files changed

+61
-80
lines changed

neuralnetlib/activations.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,13 @@ def get_config(self) -> dict:
1717
@staticmethod
1818
def from_config(config: dict):
1919
name = config['name']
20-
if name == 'Sigmoid':
21-
return Sigmoid()
22-
elif name == 'ReLU':
23-
return ReLU()
24-
elif name == 'Tanh':
25-
return Tanh()
26-
elif name == 'Softmax':
27-
return Softmax()
28-
elif name == 'Linear':
29-
return Linear()
30-
elif name == 'LeakyReLU':
31-
return LeakyReLU(alpha=config['alpha'])
32-
elif name == 'ELU':
33-
return ELU()
34-
elif name == 'SELU':
35-
return SELU(alpha=config['alpha'], scale=config['scale'])
36-
elif name == 'GELU':
37-
return GELU()
38-
else:
39-
raise ValueError(f'Unknown activation function: {name}')
20+
21+
for activation_class in ActivationFunction.__subclasses__():
22+
if activation_class.__name__ == name:
23+
constructor_params = {k: v for k, v in config.items() if k != 'name'}
24+
return activation_class(**constructor_params)
25+
26+
raise ValueError(f'Unknown activation function: {name}')
4027

4128

4229
class Sigmoid(ActivationFunction):

neuralnetlib/losses.py

Lines changed: 44 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -16,57 +16,53 @@ def get_config(self) -> dict:
1616

1717
@staticmethod
1818
def from_config(config: dict) -> 'LossFunction':
19-
if config['name'] == 'MeanSquaredError':
20-
return MeanSquaredError()
21-
elif config['name'] == 'BinaryCrossentropy':
22-
return BinaryCrossentropy()
23-
elif config['name'] == 'CategoricalCrossentropy':
24-
return CategoricalCrossentropy()
25-
elif config['name'] == 'MeanAbsoluteError':
26-
return MeanAbsoluteError()
27-
elif config['name'] == 'HuberLoss':
28-
return HuberLoss(config['delta'])
29-
elif config['name'] == 'KullbackLeiblerDivergence':
30-
return KullbackLeiblerDivergence()
31-
elif config['name'] == 'CrossEntropyWithLabelSmoothing':
32-
return CrossEntropyWithLabelSmoothing(config['label_smoothing'])
33-
elif config['name'] == 'Wasserstein':
34-
return Wasserstein()
35-
elif config['name'] == 'FocalLoss':
36-
return FocalLoss(config['gamma'], config['alpha'])
37-
else:
38-
raise ValueError(f'Unknown loss function: {config["name"]}')
19+
loss_name = config['name']
20+
21+
for loss_class in LossFunction.__subclasses__():
22+
if loss_class.__name__ == loss_name:
23+
constructor_params = {k: v for k, v in config.items() if k != 'name'}
24+
return loss_class(**constructor_params)
3925

4026
@staticmethod
4127
def from_name(name: str) -> "LossFunction":
28+
aliases = {
29+
"mse": "MeanSquaredError",
30+
"bce": "BinaryCrossentropy",
31+
"cce": "CategoricalCrossentropy",
32+
"scce": "SparseCategoricalCrossentropy",
33+
"mae": "MeanAbsoluteError",
34+
"kld": "KullbackLeiblerDivergence",
35+
"cels": "CrossEntropyWithLabelSmoothing",
36+
"wass": "Wasserstein",
37+
"focal": "FocalLoss",
38+
"fl": "FocalLoss"
39+
}
40+
41+
original_name = name
4242
name = name.lower().replace("_", "")
43-
if name == "mse" or name == "meansquarederror":
44-
return MeanSquaredError()
45-
elif name == "bce" or name == "binarycrossentropy":
46-
return BinaryCrossentropy()
47-
elif name == "cce" or name == "categorycrossentropy":
48-
return CategoricalCrossentropy()
49-
elif name == "scce" or name == "sparsecategoricalcrossentropy":
50-
return SparseCategoricalCrossentropy()
51-
elif name == "mae" or name == "meanabsoluteerror":
52-
return MeanAbsoluteError()
53-
elif name == "kld" or name == "kullbackleiblerdivergence":
54-
return KullbackLeiblerDivergence()
55-
elif name == "crossentropywithlabelsmoothing" or name == "cels":
56-
return CrossEntropyWithLabelSmoothing()
57-
elif name == "Wasserstein" or name == "wasserstein" or name == "wass":
58-
return Wasserstein()
59-
elif name == "focalloss" or name == "focal" or name == "fl":
60-
return FocalLoss()
61-
elif name.startswith("huber") and len(name.split("_")) == 2:
62-
delta = float(name.split("_")[-1])
63-
return HuberLoss(delta)
64-
else:
65-
for subclass in LossFunction.__subclasses__():
66-
if subclass.__name__.lower() == name:
67-
return subclass()
68-
69-
raise ValueError(f"No loss function found for the name: {name}")
43+
44+
if name.startswith("huber") and len(original_name.split("_")) == 2:
45+
try:
46+
delta = float(original_name.split("_")[-1])
47+
return Huber(delta=delta)
48+
except ValueError:
49+
pass
50+
51+
if name in aliases:
52+
name = aliases[name]
53+
54+
for loss_class in LossFunction.__subclasses__():
55+
if loss_class.__name__.lower() == name or loss_class.__name__ == name:
56+
if loss_class.__name__ == "Huber":
57+
return loss_class(delta=1.0)
58+
elif loss_class.__name__ == "CrossEntropyWithLabelSmoothing":
59+
return loss_class(label_smoothing=0.1)
60+
elif loss_class.__name__ == "FocalLoss":
61+
return loss_class(gamma=2.0, alpha=0.25)
62+
else:
63+
return loss_class()
64+
65+
raise ValueError(f"No loss function found for the name: {original_name}")
7066

7167

7268
class MeanSquaredError(LossFunction):
@@ -142,7 +138,7 @@ def __str__(self):
142138
return "MeanAbsoluteError"
143139

144140

145-
class HuberLoss(LossFunction):
141+
class Huber(LossFunction):
146142
def __init__(self, delta: float = 1.0):
147143
super().__init__()
148144
self.delta = delta

neuralnetlib/optimizers.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,14 @@ def get_config(self) -> dict:
1616

1717
@staticmethod
1818
def from_config(config: dict):
19-
if config['name'] == 'SGD':
20-
return SGD.from_config(config)
21-
elif config['name'] == 'Momentum':
22-
return Momentum.from_config(config)
23-
elif config['name'] == 'RMSprop':
24-
return RMSprop.from_config(config)
25-
elif config['name'] == 'Adam':
26-
return Adam.from_config(config)
27-
else:
28-
raise ValueError(f"Unknown optimizer name: {config['name']}")
19+
optimizer_name = config['name']
20+
21+
for optimizer_class in Optimizer.__subclasses__():
22+
if optimizer_class.__name__ == optimizer_name:
23+
constructor_params = {k: v for k, v in config.items() if k != 'name'}
24+
return optimizer_class(**constructor_params)
25+
26+
raise ValueError(f"No optimizer found for the name: {optimizer_name}")
2927

3028
@staticmethod
3129
def from_name(name: str) -> "Optimizer":

tests/test_losses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44

55
from neuralnetlib.losses import MeanSquaredError, BinaryCrossentropy, CategoricalCrossentropy, MeanAbsoluteError, \
6-
HuberLoss
6+
Huber
77

88

99
class TestLossFunctions(unittest.TestCase):
@@ -41,7 +41,7 @@ def test_mean_absolute_error(self):
4141
self.assertAlmostEqual(calculated_loss, expected_loss)
4242

4343
def test_huber_loss(self):
44-
huber = HuberLoss(delta=1.0)
44+
huber = Huber(delta=1.0)
4545
y_true = np.array([1, 2, 3])
4646
y_pred = np.array([1, 2, 4])
4747
error = y_true - y_pred

0 commit comments

Comments
 (0)