11//! utilities for working with the kv cache
22
3- use std:: num:: NonZeroU8 ;
43use crate :: context:: LlamaContext ;
4+ use std:: ffi:: c_int;
5+ use std:: num:: NonZeroU8 ;
56
67impl LlamaContext < ' _ > {
78 /// Copy the cache from one sequence to another.
@@ -24,14 +25,10 @@ impl LlamaContext<'_> {
2425 /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to [p1].
2526 /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from [p0].
2627 pub fn copy_kv_cache_seq ( & mut self , src : i32 , dest : i32 , p0 : Option < u16 > , p1 : Option < u16 > ) {
28+ let p0 = p0. map_or ( -1 , i32:: from) ;
29+ let p1 = p1. map_or ( -1 , i32:: from) ;
2730 unsafe {
28- llama_cpp_sys_2:: llama_kv_cache_seq_cp (
29- self . context . as_ptr ( ) ,
30- src,
31- dest,
32- p0. map_or ( -1 , i32:: from) ,
33- p1. map_or ( -1 , i32:: from) ,
34- )
31+ llama_cpp_sys_2:: llama_kv_cache_seq_cp ( self . context . as_ptr ( ) , src, dest, p0, p1) ;
3532 }
3633 }
3734
@@ -43,17 +40,15 @@ impl LlamaContext<'_> {
4340 /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to [p1].
4441 /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from [p0].
4542 pub fn clear_kv_cache_seq ( & mut self , src : i32 , p0 : Option < u16 > , p1 : Option < u16 > ) {
43+ let p0 = p0. map_or ( -1 , i32:: from) ;
44+ let p1 = p1. map_or ( -1 , i32:: from) ;
4645 unsafe {
47- llama_cpp_sys_2:: llama_kv_cache_seq_rm (
48- self . context . as_ptr ( ) ,
49- src,
50- p0. map_or ( -1 , i32:: from) ,
51- p1. map_or ( -1 , i32:: from) ,
52- ) ;
46+ llama_cpp_sys_2:: llama_kv_cache_seq_rm ( self . context . as_ptr ( ) , src, p0, p1) ;
5347 }
5448 }
5549
5650 /// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
51+ #[ must_use]
5752 pub fn get_kv_cache_used_cells ( & self ) -> i32 {
5853 unsafe { llama_cpp_sys_2:: llama_get_kv_cache_used_cells ( self . context . as_ptr ( ) ) }
5954 }
@@ -74,8 +69,8 @@ impl LlamaContext<'_> {
7469
7570 /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
7671 /// If the KV cache is RoPEd, the KV data is updated accordingly:
77- /// - lazily on next llama_decode()
78- /// - explicitly with llama_kv_cache_update()
72+ /// - lazily on next [`LlamaContext::decode`]
73+ /// - explicitly with [`Self::kv_cache_update`]
7974 ///
8075 /// # Parameters
8176 ///
@@ -84,53 +79,51 @@ impl LlamaContext<'_> {
8479 /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0].
8580 /// * `delta` - The relative position to add to the tokens
8681 pub fn kv_cache_seq_add ( & mut self , seq_id : i32 , p0 : Option < u16 > , p1 : Option < u16 > , delta : i32 ) {
82+ let p0 = p0. map_or ( -1 , i32:: from) ;
83+ let p1 = p1. map_or ( -1 , i32:: from) ;
8784 unsafe {
88- llama_cpp_sys_2:: llama_kv_cache_seq_add (
89- self . context . as_ptr ( ) ,
90- seq_id,
91- p0. map_or ( -1 , i32:: from) ,
92- p1. map_or ( -1 , i32:: from) ,
93- delta,
94- )
85+ llama_cpp_sys_2:: llama_kv_cache_seq_add ( self . context . as_ptr ( ) , seq_id, p0, p1, delta) ;
9586 }
9687 }
9788
9889 /// Integer division of the positions by factor of `d > 1`
99- /// If the KV cache is RoPEd, the KV data is updated accordingly:
100- /// - lazily on next llama_decode()
101- /// - explicitly with llama_kv_cache_update()
90+ /// If the KV cache is ` RoPEd` , the KV data is updated accordingly:
91+ /// - lazily on next [`LlamaContext::decode`]
92+ /// - explicitly with [`Self::kv_cache_update`]
10293 ///
10394 /// # Parameters
10495 ///
10596 /// * `seq_id` - The sequence id to update
10697 /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to [p1].
10798 /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0].
10899 /// * `d` - The factor to divide the positions by
109- pub fn kv_cache_seq_div ( & mut self , seq_id : i32 , p0 : Option < u16 > , p1 : Option < u16 > , d : NonZeroU8 ) {
110- unsafe {
111- llama_cpp_sys_2:: llama_kv_cache_seq_div (
112- self . context . as_ptr ( ) ,
113- seq_id,
114- p0. map_or ( -1 , i32:: from) ,
115- p1. map_or ( -1 , i32:: from) ,
116- d. get ( ) . try_into ( ) . expect ( "d does not fit into a i32" ) ,
117- )
118- }
100+ pub fn kv_cache_seq_div (
101+ & mut self ,
102+ seq_id : i32 ,
103+ p0 : Option < u16 > ,
104+ p1 : Option < u16 > ,
105+ d : NonZeroU8 ,
106+ ) {
107+ let p0 = p0. map_or ( -1 , i32:: from) ;
108+ let p1 = p1. map_or ( -1 , i32:: from) ;
109+ let d = c_int:: from ( d. get ( ) ) ;
110+ unsafe { llama_cpp_sys_2:: llama_kv_cache_seq_div ( self . context . as_ptr ( ) , seq_id, p0, p1, d) }
119111 }
120112
121113 /// Returns the largest position present in the KV cache for the specified sequence
122114 ///
123115 /// # Parameters
124116 ///
125117 /// * `seq_id` - The sequence id to get the max position for
118+ #[ must_use]
126119 pub fn kv_cache_seq_pos_max ( & self , seq_id : i32 ) -> i32 {
127120 unsafe { llama_cpp_sys_2:: llama_kv_cache_seq_pos_max ( self . context . as_ptr ( ) , seq_id) }
128121 }
129122
130123 /// Defragment the KV cache
131124 /// This will be applied:
132- /// - lazily on next llama_decode()
133- /// - explicitly with llama_kv_cache_update()
125+ /// - lazily on next [`LlamaContext::decode`]
126+ /// - explicitly with [`Self::kv_cache_update`]
134127 pub fn kv_cache_defrag ( & mut self ) {
135128 unsafe { llama_cpp_sys_2:: llama_kv_cache_defrag ( self . context . as_ptr ( ) ) }
136129 }
@@ -142,6 +135,7 @@ impl LlamaContext<'_> {
142135
143136 /// Returns the number of tokens in the KV cache (slow, use only for debug)
144137 /// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
138+ #[ must_use]
145139 pub fn get_kv_cache_token_count ( & self ) -> i32 {
146140 unsafe { llama_cpp_sys_2:: llama_get_kv_cache_token_count ( self . context . as_ptr ( ) ) }
147141 }
@@ -152,14 +146,15 @@ impl LlamaContext<'_> {
152146 ///
153147 /// * `n_max_seq` - Maximum number of sequences that can exist in a cell. It's not an error
154148 /// if there are more sequences in a cell than this value, however they will
155- /// not be visible in the view cells_sequences.
149+ /// not be visible in the view `cells_sequences`.
150+ #[ must_use]
156151 pub fn new_kv_cache_view ( & self , n_max_seq : i32 ) -> KVCacheView {
157- let view = unsafe { llama_cpp_sys_2:: llama_kv_cache_view_init ( self . context . as_ptr ( ) , n_max_seq) } ;
152+ let view =
153+ unsafe { llama_cpp_sys_2:: llama_kv_cache_view_init ( self . context . as_ptr ( ) , n_max_seq) } ;
158154 KVCacheView { view, ctx : self }
159155 }
160156}
161157
162-
163158/// Information associated with an individual cell in the KV cache view.
164159#[ derive( Debug ) ]
165160pub struct KVCacheViewCell {
@@ -178,48 +173,75 @@ pub struct KVCacheView<'a> {
178173impl < ' a > KVCacheView < ' a > {
179174 /// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
180175 pub fn update ( & mut self ) {
181- unsafe { llama_cpp_sys_2:: llama_kv_cache_view_update ( self . ctx . context . as_ptr ( ) , & mut self . view ) }
176+ unsafe {
177+ llama_cpp_sys_2:: llama_kv_cache_view_update ( self . ctx . context . as_ptr ( ) , & mut self . view ) ;
178+ }
182179 }
183180
184181 /// Number of KV cache cells. This will be the same as the context size.
182+ #[ must_use]
185183 pub fn n_cells ( & self ) -> i32 {
186184 self . view . n_cells
187185 }
188186
189187 /// Number of tokens in the cache. For example, if there are two populated
190188 /// cells, the first with 1 sequence id in it and the second with 2 sequence
191189 /// ids then you'll have 3 tokens.
190+ #[ must_use]
192191 pub fn token_count ( & self ) -> i32 {
193192 self . view . token_count
194193 }
195194
196195 /// Number of populated cache cells.
196+ #[ must_use]
197197 pub fn used_cells ( & self ) -> i32 {
198198 self . view . used_cells
199199 }
200200
201201 /// Maximum contiguous empty slots in the cache.
202+ #[ must_use]
202203 pub fn max_contiguous ( & self ) -> i32 {
203204 self . view . max_contiguous
204205 }
205206
206- /// Index to the start of the max_contiguous slot range. Can be negative
207+ /// Index to the start of the ` max_contiguous` slot range. Can be negative
207208 /// when cache is full.
209+ #[ must_use]
208210 pub fn max_contiguous_idx ( & self ) -> i32 {
209211 self . view . max_contiguous_idx
210212 }
211213
212214 /// Information for individual cells.
213- pub fn cells ( & self ) -> impl Iterator < Item =KVCacheViewCell > {
214- unsafe { std:: slice:: from_raw_parts ( self . view . cells , self . view . n_cells . try_into ( ) . unwrap ( ) ) }
215- . iter ( )
216- . map ( |& cell| KVCacheViewCell { pos : cell. pos } )
215+ ///
216+ /// # Panics
217+ ///
218+ /// - if `n_cells` does not fit into usize.
219+ pub fn cells ( & self ) -> impl Iterator < Item = KVCacheViewCell > {
220+ unsafe {
221+ std:: slice:: from_raw_parts (
222+ self . view . cells ,
223+ usize:: try_from ( self . view . n_cells ) . expect ( "failed to fit n_cells into usize" ) ,
224+ )
225+ }
226+ . iter ( )
227+ . map ( |& cell| KVCacheViewCell { pos : cell. pos } )
217228 }
218229
219- /// The sequences for each cell. There will be n_max_seq items per cell.
220- pub fn cells_sequences ( & self ) -> impl Iterator < Item =& [ llama_cpp_sys_2:: llama_seq_id ] > {
221- unsafe { std:: slice:: from_raw_parts ( self . view . cells_sequences , ( self . view . n_cells * self . view . n_max_seq ) . try_into ( ) . unwrap ( ) ) }
222- . chunks ( self . view . n_max_seq . try_into ( ) . unwrap ( ) )
230+ /// The sequences for each cell. There will be `n_max_seq` items per cell.
231+ ///
232+ /// # Panics
233+ ///
234+ /// - if `n_cells * n_max_seq` does not fit into usize.
235+ /// - if `n_max_seq` does not fit into usize.
236+ pub fn cells_sequences ( & self ) -> impl Iterator < Item = & [ llama_cpp_sys_2:: llama_seq_id ] > {
237+ unsafe {
238+ std:: slice:: from_raw_parts (
239+ self . view . cells_sequences ,
240+ usize:: try_from ( self . view . n_cells * self . view . n_max_seq )
241+ . expect ( "failed to fit n_cells * n_max_seq into usize" ) ,
242+ )
243+ }
244+ . chunks ( usize:: try_from ( self . view . n_max_seq ) . expect ( "failed to fit n_max_seq into usize" ) )
223245 }
224246}
225247
@@ -229,4 +251,4 @@ impl<'a> Drop for KVCacheView<'a> {
229251 llama_cpp_sys_2:: llama_kv_cache_view_free ( & mut self . view ) ;
230252 }
231253 }
232- }
254+ }
0 commit comments