Skip to content

Commit ac4cff3

Browse files
authored
Expose PyTorch profiler configuration to environment variables
* [FEAT] Expose PyTorch profiler config via environment variables * [DOC] update profiling.md * NOT change profiler in v0 worker. related to #18571
1 parent 2cc5711 commit ac4cff3

File tree

4 files changed

+56
-4
lines changed

4 files changed

+56
-4
lines changed

docs/contributing/profiling.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,15 @@
55

66
## Profile with PyTorch Profiler
77

8-
We support tracing vLLM workers using the `torch.profiler` module. You can enable tracing by setting the `VLLM_TORCH_PROFILER_DIR` environment variable to the directory where you want to save the traces: `VLLM_TORCH_PROFILER_DIR=/mnt/traces/`
8+
We support tracing vLLM workers using the `torch.profiler` module. You can enable tracing by setting the `VLLM_TORCH_PROFILER_DIR` environment variable to the directory where you want to save the traces: `VLLM_TORCH_PROFILER_DIR=/mnt/traces/`. Additionally, you can control the profiling content by specifying the following environment variables:
9+
10+
`VLLM_TORCH_PROFILER_RECORD_SHAPES=1` to enable recording Tensor Shapes, off by default
11+
12+
`VLLM_TORCH_PROFILER_PROFILE_MEMORY=1` to record memory, off by default
13+
14+
`VLLM_TORCH_PROFILER_WITH_STACK=1` to enable recording stack information, on by default
15+
16+
`VLLM_TORCH_PROFILER_WITH_FLOPS=1` to enable recording FLOPs, off by default
917

1018
The OpenAI server also needs to be started with the `VLLM_TORCH_PROFILER_DIR` environment variable set.
1119

vllm/envs.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@
7979
VLLM_PLUGINS: Optional[list[str]] = None
8080
VLLM_LORA_RESOLVER_CACHE_DIR: Optional[str] = None
8181
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
82+
VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False
83+
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False
84+
VLLM_TORCH_PROFILER_WITH_STACK: bool = True
85+
VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False
8286
VLLM_USE_TRITON_AWQ: bool = False
8387
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
8488
VLLM_SKIP_P2P_CHECK: bool = False
@@ -621,6 +625,26 @@ def get_vllm_port() -> Optional[int]:
621625
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
622626
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))),
623627

628+
# Enable torch profiler to record shapes if set VLLM_TORCH_PROFILER_RECORD_SHAPES=1.
629+
# If not set, torch profiler will not record shapes.
630+
"VLLM_TORCH_PROFILER_RECORD_SHAPES":
631+
lambda: bool(os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES", "0") != "0"),
632+
633+
# Enable torch profiler to profile memory if set VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1.
634+
# If not set, torch profiler will not profile memory.
635+
"VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY":
636+
lambda: bool(os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY", "0") != "0"),
637+
638+
# Enable torch profiler to profile stack if set VLLM_TORCH_PROFILER_WITH_STACK=1.
639+
# If not set, torch profiler WILL profile stack by default.
640+
"VLLM_TORCH_PROFILER_WITH_STACK":
641+
lambda: bool(os.getenv("VLLM_TORCH_PROFILER_WITH_STACK", "1") != "0"),
642+
643+
# Enable torch profiler to profile flops if set VLLM_TORCH_PROFILER_WITH_FLOPS=1.
644+
# If not set, torch profiler will not profile flops.
645+
"VLLM_TORCH_PROFILER_WITH_FLOPS":
646+
lambda: bool(os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS", "0") != "0"),
647+
624648
# If set, vLLM will use Triton implementations of AWQ.
625649
"VLLM_USE_TRITON_AWQ":
626650
lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),

vllm/v1/worker/gpu_worker.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,22 @@ def __init__(
7171
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
7272
logger.info("Profiling enabled. Traces will be saved to: %s",
7373
torch_profiler_trace_dir)
74+
logger.debug(
75+
"Profiler config: record_shapes=%s, profile_memory=%s, with_stack=%s, with_flops=%s",
76+
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
77+
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
78+
envs.VLLM_TORCH_PROFILER_WITH_STACK,
79+
envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
80+
)
7481
self.profiler = torch.profiler.profile(
7582
activities=[
7683
torch.profiler.ProfilerActivity.CPU,
7784
torch.profiler.ProfilerActivity.CUDA,
7885
],
79-
with_stack=True,
86+
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
87+
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
88+
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
89+
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
8090
on_trace_ready=torch.profiler.tensorboard_trace_handler(
8191
torch_profiler_trace_dir, use_gzip=True))
8292
else:
@@ -209,7 +219,7 @@ def reload_weights(self) -> None:
209219

210220
@torch.inference_mode()
211221
def determine_available_memory(self) -> int:
212-
"""Profiles the peak memory usage of the model to determine how much
222+
"""Profiles the peak memory usage of the model to determine how much
213223
memory can be used for KV cache without OOMs.
214224
215225
The engine will first conduct a profiling of the existing memory usage.

vllm/v1/worker/xpu_worker.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,22 @@ def __init__(
4141
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
4242
logger.info("Profiling enabled. Traces will be saved to: %s",
4343
torch_profiler_trace_dir)
44+
logger.debug(
45+
"Profiler config: record_shapes=%s, profile_memory=%s, with_stack=%s, with_flops=%s",
46+
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
47+
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
48+
envs.VLLM_TORCH_PROFILER_WITH_STACK,
49+
envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
50+
)
4451
self.profiler = torch.profiler.profile(
4552
activities=[
4653
torch.profiler.ProfilerActivity.CPU,
4754
torch.profiler.ProfilerActivity.XPU,
4855
],
49-
with_stack=True,
56+
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
57+
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
58+
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
59+
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
5060
on_trace_ready=torch.profiler.tensorboard_trace_handler(
5161
torch_profiler_trace_dir, use_gzip=True))
5262
else:

0 commit comments

Comments
 (0)