@@ -132,6 +132,50 @@ impl LlamaContextParams {
132132 NonZeroU32 :: new ( self . context_params . n_ctx )
133133 }
134134
135+ /// Set the n_batch
136+ ///
137+ /// # Examples
138+ ///
139+ /// ```rust
140+ /// # use std::num::NonZeroU32;
141+ /// use llama_cpp_2::context::params::LlamaContextParams;
142+ /// let params = LlamaContextParams::default()
143+ /// .with_n_batch(2048);
144+ /// assert_eq!(params.n_batch(), 2048);
145+ /// ```
146+ pub fn with_n_batch ( mut self , n_batch : u32 ) -> Self {
147+ self . context_params . n_batch = n_batch;
148+ self
149+ }
150+
151+ /// Get the n_batch
152+ ///
153+ /// # Examples
154+ ///
155+ /// ```rust
156+ /// use llama_cpp_2::context::params::LlamaContextParams;
157+ /// let params = LlamaContextParams::default();
158+ /// assert_eq!(params.n_batch(), 512);
159+ /// ```
160+ pub fn n_batch ( & self ) -> u32 {
161+ self . context_params . n_batch
162+ }
163+
164+ /// Set the type of rope scaling.
165+ ///
166+ /// # Examples
167+ ///
168+ /// ```rust
169+ /// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
170+ /// let params = LlamaContextParams::default()
171+ /// .with_rope_scaling_type(RopeScalingType::Linear);
172+ /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
173+ /// ```
174+ pub fn with_rope_scaling_type ( mut self , rope_scaling_type : RopeScalingType ) -> Self {
175+ self . context_params . rope_scaling_type = i8:: from ( rope_scaling_type) ;
176+ self
177+ }
178+
135179 /// Get the type of rope scaling.
136180 ///
137181 /// # Examples
@@ -143,6 +187,60 @@ impl LlamaContextParams {
143187 pub fn rope_scaling_type ( & self ) -> RopeScalingType {
144188 RopeScalingType :: from ( self . context_params . rope_scaling_type )
145189 }
190+
191+ /// Set the rope frequency base.
192+ ///
193+ /// # Examples
194+ ///
195+ /// ```rust
196+ /// use llama_cpp_2::context::params::LlamaContextParams;
197+ /// let params = LlamaContextParams::default()
198+ /// .with_rope_freq_base(0.5);
199+ /// assert_eq!(params.rope_freq_base(), 0.5);
200+ /// ```
201+ pub fn with_rope_freq_base ( mut self , rope_freq_base : f32 ) -> Self {
202+ self . context_params . rope_freq_base = rope_freq_base;
203+ self
204+ }
205+
206+ /// Get the rope frequency base.
207+ ///
208+ /// # Examples
209+ ///
210+ /// ```rust
211+ /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
212+ /// assert_eq!(params.rope_freq_base(), 0.0);
213+ /// ```
214+ pub fn rope_freq_base ( & self ) -> f32 {
215+ self . context_params . rope_freq_base
216+ }
217+
218+ /// Set the rope frequency scale.
219+ ///
220+ /// # Examples
221+ ///
222+ /// ```rust
223+ /// use llama_cpp_2::context::params::LlamaContextParams;
224+ /// let params = LlamaContextParams::default()
225+ /// .with_rope_freq_scale(0.5);
226+ /// assert_eq!(params.rope_freq_scale(), 0.5);
227+ /// ```
228+ pub fn with_rope_freq_scale ( mut self , rope_freq_scale : f32 ) -> Self {
229+ self . context_params . rope_freq_scale = rope_freq_scale;
230+ self
231+ }
232+
233+ /// Get the rope frequency scale.
234+ ///
235+ /// # Examples
236+ ///
237+ /// ```rust
238+ /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
239+ /// assert_eq!(params.rope_freq_scale(), 0.0);
240+ /// ```
241+ pub fn rope_freq_scale ( & self ) -> f32 {
242+ self . context_params . rope_freq_scale
243+ }
146244}
147245
148246/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
@@ -156,6 +254,6 @@ impl LlamaContextParams {
156254impl Default for LlamaContextParams {
157255 fn default ( ) -> Self {
158256 let context_params = unsafe { llama_cpp_sys_2:: llama_context_default_params ( ) } ;
159- Self { context_params, }
257+ Self { context_params }
160258 }
161259}
0 commit comments