diff --git a/run_benchmarks.py b/run_benchmarks.py index 9e73cf8..ae8d3de 100644 --- a/run_benchmarks.py +++ b/run_benchmarks.py @@ -4,7 +4,7 @@ import argparse from vllm_benchmark import run_benchmark -async def run_all_benchmarks(vllm_url, api_key, use_long_context): +async def run_all_benchmarks(model, vllm_url, api_key, use_long_context): configurations = [ {"num_requests": 10, "concurrency": 1, "output_tokens": 100}, {"num_requests": 100, "concurrency": 10, "output_tokens": 100}, @@ -16,7 +16,7 @@ async def run_all_benchmarks(vllm_url, api_key, use_long_context): for config in configurations: print(f"Running benchmark with concurrency {config['concurrency']}...") - results = await run_benchmark(config['num_requests'], config['concurrency'], 30, config['output_tokens'], vllm_url, api_key, use_long_context) + results = await run_benchmark(model, config['num_requests'], config['concurrency'], 30, config['output_tokens'], vllm_url, api_key, use_long_context) all_results.append(results) time.sleep(5) # Wait a bit between runs to let the system cool down @@ -24,12 +24,13 @@ async def run_all_benchmarks(vllm_url, api_key, use_long_context): def main(): parser = argparse.ArgumentParser(description="Run vLLM benchmarks with various configurations") + parser.add_argument("--model", type=str, required=True, help="Model to benchmark") parser.add_argument("--vllm_url", type=str, required=True, help="URL of the vLLM server") parser.add_argument("--api_key", type=str, required=True, help="API key for vLLM server") parser.add_argument("--use_long_context", action="store_true", help="Use long context prompt pairs instead of short prompts") args = parser.parse_args() - all_results = asyncio.run(run_all_benchmarks(args.vllm_url, args.api_key, args.use_long_context)) + all_results = asyncio.run(run_all_benchmarks(args.model, args.vllm_url, args.api_key, args.use_long_context)) with open('benchmark_results.json', 'w') as f: json.dump(all_results, f, indent=2) diff --git a/vllm_benchmark.py b/vllm_benchmark.py index 2b8ad39..14cb5f4 100644 --- a/vllm_benchmark.py +++ b/vllm_benchmark.py @@ -118,7 +118,7 @@ async def process_stream(stream): break return first_token_time, total_tokens -async def make_request(client, output_tokens, request_timeout, use_long_context): +async def make_request(model,client, output_tokens, request_timeout, use_long_context): start_time = time.time() if use_long_context: prompt_pair = random.choice(LONG_PROMPT_PAIRS) @@ -128,7 +128,7 @@ async def make_request(client, output_tokens, request_timeout, use_long_context) try: stream = await client.chat.completions.create( - model="NousResearch/Meta-Llama-3.1-8B-Instruct", + model=model, messages=[ {"role": "user", "content": content} ], @@ -150,7 +150,7 @@ async def make_request(client, output_tokens, request_timeout, use_long_context) logging.error(f"Error during request: {str(e)}") return None -async def worker(client, semaphore, queue, results, output_tokens, request_timeout, use_long_context): +async def worker(model, client, semaphore, queue, results, output_tokens, request_timeout, use_long_context): while True: async with semaphore: task_id = await queue.get() @@ -158,7 +158,7 @@ async def worker(client, semaphore, queue, results, output_tokens, request_timeo queue.task_done() break logging.info(f"Starting request {task_id}") - result = await make_request(client, output_tokens, request_timeout, use_long_context) + result = await make_request(model, client, output_tokens, request_timeout, use_long_context) if result: results.append(result) else: @@ -173,7 +173,7 @@ def calculate_percentile(values, percentile, reverse=False): return np.percentile(values, 100 - percentile) return np.percentile(values, percentile) -async def run_benchmark(num_requests, concurrency, request_timeout, output_tokens, vllm_url, api_key, use_long_context): +async def run_benchmark(model, num_requests, concurrency, request_timeout, output_tokens, vllm_url, api_key, use_long_context): client = AsyncOpenAI(base_url=vllm_url, api_key=api_key) semaphore = asyncio.Semaphore(concurrency) queue = asyncio.Queue() @@ -188,7 +188,7 @@ async def run_benchmark(num_requests, concurrency, request_timeout, output_token await queue.put(None) # Create worker tasks - workers = [asyncio.create_task(worker(client, semaphore, queue, results, output_tokens, request_timeout, use_long_context)) for _ in range(concurrency)] + workers = [asyncio.create_task(worker(model, client, semaphore, queue, results, output_tokens, request_timeout, use_long_context)) for _ in range(concurrency)] start_time = time.time() @@ -252,6 +252,7 @@ def print_results(results): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Benchmark LLaMA-3 model with vLLM") + parser.add_argument("--model", type=str, required=True, help="Model to benchmark") parser.add_argument("--num_requests", type=int, required=True, help="Number of requests to make") parser.add_argument("--concurrency", type=int, required=True, help="Number of concurrent requests") parser.add_argument("--request_timeout", type=int, default=30, help="Timeout for each request in seconds (default: 30)") @@ -261,7 +262,7 @@ def print_results(results): parser.add_argument("--use_long_context", action="store_true", help="Use long context prompt pairs instead of short prompts") args = parser.parse_args() - results = asyncio.run(run_benchmark(args.num_requests, args.concurrency, args.request_timeout, args.output_tokens, args.vllm_url, args.api_key, args.use_long_context)) + results = asyncio.run(run_benchmark(args.model, args.num_requests, args.concurrency, args.request_timeout, args.output_tokens, args.vllm_url, args.api_key, args.use_long_context)) print_results(results) else: # When imported as a module, provide the run_benchmark function