diff --git a/src/anemoi/inference/runners/parallel.py b/src/anemoi/inference/runners/parallel.py index f73368e4..c13ccc90 100644 --- a/src/anemoi/inference/runners/parallel.py +++ b/src/anemoi/inference/runners/parallel.py @@ -276,6 +276,26 @@ def _bootstrap_processes(self) -> None: LOG.warning( f"world size ({self.config.world_size}) set in the config is ignored because we are launching via srun, using 'SLURM_NTASKS' instead" ) + elif "RANK" in os.environ and "WORLD_SIZE" in os.environ: + # TODO(refactor): Extract AzureML/general env bootstrap (RANK/WORLD_SIZE/LOCAL_RANK and + # MASTER_ADDR/MASTER_PORT wiring) into a delegated ClusterEnvironment + # (e.g., AzureMLEnvironment.bootstrap()) in a follow-up PR. + # New branch for Azure ML / general distributed env (e.g., env:// mode) + self.global_rank = int(os.environ["RANK"]) + self.local_rank = int( + os.environ.get("LOCAL_RANK", self.global_rank) + ) # Fallback to global if LOCAL_RANK unset + self.world_size = int(os.environ["WORLD_SIZE"]) + self.master_addr = os.environ.get("MASTER_ADDR") + self.master_port = os.environ.get("MASTER_PORT") + if self.master_addr is None or self.master_port is None: + raise ValueError( + "MASTER_ADDR and MASTER_PORT must be set for distributed initialization (e.g., in Azure ML)" + ) + if self.config.world_size != 1 and self.config.world_size != self.world_size: + LOG.warning( + f"Config world_size ({self.config.world_size}) ignored; using WORLD_SIZE from environment ({self.world_size})" + ) else: # If srun is not available, spawn procs manually on a node @@ -377,19 +397,31 @@ def _init_parallel(self) -> Optional["torch.distributed.ProcessGroup"]: else: backend = "gloo" - dist.init_process_group( - backend=backend, - init_method=f"tcp://{self.master_addr}:{self.master_port}", - timeout=datetime.timedelta(minutes=3), - world_size=self.world_size, - rank=self.global_rank, - ) + if backend == "mpi": + # MPI backend: No init_method or explicit sizes needed + dist.init_process_group(backend="mpi") + model_comm_group_ranks = np.arange(self.world_size, dtype=int) + model_comm_group = dist.new_group(model_comm_group_ranks) + else: + # TODO(refactor): Delegate backend + init_method selection and dist.init_process_group(...) + # to a ClusterEnvironment (e.g., env.init_parallel()) so ParallelRunner has no branching here. + if self._using_distributed_env(): + init_method = "env://" # Azure ML recommended + else: + init_method = f"tcp://{self.master_addr}:{self.master_port}" + dist.init_process_group( + backend=backend, + init_method=init_method, + timeout=datetime.timedelta(minutes=3), + world_size=self.world_size, + rank=self.global_rank, + ) + model_comm_group_ranks = np.arange(self.world_size, dtype=int) + model_comm_group = dist.new_group(model_comm_group_ranks) LOG.info(f"Creating a model communication group with {self.world_size} devices with the {backend} backend") - - model_comm_group_ranks = np.arange(self.world_size, dtype=int) - model_comm_group = dist.new_group(model_comm_group_ranks) else: model_comm_group = None + LOG.warning("ParallelRunner selected but world size of 1 detected") return model_comm_group @@ -406,3 +438,11 @@ def _get_parallel_info_from_slurm(self) -> tuple[int, int, int]: world_size = int(os.environ.get("SLURM_NTASKS", 1)) # Total number of processes return global_rank, local_rank, world_size + + def _using_distributed_env(self) -> bool: + """Checks for distributed env vars like those in Azure ML.""" + return "RANK" in os.environ and "WORLD_SIZE" in os.environ + + def _is_mpi_env(self) -> bool: + """Detects common MPI implementations (optional, for generality).""" + return "OMPI_COMM_WORLD_SIZE" in os.environ or "PMI_SIZE" in os.environ