Skip to content

Commit ddbebf9

Browse files
committed
[PD] add test for chat completions endpoint
Signed-off-by: Abirdcfly <[email protected]>
1 parent 5bbaf49 commit ddbebf9

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,23 @@ def check_vllm_server(url: str, timeout=5, retries=3) -> bool:
5252

5353

5454
def run_simple_prompt(base_url: str, model_name: str,
55-
input_prompt: str) -> str:
55+
input_prompt: str, use_chat_endpoint: bool) -> str:
5656
client = openai.OpenAI(api_key="EMPTY", base_url=base_url)
57-
completion = client.completions.create(model=model_name,
58-
prompt=input_prompt,
59-
max_tokens=MAX_OUTPUT_LEN,
60-
temperature=0.0,
61-
seed=42)
57+
if use_chat_endpoint:
58+
completion = client.chat.completions.create(model=model_name,
59+
messages=[{"role": "user", "content": [{"type": "text","text": input_prompt}]}],
60+
max_completion_tokens=MAX_OUTPUT_LEN,
61+
temperature=0.0,
62+
seed=42)
63+
return completion.choices[0].message.content
64+
else:
65+
completion = client.completions.create(model=model_name,
66+
prompt=input_prompt,
67+
max_tokens=MAX_OUTPUT_LEN,
68+
temperature=0.0,
69+
seed=42)
6270

63-
# print("-" * 50)
64-
# print(f"Completion results for {model_name}:")
65-
# print(completion)
66-
# print("-" * 50)
67-
return completion.choices[0].text
71+
return completion.choices[0].text
6872

6973

7074
def main():
@@ -125,10 +129,12 @@ def main():
125129
f"vllm server: {args.service_url} is not ready yet!")
126130

127131
output_strs = dict()
128-
for prompt in SAMPLE_PROMPTS:
132+
for i, prompt in enumerate(SAMPLE_PROMPTS):
133+
use_chat_endpoint = (i % 2 == 1)
129134
output_str = run_simple_prompt(base_url=service_url,
130135
model_name=args.model_name,
131-
input_prompt=prompt)
136+
input_prompt=prompt,
137+
use_chat_endpoint=use_chat_endpoint)
132138
print(f"Prompt: {prompt}, output: {output_str}")
133139
output_strs[prompt] = output_str
134140

tests/v1/kv_connector/nixl_integration/toy_proxy_server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ async def send_request_to_service(client_info: dict, endpoint: str,
162162
}
163163
req_data["stream"] = False
164164
req_data["max_tokens"] = 1
165+
if "max_completion_tokens" in req_data:
166+
req_data["max_completion_tokens"] = 1
165167
if "stream_options" in req_data:
166168
del req_data["stream_options"]
167169
headers = {

0 commit comments

Comments
 (0)