Skip to content

Commit 1fe9efe

Browse files
pcullitoncopybara-github
authored andcommitted
Pre-compress query activations to BF16 before FlashAttention.
PiperOrigin-RevId: 826098088
1 parent 8a100c1 commit 1fe9efe

File tree

10 files changed

+59
-59
lines changed

10 files changed

+59
-59
lines changed

evals/benchmark_helper.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ namespace gcpp {
3838

3939
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
4040
const InferenceArgs& inference)
41-
: ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) {
41+
: initializer_value_(gcpp::InternalInit()),
42+
ctx_(threading),
43+
env_(ctx_),
44+
gemma_(loader, inference, ctx_) {
4245
const ModelConfig& config = gemma_.Config();
4346
// Only allocate one for starters because GenerateBatch might not be called.
4447
kv_caches_.push_back(KVCache(config, inference, ctx_.allocator));

evals/benchmark_helper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ class GemmaEnv {
125125
MatMulEnv& MutableEnv() { return env_; }
126126

127127
private:
128+
// This is used to ensure that InternalInit is called before anything else.
129+
int initializer_value_ = 0;
128130
ThreadingContext ctx_;
129131
MatMulEnv env_;
130132
Gemma gemma_;

evals/gemma_batch_bench.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,5 +153,3 @@ int main(int argc, char** argv) {
153153

154154
return RUN_ALL_TESTS();
155155
}
156-
157-

evals/gemma_test.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,6 @@ TEST_F(GemmaTest, CrossEntropySmall) {
181181

182182
int main(int argc, char** argv) {
183183
testing::InitGoogleTest(&argc, argv);
184-
gcpp::InternalInit();
185184
gcpp::GemmaTest::InitEnv(argc, argv);
186185
int ret = RUN_ALL_TESTS();
187186
gcpp::GemmaTest::DeleteEnv();

gemma/activations.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ struct AttentionActivations {
5454
? layer_config.heads * 3 * layer_config.qkv_dim
5555
: layer_config.heads * layer_config.qkv_dim,
5656
allocator)),
57+
q_bf(MatFactory("q_bf", batch_size,
58+
config.vocab_size == 0
59+
? layer_config.heads * 3 * layer_config.qkv_dim
60+
: layer_config.heads * layer_config.qkv_dim,
61+
allocator)),
5762
q_T(MatFactory("q_T", layer_config.qkv_dim,
5863
config.vocab_size == 0
5964
? batch_size * layer_config.heads * 3
@@ -88,12 +93,14 @@ struct AttentionActivations {
8893
// If we forget any MatMul outputs here, debug builds print a warning but
8994
// fill them in each MatMul call.
9095
q.AllocateAndAttachRowPtrs(row_ptrs);
96+
q_bf.AllocateAndAttachRowPtrs(row_ptrs);
9197
q_T.AllocateAndAttachRowPtrs(row_ptrs);
9298
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
9399
}
94100

95101
void SetBatchSize(size_t batch_size) {
96102
q.OverrideRows(batch_size);
103+
q_bf.OverrideRows(batch_size);
97104
// q_T rows are always qkv_dim!
98105

99106
pre_att_rms_out.OverrideRows(batch_size);
@@ -105,6 +112,7 @@ struct AttentionActivations {
105112
}
106113

107114
MatStorageT<float> q; // query
115+
MatStorageT<BF16> q_bf;
108116
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
109117

110118
MatStorageT<float> pre_att_rms_out;
@@ -130,6 +138,7 @@ struct AttentionActivationsPtrs {
130138
const AttentionActivations& activations)
131139
: AttentionActivationsPtrs(config, seq_len) {
132140
q = activations.q;
141+
q_bf = activations.q_bf;
133142
q_T = activations.q_T;
134143
pre_att_rms_out = activations.pre_att_rms_out;
135144
att = activations.att;
@@ -141,6 +150,7 @@ struct AttentionActivationsPtrs {
141150

142151
void SetBatchSize(size_t batch_size) {
143152
q.OverrideRows(batch_size);
153+
q_bf.OverrideRows(batch_size);
144154
// q_T rows are always qkv_dim!
145155
pre_att_rms_out.OverrideRows(batch_size);
146156
att.OverrideRows(batch_size);
@@ -151,6 +161,7 @@ struct AttentionActivationsPtrs {
151161

152162
const ModelConfig& config;
153163
MatPtrT<float> q;
164+
MatPtrT<BF16> q_bf;
154165
MatPtrT<BF16> q_T;
155166
MatPtrT<float> pre_att_rms_out;
156167
MatPtrT<float> att;

gemma/flash_attention.cc

Lines changed: 37 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -154,25 +154,20 @@ void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max,
154154

155155
// Calculates the complete attention outputs for a single row of q.
156156
void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
157-
const float* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
157+
const BF16* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
158158
const MatPtrT<KV_t>& v, const size_t layer_idx,
159159
const AttentionActivationsPtrs& activations,
160160
float* HWY_RESTRICT att_out, ThreadingContext& ctx,
161161
const size_t worker) {
162162
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
163163
const hn::ScalableTag<BF16> dbf;
164164
const size_t qkv_dim = k.Cols();
165-
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
166165

167-
CompressPerThread tls;
168-
const hn::ScalableTag<float> df;
169-
CompressTraits<BF16>::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim),
170-
0);
171166
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
172167
// TODO: Mixed-mode can be further improved for Turin: we can demote right
173168
// before we do the dot product instruction, rather than promote both to f32.
174169
// But some potential accuracy loss there, needs evaluation first.
175-
float m = Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
170+
float m = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
176171
if (float cap = activations.config.att_cap; cap > 0.0f) {
177172
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
178173
m = cap * std::tanh(m / cap);
@@ -182,8 +177,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
182177
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker);
183178
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
184179
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
185-
float x =
186-
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
180+
float x = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
187181
SingleFlashAttentionStep(x, activations.config.att_cap, m, d,
188182
v.Row(pos_mod), v.Cols(), att_out);
189183
}
@@ -193,19 +187,15 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
193187
// the dot products of NF rows of Q for a single K timestep.
194188
template <class DF, class VF = hn::Vec<DF>>
195189
VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
196-
const size_t k_pos, const MatPtrT<float>& q,
190+
const size_t k_pos, const MatPtrT<BF16>& q,
197191
const MatPtrT<KV_t>& k) {
198192
const hn::ScalableTag<BF16> dbf;
199193
const size_t qkv_dim = k.Cols();
200-
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
201-
CompressPerThread tls;
202194

203195
hn::TFromD<DF> results[hn::MaxLanes(df)];
204196
for (size_t i = 0; i < hn::Lanes(df); ++i) {
205-
CompressTraits<BF16>::Compress(df, q.Row(0) + q_offsets[i], qkv_dim, tls,
206-
MakeSpan(q_bf, qkv_dim), 0);
207-
results[i] =
208-
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
197+
results[i] = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[i], qkv_dim), 0,
198+
k.Row(k_pos), qkv_dim);
209199
}
210200
return hn::LoadU(df, results);
211201
}
@@ -290,7 +280,7 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
290280
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
291281
// max_last_pos].
292282
void TileFlashAttention(
293-
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
283+
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets,
294284
const StridedView<BF16>& qT, const MatPtrT<KV_t>& k, const size_t start_pos,
295285
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
296286
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
@@ -396,7 +386,7 @@ void TileFlashAttention(
396386
// This is the result of 4 rows of Q against NF K timesteps, with positions
397387
// given by k_offsets[0..NF].
398388
template <class DF, class VF = hn::Vec<DF>>
399-
void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
389+
void QDotKTilex4(DF df, const BF16* HWY_RESTRICT q,
400390
const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<KV_t>& k,
401391
const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1,
402392
VF& sum2, VF& sum3) {
@@ -411,17 +401,13 @@ void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
411401
VI k_offsets_vec = hn::LoadU(di, k_offsets);
412402
for (size_t i = 0; i < k.Cols(); ++i) {
413403
VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec);
414-
VF q_0 = hn::Set(df, hwy::ConvertScalarTo<float>(
415-
hwy::ConvertScalarTo<BF16>(q[q_offsets[0] + i])));
404+
VF q_0 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[0] + i]));
416405
sum0 = hn::MulAdd(q_0, k_vec, sum0);
417-
VF q_1 = hn::Set(df, hwy::ConvertScalarTo<float>(
418-
hwy::ConvertScalarTo<BF16>(q[q_offsets[1] + i])));
406+
VF q_1 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[1] + i]));
419407
sum1 = hn::MulAdd(q_1, k_vec, sum1);
420-
VF q_2 = hn::Set(df, hwy::ConvertScalarTo<float>(
421-
hwy::ConvertScalarTo<BF16>(q[q_offsets[2] + i])));
408+
VF q_2 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[2] + i]));
422409
sum2 = hn::MulAdd(q_2, k_vec, sum2);
423-
VF q_3 = hn::Set(df, hwy::ConvertScalarTo<float>(
424-
hwy::ConvertScalarTo<BF16>(q[q_offsets[3] + i])));
410+
VF q_3 = hn::Set(df, hwy::ConvertScalarTo<float>(q[q_offsets[3] + i]));
425411
sum3 = hn::MulAdd(q_3, k_vec, sum3);
426412
}
427413
}
@@ -446,7 +432,7 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
446432
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
447433
// max_last_pos].
448434
Tile4FlashState TileFlashAttention4(
449-
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
435+
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets,
450436
const MatPtrT<KV_t>& k, const size_t start_pos,
451437
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
452438
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
@@ -500,51 +486,40 @@ Tile4FlashState TileFlashAttention4(
500486
}
501487
const hn::ScalableTag<BF16> dbf;
502488
const size_t qkv_dim = k.Cols();
503-
HWY_ALIGN BF16 q_bf[kMaxQKVDim];
504-
CompressPerThread tls;
505-
const hn::ScalableTag<float> df_compress;
506489

507490
while (position <= max_last_pos) {
508491
size_t k_pos = activations.div_seq_len.Remainder(position);
509492
if (position <= last_pos[0]) {
510493
// Past the last position, x0 doesn't count.
511-
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[0],
512-
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
513-
float x0 =
514-
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
494+
float x0 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[0], qkv_dim), 0,
495+
k.Row(k_pos), qkv_dim);
515496
SingleFlashAttentionStep(x0, activations.config.att_cap,
516497
state.row_states[0].max, state.row_states[0].d,
517498
v.Row(k_pos), v.Cols(),
518499
att_out.Row(0) + out_offsets[0]);
519500
}
520501
if (position <= last_pos[1]) {
521502
// Past the last position, x1 doesn't count.
522-
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[1],
523-
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
524-
float x1 =
525-
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
503+
float x1 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[1], qkv_dim), 0,
504+
k.Row(k_pos), qkv_dim);
526505
SingleFlashAttentionStep(x1, activations.config.att_cap,
527506
state.row_states[1].max, state.row_states[1].d,
528507
v.Row(k_pos), v.Cols(),
529508
att_out.Row(0) + out_offsets[1]);
530509
}
531510
if (position <= last_pos[2]) {
532511
// Past the last position, x2 doesn't count.
533-
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[2],
534-
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
535-
float x2 =
536-
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
512+
float x2 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[2], qkv_dim), 0,
513+
k.Row(k_pos), qkv_dim);
537514
SingleFlashAttentionStep(x2, activations.config.att_cap,
538515
state.row_states[2].max, state.row_states[2].d,
539516
v.Row(k_pos), v.Cols(),
540517
att_out.Row(0) + out_offsets[2]);
541518
}
542519
if (position <= last_pos[3]) {
543520
// Past the last position, x3 doesn't count.
544-
CompressTraits<BF16>::Compress(df_compress, q.Row(0) + q_offsets[3],
545-
qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0);
546-
float x3 =
547-
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim);
521+
float x3 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[3], qkv_dim), 0,
522+
k.Row(k_pos), qkv_dim);
548523
SingleFlashAttentionStep(x3, activations.config.att_cap,
549524
state.row_states[3].max, state.row_states[3].d,
550525
v.Row(k_pos), v.Cols(),
@@ -642,6 +617,17 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
642617
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q,
643618
query_norm_scale, layer_idx, activations, ctx);
644619
const hwy::Divisor div_qbatch(qbatch.Size());
620+
// Compress q to q_bf.
621+
ParallelFor(
622+
ParallelismStrategy::kWithinCluster, activations.q.Rows(), ctx,
623+
/*cluster_idx=*/0, Callers::kFlashAttention,
624+
[&](size_t row, size_t worker) {
625+
CompressPerThread tls;
626+
const hn::ScalableTag<float> df;
627+
CompressTraits<BF16>::Compress(
628+
df, activations.q.Row(row), activations.q.Cols(), tls,
629+
MakeSpan(activations.q_bf.Row(row), activations.q_bf.Cols()), 0);
630+
});
645631
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
646632
const size_t qkv_dim = layer_config.qkv_dim;
647633

@@ -736,8 +722,8 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
736722
last_pos[offset] = last;
737723
min_last_pos = HWY_MIN(min_last_pos, last);
738724
max_last_pos = HWY_MAX(max_last_pos, last);
739-
q_offsets[offset] =
740-
activations.q.Row(tq_idx) + head * qkv_dim - activations.q.Row(0);
725+
q_offsets[offset] = activations.q_bf.Row(tq_idx) + head * qkv_dim -
726+
activations.q_bf.Row(0);
741727
out_offsets[offset] = activations.att_out.Row(tq_idx) + head * qkv_dim -
742728
activations.att_out.Row(0);
743729
const size_t kv_index = head / kHeadGroups;
@@ -776,12 +762,12 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
776762
// kNFx8HTileSize. In this case, qT is never used. Some tasks might
777763
// use qT and some might not, which is why the more general condition
778764
// is used above to catch all cases where qT will be used.
779-
TileFlashAttention(activations.q, q_offsets, qT, k,
765+
TileFlashAttention(activations.q_bf, q_offsets, qT, k,
780766
start_positions[offset], last_pos, min_last_pos,
781767
max_last_pos, v, layer_idx, activations,
782768
activations.att_out, out_offsets, ctx, worker);
783769
} else if (kVTileSize == 4) {
784-
TileFlashAttention4(activations.q, q_offsets, k,
770+
TileFlashAttention4(activations.q_bf, q_offsets, k,
785771
start_positions[offset], last_pos, min_last_pos,
786772
max_last_pos, v, layer_idx, activations,
787773
activations.att_out, out_offsets, ctx, worker);
@@ -791,7 +777,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
791777
break;
792778
} else {
793779
SingleFlashAttention(start_positions[offset], last_pos[offset],
794-
activations.q.Row(0) + q_offsets[offset], k, v,
780+
activations.q_bf.Row(0) + q_offsets[offset], k, v,
795781
layer_idx, activations,
796782
activations.att_out.Row(0) + out_offsets[offset],
797783
ctx, worker);

gemma/flash_attention.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ namespace gcpp {
4545
ThreadingContext& ctx, size_t worker); \
4646
\
4747
Tile4FlashState TileFlashAttention4( \
48-
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets, \
48+
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
4949
const MatPtrT<KV_t>& k, size_t start_pos, \
5050
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
5151
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \

io/io.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,9 @@ bool IOBatch::Add(void* mem, size_t bytes) {
236236
return true;
237237
}
238238

239-
void InternalInit() {
239+
int InternalInit() {
240+
// currently unused, except for init list ordering in GemmaEnv.
241+
return 0;
240242
}
241243

242244
uint64_t IOBatch::Read(const File& file) const {

io/io.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ std::string ReadFileToString(const Path& path);
150150

151151
// No-op in open-source. Must be called at the beginning of a binary, before
152152
// any I/O or flag usage.
153-
void InternalInit();
153+
int InternalInit();
154154

155155
} // namespace gcpp
156156

paligemma/paligemma_test.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ TEST_F(PaliGemmaTest, QueryObjects) {
7272

7373
int main(int argc, char** argv) {
7474
testing::InitGoogleTest(&argc, argv);
75-
gcpp::InternalInit();
7675

7776
gcpp::GemmaEnv env(argc, argv);
7877
gcpp::s_env = &env;

0 commit comments

Comments
 (0)