Skip to content
Draft
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 62 additions & 37 deletions deepspeed/runtime/domino/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,19 +109,25 @@ def __init__(self, config, layer_number, mpu, attn_mask_type=AttnMaskType.causal
self.layer_number = max(1, layer_number)
self.att_dropout_p = config.attention_dropout
self.is_causal = True
projection_size = config.kv_channels * config.num_attention_heads
self.projection_size = config.kv_channels * config.num_attention_heads
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = projection_size // world_size
self.hidden_size_per_partition = self.projection_size // world_size

def forward(self, query_layer, key_layer, value_layer, attention_mask):
# attn_mask is None when is_causal=True
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
key_layer,
value_layer,
attn_mask=None,
dropout_p=self.att_dropout_p,
is_causal=True,
scale=None)
query_layer = query_layer.permute(1, 2, 0, 3).contiguous()
key_layer = key_layer.permute(1, 2, 0, 3).contiguous()
value_layer = value_layer.permute(1, 2, 0, 3).contiguous()
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=None,
dropout_p=self.att_dropout_p,
is_causal=True,
scale=None,
enable_gqa=True,
)

# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
Expand Down Expand Up @@ -149,19 +155,30 @@ def __init__(self,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.causal):
super(ShardedAttention, self).__init__()
self.config = config
self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.params_dtype = config.params_dtype
self.apply_rotary_pos_emb = apply_rotary_pos_emb
self.group_query_attention = config.group_query_attention
self.num_query_groups = config.num_query_groups

query_projection_size = config.kv_channels * config.num_attention_heads
kv_projection_size = config.kv_channels * config.num_attention_heads
print(f'self.group_query_attention: {self.group_query_attention}')
if self.group_query_attention:
kv_projection_size = config.kv_channels * config.num_query_groups
else:
kv_projection_size = config.kv_channels * config.num_attention_heads

# Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = query_projection_size // config.num_attention_heads
self.num_attention_heads_per_partition = config.num_attention_heads // world_size
if self.group_query_attention:
self.num_query_groups_per_partition = config.num_query_groups // world_size
else:
self.num_query_groups_per_partition = self.num_attention_heads_per_partition

self.query_key_value = ColumnParallelLinear(config.hidden_size,
query_projection_size + 2 * kv_projection_size,
Expand Down Expand Up @@ -189,23 +206,36 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):

# [s, b, np * 3 * hn] --> [s, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
self.num_query_groups_per_partition,
((self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) *
self.hidden_size_per_attention_head),
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

# [s, b, np, 3 * hn] -> [b, np, s, 3*hn]
mixed_x_layer = mixed_x_layer.permute(1, 2, 0, 3).contiguous()

# [s, b, np, 3 * hn] --> [s, b, np, hn], [s, b, np, hn], [s, b, np, hn]
(query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, [
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head
],
dim=3)
# [s, b, np, np * hn] -> [s, b, np, hn]
query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1,
self.hidden_size_per_attention_head)
# mixed_x_layer = mixed_x_layer.permute(1, 2, 0, 3).contiguous()

# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query_layer, key_layer,
value_layer) = torch.split(mixed_x_layer,
[(self.num_attention_heads_per_partition // self.num_query_groups_per_partition *
self.hidden_size_per_attention_head), self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head],
dim=3)

# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] -
query_layer = query_layer.reshape(query_layer.size(0), query_layer.size(1), -1,
self.hidden_size_per_attention_head)

# if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1:
# key_layer = key_layer.repeat_interleave(
# self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
# dim = 2
# )
# value_layer = value_layer.repeat_interleave(
# self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
# dim = 2
# )

# apply rotary embedding
if rotary_pos_emb is not None:
Expand All @@ -214,8 +244,8 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
else:
rotary_pos_emb = ((rotary_pos_emb, ) * 2)
q_pos_emb, k_pos_emb = rotary_pos_emb
query_layer = self.apply_rotary_pos_emb(query_layer, q_pos_emb)
key_layer = self.apply_rotary_pos_emb(key_layer, k_pos_emb)
query_layer = self.apply_rotary_pos_emb(query_layer, q_pos_emb, self.config)
key_layer = self.apply_rotary_pos_emb(key_layer, k_pos_emb, self.config)

context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)

Expand All @@ -234,7 +264,7 @@ def __init__(self,
config,
layer_number,
mpu,
fused_layer_norm,
get_norm,
_initialize_affine_weight_gpu,
ColumnParallelLinear,
RowParallelLinearNoComm,
Expand All @@ -252,6 +282,7 @@ def __init__(self,
dist.init_distributed()
assert dist.is_initialized(), "deepspeed.comm is not initialized!"

self.config = config
self.llama_model = config.llama_model
self.layer_number = layer_number
self.layer_type = layer_type
Expand All @@ -262,9 +293,7 @@ def __init__(self,
self.output_bias = output_bias

# Layernorm on the input data.
self.input_layernorm = fused_layer_norm(config.hidden_size,
eps=config.layernorm_epsilon,
no_persist_layer_norm=config.no_persist_layer_norm)
self.input_layernorm = get_norm(config)

# Self attention.
self.self_attention = ShardedAttention(config,
Expand All @@ -279,9 +308,7 @@ def __init__(self,
self.hidden_dropout = config.hidden_dropout

# Layernorm on the attention output
self.post_attention_layernorm = fused_layer_norm(config.hidden_size,
eps=config.layernorm_epsilon,
no_persist_layer_norm=config.no_persist_layer_norm)
self.post_attention_layernorm = get_norm(config)

# ------------ init mlp start ------------
init_method = config.init_method
Expand Down Expand Up @@ -457,7 +484,7 @@ def __init__(self,
config,
model_type,
mpu,
fused_layer_norm,
get_norm,
_initialize_affine_weight_gpu,
ColumnParallelLinear,
RowParallelLinearNoComm,
Expand All @@ -484,7 +511,7 @@ def build_layer(layer_number):
return DominoTransformerLayer(config,
layer_number,
mpu,
fused_layer_norm,
get_norm,
_initialize_affine_weight_gpu,
ColumnParallelLinear,
RowParallelLinearNoComm,
Expand All @@ -498,9 +525,7 @@ def build_layer(layer_number):
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])

if self.post_process and self.post_layer_norm:
self.final_layernorm = fused_layer_norm(config.hidden_size,
eps=config.layernorm_epsilon,
no_persist_layer_norm=config.no_persist_layer_norm)
self.final_layernorm = get_norm(config)

def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
# hidden_states: [s, b, h]
Expand Down
Loading