Skip to content

Commit 200a5e3

Browse files
committed
fix(dropout/batchnorm): save&load
1 parent 47535a8 commit 200a5e3

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

neuralnetlib/layers.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -318,20 +318,23 @@ def get_config(self) -> dict:
318318
'name': self.__class__.__name__,
319319
'rate': self.rate,
320320
'adaptive': self.adaptive,
321+
'min_rate': self.dropout_impl.min_rate if self.adaptive else 0.1,
322+
'max_rate': self.dropout_impl.max_rate if self.adaptive else 0.9,
323+
'temperature': self.dropout_impl.temperature if self.adaptive else 1.0,
321324
'random_state': self.random_state
322325
}
323-
324-
if self.adaptive:
325-
config.update(self.dropout_impl.get_config())
326-
327326
return config
328327

329328
@staticmethod
330329
def from_config(config: dict):
331-
adaptive = config.pop('adaptive', False)
332-
if adaptive:
333-
return Dropout(adaptive=True, **config)
334-
return Dropout(**config)
330+
return Dropout(
331+
rate=config['rate'],
332+
adaptive=config['adaptive'],
333+
min_rate=config['min_rate'],
334+
max_rate=config['max_rate'],
335+
temperature=config['temperature'],
336+
random_state=config['random_state']
337+
)
335338

336339

337340
class Conv2D(Layer):
@@ -1275,15 +1278,26 @@ def get_config(self) -> dict:
12751278
'gamma': self.gamma.tolist() if self.gamma is not None else None,
12761279
'beta': self.beta.tolist() if self.beta is not None else None,
12771280
'momentum': self.momentum,
1278-
'epsilon': self.epsilon
1281+
'epsilon': self.epsilon,
1282+
'running_mean': self.running_mean.tolist() if self.running_mean is not None else None,
1283+
'running_var': self.running_var.tolist() if self.running_var is not None else None,
1284+
'input_shape': self.gamma.shape if self.gamma is not None else None
12791285
}
12801286

12811287
@staticmethod
12821288
def from_config(config: dict):
12831289
layer = BatchNormalization(config['momentum'], config['epsilon'])
1290+
12841291
if config['gamma'] is not None:
12851292
layer.gamma = np.array(config['gamma'])
12861293
layer.beta = np.array(config['beta'])
1294+
1295+
layer.running_mean = np.array(config['running_mean'])
1296+
layer.running_var = np.array(config['running_var'])
1297+
1298+
layer.d_gamma = np.zeros_like(layer.gamma)
1299+
layer.d_beta = np.zeros_like(layer.beta)
1300+
12871301
return layer
12881302

12891303

0 commit comments

Comments
 (0)