-
Notifications
You must be signed in to change notification settings - Fork 203
[WIP] single controller: add rollout controller #415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
21d5385
7237882
9ae69d2
076c3ba
3159d39
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -330,14 +330,11 @@ async def _rollout_thread_async(self): | |||||||||
try: | ||||||||||
while not self.exiting.is_set(): | ||||||||||
# Check capacity | ||||||||||
capacity = self.get_capacity() | ||||||||||
# capacity = self.get_capacity() | ||||||||||
# self.logger.info(f"Current rollout capacity: {capacity}") | ||||||||||
# Create new rollout task | ||||||||||
self.lock.acquire() | ||||||||||
while ( | ||||||||||
capacity > 0 | ||||||||||
and not self.paused.is_set() | ||||||||||
and self.input_queue.qsize() > 0 | ||||||||||
): | ||||||||||
while not self.paused.is_set() and self.input_queue.qsize() > 0: | ||||||||||
x = self.input_queue.get_nowait() | ||||||||||
x: _RolloutTaskInput | ||||||||||
self.logger.debug(f"Get data from puller: {x.data}") | ||||||||||
|
@@ -357,7 +354,7 @@ async def _rollout_thread_async(self): | |||||||||
f"running: {self.rollout_stat.running}, " | ||||||||||
f"accepted: {self.rollout_stat.accepted}." | ||||||||||
) | ||||||||||
capacity -= 1 | ||||||||||
# capacity -= 1 | ||||||||||
rid += 1 | ||||||||||
tasks = [x.task for x in rollout_tasks.values()] | ||||||||||
self.lock.release() | ||||||||||
|
@@ -524,7 +521,7 @@ def rollout_batch( | |||||||||
|
||||||||||
def prepare_batch( | ||||||||||
self, | ||||||||||
dataloader: StatefulDataLoader, | ||||||||||
dataloader: StatefulDataLoader | List[Dict[str, Any]], | ||||||||||
workflow: "RolloutWorkflow" | None = None, | ||||||||||
workflow_builder: Callable | None = None, | ||||||||||
should_accept: Callable | None = None, | ||||||||||
|
@@ -533,28 +530,62 @@ def prepare_batch( | |||||||||
|
||||||||||
See :meth:`~areal.api.engine_api.InferenceEngine.prepare_batch` for detailed documentation. | ||||||||||
""" | ||||||||||
if not hasattr(self, "data_generator"): | ||||||||||
self.data_generator = cycle_dataloader(dataloader) | ||||||||||
assert dataloader.batch_size is not None | ||||||||||
while True: | ||||||||||
# Submit at least two batches to allow maximum overlap | ||||||||||
if ( | ||||||||||
self.get_capacity() + dataloader.batch_size > 0 | ||||||||||
and self.input_queue.qsize() + dataloader.batch_size | ||||||||||
< self.input_queue.maxsize | ||||||||||
): | ||||||||||
data = next(self.data_generator) | ||||||||||
for item in data: | ||||||||||
if isinstance(dataloader, StatefulDataLoader): | ||||||||||
# 处理StatefulDataLoader类型 - 保持原有逻辑不变 | ||||||||||
if not hasattr(self, "data_generator"): | ||||||||||
self.data_generator = cycle_dataloader(dataloader) | ||||||||||
assert dataloader.batch_size is not None | ||||||||||
batch_size = dataloader.batch_size | ||||||||||
|
||||||||||
while True: | ||||||||||
# Submit at least two batches to allow maximum overlap | ||||||||||
if ( | ||||||||||
self.get_capacity() + batch_size > 0 | ||||||||||
and self.input_queue.qsize() + batch_size | ||||||||||
< self.input_queue.maxsize | ||||||||||
): | ||||||||||
data = next(self.data_generator) | ||||||||||
for item in data: | ||||||||||
self.submit( | ||||||||||
item, | ||||||||||
workflow=workflow, | ||||||||||
workflow_builder=workflow_builder, | ||||||||||
should_accept=should_accept, | ||||||||||
) | ||||||||||
try: | ||||||||||
return self.wait(batch_size, timeout=1) | ||||||||||
except TimeoutError: | ||||||||||
pass | ||||||||||
else: | ||||||||||
self.data_list_index = 0 | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The instance attribute
Suggested change
|
||||||||||
|
||||||||||
# 对于List类型,使用固定的batch_size=1 | ||||||||||
batch_size = 1 | ||||||||||
|
||||||||||
while True: | ||||||||||
# Submit at least two batches to allow maximum overlap | ||||||||||
if ( | ||||||||||
self.get_capacity() + batch_size > 0 | ||||||||||
and self.input_queue.qsize() + batch_size | ||||||||||
< self.input_queue.maxsize | ||||||||||
): | ||||||||||
# 从List中获取数据,支持循环访问 | ||||||||||
if self.data_list_index >= len(dataloader): | ||||||||||
self.data_list_index = 0 # 循环访问 | ||||||||||
Comment on lines
+534
to
+574
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||
|
||||||||||
item = dataloader[self.data_list_index] | ||||||||||
self.data_list_index += 1 | ||||||||||
|
||||||||||
self.submit( | ||||||||||
item, | ||||||||||
workflow=workflow, | ||||||||||
workflow_builder=workflow_builder, | ||||||||||
should_accept=should_accept, | ||||||||||
) | ||||||||||
try: | ||||||||||
return self.wait(dataloader.batch_size, timeout=1) | ||||||||||
except TimeoutError: | ||||||||||
pass | ||||||||||
try: | ||||||||||
return self.wait(batch_size, timeout=1) | ||||||||||
except TimeoutError: | ||||||||||
pass | ||||||||||
Comment on lines
+533
to
+588
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is significant code duplication between the |
||||||||||
|
||||||||||
def pause(self): | ||||||||||
"""Pause request submission for async rollout. | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring indicates that this method returns
Scheduling
, but the function's type hint specifiesList[Scheduling]
. This inconsistency can be misleading for developers using this API. Please update the docstring to match the return type hint.