Support multimodal rotary embedding#192
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds a multi-modal rotary embedding (M-RoPE) SYCL kernel for XPU, used by models like Qwen2-VL and Qwen3-omni that partition rotation dimensions across positional axes (e.g., temporal/height/width).
Changes:
- New SYCL kernel
multimodal_rotary_embedding_kernelthat builds a per-token merged cos/sin cache from per-section positions, then delegates to the existingapply_rotary_embeddinghelper - Torch bindings and op registration for
multimodal_rotary_embedding - Comprehensive test file comparing kernel output against a pure-Python reference, including a test that single-section M-RoPE matches standard RoPE
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| csrc/pos_encoding_kernels.cpp | New M-RoPE kernel class and host-side launch function |
| csrc/torch_bindings.cpp | Op schema registration and XPU dispatch for the new kernel |
| csrc/ops.h | Declaration of multimodal_rotary_embedding |
| tests/register_ops.py | Python wrapper for the new op |
| tests/test_multimodal_rotary_embedding.py | Tests with reference implementation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| key.has_value() ? key_hidden_size / head_size : num_heads; | ||
| TORCH_CHECK(num_heads % num_kv_heads == 0); | ||
|
|
||
| const int rot_dim = cos_sin_cache.size(1); |
There was a problem hiding this comment.
Missing validation that rot_dim <= MROPE_MAX_ROT_DIM (512). The kernel allocates a fixed-size merged_cache[MROPE_MAX_ROT_DIM] array in private memory, but there's no TORCH_CHECK in call_multimodal_rotary_embedding_kernel to ensure rot_dim doesn't exceed this limit. This could cause a buffer overflow for models with large rotation dimensions. Please add a check like TORCH_CHECK(rot_dim <= vllm::MROPE_MAX_ROT_DIM, ...).
| const int rot_dim = cos_sin_cache.size(1); | |
| const int rot_dim = cos_sin_cache.size(1); | |
| TORCH_CHECK( | |
| rot_dim <= vllm::MROPE_MAX_ROT_DIM, | |
| "rot_dim exceeds MROPE_MAX_ROT_DIM=", | |
| vllm::MROPE_MAX_ROT_DIM, | |
| ", got rot_dim=", | |
| rot_dim); |
| scalar_t merged_cache[MROPE_MAX_ROT_DIM]; // [cos | sin], size = rot_dim | ||
| int cumsum = 0; | ||
| for (int s = 0; s < num_mrope_sections; ++s) { | ||
| const int lo = cumsum; | ||
| const int hi = lo + mrope_section[s]; | ||
| cumsum = hi; | ||
| const int64_t pos = positions[s * num_tokens + token_idx]; | ||
| const scalar_t* src = cos_sin_cache + pos * rot_dim; | ||
| for (int r = lo; r < hi; ++r) { | ||
| merged_cache[r] = src[r]; // cos slice | ||
| merged_cache[embed_dim + r] = src[embed_dim + r]; // sin slice | ||
| } | ||
| } |
There was a problem hiding this comment.
The merged_cache array is not zero-initialized, and only the ranges covered by mrope_section entries are populated. If the section values don't sum to exactly embed_dim (= rot_dim / 2), apply_rotary_embedding will read uninitialized values from the gaps. Consider either zero-initializing merged_cache or adding a TORCH_CHECK that the section values sum to embed_dim.
jikunshang
left a comment
There was a problem hiding this comment.
please fix dco and pre-commit.
| " Tensor!? key, int head_size," | ||
| " Tensor cos_sin_cache, bool is_neox," | ||
| " int[] mrope_section) -> ()"); | ||
| ops.impl("multimodal_rotary_embedding", torch::kXPU, |
There was a problem hiding this comment.
vllm code base doesn't have this op, we'd better move it to csrc/xpu/ folder
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED.
Purpose
support multimodal rotary embedding used by Qwen3-omni
Test Plan
python -m pytest tests/test_multimodal_rotary_embedding.py -v
Test Result
Pass
(Optional) Documentation Update
BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing (anything written below this line will be removed by GitHub Actions)