Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,26 +88,24 @@
import asyncio
import functools
import heapq
import json
import os
import sys
import threading
import uuid
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, List
from typing import List

import httpx
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from transformers import AutoTokenizer
from vllm.logger import init_logger

logger = init_logger(__name__)

# Add uvloop for faster event loop if available
try:
import uvloop

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
pass
Expand Down Expand Up @@ -154,6 +152,9 @@ def __init__(self, prefiller_instances, decoder_instances):
heapq.heapify(self.prefiller_heap)
heapq.heapify(self.decoder_heap)
self.req_id_future = {}
self.req_data_dict = {}
self.tokenizer = AutoTokenizer.from_pretrained(
global_args.tokenizer_dir)

def _update_prefiller_priority(self, server_idx: int):
"""Update the priority of a prefiller server in the heap."""
Expand Down Expand Up @@ -280,6 +281,10 @@ def parse_args():
nargs="+",
default=["localhost"])
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
parser.add_argument("--tokenizer-dir",
type=str,
default="/mnt/weight/Qwen3-235B-A22B-W8A8",
help="Maximum number of retries for HTTP requests")
parser.add_argument("--max-retries",
type=int,
default=3,
Expand Down Expand Up @@ -356,17 +361,6 @@ async def send_request_to_service(client: httpx.AsyncClient,
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
prefiller_id)
req_data = req_data.copy()
req_data['kv_transfer_params'] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
"aborted_request": list(aborted_requests),
"metaserver":
f"http://{global_args.host}:{global_args.port}/v1/metaserver"
}
req_data["stream"] = False
req_data["max_tokens"] = 1
if "stream_options" in req_data:
Expand Down Expand Up @@ -458,180 +452,59 @@ def get_api_request_id(api, req_id):
return "chatcmpl-" + req_id


async def _handle_select_instance(api: str, req_data: Any,
request_length: int):
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
logger.debug(
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
)
request_id = await proxy_state.next_req_id()
# Select prefiller
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
prefiller = proxy_state.prefillers[prefiller_idx]
result_future = asyncio.Future() # type: ignore
request_id_api = get_api_request_id(api, request_id)
proxy_state.req_id_future[request_id_api] = result_future
# Send request to prefiller
asyncio.get_running_loop().create_task(
send_request_to_service(prefiller.client,
prefiller_idx,
api,
req_data,
request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay))
proxy_state.release_prefiller(prefiller_idx, prefiller_score)

response = await result_future
del proxy_state.req_id_future[request_id_api]
req_data["kv_transfer_params"] = response

# Select decoder
decoder_score = proxy_state.calculate_decode_scores(request_length)
logger.debug("Decoder score: %f", decoder_score)
# Use the prefiller's kv_transfer_params to select decoder
decoder_idx = proxy_state.select_decoder(decoder_score)
decoder = proxy_state.decoders[decoder_idx]
logger.debug("Using %s %s", prefiller.url, decoder.url)
return InstanceInfo(request_id=request_id,
prefiller_idx=prefiller_idx,
prefiller_score=prefiller_score,
prefiller=prefiller,
decoder=decoder,
decoder_idx=decoder_idx,
decoder_score=decoder_score)


@dataclass
class InstanceInfo:
request_id: str
prefiller_idx: int
prefiller_score: float
prefiller: ServerState
decoder_idx: int
decoder_score: float
decoder: ServerState
def get_origin_request_id(api, req_id):
if api == "/completions":
return req_id.replace("cmpl-", "").replace("-0", "")
elif api == "/chat/completions":
return req_id.replace("chatcmpl-", "")


async def _handle_completions(api: str, request: Request):
try:
req_data = await request.json()
req_body = await request.body()
request_length = len(req_body)
instance_info = await _handle_select_instance(api, req_data,
request_length)
stream_flag = bool(req_data.get("stream", False))
chat_flag = "messages" in req_data

if "prompt" in req_data:
origin_prompt = req_data["prompt"]
elif chat_flag:
messages = req_data["messages"]
origin_prompt = messages[0].get("content", "")
else:
origin_prompt = ""
# refer to vLLM sampling_params: max_token default value
origin_max_tokens = req_data.get("max_tokens", 16)
request_id = await proxy_state.next_req_id()
request_id_api = get_api_request_id(api, request_id)
proxy_state.req_data_dict[request_id_api] = (req_data, request_length,
api)
req_data['kv_transfer_params'] = {
"do_remote_decode":
False,
"do_remote_prefill":
True,
"metaserver":
f"http://{global_args.host}:{global_args.port}/v1/metaserver"
}
# Select decoder
decoder_score = proxy_state.calculate_decode_scores(request_length)
logger.debug("Decoder score: %f", decoder_score)
# Use the prefiller's kv_transfer_params to select decoder
decoder_idx = proxy_state.select_decoder(decoder_score)
decoder = proxy_state.decoders[decoder_idx]
# logger.debug("Using %s %s", prefiller.url, decoder.url)
# Stream response from decoder
released_kv = False

async def generate_stream():
nonlocal instance_info
generated_token = ""
released_kv = False
retry_count = 0
retry = True
completion_tokens = 0
nonlocal released_kv
# Only one await per chunk, minimal logic in loop
try:
while retry:
retry = False
async for chunk in stream_service_response_with_retry(
instance_info.decoder.client,
api,
req_data,
request_id=instance_info.request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay):
if not released_kv and chunk:
proxy_state.release_prefiller_kv(
instance_info.prefiller_idx,
instance_info.prefiller_score)
released_kv = True
try:
chunk_str = chunk.decode("utf-8").strip()
except UnicodeDecodeError:
logger.debug(
f"Skipping chunk: {chunk}")
yield chunk
continue
if not chunk_str:
continue
if chunk_str.startswith("data: "):
chunk_str = chunk_str[len("data: "):]
try:
chunk_json = json.loads(chunk_str)
except json.JSONDecodeError:
# if chunk is [done], skip it.
logger.debug(
f"Skipping chunk: {chunk_str}")
yield chunk
continue
choices = chunk_json.get("choices", [])
if not choices:
yield chunk
continue

choice = choices[0]
delta = choice.get("delta") or {}
message = choice.get("message") or {}
content = (
delta.get("content")
or message.get("content")
or choice.get("text")
or ""
)
generated_token += content

stop_reason = choice.get(
"stop_reason")
usage = chunk_json.get("usage", {})
completion_tokens = (completion_tokens + 1) if stream_flag else \
(completion_tokens + usage.get("completion_tokens"))
if stop_reason == "recomputed":
retry = True
retry_count += 1
if chat_flag:
messages[0][
"content"] = origin_prompt + generated_token
else:
req_data[
"prompt"] = origin_prompt + generated_token
req_data[
"max_tokens"] = origin_max_tokens - completion_tokens + retry_count
tmp_request_length = len(
json.dumps(req_data).encode("utf-8"))
instance_info = await _handle_select_instance(
api, req_data, tmp_request_length)
break
if retry_count > 0 and not stream_flag:
if chat_flag:
choices[0]["message"][
"content"] = generated_token
else:
choices[0]["text"] = generated_token
chunk = json.dumps(chunk_json).encode("utf-8")
yield chunk
async for chunk in stream_service_response_with_retry(
decoder.client,
api,
req_data,
request_id=request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay):
yield chunk
except Exception as e:
logger.error(
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
)
proxy_state.abort_prefiller_request(
instance_info.prefiller_idx, instance_info.request_id)
proxy_state.release_prefiller_kv(instance_info.prefiller_idx,
instance_info.prefiller_score)

# After streaming done, release tokens
proxy_state.release_decoder(instance_info.decoder_idx,
instance_info.decoder_score)
proxy_state.release_decoder(decoder_idx, decoder_score)

return StreamingResponse(generate_stream(),
media_type="application/json")
Expand Down Expand Up @@ -669,11 +542,33 @@ async def healthcheck():
@app.post("/v1/metaserver")
async def metaserver(request: Request):
try:
req_data = await request.json()
request_id = req_data.pop("request_id", None)
if request_id in proxy_state.req_id_future:
result_future = proxy_state.req_id_future[request_id]
result_future.set_result(req_data)
kv_transfer_params = await request.json()

request_id = kv_transfer_params["request_id"]
assert request_id in proxy_state.req_data_dict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using assert for input validation in an API endpoint is not robust. If the assertion fails (e.g., due to a race condition or a bug where req_data_dict is cleared prematurely), it will raise an AssertionError and cause a 500 Internal Server Error, which is not a clean API response. It's better to handle this case gracefully by returning a specific HTTP error, like a 404 Not Found, to provide a clearer error to the client. You will need to add JSONResponse to the imports from fastapi.responses.

Suggested change
assert request_id in proxy_state.req_data_dict
if request_id not in proxy_state.req_data_dict:
logger.error(f"Request ID {request_id} not found in req_data_dict.")
return JSONResponse(status_code=404, content={"error": f"Request ID {request_id} not found"})

req_data, request_length, api = proxy_state.req_data_dict[request_id]
request_id = get_origin_request_id(api, request_id)
req_data["kv_transfer_params"] = kv_transfer_params
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
logger.debug(
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
)

# Select prefiller
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
prefiller = proxy_state.prefillers[prefiller_idx]
logger.debug(f"Using prefill {prefiller.url=} {req_data=}")
# Send request to prefiller
response = await send_request_to_service(
prefiller.client,
prefiller_idx,
api,
req_data,
request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay)
proxy_state.release_prefiller(prefiller_idx, prefiller_score)

except Exception as e:
logger.error(f"Post metaserver failed with: {str(e)}")
Comment on lines 572 to 573
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The broad except Exception as e: block only logs the error. Since it doesn't return any response, the client (which is another service in this architecture) will hang until it times out. This can lead to resource exhaustion and cascading failures. It's better to return an appropriate HTTP error response (e.g., 500 Internal Server Error) to the client. You will need to add JSONResponse to the imports from fastapi.responses.

Suggested change
except Exception as e:
logger.error(f"Post metaserver failed with: {str(e)}")
except Exception as e:
logger.error(f"Post metaserver failed with: {str(e)}")
return JSONResponse(status_code=500, content={"error": "Internal server error in metaserver"})


Expand All @@ -682,5 +577,4 @@ async def metaserver(request: Request):
global global_args
global_args = parse_args()
import uvicorn

uvicorn.run(app, host=global_args.host, port=global_args.port)
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from __future__ import annotations

import os
from unittest.mock import patch

import pytest
from vllm import SamplingParams
from vllm.config import CompilationConfig, CUDAGraphMode
Expand Down
2 changes: 1 addition & 1 deletion tests/ut/kv_connector/test_mooncake_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,4 +1136,4 @@ def test_device_id_selection_with_physical_devices(self):


if __name__ == '__main__':
unittest.main()
unittest.main()
Loading
Loading