diff --git a/deepspeed/runtime/domino/transformer.py b/deepspeed/runtime/domino/transformer.py index 88c5494c8147..089b9fd13e59 100644 --- a/deepspeed/runtime/domino/transformer.py +++ b/deepspeed/runtime/domino/transformer.py @@ -109,19 +109,25 @@ def __init__(self, config, layer_number, mpu, attn_mask_type=AttnMaskType.causal self.layer_number = max(1, layer_number) self.att_dropout_p = config.attention_dropout self.is_causal = True - projection_size = config.kv_channels * config.num_attention_heads + self.projection_size = config.kv_channels * config.num_attention_heads world_size = mpu.get_tensor_model_parallel_world_size() - self.hidden_size_per_partition = projection_size // world_size + self.hidden_size_per_partition = self.projection_size // world_size def forward(self, query_layer, key_layer, value_layer, attention_mask): # attn_mask is None when is_causal=True - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, - key_layer, - value_layer, - attn_mask=None, - dropout_p=self.att_dropout_p, - is_causal=True, - scale=None) + query_layer = query_layer.permute(1, 2, 0, 3).contiguous() + key_layer = key_layer.permute(1, 2, 0, 3).contiguous() + value_layer = value_layer.permute(1, 2, 0, 3).contiguous() + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=None, + dropout_p=self.att_dropout_p, + is_causal=True, + scale=None, + enable_gqa=True, + ) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() @@ -149,19 +155,30 @@ def __init__(self, attention_type=AttnType.self_attn, attn_mask_type=AttnMaskType.causal): super(ShardedAttention, self).__init__() + self.config = config self.layer_number = max(1, layer_number) self.attention_type = attention_type self.attn_mask_type = attn_mask_type self.params_dtype = config.params_dtype self.apply_rotary_pos_emb = apply_rotary_pos_emb + self.group_query_attention = config.group_query_attention + self.num_query_groups = config.num_query_groups query_projection_size = config.kv_channels * config.num_attention_heads - kv_projection_size = config.kv_channels * config.num_attention_heads + print(f'self.group_query_attention: {self.group_query_attention}') + if self.group_query_attention: + kv_projection_size = config.kv_channels * config.num_query_groups + else: + kv_projection_size = config.kv_channels * config.num_attention_heads # Per attention head and per partition values. world_size = mpu.get_tensor_model_parallel_world_size() self.hidden_size_per_attention_head = query_projection_size // config.num_attention_heads self.num_attention_heads_per_partition = config.num_attention_heads // world_size + if self.group_query_attention: + self.num_query_groups_per_partition = config.num_query_groups // world_size + else: + self.num_query_groups_per_partition = self.num_attention_heads_per_partition self.query_key_value = ColumnParallelLinear(config.hidden_size, query_projection_size + 2 * kv_projection_size, @@ -189,23 +206,36 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): # [s, b, np * 3 * hn] --> [s, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head, + self.num_query_groups_per_partition, + ((self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) * + self.hidden_size_per_attention_head), ) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [s, b, np, 3 * hn] -> [b, np, s, 3*hn] - mixed_x_layer = mixed_x_layer.permute(1, 2, 0, 3).contiguous() - - # [s, b, np, 3 * hn] --> [s, b, np, hn], [s, b, np, hn], [s, b, np, hn] - (query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, [ - self.hidden_size_per_attention_head, self.hidden_size_per_attention_head, - self.hidden_size_per_attention_head - ], - dim=3) - # [s, b, np, np * hn] -> [s, b, np, hn] - query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, - self.hidden_size_per_attention_head) + # mixed_x_layer = mixed_x_layer.permute(1, 2, 0, 3).contiguous() + + # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query_layer, key_layer, + value_layer) = torch.split(mixed_x_layer, + [(self.num_attention_heads_per_partition // self.num_query_groups_per_partition * + self.hidden_size_per_attention_head), self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head], + dim=3) + + # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] - + query_layer = query_layer.reshape(query_layer.size(0), query_layer.size(1), -1, + self.hidden_size_per_attention_head) + + # if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1: + # key_layer = key_layer.repeat_interleave( + # self.num_attention_heads_per_partition // self.num_query_groups_per_partition, + # dim = 2 + # ) + # value_layer = value_layer.repeat_interleave( + # self.num_attention_heads_per_partition // self.num_query_groups_per_partition, + # dim = 2 + # ) # apply rotary embedding if rotary_pos_emb is not None: @@ -214,8 +244,8 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): else: rotary_pos_emb = ((rotary_pos_emb, ) * 2) q_pos_emb, k_pos_emb = rotary_pos_emb - query_layer = self.apply_rotary_pos_emb(query_layer, q_pos_emb) - key_layer = self.apply_rotary_pos_emb(key_layer, k_pos_emb) + query_layer = self.apply_rotary_pos_emb(query_layer, q_pos_emb, self.config) + key_layer = self.apply_rotary_pos_emb(key_layer, k_pos_emb, self.config) context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) @@ -234,7 +264,7 @@ def __init__(self, config, layer_number, mpu, - fused_layer_norm, + get_norm, _initialize_affine_weight_gpu, ColumnParallelLinear, RowParallelLinearNoComm, @@ -252,6 +282,7 @@ def __init__(self, dist.init_distributed() assert dist.is_initialized(), "deepspeed.comm is not initialized!" + self.config = config self.llama_model = config.llama_model self.layer_number = layer_number self.layer_type = layer_type @@ -262,9 +293,7 @@ def __init__(self, self.output_bias = output_bias # Layernorm on the input data. - self.input_layernorm = fused_layer_norm(config.hidden_size, - eps=config.layernorm_epsilon, - no_persist_layer_norm=config.no_persist_layer_norm) + self.input_layernorm = get_norm(config) # Self attention. self.self_attention = ShardedAttention(config, @@ -279,9 +308,7 @@ def __init__(self, self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output - self.post_attention_layernorm = fused_layer_norm(config.hidden_size, - eps=config.layernorm_epsilon, - no_persist_layer_norm=config.no_persist_layer_norm) + self.post_attention_layernorm = get_norm(config) # ------------ init mlp start ------------ init_method = config.init_method @@ -457,7 +484,7 @@ def __init__(self, config, model_type, mpu, - fused_layer_norm, + get_norm, _initialize_affine_weight_gpu, ColumnParallelLinear, RowParallelLinearNoComm, @@ -484,7 +511,7 @@ def build_layer(layer_number): return DominoTransformerLayer(config, layer_number, mpu, - fused_layer_norm, + get_norm, _initialize_affine_weight_gpu, ColumnParallelLinear, RowParallelLinearNoComm, @@ -498,9 +525,7 @@ def build_layer(layer_number): self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) if self.post_process and self.post_layer_norm: - self.final_layernorm = fused_layer_norm(config.hidden_size, - eps=config.layernorm_epsilon, - no_persist_layer_norm=config.no_persist_layer_norm) + self.final_layernorm = get_norm(config) def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): # hidden_states: [s, b, h]