-
Notifications
You must be signed in to change notification settings - Fork 31
Support multimodal rotary embedding #192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -160,6 +160,103 @@ class rotary_embedding_kernel { | |||||||||||||||||
| const int head_size; | ||||||||||||||||||
| }; | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| // Maximum number of M-RoPE sections supported (e.g. 3 for Qwen2-VL: | ||||||||||||||||||
| // temporal / height / width). | ||||||||||||||||||
| constexpr int MROPE_MAX_SECTIONS = 4; | ||||||||||||||||||
| // Maximum rot_dim supported for the merged cache buffer (cos + sin). | ||||||||||||||||||
| constexpr int MROPE_MAX_ROT_DIM = 512; | ||||||||||||||||||
|
|
||||||||||||||||||
| template <typename scalar_t, bool IS_NEOX> | ||||||||||||||||||
| class multimodal_rotary_embedding_kernel { | ||||||||||||||||||
| public: | ||||||||||||||||||
| multimodal_rotary_embedding_kernel( | ||||||||||||||||||
| const int64_t* __restrict__ positions_, // [num_mrope_sections, | ||||||||||||||||||
| // num_tokens] | ||||||||||||||||||
| scalar_t* __restrict__ query_, | ||||||||||||||||||
| scalar_t* __restrict__ key_, | ||||||||||||||||||
| const scalar_t* __restrict__ cos_sin_cache_, // [max_position, rot_dim] | ||||||||||||||||||
| const int* mrope_section_data, // host array [num_mrope_sections] | ||||||||||||||||||
| const int num_mrope_sections_, | ||||||||||||||||||
| const int num_tokens_, | ||||||||||||||||||
| const int rot_dim_, | ||||||||||||||||||
| const int64_t query_stride_, | ||||||||||||||||||
| const int64_t key_stride_, | ||||||||||||||||||
| const int64_t head_stride_, | ||||||||||||||||||
| const int num_heads_, | ||||||||||||||||||
| const int num_kv_heads_, | ||||||||||||||||||
| const int head_size_) | ||||||||||||||||||
| : positions(positions_), | ||||||||||||||||||
| query(query_), | ||||||||||||||||||
| key(key_), | ||||||||||||||||||
| cos_sin_cache(cos_sin_cache_), | ||||||||||||||||||
| num_mrope_sections(num_mrope_sections_), | ||||||||||||||||||
| num_tokens(num_tokens_), | ||||||||||||||||||
| rot_dim(rot_dim_), | ||||||||||||||||||
| query_stride(query_stride_), | ||||||||||||||||||
| key_stride(key_stride_), | ||||||||||||||||||
| head_stride(head_stride_), | ||||||||||||||||||
| num_heads(num_heads_), | ||||||||||||||||||
| num_kv_heads(num_kv_heads_), | ||||||||||||||||||
| head_size(head_size_) { | ||||||||||||||||||
| for (int s = 0; s < num_mrope_sections_; ++s) | ||||||||||||||||||
| mrope_section[s] = mrope_section_data[s]; | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| void operator() [[sycl::reqd_sub_group_size(32)]] ( | ||||||||||||||||||
| const sycl::nd_item<3>& item_ct1) const { | ||||||||||||||||||
| // Each work-group handles one token. | ||||||||||||||||||
| const int token_idx = item_ct1.get_group(2); | ||||||||||||||||||
| const int embed_dim = rot_dim / 2; | ||||||||||||||||||
|
|
||||||||||||||||||
| // Build cos/sin cache for this token (private memory per thread). | ||||||||||||||||||
| scalar_t merged_cache[MROPE_MAX_ROT_DIM]; // [cos | sin], size = rot_dim | ||||||||||||||||||
| int cumsum = 0; | ||||||||||||||||||
| for (int s = 0; s < num_mrope_sections; ++s) { | ||||||||||||||||||
| const int lo = cumsum; | ||||||||||||||||||
| const int hi = lo + mrope_section[s]; | ||||||||||||||||||
| cumsum = hi; | ||||||||||||||||||
| const int64_t pos = positions[s * num_tokens + token_idx]; | ||||||||||||||||||
| const scalar_t* src = cos_sin_cache + pos * rot_dim; | ||||||||||||||||||
| for (int r = lo; r < hi; ++r) { | ||||||||||||||||||
| merged_cache[r] = src[r]; // cos slice | ||||||||||||||||||
| merged_cache[embed_dim + r] = src[embed_dim + r]; // sin slice | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
| const scalar_t* cache_ptr = merged_cache; | ||||||||||||||||||
|
|
||||||||||||||||||
| apply_rotary_embedding<scalar_t, IS_NEOX>( | ||||||||||||||||||
| query, | ||||||||||||||||||
| key, | ||||||||||||||||||
| cache_ptr, | ||||||||||||||||||
| head_size, | ||||||||||||||||||
| num_heads, | ||||||||||||||||||
| num_kv_heads, | ||||||||||||||||||
| rot_dim, | ||||||||||||||||||
| token_idx, | ||||||||||||||||||
| query_stride, | ||||||||||||||||||
| key_stride, | ||||||||||||||||||
| head_stride, | ||||||||||||||||||
| item_ct1); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| private: | ||||||||||||||||||
| const int64_t* __restrict__ positions; | ||||||||||||||||||
| scalar_t* __restrict__ query; | ||||||||||||||||||
| scalar_t* __restrict__ key; | ||||||||||||||||||
| const scalar_t* __restrict__ cos_sin_cache; | ||||||||||||||||||
| int mrope_section[MROPE_MAX_SECTIONS]; // embedded, copied from host list | ||||||||||||||||||
| const int num_mrope_sections; | ||||||||||||||||||
| const int num_tokens; | ||||||||||||||||||
| const int rot_dim; | ||||||||||||||||||
| const int64_t query_stride; | ||||||||||||||||||
| const int64_t key_stride; | ||||||||||||||||||
| const int64_t head_stride; | ||||||||||||||||||
| const int num_heads; | ||||||||||||||||||
| const int num_kv_heads; | ||||||||||||||||||
| const int head_size; | ||||||||||||||||||
| }; | ||||||||||||||||||
|
|
||||||||||||||||||
| } // namespace vllm | ||||||||||||||||||
|
|
||||||||||||||||||
| template <typename scalar_t> | ||||||||||||||||||
|
|
@@ -284,3 +381,138 @@ void rotary_embedding( | |||||||||||||||||
| positions, query, key, head_size, cos_sin_cache, is_neox); | ||||||||||||||||||
| }); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| // ── Multi-Modal Rotary Embedding (M-RoPE) ────────────────────────────────── | ||||||||||||||||||
| // Used by models such as Qwen2-VL that need per-section position encoding. | ||||||||||||||||||
| // | ||||||||||||||||||
| // positions : [num_mrope_sections, num_tokens] int64, on device | ||||||||||||||||||
| // query : [num_tokens, num_heads * head_size] or | ||||||||||||||||||
| // [num_tokens, num_heads, head_size] | ||||||||||||||||||
| // key : same shapes as query but kv_heads, or nullopt | ||||||||||||||||||
| // cos_sin_cache : [max_position, rot_dim] | ||||||||||||||||||
| // mrope_section : [num_mrope_sections] int32, on device; | ||||||||||||||||||
| // values in embed_dim units summing to rot_dim / 2 | ||||||||||||||||||
|
|
||||||||||||||||||
| template <typename scalar_t> | ||||||||||||||||||
| void call_multimodal_rotary_embedding_kernel( | ||||||||||||||||||
| torch::Tensor& positions, | ||||||||||||||||||
| torch::Tensor& query, | ||||||||||||||||||
| std::optional<torch::Tensor> key, | ||||||||||||||||||
| int64_t head_size, | ||||||||||||||||||
| torch::Tensor& cos_sin_cache, | ||||||||||||||||||
| bool is_neox, | ||||||||||||||||||
| const std::vector<int64_t>& mrope_section) { | ||||||||||||||||||
| using sycl_t = typename vllm::xpu::SyclTypeTrait<scalar_t>::Type; | ||||||||||||||||||
|
|
||||||||||||||||||
| TORCH_CHECK( | ||||||||||||||||||
| positions.dim() == 2, | ||||||||||||||||||
| "positions must have shape [num_mrope_sections, num_tokens]"); | ||||||||||||||||||
| const int num_mrope_sections = positions.size(0); | ||||||||||||||||||
| const int64_t num_tokens = positions.size(1); | ||||||||||||||||||
|
|
||||||||||||||||||
| TORCH_CHECK( | ||||||||||||||||||
| (int)mrope_section.size() == num_mrope_sections, | ||||||||||||||||||
| "mrope_section length must equal positions.size(0)"); | ||||||||||||||||||
| TORCH_CHECK( | ||||||||||||||||||
| num_mrope_sections <= vllm::MROPE_MAX_SECTIONS, | ||||||||||||||||||
| "num_mrope_sections exceeds MROPE_MAX_SECTIONS=", | ||||||||||||||||||
| vllm::MROPE_MAX_SECTIONS); | ||||||||||||||||||
|
|
||||||||||||||||||
| const int query_hidden_size = query.numel() / num_tokens; | ||||||||||||||||||
| const int key_hidden_size = | ||||||||||||||||||
| key.has_value() ? key->numel() / num_tokens : 0; | ||||||||||||||||||
| TORCH_CHECK(query_hidden_size % head_size == 0); | ||||||||||||||||||
| TORCH_CHECK(key_hidden_size % head_size == 0); | ||||||||||||||||||
|
|
||||||||||||||||||
| const int num_heads = query_hidden_size / head_size; | ||||||||||||||||||
| const int num_kv_heads = | ||||||||||||||||||
| key.has_value() ? key_hidden_size / head_size : num_heads; | ||||||||||||||||||
| TORCH_CHECK(num_heads % num_kv_heads == 0); | ||||||||||||||||||
|
|
||||||||||||||||||
| const int rot_dim = cos_sin_cache.size(1); | ||||||||||||||||||
|
||||||||||||||||||
| const int rot_dim = cos_sin_cache.size(1); | |
| const int rot_dim = cos_sin_cache.size(1); | |
| TORCH_CHECK( | |
| rot_dim <= vllm::MROPE_MAX_ROT_DIM, | |
| "rot_dim exceeds MROPE_MAX_ROT_DIM=", | |
| vllm::MROPE_MAX_ROT_DIM, | |
| ", got rot_dim=", | |
| rot_dim); |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -64,6 +64,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | |
| " Tensor cos_sin_cache, bool is_neox) -> ()"); | ||
| ops.impl("rotary_embedding", torch::kXPU, &rotary_embedding); | ||
|
|
||
| // Multi-modal Rotary Embedding (M-RoPE) — used by e.g. Qwen2-VL. | ||
| // positions has shape [num_mrope_sections, num_tokens]; mrope_section is | ||
| // an int32 device tensor of length num_mrope_sections that partitions the | ||
| // rotation dimensions across positional axes (e.g. time / height / width). | ||
| ops.def( | ||
| "multimodal_rotary_embedding(Tensor positions, Tensor! query," | ||
| " Tensor!? key, int head_size," | ||
| " Tensor cos_sin_cache, bool is_neox," | ||
| " int[] mrope_section) -> ()"); | ||
| ops.impl("multimodal_rotary_embedding", torch::kXPU, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. vllm code base doesn't have this op, we'd better move it to |
||
| &multimodal_rotary_embedding); | ||
|
|
||
| // Compute FP8 quantized tensor for given scaling factor. | ||
| ops.def( | ||
| "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale, " | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
merged_cachearray is not zero-initialized, and only the ranges covered bymrope_sectionentries are populated. If the section values don't sum to exactlyembed_dim(=rot_dim / 2),apply_rotary_embeddingwill read uninitialized values from the gaps. Consider either zero-initializingmerged_cacheor adding aTORCH_CHECKthat the section values sum toembed_dim.