@@ -4006,7 +4006,7 @@ def _capture_model(self):
40064006 if aclgraph_mode .mixed_mode () != CUDAGraphMode .NONE :
40074007 aclgraph_runtime_mode = aclgraph_mode .mixed_mode ()
40084008
4009- compilation_cases = list ( reversed ( self .aclgraph_batch_sizes ) )
4009+ compilation_cases = sorted ( self .aclgraph_batch_sizes )
40104010
40114011 try :
40124012 self ._capture_aclgraphs (
@@ -4041,8 +4041,17 @@ def _capture_model(self):
40414041 x for x in self .aclgraph_batch_sizes if x <= max_num_tokens
40424042 and x >= self .uniform_decode_query_len
40434043 ]
4044- compilation_cases_decode = list (
4045- reversed (decode_cudagraph_batch_sizes ))
4044+ compilation_cases_decode = sorted (decode_cudagraph_batch_sizes )
4045+ # TODO: refactor this when vLLM supports mtp>1
4046+ if not all (x % self .uniform_decode_query_len == 0
4047+ for x in decode_cudagraph_batch_sizes ):
4048+ raise ValueError (
4049+ "In the MTP fullgraph scenario, each graph size must be an integer multiple of "
4050+ f"(num_speculative_tokens + 1): { self .uniform_decode_query_len } . "
4051+ f"Please modify the cudagraph_capture_sizes variable to be integer multiple of { self .uniform_decode_query_len } , "
4052+ f"while ensuring the maximum cudagraph_capture_sizes does not exceed max_num_seqs * (num_speculative_tokens + 1): { max_num_tokens } . "
4053+ "For example, with MTP=2 and max_num_seqs=16, we recommend setting cudagraph_capture_sizes to [48]."
4054+ )
40464055 self ._capture_aclgraphs (
40474056 compilation_cases = compilation_cases_decode ,
40484057 aclgraph_runtime_mode = CUDAGraphMode .FULL ,
0 commit comments