Skip to content

Commit b588300

Browse files
committed
add initial implementation of mup [to be checked]
1 parent 0e48cff commit b588300

File tree

2 files changed

+144
-8
lines changed

2 files changed

+144
-8
lines changed

protein_lm/modeling/models/apt/config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,32 @@ def __init__(
1111
position_embedding="learned",
1212
tokenizer=None,
1313
max_sequence_length = 1024,
14+
use_mup = False,
15+
query_zero_init = True,
16+
n_layer = None,
17+
initializer_range = 0.02,
18+
mup_init_scale = 1.0,
19+
mup_output_temp = 1.0,
20+
mup_attn_mult = 1.0,
21+
mup_embedding_mult = 1.0,
22+
mup_rp_embedding_mult = 1.0,
23+
mup_width_scale = 2.0,
1424
**kwargs
1525
):
1626
super().__init__(**kwargs)
1727
self.nn_model_type = "APT"
1828
self.position_embedding = position_embedding
1929
self.tokenizer = tokenizer
2030
self.max_sequence_length = max_sequence_length
31+
32+
self.use_mup = use_mup
33+
self.query_zero_init = query_zero_init,
34+
self.n_layer = n_layer
35+
self.initializer_range = initializer_range
36+
self.mup_init_scale = mup_init_scale
37+
self.mup_output_temp = mup_output_temp
38+
self.mup_attn_mult = mup_attn_mult
39+
self.mup_embedding_mult = mup_embedding_mult
40+
self.mup_rp_embedding_mult = mup_rp_embedding_mult
41+
self.mup_width_scale = mup_width_scale
2142

protein_lm/modeling/models/apt/model_pytorch.py

Lines changed: 123 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Optional, Tuple, Union
2+
import math
23
import torch
34
from torch import nn
45
from torch.nn import CrossEntropyLoss
@@ -41,6 +42,9 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
4142
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
4243
f" {self.num_heads})."
4344
)
45+
46+
# muP
47+
self.use_mup = config.use_mup
4448

4549
self.scale_attn_weights = config.scale_attn_weights
4650
self.is_cross_attention = is_cross_attention
@@ -53,15 +57,41 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
5357
if self.is_cross_attention:
5458
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
5559
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
60+
61+
#muP -- q_attn
62+
if self.use_mup:
63+
self.q_attn.weight.data.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.mup_width_scale))
64+
self.q_attn.bias.zero_()
65+
5666
else:
5767
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
58-
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
5968

60-
self.attn_dropout = nn.Dropout(config.attn_pdrop)
61-
self.resid_dropout = nn.Dropout(config.resid_pdrop)
69+
#muP -- c_attn specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L487
70+
if self.use_mup:
71+
if config.query_zero_init:
72+
_, fanout = self.c_attn.weight.shape
73+
self.c_attn.weight.data[:, :fanout//3] = 0
74+
self.c_attn.bias.zero_()
75+
76+
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
77+
78+
#muP -- c_proj specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L494
79+
if self.use_mup:
80+
depth_std = config.initializer_range / math.sqrt(2 * config.n_layer)
81+
self.c_proj.weight.data.normal_(mean=0.0, std=math.sqrt(config.depth_std ** 2 / config.mup_width_scale))
82+
self.c_proj.bias.zero_()
83+
84+
if self.use_mup:
85+
self.attn_dropout = nn.Identity()
86+
self.resid_dropout = nn.Identity()
87+
else:
88+
self.attn_dropout = nn.Dropout(config.attn_pdrop)
89+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
6290

6391
self.pruned_heads = set()
6492

93+
94+
6595
self.rot_emb=None
6696
if self.position_embedding == "rope":
6797
self.rot_emb=RotaryEmbedding(dim=self.head_dim)
@@ -76,8 +106,13 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
76106

77107
def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bias=None):
78108
attn_weights = torch.matmul(query, key.transpose(-1, -2))
79-
80-
if self.scale_attn_weights:
109+
110+
#muP
111+
if self.use_mup:
112+
attn_weights = attn_weights / torch.full(
113+
[], value.size(-1), dtype=attn_weights.dtype, device=attn_weights.device
114+
)
115+
elif self.scale_attn_weights:
81116
attn_weights = attn_weights / torch.full(
82117
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
83118
)
@@ -251,10 +286,31 @@ class APTMLP(nn.Module):
251286
def __init__(self, intermediate_size, config):
252287
super().__init__()
253288
embed_dim = config.hidden_size
289+
290+
#muP
291+
use_mup = config.use_mup
292+
254293
self.c_fc = Conv1D(intermediate_size, embed_dim)
294+
295+
#muP -- matrix-like
296+
if use_mup:
297+
self.c_fc.weight.data.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.mup_width_scale))
298+
self.c_fc.bias.zero_()
299+
255300
self.c_proj = Conv1D(embed_dim, intermediate_size)
301+
302+
#muP -- matrix-like, c_proj-specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L494
303+
if use_mup:
304+
depth_std = config.initializer_range / math.sqrt(2 * config.n_layer)
305+
self.c_proj.weight.data.normal_(mean=0.0, std=math.sqrt(depth_std ** 2 / config.mup_width_scale))
306+
self.c_proj.bias.zero_()
307+
256308
self.act = ACT2FN[config.activation_function]
257-
self.dropout = nn.Dropout(config.resid_pdrop)
309+
310+
if use_mup:
311+
self.dropout = nn.Identity()
312+
else:
313+
self.dropout = nn.Dropout(config.resid_pdrop)
258314

259315
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
260316
hidden_states = self.c_fc(hidden_states)
@@ -270,14 +326,34 @@ def __init__(self, config, layer_idx=None):
270326
hidden_size = config.hidden_size
271327
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
272328

329+
#muP
330+
use_mup = config.use_mup
331+
273332
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
333+
334+
#muP -- vector-like
335+
if self.use_mup:
336+
self.ln_1.weight.data.fill_(1.0)
337+
self.ln_1.bias.data.zero_()
338+
274339
self.attn = APTAttention(config, layer_idx=layer_idx)
275340
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
276341

342+
#muP -- vector-like
343+
if use_mup:
344+
self.ln_2.weight.data.fill_(1.0)
345+
self.ln_2.bias.data.zero_()
346+
277347
if config.add_cross_attention:
348+
#muP TO DO: check proper behavior in case of crossattention
278349
self.crossattention = APTAttention(config, is_cross_attention=True, layer_idx=layer_idx)
279350
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
280351

352+
#muP -- vector-like
353+
if use_mup:
354+
self.ln_cross_attn.weight.data.fill_(1.0)
355+
self.ln_cross_attn.bias.data.zero_()
356+
281357
self.mlp = APTMLP(inner_dim, config)
282358

283359
def forward(
@@ -353,26 +429,60 @@ class APTModel(GPT2PreTrainedModel):
353429
def __init__(self, config):
354430
super().__init__(config)
355431

356-
self.embed_dim = config.hidden_size
432+
self.embed_dim = config.hidden_sizeù
433+
use_mup = config.use_mup
357434

358435
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
436+
437+
#muP -- vector-like, zero if zero init or mantained regardless of width
438+
if use_mup:
439+
if config.wte_zero_init:
440+
self.wte.weight.data.zero_()
441+
else:
442+
self.wte.weight.data.normal_(mean=0.0, std=config.initializer_range)
443+
444+
if self.wte.padding_idx is not None:
445+
self.wte.weight.data[self.wte.padding_idx].zero_()
446+
359447
self.position_embedding = config.position_embedding if hasattr(config, "position_embedding") else "learned"
360448

361449
if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling":
362450
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
451+
452+
#muP -- vector-like, constant regardless of width
453+
#muP TO DO: check proper behavior in rope & rerope case
454+
if self.use_mup:
455+
self.wpe.weight.data.normal_(0.0, std=config.initializer_range)
456+
457+
if self.wpe.padding_idx is not None:
458+
self.wpe.weight.data[self.wte.padding_idx].zero_()
459+
363460
self.alibi = None
364461
elif self.position_embedding=="alibi":
462+
#muP TO DO: check proper behavior in alibi case
463+
365464
maxpos = config.n_positions
366465
attn_heads = config.n_head
367466
alibi = create_alibi_tensor(attn_heads,maxpos)
368467
self.register_buffer('alibi',alibi)
369468
else:
370469
raise Exception(f'position_embedding {self.position_embedding} not supported. Please select one of learned, rope, rerope, linear rope, dynamic rope or alibi')
371470

372-
self.drop = nn.Dropout(config.embd_pdrop)
471+
#muP
472+
if use_mup:
473+
self.drop = nn.Identity()
474+
else:
475+
self.drop = nn.Dropout(config.embd_pdrop)
476+
373477
self.h = nn.ModuleList([APTBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
478+
374479
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
375480

481+
#muP -- vector-like
482+
if use_mup:
483+
self.ln_f.weight.data.fill_(1.0)
484+
self.ln_f.bias.data.zero_()
485+
376486
# Model parallel
377487
self.model_parallel = False
378488
self.device_map = None
@@ -474,6 +584,7 @@ def forward(
474584

475585
if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling":
476586
position_embeds = self.wpe(position_ids)
587+
position_embeds.mul_(self.mup_rp_embedding_mult)
477588
hidden_states = inputs_embeds + position_embeds
478589
else:
479590
hidden_states = inputs_embeds
@@ -593,6 +704,10 @@ class APTLMHeadModel(GPT2PreTrainedModel):
593704
def __init__(self, config):
594705
super().__init__(config)
595706
self.transformer = APTModel(config)
707+
708+
#muP TO DO: check proper behavior for LM head, nothing should be done (?)
709+
#see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L472
710+
#see also table 8's caption in https://arxiv.org/pdf/2203.03466.pdf
596711
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
597712

598713
# Model parallel

0 commit comments

Comments
 (0)