Skip to content

Commit d5a5645

Browse files
authored
Merge pull request #672 from nhaghighat/feature/add-samplers-and-lifecycle-methods
Add top_n_sigma and grammar_lazy samplers; Add reset and get_seed methods to LlamaSampler
2 parents 8624b3f + 914bce3 commit d5a5645

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed

llama-cpp-2/src/sampling.rs

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,26 @@ impl LlamaSampler {
6262
self
6363
}
6464

65+
/// Resets the internal state of the sampler.
66+
///
67+
/// This can be useful when you want to start fresh with a sampler without creating a new instance.
68+
pub fn reset(&mut self) {
69+
unsafe {
70+
llama_cpp_sys_2::llama_sampler_reset(self.sampler);
71+
}
72+
}
73+
74+
/// Gets the random seed used by this sampler.
75+
///
76+
/// Returns:
77+
/// - For random samplers (dist, mirostat, mirostat_v2): returns their current seed
78+
/// - For sampler chains: returns the first non-default seed found in reverse order
79+
/// - For all other samplers: returns 0xFFFFFFFF
80+
#[must_use]
81+
pub fn get_seed(&self) -> u32 {
82+
unsafe { llama_cpp_sys_2::llama_sampler_get_seed(self.sampler) }
83+
}
84+
6585
/// Combines a list of samplers into a single sampler that applies each component sampler one
6686
/// after another.
6787
///
@@ -191,6 +211,37 @@ impl LlamaSampler {
191211
Self { sampler }
192212
}
193213

214+
/// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need"
215+
/// <https://arxiv.org/pdf/2411.07641>
216+
///
217+
/// This method filters logits by selecting only those within *n* standard deviations of the mean.
218+
///
219+
/// # Parameters
220+
/// - `n`: Number of standard deviations from the mean to include in sampling
221+
///
222+
/// # Example
223+
/// ```rust
224+
/// use llama_cpp_2::sampling::LlamaSampler;
225+
/// use llama_cpp_2::token::{
226+
/// LlamaToken,
227+
/// data::LlamaTokenData,
228+
/// data_array::LlamaTokenDataArray
229+
/// };
230+
///
231+
/// let mut data_array = LlamaTokenDataArray::new(vec![
232+
/// LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
233+
/// LlamaTokenData::new(LlamaToken(1), 1.0, 0.0),
234+
/// LlamaTokenData::new(LlamaToken(2), 2.0, 0.0),
235+
/// ], false);
236+
///
237+
/// data_array.apply_sampler(&mut LlamaSampler::top_n_sigma(2.0));
238+
/// ```
239+
#[must_use]
240+
pub fn top_n_sigma(n: f32) -> Self {
241+
let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_n_sigma(n) };
242+
Self { sampler }
243+
}
244+
194245
/// Locally Typical Sampling implementation described in the paper <https://arxiv.org/abs/2202.00666>.
195246
#[must_use]
196247
pub fn typical(p: f32, min_keep: usize) -> Self {
@@ -239,6 +290,49 @@ impl LlamaSampler {
239290
Self { sampler }
240291
}
241292

293+
/// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
294+
///
295+
/// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
296+
///
297+
/// # Panics
298+
/// - If `grammar_str` or `grammar_root` contain null bytes
299+
/// - If any trigger word contains null bytes
300+
#[must_use]
301+
pub fn grammar_lazy(
302+
model: &LlamaModel,
303+
grammar_str: &str,
304+
grammar_root: &str,
305+
trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
306+
trigger_tokens: &[LlamaToken],
307+
) -> Self {
308+
let grammar_str = CString::new(grammar_str).unwrap();
309+
let grammar_root = CString::new(grammar_root).unwrap();
310+
311+
let trigger_word_cstrings: Vec<CString> = trigger_words
312+
.into_iter()
313+
.map(|word| CString::new(word.as_ref()).unwrap())
314+
.collect();
315+
316+
let mut trigger_word_ptrs: Vec<*const c_char> = trigger_word_cstrings
317+
.iter()
318+
.map(|cs| cs.as_ptr())
319+
.collect();
320+
321+
let sampler = unsafe {
322+
llama_cpp_sys_2::llama_sampler_init_grammar_lazy(
323+
model.vocab_ptr(),
324+
grammar_str.as_ptr(),
325+
grammar_root.as_ptr(),
326+
trigger_word_ptrs.as_mut_ptr(),
327+
trigger_word_ptrs.len(),
328+
trigger_tokens.as_ptr().cast(),
329+
trigger_tokens.len(),
330+
)
331+
};
332+
333+
Self { sampler }
334+
}
335+
242336
/// DRY sampler, designed by p-e-w, as described in:
243337
/// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
244338
/// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>

0 commit comments

Comments
 (0)