Skip to content

[0.9.1][Dist][Bugfix] Fix mc2 process group to resolve self.cpu_group is None #1831

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: v0.9.1-dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions vllm_ascend/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,22 @@ def model_parallel_initialized():

def init_ascend_model_parallel(
expert_parallel_size: int = 1,
world_size: Optional[int] = None,
backend: Optional[str] = None,
):
if model_parallel_initialized():
return
assert torch.distributed.is_initialized()
world_size = world_size or torch.distributed.get_world_size()
world_size = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
num_expert_parallel_groups = world_size // expert_parallel_size

# The layout of all ranks: ExternalDP * EP
# ExternalDP is the data parallel group that is not part of the model,
# every dp rank can generate independently (in verl integration).
all_ranks = torch.arange(world_size).reshape(-1, expert_parallel_size)
global _MC2
group_ranks = []
for i in range(num_expert_parallel_groups):
ranks = list(range(i, world_size, num_expert_parallel_groups))
group_ranks.append(ranks)
group_ranks = all_ranks.unbind(0)
group_ranks = [x.tolist() for x in group_ranks]

_MC2 = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
Expand Down
3 changes: 1 addition & 2 deletions vllm_ascend/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,8 +546,7 @@ def _init_worker_distributed_environment(
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
init_ascend_model_parallel(parallel_config.expert_parallel_size,
parallel_config.world_size_across_dp)
init_ascend_model_parallel(parallel_config.expert_parallel_size)
ensure_kv_transfer_initialized(vllm_config)


Expand Down
3 changes: 1 addition & 2 deletions vllm_ascend/worker/worker_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,7 @@ def _init_worker_distributed_environment(self) -> None:
ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size)
init_ascend_model_parallel(self.parallel_config.expert_parallel_size,
self.parallel_config.world_size_across_dp)
init_ascend_model_parallel(self.parallel_config.expert_parallel_size)
ensure_kv_transfer_initialized(self.vllm_config)

def _init_profiler(self):
Expand Down