1
1
from functools import partial
2
- from typing import Callable , List , Optional , Sequence , Tuple , Union
2
+ from typing import Callable , Dict , List , Optional , Sequence , Tuple , Union
3
3
4
4
import torch
5
5
import torch .nn as nn
@@ -129,12 +129,12 @@ def __init__(
129
129
num_classes : int = 1000 ,
130
130
in_chans : int = 3 ,
131
131
stem_size : int = 16 ,
132
- stem_bias : bool = False ,
132
+ stem_bias : bool = True ,
133
133
fix_stem : bool = False ,
134
134
num_features : int = 2048 ,
135
135
pad_type : str = '' ,
136
136
use_msfa : bool = True ,
137
- msfa_indices : List [int ] = (- 3 , - 2 , - 1 ),
137
+ msfa_indices : List [int ] = (- 2 , - 1 ),
138
138
msfa_output_resolution : int = 16 ,
139
139
act_layer : Optional [LayerType ] = None ,
140
140
norm_layer : Optional [LayerType ] = None ,
@@ -574,6 +574,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
574
574
return self .forward_features (x )
575
575
576
576
577
+ def checkpoint_filter_fn (
578
+ state_dict : Dict [str , torch .Tensor ],
579
+ model ,
580
+ ) -> Dict [str , torch .Tensor ]:
581
+ """ convert weights from gemma encoders """
582
+ state_dict = state_dict .get ('model' , state_dict )
583
+ state_dict = state_dict .get ('state_dict' , state_dict )
584
+ if 'model.vision_tower.timm_model.conv_stem.conv.weight' in state_dict :
585
+ prefix = 'model.vision_tower.timm_model.'
586
+ state_dict = {k .replace (prefix , '' ): v for k , v in state_dict .items () if prefix in k }
587
+ return state_dict
588
+
589
+
577
590
def _create_mnv5_encoder (variant : str , pretrained : bool = False , ** kwargs ) -> MobileNetV5Encoder :
578
591
out_indices = kwargs .pop ('out_indices' , (0 , 1 , 2 , 3 , 4 ))
579
592
feature_cfg = dict (out_indices = out_indices , feature_cls = 'getter' )
@@ -590,21 +603,22 @@ def _create_mnv5_encoder(variant: str, pretrained: bool = False, **kwargs) -> Mo
590
603
variant ,
591
604
pretrained ,
592
605
pretrained_strict = False ,
606
+ pretrained_filter_fn = checkpoint_filter_fn ,
593
607
feature_cfg = feature_cfg ,
594
608
kwargs_filter = kwargs_filter ,
595
609
** kwargs ,
596
610
)
597
611
return model
598
612
599
613
600
- def _create_mnv5 (variant : str , pretrained : bool = False , ** kwargs ) -> MobileNetV5Encoder :
614
+ def _create_mnv5 (variant : str , pretrained : bool = False , ** kwargs ) -> MobileNetV5 :
601
615
out_indices = kwargs .pop ('out_indices' , (0 , 1 , 2 , 3 , 4 ))
602
616
feature_cfg = dict (out_indices = out_indices , feature_cls = 'getter' )
603
617
model = build_model_with_cfg (
604
618
MobileNetV5 ,
605
619
variant ,
606
620
pretrained ,
607
- pretrained_strict = False ,
621
+ pretrained_filter_fn = checkpoint_filter_fn ,
608
622
feature_cfg = feature_cfg ,
609
623
** kwargs ,
610
624
)
@@ -809,8 +823,8 @@ def _cfg(url: str = '', **kwargs):
809
823
num_classes = 0 ),
810
824
811
825
# WIP classification configs for testing
812
- 'mobilenetv5_300m' : _cfg (
813
- # hf_hub_id='timm/',
826
+ 'mobilenetv5_300m.gemma3n ' : _cfg (
827
+ hf_hub_id = 'timm/' ,
814
828
mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ),
815
829
input_size = (3 , 768 , 768 ),
816
830
num_classes = 0 ),
0 commit comments