-
Notifications
You must be signed in to change notification settings - Fork 203
feat: support single controller and balance batch #410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: support single controller and balance batch #410
Conversation
Co-authored-by: shun001 <[email protected]> Co-authored-by: flemingpau <[email protected]> Co-authored-by: Moocharr <[email protected]> Co-authored-by: HUZZZW <[email protected]> Co-authored-by: ChengQianqian <[email protected]> Co-authored-by: zx506 <[email protected]> Co-authored-by: Shengzhou Lyu <[email protected]> Co-authored-by: flemingpau <[email protected]> Co-authored-by: casparcwang <[email protected]> Co-authored-by: dazhuangzhuang1024 <[email protected]> Co-authored-by: HShan886 <[email protected]>
Summary of ChangesHello @Ronbogo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces significant enhancements to distributed training by adding a single-control mode and a sophisticated load-balancing mechanism. The single-control mode addresses potential timeouts in scenarios with high Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a single-controller mode and a batch balancing feature to improve performance and avoid timeouts in cases of imbalanced rollouts. The core logic for batch balancing is implemented in areal/utils/data.py
, and examples/math/boba_grpo.py
is updated to demonstrate its usage. My review has identified a critical bug in the example's rollout logic, an incorrect type hint in the workflow API, and several medium-severity issues related to documentation clarity, code style, and redundancy. Addressing these points will improve the correctness and maintainability of the new features.
examples/math/boba_grpo.py
Outdated
try: | ||
data = next(data_generator) | ||
except StopIteration: | ||
data_generator = iter(train_dataloader) | ||
data = next(data_generator) | ||
batch = rollout.rollout_batch( | ||
data=data, | ||
workflow=workflow, | ||
should_accept=lambda sample: True, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a bug in the non-asynchronous rollout logic. The rollout.rollout_batch(...)
call is inside the except StopIteration
block. This means if next(data_generator)
succeeds, batch
will not be assigned and will remain None
, causing issues later. The rollout_batch
call should be moved outside the try-except
block to ensure it's always executed.
try: | |
data = next(data_generator) | |
except StopIteration: | |
data_generator = iter(train_dataloader) | |
data = next(data_generator) | |
batch = rollout.rollout_batch( | |
data=data, | |
workflow=workflow, | |
should_accept=lambda sample: True, | |
) | |
try: | |
data = next(data_generator) | |
except StopIteration: | |
data_generator = iter(train_dataloader) | |
data = next(data_generator) | |
batch = rollout.rollout_batch( | |
data=data, | |
workflow=workflow, | |
should_accept=lambda sample: True, | |
) |
areal/api/workflow_api.py
Outdated
def wait(self, count: int, timeout: float | None = None) -> Dict[str, Any]: | ||
def wait( | ||
self, count: int, timeout: float | None = None, single_rank_load: bool = False | ||
) -> Dict[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type hint for the wait
method is now incorrect. With the introduction of single_rank_load
, the method can return either a Dict[str, Any]
or a List[Dict[str, Any]]
. The type hint should be updated to reflect this.
) -> Dict[str, Any]: | |
) -> Dict[str, Any] | List[Dict[str, Any]]: |
metadata={ | ||
"help": "balance all rollouts across dp ranks by total tokens.now, it works only when single_rank_load was set true." | ||
}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The help text for balance_batch
is a bit unclear and contains a typo. It can be improved for better readability and correctness.
metadata={ | |
"help": "balance all rollouts across dp ranks by total tokens.now, it works only when single_rank_load was set true." | |
}, | |
metadata={ | |
"help": "Balance all rollouts across DP ranks by total tokens. Note: this only works when `single_rank_load` is set to True." | |
}, |
print( | ||
f"-----cv_original {cv_original} total_tokens_per_rank_original {total_tokens_per_rank_original}" | ||
) | ||
print(f"-----cv_balance {cv_balance} total_tokens_per_rank {total_tokens_per_rank}") | ||
print("*****************************************************") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
areal/utils/data.py
Outdated
balance_batch_enabled, | ||
): | ||
""" | ||
Broadcast data when using signle controller (single_rank_load==True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
def karmarkar_karp(seqlen_list: List[int], k_partitions: int, equal_size: bool): | ||
import heapq |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def spread(self): | ||
return self.sets[0].sum - self.sets[-1].sum |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| `num_workers` | integer | `0` | Number of worker processes for data loading | | ||
| `drop_last` | boolean | `True` | Drop the last incomplete batch | | ||
| `single_rank_load` | boolean | `False` | Use single rank rollout send/recive or not | | ||
| `balance_batch` | boolean | `False` | balance all rollouts across dp ranks by total tokens.now, it works only when single_rank_load was set true. | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The description for balance_batch
has some grammatical issues and can be made clearer for better user understanding.
| `balance_batch` | boolean | `False` | balance all rollouts across dp ranks by total tokens.now, it works only when single_rank_load was set true. | | |
| `balance_batch` | boolean | `False` | balance all rollouts across dp ranks by total tokens.now, it works only when single_rank_load is set true. | |
| `model` | string | `""` | - | | ||
| `seed` | integer | `1` | - | | ||
| `skip_tokenizer_init` | boolean | `False` | - | | ||
| `enforce_eager` | boolean | `True` | - | | ||
| `dtype` | string | `"bfloat16"` | - | | ||
| `max_num_seqs` | integer | `256` | - | | ||
| `block_size` | integer | `16` | - | | ||
| `swap_space` | integer | `4` | - | | ||
| `cpu_offload_gb` | float | `0` | - | | ||
| `max_seq_len_to_capture` | integer | `32768` | - | | ||
| `disable_sliding_window` | boolean | `True` | - | | ||
| `max_model_len` | integer \| None | `32768` | - | | ||
| `enable_chunked_prefill` | boolean | `False` | - | | ||
| `enable_prefix_caching` | boolean | `False` | - | | ||
| `gpu_memory_utilization` | float | `0.9` | - | | ||
| `worker_extension_cls` | string | `"areal.thirdparty.vllm.vllm_worker_extension.VLLMWorkerExtension"` | - | | ||
| `enable_sleep_mode` | boolean | `False` | - | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
examples/math/boba_grpo.py
Outdated
config.rollout.consumer_batch_size *= world_size | ||
config.rollout.max_concurrent_rollouts *= world_size | ||
else: | ||
# Create empty datqaloader for other ranks when using single rank load |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Ronbogo Hi, thanks for this important feature! FYI we've planned publish a new release v0.3.4 by 10.10 or some near day, and we'd like to merge this feature after that. :) |
OK |
Overview
Add single-control mode, enabled by setting
single_rank_load=True
. This feature avoid possible timeouts caused by severe imbalance in rollout results across trainer DP ranks whenmax_concurrent_rollouts
andmax_head_offpolicyness
are large.Introduce DP load-balancing for single-control mode, inspired by verl's seqlen balancer. The algorithm actively rebalances workload among DP ranks, ensuring more even utilization and improving performance.