2424    StreamEmbedLinear ,
2525    StreamEmbedTransformer ,
2626)
27- from  weathergen .model .layers  import  MLP 
27+ from  weathergen .model .layers  import  MLP ,  MoEMLP 
2828from  weathergen .model .utils  import  ActivationFactory 
2929from  weathergen .utils .utils  import  get_dtype 
3030
31+ import  logging 
32+ logger  =  logging .getLogger (__name__ )
3133
3234class  EmbeddingEngine :
3335    name : "EmbeddingEngine" 
@@ -249,17 +251,50 @@ def create(self) -> torch.nn.ModuleList:
249251                    )
250252                )
251253            # MLP block 
252-             self .ae_global_blocks .append (
253-                 MLP (
254-                     self .cf .ae_global_dim_embed ,
255-                     self .cf .ae_global_dim_embed ,
256-                     with_residual = True ,
257-                     dropout_rate = self .cf .ae_global_dropout_rate ,
258-                     hidden_factor = self .cf .ae_global_mlp_hidden_factor ,
259-                     norm_type = self .cf .norm_type ,
260-                     norm_eps = self .cf .mlp_norm_eps ,
261-                 )
254+             # Add MoE option 
255+             use_moe  =  getattr (self .cf , "ae_global_mlp_type" , "dense" ) ==  "moe" 
256+             mlp_common_kwargs  =  dict (
257+                 dim_in = self .cf .ae_global_dim_embed ,
258+                 dim_out = self .cf .ae_global_dim_embed ,
259+                 with_residual = True ,
260+                 dropout_rate = self .cf .ae_global_dropout_rate ,
261+                 norm_type = self .cf .norm_type ,
262+                 norm_eps = self .cf .mlp_norm_eps ,
262263            )
264+             if  use_moe :
265+                 self .ae_global_blocks .append (
266+                     MoEMLP (
267+                         ** mlp_common_kwargs ,
268+                         num_experts = getattr (self .cf , "ae_global_moe_num_experts" , 2 ),
269+                         top_k = getattr (self .cf , "ae_global_moe_top_k" , 1 ),
270+                         router_noisy_std = getattr (self .cf , "ae_global_moe_router_noisy_std" , 0.0 ),
271+                         hidden_factor = getattr (self .cf , "ae_global_moe_hidden_factor" , 2 ),
272+                     )
273+                 )
274+             else :
275+                 self .ae_global_blocks .append (
276+                     MLP (
277+                         self .cf .ae_global_dim_embed ,
278+                         self .cf .ae_global_dim_embed ,
279+                         with_residual = True ,
280+                         dropout_rate = self .cf .ae_global_dropout_rate ,
281+                         hidden_factor = self .cf .ae_global_mlp_hidden_factor ,
282+                         norm_type = self .cf .norm_type ,
283+                         norm_eps = self .cf .mlp_norm_eps ,
284+                     )
285+                 )
286+         # Count MoE blocks 
287+         num_moe  =  sum (1  for  m  in  self .ae_global_blocks  if  isinstance (m , MoEMLP ))
288+         logger .info (
289+             "[MoE] GlobalAssimilationEngine: %d MoEMLP blocks " 
290+             "(ae_global_mlp_type=%s, experts=%s, top_k=%s, hidden_factor=%s)" ,
291+             num_moe ,
292+             getattr (self .cf , "ae_global_mlp_type" , "dense" ),
293+             getattr (self .cf , "ae_global_moe_num_experts" , None ),
294+             getattr (self .cf , "ae_global_moe_top_k" , None ),
295+             getattr (self .cf , "ae_global_moe_hidden_factor" , None ),
296+         )
297+ 
263298        return  self .ae_global_blocks 
264299
265300
@@ -343,8 +378,8 @@ def create(self) -> torch.nn.ModuleList:
343378                    self .fe_blocks .append (
344379                        MoEMLP (
345380                            ** mlp_common_kwargs ,
346-                             num_experts = getattr (self .cf , "fe_moe_num_experts" , 8 ),
347-                             top_k = getattr (self .cf , "fe_moe_top_k" , 4 ),
381+                             num_experts = getattr (self .cf , "fe_moe_num_experts" , 2 ),
382+                             top_k = getattr (self .cf , "fe_moe_top_k" , 2 ),
348383                            router_noisy_std = getattr (self .cf , "fe_moe_router_noisy_std" , 0.0 ),
349384                            hidden_factor = getattr (self .cf , "fe_moe_hidden_factor" , 2 ),
350385                        )
@@ -362,15 +397,24 @@ def create(self) -> torch.nn.ModuleList:
362397                        )
363398                    )
364399                # ------------------------------------------------------------------ 
365-         def  init_weights_final (m ):
366-             if  isinstance (m , torch .nn .Linear ):
367-                 torch .nn .init .normal_ (m .weight , mean = 0 , std = 0.001 )
368-                 if  m .bias  is  not None :
369-                     torch .nn .init .normal_ (m .bias , mean = 0 , std = 0.001 )
370- 
371-         for  block  in  self .fe_blocks :
372-             block .apply (init_weights_final )
373- 
400+         # def init_weights_final(m): 
401+         #     if isinstance(m, torch.nn.Linear) and not getattr(m, "is_moe_router", False): 
402+         #         torch.nn.init.normal_(m.weight, mean=0, std=0.001) 
403+         #         if m.bias is not None: 
404+         #             torch.nn.init.normal_(m.bias, mean=0, std=0.001) 
405+ 
406+         # for block in self.fe_blocks: 
407+         #     block.apply(init_weights_final) 
408+         num_moe  =  sum (1  for  m  in  self .fe_blocks  if  isinstance (m , MoEMLP ))
409+         logger .info (
410+             "[MoE] ForecastingEngine: %d MoEMLP blocks " 
411+             "(fe_mlp_type=%s, experts=%s, top_k=%s, hidden_factor=%s)" ,
412+             num_moe ,
413+             getattr (self .cf , "fe_mlp_type" , "dense" ),
414+             getattr (self .cf , "fe_moe_num_experts" , None ),
415+             getattr (self .cf , "fe_moe_top_k" , None ),
416+             getattr (self .cf , "fe_moe_hidden_factor" , None ),
417+         )
374418        return  self .fe_blocks 
375419
376420
@@ -619,6 +663,14 @@ def __init__(
619663                        with_adanorm = False ,
620664                        with_mlp = False ,
621665                        attention_kwargs = attention_kwargs ,
666+                         ffn_mlp_type = getattr (self .cf , "decoder_ffn_mlp_type" , "dense" ),
667+                         ffn_hidden_factor = getattr (self .cf , "decoder_ffn_hidden_factor" , 4 ),
668+                         moe_kwargs = dict (
669+                             num_experts = getattr (self .cf , "decoder_moe_num_experts" , 2 ),
670+                             top_k = getattr (self .cf , "decoder_moe_top_k" , 2 ),
671+                             router_noisy_std = getattr (self .cf , "decoder_moe_router_noisy_std" , 0.0 ),
672+                             use_checkpoint = getattr (self .cf , "decoder_moe_use_checkpoint" , False ),
673+                         )
622674                    )
623675                )
624676            elif  self .cf .decoder_type  ==  "AdaLayerNormConditioning" :
@@ -674,6 +726,14 @@ def __init__(
674726                        tr_mlp_hidden_factor = tr_mlp_hidden_factor ,
675727                        tro_type = tro_type ,
676728                        mlp_norm_eps = self .cf .mlp_norm_eps ,
729+                         ffn_mlp_type = getattr (self .cf , "decoder_ffn_mlp_type" , "dense" ),
730+                         ffn_hidden_factor = getattr (self .cf , "decoder_ffn_hidden_factor" , 4 ),
731+                         moe_kwargs = dict (
732+                             num_experts = getattr (self .cf , "decoder_moe_num_experts" , 2 ),
733+                             top_k = getattr (self .cf , "decoder_moe_top_k" , 2 ),
734+                             router_noisy_std = getattr (self .cf , "decoder_moe_router_noisy_std" , 0.0 ),
735+                             use_checkpoint = getattr (self .cf , "decoder_moe_use_checkpoint" , False ),
736+                         )
677737                    )
678738                )
679739            else :
0 commit comments