Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2247,6 +2247,10 @@ class SchedulerConfig:
"""Maximum length of a sequence (including prompt and generated text). This
is primarily set in `ModelConfig` and that value should be manually
duplicated here."""
max_waiting_queue_length: Optional[int] = None
"""Maximum number of sequences in the waiting queue. If None, no limit is
applied. If set, the scheduler will reject new requests when the queue
length exceeds this value."""

max_num_partial_prefills: int = 1
"""For chunked prefill, the maximum number of sequences that can be
Expand Down Expand Up @@ -2470,7 +2474,11 @@ def _verify_args(self) -> Self:
"max_num_batched_tokens and makes vLLM reject longer "
"sequences. Please increase max_num_batched_tokens or "
"decrease max_model_len.")

if (self.max_waiting_queue_length is not None
and self.max_waiting_queue_length <= 0):
raise ValueError(
"max_waiting_queue_length must be a positive integer. "
"Use None for an unlimited queue.")
if self.max_num_batched_tokens < self.max_num_seqs:
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
Expand Down
7 changes: 7 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,9 @@ class EngineArgs:
# DEPRECATED
enable_prompt_adapter: bool = False

max_waiting_queue_length: Optional[int] = (
SchedulerConfig.max_waiting_queue_length)

def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
Expand Down Expand Up @@ -778,6 +781,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
title="SchedulerConfig",
description=SchedulerConfig.__doc__,
)
scheduler_group.add_argument(
"--max-waiting-queue-length",
**scheduler_kwargs["max_waiting_queue_length"])
scheduler_group.add_argument(
"--max-num-batched-tokens",
**scheduler_kwargs["max_num_batched_tokens"])
Expand Down Expand Up @@ -1215,6 +1221,7 @@ def create_engine_config(
max_num_batched_tokens=self.max_num_batched_tokens,
max_num_seqs=self.max_num_seqs,
max_model_len=model_config.max_model_len,
max_waiting_queue_length=self.max_waiting_queue_length,
cuda_graph_sizes=self.cuda_graph_sizes,
num_lookahead_slots=num_lookahead_slots,
delay_factor=self.scheduler_delay_factor,
Expand Down
8 changes: 7 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,10 @@ async def cancel_responses(response_id: str, raw_request: Request):
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
}
},
HTTPStatus.SERVICE_UNAVAILABLE.value: {
"model": ErrorResponse
},
})
@with_cancellation
@load_aware_call
Expand Down Expand Up @@ -670,6 +673,9 @@ async def create_chat_completion(request: ChatCompletionRequest,
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
HTTPStatus.SERVICE_UNAVAILABLE.value: {
"model": ErrorResponse
},
})
@with_cancellation
@load_aware_call
Expand Down
8 changes: 3 additions & 5 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,7 @@ async def create_chat_completion(
request, result_generator, request_id, model_name,
conversation, tokenizer, request_metadata)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)

def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
Expand Down Expand Up @@ -929,8 +928,7 @@ async def chat_completion_stream_generator(

except Exception as e:
# TODO: Use a vllm-specific Validation Error
logger.exception("Error in chat completion stream generator.")
data = self.create_streaming_error_response(str(e))
data = self.create_streaming_error_response(e)
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"
Expand All @@ -956,7 +954,7 @@ async def chat_completion_full_generator(
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)

assert final_res is not None

Expand Down
6 changes: 3 additions & 3 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ async def create_completion(
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)

result_generator = merge_async_iterators(*generators)

Expand Down Expand Up @@ -288,7 +288,7 @@ async def create_completion(
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)

# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
Expand Down Expand Up @@ -470,7 +470,7 @@ async def completion_stream_generator(

except Exception as e:
# TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e))
data = self.create_streaming_error_response(e)
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"

Expand Down
17 changes: 14 additions & 3 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
merge_async_iterators, random_uuid)
from vllm.v1.engine.exceptions import SchedulerWaitingQueueFullError

logger = init_logger(__name__)

Expand Down Expand Up @@ -408,16 +409,26 @@ async def _collect_batch(

def create_error_response(
self,
message: str,
message: Union[str, Exception],
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
return ErrorResponse(message=message,
if isinstance(message, SchedulerWaitingQueueFullError):
return ErrorResponse(
message=str(message),
type="ServiceUnavailableError",
code=HTTPStatus.SERVICE_UNAVAILABLE.value,
)
elif isinstance(message, Exception):
message_str = str(message)
else:
message_str = message
return ErrorResponse(message=message_str,
type=err_type,
code=status_code.value)

def create_streaming_error_response(
self,
message: str,
message: Union[str, Exception],
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
json_str = json.dumps({
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/serving_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ async def create_responses(
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)

assert len(generators) == 1
result_generator, = generators
Expand Down Expand Up @@ -262,7 +262,7 @@ async def responses_full_generator(
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)

assert final_res is not None
assert len(final_res.outputs) == 1
Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from vllm.v1.core.sched.utils import check_stop
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
EngineCoreOutputs)
from vllm.v1.engine.exceptions import SchedulerWaitingQueueFullError
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
Expand Down Expand Up @@ -957,6 +958,10 @@ def get_request_counts(self) -> tuple[int, int]:
return len(self.running), len(self.waiting)

def add_request(self, request: Request) -> None:
if (self.scheduler_config.max_waiting_queue_length
and len(self.waiting)
>= self.scheduler_config.max_waiting_queue_length):
raise SchedulerWaitingQueueFullError(request_id=request.request_id)
self.waiting.add_request(request)
self.requests[request.request_id] = request
if self.log_stats:
Expand Down
9 changes: 9 additions & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@ class UtilityOutput(
result: Any = None


class EngineErrorPayload(msgspec.Struct):
exc_type: str
exc_module: str
exc_args: list
exc_traceback: str


class EngineCoreOutputs(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
Expand All @@ -161,6 +168,8 @@ class EngineCoreOutputs(
# "old" wave, so the next wave needs to be started in other engines.
start_wave: Optional[int] = None

engine_error: Optional[EngineErrorPayload] = None

def __post_init__(self):
if self.timestamp == 0.0:
self.timestamp = time.monotonic()
Expand Down
22 changes: 18 additions & 4 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, cdiv
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.v1.engine.core_client import EngineCoreClient, process_engine_error
from vllm.v1.engine.exceptions import (EngineDeadError, EngineGenerateError,
SchedulerWaitingQueueFullError)
from vllm.v1.engine.output_processor import (OutputProcessor,
RequestOutputCollector)
from vllm.v1.engine.parallel_sampling import ParentRequest
Expand Down Expand Up @@ -341,13 +342,16 @@ async def generate(
if self.log_requests:
logger.info("Request %s failed (engine dead).", request_id)
raise

except SchedulerWaitingQueueFullError:
if self.log_requests:
logger.info("Request %s failed (waiting queue full).",
request_id)
raise
# Request validation error.
except ValueError:
if self.log_requests:
logger.info("Request %s failed (bad request).", request_id)
raise

# Unexpected error in the generate() task (possibly recoverable).
except Exception as e:
await self.abort(request_id)
Expand All @@ -373,6 +377,10 @@ async def output_handler():
while True:
# 1) Pull EngineCoreOutputs from the EngineCore.
outputs = await engine_core.get_output_async()
if outputs.engine_error:
output_processor.propagate_error(
process_engine_error(outputs.engine_error))
continue
num_outputs = len(outputs.outputs)

iteration_stats = IterationStats() if (
Expand Down Expand Up @@ -494,6 +502,12 @@ async def encode(
logger.info("Request %s failed (engine dead).", request_id)
raise

except SchedulerWaitingQueueFullError:
if self.log_requests:
logger.info("Request %s failed (waiting queue full).",
request_id)
raise

# Request validation error.
except ValueError:
if self.log_requests:
Expand Down
25 changes: 21 additions & 4 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
import threading
import time
import traceback
from collections import deque
from collections.abc import Generator
from concurrent.futures import Future
Expand All @@ -32,7 +33,7 @@
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType,
EngineCoreRequestType, EngineErrorPayload,
ReconfigureDistributedRequest, ReconfigureRankType,
UtilityOutput)
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
Expand Down Expand Up @@ -621,9 +622,11 @@ def signal_handler(signum, frame):
engine_core = DPEngineCoreProc(*args, **kwargs)
else:
engine_core = EngineCoreProc(*args, **kwargs)

engine_core.run_busy_loop()

while True:
try:
engine_core.run_busy_loop()
except ValueError as e:
engine_core._send_engine_error(e)
except SystemExit:
logger.debug("EngineCore exiting.")
raise
Expand Down Expand Up @@ -734,6 +737,20 @@ def _send_engine_dead(self):
logger.fatal("vLLM shutdown signal from EngineCore failed "
"to send. Please report this issue.")

def _send_engine_error(self, exc: BaseException):
"""Send CustomEngineError status to the EngineCoreClient."""

# Put CustomEngineError in the queue.
self.output_queue.put_nowait((
0,
EngineCoreOutputs(engine_error=EngineErrorPayload(
exc_type=type(exc).__name__,
exc_module=type(exc).__module__,
exc_args=list(exc.args),
exc_traceback=traceback.format_exc(),
)),
))

def process_input_sockets(self, input_addresses: list[str],
coord_input_address: Optional[str],
identity: bytes):
Expand Down
16 changes: 15 additions & 1 deletion vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from vllm.lora.request import LoRARequest
from vllm.utils import get_open_port, get_open_zmq_inproc_path, make_zmq_socket
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType,
EngineCoreRequestType, EngineErrorPayload,
ReconfigureDistributedRequest, ReconfigureRankType,
UtilityOutput)
from vllm.v1.engine.coordinator import DPCoordinator
Expand Down Expand Up @@ -719,6 +719,9 @@ async def process_outputs_socket():
frames = await output_socket.recv_multipart(copy=False)
resources.validate_alive(frames)
outputs: EngineCoreOutputs = decoder.decode(frames)
if outputs.engine_error:
outputs_queue.put_nowait(outputs)
continue
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
Expand Down Expand Up @@ -1234,3 +1237,14 @@ async def _scale_down_elastic_ep(self, cur_data_parallel_size: int,
logger.info(
"[Elastic EP] Scale down completed, new data parallel size: %s",
new_data_parallel_size)


def process_engine_error(engine_error: EngineErrorPayload) -> Exception:
"""Process an engine error payload and raise an exception."""
try:
module = sys.modules.get(engine_error.exc_module)
exc_class = getattr(module, engine_error.exc_type)
except Exception:
exc_class = RuntimeError # fallback
exc = exc_class(*engine_error.exc_args)
return exc
10 changes: 10 additions & 0 deletions vllm/v1/engine/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
class EngineGenerateError(Exception):
"""Raised when a AsyncLLM.generate() fails. Recoverable."""

pass


Expand All @@ -15,3 +16,12 @@ def __init__(self, *args, suppress_context: bool = False, **kwargs):
# Make stack trace clearer when using with LLMEngine by
# silencing irrelevant ZMQError.
self.__suppress_context__ = suppress_context


class SchedulerWaitingQueueFullError(ValueError):
"""Raised when the scheduler's waiting queue is full and cannot accept
new requests."""

def __init__(self, request_id: str):
super().__init__(request_id)
self.request_id = request_id
Loading