We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 49d420a commit 5a50087Copy full SHA for 5a50087
gemma/activations.h
@@ -46,7 +46,8 @@ static inline float ChooseQueryScale(const ModelConfig& config) {
46
struct AttentionActivations {
47
AttentionActivations(
48
const ModelConfig& config, const LayerConfig& layer_config,
49
- size_t batch_size, size_t seq_len, const Allocator& allocator,
+ size_t batch_size, size_t seq_len, AttentionImpl attention_impl,
50
+ const Allocator& allocator,
51
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
52
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
53
// MHA and does not use an external KV cache.
@@ -217,7 +218,8 @@ struct Activations {
217
218
219
attention_impl(runtime_config.attention_impl),
220
attention_storage(config, layer_config, batch_size, seq_len,
- ctx.allocator, row_ptrs),
221
+ runtime_config.attention_impl, ctx.allocator,
222
+ row_ptrs),
223
attention(config, seq_len, attention_storage) {
224
HWY_ASSERT(batch_size != 0);
225
gemma/flash_attention_test.cc
@@ -124,7 +124,8 @@ void TestFlashAttention(size_t target_parallelism) {
124
const size_t batch_size = kOuter;
125
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
126
AttentionActivations attention_storage(config, layer_config, batch_size,
127
- kOuter, ctx.allocator, row_ptrs);
+ kOuter, AttentionImpl::kFlash,
128
+ ctx.allocator, row_ptrs);
129
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
130
const size_t qkv_dim = layer_config.qkv_dim;
131
ASSERT_EQ(qkv_dim, kInner);
0 commit comments