From 01f10d620a4ce9361de2aa94bfcd989101d6d8f2 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Sat, 2 Aug 2025 04:13:17 -0400 Subject: [PATCH 1/4] Enable decode-only FCG for mamba Signed-off-by: Thomas Parnell --- vllm/v1/attention/backends/mamba_attn.py | 26 ++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 8b702e28d67c..f99694a4cec5 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -2,14 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, ClassVar, Optional import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, + AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec @@ -86,6 +86,8 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( AttentionMetadataBuilder[Mamba2AttentionMetadata]): + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.PURE_DECODE_ONLY def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -168,3 +170,23 @@ def build(self, state_indices_tensor=state_indices_tensor, ) return attn_metadata + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata): + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with Mamba. + """ + m = common_attn_metadata + + assert m.num_reqs == m.num_actual_tokens, \ + "Mamba only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + + m.max_query_len = 1 # decode-only + + return self.build(0, m) + + def can_run_in_cudagraph( + self, common_attn_metadata: CommonAttentionMetadata) -> bool: + return common_attn_metadata.max_query_len == 1 From 17ab125b7ad37dc296f5fae06dd1281462e50cdb Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Sat, 2 Aug 2025 05:09:32 -0400 Subject: [PATCH 2/4] Add FCG test Signed-off-by: Thomas Parnell --- .../models/language/generation/test_hybrid.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 2238924c1b50..8144c75aa536 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -387,3 +387,65 @@ def test_distributed_correctness( name_0="vllm_tp_1", name_1="vllm_tp_2", ) + + +@pytest.mark.parametrize( + "model", + ["nvidia/Nemotron-H-8B-Base-8K", "mistralai/Mamba-Codestral-7B-v0.1"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_full_cuda_graph( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + num_logprobs: int, +) -> None: + + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + with hf_runner(model) as hf_model: + if model not in HF_UNSUPPORTED_MODELS: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + else: + hf_outputs = None + + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_v0_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + if model in HYBRID_MODELS: + # required due to reorder_batch behaviour + m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") + with vllm_runner(model, + max_num_seqs=MAX_NUM_SEQS, + compilation_config={'full_cuda_graph': True}, + enable_prefix_caching=False) as vllm_model: + vllm_v1_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + if hf_outputs is not None: + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_v0_outputs, + name_0="hf", + name_1="vllm-v0", + ) + + ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs + check_logprobs_close( + outputs_0_lst=ref_outputs, + outputs_1_lst=vllm_v1_outputs, + name_0="hf" if hf_outputs is not None else "vllm-v0", + name_1="vllm-v1", + ) From c3f2224223dc3fb0836576c837ffe5e9da97c6d1 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 4 Aug 2025 17:49:19 -0400 Subject: [PATCH 3/4] Use smaller model in test that is supported by transformers Signed-off-by: Thomas Parnell --- tests/models/language/generation/test_hybrid.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 8144c75aa536..9d82bec19b6b 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -389,9 +389,7 @@ def test_distributed_correctness( ) -@pytest.mark.parametrize( - "model", - ["nvidia/Nemotron-H-8B-Base-8K", "mistralai/Mamba-Codestral-7B-v0.1"]) +@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) def test_full_cuda_graph( From 2868ba7a17ba84d9f12a52b98e654470398e6943 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Sat, 9 Aug 2025 12:13:33 -0400 Subject: [PATCH 4/4] Fix padding issue Signed-off-by: Thomas Parnell --- vllm/v1/attention/backends/mamba_attn.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index f86acbf6d8b4..7c1226049f69 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -7,6 +7,7 @@ import torch from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig from vllm.v1.attention.backends.utils import (AttentionCGSupport, AttentionMetadataBuilder, @@ -93,8 +94,18 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], assert isinstance(kv_cache_spec, MambaSpec) self.kv_cache_spec = kv_cache_spec self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config assert self.chunk_size is not None, ( "chunk_size needs to be set in the model config for Mamba2 models") + self.decode_cudagraph_max_bs = min( + self.vllm_config.scheduler_config.max_num_seqs, + self.compilation_config.max_capture_size) + self.state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.int32, + device=device, + ) def build(self, common_prefix_len: int, @@ -147,6 +158,14 @@ def build(self, query_start_loc_p, self.chunk_size, num_prefill_tokens)) + elif num_decodes <= self.decode_cudagraph_max_bs: + # Pad state tensor for CUDA graph + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes) + self.state_indices_tensor[:num_decodes].copy_(state_indices_tensor, + non_blocking=True) + state_indices_tensor = self.state_indices_tensor[:num_input_tokens] + state_indices_tensor[num_decodes:] = PAD_SLOT_ID + attn_metadata = Mamba2AttentionMetadata( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens,