Skip to content
60 changes: 50 additions & 10 deletions src/anemoi/inference/runners/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,26 @@ def _bootstrap_processes(self) -> None:
LOG.warning(
f"world size ({self.config.world_size}) set in the config is ignored because we are launching via srun, using 'SLURM_NTASKS' instead"
)
elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am worried RANK and WORLD_SIZE are too generic, and we would come in here too much by mistake. Also, this block will fail if MASTER_ADDR/PORT are not set.

What about changing this line to
elif "MASTER_ADDR" in os.environ and "WORLD_SIZE in os.environ"

# TODO(refactor): Extract AzureML/general env bootstrap (RANK/WORLD_SIZE/LOCAL_RANK and
# MASTER_ADDR/MASTER_PORT wiring) into a delegated ClusterEnvironment
# (e.g., AzureMLEnvironment.bootstrap()) in a follow-up PR.
# New branch for Azure ML / general distributed env (e.g., env:// mode)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please remove mention of Azure ML here.
You have added a way to bootstrap torch dist from env:// , one of the use cases of which is Azure ML. But I would rather keep the code more generic and not mention Azure ML

self.global_rank = int(os.environ["RANK"])
self.local_rank = int(
os.environ.get("LOCAL_RANK", self.global_rank)
) # Fallback to global if LOCAL_RANK unset
self.world_size = int(os.environ["WORLD_SIZE"])
self.master_addr = os.environ.get("MASTER_ADDR")
self.master_port = os.environ.get("MASTER_PORT")
if self.master_addr is None or self.master_port is None:
raise ValueError(
"MASTER_ADDR and MASTER_PORT must be set for distributed initialization (e.g., in Azure ML)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also remove mention of Azure ML here please

)
if self.config.world_size != 1 and self.config.world_size != self.world_size:
LOG.warning(
f"Config world_size ({self.config.world_size}) ignored; using WORLD_SIZE from environment ({self.world_size})"
)
else:
# If srun is not available, spawn procs manually on a node

Expand Down Expand Up @@ -377,19 +397,31 @@ def _init_parallel(self) -> Optional["torch.distributed.ProcessGroup"]:
else:
backend = "gloo"

dist.init_process_group(
backend=backend,
init_method=f"tcp://{self.master_addr}:{self.master_port}",
timeout=datetime.timedelta(minutes=3),
world_size=self.world_size,
rank=self.global_rank,
)
if backend == "mpi":
# MPI backend: No init_method or explicit sizes needed
dist.init_process_group(backend="mpi")
model_comm_group_ranks = np.arange(self.world_size, dtype=int)
model_comm_group = dist.new_group(model_comm_group_ranks)
else:
# TODO(refactor): Delegate backend + init_method selection and dist.init_process_group(...)
# to a ClusterEnvironment (e.g., env.init_parallel()) so ParallelRunner has no branching here.
if self._using_distributed_env():
init_method = "env://" # Azure ML recommended
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you remove Azure ML here please

else:
init_method = f"tcp://{self.master_addr}:{self.master_port}"
dist.init_process_group(
backend=backend,
init_method=init_method,
timeout=datetime.timedelta(minutes=3),
world_size=self.world_size,
rank=self.global_rank,
)
model_comm_group_ranks = np.arange(self.world_size, dtype=int)
model_comm_group = dist.new_group(model_comm_group_ranks)
LOG.info(f"Creating a model communication group with {self.world_size} devices with the {backend} backend")

model_comm_group_ranks = np.arange(self.world_size, dtype=int)
model_comm_group = dist.new_group(model_comm_group_ranks)
else:
model_comm_group = None
LOG.warning("ParallelRunner selected but world size of 1 detected")

return model_comm_group

Expand All @@ -406,3 +438,11 @@ def _get_parallel_info_from_slurm(self) -> tuple[int, int, int]:
world_size = int(os.environ.get("SLURM_NTASKS", 1)) # Total number of processes

return global_rank, local_rank, world_size

def _using_distributed_env(self) -> bool:
"""Checks for distributed env vars like those in Azure ML."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and remove Azure ML here too please

return "RANK" in os.environ and "WORLD_SIZE" in os.environ

def _is_mpi_env(self) -> bool:
"""Detects common MPI implementations (optional, for generality)."""
return "OMPI_COMM_WORLD_SIZE" in os.environ or "PMI_SIZE" in os.environ