@@ -12,7 +12,7 @@ use crate::compute_cap::{
1212} ;
1313use crate :: models:: {
1414 BertConfig , BertModel , DistilBertConfig , DistilBertModel , JinaBertModel , JinaCodeBertModel ,
15- JinaCodeConfig , JinaConfig , Model , NomicBertModel , NomicConfig ,
15+ Model , NomicBertModel , NomicConfig ,
1616} ;
1717#[ cfg( feature = "cuda" ) ]
1818use crate :: models:: {
@@ -30,17 +30,28 @@ use text_embeddings_backend_core::{
3030 Backend , BackendError , Batch , Embedding , Embeddings , ModelType , Predictions ,
3131} ;
3232
33+ /// This enum is needed to be able to differentiate between jina models that also use
34+ /// the `bert` model type and valid Bert models.
35+ /// We use the `_name_or_path` field in the config to do so. This might not be robust in the long
36+ /// run but is still better than the other options...
37+ #[ derive( Debug , Clone , PartialEq , Deserialize ) ]
38+ #[ serde( tag = "_name_or_path" ) ]
39+ pub enum BertConfigWrapper {
40+ #[ serde( rename = "jinaai/jina-bert-implementation" ) ]
41+ JinaBert ( BertConfig ) ,
42+ #[ serde( rename = "jinaai/jina-bert-v2-qk-post-norm" ) ]
43+ JinaCodeBert ( BertConfig ) ,
44+ #[ serde( untagged) ]
45+ Bert ( BertConfig ) ,
46+ }
47+
3348#[ derive( Deserialize ) ]
3449#[ serde( tag = "model_type" , rename_all = "kebab-case" ) ]
3550enum Config {
36- Bert ( BertConfig ) ,
51+ Bert ( BertConfigWrapper ) ,
3752 XlmRoberta ( BertConfig ) ,
3853 Camembert ( BertConfig ) ,
3954 Roberta ( BertConfig ) ,
40- #[ serde( rename( deserialize = "jina_bert" ) ) ]
41- JinaBert ( JinaConfig ) ,
42- #[ serde( rename( deserialize = "jina_code_bert" ) ) ]
43- JinaCodeBert ( JinaCodeConfig ) ,
4455 #[ serde( rename( deserialize = "distilbert" ) ) ]
4556 DistilBert ( DistilBertConfig ) ,
4657 #[ serde( rename( deserialize = "nomic_bert" ) ) ]
@@ -76,7 +87,7 @@ impl CandleBackend {
7687 "Runtime compute cap {} is not compatible with compile time compute cap {}" ,
7788 get_runtime_compute_cap( ) . unwrap( ) ,
7889 get_compile_compute_cap( ) . unwrap( )
79- ) ) )
90+ ) ) ) ;
8091 }
8192 Err ( err) => {
8293 tracing:: warn!( "Could not find a compatible CUDA device on host: {err:?}" ) ;
@@ -123,20 +134,22 @@ impl CandleBackend {
123134 ( _, Device :: Cuda ( _) ) => Err ( BackendError :: Start (
124135 "`cuda` feature is not enabled" . to_string ( ) ,
125136 ) ) ,
126- ( Config :: Bert ( config) , Device :: Cpu | Device :: Metal ( _) ) => {
127- tracing:: info!( "Starting Bert model on {:?}" , device) ;
128- Ok ( Box :: new ( BertModel :: load ( vb, & config, model_type) . s ( ) ?) )
129- }
130- ( Config :: JinaBert ( config) , Device :: Cpu | Device :: Metal ( _) ) => {
131- tracing:: info!( "Starting JinaBertModel model on {:?}" , device) ;
132- Ok ( Box :: new ( JinaBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
133- }
134- ( Config :: JinaCodeBert ( config) , Device :: Cpu | Device :: Metal ( _) ) => {
135- tracing:: info!( "Starting JinaCodeBertModel model on {:?}" , device) ;
136- Ok ( Box :: new (
137- JinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?,
138- ) )
139- }
137+ ( Config :: Bert ( config) , Device :: Cpu | Device :: Metal ( _) ) => match config {
138+ BertConfigWrapper :: JinaBert ( config) => {
139+ tracing:: info!( "Starting JinaBertModel model on {:?}" , device) ;
140+ Ok ( Box :: new ( JinaBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
141+ }
142+ BertConfigWrapper :: JinaCodeBert ( config) => {
143+ tracing:: info!( "Starting JinaCodeBert model on {:?}" , device) ;
144+ Ok ( Box :: new (
145+ JinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?,
146+ ) )
147+ }
148+ BertConfigWrapper :: Bert ( config) => {
149+ tracing:: info!( "Starting Bert model on {:?}" , device) ;
150+ Ok ( Box :: new ( BertModel :: load ( vb, & config, model_type) . s ( ) ?) )
151+ }
152+ } ,
140153 (
141154 Config :: XlmRoberta ( config) | Config :: Camembert ( config) | Config :: Roberta ( config) ,
142155 Device :: Cpu | Device :: Metal ( _) ,
@@ -160,56 +173,45 @@ impl CandleBackend {
160173 ( Config :: Bert ( config) , Device :: Cuda ( _) ) => {
161174 if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
162175 && dtype == DType :: F16
163- && ( ( config. position_embedding_type == PositionEmbeddingType :: Absolute ) | ( config. position_embedding_type == PositionEmbeddingType :: Alibi ) )
164176 // Allow disabling because of flash attention v1 precision problems
165177 // See: https://github.com/huggingface/text-embeddings-inference/issues/37
166178 && & std:: env:: var ( "USE_FLASH_ATTENTION" ) . unwrap_or ( "True" . to_string ( ) ) . to_lowercase ( ) == "true"
167179 {
168- if config. position_embedding_type == PositionEmbeddingType :: Alibi {
169- tracing:: info!( "Starting FlashBert model on {:?}" , device) ;
170- Ok ( Box :: new ( FlashBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
171- } else {
172- tracing:: info!( "Starting Bert model on {:?}" , device) ;
173- Ok ( Box :: new ( BertModel :: load ( vb, & config, model_type) . s ( ) ?) )
180+ match config {
181+ BertConfigWrapper :: JinaBert ( config) => {
182+ tracing:: info!( "Starting FlashJinaBert model on {:?}" , device) ;
183+ Ok ( Box :: new (
184+ FlashJinaBertModel :: load ( vb, & config, model_type) . s ( ) ?,
185+ ) )
186+ }
187+ BertConfigWrapper :: JinaCodeBert ( config) => {
188+ tracing:: info!( "Starting FlashJinaCodeBert model on {:?}" , device) ;
189+ Ok ( Box :: new (
190+ FlashJinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?,
191+ ) )
192+ }
193+ BertConfigWrapper :: Bert ( config) => {
194+ tracing:: info!( "Starting FlashBert model on {:?}" , device) ;
195+ Ok ( Box :: new ( FlashBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
196+ }
174197 }
175- }
176- }
177- #[ cfg( feature = "cuda" ) ]
178- ( Config :: JinaBert ( config) , Device :: Cuda ( _) ) => {
179- if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
180- && dtype == DType :: F16
181- && ( ( config. position_embedding_type == PositionEmbeddingType :: Absolute ) | ( config. position_embedding_type == PositionEmbeddingType :: Alibi ) )
182- // Allow disabling because of flash attention v1 precision problems
183- // See: https://github.com/huggingface/text-embeddings-inference/issues/37
184- && & std:: env:: var ( "USE_FLASH_ATTENTION" ) . unwrap_or ( "True" . to_string ( ) ) . to_lowercase ( ) == "true"
185- {
186- tracing:: info!( "Starting FlashJinaBertModel model on {:?}" , device) ;
187- Ok ( Box :: new (
188- FlashJinaBertModel :: load ( vb, & config, model_type) . s ( ) ?,
189- ) )
190- } else {
191- tracing:: info!( "Starting JinaBertModel model on {:?}" , device) ;
192- Ok ( Box :: new ( JinaBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
193- }
194- }
195- #[ cfg( feature = "cuda" ) ]
196- ( Config :: JinaCodeBert ( config) , Device :: Cuda ( _) ) => {
197- if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
198- && dtype == DType :: F16
199- && ( ( config. position_embedding_type == PositionEmbeddingType :: Absolute ) | ( config. position_embedding_type == PositionEmbeddingType :: Alibi ) )
200- // Allow disabling because of flash attention v1 precision problems
201- // See: https://github.com/huggingface/text-embeddings-inference/issues/37
202- && & std:: env:: var ( "USE_FLASH_ATTENTION" ) . unwrap_or ( "True" . to_string ( ) ) . to_lowercase ( ) == "true"
203- {
204- tracing:: info!( "Starting FlashJinaCodeBertModel model on {:?}" , device) ;
205- Ok ( Box :: new (
206- FlashJinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?,
207- ) )
208198 } else {
209- tracing:: info!( "Starting JinaCodeBertModel model on {:?}" , device) ;
210- Ok ( Box :: new (
211- JinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?,
212- ) )
199+ match config {
200+ BertConfigWrapper :: JinaBert ( config) => {
201+ tracing:: info!( "Starting JinaBertModel model on {:?}" , device) ;
202+ Ok ( Box :: new ( JinaBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
203+ }
204+ BertConfigWrapper :: JinaCodeBert ( config) => {
205+ tracing:: info!( "Starting JinaCodeBert model on {:?}" , device) ;
206+ Ok ( Box :: new (
207+ JinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?,
208+ ) )
209+ }
210+ BertConfigWrapper :: Bert ( config) => {
211+ tracing:: info!( "Starting Bert model on {:?}" , device) ;
212+ Ok ( Box :: new ( BertModel :: load ( vb, & config, model_type) . s ( ) ?) )
213+ }
214+ }
213215 }
214216 }
215217 #[ cfg( feature = "cuda" ) ]
@@ -219,7 +221,6 @@ impl CandleBackend {
219221 ) => {
220222 if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
221223 && dtype == DType :: F16
222- && ( ( config. position_embedding_type == PositionEmbeddingType :: Absolute ) | ( config. position_embedding_type == PositionEmbeddingType :: Alibi ) )
223224 // Allow disabling because of flash attention v1 precision problems
224225 // See: https://github.com/huggingface/text-embeddings-inference/issues/37
225226 && & std:: env:: var ( "USE_FLASH_ATTENTION" ) . unwrap_or ( "True" . to_string ( ) ) . to_lowercase ( ) == "true"
0 commit comments