Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ void rotary_embedding(
torch::Tensor& cos_sin_cache,
bool is_neox);

void multimodal_rotary_embedding(
torch::Tensor& positions, // [num_mrope_sections, num_tokens]
torch::Tensor& query,
std::optional<torch::Tensor> key,
int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox,
std::vector<int64_t> mrope_section); // host int list [num_mrope_sections]

void reshape_and_cache(
torch::Tensor& key,
torch::Tensor& value,
Expand Down
232 changes: 232 additions & 0 deletions csrc/pos_encoding_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,103 @@ class rotary_embedding_kernel {
const int head_size;
};


// Maximum number of M-RoPE sections supported (e.g. 3 for Qwen2-VL:
// temporal / height / width).
constexpr int MROPE_MAX_SECTIONS = 4;
// Maximum rot_dim supported for the merged cache buffer (cos + sin).
constexpr int MROPE_MAX_ROT_DIM = 512;

template <typename scalar_t, bool IS_NEOX>
class multimodal_rotary_embedding_kernel {
public:
multimodal_rotary_embedding_kernel(
const int64_t* __restrict__ positions_, // [num_mrope_sections,
// num_tokens]
scalar_t* __restrict__ query_,
scalar_t* __restrict__ key_,
const scalar_t* __restrict__ cos_sin_cache_, // [max_position, rot_dim]
const int* mrope_section_data, // host array [num_mrope_sections]
const int num_mrope_sections_,
const int num_tokens_,
const int rot_dim_,
const int64_t query_stride_,
const int64_t key_stride_,
const int64_t head_stride_,
const int num_heads_,
const int num_kv_heads_,
const int head_size_)
: positions(positions_),
query(query_),
key(key_),
cos_sin_cache(cos_sin_cache_),
num_mrope_sections(num_mrope_sections_),
num_tokens(num_tokens_),
rot_dim(rot_dim_),
query_stride(query_stride_),
key_stride(key_stride_),
head_stride(head_stride_),
num_heads(num_heads_),
num_kv_heads(num_kv_heads_),
head_size(head_size_) {
for (int s = 0; s < num_mrope_sections_; ++s)
mrope_section[s] = mrope_section_data[s];
}

void operator() [[sycl::reqd_sub_group_size(32)]] (
const sycl::nd_item<3>& item_ct1) const {
// Each work-group handles one token.
const int token_idx = item_ct1.get_group(2);
const int embed_dim = rot_dim / 2;

// Build cos/sin cache for this token (private memory per thread).
scalar_t merged_cache[MROPE_MAX_ROT_DIM]; // [cos | sin], size = rot_dim
int cumsum = 0;
for (int s = 0; s < num_mrope_sections; ++s) {
const int lo = cumsum;
const int hi = lo + mrope_section[s];
cumsum = hi;
const int64_t pos = positions[s * num_tokens + token_idx];
const scalar_t* src = cos_sin_cache + pos * rot_dim;
for (int r = lo; r < hi; ++r) {
merged_cache[r] = src[r]; // cos slice
merged_cache[embed_dim + r] = src[embed_dim + r]; // sin slice
}
}
Comment on lines +213 to +225
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The merged_cache array is not zero-initialized, and only the ranges covered by mrope_section entries are populated. If the section values don't sum to exactly embed_dim (= rot_dim / 2), apply_rotary_embedding will read uninitialized values from the gaps. Consider either zero-initializing merged_cache or adding a TORCH_CHECK that the section values sum to embed_dim.

Copilot uses AI. Check for mistakes.
const scalar_t* cache_ptr = merged_cache;

apply_rotary_embedding<scalar_t, IS_NEOX>(
query,
key,
cache_ptr,
head_size,
num_heads,
num_kv_heads,
rot_dim,
token_idx,
query_stride,
key_stride,
head_stride,
item_ct1);
}

private:
const int64_t* __restrict__ positions;
scalar_t* __restrict__ query;
scalar_t* __restrict__ key;
const scalar_t* __restrict__ cos_sin_cache;
int mrope_section[MROPE_MAX_SECTIONS]; // embedded, copied from host list
const int num_mrope_sections;
const int num_tokens;
const int rot_dim;
const int64_t query_stride;
const int64_t key_stride;
const int64_t head_stride;
const int num_heads;
const int num_kv_heads;
const int head_size;
};

} // namespace vllm

template <typename scalar_t>
Expand Down Expand Up @@ -284,3 +381,138 @@ void rotary_embedding(
positions, query, key, head_size, cos_sin_cache, is_neox);
});
}

// ── Multi-Modal Rotary Embedding (M-RoPE) ──────────────────────────────────
// Used by models such as Qwen2-VL that need per-section position encoding.
//
// positions : [num_mrope_sections, num_tokens] int64, on device
// query : [num_tokens, num_heads * head_size] or
// [num_tokens, num_heads, head_size]
// key : same shapes as query but kv_heads, or nullopt
// cos_sin_cache : [max_position, rot_dim]
// mrope_section : [num_mrope_sections] int32, on device;
// values in embed_dim units summing to rot_dim / 2

template <typename scalar_t>
void call_multimodal_rotary_embedding_kernel(
torch::Tensor& positions,
torch::Tensor& query,
std::optional<torch::Tensor> key,
int64_t head_size,
torch::Tensor& cos_sin_cache,
bool is_neox,
const std::vector<int64_t>& mrope_section) {
using sycl_t = typename vllm::xpu::SyclTypeTrait<scalar_t>::Type;

TORCH_CHECK(
positions.dim() == 2,
"positions must have shape [num_mrope_sections, num_tokens]");
const int num_mrope_sections = positions.size(0);
const int64_t num_tokens = positions.size(1);

TORCH_CHECK(
(int)mrope_section.size() == num_mrope_sections,
"mrope_section length must equal positions.size(0)");
TORCH_CHECK(
num_mrope_sections <= vllm::MROPE_MAX_SECTIONS,
"num_mrope_sections exceeds MROPE_MAX_SECTIONS=",
vllm::MROPE_MAX_SECTIONS);

const int query_hidden_size = query.numel() / num_tokens;
const int key_hidden_size =
key.has_value() ? key->numel() / num_tokens : 0;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0);

const int num_heads = query_hidden_size / head_size;
const int num_kv_heads =
key.has_value() ? key_hidden_size / head_size : num_heads;
TORCH_CHECK(num_heads % num_kv_heads == 0);

const int rot_dim = cos_sin_cache.size(1);
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing validation that rot_dim <= MROPE_MAX_ROT_DIM (512). The kernel allocates a fixed-size merged_cache[MROPE_MAX_ROT_DIM] array in private memory, but there's no TORCH_CHECK in call_multimodal_rotary_embedding_kernel to ensure rot_dim doesn't exceed this limit. This could cause a buffer overflow for models with large rotation dimensions. Please add a check like TORCH_CHECK(rot_dim <= vllm::MROPE_MAX_ROT_DIM, ...).

Suggested change
const int rot_dim = cos_sin_cache.size(1);
const int rot_dim = cos_sin_cache.size(1);
TORCH_CHECK(
rot_dim <= vllm::MROPE_MAX_ROT_DIM,
"rot_dim exceeds MROPE_MAX_ROT_DIM=",
vllm::MROPE_MAX_ROT_DIM,
", got rot_dim=",
rot_dim);

Copilot uses AI. Check for mistakes.

// query is always [num_tokens, ...] in the M-RoPE path.
const int64_t query_stride = query.stride(0);
const int64_t key_stride = key.has_value() ? key->stride(0) : 0;
const int query_ndim = query.dim();
// For [num_tokens, num_heads, head_size] use stride(-2), else head_size.
const int64_t head_stride =
(query_ndim == 3) ? query.stride(-2) : head_size;

// Ensure positions is contiguous so that raw pointer arithmetic
// s * num_tokens + t correctly addresses positions[s, t].
at::Tensor positions_contig = positions.contiguous();
auto positions_ptr = positions_contig.data_ptr<int64_t>();
auto query_ptr = query.data_ptr<scalar_t>();
auto key_ptr = key.has_value() ? key->data_ptr<scalar_t>() : nullptr;
auto cos_sin_cache_ptr = cos_sin_cache.data_ptr<scalar_t>();

// Convert int64 list to int array for the kernel.
int mrope_section_arr[vllm::MROPE_MAX_SECTIONS] = {};
for (int s = 0; s < num_mrope_sections; ++s)
mrope_section_arr[s] = static_cast<int>(mrope_section[s]);

sycl::range<3> grid(1, 1, num_tokens);
sycl::range<3> block(1, 1, std::min<int64_t>(num_heads * rot_dim / 2, 512));

at::DeviceGuard device_guard(query.device());
auto& queue = vllm::xpu::vllmGetQueue();
if (is_neox) {
queue.submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl::nd_range<3>(grid * block, block),
vllm::multimodal_rotary_embedding_kernel<sycl_t, true>(
positions_ptr,
(sycl_t*)query_ptr,
(sycl_t*)key_ptr,
(sycl_t*)cos_sin_cache_ptr,
mrope_section_arr,
num_mrope_sections,
num_tokens,
rot_dim,
query_stride,
key_stride,
head_stride,
num_heads,
num_kv_heads,
head_size));
});
} else {
queue.submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl::nd_range<3>(grid * block, block),
vllm::multimodal_rotary_embedding_kernel<sycl_t, false>(
positions_ptr,
(sycl_t*)query_ptr,
(sycl_t*)key_ptr,
(sycl_t*)cos_sin_cache_ptr,
mrope_section_arr,
num_mrope_sections,
num_tokens,
rot_dim,
query_stride,
key_stride,
head_stride,
num_heads,
num_kv_heads,
head_size));
});
}
}

void multimodal_rotary_embedding(
torch::Tensor& positions, // [num_mrope_sections, num_tokens]
torch::Tensor& query,
std::optional<torch::Tensor> key,
int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox,
std::vector<int64_t> mrope_section) // host int list [num_mrope_sections]
{
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(), "multimodal_rotary_embedding", [&] {
call_multimodal_rotary_embedding_kernel<scalar_t>(
positions, query, key, head_size, cos_sin_cache, is_neox,
mrope_section);
});
}
12 changes: 12 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kXPU, &rotary_embedding);

// Multi-modal Rotary Embedding (M-RoPE) — used by e.g. Qwen2-VL.
// positions has shape [num_mrope_sections, num_tokens]; mrope_section is
// an int32 device tensor of length num_mrope_sections that partitions the
// rotation dimensions across positional axes (e.g. time / height / width).
ops.def(
"multimodal_rotary_embedding(Tensor positions, Tensor! query,"
" Tensor!? key, int head_size,"
" Tensor cos_sin_cache, bool is_neox,"
" int[] mrope_section) -> ()");
ops.impl("multimodal_rotary_embedding", torch::kXPU,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm code base doesn't have this op, we'd better move it to csrc/xpu/ folder

&multimodal_rotary_embedding);

// Compute FP8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale, "
Expand Down
13 changes: 13 additions & 0 deletions tests/register_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@ def rotary_embedding(
torch.ops._C.rotary_embedding(positions, query, key, head_size,
cos_sin_cache, is_neox)

def multimodal_rotary_embedding(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
mrope_section: list[int],
) -> None:
torch.ops._C.multimodal_rotary_embedding(positions, query, key,
head_size, cos_sin_cache,
is_neox, mrope_section)


def deepseek_scaling_rope(
positions: torch.Tensor,
Expand Down
Loading
Loading