-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
[Backend][WIP] Integrate MPK (Mirage) compiler as an experimental execution backend to vLLM #27218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
vllm_config: VllmConfig, | ||
forward_context: ForwardContext, | ||
args: list[Any], | ||
transfered_tensor_names: list[str], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import typing.Any to avoid module load failure
The new backend annotates several parameters with list[Any]
/-> Any
but never imports Any
from typing
. Because annotations are evaluated eagerly, importing this module will immediately raise NameError: name 'Any' is not defined
before the backend can be used, even when all other dependencies are present. Add the missing import so the module can load.
Useful? React with 👍 / 👎.
vllm/v1/attention/backends/mirage.py
Outdated
assert self.q_data_type == torch.bfloat16, "MirageAttentionBackend currently only supports bfloat16" | ||
|
||
attn_metadata = MirageAttentionMetadata( | ||
num_actual_tokens=num_actual_tokens, | ||
q_data_type=self.q_data_type, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initialize q_data_type before using in metadata builder
MirageAttentionMetadataBuilder.build
asserts self.q_data_type == torch.bfloat16
and passes self.q_data_type
into the metadata, but the constructor never defines self.q_data_type
(unlike other builders that set it from the model config or cache dtype). The first call to build
will therefore raise AttributeError: 'MirageAttentionMetadataBuilder' object has no attribute 'q_data_type'
. Set this attribute during initialization based on the model/cache dtype before using it.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces experimental support for the Mirage (MPK) compiler as a new execution backend in vLLM. This involves adding a new compilation backend (MirageBackend
) that integrates with torch.compile
, as well as a corresponding attention backend (MirageAttentionBackend
).
My review focuses on the correctness and maintainability of the new integration. I've identified a critical issue in the attention backend implementation that could lead to incorrect model outputs, as well as some high-severity maintainability concerns in the compilation backend code that should be addressed to improve clarity and prevent future bugs.
Key feedback points:
- The
MirageAttentionImpl
returns uninitialized data, which is a critical bug. - The new
mirage_backend.py
file has a misleading docstring and uses a wildcard import, which harms maintainability.
Overall, this is a significant feature addition, and addressing these points will help ensure its stability and correctness as it moves from a work-in-progress to a stable feature.
def forward( | ||
self, | ||
layer: torch.nn.Module, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
kv_cache: torch.Tensor, | ||
attn_metadata: MirageAttentionMetadata, | ||
output: torch.Tensor | None = None, | ||
output_scale: torch.Tensor | None = None, | ||
output_block_scale: torch.Tensor | None = None, | ||
) -> torch.Tensor: | ||
"""Forward pass that do nothing. | ||
Args: | ||
query: shape = [num_tokens, num_heads, head_size] | ||
key: shape = [num_tokens, num_kv_heads, head_size] | ||
value: shape = [num_tokens, num_kv_heads, head_size] | ||
kv_cache: KV cache tensor with different possible shapes: | ||
- NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] | ||
- HND: [num_blocks, 2, num_kv_heads, block_size, head_size] | ||
attn_metadata: Metadata for attention. | ||
Returns: | ||
shape = [num_tokens, num_heads * head_size] | ||
""" | ||
assert output is not None, "Output tensor must be provided." | ||
|
||
if attn_metadata is None: | ||
# Profiling run. | ||
return output.fill_(0) | ||
|
||
return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The forward
method of MirageAttentionImpl
currently returns the output
tensor without modification. Since the output
tensor is typically initialized with torch.empty_like
, this means the method returns uninitialized memory, which will lead to incorrect model outputs and behavior.
This seems to be a placeholder implementation, as indicated by the docstring "Forward pass that do nothing" and the empty __init__
method. If the intention is for the Mirage compiler to handle the entire model execution (including attention), this attention implementation should likely not be used or called at all. If it is meant to be used, it must be implemented to correctly compute attention. Returning uninitialized data is a critical bug.
from collections import defaultdict | ||
from .backends import * | ||
from mirage import MPK, MPKMetadata, MirageModelConfig | ||
import re | ||
from vllm.config import ModelConfig, get_current_vllm_config | ||
from vllm.forward_context import ForwardContext, get_forward_context | ||
from vllm.model_executor.models.utils import extract_layer_index | ||
import torch | ||
from vllm.logger import init_logger |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a wildcard import (from .backends import *
) makes it difficult to determine the origin of names, which can harm code readability and maintainability. It's better to import names explicitly.
This file implicitly imports VllmConfig
, CompilationConfig
, PostGradPassManager
, compilation_counter
, model_tag
, time
, Any
, and torch.fx
through the wildcard. This can be confusing and may lead to issues if vllm/compilation/backends.py
changes.
Please replace the wildcard import with explicit imports for better clarity.
from collections import defaultdict
import re
import time
from typing import Any
import torch
import torch.fx as fx
from mirage import MPK, MPKMetadata, MirageModelConfig
from vllm.compilation.backends import model_tag
from vllm.compilation.counter import compilation_counter
from vllm.compilation.pass_manager import PostGradPassManager
from vllm.config import CompilationConfig, ModelConfig, VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.models.utils import extract_layer_index
The major work of this backend is to split the graph into | ||
piecewise graphs, and pass them to the piecewise backend. | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for MirageBackend
states that its major work is to split the graph into piecewise graphs. However, the implementation does not perform any graph splitting; it seems to compile the graph it receives as a whole. This docstring appears to be copied from VllmBackend
and is misleading. Please update it to accurately describe what MirageBackend
does.
vllm/config/compilation.py
Outdated
from vllm.compilation.backends import VllmBackend | ||
|
||
return VllmBackend(vllm_config) | ||
# Default piecewise compilation path (vLLM backend driving Inductor/Eager) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of this approach, Mirage should use CompilationMode.DYNAMO_TRACE_ONCE
and then be directly passed to torch.compile(backend=...)
, like it is now. VLLM_COMPILE
implies VllmBackend
. Later, if we want to add support for caching and any other VllmBackend
optimizations, we can add this code path and initialize VllmBackend which will use a nested MirageBackend for the actual compilation like it currently uses InductorAdaptor
for backend="inductor"
, but that will come much later.
self.compiled = True | ||
|
||
logger.info(f"[Mirage] Calling the compiled result...") | ||
return self.mpk() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we try to pass at least the input tensors to mpk here? Or do those have to be static too? If yes, I'm not sure that's guaranteed by the GPUModelRunner
]: | ||
if self.backend in torch_backends: | ||
return self.backend | ||
return resolve_obj_by_qualname(self.backend) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return resolve_obj_by_qualname(self.backend) | |
if self.backend == "mirage": | |
from vllm.compilation.mirage_backend import MirageBackend | |
return MirageBackend(vllm_config) | |
return resolve_obj_by_qualname(self.backend) |
Nice work! I know this is still a draft but it's going in the right direction - I left some comments, ping me when this is ready for full review, and I'll try it out locally as well! |
Purpose
See #22201 and mirage-project/mirage#522.
TODO:
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.