Skip to content

Commit 32cc319

Browse files
committed
launcher: ensure correct detection of Gemma 3 head size
1 parent 3d71c06 commit 32cc319

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

launcher/src/main.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,22 @@ struct Config {
260260

261261
impl Config {
262262
fn get_head_dim(&self) -> Option<usize> {
263-
self.head_dim.or_else(|| {
264-
self.text_config
265-
.as_ref()
266-
.and_then(|text_config| text_config.head_dim)
267-
})
263+
if let Some(head_dim) = self.head_dim {
264+
return Some(head_dim);
265+
}
266+
267+
let text_config = self.text_config.as_ref()?;
268+
if let Some(head_size) = text_config.head_dim {
269+
return Some(head_size);
270+
}
271+
272+
match self.model_type.as_deref() {
273+
// We special-case gemma3 here, since we need flashinfer for
274+
// handling bidirectional masks. And flashinfer can only be
275+
// used when the head size is known.
276+
Some("gemma3") => Some(256),
277+
_ => None,
278+
}
268279
}
269280

270281
fn flop(&self) -> Option<u64> {

0 commit comments

Comments
 (0)