Skip to content

Conversation

@peakcrosser7
Copy link
Contributor

@peakcrosser7 peakcrosser7 commented Feb 12, 2026

Purpose

In Mamba cache align mode, prefill requests are required to have a block-aligned number of tokens per scheduling step. If max_num_batch_tokens is smaller than block_size while the request length exceeds the block_size, the _mamba_block_aligned_split() function will return a num_new_tokens of 0 due to these alignment constraints. This prevents the request from ever being scheduled, eventually causing the engine to hang.

This PR adds a validation check to ensure that block_size is not larger than max_num_batch_tokens when Mamba cache align mode is enabled.

Test Plan

import time

from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory


def main():
    MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct"  # gdn
    PROMPT_MULTIPLE = 310
    sampling_params = SamplingParams(temperature=0.0, max_tokens=128)
    prefix = ( # examples/offline_inference/prefix_caching.py
        "You are an expert school principal, skilled in effectively managing "
        "faculty and staff. Draft 10-15 questions for a potential first grade "
        "Head Teacher for my K-12, all-girls', independent school that emphasizes "
        "community, joyful discovery, and life-long learning. The candidate is "
        "coming in for a first-round panel interview for a 8th grade Math "
        "teaching role. They have 5 years of previous teaching experience "
        "as an assistant teacher at a co-ed, public school with experience "
        "in middle school math teaching. ")
    prefix2 = ("Based on these information, fulfill "
                "the following paragraph: ")
    prompt = PROMPT_MULTIPLE * prefix + prefix2 + "Hello, my name is"
    print('Prompt length:', len(prompt))
    for APC in [
        True
    ]:
        engine = LLM(
            model=MODEL, enable_prefix_caching=APC, 
            max_num_batched_tokens=512,   # smaller than block_size=544
            tensor_parallel_size=2,
            gpu_memory_utilization=0.9, 
            disable_log_stats=False,
            mamba_cache_mode="align",
        )
        for i in range(3):
            if i == 0:
                print('Warm-up')
            if i == 1:
                print('Measuring')
                start_time = time.time()
            outputs = engine.generate(prompt, sampling_params)
            print('APC:', APC, i, f"Generated text: {outputs[0].outputs[0].text!r}")
            for m in engine.llm_engine.get_metrics():
                if 'vllm:prefix_cache_hits' in m.name:
                    print(m.name, m.value)
        print('APC:', APC, "loop took --- %s seconds ---" % (time.time() - start_time))
        del engine
        cleanup_dist_env_and_memory()


if __name__ == "__main__":
    main()

Test Result

Before fix: Engine hang.

Warm-up

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]
Adding requests: 100%|██████████| 1/1 [00:00<00:00, 400.95it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

After: Proper validation error.

Traceback (most recent call last):
  File "/root/huanghy/vllm_opsrc/my_tests/test_lpc_offline.py", line 56, in <module>
    main()
  File "/root/huanghy/vllm_opsrc/my_tests/test_lpc_offline.py", line 31, in main
    engine = LLM(
             ^^^^
  File "/root/huanghy/vllm_opsrc/vllm/entrypoints/llm.py", line 346, in __init__
    self.llm_engine = LLMEngine.from_engine_args(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/huanghy/vllm_opsrc/vllm/v1/engine/llm_engine.py", line 166, in from_engine_args
    vllm_config = engine_args.create_engine_config(usage_context)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/huanghy/vllm_opsrc/vllm/engine/arg_utils.py", line 1809, in create_engine_config
    config = VllmConfig(
             ^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/pydantic/_internal/_dataclasses.py", line 121, in __init__
    s.__pydantic_validator__.validate_python(ArgsKwargs(args, kwargs), self_instance=s)
pydantic_core._pydantic_core.ValidationError: 1 validation error for VllmConfig
  Assertion failed, In Mamba cache align mode, block_size (544) must be <= max_num_batched_tokens (512). [type=assertion_error, input_value=ArgsKwargs((), {'model_co...transfer_config': None}), input_type=ArgsKwargs]
    For further information visit https://errors.pydantic.dev/2.12/v/assertion_error

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a crucial validation check to prevent the engine from hanging when using Mamba's cache align mode with an invalid configuration where block_size exceeds max_num_batched_tokens. The fix is correct and well-placed. I've suggested a minor improvement to the assertion's error message to make it more informative for users encountering this configuration error.

Signed-off-by: huanghaoyan.hhy <[email protected]>
Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@heheda12345 heheda12345 enabled auto-merge (squash) February 12, 2026 19:31
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants