12
12
from vllm .compilation .counter import compilation_counter
13
13
from vllm .compilation .decorators import (ignore_torch_compile ,
14
14
support_torch_compile )
15
- from vllm .config import (CompilationConfig , CompilationLevel , VllmConfig ,
16
- set_current_vllm_config )
17
- from vllm .envs import VLLM_USE_V1
18
- from vllm .forward_context import set_forward_context
15
+ from vllm .config import (CompilationConfig , CompilationLevel , CUDAGraphMode ,
16
+ VllmConfig , set_current_vllm_config )
17
+ from vllm .forward_context import BatchDescriptor , set_forward_context
19
18
from vllm .utils import direct_register_custom_op
20
19
21
20
# create a library to hold the custom op
@@ -164,104 +163,34 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
164
163
return x
165
164
166
165
167
- def test_ignore_torch_compile_decorator ():
168
- assert VLLM_USE_V1
169
-
170
- # piecewise
171
- vllm_config = VllmConfig (compilation_config = CompilationConfig (
172
- level = CompilationLevel .PIECEWISE ,
173
- use_cudagraph = True ,
174
- splitting_ops = ["silly.attention" ],
175
- cudagraph_capture_sizes = [1 , 2 ],
176
- ))
177
-
178
- @support_torch_compile
179
- class A (nn .Module ):
180
-
181
- def __init__ (self ,
182
- * ,
183
- vllm_config : VllmConfig ,
184
- prefix : str = '' ,
185
- ** kwargs ) -> None :
186
- super ().__init__ ()
187
-
188
- def forward (self , x : torch .Tensor ) -> torch .Tensor :
189
- x = x + x
190
- attn_output = torch .empty_like (x )
191
- torch .ops .silly .attention (x , x , x , attn_output )
192
- x = attn_output
193
- x = x * 3
194
- return x
195
-
196
- @ignore_torch_compile
197
- class B (A ):
198
- ...
199
-
200
- @support_torch_compile
201
- class C (B ):
202
- ...
203
-
204
- with set_current_vllm_config (vllm_config ):
205
- mod_A = A (vllm_config = vllm_config , prefix = '' ).eval ().cuda ()
206
-
207
- # A has support_torch_compile
208
- with compilation_counter .expect (
209
- num_graphs_seen = 1 ,
210
- num_piecewise_graphs_seen = 3 ,
211
- num_piecewise_capturable_graphs_seen = 2 ,
212
- num_backend_compilations = 2 ,
213
- num_cudagraph_captured = 4 ,
214
- # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
215
- ), set_forward_context ({}, vllm_config = vllm_config ):
216
- # first run is for compile
217
- mod_A (torch .randn (BATCH_SIZE , MLP_SIZE ).cuda ())
218
- # run cudagraph captured sizes
219
- mod_A (torch .randn (2 , MLP_SIZE ).cuda ())
220
- mod_A (torch .randn (1 , MLP_SIZE ).cuda ())
221
-
222
- with set_current_vllm_config (vllm_config ):
223
- mod_B = B (vllm_config = vllm_config , prefix = '' ).eval ().cuda ()
224
-
225
- # B's ignore_torch_compile should override A's support_torch_compile
226
- with compilation_counter .expect (
227
- num_graphs_seen = 0 ,
228
- num_piecewise_graphs_seen = 0 ,
229
- num_piecewise_capturable_graphs_seen = 0 ,
230
- num_backend_compilations = 0 ,
231
- num_cudagraph_captured = 0 ,
232
- ), set_forward_context ({}, vllm_config = vllm_config ):
233
- mod_B (torch .randn (BATCH_SIZE , MLP_SIZE ).cuda ())
234
- mod_B (torch .randn (2 , MLP_SIZE ).cuda ())
235
- mod_B (torch .randn (1 , MLP_SIZE ).cuda ())
236
-
237
- with set_current_vllm_config (vllm_config ):
238
- mod_C = C (vllm_config = vllm_config , prefix = '' ).eval ().cuda ()
239
-
240
- # C's support_torch_compile should override B's ignore_torch_compile
241
- with compilation_counter .expect (
242
- num_graphs_seen = 1 ,
243
- num_piecewise_graphs_seen = 3 ,
244
- num_piecewise_capturable_graphs_seen = 2 ,
245
- num_backend_compilations = 2 ,
246
- num_cudagraph_captured = 4 ,
247
- # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
248
- ), set_forward_context ({}, vllm_config = vllm_config ):
249
- mod_C (torch .randn (BATCH_SIZE , MLP_SIZE ).cuda ())
250
- mod_C (torch .randn (2 , MLP_SIZE ).cuda ())
251
- mod_C (torch .randn (1 , MLP_SIZE ).cuda ())
252
-
253
-
254
166
@torch .inference_mode
255
- def run_model (vllm_config , model : nn .Module , inputs : torch .Tensor ):
167
+ def run_model (vllm_config : VllmConfig , model : nn .Module , inputs : torch .Tensor ,
168
+ cudagraph_runtime_mode : CUDAGraphMode ):
256
169
with set_forward_context ({}, vllm_config = vllm_config ):
257
- # First run is for compile
170
+ # warmup for the model with cudagraph_mode NONE
258
171
model (inputs )
259
172
260
- # Run CUDAGraph captured sizes
261
- model (inputs [:2 ])
262
- model (inputs [:1 ])
263
-
264
- output = model (inputs [:2 ])
173
+ # simulate cudagraphs capturing
174
+ with set_forward_context ({},
175
+ vllm_config = vllm_config ,
176
+ cudagraph_runtime_mode = cudagraph_runtime_mode ,
177
+ batch_descriptor = BatchDescriptor (
178
+ num_tokens = 2 , )):
179
+ model (inputs [:2 ])
180
+ with set_forward_context ({},
181
+ vllm_config = vllm_config ,
182
+ cudagraph_runtime_mode = cudagraph_runtime_mode ,
183
+ batch_descriptor = BatchDescriptor (
184
+ num_tokens = 1 , )):
185
+ model (inputs [:1 ])
186
+
187
+ # simulate cudagraphs replay
188
+ with set_forward_context ({},
189
+ vllm_config = vllm_config ,
190
+ cudagraph_runtime_mode = cudagraph_runtime_mode ,
191
+ batch_descriptor = BatchDescriptor (
192
+ num_tokens = 2 , )):
193
+ output = model (inputs [:2 ])
265
194
266
195
output = output .cpu ()
267
196
return output .cpu ()
@@ -277,6 +206,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
277
206
splitting_ops = ["silly.attention" ],
278
207
cudagraph_capture_sizes = [1 , 2 ],
279
208
))
209
+ cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE
280
210
281
211
with set_current_vllm_config (vllm_config ):
282
212
model = SimpleModelWithTwoGraphs (mlp_size = MLP_SIZE ,
@@ -299,11 +229,13 @@ def test_multi_graph_piecewise_compile_outputs_equal():
299
229
num_cudagraph_captured = 8 ,
300
230
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
301
231
):
302
- outputs .append (run_model (vllm_config , model , inputs ))
232
+ outputs .append (
233
+ run_model (vllm_config , model , inputs , cudagraph_runtime_mode ))
303
234
304
235
# no compile or cudagraph
305
236
vllm_config = VllmConfig (compilation_config = CompilationConfig (
306
237
level = CompilationLevel .NO_COMPILATION , ))
238
+ cudagraph_runtime_mode = CUDAGraphMode .NONE
307
239
308
240
with set_current_vllm_config (vllm_config ):
309
241
model = SimpleModelWithTwoGraphs (mlp_size = MLP_SIZE ,
@@ -318,14 +250,16 @@ def test_multi_graph_piecewise_compile_outputs_equal():
318
250
num_backend_compilations = 0 ,
319
251
num_cudagraph_captured = 0 ,
320
252
):
321
- outputs .append (run_model (vllm_config , model , inputs ))
253
+ outputs .append (
254
+ run_model (vllm_config , model , inputs , cudagraph_runtime_mode ))
322
255
323
256
# piecewise compile without CUDA graph
324
257
vllm_config = VllmConfig (compilation_config = CompilationConfig (
325
258
level = CompilationLevel .PIECEWISE ,
326
259
use_cudagraph = False ,
327
260
splitting_ops = ["silly.attention" ],
328
261
))
262
+ cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE
329
263
330
264
with set_current_vllm_config (vllm_config ):
331
265
model = SimpleModelWithTwoGraphs (mlp_size = MLP_SIZE ,
@@ -340,7 +274,8 @@ def test_multi_graph_piecewise_compile_outputs_equal():
340
274
num_backend_compilations = 4 ,
341
275
num_cudagraph_captured = 0 , # no cudagraph captured
342
276
):
343
- outputs .append (run_model (vllm_config , model , inputs ))
277
+ outputs .append (
278
+ run_model (vllm_config , model , inputs , cudagraph_runtime_mode ))
344
279
345
280
# Generally don't expect outputs with and without inductor
346
281
# to be bitwise equivalent
0 commit comments