Skip to content

Expose PyTorch profiler configuration to environment variables #21803

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 30, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion docs/contributing/profiling.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@

## Profile with PyTorch Profiler

We support tracing vLLM workers using the `torch.profiler` module. You can enable tracing by setting the `VLLM_TORCH_PROFILER_DIR` environment variable to the directory where you want to save the traces: `VLLM_TORCH_PROFILER_DIR=/mnt/traces/`
We support tracing vLLM workers using the `torch.profiler` module. You can enable tracing by setting the `VLLM_TORCH_PROFILER_DIR` environment variable to the directory where you want to save the traces: `VLLM_TORCH_PROFILER_DIR=/mnt/traces/`. Additionally, you can control the profiling content by specifying the following environment variables:

`VLLM_TORCH_PROFILER_RECORD_SHAPES=1` to enable recording Tensor Shapes, off by default

`VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1` to record memory, off by default

`VLLM_TORCH_PROFILER_WITH_STACK=1` to enable recording stack information, on by default

`VLLM_TORCH_PROFILER_WITH_FLOPS=1` to enable recording FLOPs, off by default

The OpenAI server also needs to be started with the `VLLM_TORCH_PROFILER_DIR` environment variable set.

Expand Down
29 changes: 29 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@
VLLM_PLUGINS: Optional[list[str]] = None
VLLM_LORA_RESOLVER_CACHE_DIR: Optional[str] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False
VLLM_TORCH_PROFILER_WITH_STACK: bool = True
VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
Expand Down Expand Up @@ -621,6 +625,31 @@ def get_vllm_port() -> Optional[int]:
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))),

# Enable torch profiler to record shapes if set
# VLLM_TORCH_PROFILER_RECORD_SHAPES=1. If not set, torch profiler will
# not record shapes.
"VLLM_TORCH_PROFILER_RECORD_SHAPES":
lambda: bool(os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES", "0") != "0"),

# Enable torch profiler to profile memory if set
# VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1. If not set, torch profiler
# will not profile memory.
"VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY":
lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY", "0") != "0"),

# Enable torch profiler to profile stack if set
# VLLM_TORCH_PROFILER_WITH_STACK=1. If not set, torch profiler WILL
# profile stack by default.
"VLLM_TORCH_PROFILER_WITH_STACK":
lambda: bool(os.getenv("VLLM_TORCH_PROFILER_WITH_STACK", "1") != "0"),

# Enable torch profiler to profile flops if set
# VLLM_TORCH_PROFILER_WITH_FLOPS=1. If not set, torch profiler will
# not profile flops.
"VLLM_TORCH_PROFILER_WITH_FLOPS":
lambda: bool(os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS", "0") != "0"),

# If set, vLLM will use Triton implementations of AWQ.
"VLLM_USE_TRITON_AWQ":
lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),
Expand Down
15 changes: 13 additions & 2 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,23 @@ def __init__(
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
logger.debug(
"Profiler config: record_shapes=%s,"
"profile_memory=%s,with_stack=%s,with_flops=%s",
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
envs.VLLM_TORCH_PROFILER_WITH_STACK,
envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
Expand Down Expand Up @@ -209,7 +220,7 @@ def reload_weights(self) -> None:

@torch.inference_mode()
def determine_available_memory(self) -> int:
"""Profiles the peak memory usage of the model to determine how much
"""Profiles the peak memory usage of the model to determine how much
memory can be used for KV cache without OOMs.

The engine will first conduct a profiling of the existing memory usage.
Expand Down
13 changes: 12 additions & 1 deletion vllm/v1/worker/xpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,23 @@ def __init__(
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
logger.debug(
"Profiler config: record_shapes=%s,"
"profile_memory=%s,with_stack=%s,with_flops=%s",
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
envs.VLLM_TORCH_PROFILER_WITH_STACK,
envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.XPU,
],
with_stack=True,
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
Expand Down