88from transformers .pytorch_utils import Conv1D
99from transformers .activations import ACT2FN
1010from transformers .utils import logging
11+
1112from protein_lm .modeling .utils .rotary_embedding import RotaryEmbedding
1213from protein_lm .modeling .utils .rerope_embedding import RectifiedRotaryEmbedding
1314from protein_lm .modeling .utils .alibi_embedding import create_alibi_tensor
@@ -34,6 +35,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
3435 self .max_sequence_length = config .max_sequence_length
3536 self .embed_dim = config .hidden_size
3637 self .num_heads = config .num_attention_heads
38+ self .attn_type = config .attn_type
3739 self .head_dim = self .embed_dim // self .num_heads
3840 self .split_size = self .embed_dim
3941 if self .head_dim * self .num_heads != self .embed_dim :
@@ -48,7 +50,15 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
4850 # Layer-wise attention scaling, reordering, and upcasting
4951 self .scale_attn_by_inverse_layer_idx = config .scale_attn_by_inverse_layer_idx
5052 self .layer_idx = layer_idx
51- self .reorder_and_upcast_attn = config .reorder_and_upcast_attn
53+
54+ if self .attn_type == "gqa" :
55+ self .gqa_attn = True
56+ elif self .attn_type == "reorder_and_upcast_attn" :
57+ self .reorder_and_upcast_attn = True
58+ elif self .attn_type == "standard" :
59+ self .standard_attn = True
60+
61+ #self.reorder_and_upcast_attn = config.reorder_and_upcast_attn #comment out because config now states attn type
5262
5363 if self .is_cross_attention :
5464 self .c_attn = Conv1D (2 * self .embed_dim , self .embed_dim )
@@ -116,6 +126,87 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bia
116126
117127 return attn_output , attn_weights
118128
129+ def _gqa_attn (self , query , key , value , attention_mask = None ,
130+ alibi_bias = None , dropout = 0.0 ):
131+ """Group Query Attention implementation."""
132+
133+ # Check for potential issues before moving on
134+ if not query .ndim == key .ndim == value .ndim == 4 :
135+ raise ValueError (f"Expected query, key, and value to be 4-dimensional, but got shapes "
136+ f"{ query .shape } , { key .shape } , and { value .shape } ." )
137+
138+ """
139+ Expected shapes: (batch_size, num_heads, query_len, query_dim) similar to _upcast_and_reordered_attn
140+ """
141+ batch_size , num_heads , query_len , query_dim = query .shape
142+
143+
144+ scale_factor = 1.0
145+ if self .scale_attn_weights :
146+ scale_factor /= float (value .size (- 1 )) ** 0.5
147+ query = query / scale_factor
148+
149+ '''
150+ Determine the number of groups
151+ For example lets say we have 4 queries heads and 2 keys heads, then we have 2 groups
152+ Lets say the number of group are 2 and head are 2,
153+ then reshape the query tensor to (batch_size, (2, 2), query_len, query_dim)
154+ query shape (batch_size, num_groups, num_heads, query_len, query_dim)
155+ attention_weights_grouped shape (batch_size, num_groups, num_heads, query_len, key_len).
156+ attention weights shape: (batch_size, num_heads, query_len, key_len)
157+ '''
158+
159+ n_groups = query .size (1 ) // key .size (1 )
160+
161+ if n_groups > 1 :
162+ query_shape = query .shape
163+ grouped_shape = (query_shape [0 ], n_groups , query_shape [1 ]// n_groups , query_shape [2 ], query_shape [3 ])
164+ query_grouped = query .reshape (grouped_shape )
165+ attn_weights_grouped = torch .matmul (query_grouped , key .transpose (- 2 , - 1 ))
166+ attn_weights = attn_weights_grouped .sum (dim = 1 )
167+ #print("attn_weights:", attn_weights.shape)
168+
169+ else :
170+ '''
171+ If the number of groups is 1, then we can use the normal attention function
172+ '''
173+ attn_weights = torch .matmul (query , key .transpose (- 2 , - 1 ))
174+
175+ if self .scale_attn_by_inverse_layer_idx :
176+ attn_weights = attn_weights / float (self .layer_idx + 1 )
177+
178+ if attention_mask is not None :
179+ # Apply the attention mask
180+ '''
181+ Input attention_mask shape: (batch_size, query_len, key_len)
182+ '''
183+ attn_weights += attention_mask .unsqueeze (1 ) # Unsqueeze to Add head dimension
184+
185+ # Causal masking ensures that the attention mechanism doesn't attend to "future" tokens in sequences.
186+ ## Adapted to work with groups and ensure similarity with vanilla attention
187+ if not self .is_cross_attention :
188+ query_length , key_length = query .size (- 2 ), key .size (- 2 )
189+ causal_mask = self .bias [:, :, key_length - query_length : key_length , :key_length ]
190+ mask_value = torch .finfo (attn_weights .dtype ).min
191+ mask_value = torch .full ([], mask_value , dtype = attn_weights .dtype ).to (attn_weights .device )
192+ attn_weights = torch .where (causal_mask , attn_weights .to (attn_weights .dtype ), mask_value )
193+
194+ # print("attn_weights:", attn_weights)
195+ # Softmax normalization to get the attention scores
196+ attn_weights = nn .functional .softmax (attn_weights , dim = - 1 )
197+
198+ if alibi_bias is not None :
199+ attn_weights = attn_weights + alibi_bias [:,:,:attn_weights .size (- 1 )]
200+
201+ # Apply dropout if specified
202+ attn_weights = attn_weights .type (value .dtype )
203+ attn_weights = self .attn_dropout (attn_weights )
204+
205+ # Compute the output by multiplying the attention scores with the value tensor.
206+ attn_output = torch .matmul (attn_weights , value )
207+
208+ return attn_output , attn_weights
209+
119210 def _upcast_and_reordered_attn (self , query , key , value , attention_mask = None , head_mask = None ,alibi_bias = None ):
120211 # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
121212 bsz , num_heads , q_seq_len , dk = query .size ()
@@ -233,9 +324,10 @@ def forward(
233324
234325 if self .reorder_and_upcast_attn :
235326 attn_output , attn_weights = self ._upcast_and_reordered_attn (query , key , value , attention_mask , head_mask ,alibi_bias = alibi_bias )
236- else :
327+ elif self . standard_attn :
237328 attn_output , attn_weights = self ._attn (query , key , value , attention_mask , head_mask ,alibi_bias = alibi_bias )
238-
329+ elif self .gqa_attn :
330+ attn_output , attn_weights = self ._gqa_attn (query , key , value , attention_mask ,alibi_bias = alibi_bias )
239331 attn_output = self ._merge_heads (attn_output , self .num_heads , self .head_dim )
240332 attn_output = self .c_proj (attn_output )
241333 attn_output = self .resid_dropout (attn_output )
@@ -244,7 +336,7 @@ def forward(
244336 if output_attentions :
245337 outputs += (attn_weights ,)
246338
247- return outputs # a, present, (attentions)
339+ return outputs # a, present, (attentions)
248340
249341
250342class APTMLP (nn .Module ):
0 commit comments