1+ //! utilities for working with session files
2+
3+ use std:: ffi:: { CString , NulError } ;
4+ use std:: path:: { Path , PathBuf } ;
5+ use crate :: context:: LlamaContext ;
6+ use crate :: token:: LlamaToken ;
7+
8+ #[ derive( Debug , Eq , PartialEq , thiserror:: Error ) ]
9+ pub enum SaveSessionError {
10+ #[ error( "Failed to save session file" ) ]
11+ FailedToSave ,
12+
13+ #[ error( "null byte in string {0}" ) ]
14+ NullError ( #[ from] NulError ) ,
15+
16+ #[ error( "failed to convert path {0} to str" ) ]
17+ PathToStrError ( PathBuf ) ,
18+ }
19+
20+ #[ derive( Debug , Eq , PartialEq , thiserror:: Error ) ]
21+ pub enum LoadSessionError {
22+ #[ error( "Failed to load session file" ) ]
23+ FailedToLoad ,
24+
25+ #[ error( "null byte in string {0}" ) ]
26+ NullError ( #[ from] NulError ) ,
27+
28+ #[ error( "failed to convert path {0} to str" ) ]
29+ PathToStrError ( PathBuf ) ,
30+ }
31+
32+ impl LlamaContext < ' _ > {
33+ pub fn save_session_file ( & self , path_session : impl AsRef < Path > , tokens : & [ LlamaToken ] ) -> Result < ( ) , SaveSessionError > {
34+ let path = path_session. as_ref ( ) ;
35+ let path = path
36+ . to_str ( )
37+ . ok_or ( SaveSessionError :: PathToStrError ( path. to_path_buf ( ) ) ) ?;
38+
39+ let cstr = CString :: new ( path) ?;
40+
41+ if unsafe {
42+ llama_cpp_sys_2:: llama_save_session_file (
43+ self . context . as_ptr ( ) ,
44+ cstr. as_ptr ( ) ,
45+ tokens. as_ptr ( ) as * const i32 ,
46+ tokens. len ( ) )
47+ } {
48+ Ok ( ( ) )
49+ } else {
50+ Err ( SaveSessionError :: FailedToSave )
51+ }
52+ }
53+ pub fn load_session_file ( & mut self , path_session : impl AsRef < Path > , max_tokens : usize ) -> Result < Vec < LlamaToken > , LoadSessionError > {
54+ let path = path_session. as_ref ( ) ;
55+ let path = path
56+ . to_str ( )
57+ . ok_or ( LoadSessionError :: PathToStrError ( path. to_path_buf ( ) ) ) ?;
58+
59+ let cstr = CString :: new ( path) ?;
60+ let mut tokens = Vec :: with_capacity ( max_tokens) ;
61+ let mut n_out = 0 ;
62+
63+ unsafe {
64+ if llama_cpp_sys_2:: llama_load_session_file (
65+ self . context . as_ptr ( ) ,
66+ cstr. as_ptr ( ) ,
67+ tokens. as_mut_ptr ( ) as * mut i32 ,
68+ max_tokens,
69+ & mut n_out) {
70+ tokens. set_len ( n_out) ;
71+ Ok ( tokens)
72+ } else {
73+ Err ( LoadSessionError :: FailedToLoad )
74+ }
75+ }
76+ }
77+ }
0 commit comments