Skip to content

Commit 70c90cd

Browse files
sarckkepwalsh
authored andcommitted
[torch.compile] Support conditional torch.compile per module (vllm-project#22269)
Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent bc24971 commit 70c90cd

File tree

4 files changed

+308
-103
lines changed

4 files changed

+308
-103
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ steps:
328328
- pytest -v -s compile/test_sequence_parallelism.py
329329
- pytest -v -s compile/test_async_tp.py
330330
- pytest -v -s compile/test_fusion_all_reduce.py
331+
- pytest -v -s compile/test_decorator.py
331332

332333
- label: PyTorch Fullgraph Smoke Test # 9min
333334
mirror_hardwares: [amdexperimental]
@@ -341,6 +342,7 @@ steps:
341342
- pytest -v -s compile/piecewise/test_simple.py
342343
- pytest -v -s compile/piecewise/test_toy_llama.py
343344
- pytest -v -s compile/piecewise/test_full_cudagraph.py
345+
- pytest -v -s compile/piecewise/test_multiple_graphs.py
344346

345347
- label: PyTorch Fullgraph Test # 18min
346348
mirror_hardwares: [amdexperimental]

tests/compile/piecewise/test_multiple_graphs.py

Lines changed: 36 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212
from vllm.compilation.counter import compilation_counter
1313
from vllm.compilation.decorators import (ignore_torch_compile,
1414
support_torch_compile)
15-
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
16-
set_current_vllm_config)
17-
from vllm.envs import VLLM_USE_V1
18-
from vllm.forward_context import set_forward_context
15+
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
16+
VllmConfig, set_current_vllm_config)
17+
from vllm.forward_context import BatchDescriptor, set_forward_context
1918
from vllm.utils import direct_register_custom_op
2019

2120
# create a library to hold the custom op
@@ -164,104 +163,34 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
164163
return x
165164

166165

167-
def test_ignore_torch_compile_decorator():
168-
assert VLLM_USE_V1
169-
170-
# piecewise
171-
vllm_config = VllmConfig(compilation_config=CompilationConfig(
172-
level=CompilationLevel.PIECEWISE,
173-
use_cudagraph=True,
174-
splitting_ops=["silly.attention"],
175-
cudagraph_capture_sizes=[1, 2],
176-
))
177-
178-
@support_torch_compile
179-
class A(nn.Module):
180-
181-
def __init__(self,
182-
*,
183-
vllm_config: VllmConfig,
184-
prefix: str = '',
185-
**kwargs) -> None:
186-
super().__init__()
187-
188-
def forward(self, x: torch.Tensor) -> torch.Tensor:
189-
x = x + x
190-
attn_output = torch.empty_like(x)
191-
torch.ops.silly.attention(x, x, x, attn_output)
192-
x = attn_output
193-
x = x * 3
194-
return x
195-
196-
@ignore_torch_compile
197-
class B(A):
198-
...
199-
200-
@support_torch_compile
201-
class C(B):
202-
...
203-
204-
with set_current_vllm_config(vllm_config):
205-
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
206-
207-
# A has support_torch_compile
208-
with compilation_counter.expect(
209-
num_graphs_seen=1,
210-
num_piecewise_graphs_seen=3,
211-
num_piecewise_capturable_graphs_seen=2,
212-
num_backend_compilations=2,
213-
num_cudagraph_captured=4,
214-
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
215-
), set_forward_context({}, vllm_config=vllm_config):
216-
# first run is for compile
217-
mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
218-
# run cudagraph captured sizes
219-
mod_A(torch.randn(2, MLP_SIZE).cuda())
220-
mod_A(torch.randn(1, MLP_SIZE).cuda())
221-
222-
with set_current_vllm_config(vllm_config):
223-
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda()
224-
225-
# B's ignore_torch_compile should override A's support_torch_compile
226-
with compilation_counter.expect(
227-
num_graphs_seen=0,
228-
num_piecewise_graphs_seen=0,
229-
num_piecewise_capturable_graphs_seen=0,
230-
num_backend_compilations=0,
231-
num_cudagraph_captured=0,
232-
), set_forward_context({}, vllm_config=vllm_config):
233-
mod_B(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
234-
mod_B(torch.randn(2, MLP_SIZE).cuda())
235-
mod_B(torch.randn(1, MLP_SIZE).cuda())
236-
237-
with set_current_vllm_config(vllm_config):
238-
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda()
239-
240-
# C's support_torch_compile should override B's ignore_torch_compile
241-
with compilation_counter.expect(
242-
num_graphs_seen=1,
243-
num_piecewise_graphs_seen=3,
244-
num_piecewise_capturable_graphs_seen=2,
245-
num_backend_compilations=2,
246-
num_cudagraph_captured=4,
247-
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
248-
), set_forward_context({}, vllm_config=vllm_config):
249-
mod_C(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
250-
mod_C(torch.randn(2, MLP_SIZE).cuda())
251-
mod_C(torch.randn(1, MLP_SIZE).cuda())
252-
253-
254166
@torch.inference_mode
255-
def run_model(vllm_config, model: nn.Module, inputs: torch.Tensor):
167+
def run_model(vllm_config: VllmConfig, model: nn.Module, inputs: torch.Tensor,
168+
cudagraph_runtime_mode: CUDAGraphMode):
256169
with set_forward_context({}, vllm_config=vllm_config):
257-
# First run is for compile
170+
# warmup for the model with cudagraph_mode NONE
258171
model(inputs)
259172

260-
# Run CUDAGraph captured sizes
261-
model(inputs[:2])
262-
model(inputs[:1])
263-
264-
output = model(inputs[:2])
173+
# simulate cudagraphs capturing
174+
with set_forward_context({},
175+
vllm_config=vllm_config,
176+
cudagraph_runtime_mode=cudagraph_runtime_mode,
177+
batch_descriptor=BatchDescriptor(
178+
num_tokens=2, )):
179+
model(inputs[:2])
180+
with set_forward_context({},
181+
vllm_config=vllm_config,
182+
cudagraph_runtime_mode=cudagraph_runtime_mode,
183+
batch_descriptor=BatchDescriptor(
184+
num_tokens=1, )):
185+
model(inputs[:1])
186+
187+
# simulate cudagraphs replay
188+
with set_forward_context({},
189+
vllm_config=vllm_config,
190+
cudagraph_runtime_mode=cudagraph_runtime_mode,
191+
batch_descriptor=BatchDescriptor(
192+
num_tokens=2, )):
193+
output = model(inputs[:2])
265194

266195
output = output.cpu()
267196
return output.cpu()
@@ -277,6 +206,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
277206
splitting_ops=["silly.attention"],
278207
cudagraph_capture_sizes=[1, 2],
279208
))
209+
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
280210

281211
with set_current_vllm_config(vllm_config):
282212
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
@@ -299,11 +229,13 @@ def test_multi_graph_piecewise_compile_outputs_equal():
299229
num_cudagraph_captured=8,
300230
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
301231
):
302-
outputs.append(run_model(vllm_config, model, inputs))
232+
outputs.append(
233+
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
303234

304235
# no compile or cudagraph
305236
vllm_config = VllmConfig(compilation_config=CompilationConfig(
306237
level=CompilationLevel.NO_COMPILATION, ))
238+
cudagraph_runtime_mode = CUDAGraphMode.NONE
307239

308240
with set_current_vllm_config(vllm_config):
309241
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
@@ -318,14 +250,16 @@ def test_multi_graph_piecewise_compile_outputs_equal():
318250
num_backend_compilations=0,
319251
num_cudagraph_captured=0,
320252
):
321-
outputs.append(run_model(vllm_config, model, inputs))
253+
outputs.append(
254+
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
322255

323256
# piecewise compile without CUDA graph
324257
vllm_config = VllmConfig(compilation_config=CompilationConfig(
325258
level=CompilationLevel.PIECEWISE,
326259
use_cudagraph=False,
327260
splitting_ops=["silly.attention"],
328261
))
262+
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
329263

330264
with set_current_vllm_config(vllm_config):
331265
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
@@ -340,7 +274,8 @@ def test_multi_graph_piecewise_compile_outputs_equal():
340274
num_backend_compilations=4,
341275
num_cudagraph_captured=0, # no cudagraph captured
342276
):
343-
outputs.append(run_model(vllm_config, model, inputs))
277+
outputs.append(
278+
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
344279

345280
# Generally don't expect outputs with and without inductor
346281
# to be bitwise equivalent

0 commit comments

Comments
 (0)