Skip to content

Commit dd8ea3e

Browse files
committed
refactoring layerwise_proxy
Signed-off-by: liziyu <[email protected]>
1 parent 0a9bc7a commit dd8ea3e

File tree

1 file changed

+119
-180
lines changed

1 file changed

+119
-180
lines changed

examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py

Lines changed: 119 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -88,26 +88,24 @@
8888
import asyncio
8989
import functools
9090
import heapq
91-
import json
9291
import os
9392
import sys
9493
import threading
9594
import uuid
9695
from contextlib import asynccontextmanager
97-
from dataclasses import dataclass
98-
from typing import Any, List
96+
from typing import List
9997

10098
import httpx
10199
from fastapi import FastAPI, Request
102100
from fastapi.responses import StreamingResponse
101+
from transformers import AutoTokenizer
103102
from vllm.logger import init_logger
104103

105104
logger = init_logger(__name__)
106105

107106
# Add uvloop for faster event loop if available
108107
try:
109108
import uvloop
110-
111109
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
112110
except ImportError:
113111
pass
@@ -154,6 +152,9 @@ def __init__(self, prefiller_instances, decoder_instances):
154152
heapq.heapify(self.prefiller_heap)
155153
heapq.heapify(self.decoder_heap)
156154
self.req_id_future = {}
155+
self.req_data_dict = {}
156+
self.tokenizer = AutoTokenizer.from_pretrained(
157+
global_args.tokenizer_dir)
157158

158159
def _update_prefiller_priority(self, server_idx: int):
159160
"""Update the priority of a prefiller server in the heap."""
@@ -280,6 +281,10 @@ def parse_args():
280281
nargs="+",
281282
default=["localhost"])
282283
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
284+
parser.add_argument("--tokenizer-dir",
285+
type=str,
286+
default="/mnt/weight/Qwen3-235B-A22B-W8A8",
287+
help="Maximum number of retries for HTTP requests")
283288
parser.add_argument("--max-retries",
284289
type=int,
285290
default=3,
@@ -356,17 +361,16 @@ async def send_request_to_service(client: httpx.AsyncClient,
356361
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
357362
prefiller_id)
358363
req_data = req_data.copy()
359-
req_data['kv_transfer_params'] = {
360-
"do_remote_decode": True,
361-
"do_remote_prefill": False,
362-
"remote_engine_id": None,
363-
"remote_block_ids": None,
364-
"remote_host": None,
365-
"remote_port": None,
366-
"aborted_request": list(aborted_requests),
367-
"metaserver":
368-
f"http://{global_args.host}:{global_args.port}/v1/metaserver"
369-
}
364+
# req_data['kv_transfer_params'] = {
365+
# "do_remote_decode": True,
366+
# "do_remote_prefill": False,
367+
# "remote_engine_id": None,
368+
# "remote_block_ids": None,
369+
# "remote_host": None,
370+
# "remote_port": None,
371+
# "aborted_request": list(aborted_requests),
372+
# "metaserver": f"http://{global_args.host}:{global_args.port}/v1/metaserver"
373+
# }
370374
req_data["stream"] = False
371375
req_data["max_tokens"] = 1
372376
if "stream_options" in req_data:
@@ -458,180 +462,89 @@ def get_api_request_id(api, req_id):
458462
return "chatcmpl-" + req_id
459463

460464

461-
async def _handle_select_instance(api: str, req_data: Any,
462-
request_length: int):
463-
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
464-
logger.debug(
465-
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
466-
)
467-
request_id = await proxy_state.next_req_id()
468-
# Select prefiller
469-
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
470-
prefiller = proxy_state.prefillers[prefiller_idx]
471-
result_future = asyncio.Future() # type: ignore
472-
request_id_api = get_api_request_id(api, request_id)
473-
proxy_state.req_id_future[request_id_api] = result_future
474-
# Send request to prefiller
475-
asyncio.get_running_loop().create_task(
476-
send_request_to_service(prefiller.client,
477-
prefiller_idx,
478-
api,
479-
req_data,
480-
request_id,
481-
max_retries=global_args.max_retries,
482-
base_delay=global_args.retry_delay))
483-
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
484-
485-
response = await result_future
486-
del proxy_state.req_id_future[request_id_api]
487-
req_data["kv_transfer_params"] = response
488-
489-
# Select decoder
490-
decoder_score = proxy_state.calculate_decode_scores(request_length)
491-
logger.debug("Decoder score: %f", decoder_score)
492-
# Use the prefiller's kv_transfer_params to select decoder
493-
decoder_idx = proxy_state.select_decoder(decoder_score)
494-
decoder = proxy_state.decoders[decoder_idx]
495-
logger.debug("Using %s %s", prefiller.url, decoder.url)
496-
return InstanceInfo(request_id=request_id,
497-
prefiller_idx=prefiller_idx,
498-
prefiller_score=prefiller_score,
499-
prefiller=prefiller,
500-
decoder=decoder,
501-
decoder_idx=decoder_idx,
502-
decoder_score=decoder_score)
503-
504-
505-
@dataclass
506-
class InstanceInfo:
507-
request_id: str
508-
prefiller_idx: int
509-
prefiller_score: float
510-
prefiller: ServerState
511-
decoder_idx: int
512-
decoder_score: float
513-
decoder: ServerState
465+
def get_origin_request_id(api, req_id):
466+
if api == "/completions":
467+
return req_id.replace("cmpl-", "").replace("-0", "")
468+
elif api == "/chat/completions":
469+
return req_id.replace("chatcmpl-", "")
514470

515471

516472
async def _handle_completions(api: str, request: Request):
517473
try:
518474
req_data = await request.json()
519475
req_body = await request.body()
520476
request_length = len(req_body)
521-
instance_info = await _handle_select_instance(api, req_data,
522-
request_length)
523-
stream_flag = bool(req_data.get("stream", False))
524-
chat_flag = "messages" in req_data
525-
526-
if "prompt" in req_data:
527-
origin_prompt = req_data["prompt"]
528-
elif chat_flag:
529-
messages = req_data["messages"]
530-
origin_prompt = messages[0].get("content", "")
531-
else:
532-
origin_prompt = ""
533-
# refer to vLLM sampling_params: max_token default value
534-
origin_max_tokens = req_data.get("max_tokens", 16)
477+
# prefiller_score = proxy_state.calculate_prefill_scores(request_length)
478+
# logger.debug(
479+
# f"Request length: {request_length}, Prefiller score: {prefiller_score}"
480+
# )
481+
request_id = await proxy_state.next_req_id()
482+
request_id_api = get_api_request_id(api, request_id)
483+
proxy_state.req_data_dict[request_id_api] = (req_data, request_length,
484+
api)
485+
# # Select prefiller
486+
# prefiller_idx = proxy_state.select_prefiller(prefiller_score)
487+
# prefiller = proxy_state.prefillers[prefiller_idx]
488+
# result_future = asyncio.Future() # type: ignore
489+
# proxy_state.req_id_future[request_id_api] = result_future
490+
# # Send request to prefiller
491+
# asyncio.get_running_loop().create_task(send_request_to_service(
492+
# prefiller.client,
493+
# prefiller_idx,
494+
# api,
495+
# req_data,
496+
# request_id,
497+
# max_retries=global_args.max_retries,
498+
# base_delay=global_args.retry_delay))
499+
# proxy_state.release_prefiller(prefiller_idx, prefiller_score)
500+
501+
# response = await result_future
502+
# del proxy_state.req_id_future[request_id_api]
503+
# req_data["kv_transfer_params"] = response
504+
req_data['kv_transfer_params'] = {
505+
"do_remote_decode":
506+
False,
507+
"do_remote_prefill":
508+
True,
509+
"metaserver":
510+
f"http://{global_args.host}:{global_args.port}/v1/metaserver"
511+
}
512+
# Select decoder
513+
decoder_score = proxy_state.calculate_decode_scores(request_length)
514+
logger.debug("Decoder score: %f", decoder_score)
515+
# Use the prefiller's kv_transfer_params to select decoder
516+
decoder_idx = proxy_state.select_decoder(decoder_score)
517+
decoder = proxy_state.decoders[decoder_idx]
518+
# logger.debug("Using %s %s", prefiller.url, decoder.url)
519+
# Stream response from decoder
520+
released_kv = False
535521

536522
async def generate_stream():
537-
nonlocal instance_info
538-
generated_token = ""
539-
released_kv = False
540-
retry_count = 0
541-
retry = True
542-
completion_tokens = 0
523+
nonlocal released_kv
543524
# Only one await per chunk, minimal logic in loop
544525
try:
545-
while retry:
546-
retry = False
547-
async for chunk in stream_service_response_with_retry(
548-
instance_info.decoder.client,
549-
api,
550-
req_data,
551-
request_id=instance_info.request_id,
552-
max_retries=global_args.max_retries,
553-
base_delay=global_args.retry_delay):
554-
if not released_kv and chunk:
555-
proxy_state.release_prefiller_kv(
556-
instance_info.prefiller_idx,
557-
instance_info.prefiller_score)
558-
released_kv = True
559-
try:
560-
chunk_str = chunk.decode("utf-8").strip()
561-
except UnicodeDecodeError:
562-
logger.debug(
563-
f"Skipping chunk: {chunk}")
564-
yield chunk
565-
continue
566-
if not chunk_str:
567-
continue
568-
if chunk_str.startswith("data: "):
569-
chunk_str = chunk_str[len("data: "):]
570-
try:
571-
chunk_json = json.loads(chunk_str)
572-
except json.JSONDecodeError:
573-
# if chunk is [done], skip it.
574-
logger.debug(
575-
f"Skipping chunk: {chunk_str}")
576-
yield chunk
577-
continue
578-
choices = chunk_json.get("choices", [])
579-
if not choices:
580-
yield chunk
581-
continue
582-
583-
choice = choices[0]
584-
delta = choice.get("delta") or {}
585-
message = choice.get("message") or {}
586-
content = (
587-
delta.get("content")
588-
or message.get("content")
589-
or choice.get("text")
590-
or ""
591-
)
592-
generated_token += content
593-
594-
stop_reason = choice.get(
595-
"stop_reason")
596-
usage = chunk_json.get("usage", {})
597-
completion_tokens = (completion_tokens + 1) if stream_flag else \
598-
(completion_tokens + usage.get("completion_tokens"))
599-
if stop_reason == "recomputed":
600-
retry = True
601-
retry_count += 1
602-
if chat_flag:
603-
messages[0][
604-
"content"] = origin_prompt + generated_token
605-
else:
606-
req_data[
607-
"prompt"] = origin_prompt + generated_token
608-
req_data[
609-
"max_tokens"] = origin_max_tokens - completion_tokens + retry_count
610-
tmp_request_length = len(
611-
json.dumps(req_data).encode("utf-8"))
612-
instance_info = await _handle_select_instance(
613-
api, req_data, tmp_request_length)
614-
break
615-
if retry_count > 0 and not stream_flag:
616-
if chat_flag:
617-
choices[0]["message"][
618-
"content"] = generated_token
619-
else:
620-
choices[0]["text"] = generated_token
621-
chunk = json.dumps(chunk_json).encode("utf-8")
622-
yield chunk
526+
async for chunk in stream_service_response_with_retry(
527+
decoder.client,
528+
api,
529+
req_data,
530+
request_id=request_id,
531+
max_retries=global_args.max_retries,
532+
base_delay=global_args.retry_delay):
533+
# if not released_kv and chunk:
534+
# proxy_state.release_prefiller_kv(
535+
# prefiller_idx, prefiller_score)
536+
# released_kv = True
537+
yield chunk
623538
except Exception as e:
624539
logger.error(
625-
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
540+
f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
626541
)
627-
proxy_state.abort_prefiller_request(
628-
instance_info.prefiller_idx, instance_info.request_id)
629-
proxy_state.release_prefiller_kv(instance_info.prefiller_idx,
630-
instance_info.prefiller_score)
542+
# proxy_state.abort_prefiller_request(prefiller_idx, request_id)
543+
# proxy_state.release_prefiller_kv(prefiller_idx,
544+
# prefiller_score)
631545

632546
# After streaming done, release tokens
633-
proxy_state.release_decoder(instance_info.decoder_idx,
634-
instance_info.decoder_score)
547+
proxy_state.release_decoder(decoder_idx, decoder_score)
635548

636549
return StreamingResponse(generate_stream(),
637550
media_type="application/json")
@@ -669,11 +582,38 @@ async def healthcheck():
669582
@app.post("/v1/metaserver")
670583
async def metaserver(request: Request):
671584
try:
672-
req_data = await request.json()
673-
request_id = req_data.pop("request_id", None)
674-
if request_id in proxy_state.req_id_future:
675-
result_future = proxy_state.req_id_future[request_id]
676-
result_future.set_result(req_data)
585+
kv_transfer_params = await request.json()
586+
587+
request_id = kv_transfer_params["request_id"]
588+
assert request_id in proxy_state.req_data_dict
589+
req_data, request_length, api = proxy_state.req_data_dict[request_id]
590+
# output_prompt = proxy_state.tokenizer.decode(kv_transfer_params["token_ids"])
591+
# req_data["prompt"] = output_prompt
592+
# del kv_transfer_params['token_ids']
593+
request_id = get_origin_request_id(api, request_id)
594+
req_data["kv_transfer_params"] = kv_transfer_params
595+
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
596+
logger.debug(
597+
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
598+
)
599+
600+
# Select prefiller
601+
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
602+
prefiller = proxy_state.prefillers[prefiller_idx]
603+
logger.debug(f"Using prefill {prefiller.url=} {req_data=}")
604+
# Send request to prefiller
605+
response = await send_request_to_service(
606+
prefiller.client,
607+
prefiller_idx,
608+
api,
609+
req_data,
610+
request_id,
611+
max_retries=global_args.max_retries,
612+
base_delay=global_args.retry_delay)
613+
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
614+
615+
# del req_data["prompt"]
616+
677617
except Exception as e:
678618
logger.error(f"Post metaserver failed with: {str(e)}")
679619

@@ -682,5 +622,4 @@ async def metaserver(request: Request):
682622
global global_args
683623
global_args = parse_args()
684624
import uvicorn
685-
686-
uvicorn.run(app, host=global_args.host, port=global_args.port)
625+
uvicorn.run(app, host=global_args.host, port=global_args.port)

0 commit comments

Comments
 (0)