diff --git a/candle-examples/examples/quantized-gemma/main.rs b/candle-examples/examples/quantized-gemma/main.rs index 98ce7bd41e..4181f54f2b 100644 --- a/candle-examples/examples/quantized-gemma/main.rs +++ b/candle-examples/examples/quantized-gemma/main.rs @@ -4,6 +4,7 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; +use candle_nn::kv_cache::DefaultKvCache; use clap::{Parser, ValueEnum}; use std::io::Write; use tokenizers::Tokenizer; @@ -175,7 +176,7 @@ fn main() -> anyhow::Result<()> { let start = std::time::Instant::now(); let device = candle_examples::device(args.cpu)?; - let mut model = { + let mut model: ModelWeights = { let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(&model_path))?; let mut total_size_in_bytes = 0; for (_, tensor) in model.tensor_infos.iter() { diff --git a/candle-examples/examples/quantized-qwen3/main.rs b/candle-examples/examples/quantized-qwen3/main.rs index 21c79d528b..02b6830a51 100644 --- a/candle-examples/examples/quantized-qwen3/main.rs +++ b/candle-examples/examples/quantized-qwen3/main.rs @@ -4,6 +4,7 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; +use candle_nn::kv_cache::DefaultKvCache; use clap::{Parser, ValueEnum}; use std::io::Write; use tokenizers::Tokenizer; @@ -189,7 +190,7 @@ fn main() -> anyhow::Result<()> { let start = std::time::Instant::now(); let device = candle_examples::device(args.cpu)?; - let mut model = { + let mut model: Qwen3 = { let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?; let mut total_size_in_bytes = 0; for (_, tensor) in model.tensor_infos.iter() { diff --git a/candle-nn/src/kv_cache.rs b/candle-nn/src/kv_cache.rs index cc445e9817..13426e36cb 100644 --- a/candle-nn/src/kv_cache.rs +++ b/candle-nn/src/kv_cache.rs @@ -2,8 +2,25 @@ //! use candle::{DType, Device, Result, Tensor}; +pub type DefaultKvCache = ConcatKvCache; + +pub trait KvCache { + type Mask; + fn new(dim: usize, max_seq_len: usize) -> Self; + fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)>; + fn append_with_mask( + &mut self, + k: &Tensor, + v: &Tensor, + _mask: Option<&Self::Mask>, + ) -> Result<(Tensor, Tensor)> { + self.append(k, v) + } + fn reset(&mut self); +} + #[derive(Debug, Clone)] -pub struct Cache { +pub struct InnerCache { // all_data is an option on a Tensor, this makes it possible to only create the actual tensor // on the first call where the batch size is easily known. // Also this makes it safe to clone a KvCache that has been reset (as in it will not share @@ -11,17 +28,17 @@ pub struct Cache { all_data: Option, dim: usize, current_seq_len: usize, - grow_by: usize, + increment: usize, max_seq_len: usize, } -impl Cache { +impl InnerCache { pub fn new(dim: usize, max_seq_len: usize) -> Self { Self { all_data: None, dim, current_seq_len: 0, - grow_by: max_seq_len, + increment: max_seq_len, max_seq_len, } } @@ -68,10 +85,10 @@ impl Cache { let ad = self.all_data.as_mut().unwrap(); while self.current_seq_len + seq_len > self.max_seq_len { let mut shape = src.dims().to_vec(); - shape[self.dim] = self.grow_by; + shape[self.dim] = self.increment; let next_ad = Tensor::zeros(shape, src.dtype(), src.device())?; *ad = Tensor::cat(&[&*ad, &next_ad], self.dim)?; - self.max_seq_len += self.grow_by; + self.max_seq_len += self.increment; } ad.slice_set(src, self.dim, self.current_seq_len)?; self.current_seq_len += seq_len; @@ -80,31 +97,46 @@ impl Cache { } #[derive(Debug, Clone)] -pub struct KvCache { - k: Cache, - v: Cache, +pub struct IncrementalKvCache { + k: InnerCache, + v: InnerCache, } -impl KvCache { +impl KvCache for IncrementalKvCache { + type Mask = (); + fn new(dim: usize, max_seq_len: usize) -> Self { + IncrementalKvCache::new(dim, max_seq_len) + } + + fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { + self.append(k, v) + } + + fn reset(&mut self) { + self.reset() + } +} + +impl IncrementalKvCache { pub fn new(dim: usize, max_seq_len: usize) -> Self { - let k = Cache::new(dim, max_seq_len); - let v = Cache::new(dim, max_seq_len); + let k = InnerCache::new(dim, max_seq_len); + let v = InnerCache::new(dim, max_seq_len); Self { k, v } } - pub fn k_cache(&self) -> &Cache { + pub fn k_cache(&self) -> &InnerCache { &self.k } - pub fn v_cache(&self) -> &Cache { + pub fn v_cache(&self) -> &InnerCache { &self.v } - pub fn k_cache_mut(&mut self) -> &mut Cache { + pub fn k_cache_mut(&mut self) -> &mut InnerCache { &mut self.k } - pub fn v_cache_mut(&mut self) -> &mut Cache { + pub fn v_cache_mut(&mut self) -> &mut InnerCache { &mut self.v } @@ -117,8 +149,8 @@ impl KvCache { } pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { - self.k.append(k)?; - self.v.append(v)?; + self.k.append(&k.contiguous()?)?; + self.v.append(&v.contiguous()?)?; let out_k = self.k.current_data()?; let out_v = self.v.current_data()?; let k = match out_k { @@ -338,6 +370,21 @@ pub struct RotatingKvCache { v: RotatingCache, } +impl KvCache for RotatingKvCache { + type Mask = (); + fn new(dim: usize, max_seq_len: usize) -> Self { + RotatingKvCache::new(dim, max_seq_len) + } + + fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { + self.append(k, v) + } + + fn reset(&mut self) { + self.reset() + } +} + impl RotatingKvCache { pub fn new(dim: usize, max_seq_len: usize) -> Self { let k = RotatingCache::new(dim, max_seq_len); @@ -414,34 +461,90 @@ impl IndicesAndMask { #[derive(Debug, Clone)] pub struct ScatteredKvCache { - k: Tensor, - v: Tensor, + k: Option, + v: Option, + dim: usize, context: usize, } +impl KvCache for ScatteredKvCache { + type Mask = IndicesAndMask; + + fn new(dim: usize, max_seq_len: usize) -> Self { + ScatteredKvCache::new(dim, max_seq_len) + } + + fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { + self.append_with_mask(k, v, None) + } + + fn append_with_mask( + &mut self, + k: &Tensor, + v: &Tensor, + mask: Option<&Self::Mask>, + ) -> Result<(Tensor, Tensor)> { + if let Some(mask) = mask { + self.scattered_append(k, v, mask) + } else { + candle::bail!("ScatteredKvCache requires InidicesAndMask") + } + } + + fn reset(&mut self) { + self.reset() + } +} impl ScatteredKvCache { - pub fn append( + pub fn new(dim: usize, context: usize) -> Self { + Self { + k: None, + v: None, + dim, + context, + } + } + + pub fn scattered_append( &mut self, k: &Tensor, v: &Tensor, iam: &IndicesAndMask, ) -> Result<(Tensor, Tensor)> { - if self.context <= k.dim(2)? { + if self.context <= k.dim(self.dim)? { return Ok((k.clone(), v.clone())); } - let indices = iam.indices.unsqueeze(2)?.unsqueeze(1)?; + if self.k.is_none() { + let mut k_shape = k.dims().to_vec(); + k_shape[self.dim] = self.context; + self.k = Some(Tensor::zeros(k_shape.clone(), k.dtype(), k.device())?); + } + if self.v.is_none() { + let mut v_shape = v.dims().to_vec(); + v_shape[self.dim] = self.context; + self.v = Some(Tensor::zeros(v_shape.clone(), v.dtype(), v.device())?); + } + + let indices = iam.indices.unsqueeze(self.dim)?.unsqueeze(1)?; let indices = indices.broadcast_as(k.shape())?.contiguous()?; - self.k.scatter_set(&indices, k, 2)?; - self.v.scatter_set(&indices, v, 2)?; - Ok((self.k.clone(), self.v.clone())) + let new_k = self.k.as_mut().unwrap(); + let new_v = self.v.as_mut().unwrap(); + new_k.scatter_set(&indices, k, self.dim)?; + new_v.scatter_set(&indices, v, self.dim)?; + Ok((new_k.clone(), new_v.clone())) } - pub fn k(&self) -> &Tensor { - &self.k + pub fn k(&self) -> Option<&Tensor> { + self.k.as_ref() } - pub fn v(&self) -> &Tensor { - &self.v + pub fn v(&self) -> Option<&Tensor> { + self.v.as_ref() + } + + pub fn reset(&mut self) { + self.k = None; + self.v = None; } } @@ -469,16 +572,8 @@ impl ScatteredCacheBuilder { }) } - pub fn make_cache(&self, num_heads: usize, head_dim: usize) -> Result { - let batch_size = self.batch_size(); - let shape = (batch_size, num_heads, self.context, head_dim); - let k = Tensor::zeros(shape, self.dtype, self.device())?; - let v = Tensor::zeros(shape, self.dtype, self.device())?; - Ok(ScatteredKvCache { - k, - v, - context: self.context, - }) + pub fn make_cache(&self, head_dim: usize) -> ScatteredKvCache { + ScatteredKvCache::new(head_dim, self.context) } pub fn positions(&self) -> &[usize] { @@ -499,7 +594,6 @@ impl ScatteredCacheBuilder { self.indices[batch_index] = 0; } - #[allow(clippy::needless_range_loop)] pub fn indices_and_mask( &mut self, seq_len: usize, @@ -525,18 +619,26 @@ impl ScatteredCacheBuilder { let mut indices = Vec::with_capacity(seq_len); let mut all_pos = vec![usize::MAX; context]; if start_pos < context { - for i in 0..start_pos { - all_pos[i] = i; - } + all_pos + .iter_mut() + .enumerate() + .take(start_pos) + .for_each(|(i, p)| { + *p = i; + }); } else { let offset = start_pos - start_index; - for i in 0..context { - all_pos[i] = if i < start_index { - i + offset - } else { - i + offset - context - }; - } + all_pos + .iter_mut() + .enumerate() + .take(context) + .for_each(|(i, p)| { + *p = if i < start_index { + i + offset + } else { + i + offset - context + }; + }); } for seq_i in 0..seq_len { let index = self.indices[batch_i]; @@ -584,7 +686,6 @@ impl ScatteredCacheBuilder { &self.device } - #[allow(clippy::needless_range_loop)] fn indices_and_mask_abs( &mut self, seq_len: usize, @@ -642,7 +743,7 @@ impl ScatteredCacheBuilder { /// - GPU inference (CUDA, Metal) /// - Autoregressive generation (token-by-token decoding) /// -/// **Use `KvCache` instead for:** +/// **Use `IncrementalKvCache` instead for:** /// - CPU-only inference /// - When you need fixed memory allocation upfront /// @@ -670,6 +771,22 @@ pub struct ConcatKvCache { dim: usize, } +impl KvCache for ConcatKvCache { + type Mask = (); + + fn new(dim: usize, _: usize) -> Self { + ConcatKvCache::new(dim) + } + + fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { + self.append(k, v) + } + + fn reset(&mut self) { + self.reset() + } +} + impl ConcatKvCache { /// Create a new empty concatenation-based KV-cache /// diff --git a/candle-nn/tests/kv_cache.rs b/candle-nn/tests/kv_cache.rs index c8a193a84d..086888953e 100644 --- a/candle-nn/tests/kv_cache.rs +++ b/candle-nn/tests/kv_cache.rs @@ -8,7 +8,7 @@ use candle::{Device, Result, Tensor}; #[test] fn kv_cache() -> Result<()> { - let mut cache = candle_nn::kv_cache::Cache::new(0, 16); + let mut cache = candle_nn::kv_cache::InnerCache::new(0, 16); for _ in [0, 1] { assert_eq!(cache.current_seq_len(), 0); let data = cache.current_data()?; diff --git a/candle-transformers/src/models/gemma3.rs b/candle-transformers/src/models/gemma3.rs index 08b4e5ad6e..b618515b3f 100644 --- a/candle-transformers/src/models/gemma3.rs +++ b/candle-transformers/src/models/gemma3.rs @@ -147,7 +147,7 @@ impl Module for MLP { #[derive(Debug, Clone)] enum KvCache { - Normal(candle_nn::kv_cache::KvCache), + Normal(candle_nn::kv_cache::IncrementalKvCache), Rotating(candle_nn::kv_cache::RotatingKvCache), } @@ -192,7 +192,7 @@ impl Attention { let kv_cache = if let Some(sliding_window) = sliding_window { KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(2, sliding_window)) } else { - KvCache::Normal(candle_nn::kv_cache::KvCache::new( + KvCache::Normal(candle_nn::kv_cache::IncrementalKvCache::new( 2, cfg.max_position_embeddings, )) diff --git a/candle-transformers/src/models/glm4_new.rs b/candle-transformers/src/models/glm4_new.rs index cb7294a43c..b5cf91327c 100644 --- a/candle-transformers/src/models/glm4_new.rs +++ b/candle-transformers/src/models/glm4_new.rs @@ -4,7 +4,7 @@ use crate::{ utils::repeat_kv, }; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; -use candle_nn::{kv_cache::KvCache, Activation, VarBuilder}; +use candle_nn::{kv_cache::IncrementalKvCache, Activation, VarBuilder}; use std::sync::Arc; #[derive(Debug, Clone, serde::Deserialize)] @@ -123,7 +123,7 @@ pub(crate) struct Attention { head_dim: usize, hidden_size: usize, rotary_emb: Arc, - kv_cache: KvCache, + kv_cache: IncrementalKvCache, } impl Attention { @@ -169,7 +169,7 @@ impl Attention { // Initialize KV cache with 512 tokens capacity to reduce initial memory allocation. // The cache will grow in chunks of 512 tokens when needed. - let kv_cache = KvCache::new(2, 512); + let kv_cache = IncrementalKvCache::new(2, 512); Ok(Self { q_proj, diff --git a/candle-transformers/src/models/quantized_gemma3.rs b/candle-transformers/src/models/quantized_gemma3.rs index bc5b9e7ff0..55b40f62aa 100644 --- a/candle-transformers/src/models/quantized_gemma3.rs +++ b/candle-transformers/src/models/quantized_gemma3.rs @@ -19,7 +19,10 @@ use candle::quantized::gguf_file; use candle::quantized::QTensor; use candle::D; use candle::{DType, Device, IndexOp, Result, Tensor}; -use candle_nn::{Embedding, Module}; +use candle_nn::{ + kv_cache::{DefaultKvCache, KvCache}, + Embedding, Module, +}; pub const MAX_SEQ_LEN: usize = 131072; // Gemma 3 supports 128K context window pub const DEFAULT_SLIDING_WINDOW_TYPE: usize = 6; @@ -101,7 +104,7 @@ impl RotaryEmbedding { } #[derive(Debug, Clone)] -struct LayerWeights { +struct LayerWeights { // Attention components attention_wq: QMatMul, attention_wk: QMatMul, @@ -133,14 +136,14 @@ struct LayerWeights { neg_inf: Tensor, // Cache - kv_cache: Option<(Tensor, Tensor)>, + kv_cache: C, // Tracing span_attn: tracing::Span, span_mlp: tracing::Span, } -impl LayerWeights { +impl LayerWeights { fn mask( &self, b_sz: usize, @@ -207,19 +210,7 @@ impl LayerWeights { .rotary_embedding .apply_rotary_emb_qkv(&q, &k, index_pos)?; - let (k, v) = match &self.kv_cache { - None => (k, v), - Some((k_cache, v_cache)) => { - if index_pos == 0 { - (k, v) - } else { - let k = Tensor::cat(&[k_cache, &k], 2)?; // concat on seq dim - let v = Tensor::cat(&[v_cache, &v], 2)?; - (k, v) - } - } - }; - self.kv_cache = Some((k.clone(), v.clone())); // update cache + let (k, v) = self.kv_cache.append(&k, &v)?; // Repeat KV for GQA let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?; @@ -247,17 +238,17 @@ impl LayerWeights { } #[derive(Debug, Clone)] -pub struct ModelWeights { +pub struct ModelWeights { tok_embeddings: Embedding, embedding_length: usize, - layers: Vec, + layers: Vec>, norm: RmsNorm, output: QMatMul, span: tracing::Span, span_output: tracing::Span, } -impl ModelWeights { +impl ModelWeights { pub fn from_gguf( ct: gguf_file::Content, reader: &mut R, @@ -402,7 +393,7 @@ impl ModelWeights { sliding_window_size, rotary_embedding, neg_inf: neg_inf.clone(), - kv_cache: None, + kv_cache: C::new(2, 512), span_attn, span_mlp, }) diff --git a/candle-transformers/src/models/quantized_phi3.rs b/candle-transformers/src/models/quantized_phi3.rs index 4a04e43418..5b5a9d59ce 100644 --- a/candle-transformers/src/models/quantized_phi3.rs +++ b/candle-transformers/src/models/quantized_phi3.rs @@ -18,7 +18,7 @@ use std::collections::HashMap; use candle::quantized::gguf_file; use candle::quantized::QTensor; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; -use candle_nn::{kv_cache::KvCache, Embedding, RmsNorm}; +use candle_nn::{kv_cache::IncrementalKvCache, Embedding, RmsNorm}; #[derive(Debug, Clone)] struct QLinear { @@ -83,7 +83,7 @@ struct LayerWeights { cos: Tensor, sin: Tensor, neg_inf: Tensor, - kv_cache: KvCache, + kv_cache: IncrementalKvCache, use_flash_attn: bool, span_attn: tracing::Span, span_rot: tracing::Span, @@ -269,7 +269,7 @@ impl ModelWeights { )?; let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); - let kv_cache = KvCache::new(2, max_seq_len); + let kv_cache = IncrementalKvCache::new(2, max_seq_len); layers.push(LayerWeights { attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?, attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?, diff --git a/candle-transformers/src/models/quantized_qwen3.rs b/candle-transformers/src/models/quantized_qwen3.rs index 5d9f414658..f9469cd547 100644 --- a/candle-transformers/src/models/quantized_qwen3.rs +++ b/candle-transformers/src/models/quantized_qwen3.rs @@ -10,7 +10,10 @@ use super::with_tracing::QMatMul; use crate::{quantized_nn::RmsNorm, utils::repeat_kv}; use candle::quantized::{gguf_file, QTensor}; use candle::{DType, Device, Result, Tensor}; -use candle_nn::{kv_cache::ConcatKvCache, Activation, Embedding, Module}; +use candle_nn::{ + kv_cache::{DefaultKvCache, KvCache}, + Activation, Embedding, Module, +}; use std::io::{Read, Seek}; use std::sync::Arc; @@ -124,7 +127,7 @@ impl RotaryEmbedding { } #[derive(Debug, Clone)] -struct AttentionWeights { +struct AttentionWeights { q_proj: QMatMul, k_proj: QMatMul, v_proj: QMatMul, @@ -136,11 +139,11 @@ struct AttentionWeights { num_kv_groups: usize, head_dim: usize, rotary_emb: Arc, - kv_cache: ConcatKvCache, + kv_cache: C, span_attn: tracing::Span, } -impl AttentionWeights { +impl AttentionWeights { fn new( gg: &mut Gguf, num_heads: usize, @@ -160,7 +163,7 @@ impl AttentionWeights { let q_norm = gg.rms_norm(&format!("{prefix}.attn_q_norm.weight"), rms_norm_eps)?; let k_norm = gg.rms_norm(&format!("{prefix}.attn_k_norm.weight"), rms_norm_eps)?; - let kv_cache = ConcatKvCache::new(2); + let kv_cache = C::new(2, 512); let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); @@ -240,14 +243,14 @@ impl AttentionWeights { } #[derive(Debug, Clone)] -struct LayerWeights { - self_attn: AttentionWeights, +struct LayerWeights { + self_attn: AttentionWeights, mlp: MlpWeights, ln1: RmsNorm, ln2: RmsNorm, } -impl LayerWeights { +impl LayerWeights { fn new( gg: &mut Gguf, num_attention_heads: usize, @@ -294,9 +297,9 @@ impl LayerWeights { } #[derive(Debug, Clone)] -pub struct ModelWeights { +pub struct ModelWeights { embed_tokens: Embedding, - layers: Vec, + layers: Vec>, norm: RmsNorm, lm_head: QMatMul, device: Device, @@ -305,7 +308,7 @@ pub struct ModelWeights { span_output: tracing::Span, } -impl ModelWeights { +impl ModelWeights { pub fn from_gguf( ct: gguf_file::Content, reader: &mut R, diff --git a/candle-transformers/src/models/qwen3.rs b/candle-transformers/src/models/qwen3.rs index 9f018939ae..27da3711fc 100644 --- a/candle-transformers/src/models/qwen3.rs +++ b/candle-transformers/src/models/qwen3.rs @@ -3,7 +3,10 @@ use crate::{ utils::repeat_kv, }; use candle::{DType, Device, Module, Result, Tensor}; -use candle_nn::{kv_cache::ConcatKvCache, Activation, VarBuilder}; +use candle_nn::{ + kv_cache::{DefaultKvCache, KvCache}, + Activation, VarBuilder, +}; use std::sync::Arc; #[derive(Debug, Clone, PartialEq, serde::Deserialize)] @@ -91,7 +94,7 @@ impl Module for Qwen3MLP { } #[derive(Debug, Clone)] -pub(crate) struct Qwen3Attention { +pub(crate) struct Qwen3Attention { // projections q_proj: Linear, k_proj: Linear, @@ -108,10 +111,10 @@ pub(crate) struct Qwen3Attention { hidden_size: usize, // utils rotary_emb: Arc, - kv_cache: ConcatKvCache, + kv_cache: C, } -impl Qwen3Attention { +impl Qwen3Attention { pub(crate) fn new( cfg: &Config, rotary_emb: Arc, @@ -159,7 +162,9 @@ impl Qwen3Attention { // dim=2 because we concatenate along the sequence dimension // For tensors of shape [batch, heads, seq, head_dim] - let kv_cache = ConcatKvCache::new(2); + // The KV cache is initialized with 512 tokens capacity. If the KvCache implementation uses capacity this + // leads to reduced initial memory allocation, and the cache will grow in chunks of 512 tokens when needed. + let kv_cache = C::new(2, 512); Ok(Self { q_proj, @@ -241,14 +246,14 @@ impl Qwen3Attention { } #[derive(Debug, Clone)] -struct DecoderLayer { - self_attn: Qwen3Attention, +struct DecoderLayer { + self_attn: Qwen3Attention, mlp: Qwen3MLP, ln1: RmsNorm, ln2: RmsNorm, } -impl DecoderLayer { +impl DecoderLayer { fn new(cfg: &Config, rotary: Arc, vb: VarBuilder) -> Result { let self_attn = Qwen3Attention::new(cfg, rotary, vb.pp("self_attn"))?; let mlp = Qwen3MLP::new(cfg, vb.pp("mlp"))?; @@ -281,9 +286,9 @@ impl DecoderLayer { } #[derive(Debug, Clone)] -pub struct Model { +pub struct Model { embed_tokens: candle_nn::Embedding, - layers: Vec, + layers: Vec>, norm: RmsNorm, device: Device, dtype: DType, diff --git a/candle-transformers/src/models/qwen3_moe.rs b/candle-transformers/src/models/qwen3_moe.rs index b76ce92de4..b954705d50 100644 --- a/candle-transformers/src/models/qwen3_moe.rs +++ b/candle-transformers/src/models/qwen3_moe.rs @@ -3,7 +3,10 @@ use crate::models::{ with_tracing::{linear_no_bias, Linear, RmsNorm}, }; use candle::{DType, Device, Module, Result, Tensor, D}; -use candle_nn::{Activation, VarBuilder}; +use candle_nn::{ + kv_cache::{DefaultKvCache, KvCache}, + Activation, VarBuilder, +}; use std::sync::Arc; #[derive(Debug, Clone, PartialEq, serde::Deserialize)] @@ -189,14 +192,14 @@ impl Module for Qwen3FeedForward { } #[derive(Debug, Clone)] -struct DecoderLayer { - self_attn: Qwen3Attention, +struct DecoderLayer { + self_attn: Qwen3Attention, feed_forward: Qwen3FeedForward, ln1: RmsNorm, ln2: RmsNorm, } -impl DecoderLayer { +impl DecoderLayer { fn new( layer_idx: usize, cfg: &Config, @@ -243,15 +246,15 @@ impl DecoderLayer { } #[derive(Debug, Clone)] -pub struct Model { +pub struct Model { embed_tokens: candle_nn::Embedding, - layers: Vec, + layers: Vec>, norm: RmsNorm, device: Device, dtype: DType, } -impl Model { +impl Model { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { let embed_tokens = candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; @@ -325,12 +328,12 @@ impl Model { } #[derive(Debug, Clone)] -pub struct ModelForCausalLM { - base: Model, +pub struct ModelForCausalLM { + base: Model, lm_head: Linear, } -impl ModelForCausalLM { +impl ModelForCausalLM { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { let base = Model::new(cfg, vb.clone())?; let lm_head = if cfg.tie_word_embeddings { diff --git a/candle-transformers/src/models/qwen3_vl/text.rs b/candle-transformers/src/models/qwen3_vl/text.rs index febe426879..de415b2ca0 100644 --- a/candle-transformers/src/models/qwen3_vl/text.rs +++ b/candle-transformers/src/models/qwen3_vl/text.rs @@ -2,8 +2,8 @@ use std::sync::{Arc, Mutex}; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{ - embedding, kv_cache::KvCache, linear, linear_b, rms_norm, Activation, Embedding, Linear, - Module, RmsNorm, VarBuilder, + embedding, kv_cache::IncrementalKvCache, linear, linear_b, rms_norm, Activation, Embedding, + Linear, Module, RmsNorm, VarBuilder, }; use super::config::TextConfig; @@ -104,7 +104,7 @@ struct Attention { rotary_emb: Arc, n_kv_groups: usize, softmax_scale: f64, - kv_cache: Arc>, + kv_cache: Arc>, } impl Attention { @@ -141,7 +141,10 @@ impl Attention { rotary_emb, n_kv_groups: cfg.num_attention_heads / cfg.num_key_value_heads, softmax_scale: 1.0 / (cfg.head_dim as f64).sqrt(), - kv_cache: Arc::new(Mutex::new(KvCache::new(2, cfg.max_position_embeddings))), + kv_cache: Arc::new(Mutex::new(IncrementalKvCache::new( + 2, + cfg.max_position_embeddings, + ))), }) }