Skip to content

Commit 9c03039

Browse files
committed
fix(activation): config loading
1 parent c7be214 commit 9c03039

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

neuralnetlib/activations.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,19 @@ def derivative(self, x: np.ndarray) -> np.ndarray:
1212
raise NotImplementedError
1313

1414
def get_config(self) -> dict:
15-
return {}
15+
return {"name": self.__class__.__name__}
1616

1717
@staticmethod
1818
def from_config(config: dict):
19-
name = config['name']
19+
name = config.get('name')
20+
if not name:
21+
raise ValueError('Config must contain "name" field')
22+
23+
constructor_params = {k: v for k, v in config.items()
24+
if k not in ['name', 'config']}
2025

2126
for activation_class in ActivationFunction.__subclasses__():
2227
if activation_class.__name__ == name:
23-
constructor_params = {k: v for k,
24-
v in config.items() if k != 'name'}
2528
return activation_class(**constructor_params)
2629

2730
raise ValueError(f'Unknown activation function: {name}')

0 commit comments

Comments
 (0)