-
Notifications
You must be signed in to change notification settings - Fork 197
[feature] support local scheduler #427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @yulangz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the system's flexibility by integrating a local scheduler, allowing for the execution of complex reinforcement learning workflows on a single machine. The changes involve crucial refactoring of engine initialization logic to accommodate local execution, alongside improvements to RPC data handling for better performance and stability. A new, detailed test suite ensures the robustness of this local execution capability. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a local scheduler, which is a significant feature for enabling local development and testing. The refactoring to move non-serializable members to initialize
methods is a great improvement for the architecture, making engines more flexible. The optimization in the RPC server to move data to the CPU before transmission is also a good performance enhancement. My review focuses on improving the robustness and maintainability of the new local scheduler implementation, cleaning up some leftover code from the refactoring, and ensuring API consistency. I've pointed out a potentially fragile command construction in the local launcher, a problematic use of __del__
, and some missing type hints and commented-out code. Overall, this is a solid contribution.
def __del__(self): | ||
self.wait() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __del__
method calls self.wait()
. This can be problematic. __del__
is called during garbage collection, and its execution context is unpredictable. If self.wait()
blocks for a long time (e.g., waiting for long-running subprocesses), it can delay or interfere with program shutdown. Also, if an exception occurs in wait()
, it will be ignored. It's generally recommended to avoid complex logic in __del__
. A better pattern is to use a context manager (__enter__
and __exit__
). If you must use __del__
, consider calling self.stop_all()
instead to terminate jobs, which is a more predictable cleanup operation.
def __del__(self): | |
self.wait() | |
def __del__(self): | |
self.stop_all() |
c = ( | ||
" ".join(str(k) + "=" + str(v) for k, v in env_vars.items()) | ||
+ " stdbuf -oL " | ||
+ cmd[i] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current method of constructing the command string by joining environment variables is not robust. If an environment variable's value contains spaces or other shell special characters, it will break the command. A safer approach is to use the env
parameter of subprocess.Popen
to pass environment variables. This avoids shell injection issues and handles special characters correctly. You would need to create a copy of os.environ
, update it with env_vars
, and pass it to Popen
on line 140.
# capacity = self.get_capacity() | ||
# self.logger.info(f"Current rollout capacity: {capacity}") | ||
# Create new rollout task | ||
self.lock.acquire() | ||
while ( | ||
capacity > 0 | ||
and not self.paused.is_set() | ||
and self.input_queue.qsize() > 0 | ||
): | ||
while not self.paused.is_set() and self.input_queue.qsize() > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def compute_advantages(self, *args, **kwargs): | ||
return self.actor.compute_advantages(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type hint for compute_advantages
is missing. Based on the wrapped self.actor.compute_advantages
method, it should be -> Dict[str, Any]
. Adding type hints improves code clarity and allows static analysis tools to catch potential bugs.
def compute_advantages(self, *args, **kwargs): | |
return self.actor.compute_advantages(*args, **kwargs) | |
def compute_advantages(self, *args, **kwargs) -> Dict[str, Any]: | |
return self.actor.compute_advantages(*args, **kwargs) |
env_vars=dict( | ||
**get_env_vars( | ||
config.cluster.cluster_name, | ||
# config.launcher.worker_env_vars, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
config.cluster.cluster_name, | ||
config.launcher.trainer_env_vars, | ||
), | ||
# AREAL_LLM_SERVER_ADDRS=",".join(server_addrs), # not need? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self, worker_id: str, method: str, max_retries: int, *args, **kwargs | ||
) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default value for max_retries
has been removed, making it a required argument. This could be an unintentional breaking change for callers of this method. If the intention is to always require it, this is fine, but if a default is desirable, consider re-adding it (e.g., max_retries: int = 3
). The only call site in this PR (areal/scheduler/local.py:407
) hardcodes it to 3
.
self, worker_id: str, method: str, max_retries: int, *args, **kwargs | |
) -> Any: | |
self, worker_id: str, method: str, max_retries: int = 3, *args, **kwargs | |
) -> Any: |
# NOTE: We must use HTTPServer rather than ThreadingHTTPServer here, since the rank and device info | ||
# of pytorch is thread level, if use ThreadingHTTPServer, the device set by create_engine thread | ||
# will not be seen by call_engine thread. | ||
# server = ThreadingHTTPServer(("0.0.0.0", port), EngineRPCServer) | ||
server = HTTPServer(("0.0.0.0", port), EngineRPCServer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment explaining the switch from ThreadingHTTPServer
to HTTPServer
is excellent. To improve code clarity, the commented-out ThreadingHTTPServer
line could be removed, as the explanation is sufficient.
# NOTE: We must use HTTPServer rather than ThreadingHTTPServer here, since the rank and device info | |
# of pytorch is thread level, if use ThreadingHTTPServer, the device set by create_engine thread | |
# will not be seen by call_engine thread. | |
# server = ThreadingHTTPServer(("0.0.0.0", port), EngineRPCServer) | |
server = HTTPServer(("0.0.0.0", port), EngineRPCServer) | |
# NOTE: We must use HTTPServer rather than ThreadingHTTPServer here, since the rank and device info | |
# of pytorch is thread level, if use ThreadingHTTPServer, the device set by create_engine thread | |
# will not be seen by call_engine thread. | |
server = HTTPServer(("0.0.0.0", port), EngineRPCServer) |
print("[wht debug] rollout workers:", rollout_workers) | ||
print("[wht debug] actor workers:", actor_workers) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test file contains many print
statements with [wht debug]
prefixes. While useful for debugging, it's better practice to use the logging
module. This allows for configurable log levels and makes it easier to manage log output, especially in CI/CD environments. Consider replacing these prints with logger.debug()
or logger.info()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to change this file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Under my configuration, I will hang due to insufficient capacity.
The calculation of capacity
seems to be one-time for each round of rollout. In one round of rollout, even if some queries have completed the rollout, the capacity
will not increase. The calculation logic of capacity
seems to have been problematic here.
areal/engine/fsdp_engine.py
Outdated
"torch", "2.4.0" | ||
), f"areal only supports FSDP2, which requires torch>=2.4.0" | ||
|
||
self.world_size = int(os.environ["WORLD_SIZE"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be initialized in create_process_group
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move it to create_process_group
pass | ||
|
||
|
||
class LocalLauncher: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this? Can't we just import from areal.launcher.local
?
AREAL_LLM_SERVER_ADDRS=server_addrs[ | ||
i % alloc_mode.gen.dp_size | ||
], | ||
AREAL_RECOVER_RUN=str(int(is_recover_run)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This environment variable is not used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file should be changed such that we can run the test with pytest areal/schduler/test_local.py
.
Besides, move this file to areal/tests/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't need this config if we don't have a corresponding training script.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As an example, it will be temporarily left here and deleted before merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this reward function used anywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
concurrent.futures.process._RemoteTraceback:
"""
Traceback (most recent call last):
File "/usr/lib/python3.12/multiprocessing/queues.py", line 264, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
_pickle.PicklingError: Can't pickle <function gsm8k_reward_fn at 0x7f82ee9e5d00>: attribute lookup gsm8k_reward_fn on __main__ failed
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/storage/openpsi/codes/wht125/project/github_areal/AReaL/areal/api/reward_api.py", line 122, in __call__
reward = await asyncio.wait_for(
^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/asyncio/tasks.py", line 520, in wait_for
return await fut
^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/queues.py", line 264, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
_pickle.PicklingError: Can't pickle <function gsm8k_reward_fn at 0x7f82ee9e5d00>: attribute lookup gsm8k_reward_fn on __main__ failed
If this reward_fn is placed in the main program, the above error will occur because, in single controller mode, the process where the train engine resides is different from the process where the controller resides. After serializing and sending the reward_fn to the remote process, the reward_api.py in the train engine will serialize it again and send it to the subprocess. At this point, the reward_fn definition cannot be found for serialization (since it attempts to read reward_fn from main, but the subprocess's main is not the same as the controller's main). Therefore, the reward_fn should be defined in a separate file.
config.experiment_name, config.trial_name, config.cluster.fileroot | ||
) | ||
|
||
def create_workers(self, worker_role, config, *args, **kwargs) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The signature of the base method is worker_key
instead of worker_role
…usionAI/AReaL into wht/feature/support_local_scheduler
support local scheduler
initialize
method.usage