Skip to content

Commit 2c291bc

Browse files
[bugfix] layerwise D first plan (#3866)
### What this PR does / why we need it? Refactored the layerwise code to send to the D node first, preventing P-node hangs due to communication timeouts when DP > 1. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? By ci - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@83f478b --------- Signed-off-by: wangxiaoteng <[email protected]> Signed-off-by: liziyu <[email protected]> Co-authored-by: liziyu <[email protected]>
1 parent 627f20c commit 2c291bc

File tree

4 files changed

+942
-1333
lines changed

4 files changed

+942
-1333
lines changed

examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py

Lines changed: 73 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -88,26 +88,24 @@
8888
import asyncio
8989
import functools
9090
import heapq
91-
import json
9291
import os
9392
import sys
9493
import threading
9594
import uuid
9695
from contextlib import asynccontextmanager
97-
from dataclasses import dataclass
98-
from typing import Any, List
96+
from typing import List
9997

10098
import httpx
10199
from fastapi import FastAPI, Request
102100
from fastapi.responses import StreamingResponse
101+
from transformers import AutoTokenizer
103102
from vllm.logger import init_logger
104103

105104
logger = init_logger(__name__)
106105

107106
# Add uvloop for faster event loop if available
108107
try:
109108
import uvloop
110-
111109
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
112110
except ImportError:
113111
pass
@@ -154,6 +152,9 @@ def __init__(self, prefiller_instances, decoder_instances):
154152
heapq.heapify(self.prefiller_heap)
155153
heapq.heapify(self.decoder_heap)
156154
self.req_id_future = {}
155+
self.req_data_dict = {}
156+
self.tokenizer = AutoTokenizer.from_pretrained(
157+
global_args.tokenizer_dir)
157158

158159
def _update_prefiller_priority(self, server_idx: int):
159160
"""Update the priority of a prefiller server in the heap."""
@@ -280,6 +281,10 @@ def parse_args():
280281
nargs="+",
281282
default=["localhost"])
282283
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
284+
parser.add_argument("--tokenizer-dir",
285+
type=str,
286+
default="/mnt/weight/Qwen3-235B-A22B-W8A8",
287+
help="Maximum number of retries for HTTP requests")
283288
parser.add_argument("--max-retries",
284289
type=int,
285290
default=3,
@@ -356,17 +361,6 @@ async def send_request_to_service(client: httpx.AsyncClient,
356361
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
357362
prefiller_id)
358363
req_data = req_data.copy()
359-
req_data['kv_transfer_params'] = {
360-
"do_remote_decode": True,
361-
"do_remote_prefill": False,
362-
"remote_engine_id": None,
363-
"remote_block_ids": None,
364-
"remote_host": None,
365-
"remote_port": None,
366-
"aborted_request": list(aborted_requests),
367-
"metaserver":
368-
f"http://{global_args.host}:{global_args.port}/v1/metaserver"
369-
}
370364
req_data["stream"] = False
371365
req_data["max_tokens"] = 1
372366
if "stream_options" in req_data:
@@ -458,180 +452,59 @@ def get_api_request_id(api, req_id):
458452
return "chatcmpl-" + req_id
459453

460454

461-
async def _handle_select_instance(api: str, req_data: Any,
462-
request_length: int):
463-
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
464-
logger.debug(
465-
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
466-
)
467-
request_id = await proxy_state.next_req_id()
468-
# Select prefiller
469-
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
470-
prefiller = proxy_state.prefillers[prefiller_idx]
471-
result_future = asyncio.Future() # type: ignore
472-
request_id_api = get_api_request_id(api, request_id)
473-
proxy_state.req_id_future[request_id_api] = result_future
474-
# Send request to prefiller
475-
asyncio.get_running_loop().create_task(
476-
send_request_to_service(prefiller.client,
477-
prefiller_idx,
478-
api,
479-
req_data,
480-
request_id,
481-
max_retries=global_args.max_retries,
482-
base_delay=global_args.retry_delay))
483-
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
484-
485-
response = await result_future
486-
del proxy_state.req_id_future[request_id_api]
487-
req_data["kv_transfer_params"] = response
488-
489-
# Select decoder
490-
decoder_score = proxy_state.calculate_decode_scores(request_length)
491-
logger.debug("Decoder score: %f", decoder_score)
492-
# Use the prefiller's kv_transfer_params to select decoder
493-
decoder_idx = proxy_state.select_decoder(decoder_score)
494-
decoder = proxy_state.decoders[decoder_idx]
495-
logger.debug("Using %s %s", prefiller.url, decoder.url)
496-
return InstanceInfo(request_id=request_id,
497-
prefiller_idx=prefiller_idx,
498-
prefiller_score=prefiller_score,
499-
prefiller=prefiller,
500-
decoder=decoder,
501-
decoder_idx=decoder_idx,
502-
decoder_score=decoder_score)
503-
504-
505-
@dataclass
506-
class InstanceInfo:
507-
request_id: str
508-
prefiller_idx: int
509-
prefiller_score: float
510-
prefiller: ServerState
511-
decoder_idx: int
512-
decoder_score: float
513-
decoder: ServerState
455+
def get_origin_request_id(api, req_id):
456+
if api == "/completions":
457+
return req_id.replace("cmpl-", "").replace("-0", "")
458+
elif api == "/chat/completions":
459+
return req_id.replace("chatcmpl-", "")
514460

515461

516462
async def _handle_completions(api: str, request: Request):
517463
try:
518464
req_data = await request.json()
519465
req_body = await request.body()
520466
request_length = len(req_body)
521-
instance_info = await _handle_select_instance(api, req_data,
522-
request_length)
523-
stream_flag = bool(req_data.get("stream", False))
524-
chat_flag = "messages" in req_data
525-
526-
if "prompt" in req_data:
527-
origin_prompt = req_data["prompt"]
528-
elif chat_flag:
529-
messages = req_data["messages"]
530-
origin_prompt = messages[0].get("content", "")
531-
else:
532-
origin_prompt = ""
533-
# refer to vLLM sampling_params: max_token default value
534-
origin_max_tokens = req_data.get("max_tokens", 16)
467+
request_id = await proxy_state.next_req_id()
468+
request_id_api = get_api_request_id(api, request_id)
469+
proxy_state.req_data_dict[request_id_api] = (req_data, request_length,
470+
api)
471+
req_data['kv_transfer_params'] = {
472+
"do_remote_decode":
473+
False,
474+
"do_remote_prefill":
475+
True,
476+
"metaserver":
477+
f"http://{global_args.host}:{global_args.port}/v1/metaserver"
478+
}
479+
# Select decoder
480+
decoder_score = proxy_state.calculate_decode_scores(request_length)
481+
logger.debug("Decoder score: %f", decoder_score)
482+
# Use the prefiller's kv_transfer_params to select decoder
483+
decoder_idx = proxy_state.select_decoder(decoder_score)
484+
decoder = proxy_state.decoders[decoder_idx]
485+
# logger.debug("Using %s %s", prefiller.url, decoder.url)
486+
# Stream response from decoder
487+
released_kv = False
535488

536489
async def generate_stream():
537-
nonlocal instance_info
538-
generated_token = ""
539-
released_kv = False
540-
retry_count = 0
541-
retry = True
542-
completion_tokens = 0
490+
nonlocal released_kv
543491
# Only one await per chunk, minimal logic in loop
544492
try:
545-
while retry:
546-
retry = False
547-
async for chunk in stream_service_response_with_retry(
548-
instance_info.decoder.client,
549-
api,
550-
req_data,
551-
request_id=instance_info.request_id,
552-
max_retries=global_args.max_retries,
553-
base_delay=global_args.retry_delay):
554-
if not released_kv and chunk:
555-
proxy_state.release_prefiller_kv(
556-
instance_info.prefiller_idx,
557-
instance_info.prefiller_score)
558-
released_kv = True
559-
try:
560-
chunk_str = chunk.decode("utf-8").strip()
561-
except UnicodeDecodeError:
562-
logger.debug(
563-
f"Skipping chunk: {chunk}")
564-
yield chunk
565-
continue
566-
if not chunk_str:
567-
continue
568-
if chunk_str.startswith("data: "):
569-
chunk_str = chunk_str[len("data: "):]
570-
try:
571-
chunk_json = json.loads(chunk_str)
572-
except json.JSONDecodeError:
573-
# if chunk is [done], skip it.
574-
logger.debug(
575-
f"Skipping chunk: {chunk_str}")
576-
yield chunk
577-
continue
578-
choices = chunk_json.get("choices", [])
579-
if not choices:
580-
yield chunk
581-
continue
582-
583-
choice = choices[0]
584-
delta = choice.get("delta") or {}
585-
message = choice.get("message") or {}
586-
content = (
587-
delta.get("content")
588-
or message.get("content")
589-
or choice.get("text")
590-
or ""
591-
)
592-
generated_token += content
593-
594-
stop_reason = choice.get(
595-
"stop_reason")
596-
usage = chunk_json.get("usage", {})
597-
completion_tokens = (completion_tokens + 1) if stream_flag else \
598-
(completion_tokens + usage.get("completion_tokens"))
599-
if stop_reason == "recomputed":
600-
retry = True
601-
retry_count += 1
602-
if chat_flag:
603-
messages[0][
604-
"content"] = origin_prompt + generated_token
605-
else:
606-
req_data[
607-
"prompt"] = origin_prompt + generated_token
608-
req_data[
609-
"max_tokens"] = origin_max_tokens - completion_tokens + retry_count
610-
tmp_request_length = len(
611-
json.dumps(req_data).encode("utf-8"))
612-
instance_info = await _handle_select_instance(
613-
api, req_data, tmp_request_length)
614-
break
615-
if retry_count > 0 and not stream_flag:
616-
if chat_flag:
617-
choices[0]["message"][
618-
"content"] = generated_token
619-
else:
620-
choices[0]["text"] = generated_token
621-
chunk = json.dumps(chunk_json).encode("utf-8")
622-
yield chunk
493+
async for chunk in stream_service_response_with_retry(
494+
decoder.client,
495+
api,
496+
req_data,
497+
request_id=request_id,
498+
max_retries=global_args.max_retries,
499+
base_delay=global_args.retry_delay):
500+
yield chunk
623501
except Exception as e:
624502
logger.error(
625-
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
503+
f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
626504
)
627-
proxy_state.abort_prefiller_request(
628-
instance_info.prefiller_idx, instance_info.request_id)
629-
proxy_state.release_prefiller_kv(instance_info.prefiller_idx,
630-
instance_info.prefiller_score)
631505

632506
# After streaming done, release tokens
633-
proxy_state.release_decoder(instance_info.decoder_idx,
634-
instance_info.decoder_score)
507+
proxy_state.release_decoder(decoder_idx, decoder_score)
635508

636509
return StreamingResponse(generate_stream(),
637510
media_type="application/json")
@@ -669,11 +542,33 @@ async def healthcheck():
669542
@app.post("/v1/metaserver")
670543
async def metaserver(request: Request):
671544
try:
672-
req_data = await request.json()
673-
request_id = req_data.pop("request_id", None)
674-
if request_id in proxy_state.req_id_future:
675-
result_future = proxy_state.req_id_future[request_id]
676-
result_future.set_result(req_data)
545+
kv_transfer_params = await request.json()
546+
547+
request_id = kv_transfer_params["request_id"]
548+
assert request_id in proxy_state.req_data_dict
549+
req_data, request_length, api = proxy_state.req_data_dict[request_id]
550+
request_id = get_origin_request_id(api, request_id)
551+
req_data["kv_transfer_params"] = kv_transfer_params
552+
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
553+
logger.debug(
554+
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
555+
)
556+
557+
# Select prefiller
558+
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
559+
prefiller = proxy_state.prefillers[prefiller_idx]
560+
logger.debug(f"Using prefill {prefiller.url=} {req_data=}")
561+
# Send request to prefiller
562+
response = await send_request_to_service(
563+
prefiller.client,
564+
prefiller_idx,
565+
api,
566+
req_data,
567+
request_id,
568+
max_retries=global_args.max_retries,
569+
base_delay=global_args.retry_delay)
570+
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
571+
677572
except Exception as e:
678573
logger.error(f"Post metaserver failed with: {str(e)}")
679574

@@ -682,5 +577,4 @@ async def metaserver(request: Request):
682577
global global_args
683578
global_args = parse_args()
684579
import uvicorn
685-
686580
uvicorn.run(app, host=global_args.host, port=global_args.port)

0 commit comments

Comments
 (0)