Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 8 additions & 51 deletions areal/api/controller_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def set_version(self, version: int):
"""
raise NotImplementedError()

def get_version(self) -> int:
def get_version(self) -> List[int]:
"""Get the current weight version in the training engine.

Returns
Expand Down Expand Up @@ -359,7 +359,7 @@ def train_batch(
input_: DistributedBatch,
loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor],
loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor],
) -> Dict[str, float]:
) -> List[Dict[str, float]]:
"""Update the model with a batch of data and a loss function.

Note
Expand All @@ -382,7 +382,7 @@ def train_batch(

Returns
-------
Dict[str, float]
List[Dict[str, float]]
Scalar statistics after training, e.g., the current learning rate,
gradient norm, etc.
"""
Expand All @@ -394,7 +394,7 @@ def eval_batch(
input_: DistributedBatch,
loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor],
loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor],
) -> torch.Tensor | None:
) -> List[torch.Tensor]:
"""Evaluate the model using the forward pass and loss function.

Note
Expand Down Expand Up @@ -458,7 +458,6 @@ def forward(
"""
raise NotImplementedError()


class RolloutController(abc.ABC):
"""A centralized controller that manages multiple distributed InferenceEngine workers for rollout generation.

Expand Down Expand Up @@ -508,21 +507,6 @@ def destroy(self):
"""Destroy the engine and release GPU memory for the local inference engine."""
raise NotImplementedError()

async def agenerate(self, req: ModelRequest) -> ModelResponse:
"""Asynchronously generate a response for the given request.

Parameters
----------
req : ModelRequest
The model request containing input data and generation parameters

Returns
-------
ModelResponse
The generated response from the model
"""
raise NotImplementedError()

def update_weights(self, meta: WeightUpdateMeta) -> Future:
"""Update weights in the inference engine in a non-blocking manner.

Expand Down Expand Up @@ -571,7 +555,7 @@ def get_version(self) -> int:

def submit(
self,
data: Dict[str, Any],
data: DistributedBatch,
workflow: Optional["RolloutWorkflow"] = None,
workflow_builder: Optional[Callable] = None,
should_accept: Callable | None = None,
Expand Down Expand Up @@ -623,7 +607,7 @@ def wait(self, count: int, timeout: float | None = None) -> DistributedBatch:

def rollout_batch(
self,
data: List[Dict[str, Any]],
data: DistributedBatch,
workflow: Optional["RolloutWorkflow"] = None,
workflow_builder: Optional[Callable] = None,
should_accept: Callable | None = None,
Expand Down Expand Up @@ -652,7 +636,7 @@ def rollout_batch(

def prepare_batch(
self,
dataloader: StatefulDataLoader,
dataloader: DistributedBatch,
workflow: Optional["RolloutWorkflow"] = None,
workflow_builder: Optional[Callable] = None,
should_accept: Callable | None = None,
Expand Down Expand Up @@ -688,31 +672,4 @@ def pause(self):

def resume(self):
"""Resume request submission for async rollout."""
raise NotImplementedError()

def register_callback_to_all_worker(
self, method: str, callback: Callable, **kwargs
):
"""Register a callback function for the specified method across all workers.

Partial rollout API. After successful registration, the controller will poll
and call the specified method in a background thread. When the return value
is obtained, it will be used as a parameter to call the `callback` function.

Parameters
----------
method : str
The name of the method to register the callback for
callback : Callable
The callback function to be called with the method's return value
**kwargs
Additional keyword arguments for the callback registration
"""
raise NotImplementedError()

def abort_all_requests(self) -> None:
"""Abort all ongoing requests in the inference engine.

Partial rollout API for canceling all queued and in-progress requests.
"""
raise NotImplementedError()
raise NotImplementedError()
16 changes: 15 additions & 1 deletion areal/api/engine_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class Scheduling:
cpu: int
gpu: int
mem: int
port_count: int
cmd: str | None = None
nodelist: str | None = None
exclude: str | None = None
partition: str | None = None
Expand Down Expand Up @@ -138,7 +140,7 @@ def parallelism_group(self) -> dist.ProcessGroup:
"""
raise NotImplementedError()

def get_scheduling_config(self) -> Scheduling:
def get_scheduling_config(self) -> List[Scheduling]:
"""Get the scheduling configuration for the engine.

This includes configuration such as container image, CPU/GPU/memory size.
Expand Down Expand Up @@ -553,3 +555,15 @@ def pause(self):
def resume(self):
"""Resume request submission for async rollout."""
raise NotImplementedError()

def get_scheduling_config(self) -> List[Scheduling]:
"""Get the scheduling configuration for the engine.

This includes configuration such as container image, CPU/GPU/memory size.

Returns
-------
Scheduling
The scheduling configuration for the engine
Comment on lines +564 to +567
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring indicates that this method returns Scheduling, but the function's type hint specifies List[Scheduling]. This inconsistency can be misleading for developers using this API. Please update the docstring to match the return type hint.

"""
raise NotImplementedError()
33 changes: 13 additions & 20 deletions areal/api/scheduler_api.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,40 @@
import abc
from dataclasses import dataclass, field
from typing import Dict, List
from typing import List, Literal

from areal.api.engine_api import Scheduling


@dataclass
class Worker:
id: str
ip: str
ports: List[str] = field(default_factory=list)


@dataclass
class ContainerSpec:
cpu: int = 0
gpu: int = 0
mem: int = 0
container_image: str = ""
cmd: str = ""
env_vars: Dict[str, str] = field(default_factory=dict)
port_count: int = 2
serve_port: str
extra_ports: List[str] = field(default_factory=list)


@dataclass
class ScheduleStrategy:
type: str = ""
uid: str = ""
type: Literal["colocation", "separation", ""] = ""
target: str = ""


@dataclass
class SchedulingConfig:
class Job:
replicas: int = 0
specs: List[ContainerSpec] = field(default_factory=list)
tasks: List[Scheduling] = field(default_factory=list)
schedule_strategy: ScheduleStrategy | None = None
role: str = ""


class Scheduler(abc.ABC):
def create_workers(self, worker_key, scheduler_config, *args, **kwargs) -> str:
def create_workers(self, worker_key, scheduler_config, *args, **kwargs) -> None:
"""
Start workers, return job id
Start workers
"""
raise NotImplementedError()

def get_workers(self, worker_key, timeout=None) -> List[Worker]:
def get_workers(self, role: str, timeout=None) -> List[Worker]:
"""
Wait and return worker list, including scheduling results such as ip and engine ports
(worker id, ip, ports)
Expand Down
79 changes: 55 additions & 24 deletions areal/api/workflow_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,11 @@ async def _rollout_thread_async(self):
try:
while not self.exiting.is_set():
# Check capacity
capacity = self.get_capacity()
# capacity = self.get_capacity()
# self.logger.info(f"Current rollout capacity: {capacity}")
# Create new rollout task
self.lock.acquire()
while (
capacity > 0
and not self.paused.is_set()
and self.input_queue.qsize() > 0
):
while not self.paused.is_set() and self.input_queue.qsize() > 0:
x = self.input_queue.get_nowait()
x: _RolloutTaskInput
self.logger.debug(f"Get data from puller: {x.data}")
Expand All @@ -357,7 +354,7 @@ async def _rollout_thread_async(self):
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
capacity -= 1
# capacity -= 1
rid += 1
tasks = [x.task for x in rollout_tasks.values()]
self.lock.release()
Expand Down Expand Up @@ -524,7 +521,7 @@ def rollout_batch(

def prepare_batch(
self,
dataloader: StatefulDataLoader,
dataloader: StatefulDataLoader | List[Dict[str, Any]],
workflow: "RolloutWorkflow" | None = None,
workflow_builder: Callable | None = None,
should_accept: Callable | None = None,
Expand All @@ -533,28 +530,62 @@ def prepare_batch(

See :meth:`~areal.api.engine_api.InferenceEngine.prepare_batch` for detailed documentation.
"""
if not hasattr(self, "data_generator"):
self.data_generator = cycle_dataloader(dataloader)
assert dataloader.batch_size is not None
while True:
# Submit at least two batches to allow maximum overlap
if (
self.get_capacity() + dataloader.batch_size > 0
and self.input_queue.qsize() + dataloader.batch_size
< self.input_queue.maxsize
):
data = next(self.data_generator)
for item in data:
if isinstance(dataloader, StatefulDataLoader):
# 处理StatefulDataLoader类型 - 保持原有逻辑不变
if not hasattr(self, "data_generator"):
self.data_generator = cycle_dataloader(dataloader)
assert dataloader.batch_size is not None
batch_size = dataloader.batch_size

while True:
# Submit at least two batches to allow maximum overlap
if (
self.get_capacity() + batch_size > 0
and self.input_queue.qsize() + batch_size
< self.input_queue.maxsize
):
data = next(self.data_generator)
for item in data:
self.submit(
item,
workflow=workflow,
workflow_builder=workflow_builder,
should_accept=should_accept,
)
try:
return self.wait(batch_size, timeout=1)
except TimeoutError:
pass
else:
self.data_list_index = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The instance attribute self.data_list_index is initialized here for the first time. It's a best practice to declare all instance attributes in the __init__ method of the class (WorkflowExecutor). This prevents potential AttributeError exceptions if other methods are called before this one and improves code readability by providing a single place to see all attributes of an object.

Suggested change
self.data_list_index = 0
# This should be initialized in WorkflowExecutor.__init__
if not hasattr(self, "data_list_index"):
self.data_list_index = 0


# 对于List类型,使用固定的batch_size=1
batch_size = 1

while True:
# Submit at least two batches to allow maximum overlap
if (
self.get_capacity() + batch_size > 0
and self.input_queue.qsize() + batch_size
< self.input_queue.maxsize
):
# 从List中获取数据,支持循环访问
if self.data_list_index >= len(dataloader):
self.data_list_index = 0 # 循环访问
Comment on lines +534 to +574
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code contains comments in Chinese. To ensure the codebase is accessible and maintainable for all contributors, please translate these comments into English.


item = dataloader[self.data_list_index]
self.data_list_index += 1

self.submit(
item,
workflow=workflow,
workflow_builder=workflow_builder,
should_accept=should_accept,
)
try:
return self.wait(dataloader.batch_size, timeout=1)
except TimeoutError:
pass
try:
return self.wait(batch_size, timeout=1)
except TimeoutError:
pass
Comment on lines +533 to +588
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication between the if isinstance(dataloader, StatefulDataLoader): block and the else: block. Both branches contain a while True: loop with very similar logic for submitting tasks and waiting for results. This duplication makes the code harder to maintain and read. Consider refactoring this method to extract the common logic into a helper function or by using a more abstract way to iterate over the different dataloader types.


def pause(self):
"""Pause request submission for async rollout.
Expand Down
Loading