8888import asyncio
8989import functools
9090import heapq
91- import json
9291import os
9392import sys
9493import threading
9594import uuid
9695from contextlib import asynccontextmanager
97- from dataclasses import dataclass
98- from typing import Any , List
96+ from typing import List
9997
10098import httpx
10199from fastapi import FastAPI , Request
102100from fastapi .responses import StreamingResponse
101+ from transformers import AutoTokenizer
103102from vllm .logger import init_logger
104103
105104logger = init_logger (__name__ )
106105
107106# Add uvloop for faster event loop if available
108107try :
109108 import uvloop
110-
111109 asyncio .set_event_loop_policy (uvloop .EventLoopPolicy ())
112110except ImportError :
113111 pass
@@ -154,6 +152,9 @@ def __init__(self, prefiller_instances, decoder_instances):
154152 heapq .heapify (self .prefiller_heap )
155153 heapq .heapify (self .decoder_heap )
156154 self .req_id_future = {}
155+ self .req_data_dict = {}
156+ self .tokenizer = AutoTokenizer .from_pretrained (
157+ global_args .tokenizer_dir )
157158
158159 def _update_prefiller_priority (self , server_idx : int ):
159160 """Update the priority of a prefiller server in the heap."""
@@ -280,6 +281,10 @@ def parse_args():
280281 nargs = "+" ,
281282 default = ["localhost" ])
282283 parser .add_argument ("--decoder-ports" , type = int , nargs = "+" , default = [8002 ])
284+ parser .add_argument ("--tokenizer-dir" ,
285+ type = str ,
286+ default = "/mnt/weight/Qwen3-235B-A22B-W8A8" ,
287+ help = "Maximum number of retries for HTTP requests" )
283288 parser .add_argument ("--max-retries" ,
284289 type = int ,
285290 default = 3 ,
@@ -356,17 +361,16 @@ async def send_request_to_service(client: httpx.AsyncClient,
356361 aborted_requests = proxy_state .aquire_aborted_prefiller_requests (
357362 prefiller_id )
358363 req_data = req_data .copy ()
359- req_data ['kv_transfer_params' ] = {
360- "do_remote_decode" : True ,
361- "do_remote_prefill" : False ,
362- "remote_engine_id" : None ,
363- "remote_block_ids" : None ,
364- "remote_host" : None ,
365- "remote_port" : None ,
366- "aborted_request" : list (aborted_requests ),
367- "metaserver" :
368- f"http://{ global_args .host } :{ global_args .port } /v1/metaserver"
369- }
364+ # req_data['kv_transfer_params'] = {
365+ # "do_remote_decode": True,
366+ # "do_remote_prefill": False,
367+ # "remote_engine_id": None,
368+ # "remote_block_ids": None,
369+ # "remote_host": None,
370+ # "remote_port": None,
371+ # "aborted_request": list(aborted_requests),
372+ # "metaserver": f"http://{global_args.host}:{global_args.port}/v1/metaserver"
373+ # }
370374 req_data ["stream" ] = False
371375 req_data ["max_tokens" ] = 1
372376 if "stream_options" in req_data :
@@ -458,180 +462,89 @@ def get_api_request_id(api, req_id):
458462 return "chatcmpl-" + req_id
459463
460464
461- async def _handle_select_instance (api : str , req_data : Any ,
462- request_length : int ):
463- prefiller_score = proxy_state .calculate_prefill_scores (request_length )
464- logger .debug (
465- f"Request length: { request_length } , Prefiller score: { prefiller_score } "
466- )
467- request_id = await proxy_state .next_req_id ()
468- # Select prefiller
469- prefiller_idx = proxy_state .select_prefiller (prefiller_score )
470- prefiller = proxy_state .prefillers [prefiller_idx ]
471- result_future = asyncio .Future () # type: ignore
472- request_id_api = get_api_request_id (api , request_id )
473- proxy_state .req_id_future [request_id_api ] = result_future
474- # Send request to prefiller
475- asyncio .get_running_loop ().create_task (
476- send_request_to_service (prefiller .client ,
477- prefiller_idx ,
478- api ,
479- req_data ,
480- request_id ,
481- max_retries = global_args .max_retries ,
482- base_delay = global_args .retry_delay ))
483- proxy_state .release_prefiller (prefiller_idx , prefiller_score )
484-
485- response = await result_future
486- del proxy_state .req_id_future [request_id_api ]
487- req_data ["kv_transfer_params" ] = response
488-
489- # Select decoder
490- decoder_score = proxy_state .calculate_decode_scores (request_length )
491- logger .debug ("Decoder score: %f" , decoder_score )
492- # Use the prefiller's kv_transfer_params to select decoder
493- decoder_idx = proxy_state .select_decoder (decoder_score )
494- decoder = proxy_state .decoders [decoder_idx ]
495- logger .debug ("Using %s %s" , prefiller .url , decoder .url )
496- return InstanceInfo (request_id = request_id ,
497- prefiller_idx = prefiller_idx ,
498- prefiller_score = prefiller_score ,
499- prefiller = prefiller ,
500- decoder = decoder ,
501- decoder_idx = decoder_idx ,
502- decoder_score = decoder_score )
503-
504-
505- @dataclass
506- class InstanceInfo :
507- request_id : str
508- prefiller_idx : int
509- prefiller_score : float
510- prefiller : ServerState
511- decoder_idx : int
512- decoder_score : float
513- decoder : ServerState
465+ def get_origin_request_id (api , req_id ):
466+ if api == "/completions" :
467+ return req_id .replace ("cmpl-" , "" ).replace ("-0" , "" )
468+ elif api == "/chat/completions" :
469+ return req_id .replace ("chatcmpl-" , "" )
514470
515471
516472async def _handle_completions (api : str , request : Request ):
517473 try :
518474 req_data = await request .json ()
519475 req_body = await request .body ()
520476 request_length = len (req_body )
521- instance_info = await _handle_select_instance (api , req_data ,
522- request_length )
523- stream_flag = bool (req_data .get ("stream" , False ))
524- chat_flag = "messages" in req_data
525-
526- if "prompt" in req_data :
527- origin_prompt = req_data ["prompt" ]
528- elif chat_flag :
529- messages = req_data ["messages" ]
530- origin_prompt = messages [0 ].get ("content" , "" )
531- else :
532- origin_prompt = ""
533- # refer to vLLM sampling_params: max_token default value
534- origin_max_tokens = req_data .get ("max_tokens" , 16 )
477+ # prefiller_score = proxy_state.calculate_prefill_scores(request_length)
478+ # logger.debug(
479+ # f"Request length: {request_length}, Prefiller score: {prefiller_score}"
480+ # )
481+ request_id = await proxy_state .next_req_id ()
482+ request_id_api = get_api_request_id (api , request_id )
483+ proxy_state .req_data_dict [request_id_api ] = (req_data , request_length ,
484+ api )
485+ # # Select prefiller
486+ # prefiller_idx = proxy_state.select_prefiller(prefiller_score)
487+ # prefiller = proxy_state.prefillers[prefiller_idx]
488+ # result_future = asyncio.Future() # type: ignore
489+ # proxy_state.req_id_future[request_id_api] = result_future
490+ # # Send request to prefiller
491+ # asyncio.get_running_loop().create_task(send_request_to_service(
492+ # prefiller.client,
493+ # prefiller_idx,
494+ # api,
495+ # req_data,
496+ # request_id,
497+ # max_retries=global_args.max_retries,
498+ # base_delay=global_args.retry_delay))
499+ # proxy_state.release_prefiller(prefiller_idx, prefiller_score)
500+
501+ # response = await result_future
502+ # del proxy_state.req_id_future[request_id_api]
503+ # req_data["kv_transfer_params"] = response
504+ req_data ['kv_transfer_params' ] = {
505+ "do_remote_decode" :
506+ False ,
507+ "do_remote_prefill" :
508+ True ,
509+ "metaserver" :
510+ f"http://{ global_args .host } :{ global_args .port } /v1/metaserver"
511+ }
512+ # Select decoder
513+ decoder_score = proxy_state .calculate_decode_scores (request_length )
514+ logger .debug ("Decoder score: %f" , decoder_score )
515+ # Use the prefiller's kv_transfer_params to select decoder
516+ decoder_idx = proxy_state .select_decoder (decoder_score )
517+ decoder = proxy_state .decoders [decoder_idx ]
518+ # logger.debug("Using %s %s", prefiller.url, decoder.url)
519+ # Stream response from decoder
520+ released_kv = False
535521
536522 async def generate_stream ():
537- nonlocal instance_info
538- generated_token = ""
539- released_kv = False
540- retry_count = 0
541- retry = True
542- completion_tokens = 0
523+ nonlocal released_kv
543524 # Only one await per chunk, minimal logic in loop
544525 try :
545- while retry :
546- retry = False
547- async for chunk in stream_service_response_with_retry (
548- instance_info .decoder .client ,
549- api ,
550- req_data ,
551- request_id = instance_info .request_id ,
552- max_retries = global_args .max_retries ,
553- base_delay = global_args .retry_delay ):
554- if not released_kv and chunk :
555- proxy_state .release_prefiller_kv (
556- instance_info .prefiller_idx ,
557- instance_info .prefiller_score )
558- released_kv = True
559- try :
560- chunk_str = chunk .decode ("utf-8" ).strip ()
561- except UnicodeDecodeError :
562- logger .debug (
563- f"Skipping chunk: { chunk } " )
564- yield chunk
565- continue
566- if not chunk_str :
567- continue
568- if chunk_str .startswith ("data: " ):
569- chunk_str = chunk_str [len ("data: " ):]
570- try :
571- chunk_json = json .loads (chunk_str )
572- except json .JSONDecodeError :
573- # if chunk is [done], skip it.
574- logger .debug (
575- f"Skipping chunk: { chunk_str } " )
576- yield chunk
577- continue
578- choices = chunk_json .get ("choices" , [])
579- if not choices :
580- yield chunk
581- continue
582-
583- choice = choices [0 ]
584- delta = choice .get ("delta" ) or {}
585- message = choice .get ("message" ) or {}
586- content = (
587- delta .get ("content" )
588- or message .get ("content" )
589- or choice .get ("text" )
590- or ""
591- )
592- generated_token += content
593-
594- stop_reason = choice .get (
595- "stop_reason" )
596- usage = chunk_json .get ("usage" , {})
597- completion_tokens = (completion_tokens + 1 ) if stream_flag else \
598- (completion_tokens + usage .get ("completion_tokens" ))
599- if stop_reason == "recomputed" :
600- retry = True
601- retry_count += 1
602- if chat_flag :
603- messages [0 ][
604- "content" ] = origin_prompt + generated_token
605- else :
606- req_data [
607- "prompt" ] = origin_prompt + generated_token
608- req_data [
609- "max_tokens" ] = origin_max_tokens - completion_tokens + retry_count
610- tmp_request_length = len (
611- json .dumps (req_data ).encode ("utf-8" ))
612- instance_info = await _handle_select_instance (
613- api , req_data , tmp_request_length )
614- break
615- if retry_count > 0 and not stream_flag :
616- if chat_flag :
617- choices [0 ]["message" ][
618- "content" ] = generated_token
619- else :
620- choices [0 ]["text" ] = generated_token
621- chunk = json .dumps (chunk_json ).encode ("utf-8" )
622- yield chunk
526+ async for chunk in stream_service_response_with_retry (
527+ decoder .client ,
528+ api ,
529+ req_data ,
530+ request_id = request_id ,
531+ max_retries = global_args .max_retries ,
532+ base_delay = global_args .retry_delay ):
533+ # if not released_kv and chunk:
534+ # proxy_state.release_prefiller_kv(
535+ # prefiller_idx, prefiller_score)
536+ # released_kv = True
537+ yield chunk
623538 except Exception as e :
624539 logger .error (
625- f"Error during streaming from decoder { instance_info . decoder .url } : { str (e )} the aborted request { instance_info . request_id } will be routing to the target prefiller when new request is ready to dispatch to it"
540+ f"Error during streaming from decoder { decoder .url } : { str (e )} the aborted request { request_id } will be routing to the target prefiller when new request is ready to dispatch to it"
626541 )
627- proxy_state .abort_prefiller_request (
628- instance_info .prefiller_idx , instance_info .request_id )
629- proxy_state .release_prefiller_kv (instance_info .prefiller_idx ,
630- instance_info .prefiller_score )
542+ # proxy_state.abort_prefiller_request(prefiller_idx, request_id)
543+ # proxy_state.release_prefiller_kv(prefiller_idx,
544+ # prefiller_score)
631545
632546 # After streaming done, release tokens
633- proxy_state .release_decoder (instance_info .decoder_idx ,
634- instance_info .decoder_score )
547+ proxy_state .release_decoder (decoder_idx , decoder_score )
635548
636549 return StreamingResponse (generate_stream (),
637550 media_type = "application/json" )
@@ -669,11 +582,38 @@ async def healthcheck():
669582@app .post ("/v1/metaserver" )
670583async def metaserver (request : Request ):
671584 try :
672- req_data = await request .json ()
673- request_id = req_data .pop ("request_id" , None )
674- if request_id in proxy_state .req_id_future :
675- result_future = proxy_state .req_id_future [request_id ]
676- result_future .set_result (req_data )
585+ kv_transfer_params = await request .json ()
586+
587+ request_id = kv_transfer_params ["request_id" ]
588+ assert request_id in proxy_state .req_data_dict
589+ req_data , request_length , api = proxy_state .req_data_dict [request_id ]
590+ # output_prompt = proxy_state.tokenizer.decode(kv_transfer_params["token_ids"])
591+ # req_data["prompt"] = output_prompt
592+ # del kv_transfer_params['token_ids']
593+ request_id = get_origin_request_id (api , request_id )
594+ req_data ["kv_transfer_params" ] = kv_transfer_params
595+ prefiller_score = proxy_state .calculate_prefill_scores (request_length )
596+ logger .debug (
597+ f"Request length: { request_length } , Prefiller score: { prefiller_score } "
598+ )
599+
600+ # Select prefiller
601+ prefiller_idx = proxy_state .select_prefiller (prefiller_score )
602+ prefiller = proxy_state .prefillers [prefiller_idx ]
603+ logger .debug (f"Using prefill { prefiller .url = } { req_data = } " )
604+ # Send request to prefiller
605+ response = await send_request_to_service (
606+ prefiller .client ,
607+ prefiller_idx ,
608+ api ,
609+ req_data ,
610+ request_id ,
611+ max_retries = global_args .max_retries ,
612+ base_delay = global_args .retry_delay )
613+ proxy_state .release_prefiller (prefiller_idx , prefiller_score )
614+
615+ # del req_data["prompt"]
616+
677617 except Exception as e :
678618 logger .error (f"Post metaserver failed with: { str (e )} " )
679619
@@ -682,5 +622,4 @@ async def metaserver(request: Request):
682622 global global_args
683623 global_args = parse_args ()
684624 import uvicorn
685-
686- uvicorn .run (app , host = global_args .host , port = global_args .port )
625+ uvicorn .run (app , host = global_args .host , port = global_args .port )
0 commit comments