@@ -586,27 +586,15 @@ def _normalize_arch(
586
586
587
587
return architecture
588
588
589
- def _normalize_archs (
590
- self ,
591
- architectures : list [str ],
592
- model_config : ModelConfig ,
593
- ) -> list [str ]:
594
- if not architectures :
595
- logger .warning ("No model architectures are specified" )
596
-
597
- return [
598
- self ._normalize_arch (arch , model_config ) for arch in architectures
599
- ]
600
-
601
589
def inspect_model_cls (
602
590
self ,
603
591
architectures : Union [str , list [str ]],
604
592
model_config : ModelConfig ,
605
593
) -> tuple [_ModelInfo , str ]:
606
594
if isinstance (architectures , str ):
607
595
architectures = [architectures ]
608
-
609
- normalized_archs = self . _normalize_archs ( architectures , model_config )
596
+ if not architectures :
597
+ raise ValueError ( "No model architectures are specified" )
610
598
611
599
# Require transformers impl
612
600
if model_config .model_impl == ModelImpl .TRANSFORMERS :
@@ -617,13 +605,26 @@ def inspect_model_cls(
617
605
if model_info is not None :
618
606
return (model_info , arch )
619
607
620
- for arch , normalized_arch in zip (architectures , normalized_archs ):
608
+ # Fallback to transformers impl (after resolving convert_type)
609
+ if (all (arch not in self .models for arch in architectures )
610
+ and model_config .model_impl == ModelImpl .AUTO
611
+ and getattr (model_config , "convert_type" , "none" ) == "none" ):
612
+ arch = self ._try_resolve_transformers (architectures [0 ],
613
+ model_config )
614
+ if arch is not None :
615
+ model_info = self ._try_inspect_model_cls (arch )
616
+ if model_info is not None :
617
+ return (model_info , arch )
618
+
619
+ for arch in architectures :
620
+ normalized_arch = self ._normalize_arch (arch , model_config )
621
621
model_info = self ._try_inspect_model_cls (normalized_arch )
622
622
if model_info is not None :
623
623
return (model_info , arch )
624
624
625
- # Fallback to transformers impl
626
- if model_config .model_impl in (ModelImpl .AUTO , ModelImpl .TRANSFORMERS ):
625
+ # Fallback to transformers impl (before resolving runner_type)
626
+ if (all (arch not in self .models for arch in architectures )
627
+ and model_config .model_impl == ModelImpl .AUTO ):
627
628
arch = self ._try_resolve_transformers (architectures [0 ],
628
629
model_config )
629
630
if arch is not None :
@@ -640,8 +641,8 @@ def resolve_model_cls(
640
641
) -> tuple [type [nn .Module ], str ]:
641
642
if isinstance (architectures , str ):
642
643
architectures = [architectures ]
643
-
644
- normalized_archs = self . _normalize_archs ( architectures , model_config )
644
+ if not architectures :
645
+ raise ValueError ( "No model architectures are specified" )
645
646
646
647
# Require transformers impl
647
648
if model_config .model_impl == ModelImpl .TRANSFORMERS :
@@ -652,13 +653,26 @@ def resolve_model_cls(
652
653
if model_cls is not None :
653
654
return (model_cls , arch )
654
655
655
- for arch , normalized_arch in zip (architectures , normalized_archs ):
656
+ # Fallback to transformers impl (after resolving convert_type)
657
+ if (all (arch not in self .models for arch in architectures )
658
+ and model_config .model_impl == ModelImpl .AUTO
659
+ and getattr (model_config , "convert_type" , "none" ) == "none" ):
660
+ arch = self ._try_resolve_transformers (architectures [0 ],
661
+ model_config )
662
+ if arch is not None :
663
+ model_cls = self ._try_load_model_cls (arch )
664
+ if model_cls is not None :
665
+ return (model_cls , arch )
666
+
667
+ for arch in architectures :
668
+ normalized_arch = self ._normalize_arch (arch , model_config )
656
669
model_cls = self ._try_load_model_cls (normalized_arch )
657
670
if model_cls is not None :
658
671
return (model_cls , arch )
659
672
660
- # Fallback to transformers impl
661
- if model_config .model_impl in (ModelImpl .AUTO , ModelImpl .TRANSFORMERS ):
673
+ # Fallback to transformers impl (before resolving runner_type)
674
+ if (all (arch not in self .models for arch in architectures )
675
+ and model_config .model_impl == ModelImpl .AUTO ):
662
676
arch = self ._try_resolve_transformers (architectures [0 ],
663
677
model_config )
664
678
if arch is not None :
0 commit comments