-
Notifications
You must be signed in to change notification settings - Fork 200
[feature] support local scheduler #427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -330,14 +330,11 @@ async def _rollout_thread_async(self): | |
try: | ||
while not self.exiting.is_set(): | ||
# Check capacity | ||
capacity = self.get_capacity() | ||
# capacity = self.get_capacity() | ||
# self.logger.info(f"Current rollout capacity: {capacity}") | ||
# Create new rollout task | ||
self.lock.acquire() | ||
while ( | ||
capacity > 0 | ||
and not self.paused.is_set() | ||
and self.input_queue.qsize() > 0 | ||
): | ||
while not self.paused.is_set() and self.input_queue.qsize() > 0: | ||
Comment on lines
+333
to
+337
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
x = self.input_queue.get_nowait() | ||
x: _RolloutTaskInput | ||
self.logger.debug(f"Get data from puller: {x.data}") | ||
|
@@ -357,7 +354,7 @@ async def _rollout_thread_async(self): | |
f"running: {self.rollout_stat.running}, " | ||
f"accepted: {self.rollout_stat.accepted}." | ||
) | ||
capacity -= 1 | ||
# capacity -= 1 | ||
rid += 1 | ||
tasks = [x.task for x in rollout_tasks.values()] | ||
self.lock.release() | ||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -67,7 +67,7 @@ def calc_logprobs(logits, input_data): | |||||||||
aggregate_fn=lambda xs: torch.cat(xs, dim=-1), | ||||||||||
) | ||||||||||
|
||||||||||
def compute_advantages(self, data: Dict[str, Any]) -> None: | ||||||||||
def compute_advantages(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||||||||
bs = data["input_ids"].shape[0] | ||||||||||
max_seqlen = data["input_ids"].shape[1] | ||||||||||
batch_indices = torch.arange( | ||||||||||
|
@@ -162,6 +162,8 @@ def compute_advantages(self, data: Dict[str, Any]) -> None: | |||||||||
# because we have rolled old_logp by -1 | ||||||||||
data["logprobs"] = old_logp | ||||||||||
|
||||||||||
return data | ||||||||||
|
||||||||||
def ppo_update(self, data: Dict[str, Any]) -> List[Dict[str, float]]: | ||||||||||
|
||||||||||
if self.dynamic_sampling and len(data["rewards"]) % self.group_size == 0: | ||||||||||
|
@@ -286,8 +288,8 @@ def compute_logp(self, *args, **kwargs) -> torch.Tensor | None: | |||||||||
return self.actor.compute_logp(*args, **kwargs) | ||||||||||
|
||||||||||
@torch.no_grad() | ||||||||||
def compute_advantages(self, *args, **kwargs) -> None: | ||||||||||
self.actor.compute_advantages(*args, **kwargs) | ||||||||||
def compute_advantages(self, *args, **kwargs): | ||||||||||
return self.actor.compute_advantages(*args, **kwargs) | ||||||||||
Comment on lines
+291
to
+292
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The return type hint for
Suggested change
|
||||||||||
|
||||||||||
def ppo_update(self, *args, **kwargs) -> List[Dict[str, float]]: | ||||||||||
return self.actor.ppo_update(*args, **kwargs) | ||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this reward function used anywhere? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If this reward_fn is placed in the main program, the above error will occur because, in single controller mode, the process where the train engine resides is different from the process where the controller resides. After serializing and sending the reward_fn to the remote process, the reward_api.py in the train engine will serialize it again and send it to the subprocess. At this point, the reward_fn definition cannot be found for serialization (since it attempts to read reward_fn from main, but the subprocess's main is not the same as the controller's main). Therefore, the reward_fn should be defined in a separate file. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from areal.reward.math_parser import process_results | ||
|
||
|
||
def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs): | ||
return int(process_results(completions, answer)[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to change this file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Under my configuration, I will hang due to insufficient capacity.
The calculation of
capacity
seems to be one-time for each round of rollout. In one round of rollout, even if some queries have completed the rollout, thecapacity
will not increase. The calculation logic ofcapacity
seems to have been problematic here.