Skip to content

Commit dcae6ae

Browse files
Krzysztof Rymskicopybara-github
authored andcommitted
Internal changes
PiperOrigin-RevId: 835159997
1 parent 5a50087 commit dcae6ae

File tree

5 files changed

+16
-6
lines changed

5 files changed

+16
-6
lines changed

gemma/flash_attention.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
#include <stdint.h>
1818

1919
#include <algorithm>
20+
#include <array>
2021
#include <cmath>
22+
#include <cstdlib>
2123
#include <limits>
2224

2325
#include "compression/types.h" // GEMMA_DISABLED_TARGETS

gemma/flash_attention.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ namespace gcpp {
6060
size_t layer_idx, const MatPtr& query_norm_scale, \
6161
AttentionActivationsPtrs& activations, QBatch& qbatch, \
6262
ThreadingContext& ctx); \
63+
\
6364
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
6465
} // namespace NAMESPACE
6566

gemma/kv_cache.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,16 @@ KVCache KVCache::Copy() {
5151
KVCache copy(kv_cache.Extents(), allocator_);
5252

5353
CopyMat(kv_cache, copy.kv_cache);
54-
5554
return copy;
5655
}
5756

5857
std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches) {
5958
std::vector<KVCachePtr> ptrs;
6059
ptrs.reserve(kv_caches.size());
6160
for (size_t i = 0; i < kv_caches.size(); ++i) {
62-
ptrs.push_back(KVCachePtr{.kv_cache = kv_caches[i].kv_cache});
61+
ptrs.push_back(KVCachePtr{
62+
.kv_cache = kv_caches[i].kv_cache,
63+
});
6364
}
6465
return ptrs;
6566
}

gemma/kv_cache.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
1818

1919
#include <stddef.h>
20+
21+
#include <optional>
2022
#include <vector>
2123

22-
#include "gemma/configs.h" // ModelConfig
24+
#include "gemma/configs.h" // ModelConfig
2325
#include "gemma/gemma_args.h" // InferenceArgs
2426
#include "util/basics.h" // BF16
2527
#include "util/mat.h"
@@ -31,12 +33,13 @@ using KV_t = float;
3133
struct KVCache {
3234
KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
3335
const Allocator& allocator);
34-
3536
// Returns a deep copy of the KVCache. Use explicit function instead of
3637
// copy ctor to make the cost explicit.
3738
KVCache Copy();
3839

39-
size_t SeqLen() const { return kv_cache.Rows(); }
40+
size_t SeqLen() const {
41+
return kv_cache.Rows();
42+
}
4043

4144
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
4245

@@ -49,7 +52,9 @@ struct KVCache {
4952

5053
// A non-owning view of a KVCache.
5154
struct KVCachePtr {
52-
size_t SeqLen() const { return kv_cache.Rows(); }
55+
size_t SeqLen() const {
56+
return kv_cache.Rows();
57+
}
5358
MatPtrT<KV_t> kv_cache;
5459
};
5560

ops/ops-inl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <cstdint>
2626
#include <random>
2727
#include <type_traits> // std::enable_if_t
28+
#include <utility>
2829
#include <vector>
2930

3031
#include "ops/matmul.h"

0 commit comments

Comments
 (0)