Skip to content

Commit 50f6016

Browse files
jan-wassenbergcopybara-github
authored andcommitted
Change (old) attention behavior to disallow wraparound, enforced via assertion.
Shared kU64PerLine constant PiperOrigin-RevId: 826042078
1 parent 3a63a12 commit 50f6016

File tree

6 files changed

+48
-63
lines changed

6 files changed

+48
-63
lines changed

gemma/activations.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ struct AttentionActivationsPtrs {
159159
// `inv_timescale*` are not batched.
160160
}
161161

162+
size_t SeqLen() const {
163+
return static_cast<size_t>(div_seq_len.GetDivisor());
164+
}
165+
162166
const ModelConfig& config;
163167
MatPtrT<float> q;
164168
MatPtrT<BF16> q_bf;

gemma/attention.cc

Lines changed: 26 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,12 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
6666
CompressTraits<BF16>::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim),
6767
0);
6868

69-
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
70-
// Slightly faster: no wraparound.
71-
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
72-
const float score =
73-
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos), qkv_dim);
74-
att[pos] = score;
75-
}
76-
} else {
77-
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
78-
const size_t pos_modulo = div_seq_len.Remainder(pos);
79-
const float score =
80-
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_modulo), qkv_dim);
81-
att[pos_modulo] = score;
82-
}
69+
// --seq_len must be large enough to avoid wraparound.
70+
HWY_DASSERT(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()));
71+
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
72+
const float score =
73+
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos), qkv_dim);
74+
att[pos] = score;
8375
}
8476
}
8577

@@ -114,25 +106,13 @@ static HWY_INLINE void WeightedSumV(
114106
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
115107
const MatPtrT<KV_t>& v, float* HWY_RESTRICT att_out, ThreadingContext& ctx,
116108
const size_t worker) {
117-
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
118-
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
119-
// we supported non-transposed B.
120-
// TODO: 2..4x unroll
121-
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), ctx,
122-
worker);
123-
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
124-
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols());
125-
}
126-
} else {
127-
{
128-
const size_t pos_mod = div_seq_len.Remainder(start_pos);
129-
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), ctx,
130-
worker);
131-
}
132-
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
133-
const size_t pos_mod = div_seq_len.Remainder(pos);
134-
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols());
135-
}
109+
// --seq_len must be large enough to avoid wraparound.
110+
HWY_DASSERT(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()));
111+
// TODO: replace with MatMul(att, v) after it supports non-transposed B.
112+
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), ctx,
113+
worker);
114+
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
115+
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols());
136116
}
137117
}
138118

@@ -146,9 +126,10 @@ void SingleDotSoftmaxWeightedSum(
146126
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
147127
const float att_cap = activations.config.att_cap;
148128
const float query_scale = activations.query_scale;
149-
const size_t seq_len =
150-
static_cast<size_t>(activations.div_seq_len.GetDivisor());
129+
// --seq_len must be large enough to avoid wraparound.
130+
HWY_DASSERT(last_pos < activations.SeqLen());
151131
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
132+
152133
// Apply rope and scaling to Q.
153134
if (query_norm_scale.HasPtr()) {
154135
CallUpcasted(&query_norm_scale, [&](const auto* weights_t) {
@@ -163,8 +144,7 @@ void SingleDotSoftmaxWeightedSum(
163144
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker);
164145

165146
// SoftMax with optional SoftCap yields "probabilities" in att.
166-
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
167-
const Logits logits(att, att_len);
147+
const Logits logits(att, last_pos + 1);
168148
MaybeLogitsSoftCap(att_cap, logits, ctx, worker);
169149
Softmax(logits, ctx, worker, /*temperature=*/1.0f);
170150

@@ -194,8 +174,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
194174
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
195175

196176
const size_t cache_layer_size = layer_config.CacheLayerSize();
197-
const size_t seq_len =
198-
static_cast<size_t>(activations.div_seq_len.GetDivisor());
177+
const size_t seq_len = activations.SeqLen();
199178
// All layers should have the same number of heads.
200179
HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads);
201180

@@ -284,8 +263,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
284263
++interleaved_idx) {
285264
const size_t qi = div_qbatch.Remainder(interleaved_idx);
286265
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
287-
const size_t cache_pos =
288-
activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx);
266+
const size_t cache_pos = qbatch.Pos(qi) + batch_idx;
267+
// --seq_len must be large enough to avoid wraparound.
268+
HWY_DASSERT(cache_pos < activations.SeqLen());
269+
289270
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
290271
qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size);
291272
}
@@ -304,8 +285,9 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
304285
const size_t interleaved_idx = task / kv_heads;
305286
const size_t qi = div_qbatch.Remainder(interleaved_idx);
306287
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
307-
const size_t pos = qbatch.Pos(qi) + batch_idx;
308-
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
288+
const size_t cache_pos = qbatch.Pos(qi) + batch_idx;
289+
// --seq_len must be large enough to avoid wraparound.
290+
HWY_DASSERT(cache_pos < activations.SeqLen());
309291
auto& kv_cache = qbatch.KV(qi).kv_cache;
310292
KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
311293
layer_idx * cache_layer_size +
@@ -325,7 +307,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
325307
}
326308

327309
PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker,
328-
pos, /*mul=*/1.0f);
310+
cache_pos, /*mul=*/1.0f);
329311
CompressPerThread tls;
330312
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
331313
});

gemma/flash_attention.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
716716
size_t last = pos;
717717
const size_t prefix_end = qbatch.PrefixEnd(qi);
718718
if (prefix_end > 0 && prefix_end - 1 > last) {
719-
// last_pos in QDotK and WeightedSumV is inclusive.
719+
// last_pos in `TileFlashAttention` is inclusive.
720720
last = prefix_end - 1;
721721
}
722722
last_pos[offset] = last;

util/basics.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ namespace gcpp {
3333
// For hwy::BitSet4096. Note that KVs are extremely large for such batches.
3434
HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096;
3535

36+
// Multiplier so a u64 occupies an entire cache line; avoids false sharing.
37+
HWY_INLINE_VAR constexpr size_t kU64PerLine = HWY_ALIGNMENT / sizeof(uint64_t);
38+
3639
enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 };
3740

3841
static inline const char* ToString(Tristate t) {

util/threading_context.cc

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,9 @@ static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) {
4343
const size_t num_tasks[4] = {HWY_MAX(1, num_workers / 2), num_workers * 1,
4444
num_workers * 5, num_workers * 20};
4545

46-
// Count tasks executed to ensure workers aren't optimized out. One per
47-
// cache line to avoid false sharing.
48-
const size_t kSizePerLine = HWY_ALIGNMENT / sizeof(size_t);
49-
50-
std::vector<size_t> counters(num_workers * kSizePerLine);
51-
size_t prev_total = 0; // avoids having to reset counters.
46+
// Count tasks executed to ensure workers aren't optimized out.
47+
std::vector<uint64_t> counters(num_workers * kU64PerLine);
48+
uint64_t prev_total = 0; // avoids having to reset counters.
5249

5350
hwy::RandomState rng;
5451
for (size_t rep = 0; rep < 500; ++rep) {
@@ -63,13 +60,13 @@ static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) {
6360
pool.Run(begin, end, [&](uint64_t task, size_t thread) {
6461
HWY_ASSERT(begin <= task && task < end);
6562
HWY_ASSERT(thread < num_workers);
66-
counters[thread * kSizePerLine]++;
63+
counters[thread * kU64PerLine]++;
6764
});
6865

6966
// Reduce count and ensure it matches the expected number of tasks.
70-
size_t total = 0;
67+
uint64_t total = 0;
7168
for (size_t i = 0; i < num_workers; ++i) {
72-
total += counters[i * kSizePerLine];
69+
total += counters[i * kU64PerLine];
7370
}
7471
const size_t expected = end - begin;
7572
HWY_ASSERT(total == prev_total + expected);

util/threading_test.cc

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,7 @@ TEST(ThreadingTest, TestStaticPartition) {
202202
}
203203
}
204204

205-
static constexpr size_t kU64PerThread = HWY_ALIGNMENT / sizeof(size_t);
206-
static uint64_t outputs[hwy::kMaxLogicalProcessors * kU64PerThread];
205+
static uint64_t outputs[hwy::kMaxLogicalProcessors * kU64PerLine];
207206

208207
std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
209208
// Governs duration of test; avoid timeout in debug builds.
@@ -217,7 +216,7 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
217216
const double t0 = hwy::platform::Now();
218217
for (size_t reps = 0; reps < 1200; ++reps) {
219218
pool.Run(0, pool.NumWorkers(), kCaller, [&](uint64_t task, size_t thread) {
220-
outputs[thread * kU64PerThread] = base + thread;
219+
outputs[thread * kU64PerLine] = base + thread;
221220
});
222221
hwy::PreventElision(outputs[base]);
223222
if (pool.AutoTuneComplete()) break;
@@ -258,7 +257,7 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
258257
const uint64_t t0 = hwy::timer::Start();
259258
pool.Run(0, pool.NumWorkers(), kCaller,
260259
[&](uint64_t task, size_t thread) {
261-
outputs[thread * kU64PerThread] = base + thread;
260+
outputs[thread * kU64PerLine] = base + thread;
262261
});
263262
const uint64_t t1 = hwy::timer::Stop();
264263
times.push_back(t1 - t0);
@@ -268,7 +267,7 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
268267
const uint64_t t0 = hwy::timer::Start();
269268
pool.Run(0, pool.NumWorkers(), kCaller,
270269
[&](uint64_t task, size_t thread) {
271-
outputs[thread * kU64PerThread] = base + thread;
270+
outputs[thread * kU64PerLine] = base + thread;
272271
});
273272
const uint64_t t1 = hwy::timer::Start();
274273
times.push_back(t1 - t0);
@@ -315,10 +314,10 @@ TEST(ThreadingTest, BenchJoin) {
315314

316315
// Verify outputs to ensure the measured code is not a no-op.
317316
for (size_t lp = 0; lp < pool.NumWorkers(); ++lp) {
318-
HWY_ASSERT(outputs[lp * kU64PerThread] >= 1);
319-
HWY_ASSERT(outputs[lp * kU64PerThread] <= 1 + pool.NumWorkers());
320-
for (size_t i = 1; i < kU64PerThread; ++i) {
321-
HWY_ASSERT(outputs[lp * kU64PerThread + i] == 0);
317+
HWY_ASSERT(outputs[lp * kU64PerLine] >= 1);
318+
HWY_ASSERT(outputs[lp * kU64PerLine] <= 1 + pool.NumWorkers());
319+
for (size_t i = 1; i < kU64PerLine; ++i) {
320+
HWY_ASSERT(outputs[lp * kU64PerLine + i] == 0);
322321
}
323322
}
324323
};

0 commit comments

Comments
 (0)