diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 2238924c1b50..9d82bec19b6b 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -387,3 +387,63 @@ def test_distributed_correctness( name_0="vllm_tp_1", name_1="vllm_tp_2", ) + + +@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_full_cuda_graph( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + num_logprobs: int, +) -> None: + + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + with hf_runner(model) as hf_model: + if model not in HF_UNSUPPORTED_MODELS: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + else: + hf_outputs = None + + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_v0_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + if model in HYBRID_MODELS: + # required due to reorder_batch behaviour + m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") + with vllm_runner(model, + max_num_seqs=MAX_NUM_SEQS, + compilation_config={'full_cuda_graph': True}, + enable_prefix_caching=False) as vllm_model: + vllm_v1_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + if hf_outputs is not None: + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_v0_outputs, + name_0="hf", + name_1="vllm-v0", + ) + + ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs + check_logprobs_close( + outputs_0_lst=ref_outputs, + outputs_1_lst=vllm_v1_outputs, + name_0="hf" if hf_outputs is not None else "vllm-v0", + name_1="vllm-v1", + ) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 66a8d91db89c..f86acbf6d8b4 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -8,7 +8,8 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, CommonAttentionMetadata, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec @@ -82,6 +83,8 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( AttentionMetadataBuilder[Mamba2AttentionMetadata]): + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.PURE_DECODE_ONLY reorder_batch_threshold: ClassVar[int] = 1 @@ -160,3 +163,23 @@ def build(self, state_indices_tensor=state_indices_tensor, ) return attn_metadata + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata): + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with Mamba. + """ + m = common_attn_metadata + + assert m.num_reqs == m.num_actual_tokens, \ + "Mamba only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + + m.max_query_len = 1 # decode-only + + return self.build(0, m) + + def can_run_in_cudagraph( + self, common_attn_metadata: CommonAttentionMetadata) -> bool: + return common_attn_metadata.max_query_len == 1