Skip to content

Commit 310869a

Browse files
author
Ali Roshan Ghias
committed
feat(mimo): phase 2 (model provider, DDP wrapping, process groups)
Signed-off-by: Ali Roshan Ghias <aliroshanghias@nvidia.com>
1 parent 84041e6 commit 310869a

File tree

7 files changed

+846
-7
lines changed

7 files changed

+846
-7
lines changed

src/megatron/bridge/models/mimo/mimo_builder.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
from __future__ import annotations
22

3-
from typing import Dict, List, Optional
3+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
4+
5+
import torch.distributed as dist
46

57
from megatron.bridge.training.mimo_config import MimoParallelismConfig
68

9+
if TYPE_CHECKING:
10+
from megatron.core.hyper_comm_grid import HyperCommGrid
11+
712

813
def build_hypercomm_grids(
914
mimo_parallelism_config: MimoParallelismConfig,
@@ -66,3 +71,67 @@ def build_colocated_comm_config(
6671
topology=topology,
6772
dim_mapping={"b": 0, "s": 1, "h": 2},
6873
)
74+
75+
76+
def populate_embedding_and_position_groups(
77+
pp_group: dist.ProcessGroup,
78+
) -> Tuple[Optional[dist.ProcessGroup], Optional[dist.ProcessGroup]]:
79+
"""Create embedding-related process groups from PP group ranks.
80+
81+
Following MCore semantics:
82+
- pos_embd_pg: Only rank 0 of PP (first stage) - for position embeddings
83+
- embd_pg: Ranks 0 and -1 of PP (first and last stages) - for tied word embeddings
84+
85+
IMPORTANT: This calls dist.new_group which is a collective operation.
86+
Must be called on all ranks that could participate.
87+
88+
Args:
89+
pp_group: The pipeline parallel process group.
90+
91+
Returns:
92+
Tuple of (pos_embd_pg, embd_pg). Returns (None, None) if pp_group is None.
93+
"""
94+
if pp_group is None:
95+
return None, None
96+
97+
pp_ranks = sorted(dist.get_process_group_ranks(pp_group))
98+
99+
# Position embeddings only on first PP stage
100+
pos_embd_ranks = [pp_ranks[0]]
101+
pos_embd_pg = dist.new_group(ranks=pos_embd_ranks)
102+
103+
# Word embeddings on first and last PP stages (for tied embeddings)
104+
embd_ranks = [pp_ranks[0]]
105+
if len(pp_ranks) > 1 and pp_ranks[-1] != pp_ranks[0]:
106+
embd_ranks.append(pp_ranks[-1])
107+
embd_pg = dist.new_group(ranks=embd_ranks)
108+
109+
return pos_embd_pg, embd_pg
110+
111+
112+
def is_pp_first_stage(pp_group: Optional[dist.ProcessGroup]) -> bool:
113+
"""Check if current rank is first stage in pipeline."""
114+
if pp_group is None:
115+
return True
116+
pp_ranks = sorted(dist.get_process_group_ranks(pp_group))
117+
return dist.get_rank() == pp_ranks[0]
118+
119+
120+
def is_pp_last_stage(pp_group: Optional[dist.ProcessGroup]) -> bool:
121+
"""Check if current rank is last stage in pipeline."""
122+
if pp_group is None:
123+
return True
124+
pp_ranks = sorted(dist.get_process_group_ranks(pp_group))
125+
return dist.get_rank() == pp_ranks[-1]
126+
127+
128+
def is_current_rank_in_grid(grid: "HyperCommGrid") -> bool:
129+
"""Check if the current rank participates in this grid.
130+
131+
Args:
132+
grid: A HyperCommGrid instance.
133+
134+
Returns:
135+
True if dist.get_rank() is within the grid's rank range.
136+
"""
137+
return grid.rank_offset <= dist.get_rank() < (grid.rank_offset + grid.size)

src/megatron/bridge/models/mimo/mimo_provider.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
build_hypercomm_grids,
2121
build_colocated_comm_config,
2222
_default_topology,
23+
populate_embedding_and_position_groups,
24+
is_pp_first_stage,
25+
is_pp_last_stage,
2326
)
2427

2528
if TYPE_CHECKING:
@@ -109,6 +112,7 @@ def _get_pg_collections_from_grids(
109112
) -> Dict[str, Optional[ProcessGroupCollection]]:
110113
"""Get ProcessGroupCollections from HyperCommGrids.
111114
115+
Creates all standard process groups plus embedding groups for PP > 1.
112116
Returns None for modules this rank doesn't participate in.
113117
"""
114118
pg_collections: Dict[str, Optional[ProcessGroupCollection]] = {}
@@ -117,13 +121,26 @@ def _get_pg_collections_from_grids(
117121
for module_name, grid in grids.items():
118122
# Check if current rank is in this grid's range
119123
if grid.rank_offset <= current_rank < (grid.rank_offset + grid.size):
124+
pp_group = grid.get_pg(["pp"])
125+
126+
# Create embedding groups for PP > 1 (collective operation on all PP ranks)
127+
pos_embd_pg, embd_pg = populate_embedding_and_position_groups(pp_group)
128+
129+
# Only assign embedding groups to ranks that should have them
130+
first_stage = is_pp_first_stage(pp_group)
131+
last_stage = is_pp_last_stage(pp_group)
132+
120133
pg_collections[module_name] = ProcessGroupCollection(
121134
tp=grid.get_pg(["tp"]),
122135
dp=grid.get_pg(["dp"]),
123-
pp=grid.get_pg(["pp"]),
136+
pp=pp_group,
124137
cp=grid.get_pg(["cp"]),
125138
ep=grid.get_pg(["ep"]),
126139
dp_cp=grid.get_pg(["dp", "cp"]),
140+
# Position embeddings only on first PP stage
141+
pos_embd=pos_embd_pg if first_stage else None,
142+
# Word embeddings on first and last PP stages (for tied embeddings)
143+
embd=embd_pg if (first_stage or last_stage) else None,
127144
)
128145
else:
129146
pg_collections[module_name] = None

src/megatron/bridge/training/mimo_config.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,18 +107,48 @@ def _validate_homogeneous(self) -> None:
107107
def _validate_heterogeneous(self) -> None:
108108
# "heterogeneous" describes rank placement across distinct modules.
109109
ranges = []
110-
for parallelism in self.module_parallelisms.values():
110+
for name, parallelism in self.module_parallelisms.items():
111111
if parallelism.data_parallel is None:
112112
raise ValueError("data_parallel must be set for heterogeneous deployment.")
113-
ranges.append((parallelism.rank_offset, parallelism.rank_offset + parallelism.total_ranks))
113+
ranges.append((parallelism.rank_offset, parallelism.rank_offset + parallelism.total_ranks, name))
114114

115-
ranges.sort()
115+
ranges.sort(key=lambda x: x[0])
116116
for idx in range(1, len(ranges)):
117117
prev_end = ranges[idx - 1][1]
118118
cur_start = ranges[idx][0]
119119
if cur_start < prev_end:
120120
raise ValueError("rank_offset ranges overlap in heterogeneous deployment.")
121121

122+
# Check for gaps between modules (likely misconfiguration)
123+
# Gaps in the middle are errors; leading gaps (rank_offset > 0) are warnings
124+
if ranges:
125+
min_rank = ranges[0][0] # Already sorted by rank_offset
126+
max_rank = ranges[-1][1]
127+
128+
# Collect all covered ranks
129+
covered_ranks = set()
130+
for parallelism in self.module_parallelisms.values():
131+
start = parallelism.rank_offset
132+
end = start + parallelism.total_ranks
133+
covered_ranks.update(range(start, end))
134+
135+
# Check for gaps between min and max (error - likely misconfiguration)
136+
expected_middle = set(range(min_rank, max_rank))
137+
gaps_in_middle = expected_middle - covered_ranks
138+
if gaps_in_middle:
139+
raise ValueError(
140+
f"Ranks {sorted(gaps_in_middle)} are not assigned to any module in heterogeneous "
141+
f"deployment. This creates a gap between modules which is not allowed."
142+
)
143+
144+
# Check for leading gap (ranks 0 to min_rank-1 unused) - warning only
145+
if min_rank > 0:
146+
warnings.warn(
147+
f"Ranks {list(range(min_rank))} (before first module) are not assigned to any "
148+
f"module in heterogeneous deployment. These ranks will be idle during training.",
149+
stacklevel=3,
150+
)
151+
122152
def finalize(self, world_size: Optional[int]) -> None:
123153
if self.llm_module_name not in self.module_parallelisms:
124154
raise ValueError(f"LLM module '{self.llm_module_name}' not in module_parallelisms.")
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""DDP wrapping utilities for MIMO models.
2+
3+
Called from the training layer after MimoModelProvider.provide().
4+
5+
Note: This module only supports DDP wrapping. FSDP is not yet implemented.
6+
"""
7+
from __future__ import annotations
8+
9+
from typing import TYPE_CHECKING, Dict, Optional
10+
11+
from megatron.bridge.models.mimo.mimo_builder import is_current_rank_in_grid
12+
13+
if TYPE_CHECKING:
14+
from megatron.core.distributed import DistributedDataParallelConfig
15+
from megatron.core.hyper_comm_grid import HyperCommGrid
16+
from megatron.core.models.mimo import MimoModel
17+
from megatron.core.process_groups_config import ProcessGroupCollection
18+
from megatron.bridge.training.mimo_config import MimoParallelismConfig
19+
20+
21+
def wrap_mimo_model_distributed(
22+
mimo_model: "MimoModel",
23+
ddp_config: "DistributedDataParallelConfig",
24+
mimo_parallelism_config: "MimoParallelismConfig",
25+
grids: Dict[str, "HyperCommGrid"],
26+
pg_collections: Dict[str, Optional["ProcessGroupCollection"]],
27+
) -> "MimoModel":
28+
"""Wrap MIMO model's submodules with DDP.
29+
30+
Modifies mimo_model in-place and returns it.
31+
32+
Args:
33+
mimo_model: The MimoModel to wrap.
34+
ddp_config: DDP configuration from Bridge.
35+
mimo_parallelism_config: MIMO parallelism configuration.
36+
grids: Module name to HyperCommGrid mapping.
37+
pg_collections: Module name to ProcessGroupCollection mapping.
38+
39+
Returns:
40+
The same mimo_model with wrapped submodules.
41+
"""
42+
from megatron.core.distributed import DistributedDataParallel
43+
44+
llm_name = mimo_parallelism_config.llm_module_name
45+
46+
# Wrap language model if present and rank participates
47+
if mimo_model.language_model is not None:
48+
llm_grid = grids.get(llm_name)
49+
if llm_grid is not None and is_current_rank_in_grid(llm_grid):
50+
llm_pg = pg_collections.get(llm_name)
51+
if llm_pg is not None:
52+
mimo_model.language_model = DistributedDataParallel(
53+
config=mimo_model.language_model.config,
54+
ddp_config=ddp_config,
55+
module=mimo_model.language_model,
56+
pg_collection=llm_pg,
57+
)
58+
59+
# Wrap modality submodules
60+
if hasattr(mimo_model, 'modality_submodules'):
61+
for module_name, submodule in mimo_model.modality_submodules.items():
62+
if submodule is None:
63+
continue
64+
module_grid = grids.get(module_name)
65+
if module_grid is None:
66+
continue
67+
if not is_current_rank_in_grid(module_grid):
68+
continue
69+
70+
module_pg = pg_collections.get(module_name)
71+
if module_pg is None:
72+
continue
73+
74+
# Get config from first encoder in the submodule.
75+
# Note: We use the first encoder's config for DDP bucket sizing.
76+
# This assumes all encoders in a modality submodule share similar
77+
# parallelism settings, which is typical for MIMO models.
78+
if hasattr(submodule, 'encoders') and submodule.encoders:
79+
encoder_key = next(iter(submodule.encoders.keys()))
80+
first_encoder = submodule.encoders[encoder_key]
81+
82+
if not hasattr(first_encoder, 'config'):
83+
raise AttributeError(
84+
f"Encoder '{encoder_key}' in modality '{module_name}' does not have "
85+
f"a 'config' attribute. Encoders must be MegatronModule subclasses."
86+
)
87+
88+
wrapped = DistributedDataParallel(
89+
config=first_encoder.config,
90+
ddp_config=ddp_config,
91+
module=submodule,
92+
pg_collection=module_pg,
93+
)
94+
mimo_model.modality_submodules[module_name] = wrapped
95+
96+
return mimo_model

0 commit comments

Comments
 (0)