Skip to content

Commit c9c7f35

Browse files
committed
Add list mappings
1 parent d687b8c commit c9c7f35

File tree

5 files changed

+65
-19
lines changed

5 files changed

+65
-19
lines changed

src/anemoi/inference/clusters/distributed.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
# nor does it submit to any jurisdiction.
88
#
99

10-
import os
1110

1211
from anemoi.inference.clusters import cluster_registry
1312
from anemoi.inference.clusters.mapping import EnvMapping
@@ -32,4 +31,4 @@ def __init__(self) -> None:
3231

3332
@classmethod
3433
def used(cls) -> bool:
35-
return DISTRIBUTED_MAPPING.global_rank in os.environ and DISTRIBUTED_MAPPING.local_rank in os.environ
34+
return bool(DISTRIBUTED_MAPPING.get_env("world_size")) and bool(DISTRIBUTED_MAPPING.get_env("global_rank"))

src/anemoi/inference/clusters/mapping.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import dataclasses
1111
import logging
1212
import os
13+
from typing import Any
1314

1415
from anemoi.inference.clusters import cluster_registry
1516
from anemoi.inference.clusters.client import ComputeClientFactory
@@ -19,16 +20,34 @@
1920

2021
@dataclasses.dataclass
2122
class EnvMapping:
22-
local_rank: str
23-
global_rank: str
24-
world_size: str
23+
"""Dataclass to hold environment variable mappings for cluster configuration.
2524
26-
master_addr: str
27-
master_port: str
25+
Elements can be either strings or lists of strings.
26+
If a list is provided, the first found environment variable will be used.
27+
"""
28+
29+
local_rank: str | list[str]
30+
global_rank: str | list[str]
31+
world_size: str | list[str]
32+
33+
master_addr: str | list[str]
34+
master_port: str | list[str]
2835

2936
backend: str | None = None
3037
init_method: str = "env://"
3138

39+
def get_env(self, key: str, default: Any = None):
40+
"""Get the environment variable value for the given key."""
41+
mapped_value = getattr(self, key)
42+
if mapped_value is None:
43+
return default
44+
45+
for env_var in (mapped_value if isinstance(mapped_value, list) else [mapped_value]):
46+
value = os.environ.get(env_var)
47+
if value is not None:
48+
return value
49+
return default
50+
3251

3352
@cluster_registry.register("custom")
3453
class MappingCluster(ComputeClientFactory):
@@ -88,27 +107,27 @@ def backend(self) -> str:
88107
@property
89108
def world_size(self) -> int:
90109
"""Return the total number of processes in the cluster."""
91-
return int(os.environ.get(self._mapping.world_size, 1))
110+
return int(self._mapping.get_env("world_size", 1))
92111

93112
@property
94113
def global_rank(self) -> int:
95114
"""Return the rank of the current process."""
96-
return int(os.environ.get(self._mapping.global_rank, 0))
115+
return int(self._mapping.get_env("global_rank", 0))
97116

98117
@property
99118
def local_rank(self) -> int:
100119
"""Return the rank of the current process."""
101-
return int(os.environ.get(self._mapping.local_rank, self.global_rank))
120+
return int(self._mapping.get_env("local_rank", self.global_rank))
102121

103122
@property
104123
def master_addr(self) -> str:
105124
"""Return the master address."""
106-
return os.environ.get(self._mapping.master_addr, "")
125+
return self._mapping.get_env("master_addr", "")
107126

108127
@property
109128
def master_port(self) -> int:
110129
"""Return the master port."""
111-
return int(os.environ.get(self._mapping.master_port, 0))
130+
return int(self._mapping.get_env("master_port", 0))
112131

113132
@classmethod
114133
def used(cls) -> bool:

src/anemoi/inference/clusters/mpi.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#
99

1010
import logging
11-
import os
1211

1312
from anemoi.inference.clusters import cluster_registry
1413
from anemoi.inference.clusters.mapping import EnvMapping
@@ -18,9 +17,9 @@
1817
LOG = logging.getLogger(__name__)
1918

2019
MPI_MAPPING = EnvMapping(
21-
local_rank="OMPI_COMM_WORLD_LOCAL_RANK",
22-
global_rank="OMPI_COMM_WORLD_RANK",
23-
world_size="OMPI_COMM_WORLD_SIZE",
20+
local_rank=["OMPI_COMM_WORLD_LOCAL_RANK", "PMI_RANK"],
21+
global_rank=["OMPI_COMM_WORLD_RANK", "PMI_RANK"],
22+
world_size=["OMPI_COMM_WORLD_SIZE", "PMI_SIZE"],
2423
master_addr="MASTER_ADDR",
2524
master_port="MASTER_PORT",
2625
init_method="tcp://{master_addr}:{master_port}",
@@ -44,7 +43,7 @@ def __init__(self, use_mpi_backend: bool = False, **kwargs) -> None:
4443

4544
@classmethod
4645
def used(cls) -> bool:
47-
return MPI_MAPPING.world_size in os.environ or "PMI_SIZE" in os.environ
46+
return bool(MPI_MAPPING.get_env("world_size"))
4847

4948
@property
5049
def backend(self) -> str:

src/anemoi/inference/clusters/slurm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self) -> None:
4242
def used(cls) -> bool:
4343
# from pytorch lightning
4444
# https://github.com/Lightning-AI/pytorch-lightning/blob/a944e7744e57a5a2c13f3c73b9735edf2f71e329/src/lightning/fabric/plugins/environments/slurm.py
45-
return SLURM_MAPPING.world_size in os.environ and os.environ.get("SLURM_JOB_NAME") not in (
45+
return bool(SLURM_MAPPING.get_env("world_size")) and os.environ.get("SLURM_JOB_NAME") not in (
4646
"bash",
4747
"interactive",
4848
)

tests/unit/test_clusters.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def test_distributed_cluster_used_detection(self):
320320
assert not DistributedCluster.used()
321321

322322
# In distributed environment (torchrun)
323-
with patch.dict(os.environ, {"RANK": "3", "LOCAL_RANK": "1"}):
323+
with patch.dict(os.environ, {"RANK": "3", "WORLD_SIZE": "4"}):
324324
assert DistributedCluster.used()
325325

326326
def test_distributed_cluster_initialization(self):
@@ -422,6 +422,35 @@ def test_mapping_cluster_not_used(self):
422422
"""Test that MappingCluster.used() returns False."""
423423
assert not MappingCluster.used()
424424

425+
def test_mapping_cluster_list(self):
426+
"""Test MappingCluster with list mapping, such that the first found env var is used."""
427+
mapping = EnvMapping(
428+
local_rank=["MY_LOCAL_RANK_A", "MY_LOCAL_RANK_B"],
429+
global_rank=["MY_GLOBAL_RANK_A", "MY_GLOBAL_RANK_B"],
430+
world_size=["MY_WORLD_SIZE_A", "MY_WORLD_SIZE_B"],
431+
master_addr=["MY_MASTER_ADDR_A", "MY_MASTER_ADDR_B"],
432+
master_port=["MY_MASTER_PORT_A", "MY_MASTER_PORT_B"],
433+
init_method="tcp://{master_addr}:{master_port}",
434+
)
435+
436+
with patch.dict(
437+
os.environ,
438+
{
439+
"MY_WORLD_SIZE_B": "16",
440+
"MY_GLOBAL_RANK_A": "5",
441+
"MY_LOCAL_RANK_B": "2",
442+
"MY_MASTER_ADDR_A": "192.168.1.1",
443+
"MY_MASTER_PORT_B": "40000",
444+
},
445+
):
446+
cluster = MappingCluster(mapping=mapping)
447+
448+
assert cluster.world_size == 16
449+
assert cluster.global_rank == 5
450+
assert cluster.local_rank == 2
451+
assert cluster.master_addr == "192.168.1.1"
452+
assert cluster.master_port == 40000
453+
425454

426455
class TestClusterRegistry:
427456
"""Tests for cluster registry and creation."""

0 commit comments

Comments
 (0)