Skip to content

Commit e125823

Browse files
author
白永斌
committed
fix pre-commit
Signed-off-by: 白永斌 <[email protected]>
1 parent bc45968 commit e125823

File tree

4 files changed

+46
-40
lines changed

4 files changed

+46
-40
lines changed

vllm/distributed/eplb/eplb_expert_mapper.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
34
import networkx as nx
45
import numpy as np
56
import torch
67

78

8-
from typing import Dict, List
9-
109
class ComposeExpertUpdate:
1110
def __init__(self, updated_expert_maps, current_expert_maps):
1211
self.updated_org = updated_expert_maps
@@ -39,8 +38,8 @@ def generate(self):
3938
updated_layer = self.updated[layer_id]
4039
current_layer = self.current[layer_id]
4140

42-
expert_send_info_this_layer: Dict[int, List[int]] = {}
43-
expert_recv_info_this_layer: Dict[int, List[int]] = {}
41+
expert_send_info_this_layer: dict[int, list[int]] = {}
42+
expert_recv_info_this_layer: dict[int, list[int]] = {}
4443

4544
# Guard Clause: if there is no expert weight update,
4645
# avoid subsequent processing.
@@ -49,7 +48,7 @@ def generate(self):
4948
expert_send_info_this_layer,
5049
expert_recv_info_this_layer,
5150
self._map_to_yield(layer_id),
52-
layer_id
51+
layer_id,
5352
)
5453

5554
# Main planning
@@ -66,7 +65,7 @@ def generate(self):
6665
expert_send_info_this_layer,
6766
expert_recv_info_this_layer,
6867
self._map_to_yield(layer_id),
69-
layer_id
68+
layer_id,
7069
)
7170

7271

@@ -97,9 +96,9 @@ def _plan_transfers(
9796
for idx in range(len(dst_rank_indices)):
9897
expert_id = experts_to_recv[idx].item()
9998
if expert_id not in src_ranks_set:
100-
src_ranks_set[expert_id] = np.where(
101-
current_layer[:, expert_id] != -1
102-
)[0]
99+
src_ranks_set[expert_id] = np.where(current_layer[:, expert_id] != -1)[
100+
0
101+
]
103102

104103
# Loop until all experts are scheduled
105104
while len(dst_rank_indices) > 0:
@@ -129,7 +128,7 @@ def _plan_transfers(
129128
for src_rank, dst_rank in all_matches.items():
130129
dst_rank = int(dst_rank)
131130
assert src_rank != dst_rank
132-
if graph_expert_update.nodes[src_rank].get('bipartite') == 0:
131+
if graph_expert_update.nodes[src_rank].get("bipartite") == 0:
133132
# currently not scheduled experts in rank dst_rank
134133
experts_v = experts_to_recv[np.where(dst_rank_indices == dst_rank)]
135134
# src: src_rank, dest: dst_rank, expert: expert_id
@@ -161,6 +160,7 @@ def _plan_transfers(
161160

162161
class GreedyExpertUpdate(ComposeExpertUpdate):
163162
"""Greedy version."""
163+
164164
def _prepare_internal(self, updated, current):
165165
# align devices
166166
if not torch.is_tensor(updated):
@@ -209,4 +209,3 @@ def _plan_transfers(
209209

210210
send_dict[src_rank_id].append((dst_rank_id, expert_id))
211211
recv_dict[dst_rank_id].append((src_rank_id, expert_id))
212-

vllm/distributed/eplb/eplb_process.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import multiprocessing as mp
44
import random
5+
from collections.abc import Callable
56
from contextlib import suppress
67
from multiprocessing import Queue
7-
from typing import Callable
88
from queue import Empty
99

1010
import torch
@@ -13,7 +13,7 @@
1313
from vllm.logger import init_logger
1414

1515
from .eplb_expert_mapper import BipartiteExpertUpdate, GreedyExpertUpdate
16-
from .eplb_state import RebalanceTaskArgs, ExpertMapperArgs
16+
from .eplb_state import ExpertMapperArgs, RebalanceTaskArgs
1717

1818
logger = init_logger(__name__)
1919

@@ -44,7 +44,6 @@ def __init__(self, target_func: Callable, num_wait_worker_iterations: int):
4444
self._exception_queue: Queue | None = None
4545
self._step_counter = 0
4646
self._result: tuple | None = None
47-
self._args: tuple | None = None
4847
self._is_running = False
4948
self._has_pending_task = False
5049
self._is_post_processing = False
@@ -65,13 +64,13 @@ def _initialize_process(self) -> None:
6564
self._process = mp.Process(
6665
target=self._worker_loop,
6766
name="EPLBProcess",
68-
args=(self._input_queue, self._result_queue, self._exception_queue)
67+
args=(self._input_queue, self._result_queue, self._exception_queue),
6968
)
7069
self._process.start()
7170
self._is_running = True
7271
logger.debug("EPLB background process started")
7372

74-
except Exception as e:
73+
except Exception:
7574
self.cleanup()
7675
raise
7776

@@ -126,9 +125,8 @@ def generate_log2phy_map(self, expert_map):
126125
num_ranks, num_global_expert = log2phy_map.shape
127126

128127
row_indices = (
129-
torch.arange(num_ranks).view(-1, 1).expand(
130-
num_ranks, num_global_expert
131-
) * num_local_experts
128+
torch.arange(num_ranks).view(-1, 1).expand(num_ranks, num_global_expert)
129+
* num_local_experts
132130
)
133131
log2phy_map[log2phy_map != -1] += row_indices[log2phy_map != -1]
134132

@@ -182,7 +180,8 @@ def _worker_loop(
182180
expert_mapper_args.num_moe_layers,
183181
args.num_gpus,
184182
-1,
185-
))
183+
)
184+
)
186185
if policy_type == "bipartite":
187186
update_info = BipartiteExpertUpdate(
188187
new_deployment, old_deployment
@@ -212,9 +211,7 @@ def _worker_loop(
212211
logger.debug("EPLB worker process exiting")
213212

214213
def submit_task(
215-
self,
216-
args: RebalanceTaskArgs,
217-
expert_mapper_args: ExpertMapperArgs
214+
self, args: RebalanceTaskArgs, expert_mapper_args: ExpertMapperArgs
218215
) -> bool:
219216
"""
220217
Submit a task to the asynchronous process
@@ -242,7 +239,6 @@ def submit_task(
242239
# Put arguments to the input queue
243240
combined_args = (args, expert_mapper_args)
244241
self._input_queue.put(combined_args)
245-
self._args = args
246242
self._has_pending_task = True
247243
self._step_counter = 0
248244
self._result = None
@@ -302,7 +298,7 @@ def cleanup(self) -> None:
302298
# Send sentinel value to stop the process
303299
if self._input_queue:
304300
with suppress(Exception):
305-
self._input_queue.put(None, None)
301+
self._input_queue.put(None)
306302

307303
if self._process:
308304
if self._process.is_alive():

vllm/distributed/eplb/eplb_state.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151

5252
logger = init_logger(__name__)
5353

54+
5455
@dataclass
5556
class RebalanceTaskArgs:
5657
global_expert_load_window: torch.Tensor
@@ -62,7 +63,7 @@ class RebalanceTaskArgs:
6263
@dataclass
6364
class ExpertMapperArgs:
6465
num_moe_layers: int
65-
policy_type: Literal["greedy","bipartite"]
66+
policy_type: Literal["greedy", "bipartite"]
6667
phyhsical_to_logical_map: torch.Tensor
6768

6869
@dataclass
@@ -211,7 +212,7 @@ class EplbState:
211212
"""
212213
Records the current moe layer being precessed for expert weight transfer.
213214
"""
214-
215+
215216
@staticmethod
216217
def build_initial_global_physical_to_logical_map(
217218
num_routed_experts: int,
@@ -381,9 +382,11 @@ def build(
381382
rank_mapping,
382383
)
383384
expert_rearrangement_step = 0
384-
expert_mapper_args = ExpertMapperArgs()
385-
expert_mapper_args.num_moe_layers = model.num_moe_layers
386-
expert_mapper_args.policy_type = parallel_config.eplb_config.expert_mapper_policy_type
385+
expert_mapper_args = ExpertMapperArgs(
386+
model.num_moe_layers,
387+
parallel_config.eplb_config.expert_mapper_policy_type,
388+
None
389+
)
387390
return cls(
388391
physical_to_logical_map,
389392
logical_to_physical_map,
@@ -393,9 +396,11 @@ def build(
393396
expert_load_window_size=expert_load_window_size,
394397
expert_rearrangement_step=expert_rearrangement_step,
395398
expert_rearrangement_step_interval=eplb_step_interval,
396-
num_wait_worker_iterations=parallel_config.eplb_config.num_wait_worker_iterations,
399+
num_wait_worker_iterations=(
400+
parallel_config.eplb_config.num_wait_worker_iterations
401+
),
397402
enable_async=parallel_config.eplb_config.enable_async,
398-
expert_mapper_args=expert_mapper_args
403+
expert_mapper_args=expert_mapper_args,
399404
)
400405

401406
def __post_init__(self):
@@ -518,7 +523,13 @@ def step(
518523
)
519524
input_args = self.rebalance_task_args
520525

521-
self.expert_mapper_args.phyhsical_to_logical_map = self.physical_to_logical_map.cpu()
526+
assert(
527+
self.expert_mapper_args is not None,
528+
"expert_mapper_args is not initialized",
529+
)
530+
self.expert_mapper_args.phyhsical_to_logical_map = (
531+
self.physical_to_logical_map.cpu()
532+
)
522533
expert_mapper_args = self.expert_mapper_args
523534

524535
self.rebalance_task(input_args, expert_mapper_args)
@@ -529,7 +540,7 @@ def step(
529540
+ self.num_wait_worker_iterations
530541
+ model.num_moe_layers
531542
):
532-
self.expert_rearrangement_step = 0
543+
self.expert_rearrangement_step = 0
533544

534545
def rearrange(
535546
self,
@@ -747,9 +758,10 @@ def get_at_index(self, model, result, layer_id) -> list[Any]:
747758
size = len(result)
748759
# check if queue length matches the of layers
749760
if size != model.num_moe_layers:
750-
logger.info(f"size={size}, num_moe_layers={model.num_moe_layers}")
761+
logger.info("size=%s, num_moe_layers=%s", size, model.num_moe_layers)
751762
raise ValueError(
752-
f"Queue length {size} does not match the number of moe layers in the model"
763+
f"Queue length {size} does not match "
764+
"the number of moe layers in the model"
753765
)
754766
if layer_id < 0 or layer_id >= size:
755767
raise ValueError(f"Index {layer_id} out of range for queue of size {size}")
@@ -884,6 +896,7 @@ def should_trigger_rebalance(self):
884896
return self.expert_rearrangement_step == (
885897
self.expert_rearrangement_step_interval - 1
886898
)
899+
887900
def compute_and_set_moe_load(self):
888901
"""
889902
Computes the MoE load across all ranks and sets it in the shared dictionary.
@@ -992,7 +1005,6 @@ def rebalance_task(self, input_args, expert_mapper_args):
9921005
logger.error("Failed to submit rebalance task to async process")
9931006
return None
9941007

995-
9961008
def __del__(self):
9971009
"""Clean up async process resources"""
9981010
if self._async_processor:

vllm/v1/worker/gpu_model_runner.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2410,9 +2410,8 @@ def execute_model(
24102410
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
24112411
with record_function_or_nullcontext("Preprocess"):
24122412
with self.synchronize_input_prep():
2413-
if self.parallel_config.eplb_config.enable_async:
2414-
if self.eplb_state is not None:
2415-
self.eplb_state.step_before_forward(self.get_model())
2413+
if self.parallel_config.eplb_config.enable_async and self.eplb_state:
2414+
self.eplb_state.step_before_forward(self.get_model())
24162415
# Update persistent batch states.
24172416
self._update_states(scheduler_output)
24182417

0 commit comments

Comments
 (0)