-
Notifications
You must be signed in to change notification settings - Fork 198
[WIP] single controller: add rollout controller #415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[WIP] single controller: add rollout controller #415
Conversation
Summary of ChangesHello @dingzhiqiang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request lays crucial groundwork for enabling scalable and fault-tolerant distributed machine learning workflows. It introduces dedicated controllers for managing both distributed model rollouts and training, alongside a comprehensive refactoring of underlying API definitions. These changes streamline the orchestration of distributed tasks, improve inter-worker communication robustness, and provide greater flexibility in deploying distributed components, ultimately paving the way for more complex and efficient distributed ML systems. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces new DistributedRolloutController
and DistributedTrainController
classes, along with significant refactoring of the controller, scheduler, and engine APIs. The changes aim to establish a more robust single-controller architecture. While the overall direction is good, I've identified several critical issues, primarily concerning violations of the Liskov Substitution Principle in the DistributedRolloutController
, which break the established API contracts. Additionally, there are concerns regarding resource management, code quality (e.g., code duplication, magic strings, foreign-language comments), and robustness in error handling. Addressing these points will be crucial for the stability and maintainability of the new controller framework.
def wait(self, counts: List[int], timeout: float | None = None)->DistributedBatch: | ||
assert len(counts) == len(self.dp_head_workers) | ||
results = self.custom_function_call("wait", counts, timeout) | ||
return DistributedBatch.concat(results) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The signature of this wait
method (self, counts: List[int], ...)
is incompatible with the base class RolloutController.wait
which has the signature (self, count: int, ...)
. This violates the Liskov Substitution Principle. Subclass methods should have compatible signatures with their parent classes to ensure polymorphism works as expected.
def update_weights(self, meta: WeightUpdateMeta) -> None: | ||
"""Update weights in the inference engine.""" | ||
self.custom_function_call("update_weights", None, meta) | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation of update_weights
violates the Liskov Substitution Principle. The base class RolloutController.update_weights
is defined as a non-blocking method that returns a Future
, but this implementation is blocking (due to rpc_call
) and returns None
. This breaks the API contract and can lead to deadlocks or unexpected behavior in client code that expects an asynchronous operation.
def prepare_batch(self, data: DistributedBatch, workflow: RolloutWorkflow) -> None: | ||
"""Asynchronously submit a request to the inference engine. Exits immediately.""" | ||
batches = data.chunk(self.alloc_mode.gen.dp_size) | ||
self.custom_function_call("prepare_batch", batches, workflow) | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This prepare_batch
method returns None
, whereas the base class RolloutController.prepare_batch
is type-hinted to return a DistributedBatch
. This violates the Liskov Substitution Principle and breaks the API contract. The implementation should be updated to conform to the base class signature and behavior.
if exit_on_exception: | ||
logger.info("Exiting due to exception in future.") | ||
os.kill(os.getpid(), signal.SIGTERM) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using os.kill(os.getpid(), signal.SIGTERM)
to handle an exception is a very aggressive approach that can be dangerous. It terminates the entire process abruptly, preventing any cleanup code (in finally
blocks or atexit
handlers) from running. This can lead to corrupted state or resource leaks. A better approach is to re-raise the exception and let the caller decide on the appropriate action, which might include a graceful shutdown.
if exit_on_exception:
logger.info("Exiting due to exception in future.")
raise e
except TimeoutError: | ||
pass | ||
else: | ||
self.data_list_index = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The instance attribute self.data_list_index
is initialized here for the first time. It's a best practice to declare all instance attributes in the __init__
method of the class (WorkflowExecutor
). This prevents potential AttributeError
exceptions if other methods are called before this one and improves code readability by providing a single place to see all attributes of an object.
self.data_list_index = 0 | |
# This should be initialized in WorkflowExecutor.__init__ | |
if not hasattr(self, "data_list_index"): | |
self.data_list_index = 0 |
def custom_function_call(self, method: str, batches, *args, **kwargs): | ||
return rpc_call(self.scheduler, self.dp_head_workers, method, batches, args, kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using magic strings like "update_weights"
, "prepare_batch"
, etc., for RPC method names is brittle. Typos or changes in the remote method names will only be caught at runtime. It would be more robust to define these method names as constants in a shared API module. This would allow for static analysis and reduce the risk of runtime errors.
results = wait_future_ordered(futures, exit_on_exception=True) | ||
except Exception as e: | ||
raise RuntimeError(f"{method} failed, error: {e}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When an exception occurs, it's being wrapped in a RuntimeError
without preserving the original exception's context. This can make debugging more difficult as the original traceback is lost. It's better to use raise RuntimeError(...) from e
to chain the exceptions and preserve the full context.
results = wait_future_ordered(futures, exit_on_exception=True) | |
except Exception as e: | |
raise RuntimeError(f"{method} failed, error: {e}") | |
raise RuntimeError(f"{method} failed, error: {e}") from e |
return self.workflow_executor.resume() | ||
|
||
def get_scheduling_config(self) -> List[Scheduling]: | ||
# 部署 launcher/sglang_server.py, local_scheduler 注入一个ENGINE_PORTS的端口环境变量,里面有两个端口 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is in Chinese. For consistency and maintainability of the codebase, please write all comments in English.
# 部署 launcher/sglang_server.py, local_scheduler 注入一个ENGINE_PORTS的端口环境变量,里面有两个端口 | |
# Deploy launcher/sglang_server.py, local_scheduler injects an ENGINE_PORTS environment variable, which contains two ports |
Returns | ||
------- | ||
Scheduling | ||
The scheduling configuration for the engine |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# 处理StatefulDataLoader类型 - 保持原有逻辑不变 | ||
if not hasattr(self, "data_generator"): | ||
self.data_generator = cycle_dataloader(dataloader) | ||
assert dataloader.batch_size is not None | ||
batch_size = dataloader.batch_size | ||
|
||
while True: | ||
# Submit at least two batches to allow maximum overlap | ||
if ( | ||
self.get_capacity() + batch_size > 0 | ||
and self.input_queue.qsize() + batch_size | ||
< self.input_queue.maxsize | ||
): | ||
data = next(self.data_generator) | ||
for item in data: | ||
self.submit( | ||
item, | ||
workflow=workflow, | ||
workflow_builder=workflow_builder, | ||
should_accept=should_accept, | ||
) | ||
try: | ||
return self.wait(batch_size, timeout=1) | ||
except TimeoutError: | ||
pass | ||
else: | ||
self.data_list_index = 0 | ||
|
||
# 对于List类型,使用固定的batch_size=1 | ||
batch_size = 1 | ||
|
||
while True: | ||
# Submit at least two batches to allow maximum overlap | ||
if ( | ||
self.get_capacity() + batch_size > 0 | ||
and self.input_queue.qsize() + batch_size | ||
< self.input_queue.maxsize | ||
): | ||
# 从List中获取数据,支持循环访问 | ||
if self.data_list_index >= len(dataloader): | ||
self.data_list_index = 0 # 循环访问 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No description provided.