From aa032aaa995c4b94ac523f6d26b421c520461ba7 Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Mon, 14 Jul 2025 16:33:26 -0700 Subject: [PATCH 01/11] Rename config and remove dropout --- backends/candle/src/lib.rs | 8 +- backends/candle/src/models/debertav2.rs | 1387 +++++++++++++++++++++++ backends/candle/src/models/mod.rs | 2 + 3 files changed, 1396 insertions(+), 1 deletion(-) create mode 100644 backends/candle/src/models/debertav2.rs diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index 882cdb8a..3584b294 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -11,7 +11,7 @@ use crate::compute_cap::{ compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap, }; use crate::models::{ - BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel, + BertConfig, BertModel, DebertaV2Config, DebertaV2Model, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel, JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model, ModernBertConfig, ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config, Qwen3Config, Qwen3Model, }; @@ -90,6 +90,8 @@ impl<'de> Deserialize<'de> for BertConfigWrapper { #[serde(tag = "model_type", rename_all = "kebab-case")] enum Config { Bert(BertConfigWrapper), + #[serde(rename(deserialize = "deberta-v2"))] + DebertaV2(DebertaV2Config), XlmRoberta(BertConfig), Camembert(BertConfig), Roberta(BertConfig), @@ -244,6 +246,10 @@ impl CandleBackend { Ok(Box::new(BertModel::load(vb, &config, model_type).s()?)) } }, + (Config::DebertaV2(config), Device::Cpu | Device::Metal(_)) => { + tracing::info!("Starting DebertaV2 model on {:?}", device); + Ok(Box::new(DebertaV2Model::load(vb, &config, model_type).s()?)) + }, ( Config::XlmRoberta(config) | Config::Camembert(config) | Config::Roberta(config), Device::Cpu | Device::Metal(_), diff --git a/backends/candle/src/models/debertav2.rs b/backends/candle/src/models/debertav2.rs new file mode 100644 index 00000000..9446ad0b --- /dev/null +++ b/backends/candle/src/models/debertav2.rs @@ -0,0 +1,1387 @@ +use std::collections::HashMap; + +use candle::{bail, Context, DType, Device, Module, Result, Tensor, D}; +use candle_nn::{ + conv1d, embedding, layer_norm, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder, +}; +use serde::{Deserialize, Deserializer}; + +pub const DTYPE: DType = DType::F32; + +// NOTE: HiddenAct and HiddenActLayer are both direct copies from bert.rs. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum HiddenAct { + Gelu, + GeluApproximate, + Relu, +} + +pub struct HiddenActLayer { + act: HiddenAct, + span: tracing::Span, +} + +impl HiddenActLayer { + fn new(act: HiddenAct) -> Self { + let span = tracing::span!(tracing::Level::TRACE, "hidden-act"); + Self { act, span } + } + + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + match self.act { + // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213 + HiddenAct::Gelu => xs.gelu_erf(), + HiddenAct::GeluApproximate => xs.gelu(), + HiddenAct::Relu => xs.relu(), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +enum PositionEmbeddingType { + #[default] + Absolute, +} + +pub type Id2Label = HashMap; +pub type Label2Id = HashMap; + +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct DebertaV2Config { + pub vocab_size: usize, + pub hidden_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub intermediate_size: usize, + pub hidden_act: HiddenAct, + pub hidden_dropout_prob: f64, + pub attention_probs_dropout_prob: f64, + pub max_position_embeddings: usize, + pub type_vocab_size: usize, + pub initializer_range: f64, + pub layer_norm_eps: f64, + pub relative_attention: bool, + pub max_relative_positions: isize, + pub pad_token_id: Option, + pub position_biased_input: bool, + #[serde(deserialize_with = "deserialize_pos_att_type")] + pub pos_att_type: Vec, + pub position_buckets: Option, + pub share_att_key: Option, + pub attention_head_size: Option, + pub embedding_size: Option, + pub norm_rel_ebd: Option, + pub conv_kernel_size: Option, + pub conv_groups: Option, + pub conv_act: Option, + pub id2label: Option, + pub label2id: Option, + pub pooler_dropout: Option, + pub pooler_hidden_act: Option, + pub pooler_hidden_size: Option, + pub cls_dropout: Option, +} + +fn deserialize_pos_att_type<'de, D>(deserializer: D) -> std::result::Result, D::Error> +where + D: Deserializer<'de>, +{ + #[derive(Deserialize, Debug)] + #[serde(untagged)] + enum StringOrVec { + String(String), + Vec(Vec), + } + + match StringOrVec::deserialize(deserializer)? { + StringOrVec::String(s) => Ok(s.split('|').map(String::from).collect()), + StringOrVec::Vec(v) => Ok(v), + } +} + +// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L823 +pub struct DebertaV2Embeddings { + device: Device, + word_embeddings: Embedding, + position_embeddings: Option, + token_type_embeddings: Option, + layer_norm: LayerNorm, + position_ids: Tensor, + config: DebertaV2Config, + embedding_size: usize, + embed_proj: Option, +} + +impl DebertaV2Embeddings { + pub fn load(vb: VarBuilder, config: &DebertaV2Config) -> Result { + let device = vb.device().clone(); + let config = config.clone(); + + let embedding_size = config.embedding_size.unwrap_or(config.hidden_size); + + let word_embeddings = + embedding(config.vocab_size, embedding_size, vb.pp("word_embeddings"))?; + + let position_embeddings = if config.position_biased_input { + Some(embedding( + config.max_position_embeddings, + embedding_size, + vb.pp("position_embeddings"), + )?) + } else { + None + }; + + let token_type_embeddings: Option = if config.type_vocab_size > 0 { + Some(candle_nn::embedding( + config.type_vocab_size, + config.hidden_size, + vb.pp("token_type_embeddings"), + )?) + } else { + None + }; + + let embed_proj: Option = if embedding_size != config.hidden_size { + Some(candle_nn::linear_no_bias( + embedding_size, + config.hidden_size, + vb.pp("embed_proj"), + )?) + } else { + None + }; + + let layer_norm = layer_norm( + config.hidden_size, + config.layer_norm_eps, + vb.pp("LayerNorm"), + )?; + + let position_ids = + Tensor::arange(0, config.max_position_embeddings as u32, &device)?.unsqueeze(0)?; + + Ok(Self { + word_embeddings, + position_embeddings, + token_type_embeddings, + layer_norm, + position_ids, + device, + config, + embedding_size, + embed_proj, + }) + } + + pub fn forward( + &self, + input_ids: Option<&Tensor>, + token_type_ids: Option<&Tensor>, + position_ids: Option<&Tensor>, + mask: Option<&Tensor>, + inputs_embeds: Option<&Tensor>, + ) -> Result { + let (input_shape, input_embeds) = match (input_ids, inputs_embeds) { + (Some(ids), None) => { + let embs = self.word_embeddings.forward(ids)?; + (ids.dims(), embs) + } + (None, Some(e)) => (e.dims(), e.clone()), + (None, None) => { + bail!("Must specify either input_ids or inputs_embeds") + } + (Some(_), Some(_)) => { + bail!("Can't specify both input_ids and inputs_embeds") + } + }; + + let seq_length = match input_shape.last() { + Some(v) => *v, + None => bail!("DebertaV2Embeddings invalid input shape"), + }; + + let position_ids = match position_ids { + Some(v) => v.clone(), + None => self.position_ids.narrow(1, 0, seq_length)?, + }; + + let token_type_ids = match token_type_ids { + Some(ids) => ids.clone(), + None => Tensor::zeros(input_shape, DType::U32, &self.device)?, + }; + + let position_embeddings = match &self.position_embeddings { + Some(emb) => emb.forward(&position_ids)?, + None => Tensor::zeros_like(&input_embeds)?, + }; + + let mut embeddings = input_embeds; + + if self.config.position_biased_input { + embeddings = embeddings.add(&position_embeddings)?; + } + + if self.config.type_vocab_size > 0 { + embeddings = self.token_type_embeddings.as_ref().map_or_else( + || bail!("token_type_embeddings must be set when type_vocab_size > 0"), + |token_type_embeddings| { + embeddings.add(&token_type_embeddings.forward(&token_type_ids)?) + }, + )?; + } + + if self.embedding_size != self.config.hidden_size { + embeddings = if let Some(embed_proj) = &self.embed_proj { + embed_proj.forward(&embeddings)? + } else { + bail!("embed_proj must exist if embedding_size != config.hidden_size"); + } + } + + embeddings = self.layer_norm.forward(&embeddings)?; + + if let Some(mask) = mask { + let mut mask = mask.clone(); + if mask.dims() != embeddings.dims() { + if mask.dims().len() == 4 { + mask = mask.squeeze(1)?.squeeze(1)?; + } + mask = mask.unsqueeze(2)?; + } + + mask = mask.to_dtype(embeddings.dtype())?; + embeddings = embeddings.broadcast_mul(&mask)?; + } + + Ok(embeddings) + } +} + +// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L72 +struct XSoftmax {} + +impl XSoftmax { + pub fn apply(input: &Tensor, mask: &Tensor, dim: D, device: &Device) -> Result { + // NOTE: At the time of this writing, candle does not have a logical-not operator. + let mut rmask = mask.broadcast_as(input.shape())?.to_dtype(DType::F32)?; + + rmask = rmask + .broadcast_lt(&Tensor::new(&[1.0_f32], device)?)? + .to_dtype(DType::U8)?; + + let min_value_tensor = Tensor::new(&[f32::MIN], device)?.broadcast_as(input.shape())?; + let mut output = rmask.where_cond(&min_value_tensor, input)?; + + output = candle_nn::ops::softmax(&output, dim)?; + + let t_zeroes = Tensor::new(&[0f32], device)?.broadcast_as(input.shape())?; + output = rmask.where_cond(&t_zeroes, &output)?; + + Ok(output) + } +} + +// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L605 +pub struct DebertaV2DisentangledSelfAttention { + config: DebertaV2Config, + num_attention_heads: usize, + query_proj: candle_nn::Linear, + key_proj: candle_nn::Linear, + value_proj: candle_nn::Linear, + device: Device, + relative_attention: bool, + position_buckets: isize, + max_relative_positions: isize, + pos_ebd_size: isize, + share_att_key: bool, + pos_key_proj: Option, + pos_query_proj: Option, +} + +impl DebertaV2DisentangledSelfAttention { + pub fn load(vb: VarBuilder, config: &DebertaV2Config) -> Result { + let config = config.clone(); + let vb = vb.clone(); + + if config.hidden_size % config.num_attention_heads != 0 { + return Err(candle::Error::Msg(format!( + "The hidden size {} is not a multiple of the number of attention heads {}", + config.hidden_size, config.num_attention_heads + ))); + } + + let num_attention_heads = config.num_attention_heads; + + let attention_head_size = config + .attention_head_size + .unwrap_or(config.hidden_size / config.num_attention_heads); + + let all_head_size = num_attention_heads * attention_head_size; + + let query_proj = candle_nn::linear(config.hidden_size, all_head_size, vb.pp("query_proj"))?; + let key_proj = candle_nn::linear(config.hidden_size, all_head_size, vb.pp("key_proj"))?; + let value_proj = candle_nn::linear(config.hidden_size, all_head_size, vb.pp("value_proj"))?; + + let share_att_key = config.share_att_key.unwrap_or(false); + let relative_attention = config.relative_attention; + let mut max_relative_positions = config.max_relative_positions; + + let mut pos_ebd_size: isize = 0; + let position_buckets = config.position_buckets.unwrap_or(-1); + let mut pos_key_proj: Option = None; + let mut pos_query_proj: Option = None; + + if relative_attention { + if max_relative_positions < 1 { + max_relative_positions = config.max_position_embeddings as isize; + } + pos_ebd_size = max_relative_positions; + if position_buckets > 0 { + pos_ebd_size = position_buckets + } + + if !share_att_key { + if config.pos_att_type.iter().any(|s| s == "c2p") { + pos_key_proj = Some(candle_nn::linear( + config.hidden_size, + all_head_size, + vb.pp("pos_key_proj"), + )?); + } + if config.pos_att_type.iter().any(|s| s == "p2c") { + pos_query_proj = Some(candle_nn::linear( + config.hidden_size, + all_head_size, + vb.pp("pos_query_proj"), + )?); + } + } + } + + let device = vb.device().clone(); + + Ok(Self { + config, + num_attention_heads, + query_proj, + key_proj, + value_proj, + device, + relative_attention, + position_buckets, + max_relative_positions, + pos_ebd_size, + share_att_key, + pos_key_proj, + pos_query_proj, + }) + } + + pub fn forward( + &self, + hidden_states: &Tensor, + attention_mask: &Tensor, + query_states: Option<&Tensor>, + relative_pos: Option<&Tensor>, + rel_embeddings: Option<&Tensor>, + ) -> Result { + let query_states = match query_states { + Some(qs) => qs, + None => hidden_states, + }; + + let query_layer = self.transpose_for_scores(&self.query_proj.forward(query_states)?)?; + let key_layer = self.transpose_for_scores(&self.key_proj.forward(query_states)?)?; + let value_layer = self.transpose_for_scores(&self.value_proj.forward(query_states)?)?; + + let mut rel_att: Option = None; + + let mut scale_factor: usize = 1; + + if self.config.pos_att_type.iter().any(|s| s == "c2p") { + scale_factor += 1; + } + + if self.config.pos_att_type.iter().any(|s| s == "p2c") { + scale_factor += 1; + } + + let scale = { + let q_size = query_layer.dim(D::Minus1)?; + Tensor::new(&[(q_size * scale_factor) as f32], &self.device)?.sqrt()? + }; + + let mut attention_scores: Tensor = { + let key_layer_transposed = key_layer.t()?; + let div = key_layer_transposed + .broadcast_div(scale.to_dtype(query_layer.dtype())?.as_ref())?; + query_layer.matmul(&div)? + }; + + if self.relative_attention { + if let Some(rel_embeddings) = rel_embeddings { + let rel_att = Some(self.disentangled_attention_bias( + query_layer, + key_layer, + relative_pos, + rel_embeddings.clone(), + scale_factor, + )?); + } + } + + if let Some(rel_att) = rel_att { + attention_scores = attention_scores.broadcast_add(&rel_att)?; + } + + attention_scores = attention_scores.reshape(( + (), + self.num_attention_heads, + attention_scores.dim(D::Minus2)?, + attention_scores.dim(D::Minus1)?, + ))?; + + let mut attention_probs = + XSoftmax::apply(&attention_scores, attention_mask, D::Minus1, &self.device)?; + + let mut context_layer = attention_probs + .reshape(( + (), + attention_probs.dim(D::Minus2)?, + attention_probs.dim(D::Minus1)?, + ))? + .matmul(&value_layer)?; + + context_layer = context_layer + .reshape(( + (), + self.num_attention_heads, + context_layer.dim(D::Minus2)?, + context_layer.dim(D::Minus1)?, + ))? + .permute((0, 2, 1, 3))? + .contiguous()?; + + let dims = context_layer.dims(); + + context_layer = match dims.len() { + 2 => context_layer.reshape(())?, + 3 => context_layer.reshape((dims[0], ()))?, + 4 => context_layer.reshape((dims[0], dims[1], ()))?, + 5 => context_layer.reshape((dims[0], dims[1], dims[2], ()))?, + _ => { + bail!( + "Invalid shape for DisentangledSelfAttention context layer: {:?}", + dims + ) + } + }; + + Ok(context_layer) + } + + fn transpose_for_scores(&self, xs: &Tensor) -> Result { + let dims = xs.dims().to_vec(); + match dims.len() { + 3 => { + let reshaped = xs.reshape((dims[0], dims[1], self.num_attention_heads, ()))?; + + reshaped.transpose(1, 2)?.contiguous()?.reshape(( + (), + reshaped.dim(1)?, + reshaped.dim(D::Minus1)?, + )) + } + shape => { + bail!("Invalid shape for transpose_for_scores. Expected 3 dimensions, got {shape}") + } + } + } + + fn disentangled_attention_bias( + &self, + query_layer: Tensor, + key_layer: Tensor, + relative_pos: Option<&Tensor>, + rel_embeddings: Tensor, + scale_factor: usize, + ) -> Result { + let mut relative_pos = relative_pos.map_or( + build_relative_position( + query_layer.dim(D::Minus2)?, + key_layer.dim(D::Minus2)?, + &self.device, + Some(self.position_buckets), + Some(self.max_relative_positions), + )?, + |pos| pos.clone(), + ); + + relative_pos = match relative_pos.dims().len() { + 2 => relative_pos.unsqueeze(0)?.unsqueeze(0)?, + 3 => relative_pos.unsqueeze(1)?, + other => { + bail!("Relative position ids must be of dim 2 or 3 or 4. Got dim of size {other}") + } + }; + + let att_span = self.pos_ebd_size; + + let rel_embeddings = rel_embeddings + .narrow(0, 0, (att_span * 2) as usize)? + .unsqueeze(0)?; + + let mut pos_query_layer: Option = None; + let mut pos_key_layer: Option = None; + + let repeat_with = query_layer.dim(0)? / self.num_attention_heads; + if self.share_att_key { + pos_query_layer = Some( + self.transpose_for_scores(&self.query_proj.forward(&rel_embeddings)?)? + .repeat(repeat_with)?, + ); + + pos_key_layer = Some( + self.transpose_for_scores(&self.key_proj.forward(&rel_embeddings)?)? + .repeat(repeat_with)?, + ) + } else { + if self.config.pos_att_type.iter().any(|s| s == "c2p") { + pos_key_layer = Some( + self.transpose_for_scores( + &self + .pos_key_proj + .as_ref() + .context( + "Need pos_key_proj when share_att_key is false or not specified", + )? + .forward(&rel_embeddings)?, + )? + .repeat(repeat_with)?, + ) + } + if self.config.pos_att_type.iter().any(|s| s == "p2c") { + pos_query_layer = Some(self.transpose_for_scores(&self + .pos_query_proj + .as_ref() + .context("Need a pos_query_proj when share_att_key is false or not specified")? + .forward(&rel_embeddings)?)?.repeat(repeat_with)?) + } + } + + let mut score = Tensor::new(&[0 as f32], &self.device)?; + + if self.config.pos_att_type.iter().any(|s| s == "c2p") { + let pos_key_layer = pos_key_layer.context("c2p without pos_key_layer")?; + + let scale = Tensor::new( + &[(pos_key_layer.dim(D::Minus1)? * scale_factor) as f32], + &self.device, + )? + .sqrt()?; + + let mut c2p_att = query_layer.matmul(&pos_key_layer.t()?)?; + + let c2p_pos = relative_pos + .broadcast_add(&Tensor::new(&[att_span as i64], &self.device)?)? + .clamp(0 as f32, (att_span * 2 - 1) as f32)?; + + c2p_att = c2p_att.gather( + &c2p_pos + .squeeze(0)? + .expand(&[ + query_layer.dim(0)?, + query_layer.dim(1)?, + relative_pos.dim(D::Minus1)?, + ])? + .contiguous()?, + D::Minus1, + )?; + + score = score.broadcast_add( + &c2p_att.broadcast_div(scale.to_dtype(c2p_att.dtype())?.as_ref())?, + )?; + } + + if self.config.pos_att_type.iter().any(|s| s == "p2c") { + let pos_query_layer = pos_query_layer.context("p2c without pos_key_layer")?; + + let scale = Tensor::new( + &[(pos_query_layer.dim(D::Minus1)? * scale_factor) as f32], + &self.device, + )? + .sqrt()?; + + let r_pos = { + if key_layer.dim(D::Minus2)? != query_layer.dim(D::Minus2)? { + build_relative_position( + key_layer.dim(D::Minus2)?, + key_layer.dim(D::Minus2)?, + &self.device, + Some(self.position_buckets), + Some(self.max_relative_positions), + )? + .unsqueeze(0)? + } else { + relative_pos + } + }; + + let p2c_pos = r_pos + .to_dtype(DType::F32)? + .neg()? + .broadcast_add(&Tensor::new(&[att_span as f32], &self.device)?)? + .clamp(0f32, (att_span * 2 - 1) as f32)?; + + let p2c_att = key_layer + .matmul(&pos_query_layer.t()?)? + .gather( + &p2c_pos + .squeeze(0)? + .expand(&[ + query_layer.dim(0)?, + key_layer.dim(D::Minus2)?, + key_layer.dim(D::Minus2)?, + ])? + .contiguous()? + .to_dtype(DType::U32)?, + D::Minus1, + )? + .t()?; + + score = + score.broadcast_add(&p2c_att.broadcast_div(&scale.to_dtype(p2c_att.dtype())?)?)?; + } + + Ok(score) + } +} + +// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L270 +pub struct DebertaV2Attention { + dsa: DebertaV2DisentangledSelfAttention, + output: DebertaV2SelfOutput, +} + +impl DebertaV2Attention { + pub fn load(vb: VarBuilder, config: &DebertaV2Config) -> Result { + let dsa = DebertaV2DisentangledSelfAttention::load(vb.pp("attention.self"), config)?; + let output = DebertaV2SelfOutput::load(vb.pp("attention.output"), config)?; + Ok(Self { dsa, output }) + } + + fn forward( + &self, + hidden_states: &Tensor, + attention_mask: &Tensor, + query_states: Option<&Tensor>, + relative_pos: Option<&Tensor>, + rel_embeddings: Option<&Tensor>, + ) -> Result { + let self_output = self.dsa.forward( + hidden_states, + attention_mask, + query_states, + relative_pos, + rel_embeddings, + )?; + + self.output + .forward(&self_output, query_states.unwrap_or(hidden_states)) + } +} + +// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L255 +pub struct DebertaV2SelfOutput { + dense: candle_nn::Linear, + layer_norm: LayerNorm, +} + +impl DebertaV2SelfOutput { + pub fn load(vb: VarBuilder, config: &DebertaV2Config) -> Result { + let dense = candle_nn::linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; + let layer_norm = candle_nn::layer_norm( + config.hidden_size, + config.layer_norm_eps, + vb.pp("LayerNorm"), + )?; + Ok(Self { + dense, + layer_norm, + }) + } + + pub fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { + let mut hidden_states = self.dense.forward(hidden_states)?; + self.layer_norm + .forward(&hidden_states.broadcast_add(input_tensor)?) + } +} + +// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L307 +pub struct DebertaV2Intermediate { + dense: candle_nn::Linear, + intermediate_act: HiddenActLayer, +} + +impl DebertaV2Intermediate { + pub fn load(vb: VarBuilder, config: &DebertaV2Config) -> Result { + let dense = candle_nn::linear( + config.hidden_size, + config.intermediate_size, + vb.pp("intermediate.dense"), + )?; + let intermediate_act = HiddenActLayer::new(config.hidden_act); + Ok(Self { + dense, + intermediate_act, + }) + } + + pub fn forward(&self, hidden_states: &Tensor) -> Result { + self.intermediate_act + .forward(&self.dense.forward(hidden_states)?) + } +} + +// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L323 +pub struct DebertaV2Output { + dense: candle_nn::Linear, + layer_norm: LayerNorm, +} + +impl DebertaV2Output { + pub fn load(vb: VarBuilder, config: &DebertaV2Config) -> Result { + let dense = candle_nn::linear( + config.intermediate_size, + config.hidden_size, + vb.pp("output.dense"), + )?; + let layer_norm = candle_nn::layer_norm( + config.hidden_size, + config.layer_norm_eps, + vb.pp("output.LayerNorm"), + )?; + Ok(Self { + dense, + layer_norm, + }) + } + + pub fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { + let mut hidden_states = self.dense.forward(hidden_states)?; + hidden_states = { + let to_norm = hidden_states.broadcast_add(input_tensor)?; + self.layer_norm.forward(&to_norm)? + }; + Ok(hidden_states) + } +} + +// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L339 +pub struct DebertaV2Layer { + attention: DebertaV2Attention, + intermediate: DebertaV2Intermediate, + output: DebertaV2Output, +} + +impl DebertaV2Layer { + pub fn load(vb: VarBuilder, config: &DebertaV2Config) -> Result { + let attention = DebertaV2Attention::load(vb.clone(), config)?; + let intermediate = DebertaV2Intermediate::load(vb.clone(), config)?; + let output = DebertaV2Output::load(vb.clone(), config)?; + Ok(Self { + attention, + intermediate, + output, + }) + } + + fn forward( + &self, + hidden_states: &Tensor, + attention_mask: &Tensor, + query_states: Option<&Tensor>, + relative_pos: Option<&Tensor>, + rel_embeddings: Option<&Tensor>, + ) -> Result { + let attention_output = self.attention.forward( + hidden_states, + attention_mask, + query_states, + relative_pos, + rel_embeddings, + )?; + + let intermediate_output = self.intermediate.forward(&attention_output)?; + + let layer_output = self + .output + .forward(&intermediate_output, &attention_output)?; + + Ok(layer_output) + } +} + +// TODO: In order to fully test ConvLayer a model needs to be found has a configuration where `conv_kernel_size` exists and is > 0 +// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L373 +pub struct ConvLayer { + _conv_act: String, + _conv: Conv1d, + _layer_norm: LayerNorm, + _config: DebertaV2Config, +} + +impl ConvLayer { + pub fn load(vb: VarBuilder, config: &DebertaV2Config) -> Result { + let config = config.clone(); + let kernel_size = config.conv_kernel_size.unwrap_or(3); + let groups = config.conv_groups.unwrap_or(1); + let conv_act: String = config.conv_act.clone().unwrap_or("tanh".to_string()); + + let conv_conf = Conv1dConfig { + padding: (kernel_size - 1) / 2, + groups, + ..Default::default() + }; + + let conv = conv1d( + config.hidden_size, + config.hidden_size, + kernel_size, + conv_conf, + vb.pp("conv"), + )?; + + let layer_norm = layer_norm( + config.hidden_size, + config.layer_norm_eps, + vb.pp("LayerNorm"), + )?; + + + Ok(Self { + _conv_act: conv_act, + _conv: conv, + _layer_norm: layer_norm, + _config: config, + }) + } + + pub fn forward( + &self, + _hidden_states: &Tensor, + _residual_states: &Tensor, + _input_mask: &Tensor, + ) -> Result { + todo!("Need a model that contains a conv layer to test against.") + } +} + +// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L409 +pub struct DebertaV2Encoder { + layer: Vec, + relative_attention: bool, + max_relative_positions: isize, + position_buckets: isize, + rel_embeddings: Option, + norm_rel_ebd: String, + layer_norm: Option, + conv: Option, + device: Device, +} + +impl DebertaV2Encoder { + pub fn load(vb: VarBuilder, config: &DebertaV2Config) -> Result { + let layer = (0..config.num_hidden_layers) + .map(|index| DebertaV2Layer::load(vb.pp(format!("layer.{index}")), config)) + .collect::>>()?; + + let relative_attention = config.relative_attention; + let mut max_relative_positions = config.max_relative_positions; + + let position_buckets = config.position_buckets.unwrap_or(-1); + + let mut rel_embeddings: Option = None; + + if relative_attention { + if max_relative_positions < 1 { + max_relative_positions = config.max_position_embeddings as isize; + } + + let mut pos_ebd_size = max_relative_positions * 2; + + if position_buckets > 0 { + pos_ebd_size = position_buckets * 2; + } + + rel_embeddings = Some(embedding( + pos_ebd_size as usize, + config.hidden_size, + vb.pp("rel_embeddings"), + )?); + } + + // NOTE: The Python code assumes that the config attribute "norm_rel_ebd" is an array of some kind, but most examples have it as a string. + // So it might need to be updated at some point. + let norm_rel_ebd = match config.norm_rel_ebd.as_ref() { + Some(nre) => nre.trim().to_string(), + None => "none".to_string(), + }; + + let layer_norm: Option = if norm_rel_ebd == "layer_norm" { + Some(layer_norm( + config.hidden_size, + config.layer_norm_eps, + vb.pp("LayerNorm"), + )?) + } else { + None + }; + + let conv: Option = if config.conv_kernel_size.unwrap_or(0) > 0 { + Some(ConvLayer::load(vb.pp("conv"), config)?) + } else { + None + }; + + Ok(Self { + layer, + relative_attention, + max_relative_positions, + position_buckets, + rel_embeddings, + norm_rel_ebd, + layer_norm, + conv, + device: vb.device().clone(), + }) + } + + pub fn forward( + &self, + hidden_states: &Tensor, + attention_mask: &Tensor, + query_states: Option<&Tensor>, + relative_pos: Option<&Tensor>, + ) -> Result { + let input_mask = if attention_mask.dims().len() <= 2 { + attention_mask.clone() + } else { + attention_mask + .sum_keepdim(attention_mask.rank() - 2)? + .gt(0.)? + }; + + let attention_mask = self.get_attention_mask(attention_mask.clone())?; + + let relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)?; + + let mut next_kv: Tensor = hidden_states.clone(); + let rel_embeddings = self.get_rel_embedding()?; + let mut output_states = next_kv.to_owned(); + let mut query_states: Option = query_states.cloned(); + + for (i, layer_module) in self.layer.iter().enumerate() { + // NOTE: The original python code branches here if this model is being + // used for training vs. inferencing. For now, we will only handle the + // inferencing side of things + + output_states = layer_module.forward( + next_kv.as_ref(), + &attention_mask, + query_states.as_ref(), + relative_pos.as_ref(), + rel_embeddings.as_ref(), + )?; + + if i == 0 { + if let Some(conv) = &self.conv { + output_states = conv.forward(hidden_states, &output_states, &input_mask)?; + } + } + + if query_states.is_some() { + query_states = Some(output_states.clone()); + } else { + next_kv = output_states.clone(); + } + } + + Ok(output_states) + } + + fn get_attention_mask(&self, mut attention_mask: Tensor) -> Result { + match attention_mask.dims().len() { + 0..=2 => { + let extended_attention_mask = attention_mask.unsqueeze(1)?.unsqueeze(2)?; + attention_mask = extended_attention_mask.broadcast_mul( + &extended_attention_mask + .squeeze(D::Minus2)? + .unsqueeze(D::Minus1)?, + )?; + } + 3 => attention_mask = attention_mask.unsqueeze(1)?, + len => bail!("Unsupported attentiom mask size length: {len}"), + } + + Ok(attention_mask) + } + + fn get_rel_pos( + &self, + hidden_states: &Tensor, + query_states: Option<&Tensor>, + relative_pos: Option<&Tensor>, + ) -> Result> { + if self.relative_attention && relative_pos.is_none() { + let q = if let Some(query_states) = query_states { + query_states.dim(D::Minus2)? + } else { + hidden_states.dim(D::Minus2)? + }; + + return Ok(Some(build_relative_position( + q, + hidden_states.dim(D::Minus2)?, + &self.device, + Some(self.position_buckets), + Some(self.max_relative_positions), + )?)); + } + + if relative_pos.is_some() { + Ok(relative_pos.cloned()) + } else { + Ok(None) + } + } + fn get_rel_embedding(&self) -> Result> { + if !self.relative_attention { + return Ok(None); + } + + let rel_embeddings = self + .rel_embeddings + .as_ref() + .context("self.rel_embeddings not present when using relative_attention")? + .embeddings() + .clone(); + + if !self.norm_rel_ebd.contains("layer_norm") { + return Ok(Some(rel_embeddings)); + } + + let layer_normed_embeddings = self + .layer_norm + .as_ref() + .context("DebertaV2Encoder layer_norm is None when norm_rel_ebd contains layer_norm")? + .forward(&rel_embeddings)?; + + Ok(Some(layer_normed_embeddings)) + } +} + +// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L991 +pub struct DebertaV2Model { + embeddings: DebertaV2Embeddings, + encoder: DebertaV2Encoder, + z_steps: usize, + pub device: Device, +} + +impl DebertaV2Model { + pub fn load(vb: VarBuilder, config: &DebertaV2Config) -> Result { + let vb = vb.clone(); + let embeddings = DebertaV2Embeddings::load(vb.pp("embeddings"), config)?; + let encoder = DebertaV2Encoder::load(vb.pp("encoder"), config)?; + let z_steps: usize = 0; + + Ok(Self { + embeddings, + encoder, + z_steps, + device: vb.device().clone(), + }) + } + + pub fn forward( + &self, + input_ids: &Tensor, + token_type_ids: Option, + attention_mask: Option, + ) -> Result { + let input_ids_shape = input_ids.shape(); + + let attention_mask = match attention_mask { + Some(mask) => mask, + None => Tensor::ones(input_ids_shape, DType::I64, &self.device)?, + }; + + let token_type_ids = match token_type_ids { + Some(ids) => ids, + None => Tensor::zeros(input_ids_shape, DType::U32, &self.device)?, + }; + + let embedding_output = self.embeddings.forward( + Some(input_ids), + Some(&token_type_ids), + None, + Some(&attention_mask), + None, + )?; + + let encoder_output = + self.encoder + .forward(&embedding_output, &attention_mask, None, None)?; + + if self.z_steps > 1 { + todo!("Complete DebertaV2Model forward() when z_steps > 1 -- Needs a model to test this situation.") + } + + Ok(encoder_output) + } +} + +#[derive(Debug)] +pub struct NERItem { + pub entity: String, + pub word: String, + pub score: f32, + pub start: usize, + pub end: usize, + pub index: usize, +} + +#[derive(Debug)] +pub struct TextClassificationItem { + pub label: String, + pub score: f32, +} + +pub struct DebertaV2NERModel { + pub device: Device, + deberta: DebertaV2Model, + dropout: candle_nn::Dropout, + classifier: candle_nn::Linear, +} + +fn id2label_len(config: &DebertaV2Config, id2label: Option>) -> Result { + let id2label_len = match (&config.id2label, id2label) { + (None, None) => bail!("Id2Label is either not present in the model configuration or not passed into DebertaV2NERModel::load as a parameter"), + (None, Some(id2label_p)) => id2label_p.len(), + (Some(id2label_c), None) => id2label_c.len(), + (Some(id2label_c), Some(id2label_p)) => { + if *id2label_c == id2label_p { + id2label_c.len() + } else { + bail!("Id2Label is both present in the model configuration and provided as a parameter, and they are different.") + } + } + }; + Ok(id2label_len) +} + +impl DebertaV2NERModel { + pub fn load(vb: VarBuilder, config: &DebertaV2Config, id2label: Option) -> Result { + let id2label_len = id2label_len(config, id2label)?; + + let deberta = DebertaV2Model::load(vb.clone(), config)?; + let dropout = candle_nn::Dropout::new(config.hidden_dropout_prob as f32); + let classifier: candle_nn::Linear = candle_nn::linear_no_bias( + config.hidden_size, + id2label_len, + vb.root().pp("classifier"), + )?; + + Ok(Self { + device: vb.device().clone(), + deberta, + dropout, + classifier, + }) + } + + pub fn forward( + &self, + input_ids: &Tensor, + token_type_ids: Option, + attention_mask: Option, + ) -> Result { + let output = self + .deberta + .forward(input_ids, token_type_ids, attention_mask)?; + let output = self.dropout.forward(&output, false)?; + self.classifier.forward(&output) + } +} + +pub struct DebertaV2SeqClassificationModel { + pub device: Device, + deberta: DebertaV2Model, + pooler: DebertaV2ContextPooler, + classifier: candle_nn::Linear, +} + +impl DebertaV2SeqClassificationModel { + pub fn load(vb: VarBuilder, config: &DebertaV2Config, id2label: Option) -> Result { + let id2label_len = id2label_len(config, id2label)?; + let deberta = DebertaV2Model::load(vb.clone(), config)?; + let pooler = DebertaV2ContextPooler::load(vb.clone(), config)?; + let output_dim = pooler.output_dim()?; + let classifier = candle_nn::linear(output_dim, id2label_len, vb.root().pp("classifier"))?; + + Ok(Self { + device: vb.device().clone(), + deberta, + pooler, + classifier, + }) + } + + pub fn forward( + &self, + input_ids: &Tensor, + token_type_ids: Option, + attention_mask: Option, + ) -> Result { + let encoder_layer = self + .deberta + .forward(input_ids, token_type_ids, attention_mask)?; + let pooled_output = self.pooler.forward(&encoder_layer)?; + self.classifier.forward(&pooled_output) + } +} + +pub struct DebertaV2ContextPooler { + dense: candle_nn::Linear, + config: DebertaV2Config, +} + +// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L49 +impl DebertaV2ContextPooler { + pub fn load(vb: VarBuilder, config: &DebertaV2Config) -> Result { + let pooler_hidden_size = config + .pooler_hidden_size + .context("config.pooler_hidden_size is required for DebertaV2ContextPooler")?; + + let pooler_dropout = config + .pooler_dropout + .context("config.pooler_dropout is required for DebertaV2ContextPooler")?; + + let dense = candle_nn::linear( + pooler_hidden_size, + pooler_hidden_size, + vb.root().pp("pooler.dense"), + )?; + + + Ok(Self { + dense, + config: config.clone(), + }) + } + + pub fn forward(&self, hidden_states: &Tensor) -> Result { + let context_token = hidden_states.narrow(1, 0, 1)?.squeeze(1)?; + + let pooled_output = self.dense.forward(&context_token.contiguous()?)?; + let pooler_hidden_act = self + .config + .pooler_hidden_act + .context("Could not obtain pooler hidden act from config")?; + + HiddenActLayer::new(pooler_hidden_act).forward(&pooled_output) + } + + pub fn output_dim(&self) -> Result { + self.config.pooler_hidden_size.context("DebertaV2ContextPooler cannot return output_dim (pooler_hidden_size) since it is not specified in the model config") + } +} + +// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L557 +pub(crate) fn build_relative_position( + query_size: usize, + key_size: usize, + device: &Device, + bucket_size: Option, + max_position: Option, +) -> Result { + let q_ids = Tensor::arange(0, query_size as i64, device)?.unsqueeze(0)?; + let k_ids: Tensor = Tensor::arange(0, key_size as i64, device)?.unsqueeze(D::Minus1)?; + let mut rel_pos_ids = k_ids.broadcast_sub(&q_ids)?; + let bucket_size = bucket_size.unwrap_or(-1); + let max_position = max_position.unwrap_or(-1); + + if bucket_size > 0 && max_position > 0 { + rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position, device)?; + } + + rel_pos_ids = rel_pos_ids.to_dtype(DType::I64)?; + rel_pos_ids = rel_pos_ids.narrow(0, 0, query_size)?; + rel_pos_ids.unsqueeze(0) +} + +// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L542 +pub(crate) fn make_log_bucket_position( + relative_pos: Tensor, + bucket_size: isize, + max_position: isize, + device: &Device, +) -> Result { + let sign = relative_pos.to_dtype(DType::F32)?.sign()?; + + let mid = bucket_size / 2; + + let lt_mid = relative_pos.lt(mid as i64)?; + let gt_neg_mid = relative_pos.gt(-mid as i64)?; + + let condition = lt_mid + .to_dtype(candle::DType::F32)? + .mul(>_neg_mid.to_dtype(candle::DType::F32)?)? + .to_dtype(DType::U8)?; + + let on_true = Tensor::new(&[(mid - 1) as u32], device)? + .broadcast_as(relative_pos.shape())? + .to_dtype(relative_pos.dtype())?; + + let on_false = relative_pos + .to_dtype(DType::F32)? + .abs()? + .to_dtype(DType::I64)?; + + let abs_pos = condition.where_cond(&on_true, &on_false)?; + + let mid_as_tensor = Tensor::from_slice(&[mid as f32], (1,), device)?; + + let log_pos = { + let first_log = abs_pos + .to_dtype(DType::F32)? + .broadcast_div(&mid_as_tensor)? + .log()?; + + let second_log = + Tensor::from_slice(&[((max_position as f32 - 1.0) / mid as f32)], (1,), device)? + .log()?; + + let first_div_second = first_log.broadcast_div(&second_log)?; + + let to_ceil = first_div_second + .broadcast_mul(Tensor::from_slice(&[(mid - 1) as f32], (1,), device)?.as_ref())?; + + let ceil = to_ceil.ceil()?; + + ceil.broadcast_add(&mid_as_tensor)? + }; + + Ok({ + let abs_pos_lte_mid = abs_pos.to_dtype(DType::F32)?.broadcast_le(&mid_as_tensor)?; + let relative_pos = relative_pos.to_dtype(relative_pos.dtype())?; + let log_pos_mul_sign = log_pos.broadcast_mul(&sign.to_dtype(DType::F32)?)?; + abs_pos_lte_mid.where_cond(&relative_pos.to_dtype(DType::F32)?, &log_pos_mul_sign)? + }) +} \ No newline at end of file diff --git a/backends/candle/src/models/mod.rs b/backends/candle/src/models/mod.rs index 0d4d5506..65c6e8b5 100644 --- a/backends/candle/src/models/mod.rs +++ b/backends/candle/src/models/mod.rs @@ -5,6 +5,7 @@ extern crate intel_mkl_src; extern crate accelerate_src; mod bert; +mod debertav2; mod distilbert; mod jina; mod jina_code; @@ -49,6 +50,7 @@ mod qwen3; pub use bert::{BertConfig, BertModel, PositionEmbeddingType}; use candle::{Result, Tensor}; +pub use debertav2::{DebertaV2Config, DebertaV2Model}; pub use distilbert::{DistilBertConfig, DistilBertModel}; #[allow(unused_imports)] pub use gte::{GTEClassificationHead, GTEConfig, GTEModel, GTEMLP}; From 32ea2bee22b925575c6b0f028dc60aaf24f07c52 Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Mon, 14 Jul 2025 16:44:48 -0700 Subject: [PATCH 02/11] Remove NER and id2label --- backends/candle/src/models/debertav2.rs | 74 ++----------------------- 1 file changed, 5 insertions(+), 69 deletions(-) diff --git a/backends/candle/src/models/debertav2.rs b/backends/candle/src/models/debertav2.rs index 9446ad0b..1451825c 100644 --- a/backends/candle/src/models/debertav2.rs +++ b/backends/candle/src/models/debertav2.rs @@ -1147,79 +1147,12 @@ impl DebertaV2Model { } } -#[derive(Debug)] -pub struct NERItem { - pub entity: String, - pub word: String, - pub score: f32, - pub start: usize, - pub end: usize, - pub index: usize, -} - #[derive(Debug)] pub struct TextClassificationItem { pub label: String, pub score: f32, } -pub struct DebertaV2NERModel { - pub device: Device, - deberta: DebertaV2Model, - dropout: candle_nn::Dropout, - classifier: candle_nn::Linear, -} - -fn id2label_len(config: &DebertaV2Config, id2label: Option>) -> Result { - let id2label_len = match (&config.id2label, id2label) { - (None, None) => bail!("Id2Label is either not present in the model configuration or not passed into DebertaV2NERModel::load as a parameter"), - (None, Some(id2label_p)) => id2label_p.len(), - (Some(id2label_c), None) => id2label_c.len(), - (Some(id2label_c), Some(id2label_p)) => { - if *id2label_c == id2label_p { - id2label_c.len() - } else { - bail!("Id2Label is both present in the model configuration and provided as a parameter, and they are different.") - } - } - }; - Ok(id2label_len) -} - -impl DebertaV2NERModel { - pub fn load(vb: VarBuilder, config: &DebertaV2Config, id2label: Option) -> Result { - let id2label_len = id2label_len(config, id2label)?; - - let deberta = DebertaV2Model::load(vb.clone(), config)?; - let dropout = candle_nn::Dropout::new(config.hidden_dropout_prob as f32); - let classifier: candle_nn::Linear = candle_nn::linear_no_bias( - config.hidden_size, - id2label_len, - vb.root().pp("classifier"), - )?; - - Ok(Self { - device: vb.device().clone(), - deberta, - dropout, - classifier, - }) - } - - pub fn forward( - &self, - input_ids: &Tensor, - token_type_ids: Option, - attention_mask: Option, - ) -> Result { - let output = self - .deberta - .forward(input_ids, token_type_ids, attention_mask)?; - let output = self.dropout.forward(&output, false)?; - self.classifier.forward(&output) - } -} - pub struct DebertaV2SeqClassificationModel { pub device: Device, deberta: DebertaV2Model, @@ -1228,8 +1161,11 @@ pub struct DebertaV2SeqClassificationModel { } impl DebertaV2SeqClassificationModel { - pub fn load(vb: VarBuilder, config: &DebertaV2Config, id2label: Option) -> Result { - let id2label_len = id2label_len(config, id2label)?; + pub fn load(vb: VarBuilder, config: &DebertaV2Config) -> Result { + let id2label_len = match &config.id2label { + None => candle::bail!("`id2label` must be set for classifier models"), + Some(id2label) => id2label.len(), + }; let deberta = DebertaV2Model::load(vb.clone(), config)?; let pooler = DebertaV2ContextPooler::load(vb.clone(), config)?; let output_dim = pooler.output_dim()?; From e611686fddd0b2fe3eea8c57523756000370de4a Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Mon, 14 Jul 2025 17:49:03 -0700 Subject: [PATCH 03/11] Load debertav2 classification head but not implement trait/forward --- backends/candle/src/models/debertav2.rs | 81 +++++++++---------------- 1 file changed, 30 insertions(+), 51 deletions(-) diff --git a/backends/candle/src/models/debertav2.rs b/backends/candle/src/models/debertav2.rs index 1451825c..a740b5d0 100644 --- a/backends/candle/src/models/debertav2.rs +++ b/backends/candle/src/models/debertav2.rs @@ -5,6 +5,7 @@ use candle_nn::{ conv1d, embedding, layer_norm, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder, }; use serde::{Deserialize, Deserializer}; +use text_embeddings_backend_core::{Batch, ModelType, Pool}; pub const DTYPE: DType = DType::F32; @@ -160,7 +161,7 @@ impl DebertaV2Embeddings { config.layer_norm_eps, vb.pp("LayerNorm"), )?; - + let position_ids = Tensor::arange(0, config.max_position_embeddings as u32, &device)?.unsqueeze(0)?; @@ -343,7 +344,7 @@ impl DebertaV2DisentangledSelfAttention { if position_buckets > 0 { pos_ebd_size = position_buckets } - + if !share_att_key { if config.pos_att_type.iter().any(|s| s == "c2p") { pos_key_proj = Some(candle_nn::linear( @@ -447,7 +448,7 @@ impl DebertaV2DisentangledSelfAttention { let mut attention_probs = XSoftmax::apply(&attention_scores, attention_mask, D::Minus1, &self.device)?; - + let mut context_layer = attention_probs .reshape(( (), @@ -561,7 +562,7 @@ impl DebertaV2DisentangledSelfAttention { )? .forward(&rel_embeddings)?, )? - .repeat(repeat_with)?, + .repeat(repeat_with)?, ) } if self.config.pos_att_type.iter().any(|s| s == "p2c") { @@ -582,7 +583,7 @@ impl DebertaV2DisentangledSelfAttention { &[(pos_key_layer.dim(D::Minus1)? * scale_factor) as f32], &self.device, )? - .sqrt()?; + .sqrt()?; let mut c2p_att = query_layer.matmul(&pos_key_layer.t()?)?; @@ -614,7 +615,7 @@ impl DebertaV2DisentangledSelfAttention { &[(pos_query_layer.dim(D::Minus1)? * scale_factor) as f32], &self.device, )? - .sqrt()?; + .sqrt()?; let r_pos = { if key_layer.dim(D::Minus2)? != query_layer.dim(D::Minus2)? { @@ -625,7 +626,7 @@ impl DebertaV2DisentangledSelfAttention { Some(self.position_buckets), Some(self.max_relative_positions), )? - .unsqueeze(0)? + .unsqueeze(0)? } else { relative_pos } @@ -709,10 +710,7 @@ impl DebertaV2SelfOutput { config.layer_norm_eps, vb.pp("LayerNorm"), )?; - Ok(Self { - dense, - layer_norm, - }) + Ok(Self { dense, layer_norm }) } pub fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { @@ -766,10 +764,7 @@ impl DebertaV2Output { config.layer_norm_eps, vb.pp("output.LayerNorm"), )?; - Ok(Self { - dense, - layer_norm, - }) + Ok(Self { dense, layer_norm }) } pub fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { @@ -863,7 +858,6 @@ impl ConvLayer { vb.pp("LayerNorm"), )?; - Ok(Self { _conv_act: conv_act, _conv: conv, @@ -1090,21 +1084,32 @@ impl DebertaV2Encoder { pub struct DebertaV2Model { embeddings: DebertaV2Embeddings, encoder: DebertaV2Encoder, - z_steps: usize, + classifier: Option, + pub device: Device, } impl DebertaV2Model { - pub fn load(vb: VarBuilder, config: &DebertaV2Config) -> Result { + pub fn load(vb: VarBuilder, config: &DebertaV2Config, model_type: ModelType) -> Result { let vb = vb.clone(); + let classifier = match model_type { + ModelType::Classifier => { + + let classifier = DebertaV2SeqClassificationHead::load(vb.clone(), config)?; + + Some(classifier) + } + ModelType::Embedding(pool) => { + None + } + }; let embeddings = DebertaV2Embeddings::load(vb.pp("embeddings"), config)?; let encoder = DebertaV2Encoder::load(vb.pp("encoder"), config)?; - let z_steps: usize = 0; Ok(Self { embeddings, encoder, - z_steps, + classifier, device: vb.device().clone(), }) } @@ -1139,56 +1144,35 @@ impl DebertaV2Model { self.encoder .forward(&embedding_output, &attention_mask, None, None)?; - if self.z_steps > 1 { - todo!("Complete DebertaV2Model forward() when z_steps > 1 -- Needs a model to test this situation.") - } - Ok(encoder_output) } } -#[derive(Debug)] -pub struct TextClassificationItem { - pub label: String, - pub score: f32, -} - -pub struct DebertaV2SeqClassificationModel { +pub struct DebertaV2SeqClassificationHead { pub device: Device, - deberta: DebertaV2Model, pooler: DebertaV2ContextPooler, classifier: candle_nn::Linear, } -impl DebertaV2SeqClassificationModel { +impl DebertaV2SeqClassificationHead { pub fn load(vb: VarBuilder, config: &DebertaV2Config) -> Result { let id2label_len = match &config.id2label { None => candle::bail!("`id2label` must be set for classifier models"), Some(id2label) => id2label.len(), }; - let deberta = DebertaV2Model::load(vb.clone(), config)?; let pooler = DebertaV2ContextPooler::load(vb.clone(), config)?; let output_dim = pooler.output_dim()?; let classifier = candle_nn::linear(output_dim, id2label_len, vb.root().pp("classifier"))?; Ok(Self { device: vb.device().clone(), - deberta, pooler, classifier, }) } - pub fn forward( - &self, - input_ids: &Tensor, - token_type_ids: Option, - attention_mask: Option, - ) -> Result { - let encoder_layer = self - .deberta - .forward(input_ids, token_type_ids, attention_mask)?; - let pooled_output = self.pooler.forward(&encoder_layer)?; + pub fn forward(&self, hidden_states: &Tensor) -> Result { + let pooled_output = self.pooler.forward(&hidden_states)?; self.classifier.forward(&pooled_output) } } @@ -1205,17 +1189,12 @@ impl DebertaV2ContextPooler { .pooler_hidden_size .context("config.pooler_hidden_size is required for DebertaV2ContextPooler")?; - let pooler_dropout = config - .pooler_dropout - .context("config.pooler_dropout is required for DebertaV2ContextPooler")?; - let dense = candle_nn::linear( pooler_hidden_size, pooler_hidden_size, vb.root().pp("pooler.dense"), )?; - Ok(Self { dense, config: config.clone(), @@ -1320,4 +1299,4 @@ pub(crate) fn make_log_bucket_position( let log_pos_mul_sign = log_pos.broadcast_mul(&sign.to_dtype(DType::F32)?)?; abs_pos_lte_mid.where_cond(&relative_pos.to_dtype(DType::F32)?, &log_pos_mul_sign)? }) -} \ No newline at end of file +} From 7cd8dd84c340dadbbdbe039b62f2d4975ba57ba5 Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Tue, 15 Jul 2025 16:20:46 -0700 Subject: [PATCH 04/11] [WIP] loaded weights but failing during warmup --- backends/candle/src/models/debertav2.rs | 268 ++++++++++++++++++++---- 1 file changed, 228 insertions(+), 40 deletions(-) diff --git a/backends/candle/src/models/debertav2.rs b/backends/candle/src/models/debertav2.rs index a740b5d0..9727872b 100644 --- a/backends/candle/src/models/debertav2.rs +++ b/backends/candle/src/models/debertav2.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; -use candle::{bail, Context, DType, Device, Module, Result, Tensor, D}; +use candle::{bail, Context, DType, Device, IndexOp, Module, Result, Tensor, D}; +use crate::models::Model; use candle_nn::{ conv1d, embedding, layer_norm, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder, }; @@ -47,7 +48,7 @@ enum PositionEmbeddingType { Absolute, } -pub type Id2Label = HashMap; +pub type Id2Label = HashMap; pub type Label2Id = HashMap; #[derive(Debug, Clone, PartialEq, Deserialize)] @@ -399,8 +400,6 @@ impl DebertaV2DisentangledSelfAttention { let key_layer = self.transpose_for_scores(&self.key_proj.forward(query_states)?)?; let value_layer = self.transpose_for_scores(&self.value_proj.forward(query_states)?)?; - let mut rel_att: Option = None; - let mut scale_factor: usize = 1; if self.config.pos_att_type.iter().any(|s| s == "c2p") { @@ -423,17 +422,21 @@ impl DebertaV2DisentangledSelfAttention { query_layer.matmul(&div)? }; - if self.relative_attention { + let rel_att = if self.relative_attention { if let Some(rel_embeddings) = rel_embeddings { - let rel_att = Some(self.disentangled_attention_bias( + Some(self.disentangled_attention_bias( query_layer, key_layer, relative_pos, rel_embeddings.clone(), scale_factor, - )?); + )?) + } else { + None } - } + } else { + None + }; if let Some(rel_att) = rel_att { attention_scores = attention_scores.broadcast_add(&rel_att)?; @@ -446,7 +449,7 @@ impl DebertaV2DisentangledSelfAttention { attention_scores.dim(D::Minus1)?, ))?; - let mut attention_probs = + let attention_probs = XSoftmax::apply(&attention_scores, attention_mask, D::Minus1, &self.device)?; let mut context_layer = attention_probs @@ -714,7 +717,7 @@ impl DebertaV2SelfOutput { } pub fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { - let mut hidden_states = self.dense.forward(hidden_states)?; + let hidden_states = self.dense.forward(hidden_states)?; self.layer_norm .forward(&hidden_states.broadcast_add(input_tensor)?) } @@ -1085,66 +1088,224 @@ pub struct DebertaV2Model { embeddings: DebertaV2Embeddings, encoder: DebertaV2Encoder, classifier: Option, - + pool: Option, pub device: Device, + pub dtype: DType, + span: tracing::Span, } impl DebertaV2Model { pub fn load(vb: VarBuilder, config: &DebertaV2Config, model_type: ModelType) -> Result { let vb = vb.clone(); - let classifier = match model_type { + let (classifier, pool) = match model_type { ModelType::Classifier => { - let classifier = DebertaV2SeqClassificationHead::load(vb.clone(), config)?; - - Some(classifier) + (Some(classifier), None) } ModelType::Embedding(pool) => { - None + (None, Some(pool)) } }; - let embeddings = DebertaV2Embeddings::load(vb.pp("embeddings"), config)?; - let encoder = DebertaV2Encoder::load(vb.pp("encoder"), config)?; + + // Try loading embeddings from "embeddings" first, then "deberta.embeddings" + let embeddings = match DebertaV2Embeddings::load(vb.pp("embeddings"), config) { + Ok(embeddings) => embeddings, + Err(_) => DebertaV2Embeddings::load(vb.pp("deberta.embeddings"), config)?, + }; + + // Try loading encoder from "encoder" first, then "deberta.encoder" + let encoder = match DebertaV2Encoder::load(vb.pp("encoder"), config) { + Ok(encoder) => encoder, + Err(_) => DebertaV2Encoder::load(vb.pp("deberta.encoder"), config)?, + }; Ok(Self { embeddings, encoder, classifier, + pool, device: vb.device().clone(), + dtype: vb.dtype(), + span: tracing::span!(tracing::Level::TRACE, "model"), }) } - pub fn forward( - &self, - input_ids: &Tensor, - token_type_ids: Option, - attention_mask: Option, - ) -> Result { - let input_ids_shape = input_ids.shape(); + pub fn forward(&self, batch: Batch) -> Result<(Option, Option)> { + let _enter = self.span.enter(); - let attention_mask = match attention_mask { - Some(mask) => mask, - None => Tensor::ones(input_ids_shape, DType::I64, &self.device)?, - }; + let batch_size = batch.len(); + let max_length = batch.max_length as usize; + + let shape = (batch_size, max_length); + + let (input_ids, token_type_ids, position_ids, input_lengths, attention_mask) = if batch_size > 1 { + // Prepare padded batch + let elems = batch_size * max_length; + + let mut input_ids = Vec::with_capacity(elems); + let mut token_type_ids = Vec::with_capacity(elems); + let mut position_ids = Vec::with_capacity(elems); + let mut attention_mask = Vec::with_capacity(elems); + let mut input_lengths = Vec::with_capacity(batch_size); + // Bool to know if we need to use the attention mask + let mut masking = false; + + for i in 0..batch_size { + let start = batch.cumulative_seq_lengths[i] as usize; + let end = batch.cumulative_seq_lengths[i + 1] as usize; + let seq_length = (end - start) as u32; + input_lengths.push(seq_length as f32); + + // Copy values + for j in start..end { + input_ids.push(batch.input_ids[j]); + token_type_ids.push(batch.token_type_ids[j]); + position_ids.push(batch.position_ids[j]); + attention_mask.push(1.0_f32); + } - let token_type_ids = match token_type_ids { - Some(ids) => ids, - None => Tensor::zeros(input_ids_shape, DType::U32, &self.device)?, + // Add padding if needed + let padding = batch.max_length - seq_length; + if padding > 0 { + // Set bool to use attention mask + masking = true; + for _ in 0..padding { + input_ids.push(0); + token_type_ids.push(0); + position_ids.push(0); + attention_mask.push(0.0_f32); + } + } + } + + let attention_mask = match masking { + true => { + let attention_mask = Tensor::from_vec( + attention_mask, + (batch_size, max_length, 1), + &self.device, + )? + .to_dtype(self.dtype)?; + + Some(attention_mask) + } + false => None, + }; + + (input_ids, token_type_ids, position_ids, input_lengths, attention_mask) + } else { + ( + batch.input_ids, + batch.token_type_ids, + batch.position_ids, + vec![batch.max_length as f32], + None, + ) }; + let input_ids = Tensor::from_vec(input_ids, shape, &self.device)?; + let token_type_ids = Tensor::from_vec(token_type_ids, shape, &self.device)?; + let position_ids = Tensor::from_vec(position_ids, shape, &self.device)?; + let mut input_lengths = + Tensor::from_vec(input_lengths, (batch_size, 1), &self.device)?.to_dtype(self.dtype)?; + let embedding_output = self.embeddings.forward( - Some(input_ids), + Some(&input_ids), Some(&token_type_ids), - None, - Some(&attention_mask), + Some(&position_ids), + attention_mask.as_ref(), None, )?; - let encoder_output = - self.encoder - .forward(&embedding_output, &attention_mask, None, None)?; + let encoder_attention_mask = attention_mask.as_ref().cloned().unwrap_or_else(|| Tensor::ones(shape, DType::I64, &self.device).unwrap()); + let encoder_output = self.encoder.forward(&embedding_output, &encoder_attention_mask, None, None)?; + + let has_pooling_requests = !batch.pooled_indices.is_empty(); + let has_raw_requests = !batch.raw_indices.is_empty(); + + let pooled_embeddings = if has_pooling_requests { + let pooled_indices_length = batch.pooled_indices.len(); + let mut outputs = encoder_output.clone(); + + // Only use pooled_indices if at least one member of the batch ask for raw embeddings + let pooled_indices = if has_raw_requests { + let pooled_indices = + Tensor::from_vec(batch.pooled_indices, pooled_indices_length, &self.device)?; + + // Select values in the batch + outputs = outputs.index_select(&pooled_indices, 0)?; + Some(pooled_indices) + } else { + None + }; + + let pooled_embeddings = match self.pool { + // CLS pooling + Some(Pool::Cls) => outputs.i((.., 0))?, + // Last token pooling is not supported for this model + Some(Pool::LastToken) => unreachable!(), + // Mean pooling + Some(Pool::Mean) => { + if let Some(ref attention_mask) = attention_mask { + let mut attention_mask = attention_mask.clone(); + + if let Some(pooled_indices) = pooled_indices { + // Select values in the batch + attention_mask = attention_mask.index_select(&pooled_indices, 0)?; + input_lengths = input_lengths.index_select(&pooled_indices, 0)?; + }; + + // Mask padded values + outputs = outputs.broadcast_mul(&attention_mask)?; + } + + (outputs.sum(1)?.broadcast_div(&input_lengths))? + } + Some(Pool::Splade) => unreachable!(), + None => outputs, + }; + Some(pooled_embeddings) + } else { + None + }; + + let raw_embeddings = if has_raw_requests { + // Reshape outputs + let (b, l, h) = encoder_output.shape().dims3()?; + let outputs = encoder_output.reshape((b * l, h))?; + + // We need to remove the padding tokens only if batch_size > 1 and there are some + // member of the batch that require pooling + // or if batch_size > 1 and the members of the batch have different lengths + if (attention_mask.is_some() || has_pooling_requests) && batch_size > 1 { + let mut final_indices: Vec = Vec::with_capacity(batch_size * max_length); + + for i in batch.raw_indices.into_iter() { + let start = i * batch.max_length; + let i = i as usize; + let length = + batch.cumulative_seq_lengths[i + 1] - batch.cumulative_seq_lengths[i]; + + for j in start..start + length { + // Add indices for the tokens of this specific member of the batch + final_indices.push(j); + } + } + + let final_indices_length = final_indices.len(); + let final_indices = + Tensor::from_vec(final_indices, final_indices_length, &self.device)?; + + // Select the tokens with final indices + Some(outputs.index_select(&final_indices, 0)?) + } else { + Some(outputs) + } + } else { + None + }; - Ok(encoder_output) + Ok((pooled_embeddings, raw_embeddings)) } } @@ -1162,7 +1323,12 @@ impl DebertaV2SeqClassificationHead { }; let pooler = DebertaV2ContextPooler::load(vb.clone(), config)?; let output_dim = pooler.output_dim()?; - let classifier = candle_nn::linear(output_dim, id2label_len, vb.root().pp("classifier"))?; + + // Try loading classifier from "classifier" first, then "deberta.classifier" + let classifier = match candle_nn::linear(output_dim, id2label_len, vb.root().pp("classifier")) { + Ok(classifier) => classifier, + Err(_) => candle_nn::linear(output_dim, id2label_len, vb.root().pp("deberta.classifier"))?, + }; Ok(Self { device: vb.device().clone(), @@ -1300,3 +1466,25 @@ pub(crate) fn make_log_bucket_position( abs_pos_lte_mid.where_cond(&relative_pos.to_dtype(DType::F32)?, &log_pos_mul_sign)? }) } + +impl Model for DebertaV2Model { + fn is_padded(&self) -> bool { + true + } + + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { + self.forward(batch) + } + + fn predict(&self, batch: Batch) -> Result { + match &self.classifier { + None => candle::bail!("`predict` is not implemented for this model"), + Some(classifier) => { + let (_pooled_embeddings, raw_embeddings) = self.forward(batch)?; + let raw_embeddings = + raw_embeddings.expect("raw_embeddings is empty. This is a bug."); + classifier.forward(&raw_embeddings) + } + } + } +} \ No newline at end of file From 940b23961294ea684d75c313544ac9df431315ac Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Tue, 15 Jul 2025 16:32:26 -0700 Subject: [PATCH 05/11] IT WORKS!! but not for batches :( --- backends/candle/src/models/debertav2.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/backends/candle/src/models/debertav2.rs b/backends/candle/src/models/debertav2.rs index 9727872b..60332053 100644 --- a/backends/candle/src/models/debertav2.rs +++ b/backends/candle/src/models/debertav2.rs @@ -1480,10 +1480,10 @@ impl Model for DebertaV2Model { match &self.classifier { None => candle::bail!("`predict` is not implemented for this model"), Some(classifier) => { - let (_pooled_embeddings, raw_embeddings) = self.forward(batch)?; - let raw_embeddings = - raw_embeddings.expect("raw_embeddings is empty. This is a bug."); - classifier.forward(&raw_embeddings) + let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?; + let pooled_embeddings = + pooled_embeddings.expect("pooled_embeddings is empty. This is a bug."); + classifier.forward(&pooled_embeddings) } } } From 537f83a604001528211b2571566a3bd54e83d819 Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Wed, 29 Oct 2025 20:24:43 -0700 Subject: [PATCH 06/11] interim review changes --- backends/candle/src/models/debertav2.rs | 119 +++++++++++------------- 1 file changed, 54 insertions(+), 65 deletions(-) diff --git a/backends/candle/src/models/debertav2.rs b/backends/candle/src/models/debertav2.rs index 60332053..2c683787 100644 --- a/backends/candle/src/models/debertav2.rs +++ b/backends/candle/src/models/debertav2.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use candle::{bail, Context, DType, Device, IndexOp, Module, Result, Tensor, D}; use crate::models::Model; use candle_nn::{ - conv1d, embedding, layer_norm, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder, + conv1d, embedding, layer_norm, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder, }; use serde::{Deserialize, Deserializer}; use text_embeddings_backend_core::{Batch, ModelType, Pool}; @@ -59,8 +59,6 @@ pub struct DebertaV2Config { pub num_attention_heads: usize, pub intermediate_size: usize, pub hidden_act: HiddenAct, - pub hidden_dropout_prob: f64, - pub attention_probs_dropout_prob: f64, pub max_position_embeddings: usize, pub type_vocab_size: usize, pub initializer_range: f64, @@ -81,10 +79,8 @@ pub struct DebertaV2Config { pub conv_act: Option, pub id2label: Option, pub label2id: Option, - pub pooler_dropout: Option, pub pooler_hidden_act: Option, pub pooler_hidden_size: Option, - pub cls_dropout: Option, } fn deserialize_pos_att_type<'de, D>(deserializer: D) -> std::result::Result, D::Error> @@ -111,10 +107,10 @@ pub struct DebertaV2Embeddings { position_embeddings: Option, token_type_embeddings: Option, layer_norm: LayerNorm, - position_ids: Tensor, config: DebertaV2Config, embedding_size: usize, - embed_proj: Option, + embed_proj: Option, + span: tracing::Span, } impl DebertaV2Embeddings { @@ -147,7 +143,7 @@ impl DebertaV2Embeddings { None }; - let embed_proj: Option = if embedding_size != config.hidden_size { + let embed_proj: Option = if embedding_size != config.hidden_size { Some(candle_nn::linear_no_bias( embedding_size, config.hidden_size, @@ -163,58 +159,28 @@ impl DebertaV2Embeddings { vb.pp("LayerNorm"), )?; - let position_ids = - Tensor::arange(0, config.max_position_embeddings as u32, &device)?.unsqueeze(0)?; - Ok(Self { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, - position_ids, device, config, embedding_size, embed_proj, + span: tracing::span!(tracing::Level::TRACE, "embeddings"), }) } pub fn forward( &self, - input_ids: Option<&Tensor>, - token_type_ids: Option<&Tensor>, - position_ids: Option<&Tensor>, + input_ids: &Tensor, + token_type_ids: &Tensor, + position_ids: &Tensor, mask: Option<&Tensor>, - inputs_embeds: Option<&Tensor>, ) -> Result { - let (input_shape, input_embeds) = match (input_ids, inputs_embeds) { - (Some(ids), None) => { - let embs = self.word_embeddings.forward(ids)?; - (ids.dims(), embs) - } - (None, Some(e)) => (e.dims(), e.clone()), - (None, None) => { - bail!("Must specify either input_ids or inputs_embeds") - } - (Some(_), Some(_)) => { - bail!("Can't specify both input_ids and inputs_embeds") - } - }; - - let seq_length = match input_shape.last() { - Some(v) => *v, - None => bail!("DebertaV2Embeddings invalid input shape"), - }; - - let position_ids = match position_ids { - Some(v) => v.clone(), - None => self.position_ids.narrow(1, 0, seq_length)?, - }; - - let token_type_ids = match token_type_ids { - Some(ids) => ids.clone(), - None => Tensor::zeros(input_shape, DType::U32, &self.device)?, - }; + let _enter = self.span.enter(); + let input_embeds = self.word_embeddings.forward(input_ids)?; let position_embeddings = match &self.position_embeddings { Some(emb) => emb.forward(&position_ids)?, @@ -291,17 +257,18 @@ impl XSoftmax { pub struct DebertaV2DisentangledSelfAttention { config: DebertaV2Config, num_attention_heads: usize, - query_proj: candle_nn::Linear, - key_proj: candle_nn::Linear, - value_proj: candle_nn::Linear, + query_proj: Linear, + key_proj: Linear, + value_proj: Linear, device: Device, relative_attention: bool, position_buckets: isize, max_relative_positions: isize, pos_ebd_size: isize, share_att_key: bool, - pos_key_proj: Option, - pos_query_proj: Option, + pos_key_proj: Option, + pos_query_proj: Option, + span: tracing::Span, } impl DebertaV2DisentangledSelfAttention { @@ -334,8 +301,8 @@ impl DebertaV2DisentangledSelfAttention { let mut pos_ebd_size: isize = 0; let position_buckets = config.position_buckets.unwrap_or(-1); - let mut pos_key_proj: Option = None; - let mut pos_query_proj: Option = None; + let mut pos_key_proj: Option = None; + let mut pos_query_proj: Option = None; if relative_attention { if max_relative_positions < 1 { @@ -380,6 +347,7 @@ impl DebertaV2DisentangledSelfAttention { share_att_key, pos_key_proj, pos_query_proj, + span: tracing::span!(tracing::Level::TRACE, "attention"), }) } @@ -391,6 +359,7 @@ impl DebertaV2DisentangledSelfAttention { relative_pos: Option<&Tensor>, rel_embeddings: Option<&Tensor>, ) -> Result { + let _enter = self.span.enter(); let query_states = match query_states { Some(qs) => qs, None => hidden_states, @@ -669,13 +638,14 @@ impl DebertaV2DisentangledSelfAttention { pub struct DebertaV2Attention { dsa: DebertaV2DisentangledSelfAttention, output: DebertaV2SelfOutput, + span: tracing::Span, } impl DebertaV2Attention { pub fn load(vb: VarBuilder, config: &DebertaV2Config) -> Result { let dsa = DebertaV2DisentangledSelfAttention::load(vb.pp("attention.self"), config)?; let output = DebertaV2SelfOutput::load(vb.pp("attention.output"), config)?; - Ok(Self { dsa, output }) + Ok(Self { dsa, output, span: tracing::span!(tracing::Level::TRACE, "attention") }) } fn forward( @@ -686,6 +656,7 @@ impl DebertaV2Attention { relative_pos: Option<&Tensor>, rel_embeddings: Option<&Tensor>, ) -> Result { + let _enter = self.span.enter(); let self_output = self.dsa.forward( hidden_states, attention_mask, @@ -701,8 +672,9 @@ impl DebertaV2Attention { // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L255 pub struct DebertaV2SelfOutput { - dense: candle_nn::Linear, + dense: Linear, layer_norm: LayerNorm, + span: tracing::Span, } impl DebertaV2SelfOutput { @@ -713,10 +685,11 @@ impl DebertaV2SelfOutput { config.layer_norm_eps, vb.pp("LayerNorm"), )?; - Ok(Self { dense, layer_norm }) + Ok(Self { dense, layer_norm, span: tracing::span!(tracing::Level::TRACE, "self-output") }) } pub fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { + let _enter = self.span.enter(); let hidden_states = self.dense.forward(hidden_states)?; self.layer_norm .forward(&hidden_states.broadcast_add(input_tensor)?) @@ -725,8 +698,9 @@ impl DebertaV2SelfOutput { // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L307 pub struct DebertaV2Intermediate { - dense: candle_nn::Linear, + dense: Linear, intermediate_act: HiddenActLayer, + span: tracing::Span, } impl DebertaV2Intermediate { @@ -740,10 +714,12 @@ impl DebertaV2Intermediate { Ok(Self { dense, intermediate_act, + span: tracing::span!(tracing::Level::TRACE, "intermediate"), }) } pub fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); self.intermediate_act .forward(&self.dense.forward(hidden_states)?) } @@ -751,8 +727,9 @@ impl DebertaV2Intermediate { // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L323 pub struct DebertaV2Output { - dense: candle_nn::Linear, + dense: Linear, layer_norm: LayerNorm, + span: tracing::Span, } impl DebertaV2Output { @@ -767,10 +744,11 @@ impl DebertaV2Output { config.layer_norm_eps, vb.pp("output.LayerNorm"), )?; - Ok(Self { dense, layer_norm }) + Ok(Self { dense, layer_norm, span: tracing::span!(tracing::Level::TRACE, "output") }) } pub fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { + let _enter = self.span.enter(); let mut hidden_states = self.dense.forward(hidden_states)?; hidden_states = { let to_norm = hidden_states.broadcast_add(input_tensor)?; @@ -785,6 +763,7 @@ pub struct DebertaV2Layer { attention: DebertaV2Attention, intermediate: DebertaV2Intermediate, output: DebertaV2Output, + span: tracing::Span, } impl DebertaV2Layer { @@ -796,6 +775,7 @@ impl DebertaV2Layer { attention, intermediate, output, + span: tracing::span!(tracing::Level::TRACE, "layer"), }) } @@ -807,6 +787,7 @@ impl DebertaV2Layer { relative_pos: Option<&Tensor>, rel_embeddings: Option<&Tensor>, ) -> Result { + let _enter = self.span.enter(); let attention_output = self.attention.forward( hidden_states, attention_mask, @@ -890,6 +871,7 @@ pub struct DebertaV2Encoder { layer_norm: Option, conv: Option, device: Device, + span: tracing::Span, } impl DebertaV2Encoder { @@ -956,6 +938,7 @@ impl DebertaV2Encoder { layer_norm, conv, device: vb.device().clone(), + span: tracing::span!(tracing::Level::TRACE, "encoder"), }) } @@ -966,6 +949,7 @@ impl DebertaV2Encoder { query_states: Option<&Tensor>, relative_pos: Option<&Tensor>, ) -> Result { + let _enter = self.span.enter(); let input_mask = if attention_mask.dims().len() <= 2 { attention_mask.clone() } else { @@ -1210,14 +1194,13 @@ impl DebertaV2Model { Tensor::from_vec(input_lengths, (batch_size, 1), &self.device)?.to_dtype(self.dtype)?; let embedding_output = self.embeddings.forward( - Some(&input_ids), - Some(&token_type_ids), - Some(&position_ids), + &input_ids, + &token_type_ids, + &position_ids, attention_mask.as_ref(), - None, )?; - let encoder_attention_mask = attention_mask.as_ref().cloned().unwrap_or_else(|| Tensor::ones(shape, DType::I64, &self.device).unwrap()); + let encoder_attention_mask = attention_mask.as_ref().cloned().unwrap_or_else(|| Tensor::ones(shape, self.dtype, &self.device).unwrap()); let encoder_output = self.encoder.forward(&embedding_output, &encoder_attention_mask, None, None)?; let has_pooling_requests = !batch.pooled_indices.is_empty(); @@ -1312,7 +1295,8 @@ impl DebertaV2Model { pub struct DebertaV2SeqClassificationHead { pub device: Device, pooler: DebertaV2ContextPooler, - classifier: candle_nn::Linear, + classifier: Linear, + span: tracing::Span, } impl DebertaV2SeqClassificationHead { @@ -1323,7 +1307,7 @@ impl DebertaV2SeqClassificationHead { }; let pooler = DebertaV2ContextPooler::load(vb.clone(), config)?; let output_dim = pooler.output_dim()?; - + // Try loading classifier from "classifier" first, then "deberta.classifier" let classifier = match candle_nn::linear(output_dim, id2label_len, vb.root().pp("classifier")) { Ok(classifier) => classifier, @@ -1334,18 +1318,21 @@ impl DebertaV2SeqClassificationHead { device: vb.device().clone(), pooler, classifier, + span: tracing::span!(tracing::Level::TRACE, "classifier"), }) } pub fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); let pooled_output = self.pooler.forward(&hidden_states)?; self.classifier.forward(&pooled_output) } } pub struct DebertaV2ContextPooler { - dense: candle_nn::Linear, + dense: Linear, config: DebertaV2Config, + span: tracing::Span, } // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L49 @@ -1364,10 +1351,12 @@ impl DebertaV2ContextPooler { Ok(Self { dense, config: config.clone(), + span: tracing::span!(tracing::Level::TRACE, "pooler"), }) } pub fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); let context_token = hidden_states.narrow(1, 0, 1)?.squeeze(1)?; let pooled_output = self.dense.forward(&context_token.contiguous()?)?; From f38caeaa92d2505261bb6c4c41c49441593bc06d Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Wed, 29 Oct 2025 21:36:05 -0700 Subject: [PATCH 07/11] interim changes again --- backends/candle/src/layers/linear.rs | 3 + backends/candle/src/models/debertav2.rs | 95 ++++++++++++++----------- 2 files changed, 56 insertions(+), 42 deletions(-) diff --git a/backends/candle/src/layers/linear.rs b/backends/candle/src/layers/linear.rs index e15ca8e8..820b0411 100644 --- a/backends/candle/src/layers/linear.rs +++ b/backends/candle/src/layers/linear.rs @@ -7,6 +7,7 @@ use serde::Deserialize; pub enum HiddenAct { #[serde(alias = "gelu_pytorch_tanh")] Gelu, + GeluExact, Relu, Silu, Swiglu, @@ -16,6 +17,7 @@ impl HiddenAct { pub fn forward(&self, x: &Tensor) -> Result { match self { Self::Gelu => x.gelu(), + Self::GeluExact => x.gelu_erf(), Self::Relu => x.relu(), Self::Silu => x.silu(), Self::Swiglu => candle_nn::ops::swiglu(x), @@ -85,6 +87,7 @@ impl Linear { if let Some(act) = &self.act { match act { HiddenAct::Gelu => x.gelu(), + HiddenAct::GeluExact => x.gelu_erf(), HiddenAct::Relu => x.relu(), HiddenAct::Silu => x.silu(), HiddenAct::Swiglu => candle_nn::ops::swiglu(&x), diff --git a/backends/candle/src/models/debertav2.rs b/backends/candle/src/models/debertav2.rs index 2c683787..3974ff59 100644 --- a/backends/candle/src/models/debertav2.rs +++ b/backends/candle/src/models/debertav2.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use candle::{bail, Context, DType, Device, IndexOp, Module, Result, Tensor, D}; +use crate::layers::HiddenAct; use crate::models::Model; use candle_nn::{ conv1d, embedding, layer_norm, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder, @@ -10,37 +11,6 @@ use text_embeddings_backend_core::{Batch, ModelType, Pool}; pub const DTYPE: DType = DType::F32; -// NOTE: HiddenAct and HiddenActLayer are both direct copies from bert.rs. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] -#[serde(rename_all = "lowercase")] -pub enum HiddenAct { - Gelu, - GeluApproximate, - Relu, -} - -pub struct HiddenActLayer { - act: HiddenAct, - span: tracing::Span, -} - -impl HiddenActLayer { - fn new(act: HiddenAct) -> Self { - let span = tracing::span!(tracing::Level::TRACE, "hidden-act"); - Self { act, span } - } - - fn forward(&self, xs: &Tensor) -> Result { - let _enter = self.span.enter(); - match self.act { - // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213 - HiddenAct::Gelu => xs.gelu_erf(), - HiddenAct::GeluApproximate => xs.gelu(), - HiddenAct::Relu => xs.relu(), - } - } -} - #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] #[serde(rename_all = "lowercase")] enum PositionEmbeddingType { @@ -58,6 +28,7 @@ pub struct DebertaV2Config { pub num_hidden_layers: usize, pub num_attention_heads: usize, pub intermediate_size: usize, + #[serde(deserialize_with = "deberta_hidden_act_deserializer::deserialize")] pub hidden_act: HiddenAct, pub max_position_embeddings: usize, pub type_vocab_size: usize, @@ -79,6 +50,7 @@ pub struct DebertaV2Config { pub conv_act: Option, pub id2label: Option, pub label2id: Option, + #[serde(deserialize_with = "deberta_hidden_act_deserializer::deserialize_optional")] pub pooler_hidden_act: Option, pub pooler_hidden_size: Option, } @@ -100,6 +72,36 @@ where } } +// Custom deserializer for DeBERTa hidden_act: maps "gelu" to GeluExact (exact GELU) +mod deberta_hidden_act_deserializer { + use super::*; + + fn parse(s: &str) -> std::result::Result { + match s.to_lowercase().as_str() { + "gelu" => Ok(HiddenAct::GeluExact), + "relu" => Ok(HiddenAct::Relu), + _ => Err(format!("Unknown hidden_act: {}", s)), + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> std::result::Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + parse(&s).map_err(serde::de::Error::custom) + } + + pub fn deserialize_optional<'de, D>(deserializer: D) -> std::result::Result, D::Error> + where + D: Deserializer<'de>, + { + Option::::deserialize(deserializer)? + .map(|s| parse(&s).map_err(serde::de::Error::custom)) + .transpose() + } +} + // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L823 pub struct DebertaV2Embeddings { device: Device, @@ -268,6 +270,8 @@ pub struct DebertaV2DisentangledSelfAttention { share_att_key: bool, pos_key_proj: Option, pos_query_proj: Option, + is_c2p_attn: bool, + is_p2c_attn: bool, span: tracing::Span, } @@ -299,6 +303,10 @@ impl DebertaV2DisentangledSelfAttention { let relative_attention = config.relative_attention; let mut max_relative_positions = config.max_relative_positions; + // Precompute attention type checks + let is_c2p_attn = config.pos_att_type.iter().any(|s| s == "c2p"); + let is_p2c_attn = config.pos_att_type.iter().any(|s| s == "p2c"); + let mut pos_ebd_size: isize = 0; let position_buckets = config.position_buckets.unwrap_or(-1); let mut pos_key_proj: Option = None; @@ -314,14 +322,14 @@ impl DebertaV2DisentangledSelfAttention { } if !share_att_key { - if config.pos_att_type.iter().any(|s| s == "c2p") { + if is_c2p_attn { pos_key_proj = Some(candle_nn::linear( config.hidden_size, all_head_size, vb.pp("pos_key_proj"), )?); } - if config.pos_att_type.iter().any(|s| s == "p2c") { + if is_p2c_attn { pos_query_proj = Some(candle_nn::linear( config.hidden_size, all_head_size, @@ -347,6 +355,8 @@ impl DebertaV2DisentangledSelfAttention { share_att_key, pos_key_proj, pos_query_proj, + is_c2p_attn, + is_p2c_attn, span: tracing::span!(tracing::Level::TRACE, "attention"), }) } @@ -371,11 +381,11 @@ impl DebertaV2DisentangledSelfAttention { let mut scale_factor: usize = 1; - if self.config.pos_att_type.iter().any(|s| s == "c2p") { + if self.is_c2p_attn { scale_factor += 1; } - if self.config.pos_att_type.iter().any(|s| s == "p2c") { + if self.is_p2c_attn { scale_factor += 1; } @@ -523,7 +533,7 @@ impl DebertaV2DisentangledSelfAttention { .repeat(repeat_with)?, ) } else { - if self.config.pos_att_type.iter().any(|s| s == "c2p") { + if self.is_c2p_attn { pos_key_layer = Some( self.transpose_for_scores( &self @@ -537,7 +547,7 @@ impl DebertaV2DisentangledSelfAttention { .repeat(repeat_with)?, ) } - if self.config.pos_att_type.iter().any(|s| s == "p2c") { + if self.is_p2c_attn { pos_query_layer = Some(self.transpose_for_scores(&self .pos_query_proj .as_ref() @@ -548,7 +558,7 @@ impl DebertaV2DisentangledSelfAttention { let mut score = Tensor::new(&[0 as f32], &self.device)?; - if self.config.pos_att_type.iter().any(|s| s == "c2p") { + if self.is_c2p_attn { let pos_key_layer = pos_key_layer.context("c2p without pos_key_layer")?; let scale = Tensor::new( @@ -580,7 +590,7 @@ impl DebertaV2DisentangledSelfAttention { )?; } - if self.config.pos_att_type.iter().any(|s| s == "p2c") { + if self.is_p2c_attn { let pos_query_layer = pos_query_layer.context("p2c without pos_key_layer")?; let scale = Tensor::new( @@ -699,7 +709,7 @@ impl DebertaV2SelfOutput { // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L307 pub struct DebertaV2Intermediate { dense: Linear, - intermediate_act: HiddenActLayer, + intermediate_act: HiddenAct, span: tracing::Span, } @@ -710,7 +720,7 @@ impl DebertaV2Intermediate { config.intermediate_size, vb.pp("intermediate.dense"), )?; - let intermediate_act = HiddenActLayer::new(config.hidden_act); + let intermediate_act = config.hidden_act.clone(); Ok(Self { dense, intermediate_act, @@ -1363,9 +1373,10 @@ impl DebertaV2ContextPooler { let pooler_hidden_act = self .config .pooler_hidden_act + .clone() .context("Could not obtain pooler hidden act from config")?; - HiddenActLayer::new(pooler_hidden_act).forward(&pooled_output) + pooler_hidden_act.forward(&pooled_output) } pub fn output_dim(&self) -> Result { From 64b4b3ab65caacdb7347624581ef2b7f866af9ad Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Thu, 30 Oct 2025 10:40:16 -0700 Subject: [PATCH 08/11] fix linting --- backends/candle/src/lib.rs | 11 +- backends/candle/src/models/debertav2.rs | 192 +++++++++++++----------- 2 files changed, 111 insertions(+), 92 deletions(-) diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index e9618d89..27bc3468 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -22,10 +22,11 @@ use crate::compute_cap::{ compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap, }; use crate::models::{ - BertConfig, BertModel, Dense, DenseConfig, DenseLayer, DebertaV2Config, DebertaV2Model, DistilBertConfig, DistilBertModel, - GTEConfig, GTEModel, Gemma3Config, Gemma3Model, JinaBertModel, JinaCodeBertModel, MPNetConfig, - MPNetModel, MistralConfig, Model, ModernBertConfig, ModernBertModel, NomicBertModel, - NomicConfig, Qwen2Config, Qwen3Config, Qwen3Model, + BertConfig, BertModel, DebertaV2Config, DebertaV2Model, Dense, DenseConfig, DenseLayer, + DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, Gemma3Config, Gemma3Model, + JinaBertModel, JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model, + ModernBertConfig, ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config, Qwen3Config, + Qwen3Model, }; #[cfg(feature = "cuda")] use crate::models::{ @@ -255,7 +256,7 @@ impl CandleBackend { (Config::DebertaV2(config), Device::Cpu | Device::Metal(_)) => { tracing::info!("Starting DebertaV2 model on {:?}", device); Ok(Box::new(DebertaV2Model::load(vb, &config, model_type).s()?)) - }, + } ( Config::Camembert(config) | Config::Roberta(config) | Config::XlmRoberta(config), Device::Cpu | Device::Metal(_), diff --git a/backends/candle/src/models/debertav2.rs b/backends/candle/src/models/debertav2.rs index 3974ff59..3061fb16 100644 --- a/backends/candle/src/models/debertav2.rs +++ b/backends/candle/src/models/debertav2.rs @@ -1,16 +1,14 @@ use std::collections::HashMap; -use candle::{bail, Context, DType, Device, IndexOp, Module, Result, Tensor, D}; use crate::layers::HiddenAct; use crate::models::Model; +use candle::{bail, Context, DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{ conv1d, embedding, layer_norm, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder, }; use serde::{Deserialize, Deserializer}; use text_embeddings_backend_core::{Batch, ModelType, Pool}; -pub const DTYPE: DType = DType::F32; - #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] #[serde(rename_all = "lowercase")] enum PositionEmbeddingType { @@ -92,7 +90,9 @@ mod deberta_hidden_act_deserializer { parse(&s).map_err(serde::de::Error::custom) } - pub fn deserialize_optional<'de, D>(deserializer: D) -> std::result::Result, D::Error> + pub fn deserialize_optional<'de, D>( + deserializer: D, + ) -> std::result::Result, D::Error> where D: Deserializer<'de>, { @@ -104,7 +104,6 @@ mod deberta_hidden_act_deserializer { // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L823 pub struct DebertaV2Embeddings { - device: Device, word_embeddings: Embedding, position_embeddings: Option, token_type_embeddings: Option, @@ -117,7 +116,6 @@ pub struct DebertaV2Embeddings { impl DebertaV2Embeddings { pub fn load(vb: VarBuilder, config: &DebertaV2Config) -> Result { - let device = vb.device().clone(); let config = config.clone(); let embedding_size = config.embedding_size.unwrap_or(config.hidden_size); @@ -166,7 +164,6 @@ impl DebertaV2Embeddings { position_embeddings, token_type_embeddings, layer_norm, - device, config, embedding_size, embed_proj, @@ -185,7 +182,7 @@ impl DebertaV2Embeddings { let input_embeds = self.word_embeddings.forward(input_ids)?; let position_embeddings = match &self.position_embeddings { - Some(emb) => emb.forward(&position_ids)?, + Some(emb) => emb.forward(position_ids)?, None => Tensor::zeros_like(&input_embeds)?, }; @@ -199,7 +196,7 @@ impl DebertaV2Embeddings { embeddings = self.token_type_embeddings.as_ref().map_or_else( || bail!("token_type_embeddings must be set when type_vocab_size > 0"), |token_type_embeddings| { - embeddings.add(&token_type_embeddings.forward(&token_type_ids)?) + embeddings.add(&token_type_embeddings.forward(token_type_ids)?) }, )?; } @@ -257,7 +254,6 @@ impl XSoftmax { // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L605 pub struct DebertaV2DisentangledSelfAttention { - config: DebertaV2Config, num_attention_heads: usize, query_proj: Linear, key_proj: Linear, @@ -342,7 +338,6 @@ impl DebertaV2DisentangledSelfAttention { let device = vb.device().clone(); Ok(Self { - config, num_attention_heads, query_proj, key_proj, @@ -655,7 +650,11 @@ impl DebertaV2Attention { pub fn load(vb: VarBuilder, config: &DebertaV2Config) -> Result { let dsa = DebertaV2DisentangledSelfAttention::load(vb.pp("attention.self"), config)?; let output = DebertaV2SelfOutput::load(vb.pp("attention.output"), config)?; - Ok(Self { dsa, output, span: tracing::span!(tracing::Level::TRACE, "attention") }) + Ok(Self { + dsa, + output, + span: tracing::span!(tracing::Level::TRACE, "attention"), + }) } fn forward( @@ -695,7 +694,11 @@ impl DebertaV2SelfOutput { config.layer_norm_eps, vb.pp("LayerNorm"), )?; - Ok(Self { dense, layer_norm, span: tracing::span!(tracing::Level::TRACE, "self-output") }) + Ok(Self { + dense, + layer_norm, + span: tracing::span!(tracing::Level::TRACE, "self-output"), + }) } pub fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { @@ -754,7 +757,11 @@ impl DebertaV2Output { config.layer_norm_eps, vb.pp("output.LayerNorm"), )?; - Ok(Self { dense, layer_norm, span: tracing::span!(tracing::Level::TRACE, "output") }) + Ok(Self { + dense, + layer_norm, + span: tracing::span!(tracing::Level::TRACE, "output"), + }) } pub fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { @@ -1096,17 +1103,15 @@ impl DebertaV2Model { let classifier = DebertaV2SeqClassificationHead::load(vb.clone(), config)?; (Some(classifier), None) } - ModelType::Embedding(pool) => { - (None, Some(pool)) - } + ModelType::Embedding(pool) => (None, Some(pool)), }; - + // Try loading embeddings from "embeddings" first, then "deberta.embeddings" let embeddings = match DebertaV2Embeddings::load(vb.pp("embeddings"), config) { Ok(embeddings) => embeddings, Err(_) => DebertaV2Embeddings::load(vb.pp("deberta.embeddings"), config)?, }; - + // Try loading encoder from "encoder" first, then "deberta.encoder" let encoder = match DebertaV2Encoder::load(vb.pp("encoder"), config) { Ok(encoder) => encoder, @@ -1132,71 +1137,78 @@ impl DebertaV2Model { let shape = (batch_size, max_length); - let (input_ids, token_type_ids, position_ids, input_lengths, attention_mask) = if batch_size > 1 { - // Prepare padded batch - let elems = batch_size * max_length; - - let mut input_ids = Vec::with_capacity(elems); - let mut token_type_ids = Vec::with_capacity(elems); - let mut position_ids = Vec::with_capacity(elems); - let mut attention_mask = Vec::with_capacity(elems); - let mut input_lengths = Vec::with_capacity(batch_size); - // Bool to know if we need to use the attention mask - let mut masking = false; - - for i in 0..batch_size { - let start = batch.cumulative_seq_lengths[i] as usize; - let end = batch.cumulative_seq_lengths[i + 1] as usize; - let seq_length = (end - start) as u32; - input_lengths.push(seq_length as f32); - - // Copy values - for j in start..end { - input_ids.push(batch.input_ids[j]); - token_type_ids.push(batch.token_type_ids[j]); - position_ids.push(batch.position_ids[j]); - attention_mask.push(1.0_f32); - } + let (input_ids, token_type_ids, position_ids, input_lengths, attention_mask) = + if batch_size > 1 { + // Prepare padded batch + let elems = batch_size * max_length; + + let mut input_ids = Vec::with_capacity(elems); + let mut token_type_ids = Vec::with_capacity(elems); + let mut position_ids = Vec::with_capacity(elems); + let mut attention_mask = Vec::with_capacity(elems); + let mut input_lengths = Vec::with_capacity(batch_size); + // Bool to know if we need to use the attention mask + let mut masking = false; + + for i in 0..batch_size { + let start = batch.cumulative_seq_lengths[i] as usize; + let end = batch.cumulative_seq_lengths[i + 1] as usize; + let seq_length = (end - start) as u32; + input_lengths.push(seq_length as f32); + + // Copy values + for j in start..end { + input_ids.push(batch.input_ids[j]); + token_type_ids.push(batch.token_type_ids[j]); + position_ids.push(batch.position_ids[j]); + attention_mask.push(1.0_f32); + } - // Add padding if needed - let padding = batch.max_length - seq_length; - if padding > 0 { - // Set bool to use attention mask - masking = true; - for _ in 0..padding { - input_ids.push(0); - token_type_ids.push(0); - position_ids.push(0); - attention_mask.push(0.0_f32); + // Add padding if needed + let padding = batch.max_length - seq_length; + if padding > 0 { + // Set bool to use attention mask + masking = true; + for _ in 0..padding { + input_ids.push(0); + token_type_ids.push(0); + position_ids.push(0); + attention_mask.push(0.0_f32); + } } } - } - let attention_mask = match masking { - true => { - let attention_mask = Tensor::from_vec( - attention_mask, - (batch_size, max_length, 1), - &self.device, - )? - .to_dtype(self.dtype)?; + let attention_mask = match masking { + true => { + let attention_mask = Tensor::from_vec( + attention_mask, + (batch_size, max_length, 1), + &self.device, + )? + .to_dtype(self.dtype)?; - Some(attention_mask) - } - false => None, + Some(attention_mask) + } + false => None, + }; + + ( + input_ids, + token_type_ids, + position_ids, + input_lengths, + attention_mask, + ) + } else { + ( + batch.input_ids, + batch.token_type_ids, + batch.position_ids, + vec![batch.max_length as f32], + None, + ) }; - (input_ids, token_type_ids, position_ids, input_lengths, attention_mask) - } else { - ( - batch.input_ids, - batch.token_type_ids, - batch.position_ids, - vec![batch.max_length as f32], - None, - ) - }; - let input_ids = Tensor::from_vec(input_ids, shape, &self.device)?; let token_type_ids = Tensor::from_vec(token_type_ids, shape, &self.device)?; let position_ids = Tensor::from_vec(position_ids, shape, &self.device)?; @@ -1210,8 +1222,13 @@ impl DebertaV2Model { attention_mask.as_ref(), )?; - let encoder_attention_mask = attention_mask.as_ref().cloned().unwrap_or_else(|| Tensor::ones(shape, self.dtype, &self.device).unwrap()); - let encoder_output = self.encoder.forward(&embedding_output, &encoder_attention_mask, None, None)?; + let encoder_attention_mask = attention_mask + .as_ref() + .cloned() + .unwrap_or_else(|| Tensor::ones(shape, self.dtype, &self.device).unwrap()); + let encoder_output = + self.encoder + .forward(&embedding_output, &encoder_attention_mask, None, None)?; let has_pooling_requests = !batch.pooled_indices.is_empty(); let has_raw_requests = !batch.raw_indices.is_empty(); @@ -1303,7 +1320,6 @@ impl DebertaV2Model { } pub struct DebertaV2SeqClassificationHead { - pub device: Device, pooler: DebertaV2ContextPooler, classifier: Linear, span: tracing::Span, @@ -1319,13 +1335,15 @@ impl DebertaV2SeqClassificationHead { let output_dim = pooler.output_dim()?; // Try loading classifier from "classifier" first, then "deberta.classifier" - let classifier = match candle_nn::linear(output_dim, id2label_len, vb.root().pp("classifier")) { - Ok(classifier) => classifier, - Err(_) => candle_nn::linear(output_dim, id2label_len, vb.root().pp("deberta.classifier"))?, - }; + let classifier = + match candle_nn::linear(output_dim, id2label_len, vb.root().pp("classifier")) { + Ok(classifier) => classifier, + Err(_) => { + candle_nn::linear(output_dim, id2label_len, vb.root().pp("deberta.classifier"))? + } + }; Ok(Self { - device: vb.device().clone(), pooler, classifier, span: tracing::span!(tracing::Level::TRACE, "classifier"), @@ -1334,7 +1352,7 @@ impl DebertaV2SeqClassificationHead { pub fn forward(&self, hidden_states: &Tensor) -> Result { let _enter = self.span.enter(); - let pooled_output = self.pooler.forward(&hidden_states)?; + let pooled_output = self.pooler.forward(hidden_states)?; self.classifier.forward(&pooled_output) } } @@ -1487,4 +1505,4 @@ impl Model for DebertaV2Model { } } } -} \ No newline at end of file +} From 46eab77c26e79bcd4e0e46153c2052af06611f1c Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Thu, 30 Oct 2025 10:53:28 -0700 Subject: [PATCH 09/11] refactor our xsoftmax --- backends/candle/src/models/debertav2.rs | 34 +++++-------------------- 1 file changed, 7 insertions(+), 27 deletions(-) diff --git a/backends/candle/src/models/debertav2.rs b/backends/candle/src/models/debertav2.rs index 3061fb16..32882f71 100644 --- a/backends/candle/src/models/debertav2.rs +++ b/backends/candle/src/models/debertav2.rs @@ -228,30 +228,6 @@ impl DebertaV2Embeddings { } } -// https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L72 -struct XSoftmax {} - -impl XSoftmax { - pub fn apply(input: &Tensor, mask: &Tensor, dim: D, device: &Device) -> Result { - // NOTE: At the time of this writing, candle does not have a logical-not operator. - let mut rmask = mask.broadcast_as(input.shape())?.to_dtype(DType::F32)?; - - rmask = rmask - .broadcast_lt(&Tensor::new(&[1.0_f32], device)?)? - .to_dtype(DType::U8)?; - - let min_value_tensor = Tensor::new(&[f32::MIN], device)?.broadcast_as(input.shape())?; - let mut output = rmask.where_cond(&min_value_tensor, input)?; - - output = candle_nn::ops::softmax(&output, dim)?; - - let t_zeroes = Tensor::new(&[0f32], device)?.broadcast_as(input.shape())?; - output = rmask.where_cond(&t_zeroes, &output)?; - - Ok(output) - } -} - // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L605 pub struct DebertaV2DisentangledSelfAttention { num_attention_heads: usize, @@ -423,8 +399,8 @@ impl DebertaV2DisentangledSelfAttention { attention_scores.dim(D::Minus1)?, ))?; - let attention_probs = - XSoftmax::apply(&attention_scores, attention_mask, D::Minus1, &self.device)?; + let attention_probs = attention_scores.broadcast_add(attention_mask)?; + let attention_probs = candle_nn::ops::softmax(&attention_probs, D::Minus1)?; let mut context_layer = attention_probs .reshape(( @@ -1027,7 +1003,11 @@ impl DebertaV2Encoder { len => bail!("Unsupported attentiom mask size length: {len}"), } - Ok(attention_mask) + // Convert binary mask to additive bias: 0 for valid positions, large negative for masked + let one = Tensor::ones_like(&attention_mask)?; + let bias = attention_mask.broadcast_sub(&one)?.broadcast_mul(&Tensor::new(&[10000.0_f32], &self.device)?)?; + + Ok(bias) } fn get_rel_pos( From 9abd6ebad092ec40f2b6528b18238743dc77fcf1 Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Thu, 30 Oct 2025 10:54:18 -0700 Subject: [PATCH 10/11] linting --- backends/candle/src/models/debertav2.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/backends/candle/src/models/debertav2.rs b/backends/candle/src/models/debertav2.rs index 32882f71..2cdebe8a 100644 --- a/backends/candle/src/models/debertav2.rs +++ b/backends/candle/src/models/debertav2.rs @@ -1005,7 +1005,9 @@ impl DebertaV2Encoder { // Convert binary mask to additive bias: 0 for valid positions, large negative for masked let one = Tensor::ones_like(&attention_mask)?; - let bias = attention_mask.broadcast_sub(&one)?.broadcast_mul(&Tensor::new(&[10000.0_f32], &self.device)?)?; + let bias = attention_mask + .broadcast_sub(&one)? + .broadcast_mul(&Tensor::new(&[10000.0_f32], &self.device)?)?; Ok(bias) } From 94ccfcd924e630212a83ea25f6e98485f897d704 Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Thu, 30 Oct 2025 12:03:14 -0700 Subject: [PATCH 11/11] solve compatibility with fp16 --- backends/candle/src/models/debertav2.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/backends/candle/src/models/debertav2.rs b/backends/candle/src/models/debertav2.rs index 2cdebe8a..b3bb459e 100644 --- a/backends/candle/src/models/debertav2.rs +++ b/backends/candle/src/models/debertav2.rs @@ -399,7 +399,9 @@ impl DebertaV2DisentangledSelfAttention { attention_scores.dim(D::Minus1)?, ))?; - let attention_probs = attention_scores.broadcast_add(attention_mask)?; + // Add attention mask bias and apply softmax (ModernBERT approach) + let attention_mask = attention_mask.to_dtype(attention_scores.dtype())?; + let attention_probs = attention_scores.broadcast_add(&attention_mask)?; let attention_probs = candle_nn::ops::softmax(&attention_probs, D::Minus1)?; let mut context_layer = attention_probs @@ -527,7 +529,8 @@ impl DebertaV2DisentangledSelfAttention { } } - let mut score = Tensor::new(&[0 as f32], &self.device)?; + // Initialize score tensor with the same dtype as query_layer to avoid dtype mismatches + let mut score = Tensor::new(&[0 as f32], &self.device)?.to_dtype(query_layer.dtype())?; if self.is_c2p_attn { let pos_key_layer = pos_key_layer.context("c2p without pos_key_layer")?; @@ -541,7 +544,7 @@ impl DebertaV2DisentangledSelfAttention { let mut c2p_att = query_layer.matmul(&pos_key_layer.t()?)?; let c2p_pos = relative_pos - .broadcast_add(&Tensor::new(&[att_span as i64], &self.device)?)? + .broadcast_add(&Tensor::new(&[att_span as i64], &self.device)?.to_dtype(relative_pos.dtype())?)? .clamp(0 as f32, (att_span * 2 - 1) as f32)?; c2p_att = c2p_att.gather( @@ -1007,7 +1010,7 @@ impl DebertaV2Encoder { let one = Tensor::ones_like(&attention_mask)?; let bias = attention_mask .broadcast_sub(&one)? - .broadcast_mul(&Tensor::new(&[10000.0_f32], &self.device)?)?; + .broadcast_mul(&Tensor::new(&[10000.0_f32], &self.device)?.to_dtype(attention_mask.dtype())?)?; Ok(bias) }