diff --git a/areal/api/controller_api.py b/areal/api/controller_api.py index 1df593161..5ce5ec5dc 100644 --- a/areal/api/controller_api.py +++ b/areal/api/controller_api.py @@ -315,7 +315,7 @@ def set_version(self, version: int): """ raise NotImplementedError() - def get_version(self) -> int: + def get_version(self) -> List[int]: """Get the current weight version in the training engine. Returns @@ -359,7 +359,7 @@ def train_batch( input_: DistributedBatch, loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor], loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor], - ) -> Dict[str, float]: + ) -> List[Dict[str, float]]: """Update the model with a batch of data and a loss function. Note @@ -382,7 +382,7 @@ def train_batch( Returns ------- - Dict[str, float] + List[Dict[str, float]] Scalar statistics after training, e.g., the current learning rate, gradient norm, etc. """ @@ -394,7 +394,7 @@ def eval_batch( input_: DistributedBatch, loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor], loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor], - ) -> torch.Tensor | None: + ) -> List[torch.Tensor]: """Evaluate the model using the forward pass and loss function. Note diff --git a/areal/api/engine_api.py b/areal/api/engine_api.py index 021c0a2ea..cb82552dc 100644 --- a/areal/api/engine_api.py +++ b/areal/api/engine_api.py @@ -25,6 +25,8 @@ class Scheduling: cpu: int gpu: int mem: int + port_count: int + cmd: str | None = None nodelist: str | None = None exclude: str | None = None partition: str | None = None @@ -138,7 +140,7 @@ def parallelism_group(self) -> dist.ProcessGroup: """ raise NotImplementedError() - def get_scheduling_config(self) -> Scheduling: + def get_scheduling_config(self) -> List[Scheduling]: """Get the scheduling configuration for the engine. This includes configuration such as container image, CPU/GPU/memory size. diff --git a/areal/api/scheduler_api.py b/areal/api/scheduler_api.py index f7e9fb941..cab700516 100644 --- a/areal/api/scheduler_api.py +++ b/areal/api/scheduler_api.py @@ -1,47 +1,40 @@ import abc from dataclasses import dataclass, field -from typing import Dict, List +from typing import List, Literal + +from areal.api.engine_api import Scheduling @dataclass class Worker: id: str ip: str - ports: List[str] = field(default_factory=list) - - -@dataclass -class ContainerSpec: - cpu: int = 0 - gpu: int = 0 - mem: int = 0 - container_image: str = "" - cmd: str = "" - env_vars: Dict[str, str] = field(default_factory=dict) - port_count: int = 2 + serve_port: str + extra_ports: List[str] = field(default_factory=list) @dataclass class ScheduleStrategy: - type: str = "" + type: Literal["colocation", "separation", ""] = "" uid: str = "" @dataclass -class SchedulingConfig: +class Job: replicas: int = 0 - specs: List[ContainerSpec] = field(default_factory=list) + tasks: List[Scheduling] = field(default_factory=list) schedule_strategy: ScheduleStrategy | None = None role: str = "" class Scheduler(abc.ABC): - def create_workers(self, worker_key, scheduler_config, *args, **kwargs) -> str: + def create_workers(self, job: Job, *args, **kwargs): """ - Start workers, return job id + Start workers """ + raise NotImplementedError() - def get_workers(self, worker_key, timeout=None) -> List[Worker]: + def get_workers(self, role: str, timeout=None) -> List[Worker]: """ Wait and return worker list, including scheduling results such as ip and engine ports (worker id, ip, ports) diff --git a/areal/controller/train_controller.py b/areal/controller/train_controller.py new file mode 100644 index 000000000..99ed02dcb --- /dev/null +++ b/areal/controller/train_controller.py @@ -0,0 +1,193 @@ +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from typing import Any, Callable, Dict, List + +import torch + +from areal.api.alloc_mode import ParallelStrategy +from areal.api.cli_args import TrainEngineConfig +from areal.api.controller_api import DistributedBatch, TrainController +from areal.api.engine_api import TrainEngine +from areal.api.io_struct import ( + AllocationMode, + FinetuneSpec, + ParamSpec, + SaveLoadMeta, + WeightUpdateMeta, +) +from areal.api.scheduler_api import Job, Scheduler, ScheduleStrategy, Worker +from areal.controller.utils import create_engine_with_retry, rpc_call +from areal.utils import logging +from areal.utils.http import wait_future_ordered + +logger = logging.getLogger("DistributedTrainController") + + +class DistributedTrainController(TrainController): + def __init__( + self, train_engine: TrainEngine, config: TrainEngineConfig, scheduler: Scheduler + ): + super().__init__(train_engine, config, scheduler) + + self.role: str = "train" + self.group_size: int + self.alloc_mode: AllocationMode + self.workers: List[Worker] + self.engine_dp_ranks: List[int] + + def create_process_group(self, parallel_strategy: ParallelStrategy | None = None): + assert self.workers is not None, "Workers are not created" + self.custom_function_call("create_process_group", parallel_strategy) + + def initialize( + self, + alloc_mode_str: str, + ft_spec: FinetuneSpec, + schedule_strategy: ScheduleStrategy, + group_size: int = 1, + ): + """Initialize environments for distributed training and load models.""" + self.alloc_mode = AllocationMode.from_str(alloc_mode_str) + self.ft_spec = ft_spec + self.group_size = group_size + + job = Job( + replicas=self.alloc_mode.train.world_size, + tasks=self.train_engine.get_scheduling_config(), + schedule_strategy=schedule_strategy, + role=self.role, + ) + logger.info(f"Start to create job: {job}") + self.scheduler.create_workers(job) + # after get workers, all rpc server is ready + self.workers = self.scheduler.get_workers(self.role, timeout=1800) + + logger.info(f"Start to create process group") + self.create_process_group(self.alloc_mode.train) + + logger.info(f"Start to initialize engine") + with ThreadPoolExecutor(max_workers=len(self.workers)) as executor: + futures = [ + executor.submit( + partial( + create_engine_with_retry, + self.scheduler.create_engine, + worker.id, + self.train_engine, + None, + self.ft_spec, + ) + ) + for worker in self.workers + ] + + wait_future_ordered(futures, exit_on_exception=True) + + logger.info(f"Start to get rank info from engine") + self.engine_dp_ranks = rpc_call( + self.scheduler, self.workers, "data_parallel_rank" + ) + logger.info(f"Initialize train engines succeeded!") + + def destroy(self): + self.scheduler.delete_workers() + + def train(self, mode: bool = True): + self.custom_function_call("train", mode) + + def upload_weights(self, meta: WeightUpdateMeta): + self.custom_function_call("upload_weights", meta) + + def get_param_specs( + self, weight_chunked_mem_mb: int = 1024 + ) -> List[List[ParamSpec]]: + ret: List[List[List[ParamSpec]]] = self.custom_function_call( + "get_param_specs", weight_chunked_mem_mb + ) + flattened = [inner for outer in ret for inner in outer] + return flattened + + def set_version(self, version: int): + return self.custom_function_call("set_version", version) + + def get_version(self) -> List[int]: + return self.custom_function_call("get_version") + + def save(self, meta: SaveLoadMeta): + self.custom_function_call("save", meta) + + def load(self, meta: SaveLoadMeta): + self.custom_function_call("load", meta) + + def step_lr_scheduler(self): + self.custom_function_call("step_lr_scheduler") + + def custom_function_call(self, method: str, *args, **kwargs): + return rpc_call(self.scheduler, self.workers, method, None, args, kwargs) + + def _align_batches_with_dp( + self, input_: DistributedBatch, rebalance=True + ) -> List[DistributedBatch]: + if rebalance: + inputs = input_.chunk_by_ffd(self.group_size, self.alloc_mode.train.dp_size) + else: + inputs = input_.chunk(self.alloc_mode.train.dp_size) + + batches = [] + for dp_rank in self.engine_dp_ranks: + batches.append(inputs[dp_rank]) + + return batches + + def train_batch( + self, + input_: DistributedBatch, + loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor], + loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor], + ) -> List[Dict[str, float]]: + + batches = self._align_batches_with_dp(input_, True) + train_stats = rpc_call( + self.scheduler, + self.workers, + "train_batch", + batches, + loss_fn, + loss_weight_fn, + ) + + return train_stats + + def eval_batch( + self, + input_: DistributedBatch, + loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor], + loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor], + ) -> List[torch.Tensor]: + + batches = self._align_batches_with_dp(input_, True) + eval_stats = rpc_call( + self.scheduler, self.workers, "eval_batch", batches, loss_fn, loss_weight_fn + ) + + return eval_stats + + def forward( + self, + input_: DistributedBatch, + output_seqlens: List[int] | None = None, + post_hook: Callable[[torch.Tensor, Dict[str, Any]], Any] | None = None, + aggregate_fn: Callable[[List[Any]], Any] = torch.cat, + ) -> List[Any]: + batches = self._align_batches_with_dp(input_, False) + forward_stats = rpc_call( + self.scheduler, + self.workers, + "forward", + batches, + output_seqlens, + post_hook, + aggregate_fn, + ) + + return forward_stats diff --git a/areal/controller/utils.py b/areal/controller/utils.py new file mode 100644 index 000000000..63b427f56 --- /dev/null +++ b/areal/controller/utils.py @@ -0,0 +1,95 @@ +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Any, List, Optional + +from requests.exceptions import ConnectionError + +from areal.api.scheduler_api import Scheduler, Worker +from areal.utils import logging +from areal.utils.http import wait_future_ordered + +logger = logging.getLogger("ControllerUtil") + + +def create_engine_with_retry( + create_engine_func, max_retries=60, retry_delay=10, *args, **kwargs +): + """ + Attempts to create an engine with retry logic. + :param create_engine_func: Callable function for creating the engine. + :param max_retries: Maximum number of retries before giving up. + :param retry_delay: Seconds to wait between retries. + :param args: Positional arguments to pass to create_engine_func. + :param kwargs: Keyword arguments to pass to create_engine_func. + :return: Engine instance created by create_engine_func. + :raises RuntimeError: If maximum retries are reached and connection still fails. + """ + logger.info( + f"Create engine with retry: {max_retries}, {retry_delay}, {args}, {kwargs}" + ) + retries = 0 + while retries < max_retries: + try: + return create_engine_func(*args, **kwargs) + except ConnectionError as e: + logger.info( + f"Worker is not ready, exception: {e}, retrying in {retry_delay} seconds..." + ) + time.sleep(retry_delay) + retries += 1 + except Exception as e: + logger.error(f"Connection failed: {e}. unknown exception") + raise e + + raise RuntimeError("Failed to connect to remote service after maximum retries.") + + +def rpc_call( + scheduler: Scheduler, + workers: List[Worker], + method: str, + batches: Optional[List[Any]] = None, + *args, + **kwargs, +) -> List[Any]: + """ + Utility method: Perform concurrent RPC calls to multiple workers. + :param scheduler: Scheduler object with a call_engine(worker_id, method, *args, **kwargs) method. + :param workers: List of worker instances. Each worker must have an 'id' attribute. + :param method: Name of the method to invoke on each worker. + :param batches: Optional list of batches, each batch is passed to the corresponding worker. + If provided, its length must match the number of workers. + :param args: Positional arguments to pass to call_engine. + :param kwargs: Keyword arguments to pass to call_engine. + :return: List of results returned in the order of workers. + :raises ValueError: If the batches parameter is provided but its length does not match the number of workers. + :raises RuntimeError: If any exception occurs during RPC execution. + """ + + if batches is not None and len(batches) != len(workers): + raise ValueError( + f"Batches length ({len(batches)}) must match workers count ({len(workers)})" + ) + logger.info(f"Start to rpc call, method: {method}") + + with ThreadPoolExecutor(max_workers=len(workers)) as executor: + futures = [] + for i, worker in enumerate(workers): + # 构建调用参数 + if batches is not None: + # 当有batch参数时:将batch作为第一位置参数 + worker_args = (batches[i],) + args + future = executor.submit( + scheduler.call_engine, worker.id, method, *worker_args, **kwargs + ) + else: + future = executor.submit( + scheduler.call_engine, worker.id, method, *args, **kwargs + ) + futures.append(future) + try: + results = wait_future_ordered(futures, exit_on_exception=True) + except Exception as e: + raise RuntimeError(f"{method} failed, error: {e}") + + return results diff --git a/areal/utils/http.py b/areal/utils/http.py index 5dc88f1df..140e7e474 100644 --- a/areal/utils/http.py +++ b/areal/utils/http.py @@ -1,6 +1,10 @@ import asyncio +import os +import signal +import traceback +from concurrent.futures import Future, as_completed from http import HTTPStatus -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import aiohttp @@ -96,3 +100,31 @@ def response_ok(http_code: int) -> bool: def response_retryable(http_code: int) -> bool: return http_code == HTTPStatus.REQUEST_TIMEOUT + + +def wait_future_ordered( + futures: List[Future], exit_on_exception: bool = False +) -> List[Any]: + """ + Waits for a list of futures to complete and returns the results in the order the futures were submitted. + :param futures: List of Future objects to wait for. + :param exit_on_exception: If True, terminate the process upon an exception in any future. + If False, raise the exception. + :return: List of results in the same order as the input futures. + :raises Exception: If exit_on_exception is False and any future raises an exception. + """ + results = [None] * len(futures) + future_index_map = {future: i for i, future in enumerate(futures)} + for future in as_completed(futures): + index = future_index_map[future] + try: + results[index] = future.result() + except Exception as e: + logger.warning(f"Exception caught when waiting for future: {e}") + logger.warning(traceback.format_exc()) + if exit_on_exception: + logger.info("Exiting due to exception in future.") + os.kill(os.getpid(), signal.SIGTERM) + else: + raise e + return results