diff --git a/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py index 00e62f351ce3..697e101c3592 100644 --- a/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py +++ b/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py @@ -51,20 +51,31 @@ def check_vllm_server(url: str, timeout=5, retries=3) -> bool: return False -def run_simple_prompt(base_url: str, model_name: str, - input_prompt: str) -> str: +def run_simple_prompt(base_url: str, model_name: str, input_prompt: str, + use_chat_endpoint: bool) -> str: client = openai.OpenAI(api_key="EMPTY", base_url=base_url) - completion = client.completions.create(model=model_name, - prompt=input_prompt, - max_tokens=MAX_OUTPUT_LEN, - temperature=0.0, - seed=42) + if use_chat_endpoint: + completion = client.chat.completions.create( + model=model_name, + messages=[{ + "role": "user", + "content": [{ + "type": "text", + "text": input_prompt + }] + }], + max_completion_tokens=MAX_OUTPUT_LEN, + temperature=0.0, + seed=42) + return completion.choices[0].message.content + else: + completion = client.completions.create(model=model_name, + prompt=input_prompt, + max_tokens=MAX_OUTPUT_LEN, + temperature=0.0, + seed=42) - # print("-" * 50) - # print(f"Completion results for {model_name}:") - # print(completion) - # print("-" * 50) - return completion.choices[0].text + return completion.choices[0].text def main(): @@ -125,10 +136,12 @@ def main(): f"vllm server: {args.service_url} is not ready yet!") output_strs = dict() - for prompt in SAMPLE_PROMPTS: + for i, prompt in enumerate(SAMPLE_PROMPTS): + use_chat_endpoint = (i % 2 == 1) output_str = run_simple_prompt(base_url=service_url, model_name=args.model_name, - input_prompt=prompt) + input_prompt=prompt, + use_chat_endpoint=use_chat_endpoint) print(f"Prompt: {prompt}, output: {output_str}") output_strs[prompt] = output_str diff --git a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py index 66e237da0f80..905ae0ea7172 100644 --- a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py +++ b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py @@ -162,6 +162,8 @@ async def send_request_to_service(client_info: dict, endpoint: str, } req_data["stream"] = False req_data["max_tokens"] = 1 + if "max_completion_tokens" in req_data: + req_data["max_completion_tokens"] = 1 if "stream_options" in req_data: del req_data["stream_options"] headers = {