Skip to content

Commit e3f6861

Browse files
[GPU] Add IncreasePositionIdsPrecision for Qwen3-VL models (#34716)
### Description of the issue(symptom, root-cause, how it was resolved) - Symptom: Qwen3-VL-4B-Instruct INT4 model produces incorrect output on GPU for long input sequences (>2048 tokens). The 1st token prediction is wrong, causing completely incoherent generated text. CPU output is correct. - Root-cause: The `position_ids` (integer values) are converted to FP16 before the frequency MatMul in the RoPE computation path. FP16 has only 10 mantissa bits, so integers in range 4096–8192 are rounded to the nearest multiple of 4 (e.g., 4173→4172, 4174→4176). This corrupts the sin/cos positional embeddings fed into every transformer layer. The existing `IncreasePositionIdsPrecision` transformation has 4 model-specific patterns but none matches Qwen3-VL because: (1) Unsqueeze is decomposed to Reshape by the frontend, and (2) the path between MatMul and Sin/Cos includes a complex `Gather×3 → ScatterNDUpdate` chain for 3D position assignment (temporal, height, width) that is unique to Qwen3-VL. - Resolution: Added `IncreasePositionIdsPrecisionForQwen3VL` matcher pass that pattern-matches `Convert→Reshape|Unsqueeze→Convert(i32→f16)→MatMul(Broadcast,...)`, then uses forward BFS from MatMul to locate downstream Sin/Cos nodes. The transformation upgrades the position_ids computation path from f16 to f32, and inserts f32→f16 converts after Sin/Cos to restore original precision at the boundary. #### The code and line that caused this issue (if it is not changed directly) - intel_gpu/src/plugin/transformations/increase_position_ids_precision.cpp - `IncreasePositionIdsPrecision::run_on_model()`: the 4 existing sub-passes (ForRoPE, ForQwen25VL, ForLtxVideo, ForGPTOSS) all failed to match the Qwen3-VL graph pattern, so no precision upgrade was applied. #### Reproduction step and snapshot (if applicable. Do not attach for customer model) - python genai/tools/llm_bench/benchmark.py -m Qwen3-VL-4B-Instruct/INT4 -d GPU.1 --task visual_text_gen -pf raw_prompt.jsonl -ic 128 -lc config.json - where config.json = {"ATTENTION_BACKEND": "PA", "CACHE_DIR": ""} - Input: 5545 tokens (tool-calling prompt without image) #### Problematic graph - Qwen3-VL RoPE position_ids path in the language model subgraph: <img width="509" height="1014" alt="image" src="https://github.com/user-attachments/assets/2c5632a0-ad75-440e-b218-4f42f16f9726" /> - The fix changes Convert(i32→f16) to Convert(i32→f32), inserts Convert(f16→f32) after Broadcast, and inserts Convert(f32→f16) after Sin/Cos. #### Checklist - [x] Is it a proper fix? (not a workaround) - [x] Did you include test case for this fix, if necessary? - [x] Did you review existing test that can be extended to cover this scenario? Which test did you review? - Reviewed `IncreasePositionIdsPrecisionForQwen25VL` test and Added a new dedicated test `IncreasePositionIdsPrecisionForQwen3VL` ### Tickets: - [CVS-182656](https://jira.devtools.intel.com/browse/CVS-182656) Signed-off-by: Andrew Park <andrew.park@intel.com>
1 parent 6fab794 commit e3f6861

File tree

3 files changed

+215
-0
lines changed

3 files changed

+215
-0
lines changed

src/plugins/intel_gpu/src/plugin/transformations/increase_position_ids_precision.cpp

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include "increase_position_ids_precision.hpp"
66

7+
#include <set>
8+
79
#include "intel_gpu/op/gemm.hpp"
810
#include "ov_ops/rotary_positional_embeddings.hpp"
911

@@ -186,6 +188,113 @@ IncreasePositionIdsPrecisionForQwen25VL::IncreasePositionIdsPrecisionForQwen25VL
186188
this->register_matcher(m, callback);
187189
}
188190

191+
IncreasePositionIdsPrecisionForQwen3VL::IncreasePositionIdsPrecisionForQwen3VL() {
192+
using namespace ov::pass::pattern;
193+
using ov::pass::pattern::op::Or;
194+
195+
// Qwen3-VL RoPE pattern:
196+
// position_ids -> Convert(i64->i32) -> Reshape(unsqueeze) -> Convert(i32->f16) -> MatMul(Broadcast, Convert)
197+
// -> Reshape(transpose) -> Gather(select_channel) x3 -> ScatterNDUpdate chain -> Reshape -> Concat(self,self)
198+
// -> Sin/Cos -> Reshape(unsqueeze) -> RoPE
199+
//
200+
// The intermediate path between MatMul and Sin/Cos is too complex to pattern-match,
201+
// so we match the beginning (up to MatMul) and use graph traversal to find downstream Sin/Cos.
202+
// Key difference from Qwen2.5-VL: Unsqueeze is decomposed to Reshape.
203+
auto position_ids = any_input();
204+
auto convert_to_i32 = wrap_type<ov::op::v0::Convert>({position_ids});
205+
auto reshape_unsqueeze = wrap_type<ov::op::v1::Reshape>({convert_to_i32, wrap_type<ov::op::v0::Constant>()});
206+
auto unsqueeze = wrap_type<ov::op::v0::Unsqueeze>({convert_to_i32, any_input()});
207+
auto reshape_or_unsqueeze = std::make_shared<Or>(OutputVector{reshape_unsqueeze, unsqueeze});
208+
auto convert_to_f16 = wrap_type<ov::op::v0::Convert>({reshape_or_unsqueeze});
209+
210+
auto broadcast_freq = wrap_type<ov::op::v3::Broadcast>({any_input(), any_input()});
211+
auto matmul = wrap_type<ov::op::v0::MatMul>({broadcast_freq, convert_to_f16});
212+
213+
ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
214+
const auto& pattern_map = m.get_pattern_value_map();
215+
216+
auto convert_node = ov::as_type_ptr<ov::op::v0::Convert>(pattern_map.at(convert_to_f16).get_node_shared_ptr());
217+
auto broadcast_node = pattern_map.at(broadcast_freq).get_node_shared_ptr();
218+
auto matmul_node = ov::as_type_ptr<ov::op::v0::MatMul>(pattern_map.at(matmul).get_node_shared_ptr());
219+
220+
if (!convert_node || !matmul_node || transformation_callback(convert_node))
221+
return false;
222+
223+
const auto desired_et = ov::element::f32;
224+
const auto original_et = convert_node->get_output_element_type(0);
225+
if (original_et == desired_et)
226+
return false;
227+
228+
// Verify input is integer type (position_ids should be i32 or i64)
229+
auto input_et = convert_node->input_value(0).get_element_type();
230+
if (!input_et.is_integral())
231+
return false;
232+
233+
// Walk forward from MatMul to find Sin and Cos nodes through the
234+
// Reshape -> Gather -> ScatterNDUpdate -> Reshape -> Concat chain.
235+
// Only follow floating-point data outputs to stay on the data path.
236+
std::shared_ptr<ov::op::v0::Sin> sin_node;
237+
std::shared_ptr<ov::op::v0::Cos> cos_node;
238+
239+
std::vector<ov::Node*> stack;
240+
std::set<ov::Node*> visited;
241+
stack.push_back(matmul_node.get());
242+
constexpr size_t max_nodes = 30;
243+
size_t nodes_visited = 0;
244+
245+
while (!stack.empty() && nodes_visited < max_nodes && (!sin_node || !cos_node)) {
246+
auto* current = stack.back();
247+
stack.pop_back();
248+
249+
for (auto& output : current->outputs()) {
250+
if (!output.get_element_type().is_real())
251+
continue;
252+
for (auto& target_input : output.get_target_inputs()) {
253+
auto consumer = target_input.get_node()->shared_from_this();
254+
if (!visited.insert(consumer.get()).second)
255+
continue;
256+
nodes_visited++;
257+
258+
if (auto sin_ptr = ov::as_type_ptr<ov::op::v0::Sin>(consumer)) {
259+
sin_node = sin_ptr;
260+
} else if (auto cos_ptr = ov::as_type_ptr<ov::op::v0::Cos>(consumer)) {
261+
cos_node = cos_ptr;
262+
} else {
263+
stack.push_back(consumer.get());
264+
}
265+
}
266+
}
267+
}
268+
269+
if (!sin_node || !cos_node)
270+
return false;
271+
272+
// 1. Change Convert output from f16 to f32 (position_ids path)
273+
auto new_convert = std::make_shared<ov::op::v0::Convert>(convert_node->input_value(0), desired_et);
274+
new_convert->set_friendly_name(convert_node->get_friendly_name() + "_increase_precision");
275+
copy_runtime_info(convert_node, new_convert);
276+
ov::replace_node(convert_node, new_convert);
277+
278+
// 2. Insert Convert(f16->f32) after Broadcast (freq path) to match MatMul types
279+
if (broadcast_node->get_output_element_type(0) != desired_et) {
280+
auto broadcast_to_f32 = std::make_shared<ov::op::v0::Convert>(broadcast_node->output(0), desired_et);
281+
broadcast_to_f32->set_friendly_name(broadcast_node->get_friendly_name() + "_to_f32");
282+
copy_runtime_info(broadcast_node, broadcast_to_f32);
283+
matmul_node->input(0).replace_source_output(broadcast_to_f32->output(0));
284+
}
285+
286+
// 3. Insert Convert(f32->f16) after Sin/Cos to restore original precision
287+
size_t output_idx = 0;
288+
insert_converts_after_if_needed(sin_node, original_et, output_idx);
289+
insert_converts_after_if_needed(cos_node, original_et, output_idx);
290+
291+
return true;
292+
};
293+
294+
auto m = std::make_shared<ov::pass::pattern::Matcher>(matmul, "IncreasePositionIdsPrecisionForQwen3VL");
295+
this->register_matcher(m, callback);
296+
}
297+
189298
IncreasePositionIdsPrecisionForLtxVideo::IncreasePositionIdsPrecisionForLtxVideo() {
190299
using namespace ov::pass::pattern;
191300
using ov::pass::pattern::op::Or;
@@ -338,6 +447,7 @@ bool IncreasePositionIdsPrecision::run_on_model(const std::shared_ptr<ov::Model>
338447
auto symbolic_ctx_manager = symbolic_optimizations.get_manager();
339448
symbolic_ctx_manager->register_pass<IncreasePositionIdsPrecisionForRoPE>();
340449
symbolic_ctx_manager->register_pass<IncreasePositionIdsPrecisionForQwen25VL>();
450+
symbolic_ctx_manager->register_pass<IncreasePositionIdsPrecisionForQwen3VL>();
341451
symbolic_ctx_manager->register_pass<IncreasePositionIdsPrecisionForLtxVideo>();
342452
symbolic_ctx_manager->register_pass<IncreasePositionIdsPrecisionForGPTOSS>();
343453
return symbolic_optimizations.run_on_model(model);

src/plugins/intel_gpu/src/plugin/transformations/increase_position_ids_precision.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ class IncreasePositionIdsPrecisionForQwen25VL : public ov::pass::MatcherPass {
2020
IncreasePositionIdsPrecisionForQwen25VL();
2121
};
2222

23+
class IncreasePositionIdsPrecisionForQwen3VL : public ov::pass::MatcherPass {
24+
public:
25+
OPENVINO_MATCHER_PASS_RTTI("IncreasePositionIdsPrecisionForQwen3VL");
26+
IncreasePositionIdsPrecisionForQwen3VL();
27+
};
28+
2329
class IncreasePositionIdsPrecisionForLtxVideo : public ov::pass::MatcherPass {
2430
public:
2531
OPENVINO_MATCHER_PASS_RTTI("IncreasePositionIdsPrecisionForLtxVideo");

src/plugins/intel_gpu/tests/unit/transformations/increase_precision_test.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,105 @@ TEST_F(TransformationTestsF, IncreasePositionIdsPrecisionForQwen25VL) {
729729
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
730730
}
731731

732+
TEST_F(TransformationTestsF, IncreasePositionIdsPrecisionForQwen3VL) {
733+
// Qwen3-VL pattern: position_ids -> Convert(i64->i32) -> Reshape(unsqueeze) -> Convert(i32->f16)
734+
// -> MatMul(Broadcast, Convert) -> Reshape(transpose) -> Gather -> Concat(self,self)
735+
// -> Sin/Cos -> Reshape(unsqueeze) -> RoPE
736+
{
737+
auto position_ids = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ 3, -1 });
738+
auto input_convert = std::make_shared<ov::op::v0::Convert>(position_ids, ov::element::i32);
739+
// Qwen3-VL uses Reshape instead of Unsqueeze
740+
auto input_reshape = std::make_shared<ov::op::v1::Reshape>(input_convert,
741+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{4}, std::vector<int32_t>{3, -1, 1, 1}), true);
742+
auto convert_2 = std::make_shared<ov::op::v0::Convert>(input_reshape, ov::element::f16);
743+
744+
auto shape_of = std::make_shared<ov::op::v3::ShapeOf>(input_convert, ov::element::i32);
745+
auto gather_0 = std::make_shared<ov::op::v8::Gather>(shape_of,
746+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{1}, std::vector<int64_t>{1}),
747+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, std::vector<int64_t>{0}));
748+
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{
749+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{1}, std::vector<int64_t>{1}),
750+
gather_0,
751+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{1}, std::vector<int64_t>{64}),
752+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{1}, std::vector<int64_t>{1})}, 0);
753+
auto broadcast = std::make_shared<ov::op::v3::Broadcast>(
754+
std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{1, 1, 64, 1}), concat);
755+
auto matmul = std::make_shared<ov::op::v0::MatMul>(broadcast, convert_2);
756+
757+
// Reshape(transpose) -> Gather(select channel 0)
758+
auto reshape_transpose = std::make_shared<ov::op::v1::Reshape>(matmul,
759+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{4}, std::vector<int32_t>{3, -1, 1, 64}), true);
760+
auto gather_ch0 = std::make_shared<ov::op::v8::Gather>(reshape_transpose,
761+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, std::vector<int64_t>{0}),
762+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, std::vector<int64_t>{0}));
763+
764+
// Concat(self, self) to produce [?, 1, 128]
765+
auto concat_2 = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{gather_ch0, gather_ch0}, 2);
766+
767+
auto cos = std::make_shared<ov::op::v0::Cos>(concat_2);
768+
auto sin = std::make_shared<ov::op::v0::Sin>(concat_2);
769+
auto cos_unsqueeze = std::make_shared<ov::op::v1::Reshape>(cos,
770+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{4}, std::vector<int32_t>{-1, 1, 1, 128}), true);
771+
auto sin_unsqueeze = std::make_shared<ov::op::v1::Reshape>(sin,
772+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{4}, std::vector<int32_t>{-1, 1, 1, 128}), true);
773+
774+
auto input_2 = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{ -1, 8, -1, 128});
775+
auto rope = std::make_shared<ov::op::internal::RoPE>(ov::OutputVector{input_2, cos_unsqueeze, sin_unsqueeze},
776+
ov::op::internal::RoPE::Config());
777+
778+
model = std::make_shared<ov::Model>(ov::OutputVector{rope}, ov::ParameterVector{position_ids, input_2});
779+
manager.register_pass<IncreasePositionIdsPrecision>();
780+
}
781+
{
782+
auto position_ids = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ 3, -1 });
783+
auto input_convert = std::make_shared<ov::op::v0::Convert>(position_ids, ov::element::i32);
784+
auto input_reshape = std::make_shared<ov::op::v1::Reshape>(input_convert,
785+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{4}, std::vector<int32_t>{3, -1, 1, 1}), true);
786+
// Changed: Convert to f32 instead of f16
787+
auto convert_2 = std::make_shared<ov::op::v0::Convert>(input_reshape, ov::element::f32);
788+
789+
auto shape_of = std::make_shared<ov::op::v3::ShapeOf>(input_convert, ov::element::i32);
790+
auto gather_0 = std::make_shared<ov::op::v8::Gather>(shape_of,
791+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{1}, std::vector<int64_t>{1}),
792+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, std::vector<int64_t>{0}));
793+
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{
794+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{1}, std::vector<int64_t>{1}),
795+
gather_0,
796+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{1}, std::vector<int64_t>{64}),
797+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{1}, std::vector<int64_t>{1})}, 0);
798+
auto broadcast = std::make_shared<ov::op::v3::Broadcast>(
799+
std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{1, 1, 64, 1}), concat);
800+
// Changed: Insert Convert(f16->f32) after Broadcast
801+
auto broadcast_to_f32 = std::make_shared<ov::op::v0::Convert>(broadcast, ov::element::f32);
802+
auto matmul = std::make_shared<ov::op::v0::MatMul>(broadcast_to_f32, convert_2);
803+
804+
auto reshape_transpose = std::make_shared<ov::op::v1::Reshape>(matmul,
805+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{4}, std::vector<int32_t>{3, -1, 1, 64}), true);
806+
auto gather_ch0 = std::make_shared<ov::op::v8::Gather>(reshape_transpose,
807+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, std::vector<int64_t>{0}),
808+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, std::vector<int64_t>{0}));
809+
810+
auto concat_2 = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{gather_ch0, gather_ch0}, 2);
811+
812+
auto cos = std::make_shared<ov::op::v0::Cos>(concat_2);
813+
auto sin = std::make_shared<ov::op::v0::Sin>(concat_2);
814+
// Changed: Insert Convert(f32->f16) after Cos and Sin
815+
auto cos_to_f16 = std::make_shared<ov::op::v0::Convert>(cos, ov::element::f16);
816+
auto sin_to_f16 = std::make_shared<ov::op::v0::Convert>(sin, ov::element::f16);
817+
auto cos_unsqueeze = std::make_shared<ov::op::v1::Reshape>(cos_to_f16,
818+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{4}, std::vector<int32_t>{-1, 1, 1, 128}), true);
819+
auto sin_unsqueeze = std::make_shared<ov::op::v1::Reshape>(sin_to_f16,
820+
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{4}, std::vector<int32_t>{-1, 1, 1, 128}), true);
821+
822+
auto input_2 = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{ -1, 8, -1, 128});
823+
auto rope = std::make_shared<ov::op::internal::RoPE>(ov::OutputVector{input_2, cos_unsqueeze, sin_unsqueeze},
824+
ov::op::internal::RoPE::Config());
825+
826+
model_ref = std::make_shared<ov::Model>(ov::OutputVector{rope}, ov::ParameterVector{position_ids, input_2});
827+
}
828+
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
829+
}
830+
732831
TEST_F(TransformationTestsF, IncreasePositionIdsPrecisionForGPTOSS) {
733832
{
734833
auto position_ids = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ 3, -1, -1 });

0 commit comments

Comments
 (0)