Skip to content

Commit 8415d03

Browse files
fuse q and kve parameters for qga case (bigscience-workshop#291)
1 parent 2348eed commit 8415d03

File tree

1 file changed

+28
-56
lines changed

1 file changed

+28
-56
lines changed

megatron/model/transformer.py

Lines changed: 28 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -532,39 +532,23 @@ def __init__(self, config, layer_number,
532532
config.num_attention_heads, world_size)
533533

534534
# Per GQA head and per partition values
535-
if self.use_gqa:
536-
kv_projection_size = config.kv_channels * config.num_key_value_heads
537-
self.num_key_value_heads_per_partition = core.utils.divide(
538-
config.num_key_value_heads, world_size)
539-
self.num_key_value_groups = core.utils.divide(
540-
config.num_attention_heads, config.num_key_value_heads)
541-
assert self.hidden_size_per_attention_head == core.utils.divide(
542-
kv_projection_size, config.num_key_value_heads)
535+
self.num_key_value_heads_per_partition = core.utils.divide(
536+
config.num_key_value_heads, world_size)
537+
self.num_key_value_groups = core.utils.divide(
538+
config.num_attention_heads, config.num_key_value_heads)
539+
kv_projection_size = config.kv_channels * config.num_key_value_heads
540+
assert self.hidden_size_per_attention_head == core.utils.divide(
541+
kv_projection_size, config.num_key_value_heads)
543542

544543
# Strided linear layer.
545-
if attention_type == AttnType.self_attn and not self.use_gqa:
544+
if attention_type == AttnType.self_attn:
546545
self.query_key_value = tensor_parallel.ColumnParallelLinear(
547546
config.hidden_size,
548-
3 * projection_size,
547+
projection_size + 2 * kv_projection_size,
549548
config=config,
550549
init_method=config.init_method,
551550
bias=args.add_bias_linear,
552551
gather_output=False)
553-
elif attention_type == AttnType.self_attn and self.use_gqa:
554-
self.query = tensor_parallel.ColumnParallelLinear(
555-
config.hidden_size,
556-
projection_size,
557-
config=config,
558-
init_method=config.init_method,
559-
bias=config.add_bias_linear,
560-
gather_output=False)
561-
self.key_value = tensor_parallel.ColumnParallelLinear(
562-
config.hidden_size,
563-
2 * kv_projection_size,
564-
config=config,
565-
init_method=config.init_method,
566-
bias=config.add_bias_linear,
567-
gather_output=False)
568552
else:
569553
assert attention_type == AttnType.cross_attn
570554
self.query = tensor_parallel.ColumnParallelLinear(
@@ -657,6 +641,13 @@ def repeat_kv(self, hidden_states, n_rep):
657641
return hidden_states.reshape(slen, batch,
658642
num_key_value_heads_per_partition * n_rep,
659643
head_dim)
644+
645+
def split_tensor(self, mixed_x_layer):
646+
query_layer = mixed_x_layer[:, :, :, :-2, :].reshape(mixed_x_layer.shape[:-1] + (-1, self.hidden_size_per_attention_head))
647+
key_layer = mixed_x_layer[:, :, :, -2, :]
648+
value_layer = mixed_x_layer[:, :, :, -1, :]
649+
650+
return query_layer, key_layer, value_layer
660651

661652
def forward(self, hidden_states, attention_mask,
662653
encoder_output=None, inference_params=None,
@@ -686,45 +677,26 @@ def forward(self, hidden_states, attention_mask,
686677
# Query, Key, and Value
687678
# =====================
688679

689-
if self.attention_type == AttnType.self_attn and not self.use_gqa:
690-
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
680+
if self.attention_type == AttnType.self_attn:
681+
# Attention heads [sq, b, h] --> [sq, b, ((nq + 2 * nkv) * hn)]
691682
mixed_x_layer, _ = self.query_key_value(hidden_states)
692683

693-
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
684+
# [sq, b, ((nq + 2 * nkv) * hn)] --> [sq, b, nkv, (nq // nkv + 2), hn]
694685
new_tensor_shape = mixed_x_layer.size()[:-1] + \
695-
(self.num_attention_heads_per_partition,
696-
3 * self.hidden_size_per_attention_head)
686+
(-1, (self.num_key_value_groups + 2),
687+
self.hidden_size_per_attention_head)
697688
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
698689

699-
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
700-
(query_layer,
690+
# [sq, b, nkv, (nq // nkv + 2), hn] --> 3 [sq, b, np, hn]
691+
(query_layer
701692
key_layer,
702-
value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
703-
elif self.attention_type == AttnType.self_attn and self.use_gqa:
704-
# Attention head [sq, b, h] --> [sq, b, hp]
705-
query_layer, _ = self.query(hidden_states)
706-
# [sq, b, hp] --> [sq, b, np, hn]
707-
new_tensor_shape = query_layer.size()[:-1] + \
708-
(self.num_attention_heads_per_partition,
709-
self.hidden_size_per_attention_head)
710-
query_layer = query_layer.view(*new_tensor_shape)
711-
712-
# Attention heads [sq, b, h] --> [sq, b, (np * 2 * hn)]
713-
mixed_kv_layer, _ = self.key_value(hidden_states)
714-
# [sq, b, (np * 2 * hn)] --> [sq, b, np, 2 * hn]
715-
new_tensor_shape = mixed_kv_layer.size()[:-1] + \
716-
(self.num_key_value_heads_per_partition,
717-
2 * self.hidden_size_per_attention_head)
718-
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
719-
# [sq, b, np, 2 * hn] --> 2 [sq, b, np, hn]
720-
(key_layer,
721-
value_layer) = tensor_parallel.split_tensor_along_last_dim(
722-
mixed_kv_layer, 2)
693+
value_layer) = self.split_tensor(mixed_x_layer)
723694

724695
# Repeat kv
725-
key_layer = self.repeat_kv(key_layer, self.num_key_value_groups)
726-
value_layer = self.repeat_kv(value_layer,
727-
self.num_key_value_groups)
696+
if self.use_gqa:
697+
key_layer = self.repeat_kv(key_layer, self.num_key_value_groups)
698+
value_layer = self.repeat_kv(value_layer,
699+
self.num_key_value_groups)
728700
else:
729701
assert not self.use_gqa, 'GQA + cross-attn not tested yet'
730702

0 commit comments

Comments
 (0)