Skip to content
3 changes: 3 additions & 0 deletions backends/candle/src/layers/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use serde::Deserialize;
pub enum HiddenAct {
#[serde(alias = "gelu_pytorch_tanh")]
Gelu,
GeluExact,
Relu,
Silu,
Swiglu,
Expand All @@ -16,6 +17,7 @@ impl HiddenAct {
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
match self {
Self::Gelu => x.gelu(),
Self::GeluExact => x.gelu_erf(),
Self::Relu => x.relu(),
Self::Silu => x.silu(),
Self::Swiglu => candle_nn::ops::swiglu(x),
Expand Down Expand Up @@ -85,6 +87,7 @@ impl Linear {
if let Some(act) = &self.act {
match act {
HiddenAct::Gelu => x.gelu(),
HiddenAct::GeluExact => x.gelu_erf(),
HiddenAct::Relu => x.relu(),
HiddenAct::Silu => x.silu(),
HiddenAct::Swiglu => candle_nn::ops::swiglu(&x),
Expand Down
15 changes: 11 additions & 4 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ use crate::compute_cap::{
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
};
use crate::models::{
BertConfig, BertModel, Dense, DenseConfig, DenseLayer, DistilBertConfig, DistilBertModel,
GTEConfig, GTEModel, Gemma3Config, Gemma3Model, JinaBertModel, JinaCodeBertModel, MPNetConfig,
MPNetModel, MistralConfig, Model, ModernBertConfig, ModernBertModel, NomicBertModel,
NomicConfig, Qwen2Config, Qwen3Config, Qwen3Model,
BertConfig, BertModel, DebertaV2Config, DebertaV2Model, Dense, DenseConfig, DenseLayer,
DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, Gemma3Config, Gemma3Model,
JinaBertModel, JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model,
ModernBertConfig, ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config, Qwen3Config,
Qwen3Model,
};
#[cfg(feature = "cuda")]
use crate::models::{
Expand Down Expand Up @@ -92,6 +93,8 @@ impl<'de> Deserialize<'de> for BertConfigWrapper {
#[serde(tag = "model_type", rename_all = "kebab-case")]
enum Config {
Bert(BertConfigWrapper),
#[serde(rename(deserialize = "deberta-v2"))]
DebertaV2(DebertaV2Config),
Camembert(BertConfig),
#[serde(rename(deserialize = "distilbert"))]
DistilBert(DistilBertConfig),
Expand Down Expand Up @@ -250,6 +253,10 @@ impl CandleBackend {
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
}
},
(Config::DebertaV2(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting DebertaV2 model on {:?}", device);
Ok(Box::new(DebertaV2Model::load(vb, &config, model_type).s()?))
}
(
Config::Camembert(config) | Config::Roberta(config) | Config::XlmRoberta(config),
Device::Cpu | Device::Metal(_),
Expand Down
Loading