Skip to content

Commit 52687ca

Browse files
authored
Merge pull request #116 from zh217/main
Expose the complete API for dealing with KV cache and states
2 parents e85ed8c + 2967bd6 commit 52687ca

File tree

2 files changed

+224
-1
lines changed

2 files changed

+224
-1
lines changed

llama-cpp-2/src/context/kv_cache.rs

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! utilities for working with the kv cache
22
3+
use std::num::NonZeroU8;
34
use crate::context::LlamaContext;
45

56
impl LlamaContext<'_> {
@@ -14,6 +15,26 @@ impl LlamaContext<'_> {
1415
unsafe { llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, 0, size) }
1516
}
1617

18+
/// Copy the cache from one sequence to another.
19+
///
20+
/// # Parameters
21+
///
22+
/// * `src` - The sequence id to copy the cache from.
23+
/// * `dest` - The sequence id to copy the cache to.
24+
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to [p1].
25+
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from [p0].
26+
pub fn copy_kv_cache_seq(&mut self, src: i32, dest: i32, p0: Option<u16>, p1: Option<u16>) {
27+
unsafe {
28+
llama_cpp_sys_2::llama_kv_cache_seq_cp(
29+
self.context.as_ptr(),
30+
src,
31+
dest,
32+
p0.map_or(-1, i32::from),
33+
p1.map_or(-1, i32::from),
34+
)
35+
}
36+
}
37+
1738
/// Clear the kv cache for the given sequence.
1839
///
1940
/// # Parameters
@@ -31,4 +52,181 @@ impl LlamaContext<'_> {
3152
);
3253
}
3354
}
55+
56+
/// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
57+
pub fn get_kv_cache_used_cells(&self) -> i32 {
58+
unsafe { llama_cpp_sys_2::llama_get_kv_cache_used_cells(self.context.as_ptr()) }
59+
}
60+
61+
/// Clear the KV cache
62+
pub fn clear_kv_cache(&mut self) {
63+
unsafe { llama_cpp_sys_2::llama_kv_cache_clear(self.context.as_ptr()) }
64+
}
65+
66+
/// Removes all tokens that do not belong to the specified sequence
67+
///
68+
/// # Parameters
69+
///
70+
/// * `seq_id` - The sequence id to keep
71+
pub fn llama_kv_cache_seq_keep(&mut self, seq_id: i32) {
72+
unsafe { llama_cpp_sys_2::llama_kv_cache_seq_keep(self.context.as_ptr(), seq_id) }
73+
}
74+
75+
/// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
76+
/// If the KV cache is RoPEd, the KV data is updated accordingly:
77+
/// - lazily on next llama_decode()
78+
/// - explicitly with llama_kv_cache_update()
79+
///
80+
/// # Parameters
81+
///
82+
/// * `seq_id` - The sequence id to update
83+
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to [p1].
84+
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0].
85+
/// * `delta` - The relative position to add to the tokens
86+
pub fn kv_cache_seq_add(&mut self, seq_id: i32, p0: Option<u16>, p1: Option<u16>, delta: i32) {
87+
unsafe {
88+
llama_cpp_sys_2::llama_kv_cache_seq_add(
89+
self.context.as_ptr(),
90+
seq_id,
91+
p0.map_or(-1, i32::from),
92+
p1.map_or(-1, i32::from),
93+
delta,
94+
)
95+
}
96+
}
97+
98+
/// Integer division of the positions by factor of `d > 1`
99+
/// If the KV cache is RoPEd, the KV data is updated accordingly:
100+
/// - lazily on next llama_decode()
101+
/// - explicitly with llama_kv_cache_update()
102+
///
103+
/// # Parameters
104+
///
105+
/// * `seq_id` - The sequence id to update
106+
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to [p1].
107+
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0].
108+
/// * `d` - The factor to divide the positions by
109+
pub fn kv_cache_seq_div(&mut self, seq_id: i32, p0: Option<u16>, p1: Option<u16>, d: NonZeroU8) {
110+
unsafe {
111+
llama_cpp_sys_2::llama_kv_cache_seq_div(
112+
self.context.as_ptr(),
113+
seq_id,
114+
p0.map_or(-1, i32::from),
115+
p1.map_or(-1, i32::from),
116+
d.get().try_into().expect("d does not fit into a i32"),
117+
)
118+
}
119+
}
120+
121+
/// Returns the largest position present in the KV cache for the specified sequence
122+
///
123+
/// # Parameters
124+
///
125+
/// * `seq_id` - The sequence id to get the max position for
126+
pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
127+
unsafe { llama_cpp_sys_2::llama_kv_cache_seq_pos_max(self.context.as_ptr(), seq_id) }
128+
}
129+
130+
/// Defragment the KV cache
131+
/// This will be applied:
132+
/// - lazily on next llama_decode()
133+
/// - explicitly with llama_kv_cache_update()
134+
pub fn kv_cache_defrag(&mut self) {
135+
unsafe { llama_cpp_sys_2::llama_kv_cache_defrag(self.context.as_ptr()) }
136+
}
137+
138+
/// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
139+
pub fn kv_cache_update(&mut self) {
140+
unsafe { llama_cpp_sys_2::llama_kv_cache_update(self.context.as_ptr()) }
141+
}
142+
143+
/// Returns the number of tokens in the KV cache (slow, use only for debug)
144+
/// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
145+
pub fn get_kv_cache_token_count(&self) -> i32 {
146+
unsafe { llama_cpp_sys_2::llama_get_kv_cache_token_count(self.context.as_ptr()) }
147+
}
148+
149+
/// Create an empty KV cache view. (use only for debugging purposes)
150+
///
151+
/// # Parameters
152+
///
153+
/// * `n_max_seq` - Maximum number of sequences that can exist in a cell. It's not an error
154+
/// if there are more sequences in a cell than this value, however they will
155+
/// not be visible in the view cells_sequences.
156+
pub fn new_kv_cache_view(&self, n_max_seq: i32) -> KVCacheView {
157+
let view = unsafe { llama_cpp_sys_2::llama_kv_cache_view_init(self.context.as_ptr(), n_max_seq) };
158+
KVCacheView { view, ctx: self }
159+
}
160+
}
161+
162+
163+
/// Information associated with an individual cell in the KV cache view.
164+
#[derive(Debug)]
165+
pub struct KVCacheViewCell {
166+
/// The position for this cell. Takes KV cache shifts into account.
167+
/// May be negative if the cell is not populated.
168+
pub pos: llama_cpp_sys_2::llama_pos,
169+
}
170+
171+
/// An updateable view of the KV cache. (use only for debugging purposes)
172+
#[derive(Debug)]
173+
pub struct KVCacheView<'a> {
174+
ctx: &'a LlamaContext<'a>,
175+
view: llama_cpp_sys_2::llama_kv_cache_view,
34176
}
177+
178+
impl<'a> KVCacheView<'a> {
179+
/// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
180+
pub fn update(&mut self) {
181+
unsafe { llama_cpp_sys_2::llama_kv_cache_view_update(self.ctx.context.as_ptr(), &mut self.view) }
182+
}
183+
184+
/// Number of KV cache cells. This will be the same as the context size.
185+
pub fn n_cells(&self) -> i32 {
186+
self.view.n_cells
187+
}
188+
189+
/// Number of tokens in the cache. For example, if there are two populated
190+
/// cells, the first with 1 sequence id in it and the second with 2 sequence
191+
/// ids then you'll have 3 tokens.
192+
pub fn token_count(&self) -> i32 {
193+
self.view.token_count
194+
}
195+
196+
/// Number of populated cache cells.
197+
pub fn used_cells(&self) -> i32 {
198+
self.view.used_cells
199+
}
200+
201+
/// Maximum contiguous empty slots in the cache.
202+
pub fn max_contiguous(&self) -> i32 {
203+
self.view.max_contiguous
204+
}
205+
206+
/// Index to the start of the max_contiguous slot range. Can be negative
207+
/// when cache is full.
208+
pub fn max_contiguous_idx(&self) -> i32 {
209+
self.view.max_contiguous_idx
210+
}
211+
212+
/// Information for individual cells.
213+
pub fn cells(&self) -> impl Iterator<Item=KVCacheViewCell> {
214+
unsafe { std::slice::from_raw_parts(self.view.cells, self.view.n_cells.try_into().unwrap()) }
215+
.iter()
216+
.map(|&cell| KVCacheViewCell { pos: cell.pos })
217+
}
218+
219+
/// The sequences for each cell. There will be n_max_seq items per cell.
220+
pub fn cells_sequences(&self) -> impl Iterator<Item=&[llama_cpp_sys_2::llama_seq_id]> {
221+
unsafe { std::slice::from_raw_parts(self.view.cells_sequences, (self.view.n_cells * self.view.n_max_seq).try_into().unwrap()) }
222+
.chunks(self.view.n_max_seq.try_into().unwrap())
223+
}
224+
}
225+
226+
impl<'a> Drop for KVCacheView<'a> {
227+
fn drop(&mut self) {
228+
unsafe {
229+
llama_cpp_sys_2::llama_kv_cache_view_free(&mut self.view);
230+
}
231+
}
232+
}

llama-cpp-2/src/context/session.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ pub enum LoadSessionError {
3535
/// failed to convert path to str
3636
#[error("failed to convert path {0} to str")]
3737
PathToStrError(PathBuf),
38-
38+
3939
/// Insufficient max length
4040
#[error("max_length is not large enough to hold {n_out} (was {max_tokens})")]
4141
InsufficientMaxLength {
@@ -130,4 +130,29 @@ impl LlamaContext<'_> {
130130
}
131131
}
132132
}
133+
134+
/// Returns the maximum size in bytes of the state (rng, logits, embedding
135+
/// and kv_cache) - will often be smaller after compacting tokens
136+
pub fn get_state_size(&self) -> usize {
137+
unsafe { llama_cpp_sys_2::llama_get_state_size(self.context.as_ptr()) }
138+
}
139+
140+
/// Copies the state to the specified destination address.
141+
/// Destination needs to have allocated enough memory.
142+
/// Returns the number of bytes copied
143+
pub unsafe fn copy_state_data(&self, dest: *mut u8) -> usize {
144+
unsafe {
145+
llama_cpp_sys_2::llama_copy_state_data(self.context.as_ptr(), dest)
146+
}
147+
}
148+
149+
/// Set the state reading from the specified address
150+
/// Returns the number of bytes read
151+
pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize {
152+
unsafe {
153+
// we don't really need a mutable pointer for `src` -- this is a llama-cpp lapse,
154+
// so we cast away the constness
155+
llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr() as *mut u8)
156+
}
157+
}
133158
}

0 commit comments

Comments
 (0)