@@ -9,6 +9,7 @@ use crate::model::LlamaModel;
99use crate :: token:: data_array:: LlamaTokenDataArray ;
1010use crate :: token:: logit_bias:: LlamaLogitBias ;
1111use crate :: token:: LlamaToken ;
12+ use crate :: GrammarError ;
1213
1314/// A safe wrapper around `llama_sampler`.
1415pub struct LlamaSampler {
@@ -274,13 +275,14 @@ impl LlamaSampler {
274275 }
275276
276277 /// Grammar sampler
277- ///
278- /// # Panics
279- /// If either of ``grammar_str`` or ``grammar_root`` contain null bytes.
280278 #[ must_use]
281- pub fn grammar ( model : & LlamaModel , grammar_str : & str , grammar_root : & str ) -> Option < Self > {
282- let grammar_str = CString :: new ( grammar_str) . unwrap ( ) ;
283- let grammar_root = CString :: new ( grammar_root) . unwrap ( ) ;
279+ pub fn grammar (
280+ model : & LlamaModel ,
281+ grammar_str : & str ,
282+ grammar_root : & str ,
283+ ) -> Result < Self , GrammarError > {
284+ let ( grammar_str, grammar_root) =
285+ Self :: sanitize_grammar_strings ( grammar_str, grammar_root) ?;
284286
285287 let sampler = unsafe {
286288 llama_cpp_sys_2:: llama_sampler_init_grammar (
@@ -291,37 +293,29 @@ impl LlamaSampler {
291293 } ;
292294
293295 if sampler. is_null ( ) {
294- None
296+ Err ( GrammarError :: NullGrammar )
295297 } else {
296- Some ( Self { sampler } )
298+ Ok ( Self { sampler } )
297299 }
298300 }
299301
300302 /// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
301303 ///
302304 /// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
303- ///
304- /// # Panics
305- /// - If `grammar_str` or `grammar_root` contain null bytes
306- /// - If any trigger word contains null bytes
307305 #[ must_use]
308306 pub fn grammar_lazy (
309307 model : & LlamaModel ,
310308 grammar_str : & str ,
311309 grammar_root : & str ,
312310 trigger_words : impl IntoIterator < Item = impl AsRef < [ u8 ] > > ,
313311 trigger_tokens : & [ LlamaToken ] ,
314- ) -> Option < Self > {
315- let grammar_str = CString :: new ( grammar_str) . unwrap ( ) ;
316- let grammar_root = CString :: new ( grammar_root) . unwrap ( ) ;
317-
318- let trigger_word_cstrings: Vec < CString > = trigger_words
319- . into_iter ( )
320- . map ( |word| CString :: new ( word. as_ref ( ) ) . unwrap ( ) )
321- . collect ( ) ;
312+ ) -> Result < Self , GrammarError > {
313+ let ( grammar_str, grammar_root) =
314+ Self :: sanitize_grammar_strings ( grammar_str, grammar_root) ?;
315+ let trigger_words = Self :: sanitize_trigger_words ( trigger_words) ?;
322316
323317 let mut trigger_word_ptrs: Vec < * const c_char > =
324- trigger_word_cstrings . iter ( ) . map ( |cs| cs. as_ptr ( ) ) . collect ( ) ;
318+ trigger_words . iter ( ) . map ( |cs| cs. as_ptr ( ) ) . collect ( ) ;
325319
326320 let sampler = unsafe {
327321 llama_cpp_sys_2:: llama_sampler_init_grammar_lazy (
@@ -336,12 +330,46 @@ impl LlamaSampler {
336330 } ;
337331
338332 if sampler. is_null ( ) {
339- None
333+ Err ( GrammarError :: NullGrammar )
340334 } else {
341- Some ( Self { sampler } )
335+ Ok ( Self { sampler } )
342336 }
343337 }
344338
339+ fn sanitize_grammar_strings (
340+ grammar_str : & str ,
341+ grammar_root : & str ,
342+ ) -> Result < ( CString , CString ) , GrammarError > {
343+ if !grammar_str. contains ( grammar_root) {
344+ return Err ( GrammarError :: RootNotFound ) ;
345+ }
346+
347+ if grammar_str. contains ( '\0' ) || grammar_root. contains ( '\0' ) {
348+ return Err ( GrammarError :: GrammarNullBytes ) ;
349+ }
350+
351+ Ok ( (
352+ CString :: new ( grammar_str) . unwrap ( ) ,
353+ CString :: new ( grammar_root) . unwrap ( ) ,
354+ ) )
355+ }
356+
357+ fn sanitize_trigger_words (
358+ trigger_words : impl IntoIterator < Item = impl AsRef < [ u8 ] > > ,
359+ ) -> Result < Vec < CString > , GrammarError > {
360+ let trigger_words: Vec < _ > = trigger_words. into_iter ( ) . collect ( ) ;
361+ if trigger_words
362+ . iter ( )
363+ . any ( |word| word. as_ref ( ) . contains ( & b'\0' ) )
364+ {
365+ return Err ( GrammarError :: TriggerWordNullBytes ) ;
366+ }
367+ Ok ( trigger_words
368+ . into_iter ( )
369+ . map ( |word| CString :: new ( word. as_ref ( ) ) . unwrap ( ) )
370+ . collect ( ) )
371+ }
372+
345373 /// DRY sampler, designed by p-e-w, as described in:
346374 /// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
347375 /// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>
0 commit comments