Skip to content

Commit 288811e

Browse files
authored
Merge pull request #874 from marek-hradil/main
Improve the grammar error handling
2 parents 5ee1d5b + 41c1883 commit 288811e

File tree

2 files changed

+68
-23
lines changed

2 files changed

+68
-23
lines changed

llama-cpp-2/src/lib.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,23 @@ pub enum EmbeddingsError {
156156
NonePoolType,
157157
}
158158

159+
/// Errors that can occur when initializing a grammar sampler
160+
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
161+
pub enum GrammarError {
162+
/// The grammar root was not found in the grammar string
163+
#[error("Grammar root not found in grammar string")]
164+
RootNotFound,
165+
/// The trigger word contains null bytes
166+
#[error("Trigger word contains null bytes")]
167+
TriggerWordNullBytes,
168+
/// The grammar string or root contains null bytes
169+
#[error("Grammar string or root contains null bytes")]
170+
GrammarNullBytes,
171+
/// The grammar call returned null
172+
#[error("Grammar call returned null")]
173+
NullGrammar,
174+
}
175+
159176
/// Decode a error from llama.cpp into a [`DecodeError`].
160177
impl From<NonZeroI32> for DecodeError {
161178
fn from(value: NonZeroI32) -> Self {

llama-cpp-2/src/sampling.rs

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use crate::model::LlamaModel;
99
use crate::token::data_array::LlamaTokenDataArray;
1010
use crate::token::logit_bias::LlamaLogitBias;
1111
use crate::token::LlamaToken;
12+
use crate::GrammarError;
1213

1314
/// A safe wrapper around `llama_sampler`.
1415
pub struct LlamaSampler {
@@ -274,13 +275,14 @@ impl LlamaSampler {
274275
}
275276

276277
/// Grammar sampler
277-
///
278-
/// # Panics
279-
/// If either of ``grammar_str`` or ``grammar_root`` contain null bytes.
280278
#[must_use]
281-
pub fn grammar(model: &LlamaModel, grammar_str: &str, grammar_root: &str) -> Option<Self> {
282-
let grammar_str = CString::new(grammar_str).unwrap();
283-
let grammar_root = CString::new(grammar_root).unwrap();
279+
pub fn grammar(
280+
model: &LlamaModel,
281+
grammar_str: &str,
282+
grammar_root: &str,
283+
) -> Result<Self, GrammarError> {
284+
let (grammar_str, grammar_root) =
285+
Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
284286

285287
let sampler = unsafe {
286288
llama_cpp_sys_2::llama_sampler_init_grammar(
@@ -291,37 +293,29 @@ impl LlamaSampler {
291293
};
292294

293295
if sampler.is_null() {
294-
None
296+
Err(GrammarError::NullGrammar)
295297
} else {
296-
Some(Self { sampler })
298+
Ok(Self { sampler })
297299
}
298300
}
299301

300302
/// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
301303
///
302304
/// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
303-
///
304-
/// # Panics
305-
/// - If `grammar_str` or `grammar_root` contain null bytes
306-
/// - If any trigger word contains null bytes
307305
#[must_use]
308306
pub fn grammar_lazy(
309307
model: &LlamaModel,
310308
grammar_str: &str,
311309
grammar_root: &str,
312310
trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
313311
trigger_tokens: &[LlamaToken],
314-
) -> Option<Self> {
315-
let grammar_str = CString::new(grammar_str).unwrap();
316-
let grammar_root = CString::new(grammar_root).unwrap();
317-
318-
let trigger_word_cstrings: Vec<CString> = trigger_words
319-
.into_iter()
320-
.map(|word| CString::new(word.as_ref()).unwrap())
321-
.collect();
312+
) -> Result<Self, GrammarError> {
313+
let (grammar_str, grammar_root) =
314+
Self::sanitize_grammar_strings(grammar_str, grammar_root)?;
315+
let trigger_words = Self::sanitize_trigger_words(trigger_words)?;
322316

323317
let mut trigger_word_ptrs: Vec<*const c_char> =
324-
trigger_word_cstrings.iter().map(|cs| cs.as_ptr()).collect();
318+
trigger_words.iter().map(|cs| cs.as_ptr()).collect();
325319

326320
let sampler = unsafe {
327321
llama_cpp_sys_2::llama_sampler_init_grammar_lazy(
@@ -336,12 +330,46 @@ impl LlamaSampler {
336330
};
337331

338332
if sampler.is_null() {
339-
None
333+
Err(GrammarError::NullGrammar)
340334
} else {
341-
Some(Self { sampler })
335+
Ok(Self { sampler })
342336
}
343337
}
344338

339+
fn sanitize_grammar_strings(
340+
grammar_str: &str,
341+
grammar_root: &str,
342+
) -> Result<(CString, CString), GrammarError> {
343+
if !grammar_str.contains(grammar_root) {
344+
return Err(GrammarError::RootNotFound);
345+
}
346+
347+
if grammar_str.contains('\0') || grammar_root.contains('\0') {
348+
return Err(GrammarError::GrammarNullBytes);
349+
}
350+
351+
Ok((
352+
CString::new(grammar_str).unwrap(),
353+
CString::new(grammar_root).unwrap(),
354+
))
355+
}
356+
357+
fn sanitize_trigger_words(
358+
trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
359+
) -> Result<Vec<CString>, GrammarError> {
360+
let trigger_words: Vec<_> = trigger_words.into_iter().collect();
361+
if trigger_words
362+
.iter()
363+
.any(|word| word.as_ref().contains(&b'\0'))
364+
{
365+
return Err(GrammarError::TriggerWordNullBytes);
366+
}
367+
Ok(trigger_words
368+
.into_iter()
369+
.map(|word| CString::new(word.as_ref()).unwrap())
370+
.collect())
371+
}
372+
345373
/// DRY sampler, designed by p-e-w, as described in:
346374
/// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
347375
/// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>

0 commit comments

Comments
 (0)