8
8
from vllm .compilation .decorators import (ignore_torch_compile ,
9
9
support_torch_compile )
10
10
from vllm .config import (CacheConfig , CompilationConfig , CompilationLevel ,
11
- VllmConfig , set_current_vllm_config )
12
- from vllm .forward_context import set_forward_context
11
+ CUDAGraphMode , VllmConfig , set_current_vllm_config )
12
+ from vllm .forward_context import BatchDescriptor , set_forward_context
13
13
from vllm .utils import direct_register_custom_op
14
14
15
15
# create a library to hold the custom op
@@ -40,6 +40,39 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
40
40
)
41
41
42
42
43
+ @torch .inference_mode
44
+ def run_model (vllm_config : VllmConfig , model : nn .Module ,
45
+ cudagraph_runtime_mode : CUDAGraphMode ):
46
+ with set_forward_context ({}, vllm_config = vllm_config ):
47
+ # warmup for the model with cudagraph_mode NONE
48
+ model (torch .randn (BATCH_SIZE , MLP_SIZE ).cuda ())
49
+
50
+ # simulate cudagraphs capturing
51
+ with set_forward_context ({},
52
+ vllm_config = vllm_config ,
53
+ cudagraph_runtime_mode = cudagraph_runtime_mode ,
54
+ batch_descriptor = BatchDescriptor (
55
+ num_tokens = 2 , )):
56
+ model (torch .randn (2 , MLP_SIZE ).cuda ())
57
+ with set_forward_context ({},
58
+ vllm_config = vllm_config ,
59
+ cudagraph_runtime_mode = cudagraph_runtime_mode ,
60
+ batch_descriptor = BatchDescriptor (
61
+ num_tokens = 1 , )):
62
+ model (torch .randn (1 , MLP_SIZE ).cuda ())
63
+
64
+ # simulate cudagraphs replay
65
+ with set_forward_context ({},
66
+ vllm_config = vllm_config ,
67
+ cudagraph_runtime_mode = cudagraph_runtime_mode ,
68
+ batch_descriptor = BatchDescriptor (
69
+ num_tokens = 2 , )):
70
+ output = model (torch .randn (2 , MLP_SIZE ).cuda ())
71
+
72
+ output = output .cpu ()
73
+ return output .cpu ()
74
+
75
+
43
76
def test_ignore_torch_compile_decorator ():
44
77
# piecewise
45
78
vllm_config = VllmConfig (compilation_config = CompilationConfig (
@@ -48,6 +81,7 @@ def test_ignore_torch_compile_decorator():
48
81
splitting_ops = ["silly.attention" ],
49
82
cudagraph_capture_sizes = [1 , 2 ],
50
83
))
84
+ cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE
51
85
52
86
@support_torch_compile
53
87
class A (nn .Module ):
@@ -86,12 +120,8 @@ class C(B):
86
120
num_backend_compilations = 2 ,
87
121
num_cudagraph_captured = 4 ,
88
122
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
89
- ), set_forward_context ({}, vllm_config = vllm_config ):
90
- # first run is for compile
91
- mod_A (torch .randn (BATCH_SIZE , MLP_SIZE ).cuda ())
92
- # run cudagraph captured sizes
93
- mod_A (torch .randn (2 , MLP_SIZE ).cuda ())
94
- mod_A (torch .randn (1 , MLP_SIZE ).cuda ())
123
+ ):
124
+ run_model (vllm_config , mod_A , cudagraph_runtime_mode )
95
125
96
126
with set_current_vllm_config (vllm_config ):
97
127
mod_B = B (vllm_config = vllm_config , prefix = '' ).eval ().cuda ()
@@ -103,10 +133,8 @@ class C(B):
103
133
num_piecewise_capturable_graphs_seen = 0 ,
104
134
num_backend_compilations = 0 ,
105
135
num_cudagraph_captured = 0 ,
106
- ), set_forward_context ({}, vllm_config = vllm_config ):
107
- mod_B (torch .randn (BATCH_SIZE , MLP_SIZE ).cuda ())
108
- mod_B (torch .randn (2 , MLP_SIZE ).cuda ())
109
- mod_B (torch .randn (1 , MLP_SIZE ).cuda ())
136
+ ):
137
+ run_model (vllm_config , mod_B , cudagraph_runtime_mode )
110
138
111
139
with set_current_vllm_config (vllm_config ):
112
140
mod_C = C (vllm_config = vllm_config , prefix = '' ).eval ().cuda ()
@@ -119,10 +147,8 @@ class C(B):
119
147
num_backend_compilations = 2 ,
120
148
num_cudagraph_captured = 4 ,
121
149
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
122
- ), set_forward_context ({}, vllm_config = vllm_config ):
123
- mod_C (torch .randn (BATCH_SIZE , MLP_SIZE ).cuda ())
124
- mod_C (torch .randn (2 , MLP_SIZE ).cuda ())
125
- mod_C (torch .randn (1 , MLP_SIZE ).cuda ())
150
+ ):
151
+ run_model (vllm_config , mod_C , cudagraph_runtime_mode )
126
152
127
153
128
154
# Only enable torch.compile if
@@ -180,6 +206,7 @@ def test_conditional_compile_enable_if():
180
206
splitting_ops = ["silly.attention" ],
181
207
cudagraph_capture_sizes = [1 , 2 ],
182
208
))
209
+ cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE
183
210
184
211
with set_current_vllm_config (vllm_config ):
185
212
mod_A = A (vllm_config = vllm_config , prefix = '' ).eval ().cuda ()
@@ -195,12 +222,8 @@ def test_conditional_compile_enable_if():
195
222
num_backend_compilations = 4 ,
196
223
num_cudagraph_captured = 8 ,
197
224
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
198
- ), set_forward_context ({}, vllm_config = vllm_config ):
199
- # first run is for compile
200
- mod_A (torch .randn (BATCH_SIZE , MLP_SIZE ).cuda ())
201
- # run cudagraph captured sizes
202
- mod_A (torch .randn (2 , MLP_SIZE ).cuda ())
203
- mod_A (torch .randn (1 , MLP_SIZE ).cuda ())
225
+ ):
226
+ run_model (vllm_config , mod_A , cudagraph_runtime_mode )
204
227
205
228
# Set kv_sharing_fast_prefill=False
206
229
# which will cause A to be compiled and B to not be compiled
@@ -224,9 +247,5 @@ def test_conditional_compile_enable_if():
224
247
num_backend_compilations = 4 ,
225
248
num_cudagraph_captured = 8 ,
226
249
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
227
- ), set_forward_context ({}, vllm_config = vllm_config ):
228
- # first run is for compile
229
- mod_A (torch .randn (BATCH_SIZE , MLP_SIZE ).cuda ())
230
- # run cudagraph captured sizes
231
- mod_A (torch .randn (2 , MLP_SIZE ).cuda ())
232
- mod_A (torch .randn (1 , MLP_SIZE ).cuda ())
250
+ ):
251
+ run_model (vllm_config , mod_A , cudagraph_runtime_mode )
0 commit comments