@@ -46,6 +46,8 @@ const std::shared_ptr<ov::Node> scaled_dot_product_attention_decomposition(std::
4646 std::shared_ptr<ov::Node> attention_mask,
4747 std::shared_ptr<ov::Node> scale,
4848 bool casual,
49+ bool scale_on_k = false ,
50+ bool scale_after_matmul = false ,
4951 std::shared_ptr<ov::Node> sinks = nullptr );
5052
5153TEST_F (TransformationTestsF, ScaledDotProductAttentionDecompositionStaticBasic) {
@@ -132,7 +134,7 @@ TEST_F(TransformationTestsF, ScaledDotProductAttentionDecompositionStaticBroadca
132134
133135 {
134136 const auto scaled_dot_product_attention =
135- scaled_dot_product_attention_decomposition (query, key, value, attention_mask, scale, casual);
137+ scaled_dot_product_attention_decomposition (query, key, value, attention_mask, scale, casual, false , true );
136138 model_ref = std::make_shared<ov::Model>(OutputVector{scaled_dot_product_attention},
137139 ParameterVector{query, key, value, attention_mask, scale});
138140 }
@@ -196,7 +198,7 @@ TEST_F(TransformationTestsF, ScaledDotProductAttentionDecompositionDynamic) {
196198 }
197199}
198200
199- TEST_F (TransformationTestsF, ScaledDotProductAttentionDecomposition_ScalarScale_MultiplyOnK ) {
201+ TEST_F (TransformationTestsF, ScaledDotProductAttentionDecomposition_ScalarScale_MultiplyAfterMatMul ) {
200202 const PartialShape query_shape{1 , 32 , 64 };
201203 const PartialShape key_shape{1 , 32 , 64 };
202204 const PartialShape value_shape{1 , 32 , 64 };
@@ -219,12 +221,13 @@ TEST_F(TransformationTestsF, ScaledDotProductAttentionDecomposition_ScalarScale_
219221 }
220222
221223 {
222- auto ref = scaled_dot_product_attention_decomposition (query, key, value, attention_mask, scale, casual);
224+ auto ref =
225+ scaled_dot_product_attention_decomposition (query, key, value, attention_mask, scale, casual, false , true );
223226 model_ref = std::make_shared<ov::Model>(OutputVector{ref}, ParameterVector{query, key, value, attention_mask});
224227 }
225228}
226229
227- TEST_F (TransformationTestsF, ScaledDotProductAttentionDecomposition_DynamicScale_MultiplyOnK ) {
230+ TEST_F (TransformationTestsF, ScaledDotProductAttentionDecomposition_DynamicScale_MultiplyBeforeMatMul ) {
228231 const PartialShape query_shape{-1 , -1 , 64 };
229232 const PartialShape key_shape{-1 , -1 , 64 };
230233 const PartialShape value_shape{-1 , -1 , 64 };
@@ -259,6 +262,8 @@ const std::shared_ptr<ov::Node> scaled_dot_product_attention_decomposition(std::
259262 std::shared_ptr<ov::Node> attention_mask,
260263 std::shared_ptr<ov::Node> scale,
261264 bool casual,
265+ bool scale_on_k,
266+ bool scale_after_matmul,
262267 std::shared_ptr<ov::Node> sinks) {
263268 const auto q_shape = std::make_shared<v3::ShapeOf>(query, element::i32 );
264269 const auto k_shape = std::make_shared<v3::ShapeOf>(key, element::i32 );
@@ -298,8 +303,17 @@ const std::shared_ptr<ov::Node> scaled_dot_product_attention_decomposition(std::
298303 std::make_shared<v0::Concat>(OutputVector{k_dims_before_transpose, k_last_dim, k_next_dim}, 0 );
299304 const auto k_transposed = std::make_shared<v1::Transpose>(key, transpose_dims);
300305
301- const auto k_scaled = std::make_shared<v1::Multiply>(k_transposed, scale);
302- Output<Node> scaled_atten = std::make_shared<v0::MatMul>(query, k_scaled)->output (0 );
306+ Output<Node> scaled_atten;
307+ if (scale_on_k) {
308+ const auto k_scaled = std::make_shared<v1::Multiply>(k_transposed, scale);
309+ scaled_atten = std::make_shared<v0::MatMul>(query, k_scaled)->output (0 );
310+ } else if (scale_after_matmul) {
311+ const auto atten = std::make_shared<v0::MatMul>(query, k_transposed)->output (0 );
312+ scaled_atten = std::make_shared<v1::Multiply>(atten, scale);
313+ } else {
314+ const auto q_scaled = std::make_shared<v1::Multiply>(query, scale);
315+ scaled_atten = std::make_shared<v0::MatMul>(q_scaled, k_transposed)->output (0 );
316+ }
303317 minus_inf = std::make_shared<v1::ConvertLike>(minus_inf, scaled_atten);
304318
305319 Output<Node> mask;
@@ -388,9 +402,9 @@ TEST_F(TransformationTestsF, ScaledDotProductAttentionDecomposition_PreScaledQue
388402 }
389403
390404 {
391- // Expected: scale applied to K^T (always, unconditionally )
405+ // Expected: scale applied to K^T (Q is pre-scaled )
392406 auto ref =
393- scaled_dot_product_attention_decomposition (query_prescaled, key, value, attention_mask, sdpa_scale, casual);
407+ scaled_dot_product_attention_decomposition (query_prescaled, key, value, attention_mask, sdpa_scale, casual, true );
394408 model_ref =
395409 std::make_shared<ov::Model>(OutputVector{ref}, ParameterVector{raw_query, key, value, attention_mask});
396410 }
@@ -422,7 +436,7 @@ TEST_F(TransformationTestsF, ScaledDotProductAttentionDecomposition_Sinks) {
422436 }
423437
424438 {
425- auto ref = scaled_dot_product_attention_decomposition (query, key, value, attention_mask, scale, casual, sinks);
439+ auto ref = scaled_dot_product_attention_decomposition (query, key, value, attention_mask, scale, casual, false , false , sinks);
426440 model_ref = std::make_shared<ov::Model>(OutputVector{ref},
427441 ParameterVector{query, key, value, attention_mask, scale, sinks});
428442 }
0 commit comments