Skip to content

Commit 66be639

Browse files
committed
fix(configs): models states saving and loading
1 parent 1fcafb5 commit 66be639

File tree

1 file changed

+109
-11
lines changed

1 file changed

+109
-11
lines changed

neuralnetlib/models.py

Lines changed: 109 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,7 +1205,8 @@ def save(self, filename: str):
12051205
'random_state': self.random_state,
12061206
'skip_connections': self.skip_connections,
12071207
'l1_reg': self.l1_reg,
1208-
'l2_reg': self.l2_reg
1208+
'l2_reg': self.l2_reg,
1209+
'variational': self.variational
12091210
}
12101211

12111212
for layer in self.encoder_layers:
@@ -2265,7 +2266,8 @@ def evaluate(self, x_test: list[np.ndarray], y_test: np.ndarray, batch_size: int
22652266
return avg_loss, all_predictions
22662267

22672268
def get_config(self) -> dict:
2268-
return {
2269+
config = {
2270+
'type': 'Transformer',
22692271
'src_vocab_size': self.src_vocab_size,
22702272
'tgt_vocab_size': self.tgt_vocab_size,
22712273
'd_model': self.d_model,
@@ -2278,15 +2280,24 @@ def get_config(self) -> dict:
22782280
'gradient_clip_threshold': self.gradient_clip_threshold,
22792281
'enable_padding': self.enable_padding,
22802282
'padding_size': self.padding_size,
2281-
'random_state': self.random_state
2283+
'random_state': self.random_state,
2284+
2285+
'src_embedding': self.src_embedding.get_config(),
2286+
'tgt_embedding': self.tgt_embedding.get_config(),
2287+
'positional_encoding': self.positional_encoding.get_config(),
2288+
2289+
'encoder_layers': [layer.get_config() for layer in self.encoder_layers],
2290+
'decoder_layers': [layer.get_config() for layer in self.decoder_layers],
2291+
2292+
'encoder_dropout': self.encoder_dropout.get_config(),
2293+
'decoder_dropout': self.decoder_dropout.get_config(),
2294+
2295+
'output_layer': self.output_layer.get_config(),
2296+
2297+
'loss_function': self.loss_function.get_config() if self.loss_function is not None else None,
2298+
'optimizer': self.optimizer.get_config() if self.optimizer is not None else None
22822299
}
2283-
2284-
def save(self, filename: str) -> None:
2285-
config = self.get_config()
2286-
config['type'] = 'Transformer'
2287-
2288-
with open(filename, 'w') as f:
2289-
json.dump(config, f, indent=4)
2300+
return config
22902301

22912302
@classmethod
22922303
def load(cls, filename: str) -> 'Transformer':
@@ -2296,7 +2307,94 @@ def load(cls, filename: str) -> 'Transformer':
22962307
if config['type'] != 'Transformer':
22972308
raise ValueError(f"Invalid model type {config['type']}")
22982309

2299-
return cls(**{k: v for k, v in config.items() if k != 'type'})
2310+
model = cls(
2311+
src_vocab_size=config['src_vocab_size'],
2312+
tgt_vocab_size=config['tgt_vocab_size'],
2313+
d_model=config['d_model'],
2314+
n_heads=config['n_heads'],
2315+
n_encoder_layers=config['n_encoder_layers'],
2316+
n_decoder_layers=config['n_decoder_layers'],
2317+
d_ff=config['d_ff'],
2318+
dropout_rate=config['dropout_rate'],
2319+
max_sequence_length=config['max_sequence_length'],
2320+
gradient_clip_threshold=config['gradient_clip_threshold'],
2321+
enable_padding=config['enable_padding'],
2322+
padding_size=config['padding_size'],
2323+
random_state=config['random_state']
2324+
)
2325+
2326+
model.src_embedding = Embedding.from_config(config['src_embedding'])
2327+
model.tgt_embedding = Embedding.from_config(config['tgt_embedding'])
2328+
model.positional_encoding = PositionalEncoding.from_config(config['positional_encoding'])
2329+
2330+
model.encoder_dropout = Dropout.from_config(config['encoder_dropout'])
2331+
model.decoder_dropout = Dropout.from_config(config['decoder_dropout'])
2332+
2333+
model.encoder_layers = [TransformerEncoderLayer.from_config(layer_config)
2334+
for layer_config in config['encoder_layers']]
2335+
model.decoder_layers = [TransformerDecoderLayer.from_config(layer_config)
2336+
for layer_config in config['decoder_layers']]
2337+
2338+
model.output_layer = Dense.from_config(config['output_layer'])
2339+
2340+
if config['loss_function']:
2341+
model.loss_function = LossFunction.from_config(config['loss_function'])
2342+
if config['optimizer']:
2343+
model.optimizer = Optimizer.from_config(config['optimizer'])
2344+
2345+
return model
2346+
2347+
def save(self, filename: str) -> None:
2348+
base, ext = os.path.splitext(filename)
2349+
2350+
config = self.get_config()
2351+
2352+
if self.src_embedding is not None:
2353+
src_emb_file = f"{base}_src_embedding{ext}"
2354+
config['src_embedding_file'] = src_emb_file
2355+
with open(src_emb_file, 'w') as f:
2356+
json.dump(self.src_embedding.get_config(), f, indent=4)
2357+
2358+
if self.tgt_embedding is not None:
2359+
tgt_emb_file = f"{base}_tgt_embedding{ext}"
2360+
config['tgt_embedding_file'] = tgt_emb_file
2361+
with open(tgt_emb_file, 'w') as f:
2362+
json.dump(self.tgt_embedding.get_config(), f, indent=4)
2363+
2364+
config['encoder_layers_files'] = []
2365+
for i, layer in enumerate(self.encoder_layers):
2366+
encoder_file = f"{base}_encoder_layer_{i}{ext}"
2367+
config['encoder_layers_files'].append(encoder_file)
2368+
with open(encoder_file, 'w') as f:
2369+
json.dump(layer.get_config(), f, indent=4)
2370+
2371+
config['decoder_layers_files'] = []
2372+
for i, layer in enumerate(self.decoder_layers):
2373+
decoder_file = f"{base}_decoder_layer_{i}{ext}"
2374+
config['decoder_layers_files'].append(decoder_file)
2375+
with open(decoder_file, 'w') as f:
2376+
json.dump(layer.get_config(), f, indent=4)
2377+
2378+
if self.output_layer is not None:
2379+
output_file = f"{base}_output_layer{ext}"
2380+
config['output_layer_file'] = output_file
2381+
with open(output_file, 'w') as f:
2382+
json.dump(self.output_layer.get_config(), f, indent=4)
2383+
2384+
if self.optimizer is not None:
2385+
optimizer_file = f"{base}_optimizer{ext}"
2386+
config['optimizer_file'] = optimizer_file
2387+
with open(optimizer_file, 'w') as f:
2388+
json.dump(self.optimizer.get_config(), f, indent=4)
2389+
2390+
if self.loss_function is not None:
2391+
loss_file = f"{base}_loss{ext}"
2392+
config['loss_file'] = loss_file
2393+
with open(loss_file, 'w') as f:
2394+
json.dump(self.loss_function.get_config(), f, indent=4)
2395+
2396+
with open(filename, 'w') as f:
2397+
json.dump(config, f, indent=4)
23002398

23012399
def __str__(self) -> str:
23022400
return (f"Transformer(\n"

0 commit comments

Comments
 (0)