Skip to content

Conversation

@yaoyu-33
Copy link
Contributor

Refactor provider_bridge for Llama and Qwen models

Summary

This PR refactors the model bridge architecture to centralize model-specific configurations within the provider_bridge method of each model bridge. It eliminates redundant per-model provider classes (e.g., LlamaModelProvider, Qwen2ModelProvider) in favor of using GPTModelProvider directly with model-specific defaults defined in bridge classes.

Motivation

Before:

LlamaForCausalLM (HF) 
    ↓ 
LlamaBridge.provider_bridge() 
    ↓ 
LlamaModelProvider (custom subclass)
    ↓
Llama2ModelProvider7B / Llama3ModelProvider / etc.

After:

LlamaForCausalLM (HF) 
    ↓ 
LlamaBridge.provider_bridge()  ← Uses base class + MEGATRON_DEFAULTS
    ↓ 
GPTModelProvider (generic, configured via kwargs)

This simplifies the codebase by:

  1. Eliminating redundant provider classes per model size
  2. Using GPTModelProvider directly with model-specific defaults defined in bridges
  3. Enabling bidirectional config conversion (HF ↔ Megatron)
  4. Leveraging base class helpers for common logic

Key Changes

Base Class Enhancements (MegatronModelBridge)

  • CONFIG_MAPPING: Common bidirectional field name mapping (HF ↔ Megatron):

    CONFIG_MAPPING = [
        ("num_hidden_layers", "num_layers"),
        ("hidden_size", "hidden_size"),
        ("intermediate_size", "ffn_hidden_size"),
        ("num_attention_heads", "num_attention_heads"),
        ("num_key_value_heads", "num_query_groups"),
        # ... and more
    ]
  • MoE-related field mappings added to base class:

    • num_expertsnum_moe_experts
    • num_experts_per_tokmoe_router_topk
    • moe_intermediate_sizemoe_ffn_hidden_size
  • ACTIVATION_MAPPING: Common activation function mapping (silu, gelu, relu, tanh)

  • Helper methods:

    • hf_config_to_provider_kwargs(): HF config → Megatron provider kwargs
    • provider_to_hf_config(): Megatron provider → HF config dict
    • hf_to_megatron_activation(): Activation string → function
    • megatron_to_hf_activation(): Activation function → string
  • Default provider_bridge() implementation: Subclasses no longer need to override unless special handling is required

Model Bridge Refactoring

Model Changes
LlamaBridge Use MEGATRON_DEFAULTS and HF_DEFAULTS class attributes; Override provider_bridge() only for RoPE scaling (Llama 3.1/3.2)
Qwen2Bridge Use MEGATRON_DEFAULTS (add_qkv_bias=True) and HF_DEFAULTS; No provider_bridge() override needed
Qwen3Bridge Use MEGATRON_DEFAULTS (qk_layernorm=True) and HF_DEFAULTS; No provider_bridge() override needed
Qwen3MoEBridge Use MEGATRON_DEFAULTS with MoE settings and HF_DEFAULTS; No provider_bridge() override needed

Code Style: Minimal Overrides

# ✅ Best: No override if no special logic needed
class Qwen2Bridge(MegatronModelBridge):
    MEGATRON_DEFAULTS = {"add_qkv_bias": True, ...}
    HF_DEFAULTS = {"model_type": "qwen2", ...}
    # Base class provider_bridge() handles everything!

# ✅ Good: Override only for model-specific handling
class LlamaBridge(MegatronModelBridge):
    MEGATRON_DEFAULTS = {...}
    HF_DEFAULTS = {...}
    
    def provider_bridge(self, hf_pretrained):
        provider = super().provider_bridge(hf_pretrained)
        # RoPE scaling for Llama 3.1/3.2 only
        if rope_scaling := getattr(hf_pretrained.config, "rope_scaling", None):
            provider._rope_scaling = {...}
        return provider

Files Changed

Core

  • src/megatron/bridge/models/conversion/model_bridge.py - Added CONFIG_MAPPING, ACTIVATION_MAPPING, helper methods, default provider_bridge()
  • src/megatron/bridge/models/llama/llama_bridge.py - Simplified to use base class
  • src/megatron/bridge/models/qwen/qwen2_bridge.py - Simplified to use base class
  • src/megatron/bridge/models/qwen/qwen3_bridge.py - Simplified to use base class
  • src/megatron/bridge/models/qwen/qwen3_moe_bridge.py - Simplified to use base class

Tests

  • tests/unit_tests/models/llama/test_llama_bridge.py - Updated to test new structure, expect GPTModelProvider
  • tests/unit_tests/models/qwen/test_qwen3_bridge.py - Updated to expect GPTModelProvider
  • tests/unit_tests/models/qwen/test_qwen3_moe_bridge.py - Updated to expect GPTModelProvider

Verification Scripts

  • scripts/verify_llama_provider_refactor.py - Verifies Llama conversion parity
  • scripts/verify_qwen_provider_refactor.py - Verifies Qwen conversion parity

Verification

Tested Models

Model Status
Qwen/Qwen2-0.5B ✅ PASS
Qwen/Qwen2-7B ✅ PASS
Qwen/Qwen3-0.6B ✅ PASS
Qwen/Qwen3-1.7B ✅ PASS
Qwen/Qwen3-30B-A3B (MoE) ✅ PASS

Running Verification

# Verify Llama bridges
uv run python scripts/verify_llama_provider_refactor.py

# Verify Qwen bridges
uv run python scripts/verify_qwen_provider_refactor.py

Breaking Changes

  • Tests updated to expect GPTModelProvider instead of model-specific providers (e.g., LlamaModelProvider, Qwen2ModelProvider)
  • No public API changes - provider_bridge() return type is still compatible

Migration Guide

For downstream code that was type-checking for specific provider classes:

# Before
from megatron.bridge.models.llama.llama_provider import LlamaModelProvider
assert isinstance(provider, LlamaModelProvider)

# After
from megatron.bridge.models.gpt_provider import GPTModelProvider
assert isinstance(provider, GPTModelProvider)

Design Document

See docs/proposals/provider_bridge_refactor.md for the complete refactoring guide and design principles.

Checklist

  • Base class CONFIG_MAPPING covers common HF ↔ Megatron field mappings
  • Base class ACTIVATION_MAPPING covers common activation functions
  • MoE-related mappings added to base class
  • LlamaBridge refactored with minimal override (RoPE scaling only)
  • Qwen2Bridge refactored with no override needed
  • Qwen3Bridge refactored with no override needed
  • Qwen3MoEBridge refactored with no override needed
  • Unit tests updated
  • Verification scripts added
  • Tested on remote server
  • Run full CI test suite
  • Update documentation for other model bridges

Future Work

Other models to refactor using this pattern:

  • Mistral
  • Phi
  • Gemma
  • DeepSeek

This refactoring centralizes model-specific configurations within the
provider_bridge method of each model bridge.

Changes:
- Add MoE-related field mappings to base class CONFIG_MAPPING:
  - num_experts -> num_moe_experts
  - num_experts_per_tok -> moe_router_topk
  - moe_intermediate_size -> moe_ffn_hidden_size

- Refactor LlamaBridge:
  - Use MEGATRON_DEFAULTS and HF_DEFAULTS class attributes
  - Override provider_bridge only for RoPE scaling (Llama 3.1/3.2)

- Refactor Qwen2Bridge:
  - Use MEGATRON_DEFAULTS (add_qkv_bias=True) and HF_DEFAULTS
  - No provider_bridge override needed

- Refactor Qwen3Bridge:
  - Use MEGATRON_DEFAULTS (qk_layernorm=True) and HF_DEFAULTS
  - No provider_bridge override needed

- Refactor Qwen3MoEBridge:
  - Use MEGATRON_DEFAULTS with MoE settings and HF_DEFAULTS
  - No provider_bridge override needed

- Update tests to expect GPTModelProvider instead of model-specific providers
- Add verification scripts for both Llama and Qwen bridges

Verified on remote server:
- Qwen/Qwen2-0.5B: PASS
- Qwen/Qwen2-7B: PASS
- Qwen/Qwen3-0.6B: PASS
- Qwen/Qwen3-1.7B: PASS
- Qwen/Qwen3-30B-A3B: PASS
@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 23, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

…dels

- Add MLAModelProvider as unified base for Multi-Latent Attention models
- Refactor DeepSeek V2/V3 bridges to use MLAModelProvider
- Refactor Kimi K2 bridge to use MLAModelProvider
- Move model-specific defaults from providers to MEGATRON_DEFAULTS in bridges
- Add model_type parameter to @register_bridge decorator for auto HF config
- Simplify provider files to deprecated backward-compatible aliases

Verified: DeepSeek-V2-Lite, DeepSeek-V2, DeepSeek-V3, Moonlight-16B, Kimi-K2
- Register GemmaModelProvider, Gemma2ModelProvider, Gemma3ModelProvider via decorator
- Add MEGATRON_DEFAULTS to Gemma/Gemma2 bridges for explicit config defaults
- Add gelu_pytorch_tanh -> fast_gelu to ACTIVATION_MAPPING in model_bridge.py
- Add verification script for Gemma provider refactoring

Verified: gemma-2b, gemma-7b, gemma-2-2b, gemma-2-9b, gemma-2-27b,
         gemma-3-4b-it, gemma-3-12b-it, gemma-3-27b-it
Signed-off-by: yaoyu-33 <[email protected]>
Signed-off-by: yaoyu-33 <[email protected]>
for hf_name, megatron_func in cls.ACTIVATION_MAPPING.items():
if activation_func is megatron_func:
return hf_name
# Default to silu if not found
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise error if not found

return provider_class(**provider_kwargs)

@classmethod
def megatron_to_hf_config(cls, provider) -> dict:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

combine with provider_to_hf_config

*,
source: Type[PreTrainedModel] | str,
target: Type[MegatronModel],
provider: Optional[Type[ModelProviderTarget]] = None,
Copy link
Contributor Author

@yaoyu-33 yaoyu-33 Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not use optional, use Type[ModelProviderTarget] | None, same for other changes

ffn_hidden_size: int = 12288
num_moe_experts: int = 160
moe_ffn_hidden_size: int = 1536
moe_shared_expert_intermediate_size: int = 3072 # 1536 * 2 shared experts
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert changes


@dataclass
class DeepSeekV2ModelProvider(DeepSeekModelProvider):
"""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert docstring changes

Signed-off-by: yaoyu-33 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants