11from typing import Optional , Tuple , Union
2+ import math
23import torch
34from torch import nn
45from torch .nn import CrossEntropyLoss
@@ -41,6 +42,9 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
4142 f"`embed_dim` must be divisible by num_heads (got `embed_dim`: { self .embed_dim } and `num_heads`:"
4243 f" { self .num_heads } )."
4344 )
45+
46+ # muP
47+ self .use_mup = config .use_mup
4448
4549 self .scale_attn_weights = config .scale_attn_weights
4650 self .is_cross_attention = is_cross_attention
@@ -53,15 +57,41 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
5357 if self .is_cross_attention :
5458 self .c_attn = Conv1D (2 * self .embed_dim , self .embed_dim )
5559 self .q_attn = Conv1D (self .embed_dim , self .embed_dim )
60+
61+ #muP -- q_attn
62+ if self .use_mup :
63+ self .q_attn .weight .data .normal_ (mean = 0.0 , std = math .sqrt ((config .initializer_range ** 2 ) / config .mup_width_scale ))
64+ self .q_attn .bias .zero_ ()
65+
5666 else :
5767 self .c_attn = Conv1D (3 * self .embed_dim , self .embed_dim )
58- self .c_proj = Conv1D (self .embed_dim , self .embed_dim )
5968
60- self .attn_dropout = nn .Dropout (config .attn_pdrop )
61- self .resid_dropout = nn .Dropout (config .resid_pdrop )
69+ #muP -- c_attn specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L487
70+ if self .use_mup :
71+ if config .query_zero_init :
72+ _ , fanout = self .c_attn .weight .shape
73+ self .c_attn .weight .data [:, :fanout // 3 ] = 0
74+ self .c_attn .bias .zero_ ()
75+
76+ self .c_proj = Conv1D (self .embed_dim , self .embed_dim )
77+
78+ #muP -- c_proj specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L494
79+ if self .use_mup :
80+ depth_std = config .initializer_range / math .sqrt (2 * config .n_layer )
81+ self .c_proj .weight .data .normal_ (mean = 0.0 , std = math .sqrt (config .depth_std ** 2 / config .mup_width_scale ))
82+ self .c_proj .bias .zero_ ()
83+
84+ if self .use_mup :
85+ self .attn_dropout = nn .Identity ()
86+ self .resid_dropout = nn .Identity ()
87+ else :
88+ self .attn_dropout = nn .Dropout (config .attn_pdrop )
89+ self .resid_dropout = nn .Dropout (config .resid_pdrop )
6290
6391 self .pruned_heads = set ()
6492
93+
94+
6595 self .rot_emb = None
6696 if self .position_embedding == "rope" :
6797 self .rot_emb = RotaryEmbedding (dim = self .head_dim )
@@ -76,8 +106,13 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
76106
77107 def _attn (self , query , key , value , attention_mask = None , head_mask = None ,alibi_bias = None ):
78108 attn_weights = torch .matmul (query , key .transpose (- 1 , - 2 ))
79-
80- if self .scale_attn_weights :
109+
110+ #muP
111+ if self .use_mup :
112+ attn_weights = attn_weights / torch .full (
113+ [], value .size (- 1 ), dtype = attn_weights .dtype , device = attn_weights .device
114+ )
115+ elif self .scale_attn_weights :
81116 attn_weights = attn_weights / torch .full (
82117 [], value .size (- 1 ) ** 0.5 , dtype = attn_weights .dtype , device = attn_weights .device
83118 )
@@ -251,10 +286,31 @@ class APTMLP(nn.Module):
251286 def __init__ (self , intermediate_size , config ):
252287 super ().__init__ ()
253288 embed_dim = config .hidden_size
289+
290+ #muP
291+ use_mup = config .use_mup
292+
254293 self .c_fc = Conv1D (intermediate_size , embed_dim )
294+
295+ #muP -- matrix-like
296+ if use_mup :
297+ self .c_fc .weight .data .normal_ (mean = 0.0 , std = math .sqrt ((config .initializer_range ** 2 ) / config .mup_width_scale ))
298+ self .c_fc .bias .zero_ ()
299+
255300 self .c_proj = Conv1D (embed_dim , intermediate_size )
301+
302+ #muP -- matrix-like, c_proj-specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L494
303+ if use_mup :
304+ depth_std = config .initializer_range / math .sqrt (2 * config .n_layer )
305+ self .c_proj .weight .data .normal_ (mean = 0.0 , std = math .sqrt (depth_std ** 2 / config .mup_width_scale ))
306+ self .c_proj .bias .zero_ ()
307+
256308 self .act = ACT2FN [config .activation_function ]
257- self .dropout = nn .Dropout (config .resid_pdrop )
309+
310+ if use_mup :
311+ self .dropout = nn .Identity ()
312+ else :
313+ self .dropout = nn .Dropout (config .resid_pdrop )
258314
259315 def forward (self , hidden_states : Optional [Tuple [torch .FloatTensor ]]) -> torch .FloatTensor :
260316 hidden_states = self .c_fc (hidden_states )
@@ -270,14 +326,34 @@ def __init__(self, config, layer_idx=None):
270326 hidden_size = config .hidden_size
271327 inner_dim = config .n_inner if config .n_inner is not None else 4 * hidden_size
272328
329+ #muP
330+ use_mup = config .use_mup
331+
273332 self .ln_1 = nn .LayerNorm (hidden_size , eps = config .layer_norm_epsilon )
333+
334+ #muP -- vector-like
335+ if self .use_mup :
336+ self .ln_1 .weight .data .fill_ (1.0 )
337+ self .ln_1 .bias .data .zero_ ()
338+
274339 self .attn = APTAttention (config , layer_idx = layer_idx )
275340 self .ln_2 = nn .LayerNorm (hidden_size , eps = config .layer_norm_epsilon )
276341
342+ #muP -- vector-like
343+ if use_mup :
344+ self .ln_2 .weight .data .fill_ (1.0 )
345+ self .ln_2 .bias .data .zero_ ()
346+
277347 if config .add_cross_attention :
348+ #muP TO DO: check proper behavior in case of crossattention
278349 self .crossattention = APTAttention (config , is_cross_attention = True , layer_idx = layer_idx )
279350 self .ln_cross_attn = nn .LayerNorm (hidden_size , eps = config .layer_norm_epsilon )
280351
352+ #muP -- vector-like
353+ if use_mup :
354+ self .ln_cross_attn .weight .data .fill_ (1.0 )
355+ self .ln_cross_attn .bias .data .zero_ ()
356+
281357 self .mlp = APTMLP (inner_dim , config )
282358
283359 def forward (
@@ -353,26 +429,60 @@ class APTModel(GPT2PreTrainedModel):
353429 def __init__ (self , config ):
354430 super ().__init__ (config )
355431
356- self .embed_dim = config .hidden_size
432+ self .embed_dim = config .hidden_sizeù
433+ use_mup = config .use_mup
357434
358435 self .wte = nn .Embedding (config .vocab_size , self .embed_dim )
436+
437+ #muP -- vector-like, zero if zero init or mantained regardless of width
438+ if use_mup :
439+ if config .wte_zero_init :
440+ self .wte .weight .data .zero_ ()
441+ else :
442+ self .wte .weight .data .normal_ (mean = 0.0 , std = config .initializer_range )
443+
444+ if self .wte .padding_idx is not None :
445+ self .wte .weight .data [self .wte .padding_idx ].zero_ ()
446+
359447 self .position_embedding = config .position_embedding if hasattr (config , "position_embedding" ) else "learned"
360448
361449 if self .position_embedding == "learned" or self .position_embedding == 'rope' or self .position_embedding == 'rerope' or self .position_embedding == "linear_rope_scaling" or self .position_embedding == "dynamic_rope_scaling" :
362450 self .wpe = nn .Embedding (config .max_position_embeddings , self .embed_dim )
451+
452+ #muP -- vector-like, constant regardless of width
453+ #muP TO DO: check proper behavior in rope & rerope case
454+ if self .use_mup :
455+ self .wpe .weight .data .normal_ (0.0 , std = config .initializer_range )
456+
457+ if self .wpe .padding_idx is not None :
458+ self .wpe .weight .data [self .wte .padding_idx ].zero_ ()
459+
363460 self .alibi = None
364461 elif self .position_embedding == "alibi" :
462+ #muP TO DO: check proper behavior in alibi case
463+
365464 maxpos = config .n_positions
366465 attn_heads = config .n_head
367466 alibi = create_alibi_tensor (attn_heads ,maxpos )
368467 self .register_buffer ('alibi' ,alibi )
369468 else :
370469 raise Exception (f'position_embedding { self .position_embedding } not supported. Please select one of learned, rope, rerope, linear rope, dynamic rope or alibi' )
371470
372- self .drop = nn .Dropout (config .embd_pdrop )
471+ #muP
472+ if use_mup :
473+ self .drop = nn .Identity ()
474+ else :
475+ self .drop = nn .Dropout (config .embd_pdrop )
476+
373477 self .h = nn .ModuleList ([APTBlock (config , layer_idx = i ) for i in range (config .num_hidden_layers )])
478+
374479 self .ln_f = nn .LayerNorm (self .embed_dim , eps = config .layer_norm_epsilon )
375480
481+ #muP -- vector-like
482+ if use_mup :
483+ self .ln_f .weight .data .fill_ (1.0 )
484+ self .ln_f .bias .data .zero_ ()
485+
376486 # Model parallel
377487 self .model_parallel = False
378488 self .device_map = None
@@ -474,6 +584,7 @@ def forward(
474584
475585 if self .position_embedding == "learned" or self .position_embedding == 'rope' or self .position_embedding == 'rerope' or self .position_embedding == "linear_rope_scaling" or self .position_embedding == "dynamic_rope_scaling" :
476586 position_embeds = self .wpe (position_ids )
587+ position_embeds .mul_ (self .mup_rp_embedding_mult )
477588 hidden_states = inputs_embeds + position_embeds
478589 else :
479590 hidden_states = inputs_embeds
@@ -593,6 +704,10 @@ class APTLMHeadModel(GPT2PreTrainedModel):
593704 def __init__ (self , config ):
594705 super ().__init__ (config )
595706 self .transformer = APTModel (config )
707+
708+ #muP TO DO: check proper behavior for LM head, nothing should be done (?)
709+ #see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L472
710+ #see also table 8's caption in https://arxiv.org/pdf/2203.03466.pdf
596711 self .lm_head = nn .Linear (config .n_embd , config .vocab_size , bias = False )
597712
598713 # Model parallel
0 commit comments