@@ -1205,7 +1205,8 @@ def save(self, filename: str):
12051205 'random_state' : self .random_state ,
12061206 'skip_connections' : self .skip_connections ,
12071207 'l1_reg' : self .l1_reg ,
1208- 'l2_reg' : self .l2_reg
1208+ 'l2_reg' : self .l2_reg ,
1209+ 'variational' : self .variational
12091210 }
12101211
12111212 for layer in self .encoder_layers :
@@ -2265,7 +2266,8 @@ def evaluate(self, x_test: list[np.ndarray], y_test: np.ndarray, batch_size: int
22652266 return avg_loss , all_predictions
22662267
22672268 def get_config (self ) -> dict :
2268- return {
2269+ config = {
2270+ 'type' : 'Transformer' ,
22692271 'src_vocab_size' : self .src_vocab_size ,
22702272 'tgt_vocab_size' : self .tgt_vocab_size ,
22712273 'd_model' : self .d_model ,
@@ -2278,15 +2280,24 @@ def get_config(self) -> dict:
22782280 'gradient_clip_threshold' : self .gradient_clip_threshold ,
22792281 'enable_padding' : self .enable_padding ,
22802282 'padding_size' : self .padding_size ,
2281- 'random_state' : self .random_state
2283+ 'random_state' : self .random_state ,
2284+
2285+ 'src_embedding' : self .src_embedding .get_config (),
2286+ 'tgt_embedding' : self .tgt_embedding .get_config (),
2287+ 'positional_encoding' : self .positional_encoding .get_config (),
2288+
2289+ 'encoder_layers' : [layer .get_config () for layer in self .encoder_layers ],
2290+ 'decoder_layers' : [layer .get_config () for layer in self .decoder_layers ],
2291+
2292+ 'encoder_dropout' : self .encoder_dropout .get_config (),
2293+ 'decoder_dropout' : self .decoder_dropout .get_config (),
2294+
2295+ 'output_layer' : self .output_layer .get_config (),
2296+
2297+ 'loss_function' : self .loss_function .get_config () if self .loss_function is not None else None ,
2298+ 'optimizer' : self .optimizer .get_config () if self .optimizer is not None else None
22822299 }
2283-
2284- def save (self , filename : str ) -> None :
2285- config = self .get_config ()
2286- config ['type' ] = 'Transformer'
2287-
2288- with open (filename , 'w' ) as f :
2289- json .dump (config , f , indent = 4 )
2300+ return config
22902301
22912302 @classmethod
22922303 def load (cls , filename : str ) -> 'Transformer' :
@@ -2296,7 +2307,94 @@ def load(cls, filename: str) -> 'Transformer':
22962307 if config ['type' ] != 'Transformer' :
22972308 raise ValueError (f"Invalid model type { config ['type' ]} " )
22982309
2299- return cls (** {k : v for k , v in config .items () if k != 'type' })
2310+ model = cls (
2311+ src_vocab_size = config ['src_vocab_size' ],
2312+ tgt_vocab_size = config ['tgt_vocab_size' ],
2313+ d_model = config ['d_model' ],
2314+ n_heads = config ['n_heads' ],
2315+ n_encoder_layers = config ['n_encoder_layers' ],
2316+ n_decoder_layers = config ['n_decoder_layers' ],
2317+ d_ff = config ['d_ff' ],
2318+ dropout_rate = config ['dropout_rate' ],
2319+ max_sequence_length = config ['max_sequence_length' ],
2320+ gradient_clip_threshold = config ['gradient_clip_threshold' ],
2321+ enable_padding = config ['enable_padding' ],
2322+ padding_size = config ['padding_size' ],
2323+ random_state = config ['random_state' ]
2324+ )
2325+
2326+ model .src_embedding = Embedding .from_config (config ['src_embedding' ])
2327+ model .tgt_embedding = Embedding .from_config (config ['tgt_embedding' ])
2328+ model .positional_encoding = PositionalEncoding .from_config (config ['positional_encoding' ])
2329+
2330+ model .encoder_dropout = Dropout .from_config (config ['encoder_dropout' ])
2331+ model .decoder_dropout = Dropout .from_config (config ['decoder_dropout' ])
2332+
2333+ model .encoder_layers = [TransformerEncoderLayer .from_config (layer_config )
2334+ for layer_config in config ['encoder_layers' ]]
2335+ model .decoder_layers = [TransformerDecoderLayer .from_config (layer_config )
2336+ for layer_config in config ['decoder_layers' ]]
2337+
2338+ model .output_layer = Dense .from_config (config ['output_layer' ])
2339+
2340+ if config ['loss_function' ]:
2341+ model .loss_function = LossFunction .from_config (config ['loss_function' ])
2342+ if config ['optimizer' ]:
2343+ model .optimizer = Optimizer .from_config (config ['optimizer' ])
2344+
2345+ return model
2346+
2347+ def save (self , filename : str ) -> None :
2348+ base , ext = os .path .splitext (filename )
2349+
2350+ config = self .get_config ()
2351+
2352+ if self .src_embedding is not None :
2353+ src_emb_file = f"{ base } _src_embedding{ ext } "
2354+ config ['src_embedding_file' ] = src_emb_file
2355+ with open (src_emb_file , 'w' ) as f :
2356+ json .dump (self .src_embedding .get_config (), f , indent = 4 )
2357+
2358+ if self .tgt_embedding is not None :
2359+ tgt_emb_file = f"{ base } _tgt_embedding{ ext } "
2360+ config ['tgt_embedding_file' ] = tgt_emb_file
2361+ with open (tgt_emb_file , 'w' ) as f :
2362+ json .dump (self .tgt_embedding .get_config (), f , indent = 4 )
2363+
2364+ config ['encoder_layers_files' ] = []
2365+ for i , layer in enumerate (self .encoder_layers ):
2366+ encoder_file = f"{ base } _encoder_layer_{ i } { ext } "
2367+ config ['encoder_layers_files' ].append (encoder_file )
2368+ with open (encoder_file , 'w' ) as f :
2369+ json .dump (layer .get_config (), f , indent = 4 )
2370+
2371+ config ['decoder_layers_files' ] = []
2372+ for i , layer in enumerate (self .decoder_layers ):
2373+ decoder_file = f"{ base } _decoder_layer_{ i } { ext } "
2374+ config ['decoder_layers_files' ].append (decoder_file )
2375+ with open (decoder_file , 'w' ) as f :
2376+ json .dump (layer .get_config (), f , indent = 4 )
2377+
2378+ if self .output_layer is not None :
2379+ output_file = f"{ base } _output_layer{ ext } "
2380+ config ['output_layer_file' ] = output_file
2381+ with open (output_file , 'w' ) as f :
2382+ json .dump (self .output_layer .get_config (), f , indent = 4 )
2383+
2384+ if self .optimizer is not None :
2385+ optimizer_file = f"{ base } _optimizer{ ext } "
2386+ config ['optimizer_file' ] = optimizer_file
2387+ with open (optimizer_file , 'w' ) as f :
2388+ json .dump (self .optimizer .get_config (), f , indent = 4 )
2389+
2390+ if self .loss_function is not None :
2391+ loss_file = f"{ base } _loss{ ext } "
2392+ config ['loss_file' ] = loss_file
2393+ with open (loss_file , 'w' ) as f :
2394+ json .dump (self .loss_function .get_config (), f , indent = 4 )
2395+
2396+ with open (filename , 'w' ) as f :
2397+ json .dump (config , f , indent = 4 )
23002398
23012399 def __str__ (self ) -> str :
23022400 return (f"Transformer(\n "
0 commit comments