@@ -52,19 +52,23 @@ def check_vllm_server(url: str, timeout=5, retries=3) -> bool:
52
52
53
53
54
54
def run_simple_prompt (base_url : str , model_name : str ,
55
- input_prompt : str ) -> str :
55
+ input_prompt : str , use_chat_endpoint : bool ) -> str :
56
56
client = openai .OpenAI (api_key = "EMPTY" , base_url = base_url )
57
- completion = client .completions .create (model = model_name ,
58
- prompt = input_prompt ,
59
- max_tokens = MAX_OUTPUT_LEN ,
60
- temperature = 0.0 ,
61
- seed = 42 )
57
+ if use_chat_endpoint :
58
+ completion = client .chat .completions .create (model = model_name ,
59
+ messages = [{"role" : "user" , "content" : [{"type" : "text" ,"text" : input_prompt }]}],
60
+ max_completion_tokens = MAX_OUTPUT_LEN ,
61
+ temperature = 0.0 ,
62
+ seed = 42 )
63
+ return completion .choices [0 ].message .content
64
+ else :
65
+ completion = client .completions .create (model = model_name ,
66
+ prompt = input_prompt ,
67
+ max_tokens = MAX_OUTPUT_LEN ,
68
+ temperature = 0.0 ,
69
+ seed = 42 )
62
70
63
- # print("-" * 50)
64
- # print(f"Completion results for {model_name}:")
65
- # print(completion)
66
- # print("-" * 50)
67
- return completion .choices [0 ].text
71
+ return completion .choices [0 ].text
68
72
69
73
70
74
def main ():
@@ -125,10 +129,12 @@ def main():
125
129
f"vllm server: { args .service_url } is not ready yet!" )
126
130
127
131
output_strs = dict ()
128
- for prompt in SAMPLE_PROMPTS :
132
+ for i , prompt in enumerate (SAMPLE_PROMPTS ):
133
+ use_chat_endpoint = (i % 2 == 1 )
129
134
output_str = run_simple_prompt (base_url = service_url ,
130
135
model_name = args .model_name ,
131
- input_prompt = prompt )
136
+ input_prompt = prompt ,
137
+ use_chat_endpoint = use_chat_endpoint )
132
138
print (f"Prompt: { prompt } , output: { output_str } " )
133
139
output_strs [prompt ] = output_str
134
140
0 commit comments