Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion candle-examples/examples/quantized-gemma/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;

use candle_nn::kv_cache::DefaultKvCache;
use clap::{Parser, ValueEnum};
use std::io::Write;
use tokenizers::Tokenizer;
Expand Down Expand Up @@ -175,7 +176,7 @@ fn main() -> anyhow::Result<()> {
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;

let mut model = {
let mut model: ModelWeights<DefaultKvCache> = {
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(&model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
Expand Down
3 changes: 2 additions & 1 deletion candle-examples/examples/quantized-qwen3/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;

use candle_nn::kv_cache::DefaultKvCache;
use clap::{Parser, ValueEnum};
use std::io::Write;
use tokenizers::Tokenizer;
Expand Down Expand Up @@ -189,7 +190,7 @@ fn main() -> anyhow::Result<()> {
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;

let mut model = {
let mut model: Qwen3<DefaultKvCache> = {
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
Expand Down
223 changes: 170 additions & 53 deletions candle-nn/src/kv_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,43 @@
//!
use candle::{DType, Device, Result, Tensor};

pub type DefaultKvCache = ConcatKvCache;

pub trait KvCache {
type Mask;
fn new(dim: usize, max_seq_len: usize) -> Self;
fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)>;
fn append_with_mask(
&mut self,
k: &Tensor,
v: &Tensor,
_mask: Option<&Self::Mask>,
) -> Result<(Tensor, Tensor)> {
self.append(k, v)
}
fn reset(&mut self);
}

#[derive(Debug, Clone)]
pub struct Cache {
pub struct InnerCache {
// all_data is an option on a Tensor, this makes it possible to only create the actual tensor
// on the first call where the batch size is easily known.
// Also this makes it safe to clone a KvCache that has been reset (as in it will not share
// its internal state with the cloned instance).
all_data: Option<Tensor>,
dim: usize,
current_seq_len: usize,
grow_by: usize,
increment: usize,
max_seq_len: usize,
}

impl Cache {
impl InnerCache {
pub fn new(dim: usize, max_seq_len: usize) -> Self {
Self {
all_data: None,
dim,
current_seq_len: 0,
grow_by: max_seq_len,
increment: max_seq_len,
max_seq_len,
}
}
Expand Down Expand Up @@ -68,10 +85,10 @@ impl Cache {
let ad = self.all_data.as_mut().unwrap();
while self.current_seq_len + seq_len > self.max_seq_len {
let mut shape = src.dims().to_vec();
shape[self.dim] = self.grow_by;
shape[self.dim] = self.increment;
let next_ad = Tensor::zeros(shape, src.dtype(), src.device())?;
*ad = Tensor::cat(&[&*ad, &next_ad], self.dim)?;
self.max_seq_len += self.grow_by;
self.max_seq_len += self.increment;
}
ad.slice_set(src, self.dim, self.current_seq_len)?;
self.current_seq_len += seq_len;
Expand All @@ -80,31 +97,46 @@ impl Cache {
}

#[derive(Debug, Clone)]
pub struct KvCache {
k: Cache,
v: Cache,
pub struct IncrementalKvCache {
k: InnerCache,
v: InnerCache,
}

impl KvCache {
impl KvCache for IncrementalKvCache {
type Mask = ();
fn new(dim: usize, max_seq_len: usize) -> Self {
IncrementalKvCache::new(dim, max_seq_len)
}

fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
self.append(k, v)
}

fn reset(&mut self) {
self.reset()
}
}

impl IncrementalKvCache {
pub fn new(dim: usize, max_seq_len: usize) -> Self {
let k = Cache::new(dim, max_seq_len);
let v = Cache::new(dim, max_seq_len);
let k = InnerCache::new(dim, max_seq_len);
let v = InnerCache::new(dim, max_seq_len);
Self { k, v }
}

pub fn k_cache(&self) -> &Cache {
pub fn k_cache(&self) -> &InnerCache {
&self.k
}

pub fn v_cache(&self) -> &Cache {
pub fn v_cache(&self) -> &InnerCache {
&self.v
}

pub fn k_cache_mut(&mut self) -> &mut Cache {
pub fn k_cache_mut(&mut self) -> &mut InnerCache {
&mut self.k
}

pub fn v_cache_mut(&mut self) -> &mut Cache {
pub fn v_cache_mut(&mut self) -> &mut InnerCache {
&mut self.v
}

Expand All @@ -117,8 +149,8 @@ impl KvCache {
}

pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
self.k.append(k)?;
self.v.append(v)?;
self.k.append(&k.contiguous()?)?;
self.v.append(&v.contiguous()?)?;
let out_k = self.k.current_data()?;
let out_v = self.v.current_data()?;
let k = match out_k {
Expand Down Expand Up @@ -338,6 +370,21 @@ pub struct RotatingKvCache {
v: RotatingCache,
}

impl KvCache for RotatingKvCache {
type Mask = ();
fn new(dim: usize, max_seq_len: usize) -> Self {
RotatingKvCache::new(dim, max_seq_len)
}

fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
self.append(k, v)
}

fn reset(&mut self) {
self.reset()
}
}

impl RotatingKvCache {
pub fn new(dim: usize, max_seq_len: usize) -> Self {
let k = RotatingCache::new(dim, max_seq_len);
Expand Down Expand Up @@ -414,34 +461,90 @@ impl IndicesAndMask {

#[derive(Debug, Clone)]
pub struct ScatteredKvCache {
k: Tensor,
v: Tensor,
k: Option<Tensor>,
v: Option<Tensor>,
dim: usize,
context: usize,
}
impl KvCache for ScatteredKvCache {
type Mask = IndicesAndMask;

fn new(dim: usize, max_seq_len: usize) -> Self {
ScatteredKvCache::new(dim, max_seq_len)
}

fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
self.append_with_mask(k, v, None)
}

fn append_with_mask(
&mut self,
k: &Tensor,
v: &Tensor,
mask: Option<&Self::Mask>,
) -> Result<(Tensor, Tensor)> {
if let Some(mask) = mask {
self.scattered_append(k, v, mask)
} else {
candle::bail!("ScatteredKvCache requires InidicesAndMask")
}
}

fn reset(&mut self) {
self.reset()
}
}

impl ScatteredKvCache {
pub fn append(
pub fn new(dim: usize, context: usize) -> Self {
Self {
k: None,
v: None,
dim,
context,
}
}

pub fn scattered_append(
&mut self,
k: &Tensor,
v: &Tensor,
iam: &IndicesAndMask,
) -> Result<(Tensor, Tensor)> {
if self.context <= k.dim(2)? {
if self.context <= k.dim(self.dim)? {
return Ok((k.clone(), v.clone()));
}
let indices = iam.indices.unsqueeze(2)?.unsqueeze(1)?;
if self.k.is_none() {
let mut k_shape = k.dims().to_vec();
k_shape[self.dim] = self.context;
self.k = Some(Tensor::zeros(k_shape.clone(), k.dtype(), k.device())?);
}
if self.v.is_none() {
let mut v_shape = v.dims().to_vec();
v_shape[self.dim] = self.context;
self.v = Some(Tensor::zeros(v_shape.clone(), v.dtype(), v.device())?);
}

let indices = iam.indices.unsqueeze(self.dim)?.unsqueeze(1)?;
let indices = indices.broadcast_as(k.shape())?.contiguous()?;
self.k.scatter_set(&indices, k, 2)?;
self.v.scatter_set(&indices, v, 2)?;
Ok((self.k.clone(), self.v.clone()))
let new_k = self.k.as_mut().unwrap();
let new_v = self.v.as_mut().unwrap();
new_k.scatter_set(&indices, k, self.dim)?;
new_v.scatter_set(&indices, v, self.dim)?;
Ok((new_k.clone(), new_v.clone()))
}

pub fn k(&self) -> &Tensor {
&self.k
pub fn k(&self) -> Option<&Tensor> {
self.k.as_ref()
}

pub fn v(&self) -> &Tensor {
&self.v
pub fn v(&self) -> Option<&Tensor> {
self.v.as_ref()
}

pub fn reset(&mut self) {
self.k = None;
self.v = None;
}
}

Expand Down Expand Up @@ -469,16 +572,8 @@ impl ScatteredCacheBuilder {
})
}

pub fn make_cache(&self, num_heads: usize, head_dim: usize) -> Result<ScatteredKvCache> {
let batch_size = self.batch_size();
let shape = (batch_size, num_heads, self.context, head_dim);
let k = Tensor::zeros(shape, self.dtype, self.device())?;
let v = Tensor::zeros(shape, self.dtype, self.device())?;
Ok(ScatteredKvCache {
k,
v,
context: self.context,
})
pub fn make_cache(&self, head_dim: usize) -> ScatteredKvCache {
ScatteredKvCache::new(head_dim, self.context)
}

pub fn positions(&self) -> &[usize] {
Expand All @@ -499,7 +594,6 @@ impl ScatteredCacheBuilder {
self.indices[batch_index] = 0;
}

#[allow(clippy::needless_range_loop)]
pub fn indices_and_mask(
&mut self,
seq_len: usize,
Expand All @@ -525,18 +619,26 @@ impl ScatteredCacheBuilder {
let mut indices = Vec::with_capacity(seq_len);
let mut all_pos = vec![usize::MAX; context];
if start_pos < context {
for i in 0..start_pos {
all_pos[i] = i;
}
all_pos
.iter_mut()
.enumerate()
.take(start_pos)
.for_each(|(i, p)| {
*p = i;
});
} else {
let offset = start_pos - start_index;
for i in 0..context {
all_pos[i] = if i < start_index {
i + offset
} else {
i + offset - context
};
}
all_pos
.iter_mut()
.enumerate()
.take(context)
.for_each(|(i, p)| {
*p = if i < start_index {
i + offset
} else {
i + offset - context
};
});
}
for seq_i in 0..seq_len {
let index = self.indices[batch_i];
Expand Down Expand Up @@ -584,7 +686,6 @@ impl ScatteredCacheBuilder {
&self.device
}

#[allow(clippy::needless_range_loop)]
fn indices_and_mask_abs(
&mut self,
seq_len: usize,
Expand Down Expand Up @@ -642,7 +743,7 @@ impl ScatteredCacheBuilder {
/// - GPU inference (CUDA, Metal)
/// - Autoregressive generation (token-by-token decoding)
///
/// **Use `KvCache` instead for:**
/// **Use `IncrementalKvCache` instead for:**
/// - CPU-only inference
/// - When you need fixed memory allocation upfront
///
Expand Down Expand Up @@ -670,6 +771,22 @@ pub struct ConcatKvCache {
dim: usize,
}

impl KvCache for ConcatKvCache {
type Mask = ();

fn new(dim: usize, _: usize) -> Self {
ConcatKvCache::new(dim)
}

fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
self.append(k, v)
}

fn reset(&mut self) {
self.reset()
}
}

impl ConcatKvCache {
/// Create a new empty concatenation-based KV-cache
///
Expand Down
Loading