Skip to content

Commit 5ee1d5b

Browse files
authored
Merge pull request #877 from ysimonson/fix-batch-ub
Added lifetime to LlamaBatch
2 parents 7b6cdf4 + 760ade2 commit 5ee1d5b

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

examples/mtmd/src/mtmd.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ pub struct MtmdCliParams {
7272

7373
/// State of the MTMD CLI application.
7474
#[allow(missing_debug_implementations)]
75-
pub struct MtmdCliContext {
75+
pub struct MtmdCliContext<'a> {
7676
/// The MTMD context for multimodal processing.
7777
pub mtmd_ctx: MtmdContext,
7878
/// The batch used for processing tokens.
79-
pub batch: LlamaBatch,
79+
pub batch: LlamaBatch<'a>,
8080
/// The list of loaded bitmaps (images/audio).
8181
pub bitmaps: Vec<MtmdBitmap>,
8282
/// The number of past tokens processed.
@@ -87,7 +87,7 @@ pub struct MtmdCliContext {
8787
pub chat: Vec<LlamaChatMessage>,
8888
}
8989

90-
impl MtmdCliContext {
90+
impl<'a> MtmdCliContext<'a> {
9191
/// Creates a new MTMD CLI context
9292
///
9393
/// # Errors

llama-cpp-2/src/llama_batch.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@
22
33
use crate::token::LlamaToken;
44
use llama_cpp_sys_2::{llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id};
5+
use std::marker::PhantomData;
56

67
/// A safe wrapper around `llama_batch`.
78
#[derive(Debug)]
8-
pub struct LlamaBatch {
9+
pub struct LlamaBatch<'a> {
910
/// The number of tokens the batch was allocated with. they are safe to write to - but not necessarily read from as they are not necessarily initialized
1011
allocated: usize,
1112
/// The logits that are initialized. Used by [`LlamaContext`] to ensure that only initialized logits are accessed.
1213
pub(crate) initialized_logits: Vec<i32>,
1314
#[allow(clippy::doc_markdown)]
1415
/// The llama_cpp batch. always initialize by `llama_cpp_sys_2::llama_batch_init(allocated, <unknown>, <unknown>)`
1516
pub(crate) llama_batch: llama_batch,
17+
phantom: PhantomData<&'a [LlamaToken]>,
1618
}
1719

1820
/// Errors that can occur when adding a token to a batch.
@@ -26,7 +28,7 @@ pub enum BatchAddError {
2628
EmptyBuffer,
2729
}
2830

29-
impl LlamaBatch {
31+
impl<'a> LlamaBatch<'a> {
3032
/// Clear the batch. This does not free the memory associated with the batch, but it does reset
3133
/// the number of tokens to 0.
3234
pub fn clear(&mut self) {
@@ -150,6 +152,7 @@ impl LlamaBatch {
150152
allocated: n_tokens,
151153
initialized_logits: vec![],
152154
llama_batch: batch,
155+
phantom: PhantomData,
153156
}
154157
}
155158

@@ -163,7 +166,7 @@ impl LlamaBatch {
163166
///
164167
/// # Panics
165168
/// If the number of tokens in ``tokens`` exceeds [`i32::MAX`].
166-
pub fn get_one(tokens: &[LlamaToken]) -> Result<Self, BatchAddError> {
169+
pub fn get_one(tokens: &'a [LlamaToken]) -> Result<Self, BatchAddError> {
167170
if tokens.is_empty() {
168171
return Err(BatchAddError::EmptyBuffer);
169172
}
@@ -183,6 +186,7 @@ impl LlamaBatch {
183186
.try_into()
184187
.expect("number of tokens exceeds i32::MAX + 1")],
185188
llama_batch: batch,
189+
phantom: PhantomData,
186190
};
187191
Ok(batch)
188192
}
@@ -194,7 +198,7 @@ impl LlamaBatch {
194198
}
195199
}
196200

197-
impl Drop for LlamaBatch {
201+
impl<'a> Drop for LlamaBatch<'a> {
198202
/// Drops the `LlamaBatch`.
199203
///
200204
/// ```

0 commit comments

Comments
 (0)