Skip to content

Commit 63d1013

Browse files
committed
Restore is_query_prescaled heuristic and can_move_scale_after_matmul fallback
Address PR #34177 review comments: - [HIGH] Restore can_move_scale_after_matmul() size-based heuristic as performance fallback for non-prescaled query cases (e.g. decode S_q=1) - [LOW] Reword comments to not imply SDPAFusion is always involved Three-way scale placement logic: 1. Q pre-scaled (Multiply(Q, scalar_const)) -> scale K^T (precision fix) 2. can_move_scale_after_matmul -> scale after MatMul (perf optimization) 3. Default -> scale Q
1 parent fa89e73 commit 63d1013

File tree

2 files changed

+86
-13
lines changed

2 files changed

+86
-13
lines changed

src/common/transformations/src/transformations/op_conversions/scaled_dot_product_attention_decomposition.cpp

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,55 @@ namespace v3 = ov::op::v3;
4343
namespace v4 = ov::op::v4;
4444
namespace v8 = ov::op::v8;
4545
namespace v13 = ov::op::v13;
46+
namespace {
47+
48+
// Checks if query is Multiply(input, scalar_constant), indicating Q was pre-scaled
49+
// (common when PyTorch exports symmetric Q/K scaling via scaled_dot_product_attention).
50+
// When detected, applying SDPA scale to K^T instead of Q preserves the original
51+
// computation order and minimizes FP rounding divergence across transformer layers.
52+
bool is_query_prescaled(const ov::Output<ov::Node>& query) {
53+
auto mul = ov::as_type_ptr<v1::Multiply>(query.get_node_shared_ptr());
54+
if (!mul)
55+
return false;
56+
for (size_t i = 0; i < 2; ++i) {
57+
auto constant = ov::as_type_ptr<v0::Constant>(mul->input_value(i).get_node_shared_ptr());
58+
if (constant) {
59+
const auto& shape = constant->get_shape();
60+
if (ov::shape_size(shape) == 1)
61+
return true;
62+
}
63+
}
64+
return false;
65+
}
66+
67+
bool can_move_scale_after_matmul(const ov::Output<ov::Node>& query,
68+
const ov::Output<ov::Node>& kT,
69+
const ov::Output<ov::Node>& scale) {
70+
const auto& scale_pshape = scale.get_partial_shape();
71+
const auto& query_pshape = query.get_partial_shape();
72+
if (scale_pshape.is_dynamic() || query_pshape.is_dynamic()) {
73+
return false;
74+
}
75+
76+
// According to the ov SDPA specification, the scale input have to be 1d with 1 element
77+
// or scalar.
78+
if (ov::shape_size(scale_pshape.to_shape()) != 1) {
79+
return false;
80+
}
81+
82+
// using the original implementation to calculate the shapes.
83+
// we need to move the scale after MatMul only if the tensor after MatMul is smaller.
84+
auto q_scaled = std::make_shared<v1::Multiply>(query, scale);
85+
auto scaled_attn = std::make_shared<v0::MatMul>(q_scaled, kT);
86+
const auto& scaled_attn_pshape = scaled_attn->output(0).get_partial_shape();
87+
if (scaled_attn_pshape.is_static()) {
88+
return ov::shape_size(query_pshape.to_shape()) > ov::shape_size(scaled_attn_pshape.to_shape());
89+
}
90+
return false;
91+
}
92+
93+
} // namespace
94+
4695
ov::pass::ScaledDotProductAttentionDecomposition::ScaledDotProductAttentionDecomposition() {
4796
MATCHER_SCOPE(ScaledDotProductAttentionDecomposition);
4897
auto pattern_node = ov::pass::pattern::wrap_type<v13::ScaledDotProductAttention>();
@@ -123,10 +172,20 @@ std::shared_ptr<ov::Node> ov::pass::ScaledDotProductAttentionDecomposition::deco
123172
register_new_node<v0::Concat>(OutputVector{k_dims_before_transpose, k_last_dim, k_next_dim}, 0);
124173
auto k_transposed = register_new_node<v1::Transpose>(key, transpose_dims);
125174

126-
// Always apply scale to K^T. SDPAFusion absorbs K-side scale into the SDPA scale
127-
// parameter, so restoring it on K^T preserves the original computation order.
128-
auto k_scaled = register_new_node<v1::Multiply>(k_transposed, scale);
129-
auto scaled_atten = register_new_node<v0::MatMul>(query, k_scaled)->output(0);
175+
ov::Output<Node> scaled_atten;
176+
if (is_query_prescaled(query)) {
177+
// Q is already pre-scaled (e.g., Multiply(Q, scalar_constant)).
178+
// Apply scale to K^T to preserve the original computation order
179+
// and minimize FP rounding divergence across transformer layers.
180+
auto k_scaled = register_new_node<v1::Multiply>(k_transposed, scale);
181+
scaled_atten = register_new_node<v0::MatMul>(query, k_scaled)->output(0);
182+
} else if (can_move_scale_after_matmul(query, k_transposed, scale)) {
183+
auto atten = register_new_node<v0::MatMul>(query, k_transposed)->output(0);
184+
scaled_atten = register_new_node<v1::Multiply>(atten, scale)->output(0);
185+
} else {
186+
auto q_scaled = register_new_node<v1::Multiply>(query, scale);
187+
scaled_atten = register_new_node<v0::MatMul>(q_scaled, k_transposed)->output(0);
188+
}
130189

131190
minus_inf = register_new_node<v1::ConvertLike>(minus_inf, scaled_atten);
132191

src/common/transformations/tests/op_conversions/scaled_dot_product_decomposition_test.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ const std::shared_ptr<ov::Node> scaled_dot_product_attention_decomposition(std::
4646
std::shared_ptr<ov::Node> attention_mask,
4747
std::shared_ptr<ov::Node> scale,
4848
bool casual,
49+
bool scale_on_k = false,
50+
bool scale_after_matmul = false,
4951
std::shared_ptr<ov::Node> sinks = nullptr);
5052

5153
TEST_F(TransformationTestsF, ScaledDotProductAttentionDecompositionStaticBasic) {
@@ -132,7 +134,7 @@ TEST_F(TransformationTestsF, ScaledDotProductAttentionDecompositionStaticBroadca
132134

133135
{
134136
const auto scaled_dot_product_attention =
135-
scaled_dot_product_attention_decomposition(query, key, value, attention_mask, scale, casual);
137+
scaled_dot_product_attention_decomposition(query, key, value, attention_mask, scale, casual, false, true);
136138
model_ref = std::make_shared<ov::Model>(OutputVector{scaled_dot_product_attention},
137139
ParameterVector{query, key, value, attention_mask, scale});
138140
}
@@ -196,7 +198,7 @@ TEST_F(TransformationTestsF, ScaledDotProductAttentionDecompositionDynamic) {
196198
}
197199
}
198200

199-
TEST_F(TransformationTestsF, ScaledDotProductAttentionDecomposition_ScalarScale_MultiplyOnK) {
201+
TEST_F(TransformationTestsF, ScaledDotProductAttentionDecomposition_ScalarScale_MultiplyAfterMatMul) {
200202
const PartialShape query_shape{1, 32, 64};
201203
const PartialShape key_shape{1, 32, 64};
202204
const PartialShape value_shape{1, 32, 64};
@@ -219,12 +221,13 @@ TEST_F(TransformationTestsF, ScaledDotProductAttentionDecomposition_ScalarScale_
219221
}
220222

221223
{
222-
auto ref = scaled_dot_product_attention_decomposition(query, key, value, attention_mask, scale, casual);
224+
auto ref =
225+
scaled_dot_product_attention_decomposition(query, key, value, attention_mask, scale, casual, false, true);
223226
model_ref = std::make_shared<ov::Model>(OutputVector{ref}, ParameterVector{query, key, value, attention_mask});
224227
}
225228
}
226229

227-
TEST_F(TransformationTestsF, ScaledDotProductAttentionDecomposition_DynamicScale_MultiplyOnK) {
230+
TEST_F(TransformationTestsF, ScaledDotProductAttentionDecomposition_DynamicScale_MultiplyBeforeMatMul) {
228231
const PartialShape query_shape{-1, -1, 64};
229232
const PartialShape key_shape{-1, -1, 64};
230233
const PartialShape value_shape{-1, -1, 64};
@@ -259,6 +262,8 @@ const std::shared_ptr<ov::Node> scaled_dot_product_attention_decomposition(std::
259262
std::shared_ptr<ov::Node> attention_mask,
260263
std::shared_ptr<ov::Node> scale,
261264
bool casual,
265+
bool scale_on_k,
266+
bool scale_after_matmul,
262267
std::shared_ptr<ov::Node> sinks) {
263268
const auto q_shape = std::make_shared<v3::ShapeOf>(query, element::i32);
264269
const auto k_shape = std::make_shared<v3::ShapeOf>(key, element::i32);
@@ -298,8 +303,17 @@ const std::shared_ptr<ov::Node> scaled_dot_product_attention_decomposition(std::
298303
std::make_shared<v0::Concat>(OutputVector{k_dims_before_transpose, k_last_dim, k_next_dim}, 0);
299304
const auto k_transposed = std::make_shared<v1::Transpose>(key, transpose_dims);
300305

301-
const auto k_scaled = std::make_shared<v1::Multiply>(k_transposed, scale);
302-
Output<Node> scaled_atten = std::make_shared<v0::MatMul>(query, k_scaled)->output(0);
306+
Output<Node> scaled_atten;
307+
if (scale_on_k) {
308+
const auto k_scaled = std::make_shared<v1::Multiply>(k_transposed, scale);
309+
scaled_atten = std::make_shared<v0::MatMul>(query, k_scaled)->output(0);
310+
} else if (scale_after_matmul) {
311+
const auto atten = std::make_shared<v0::MatMul>(query, k_transposed)->output(0);
312+
scaled_atten = std::make_shared<v1::Multiply>(atten, scale);
313+
} else {
314+
const auto q_scaled = std::make_shared<v1::Multiply>(query, scale);
315+
scaled_atten = std::make_shared<v0::MatMul>(q_scaled, k_transposed)->output(0);
316+
}
303317
minus_inf = std::make_shared<v1::ConvertLike>(minus_inf, scaled_atten);
304318

305319
Output<Node> mask;
@@ -388,9 +402,9 @@ TEST_F(TransformationTestsF, ScaledDotProductAttentionDecomposition_PreScaledQue
388402
}
389403

390404
{
391-
// Expected: scale applied to K^T (always, unconditionally)
405+
// Expected: scale applied to K^T (Q is pre-scaled)
392406
auto ref =
393-
scaled_dot_product_attention_decomposition(query_prescaled, key, value, attention_mask, sdpa_scale, casual);
407+
scaled_dot_product_attention_decomposition(query_prescaled, key, value, attention_mask, sdpa_scale, casual, true);
394408
model_ref =
395409
std::make_shared<ov::Model>(OutputVector{ref}, ParameterVector{raw_query, key, value, attention_mask});
396410
}
@@ -422,7 +436,7 @@ TEST_F(TransformationTestsF, ScaledDotProductAttentionDecomposition_Sinks) {
422436
}
423437

424438
{
425-
auto ref = scaled_dot_product_attention_decomposition(query, key, value, attention_mask, scale, casual, sinks);
439+
auto ref = scaled_dot_product_attention_decomposition(query, key, value, attention_mask, scale, casual, false, false, sinks);
426440
model_ref = std::make_shared<ov::Model>(OutputVector{ref},
427441
ParameterVector{query, key, value, attention_mask, scale, sinks});
428442
}

0 commit comments

Comments
 (0)