Skip to content

Commit 04e3850

Browse files
authored
[Bugfix] VLLM_V1 supports passing other compilation levels (#19340)
Signed-off-by: Richard Zou <[email protected]>
1 parent ab71413 commit 04e3850

File tree

5 files changed

+88
-5
lines changed

5 files changed

+88
-5
lines changed

tests/compile/test_config.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,17 @@ def test_use_cudagraphs_dynamic(monkeypatch):
2626
assert not vllm_config.compilation_config.use_cudagraph
2727

2828

29+
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
30+
@pytest.mark.forked
2931
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends
3032
# on the state of the cache directory on the current machine, which
3133
# may be influenced by other tests.
3234
@pytest.mark.parametrize("val", ["1"])
3335
def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
3436
assert vllm.envs.VLLM_USE_V1
3537

36-
# spawn means that the counters are in the same process.
37-
monkeypatch.setenv('VLLM_WORKER_MULTIPROC_METHOD', "spawn")
38+
# Disable multiprocessing so that the counter is in the same process
39+
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
3840
monkeypatch.setenv('VLLM_DISABLE_COMPILE_CACHE', val)
3941

4042
compilation_config = {
@@ -50,6 +52,8 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
5052
pass
5153

5254

55+
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
56+
@pytest.mark.forked
5357
@pytest.mark.parametrize("enabled", [True, False])
5458
def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
5559
assert vllm.envs.VLLM_USE_V1
@@ -72,3 +76,50 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
7276
compilation_config=compilation_config,
7377
gpu_memory_utilization=0.4) as _):
7478
pass
79+
80+
81+
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
82+
@pytest.mark.forked
83+
def test_dynamo_as_is(vllm_runner, monkeypatch):
84+
# Disable multiprocessing so that the counter is in the same process
85+
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
86+
87+
with (
88+
compilation_counter.expect(dynamo_as_is_count=1),
89+
# loading the model causes compilation (if enabled) to happen
90+
vllm_runner('facebook/opt-125m',
91+
compilation_config={"level": 1},
92+
gpu_memory_utilization=0.4) as _):
93+
pass
94+
95+
96+
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
97+
@pytest.mark.forked
98+
def test_no_compilation(vllm_runner, monkeypatch):
99+
# Disable multiprocessing so that the counter is in the same process
100+
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
101+
102+
with (
103+
compilation_counter.expect(num_graphs_seen=0,
104+
dynamo_as_is_count=0),
105+
# loading the model causes compilation (if enabled) to happen
106+
vllm_runner('facebook/opt-125m',
107+
compilation_config={"level": 0},
108+
gpu_memory_utilization=0.4) as _):
109+
pass
110+
111+
112+
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
113+
@pytest.mark.forked
114+
def test_enforce_eager(vllm_runner, monkeypatch):
115+
# Disable multiprocessing so that the counter is in the same process
116+
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
117+
118+
with (
119+
compilation_counter.expect(num_graphs_seen=0,
120+
dynamo_as_is_count=0),
121+
# loading the model causes compilation (if enabled) to happen
122+
vllm_runner('facebook/opt-125m',
123+
enforce_eager=True,
124+
gpu_memory_utilization=0.4) as _):
125+
pass

vllm/compilation/counter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class CompilationCounter:
2727
num_cache_entries_updated: int = 0
2828
# The number of standalone_compile compiled artifacts saved
2929
num_compiled_artifacts_saved: int = 0
30+
# Number of times a model was loaded with CompilationLevel.DYNAMO_AS_IS
31+
dynamo_as_is_count: int = 0
3032

3133
def clone(self) -> "CompilationCounter":
3234
return copy.deepcopy(self)

vllm/config.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4106,9 +4106,11 @@ class CompilationConfig:
41064106
certain small batchsizes, where inductor is good at optimizing.
41074107
"""
41084108
# Top-level Compilation control
4109-
level: int = 0
4109+
level: Optional[int] = None
41104110
"""The level of compilation:
41114111
4112+
- None: If None, we will select the default compilation level.
4113+
For V1 engine this is 3, for V0 engine this is 0.
41124114
- 0: no compilation.
41134115
- 1: dynamo as is.
41144116
- 2: dynamo once.
@@ -4664,6 +4666,22 @@ def __post_init__(self):
46644666
"To workaround this limitation, vLLM will set 'ieee' input "
46654667
"precision for chunked prefill triton kernels.")
46664668

4669+
# If the user does not explicitly set a compilation level, then
4670+
# we use the default level. The default level depends on other
4671+
# settings (see the below code).
4672+
if self.compilation_config.level is None:
4673+
if envs.VLLM_USE_V1:
4674+
if (self.model_config is not None
4675+
and not self.model_config.enforce_eager):
4676+
self.compilation_config.level = CompilationLevel.PIECEWISE
4677+
else:
4678+
self.compilation_config.level = \
4679+
CompilationLevel.NO_COMPILATION
4680+
else:
4681+
# NB: Passing both --enforce-eager and a compilation level
4682+
# in V0 means the compilation level wins out.
4683+
self.compilation_config.level = CompilationLevel.NO_COMPILATION
4684+
46674685
# async tp is built on top of sequence parallelism
46684686
# and requires it to be enabled.
46694687
if self.compilation_config.pass_config.enable_async_tp:
@@ -4676,7 +4694,6 @@ def __post_init__(self):
46764694
# By default, V1 uses piecewise CUDA graphs. If full_cuda_graph
46774695
# is set to True, full CUDA graphs will be used.
46784696
self.compilation_config.cudagraph_num_of_warmups = 1
4679-
self.compilation_config.level = CompilationLevel.PIECEWISE
46804697
self.compilation_config.set_splitting_ops_for_v1()
46814698

46824699
self._set_cudagraph_sizes()

vllm/v1/worker/gpu_model_runner.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
4444
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
4545
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
46-
is_pin_memory_available, round_up)
46+
is_pin_memory_available, round_up, supports_dynamo)
4747
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
4848
from vllm.v1.attention.backends.utils import (
4949
AttentionMetadataBuilder, CommonAttentionMetadata,
@@ -1930,6 +1930,17 @@ def load_model(self, eep_scale_up: bool = False) -> None:
19301930
rank_mapping,
19311931
)
19321932

1933+
if (
1934+
self.vllm_config.compilation_config.level == \
1935+
CompilationLevel.DYNAMO_AS_IS and supports_dynamo()
1936+
):
1937+
backend = self.vllm_config.compilation_config.init_backend(
1938+
self.vllm_config)
1939+
compilation_counter.dynamo_as_is_count += 1
1940+
self.model.compile(
1941+
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
1942+
backend=backend)
1943+
19331944
def reload_weights(self) -> None:
19341945
assert getattr(self, "model", None) is not None, \
19351946
"Cannot reload weights before model is loaded."

vllm/worker/model_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from vllm.attention import AttentionMetadata, get_attn_backend
2323
from vllm.attention.backends.abstract import AttentionState
2424
from vllm.attention.backends.utils import CommonAttentionState
25+
from vllm.compilation.counter import compilation_counter
2526
from vllm.config import CompilationLevel, VllmConfig
2627
from vllm.core.scheduler import SchedulerOutputs
2728
from vllm.distributed import broadcast_tensor_dict, get_pp_group
@@ -1121,6 +1122,7 @@ def load_model(self) -> None:
11211122
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
11221123
backend = self.vllm_config.compilation_config.init_backend(
11231124
self.vllm_config)
1125+
compilation_counter.dynamo_as_is_count += 1
11241126
self.model = torch.compile(
11251127
self.model,
11261128
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,

0 commit comments

Comments
 (0)