From 255c3a5547a1836d9600126ccffb3c3a9526161b Mon Sep 17 00:00:00 2001 From: xrwang8 Date: Fri, 11 Jul 2025 11:29:44 +0800 Subject: [PATCH] fix: prevent integer overflow in candle backend sequence length calculation This commit fixes a critical integer overflow bug in the Candle backend that causes CUDA driver crashes and massive memory allocation requests. The issue occurred when calculating sequence lengths using unsigned integer subtraction without overflow protection: (batch.cumulative_seq_lengths[i + 1] - batch.cumulative_seq_lengths[i]) as usize When cumulative_seq_lengths[i] > cumulative_seq_lengths[i + 1], the subtraction underflows, producing a very large u32 value that gets cast to usize, resulting in memory allocation requests of ~18.4 EB. Signed-off-by: xrwang8 --- backends/candle/src/lib.rs | 47 +++++++++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index 882cdb8a..ca3bff5f 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -500,7 +500,10 @@ impl Backend for CandleBackend { // Used for indexing in the raw_embeddings tensor let input_lengths: Vec = (0..batch.len()) .map(|i| { - (batch.cumulative_seq_lengths[i + 1] - batch.cumulative_seq_lengths[i]) as usize + batch.cumulative_seq_lengths[i + 1] + .checked_sub(batch.cumulative_seq_lengths[i]) + .expect("Invalid cumulative sequence lengths: sequence lengths must be non-decreasing") + as usize }) .collect(); @@ -565,3 +568,45 @@ impl WrapErr for Result { self.map_err(|e| BackendError::Inference(e.to_string())) } } + +#[cfg(test)] +mod tests { + use super::*; + use text_embeddings_backend_core::Batch; + + #[test] + #[should_panic(expected = "Invalid cumulative sequence lengths")] + fn test_invalid_cumulative_seq_lengths() { + // Create a mock backend for testing + let device = Device::Cpu; + let model = Box::new(MockModel); + let backend = CandleBackend { device, model }; + + // Create a batch with invalid cumulative sequence lengths (decreasing) + let batch = Batch { + input_ids: vec![1, 2, 3, 4], + token_type_ids: vec![0, 0, 0, 0], + position_ids: vec![0, 1, 2, 3], + cumulative_seq_lengths: vec![0, 3, 2], // Invalid: 3 > 2 + max_length: 4, + pooled_indices: vec![0], + raw_indices: vec![], + }; + + // This should panic due to invalid cumulative sequence lengths + let _ = backend.embed(batch); + } + + // Mock model for testing + struct MockModel; + + impl crate::models::Model for MockModel { + fn is_padded(&self) -> bool { + false + } + + fn embed(&self, _batch: Batch) -> candle::Result<(Option, Option)> { + Ok((None, None)) + } + } +}