Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion areal/api/scheduler_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ class SchedulingConfig:


class Scheduler(abc.ABC):
def create_workers(self, worker_key, scheduler_config, *args, **kwargs) -> str:
def create_workers(self, worker_key, scheduler_config, *args, **kwargs) -> None:
"""
Start workers, return job id
"""
raise NotImplementedError()

def get_workers(self, worker_key, timeout=None) -> List[Worker]:
"""
Expand Down
11 changes: 4 additions & 7 deletions areal/api/workflow_api.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to change this file?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under my configuration, I will hang due to insufficient capacity.
The calculation of capacity seems to be one-time for each round of rollout. In one round of rollout, even if some queries have completed the rollout, the capacity will not increase. The calculation logic of capacity seems to have been problematic here.

Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,11 @@ async def _rollout_thread_async(self):
try:
while not self.exiting.is_set():
# Check capacity
capacity = self.get_capacity()
# capacity = self.get_capacity()
# self.logger.info(f"Current rollout capacity: {capacity}")
# Create new rollout task
self.lock.acquire()
while (
capacity > 0
and not self.paused.is_set()
and self.input_queue.qsize() > 0
):
while not self.paused.is_set() and self.input_queue.qsize() > 0:
Comment on lines +333 to +337
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The capacity checking logic has been commented out, which changes the behavior of the rollout task creation loop. If this is an intentional change to remove the capacity limit, please remove the commented-out code to improve readability and avoid confusion for future maintainers.

x = self.input_queue.get_nowait()
x: _RolloutTaskInput
self.logger.debug(f"Get data from puller: {x.data}")
Expand All @@ -357,7 +354,7 @@ async def _rollout_thread_async(self):
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
capacity -= 1
# capacity -= 1
rid += 1
tasks = [x.task for x in rollout_tasks.values()]
self.lock.release()
Expand Down
2 changes: 1 addition & 1 deletion areal/engine/base_hf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self, config: TrainEngineConfig):
)
self.is_vision_model = is_valid_vision_model(self.model_config.model_type)

self.world_size = int(os.environ["WORLD_SIZE"])
self.world_size: int

def set_version(self, version: int):
self._version = version
Expand Down
10 changes: 9 additions & 1 deletion areal/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,17 @@ def create_process_group(self, parallel_strategy: ParallelStrategy | None = None
self.dp_head = int(self.world_mesh["sp_tp"].mesh[0].item())
self.dp_rank = dist.get_rank(self.dp_group)

self.world_size = int(os.environ["WORLD_SIZE"])

self.logger.info(f"Data parallel head {self.dp_head} and rank {self.dp_rank}")

def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
def initialize(
self,
addr: str | None,
ft_spec: FinetuneSpec | None,
parallel_strategy: ParallelStrategy | None = None,
):
self.create_process_group(parallel_strategy)
# Initialize distributed enviroments and load model.
assert addr is None, "FSDPEngine does not support remote initialization."
assert ft_spec is not None, "FSDPEngine requires FinetuneSpec to initialize."
Expand Down
8 changes: 5 additions & 3 deletions areal/engine/ppo/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def calc_logprobs(logits, input_data):
aggregate_fn=lambda xs: torch.cat(xs, dim=-1),
)

def compute_advantages(self, data: Dict[str, Any]) -> None:
def compute_advantages(self, data: Dict[str, Any]) -> Dict[str, Any]:
bs = data["input_ids"].shape[0]
max_seqlen = data["input_ids"].shape[1]
batch_indices = torch.arange(
Expand Down Expand Up @@ -162,6 +162,8 @@ def compute_advantages(self, data: Dict[str, Any]) -> None:
# because we have rolled old_logp by -1
data["logprobs"] = old_logp

return data

def ppo_update(self, data: Dict[str, Any]) -> List[Dict[str, float]]:

if self.dynamic_sampling and len(data["rewards"]) % self.group_size == 0:
Expand Down Expand Up @@ -286,8 +288,8 @@ def compute_logp(self, *args, **kwargs) -> torch.Tensor | None:
return self.actor.compute_logp(*args, **kwargs)

@torch.no_grad()
def compute_advantages(self, *args, **kwargs) -> None:
self.actor.compute_advantages(*args, **kwargs)
def compute_advantages(self, *args, **kwargs):
return self.actor.compute_advantages(*args, **kwargs)
Comment on lines +291 to +292
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint for compute_advantages is missing. Based on the wrapped self.actor.compute_advantages method, it should be -> Dict[str, Any]. Adding type hints improves code clarity and allows static analysis tools to catch potential bugs.

Suggested change
def compute_advantages(self, *args, **kwargs):
return self.actor.compute_advantages(*args, **kwargs)
def compute_advantages(self, *args, **kwargs) -> Dict[str, Any]:
return self.actor.compute_advantages(*args, **kwargs)


def ppo_update(self, *args, **kwargs) -> List[Dict[str, float]]:
return self.actor.ppo_update(*args, **kwargs)
Expand Down
12 changes: 7 additions & 5 deletions areal/engine/sglang_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,8 @@ def __init__(self, config: InferenceEngineConfig):
self.distributed_weight_update_initialized = False
self._version = 0

self.lock = Lock()
self.workflow_executor = WorkflowExecutor(
config=config,
inference_engine=self,
)
self.lock: Lock
self.workflow_executor: WorkflowExecutor

def _wait_for_server(self, address):
base_url = f"http://{address}"
Expand All @@ -74,6 +71,11 @@ def initialize(
addr: str | List[str] | None = None,
train_data_parallel_size: int | None = None,
):
self.lock = Lock()
self.workflow_executor = WorkflowExecutor(
config=self.config,
inference_engine=self,
)
if engine_id is None:
if dist.is_initialized():
engine_id = str(dist.get_rank())
Expand Down
5 changes: 5 additions & 0 deletions areal/reward/gsm8k_reward.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this reward function used anywhere?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

concurrent.futures.process._RemoteTraceback:
"""
Traceback (most recent call last):
  File "/usr/lib/python3.12/multiprocessing/queues.py", line 264, in _feed
    obj = _ForkingPickler.dumps(obj)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
_pickle.PicklingError: Can't pickle <function gsm8k_reward_fn at 0x7f82ee9e5d00>: attribute lookup gsm8k_reward_fn on __main__ failed
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/storage/openpsi/codes/wht125/project/github_areal/AReaL/areal/api/reward_api.py", line 122, in __call__
    reward = await asyncio.wait_for(
             ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/asyncio/tasks.py", line 520, in wait_for
    return await fut
           ^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/queues.py", line 264, in _feed
    obj = _ForkingPickler.dumps(obj)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
_pickle.PicklingError: Can't pickle <function gsm8k_reward_fn at 0x7f82ee9e5d00>: attribute lookup gsm8k_reward_fn on __main__ failed

If this reward_fn is placed in the main program, the above error will occur because, in single controller mode, the process where the train engine resides is different from the process where the controller resides. After serializing and sending the reward_fn to the remote process, the reward_api.py in the train engine will serialize it again and send it to the subprocess. At this point, the reward_fn definition cannot be found for serialization (since it attempts to read reward_fn from main, but the subprocess's main is not the same as the controller's main). Therefore, the reward_fn should be defined in a separate file.

Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from areal.reward.math_parser import process_results


def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
return int(process_results(completions, answer)[0])
Empty file added areal/scheduler/__init__.py
Empty file.
Loading
Loading