@@ -66,20 +66,12 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
6666 CompressTraits<BF16>::Compress (df, q, qkv_dim, tls, MakeSpan (q_bf, qkv_dim),
6767 0 );
6868
69- if (HWY_LIKELY (last_pos < static_cast <size_t >(div_seq_len.GetDivisor ()))) {
70- // Slightly faster: no wraparound.
71- for (size_t pos = start_pos; pos <= last_pos; ++pos) {
72- const float score =
73- Dot (dbf, MakeConstSpan (q_bf, qkv_dim), 0 , k.Row (pos), qkv_dim);
74- att[pos] = score;
75- }
76- } else {
77- for (size_t pos = start_pos; pos <= last_pos; ++pos) {
78- const size_t pos_modulo = div_seq_len.Remainder (pos);
79- const float score =
80- Dot (dbf, MakeConstSpan (q_bf, qkv_dim), 0 , k.Row (pos_modulo), qkv_dim);
81- att[pos_modulo] = score;
82- }
69+ // --seq_len must be large enough to avoid wraparound.
70+ HWY_DASSERT (last_pos < static_cast <size_t >(div_seq_len.GetDivisor ()));
71+ for (size_t pos = start_pos; pos <= last_pos; ++pos) {
72+ const float score =
73+ Dot (dbf, MakeConstSpan (q_bf, qkv_dim), 0 , k.Row (pos), qkv_dim);
74+ att[pos] = score;
8375 }
8476}
8577
@@ -114,25 +106,13 @@ static HWY_INLINE void WeightedSumV(
114106 const hwy::Divisor& div_seq_len, const float * HWY_RESTRICT att,
115107 const MatPtrT<KV_t>& v, float * HWY_RESTRICT att_out, ThreadingContext& ctx,
116108 const size_t worker) {
117- if (HWY_LIKELY (last_pos < static_cast <size_t >(div_seq_len.GetDivisor ()))) {
118- // Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
119- // we supported non-transposed B.
120- // TODO: 2..4x unroll
121- MulByConstTo (att[start_pos], v.Row (start_pos), att_out, v.Cols (), ctx,
122- worker);
123- for (size_t pos = start_pos + 1 ; pos <= last_pos; ++pos) {
124- MulByConstAndAdd (att[pos], v.Row (pos), att_out, v.Cols ());
125- }
126- } else {
127- {
128- const size_t pos_mod = div_seq_len.Remainder (start_pos);
129- MulByConstTo (att[pos_mod], v.Row (pos_mod), att_out, v.Cols (), ctx,
130- worker);
131- }
132- for (size_t pos = start_pos + 1 ; pos <= last_pos; ++pos) {
133- const size_t pos_mod = div_seq_len.Remainder (pos);
134- MulByConstAndAdd (att[pos_mod], v.Row (pos_mod), att_out, v.Cols ());
135- }
109+ // --seq_len must be large enough to avoid wraparound.
110+ HWY_DASSERT (last_pos < static_cast <size_t >(div_seq_len.GetDivisor ()));
111+ // TODO: replace with MatMul(att, v) after it supports non-transposed B.
112+ MulByConstTo (att[start_pos], v.Row (start_pos), att_out, v.Cols (), ctx,
113+ worker);
114+ for (size_t pos = start_pos + 1 ; pos <= last_pos; ++pos) {
115+ MulByConstAndAdd (att[pos], v.Row (pos), att_out, v.Cols ());
136116 }
137117}
138118
@@ -146,9 +126,10 @@ void SingleDotSoftmaxWeightedSum(
146126 float * HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
147127 const float att_cap = activations.config .att_cap ;
148128 const float query_scale = activations.query_scale ;
149- const size_t seq_len =
150- static_cast < size_t >( activations.div_seq_len . GetDivisor ());
129+ // -- seq_len must be large enough to avoid wraparound.
130+ HWY_DASSERT (last_pos < activations.SeqLen ());
151131 const LayerConfig& layer_config = activations.config .layer_configs [layer_idx];
132+
152133 // Apply rope and scaling to Q.
153134 if (query_norm_scale.HasPtr ()) {
154135 CallUpcasted (&query_norm_scale, [&](const auto * weights_t ) {
@@ -163,8 +144,7 @@ void SingleDotSoftmaxWeightedSum(
163144 QDotK (start_pos, last_pos, activations.div_seq_len , q, k, att, ctx, worker);
164145
165146 // SoftMax with optional SoftCap yields "probabilities" in att.
166- const size_t att_len = HWY_MIN (last_pos + 1 , seq_len);
167- const Logits logits (att, att_len);
147+ const Logits logits (att, last_pos + 1 );
168148 MaybeLogitsSoftCap (att_cap, logits, ctx, worker);
169149 Softmax (logits, ctx, worker, /* temperature=*/ 1 .0f );
170150
@@ -194,8 +174,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
194174 const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads ;
195175
196176 const size_t cache_layer_size = layer_config.CacheLayerSize ();
197- const size_t seq_len =
198- static_cast <size_t >(activations.div_seq_len .GetDivisor ());
177+ const size_t seq_len = activations.SeqLen ();
199178 // All layers should have the same number of heads.
200179 HWY_DASSERT (activations.div_heads .GetDivisor () == layer_config.heads );
201180
@@ -284,8 +263,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
284263 ++interleaved_idx) {
285264 const size_t qi = div_qbatch.Remainder (interleaved_idx);
286265 const size_t batch_idx = div_qbatch.Divide (interleaved_idx);
287- const size_t cache_pos =
288- activations.div_seq_len .Remainder (qbatch.Pos (qi) + batch_idx);
266+ const size_t cache_pos = qbatch.Pos (qi) + batch_idx;
267+ // --seq_len must be large enough to avoid wraparound.
268+ HWY_DASSERT (cache_pos < activations.SeqLen ());
269+
289270 env.row_ptrs [0 ][interleaved_idx] = reinterpret_cast <uint8_t *>(
290271 qbatch.KV (qi).kv_cache .Row (cache_pos) + layer_idx * cache_layer_size);
291272 }
@@ -304,8 +285,9 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
304285 const size_t interleaved_idx = task / kv_heads;
305286 const size_t qi = div_qbatch.Remainder (interleaved_idx);
306287 const size_t batch_idx = div_qbatch.Divide (interleaved_idx);
307- const size_t pos = qbatch.Pos (qi) + batch_idx;
308- const size_t cache_pos = activations.div_seq_len .Remainder (pos);
288+ const size_t cache_pos = qbatch.Pos (qi) + batch_idx;
289+ // --seq_len must be large enough to avoid wraparound.
290+ HWY_DASSERT (cache_pos < activations.SeqLen ());
309291 auto & kv_cache = qbatch.KV (qi).kv_cache ;
310292 KV_t* HWY_RESTRICT kv = kv_cache.Row (cache_pos) +
311293 layer_idx * cache_layer_size +
@@ -325,7 +307,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
325307 }
326308
327309 PositionalEncodingQK (kv_f32, layer_idx, activations, env.ctx , worker,
328- pos , /* mul=*/ 1 .0f );
310+ cache_pos , /* mul=*/ 1 .0f );
329311 CompressPerThread tls;
330312 Compress (kv_f32, 2 * qkv_dim, tls, MakeSpan (kv, 2 * qkv_dim), 0 );
331313 });
0 commit comments