8888import asyncio
8989import functools
9090import heapq
91- import json
9291import os
9392import sys
9493import threading
9594import uuid
9695from contextlib import asynccontextmanager
97- from dataclasses import dataclass
98- from typing import Any , List
96+ from typing import List
9997
10098import httpx
10199from fastapi import FastAPI , Request
102100from fastapi .responses import StreamingResponse
101+ from transformers import AutoTokenizer
103102from vllm .logger import init_logger
104103
105104logger = init_logger (__name__ )
106105
107106# Add uvloop for faster event loop if available
108107try :
109108 import uvloop
110-
111109 asyncio .set_event_loop_policy (uvloop .EventLoopPolicy ())
112110except ImportError :
113111 pass
@@ -154,6 +152,9 @@ def __init__(self, prefiller_instances, decoder_instances):
154152 heapq .heapify (self .prefiller_heap )
155153 heapq .heapify (self .decoder_heap )
156154 self .req_id_future = {}
155+ self .req_data_dict = {}
156+ self .tokenizer = AutoTokenizer .from_pretrained (
157+ global_args .tokenizer_dir )
157158
158159 def _update_prefiller_priority (self , server_idx : int ):
159160 """Update the priority of a prefiller server in the heap."""
@@ -280,6 +281,10 @@ def parse_args():
280281 nargs = "+" ,
281282 default = ["localhost" ])
282283 parser .add_argument ("--decoder-ports" , type = int , nargs = "+" , default = [8002 ])
284+ parser .add_argument ("--tokenizer-dir" ,
285+ type = str ,
286+ default = "/mnt/weight/Qwen3-235B-A22B-W8A8" ,
287+ help = "Maximum number of retries for HTTP requests" )
283288 parser .add_argument ("--max-retries" ,
284289 type = int ,
285290 default = 3 ,
@@ -356,17 +361,6 @@ async def send_request_to_service(client: httpx.AsyncClient,
356361 aborted_requests = proxy_state .aquire_aborted_prefiller_requests (
357362 prefiller_id )
358363 req_data = req_data .copy ()
359- req_data ['kv_transfer_params' ] = {
360- "do_remote_decode" : True ,
361- "do_remote_prefill" : False ,
362- "remote_engine_id" : None ,
363- "remote_block_ids" : None ,
364- "remote_host" : None ,
365- "remote_port" : None ,
366- "aborted_request" : list (aborted_requests ),
367- "metaserver" :
368- f"http://{ global_args .host } :{ global_args .port } /v1/metaserver"
369- }
370364 req_data ["stream" ] = False
371365 req_data ["max_tokens" ] = 1
372366 if "stream_options" in req_data :
@@ -458,180 +452,59 @@ def get_api_request_id(api, req_id):
458452 return "chatcmpl-" + req_id
459453
460454
461- async def _handle_select_instance (api : str , req_data : Any ,
462- request_length : int ):
463- prefiller_score = proxy_state .calculate_prefill_scores (request_length )
464- logger .debug (
465- f"Request length: { request_length } , Prefiller score: { prefiller_score } "
466- )
467- request_id = await proxy_state .next_req_id ()
468- # Select prefiller
469- prefiller_idx = proxy_state .select_prefiller (prefiller_score )
470- prefiller = proxy_state .prefillers [prefiller_idx ]
471- result_future = asyncio .Future () # type: ignore
472- request_id_api = get_api_request_id (api , request_id )
473- proxy_state .req_id_future [request_id_api ] = result_future
474- # Send request to prefiller
475- asyncio .get_running_loop ().create_task (
476- send_request_to_service (prefiller .client ,
477- prefiller_idx ,
478- api ,
479- req_data ,
480- request_id ,
481- max_retries = global_args .max_retries ,
482- base_delay = global_args .retry_delay ))
483- proxy_state .release_prefiller (prefiller_idx , prefiller_score )
484-
485- response = await result_future
486- del proxy_state .req_id_future [request_id_api ]
487- req_data ["kv_transfer_params" ] = response
488-
489- # Select decoder
490- decoder_score = proxy_state .calculate_decode_scores (request_length )
491- logger .debug ("Decoder score: %f" , decoder_score )
492- # Use the prefiller's kv_transfer_params to select decoder
493- decoder_idx = proxy_state .select_decoder (decoder_score )
494- decoder = proxy_state .decoders [decoder_idx ]
495- logger .debug ("Using %s %s" , prefiller .url , decoder .url )
496- return InstanceInfo (request_id = request_id ,
497- prefiller_idx = prefiller_idx ,
498- prefiller_score = prefiller_score ,
499- prefiller = prefiller ,
500- decoder = decoder ,
501- decoder_idx = decoder_idx ,
502- decoder_score = decoder_score )
503-
504-
505- @dataclass
506- class InstanceInfo :
507- request_id : str
508- prefiller_idx : int
509- prefiller_score : float
510- prefiller : ServerState
511- decoder_idx : int
512- decoder_score : float
513- decoder : ServerState
455+ def get_origin_request_id (api , req_id ):
456+ if api == "/completions" :
457+ return req_id .replace ("cmpl-" , "" ).replace ("-0" , "" )
458+ elif api == "/chat/completions" :
459+ return req_id .replace ("chatcmpl-" , "" )
514460
515461
516462async def _handle_completions (api : str , request : Request ):
517463 try :
518464 req_data = await request .json ()
519465 req_body = await request .body ()
520466 request_length = len (req_body )
521- instance_info = await _handle_select_instance (api , req_data ,
522- request_length )
523- stream_flag = bool (req_data .get ("stream" , False ))
524- chat_flag = "messages" in req_data
525-
526- if "prompt" in req_data :
527- origin_prompt = req_data ["prompt" ]
528- elif chat_flag :
529- messages = req_data ["messages" ]
530- origin_prompt = messages [0 ].get ("content" , "" )
531- else :
532- origin_prompt = ""
533- # refer to vLLM sampling_params: max_token default value
534- origin_max_tokens = req_data .get ("max_tokens" , 16 )
467+ request_id = await proxy_state .next_req_id ()
468+ request_id_api = get_api_request_id (api , request_id )
469+ proxy_state .req_data_dict [request_id_api ] = (req_data , request_length ,
470+ api )
471+ req_data ['kv_transfer_params' ] = {
472+ "do_remote_decode" :
473+ False ,
474+ "do_remote_prefill" :
475+ True ,
476+ "metaserver" :
477+ f"http://{ global_args .host } :{ global_args .port } /v1/metaserver"
478+ }
479+ # Select decoder
480+ decoder_score = proxy_state .calculate_decode_scores (request_length )
481+ logger .debug ("Decoder score: %f" , decoder_score )
482+ # Use the prefiller's kv_transfer_params to select decoder
483+ decoder_idx = proxy_state .select_decoder (decoder_score )
484+ decoder = proxy_state .decoders [decoder_idx ]
485+ # logger.debug("Using %s %s", prefiller.url, decoder.url)
486+ # Stream response from decoder
487+ released_kv = False
535488
536489 async def generate_stream ():
537- nonlocal instance_info
538- generated_token = ""
539- released_kv = False
540- retry_count = 0
541- retry = True
542- completion_tokens = 0
490+ nonlocal released_kv
543491 # Only one await per chunk, minimal logic in loop
544492 try :
545- while retry :
546- retry = False
547- async for chunk in stream_service_response_with_retry (
548- instance_info .decoder .client ,
549- api ,
550- req_data ,
551- request_id = instance_info .request_id ,
552- max_retries = global_args .max_retries ,
553- base_delay = global_args .retry_delay ):
554- if not released_kv and chunk :
555- proxy_state .release_prefiller_kv (
556- instance_info .prefiller_idx ,
557- instance_info .prefiller_score )
558- released_kv = True
559- try :
560- chunk_str = chunk .decode ("utf-8" ).strip ()
561- except UnicodeDecodeError :
562- logger .debug (
563- f"Skipping chunk: { chunk } " )
564- yield chunk
565- continue
566- if not chunk_str :
567- continue
568- if chunk_str .startswith ("data: " ):
569- chunk_str = chunk_str [len ("data: " ):]
570- try :
571- chunk_json = json .loads (chunk_str )
572- except json .JSONDecodeError :
573- # if chunk is [done], skip it.
574- logger .debug (
575- f"Skipping chunk: { chunk_str } " )
576- yield chunk
577- continue
578- choices = chunk_json .get ("choices" , [])
579- if not choices :
580- yield chunk
581- continue
582-
583- choice = choices [0 ]
584- delta = choice .get ("delta" ) or {}
585- message = choice .get ("message" ) or {}
586- content = (
587- delta .get ("content" )
588- or message .get ("content" )
589- or choice .get ("text" )
590- or ""
591- )
592- generated_token += content
593-
594- stop_reason = choice .get (
595- "stop_reason" )
596- usage = chunk_json .get ("usage" , {})
597- completion_tokens = (completion_tokens + 1 ) if stream_flag else \
598- (completion_tokens + usage .get ("completion_tokens" ))
599- if stop_reason == "recomputed" :
600- retry = True
601- retry_count += 1
602- if chat_flag :
603- messages [0 ][
604- "content" ] = origin_prompt + generated_token
605- else :
606- req_data [
607- "prompt" ] = origin_prompt + generated_token
608- req_data [
609- "max_tokens" ] = origin_max_tokens - completion_tokens + retry_count
610- tmp_request_length = len (
611- json .dumps (req_data ).encode ("utf-8" ))
612- instance_info = await _handle_select_instance (
613- api , req_data , tmp_request_length )
614- break
615- if retry_count > 0 and not stream_flag :
616- if chat_flag :
617- choices [0 ]["message" ][
618- "content" ] = generated_token
619- else :
620- choices [0 ]["text" ] = generated_token
621- chunk = json .dumps (chunk_json ).encode ("utf-8" )
622- yield chunk
493+ async for chunk in stream_service_response_with_retry (
494+ decoder .client ,
495+ api ,
496+ req_data ,
497+ request_id = request_id ,
498+ max_retries = global_args .max_retries ,
499+ base_delay = global_args .retry_delay ):
500+ yield chunk
623501 except Exception as e :
624502 logger .error (
625- f"Error during streaming from decoder { instance_info . decoder .url } : { str (e )} the aborted request { instance_info . request_id } will be routing to the target prefiller when new request is ready to dispatch to it"
503+ f"Error during streaming from decoder { decoder .url } : { str (e )} the aborted request { request_id } will be routing to the target prefiller when new request is ready to dispatch to it"
626504 )
627- proxy_state .abort_prefiller_request (
628- instance_info .prefiller_idx , instance_info .request_id )
629- proxy_state .release_prefiller_kv (instance_info .prefiller_idx ,
630- instance_info .prefiller_score )
631505
632506 # After streaming done, release tokens
633- proxy_state .release_decoder (instance_info .decoder_idx ,
634- instance_info .decoder_score )
507+ proxy_state .release_decoder (decoder_idx , decoder_score )
635508
636509 return StreamingResponse (generate_stream (),
637510 media_type = "application/json" )
@@ -669,11 +542,33 @@ async def healthcheck():
669542@app .post ("/v1/metaserver" )
670543async def metaserver (request : Request ):
671544 try :
672- req_data = await request .json ()
673- request_id = req_data .pop ("request_id" , None )
674- if request_id in proxy_state .req_id_future :
675- result_future = proxy_state .req_id_future [request_id ]
676- result_future .set_result (req_data )
545+ kv_transfer_params = await request .json ()
546+
547+ request_id = kv_transfer_params ["request_id" ]
548+ assert request_id in proxy_state .req_data_dict
549+ req_data , request_length , api = proxy_state .req_data_dict [request_id ]
550+ request_id = get_origin_request_id (api , request_id )
551+ req_data ["kv_transfer_params" ] = kv_transfer_params
552+ prefiller_score = proxy_state .calculate_prefill_scores (request_length )
553+ logger .debug (
554+ f"Request length: { request_length } , Prefiller score: { prefiller_score } "
555+ )
556+
557+ # Select prefiller
558+ prefiller_idx = proxy_state .select_prefiller (prefiller_score )
559+ prefiller = proxy_state .prefillers [prefiller_idx ]
560+ logger .debug (f"Using prefill { prefiller .url = } { req_data = } " )
561+ # Send request to prefiller
562+ response = await send_request_to_service (
563+ prefiller .client ,
564+ prefiller_idx ,
565+ api ,
566+ req_data ,
567+ request_id ,
568+ max_retries = global_args .max_retries ,
569+ base_delay = global_args .retry_delay )
570+ proxy_state .release_prefiller (prefiller_idx , prefiller_score )
571+
677572 except Exception as e :
678573 logger .error (f"Post metaserver failed with: { str (e )} " )
679574
@@ -682,5 +577,4 @@ async def metaserver(request: Request):
682577 global global_args
683578 global_args = parse_args ()
684579 import uvicorn
685-
686580 uvicorn .run (app , host = global_args .host , port = global_args .port )
0 commit comments