| 
 | 1 | +from concurrent.futures import ThreadPoolExecutor  | 
 | 2 | +from functools import partial  | 
 | 3 | +from typing import Any, Callable, Dict, List  | 
 | 4 | + | 
 | 5 | +import torch  | 
 | 6 | + | 
 | 7 | +from areal.api.alloc_mode import ParallelStrategy  | 
 | 8 | +from areal.api.cli_args import TrainEngineConfig  | 
 | 9 | +from areal.api.controller_api import DistributedBatch, TrainController  | 
 | 10 | +from areal.api.engine_api import TrainEngine  | 
 | 11 | +from areal.api.io_struct import (  | 
 | 12 | +    AllocationMode,  | 
 | 13 | +    FinetuneSpec,  | 
 | 14 | +    ParamSpec,  | 
 | 15 | +    SaveLoadMeta,  | 
 | 16 | +    WeightUpdateMeta,  | 
 | 17 | +)  | 
 | 18 | +from areal.api.scheduler_api import Job, Scheduler, ScheduleStrategy, Worker  | 
 | 19 | +from areal.controller.utils import create_engine_with_retry, rpc_call  | 
 | 20 | +from areal.utils import logging  | 
 | 21 | + | 
 | 22 | +logger = logging.getLogger("DistributedTrainController")  | 
 | 23 | + | 
 | 24 | + | 
 | 25 | +class DistributedTrainController(TrainController):  | 
 | 26 | +    def __init__(  | 
 | 27 | +        self, train_engine: TrainEngine, config: TrainEngineConfig, scheduler: Scheduler  | 
 | 28 | +    ):  | 
 | 29 | +        super().__init__(train_engine, config, scheduler)  | 
 | 30 | + | 
 | 31 | +        self.role: str = "train"  | 
 | 32 | +        self.group_size = 0  | 
 | 33 | +        self.alloc_mode: AllocationMode  | 
 | 34 | +        self.uid: str  | 
 | 35 | +        self.workers: List[Worker]  | 
 | 36 | + | 
 | 37 | +        # todo: delete this method  | 
 | 38 | + | 
 | 39 | +    def create_process_group(self, parallel_strategy: ParallelStrategy | None = None):  | 
 | 40 | +        assert self.workers is not None  | 
 | 41 | +        rpc_call(  | 
 | 42 | +            self.scheduler, self.workers, "create_process_group", parallel_strategy  | 
 | 43 | +        )  | 
 | 44 | + | 
 | 45 | +    def initialize(  | 
 | 46 | +        self,  | 
 | 47 | +        alloc_mode_str: str,  | 
 | 48 | +        ft_spec: FinetuneSpec,  | 
 | 49 | +        schedule_strategy: ScheduleStrategy,  | 
 | 50 | +    ):  | 
 | 51 | +        """Initialize environments for distributed training and load models."""  | 
 | 52 | +        self.alloc_mode = AllocationMode.from_str(alloc_mode_str)  | 
 | 53 | +        self.ft_spec = ft_spec  | 
 | 54 | + | 
 | 55 | +        job = Job(  | 
 | 56 | +            replicas=self.alloc_mode.train.world_size,  | 
 | 57 | +            tasks=self.train_engine.get_scheduling_config(),  | 
 | 58 | +            schedule_strategy=schedule_strategy,  | 
 | 59 | +            role=self.role,  | 
 | 60 | +        )  | 
 | 61 | + | 
 | 62 | +        logger.info(f"Start to create job: {job}")  | 
 | 63 | + | 
 | 64 | +        self.uid = self.scheduler.create_workers(job)  | 
 | 65 | + | 
 | 66 | +        self.workers = self.scheduler.get_workers(self.role, timeout=1800)  | 
 | 67 | + | 
 | 68 | +        with ThreadPoolExecutor(max_workers=len(self.workers)) as executor:  | 
 | 69 | +            futures = [  | 
 | 70 | +                executor.submit(  | 
 | 71 | +                    partial(  | 
 | 72 | +                        create_engine_with_retry,  | 
 | 73 | +                        self.scheduler.create_engine,  | 
 | 74 | +                        worker.id,  | 
 | 75 | +                        self.train_engine,  | 
 | 76 | +                        None,  | 
 | 77 | +                        self.ft_spec,  | 
 | 78 | +                    )  | 
 | 79 | +                )  | 
 | 80 | +                for index, worker in enumerate(self.workers)  | 
 | 81 | +            ]  | 
 | 82 | +            try:  | 
 | 83 | +                for worker_index, future in enumerate(futures):  | 
 | 84 | +                    rank_info = future.result()  | 
 | 85 | +                    self.rank_info[worker_index] = rank_info  | 
 | 86 | +                    logger.info(f"worker_index: {worker_index}, rank_info: {rank_info}")  | 
 | 87 | +            except Exception as e:  | 
 | 88 | +                raise RuntimeError(  | 
 | 89 | +                    f"Failed to initialize worker_index: {worker_index}, error: {e}"  | 
 | 90 | +                )  | 
 | 91 | + | 
 | 92 | +    def destroy(self):  | 
 | 93 | +        self.scheduler.delete_workers()  | 
 | 94 | + | 
 | 95 | +    def train(self, mode: bool = True):  | 
 | 96 | +        pass  | 
 | 97 | + | 
 | 98 | +    def upload_weights(self, meta: WeightUpdateMeta):  | 
 | 99 | +        pass  | 
 | 100 | + | 
 | 101 | +    def get_param_specs(  | 
 | 102 | +        self, weight_chunked_mem_mb: int = 1024  | 
 | 103 | +    ) -> List[List[ParamSpec]]:  | 
 | 104 | +        pass  | 
 | 105 | + | 
 | 106 | +    def set_version(self, version: int):  | 
 | 107 | +        pass  | 
 | 108 | + | 
 | 109 | +    def get_version(self) -> int:  | 
 | 110 | +        pass  | 
 | 111 | + | 
 | 112 | +    def save(self, meta: SaveLoadMeta):  | 
 | 113 | +        pass  | 
 | 114 | + | 
 | 115 | +    def load(self, meta: SaveLoadMeta):  | 
 | 116 | +        pass  | 
 | 117 | + | 
 | 118 | +    def step_lr_scheduler(self):  | 
 | 119 | +        pass  | 
 | 120 | + | 
 | 121 | +    def train_batch(  | 
 | 122 | +        self,  | 
 | 123 | +        input_: DistributedBatch,  | 
 | 124 | +        loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor],  | 
 | 125 | +        loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor],  | 
 | 126 | +    ) -> Dict[str, float]:  | 
 | 127 | +        pass  | 
 | 128 | + | 
 | 129 | +    def eval_batch(  | 
 | 130 | +        self,  | 
 | 131 | +        input_: DistributedBatch,  | 
 | 132 | +        loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor],  | 
 | 133 | +        loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor],  | 
 | 134 | +    ) -> torch.Tensor | None:  | 
 | 135 | +        pass  | 
 | 136 | + | 
 | 137 | +    def forward(  | 
 | 138 | +        self,  | 
 | 139 | +        input_: DistributedBatch,  | 
 | 140 | +        output_seqlens: List[int] | None = None,  | 
 | 141 | +        post_hook: Callable[[torch.Tensor, Dict[str, Any]], Any] | None = None,  | 
 | 142 | +        aggregate_fn: Callable[[List[Any]], Any] = torch.cat,  | 
 | 143 | +    ) -> Any | None:  | 
 | 144 | +        pass  | 
0 commit comments