Skip to content

Commit 51e8784

Browse files
kgopalswaws-yishanm
authored andcommitted
TP 128 draft fix
GitOrigin-RevId: b68155341b188362a4cb7137d2724fb461a9a50a
1 parent e24fd96 commit 51e8784

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

src/transformers_neuronx/layers/flash_decoding.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def gather_query_group(query, cores_per_kv_head, n_heads, tp_degree):
2525
# Communication 1: all-gather query from cores
2626
# Notice that this is not necessary for context encoding because we don't read from the KV cache
2727
cores_per_q_head = tp_degree // n_heads
28-
group_size = cores_per_kv_head // cores_per_q_head if cores_per_q_head else cores_per_kv_head
28+
group_size = cores_per_kv_head # note this cores per kv head is already divide by cores_per_q_head
2929
num_groups = tp_degree // group_size
3030
interleave=False
3131
n_kv_heads = tp_degree // cores_per_kv_head
@@ -61,9 +61,11 @@ def context(past_scores, active_score, past_values, active_values,
6161
# How many cores should compute each head collectively
6262
# All cores that hold the KV cache for the same head should communicate here
6363
cores_per_kv_head = tp_degree // n_kv_heads
64+
cores_per_q_head = tp_degree // n_heads
65+
cores_per_kv_head = cores_per_kv_head // cores_per_q_head if cores_per_q_head else cores_per_kv_head
6466
if cores_per_kv_head > 1:
6567
group_size = cores_per_kv_head
66-
num_groups = n_kv_heads
68+
num_groups = tp_degree // group_size
6769
else:
6870
# MHA case, assume all cores will have all heads in cache and kv sharded by seq
6971
num_groups = 1
@@ -164,7 +166,7 @@ def context(past_scores, active_score, past_values, active_values,
164166

165167
# Communication 3: send the results of other Q heads back to their corresponding cores
166168
# Also gather the results of the current Q head from other cores
167-
assert output.sizes[1] == group_size*size , f"n_heads {n_heads} after gather not matching kv_replication x n_heads_tp {group_size}x {size}"
169+
assert output.sizes[1] == group_size*size , f"n_heads {output.sizes[1]} after gather not matching kv_replication x n_heads_tp {group_size}x {size}"
168170
apply_fn = hlo.gen_add_func(output.dtype)
169171
output = hlo.reduce_scatter(output, dim=1, replica_groups=replica_groups, to_apply=apply_fn)
170172
assert output.sizes[1] == size , f"n_heads post scatter size mismatch, check replica_groups {replica_groups}"
@@ -174,8 +176,8 @@ def context(past_scores, active_score, past_values, active_values,
174176
# multiplied with its corresponding weights, and then an all-reduce is used to sum
175177
# results for all heads together.
176178
# We need a scaling here because multiple cores hold the same result
177-
if cores_per_q_head:
178-
output = hlo.divide(output, cores_per_q_head)
179+
#if cores_per_q_head: # we do zero padding now, so enable once replication is done
180+
# output = hlo.divide(output, cores_per_q_head)
179181
return output
180182

181183

@@ -291,4 +293,4 @@ def select_values_within_bound(cache_ids, values, keys, cores_per_kv_head, core_
291293
keys = hlo.slice_along(keys, dim=dim,limit=slice_size,stride=stride)
292294
cache_ids = hlo.slice_along(cache_ids, dim=cache_dim,limit=cache_slice_size, stride=stride)
293295

294-
return cache_ids, values, keys
296+
return cache_ids, values, keys

src/transformers_neuronx/llama/hlo.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, block_tables, c
124124
n_kv_heads = self.config.num_key_value_heads if hasattr(self.config, "num_key_value_heads") else self.config.num_attention_heads
125125
cores_per_kv_head = self.config.tp_degree // n_kv_heads
126126
self.cores_per_kv_head = cores_per_kv_head if cores_per_kv_head > 1 else self.config.tp_degree
127+
cores_per_q_head = self.config.tp_degree // self.config.num_attention_heads
128+
self.cores_per_kv_head = self.cores_per_kv_head // cores_per_q_head if cores_per_q_head else self.cores_per_kv_head
127129
if self.neuron_config.optimized_paged_attention and len(last_token_id.sizes) == 2:
128130
# For decoding with multiple KV cache blocks:
129131
# - cache_ids are used as context_lens
@@ -532,7 +534,7 @@ def attention(
532534
if self.config.num_key_value_heads is not None:
533535
n_head = self.config.num_attention_heads
534536
n_kv_head = self.config.num_key_value_heads
535-
n_head, n_kv_head_padded = utils.get_qkv_padding(n_head, n_kv_head, tp_degree, self.neuron_config)
537+
n_head_padded, n_kv_head_padded = utils.get_qkv_padding(n_head, n_kv_head, tp_degree, self.neuron_config)
536538
n_kv_heads_tp = n_kv_head_padded // tp_degree
537539

538540
# Q = (hidden @ wQ) + bQ
@@ -748,9 +750,9 @@ def attention(
748750

749751
# O = (C @ wO) + bO
750752
output = attention.output(context, out_weight, out_scales, out_bias, tp_degree, self.neuron_config)
751-
cores_per_attn_head = tp_degree // self.config.num_attention_heads
752-
if cores_per_attn_head and not self.neuron_config.shard_over_sequence:
753-
output = hlo.divide(output, cores_per_attn_head)
753+
# we do zero padding so disable now
754+
# if cores_per_attn_head and not self.neuron_config.shard_over_sequence:
755+
# output = hlo.divide(output, cores_per_attn_head)
754756
return output, updated_keys, updated_values
755757

756758

0 commit comments

Comments
 (0)