@@ -154,25 +154,20 @@ void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max,
154154
155155// Calculates the complete attention outputs for a single row of q.
156156void SingleFlashAttention (const size_t start_pos, const size_t last_pos,
157- const float * HWY_RESTRICT q, const MatPtrT<KV_t>& k,
157+ const BF16 * HWY_RESTRICT q, const MatPtrT<KV_t>& k,
158158 const MatPtrT<KV_t>& v, const size_t layer_idx,
159159 const AttentionActivationsPtrs& activations,
160160 float * HWY_RESTRICT att_out, ThreadingContext& ctx,
161161 const size_t worker) {
162162 GCPP_ZONE (ctx, worker, Zones::kFlashAttentionSingleFlashAttention );
163163 const hn::ScalableTag<BF16> dbf;
164164 const size_t qkv_dim = k.Cols ();
165- HWY_ALIGN BF16 q_bf[kMaxQKVDim ];
166165
167- CompressPerThread tls;
168- const hn::ScalableTag<float > df;
169- CompressTraits<BF16>::Compress (df, q, qkv_dim, tls, MakeSpan (q_bf, qkv_dim),
170- 0 );
171166 const size_t pos_mod = activations.div_seq_len .Remainder (start_pos);
172167 // TODO: Mixed-mode can be further improved for Turin: we can demote right
173168 // before we do the dot product instruction, rather than promote both to f32.
174169 // But some potential accuracy loss there, needs evaluation first.
175- float m = Dot (dbf, MakeConstSpan (q_bf , qkv_dim), 0 , k.Row (pos_mod), qkv_dim);
170+ float m = Dot (dbf, MakeConstSpan (q , qkv_dim), 0 , k.Row (pos_mod), qkv_dim);
176171 if (float cap = activations.config .att_cap ; cap > 0 .0f ) {
177172 // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
178173 m = cap * std::tanh (m / cap);
@@ -182,8 +177,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
182177 MulByConstTo (d, v.Row (pos_mod), att_out, v.Cols (), ctx, worker);
183178 for (size_t pos = start_pos + 1 ; pos <= last_pos; ++pos) {
184179 const size_t pos_mod = activations.div_seq_len .Remainder (pos);
185- float x =
186- Dot (dbf, MakeConstSpan (q_bf, qkv_dim), 0 , k.Row (pos_mod), qkv_dim);
180+ float x = Dot (dbf, MakeConstSpan (q, qkv_dim), 0 , k.Row (pos_mod), qkv_dim);
187181 SingleFlashAttentionStep (x, activations.config .att_cap , m, d,
188182 v.Row (pos_mod), v.Cols (), att_out);
189183 }
@@ -193,19 +187,15 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
193187// the dot products of NF rows of Q for a single K timestep.
194188template <class DF , class VF = hn::Vec<DF>>
195189VF QDotKVector (DF df, const uint32_t * HWY_RESTRICT q_offsets,
196- const size_t k_pos, const MatPtrT<float >& q,
190+ const size_t k_pos, const MatPtrT<BF16 >& q,
197191 const MatPtrT<KV_t>& k) {
198192 const hn::ScalableTag<BF16> dbf;
199193 const size_t qkv_dim = k.Cols ();
200- HWY_ALIGN BF16 q_bf[kMaxQKVDim ];
201- CompressPerThread tls;
202194
203195 hn::TFromD<DF> results[hn::MaxLanes (df)];
204196 for (size_t i = 0 ; i < hn::Lanes (df); ++i) {
205- CompressTraits<BF16>::Compress (df, q.Row (0 ) + q_offsets[i], qkv_dim, tls,
206- MakeSpan (q_bf, qkv_dim), 0 );
207- results[i] =
208- Dot (dbf, MakeConstSpan (q_bf, qkv_dim), 0 , k.Row (k_pos), qkv_dim);
197+ results[i] = Dot (dbf, MakeConstSpan (q.Row (0 ) + q_offsets[i], qkv_dim), 0 ,
198+ k.Row (k_pos), qkv_dim);
209199 }
210200 return hn::LoadU (df, results);
211201}
@@ -290,7 +280,7 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
290280// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
291281// max_last_pos].
292282void TileFlashAttention (
293- const MatPtrT<float >& q, const uint32_t * HWY_RESTRICT q_offsets,
283+ const MatPtrT<BF16 >& q, const uint32_t * HWY_RESTRICT q_offsets,
294284 const StridedView<BF16>& qT, const MatPtrT<KV_t>& k, const size_t start_pos,
295285 const uint32_t * HWY_RESTRICT last_pos, const size_t min_last_pos,
296286 const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
@@ -396,7 +386,7 @@ void TileFlashAttention(
396386// This is the result of 4 rows of Q against NF K timesteps, with positions
397387// given by k_offsets[0..NF].
398388template <class DF , class VF = hn::Vec<DF>>
399- void QDotKTilex4 (DF df, const float * HWY_RESTRICT q,
389+ void QDotKTilex4 (DF df, const BF16 * HWY_RESTRICT q,
400390 const uint32_t * HWY_RESTRICT q_offsets, const MatPtrT<KV_t>& k,
401391 const int32_t * HWY_RESTRICT k_offsets, VF& sum0, VF& sum1,
402392 VF& sum2, VF& sum3) {
@@ -411,17 +401,13 @@ void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
411401 VI k_offsets_vec = hn::LoadU (di, k_offsets);
412402 for (size_t i = 0 ; i < k.Cols (); ++i) {
413403 VF k_vec = hn::GatherIndex (df, k_base + i, k_offsets_vec);
414- VF q_0 = hn::Set (df, hwy::ConvertScalarTo<float >(
415- hwy::ConvertScalarTo<BF16>(q[q_offsets[0 ] + i])));
404+ VF q_0 = hn::Set (df, hwy::ConvertScalarTo<float >(q[q_offsets[0 ] + i]));
416405 sum0 = hn::MulAdd (q_0, k_vec, sum0);
417- VF q_1 = hn::Set (df, hwy::ConvertScalarTo<float >(
418- hwy::ConvertScalarTo<BF16>(q[q_offsets[1 ] + i])));
406+ VF q_1 = hn::Set (df, hwy::ConvertScalarTo<float >(q[q_offsets[1 ] + i]));
419407 sum1 = hn::MulAdd (q_1, k_vec, sum1);
420- VF q_2 = hn::Set (df, hwy::ConvertScalarTo<float >(
421- hwy::ConvertScalarTo<BF16>(q[q_offsets[2 ] + i])));
408+ VF q_2 = hn::Set (df, hwy::ConvertScalarTo<float >(q[q_offsets[2 ] + i]));
422409 sum2 = hn::MulAdd (q_2, k_vec, sum2);
423- VF q_3 = hn::Set (df, hwy::ConvertScalarTo<float >(
424- hwy::ConvertScalarTo<BF16>(q[q_offsets[3 ] + i])));
410+ VF q_3 = hn::Set (df, hwy::ConvertScalarTo<float >(q[q_offsets[3 ] + i]));
425411 sum3 = hn::MulAdd (q_3, k_vec, sum3);
426412 }
427413}
@@ -446,7 +432,7 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
446432// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
447433// max_last_pos].
448434Tile4FlashState TileFlashAttention4 (
449- const MatPtrT<float >& q, const uint32_t * HWY_RESTRICT q_offsets,
435+ const MatPtrT<BF16 >& q, const uint32_t * HWY_RESTRICT q_offsets,
450436 const MatPtrT<KV_t>& k, const size_t start_pos,
451437 const uint32_t * HWY_RESTRICT last_pos, const size_t min_last_pos,
452438 const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
@@ -500,51 +486,40 @@ Tile4FlashState TileFlashAttention4(
500486 }
501487 const hn::ScalableTag<BF16> dbf;
502488 const size_t qkv_dim = k.Cols ();
503- HWY_ALIGN BF16 q_bf[kMaxQKVDim ];
504- CompressPerThread tls;
505- const hn::ScalableTag<float > df_compress;
506489
507490 while (position <= max_last_pos) {
508491 size_t k_pos = activations.div_seq_len .Remainder (position);
509492 if (position <= last_pos[0 ]) {
510493 // Past the last position, x0 doesn't count.
511- CompressTraits<BF16>::Compress (df_compress, q.Row (0 ) + q_offsets[0 ],
512- qkv_dim, tls, MakeSpan (q_bf, qkv_dim), 0 );
513- float x0 =
514- Dot (dbf, MakeConstSpan (q_bf, qkv_dim), 0 , k.Row (k_pos), qkv_dim);
494+ float x0 = Dot (dbf, MakeConstSpan (q.Row (0 ) + q_offsets[0 ], qkv_dim), 0 ,
495+ k.Row (k_pos), qkv_dim);
515496 SingleFlashAttentionStep (x0, activations.config .att_cap ,
516497 state.row_states [0 ].max , state.row_states [0 ].d ,
517498 v.Row (k_pos), v.Cols (),
518499 att_out.Row (0 ) + out_offsets[0 ]);
519500 }
520501 if (position <= last_pos[1 ]) {
521502 // Past the last position, x1 doesn't count.
522- CompressTraits<BF16>::Compress (df_compress, q.Row (0 ) + q_offsets[1 ],
523- qkv_dim, tls, MakeSpan (q_bf, qkv_dim), 0 );
524- float x1 =
525- Dot (dbf, MakeConstSpan (q_bf, qkv_dim), 0 , k.Row (k_pos), qkv_dim);
503+ float x1 = Dot (dbf, MakeConstSpan (q.Row (0 ) + q_offsets[1 ], qkv_dim), 0 ,
504+ k.Row (k_pos), qkv_dim);
526505 SingleFlashAttentionStep (x1, activations.config .att_cap ,
527506 state.row_states [1 ].max , state.row_states [1 ].d ,
528507 v.Row (k_pos), v.Cols (),
529508 att_out.Row (0 ) + out_offsets[1 ]);
530509 }
531510 if (position <= last_pos[2 ]) {
532511 // Past the last position, x2 doesn't count.
533- CompressTraits<BF16>::Compress (df_compress, q.Row (0 ) + q_offsets[2 ],
534- qkv_dim, tls, MakeSpan (q_bf, qkv_dim), 0 );
535- float x2 =
536- Dot (dbf, MakeConstSpan (q_bf, qkv_dim), 0 , k.Row (k_pos), qkv_dim);
512+ float x2 = Dot (dbf, MakeConstSpan (q.Row (0 ) + q_offsets[2 ], qkv_dim), 0 ,
513+ k.Row (k_pos), qkv_dim);
537514 SingleFlashAttentionStep (x2, activations.config .att_cap ,
538515 state.row_states [2 ].max , state.row_states [2 ].d ,
539516 v.Row (k_pos), v.Cols (),
540517 att_out.Row (0 ) + out_offsets[2 ]);
541518 }
542519 if (position <= last_pos[3 ]) {
543520 // Past the last position, x3 doesn't count.
544- CompressTraits<BF16>::Compress (df_compress, q.Row (0 ) + q_offsets[3 ],
545- qkv_dim, tls, MakeSpan (q_bf, qkv_dim), 0 );
546- float x3 =
547- Dot (dbf, MakeConstSpan (q_bf, qkv_dim), 0 , k.Row (k_pos), qkv_dim);
521+ float x3 = Dot (dbf, MakeConstSpan (q.Row (0 ) + q_offsets[3 ], qkv_dim), 0 ,
522+ k.Row (k_pos), qkv_dim);
548523 SingleFlashAttentionStep (x3, activations.config .att_cap ,
549524 state.row_states [3 ].max , state.row_states [3 ].d ,
550525 v.Row (k_pos), v.Cols (),
@@ -642,6 +617,17 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
642617 RMSNormAndPositionalEncoding (num_tokens, qbatch, activations.q ,
643618 query_norm_scale, layer_idx, activations, ctx);
644619 const hwy::Divisor div_qbatch (qbatch.Size ());
620+ // Compress q to q_bf.
621+ ParallelFor (
622+ ParallelismStrategy::kWithinCluster , activations.q .Rows (), ctx,
623+ /* cluster_idx=*/ 0 , Callers::kFlashAttention ,
624+ [&](size_t row, size_t worker) {
625+ CompressPerThread tls;
626+ const hn::ScalableTag<float > df;
627+ CompressTraits<BF16>::Compress (
628+ df, activations.q .Row (row), activations.q .Cols (), tls,
629+ MakeSpan (activations.q_bf .Row (row), activations.q_bf .Cols ()), 0 );
630+ });
645631 const LayerConfig& layer_config = activations.config .layer_configs [layer_idx];
646632 const size_t qkv_dim = layer_config.qkv_dim ;
647633
@@ -736,8 +722,8 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
736722 last_pos[offset] = last;
737723 min_last_pos = HWY_MIN (min_last_pos, last);
738724 max_last_pos = HWY_MAX (max_last_pos, last);
739- q_offsets[offset] =
740- activations. q . Row (tq_idx) + head * qkv_dim - activations.q .Row (0 );
725+ q_offsets[offset] = activations. q_bf . Row (tq_idx) + head * qkv_dim -
726+ activations.q_bf .Row (0 );
741727 out_offsets[offset] = activations.att_out .Row (tq_idx) + head * qkv_dim -
742728 activations.att_out .Row (0 );
743729 const size_t kv_index = head / kHeadGroups ;
@@ -776,12 +762,12 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
776762 // kNFx8HTileSize. In this case, qT is never used. Some tasks might
777763 // use qT and some might not, which is why the more general condition
778764 // is used above to catch all cases where qT will be used.
779- TileFlashAttention (activations.q , q_offsets, qT, k,
765+ TileFlashAttention (activations.q_bf , q_offsets, qT, k,
780766 start_positions[offset], last_pos, min_last_pos,
781767 max_last_pos, v, layer_idx, activations,
782768 activations.att_out , out_offsets, ctx, worker);
783769 } else if (kVTileSize == 4 ) {
784- TileFlashAttention4 (activations.q , q_offsets, k,
770+ TileFlashAttention4 (activations.q_bf , q_offsets, k,
785771 start_positions[offset], last_pos, min_last_pos,
786772 max_last_pos, v, layer_idx, activations,
787773 activations.att_out , out_offsets, ctx, worker);
@@ -791,7 +777,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
791777 break ;
792778 } else {
793779 SingleFlashAttention (start_positions[offset], last_pos[offset],
794- activations.q .Row (0 ) + q_offsets[offset], k, v,
780+ activations.q_bf .Row (0 ) + q_offsets[offset], k, v,
795781 layer_idx, activations,
796782 activations.att_out .Row (0 ) + out_offsets[offset],
797783 ctx, worker);
0 commit comments