Skip to content

Commit 5a50087

Browse files
stollemcopybara-github
authored andcommitted
Internal change
PiperOrigin-RevId: 835115693
1 parent 49d420a commit 5a50087

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

gemma/activations.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ static inline float ChooseQueryScale(const ModelConfig& config) {
4646
struct AttentionActivations {
4747
AttentionActivations(
4848
const ModelConfig& config, const LayerConfig& layer_config,
49-
size_t batch_size, size_t seq_len, const Allocator& allocator,
49+
size_t batch_size, size_t seq_len, AttentionImpl attention_impl,
50+
const Allocator& allocator,
5051
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
5152
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
5253
// MHA and does not use an external KV cache.
@@ -217,7 +218,8 @@ struct Activations {
217218

218219
attention_impl(runtime_config.attention_impl),
219220
attention_storage(config, layer_config, batch_size, seq_len,
220-
ctx.allocator, row_ptrs),
221+
runtime_config.attention_impl, ctx.allocator,
222+
row_ptrs),
221223
attention(config, seq_len, attention_storage) {
222224
HWY_ASSERT(batch_size != 0);
223225

gemma/flash_attention_test.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ void TestFlashAttention(size_t target_parallelism) {
124124
const size_t batch_size = kOuter;
125125
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
126126
AttentionActivations attention_storage(config, layer_config, batch_size,
127-
kOuter, ctx.allocator, row_ptrs);
127+
kOuter, AttentionImpl::kFlash,
128+
ctx.allocator, row_ptrs);
128129
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
129130
const size_t qkv_dim = layer_config.qkv_dim;
130131
ASSERT_EQ(qkv_dim, kInner);

0 commit comments

Comments
 (0)