@@ -25,7 +25,7 @@ def gather_query_group(query, cores_per_kv_head, n_heads, tp_degree):
2525 # Communication 1: all-gather query from cores
2626 # Notice that this is not necessary for context encoding because we don't read from the KV cache
2727 cores_per_q_head = tp_degree // n_heads
28- group_size = cores_per_kv_head // cores_per_q_head if cores_per_q_head else cores_per_kv_head
28+ group_size = cores_per_kv_head # note this cores per kv head is already divide by cores_per_q_head
2929 num_groups = tp_degree // group_size
3030 interleave = False
3131 n_kv_heads = tp_degree // cores_per_kv_head
@@ -61,9 +61,11 @@ def context(past_scores, active_score, past_values, active_values,
6161 # How many cores should compute each head collectively
6262 # All cores that hold the KV cache for the same head should communicate here
6363 cores_per_kv_head = tp_degree // n_kv_heads
64+ cores_per_q_head = tp_degree // n_heads
65+ cores_per_kv_head = cores_per_kv_head // cores_per_q_head if cores_per_q_head else cores_per_kv_head
6466 if cores_per_kv_head > 1 :
6567 group_size = cores_per_kv_head
66- num_groups = n_kv_heads
68+ num_groups = tp_degree // group_size
6769 else :
6870 # MHA case, assume all cores will have all heads in cache and kv sharded by seq
6971 num_groups = 1
@@ -164,7 +166,7 @@ def context(past_scores, active_score, past_values, active_values,
164166
165167 # Communication 3: send the results of other Q heads back to their corresponding cores
166168 # Also gather the results of the current Q head from other cores
167- assert output .sizes [1 ] == group_size * size , f"n_heads { n_heads } after gather not matching kv_replication x n_heads_tp { group_size } x { size } "
169+ assert output .sizes [1 ] == group_size * size , f"n_heads { output . sizes [ 1 ] } after gather not matching kv_replication x n_heads_tp { group_size } x { size } "
168170 apply_fn = hlo .gen_add_func (output .dtype )
169171 output = hlo .reduce_scatter (output , dim = 1 , replica_groups = replica_groups , to_apply = apply_fn )
170172 assert output .sizes [1 ] == size , f"n_heads post scatter size mismatch, check replica_groups { replica_groups } "
@@ -174,8 +176,8 @@ def context(past_scores, active_score, past_values, active_values,
174176 # multiplied with its corresponding weights, and then an all-reduce is used to sum
175177 # results for all heads together.
176178 # We need a scaling here because multiple cores hold the same result
177- if cores_per_q_head :
178- output = hlo .divide (output , cores_per_q_head )
179+ # if cores_per_q_head: # we do zero padding now, so enable once replication is done
180+ # output = hlo.divide(output, cores_per_q_head)
179181 return output
180182
181183
@@ -291,4 +293,4 @@ def select_values_within_bound(cache_ids, values, keys, cores_per_kv_head, core_
291293 keys = hlo .slice_along (keys , dim = dim ,limit = slice_size ,stride = stride )
292294 cache_ids = hlo .slice_along (cache_ids , dim = cache_dim ,limit = cache_slice_size , stride = stride )
293295
294- return cache_ids , values , keys
296+ return cache_ids , values , keys
0 commit comments