-
Notifications
You must be signed in to change notification settings - Fork 24
feat: Add Azure ML compatibility to ParallelRunner #329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e7b06ea
a078434
aa9d962
6d1f28b
c39939e
b5315ee
5fafca9
14c3877
8828052
8828493
e202d0e
2d4f824
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -276,6 +276,26 @@ def _bootstrap_processes(self) -> None: | |
| LOG.warning( | ||
| f"world size ({self.config.world_size}) set in the config is ignored because we are launching via srun, using 'SLURM_NTASKS' instead" | ||
| ) | ||
| elif "RANK" in os.environ and "WORLD_SIZE" in os.environ: | ||
| # TODO(refactor): Extract AzureML/general env bootstrap (RANK/WORLD_SIZE/LOCAL_RANK and | ||
| # MASTER_ADDR/MASTER_PORT wiring) into a delegated ClusterEnvironment | ||
| # (e.g., AzureMLEnvironment.bootstrap()) in a follow-up PR. | ||
| # New branch for Azure ML / general distributed env (e.g., env:// mode) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please remove mention of Azure ML here. |
||
| self.global_rank = int(os.environ["RANK"]) | ||
| self.local_rank = int( | ||
| os.environ.get("LOCAL_RANK", self.global_rank) | ||
| ) # Fallback to global if LOCAL_RANK unset | ||
| self.world_size = int(os.environ["WORLD_SIZE"]) | ||
| self.master_addr = os.environ.get("MASTER_ADDR") | ||
| self.master_port = os.environ.get("MASTER_PORT") | ||
| if self.master_addr is None or self.master_port is None: | ||
| raise ValueError( | ||
| "MASTER_ADDR and MASTER_PORT must be set for distributed initialization (e.g., in Azure ML)" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you also remove mention of Azure ML here please |
||
| ) | ||
| if self.config.world_size != 1 and self.config.world_size != self.world_size: | ||
| LOG.warning( | ||
| f"Config world_size ({self.config.world_size}) ignored; using WORLD_SIZE from environment ({self.world_size})" | ||
| ) | ||
| else: | ||
| # If srun is not available, spawn procs manually on a node | ||
|
|
||
|
|
@@ -377,19 +397,31 @@ def _init_parallel(self) -> Optional["torch.distributed.ProcessGroup"]: | |
| else: | ||
| backend = "gloo" | ||
|
|
||
| dist.init_process_group( | ||
| backend=backend, | ||
| init_method=f"tcp://{self.master_addr}:{self.master_port}", | ||
| timeout=datetime.timedelta(minutes=3), | ||
| world_size=self.world_size, | ||
| rank=self.global_rank, | ||
| ) | ||
| if backend == "mpi": | ||
| # MPI backend: No init_method or explicit sizes needed | ||
| dist.init_process_group(backend="mpi") | ||
| model_comm_group_ranks = np.arange(self.world_size, dtype=int) | ||
| model_comm_group = dist.new_group(model_comm_group_ranks) | ||
| else: | ||
| # TODO(refactor): Delegate backend + init_method selection and dist.init_process_group(...) | ||
| # to a ClusterEnvironment (e.g., env.init_parallel()) so ParallelRunner has no branching here. | ||
| if self._using_distributed_env(): | ||
| init_method = "env://" # Azure ML recommended | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you remove Azure ML here please |
||
| else: | ||
| init_method = f"tcp://{self.master_addr}:{self.master_port}" | ||
| dist.init_process_group( | ||
| backend=backend, | ||
| init_method=init_method, | ||
| timeout=datetime.timedelta(minutes=3), | ||
| world_size=self.world_size, | ||
| rank=self.global_rank, | ||
| ) | ||
| model_comm_group_ranks = np.arange(self.world_size, dtype=int) | ||
| model_comm_group = dist.new_group(model_comm_group_ranks) | ||
| LOG.info(f"Creating a model communication group with {self.world_size} devices with the {backend} backend") | ||
|
|
||
| model_comm_group_ranks = np.arange(self.world_size, dtype=int) | ||
| model_comm_group = dist.new_group(model_comm_group_ranks) | ||
| else: | ||
| model_comm_group = None | ||
| LOG.warning("ParallelRunner selected but world size of 1 detected") | ||
|
|
||
| return model_comm_group | ||
|
|
||
|
|
@@ -406,3 +438,11 @@ def _get_parallel_info_from_slurm(self) -> tuple[int, int, int]: | |
| world_size = int(os.environ.get("SLURM_NTASKS", 1)) # Total number of processes | ||
|
|
||
| return global_rank, local_rank, world_size | ||
|
|
||
| def _using_distributed_env(self) -> bool: | ||
| """Checks for distributed env vars like those in Azure ML.""" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and remove Azure ML here too please |
||
| return "RANK" in os.environ and "WORLD_SIZE" in os.environ | ||
|
|
||
| def _is_mpi_env(self) -> bool: | ||
| """Detects common MPI implementations (optional, for generality).""" | ||
| return "OMPI_COMM_WORLD_SIZE" in os.environ or "PMI_SIZE" in os.environ | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am worried RANK and WORLD_SIZE are too generic, and we would come in here too much by mistake. Also, this block will fail if MASTER_ADDR/PORT are not set.
What about changing this line to
elif "MASTER_ADDR" in os.environ and "WORLD_SIZE in os.environ"