Skip to content

Conversation

yulangz
Copy link
Collaborator

@yulangz yulangz commented Oct 10, 2025

support local scheduler

  • Move non serializable members from SGLangEngine and FSDPEngine to initialize in initialize method.
  • Optimize the data transmission logic of RPC Server and move the data to the CPU before returning it to the Controller.
  • Added a simple local_scheduler test that will start up both the rollback and actor separately, completing a round of interaction from rollout to train.

usage

cd areal/scheduler

export PYTHONPATH=/path/to/AReaL:$PYTHONPATH
pkill -9 python
pkill -9 sglang

python test_local.py --config ../../examples/math/gsm8k_grpo_single_controller.yaml train_dataset.batch_size=8 trial_name=trial-1 | tee log_of_test.log

Copy link
Contributor

Summary of Changes

Hello @yulangz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the system's flexibility by integrating a local scheduler, allowing for the execution of complex reinforcement learning workflows on a single machine. The changes involve crucial refactoring of engine initialization logic to accommodate local execution, alongside improvements to RPC data handling for better performance and stability. A new, detailed test suite ensures the robustness of this local execution capability.

Highlights

  • Local Scheduler Support: Introduced a new LocalScheduler that enables running rollout and actor components locally, simplifying development and testing workflows.
  • Engine Initialization Refactoring: Non-serializable members in SGLangEngine and FSDPEngine are now initialized within their respective initialize methods, improving serialization compatibility for distributed setups.
  • RPC Data Transmission Optimization: The RPC server now explicitly moves data to the CPU before returning it to the Controller, optimizing data handling and preventing potential device-related issues.
  • Comprehensive Local Test: A new test script (test_local.py) has been added to validate the local scheduler, covering a full interaction round from rollout to training.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a local scheduler, which is a significant feature for enabling local development and testing. The refactoring to move non-serializable members to initialize methods is a great improvement for the architecture, making engines more flexible. The optimization in the RPC server to move data to the CPU before transmission is also a good performance enhancement. My review focuses on improving the robustness and maintainability of the new local scheduler implementation, cleaning up some leftover code from the refactoring, and ensuring API consistency. I've pointed out a potentially fragile command construction in the local launcher, a problematic use of __del__, and some missing type hints and commented-out code. Overall, this is a solid contribution.

Comment on lines +106 to +107
def __del__(self):
self.wait()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The __del__ method calls self.wait(). This can be problematic. __del__ is called during garbage collection, and its execution context is unpredictable. If self.wait() blocks for a long time (e.g., waiting for long-running subprocesses), it can delay or interfere with program shutdown. Also, if an exception occurs in wait(), it will be ignored. It's generally recommended to avoid complex logic in __del__. A better pattern is to use a context manager (__enter__ and __exit__). If you must use __del__, consider calling self.stop_all() instead to terminate jobs, which is a more predictable cleanup operation.

Suggested change
def __del__(self):
self.wait()
def __del__(self):
self.stop_all()

Comment on lines +133 to +137
c = (
" ".join(str(k) + "=" + str(v) for k, v in env_vars.items())
+ " stdbuf -oL "
+ cmd[i]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current method of constructing the command string by joining environment variables is not robust. If an environment variable's value contains spaces or other shell special characters, it will break the command. A safer approach is to use the env parameter of subprocess.Popen to pass environment variables. This avoids shell injection issues and handles special characters correctly. You would need to create a copy of os.environ, update it with env_vars, and pass it to Popen on line 140.

Comment on lines +333 to +337
# capacity = self.get_capacity()
# self.logger.info(f"Current rollout capacity: {capacity}")
# Create new rollout task
self.lock.acquire()
while (
capacity > 0
and not self.paused.is_set()
and self.input_queue.qsize() > 0
):
while not self.paused.is_set() and self.input_queue.qsize() > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The capacity checking logic has been commented out, which changes the behavior of the rollout task creation loop. If this is an intentional change to remove the capacity limit, please remove the commented-out code to improve readability and avoid confusion for future maintainers.

Comment on lines +291 to +292
def compute_advantages(self, *args, **kwargs):
return self.actor.compute_advantages(*args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint for compute_advantages is missing. Based on the wrapped self.actor.compute_advantages method, it should be -> Dict[str, Any]. Adding type hints improves code clarity and allows static analysis tools to catch potential bugs.

Suggested change
def compute_advantages(self, *args, **kwargs):
return self.actor.compute_advantages(*args, **kwargs)
def compute_advantages(self, *args, **kwargs) -> Dict[str, Any]:
return self.actor.compute_advantages(*args, **kwargs)

env_vars=dict(
**get_env_vars(
config.cluster.cluster_name,
# config.launcher.worker_env_vars,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line is commented out. Please remove it to improve code clarity.

config.cluster.cluster_name,
config.launcher.trainer_env_vars,
),
# AREAL_LLM_SERVER_ADDRS=",".join(server_addrs), # not need?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line is commented out. Please remove it to improve code clarity.

Comment on lines +55 to 56
self, worker_id: str, method: str, max_retries: int, *args, **kwargs
) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default value for max_retries has been removed, making it a required argument. This could be an unintentional breaking change for callers of this method. If the intention is to always require it, this is fine, but if a default is desirable, consider re-adding it (e.g., max_retries: int = 3). The only call site in this PR (areal/scheduler/local.py:407) hardcodes it to 3.

Suggested change
self, worker_id: str, method: str, max_retries: int, *args, **kwargs
) -> Any:
self, worker_id: str, method: str, max_retries: int = 3, *args, **kwargs
) -> Any:

Comment on lines +130 to +134
# NOTE: We must use HTTPServer rather than ThreadingHTTPServer here, since the rank and device info
# of pytorch is thread level, if use ThreadingHTTPServer, the device set by create_engine thread
# will not be seen by call_engine thread.
# server = ThreadingHTTPServer(("0.0.0.0", port), EngineRPCServer)
server = HTTPServer(("0.0.0.0", port), EngineRPCServer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment explaining the switch from ThreadingHTTPServer to HTTPServer is excellent. To improve code clarity, the commented-out ThreadingHTTPServer line could be removed, as the explanation is sufficient.

Suggested change
# NOTE: We must use HTTPServer rather than ThreadingHTTPServer here, since the rank and device info
# of pytorch is thread level, if use ThreadingHTTPServer, the device set by create_engine thread
# will not be seen by call_engine thread.
# server = ThreadingHTTPServer(("0.0.0.0", port), EngineRPCServer)
server = HTTPServer(("0.0.0.0", port), EngineRPCServer)
# NOTE: We must use HTTPServer rather than ThreadingHTTPServer here, since the rank and device info
# of pytorch is thread level, if use ThreadingHTTPServer, the device set by create_engine thread
# will not be seen by call_engine thread.
server = HTTPServer(("0.0.0.0", port), EngineRPCServer)

Comment on lines +50 to +51
print("[wht debug] rollout workers:", rollout_workers)
print("[wht debug] actor workers:", actor_workers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test file contains many print statements with [wht debug] prefixes. While useful for debugging, it's better practice to use the logging module. This allows for configurable log levels and makes it easier to manage log output, especially in CI/CD environments. Consider replacing these prints with logger.debug() or logger.info().

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to change this file?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under my configuration, I will hang due to insufficient capacity.
The calculation of capacity seems to be one-time for each round of rollout. In one round of rollout, even if some queries have completed the rollout, the capacity will not increase. The calculation logic of capacity seems to have been problematic here.

"torch", "2.4.0"
), f"areal only supports FSDP2, which requires torch>=2.4.0"

self.world_size = int(os.environ["WORLD_SIZE"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be initialized in create_process_group

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move it to create_process_group

pass


class LocalLauncher:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? Can't we just import from areal.launcher.local?

AREAL_LLM_SERVER_ADDRS=server_addrs[
i % alloc_mode.gen.dp_size
],
AREAL_RECOVER_RUN=str(int(is_recover_run)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This environment variable is not used.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should be changed such that we can run the test with pytest areal/schduler/test_local.py.

Besides, move this file to areal/tests/

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need this config if we don't have a corresponding training script.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an example, it will be temporarily left here and deleted before merging.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this reward function used anywhere?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

concurrent.futures.process._RemoteTraceback:
"""
Traceback (most recent call last):
  File "/usr/lib/python3.12/multiprocessing/queues.py", line 264, in _feed
    obj = _ForkingPickler.dumps(obj)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
_pickle.PicklingError: Can't pickle <function gsm8k_reward_fn at 0x7f82ee9e5d00>: attribute lookup gsm8k_reward_fn on __main__ failed
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/storage/openpsi/codes/wht125/project/github_areal/AReaL/areal/api/reward_api.py", line 122, in __call__
    reward = await asyncio.wait_for(
             ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/asyncio/tasks.py", line 520, in wait_for
    return await fut
           ^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/queues.py", line 264, in _feed
    obj = _ForkingPickler.dumps(obj)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
_pickle.PicklingError: Can't pickle <function gsm8k_reward_fn at 0x7f82ee9e5d00>: attribute lookup gsm8k_reward_fn on __main__ failed

If this reward_fn is placed in the main program, the above error will occur because, in single controller mode, the process where the train engine resides is different from the process where the controller resides. After serializing and sending the reward_fn to the remote process, the reward_api.py in the train engine will serialize it again and send it to the subprocess. At this point, the reward_fn definition cannot be found for serialization (since it attempts to read reward_fn from main, but the subprocess's main is not the same as the controller's main). Therefore, the reward_fn should be defined in a separate file.

config.experiment_name, config.trial_name, config.cluster.fileroot
)

def create_workers(self, worker_role, config, *args, **kwargs) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signature of the base method is worker_key instead of worker_role

仲青 added 2 commits October 11, 2025 11:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants