Skip to content

Commit 6c5e3ce

Browse files
committed
Add support for Gemma 3n MobileNetV5 encoder weight loading
1 parent 1e1c637 commit 6c5e3ce

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

timm/models/mobilenetv5.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import partial
2-
from typing import Callable, List, Optional, Sequence, Tuple, Union
2+
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
33

44
import torch
55
import torch.nn as nn
@@ -129,12 +129,12 @@ def __init__(
129129
num_classes: int = 1000,
130130
in_chans: int = 3,
131131
stem_size: int = 16,
132-
stem_bias: bool = False,
132+
stem_bias: bool = True,
133133
fix_stem: bool = False,
134134
num_features: int = 2048,
135135
pad_type: str = '',
136136
use_msfa: bool = True,
137-
msfa_indices: List[int] = (-3, -2, -1),
137+
msfa_indices: List[int] = (-2, -1),
138138
msfa_output_resolution: int = 16,
139139
act_layer: Optional[LayerType] = None,
140140
norm_layer: Optional[LayerType] = None,
@@ -574,6 +574,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
574574
return self.forward_features(x)
575575

576576

577+
def checkpoint_filter_fn(
578+
state_dict: Dict[str, torch.Tensor],
579+
model,
580+
) -> Dict[str, torch.Tensor]:
581+
""" convert weights from gemma encoders """
582+
state_dict = state_dict.get('model', state_dict)
583+
state_dict = state_dict.get('state_dict', state_dict)
584+
if 'model.vision_tower.timm_model.conv_stem.conv.weight' in state_dict:
585+
prefix = 'model.vision_tower.timm_model.'
586+
state_dict = {k.replace(prefix, ''): v for k, v in state_dict.items() if prefix in k}
587+
return state_dict
588+
589+
577590
def _create_mnv5_encoder(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV5Encoder:
578591
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
579592
feature_cfg = dict(out_indices=out_indices, feature_cls='getter')
@@ -590,21 +603,22 @@ def _create_mnv5_encoder(variant: str, pretrained: bool = False, **kwargs) -> Mo
590603
variant,
591604
pretrained,
592605
pretrained_strict=False,
606+
pretrained_filter_fn=checkpoint_filter_fn,
593607
feature_cfg=feature_cfg,
594608
kwargs_filter=kwargs_filter,
595609
**kwargs,
596610
)
597611
return model
598612

599613

600-
def _create_mnv5(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV5Encoder:
614+
def _create_mnv5(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV5:
601615
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
602616
feature_cfg = dict(out_indices=out_indices, feature_cls='getter')
603617
model = build_model_with_cfg(
604618
MobileNetV5,
605619
variant,
606620
pretrained,
607-
pretrained_strict=False,
621+
pretrained_filter_fn=checkpoint_filter_fn,
608622
feature_cfg=feature_cfg,
609623
**kwargs,
610624
)
@@ -809,8 +823,8 @@ def _cfg(url: str = '', **kwargs):
809823
num_classes=0),
810824

811825
# WIP classification configs for testing
812-
'mobilenetv5_300m': _cfg(
813-
# hf_hub_id='timm/',
826+
'mobilenetv5_300m.gemma3n': _cfg(
827+
hf_hub_id='timm/',
814828
mean=(0., 0., 0.), std=(1., 1., 1.),
815829
input_size=(3, 768, 768),
816830
num_classes=0),

0 commit comments

Comments
 (0)