File tree Expand file tree Collapse file tree 1 file changed +7
-4
lines changed Expand file tree Collapse file tree 1 file changed +7
-4
lines changed Original file line number Diff line number Diff line change @@ -12,16 +12,19 @@ def derivative(self, x: np.ndarray) -> np.ndarray:
1212 raise NotImplementedError
1313
1414 def get_config (self ) -> dict :
15- return {}
15+ return {"name" : self . __class__ . __name__ }
1616
1717 @staticmethod
1818 def from_config (config : dict ):
19- name = config ['name' ]
19+ name = config .get ('name' )
20+ if not name :
21+ raise ValueError ('Config must contain "name" field' )
22+
23+ constructor_params = {k : v for k , v in config .items ()
24+ if k not in ['name' , 'config' ]}
2025
2126 for activation_class in ActivationFunction .__subclasses__ ():
2227 if activation_class .__name__ == name :
23- constructor_params = {k : v for k ,
24- v in config .items () if k != 'name' }
2528 return activation_class (** constructor_params )
2629
2730 raise ValueError (f'Unknown activation function: { name } ' )
You can’t perform that action at this time.
0 commit comments