Skip to content

Commit 8312d19

Browse files
author
Ali Roshan Ghias
committed
feat(mimo): phase 2 (model provider, DDP wrapping, process groups)
Signed-off-by: Ali Roshan Ghias <[email protected]>
1 parent 37dd94f commit 8312d19

File tree

8 files changed

+796
-13
lines changed

8 files changed

+796
-13
lines changed

src/megatron/bridge/models/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,10 @@
210210
EncoderTransformerConfig,
211211
GenericVisionEncoderProvider,
212212
)
213+
from megatron.bridge.models.mimo_provider import (
214+
MIMOModelProvider,
215+
MIMOModelProviderResult,
216+
)
213217
from megatron.bridge.models.t5_provider import T5ModelProvider
214218

215219

@@ -253,6 +257,8 @@
253257
"EncoderProvider",
254258
"EncoderTransformerConfig",
255259
"GenericVisionEncoderProvider",
260+
"MIMOModelProvider",
261+
"MIMOModelProviderResult",
256262
"LlamaModelProvider",
257263
"Llama2ModelProvider7B",
258264
"Llama2ModelProvider13B",

src/megatron/bridge/models/encoder_provider.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,45 +2,98 @@
22

33
from abc import ABC, abstractmethod
44
from dataclasses import dataclass
5-
from typing import Optional
5+
from typing import Any, Optional
66

77
from megatron.core.transformer.spec_utils import ModuleSpec
88

99

1010
@dataclass
1111
class EncoderTransformerConfig:
12-
"""Lightweight base config for encoder providers."""
12+
"""Lightweight base config for encoder providers.
13+
14+
Attributes:
15+
num_layers: Number of transformer layers in the encoder.
16+
hidden_size: Hidden dimension size of the encoder.
17+
num_attention_heads: Number of attention heads.
18+
seq_length: Sequence length for the encoder.
19+
projector_type: Type of projector (e.g., "mlp", "linear", "qformer").
20+
None means no projection is needed.
21+
projector_input_size: Input size for projector. Defaults to hidden_size.
22+
projector_output_size: Output size for projector (e.g., LLM hidden size).
23+
Required if projector_type is set.
24+
projector_config: Optional TransformerConfig for the projector module.
25+
"""
1326

1427
num_layers: int
1528
hidden_size: int
1629
num_attention_heads: int
1730
seq_length: int
1831

32+
# Projector support for VLM setups
33+
projector_type: Optional[str] = None
34+
projector_input_size: Optional[int] = None
35+
projector_output_size: Optional[int] = None
36+
projector_config: Optional[Any] = None
37+
38+
def __post_init__(self) -> None:
39+
"""Set default projector_input_size to hidden_size if not specified."""
40+
if self.projector_input_size is None:
41+
self.projector_input_size = self.hidden_size
42+
1943

2044
class EncoderProvider(ABC):
21-
"""Interface for encoder providers used in MIMO setups."""
45+
"""Interface for encoder providers used in MIMO setups.
46+
47+
Subclasses must set the `config` attribute to an EncoderTransformerConfig.
48+
"""
49+
50+
config: EncoderTransformerConfig
2251

2352
@abstractmethod
2453
def provide_model(self, pg_collection) -> object:
2554
"""Create the encoder module (unwrapped)."""
2655

2756
@abstractmethod
2857
def get_transformer_layer_spec(self) -> ModuleSpec:
29-
"""Return the ModuleSpec for the encoder stack."""
58+
"""Return the ModuleSpec for the encoder transformer layers."""
3059

3160
@abstractmethod
3261
def get_projection_spec(self) -> Optional[ModuleSpec]:
33-
"""Optional projection ModuleSpec for encoder outputs."""
62+
"""Optional projection ModuleSpec for encoder outputs.
63+
64+
Returns None if no projection is needed.
65+
"""
66+
67+
def has_projector(self) -> bool:
68+
"""Check if this encoder requires a projector."""
69+
return self.config.projector_type is not None
70+
71+
def validate_projector_config(self) -> None:
72+
"""Validate projector configuration consistency.
73+
74+
Raises:
75+
ValueError: If projector_type is set but required fields are missing,
76+
or if projector_type is set but get_projection_spec() returns None.
77+
"""
78+
if self.config.projector_type is not None:
79+
if self.config.projector_output_size is None:
80+
raise ValueError(
81+
f"projector_output_size must be set when projector_type='{self.config.projector_type}'"
82+
)
83+
if self.get_projection_spec() is None:
84+
raise ValueError(
85+
f"get_projection_spec() must return a ModuleSpec when "
86+
f"projector_type='{self.config.projector_type}'"
87+
)
3488

3589

3690
class GenericVisionEncoderProvider(EncoderProvider):
37-
"""Minimal stub encoder provider for Phase 1 wiring."""
91+
"""Minimal stub encoder provider for Phase 1/2 wiring."""
3892

3993
def __init__(self, config: EncoderTransformerConfig) -> None:
4094
self.config = config
4195

4296
def provide_model(self, pg_collection) -> object:
43-
# Stub: actual encoder creation will be implemented in Phase 2.
4497
raise NotImplementedError("GenericVisionEncoderProvider.provide_model not implemented.")
4598

4699
def get_transformer_layer_spec(self) -> ModuleSpec:

0 commit comments

Comments
 (0)