diff --git a/docs/inference/parallel.rst b/docs/inference/parallel.rst index 6eb758ac..4b5ad692 100644 --- a/docs/inference/parallel.rst +++ b/docs/inference/parallel.rst @@ -2,52 +2,142 @@ Parallel Inference #################### -If the memory requirements of your model are too large to fit within a -single GPU, you can run Anemoi-Inference in parallel across multiple -GPUs. +.. contents:: Table of Contents + :local: + :depth: 2 -You have two options to launch parallel inference: - - Launch without Slurm. This allows you to run inference across - multiple GPUs **on a single node**. - - Launch via Slurm. Slurm is needed to run inference **across - multiple nodes**. +If the memory requirements of your model are too large to fit within a +single device, you can run Anemoi-Inference in parallel across multiple +devices. The parallel runner distributes the model across devices and +coordinates inference execution. *************** Prerequisites *************** -Parallel inference requires a certain minimum version of Anemoi-models ->= v0.4.2. If this breaks your checkpoints, you could cherry-pick `the -relevant PR `_ into your -old version of Anemoi-Models. +Parallel inference requires: + +- Anemoi-Models >= v0.4.2 (for model parallelism support) +- Multiple devices available on your system or cluster + +.. note:: + + If updating to Anemoi-Models v0.4.2 breaks your existing checkpoints, + you can cherry-pick `the relevant PR + `_ into your old + version of Anemoi-Models. *************** Configuration *************** -To run in parallel, you must add '``runner:parallel``' to your inference -config file. If you are running in parallel without Slurm, you must also -add a '``world_size: num_gpus``' field. This informs Anemoi-Inference -how many GPUs you want to run across. It cannot be greater then the -number of GPUs on a single node. +To run in parallel, add ``runner: parallel`` to your inference config +file. The parallel runner will automatically detect your cluster +environment (Slurm, MPI, torchrun, etc.) and configure itself +accordingly. -.. note:: +Basic Configuration +=================== - If you are launching parallel inference via Slurm, '``world_size``' - will be ignored in favour of the '``SLURM_NTASKS``' environment - variable. +For environments with automatic cluster detection (Slurm, MPI, +torchrun), a minimal configuration is sufficient: .. code:: yaml checkpoint: /path/to/inference-last.ckpt lead_time: 60 runner: parallel - world_size: 4 #Only required if running parallel inference without Slurm + + input: + grib: /path/to/input.grib + output: + grib: /path/to/output.grib + +Supported Cluster Types +======================= + +The following cluster types are automatically detected: + +.. list-table:: + :header-rows: 1 + :widths: 20 30 50 + + - - Cluster Type + - Detection Method + - Environment Variables Used + + - - **Slurm** + - Presence of ``SLURM_NTASKS`` and ``SLURM_JOB_NAME`` + - ``SLURM_PROCID``, ``SLURM_LOCALID``, ``SLURM_NTASKS``, + ``SLURM_NODELIST`` + + - - **MPI** + - Presence of ``OMPI_COMM_WORLD_SIZE`` or ``PMI_SIZE`` + - ``OMPI_COMM_WORLD_RANK``, ``OMPI_COMM_WORLD_LOCAL_RANK``, + ``OMPI_COMM_WORLD_SIZE`` + + - - **Distributed (torchrun)** + - Presence of ``RANK`` and ``LOCAL_RANK`` + - ``RANK``, ``LOCAL_RANK``, ``WORLD_SIZE``, ``MASTER_ADDR``, + ``MASTER_PORT`` + +Manual Cluster Configuration +============================ + +If you are running in an environment without automatic detection, use +the manual cluster +(:class:`anemoi.inference.clusters.manual.ManualCluster`) by specifying +the cluster as ``manual`` and the ``world_size`` (number of devices): + +.. code:: yaml + + checkpoint: /path/to/inference-last.ckpt + lead_time: 60 + runner: + parallel: + cluster: + manual: 4 # Use 4 devices + + input: + grib: /path/to/input.grib + output: + grib: /path/to/output.grib + +.. warning:: + + The ``world_size`` cannot exceed the number of available devices on + your system. + +Custom Cluster Mapping +====================== + +Additionally, if you have a custom cluster environment, you can specify +your own environment variable mapping: + +.. code:: yaml + + checkpoint: /path/to/inference-last.ckpt + lead_time: 60 + runner: + parallel: + cluster: + custom: + mapping: + local_rank: LOCAL_RANK_ENV_VAR + global_rank: GLOBAL_RANK_ENV_VAR + world_size: WORLD_SIZE_ENV_VAR + master_addr: MASTER_ADDR_ENV_VAR + master_port: MASTER_PORT_ENV_VAR + init_method: env:// + input: grib: /path/to/input.grib output: grib: /path/to/output.grib +Base Runner +----------- + By default, the `parallel` runner inherits from the `default` runner (:class:`anemoi.inference.runners.default.DefaultRunner`). If you want to run a different runner in parallel, you can pass the ``base_runner`` @@ -62,20 +152,23 @@ option: Any additional options passed to the `parallel` runner will be forwarded to the ``base_runner``. -********************************************* - Running inference in parallel without Slurm -********************************************* +******************************* + Running Inference in Parallel +******************************* + +Once you have configured ``runner: parallel`` in your config file, you +can launch parallel inference by calling ``anemoi-inference run +config.yaml`` as normal. -Once you have added '``runner:parallel``' and '``world_size: num_gpus``' -to your config file, you can launch parallel inference by calling -'``anemoi-inferece run config.yaml``' as normal. +If you are using a cluster manager like Slurm or MPI, you must launch +your job using the appropriate launcher (``srun``, ``mpirun``, etc). See +the examples below. -****************************************** - Running inference in parallel with Slurm -****************************************** +Parallel with Slurm +=================== Below is an example SLURM batch script to launch a parallel inference -job across 4 GPUs with SLURM. +job across 4 GPUs. .. code:: bash @@ -92,16 +185,16 @@ job across 4 GPUs with SLURM. .. warning:: - If you specify '``runner:parallel``' but you don't launch with - '``srun``', your anemoi-inference job may hang as only 1 process will - be launched. + If you specify ``runner: parallel`` but don't launch with ``srun``, + your anemoi-inference job may hang as only 1 process will be + launched. .. note:: - By default, anemoi-inference will determine your systems master - address and port itself. If this fails (i.e. when running - Anemoi-Inference inside a container), you can instead set these - values yourself via environment variables in your SLURM batch script: + By default, anemoi-inference will determine your system's master + address and port automatically. If this fails (e.g., when running + inside a container), you can set these values manually via + environment variables in your SLURM batch script: .. code:: bash @@ -110,3 +203,154 @@ job across 4 GPUs with SLURM. export MASTER_PORT=$((10000 + RANDOM % 10000)) srun anemoi-inference run parallel.yaml + +Parallel with MPI +================= + +To run parallel inference with MPI, use ``mpirun`` or ``mpiexec`` to +launch your job: + +.. code:: bash + + #!/bin/bash + #SBATCH --nodes=1 + #SBATCH --ntasks-per-node=4 + #SBATCH --gpus-per-node=4 + #SBATCH --cpus-per-task=8 + #SBATCH --time=0:05:00 + #SBATCH --output=outputs/parallel_inf_mpi.%j.out + + source /path/to/venv/bin/activate + + # Set master address and port for communication + MASTER_ADDR=$(scontrol show hostname $SLURM_NODELIST | head -n 1) + export MASTER_ADDR=$(nslookup $MASTER_ADDR | grep -oP '(?<=Address: ).*') + export MASTER_PORT=29500 + + mpirun -np 4 anemoi-inference run parallel.yaml + +.. note:: + + If your torch supports it (PyTorch must be compiled from source with + MPI support to use the MPI backend to torch.distributed), you can use + the ``mpi`` torch backend by configuring: + + .. code:: yaml + + runner: + parallel: + cluster: + mpi: + use_mpi_backend: true + +Parallel with torchrun +====================== + +For environments without a cluster manager, you can use PyTorch's +``torchrun`` utility: + +.. code:: bash + + #!/bin/bash + + source /path/to/venv/bin/activate + + torchrun --nproc_per_node=4 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=localhost \ + --master_port=29500 \ + $(which anemoi-inference) run parallel.yaml + +.. note:: + + When using ``torchrun``, the distributed environment variables + (``RANK``, ``LOCAL_RANK``, ``WORLD_SIZE``, etc.) are automatically + set by torchrun. + +*********************** + Environment Variables +*********************** + +The following environment variables can be used to customise parallel +inference: + +.. list-table:: + :header-rows: 1 + :widths: 30 70 + + - - Environment Variable + - Description + + - - ``ANEMOI_BASE_SEED`` + + - Base seed for reproducible inference. Will be broadcast from + rank 0 to all ranks. Values < 1000 are automatically multiplied + by 1000. + +***************** + Troubleshooting +***************** + +Common Issues +============= + +.. list-table:: + :header-rows: 1 + :widths: 30 70 + + - - Issue + - Solution + + - - **Job hangs indefinitely** + + - Ensure you're launching with the appropriate launcher + (``srun``, ``mpirun``, ``torchrun``). Check that the number of + processes matches your configuration. + + - - **"No suitable cluster found" error** + - Add explicit cluster configuration using ``cluster: manual`` or + verify your environment variables are set correctly. + + - - **Version compatibility error** + - Upgrade to Anemoi-Models >= v0.4.2 or cherry-pick the `parallel + inference PR `_. + + - - **CUDA out of memory** + + - Increase the number of devices (``world_size``) to distribute + the model across more devices. Or, increase the chunking with + ``ANEMOI_INFERENCE_NUM_CHUNKS``. + + - - **Port already in use** + - Set ``MASTER_PORT`` to a different port number, or let Slurm + auto-generate one. + + - - **Communication timeout** + - Check firewall settings and ensure all nodes can communicate. + Verify ``MASTER_ADDR`` is accessible from all ranks. + +Verification Checklist +====================== + +Before running parallel inference, verify: + +#. ✓ Anemoi-Models version >= v0.4.2 +#. ✓ Multiple devices available (``nvidia-smi`` or equivalent) +#. ✓ Configuration includes ``runner: parallel`` +#. ✓ Using appropriate launcher (``srun``, ``mpirun``, or ``torchrun``) +#. ✓ Number of processes matches available devices +#. ✓ Network connectivity between nodes (multi-node only) + +Expected Output +=============== + +When parallel inference runs successfully, you should see log messages +indicating: + +- Cluster type detected (e.g., "Using compute client: SlurmCluster") +- Rank information (e.g., "rank00", "rank01", etc.) +- Model loading on each rank +- Inference progress from rank 0 (master) + +Only rank 0 produces output files; other ranks assist with computation. diff --git a/src/anemoi/inference/clusters/__init__.py b/src/anemoi/inference/clusters/__init__.py new file mode 100644 index 00000000..e1911197 --- /dev/null +++ b/src/anemoi/inference/clusters/__init__.py @@ -0,0 +1,50 @@ +# (C) Copyright 2025- ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from typing import Any + +from anemoi.utils.registry import Registry + +from .client import ComputeClientFactory +from .spawner import ComputeSpawner + +cluster_registry: Registry[ComputeClientFactory | ComputeSpawner] = Registry(__name__) + + +def create_cluster(config: dict[str, Any] | str, *args, **kwargs) -> ComputeClientFactory | ComputeSpawner: + """Find and return the appropriate cluster for the current environment. + + Parameters + ---------- + config : dict + Configuration for the cluster. + Can be string or dict. + args : Any + Additional positional arguments. + kwargs : Any + Additional keyword arguments. + + Returns + ------- + Cluster + The created cluster instance. + """ + if config: + return cluster_registry.from_config(config, *args, **kwargs) + + for cluster in cluster_registry.factories: + cluster_cls = cluster_registry.lookup(cluster) + assert cluster_cls is not None + + if cluster_cls.used(): + return cluster_cls(*args, **kwargs) + + raise RuntimeError( + f"No suitable cluster found for the current environment,\nDiscovered implementations were {cluster_registry.registered}." + ) diff --git a/src/anemoi/inference/clusters/client.py b/src/anemoi/inference/clusters/client.py new file mode 100644 index 00000000..00ed591f --- /dev/null +++ b/src/anemoi/inference/clusters/client.py @@ -0,0 +1,144 @@ +# (C) Copyright 2025- ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import datetime +import logging +from abc import ABC +from abc import abstractmethod +from dataclasses import dataclass +from typing import Protocol + +from anemoi.inference.lazy import torch + +LOG = logging.getLogger(__name__) + + +class ClusterClientProtocol(Protocol): + @classmethod + def used(cls) -> bool: + """Check if this client is valid in the current environment.""" + ... + + +@dataclass +class ComputeClient: + world_size: int + + local_rank: int + global_rank: int + + master_addr: str + master_port: int + + process_group: "torch.distributed.ProcessGroup | None" + + @property + def is_master(self) -> bool: + """Return True if the current process is the master process.""" + return self.global_rank == 0 + + +class ComputeClientFactory(ABC): + """Abstract factory class for compute client creation.""" + + def create_client(self) -> ComputeClient: + """Create and return a ComputeClient instance.""" + return ComputeClient( + process_group=self.create_model_comm_group(), + world_size=self.world_size, + local_rank=self.local_rank, + global_rank=self.global_rank, + master_addr=self.master_addr, + master_port=self.master_port, + ) + + @classmethod + @abstractmethod + def used(cls) -> bool: + """Check if this client is valid in the current environment.""" + raise NotImplementedError + + @property + def init_method(self) -> str: + """Return the initialisation method string for distributed computing.""" + return f"tcp://{self.master_addr}:{self.master_port}" + + @property + def backend(self) -> str: + """Return the backend for distributed computing.""" + return "nccl" if torch.cuda.is_available() else "gloo" # type: ignore + + def create_model_comm_group(self) -> "torch.distributed.ProcessGroup | None": + """Create the communication group for model parallelism.""" + if self.world_size <= 1: + return None + + LOG.debug("Creating model communication group for parallel inference") + group = torch.distributed.init_process_group( + backend=self.backend, + init_method=self.init_method, + timeout=datetime.timedelta(minutes=3), + world_size=self.world_size, + rank=self.global_rank, + ) + + # Create a new process group for model communication + group = torch.distributed.new_group( + ranks=list(range(self.world_size)), + ) + LOG.info("Model communication group created") + + return group + + @property + def is_master(self) -> bool: + """Return True if the current process is the master process.""" + return self.global_rank == 0 + + @property + @abstractmethod + def local_rank(self) -> int: + """Return the rank of the current process.""" + raise NotImplementedError + + @property + def device_index(self) -> int: + """Return the device index for the current process, defaults to local rank.""" + return self.local_rank + + @property + @abstractmethod + def global_rank(self) -> int: + """Return the rank of the current process.""" + raise NotImplementedError + + @property + @abstractmethod + def world_size(self) -> int: + """Return the total number of processes in the cluster.""" + raise NotImplementedError + + @property + @abstractmethod + def master_addr(self) -> str: + """Return the master address.""" + raise NotImplementedError + + @property + @abstractmethod + def master_port(self) -> int: + """Return the master port.""" + raise NotImplementedError + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(world_size={self.world_size}, " + f"global_rank={self.global_rank}, local_rank={self.local_rank}, " + f"master_addr='{self.master_addr}', master_port={self.master_port})" + ) diff --git a/src/anemoi/inference/clusters/distributed.py b/src/anemoi/inference/clusters/distributed.py new file mode 100644 index 00000000..c4c3b14a --- /dev/null +++ b/src/anemoi/inference/clusters/distributed.py @@ -0,0 +1,34 @@ +# (C) Copyright 2025- ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +from anemoi.inference.clusters import cluster_registry +from anemoi.inference.clusters.mapping import EnvMapping +from anemoi.inference.clusters.mapping import MappingCluster + +DISTRIBUTED_MAPPING = EnvMapping( + local_rank="LOCAL_RANK", + global_rank="RANK", + world_size="WORLD_SIZE", + master_addr="MASTER_ADDR", + master_port="MASTER_PORT", + init_method="env://", +) + + +@cluster_registry.register("distributed") +class DistributedCluster(MappingCluster): + """Distributed cluster that uses environment variables for distributed setup.""" + + def __init__(self) -> None: + super().__init__(mapping=DISTRIBUTED_MAPPING) + + @classmethod + def used(cls) -> bool: + return bool(DISTRIBUTED_MAPPING.get_env("world_size")) and bool(DISTRIBUTED_MAPPING.get_env("global_rank")) diff --git a/src/anemoi/inference/clusters/manual.py b/src/anemoi/inference/clusters/manual.py new file mode 100644 index 00000000..0e52efc0 --- /dev/null +++ b/src/anemoi/inference/clusters/manual.py @@ -0,0 +1,132 @@ +# (C) Copyright 2025- ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +import os + +from anemoi.inference.clusters import cluster_registry +from anemoi.inference.clusters.client import ComputeClientFactory +from anemoi.inference.clusters.spawner import SPAWN_FUNCTION +from anemoi.inference.clusters.spawner import ComputeSpawner +from anemoi.inference.config import Configuration + +LOG = logging.getLogger(__name__) + + +@cluster_registry.register("manual") +class ManualSpawner(ComputeSpawner): + """Manual cluster that uses user-defined world size for distributed setup. + + Example usage + ------------- + In the config + ```yaml + cluster: + manual: + world_size: 4 + port: 12345 + ``` + """ + + def __init__(self, world_size: int, port: int | None = None) -> None: + if world_size < 1: + raise ValueError("world_size must be at least 1.") + self._world_size = world_size + self._port = port + self._spawned_processes = [] + + @classmethod + def used(cls) -> bool: + return False + + def _create_port(self) -> int: + """Create a unique port based on the node name.""" + if self._port is not None: + return self._port + + import hashlib + + node_name = os.uname().nodename.encode() # Convert to bytes + hash_val = int(hashlib.md5(node_name).hexdigest(), 16) # Convert hash to int + master_port = 10000 + (hash_val % 9999) + return master_port + + def spawn(self, fn: SPAWN_FUNCTION, config: "Configuration") -> None: + import torch.multiprocessing as mp + + try: + mp.set_start_method("spawn") + except RuntimeError: + LOG.warning("Multiprocessing start method has already been set.") + + port = self._create_port() + + for pid in range(self._world_size): + factory = ManualClient( + world_size=self._world_size, local_rank=pid, global_rank=pid, master_addr="localhost", master_port=port + ) + process = mp.Process(target=fn, args=(config, factory)) + process.start() + self._spawned_processes.append(process) + + # Ensure all spawned processes complete execution + for process in self._spawned_processes: + process.join() + + def teardown(self) -> None: + """Tear down the cluster environment and join spawned processes.""" + # Join all spawned processes to ensure clean shutdown + for process in self._spawned_processes: + if not process.is_alive(): + continue + + process.terminate() + process.join(1) + if process.exitcode is None: + LOG.debug(f"Kill hung process - PID: {process.pid}") + process.kill() + + +class ManualClient(ComputeClientFactory): + def __init__(self, world_size: int, local_rank: int, global_rank: int, master_addr: str, master_port: int) -> None: + """Initialise the ManualClient.""" + self._world_size = world_size + self._local_rank = local_rank + self._global_rank = global_rank + self._master_addr = master_addr + self._master_port = master_port + + @classmethod + def used(cls) -> bool: + return True + + @property + def world_size(self) -> int: + """Return the total number of processes in the cluster.""" + return self._world_size + + @property + def global_rank(self) -> int: + """Return the rank of the current process.""" + return self._global_rank + + @property + def local_rank(self) -> int: + """Return the rank of the current process.""" + return self._local_rank + + @property + def master_addr(self) -> str: + """Return the master address.""" + return self._master_addr + + @property + def master_port(self) -> int: + """Return the master port.""" + return self._master_port diff --git a/src/anemoi/inference/clusters/mapping.py b/src/anemoi/inference/clusters/mapping.py new file mode 100644 index 00000000..98d58be4 --- /dev/null +++ b/src/anemoi/inference/clusters/mapping.py @@ -0,0 +1,134 @@ +# (C) Copyright 2025- ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import dataclasses +import logging +import os +from typing import Any + +from anemoi.inference.clusters import cluster_registry +from anemoi.inference.clusters.client import ComputeClientFactory + +LOG = logging.getLogger(__name__) + + +@dataclasses.dataclass +class EnvMapping: + """Dataclass to hold environment variable mappings for cluster configuration. + + Elements can be either strings or lists of strings. + If a list is provided, the first found environment variable will be used. + """ + + local_rank: str | list[str] + global_rank: str | list[str] + world_size: str | list[str] + + master_addr: str | list[str] + master_port: str | list[str] + + backend: str | None = None + init_method: str = "env://" + + def get_env(self, key: str, default: Any = None): + """Get the environment variable value for the given key.""" + mapped_value = getattr(self, key) + if mapped_value is None: + return default + + for env_var in (mapped_value if isinstance(mapped_value, list) else [mapped_value]): + value = os.environ.get(env_var) + if value is not None: + return value + return default + + +@cluster_registry.register("custom") +class MappingCluster(ComputeClientFactory): + """Custom cluster that uses user-defined environment variables for distributed setup. + + Example usage + ------------- + + In the config + ```yaml + runner: + parallel: + cluster: + custom: + mapping: + local_rank: LOCAL_RANK_ENV_VAR + global_rank: GLOBAL_RANK_ENV_VAR + world_size: WORLD_SIZE_ENV_VAR + master_addr: MASTER_ADDR_ENV_VAR + master_port: MASTER_PORT_ENV_VAR + init_method: env:// + ``` + + ```python + from anemoi.inference.clusters.mapping import MappingCluster + cluster = MappingCluster(mapping={ + "local_rank": "LOCAL_RANK_ENV_VAR", + "global_rank": "GLOBAL_RANK_ENV_VAR", + "world_size": "WORLD_SIZE_ENV_VAR", + "master_addr": "MASTER_ADDR_ENV_VAR", + "master_port": "MASTER_PORT_ENV_VAR", + "init_method": "env://", + }) + ``` + """ + + def __init__(self, mapping: dict | EnvMapping) -> None: + """Initalise the MappingCluster + + Parameters + ---------- + mapping : dict | EnvMapping + Mapping of environment variables to cluster properties + """ + self._mapping = EnvMapping(**mapping) if isinstance(mapping, dict) else mapping + + @property + def init_method(self) -> str: + """Return the initialisation method string for distributed computing.""" + return self._mapping.init_method.format(master_addr=self.master_addr, master_port=self.master_port) + + @property + def backend(self) -> str: + """Return the backend string for distributed computing.""" + return self._mapping.backend or super().backend + + @property + def world_size(self) -> int: + """Return the total number of processes in the cluster.""" + return int(self._mapping.get_env("world_size", 1)) + + @property + def global_rank(self) -> int: + """Return the rank of the current process.""" + return int(self._mapping.get_env("global_rank", 0)) + + @property + def local_rank(self) -> int: + """Return the rank of the current process.""" + return int(self._mapping.get_env("local_rank", self.global_rank)) + + @property + def master_addr(self) -> str: + """Return the master address.""" + return self._mapping.get_env("master_addr", "") + + @property + def master_port(self) -> int: + """Return the master port.""" + return int(self._mapping.get_env("master_port", 0)) + + @classmethod + def used(cls) -> bool: + return False diff --git a/src/anemoi/inference/clusters/mpi.py b/src/anemoi/inference/clusters/mpi.py new file mode 100644 index 00000000..91172472 --- /dev/null +++ b/src/anemoi/inference/clusters/mpi.py @@ -0,0 +1,74 @@ +# (C) Copyright 2025- ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging + +from anemoi.inference.clusters import cluster_registry +from anemoi.inference.clusters.mapping import EnvMapping +from anemoi.inference.clusters.mapping import MappingCluster +from anemoi.inference.lazy import torch + +LOG = logging.getLogger(__name__) + +MPI_MAPPING = EnvMapping( + local_rank=["OMPI_COMM_WORLD_LOCAL_RANK", "PMI_RANK"], + global_rank=["OMPI_COMM_WORLD_RANK", "PMI_RANK"], + world_size=["OMPI_COMM_WORLD_SIZE", "PMI_SIZE"], + master_addr="MASTER_ADDR", + master_port="MASTER_PORT", + init_method="tcp://{master_addr}:{master_port}", +) + + +@cluster_registry.register("mpi") +class MPICluster(MappingCluster): + """MPI cluster that uses MPI environment variables for distributed setup.""" + + def __init__(self, use_mpi_backend: bool = False, **kwargs) -> None: + """Initialise the MPICluster. + + Parameters + ---------- + use_mpi_backend : bool, optional + Use the `mpi` backend in torch, by default False + """ + super().__init__(mapping=MPI_MAPPING, **kwargs) + self._use_mpi_backend = use_mpi_backend + + @classmethod + def used(cls) -> bool: + return bool(MPI_MAPPING.get_env("world_size")) + + @property + def backend(self) -> str: + """Return the backend string for distributed computing.""" + if self._use_mpi_backend: + return "mpi" + return super().backend + + def create_model_comm_group(self) -> "torch.distributed.ProcessGroup | None": + """Create the communication group for model parallelism.""" + if not self._use_mpi_backend: + return super().create_model_comm_group() + + if self.world_size <= 1: + return None + + LOG.debug("Creating model communication group for parallel inference") + group = torch.distributed.init_process_group( + backend=self.backend, + ) + + # Create a new process group for model communication + group = torch.distributed.new_group( + ranks=list(range(self.world_size)), + ) + LOG.info("Model communication group created") + + return group diff --git a/src/anemoi/inference/clusters/slurm.py b/src/anemoi/inference/clusters/slurm.py new file mode 100644 index 00000000..9b409395 --- /dev/null +++ b/src/anemoi/inference/clusters/slurm.py @@ -0,0 +1,103 @@ +# (C) Copyright 2025- ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +import os +import socket +import subprocess + +from anemoi.inference.clusters import cluster_registry +from anemoi.inference.clusters.mapping import EnvMapping +from anemoi.inference.clusters.mapping import MappingCluster + +LOG = logging.getLogger(__name__) + +SLURM_MAPPING = EnvMapping( + local_rank="SLURM_LOCALID", + global_rank="SLURM_PROCID", + world_size="SLURM_NTASKS", + master_addr="MASTER_ADDR", + master_port="MASTER_PORT", + init_method="tcp://{master_addr}:{master_port}", +) + + +@cluster_registry.register("slurm") +class SlurmCluster(MappingCluster): + """Slurm cluster that uses SLURM environment variables for distributed setup.""" + + _master_addr: str | None = None + _master_port: int | None = None + + def __init__(self) -> None: + super().__init__(mapping=SLURM_MAPPING) + + @classmethod + def used(cls) -> bool: + # from pytorch lightning + # https://github.com/Lightning-AI/pytorch-lightning/blob/a944e7744e57a5a2c13f3c73b9735edf2f71e329/src/lightning/fabric/plugins/environments/slurm.py + return bool(SLURM_MAPPING.get_env("world_size")) and os.environ.get("SLURM_JOB_NAME") not in ( + "bash", + "interactive", + ) + + @property + def master_addr(self) -> str: + """Return the master address.""" + if self._master_addr is not None: + return self._master_addr + + # Get the master address from the SLURM_NODELIST environment variable + slurm_nodelist = os.environ.get("SLURM_NODELIST") + if not slurm_nodelist: + raise ValueError("SLURM_NODELIST environment variable is not set.") + + # Check if MASTER_ADDR is given, otherwise try set it using 'scontrol' + master_addr = super().master_addr + if not master_addr: + LOG.debug("'MASTER_ADDR' environment variable not set. Trying to set via SLURM") + try: + result = subprocess.run( + ["scontrol", "show", "hostname", slurm_nodelist], stdout=subprocess.PIPE, text=True, check=True + ) + except subprocess.CalledProcessError as err: + LOG.error( + "Python could not execute 'scontrol show hostname $SLURM_NODELIST' while calculating MASTER_ADDR. You could avoid this error by setting the MASTER_ADDR env var manually." + ) + raise err + + master_addr = result.stdout.splitlines()[0] + + # Resolve the master address using nslookup + try: + master_addr = socket.gethostbyname(master_addr) + except socket.gaierror: + raise ValueError(f"Could not resolve hostname: {master_addr}") + + self._master_addr = master_addr + return master_addr + + @property + def master_port(self) -> int: + """Return the master port.""" + if self._master_port is not None: + return self._master_port + + # Check if MASTER_PORT is given, otherwise generate one based on SLURM_JOBID + master_port = super().master_port + if master_port is None or master_port == 0: + LOG.debug("'MASTER_PORT' environment variable not set. Trying to set via SLURM") + slurm_jobid = os.environ.get("SLURM_JOBID") + if not slurm_jobid: + raise ValueError("SLURM_JOBID environment variable is not set.") + + master_port = 10000 + int(slurm_jobid[-4:]) + + self._master_port = master_port + return master_port diff --git a/src/anemoi/inference/clusters/spawner.py b/src/anemoi/inference/clusters/spawner.py new file mode 100644 index 00000000..9c74321f --- /dev/null +++ b/src/anemoi/inference/clusters/spawner.py @@ -0,0 +1,54 @@ +# (C) Copyright 2025- ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from abc import ABC +from abc import abstractmethod +from typing import TYPE_CHECKING +from typing import Callable + +if TYPE_CHECKING: + from anemoi.inference.clusters.client import ComputeClientFactory + from anemoi.inference.config import Configuration + +SPAWN_FUNCTION = Callable[["Configuration", "ComputeClientFactory"], None] + + +class ComputeSpawner(ABC): + """Abstract base class for cluster operations for parallel execution.""" + + @classmethod + @abstractmethod + def used(cls) -> bool: + """Check if this client is valid in the current environment.""" + raise NotImplementedError + + @abstractmethod + def spawn(self, fn: SPAWN_FUNCTION, config: "Configuration") -> None: + """Spawn processes for parallel execution. + + Parameters + ---------- + fn : SPAWN_FUNCTION + The function to run in each process. + Expects to receive the configuration and compute client factory as arguments. + config : Configuration + The configuration object for the runner. + """ + raise NotImplementedError + + @abstractmethod + def teardown(self) -> None: + """Tear down the cluster environment.""" + raise NotImplementedError + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.teardown() diff --git a/src/anemoi/inference/config/run.py b/src/anemoi/inference/config/run.py index 3c4998e6..d36e7d4b 100644 --- a/src/anemoi/inference/config/run.py +++ b/src/anemoi/inference/config/run.py @@ -32,7 +32,7 @@ class RunConfiguration(Configuration): """A path to an Anemoi checkpoint file.""" runner: str | dict[str, Any] = "default" - """The runner to use.""" + """The runner to use. If using `parallel`, can set `cluster` options here""" lead_time: str | int | datetime.timedelta = "10d" """The lead time for the forecast. This can be a string, an integer or a timedelta object. @@ -48,9 +48,6 @@ class RunConfiguration(Configuration): use_profiler: bool = False """If True, the inference will be profiled, producing time and memory report.""" - world_size: int | None = 1 - """Number of parallel processes, used for parallel inference without SLURM.""" - report_error: bool = False """If True, the runner list the training versions of the packages in case of error. (Deprecated, unused)""" diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index d3845ef2..e12f4848 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -271,6 +271,8 @@ def run( except (TypeError, ModuleNotFoundError, AttributeError): self.checkpoint.report_error() raise + finally: + self.complete_forecast_hook() def create_constant_forcings_inputs(self, input_state: State) -> list[Forcings]: @@ -1167,6 +1169,10 @@ def output_state_hook(self, state: State) -> None: """Hook used by coupled runners to send the input state.""" pass + def complete_forecast_hook(self) -> None: + """Hook called at the end of the forecast.""" + pass + def has_split_input(self) -> bool: # To be overridden by a subclass if the we use different inputs # for initial conditions, constants and dynamic forcings diff --git a/src/anemoi/inference/runners/parallel.py b/src/anemoi/inference/runners/parallel.py index 713b4ee2..4e69cc8c 100644 --- a/src/anemoi/inference/runners/parallel.py +++ b/src/anemoi/inference/runners/parallel.py @@ -7,17 +7,17 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import datetime import logging import os -import socket -import subprocess +import warnings from typing import Any -from typing import Optional -import numpy as np from anemoi.utils.logs import enable_logging_name +from anemoi.inference.clusters import create_cluster +from anemoi.inference.clusters.client import ComputeClient +from anemoi.inference.clusters.client import ComputeClientFactory +from anemoi.inference.clusters.spawner import ComputeSpawner from anemoi.inference.config import Configuration from anemoi.inference.lazy import torch from anemoi.inference.output import Output @@ -25,36 +25,62 @@ from ..decorators import main_argument from ..outputs import create_output from ..runner import Runner -from ..runners import create_runner from . import runner_registry -from .default import DefaultRunner LOG = logging.getLogger(__name__) -def create_parallel_runner(config: Configuration, pid: int) -> None: +def create_parallel_runner(config: Configuration, client_factory: ComputeClientFactory) -> None: """Creates and runs a parallel runner. Parameters ---------- config : Configuration The configuration object for the runner. - pid : int - The process ID. + client_factory : ComputeClientFactory + The compute client factory to use for distributed inference. """ - runner = create_runner(config, pid=pid) + runner_config: dict[str, Any] = config.runner.get("parallel", {}) # type: ignore + if isinstance(runner_config, str): + runner_config = {"base_runner": runner_config} + runner_config["cluster"] = client_factory.create_client() + + runner = ParallelRunnerFactory(config, **runner_config) # type: ignore runner.execute() +class NoOp: + """No operation class used when returning after spawning processes.""" + + def execute(self, *a, **k) -> None: + return None + + @runner_registry.register("parallel") @main_argument("base_runner") class ParallelRunnerFactory: - """Creates a ParallelRunner with a dynamic base class.""" + """Creates a ParallelRunner with a dynamic base class. - def __new__(cls, config: Any, base_runner: str = "default", *args, **kwargs): - assert base_runner != "parallel", "Base runner cannot be `parallel` itself." + Parameters + ---------- + config : Any + The config for the runner. + base_runner : str + The base runner to use for the parallel runner. + Must subclass from at least `DefaultRunner`. + cluster : str | dict[str, str] | ComputeClient | None, optional + The cluster configuration or instance to use for distributed inference, by default None + """ - enable_logging_name(f"rank{int(os.environ.get('SLURM_PROCID',0)):02d}") + def __new__( + cls, + config: Any, + base_runner: str = "default", + *args, + cluster: str | dict[str, str] | ComputeClient | None = None, + **kwargs, + ): + assert base_runner != "parallel", "Base runner cannot be `parallel` itself." try: base_class = runner_registry.lookup(base_runner) @@ -63,10 +89,23 @@ def __new__(cls, config: Any, base_runner: str = "default", *args, **kwargs): assert issubclass(base_class, Runner), f"Base runner '{base_runner}' must be a subclass of Runner." - LOG.info(f"Creating ParallelRunner from base runner: {base_runner} ({base_class.__name__})") + LOG.debug(f"Creating ParallelRunner from base runner: {base_runner} ({base_class.__name__})") ParallelRunner = cls.get_class(base_class) - return ParallelRunner(config, *args, **kwargs) + if not isinstance(cluster, (ComputeClient,)): + compute = create_cluster(cluster or {}) + else: + compute = cluster + + if isinstance(compute, ComputeSpawner): + with compute: + compute.spawn(create_parallel_runner, config) + return NoOp() + + compute_client = compute if isinstance(compute, ComputeClient) else compute.create_client() + + LOG.info(f"Using compute client provider: {compute!r}") + return ParallelRunner(config, *args, compute_client=compute_client, **kwargs) @staticmethod def get_class(base_class: Runner): @@ -74,64 +113,72 @@ def get_class(base_class: Runner): return type("ParallelRunner", (ParallelRunnerMixin, base_class), {}) -class ParallelRunnerMixin: +class ParallelRunnerMixin(Runner): """Runner which splits a model over multiple devices. Should be mixed in with a base runner class.""" - def __new__(cls, config, *args, **kwargs): - - if torch.cuda.is_available(): - return super().__new__(cls) - else: - LOG.warning("CUDA is not available. Falling back to DefaultRunner") - return DefaultRunner(config) - - def __init__(self, config: Any, pid: int = 0, **kwargs) -> None: - """Initializes the ParallelRunner. + def __init__(self, config: Any, compute_client: ComputeClient | None = None, **kwargs) -> None: + """Initialises the ParallelRunner. Parameters ---------- config : Any The config for the runner. - pid : int, optional - The process ID, by default 0. + compute_client : ComputeClient, optional + The compute client to use for distributed inference """ super().__init__(config, **kwargs) - self.model_comm_group = None - self.pid = pid + compute_client = compute_client or create_cluster(config.cluster or {}).create_client() # type: ignore + assert isinstance(compute_client, ComputeClient), "Compute client must be an instance of ComputeClient." - # give the base class an opportunity to modify the parallel runner - super()._configure_parallel_runner() + # Set up logging name based on actual cluster rank + enable_logging_name(f"rank{compute_client.global_rank:02d}") + LOG.info(f"{compute_client!r}") - self._bootstrap_processes() + self.compute_client = compute_client - LOG.info( - f"ParallelRunner local/global ranks: {self.local_rank}/{self.global_rank}, host: {socket.gethostname()}" - ) + # give the base class an opportunity to modify the parallel runner + super()._configure_parallel_runner() if self.device.type == "cuda": - self.device = torch.device("cuda", index=self.local_rank) + self.device = torch.device("cuda", index=compute_client.local_rank) torch.cuda.set_device(self.device) - LOG.info(f"ParallelRunner changing to device `{self.device}`") + LOG.debug(f"ParallelRunner changing to device `{self.device}`") else: - LOG.info(f"ParallelRunner device `{self.device}` is unchanged") + LOG.warning(f"ParallelRunner device `{self.device}` is unchanged") + + self.compute_client = compute_client + self.is_master = compute_client.is_master + self.seed(compute_client.process_group) # disable most logging on non-zero ranks - if self.global_rank != 0 and self.verbosity == 0: - LOG.info("ParallelRunner logging disabled on non-zero rank") + if not self.is_master and self.verbosity == 0: + LOG.debug("ParallelRunner logging disabled on non-zero rank") logging.getLogger().setLevel(logging.WARNING) + warnings.filterwarnings("ignore") + + def seed(self, comm_group: "torch.distributed.ProcessGroup | None") -> None: + """Seed all processes in the cluster to ensure reproducibility.""" + seed = None + seed_threshold = 1000 + env_var = "ANEMOI_BASE_SEED" - # Create a model comm group for parallel inference - # A dummy comm group is created if only a single device is in use - if self.world_size > 1: - model_comm_group = self._init_parallel() - self.model_comm_group = model_comm_group + if env_var in os.environ: + seed = int(os.environ[env_var]) + if seed < seed_threshold: + seed *= seed_threshold # Ensure seed is sufficiently large - # Ensure each parallel model instance uses the same seed - self._seed_procs() + if self.is_master: + seed = seed or torch.initial_seed() + seed_list = [seed] + torch.distributed.broadcast_object_list(seed_list, src=0, group=comm_group) else: - LOG.warning("ParallelRunner selected but world size of 1 detected") + seed_list = [None] + torch.distributed.broadcast_object_list(seed_list, src=0, group=comm_group) + seed = seed_list[0] + + torch.manual_seed(seed) def predict_step(self, model: Any, input_tensor_torch: "torch.Tensor", **kwargs: Any) -> "torch.Tensor": """Performs a prediction step. @@ -153,17 +200,24 @@ def predict_step(self, model: Any, input_tensor_torch: "torch.Tensor", **kwargs: # call the predict_step of the base class since it might do some modifications # the base class is expected to forward the kwargs (including the comm group) to the model's predict_step method - if self.model_comm_group is None: + if self.compute_client.process_group is None: return super().predict_step(model, input_tensor_torch, **kwargs) else: try: - return super().predict_step(model, input_tensor_torch, model_comm_group=self.model_comm_group, **kwargs) + return super().predict_step( + model, input_tensor_torch, model_comm_group=self.compute_client.process_group, **kwargs + ) except TypeError as err: LOG.error( "Please upgrade to a newer version of anemoi-models (at least version v0.4.2) to use parallel inference. If updating breaks your checkpoints, you can try reverting to your original version of anemoi-models and cherry-picking 'https://github.com/ecmwf/anemoi-core/pull/77'" ) raise err + def complete_forecast_hook(self) -> None: + """Hook called at the end of the forecast.""" + super().complete_forecast_hook() + torch.distributed.destroy_process_group() + def create_output(self) -> Output: """Creates the real output on rank 0 and a `none` on the others. @@ -172,233 +226,8 @@ def create_output(self) -> Output: Output The created output. """ - if self.global_rank == 0: + if self.is_master: return super().create_output() else: output = create_output(self, "none") return output - - def __del__(self) -> None: - """Destructor to clean up resources.""" - if self.model_comm_group is not None: - torch.distributed.destroy_process_group() - - def _seed_procs(self) -> None: - """Ensures each process uses the same seed. - Will try read 'ANEMOI_BASE_SEED' from the environment. - Otherwise, the seed of process 0 will be shared to all processes. - """ - - seed = None - seed_threshold = 1000 - env_var_list = ["ANEMOI_BASE_SEED"] - for env_var in env_var_list: - if env_var in os.environ: - seed = int(os.environ.get(env_var)) - if seed < seed_threshold: - seed *= seed_threshold # make it (hopefully) big enough - break - - if self.global_rank == 0: - if seed is None: - seed = torch.initial_seed() - torch.distributed.broadcast_object_list([seed], src=0, group=self.model_comm_group) - else: - msg_buffer = np.array([1], dtype=np.uint64) - torch.distributed.broadcast_object_list(msg_buffer, src=0, group=self.model_comm_group) - seed = msg_buffer[0] - torch.manual_seed(seed) - - def _srun_used(self) -> bool: - """Checks if anemoi-inference was launched with srun. - - Returns - ------- - bool - True if srun is used, False otherwise. - """ - # from pytorch lightning - # https://github.com/Lightning-AI/pytorch-lightning/blob/a944e7744e57a5a2c13f3c73b9735edf2f71e329/src/lightning/fabric/plugins/environments/slurm.py - return "SLURM_NTASKS" in os.environ and os.environ.get("SLURM_JOB_NAME") not in ("bash", "interactive") - - def _spawn_parallel_procs(self, num_procs: int) -> None: - """When srun is not available, this method creates N-1 child processes within the same node for parallel inference. - - Parameters - ---------- - num_procs : int - The number of processes to spawn. - """ - LOG.debug(f"spawning {num_procs -1 } procs") - - # check num_procs <= num_gpus - if str(self.device).startswith("cuda"): - num_gpus = torch.cuda.device_count() - if num_procs > num_gpus: - raise ValueError( - f"You requested parallel inference over {num_procs} GPUs but your node only has {num_gpus} GPUs available." - ) - - # Create N-1 procs, each with a unique PID - import torch.multiprocessing as mp - - mp.set_start_method("spawn") - config = self.config - for pid in range(1, num_procs): - mp.Process(target=create_parallel_runner, args=(config, pid)).start() - - def _bootstrap_processes(self) -> None: - """Initialises processes and their network information. - - If srun is available, Slurm variables are read to determine network settings. - Otherwise, local processes are spawned and network info is inferred from the configuration. - """ - using_slurm = self._srun_used() - if using_slurm: - - # Determine world size and rank from slurm env vars - global_rank, local_rank, world_size = self._get_parallel_info_from_slurm() - self.global_rank = global_rank - self.local_rank = local_rank - self.world_size = world_size - - # determine master address and port from slurm/override env vars - slurm_addr, slurm_port = self._init_network_from_slurm() - self.master_addr = slurm_addr - self.master_port = slurm_port - - # the world size entry in the config is only needed when not launching via srun - if self.config.world_size != 1: - LOG.warning( - f"world size ({self.config.world_size}) set in the config is ignored because we are launching via srun, using 'SLURM_NTASKS' instead" - ) - else: - # If srun is not available, spawn procs manually on a node - - # Read the config to determine world_size and pid - self.global_rank = self.pid # only inference within a single node is supported when not using srun - self.local_rank = self.pid - self.world_size = self.config.world_size - if self.world_size == 1: - LOG.warning( - "You selected 'runner: parallel' but you have only set 'world_size: 1'. Please update world_size or launch via srun to make use of parallel inference" - ) - if self.world_size <= 0: - raise ValueError( - f"Error. 'world_size' must be greater then 1 to use parallel inference. {self.config.world_size=} set in the config is invalid." - ) - - # since we are running within a node, 'localhost' and any port can be used - self.master_addr = "localhost" - # generates a port between 10000 and 19999, based on the nodes hostname (which will be the same across all node-local procs) - import hashlib - - node_name = os.uname().nodename.encode() # Convert to bytes - hash_val = int(hashlib.md5(node_name).hexdigest(), 16) # Convert hash to int - self.master_port = 10000 + (hash_val % 9999) - - # Spawn the other processes manually - if self.local_rank == 0: - self._spawn_parallel_procs(self.world_size) - - def _init_network_from_slurm(self) -> tuple[str, str]: - """Reads Slurm environment to set master address and port for parallel communication. - - Returns - ------- - Tuple[str, str] - The master address and port. - """ - # Get the master address from the SLURM_NODELIST environment variable - slurm_nodelist = os.environ.get("SLURM_NODELIST") - if not slurm_nodelist: - raise ValueError("SLURM_NODELIST environment variable is not set.") - - # Check if MASTER_ADDR is given, otherwise try set it using 'scontrol' - master_addr = os.environ.get("MASTER_ADDR") - if master_addr is None: - LOG.debug("'MASTER_ADDR' environment variable not set. Trying to set via SLURM") - try: - result = subprocess.run( - ["scontrol", "show", "hostname", slurm_nodelist], stdout=subprocess.PIPE, text=True, check=True - ) - except subprocess.CalledProcessError as err: - LOG.error( - "Python could not execute 'scontrol show hostname $SLURM_NODELIST' while calculating MASTER_ADDR. You could avoid this error by setting the MASTER_ADDR env var manually." - ) - raise err - - master_addr = result.stdout.splitlines()[0] - - # Resolve the master address using nslookup - try: - master_addr = socket.gethostbyname(master_addr) - except socket.gaierror: - raise ValueError(f"Could not resolve hostname: {master_addr}") - - # Check if MASTER_PORT is given, otherwise generate one based on SLURM_JOBID - master_port = os.environ.get("MASTER_PORT") - if master_port is None: - LOG.debug("'MASTER_PORT' environment variable not set. Trying to set via SLURM") - slurm_jobid = os.environ.get("SLURM_JOBID") - if not slurm_jobid: - raise ValueError("SLURM_JOBID environment variable is not set.") - - master_port = str(10000 + int(slurm_jobid[-4:])) - - # Print the results for confirmation - LOG.debug(f"MASTER_ADDR: {master_addr}") - LOG.debug(f"MASTER_PORT: {master_port}") - - return master_addr, master_port - - def _init_parallel(self) -> Optional["torch.distributed.ProcessGroup"]: - """Creates a model communication group to be used for parallel inference. - - Returns - ------- - Optional[dist.ProcessGroup] - The model communication group. - """ - import torch.distributed as dist - - if self.world_size > 1: - - # use 'startswith' instead of '==' in case device is 'cuda:0' - if str(self.device).startswith("cuda"): - backend = "nccl" - else: - if dist.is_mpi_available(): - backend = "mpi" - else: - backend = "gloo" - - dist.init_process_group( - backend=backend, - init_method=f"tcp://{self.master_addr}:{self.master_port}", - timeout=datetime.timedelta(minutes=3), - world_size=self.world_size, - rank=self.global_rank, - ) - LOG.info(f"Creating a model communication group with {self.world_size} devices with the {backend} backend") - - model_comm_group_ranks = np.arange(self.world_size, dtype=int) - model_comm_group = dist.new_group(model_comm_group_ranks) - else: - model_comm_group = None - - return model_comm_group - - def _get_parallel_info_from_slurm(self) -> tuple[int, int, int]: - """Reads Slurm env vars, if they exist, to determine if inference is running in parallel. - - Returns - ------- - Tuple[int, int, int] - The global rank, local rank, and world size. - """ - local_rank = int(os.environ.get("SLURM_LOCALID", 0)) # Rank within a node, between 0 and num_gpus - global_rank = int(os.environ.get("SLURM_PROCID", 0)) # Rank within all nodes - world_size = int(os.environ.get("SLURM_NTASKS", 1)) # Total number of processes - - return global_rank, local_rank, world_size diff --git a/tests/unit/test_clusters.py b/tests/unit/test_clusters.py new file mode 100644 index 00000000..f7ed2443 --- /dev/null +++ b/tests/unit/test_clusters.py @@ -0,0 +1,568 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import os +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +from anemoi.inference.clusters import cluster_registry +from anemoi.inference.clusters import create_cluster +from anemoi.inference.clusters.distributed import DistributedCluster +from anemoi.inference.clusters.manual import ManualClient +from anemoi.inference.clusters.manual import ManualSpawner +from anemoi.inference.clusters.mapping import EnvMapping +from anemoi.inference.clusters.mapping import MappingCluster +from anemoi.inference.clusters.mpi import MPICluster +from anemoi.inference.clusters.slurm import SlurmCluster +from anemoi.inference.context import Context + + +@pytest.fixture +def mock_context(): + """Create a mock context for testing.""" + context = MagicMock(spec=Context) + context.device = MagicMock() + context.device.type = "cuda" + context.use_grib_paramid = False + return context + + +class TestManualSpawner: + """Tests for ManualSpawner.""" + + def test_manual_spawner_initialization(self): + """Test ManualSpawner initialization.""" + # ManualSpawner is decorated with @main_argument("world_size") + # So world_size can be passed as the second positional arg (after context) + spawner = ManualSpawner(4) + + assert spawner._world_size == 4 + assert spawner._spawned_processes == [] + + def test_manual_spawner_invalid_world_size(self): + """Test ManualSpawner with invalid world_size.""" + with pytest.raises(ValueError, match="world_size must be at least 1"): + ManualSpawner(0) + + with pytest.raises(ValueError, match="world_size must be at least 1"): + ManualSpawner(-1) + + def test_manual_spawner_spawn(self): + """Test ManualSpawner spawn functionality.""" + spawner = ManualSpawner(4, port=12345) + + mock_fn = MagicMock() + mock_config = MagicMock() + + with patch("torch.multiprocessing.Process") as mock_process: + mock_process_instance = MagicMock() + mock_process.return_value = mock_process_instance + + spawner.spawn(mock_fn, mock_config) + + # Should spawn world_size processes (ranks 0, 1, 2, 3) + assert mock_process.call_count == 4 + assert len(spawner._spawned_processes) == 4 + + def test_manual_spawner_teardown(self): + """Test ManualSpawner teardown with process cleanup.""" + # Ensure environment marker is not set + with patch.dict(os.environ, {}, clear=True): + spawner = ManualSpawner(4) + + # Create mock processes + mock_processes = [MagicMock() for _ in range(4)] + for mock_proc in mock_processes: + mock_proc.is_alive.return_value = False + + spawner._spawned_processes = mock_processes + + spawner.teardown() + + # All processes should have is_alive checked + for mock_proc in mock_processes: + mock_proc.is_alive.assert_called() + + def test_manual_spawner_teardown_with_alive_processes(self): + """Test ManualSpawner teardown when processes are still alive.""" + # Ensure environment marker is not set + with patch.dict(os.environ, {}, clear=True): + spawner = ManualSpawner(4) + + # Create mock processes that are alive + mock_process = MagicMock() + mock_process.is_alive.return_value = True + mock_process.pid = 12345 + + spawner._spawned_processes = [mock_process] + + spawner.teardown() + + # Should try to join, then terminate + mock_process.join.assert_called() + mock_process.terminate.assert_called() + + def test_manual_spawner_not_used(self): + """Test that ManualSpawner.used() returns False.""" + assert not ManualSpawner.used() + + +class TestManualClient: + """Tests for ManualClient.""" + + def test_manual_client_initialization(self): + """Test ManualClient initialization.""" + client = ManualClient( + world_size=4, + local_rank=0, + global_rank=0, + master_addr="localhost", + master_port=12345, + ) + + assert client.world_size == 4 + assert client.global_rank == 0 + assert client.local_rank == 0 + assert client.master_addr == "localhost" + assert client.master_port == 12345 + + def test_manual_client_different_ranks(self): + """Test ManualClient with different process ranks.""" + client = ManualClient( + world_size=4, + local_rank=1, + global_rank=1, + master_addr="localhost", + master_port=12345, + ) + + assert client.global_rank == 1 + assert not client.is_master + + def test_manual_client_used(self): + """Test ManualClient.used() detection.""" + # ManualClient is now always available (returns True) since it's explicitly instantiated + assert ManualClient.used() + + def test_manual_client_repr(self): + """Test ManualClient string representation.""" + client = ManualClient( + world_size=4, + local_rank=2, + global_rank=2, + master_addr="localhost", + master_port=12345, + ) + repr_str = repr(client) + + assert "ManualClient" in repr_str + assert "world_size=4" in repr_str + assert "global_rank=2" in repr_str + + +class TestSlurmCluster: + """Tests for SlurmCluster.""" + + def test_slurm_cluster_used_detection(self): + """Test SlurmCluster.used() detection.""" + # Not in Slurm environment + with patch.dict(os.environ, {}, clear=True): + assert not SlurmCluster.used() + + # In Slurm environment + with patch.dict(os.environ, {"SLURM_NTASKS": "4", "SLURM_JOB_NAME": "test_job"}): + assert SlurmCluster.used() + + # Slurm but interactive shell (should not be used) + with patch.dict(os.environ, {"SLURM_NTASKS": "4", "SLURM_JOB_NAME": "bash"}): + assert not SlurmCluster.used() + + def test_slurm_cluster_initialization(self): + """Test SlurmCluster initialization.""" + with patch.dict( + os.environ, + { + "SLURM_NTASKS": "8", + "SLURM_PROCID": "3", + "SLURM_LOCALID": "1", + "SLURM_NODELIST": "node001", + "SLURM_JOBID": "12345", + "MASTER_ADDR": "192.168.1.1", + "MASTER_PORT": "29500", + }, + ): + cluster = SlurmCluster() + + assert cluster.world_size == 8 + assert cluster.global_rank == 3 + assert cluster.local_rank == 1 + assert cluster.master_addr == "192.168.1.1" + assert cluster.master_port == 29500 + + def test_slurm_cluster_master_addr_from_nodelist(self): + """Test SlurmCluster master_addr resolution from SLURM_NODELIST.""" + with patch.dict( + os.environ, + { + "SLURM_NTASKS": "4", + "SLURM_PROCID": "0", + "SLURM_LOCALID": "0", + "SLURM_NODELIST": "node[001-004]", + "SLURM_JOBID": "12345", + }, + clear=True, + ): + with patch("subprocess.run") as mock_run: + mock_run.return_value.stdout = "node001\nnode002\nnode003\nnode004\n" + + with patch("socket.gethostbyname", return_value="192.168.1.1"): + cluster = SlurmCluster() + + assert cluster.master_addr == "192.168.1.1" + mock_run.assert_called_once() + + def test_slurm_cluster_master_port_from_jobid(self): + """Test SlurmCluster master_port generation from SLURM_JOBID.""" + with patch.dict( + os.environ, + { + "SLURM_NTASKS": "4", + "SLURM_PROCID": "0", + "SLURM_LOCALID": "0", + "SLURM_NODELIST": "node001", + "SLURM_JOBID": "98765", + }, + clear=True, + ): + with patch("subprocess.run") as mock_run: + mock_run.return_value.stdout = "node001\n" + + with patch("socket.gethostbyname", return_value="192.168.1.1"): + cluster = SlurmCluster() + + # Port should be 10000 + last 4 digits of job ID + expected_port = 10000 + 8765 + assert cluster.master_port == expected_port + + def test_slurm_cluster_scontrol_failure(self): + """Test SlurmCluster when scontrol fails.""" + with patch.dict( + os.environ, + { + "SLURM_NTASKS": "4", + "SLURM_PROCID": "0", + "SLURM_LOCALID": "0", + "SLURM_NODELIST": "node001", + "SLURM_JOBID": "12345", + }, + clear=True, + ): + with patch("subprocess.run", side_effect=Exception("scontrol failed")): + with pytest.raises(Exception, match="scontrol failed"): + cluster = SlurmCluster() + _ = cluster.master_addr + + +class TestMPICluster: + """Tests for MPICluster.""" + + def test_mpi_cluster_used_detection(self): + """Test MPICluster.used() detection.""" + # Not in MPI environment + with patch.dict(os.environ, {}, clear=True): + assert not MPICluster.used() + + # In MPI environment (OpenMPI) + with patch.dict(os.environ, {"OMPI_COMM_WORLD_SIZE": "4"}): + assert MPICluster.used() + + # In MPI environment (PMI) + with patch.dict(os.environ, {"PMI_SIZE": "4"}): + assert MPICluster.used() + + def test_mpi_cluster_initialization(self): + """Test MPICluster initialization.""" + with patch.dict( + os.environ, + { + "OMPI_COMM_WORLD_SIZE": "8", + "OMPI_COMM_WORLD_RANK": "3", + "OMPI_COMM_WORLD_LOCAL_RANK": "1", + "MASTER_ADDR": "192.168.1.1", + "MASTER_PORT": "29500", + }, + ): + cluster = MPICluster(use_mpi_backend=True) + + assert cluster.world_size == 8 + assert cluster.global_rank == 3 + assert cluster.local_rank == 1 + assert cluster.master_addr == "192.168.1.1" + assert cluster.master_port == 29500 + assert cluster.backend == "mpi" + + +class TestDistributedCluster: + """Tests for DistributedCluster (torchrun).""" + + def test_distributed_cluster_used_detection(self): + """Test DistributedCluster.used() detection.""" + # Not in distributed environment + with patch.dict(os.environ, {}, clear=True): + assert not DistributedCluster.used() + + # In distributed environment (torchrun) + with patch.dict(os.environ, {"RANK": "3", "WORLD_SIZE": "4"}): + assert DistributedCluster.used() + + def test_distributed_cluster_initialization(self): + """Test DistributedCluster initialization.""" + with patch.dict( + os.environ, + {"WORLD_SIZE": "8", "RANK": "3", "LOCAL_RANK": "1", "MASTER_ADDR": "192.168.1.1", "MASTER_PORT": "29500"}, + ): + cluster = DistributedCluster() + + assert cluster.world_size == 8 + assert cluster.global_rank == 3 + assert cluster.local_rank == 1 + assert cluster.master_addr == "192.168.1.1" + assert cluster.master_port == 29500 + + +class TestMappingCluster: + """Tests for MappingCluster (custom mapping).""" + + def test_mapping_cluster_with_dict(self): + """Test MappingCluster with dict mapping.""" + mapping = { + "local_rank": "MY_LOCAL_RANK", + "global_rank": "MY_GLOBAL_RANK", + "world_size": "MY_WORLD_SIZE", + "master_addr": "MY_MASTER_ADDR", + "master_port": "MY_MASTER_PORT", + "init_method": "tcp://{master_addr}:{master_port}", + } + + with patch.dict( + os.environ, + { + "MY_WORLD_SIZE": "8", + "MY_GLOBAL_RANK": "3", + "MY_LOCAL_RANK": "1", + "MY_MASTER_ADDR": "192.168.1.1", + "MY_MASTER_PORT": "29500", + }, + ): + cluster = MappingCluster(mapping=mapping) + + assert cluster.world_size == 8 + assert cluster.global_rank == 3 + assert cluster.local_rank == 1 + assert cluster.master_addr == "192.168.1.1" + assert cluster.master_port == 29500 + + def test_mapping_cluster_with_env_mapping(self): + """Test MappingCluster with EnvMapping object.""" + mapping = EnvMapping( + local_rank="MY_LOCAL_RANK", + global_rank="MY_GLOBAL_RANK", + world_size="MY_WORLD_SIZE", + master_addr="MY_MASTER_ADDR", + master_port="MY_MASTER_PORT", + init_method="env://", + ) + + with patch.dict( + os.environ, + { + "MY_WORLD_SIZE": "4", + "MY_GLOBAL_RANK": "2", + "MY_LOCAL_RANK": "0", + "MY_MASTER_ADDR": "localhost", + "MY_MASTER_PORT": "12345", + }, + ): + cluster = MappingCluster(mapping=mapping) + + assert cluster.world_size == 4 + assert cluster.global_rank == 2 + assert cluster.local_rank == 0 + assert cluster.init_method == "env://" + + def test_mapping_cluster_defaults(self): + """Test MappingCluster with missing environment variables.""" + mapping = EnvMapping( + local_rank="MY_LOCAL_RANK", + global_rank="MY_GLOBAL_RANK", + world_size="MY_WORLD_SIZE", + master_addr="MY_MASTER_ADDR", + master_port="MY_MASTER_PORT", + ) + + with patch.dict(os.environ, {}, clear=True): + cluster = MappingCluster(mapping=mapping) + + # Should use defaults + assert cluster.world_size == 1 + assert cluster.global_rank == 0 + assert cluster.local_rank == 0 + assert cluster.master_addr == "" + assert cluster.master_port == 0 + + def test_mapping_cluster_not_used(self): + """Test that MappingCluster.used() returns False.""" + assert not MappingCluster.used() + + def test_mapping_cluster_list(self): + """Test MappingCluster with list mapping, such that the first found env var is used.""" + mapping = EnvMapping( + local_rank=["MY_LOCAL_RANK_A", "MY_LOCAL_RANK_B"], + global_rank=["MY_GLOBAL_RANK_A", "MY_GLOBAL_RANK_B"], + world_size=["MY_WORLD_SIZE_A", "MY_WORLD_SIZE_B"], + master_addr=["MY_MASTER_ADDR_A", "MY_MASTER_ADDR_B"], + master_port=["MY_MASTER_PORT_A", "MY_MASTER_PORT_B"], + init_method="tcp://{master_addr}:{master_port}", + ) + + with patch.dict( + os.environ, + { + "MY_WORLD_SIZE_B": "16", + "MY_GLOBAL_RANK_A": "5", + "MY_LOCAL_RANK_B": "2", + "MY_MASTER_ADDR_A": "192.168.1.1", + "MY_MASTER_PORT_B": "40000", + }, + ): + cluster = MappingCluster(mapping=mapping) + + assert cluster.world_size == 16 + assert cluster.global_rank == 5 + assert cluster.local_rank == 2 + assert cluster.master_addr == "192.168.1.1" + assert cluster.master_port == 40000 + + +class TestClusterRegistry: + """Tests for cluster registry and creation.""" + + def test_cluster_registry_contains_all_clusters(self): + """Test that all cluster types are registered.""" + registered = cluster_registry.registered + + assert "manual" in registered + assert "slurm" in registered + assert "mpi" in registered + assert "distributed" in registered + assert "custom" in registered + + def test_create_cluster_with_config(self): + """Test create_cluster with explicit config.""" + config = {"manual": {"world_size": 4}} + + # Ensure the environment marker is not set + with patch.dict(os.environ, {}, clear=True): + # create_cluster with manual config returns a ManualSpawner + cluster = create_cluster(config) + + assert isinstance(cluster, ManualSpawner) + assert cluster._world_size == 4 + + def test_create_cluster_auto_detection(self): + """Test create_cluster with auto-detection.""" + with patch.dict( + os.environ, + { + "SLURM_NTASKS": "4", + "SLURM_PROCID": "0", + "SLURM_LOCALID": "0", + "SLURM_NODELIST": "node001", + "SLURM_JOBID": "12345", + "SLURM_JOB_NAME": "test_job", + "MASTER_ADDR": "192.168.1.1", + "MASTER_PORT": "29500", + }, + ): + cluster = create_cluster({}) + + assert isinstance(cluster, SlurmCluster) + assert cluster.world_size == 4 + + def test_create_cluster_no_suitable_cluster(self): + """Test create_cluster when no suitable cluster found.""" + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(RuntimeError, match="No suitable cluster found"): + create_cluster({}) + + +class TestClusterBase: + """Tests for base Cluster class functionality.""" + + def test_cluster_init_method(self): + """Test cluster init_method property.""" + client = ManualClient( + world_size=4, + local_rank=0, + global_rank=0, + master_addr="localhost", + master_port=12345, + ) + + init_method = client.init_method + assert init_method.startswith("tcp://") + assert "localhost" in init_method + assert str(client.master_port) in init_method + + def test_cluster_backend_cuda(self, mock_context): + """Test cluster backend selection for CUDA.""" + mock_context.device.type = "cuda" + with patch("anemoi.inference.lazy.torch.cuda.is_available", return_value=True): + client = ManualClient( + world_size=4, + local_rank=0, + global_rank=0, + master_addr="localhost", + master_port=12345, + ) + assert client.backend == "nccl" + + def test_cluster_backend_cpu(self): + """Test cluster backend selection for CPU.""" + with patch("anemoi.inference.lazy.torch.cuda.is_available", return_value=False): + client = ManualClient( + world_size=4, + local_rank=0, + global_rank=0, + master_addr="localhost", + master_port=12345, + ) + assert client.backend == "gloo" + + def test_cluster_is_master(self): + """Test cluster is_master property.""" + client0 = ManualClient( + world_size=4, + local_rank=0, + global_rank=0, + master_addr="localhost", + master_port=12345, + ) + assert client0.is_master + + client1 = ManualClient( + world_size=4, + local_rank=1, + global_rank=1, + master_addr="localhost", + master_port=12345, + ) + assert not client1.is_master