@@ -3,6 +3,7 @@ use crate::models::Model;
33use candle:: { DType , Device , IndexOp , Module , Result , Tensor , D } ;
44use candle_nn:: { Embedding , VarBuilder } ;
55use serde:: Deserialize ;
6+ use std:: collections:: HashMap ;
67use text_embeddings_backend_core:: { Batch , ModelType , Pool } ;
78
89#[ derive( Debug , Clone , PartialEq , Deserialize ) ]
@@ -16,6 +17,8 @@ pub struct DistilBertConfig {
1617 pub max_position_embeddings : usize ,
1718 pub pad_token_id : usize ,
1819 pub model_type : Option < String > ,
20+ pub classifier_dropout : Option < f64 > ,
21+ pub id2label : Option < HashMap < String , String > > ,
1922}
2023
2124#[ derive( Debug ) ]
@@ -318,6 +321,56 @@ impl DistilBertEncoder {
318321 }
319322}
320323
324+ pub trait ClassificationHead {
325+ fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > ;
326+ }
327+
328+ pub struct DistilBertClassificationHead {
329+ pre_classifier : Linear ,
330+ classifier : Linear ,
331+ span : tracing:: Span ,
332+ }
333+
334+ impl DistilBertClassificationHead {
335+ pub ( crate ) fn load ( vb : VarBuilder , config : & DistilBertConfig ) -> Result < Self > {
336+ let n_classes = match & config. id2label {
337+ None => candle:: bail!( "`id2label` must be set for classifier models" ) ,
338+ Some ( id2label) => id2label. len ( ) ,
339+ } ;
340+
341+ let pre_classifier_weight = vb
342+ . pp ( "pre_classifier" )
343+ . get ( ( config. dim , config. dim ) , "weight" ) ?;
344+ let pre_classifier_bias = vb. pp ( "pre_classifier" ) . get ( config. dim , "bias" ) ?;
345+ let pre_classifier = Linear :: new ( pre_classifier_weight, Some ( pre_classifier_bias) , None ) ;
346+
347+ let classifier_weight = vb. pp ( "classifier" ) . get ( ( n_classes, config. dim ) , "weight" ) ?;
348+ let classifier_bias = vb. pp ( "classifier" ) . get ( n_classes, "bias" ) ?;
349+ let classifier = Linear :: new ( classifier_weight, Some ( classifier_bias) , None ) ;
350+
351+ Ok ( Self {
352+ pre_classifier,
353+ classifier,
354+ span : tracing:: span!( tracing:: Level :: TRACE , "classifier" ) ,
355+ } )
356+ }
357+ }
358+
359+ impl ClassificationHead for DistilBertClassificationHead {
360+ fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > {
361+ let _enter = self . span . enter ( ) ;
362+
363+ let hidden_states = hidden_states. unsqueeze ( 1 ) ?;
364+
365+ let hidden_states = self . pre_classifier . forward ( & hidden_states) ?;
366+ let hidden_states = hidden_states. relu ( ) ?;
367+
368+ let hidden_states = self . classifier . forward ( & hidden_states) ?;
369+ let hidden_states = hidden_states. squeeze ( 1 ) ?;
370+ Ok ( hidden_states)
371+ }
372+ }
373+
321374#[ derive( Debug ) ]
322375pub struct DistilBertSpladeHead {
323376 vocab_transform : Linear ,
@@ -368,11 +421,11 @@ impl DistilBertSpladeHead {
368421 }
369422}
370423
371- #[ derive( Debug ) ]
372424pub struct DistilBertModel {
373425 embeddings : DistilBertEmbeddings ,
374426 encoder : DistilBertEncoder ,
375427 pool : Pool ,
428+ classifier : Option < Box < dyn ClassificationHead + Send > > ,
376429 splade : Option < DistilBertSpladeHead > ,
377430
378431 num_attention_heads : usize ,
@@ -385,15 +438,21 @@ pub struct DistilBertModel {
385438
386439impl DistilBertModel {
387440 pub fn load ( vb : VarBuilder , config : & DistilBertConfig , model_type : ModelType ) -> Result < Self > {
388- let pool = match model_type {
441+ let ( pool, classifier) = match model_type {
442+ // Classifier models always use CLS pooling
389443 ModelType :: Classifier => {
390- candle:: bail!( "`classifier` model type is not supported for DistilBert" )
444+ let pool = Pool :: Cls ;
445+
446+ let classifier: Box < dyn ClassificationHead + Send > =
447+ Box :: new ( DistilBertClassificationHead :: load ( vb. clone ( ) , config) ?) ;
448+ ( pool, Some ( classifier) )
391449 }
392450 ModelType :: Embedding ( pool) => {
393451 if pool == Pool :: LastToken {
394452 candle:: bail!( "`last_token` is not supported for DistilBert" ) ;
395453 }
396- pool
454+
455+ ( pool, None )
397456 }
398457 } ;
399458
@@ -424,6 +483,7 @@ impl DistilBertModel {
424483 embeddings,
425484 encoder,
426485 pool,
486+ classifier,
427487 splade,
428488 num_attention_heads : config. n_heads ,
429489 device : vb. device ( ) . clone ( ) ,
@@ -660,4 +720,16 @@ impl Model for DistilBertModel {
660720 fn embed ( & self , batch : Batch ) -> Result < ( Option < Tensor > , Option < Tensor > ) > {
661721 self . forward ( batch)
662722 }
723+
724+ fn predict ( & self , batch : Batch ) -> Result < Tensor > {
725+ match & self . classifier {
726+ None => candle:: bail!( "`predict` is not implemented for this model" ) ,
727+ Some ( classifier) => {
728+ let ( pooled_embeddings, _raw_embeddings) = self . forward ( batch) ?;
729+ let pooled_embeddings =
730+ pooled_embeddings. expect ( "pooled_embeddings is empty. This is a bug." ) ;
731+ classifier. forward ( & pooled_embeddings)
732+ }
733+ }
734+ }
663735}
0 commit comments