Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions areal/api/controller_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def set_version(self, version: int):
"""
raise NotImplementedError()

def get_version(self) -> int:
def get_version(self) -> List[int]:
"""Get the current weight version in the training engine.

Returns
Expand Down Expand Up @@ -359,7 +359,7 @@ def train_batch(
input_: DistributedBatch,
loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor],
loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor],
) -> Dict[str, float]:
) -> List[Dict[str, float]]:
"""Update the model with a batch of data and a loss function.

Note
Expand All @@ -382,7 +382,7 @@ def train_batch(

Returns
-------
Dict[str, float]
List[Dict[str, float]]
Scalar statistics after training, e.g., the current learning rate,
gradient norm, etc.
"""
Expand All @@ -394,7 +394,7 @@ def eval_batch(
input_: DistributedBatch,
loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor],
loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor],
) -> torch.Tensor | None:
) -> List[torch.Tensor]:
"""Evaluate the model using the forward pass and loss function.

Note
Expand Down
4 changes: 3 additions & 1 deletion areal/api/engine_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class Scheduling:
cpu: int
gpu: int
mem: int
port_count: int
cmd: str | None = None
nodelist: str | None = None
exclude: str | None = None
partition: str | None = None
Expand Down Expand Up @@ -138,7 +140,7 @@ def parallelism_group(self) -> dist.ProcessGroup:
"""
raise NotImplementedError()

def get_scheduling_config(self) -> Scheduling:
def get_scheduling_config(self) -> List[Scheduling]:
"""Get the scheduling configuration for the engine.

This includes configuration such as container image, CPU/GPU/memory size.
Expand Down
31 changes: 12 additions & 19 deletions areal/api/scheduler_api.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,40 @@
import abc
from dataclasses import dataclass, field
from typing import Dict, List
from typing import List, Literal

from areal.api.engine_api import Scheduling


@dataclass
class Worker:
id: str
ip: str
ports: List[str] = field(default_factory=list)


@dataclass
class ContainerSpec:
cpu: int = 0
gpu: int = 0
mem: int = 0
container_image: str = ""
cmd: str = ""
env_vars: Dict[str, str] = field(default_factory=dict)
port_count: int = 2
serve_port: str
extra_ports: List[str] = field(default_factory=list)


@dataclass
class ScheduleStrategy:
type: str = ""
type: Literal["colocation", "separation", ""] = ""
uid: str = ""


@dataclass
class SchedulingConfig:
class Job:
replicas: int = 0
specs: List[ContainerSpec] = field(default_factory=list)
tasks: List[Scheduling] = field(default_factory=list)
schedule_strategy: ScheduleStrategy | None = None
role: str = ""


class Scheduler(abc.ABC):
def create_workers(self, worker_key, scheduler_config, *args, **kwargs) -> str:
def create_workers(self, job: Job, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The create_workers method is missing a return type annotation. Based on its usage in areal/controller/train_controller.py (where its return value is assigned to self.uid), it is expected to return a string representing the job ID. Adding the -> str type hint improves code clarity and enables better static analysis.

Suggested change
def create_workers(self, job: Job, *args, **kwargs):
def create_workers(self, job: Job, *args, **kwargs) -> str:

"""
Start workers, return job id
Start workers
"""
raise NotImplementedError()

def get_workers(self, worker_key, timeout=None) -> List[Worker]:
def get_workers(self, role: str, timeout=None) -> List[Worker]:
"""
Wait and return worker list, including scheduling results such as ip and engine ports
(worker id, ip, ports)
Expand Down
193 changes: 193 additions & 0 deletions areal/controller/train_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any, Callable, Dict, List

import torch

from areal.api.alloc_mode import ParallelStrategy
from areal.api.cli_args import TrainEngineConfig
from areal.api.controller_api import DistributedBatch, TrainController
from areal.api.engine_api import TrainEngine
from areal.api.io_struct import (
AllocationMode,
FinetuneSpec,
ParamSpec,
SaveLoadMeta,
WeightUpdateMeta,
)
from areal.api.scheduler_api import Job, Scheduler, ScheduleStrategy, Worker
from areal.controller.utils import create_engine_with_retry, rpc_call
from areal.utils import logging
from areal.utils.http import wait_future_ordered

logger = logging.getLogger("DistributedTrainController")


class DistributedTrainController(TrainController):
def __init__(
self, train_engine: TrainEngine, config: TrainEngineConfig, scheduler: Scheduler
):
super().__init__(train_engine, config, scheduler)

self.role: str = "train"
self.group_size: int
self.alloc_mode: AllocationMode
self.workers: List[Worker]
self.engine_dp_ranks: List[int]

def create_process_group(self, parallel_strategy: ParallelStrategy | None = None):
assert self.workers is not None, "Workers are not created"
self.custom_function_call("create_process_group", parallel_strategy)

def initialize(
self,
alloc_mode_str: str,
ft_spec: FinetuneSpec,
schedule_strategy: ScheduleStrategy,
group_size: int = 1,
):
"""Initialize environments for distributed training and load models."""
self.alloc_mode = AllocationMode.from_str(alloc_mode_str)
self.ft_spec = ft_spec
self.group_size = group_size

job = Job(
replicas=self.alloc_mode.train.world_size,
tasks=self.train_engine.get_scheduling_config(),
schedule_strategy=schedule_strategy,
role=self.role,
)
logger.info(f"Start to create job: {job}")
self.scheduler.create_workers(job)
# after get workers, all rpc server is ready
self.workers = self.scheduler.get_workers(self.role, timeout=1800)

logger.info(f"Start to create process group")
self.create_process_group(self.alloc_mode.train)

logger.info(f"Start to initialize engine")
with ThreadPoolExecutor(max_workers=len(self.workers)) as executor:
futures = [
executor.submit(
partial(
create_engine_with_retry,
self.scheduler.create_engine,
worker.id,
self.train_engine,
None,
self.ft_spec,
)
)
for worker in self.workers
]

wait_future_ordered(futures, exit_on_exception=True)

logger.info(f"Start to get rank info from engine")
self.engine_dp_ranks = rpc_call(
self.scheduler, self.workers, "data_parallel_rank"
)
logger.info(f"Initialize train engines succeeded!")

def destroy(self):
self.scheduler.delete_workers()

def train(self, mode: bool = True):
self.custom_function_call("train", mode)

def upload_weights(self, meta: WeightUpdateMeta):
self.custom_function_call("upload_weights", meta)

def get_param_specs(
self, weight_chunked_mem_mb: int = 1024
) -> List[List[ParamSpec]]:
ret: List[List[List[ParamSpec]]] = self.custom_function_call(
"get_param_specs", weight_chunked_mem_mb
)
flattened = [inner for outer in ret for inner in outer]
return flattened

def set_version(self, version: int):
return self.custom_function_call("set_version", version)

def get_version(self) -> List[int]:
return self.custom_function_call("get_version")

def save(self, meta: SaveLoadMeta):
self.custom_function_call("save", meta)

def load(self, meta: SaveLoadMeta):
self.custom_function_call("load", meta)

def step_lr_scheduler(self):
self.custom_function_call("step_lr_scheduler")

def custom_function_call(self, method: str, *args, **kwargs):
return rpc_call(self.scheduler, self.workers, method, None, args, kwargs)

def _align_batches_with_dp(
self, input_: DistributedBatch, rebalance=True
) -> List[DistributedBatch]:
if rebalance:
inputs = input_.chunk_by_ffd(self.group_size, self.alloc_mode.train.dp_size)
else:
inputs = input_.chunk(self.alloc_mode.train.dp_size)

batches = []
for dp_rank in self.engine_dp_ranks:
batches.append(inputs[dp_rank])

return batches

def train_batch(
self,
input_: DistributedBatch,
loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor],
loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor],
) -> List[Dict[str, float]]:

batches = self._align_batches_with_dp(input_, True)
train_stats = rpc_call(
self.scheduler,
self.workers,
"train_batch",
batches,
loss_fn,
loss_weight_fn,
)

return train_stats

def eval_batch(
self,
input_: DistributedBatch,
loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor],
loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor],
) -> List[torch.Tensor]:

batches = self._align_batches_with_dp(input_, True)
eval_stats = rpc_call(
self.scheduler, self.workers, "eval_batch", batches, loss_fn, loss_weight_fn
)

return eval_stats

def forward(
self,
input_: DistributedBatch,
output_seqlens: List[int] | None = None,
post_hook: Callable[[torch.Tensor, Dict[str, Any]], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
) -> List[Any]:
batches = self._align_batches_with_dp(input_, False)
forward_stats = rpc_call(
self.scheduler,
self.workers,
"forward",
batches,
output_seqlens,
post_hook,
aggregate_fn,
)

return forward_stats
Loading
Loading