@@ -3,6 +3,45 @@ use llama_cpp_sys_2::{ggml_type, llama_context_params};
33use std:: fmt:: Debug ;
44use std:: num:: NonZeroU32 ;
55
6+ /// A rusty wrapper around `rope_scaling_type`.
7+ #[ repr( i8 ) ]
8+ #[ derive( Copy , Clone , Debug , PartialEq , Eq ) ]
9+ pub enum RopeScalingType {
10+ /// The scaling type is unspecified
11+ Unspecified = -1 ,
12+ /// No scaling
13+ None = 0 ,
14+ /// Linear scaling
15+ Linear = 1 ,
16+ /// Yarn scaling
17+ Yarn = 2 ,
18+ }
19+
20+ /// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
21+ /// the value is not recognized.
22+ impl From < i8 > for RopeScalingType {
23+ fn from ( value : i8 ) -> Self {
24+ match value {
25+ 0 => Self :: None ,
26+ 1 => Self :: Linear ,
27+ 2 => Self :: Yarn ,
28+ _ => Self :: Unspecified ,
29+ }
30+ }
31+ }
32+
33+ /// Create a `c_int` from a `RopeScalingType`.
34+ impl From < RopeScalingType > for i8 {
35+ fn from ( value : RopeScalingType ) -> Self {
36+ match value {
37+ RopeScalingType :: None => 0 ,
38+ RopeScalingType :: Linear => 1 ,
39+ RopeScalingType :: Yarn => 2 ,
40+ RopeScalingType :: Unspecified => -1 ,
41+ }
42+ }
43+ }
44+
645/// A safe wrapper around `llama_context_params`.
746#[ derive( Debug , Clone , Copy , PartialEq ) ]
847#[ allow(
@@ -18,7 +57,7 @@ pub struct LlamaContextParams {
1857 pub n_batch : u32 ,
1958 pub n_threads : u32 ,
2059 pub n_threads_batch : u32 ,
21- pub rope_scaling_type : i8 ,
60+ pub rope_scaling_type : RopeScalingType ,
2261 pub rope_freq_base : f32 ,
2362 pub rope_freq_scale : f32 ,
2463 pub yarn_ext_factor : f32 ,
@@ -83,7 +122,7 @@ impl From<llama_context_params> for LlamaContextParams {
83122 mul_mat_q,
84123 logits_all,
85124 embedding,
86- rope_scaling_type,
125+ rope_scaling_type : RopeScalingType :: from ( rope_scaling_type ) ,
87126 yarn_ext_factor,
88127 yarn_attn_factor,
89128 yarn_beta_fast,
@@ -131,7 +170,7 @@ impl From<LlamaContextParams> for llama_context_params {
131170 mul_mat_q,
132171 logits_all,
133172 embedding,
134- rope_scaling_type,
173+ rope_scaling_type : i8 :: from ( rope_scaling_type ) ,
135174 yarn_ext_factor,
136175 yarn_attn_factor,
137176 yarn_beta_fast,
0 commit comments