Skip to content

Commit d1bf7f0

Browse files
committed
add example for enabilng cuda graph
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
1 parent 45fd5d6 commit d1bf7f0

File tree

1 file changed

+73
-0
lines changed

1 file changed

+73
-0
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from vllm import LLM, EngineArgs
4+
from vllm.utils import FlexibleArgumentParser
5+
6+
7+
def create_parser():
8+
parser = FlexibleArgumentParser()
9+
# Add engine args
10+
engine_group = parser.add_argument_group("Engine arguments")
11+
EngineArgs.add_cli_args(engine_group)
12+
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
13+
# Add sampling params
14+
sampling_group = parser.add_argument_group("Sampling parameters")
15+
sampling_group.add_argument("--max-tokens", type=int)
16+
sampling_group.add_argument("--temperature", type=float)
17+
sampling_group.add_argument("--top-p", type=float)
18+
sampling_group.add_argument("--top-k", type=int)
19+
20+
return parser
21+
22+
23+
def main(args: dict):
24+
# Pop arguments not used by LLM
25+
max_tokens = args.pop("max_tokens")
26+
temperature = args.pop("temperature")
27+
top_p = args.pop("top_p")
28+
top_k = args.pop("top_k")
29+
30+
# Create an LLM
31+
args.pop("compilation_config",
32+
None) # Remove compilation_config if it exists
33+
args.pop("max_num_seqs", None) # Remove max_num_seqs if it exists
34+
llm = LLM(**args,
35+
max_num_seqs=256,
36+
compilation_config={
37+
"full_cuda_graph": True,
38+
"cudagraph_capture_sizes": [64, 256]
39+
})
40+
41+
# Create a sampling params object
42+
sampling_params = llm.get_default_sampling_params()
43+
if max_tokens is not None:
44+
sampling_params.max_tokens = max_tokens
45+
if temperature is not None:
46+
sampling_params.temperature = temperature
47+
if top_p is not None:
48+
sampling_params.top_p = top_p
49+
if top_k is not None:
50+
sampling_params.top_k = top_k
51+
52+
# Generate texts from the prompts. The output is a list of RequestOutput
53+
# objects that contain the prompt, generated text, and other information.
54+
prompts = [
55+
"Hello, my name is",
56+
"The president of the United States is",
57+
"The capital of France is",
58+
"The future of AI is",
59+
]
60+
outputs = llm.generate(prompts, sampling_params)
61+
# Print the outputs.
62+
print("-" * 50)
63+
for output in outputs:
64+
prompt = output.prompt
65+
generated_text = output.outputs[0].text
66+
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
67+
print("-" * 50)
68+
69+
70+
if __name__ == "__main__":
71+
parser = create_parser()
72+
args: dict = vars(parser.parse_args())
73+
main(args)

0 commit comments

Comments
 (0)