|
10 | 10 | import dataclasses |
11 | 11 | import logging |
12 | 12 | import os |
| 13 | +from typing import Any |
13 | 14 |
|
14 | 15 | from anemoi.inference.clusters import cluster_registry |
15 | 16 | from anemoi.inference.clusters.client import ComputeClientFactory |
|
19 | 20 |
|
20 | 21 | @dataclasses.dataclass |
21 | 22 | class EnvMapping: |
22 | | - local_rank: str |
23 | | - global_rank: str |
24 | | - world_size: str |
| 23 | + """Dataclass to hold environment variable mappings for cluster configuration. |
25 | 24 |
|
26 | | - master_addr: str |
27 | | - master_port: str |
| 25 | + Elements can be either strings or lists of strings. |
| 26 | + If a list is provided, the first found environment variable will be used. |
| 27 | + """ |
| 28 | + |
| 29 | + local_rank: str | list[str] |
| 30 | + global_rank: str | list[str] |
| 31 | + world_size: str | list[str] |
| 32 | + |
| 33 | + master_addr: str | list[str] |
| 34 | + master_port: str | list[str] |
28 | 35 |
|
29 | 36 | backend: str | None = None |
30 | 37 | init_method: str = "env://" |
31 | 38 |
|
| 39 | + def get_env(self, key: str, default: Any = None): |
| 40 | + """Get the environment variable value for the given key.""" |
| 41 | + mapped_value = getattr(self, key) |
| 42 | + if mapped_value is None: |
| 43 | + return default |
| 44 | + |
| 45 | + for env_var in (mapped_value if isinstance(mapped_value, list) else [mapped_value]): |
| 46 | + value = os.environ.get(env_var) |
| 47 | + if value is not None: |
| 48 | + return value |
| 49 | + return default |
| 50 | + |
32 | 51 |
|
33 | 52 | @cluster_registry.register("custom") |
34 | 53 | class MappingCluster(ComputeClientFactory): |
@@ -88,27 +107,27 @@ def backend(self) -> str: |
88 | 107 | @property |
89 | 108 | def world_size(self) -> int: |
90 | 109 | """Return the total number of processes in the cluster.""" |
91 | | - return int(os.environ.get(self._mapping.world_size, 1)) |
| 110 | + return int(self._mapping.get_env("world_size", 1)) |
92 | 111 |
|
93 | 112 | @property |
94 | 113 | def global_rank(self) -> int: |
95 | 114 | """Return the rank of the current process.""" |
96 | | - return int(os.environ.get(self._mapping.global_rank, 0)) |
| 115 | + return int(self._mapping.get_env("global_rank", 0)) |
97 | 116 |
|
98 | 117 | @property |
99 | 118 | def local_rank(self) -> int: |
100 | 119 | """Return the rank of the current process.""" |
101 | | - return int(os.environ.get(self._mapping.local_rank, self.global_rank)) |
| 120 | + return int(self._mapping.get_env("local_rank", self.global_rank)) |
102 | 121 |
|
103 | 122 | @property |
104 | 123 | def master_addr(self) -> str: |
105 | 124 | """Return the master address.""" |
106 | | - return os.environ.get(self._mapping.master_addr, "") |
| 125 | + return self._mapping.get_env("master_addr", "") |
107 | 126 |
|
108 | 127 | @property |
109 | 128 | def master_port(self) -> int: |
110 | 129 | """Return the master port.""" |
111 | | - return int(os.environ.get(self._mapping.master_port, 0)) |
| 130 | + return int(self._mapping.get_env("master_port", 0)) |
112 | 131 |
|
113 | 132 | @classmethod |
114 | 133 | def used(cls) -> bool: |
|
0 commit comments