Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions run_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import argparse
from vllm_benchmark import run_benchmark

async def run_all_benchmarks(vllm_url, api_key, use_long_context):
async def run_all_benchmarks(model, vllm_url, api_key, use_long_context):
configurations = [
{"num_requests": 10, "concurrency": 1, "output_tokens": 100},
{"num_requests": 100, "concurrency": 10, "output_tokens": 100},
Expand All @@ -16,20 +16,21 @@ async def run_all_benchmarks(vllm_url, api_key, use_long_context):

for config in configurations:
print(f"Running benchmark with concurrency {config['concurrency']}...")
results = await run_benchmark(config['num_requests'], config['concurrency'], 30, config['output_tokens'], vllm_url, api_key, use_long_context)
results = await run_benchmark(model, config['num_requests'], config['concurrency'], 30, config['output_tokens'], vllm_url, api_key, use_long_context)
all_results.append(results)
time.sleep(5) # Wait a bit between runs to let the system cool down

return all_results

def main():
parser = argparse.ArgumentParser(description="Run vLLM benchmarks with various configurations")
parser.add_argument("--model", type=str, required=True, help="Model to benchmark")
parser.add_argument("--vllm_url", type=str, required=True, help="URL of the vLLM server")
parser.add_argument("--api_key", type=str, required=True, help="API key for vLLM server")
parser.add_argument("--use_long_context", action="store_true", help="Use long context prompt pairs instead of short prompts")
args = parser.parse_args()

all_results = asyncio.run(run_all_benchmarks(args.vllm_url, args.api_key, args.use_long_context))
all_results = asyncio.run(run_all_benchmarks(args.model, args.vllm_url, args.api_key, args.use_long_context))

with open('benchmark_results.json', 'w') as f:
json.dump(all_results, f, indent=2)
Expand Down
15 changes: 8 additions & 7 deletions vllm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def process_stream(stream):
break
return first_token_time, total_tokens

async def make_request(client, output_tokens, request_timeout, use_long_context):
async def make_request(model,client, output_tokens, request_timeout, use_long_context):
start_time = time.time()
if use_long_context:
prompt_pair = random.choice(LONG_PROMPT_PAIRS)
Expand All @@ -128,7 +128,7 @@ async def make_request(client, output_tokens, request_timeout, use_long_context)

try:
stream = await client.chat.completions.create(
model="NousResearch/Meta-Llama-3.1-8B-Instruct",
model=model,
messages=[
{"role": "user", "content": content}
],
Expand All @@ -150,15 +150,15 @@ async def make_request(client, output_tokens, request_timeout, use_long_context)
logging.error(f"Error during request: {str(e)}")
return None

async def worker(client, semaphore, queue, results, output_tokens, request_timeout, use_long_context):
async def worker(model, client, semaphore, queue, results, output_tokens, request_timeout, use_long_context):
while True:
async with semaphore:
task_id = await queue.get()
if task_id is None:
queue.task_done()
break
logging.info(f"Starting request {task_id}")
result = await make_request(client, output_tokens, request_timeout, use_long_context)
result = await make_request(model, client, output_tokens, request_timeout, use_long_context)
if result:
results.append(result)
else:
Expand All @@ -173,7 +173,7 @@ def calculate_percentile(values, percentile, reverse=False):
return np.percentile(values, 100 - percentile)
return np.percentile(values, percentile)

async def run_benchmark(num_requests, concurrency, request_timeout, output_tokens, vllm_url, api_key, use_long_context):
async def run_benchmark(model, num_requests, concurrency, request_timeout, output_tokens, vllm_url, api_key, use_long_context):
client = AsyncOpenAI(base_url=vllm_url, api_key=api_key)
semaphore = asyncio.Semaphore(concurrency)
queue = asyncio.Queue()
Expand All @@ -188,7 +188,7 @@ async def run_benchmark(num_requests, concurrency, request_timeout, output_token
await queue.put(None)

# Create worker tasks
workers = [asyncio.create_task(worker(client, semaphore, queue, results, output_tokens, request_timeout, use_long_context)) for _ in range(concurrency)]
workers = [asyncio.create_task(worker(model, client, semaphore, queue, results, output_tokens, request_timeout, use_long_context)) for _ in range(concurrency)]

start_time = time.time()

Expand Down Expand Up @@ -252,6 +252,7 @@ def print_results(results):

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark LLaMA-3 model with vLLM")
parser.add_argument("--model", type=str, required=True, help="Model to benchmark")
parser.add_argument("--num_requests", type=int, required=True, help="Number of requests to make")
parser.add_argument("--concurrency", type=int, required=True, help="Number of concurrent requests")
parser.add_argument("--request_timeout", type=int, default=30, help="Timeout for each request in seconds (default: 30)")
Expand All @@ -261,7 +262,7 @@ def print_results(results):
parser.add_argument("--use_long_context", action="store_true", help="Use long context prompt pairs instead of short prompts")
args = parser.parse_args()

results = asyncio.run(run_benchmark(args.num_requests, args.concurrency, args.request_timeout, args.output_tokens, args.vllm_url, args.api_key, args.use_long_context))
results = asyncio.run(run_benchmark(args.model, args.num_requests, args.concurrency, args.request_timeout, args.output_tokens, args.vllm_url, args.api_key, args.use_long_context))
print_results(results)
else:
# When imported as a module, provide the run_benchmark function
Expand Down