-
Notifications
You must be signed in to change notification settings - Fork 14
Description
I'm finding our handling of the initial positional embeddings, before the APT blocks, (self.wpe or its absence in the definition of APTModel) to be a bit weird.
They are initialized here:
protein-lm-scaling/protein_lm/modeling/models/apt/model_pytorch.py
Lines 453 to 460 in 86ca8f5
| if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling": | |
| self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) | |
| self.alibi = None | |
| elif self.position_embedding=="alibi": | |
| maxpos = config.n_positions | |
| attn_heads = config.n_head | |
| alibi = create_alibi_tensor(attn_heads,maxpos) | |
| self.register_buffer('alibi',alibi) |
and used here:
protein-lm-scaling/protein_lm/modeling/models/apt/model_pytorch.py
Lines 567 to 571 in 86ca8f5
| if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling": | |
| position_embeds = self.wpe(position_ids) | |
| hidden_states = inputs_embeds + position_embeds | |
| else: | |
| hidden_states = inputs_embeds |
It seems that for learned embedding as well as for variants of rope, a learned positional embedding is added before passing on to the blocks. Only for alibi is this positional embedding omitted. (The APT blocks have rope/alibi as was specified, so this first positional embedding being omitted does not mean that these positional embeddings are never used.)
This seems weird to me because I don't see why rope should be grouped with learned embeddings. It makes more sense to me for rope variants to also omit having an initial positional embedding (i.e., no self.wpe). I would also be more okay with all of them having an initial positional embedding, but this doesn't seem the standard way language models are implemented e.g., in llama.
Tagging @talkhanz who I think was the original author of this logic, and @jamaliki @jeffreyruffolo @NZ99 @pascalnotin for their thoughts.