11//! utilities for working with the kv cache
22
3+ use std:: num:: NonZeroU8 ;
34use crate :: context:: LlamaContext ;
45
56impl LlamaContext < ' _ > {
@@ -14,6 +15,26 @@ impl LlamaContext<'_> {
1415 unsafe { llama_cpp_sys_2:: llama_kv_cache_seq_cp ( self . context . as_ptr ( ) , src, dest, 0 , size) }
1516 }
1617
18+ /// Copy the cache from one sequence to another.
19+ ///
20+ /// # Parameters
21+ ///
22+ /// * `src` - The sequence id to copy the cache from.
23+ /// * `dest` - The sequence id to copy the cache to.
24+ /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to [p1].
25+ /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from [p0].
26+ pub fn copy_kv_cache_seq ( & mut self , src : i32 , dest : i32 , p0 : Option < u16 > , p1 : Option < u16 > ) {
27+ unsafe {
28+ llama_cpp_sys_2:: llama_kv_cache_seq_cp (
29+ self . context . as_ptr ( ) ,
30+ src,
31+ dest,
32+ p0. map_or ( -1 , i32:: from) ,
33+ p1. map_or ( -1 , i32:: from) ,
34+ )
35+ }
36+ }
37+
1738 /// Clear the kv cache for the given sequence.
1839 ///
1940 /// # Parameters
@@ -31,4 +52,181 @@ impl LlamaContext<'_> {
3152 ) ;
3253 }
3354 }
55+
56+ /// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
57+ pub fn get_kv_cache_used_cells ( & self ) -> i32 {
58+ unsafe { llama_cpp_sys_2:: llama_get_kv_cache_used_cells ( self . context . as_ptr ( ) ) }
59+ }
60+
61+ /// Clear the KV cache
62+ pub fn clear_kv_cache ( & mut self ) {
63+ unsafe { llama_cpp_sys_2:: llama_kv_cache_clear ( self . context . as_ptr ( ) ) }
64+ }
65+
66+ /// Removes all tokens that do not belong to the specified sequence
67+ ///
68+ /// # Parameters
69+ ///
70+ /// * `seq_id` - The sequence id to keep
71+ pub fn llama_kv_cache_seq_keep ( & mut self , seq_id : i32 ) {
72+ unsafe { llama_cpp_sys_2:: llama_kv_cache_seq_keep ( self . context . as_ptr ( ) , seq_id) }
73+ }
74+
75+ /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
76+ /// If the KV cache is RoPEd, the KV data is updated accordingly:
77+ /// - lazily on next llama_decode()
78+ /// - explicitly with llama_kv_cache_update()
79+ ///
80+ /// # Parameters
81+ ///
82+ /// * `seq_id` - The sequence id to update
83+ /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to [p1].
84+ /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0].
85+ /// * `delta` - The relative position to add to the tokens
86+ pub fn kv_cache_seq_add ( & mut self , seq_id : i32 , p0 : Option < u16 > , p1 : Option < u16 > , delta : i32 ) {
87+ unsafe {
88+ llama_cpp_sys_2:: llama_kv_cache_seq_add (
89+ self . context . as_ptr ( ) ,
90+ seq_id,
91+ p0. map_or ( -1 , i32:: from) ,
92+ p1. map_or ( -1 , i32:: from) ,
93+ delta,
94+ )
95+ }
96+ }
97+
98+ /// Integer division of the positions by factor of `d > 1`
99+ /// If the KV cache is RoPEd, the KV data is updated accordingly:
100+ /// - lazily on next llama_decode()
101+ /// - explicitly with llama_kv_cache_update()
102+ ///
103+ /// # Parameters
104+ ///
105+ /// * `seq_id` - The sequence id to update
106+ /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to [p1].
107+ /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0].
108+ /// * `d` - The factor to divide the positions by
109+ pub fn kv_cache_seq_div ( & mut self , seq_id : i32 , p0 : Option < u16 > , p1 : Option < u16 > , d : NonZeroU8 ) {
110+ unsafe {
111+ llama_cpp_sys_2:: llama_kv_cache_seq_div (
112+ self . context . as_ptr ( ) ,
113+ seq_id,
114+ p0. map_or ( -1 , i32:: from) ,
115+ p1. map_or ( -1 , i32:: from) ,
116+ d. get ( ) . try_into ( ) . expect ( "d does not fit into a i32" ) ,
117+ )
118+ }
119+ }
120+
121+ /// Returns the largest position present in the KV cache for the specified sequence
122+ ///
123+ /// # Parameters
124+ ///
125+ /// * `seq_id` - The sequence id to get the max position for
126+ pub fn kv_cache_seq_pos_max ( & self , seq_id : i32 ) -> i32 {
127+ unsafe { llama_cpp_sys_2:: llama_kv_cache_seq_pos_max ( self . context . as_ptr ( ) , seq_id) }
128+ }
129+
130+ /// Defragment the KV cache
131+ /// This will be applied:
132+ /// - lazily on next llama_decode()
133+ /// - explicitly with llama_kv_cache_update()
134+ pub fn kv_cache_defrag ( & mut self ) {
135+ unsafe { llama_cpp_sys_2:: llama_kv_cache_defrag ( self . context . as_ptr ( ) ) }
136+ }
137+
138+ /// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
139+ pub fn kv_cache_update ( & mut self ) {
140+ unsafe { llama_cpp_sys_2:: llama_kv_cache_update ( self . context . as_ptr ( ) ) }
141+ }
142+
143+ /// Returns the number of tokens in the KV cache (slow, use only for debug)
144+ /// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
145+ pub fn get_kv_cache_token_count ( & self ) -> i32 {
146+ unsafe { llama_cpp_sys_2:: llama_get_kv_cache_token_count ( self . context . as_ptr ( ) ) }
147+ }
148+
149+ /// Create an empty KV cache view. (use only for debugging purposes)
150+ ///
151+ /// # Parameters
152+ ///
153+ /// * `n_max_seq` - Maximum number of sequences that can exist in a cell. It's not an error
154+ /// if there are more sequences in a cell than this value, however they will
155+ /// not be visible in the view cells_sequences.
156+ pub fn new_kv_cache_view ( & self , n_max_seq : i32 ) -> KVCacheView {
157+ let view = unsafe { llama_cpp_sys_2:: llama_kv_cache_view_init ( self . context . as_ptr ( ) , n_max_seq) } ;
158+ KVCacheView { view, ctx : self }
159+ }
160+ }
161+
162+
163+ /// Information associated with an individual cell in the KV cache view.
164+ #[ derive( Debug ) ]
165+ pub struct KVCacheViewCell {
166+ /// The position for this cell. Takes KV cache shifts into account.
167+ /// May be negative if the cell is not populated.
168+ pub pos : llama_cpp_sys_2:: llama_pos ,
169+ }
170+
171+ /// An updateable view of the KV cache. (use only for debugging purposes)
172+ #[ derive( Debug ) ]
173+ pub struct KVCacheView < ' a > {
174+ ctx : & ' a LlamaContext < ' a > ,
175+ view : llama_cpp_sys_2:: llama_kv_cache_view ,
34176}
177+
178+ impl < ' a > KVCacheView < ' a > {
179+ /// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
180+ pub fn update ( & mut self ) {
181+ unsafe { llama_cpp_sys_2:: llama_kv_cache_view_update ( self . ctx . context . as_ptr ( ) , & mut self . view ) }
182+ }
183+
184+ /// Number of KV cache cells. This will be the same as the context size.
185+ pub fn n_cells ( & self ) -> i32 {
186+ self . view . n_cells
187+ }
188+
189+ /// Number of tokens in the cache. For example, if there are two populated
190+ /// cells, the first with 1 sequence id in it and the second with 2 sequence
191+ /// ids then you'll have 3 tokens.
192+ pub fn token_count ( & self ) -> i32 {
193+ self . view . token_count
194+ }
195+
196+ /// Number of populated cache cells.
197+ pub fn used_cells ( & self ) -> i32 {
198+ self . view . used_cells
199+ }
200+
201+ /// Maximum contiguous empty slots in the cache.
202+ pub fn max_contiguous ( & self ) -> i32 {
203+ self . view . max_contiguous
204+ }
205+
206+ /// Index to the start of the max_contiguous slot range. Can be negative
207+ /// when cache is full.
208+ pub fn max_contiguous_idx ( & self ) -> i32 {
209+ self . view . max_contiguous_idx
210+ }
211+
212+ /// Information for individual cells.
213+ pub fn cells ( & self ) -> impl Iterator < Item =KVCacheViewCell > {
214+ unsafe { std:: slice:: from_raw_parts ( self . view . cells , self . view . n_cells . try_into ( ) . unwrap ( ) ) }
215+ . iter ( )
216+ . map ( |& cell| KVCacheViewCell { pos : cell. pos } )
217+ }
218+
219+ /// The sequences for each cell. There will be n_max_seq items per cell.
220+ pub fn cells_sequences ( & self ) -> impl Iterator < Item =& [ llama_cpp_sys_2:: llama_seq_id ] > {
221+ unsafe { std:: slice:: from_raw_parts ( self . view . cells_sequences , ( self . view . n_cells * self . view . n_max_seq ) . try_into ( ) . unwrap ( ) ) }
222+ . chunks ( self . view . n_max_seq . try_into ( ) . unwrap ( ) )
223+ }
224+ }
225+
226+ impl < ' a > Drop for KVCacheView < ' a > {
227+ fn drop ( & mut self ) {
228+ unsafe {
229+ llama_cpp_sys_2:: llama_kv_cache_view_free ( & mut self . view ) ;
230+ }
231+ }
232+ }
0 commit comments