-
Notifications
You must be signed in to change notification settings - Fork 550
[bugfix_v0.11.0-dev] [P/D]layerwise D first plan #3907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -88,26 +88,24 @@ | |||||||||||
| import asyncio | ||||||||||||
| import functools | ||||||||||||
| import heapq | ||||||||||||
| import json | ||||||||||||
| import os | ||||||||||||
| import sys | ||||||||||||
| import threading | ||||||||||||
| import uuid | ||||||||||||
| from contextlib import asynccontextmanager | ||||||||||||
| from dataclasses import dataclass | ||||||||||||
| from typing import Any, List | ||||||||||||
| from typing import List | ||||||||||||
|
|
||||||||||||
| import httpx | ||||||||||||
| from fastapi import FastAPI, Request | ||||||||||||
| from fastapi.responses import StreamingResponse | ||||||||||||
| from transformers import AutoTokenizer | ||||||||||||
| from vllm.logger import init_logger | ||||||||||||
|
|
||||||||||||
| logger = init_logger(__name__) | ||||||||||||
|
|
||||||||||||
| # Add uvloop for faster event loop if available | ||||||||||||
| try: | ||||||||||||
| import uvloop | ||||||||||||
|
|
||||||||||||
| asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) | ||||||||||||
| except ImportError: | ||||||||||||
| pass | ||||||||||||
|
|
@@ -154,6 +152,9 @@ def __init__(self, prefiller_instances, decoder_instances): | |||||||||||
| heapq.heapify(self.prefiller_heap) | ||||||||||||
| heapq.heapify(self.decoder_heap) | ||||||||||||
| self.req_id_future = {} | ||||||||||||
| self.req_data_dict = {} | ||||||||||||
| self.tokenizer = AutoTokenizer.from_pretrained( | ||||||||||||
| global_args.tokenizer_dir) | ||||||||||||
|
|
||||||||||||
| def _update_prefiller_priority(self, server_idx: int): | ||||||||||||
| """Update the priority of a prefiller server in the heap.""" | ||||||||||||
|
|
@@ -280,6 +281,10 @@ def parse_args(): | |||||||||||
| nargs="+", | ||||||||||||
| default=["localhost"]) | ||||||||||||
| parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002]) | ||||||||||||
| parser.add_argument("--tokenizer-dir", | ||||||||||||
| type=str, | ||||||||||||
| default="/mnt/weight/Qwen3-235B-A22B-W8A8", | ||||||||||||
| help="Maximum number of retries for HTTP requests") | ||||||||||||
| parser.add_argument("--max-retries", | ||||||||||||
| type=int, | ||||||||||||
| default=3, | ||||||||||||
|
|
@@ -356,17 +361,6 @@ async def send_request_to_service(client: httpx.AsyncClient, | |||||||||||
| aborted_requests = proxy_state.aquire_aborted_prefiller_requests( | ||||||||||||
| prefiller_id) | ||||||||||||
| req_data = req_data.copy() | ||||||||||||
| req_data['kv_transfer_params'] = { | ||||||||||||
| "do_remote_decode": True, | ||||||||||||
| "do_remote_prefill": False, | ||||||||||||
| "remote_engine_id": None, | ||||||||||||
| "remote_block_ids": None, | ||||||||||||
| "remote_host": None, | ||||||||||||
| "remote_port": None, | ||||||||||||
| "aborted_request": list(aborted_requests), | ||||||||||||
| "metaserver": | ||||||||||||
| f"http://{global_args.host}:{global_args.port}/v1/metaserver" | ||||||||||||
| } | ||||||||||||
| req_data["stream"] = False | ||||||||||||
| req_data["max_tokens"] = 1 | ||||||||||||
| if "stream_options" in req_data: | ||||||||||||
|
|
@@ -458,180 +452,59 @@ def get_api_request_id(api, req_id): | |||||||||||
| return "chatcmpl-" + req_id | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| async def _handle_select_instance(api: str, req_data: Any, | ||||||||||||
| request_length: int): | ||||||||||||
| prefiller_score = proxy_state.calculate_prefill_scores(request_length) | ||||||||||||
| logger.debug( | ||||||||||||
| f"Request length: {request_length}, Prefiller score: {prefiller_score}" | ||||||||||||
| ) | ||||||||||||
| request_id = await proxy_state.next_req_id() | ||||||||||||
| # Select prefiller | ||||||||||||
| prefiller_idx = proxy_state.select_prefiller(prefiller_score) | ||||||||||||
| prefiller = proxy_state.prefillers[prefiller_idx] | ||||||||||||
| result_future = asyncio.Future() # type: ignore | ||||||||||||
| request_id_api = get_api_request_id(api, request_id) | ||||||||||||
| proxy_state.req_id_future[request_id_api] = result_future | ||||||||||||
| # Send request to prefiller | ||||||||||||
| asyncio.get_running_loop().create_task( | ||||||||||||
| send_request_to_service(prefiller.client, | ||||||||||||
| prefiller_idx, | ||||||||||||
| api, | ||||||||||||
| req_data, | ||||||||||||
| request_id, | ||||||||||||
| max_retries=global_args.max_retries, | ||||||||||||
| base_delay=global_args.retry_delay)) | ||||||||||||
| proxy_state.release_prefiller(prefiller_idx, prefiller_score) | ||||||||||||
|
|
||||||||||||
| response = await result_future | ||||||||||||
| del proxy_state.req_id_future[request_id_api] | ||||||||||||
| req_data["kv_transfer_params"] = response | ||||||||||||
|
|
||||||||||||
| # Select decoder | ||||||||||||
| decoder_score = proxy_state.calculate_decode_scores(request_length) | ||||||||||||
| logger.debug("Decoder score: %f", decoder_score) | ||||||||||||
| # Use the prefiller's kv_transfer_params to select decoder | ||||||||||||
| decoder_idx = proxy_state.select_decoder(decoder_score) | ||||||||||||
| decoder = proxy_state.decoders[decoder_idx] | ||||||||||||
| logger.debug("Using %s %s", prefiller.url, decoder.url) | ||||||||||||
| return InstanceInfo(request_id=request_id, | ||||||||||||
| prefiller_idx=prefiller_idx, | ||||||||||||
| prefiller_score=prefiller_score, | ||||||||||||
| prefiller=prefiller, | ||||||||||||
| decoder=decoder, | ||||||||||||
| decoder_idx=decoder_idx, | ||||||||||||
| decoder_score=decoder_score) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @dataclass | ||||||||||||
| class InstanceInfo: | ||||||||||||
| request_id: str | ||||||||||||
| prefiller_idx: int | ||||||||||||
| prefiller_score: float | ||||||||||||
| prefiller: ServerState | ||||||||||||
| decoder_idx: int | ||||||||||||
| decoder_score: float | ||||||||||||
| decoder: ServerState | ||||||||||||
| def get_origin_request_id(api, req_id): | ||||||||||||
| if api == "/completions": | ||||||||||||
| return req_id.replace("cmpl-", "").replace("-0", "") | ||||||||||||
| elif api == "/chat/completions": | ||||||||||||
| return req_id.replace("chatcmpl-", "") | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| async def _handle_completions(api: str, request: Request): | ||||||||||||
| try: | ||||||||||||
| req_data = await request.json() | ||||||||||||
| req_body = await request.body() | ||||||||||||
| request_length = len(req_body) | ||||||||||||
| instance_info = await _handle_select_instance(api, req_data, | ||||||||||||
| request_length) | ||||||||||||
| stream_flag = bool(req_data.get("stream", False)) | ||||||||||||
| chat_flag = "messages" in req_data | ||||||||||||
|
|
||||||||||||
| if "prompt" in req_data: | ||||||||||||
| origin_prompt = req_data["prompt"] | ||||||||||||
| elif chat_flag: | ||||||||||||
| messages = req_data["messages"] | ||||||||||||
| origin_prompt = messages[0].get("content", "") | ||||||||||||
| else: | ||||||||||||
| origin_prompt = "" | ||||||||||||
| # refer to vLLM sampling_params: max_token default value | ||||||||||||
| origin_max_tokens = req_data.get("max_tokens", 16) | ||||||||||||
| request_id = await proxy_state.next_req_id() | ||||||||||||
| request_id_api = get_api_request_id(api, request_id) | ||||||||||||
| proxy_state.req_data_dict[request_id_api] = (req_data, request_length, | ||||||||||||
| api) | ||||||||||||
| req_data['kv_transfer_params'] = { | ||||||||||||
| "do_remote_decode": | ||||||||||||
| False, | ||||||||||||
| "do_remote_prefill": | ||||||||||||
| True, | ||||||||||||
| "metaserver": | ||||||||||||
| f"http://{global_args.host}:{global_args.port}/v1/metaserver" | ||||||||||||
| } | ||||||||||||
| # Select decoder | ||||||||||||
| decoder_score = proxy_state.calculate_decode_scores(request_length) | ||||||||||||
| logger.debug("Decoder score: %f", decoder_score) | ||||||||||||
| # Use the prefiller's kv_transfer_params to select decoder | ||||||||||||
| decoder_idx = proxy_state.select_decoder(decoder_score) | ||||||||||||
| decoder = proxy_state.decoders[decoder_idx] | ||||||||||||
| # logger.debug("Using %s %s", prefiller.url, decoder.url) | ||||||||||||
| # Stream response from decoder | ||||||||||||
| released_kv = False | ||||||||||||
|
|
||||||||||||
| async def generate_stream(): | ||||||||||||
| nonlocal instance_info | ||||||||||||
| generated_token = "" | ||||||||||||
| released_kv = False | ||||||||||||
| retry_count = 0 | ||||||||||||
| retry = True | ||||||||||||
| completion_tokens = 0 | ||||||||||||
| nonlocal released_kv | ||||||||||||
| # Only one await per chunk, minimal logic in loop | ||||||||||||
| try: | ||||||||||||
| while retry: | ||||||||||||
| retry = False | ||||||||||||
| async for chunk in stream_service_response_with_retry( | ||||||||||||
| instance_info.decoder.client, | ||||||||||||
| api, | ||||||||||||
| req_data, | ||||||||||||
| request_id=instance_info.request_id, | ||||||||||||
| max_retries=global_args.max_retries, | ||||||||||||
| base_delay=global_args.retry_delay): | ||||||||||||
| if not released_kv and chunk: | ||||||||||||
| proxy_state.release_prefiller_kv( | ||||||||||||
| instance_info.prefiller_idx, | ||||||||||||
| instance_info.prefiller_score) | ||||||||||||
| released_kv = True | ||||||||||||
| try: | ||||||||||||
| chunk_str = chunk.decode("utf-8").strip() | ||||||||||||
| except UnicodeDecodeError: | ||||||||||||
| logger.debug( | ||||||||||||
| f"Skipping chunk: {chunk}") | ||||||||||||
| yield chunk | ||||||||||||
| continue | ||||||||||||
| if not chunk_str: | ||||||||||||
| continue | ||||||||||||
| if chunk_str.startswith("data: "): | ||||||||||||
| chunk_str = chunk_str[len("data: "):] | ||||||||||||
| try: | ||||||||||||
| chunk_json = json.loads(chunk_str) | ||||||||||||
| except json.JSONDecodeError: | ||||||||||||
| # if chunk is [done], skip it. | ||||||||||||
| logger.debug( | ||||||||||||
| f"Skipping chunk: {chunk_str}") | ||||||||||||
| yield chunk | ||||||||||||
| continue | ||||||||||||
| choices = chunk_json.get("choices", []) | ||||||||||||
| if not choices: | ||||||||||||
| yield chunk | ||||||||||||
| continue | ||||||||||||
|
|
||||||||||||
| choice = choices[0] | ||||||||||||
| delta = choice.get("delta") or {} | ||||||||||||
| message = choice.get("message") or {} | ||||||||||||
| content = ( | ||||||||||||
| delta.get("content") | ||||||||||||
| or message.get("content") | ||||||||||||
| or choice.get("text") | ||||||||||||
| or "" | ||||||||||||
| ) | ||||||||||||
| generated_token += content | ||||||||||||
|
|
||||||||||||
| stop_reason = choice.get( | ||||||||||||
| "stop_reason") | ||||||||||||
| usage = chunk_json.get("usage", {}) | ||||||||||||
| completion_tokens = (completion_tokens + 1) if stream_flag else \ | ||||||||||||
| (completion_tokens + usage.get("completion_tokens")) | ||||||||||||
| if stop_reason == "recomputed": | ||||||||||||
| retry = True | ||||||||||||
| retry_count += 1 | ||||||||||||
| if chat_flag: | ||||||||||||
| messages[0][ | ||||||||||||
| "content"] = origin_prompt + generated_token | ||||||||||||
| else: | ||||||||||||
| req_data[ | ||||||||||||
| "prompt"] = origin_prompt + generated_token | ||||||||||||
| req_data[ | ||||||||||||
| "max_tokens"] = origin_max_tokens - completion_tokens + retry_count | ||||||||||||
| tmp_request_length = len( | ||||||||||||
| json.dumps(req_data).encode("utf-8")) | ||||||||||||
| instance_info = await _handle_select_instance( | ||||||||||||
| api, req_data, tmp_request_length) | ||||||||||||
| break | ||||||||||||
| if retry_count > 0 and not stream_flag: | ||||||||||||
| if chat_flag: | ||||||||||||
| choices[0]["message"][ | ||||||||||||
| "content"] = generated_token | ||||||||||||
| else: | ||||||||||||
| choices[0]["text"] = generated_token | ||||||||||||
| chunk = json.dumps(chunk_json).encode("utf-8") | ||||||||||||
| yield chunk | ||||||||||||
| async for chunk in stream_service_response_with_retry( | ||||||||||||
| decoder.client, | ||||||||||||
| api, | ||||||||||||
| req_data, | ||||||||||||
| request_id=request_id, | ||||||||||||
| max_retries=global_args.max_retries, | ||||||||||||
| base_delay=global_args.retry_delay): | ||||||||||||
| yield chunk | ||||||||||||
| except Exception as e: | ||||||||||||
| logger.error( | ||||||||||||
| f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it" | ||||||||||||
| f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it" | ||||||||||||
| ) | ||||||||||||
| proxy_state.abort_prefiller_request( | ||||||||||||
| instance_info.prefiller_idx, instance_info.request_id) | ||||||||||||
| proxy_state.release_prefiller_kv(instance_info.prefiller_idx, | ||||||||||||
| instance_info.prefiller_score) | ||||||||||||
|
|
||||||||||||
| # After streaming done, release tokens | ||||||||||||
| proxy_state.release_decoder(instance_info.decoder_idx, | ||||||||||||
| instance_info.decoder_score) | ||||||||||||
| proxy_state.release_decoder(decoder_idx, decoder_score) | ||||||||||||
|
|
||||||||||||
| return StreamingResponse(generate_stream(), | ||||||||||||
| media_type="application/json") | ||||||||||||
|
|
@@ -669,11 +542,33 @@ async def healthcheck(): | |||||||||||
| @app.post("/v1/metaserver") | ||||||||||||
| async def metaserver(request: Request): | ||||||||||||
| try: | ||||||||||||
| req_data = await request.json() | ||||||||||||
| request_id = req_data.pop("request_id", None) | ||||||||||||
| if request_id in proxy_state.req_id_future: | ||||||||||||
| result_future = proxy_state.req_id_future[request_id] | ||||||||||||
| result_future.set_result(req_data) | ||||||||||||
| kv_transfer_params = await request.json() | ||||||||||||
|
|
||||||||||||
| request_id = kv_transfer_params["request_id"] | ||||||||||||
| assert request_id in proxy_state.req_data_dict | ||||||||||||
| req_data, request_length, api = proxy_state.req_data_dict[request_id] | ||||||||||||
| request_id = get_origin_request_id(api, request_id) | ||||||||||||
| req_data["kv_transfer_params"] = kv_transfer_params | ||||||||||||
| prefiller_score = proxy_state.calculate_prefill_scores(request_length) | ||||||||||||
| logger.debug( | ||||||||||||
| f"Request length: {request_length}, Prefiller score: {prefiller_score}" | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| # Select prefiller | ||||||||||||
| prefiller_idx = proxy_state.select_prefiller(prefiller_score) | ||||||||||||
| prefiller = proxy_state.prefillers[prefiller_idx] | ||||||||||||
| logger.debug(f"Using prefill {prefiller.url=} {req_data=}") | ||||||||||||
| # Send request to prefiller | ||||||||||||
| response = await send_request_to_service( | ||||||||||||
| prefiller.client, | ||||||||||||
| prefiller_idx, | ||||||||||||
| api, | ||||||||||||
| req_data, | ||||||||||||
| request_id, | ||||||||||||
| max_retries=global_args.max_retries, | ||||||||||||
| base_delay=global_args.retry_delay) | ||||||||||||
| proxy_state.release_prefiller(prefiller_idx, prefiller_score) | ||||||||||||
|
|
||||||||||||
| except Exception as e: | ||||||||||||
| logger.error(f"Post metaserver failed with: {str(e)}") | ||||||||||||
|
Comment on lines
572
to
573
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The broad
Suggested change
|
||||||||||||
|
|
||||||||||||
|
|
@@ -682,5 +577,4 @@ async def metaserver(request: Request): | |||||||||||
| global global_args | ||||||||||||
| global_args = parse_args() | ||||||||||||
| import uvicorn | ||||||||||||
|
|
||||||||||||
| uvicorn.run(app, host=global_args.host, port=global_args.port) | ||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
assertfor input validation in an API endpoint is not robust. If the assertion fails (e.g., due to a race condition or a bug wherereq_data_dictis cleared prematurely), it will raise anAssertionErrorand cause a 500 Internal Server Error, which is not a clean API response. It's better to handle this case gracefully by returning a specific HTTP error, like a 404 Not Found, to provide a clearer error to the client. You will need to addJSONResponseto the imports fromfastapi.responses.