@@ -349,6 +349,12 @@ def update_cudagraph_capture_sizes(vllm_config: VllmConfig,
349349
350350def update_aclgraph_sizes (vllm_config : VllmConfig ) -> None :
351351 """Update ACL graph capture sizes based on hardware limitations"""
352+ from vllm .config .compilation import CUDAGraphMode
353+ if vllm_config .compilation_config .cudagraph_mode == CUDAGraphMode .FULL_DECODE_ONLY :
354+ if vllm_config .speculative_config is not None and \
355+ vllm_config .speculative_config .num_speculative_tokens > 1 :
356+ _update_spec_aclgraph_sizes (vllm_config )
357+ return
352358 # NOTE: Currently, we can only capture 1800 graphs at most,
353359 # due to the limitation of ACL graph. This number is bounded by
354360 # the number of streams, which is 2048, we save 248 streams
@@ -459,28 +465,51 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
459465 vllm_config .model_config .architectures [0 ], num_hidden_layers ,
460466 len (original_sizes ))
461467
468+ if vllm_config .speculative_config is not None and \
469+ vllm_config .speculative_config .num_speculative_tokens > 1 :
470+ _update_spec_aclgraph_sizes (vllm_config )
471+
472+
473+ def _update_spec_aclgraph_sizes (vllm_config : VllmConfig ) -> None :
462474 # default or defined cudagraph_capture_sizes may not consider num_speculative_tokens>1 scenario
463475 # the maximum size cudagraph_capture_sizes[0] should be greater or equal than
464476 # (num_speculative_tokens+1)*max_num_seqs, otherwise draft model will run in eager mode
465- if vllm_config .speculative_config is not None and \
466- vllm_config .speculative_config .num_speculative_tokens > 1 :
467- num_speculative_tokens = vllm_config .speculative_config .num_speculative_tokens
468- max_num_seqs = vllm_config .scheduler_config .max_num_seqs
469- original_sizes , compilation_config .cudagraph_capture_sizes = \
470- compilation_config .cudagraph_capture_sizes , None
471- assert len (original_sizes ) > 0
472- if original_sizes [0 ] < (num_speculative_tokens + 1 ) * max_num_seqs :
473- enlarged_sizes = [(num_speculative_tokens + 1 ) * size
474- for size in original_sizes ]
475- if vllm_version_is ("0.11.0" ):
476- compilation_config .init_with_cudagraph_sizes (enlarged_sizes )
477- else :
478- update_cudagraph_capture_sizes (vllm_config , enlarged_sizes )
479- logger .info (
480- "Adjusted ACL graphs: %s → %s for speculative decoding" ,
481- original_sizes , enlarged_sizes )
477+ from vllm .config .compilation import CUDAGraphMode
478+ compilation_config = vllm_config .compilation_config
479+ num_speculative_tokens = vllm_config .speculative_config .num_speculative_tokens
480+ uniform_decode_query_len = num_speculative_tokens + 1
481+ max_num_seqs = vllm_config .scheduler_config .max_num_seqs
482+ max_num_tokens = max_num_seqs * uniform_decode_query_len
483+ original_sizes , compilation_config .cudagraph_capture_sizes = \
484+ compilation_config .cudagraph_capture_sizes , None
485+
486+ assert len (original_sizes ) > 0
487+
488+ if vllm_config .compilation_config .cudagraph_mode == CUDAGraphMode .FULL_DECODE_ONLY and \
489+ not all (size % uniform_decode_query_len == 0 for size in original_sizes ):
490+ enlarged_sizes = [
491+ size * uniform_decode_query_len for size in original_sizes
492+ if size >= uniform_decode_query_len and size *
493+ uniform_decode_query_len <= max_num_tokens
494+ ]
495+ if vllm_version_is ("0.11.0" ):
496+ compilation_config .init_with_cudagraph_sizes (enlarged_sizes )
482497 else :
483- compilation_config .cudagraph_capture_sizes = original_sizes
498+ update_cudagraph_capture_sizes (vllm_config , enlarged_sizes )
499+ logger .info ("Adjusted ACL graphs: %s → %s for speculative decoding" ,
500+ original_sizes , enlarged_sizes )
501+ elif original_sizes [0 ] < max_num_tokens :
502+ enlarged_sizes = [
503+ size * uniform_decode_query_len for size in original_sizes
504+ ]
505+ if vllm_version_is ("0.11.0" ):
506+ compilation_config .init_with_cudagraph_sizes (enlarged_sizes )
507+ else :
508+ update_cudagraph_capture_sizes (vllm_config , enlarged_sizes )
509+ logger .info ("Adjusted ACL graphs: %s → %s for speculative decoding" ,
510+ original_sizes , enlarged_sizes )
511+ else :
512+ compilation_config .cudagraph_capture_sizes = original_sizes
484513
485514
486515# TODO(wxy): Move to ops module
0 commit comments