-
Notifications
You must be signed in to change notification settings - Fork 148
Refactor provider_bridge for Llama and Qwen models #2052
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This refactoring centralizes model-specific configurations within the provider_bridge method of each model bridge. Changes: - Add MoE-related field mappings to base class CONFIG_MAPPING: - num_experts -> num_moe_experts - num_experts_per_tok -> moe_router_topk - moe_intermediate_size -> moe_ffn_hidden_size - Refactor LlamaBridge: - Use MEGATRON_DEFAULTS and HF_DEFAULTS class attributes - Override provider_bridge only for RoPE scaling (Llama 3.1/3.2) - Refactor Qwen2Bridge: - Use MEGATRON_DEFAULTS (add_qkv_bias=True) and HF_DEFAULTS - No provider_bridge override needed - Refactor Qwen3Bridge: - Use MEGATRON_DEFAULTS (qk_layernorm=True) and HF_DEFAULTS - No provider_bridge override needed - Refactor Qwen3MoEBridge: - Use MEGATRON_DEFAULTS with MoE settings and HF_DEFAULTS - No provider_bridge override needed - Update tests to expect GPTModelProvider instead of model-specific providers - Add verification scripts for both Llama and Qwen bridges Verified on remote server: - Qwen/Qwen2-0.5B: PASS - Qwen/Qwen2-7B: PASS - Qwen/Qwen3-0.6B: PASS - Qwen/Qwen3-1.7B: PASS - Qwen/Qwen3-30B-A3B: PASS
…dels - Add MLAModelProvider as unified base for Multi-Latent Attention models - Refactor DeepSeek V2/V3 bridges to use MLAModelProvider - Refactor Kimi K2 bridge to use MLAModelProvider - Move model-specific defaults from providers to MEGATRON_DEFAULTS in bridges - Add model_type parameter to @register_bridge decorator for auto HF config - Simplify provider files to deprecated backward-compatible aliases Verified: DeepSeek-V2-Lite, DeepSeek-V2, DeepSeek-V3, Moonlight-16B, Kimi-K2
- Register GemmaModelProvider, Gemma2ModelProvider, Gemma3ModelProvider via decorator
- Add MEGATRON_DEFAULTS to Gemma/Gemma2 bridges for explicit config defaults
- Add gelu_pytorch_tanh -> fast_gelu to ACTIVATION_MAPPING in model_bridge.py
- Add verification script for Gemma provider refactoring
Verified: gemma-2b, gemma-7b, gemma-2-2b, gemma-2-9b, gemma-2-27b,
gemma-3-4b-it, gemma-3-12b-it, gemma-3-27b-it
Signed-off-by: yaoyu-33 <[email protected]>
Signed-off-by: yaoyu-33 <[email protected]>
Signed-off-by: yaoyu-33 <[email protected]>
| for hf_name, megatron_func in cls.ACTIVATION_MAPPING.items(): | ||
| if activation_func is megatron_func: | ||
| return hf_name | ||
| # Default to silu if not found |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise error if not found
| return provider_class(**provider_kwargs) | ||
|
|
||
| @classmethod | ||
| def megatron_to_hf_config(cls, provider) -> dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
combine with provider_to_hf_config
| *, | ||
| source: Type[PreTrainedModel] | str, | ||
| target: Type[MegatronModel], | ||
| provider: Optional[Type[ModelProviderTarget]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do not use optional, use Type[ModelProviderTarget] | None, same for other changes
| ffn_hidden_size: int = 12288 | ||
| num_moe_experts: int = 160 | ||
| moe_ffn_hidden_size: int = 1536 | ||
| moe_shared_expert_intermediate_size: int = 3072 # 1536 * 2 shared experts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert changes
|
|
||
| @dataclass | ||
| class DeepSeekV2ModelProvider(DeepSeekModelProvider): | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert docstring changes
Signed-off-by: yaoyu-33 <[email protected]>
Refactor
provider_bridgefor Llama and Qwen modelsSummary
This PR refactors the model bridge architecture to centralize model-specific configurations within the
provider_bridgemethod of each model bridge. It eliminates redundant per-model provider classes (e.g.,LlamaModelProvider,Qwen2ModelProvider) in favor of usingGPTModelProviderdirectly with model-specific defaults defined in bridge classes.Motivation
Before:
After:
This simplifies the codebase by:
GPTModelProviderdirectly with model-specific defaults defined in bridgesKey Changes
Base Class Enhancements (
MegatronModelBridge)CONFIG_MAPPING: Common bidirectional field name mapping (HF ↔ Megatron):MoE-related field mappings added to base class:
num_experts→num_moe_expertsnum_experts_per_tok→moe_router_topkmoe_intermediate_size→moe_ffn_hidden_sizeACTIVATION_MAPPING: Common activation function mapping (silu, gelu, relu, tanh)Helper methods:
hf_config_to_provider_kwargs(): HF config → Megatron provider kwargsprovider_to_hf_config(): Megatron provider → HF config dicthf_to_megatron_activation(): Activation string → functionmegatron_to_hf_activation(): Activation function → stringDefault
provider_bridge()implementation: Subclasses no longer need to override unless special handling is requiredModel Bridge Refactoring
MEGATRON_DEFAULTSandHF_DEFAULTSclass attributes; Overrideprovider_bridge()only for RoPE scaling (Llama 3.1/3.2)MEGATRON_DEFAULTS(add_qkv_bias=True) andHF_DEFAULTS; Noprovider_bridge()override neededMEGATRON_DEFAULTS(qk_layernorm=True) andHF_DEFAULTS; Noprovider_bridge()override neededMEGATRON_DEFAULTSwith MoE settings andHF_DEFAULTS; Noprovider_bridge()override neededCode Style: Minimal Overrides
Files Changed
Core
src/megatron/bridge/models/conversion/model_bridge.py- AddedCONFIG_MAPPING,ACTIVATION_MAPPING, helper methods, defaultprovider_bridge()src/megatron/bridge/models/llama/llama_bridge.py- Simplified to use base classsrc/megatron/bridge/models/qwen/qwen2_bridge.py- Simplified to use base classsrc/megatron/bridge/models/qwen/qwen3_bridge.py- Simplified to use base classsrc/megatron/bridge/models/qwen/qwen3_moe_bridge.py- Simplified to use base classTests
tests/unit_tests/models/llama/test_llama_bridge.py- Updated to test new structure, expectGPTModelProvidertests/unit_tests/models/qwen/test_qwen3_bridge.py- Updated to expectGPTModelProvidertests/unit_tests/models/qwen/test_qwen3_moe_bridge.py- Updated to expectGPTModelProviderVerification Scripts
scripts/verify_llama_provider_refactor.py- Verifies Llama conversion parityscripts/verify_qwen_provider_refactor.py- Verifies Qwen conversion parityVerification
Tested Models
Running Verification
Breaking Changes
GPTModelProviderinstead of model-specific providers (e.g.,LlamaModelProvider,Qwen2ModelProvider)provider_bridge()return type is still compatibleMigration Guide
For downstream code that was type-checking for specific provider classes:
Design Document
See
docs/proposals/provider_bridge_refactor.mdfor the complete refactoring guide and design principles.Checklist
CONFIG_MAPPINGcovers common HF ↔ Megatron field mappingsACTIVATION_MAPPINGcovers common activation functionsLlamaBridgerefactored with minimal override (RoPE scaling only)Qwen2Bridgerefactored with no override neededQwen3Bridgerefactored with no override neededQwen3MoEBridgerefactored with no override neededFuture Work
Other models to refactor using this pattern: