Skip to content

Commit d659cce

Browse files
aws-bowenccaws-yishanm
authored andcommitted
fix duplicate q when q/tp!=1
GitOrigin-RevId: bc5163b8e710ac3076650a3e391d7cd4f6e4feca
1 parent 66d4231 commit d659cce

File tree

1 file changed

+14
-8
lines changed
  • src/transformers_neuronx/llama

1 file changed

+14
-8
lines changed

src/transformers_neuronx/llama/hlo.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, block_tables, c
143143
# - cache_ids are used as position_ids of each token
144144
# - start_ids are used as slot_mapping
145145
# - last_token_id is used as new token length for each sequence
146-
context_lens_2d = hlo.unsqueeze(context_lens, 1)
146+
context_lens_2d = hlo.unsqueeze(context_lens, 1)
147147
seq_lens = hlo.add(context_lens, last_token_id)
148148
block_size = self.neuron_config.continuous_batching.block_size
149149
if self.neuron_config.shard_over_sequence:
@@ -159,7 +159,7 @@ def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, block_tables, c
159159
max_num_keys = (self.num_active_blocks + 1) * sharded_block_size
160160
_, n_active_tokens = cache_ids.sizes
161161
cached_to_contexted, cached_to_contexted_idx, active_to_contexted, sharded_seq_lens = attention_utils.sharded_kv_indexing(
162-
seq_lens, last_token_id, cache_ids, max_num_keys, n_active_tokens, block_size, block_tables, core_sos_rank, active_token_mask,
162+
seq_lens, last_token_id, cache_ids, max_num_keys, n_active_tokens, block_size, block_tables, core_sos_rank, active_token_mask,
163163
sos_degree=cores_per_kv_head
164164
)
165165
else:
@@ -180,7 +180,7 @@ def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, block_tables, c
180180
rope_scaling=self.config.rope_scaling
181181
)
182182

183-
# flash decoding
183+
# flash decoding
184184
if self.neuron_config.shard_over_sequence and not self.neuron_config.enable_chunked_prefill:
185185
cache_ids, mask, active_mask = flash_decoding.convert_attn_mask_and_cache_id(cache_ids, start_ids,
186186
core_id, self.n_positions,
@@ -559,9 +559,15 @@ def attention(
559559

560560
if (active_mask is None and not self.neuron_config.enable_chunked_prefill) and self.neuron_config.shard_over_sequence and self.neuron_config.duplicate_q_weight_sos:
561561
# slice on computed qeury when sos and duplicate Q weights is on
562-
kv_replication_constant = core_id.dtype.Constant(constant_value=self.neuron_config.kv_replication)
563-
slice_start = hlo.remainder(hlo.reshape(core_id,[]), kv_replication_constant)
562+
563+
# q / kv -> number of q per core after replication
564+
# core_id % tp/kv -> kv replication degree on cores
565+
# q / tp -> actual q per core before replication
566+
slice_start = hlo.remainder(hlo.reshape(core_id,[]), core_id.dtype.Constant(constant_value=self.neuron_config.kv_replication))
564567
slice_size = self.neuron_config.n_head_padded // tp_degree
568+
569+
slice_start = hlo.multiply(slice_start, slice_start.dtype.Constant(constant_value=slice_size))
570+
565571
query = hlo.dynamic_slice_along(query, 2, start=slice_start, size=slice_size)
566572

567573
# Q = Rotate(Q)
@@ -681,7 +687,7 @@ def attention(
681687
query = flash_decoding.gather_query_group(query, self.cores_per_kv_head, n_head, tp_degree)
682688
# S = Q @ K (This matmul wastes some computation)
683689
contexted_keys = attention_utils.gather_sharded_kv(cached_keys, active_idx=cached_to_contexted, active_tokens=key, active_token_idx=active_to_contexted)
684-
score = attention.score(query, contexted_keys, n_kv_heads=self.config.num_key_value_heads,
690+
score = attention.score(query, contexted_keys, n_kv_heads=self.config.num_key_value_heads,
685691
tp_degree=tp_degree, neuron_config=self.neuron_config)
686692
score = attention.mask(score, mask, tp_degree=tp_degree)
687693
# FlashAttention-Style Communication
@@ -698,11 +704,11 @@ def attention(
698704
context = attention.context_combined(score, contexted_values, n_kv_heads=self.config.num_key_value_heads, dtype=score.scribe.f32,
699705
tp_degree=tp_degree, neuron_config=self.neuron_config, skip_softmax=True)
700706
# Communication 2: softmax correction
701-
context = attention_utils.sharded_softmax_correction(context, max_score_local, l_sum_score_local, core_id, tp_degree=tp_degree,
707+
context = attention_utils.sharded_softmax_correction(context, max_score_local, l_sum_score_local, core_id, tp_degree=tp_degree,
702708
sos_degree=self.cores_per_kv_head)
703709
# Communication 3: reduce-scatter partial context
704710
num_groups = tp_degree // self.cores_per_kv_head
705-
replica_groups = utils.build_replica_groups(num_groups=num_groups, group_size=self.cores_per_kv_head,
711+
replica_groups = utils.build_replica_groups(num_groups=num_groups, group_size=self.cores_per_kv_head,
706712
interleave=False)
707713
context = hlo.reduce_scatter_sum(context, tp_degree=self.cores_per_kv_head, dim=2, replica_groups=replica_groups)
708714
context = hlo.cast(context, hidden.dtype)

0 commit comments

Comments
 (0)