Skip to content

Conversation

NorthmanPKU
Copy link

@NorthmanPKU NorthmanPKU commented Oct 20, 2025

Purpose

See #22201 and mirage-project/mirage#522.

TODO:


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the v1 label Oct 20, 2025
@NorthmanPKU NorthmanPKU marked this pull request as draft October 20, 2025 17:45
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 80 to 83
vllm_config: VllmConfig,
forward_context: ForwardContext,
args: list[Any],
transfered_tensor_names: list[str],

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Import typing.Any to avoid module load failure

The new backend annotates several parameters with list[Any]/-> Any but never imports Any from typing. Because annotations are evaluated eagerly, importing this module will immediately raise NameError: name 'Any' is not defined before the backend can be used, even when all other dependencies are present. Add the missing import so the module can load.

Useful? React with 👍 / 👎.

Comment on lines 254 to 258
assert self.q_data_type == torch.bfloat16, "MirageAttentionBackend currently only supports bfloat16"

attn_metadata = MirageAttentionMetadata(
num_actual_tokens=num_actual_tokens,
q_data_type=self.q_data_type,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Initialize q_data_type before using in metadata builder

MirageAttentionMetadataBuilder.build asserts self.q_data_type == torch.bfloat16 and passes self.q_data_type into the metadata, but the constructor never defines self.q_data_type (unlike other builders that set it from the model config or cache dtype). The first call to build will therefore raise AttributeError: 'MirageAttentionMetadataBuilder' object has no attribute 'q_data_type'. Set this attribute during initialization based on the model/cache dtype before using it.

Useful? React with 👍 / 👎.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces experimental support for the Mirage (MPK) compiler as a new execution backend in vLLM. This involves adding a new compilation backend (MirageBackend) that integrates with torch.compile, as well as a corresponding attention backend (MirageAttentionBackend).

My review focuses on the correctness and maintainability of the new integration. I've identified a critical issue in the attention backend implementation that could lead to incorrect model outputs, as well as some high-severity maintainability concerns in the compilation backend code that should be addressed to improve clarity and prevent future bugs.

Key feedback points:

  • The MirageAttentionImpl returns uninitialized data, which is a critical bug.
  • The new mirage_backend.py file has a misleading docstring and uses a wildcard import, which harms maintainability.

Overall, this is a significant feature addition, and addressing these points will help ensure its stability and correctness as it moves from a work-in-progress to a stable feature.

Comment on lines +289 to +320
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: MirageAttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass that do nothing.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: KV cache tensor with different possible shapes:
- NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
- HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."

if attn_metadata is None:
# Profiling run.
return output.fill_(0)

return output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The forward method of MirageAttentionImpl currently returns the output tensor without modification. Since the output tensor is typically initialized with torch.empty_like, this means the method returns uninitialized memory, which will lead to incorrect model outputs and behavior.

This seems to be a placeholder implementation, as indicated by the docstring "Forward pass that do nothing" and the empty __init__ method. If the intention is for the Mirage compiler to handle the entire model execution (including attention), this attention implementation should likely not be used or called at all. If it is meant to be used, it must be implemented to correctly compute attention. Returning uninitialized data is a critical bug.

Comment on lines 1 to 9
from collections import defaultdict
from .backends import *
from mirage import MPK, MPKMetadata, MirageModelConfig
import re
from vllm.config import ModelConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.models.utils import extract_layer_index
import torch
from vllm.logger import init_logger
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using a wildcard import (from .backends import *) makes it difficult to determine the origin of names, which can harm code readability and maintainability. It's better to import names explicitly.

This file implicitly imports VllmConfig, CompilationConfig, PostGradPassManager, compilation_counter, model_tag, time, Any, and torch.fx through the wildcard. This can be confusing and may lead to issues if vllm/compilation/backends.py changes.

Please replace the wildcard import with explicit imports for better clarity.

from collections import defaultdict
import re
import time
from typing import Any

import torch
import torch.fx as fx
from mirage import MPK, MPKMetadata, MirageModelConfig

from vllm.compilation.backends import model_tag
from vllm.compilation.counter import compilation_counter
from vllm.compilation.pass_manager import PostGradPassManager
from vllm.config import CompilationConfig, ModelConfig, VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.models.utils import extract_layer_index

Comment on lines +173 to +175
The major work of this backend is to split the graph into
piecewise graphs, and pass them to the piecewise backend.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The docstring for MirageBackend states that its major work is to split the graph into piecewise graphs. However, the implementation does not perform any graph splitting; it seems to compile the graph it receives as a whole. This docstring appears to be copied from VllmBackend and is misleading. Please update it to accurately describe what MirageBackend does.

from vllm.compilation.backends import VllmBackend

return VllmBackend(vllm_config)
# Default piecewise compilation path (vLLM backend driving Inductor/Eager)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of this approach, Mirage should use CompilationMode.DYNAMO_TRACE_ONCE and then be directly passed to torch.compile(backend=...), like it is now. VLLM_COMPILE implies VllmBackend. Later, if we want to add support for caching and any other VllmBackend optimizations, we can add this code path and initialize VllmBackend which will use a nested MirageBackend for the actual compilation like it currently uses InductorAdaptor for backend="inductor", but that will come much later.

self.compiled = True

logger.info(f"[Mirage] Calling the compiled result...")
return self.mpk()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we try to pass at least the input tensors to mpk here? Or do those have to be static too? If yes, I'm not sure that's guaranteed by the GPUModelRunner

]:
if self.backend in torch_backends:
return self.backend
return resolve_obj_by_qualname(self.backend)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return resolve_obj_by_qualname(self.backend)
if self.backend == "mirage":
from vllm.compilation.mirage_backend import MirageBackend
return MirageBackend(vllm_config)
return resolve_obj_by_qualname(self.backend)

@ProExpertProg
Copy link
Collaborator

Nice work! I know this is still a draft but it's going in the right direction - I left some comments, ping me when this is ready for full review, and I'll try it out locally as well!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants