Skip to content

Commit b917da4

Browse files
authored
Expose PyTorch profiler configuration to environment variables (#21803)
Signed-off-by: Csrayz <[email protected]>
1 parent fb58e3a commit b917da4

File tree

4 files changed

+60
-4
lines changed

4 files changed

+60
-4
lines changed

docs/contributing/profiling.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55

66
## Profile with PyTorch Profiler
77

8-
We support tracing vLLM workers using the `torch.profiler` module. You can enable tracing by setting the `VLLM_TORCH_PROFILER_DIR` environment variable to the directory where you want to save the traces: `VLLM_TORCH_PROFILER_DIR=/mnt/traces/`
8+
We support tracing vLLM workers using the `torch.profiler` module. You can enable tracing by setting the `VLLM_TORCH_PROFILER_DIR` environment variable to the directory where you want to save the traces: `VLLM_TORCH_PROFILER_DIR=/mnt/traces/`. Additionally, you can control the profiling content by specifying the following environment variables:
9+
10+
- `VLLM_TORCH_PROFILER_RECORD_SHAPES=1` to enable recording Tensor Shapes, off by default
11+
- `VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1` to record memory, off by default
12+
- `VLLM_TORCH_PROFILER_WITH_STACK=1` to enable recording stack information, on by default
13+
- `VLLM_TORCH_PROFILER_WITH_FLOPS=1` to enable recording FLOPs, off by default
914

1015
The OpenAI server also needs to be started with the `VLLM_TORCH_PROFILER_DIR` environment variable set.
1116

vllm/envs.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@
8080
VLLM_PLUGINS: Optional[list[str]] = None
8181
VLLM_LORA_RESOLVER_CACHE_DIR: Optional[str] = None
8282
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
83+
VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False
84+
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False
85+
VLLM_TORCH_PROFILER_WITH_STACK: bool = True
86+
VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False
8387
VLLM_USE_TRITON_AWQ: bool = False
8488
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
8589
VLLM_SKIP_P2P_CHECK: bool = False
@@ -629,6 +633,31 @@ def get_vllm_port() -> Optional[int]:
629633
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
630634
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))),
631635

636+
# Enable torch profiler to record shapes if set
637+
# VLLM_TORCH_PROFILER_RECORD_SHAPES=1. If not set, torch profiler will
638+
# not record shapes.
639+
"VLLM_TORCH_PROFILER_RECORD_SHAPES":
640+
lambda: bool(os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES", "0") != "0"),
641+
642+
# Enable torch profiler to profile memory if set
643+
# VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1. If not set, torch profiler
644+
# will not profile memory.
645+
"VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY":
646+
lambda: bool(
647+
os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY", "0") != "0"),
648+
649+
# Enable torch profiler to profile stack if set
650+
# VLLM_TORCH_PROFILER_WITH_STACK=1. If not set, torch profiler WILL
651+
# profile stack by default.
652+
"VLLM_TORCH_PROFILER_WITH_STACK":
653+
lambda: bool(os.getenv("VLLM_TORCH_PROFILER_WITH_STACK", "1") != "0"),
654+
655+
# Enable torch profiler to profile flops if set
656+
# VLLM_TORCH_PROFILER_WITH_FLOPS=1. If not set, torch profiler will
657+
# not profile flops.
658+
"VLLM_TORCH_PROFILER_WITH_FLOPS":
659+
lambda: bool(os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS", "0") != "0"),
660+
632661
# If set, vLLM will use Triton implementations of AWQ.
633662
"VLLM_USE_TRITON_AWQ":
634663
lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),

vllm/v1/worker/gpu_worker.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,23 @@ def __init__(
7171
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
7272
logger.info("Profiling enabled. Traces will be saved to: %s",
7373
torch_profiler_trace_dir)
74+
logger.debug(
75+
"Profiler config: record_shapes=%s,"
76+
"profile_memory=%s,with_stack=%s,with_flops=%s",
77+
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
78+
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
79+
envs.VLLM_TORCH_PROFILER_WITH_STACK,
80+
envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
81+
)
7482
self.profiler = torch.profiler.profile(
7583
activities=[
7684
torch.profiler.ProfilerActivity.CPU,
7785
torch.profiler.ProfilerActivity.CUDA,
7886
],
79-
with_stack=True,
87+
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
88+
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
89+
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
90+
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
8091
on_trace_ready=torch.profiler.tensorboard_trace_handler(
8192
torch_profiler_trace_dir, use_gzip=True))
8293
else:
@@ -209,7 +220,7 @@ def reload_weights(self) -> None:
209220

210221
@torch.inference_mode()
211222
def determine_available_memory(self) -> int:
212-
"""Profiles the peak memory usage of the model to determine how much
223+
"""Profiles the peak memory usage of the model to determine how much
213224
memory can be used for KV cache without OOMs.
214225
215226
The engine will first conduct a profiling of the existing memory usage.

vllm/v1/worker/xpu_worker.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,23 @@ def __init__(
4141
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
4242
logger.info("Profiling enabled. Traces will be saved to: %s",
4343
torch_profiler_trace_dir)
44+
logger.debug(
45+
"Profiler config: record_shapes=%s,"
46+
"profile_memory=%s,with_stack=%s,with_flops=%s",
47+
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
48+
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
49+
envs.VLLM_TORCH_PROFILER_WITH_STACK,
50+
envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
51+
)
4452
self.profiler = torch.profiler.profile(
4553
activities=[
4654
torch.profiler.ProfilerActivity.CPU,
4755
torch.profiler.ProfilerActivity.XPU,
4856
],
49-
with_stack=True,
57+
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
58+
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
59+
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
60+
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
5061
on_trace_ready=torch.profiler.tensorboard_trace_handler(
5162
torch_profiler_trace_dir, use_gzip=True))
5263
else:

0 commit comments

Comments
 (0)