Skip to content

Commit b8fad09

Browse files
committed
Expose functions llama_load_session_file and llama_save_session_file
1 parent 1c2306f commit b8fad09

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed

llama-cpp-2/src/context.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use std::slice;
1515
pub mod kv_cache;
1616
pub mod params;
1717
pub mod sample;
18+
pub mod session;
1819

1920
/// Safe wrapper around `llama_context`.
2021
#[allow(clippy::module_name_repetitions)]

llama-cpp-2/src/context/session.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
//! utilities for working with session files
2+
3+
use std::ffi::{CString, NulError};
4+
use std::path::{Path, PathBuf};
5+
use crate::context::LlamaContext;
6+
use crate::token::LlamaToken;
7+
8+
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
9+
pub enum SaveSessionError {
10+
#[error("Failed to save session file")]
11+
FailedToSave,
12+
13+
#[error("null byte in string {0}")]
14+
NullError(#[from] NulError),
15+
16+
#[error("failed to convert path {0} to str")]
17+
PathToStrError(PathBuf),
18+
}
19+
20+
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
21+
pub enum LoadSessionError {
22+
#[error("Failed to load session file")]
23+
FailedToLoad,
24+
25+
#[error("null byte in string {0}")]
26+
NullError(#[from] NulError),
27+
28+
#[error("failed to convert path {0} to str")]
29+
PathToStrError(PathBuf),
30+
}
31+
32+
impl LlamaContext<'_> {
33+
pub fn save_session_file(&self, path_session: impl AsRef<Path>, tokens: &[LlamaToken]) -> Result<(), SaveSessionError> {
34+
let path = path_session.as_ref();
35+
let path = path
36+
.to_str()
37+
.ok_or(SaveSessionError::PathToStrError(path.to_path_buf()))?;
38+
39+
let cstr = CString::new(path)?;
40+
41+
if unsafe {
42+
llama_cpp_sys_2::llama_save_session_file(
43+
self.context.as_ptr(),
44+
cstr.as_ptr(),
45+
tokens.as_ptr() as *const i32,
46+
tokens.len())
47+
} {
48+
Ok(())
49+
} else {
50+
Err(SaveSessionError::FailedToSave)
51+
}
52+
}
53+
pub fn load_session_file(&mut self, path_session: impl AsRef<Path>, max_tokens: usize) -> Result<Vec<LlamaToken>, LoadSessionError> {
54+
let path = path_session.as_ref();
55+
let path = path
56+
.to_str()
57+
.ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?;
58+
59+
let cstr = CString::new(path)?;
60+
let mut tokens = Vec::with_capacity(max_tokens);
61+
let mut n_out = 0;
62+
63+
unsafe {
64+
if llama_cpp_sys_2::llama_load_session_file(
65+
self.context.as_ptr(),
66+
cstr.as_ptr(),
67+
tokens.as_mut_ptr() as *mut i32,
68+
max_tokens,
69+
&mut n_out) {
70+
tokens.set_len(n_out);
71+
Ok(tokens)
72+
} else {
73+
Err(LoadSessionError::FailedToLoad)
74+
}
75+
}
76+
}
77+
}

0 commit comments

Comments
 (0)