@@ -532,39 +532,23 @@ def __init__(self, config, layer_number,
532
532
config .num_attention_heads , world_size )
533
533
534
534
# Per GQA head and per partition values
535
- if self .use_gqa :
536
- kv_projection_size = config .kv_channels * config .num_key_value_heads
537
- self .num_key_value_heads_per_partition = core .utils .divide (
538
- config .num_key_value_heads , world_size )
539
- self .num_key_value_groups = core .utils .divide (
540
- config .num_attention_heads , config .num_key_value_heads )
541
- assert self .hidden_size_per_attention_head == core .utils .divide (
542
- kv_projection_size , config .num_key_value_heads )
535
+ self .num_key_value_heads_per_partition = core .utils .divide (
536
+ config .num_key_value_heads , world_size )
537
+ self .num_key_value_groups = core .utils .divide (
538
+ config .num_attention_heads , config .num_key_value_heads )
539
+ kv_projection_size = config .kv_channels * config .num_key_value_heads
540
+ assert self .hidden_size_per_attention_head == core .utils .divide (
541
+ kv_projection_size , config .num_key_value_heads )
543
542
544
543
# Strided linear layer.
545
- if attention_type == AttnType .self_attn and not self . use_gqa :
544
+ if attention_type == AttnType .self_attn :
546
545
self .query_key_value = tensor_parallel .ColumnParallelLinear (
547
546
config .hidden_size ,
548
- 3 * projection_size ,
547
+ projection_size + 2 * kv_projection_size ,
549
548
config = config ,
550
549
init_method = config .init_method ,
551
550
bias = args .add_bias_linear ,
552
551
gather_output = False )
553
- elif attention_type == AttnType .self_attn and self .use_gqa :
554
- self .query = tensor_parallel .ColumnParallelLinear (
555
- config .hidden_size ,
556
- projection_size ,
557
- config = config ,
558
- init_method = config .init_method ,
559
- bias = config .add_bias_linear ,
560
- gather_output = False )
561
- self .key_value = tensor_parallel .ColumnParallelLinear (
562
- config .hidden_size ,
563
- 2 * kv_projection_size ,
564
- config = config ,
565
- init_method = config .init_method ,
566
- bias = config .add_bias_linear ,
567
- gather_output = False )
568
552
else :
569
553
assert attention_type == AttnType .cross_attn
570
554
self .query = tensor_parallel .ColumnParallelLinear (
@@ -657,6 +641,13 @@ def repeat_kv(self, hidden_states, n_rep):
657
641
return hidden_states .reshape (slen , batch ,
658
642
num_key_value_heads_per_partition * n_rep ,
659
643
head_dim )
644
+
645
+ def split_tensor (self , mixed_x_layer ):
646
+ query_layer = mixed_x_layer [:, :, :, :- 2 , :].reshape (mixed_x_layer .shape [:- 1 ] + (- 1 , self .hidden_size_per_attention_head ))
647
+ key_layer = mixed_x_layer [:, :, :, - 2 , :]
648
+ value_layer = mixed_x_layer [:, :, :, - 1 , :]
649
+
650
+ return query_layer , key_layer , value_layer
660
651
661
652
def forward (self , hidden_states , attention_mask ,
662
653
encoder_output = None , inference_params = None ,
@@ -686,45 +677,26 @@ def forward(self, hidden_states, attention_mask,
686
677
# Query, Key, and Value
687
678
# =====================
688
679
689
- if self .attention_type == AttnType .self_attn and not self . use_gqa :
690
- # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
680
+ if self .attention_type == AttnType .self_attn :
681
+ # Attention heads [sq, b, h] --> [sq, b, ((nq + 2 * nkv) * hn)]
691
682
mixed_x_layer , _ = self .query_key_value (hidden_states )
692
683
693
- # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
684
+ # [sq, b, ((nq + 2 * nkv) * hn)] --> [sq, b, nkv, (nq // nkv + 2), hn]
694
685
new_tensor_shape = mixed_x_layer .size ()[:- 1 ] + \
695
- (self .num_attention_heads_per_partition ,
696
- 3 * self .hidden_size_per_attention_head )
686
+ (- 1 , ( self .num_key_value_groups + 2 ) ,
687
+ self .hidden_size_per_attention_head )
697
688
mixed_x_layer = mixed_x_layer .view (* new_tensor_shape )
698
689
699
- # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
700
- (query_layer ,
690
+ # [sq, b, nkv, (nq // nkv + 2), hn] --> 3 [sq, b, np, hn]
691
+ (query_layer
701
692
key_layer ,
702
- value_layer ) = tensor_parallel .split_tensor_along_last_dim (mixed_x_layer , 3 )
703
- elif self .attention_type == AttnType .self_attn and self .use_gqa :
704
- # Attention head [sq, b, h] --> [sq, b, hp]
705
- query_layer , _ = self .query (hidden_states )
706
- # [sq, b, hp] --> [sq, b, np, hn]
707
- new_tensor_shape = query_layer .size ()[:- 1 ] + \
708
- (self .num_attention_heads_per_partition ,
709
- self .hidden_size_per_attention_head )
710
- query_layer = query_layer .view (* new_tensor_shape )
711
-
712
- # Attention heads [sq, b, h] --> [sq, b, (np * 2 * hn)]
713
- mixed_kv_layer , _ = self .key_value (hidden_states )
714
- # [sq, b, (np * 2 * hn)] --> [sq, b, np, 2 * hn]
715
- new_tensor_shape = mixed_kv_layer .size ()[:- 1 ] + \
716
- (self .num_key_value_heads_per_partition ,
717
- 2 * self .hidden_size_per_attention_head )
718
- mixed_kv_layer = mixed_kv_layer .view (* new_tensor_shape )
719
- # [sq, b, np, 2 * hn] --> 2 [sq, b, np, hn]
720
- (key_layer ,
721
- value_layer ) = tensor_parallel .split_tensor_along_last_dim (
722
- mixed_kv_layer , 2 )
693
+ value_layer ) = self .split_tensor (mixed_x_layer )
723
694
724
695
# Repeat kv
725
- key_layer = self .repeat_kv (key_layer , self .num_key_value_groups )
726
- value_layer = self .repeat_kv (value_layer ,
727
- self .num_key_value_groups )
696
+ if self .use_gqa :
697
+ key_layer = self .repeat_kv (key_layer , self .num_key_value_groups )
698
+ value_layer = self .repeat_kv (value_layer ,
699
+ self .num_key_value_groups )
728
700
else :
729
701
assert not self .use_gqa , 'GQA + cross-attn not tested yet'
730
702
0 commit comments