File tree Expand file tree Collapse file tree 1 file changed +16
-5
lines changed Expand file tree Collapse file tree 1 file changed +16
-5
lines changed Original file line number Diff line number Diff line change @@ -260,11 +260,22 @@ struct Config {
260
260
261
261
impl Config {
262
262
fn get_head_dim ( & self ) -> Option < usize > {
263
- self . head_dim . or_else ( || {
264
- self . text_config
265
- . as_ref ( )
266
- . and_then ( |text_config| text_config. head_dim )
267
- } )
263
+ if let Some ( head_dim) = self . head_dim {
264
+ return Some ( head_dim) ;
265
+ }
266
+
267
+ let text_config = self . text_config . as_ref ( ) ?;
268
+ if let Some ( head_size) = text_config. head_dim {
269
+ return Some ( head_size) ;
270
+ }
271
+
272
+ match self . model_type . as_deref ( ) {
273
+ // We special-case gemma3 here, since we need flashinfer for
274
+ // handling bidirectional masks. And flashinfer can only be
275
+ // used when the head size is known.
276
+ Some ( "gemma3" ) => Some ( 256 ) ,
277
+ _ => None ,
278
+ }
268
279
}
269
280
270
281
fn flop ( & self ) -> Option < u64 > {
You can’t perform that action at this time.
0 commit comments