Skip to content

Commit f368a24

Browse files
committed
[PD] add test for chat completions endpoint
Signed-off-by: Abirdcfly <[email protected]>
1 parent 5bbaf49 commit f368a24

File tree

2 files changed

+29
-14
lines changed

2 files changed

+29
-14
lines changed

tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,31 @@ def check_vllm_server(url: str, timeout=5, retries=3) -> bool:
5151
return False
5252

5353

54-
def run_simple_prompt(base_url: str, model_name: str,
55-
input_prompt: str) -> str:
54+
def run_simple_prompt(base_url: str, model_name: str, input_prompt: str,
55+
use_chat_endpoint: bool) -> str:
5656
client = openai.OpenAI(api_key="EMPTY", base_url=base_url)
57-
completion = client.completions.create(model=model_name,
58-
prompt=input_prompt,
59-
max_tokens=MAX_OUTPUT_LEN,
60-
temperature=0.0,
61-
seed=42)
57+
if use_chat_endpoint:
58+
completion = client.chat.completions.create(
59+
model = model_name,
60+
messages = [{
61+
"role": "user",
62+
"content": [{
63+
"type": "text",
64+
"text": input_prompt
65+
}]
66+
}],
67+
max_completion_tokens = MAX_OUTPUT_LEN,
68+
temperature = 0.0,
69+
seed = 42)
70+
return completion.choices[0].message.content
71+
else:
72+
completion = client.completions.create(model=model_name,
73+
prompt=input_prompt,
74+
max_tokens=MAX_OUTPUT_LEN,
75+
temperature=0.0,
76+
seed=42)
6277

63-
# print("-" * 50)
64-
# print(f"Completion results for {model_name}:")
65-
# print(completion)
66-
# print("-" * 50)
67-
return completion.choices[0].text
78+
return completion.choices[0].text
6879

6980

7081
def main():
@@ -125,10 +136,12 @@ def main():
125136
f"vllm server: {args.service_url} is not ready yet!")
126137

127138
output_strs = dict()
128-
for prompt in SAMPLE_PROMPTS:
139+
for i, prompt in enumerate(SAMPLE_PROMPTS):
140+
use_chat_endpoint = (i % 2 == 1)
129141
output_str = run_simple_prompt(base_url=service_url,
130142
model_name=args.model_name,
131-
input_prompt=prompt)
143+
input_prompt=prompt,
144+
use_chat_endpoint=use_chat_endpoint)
132145
print(f"Prompt: {prompt}, output: {output_str}")
133146
output_strs[prompt] = output_str
134147

tests/v1/kv_connector/nixl_integration/toy_proxy_server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ async def send_request_to_service(client_info: dict, endpoint: str,
162162
}
163163
req_data["stream"] = False
164164
req_data["max_tokens"] = 1
165+
if "max_completion_tokens" in req_data:
166+
req_data["max_completion_tokens"] = 1
165167
if "stream_options" in req_data:
166168
del req_data["stream_options"]
167169
headers = {

0 commit comments

Comments
 (0)