Skip to content

Commit b3de51e

Browse files
committed
adding MoE layers to the global engine and decoder layers, adding the router loss to the trainier
1 parent c73578e commit b3de51e

File tree

5 files changed

+477
-189
lines changed

5 files changed

+477
-189
lines changed

config/default_config.yml

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ ae_adapter_with_residual: True
2424
ae_adapter_dropout_rate: 0.1
2525

2626
ae_global_dim_embed: 2048
27-
ae_global_num_blocks: 8
27+
ae_global_num_blocks: 2
2828
ae_global_num_heads: 32
2929
ae_global_dropout_rate: 0.1
3030
ae_global_with_qk_lnorm: True
@@ -42,12 +42,12 @@ pred_mlp_adaln: True
4242

4343
# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
4444
# one is training an auto-encoder
45-
forecast_offset : 0
45+
forecast_offset : 1
4646
forecast_delta_hrs: 0
47-
forecast_steps: 0
48-
forecast_policy: null
47+
forecast_steps: 1
48+
forecast_policy: "fixed"
4949
forecast_att_dense_rate: 1.0
50-
fe_num_blocks: 0
50+
fe_num_blocks: 2
5151
fe_num_heads: 16
5252
fe_dropout_rate: 0.1
5353
fe_with_qk_lnorm: True
@@ -93,7 +93,7 @@ ema_halflife_in_thousands: 1e-3
9393

9494
# training mode: "forecast" or "masking" (masked token modeling)
9595
# for "masking" to train with auto-encoder mode, forecast_offset should be 0
96-
training_mode: "masking"
96+
training_mode: "forecast"
9797
# masking rate when training mode is "masking"; ignored in foreacast mode
9898
masking_rate: 0.6
9999
# sample the masking rate (with normal distribution centered at masking_rate)
@@ -168,8 +168,28 @@ train_log:
168168

169169
# Forecast MLP type: "dense" (default) or "moe"
170170
fe_mlp_type: "dense" # set to "moe" to enable MoE
171+
ae_global_mlp_type: "dense" # set to "moe" to enable MoE
172+
ffn_mlp_type: "dense" # set to "moe" to enable MoE in the feed-forward network of the decoder blocks
173+
decoder_mlp_type: "dense" # set to "moe" to enable MoE in the decoder prediction MLP
174+
moe_lambda: 0.02 # coefficient for the MoE load balancing loss
171175

172176
# MoE-only params (ignored when fe_mlp_type != "moe")
173-
fe_moe_num_experts: 8
174-
fe_moe_top_k: 2
177+
fe_moe_num_experts: 2
178+
fe_moe_top_k: 1
175179
fe_moe_hidden_factor: 0.5 # = HF_dense / 4
180+
181+
# MoE-only params (ignored when ae_global_mlp_type != "moe")
182+
ae_global_moe_num_experts: 4
183+
ae_global_moe_top_k: 2
184+
ae_global_moe_hidden_factor: 0.5 # = HF_dense / 4
185+
186+
# MoE-only params (ignored when ffn_mlp_type != "moe")
187+
ffn_moe_num_experts: 2
188+
ffn_moe_top_k: 1
189+
ffn_moe_hidden_factor: 0.5 # = HF_dense / 4
190+
191+
# MoE-only params (ignored when decoder_mlp_type != "moe")
192+
decoder_moe_num_experts: 2
193+
decoder_moe_top_k: 1
194+
decoder_moe_hidden_factor: 0.5 # = HF_dense / 4
195+
tr_mlp_hidden_factor: 2

src/weathergen/model/blocks.py

Lines changed: 96 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
MultiCrossAttentionHeadVarlen,
1515
MultiSelfAttentionHeadVarlen,
1616
)
17-
from weathergen.model.layers import MLP
17+
from weathergen.model.layers import MLP, MoEMLP
1818
from weathergen.model.norms import AdaLayerNormLayer
1919
from weathergen.utils.utils import get_dtype
2020

21+
import logging
22+
logger = logging.getLogger(__name__)
2123

2224
class SelfAttentionBlock(nn.Module):
2325
"""
@@ -43,14 +45,32 @@ def __init__(self, dim, dim_aux, with_adanorm, num_heads, dropout_rate, **kwargs
4345
self.mhsa_block = lambda x, _, **kwargs: self.mhsa(self.ln_sa(x), **kwargs) + x
4446

4547
approx_gelu = lambda: nn.GELU(approximate="tanh")
46-
self.mlp = MLP(
47-
dim_in=dim,
48-
dim_out=dim,
49-
hidden_factor=4,
50-
dropout_rate=0.1,
51-
nonlin=approx_gelu,
52-
with_residual=False,
53-
)
48+
use_moe_ffn = (kwargs.get("ffn_mlp_type", "dense") == "moe")
49+
ffn_hidden_factor = kwargs.get("ffn_hidden_factor", 4)
50+
moe_kwargs = kwargs.get("moe_kwargs", {}) # e.g. num_experts, top_k, router_noisy_std
51+
52+
if use_moe_ffn:
53+
self.mlp = MoEMLP(
54+
dim_in=dim,
55+
dim_out=dim,
56+
hidden_factor=ffn_hidden_factor,
57+
dropout_rate=0.1,
58+
nonlin=nn.GELU, # internal block constructs nonlin()
59+
with_residual=False,
60+
norm_type=kwargs["attention_kwargs"]["norm_type"],
61+
dim_aux=(dim_aux if self.with_adanorm else None),
62+
norm_eps=kwargs["attention_kwargs"]["norm_eps"],
63+
**moe_kwargs, # <- e.g. num_experts=8, top_k=2, router_noisy_std=0.01
64+
)
65+
else:
66+
self.mlp = MLP(
67+
dim_in=dim,
68+
dim_out=dim,
69+
hidden_factor=4,
70+
dropout_rate=0.1,
71+
nonlin=approx_gelu,
72+
with_residual=False,
73+
)
5474
if self.with_adanorm:
5575
self.mlp_fn = lambda x, **kwargs: self.mlp(x)
5676
self.mlp_block = AdaLayerNormLayer(dim, dim_aux, self.mlp_fn, dropout_rate)
@@ -104,7 +124,7 @@ def __init__(
104124

105125
self.with_adanorm = with_adanorm
106126
self.with_self_attn = with_self_attn
107-
self.with_mlp = with_self_attn
127+
self.with_mlp = with_mlp
108128

109129
if with_self_attn:
110130
self.mhsa = MultiSelfAttentionHeadVarlen(
@@ -136,18 +156,37 @@ def __init__(
136156

137157
if self.with_mlp:
138158
approx_gelu = lambda: nn.GELU(approximate="tanh")
139-
self.mlp = MLP(
140-
dim_in=dim_q,
141-
dim_out=dim_q,
142-
hidden_factor=4,
143-
nonlin=approx_gelu,
144-
with_residual=False,
145-
)
159+
160+
use_moe_ffn = (kwargs.get("ffn_mlp_type", "dense") == "moe")
161+
ffn_hidden_factor = kwargs.get("ffn_hidden_factor", 4)
162+
moe_kwargs = kwargs.get("moe_kwargs", {})
163+
164+
if use_moe_ffn:
165+
self.mlp = MoEMLP(
166+
dim_in=dim_q,
167+
dim_out=dim_q,
168+
hidden_factor=ffn_hidden_factor,
169+
dropout_rate=0.1,
170+
nonlin=nn.GELU, # internal block constructs nonlin()
171+
with_residual=False,
172+
norm_type=kwargs["attention_kwargs"]["norm_type"],
173+
dim_aux=(dim_aux if self.with_adanorm else None),
174+
norm_eps=kwargs["attention_kwargs"]["norm_eps"],
175+
**moe_kwargs, # <- e.g. num_experts=8, top_k=2, router_noisy_std=0.01
176+
)
177+
else:
178+
self.mlp = MLP(
179+
dim_in=dim_q,
180+
dim_out=dim_q,
181+
hidden_factor=4,
182+
nonlin=approx_gelu,
183+
with_residual=False,
184+
)
146185
if self.with_adanorm:
147186
self.mlp_fn = lambda x, **kwargs: self.mlp(x)
148187
self.mlp_block = AdaLayerNormLayer(dim_q, dim_aux, self.mlp_fn, dropout_rate)
149188
else:
150-
self.ln_mlp = nn.LayerNorm(dim_q, eps=kwargs["attention_kwargs"]["norm_eps"])
189+
self.ln_mlp = nn.LayerNorm(eps=kwargs["attention_kwargs"]["norm_eps"])
151190
self.mlp_block = lambda x, _, **kwargs: self.mlp(self.ln_mlp(x)) + x
152191
else:
153192
self.mlp_block = lambda x, _, **kwargs: x
@@ -191,6 +230,7 @@ def __init__(
191230
tr_mlp_hidden_factor,
192231
tro_type,
193232
mlp_norm_eps=1e-6,
233+
**kwargs,
194234
):
195235
super().__init__()
196236

@@ -237,19 +277,46 @@ def __init__(
237277
)
238278

239279
# MLP Block
240-
self.block.append(
241-
MLP(
242-
dim_in,
243-
dim_out,
244-
with_residual=True,
245-
hidden_factor=self.tr_mlp_hidden_factor,
246-
dropout_rate=0.1, # Assuming dropout_rate is 0.1
247-
norm_type=self.cf.norm_type,
248-
dim_aux=(dim_aux if self.cf.pred_mlp_adaln else None),
249-
norm_eps=self.cf.mlp_norm_eps,
250-
)
280+
# Add MoE option
281+
use_moe = getattr(self.cf, "decoder_mlp_type", "dense") == "moe"
282+
logger.info(
283+
"[MoE] Decoder head: type=%s%s",
284+
"moe" if use_moe else "dense",
285+
("" if not use_moe else
286+
f" (experts={getattr(self.cf,'moe_num_experts',None)}, top_k={getattr(self.cf,'moe_top_k',None)})"),
251287
)
252288

289+
if use_moe:
290+
self.block.append(
291+
MoEMLP(
292+
dim_in,
293+
dim_out,
294+
hidden_factor=self.tr_mlp_hidden_factor,
295+
dropout_rate=0.1,
296+
with_residual=True, # mirror dense
297+
norm_type=self.cf.norm_type,
298+
dim_aux=(dim_aux if self.cf.pred_mlp_adaln else None),
299+
norm_eps=self.cf.mlp_norm_eps,
300+
num_experts=getattr(self.cf, "moe_num_experts", 8),
301+
top_k=getattr(self.cf, "moe_top_k", 2),
302+
router_noisy_std=getattr(self.cf, "moe_router_noisy_std", 0.0),
303+
use_checkpoint=getattr(self.cf, "moe_use_checkpoint", False),
304+
)
305+
)
306+
else:
307+
self.block.append(
308+
MLP(
309+
dim_in,
310+
dim_out,
311+
with_residual=True,
312+
hidden_factor=self.tr_mlp_hidden_factor,
313+
dropout_rate=0.1, # Assuming dropout_rate is 0.1
314+
norm_type=self.cf.norm_type,
315+
dim_aux=(dim_aux if self.cf.pred_mlp_adaln else None),
316+
norm_eps=self.cf.mlp_norm_eps,
317+
)
318+
)
319+
253320
def forward(self, latent, output, coords, latent_lens, output_lens):
254321
for layer in self.block:
255322
if isinstance(layer, MultiCrossAttentionHeadVarlen):

src/weathergen/model/engines.py

Lines changed: 82 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@
2424
StreamEmbedLinear,
2525
StreamEmbedTransformer,
2626
)
27-
from weathergen.model.layers import MLP
27+
from weathergen.model.layers import MLP, MoEMLP
2828
from weathergen.model.utils import ActivationFactory
2929
from weathergen.utils.utils import get_dtype
3030

31+
import logging
32+
logger = logging.getLogger(__name__)
3133

3234
class EmbeddingEngine:
3335
name: "EmbeddingEngine"
@@ -249,17 +251,50 @@ def create(self) -> torch.nn.ModuleList:
249251
)
250252
)
251253
# MLP block
252-
self.ae_global_blocks.append(
253-
MLP(
254-
self.cf.ae_global_dim_embed,
255-
self.cf.ae_global_dim_embed,
256-
with_residual=True,
257-
dropout_rate=self.cf.ae_global_dropout_rate,
258-
hidden_factor=self.cf.ae_global_mlp_hidden_factor,
259-
norm_type=self.cf.norm_type,
260-
norm_eps=self.cf.mlp_norm_eps,
261-
)
254+
# Add MoE option
255+
use_moe = getattr(self.cf, "ae_global_mlp_type", "dense") == "moe"
256+
mlp_common_kwargs = dict(
257+
dim_in=self.cf.ae_global_dim_embed,
258+
dim_out=self.cf.ae_global_dim_embed,
259+
with_residual=True,
260+
dropout_rate=self.cf.ae_global_dropout_rate,
261+
norm_type=self.cf.norm_type,
262+
norm_eps=self.cf.mlp_norm_eps,
262263
)
264+
if use_moe:
265+
self.ae_global_blocks.append(
266+
MoEMLP(
267+
**mlp_common_kwargs,
268+
num_experts=getattr(self.cf, "ae_global_moe_num_experts", 2),
269+
top_k=getattr(self.cf, "ae_global_moe_top_k", 1),
270+
router_noisy_std=getattr(self.cf, "ae_global_moe_router_noisy_std", 0.0),
271+
hidden_factor=getattr(self.cf, "ae_global_moe_hidden_factor", 2),
272+
)
273+
)
274+
else:
275+
self.ae_global_blocks.append(
276+
MLP(
277+
self.cf.ae_global_dim_embed,
278+
self.cf.ae_global_dim_embed,
279+
with_residual=True,
280+
dropout_rate=self.cf.ae_global_dropout_rate,
281+
hidden_factor=self.cf.ae_global_mlp_hidden_factor,
282+
norm_type=self.cf.norm_type,
283+
norm_eps=self.cf.mlp_norm_eps,
284+
)
285+
)
286+
# Count MoE blocks
287+
num_moe = sum(1 for m in self.ae_global_blocks if isinstance(m, MoEMLP))
288+
logger.info(
289+
"[MoE] GlobalAssimilationEngine: %d MoEMLP blocks "
290+
"(ae_global_mlp_type=%s, experts=%s, top_k=%s, hidden_factor=%s)",
291+
num_moe,
292+
getattr(self.cf, "ae_global_mlp_type", "dense"),
293+
getattr(self.cf, "ae_global_moe_num_experts", None),
294+
getattr(self.cf, "ae_global_moe_top_k", None),
295+
getattr(self.cf, "ae_global_moe_hidden_factor", None),
296+
)
297+
263298
return self.ae_global_blocks
264299

265300

@@ -343,8 +378,8 @@ def create(self) -> torch.nn.ModuleList:
343378
self.fe_blocks.append(
344379
MoEMLP(
345380
**mlp_common_kwargs,
346-
num_experts=getattr(self.cf, "fe_moe_num_experts", 8),
347-
top_k=getattr(self.cf, "fe_moe_top_k", 4),
381+
num_experts=getattr(self.cf, "fe_moe_num_experts", 2),
382+
top_k=getattr(self.cf, "fe_moe_top_k", 2),
348383
router_noisy_std=getattr(self.cf, "fe_moe_router_noisy_std", 0.0),
349384
hidden_factor=getattr(self.cf, "fe_moe_hidden_factor", 2),
350385
)
@@ -362,15 +397,24 @@ def create(self) -> torch.nn.ModuleList:
362397
)
363398
)
364399
# ------------------------------------------------------------------
365-
def init_weights_final(m):
366-
if isinstance(m, torch.nn.Linear):
367-
torch.nn.init.normal_(m.weight, mean=0, std=0.001)
368-
if m.bias is not None:
369-
torch.nn.init.normal_(m.bias, mean=0, std=0.001)
370-
371-
for block in self.fe_blocks:
372-
block.apply(init_weights_final)
373-
400+
# def init_weights_final(m):
401+
# if isinstance(m, torch.nn.Linear) and not getattr(m, "is_moe_router", False):
402+
# torch.nn.init.normal_(m.weight, mean=0, std=0.001)
403+
# if m.bias is not None:
404+
# torch.nn.init.normal_(m.bias, mean=0, std=0.001)
405+
406+
# for block in self.fe_blocks:
407+
# block.apply(init_weights_final)
408+
num_moe = sum(1 for m in self.fe_blocks if isinstance(m, MoEMLP))
409+
logger.info(
410+
"[MoE] ForecastingEngine: %d MoEMLP blocks "
411+
"(fe_mlp_type=%s, experts=%s, top_k=%s, hidden_factor=%s)",
412+
num_moe,
413+
getattr(self.cf, "fe_mlp_type", "dense"),
414+
getattr(self.cf, "fe_moe_num_experts", None),
415+
getattr(self.cf, "fe_moe_top_k", None),
416+
getattr(self.cf, "fe_moe_hidden_factor", None),
417+
)
374418
return self.fe_blocks
375419

376420

@@ -619,6 +663,14 @@ def __init__(
619663
with_adanorm=False,
620664
with_mlp=False,
621665
attention_kwargs=attention_kwargs,
666+
ffn_mlp_type=getattr(self.cf, "decoder_ffn_mlp_type", "dense"),
667+
ffn_hidden_factor=getattr(self.cf, "decoder_ffn_hidden_factor", 4),
668+
moe_kwargs=dict(
669+
num_experts=getattr(self.cf, "decoder_moe_num_experts", 2),
670+
top_k=getattr(self.cf, "decoder_moe_top_k", 2),
671+
router_noisy_std=getattr(self.cf, "decoder_moe_router_noisy_std", 0.0),
672+
use_checkpoint=getattr(self.cf, "decoder_moe_use_checkpoint", False),
673+
)
622674
)
623675
)
624676
elif self.cf.decoder_type == "AdaLayerNormConditioning":
@@ -674,6 +726,14 @@ def __init__(
674726
tr_mlp_hidden_factor=tr_mlp_hidden_factor,
675727
tro_type=tro_type,
676728
mlp_norm_eps=self.cf.mlp_norm_eps,
729+
ffn_mlp_type=getattr(self.cf, "decoder_ffn_mlp_type", "dense"),
730+
ffn_hidden_factor=getattr(self.cf, "decoder_ffn_hidden_factor", 4),
731+
moe_kwargs=dict(
732+
num_experts=getattr(self.cf, "decoder_moe_num_experts", 2),
733+
top_k=getattr(self.cf, "decoder_moe_top_k", 2),
734+
router_noisy_std=getattr(self.cf, "decoder_moe_router_noisy_std", 0.0),
735+
use_checkpoint=getattr(self.cf, "decoder_moe_use_checkpoint", False),
736+
)
677737
)
678738
)
679739
else:

0 commit comments

Comments
 (0)