Skip to content

Commit 0d4cba5

Browse files
GQA Attention (#59)
1 parent 31dcd8f commit 0d4cba5

File tree

2 files changed

+185
-4
lines changed

2 files changed

+185
-4
lines changed

protein_lm/modeling/models/apt/model_pytorch.py

Lines changed: 96 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from transformers.pytorch_utils import Conv1D
99
from transformers.activations import ACT2FN
1010
from transformers.utils import logging
11+
1112
from protein_lm.modeling.utils.rotary_embedding import RotaryEmbedding
1213
from protein_lm.modeling.utils.rerope_embedding import RectifiedRotaryEmbedding
1314
from protein_lm.modeling.utils.alibi_embedding import create_alibi_tensor
@@ -34,6 +35,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
3435
self.max_sequence_length = config.max_sequence_length
3536
self.embed_dim = config.hidden_size
3637
self.num_heads = config.num_attention_heads
38+
self.attn_type = config.attn_type
3739
self.head_dim = self.embed_dim // self.num_heads
3840
self.split_size = self.embed_dim
3941
if self.head_dim * self.num_heads != self.embed_dim:
@@ -48,7 +50,15 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
4850
# Layer-wise attention scaling, reordering, and upcasting
4951
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
5052
self.layer_idx = layer_idx
51-
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
53+
54+
if self.attn_type == "gqa":
55+
self.gqa_attn = True
56+
elif self.attn_type == "reorder_and_upcast_attn":
57+
self.reorder_and_upcast_attn = True
58+
elif self.attn_type == "standard":
59+
self.standard_attn = True
60+
61+
#self.reorder_and_upcast_attn = config.reorder_and_upcast_attn #comment out because config now states attn type
5262

5363
if self.is_cross_attention:
5464
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
@@ -116,6 +126,87 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bia
116126

117127
return attn_output, attn_weights
118128

129+
def _gqa_attn(self, query, key, value, attention_mask=None,
130+
alibi_bias =None, dropout=0.0):
131+
"""Group Query Attention implementation."""
132+
133+
# Check for potential issues before moving on
134+
if not query.ndim == key.ndim == value.ndim == 4:
135+
raise ValueError(f"Expected query, key, and value to be 4-dimensional, but got shapes "
136+
f"{query.shape}, {key.shape}, and {value.shape}.")
137+
138+
"""
139+
Expected shapes: (batch_size, num_heads, query_len, query_dim) similar to _upcast_and_reordered_attn
140+
"""
141+
batch_size, num_heads, query_len, query_dim = query.shape
142+
143+
144+
scale_factor = 1.0
145+
if self.scale_attn_weights:
146+
scale_factor /= float(value.size(-1)) ** 0.5
147+
query = query / scale_factor
148+
149+
'''
150+
Determine the number of groups
151+
For example lets say we have 4 queries heads and 2 keys heads, then we have 2 groups
152+
Lets say the number of group are 2 and head are 2,
153+
then reshape the query tensor to (batch_size, (2, 2), query_len, query_dim)
154+
query shape (batch_size, num_groups, num_heads, query_len, query_dim)
155+
attention_weights_grouped shape (batch_size, num_groups, num_heads, query_len, key_len).
156+
attention weights shape: (batch_size, num_heads, query_len, key_len)
157+
'''
158+
159+
n_groups = query.size(1) // key.size(1)
160+
161+
if n_groups > 1:
162+
query_shape = query.shape
163+
grouped_shape = (query_shape[0], n_groups, query_shape[1]//n_groups, query_shape[2], query_shape[3])
164+
query_grouped = query.reshape(grouped_shape)
165+
attn_weights_grouped = torch.matmul(query_grouped, key.transpose(-2, -1))
166+
attn_weights = attn_weights_grouped.sum(dim=1)
167+
#print("attn_weights:", attn_weights.shape)
168+
169+
else:
170+
'''
171+
If the number of groups is 1, then we can use the normal attention function
172+
'''
173+
attn_weights = torch.matmul(query, key.transpose(-2, -1))
174+
175+
if self.scale_attn_by_inverse_layer_idx:
176+
attn_weights = attn_weights / float(self.layer_idx + 1)
177+
178+
if attention_mask is not None:
179+
# Apply the attention mask
180+
'''
181+
Input attention_mask shape: (batch_size, query_len, key_len)
182+
'''
183+
attn_weights += attention_mask.unsqueeze(1) # Unsqueeze to Add head dimension
184+
185+
# Causal masking ensures that the attention mechanism doesn't attend to "future" tokens in sequences.
186+
## Adapted to work with groups and ensure similarity with vanilla attention
187+
if not self.is_cross_attention:
188+
query_length, key_length = query.size(-2), key.size(-2)
189+
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
190+
mask_value = torch.finfo(attn_weights.dtype).min
191+
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
192+
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
193+
194+
# print("attn_weights:", attn_weights)
195+
# Softmax normalization to get the attention scores
196+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
197+
198+
if alibi_bias is not None:
199+
attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)]
200+
201+
# Apply dropout if specified
202+
attn_weights = attn_weights.type(value.dtype)
203+
attn_weights = self.attn_dropout(attn_weights)
204+
205+
# Compute the output by multiplying the attention scores with the value tensor.
206+
attn_output = torch.matmul(attn_weights, value)
207+
208+
return attn_output, attn_weights
209+
119210
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bias=None):
120211
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
121212
bsz, num_heads, q_seq_len, dk = query.size()
@@ -233,9 +324,10 @@ def forward(
233324

234325
if self.reorder_and_upcast_attn:
235326
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask,alibi_bias=alibi_bias)
236-
else:
327+
elif self.standard_attn:
237328
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask,alibi_bias=alibi_bias)
238-
329+
elif self.gqa_attn:
330+
attn_output, attn_weights = self._gqa_attn(query, key, value, attention_mask,alibi_bias=alibi_bias)
239331
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
240332
attn_output = self.c_proj(attn_output)
241333
attn_output = self.resid_dropout(attn_output)
@@ -244,7 +336,7 @@ def forward(
244336
if output_attentions:
245337
outputs += (attn_weights,)
246338

247-
return outputs # a, present, (attentions)
339+
return outputs # a, present, (attentions)
248340

249341

250342
class APTMLP(nn.Module):

protein_lm/tests/test_attention.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import pytest
2+
import torch
3+
from torch.nn import functional as F
4+
5+
from model_pytorch import APTAttention
6+
7+
class ParameterConfig:
8+
def __init__(self):
9+
self.max_position_embeddings = 512
10+
self.position_embedding = 'rope'
11+
self.max_sequence_length = 512
12+
self.hidden_size = 768
13+
self.num_attention_heads = 12
14+
self.scale_attn_weights = True
15+
self.scale_attn_by_inverse_layer_idx = True
16+
self.reorder_and_upcast_attn = True
17+
self.attn_pdrop = 0.1
18+
self.resid_pdrop = 0.1
19+
self.rope_scaling_factor = 1
20+
self.rope_theta = 1
21+
self.attn_type = 'gqa'
22+
23+
24+
def test_vanilla_attn():
25+
# Initialize with mock config
26+
config = ParameterConfig()
27+
attention = APTAttention(config, is_cross_attention=False, layer_idx=0)
28+
29+
# generate random input tensors
30+
batch_size = 4
31+
seq_length = 100
32+
num_heads = config.num_attention_heads
33+
query_dim = config.hidden_size // num_heads
34+
query = torch.randn(batch_size, num_heads, seq_length, query_dim)
35+
key = torch.randn(batch_size, num_heads, seq_length, query_dim)
36+
value = torch.randn(batch_size, num_heads, seq_length, query_dim)
37+
38+
# Create a random attention mask for testing
39+
attention_mask = torch.ones(batch_size,seq_length, seq_length)
40+
padding_positions = 10
41+
attention_mask[:, -padding_positions:, :] = float('-inf')
42+
attention_mask[:, :, -padding_positions:] = float('-inf')
43+
attention_mask = attention_mask.unsqueeze(1)
44+
# Pass them through the _attn method
45+
attn_output, attn_weights = attention._attn(query, key, value, attention_mask=attention_mask)
46+
47+
# Check the shapes and types of the output
48+
assert isinstance(attn_output, torch.Tensor)
49+
assert attn_output.shape == (batch_size, num_heads, seq_length, query_dim)
50+
assert isinstance(attn_weights, torch.Tensor)
51+
assert attn_weights.shape == (batch_size, num_heads, seq_length, seq_length)
52+
print("Test passed!")
53+
54+
def test_gqa_attn():
55+
# Initialize with mock config
56+
config = ParameterConfig()
57+
attention = APTAttention(config, is_cross_attention=False, layer_idx=0)
58+
59+
# generate random input tensors
60+
batch_size = 4
61+
seq_length = 100
62+
num_heads = config.num_attention_heads
63+
query_dim = config.hidden_size // num_heads
64+
query = torch.randn(batch_size, num_heads, seq_length, query_dim)
65+
key = torch.randn(batch_size, num_heads, seq_length, query_dim)
66+
value = torch.randn(batch_size, num_heads, seq_length, query_dim)
67+
68+
# Create a random attention mask for testing
69+
attention_mask = torch.ones(batch_size,seq_length, seq_length)
70+
padding_positions = 10
71+
attention_mask[:, -padding_positions:, :] = float('-inf')
72+
attention_mask[:, :, -padding_positions:] = float('-inf')
73+
74+
# Pass them through the _gqa_attn method
75+
attn_output, attn_weights = attention._gqa_attn(query, key, value, attention_mask=attention_mask)
76+
77+
# Check the shapes and types of the output
78+
assert isinstance(attn_output, torch.Tensor)
79+
assert attn_output.shape == (batch_size, num_heads, seq_length, query_dim)
80+
assert isinstance(attn_weights, torch.Tensor)
81+
assert attn_weights.shape == (batch_size, num_heads, seq_length, seq_length)
82+
print("Test passed!")
83+
84+
85+
test_gqa_attn()
86+
test_vanilla_attn()
87+
88+
89+

0 commit comments

Comments
 (0)