@@ -51,20 +51,31 @@ def check_vllm_server(url: str, timeout=5, retries=3) -> bool:
51
51
return False
52
52
53
53
54
- def run_simple_prompt (base_url : str , model_name : str ,
55
- input_prompt : str ) -> str :
54
+ def run_simple_prompt (base_url : str , model_name : str , input_prompt : str ,
55
+ use_chat_endpoint : bool ) -> str :
56
56
client = openai .OpenAI (api_key = "EMPTY" , base_url = base_url )
57
- completion = client .completions .create (model = model_name ,
58
- prompt = input_prompt ,
59
- max_tokens = MAX_OUTPUT_LEN ,
60
- temperature = 0.0 ,
61
- seed = 42 )
57
+ if use_chat_endpoint :
58
+ completion = client .chat .completions .create (
59
+ model = model_name ,
60
+ messages = [{
61
+ "role" : "user" ,
62
+ "content" : [{
63
+ "type" : "text" ,
64
+ "text" : input_prompt
65
+ }]
66
+ }],
67
+ max_completion_tokens = MAX_OUTPUT_LEN ,
68
+ temperature = 0.0 ,
69
+ seed = 42 )
70
+ return completion .choices [0 ].message .content
71
+ else :
72
+ completion = client .completions .create (model = model_name ,
73
+ prompt = input_prompt ,
74
+ max_tokens = MAX_OUTPUT_LEN ,
75
+ temperature = 0.0 ,
76
+ seed = 42 )
62
77
63
- # print("-" * 50)
64
- # print(f"Completion results for {model_name}:")
65
- # print(completion)
66
- # print("-" * 50)
67
- return completion .choices [0 ].text
78
+ return completion .choices [0 ].text
68
79
69
80
70
81
def main ():
@@ -125,10 +136,12 @@ def main():
125
136
f"vllm server: { args .service_url } is not ready yet!" )
126
137
127
138
output_strs = dict ()
128
- for prompt in SAMPLE_PROMPTS :
139
+ for i , prompt in enumerate (SAMPLE_PROMPTS ):
140
+ use_chat_endpoint = (i % 2 == 1 )
129
141
output_str = run_simple_prompt (base_url = service_url ,
130
142
model_name = args .model_name ,
131
- input_prompt = prompt )
143
+ input_prompt = prompt ,
144
+ use_chat_endpoint = use_chat_endpoint )
132
145
print (f"Prompt: { prompt } , output: { output_str } " )
133
146
output_strs [prompt ] = output_str
134
147
0 commit comments