@@ -318,20 +318,23 @@ def get_config(self) -> dict:
318318 'name' : self .__class__ .__name__ ,
319319 'rate' : self .rate ,
320320 'adaptive' : self .adaptive ,
321+ 'min_rate' : self .dropout_impl .min_rate if self .adaptive else 0.1 ,
322+ 'max_rate' : self .dropout_impl .max_rate if self .adaptive else 0.9 ,
323+ 'temperature' : self .dropout_impl .temperature if self .adaptive else 1.0 ,
321324 'random_state' : self .random_state
322325 }
323-
324- if self .adaptive :
325- config .update (self .dropout_impl .get_config ())
326-
327326 return config
328327
329328 @staticmethod
330329 def from_config (config : dict ):
331- adaptive = config .pop ('adaptive' , False )
332- if adaptive :
333- return Dropout (adaptive = True , ** config )
334- return Dropout (** config )
330+ return Dropout (
331+ rate = config ['rate' ],
332+ adaptive = config ['adaptive' ],
333+ min_rate = config ['min_rate' ],
334+ max_rate = config ['max_rate' ],
335+ temperature = config ['temperature' ],
336+ random_state = config ['random_state' ]
337+ )
335338
336339
337340class Conv2D (Layer ):
@@ -1275,15 +1278,26 @@ def get_config(self) -> dict:
12751278 'gamma' : self .gamma .tolist () if self .gamma is not None else None ,
12761279 'beta' : self .beta .tolist () if self .beta is not None else None ,
12771280 'momentum' : self .momentum ,
1278- 'epsilon' : self .epsilon
1281+ 'epsilon' : self .epsilon ,
1282+ 'running_mean' : self .running_mean .tolist () if self .running_mean is not None else None ,
1283+ 'running_var' : self .running_var .tolist () if self .running_var is not None else None ,
1284+ 'input_shape' : self .gamma .shape if self .gamma is not None else None
12791285 }
12801286
12811287 @staticmethod
12821288 def from_config (config : dict ):
12831289 layer = BatchNormalization (config ['momentum' ], config ['epsilon' ])
1290+
12841291 if config ['gamma' ] is not None :
12851292 layer .gamma = np .array (config ['gamma' ])
12861293 layer .beta = np .array (config ['beta' ])
1294+
1295+ layer .running_mean = np .array (config ['running_mean' ])
1296+ layer .running_var = np .array (config ['running_var' ])
1297+
1298+ layer .d_gamma = np .zeros_like (layer .gamma )
1299+ layer .d_beta = np .zeros_like (layer .beta )
1300+
12871301 return layer
12881302
12891303
0 commit comments