@@ -143,7 +143,7 @@ def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, block_tables, c
143143 # - cache_ids are used as position_ids of each token
144144 # - start_ids are used as slot_mapping
145145 # - last_token_id is used as new token length for each sequence
146- context_lens_2d = hlo .unsqueeze (context_lens , 1 )
146+ context_lens_2d = hlo .unsqueeze (context_lens , 1 )
147147 seq_lens = hlo .add (context_lens , last_token_id )
148148 block_size = self .neuron_config .continuous_batching .block_size
149149 if self .neuron_config .shard_over_sequence :
@@ -159,7 +159,7 @@ def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, block_tables, c
159159 max_num_keys = (self .num_active_blocks + 1 ) * sharded_block_size
160160 _ , n_active_tokens = cache_ids .sizes
161161 cached_to_contexted , cached_to_contexted_idx , active_to_contexted , sharded_seq_lens = attention_utils .sharded_kv_indexing (
162- seq_lens , last_token_id , cache_ids , max_num_keys , n_active_tokens , block_size , block_tables , core_sos_rank , active_token_mask ,
162+ seq_lens , last_token_id , cache_ids , max_num_keys , n_active_tokens , block_size , block_tables , core_sos_rank , active_token_mask ,
163163 sos_degree = cores_per_kv_head
164164 )
165165 else :
@@ -180,7 +180,7 @@ def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, block_tables, c
180180 rope_scaling = self .config .rope_scaling
181181 )
182182
183- # flash decoding
183+ # flash decoding
184184 if self .neuron_config .shard_over_sequence and not self .neuron_config .enable_chunked_prefill :
185185 cache_ids , mask , active_mask = flash_decoding .convert_attn_mask_and_cache_id (cache_ids , start_ids ,
186186 core_id , self .n_positions ,
@@ -559,9 +559,15 @@ def attention(
559559
560560 if (active_mask is None and not self .neuron_config .enable_chunked_prefill ) and self .neuron_config .shard_over_sequence and self .neuron_config .duplicate_q_weight_sos :
561561 # slice on computed qeury when sos and duplicate Q weights is on
562- kv_replication_constant = core_id .dtype .Constant (constant_value = self .neuron_config .kv_replication )
563- slice_start = hlo .remainder (hlo .reshape (core_id ,[]), kv_replication_constant )
562+
563+ # q / kv -> number of q per core after replication
564+ # core_id % tp/kv -> kv replication degree on cores
565+ # q / tp -> actual q per core before replication
566+ slice_start = hlo .remainder (hlo .reshape (core_id ,[]), core_id .dtype .Constant (constant_value = self .neuron_config .kv_replication ))
564567 slice_size = self .neuron_config .n_head_padded // tp_degree
568+
569+ slice_start = hlo .multiply (slice_start , slice_start .dtype .Constant (constant_value = slice_size ))
570+
565571 query = hlo .dynamic_slice_along (query , 2 , start = slice_start , size = slice_size )
566572
567573 # Q = Rotate(Q)
@@ -681,7 +687,7 @@ def attention(
681687 query = flash_decoding .gather_query_group (query , self .cores_per_kv_head , n_head , tp_degree )
682688 # S = Q @ K (This matmul wastes some computation)
683689 contexted_keys = attention_utils .gather_sharded_kv (cached_keys , active_idx = cached_to_contexted , active_tokens = key , active_token_idx = active_to_contexted )
684- score = attention .score (query , contexted_keys , n_kv_heads = self .config .num_key_value_heads ,
690+ score = attention .score (query , contexted_keys , n_kv_heads = self .config .num_key_value_heads ,
685691 tp_degree = tp_degree , neuron_config = self .neuron_config )
686692 score = attention .mask (score , mask , tp_degree = tp_degree )
687693 # FlashAttention-Style Communication
@@ -698,11 +704,11 @@ def attention(
698704 context = attention .context_combined (score , contexted_values , n_kv_heads = self .config .num_key_value_heads , dtype = score .scribe .f32 ,
699705 tp_degree = tp_degree , neuron_config = self .neuron_config , skip_softmax = True )
700706 # Communication 2: softmax correction
701- context = attention_utils .sharded_softmax_correction (context , max_score_local , l_sum_score_local , core_id , tp_degree = tp_degree ,
707+ context = attention_utils .sharded_softmax_correction (context , max_score_local , l_sum_score_local , core_id , tp_degree = tp_degree ,
702708 sos_degree = self .cores_per_kv_head )
703709 # Communication 3: reduce-scatter partial context
704710 num_groups = tp_degree // self .cores_per_kv_head
705- replica_groups = utils .build_replica_groups (num_groups = num_groups , group_size = self .cores_per_kv_head ,
711+ replica_groups = utils .build_replica_groups (num_groups = num_groups , group_size = self .cores_per_kv_head ,
706712 interleave = False )
707713 context = hlo .reduce_scatter_sum (context , tp_degree = self .cores_per_kv_head , dim = 2 , replica_groups = replica_groups )
708714 context = hlo .cast (context , hidden .dtype )
0 commit comments