Skip to content

Commit 28af941

Browse files
author
daihao
committed
single controller: add train controller
1 parent 6138e3a commit 28af941

File tree

5 files changed

+248
-21
lines changed

5 files changed

+248
-21
lines changed

areal/api/engine_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ class Scheduling:
2525
cpu: int
2626
gpu: int
2727
mem: int
28+
port_count: int
29+
cmd: str | None = None
2830
nodelist: str | None = None
2931
exclude: str | None = None
3032
partition: str | None = None
@@ -138,7 +140,7 @@ def parallelism_group(self) -> dist.ProcessGroup:
138140
"""
139141
raise NotImplementedError()
140142

141-
def get_scheduling_config(self) -> Scheduling:
143+
def get_scheduling_config(self) -> List[Scheduling]:
142144
"""Get the scheduling configuration for the engine.
143145
144146
This includes configuration such as container image, CPU/GPU/memory size.

areal/api/scheduler_api.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,40 @@
11
import abc
22
from dataclasses import dataclass, field
3-
from typing import Dict, List
3+
from typing import List, Literal
4+
5+
from areal.api.engine_api import Scheduling
46

57

68
@dataclass
79
class Worker:
810
id: str
911
ip: str
10-
ports: List[str] = field(default_factory=list)
11-
12-
13-
@dataclass
14-
class ContainerSpec:
15-
cpu: int = 0
16-
gpu: int = 0
17-
mem: int = 0
18-
container_image: str = ""
19-
cmd: str = ""
20-
env_vars: Dict[str, str] = field(default_factory=dict)
21-
port_count: int = 2
12+
serve_port: str
13+
extra_ports: List[str] = field(default_factory=list)
2214

2315

2416
@dataclass
2517
class ScheduleStrategy:
26-
type: str = ""
18+
type: Literal["colocation", "separation", ""] = ""
2719
uid: str = ""
2820

2921

3022
@dataclass
31-
class SchedulingConfig:
23+
class Job:
3224
replicas: int = 0
33-
specs: List[ContainerSpec] = field(default_factory=list)
25+
tasks: List[Scheduling] = field(default_factory=list)
3426
schedule_strategy: ScheduleStrategy | None = None
3527
role: str = ""
3628

3729

3830
class Scheduler(abc.ABC):
39-
def create_workers(self, worker_key, scheduler_config, *args, **kwargs) -> str:
31+
def create_workers(self, job: Job, *args, **kwargs):
4032
"""
41-
Start workers, return job id
33+
Start workers
4234
"""
35+
raise NotImplementedError()
4336

44-
def get_workers(self, worker_key, timeout=None) -> List[Worker]:
37+
def get_workers(self, role: str, timeout=None) -> List[Worker]:
4538
"""
4639
Wait and return worker list, including scheduling results such as ip and engine ports
4740
(worker id, ip, ports)
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
from concurrent.futures import ThreadPoolExecutor
2+
from functools import partial
3+
from typing import Any, Callable, Dict, List
4+
5+
import torch
6+
7+
from areal.api.alloc_mode import ParallelStrategy
8+
from areal.api.cli_args import TrainEngineConfig
9+
from areal.api.controller_api import DistributedBatch, TrainController
10+
from areal.api.engine_api import TrainEngine
11+
from areal.api.io_struct import (
12+
AllocationMode,
13+
FinetuneSpec,
14+
ParamSpec,
15+
SaveLoadMeta,
16+
WeightUpdateMeta,
17+
)
18+
from areal.api.scheduler_api import Job, Scheduler, ScheduleStrategy, Worker
19+
from areal.controller.utils import create_engine_with_retry, rpc_call
20+
from areal.utils import logging
21+
22+
logger = logging.getLogger("DistributedTrainController")
23+
24+
25+
class DistributedTrainController(TrainController):
26+
def __init__(
27+
self, train_engine: TrainEngine, config: TrainEngineConfig, scheduler: Scheduler
28+
):
29+
super().__init__(train_engine, config, scheduler)
30+
31+
self.role: str = "train"
32+
self.group_size = 0
33+
self.alloc_mode: AllocationMode
34+
self.uid: str
35+
self.workers: List[Worker]
36+
37+
# todo: delete this method
38+
39+
def create_process_group(self, parallel_strategy: ParallelStrategy | None = None):
40+
assert self.workers is not None
41+
rpc_call(
42+
self.scheduler, self.workers, "create_process_group", parallel_strategy
43+
)
44+
45+
def initialize(
46+
self,
47+
alloc_mode_str: str,
48+
ft_spec: FinetuneSpec,
49+
schedule_strategy: ScheduleStrategy,
50+
):
51+
"""Initialize environments for distributed training and load models."""
52+
self.alloc_mode = AllocationMode.from_str(alloc_mode_str)
53+
self.ft_spec = ft_spec
54+
55+
job = Job(
56+
replicas=self.alloc_mode.train.world_size,
57+
tasks=self.train_engine.get_scheduling_config(),
58+
schedule_strategy=schedule_strategy,
59+
role=self.role,
60+
)
61+
62+
logger.info(f"Start to create job: {job}")
63+
64+
self.uid = self.scheduler.create_workers(job)
65+
66+
self.workers = self.scheduler.get_workers(self.role, timeout=1800)
67+
68+
with ThreadPoolExecutor(max_workers=len(self.workers)) as executor:
69+
futures = [
70+
executor.submit(
71+
partial(
72+
create_engine_with_retry,
73+
self.scheduler.create_engine,
74+
worker.id,
75+
self.train_engine,
76+
None,
77+
self.ft_spec,
78+
)
79+
)
80+
for index, worker in enumerate(self.workers)
81+
]
82+
try:
83+
for worker_index, future in enumerate(futures):
84+
rank_info = future.result()
85+
self.rank_info[worker_index] = rank_info
86+
logger.info(f"worker_index: {worker_index}, rank_info: {rank_info}")
87+
except Exception as e:
88+
raise RuntimeError(
89+
f"Failed to initialize worker_index: {worker_index}, error: {e}"
90+
)
91+
92+
def destroy(self):
93+
self.scheduler.delete_workers()
94+
95+
def train(self, mode: bool = True):
96+
pass
97+
98+
def upload_weights(self, meta: WeightUpdateMeta):
99+
pass
100+
101+
def get_param_specs(
102+
self, weight_chunked_mem_mb: int = 1024
103+
) -> List[List[ParamSpec]]:
104+
pass
105+
106+
def set_version(self, version: int):
107+
pass
108+
109+
def get_version(self) -> int:
110+
pass
111+
112+
def save(self, meta: SaveLoadMeta):
113+
pass
114+
115+
def load(self, meta: SaveLoadMeta):
116+
pass
117+
118+
def step_lr_scheduler(self):
119+
pass
120+
121+
def train_batch(
122+
self,
123+
input_: DistributedBatch,
124+
loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor],
125+
loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor],
126+
) -> Dict[str, float]:
127+
pass
128+
129+
def eval_batch(
130+
self,
131+
input_: DistributedBatch,
132+
loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor],
133+
loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor],
134+
) -> torch.Tensor | None:
135+
pass
136+
137+
def forward(
138+
self,
139+
input_: DistributedBatch,
140+
output_seqlens: List[int] | None = None,
141+
post_hook: Callable[[torch.Tensor, Dict[str, Any]], Any] | None = None,
142+
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
143+
) -> Any | None:
144+
pass

areal/controller/utils.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import time
2+
from concurrent.futures import ThreadPoolExecutor
3+
from typing import Any, List
4+
5+
from requests.exceptions import ConnectionError
6+
7+
from areal.api.scheduler_api import Scheduler, Worker
8+
from areal.utils import logging
9+
from areal.utils.http import wait_future_ordered
10+
11+
logger = logging.getLogger("ControllerUtil")
12+
13+
14+
def create_engine_with_retry(
15+
create_engine_func, max_retries=60, retry_delay=10, *args, **kwargs
16+
):
17+
logger.info(
18+
f"Create engine with retry: {max_retries}, {retry_delay}, {args}, {kwargs}"
19+
)
20+
retries = 0
21+
while retries < max_retries:
22+
try:
23+
return create_engine_func(*args, **kwargs)
24+
except ConnectionError as e:
25+
logger.info(
26+
f"Worker is not ready, exception: {e}, retrying in {retry_delay} seconds..."
27+
)
28+
time.sleep(retry_delay)
29+
retries += 1
30+
except Exception as e:
31+
logger.error(f"Connection failed: {e}. unknown exception")
32+
raise e
33+
34+
raise RuntimeError("Failed to connect to remote service after maximum retries.")
35+
36+
37+
def rpc_call(
38+
scheduler: Scheduler, workers: List[Worker], method: str, *args, **kwargs
39+
) -> List[Any]:
40+
"""
41+
工具方法:并发RPC调用
42+
43+
:param scheduler: 调度器对象, 必须有 call_engine(worker_id, method, *args, **kwargs)
44+
:param workers: 可遍历的worker列表,每个worker应有 worker.id 属性
45+
:param method: 方法名字符串
46+
:param args: 传递给call_engine的*args
47+
:param kwargs: 传递给call_engine的**kwargs
48+
:return: results
49+
"""
50+
logger.info(f"Start to rpc call, method: {method}, args: {args}, kwargs: {kwargs}")
51+
with ThreadPoolExecutor(max_workers=len(workers)) as executor:
52+
futures = [
53+
executor.submit(scheduler.call_engine, worker.id, method, *args, **kwargs)
54+
for worker in workers
55+
]
56+
try:
57+
results = wait_future_ordered(futures, exit_on_exception=True)
58+
except Exception as e:
59+
raise RuntimeError(f"{method} failed, error: {e}")
60+
61+
return results

areal/utils/http.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import asyncio
2+
import os
3+
import signal
4+
import traceback
5+
from concurrent.futures import Future, as_completed
26
from http import HTTPStatus
3-
from typing import Any, Dict, Optional
7+
from typing import Any, Dict, List, Optional
48

59
import aiohttp
610

@@ -96,3 +100,26 @@ def response_ok(http_code: int) -> bool:
96100

97101
def response_retryable(http_code: int) -> bool:
98102
return http_code == HTTPStatus.REQUEST_TIMEOUT
103+
104+
105+
def wait_future_ordered(
106+
futures: List[Future], exit_on_exception: bool = False
107+
) -> List[Any]:
108+
"""
109+
按照提交顺序等待future完成,返回结果列表
110+
"""
111+
results = [None] * len(futures)
112+
future_index_map = {future: i for i, future in enumerate(futures)}
113+
for future in as_completed(futures):
114+
index = future_index_map[future]
115+
try:
116+
results[index] = future.result()
117+
except Exception as e:
118+
logger.warning(f"Exception caught when waiting for future: {e}")
119+
logger.warning(traceback.format_exc())
120+
if exit_on_exception:
121+
logger.info("Exiting due to exception in future.")
122+
os.kill(os.getpid(), signal.SIGTERM)
123+
else:
124+
raise e
125+
return results

0 commit comments

Comments
 (0)