diff --git a/areal/api/scheduler_api.py b/areal/api/scheduler_api.py index f7e9fb941..b381bbef3 100644 --- a/areal/api/scheduler_api.py +++ b/areal/api/scheduler_api.py @@ -36,10 +36,11 @@ class SchedulingConfig: class Scheduler(abc.ABC): - def create_workers(self, worker_key, scheduler_config, *args, **kwargs) -> str: + def create_workers(self, worker_key, scheduler_config, *args, **kwargs) -> None: """ Start workers, return job id """ + raise NotImplementedError() def get_workers(self, worker_key, timeout=None) -> List[Worker]: """ diff --git a/areal/api/workflow_api.py b/areal/api/workflow_api.py index 22d4facde..bd4fbe047 100644 --- a/areal/api/workflow_api.py +++ b/areal/api/workflow_api.py @@ -330,14 +330,11 @@ async def _rollout_thread_async(self): try: while not self.exiting.is_set(): # Check capacity - capacity = self.get_capacity() + # capacity = self.get_capacity() + # self.logger.info(f"Current rollout capacity: {capacity}") # Create new rollout task self.lock.acquire() - while ( - capacity > 0 - and not self.paused.is_set() - and self.input_queue.qsize() > 0 - ): + while not self.paused.is_set() and self.input_queue.qsize() > 0: x = self.input_queue.get_nowait() x: _RolloutTaskInput self.logger.debug(f"Get data from puller: {x.data}") @@ -357,7 +354,7 @@ async def _rollout_thread_async(self): f"running: {self.rollout_stat.running}, " f"accepted: {self.rollout_stat.accepted}." ) - capacity -= 1 + # capacity -= 1 rid += 1 tasks = [x.task for x in rollout_tasks.values()] self.lock.release() diff --git a/areal/engine/base_hf_engine.py b/areal/engine/base_hf_engine.py index f6095a18a..0f25bce47 100644 --- a/areal/engine/base_hf_engine.py +++ b/areal/engine/base_hf_engine.py @@ -73,7 +73,7 @@ def __init__(self, config: TrainEngineConfig): ) self.is_vision_model = is_valid_vision_model(self.model_config.model_type) - self.world_size = int(os.environ["WORLD_SIZE"]) + self.world_size: int def set_version(self, version: int): self._version = version diff --git a/areal/engine/fsdp_engine.py b/areal/engine/fsdp_engine.py index e1a1a6b7f..45eed3159 100644 --- a/areal/engine/fsdp_engine.py +++ b/areal/engine/fsdp_engine.py @@ -121,9 +121,17 @@ def create_process_group(self, parallel_strategy: ParallelStrategy | None = None self.dp_head = int(self.world_mesh["sp_tp"].mesh[0].item()) self.dp_rank = dist.get_rank(self.dp_group) + self.world_size = int(os.environ["WORLD_SIZE"]) + self.logger.info(f"Data parallel head {self.dp_head} and rank {self.dp_rank}") - def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None): + def initialize( + self, + addr: str | None, + ft_spec: FinetuneSpec | None, + parallel_strategy: ParallelStrategy | None = None, + ): + self.create_process_group(parallel_strategy) # Initialize distributed enviroments and load model. assert addr is None, "FSDPEngine does not support remote initialization." assert ft_spec is not None, "FSDPEngine requires FinetuneSpec to initialize." diff --git a/areal/engine/ppo/actor.py b/areal/engine/ppo/actor.py index 8d6329732..d2122a515 100644 --- a/areal/engine/ppo/actor.py +++ b/areal/engine/ppo/actor.py @@ -67,7 +67,7 @@ def calc_logprobs(logits, input_data): aggregate_fn=lambda xs: torch.cat(xs, dim=-1), ) - def compute_advantages(self, data: Dict[str, Any]) -> None: + def compute_advantages(self, data: Dict[str, Any]) -> Dict[str, Any]: bs = data["input_ids"].shape[0] max_seqlen = data["input_ids"].shape[1] batch_indices = torch.arange( @@ -162,6 +162,8 @@ def compute_advantages(self, data: Dict[str, Any]) -> None: # because we have rolled old_logp by -1 data["logprobs"] = old_logp + return data + def ppo_update(self, data: Dict[str, Any]) -> List[Dict[str, float]]: if self.dynamic_sampling and len(data["rewards"]) % self.group_size == 0: @@ -286,8 +288,8 @@ def compute_logp(self, *args, **kwargs) -> torch.Tensor | None: return self.actor.compute_logp(*args, **kwargs) @torch.no_grad() - def compute_advantages(self, *args, **kwargs) -> None: - self.actor.compute_advantages(*args, **kwargs) + def compute_advantages(self, *args, **kwargs): + return self.actor.compute_advantages(*args, **kwargs) def ppo_update(self, *args, **kwargs) -> List[Dict[str, float]]: return self.actor.ppo_update(*args, **kwargs) diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 9104e78e4..4b4fe8524 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -45,11 +45,8 @@ def __init__(self, config: InferenceEngineConfig): self.distributed_weight_update_initialized = False self._version = 0 - self.lock = Lock() - self.workflow_executor = WorkflowExecutor( - config=config, - inference_engine=self, - ) + self.lock: Lock + self.workflow_executor: WorkflowExecutor def _wait_for_server(self, address): base_url = f"http://{address}" @@ -74,6 +71,11 @@ def initialize( addr: str | List[str] | None = None, train_data_parallel_size: int | None = None, ): + self.lock = Lock() + self.workflow_executor = WorkflowExecutor( + config=self.config, + inference_engine=self, + ) if engine_id is None: if dist.is_initialized(): engine_id = str(dist.get_rank()) diff --git a/areal/reward/gsm8k_reward.py b/areal/reward/gsm8k_reward.py new file mode 100644 index 000000000..5a32cecd0 --- /dev/null +++ b/areal/reward/gsm8k_reward.py @@ -0,0 +1,5 @@ +from areal.reward.math_parser import process_results + + +def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs): + return int(process_results(completions, answer)[0]) diff --git a/areal/scheduler/__init__.py b/areal/scheduler/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/areal/scheduler/local.py b/areal/scheduler/local.py new file mode 100644 index 000000000..473089b8e --- /dev/null +++ b/areal/scheduler/local.py @@ -0,0 +1,408 @@ +import getpass +import os +import re +import signal as signal_module +import subprocess +import time +import uuid +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import psutil + +from areal.api.alloc_mode import AllocationMode, AllocationType +from areal.api.cli_args import ( + ClusterSpecConfig, + LauncherConfig, + RecoverConfig, + SGLangConfig, + to_structured_cfg, +) +from areal.api.scheduler_api import Scheduler, Worker +from areal.platforms import current_platform +from areal.scheduler.rpc.rpc_client import RPCClient +from areal.scheduler.rpc.rpc_server import build_rpc_server_start_command +from areal.utils import logging, name_resolve, names +from areal.utils.launcher import JobException, JobInfo, JobState, get_env_vars +from areal.utils.network import find_free_ports, gethostip +from areal.utils.recover import check_if_recover + +logger = logging.getLogger("LocalScheduler") +JOB_STATE_TO_PROCESS_STATUS = { + JobState.NOT_FOUND: [], + JobState.PENDING: [psutil.STATUS_PARKED], + JobState.RUNNING: [ + psutil.STATUS_RUNNING, + psutil.STATUS_SLEEPING, + psutil.STATUS_DISK_SLEEP, + psutil.STATUS_TRACING_STOP, + psutil.STATUS_WAKING, + psutil.STATUS_WAITING, + psutil.STATUS_LOCKED, + psutil.STATUS_IDLE, + ], + JobState.COMPLETED: [ + psutil.STATUS_DEAD, + psutil.STATUS_STOPPED, + psutil.STATUS_ZOMBIE, + ], + JobState.FAILED: [], + JobState.CANCELLED: [], +} +RECOVER_TIME_INTERVAL = 10 # seconds + +PROCESS_STATUS_TO_JOB_STATE = {} +for job_state, process_statuses in JOB_STATE_TO_PROCESS_STATUS.items(): + for process_status in process_statuses: + PROCESS_STATUS_TO_JOB_STATE[process_status] = job_state + + +def terminate_process_and_children(pid: int, signal: Optional[Union[str, int]] = None): + if signal is None: + signal = signal_module.SIGKILL + if isinstance(signal, str): + signal = getattr(signal_module, signal) + try: + parent = psutil.Process(pid) + children = parent.children(recursive=True) + for child in children: + terminate_process_and_children(child.pid) + parent.send_signal(signal) + except psutil.NoSuchProcess: + pass + + +class LocalLauncher: + def __init__(self, experiment_name: str, trial_name: str, fileroot: str): + self.experiment_name = experiment_name + self.trial_name = trial_name + self.fileroot = fileroot + + self._jobs: Dict[str, subprocess.Popen] = {} + self._job_counter: Dict[str, int] = defaultdict(int) + self._job_states = {} + + self._gpu_counter = 0 + self._gpu_devices: List[str] = os.environ.get( + current_platform.device_control_env_var, + ",".join(map(str, range(current_platform.device_count()))), + ).split(",") + if len(self._gpu_devices) < 1: + raise RuntimeError( + f"Local mode can only run when there is at least one GPU. " + f"{current_platform.device_control_env_var} is currently" + f" set to: `{os.environ.get(current_platform.device_control_env_var, '')}`." + ) + + @property + def run_name(self): + return f"{self.experiment_name}_{self.trial_name}" + + def log_path_of(self, job_name: str) -> str: + log_path = f"{self.fileroot}/logs/{getpass.getuser()}/{self.experiment_name}/{self.trial_name}" + os.makedirs(log_path, exist_ok=True) + return os.path.join(log_path, f"{job_name}.log") + + def __del__(self): + self.wait() + + def submit_array( + self, + job_name: str, + cmd: str | List[str], + count: int = 1, + gpu: int = 0, + env_vars: Optional[Dict] = None, + ): + if env_vars is None: + env_vars = {} + if not isinstance(cmd, list): + cmd = [cmd] * count + offset = self._job_counter[job_name] + for i in range(count): + if gpu > 0: + # Allocate GPUs in a round-robin manner + visible_devices = [] + for _ in range(gpu): + available_device_id = self._gpu_counter % len(self._gpu_devices) + self._gpu_counter += 1 + visible_devices.append(available_device_id) + env_vars[current_platform.device_control_env_var] = ",".join( + str(self._gpu_devices[j]) for j in visible_devices + ) + c = ( + " ".join(str(k) + "=" + str(v) for k, v in env_vars.items()) + + " stdbuf -oL " + + cmd[i] + ) + c = f"{c} 2>&1 | tee -a {self.log_path_of(job_name)}" + logger.info("Starting local process with command: %s", c) + process = subprocess.Popen(c, shell=isinstance(c, str)) + self._jobs[f"{job_name}/{offset + i}"] = process + self._job_counter[job_name] += 1 + + def submit( + self, + job_name: str, + cmd: str | List[str], + gpu: int = 0, + env_vars: Optional[Dict] = None, + ): + self.submit_array(job_name=job_name, cmd=cmd, gpu=gpu, env_vars=env_vars) + + def stop(self, job_name, signal=None): + assert any(k.startswith(job_name) for k in self._jobs) + keys = [k for k, p in self._jobs.items() if k.startswith(job_name)] + procs = [p for k, p in self._jobs.items() if k.startswith(job_name)] + logger.info( + f"Stopping local process with signal {signal if signal else 'SIGKILL'}, " + f"pid: {[p.pid for p in procs]}" + ) + for p in procs: + terminate_process_and_children(p.pid, signal=signal) + for p in procs: + p.wait() + for k, p in zip(keys, procs): + self._jobs.pop(k) + del p + + def stop_all(self, signal=None): + # signal argument is ignored in local stop_all + for name in self._job_counter: + self.stop(name, signal=signal) + + def find(self, job_name): + if job_name in self._jobs: + return JobInfo(name=job_name, state=JobState.RUNNING, host="localhost") + else: + return JobInfo(name=job_name, state=JobState.NOT_FOUND) + + def find_all(self, job_name_regex=".*"): + rs = [] + for name in self._jobs: + if re.fullmatch(job_name_regex, name): + rs.append(self.find(name)) + return rs + + def wait( + self, + timeout=None, + check_status: Tuple[JobState, ...] = ( + JobState.CANCELLED, + JobState.FAILED, + JobState.NOT_FOUND, + ), + remove_status: Tuple[JobState, ...] = (JobState.COMPLETED,), + update=False, + ): + deadline = None if timeout is None else time.time() + timeout + logger.info( + "Waiting for %d local running processes, pids: %s", + len(self._jobs), + " ".join(str(job.pid) for job in self._jobs.values()), + ) + left = set(self._jobs.keys()) + num_jobs_left = len(left) + + while len(left) > 0: + to_remove = [] + if len(left) < num_jobs_left: + num_jobs_left = len(left) + logger.info(f"Waiting for {num_jobs_left} jobs.") + if deadline is not None and time.time() > deadline: + raise TimeoutError( + f"Timeout waiting for {self.run_name}: {', '.join(sorted(left))}" + ) + # update job states + for job_name in list(left): + job = self._jobs[job_name] + pid = job.pid + try: + process = psutil.Process(pid) + self._job_states[job_name] = PROCESS_STATUS_TO_JOB_STATE.get( + process.status(), JobState.NOT_FOUND + ) + except psutil.NoSuchProcess: + self._job_states[job_name] = JobState.NOT_FOUND + + for job_name in list(left): + state = self._job_states[job_name] + if state in check_status: + raise JobException( + run_name=self.run_name, + worker_type=job_name.split("/")[0], + host="local", + reason=state, + ) + if state in remove_status: + logger.info(f"Job {job_name} is {state}.(Removed)") + left.remove(job_name) + to_remove.append(job_name) + + if update: + for k in to_remove: + self._jobs.pop(k) + worker_type = k.split("/")[0] + assert worker_type in self._job_counter + self._job_counter[worker_type] -= 1 + if self._job_counter[worker_type] <= 0: + self._job_counter.pop(worker_type) + + time.sleep(2) + + +class LocalScheduler(Scheduler): + def __init__(self, config): + self.procs = [] # Store subprocess objects + self.engine_workers: Dict[str, List[str]] = defaultdict( + list + ) # role -> [worker_id] + self.rpc_client = RPCClient() + self.launcher = LocalLauncher( + config.experiment_name, config.trial_name, config.cluster.fileroot + ) + + def create_workers(self, worker_role, config, *args, **kwargs) -> None: + config.launcher = to_structured_cfg(config.launcher, LauncherConfig) + config.recover = to_structured_cfg(config.recover, RecoverConfig) + config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig) + is_recover_run = check_if_recover(config.recover, run_id=0) + + name_resolve.reconfigure(config.cluster.name_resolve) + name_resolve.clear_subtree( + names.trial_root( + experiment_name=config.experiment_name, trial_name=config.trial_name + ) + ) + alloc_mode = AllocationMode.from_str(config.allocation_mode) + logger.info( + f"experiment_name={config.experiment_name}, " + f"trial_name={config.trial_name}, fileroot={config.cluster.fileroot}, " + f"is_recover_run={is_recover_run}" + ) + + server_cmd = [] + server_addrs = [] + if worker_role == "rollout": + if alloc_mode.gen_backend == "sglang": + # launch sglang servers + base_seed = config.sglang.random_seed + config.sglang = to_structured_cfg(config.sglang, SGLangConfig) + # each sglang need 2 ports + ports = find_free_ports( + alloc_mode.gen.dp_size * 2, port_range=(10000, 50000) + ) + host_ip = gethostip() + host = "localhost" if not config.sglang.enable_metrics else host_ip + for i in range(alloc_mode.gen.dp_size): + config.sglang.random_seed = base_seed + i + cmd = SGLangConfig.build_cmd( + config.sglang, + host=host, + tp_size=alloc_mode.gen.tp_size, + base_gpu_id=0, + port=ports[i * 2], + dist_init_addr=f"localhost:{ports[i*2+1]}", + ) + server_cmd.append(cmd) + server_addrs.append(f"{host}:{ports[i * 2]}") + + # Launch inference servers. + self.launcher.submit_array( + job_name="llm_server", + cmd=server_cmd, + count=alloc_mode.gen.dp_size, + gpu=alloc_mode.gen.pp_size * alloc_mode.gen.tp_size, + env_vars=get_env_vars( + config.cluster.cluster_name, + config.launcher.inference_server_env_vars, + ), + ) + logger.info( + f"LLM inference server launched at: AREAL_LLM_SERVER_ADDRS={','.join(server_addrs)}" + ) + + # create rpc server workers + worker_ports = find_free_ports( + alloc_mode.gen.world_size, port_range=(10000, 50000) + ) # each sglang need 2 ports + for i in range(alloc_mode.gen.world_size): + cmd = build_rpc_server_start_command(worker_ports[i]) + + self.launcher.submit( + job_name="rollout_worker", + cmd=cmd, + gpu=0, + env_vars=dict( + **get_env_vars( + config.cluster.cluster_name, + # config.launcher.worker_env_vars, + ), + AREAL_LLM_SERVER_ADDRS=server_addrs[ + i % alloc_mode.gen.dp_size + ], + AREAL_RECOVER_RUN=str(int(is_recover_run)), + ), + ) + + logger.info( + f"RPC server for rollout worker launched at port: {worker_ports[i]}" + ) + + worker_id = f"rollout_{i}_{uuid.uuid4().hex[:8]}" + self.rpc_client.register(worker_id, "localhost", worker_ports[i]) + self.engine_workers.setdefault(worker_role, []).append(worker_id) + + else: + raise NotImplementedError(f"Unsupported allocation mode: {alloc_mode}") + elif worker_role == "actor": + if alloc_mode.type_ == AllocationType.DECOUPLED_EVAL: + gpu = 0 + nprocs = 1 + else: + gpu = nprocs = alloc_mode.train.world_size + + worker_ports = find_free_ports(alloc_mode.gen.world_size, (10000, 50000)) + + self.launcher.submit( + job_name="trainer", + cmd=f"torchrun --nnodes 1 --nproc-per-node {nprocs} " + f"--master-addr localhost --master-port {find_free_ports(1, (10000, 50000))[0]} " + f"-m areal.scheduler.rpc.rpc_server --rpc_ports {','.join(map(str, worker_ports))}", + gpu=gpu, + env_vars=dict( + **get_env_vars( + config.cluster.cluster_name, + config.launcher.trainer_env_vars, + ), + # AREAL_LLM_SERVER_ADDRS=",".join(server_addrs), # not need? + AREAL_RECOVER_RUN=str(int(is_recover_run)), + ), + ) + + for i in range(alloc_mode.gen.world_size): + worker_id = f"actor_{i}_{uuid.uuid4().hex[:8]}" + self.rpc_client.register(worker_id, "localhost", worker_ports[i]) + self.engine_workers.setdefault(worker_role, []).append(worker_id) + else: + raise ValueError(f"Unknown worker role: {worker_role}") + + def get_workers(self, worker_role, timeout: float = 60.0) -> List[Worker]: + workers = [] + for worker_id in self.engine_workers.get(worker_role, []): + ip, port = self.rpc_client.get_info(worker_id) + worker = Worker(id=worker_id, ip=ip, ports=[str(port)]) + workers.append(worker) + return workers + + def delete_workers(self): + raise NotImplementedError("LocalScheduler does not support delete_workers") + + # Other methods remain the same + def create_engine(self, worker_id, engine_obj, *args, **kwargs): + # launch engine rpc server on the worker + self.rpc_client.create_engine(worker_id, engine_obj, *args, **kwargs) + + def call_engine(self, worker_id, method, *args, **kwargs): + ret = self.rpc_client.call_engine(worker_id, method, 3, *args, **kwargs) + return ret diff --git a/areal/scheduler/rpc/rpc_client.py b/areal/scheduler/rpc/rpc_client.py index 28f4b8082..b25c6b120 100644 --- a/areal/scheduler/rpc/rpc_client.py +++ b/areal/scheduler/rpc/rpc_client.py @@ -6,7 +6,6 @@ import cloudpickle import requests -from areal.api.cli_args import InferenceEngineConfig, TrainEngineConfig from areal.api.engine_api import InferenceEngine, TrainEngine from areal.utils import logging from areal.utils.http import response_ok, response_retryable @@ -22,16 +21,20 @@ def register(self, worker_id: str, ip: str, port: int) -> None: self._addrs[worker_id] = (ip, port) logger.info(f"Registered worker {worker_id} at {ip}:{port}") + def get_info(self, worker_id: str) -> tuple[str, int]: + return self._addrs[worker_id] + def create_engine( self, worker_id: str, engine_obj: Union[InferenceEngine, TrainEngine], - init_config: Union[InferenceEngineConfig, TrainEngineConfig], + *args, + **kwargs, ) -> None: ip, port = self._addrs[worker_id] url = f"http://{ip}:{port}/create_engine" logger.info(f"send create_engine to {worker_id} ({ip}:{port})") - payload = (engine_obj, init_config) + payload = (engine_obj, args, kwargs) serialized_data = cloudpickle.dumps(payload) serialized_obj = gzip.compress(serialized_data) resp = requests.post(url, data=serialized_obj) @@ -48,7 +51,7 @@ def create_engine( ) def call_engine( - self, worker_id: str, method: str, max_retries: int = 3, *args, **kwargs + self, worker_id: str, method: str, max_retries: int, *args, **kwargs ) -> Any: """ call the rpc server with method name and args, retry on failure diff --git a/areal/scheduler/rpc/rpc_server.py b/areal/scheduler/rpc/rpc_server.py index b2bc3d612..4ac8ef0f7 100644 --- a/areal/scheduler/rpc/rpc_server.py +++ b/areal/scheduler/rpc/rpc_server.py @@ -3,20 +3,46 @@ import os import traceback from http import HTTPStatus -from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer -from typing import AnyStr +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Any, AnyStr, Dict, List import cloudpickle +import torch from tensordict import TensorDict from areal.api.controller_api import DistributedBatch +from areal.api.engine_api import InferenceEngine from areal.controller.batch import DistributedBatchMemory from areal.utils import logging logger = logging.getLogger("RPCServer") -def process_input_to_distributed_batch(*args, **kwargs): +def tensor_container_to_safe( + d: Dict[str, Any] | torch.Tensor | List[torch.Tensor], *args, **kwargs +): + """Apply `t.to(*args, **kwargs)` to all tensors in the dictionary. + Support nested dictionaries. + """ + new_dict = {} + if torch.is_tensor(d): + return d.to(*args, **kwargs) + elif isinstance(d, list): + return [tensor_container_to_safe(v, *args, **kwargs) for v in d] + elif isinstance(d, dict): + for key, value in d.items(): + if isinstance(value, dict) or isinstance(value, list): + new_dict[key] = tensor_container_to_safe(value, *args, **kwargs) + elif torch.is_tensor(value): + new_dict[key] = value.to(*args, **kwargs) + else: + new_dict[key] = value + return new_dict + else: + return d + + +def process_input_to_distributed_batch(to_device, *args, **kwargs): for i in range(len(args)): if isinstance(args[i], DistributedBatch): args = list(args) @@ -27,10 +53,14 @@ def process_input_to_distributed_batch(*args, **kwargs): if isinstance(kwargs[k], DistributedBatch): kwargs[k] = kwargs[k].get_data() + args = tuple(tensor_container_to_safe(list(args), to_device)) + kwargs = tensor_container_to_safe(kwargs, to_device) + return args, kwargs def process_output_to_distributed_batch(result): + result = tensor_container_to_safe(result, "cpu") if isinstance(result, dict): return DistributedBatchMemory.from_dict(result) elif isinstance(result, TensorDict): @@ -76,9 +106,9 @@ def do_POST(self): try: if self.path == "/create_engine": decompressed_data = gzip.decompress(data) - engine_obj, init_args = cloudpickle.loads(decompressed_data) + engine_obj, args, kwargs = cloudpickle.loads(decompressed_data) EngineRPCServer.engine = engine_obj - result = EngineRPCServer.engine.initialize(init_args) + result = EngineRPCServer.engine.initialize(*args, **kwargs) logger.info(f"Engine created and initialized on RPC server: {result}") self.send_response(HTTPStatus.OK) self.end_headers() @@ -93,8 +123,14 @@ def do_POST(self): action, args, kwargs = cloudpickle.loads(data) method = getattr(EngineRPCServer.engine, action) # NOTE: DO NOT print args here, args may be a very huge tensor - logger.info(f"RPC server calling engine method: {action}") - args, kwargs = process_input_to_distributed_batch(*args, **kwargs) + if isinstance(EngineRPCServer.engine, InferenceEngine): + device = "cpu" + else: # actor + device = EngineRPCServer.engine.device + + args, kwargs = process_input_to_distributed_batch( + device, *args, **kwargs + ) result = method(*args, **kwargs) result = process_output_to_distributed_batch(result) self.send_response(HTTPStatus.OK) @@ -113,36 +149,36 @@ def do_POST(self): def start_rpc_server(port): - server = ThreadingHTTPServer(("0.0.0.0", port), EngineRPCServer) + # NOTE: We must use HTTPServer rather than ThreadingHTTPServer here, since the rank and device info + # of pytorch is thread level, if use ThreadingHTTPServer, the device set by create_engine thread + # will not be seen by call_engine thread. + # server = ThreadingHTTPServer(("0.0.0.0", port), EngineRPCServer) + server = HTTPServer(("0.0.0.0", port), EngineRPCServer) server.serve_forever() -def get_serve_port(args): - port = args.port - port_str = os.environ.get("PORT_LIST", "").strip() - - # Check if PORT_LIST is set - if port_str: - # Split by comma and strip whitespace - ports = [p.strip() for p in port_str.split(",")] - # Use the first valid port from the list - if ports and ports[0]: - try: - return int(ports[0]) - except ValueError: - logger.warning( - f"Invalid port '{ports[0]}' in PORT_LIST. Falling back to --port argument." - ) - return port +def get_server_ports(ports_str: str) -> int: + ports = [p.strip() for p in ports_str.split(",")] + word_size = int(os.environ.get("WORLD_SIZE", "1")) + rank = int(os.environ.get("RANK", "0")) + if len(ports) < word_size: + raise ValueError( + f"Not enough ports for the world size {word_size}, got {ports_str}" + ) + return int(ports[rank]) + + +def build_rpc_server_start_command(port): + return f"python3 -m areal.scheduler.rpc.rpc_server --rpc_ports {port}" if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--port", type=int, required=False) + parser.add_argument("--rpc_ports", type=str, required=True) args, unknown = parser.parse_known_args() - port = get_serve_port(args) + port = get_server_ports(args.rpc_ports) logger.info(f"About to start RPC server on {port}") diff --git a/areal/scheduler/test_local.py b/areal/scheduler/test_local.py new file mode 100644 index 000000000..564038bc6 --- /dev/null +++ b/areal/scheduler/test_local.py @@ -0,0 +1,226 @@ +import os +import sys +import time +from concurrent.futures import ThreadPoolExecutor + +from torchdata.stateful_dataloader import StatefulDataLoader + +from areal.api.alloc_mode import AllocationMode +from areal.api.cli_args import ( + GRPOConfig, + parse_cli_args, + to_structured_cfg, +) +from areal.api.io_struct import FinetuneSpec +from areal.dataset import get_custom_dataset +from areal.engine.ppo.actor import FSDPPPOActor +from areal.engine.sglang_remote import RemoteSGLangEngine +from areal.scheduler.local import LocalScheduler +from areal.utils import name_resolve +from areal.utils.data import ( + cycle_dataloader, +) +from areal.utils.hf_utils import load_hf_tokenizer +from areal.utils.stats_logger import StatsLogger +from areal.workflow.rlvr import RLVRWorkflow + +# init_config = {} + +create_workers_config, _ = parse_cli_args(sys.argv[1:]) + +from omegaconf import OmegaConf + +# config, _ = load_expr_config(sys.argv[1:]) +config = to_structured_cfg(create_workers_config, config_cls=GRPOConfig) +config = OmegaConf.to_object(config) +name_resolve.reconfigure(config.cluster.name_resolve) +config: GRPOConfig +# seeding.set_random_seed(config.seed, key=f"trainer{rank}") +allocation_mode = AllocationMode.from_str(config.allocation_mode) +parallel_strategy = allocation_mode.train + + +shcheduler = LocalScheduler(create_workers_config) +shcheduler.create_workers("rollout", create_workers_config) +shcheduler.create_workers("actor", create_workers_config) + +rollout_workers = shcheduler.get_workers("rollout", timeout=300) +actor_workers = shcheduler.get_workers("actor", timeout=300) + +print("[wht debug] rollout workers:", rollout_workers) +print("[wht debug] actor workers:", actor_workers) + +time.sleep(20) + + +rollout = RemoteSGLangEngine(config.rollout) +with ThreadPoolExecutor(max_workers=len(rollout_workers)) as executor: + + def create_engine_and_init(worker_id): + print(f"[wht debug] start create rollout engine and init {worker_id}") + shcheduler.create_engine( + worker_id, rollout, train_data_parallel_size=parallel_strategy.dp_size + ) + print(f"[wht debug] end create rollout engine and init {worker_id}") + + futures = [] + for i in range(len(rollout_workers)): + futures.append(executor.submit(create_engine_and_init, rollout_workers[i].id)) + + for future in futures: + future.result() + +ft_spec = FinetuneSpec( + total_train_epochs=config.total_train_epochs, + dataset_size=1024, # dummy value + train_batch_size=config.train_dataset.batch_size, +) + +actor = FSDPPPOActor(config=config.actor) +with ThreadPoolExecutor(max_workers=len(actor_workers)) as executor: + + def create_engine_and_init(worker_id): + print(f"[wht debug] start create actor engine and init {worker_id}") + shcheduler.create_engine( + worker_id, actor, None, ft_spec, parallel_strategy=parallel_strategy + ) + print(f"[wht debug] end create actor engine and init {worker_id}") + + futures = [] + for i in range(len(actor_workers)): + futures.append(executor.submit(create_engine_and_init, actor_workers[i].id)) + + for future in futures: + future.result() + +print("[wht debug] all engines created and initialized.") + + +tokenizer = load_hf_tokenizer(config.tokenizer_path) +train_dataset = get_custom_dataset( + path=config.train_dataset.path, + rank=0, + world_size=1, + split="train", + max_length=config.train_dataset.max_length, + type=config.train_dataset.type, + tokenizer=tokenizer, +) +train_dataloader = StatefulDataLoader( + train_dataset, + batch_size=config.train_dataset.batch_size, + shuffle=config.train_dataset.shuffle, + num_workers=config.train_dataset.num_workers, + collate_fn=lambda x: x, + drop_last=config.train_dataset.drop_last, +) +data_generator = cycle_dataloader(train_dataloader) +data = next(data_generator) + +print(f"[wht debug] get data batch: {data[0]}") + +from areal.reward.gsm8k_reward import gsm8k_reward_fn + +workflow = RLVRWorkflow( + reward_fn=gsm8k_reward_fn, + gconfig=config.gconfig, + tokenizer=tokenizer, + enable_thinking=False, + dump_dir=os.path.join(StatsLogger.get_log_path(config.stats_logger), "generated"), +) + +batch = None +with ThreadPoolExecutor(max_workers=len(rollout_workers)) as executor: + + def call_rollout(worker_id, data): + try: + batch = shcheduler.call_engine( + worker_id, + "rollout_batch", + data, + workflow=workflow, + should_accept=lambda sample: True, + ) + print(f"[wht debug] rollout {worker_id} done, got batch: {batch}") + return batch + except Exception as e: + print(f"[wht debug] rollout {worker_id} failed, error: {e}") + raise e + + futures = [] + for i in range(len(rollout_workers)): + futures.append(executor.submit(call_rollout, rollout_workers[i].id, data)) + for future in futures: + r = future.result() + print(f"[wht debug] rollout result: {r}") + batch = r + +print("[wht debug] all rollout done.") + +assert not config.actor.use_decoupled_loss and not config.actor.recompute_logprob + +with ThreadPoolExecutor(max_workers=len(actor_workers)) as executor: + + def call_compute_advantages(worker_id, data): + try: + batch = shcheduler.call_engine(worker_id, "compute_advantages", data) + print( + f"[wht debug] compute_advantages {worker_id} done, got batch: {batch}" + ) + return batch + except Exception as e: + print(f"[wht debug] compute_advantages {worker_id} failed, error: {e}") + raise e + + futures = [] + for i in range(len(actor_workers)): + futures.append( + executor.submit(call_compute_advantages, actor_workers[i].id, batch) + ) + for future in futures: + r = future.result() + print(f"[wht debug] compute_advantages result: {r}") + batch = r + +print("[wht debug] all compute_advantages done.") + +with ThreadPoolExecutor(max_workers=len(actor_workers)) as executor: + + def call_ppo_update(worker_id, data): + try: + batch = shcheduler.call_engine(worker_id, "ppo_update", data) + print(f"[wht debug] ppo_update {worker_id} done, got batch: {batch}") + return batch + except Exception as e: + print(f"[wht debug] ppo_update {worker_id} failed, error: {e}") + raise e + + futures = [] + for i in range(len(actor_workers)): + futures.append(executor.submit(call_ppo_update, actor_workers[i].id, batch)) + + for future in futures: + r = future.result() + print(f"[wht debug] ppo_update result: {r}") + +print("[wht debug] all ppo_update done.") + +with ThreadPoolExecutor(max_workers=len(actor_workers)) as executor: + + def call_step_lr_scheduler(worker_id): + try: + res = shcheduler.call_engine(worker_id, "step_lr_scheduler") + print(f"[wht debug] step_lr_scheduler {worker_id} done, got res: {res}") + return res + except Exception as e: + print(f"[wht debug] step_lr_scheduler {worker_id} failed, error: {e}") + raise e + + futures = [] + for i in range(len(actor_workers)): + futures.append(executor.submit(call_step_lr_scheduler, actor_workers[i].id)) + for future in futures: + r = future.result() + print(f"[wht debug] step_lr_scheduler result: {r}") + +print("[wht debug] all step_lr_scheduler done.") diff --git a/areal/tests/test_rpc.py b/areal/tests/test_rpc.py index 2f5ab493a..58590a407 100644 --- a/areal/tests/test_rpc.py +++ b/areal/tests/test_rpc.py @@ -16,7 +16,7 @@ from areal.scheduler.rpc.rpc_client import RPCClient from areal.scheduler.rpc.rpc_server import ( EngineRPCServer, - get_serve_port, + get_server_ports, process_input_to_distributed_batch, process_output_to_distributed_batch, start_rpc_server, @@ -175,61 +175,53 @@ def test_process_output_to_distributed_batch_other_types(): def test_get_serve_port_from_args(): """Test getting port from command line arguments""" mock_args = Mock() - mock_args.port = 8080 + mock_args.rpc_port = "8080" with patch.dict("os.environ", {}, clear=True): - port = get_serve_port(mock_args) + port = get_server_ports(mock_args.rpc_port) assert port == 8080 -def test_get_serve_port_from_env_single_port(): +def test_get_server_ports_default_from_multi_ports(): """Test getting single port from PORT_LIST environment variable""" mock_args = Mock() - mock_args.port = 8080 + mock_args.rpc_port = "8080,8081,8082,8083" - with patch.dict("os.environ", {"PORT_LIST": "9000"}): - port = get_serve_port(mock_args) - assert port == 9000 - - -def test_get_serve_port_from_env_multiple_ports(): - """Test getting first port from multiple ports in PORT_LIST environment variable""" - mock_args = Mock() - mock_args.port = 8080 - - with patch.dict("os.environ", {"PORT_LIST": "9000, 9001, 9002"}): - port = get_serve_port(mock_args) - assert port == 9000 + with patch.dict("os.environ", {}, clear=True): + port = get_server_ports(mock_args.rpc_port) + assert port == 8080 -def test_get_serve_port_invalid_env_port(): - """Test fallback when PORT_LIST contains invalid ports""" +def test_get_serve_port_from_multi_ports(): + """Test getting single port from PORT_LIST environment variable""" mock_args = Mock() - mock_args.port = 8080 + mock_args.rpc_port = "8080,8081,8082,8083" - with patch.dict("os.environ", {"PORT_LIST": "invalid_port, 9001"}): - port = get_serve_port(mock_args) + with patch.dict("os.environ", {"WORLD_SIZE": "4", "RANK": "0"}): + port = get_server_ports(mock_args.rpc_port) assert port == 8080 + with patch.dict("os.environ", {"WORLD_SIZE": "4", "RANK": "1"}): + port = get_server_ports(mock_args.rpc_port) + assert port == 8081 -def test_get_serve_port_empty_env(): - """Test fallback when PORT_LIST is empty""" - mock_args = Mock() - mock_args.port = 8080 + with patch.dict("os.environ", {"WORLD_SIZE": "4", "RANK": "2"}): + port = get_server_ports(mock_args.rpc_port) + assert port == 8082 - with patch.dict("os.environ", {"PORT_LIST": ""}): - port = get_serve_port(mock_args) - assert port == 8080 + with patch.dict("os.environ", {"WORLD_SIZE": "4", "RANK": "3"}): + port = get_server_ports(mock_args.rpc_port) + assert port == 8083 -def test_get_serve_port_whitespace_env(): - """Test fallback when PORT_LIST contains only whitespace""" +def test_get_serve_port_not_enough_ports(): + """Test error when not enough ports for WORLD_SIZE""" mock_args = Mock() - mock_args.port = 8080 + mock_args.rpc_port = "8080,8081" - with patch.dict("os.environ", {"PORT_LIST": " "}): - port = get_serve_port(mock_args) - assert port == 8080 + with patch.dict("os.environ", {"WORLD_SIZE": "4", "RANK": "0"}): + with pytest.raises(ValueError, match="Not enough ports for the world size"): + get_server_ports(mock_args.rpc_port) # RPC client and server integration tests diff --git a/areal/workflow/rlvr.py b/areal/workflow/rlvr.py index 34a6994ae..de3b26ff8 100644 --- a/areal/workflow/rlvr.py +++ b/areal/workflow/rlvr.py @@ -52,13 +52,15 @@ def __init__( self.enable_thinking = enable_thinking self.dump_dir = dump_dir self.rollout_stat_scope = rollout_stat_scope - self.async_reward_fn = AsyncRewardWrapper(reward_fn) + self.async_reward_fn = None self.get_input_ids_fn = get_input_ids_fn self.data_extract_prompt_fn = data_extract_prompt_fn if self.dump_dir is not None and not os.path.exists(self.dump_dir): os.makedirs(self.dump_dir, exist_ok=True) async def arun_episode(self, engine: InferenceEngine, data): + if self.async_reward_fn is None: + self.async_reward_fn = AsyncRewardWrapper(self.reward_fn) input_ids = self.get_input_ids_fn( self.data_extract_prompt_fn(data), self.tokenizer, self.enable_thinking ) diff --git a/examples/math/gsm8k_grpo_single_controller.yaml b/examples/math/gsm8k_grpo_single_controller.yaml new file mode 100644 index 000000000..6cf4387c0 --- /dev/null +++ b/examples/math/gsm8k_grpo_single_controller.yaml @@ -0,0 +1,153 @@ +experiment_name: gsm8k-grpo +trial_name: trial0 + +seed: 1 +total_train_epochs: 10 +tokenizer_path: ${actor.path} +async_training: true + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + +allocation_mode: sglang.d4p1t1+d4p1t1 + +rollout: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + +actor: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen2.5-1.5B-Instruct + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 10240 + optimizer: + type: adam + lr: 1.70e-5 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + backend: fsdp + group_size: ${gconfig.n_samples} + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: false + use_decoupled_loss: false + behav_imp_weight_cap: 5.0 + dynamic_sampling: false + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + +ref: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + backend: fsdp + +# SGLang +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + +# datasets +train_dataset: + batch_size: 8 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 256 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +launcher: + inference_server_cpus_per_gpu: 4 + inference_server_mem_per_gpu: 32768 + trainer_cpus_per_gpu: 4 + trainer_mem_per_gpu: 32768