From ff62894c9ecfebae86f470995c89b0ad4a6a18ce Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Wed, 4 Mar 2026 19:18:09 +0100 Subject: [PATCH 1/4] [Snippets] SplitM pass removed --- .../pass/mha_parallel_wa_optimizer.hpp | 35 ++++ .../pass/mha_parallel_wa_optimizer.cpp | 162 ++++++++++++++-- .../src/pass/common_optimizations.cpp | 5 - .../tests/src/pass/mha_tokenization.cpp | 178 ------------------ .../snippets/tests/src/utils/split_dim_m.cpp | 10 +- .../transformation_pipeline.cpp | 29 +-- .../skip_tests_config.cpp | 2 - .../snippets/mha_split_dim_m.cpp | 60 ++---- .../include/subgraph_mha.hpp | 13 -- .../ov_snippets_models/src/subgraph_mha.cpp | 63 ------- 10 files changed, 211 insertions(+), 346 deletions(-) diff --git a/src/common/snippets/include/snippets/lowered/pass/mha_parallel_wa_optimizer.hpp b/src/common/snippets/include/snippets/lowered/pass/mha_parallel_wa_optimizer.hpp index 35ef18096973ea..4d5c4382c8d3db 100644 --- a/src/common/snippets/include/snippets/lowered/pass/mha_parallel_wa_optimizer.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/mha_parallel_wa_optimizer.hpp @@ -40,7 +40,41 @@ class MHAParallelWAOptimizer : public lowered::pass::RuntimeOptimizer { return !m_loops_to_split.empty(); } + /** + * @brief Tries to split M dimension in "shape" in accordance to optimal parallel work amount + * @param shape Original shape + * @param optimal_parallelism_work_amount Optimal work amount + * @param batch_m_dim reference on batch's part of the split M + * @param new_m_dim reference on new M dim after the split + * @return true if split was successfull, otherwise false + */ + static bool split(const ov::Shape& shape, + size_t optimal_parallelism_work_amount, + size_t& batch_m_dim, + size_t& new_m_dim); + private: + /** + * @brief Contains splitM approaches allowing to get the batch ideally divisible by + * optimal_parallelism_work_amount + */ + static std::pair split_ideally(size_t batch_dim, + size_t m_dim, + size_t optimal_parallelism_work_amount); + /** + * @brief Splits m_dim to minimize kernel_m in order to reduce waiting time for idle threads at the last + * parallel loop iteration. + */ + static std::pair split_minimize_kernel_wa(size_t batch_dim, + size_t m_dim, + size_t optimal_parallelism_work_amount); + /** + * @brief Splits m_dim to get the batch in (optimal_parallelism_work_amount, 2 * + * optimal_parallelism_work_amount) interval + */ + static std::pair split_fallback_increase_parallel_wa(size_t batch_dim, + size_t m_dim, + size_t optimal_parallelism_work_amount); static std::unordered_set find_applicable_brgemms(const lowered::LinearIRCPtr& linear_ir, bool check_dynamic_wa = true); @@ -58,6 +92,7 @@ class MHAParallelWAOptimizer : public lowered::pass::RuntimeOptimizer { size_t m_concurrency = 0; static const size_t m_dim_M_idx; + static const size_t m_min_kernel_m; }; } // namespace ov::snippets::lowered::pass diff --git a/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp b/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp index f327bf375a1792..f33acf18984bdd 100644 --- a/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp +++ b/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp @@ -5,8 +5,10 @@ #include "snippets/lowered/pass/mha_parallel_wa_optimizer.hpp" #include +#include #include #include +#include #include #include #include @@ -22,14 +24,60 @@ #include "snippets/lowered/loop_port.hpp" #include "snippets/lowered/pass/runtime_optimizer.hpp" #include "snippets/op/brgemm.hpp" -#include "snippets/pass/split_dimension_m.hpp" #include "snippets/runtime_configurator.hpp" #include "snippets/utils/loop_utils.hpp" #include "snippets/utils/utils.hpp" namespace ov::snippets::lowered::pass { +namespace { +std::vector get_updated_order(const std::vector& order, size_t m_index) { + std::vector new_order(order.size() + 1, 0); + size_t shift_idx = 0; + for (size_t i = 0; i < order.size(); ++i) { + if (order[i] < m_index) { + new_order[i + shift_idx] = order[i]; + } else if (order[i] == m_index) { + new_order[i + shift_idx++] = order[i]; + new_order[i + shift_idx] = order[i] + 1; + } else { + new_order[i + shift_idx] = order[i] + 1; + } + } + return new_order; +} + +ov::snippets::VectorDims unsqueeze_m_dim(ov::snippets::VectorDims shape, size_t m_index) { + shape.insert(shape.begin() + m_index, 1); + return shape; +} + +ov::snippets::VectorDims reshape_m_dim(ov::snippets::VectorDims shape, + size_t m_index, + size_t batch_m_dim, + size_t new_m_dim) { + if (shape[m_index] == 1) + return unsqueeze_m_dim(std::move(shape), m_index); + shape[m_index] = new_m_dim; + shape.insert(shape.begin() + m_index, batch_m_dim); + return shape; +} + +bool is_prime_number(size_t value) { + if (value == 2lu || value == 3lu) + return true; + if (value == 1 || value % 2 == 0 || value % 3 == 0) + return false; + const auto root = std::sqrt(value) + 1; + for (size_t divisor = 5; divisor < root; divisor += 6) { + if ((value % divisor == 0) || (value % (divisor + 2) == 0)) + return false; + } + return true; +} +} // namespace const size_t MHAParallelWAOptimizer::m_dim_M_idx = 1; +const size_t MHAParallelWAOptimizer::m_min_kernel_m = 32; MHAParallelWAOptimizer::MHAParallelWAOptimizer(const lowered::LinearIRCPtr& linear_ir, const RuntimeConfigurator* configurator) @@ -56,7 +104,7 @@ MHAParallelWAOptimizer::MHAParallelWAOptimizer(const lowered::LinearIRCPtr& line : utils::get_output_dim_idx(layout, m_dim_M_idx); m_dim_M_idces[i] = dim_idx; const auto m_idx = i < configurator->get_in_num() ? dim_idx : layout.size() - 2; - m_optimized_layouts[i] = ov::snippets::pass::SplitDimensionM::get_updated_order(layout, m_idx); + m_optimized_layouts[i] = get_updated_order(layout, m_idx); } } @@ -64,10 +112,7 @@ bool MHAParallelWAOptimizer::run(const lowered::LinearIR& linear_ir) { OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::MHAParallelWAOptimizer") const auto& config = m_configurator->get_config(); size_t new_batch_dim = 0, new_kernel_dim = 0; - if (!ov::snippets::pass::SplitDimensionM::split(config->master_shape, - m_concurrency, - new_batch_dim, - new_kernel_dim)) { + if (!split(config->master_shape, m_concurrency, new_batch_dim, new_kernel_dim)) { return false; } auto& master_shape = config->master_shape; @@ -98,13 +143,9 @@ bool MHAParallelWAOptimizer::run(const lowered::LinearIR& linear_ir) { } for (size_t i = 0; i < m_configurator->get_io_num(); ++i) { - config->io_shapes[i] = - m_unsqueezed_params.count(i) - ? ov::snippets::pass::SplitDimensionM::unsqueeze_m_dim(config->io_shapes[i], m_dim_M_idces[i]) - : ov::snippets::pass::SplitDimensionM::reshape_m_dim(config->io_shapes[i], - m_dim_M_idces[i], - new_batch_dim, - new_kernel_dim); + config->io_shapes[i] = m_unsqueezed_params.count(i) + ? unsqueeze_m_dim(config->io_shapes[i], m_dim_M_idces[i]) + : reshape_m_dim(config->io_shapes[i], m_dim_M_idces[i], new_batch_dim, new_kernel_dim); } config->io_layouts = m_optimized_layouts; return true; @@ -211,4 +252,99 @@ std::vector MHAParallelWAOptimizer::find_loops_to_ return loops_to_split; } +bool MHAParallelWAOptimizer::split(const ov::Shape& shape, + size_t optimal_parallelism_work_amount, + size_t& batch_m_dim, + size_t& new_m_dim) { + const auto batch_dim = std::accumulate(shape.rbegin() + 2, shape.rend(), size_t(1), std::multiplies()); + const auto m_dim = *(shape.rbegin() + 1); + if (is_prime_number(m_dim)) + return false; + + // We skip optimization if the current batch is optimal for concurrency + if (batch_dim % optimal_parallelism_work_amount == 0) + return false; + + auto split_is_done = [&batch_m_dim]() { + return batch_m_dim != 1; + }; + + std::tie(batch_m_dim, new_m_dim) = split_ideally(batch_dim, m_dim, optimal_parallelism_work_amount); + if (split_is_done()) + return true; + + std::tie(batch_m_dim, new_m_dim) = split_minimize_kernel_wa(batch_dim, m_dim, optimal_parallelism_work_amount); + if (split_is_done()) + return true; + // If all the previous heuristics failed, fallback heuristic is used, which reflects the old splitting behavior + if (batch_dim < optimal_parallelism_work_amount) + std::tie(batch_m_dim, new_m_dim) = + split_fallback_increase_parallel_wa(batch_dim, m_dim, optimal_parallelism_work_amount); + return split_is_done(); +} + +std::pair MHAParallelWAOptimizer::split_ideally(size_t batch_dim, + size_t m_dim, + size_t optimal_parallelism_work_amount) { + // Ideal case #1: M can be split on the parts one of which complements the batch dimension to the optimal parallel + // work amount In this case, each thread will execute the Snippets kernel once + const size_t lower_bound = optimal_parallelism_work_amount / batch_dim; + if (lower_bound * batch_dim == optimal_parallelism_work_amount && m_dim % lower_bound == 0) + return std::make_pair(lower_bound, m_dim / lower_bound); + + // Ideal case #2: M is divisible by optimal parallel work amount, and the new_m_dim is big enough + // In this case, each thread will execute the Snippets kernel 'batch_dim' times + if (m_dim % optimal_parallelism_work_amount == 0) { + const auto new_m_dim = m_dim / optimal_parallelism_work_amount; + if (new_m_dim >= m_min_kernel_m) + return std::make_pair(optimal_parallelism_work_amount, new_m_dim); + } + + return std::make_pair(1, m_dim); +} + +std::pair MHAParallelWAOptimizer::split_fallback_increase_parallel_wa( + size_t batch_dim, + size_t m_dim, + size_t optimal_parallelism_work_amount) { + std::pair splited = {1, m_dim}; + const size_t upper_bound = utils::div_up(2 * optimal_parallelism_work_amount, batch_dim); + for (size_t divisor_0 = upper_bound - 1; divisor_0 > 1; divisor_0--) { + size_t divisor_1 = m_dim / divisor_0; + if (divisor_1 * divisor_0 == m_dim) + return divisor_0 * batch_dim >= optimal_parallelism_work_amount ? std::make_pair(divisor_0, divisor_1) + : splited; + } + return splited; +} + +std::pair MHAParallelWAOptimizer::split_minimize_kernel_wa(size_t batch_dim, + size_t m_dim, + size_t optimal_parallelism_work_amount) { + // This heuristic minimizes 'm_kernel' (=> maximizes 'm_batch') with a limitation that 'm_kernel >= min_kernel_m'. + // In other words, it tries to find 'm_kernel' bigger than 'm_min_kernel_m' and at the same time as close as possible + // to this value. + std::pair best_result = {1, m_dim}; + for (size_t divisor = 2; divisor < std::sqrt(m_dim); ++divisor) { + if (m_dim % divisor != 0) + continue; + // If divisor is more than 'm_min_kernel_m', divisor becomes 'm_kernel', + // guaranteeing the most optimal implementation from 'm_kernel' minimization perspective. + if (divisor >= m_min_kernel_m) + return std::make_pair(m_dim / divisor, divisor); + + // If divisor is less than 'm_min_kernel_m', divisor becomes m_batch. + // However, it is not guaranteed that the current 'm_kernel = m_dim / divisor' is minimized, as one of the next + // divisors can be more optimal. So in this case the best result is remembered + const size_t m_kernel = m_dim / divisor; + if (m_kernel >= m_min_kernel_m) { + best_result.first = divisor; + best_result.second = m_kernel; + } + } + if (best_result.first * batch_dim >= optimal_parallelism_work_amount) + return best_result; + return std::make_pair(1, m_dim); +} + } // namespace ov::snippets::lowered::pass diff --git a/src/common/snippets/src/pass/common_optimizations.cpp b/src/common/snippets/src/pass/common_optimizations.cpp index 798a19c8ae19fc..73831f907f65d4 100644 --- a/src/common/snippets/src/pass/common_optimizations.cpp +++ b/src/common/snippets/src/pass/common_optimizations.cpp @@ -18,7 +18,6 @@ #include "snippets/pass/extract_unsupported_transposes.hpp" #include "snippets/pass/fq_decomposition.hpp" #include "snippets/pass/softmax_reshape_elimination.hpp" -#include "snippets/pass/split_dimension_m.hpp" #include "snippets/pass/subgraph_manager.hpp" #include "snippets/pass/transform_convert.hpp" #include "snippets/pass/validate.hpp" @@ -60,10 +59,6 @@ CommonOptimizations::CommonOptimizations(const CommonOptimizations::Config& conf ov::snippets::pass::ExtractUnsupportedTransposes, is_domain_sensitive, config.get_transpose_support_callback()); - REGISTER_SNIPPETS_PASS(subgraph_manager, - ov::snippets::pass::SplitDimensionM, - is_domain_sensitive && config.get_split_m_dimension(), - config.get_concurrency()); subgraph_manager.run_passes(subgraph); // Validate the body after all common optimizations diff --git a/src/common/snippets/tests/src/pass/mha_tokenization.cpp b/src/common/snippets/tests/src/pass/mha_tokenization.cpp index e6d85987416434..3fed3cecfe4957 100644 --- a/src/common/snippets/tests/src/pass/mha_tokenization.cpp +++ b/src/common/snippets/tests/src/pass/mha_tokenization.cpp @@ -286,184 +286,6 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Dynamic_Transpose_fusion) { execute_and_validate_function(*this, f); } -TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM) { - const auto& f = MHASplitMFunction( - std::vector{{128, 12, 64}, {128, 12, 64}, {12, 128, 128}, {128, 12, 64}}, - std::vector({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}), - std::vector{{2, 64, 12, 64}, {12, 1, 64, 128}, {12, 2, 64, 128}, {1, 128, 12, 64}, {128, 12, 64}}, - false); - common_config = get_default_common_optimizations_config(); - common_config.set_concurrency(24); - execute_and_validate_function(*this, f); -} - -TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM_Const_B) { - const auto& f = MHASplitMFunction( - std::vector{{128, 12, 64}, {128, 12, 64}, {12, 128, 128}, {128, 12, 64}}, - std::vector({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}), - std::vector{{2, 64, 12, 64}, {12, 1, 64, 128}, {12, 2, 64, 128}, {1, 128, 12, 64}, {128, 12, 64}}, - false); - common_config = get_default_common_optimizations_config(); - common_config.set_concurrency(24); - execute_and_validate_function(*this, f); -} - -TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM_withMul) { - const auto& f = MHASplitMFunction( - std::vector{{128, 12, 64}, {128, 12, 64}, {12, 128, 128}, {128, 12, 64}}, - std::vector({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}), - std::vector{{4, 32, 12, 64}, {12, 1, 64, 128}, {12, 4, 32, 128}, {1, 128, 12, 64}, {128, 12, 64}}, - true); - common_config = get_default_common_optimizations_config(); - common_config.set_concurrency(16); - execute_and_validate_function(*this, f); -} - -TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM_withMul_Const_B) { - const auto& f = MHASplitMFunction( - std::vector{{128, 12, 64}, {128, 12, 64}, {12, 128, 128}, {128, 12, 64}}, - std::vector({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}), - std::vector{{4, 32, 12, 64}, {12, 1, 64, 128}, {12, 4, 32, 128}, {1, 128, 12, 64}, {128, 12, 64}}, - true); - common_config = get_default_common_optimizations_config(); - common_config.set_concurrency(16); - execute_and_validate_function(*this, f); -} - -TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM) { - const auto& f = MHASplitMFunction( - std::vector{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}}, - std::vector({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}), - std::vector{{1, 12, 32, 16, 64}, - {1, 16, 1, 64, 384}, - {1, 1, 1, 1, 384}, - {1, 1, 384, 16, 64}, - {1, 384, 16, 64}}, - false); - common_config = get_default_common_optimizations_config(); - common_config.set_concurrency(60); - execute_and_validate_function(*this, f); -} - -TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM_Const_B) { - const auto& f = MHASplitMFunction( - std::vector{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}}, - std::vector({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}), - std::vector{{1, 12, 32, 16, 64}, - {1, 16, 1, 64, 384}, - {1, 1, 1, 1, 384}, - {1, 1, 384, 16, 64}, - {1, 384, 16, 64}}, - false); - common_config = get_default_common_optimizations_config(); - common_config.set_concurrency(60); - execute_and_validate_function(*this, f); -} - -TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM_withMul) { - const auto& f = MHASplitMFunction( - std::vector{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}}, - std::vector({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}), - std::vector{{1, 12, 32, 16, 64}, - {1, 16, 1, 64, 384}, - {1, 1, 1, 1, 384}, - {1, 1, 384, 16, 64}, - {1, 384, 16, 64}}, - true); - common_config = get_default_common_optimizations_config(); - common_config.set_concurrency(60); - execute_and_validate_function(*this, f); -} - -TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM_withMul_Const_B) { - const auto& f = MHASplitMFunction( - std::vector{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}}, - std::vector({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}), - std::vector{{1, 12, 32, 16, 64}, - {1, 16, 1, 64, 384}, - {1, 1, 1, 1, 384}, - {1, 1, 384, 16, 64}, - {1, 384, 16, 64}}, - true); - common_config = get_default_common_optimizations_config(); - common_config.set_concurrency(60); - execute_and_validate_function(*this, f); -} - -TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHAWOTranspose_SplitM) { - const auto& f = MHAWOTransposeSplitMFunction( - std::vector{{10, 9216, 128}, {10, 128, 9216}, {10, 9216, 128}}, - std::vector({ov::element::f32, ov::element::f32, ov::element::f32}), - std::vector{{10, 18, 512, 128}, {10, 1, 128, 9216}, {10, 1, 9216, 128}, {10, 9216, 128}}); - common_config = get_default_common_optimizations_config(); - common_config.set_concurrency(18); - execute_and_validate_function(*this, f); -} - -TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHAWOTranspose_SplitM_Const_B) { - const auto& f = MHAWOTransposeSplitMFunction( - std::vector{{10, 9216, 128}, {10, 128, 9216}, {10, 9216, 128}}, - std::vector({ov::element::f32, ov::element::f32, ov::element::f32}), - std::vector{{10, 18, 512, 128}, {10, 1, 128, 9216}, {10, 1, 9216, 128}, {10, 9216, 128}}); - common_config = get_default_common_optimizations_config(); - common_config.set_concurrency(18); - execute_and_validate_function(*this, f); -} - -TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_SplitM_AlmostAllThreads) { - const auto& f = MHAWOTransposeSplitMFunction( - std::vector{{5, 30, 32}, {5, 32, 30}, {5, 30, 32}}, - std::vector({ov::element::f32, ov::element::f32, ov::element::f32}), - std::vector{{5, 10, 3, 32}, {5, 1, 32, 30}, {5, 1, 30, 32}, {5, 30, 32}}); - common_config = get_default_common_optimizations_config(); - common_config.set_concurrency(32); - execute_and_validate_function(*this, f); -} - -TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_SplitM_AlmostAllThreads_Const_B) { - const auto& f = MHAWOTransposeSplitMFunction( - std::vector{{5, 30, 32}, {5, 32, 30}, {5, 30, 32}}, - std::vector({ov::element::f32, ov::element::f32, ov::element::f32}), - std::vector{{5, 10, 3, 32}, {5, 1, 32, 30}, {5, 1, 30, 32}, {5, 30, 32}}); - common_config = get_default_common_optimizations_config(); - common_config.set_concurrency(32); - execute_and_validate_function(*this, f); -} - -TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_4D_SplitM_DynamicParameter) { - const auto& f = MHAFunction( - std::vector{{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 128, -1}, {1, 128, 16, 64}}, - std::vector({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}), - false, - false); - common_config = get_default_common_optimizations_config(); - common_config.set_concurrency(32); - execute_and_validate_function(*this, f); -} - -TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHASelect_SplitM) { - const auto& f = MHASelectSplitMFunction( - std::vector{{8, 512, 18}, {8, 18, 64}, {1, 512, 64}, {1, 1, 64}, {8, 64, 512}}, - std::vector{{8, 2, 256, 18}, - {8, 1, 18, 64}, - {1, 2, 256, 64}, - {1, 1, 1, 64}, - {8, 1, 64, 512}, - {8, 512, 512}}); - common_config = get_default_common_optimizations_config(); - common_config.set_concurrency(16); - execute_and_validate_function(*this, f); -} - -TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHASelect_SplitM_ScalarParams) { - const auto& f = MHASelectSplitMFunction( - std::vector{{8, 512, 18}, {8, 18, 64}, {1}, {64}, {8, 64, 512}}, - std::vector{{8, 2, 256, 18}, {8, 1, 18, 64}, {}, {}, {8, 1, 64, 512}, {8, 512, 512}}); - common_config = get_default_common_optimizations_config(); - common_config.set_concurrency(16); - execute_and_validate_function(*this, f); -} - TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Reshape_extraction) { const auto& f = MHAWithExtractedReshapeFunction(std::vector{{400, 196, 80}, {400, 80, 196}, diff --git a/src/common/snippets/tests/src/utils/split_dim_m.cpp b/src/common/snippets/tests/src/utils/split_dim_m.cpp index c65423a509181e..33bc81361734da 100644 --- a/src/common/snippets/tests/src/utils/split_dim_m.cpp +++ b/src/common/snippets/tests/src/utils/split_dim_m.cpp @@ -5,7 +5,7 @@ #include "utils/split_dim_m.hpp" #include "common_test_utils/ov_test_utils.hpp" -#include "snippets/pass/split_dimension_m.hpp" +#include "snippets/lowered/pass/mha_parallel_wa_optimizer.hpp" #include "snippets/utils/utils.hpp" namespace ov { @@ -42,10 +42,10 @@ TEST_P(SplitDimensionMTest, SplitDimensionM) { shape = {input.cur_m, last_dim}; } size_t batch_m_dim, new_m_dim; - bool result = ov::snippets::pass::SplitDimensionM::split(shape, - input.concurrency, - batch_m_dim, - new_m_dim); + bool result = ov::snippets::lowered::pass::MHAParallelWAOptimizer::split(shape, + input.concurrency, + batch_m_dim, + new_m_dim); ASSERT_EQ(result, reference.is_split); if (result) { diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 9607f7205defa8..b79ca5880222ef 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -198,7 +198,6 @@ # include "openvino/op/softmax.hpp" # include "openvino/op/subtract.hpp" # include "snippets/pass/common_optimizations.hpp" -# include "snippets/pass/split_dimension_m.hpp" # include "snippets/utils/tokenization_utils.hpp" # include "transformations/common_optimizations/rms_fusion.hpp" # include "transformations/cpu_opset/common/op/sdpa.hpp" @@ -1378,23 +1377,6 @@ void Transformations::MainSnippets() { (is_fp16 && ov::intel_cpu::brgemm_utils::is_fp16_supported()) || (is_int8 && ov::intel_cpu::brgemm_utils::is_i8_supported()); }; - auto is_unsupported_parallel_work_amount = [&](const std::shared_ptr& n, - const ov::PartialShape& shape) { - // SplitDimensionM transformation doesn't support dynamic shapes, so M dim is split in runtime configurator - if (shape.is_dynamic()) { - return false; - } - auto parallel_work_amount = ov::Dimension(1); - if (shape.size() > 2) { - parallel_work_amount = - std::accumulate(shape.rbegin() + 2, shape.rend(), parallel_work_amount, std::multiplies<>()); - } - // Ticket 160154: enable tokenization for MHA with insufficient parallel work amount - const auto is_unsupported_parallel_work_amount = - static_cast(parallel_work_amount.get_length()) < common_optimizations_config.get_concurrency() && - !SplitDimensionM::can_be_optimized(n, common_optimizations_config.get_concurrency()); - return is_unsupported_parallel_work_amount; - }; #endif // OPENVINO_ARCH_X86_64 auto is_supported_op = []([[maybe_unused]] const std::shared_ptr& n) -> bool { @@ -1482,11 +1464,7 @@ void Transformations::MainSnippets() { while (!ov::is_type(child)) { child = child->get_output_target_inputs(0).begin()->get_node()->shared_from_this(); } - if (!is_supported_matmul(child)) - return true; - - const auto& pshape = child->get_input_partial_shape(0); - return is_unsupported_parallel_work_amount(n, pshape); + return !is_supported_matmul(child); }, TokenizeMHASnippets); CPU_SET_CALLBACK_X64( @@ -1494,8 +1472,6 @@ void Transformations::MainSnippets() { [&](const std::shared_ptr& n) -> bool { if (!is_supported_matmul(n)) return true; - if (is_unsupported_parallel_work_amount(n, n->get_output_partial_shape(0))) - return true; // We've only tested MLP sequence tokenization on small model shapes // So we limit tokenization to sequences with small shapes to avoid unexpected behavior @@ -1514,8 +1490,7 @@ void Transformations::MainSnippets() { CPU_SET_CALLBACK_X64( snippetsManager, [&](const std::shared_ptr& n) -> bool { - return !is_supported_matmul(n) || - is_unsupported_parallel_work_amount(n, n->get_output_partial_shape(0)); + return !is_supported_matmul(n); }, ExtractReshapesFromMHA); } diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp index b1db74d4b463e1..b8eb7f2e34ec91 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp @@ -132,8 +132,6 @@ const std::vector& disabled_test_patterns() { std::regex(R"(.*FakeConvertLayerTest.*dataPrecision=bf16.*)"), // Need to generate sequence exactly in the i64 data type. Enable in scope of i64 enabling. std::regex(R"(.*RandomUniformLayerTestCPU.*OutPrc=i64.*)"), - // Issue: 123815 (Tests are sensintive to available thread count on testing machines) - std::regex(R"(.*smoke_Snippets_MHA_.?D_SplitDimensionM_static.*)"), // Issue: 126095 std::regex(R"(^smoke_Multinomial(?:Static|Dynamic)+(?:Log)*.*seed_g=0_seed_o=0.*device=CPU.*)"), // Issue: 129931 diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha_split_dim_m.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha_split_dim_m.cpp index f26644548cba04..ef032328ecba00 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha_split_dim_m.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha_split_dim_m.cpp @@ -9,45 +9,19 @@ namespace ov { namespace test { namespace snippets { - namespace { -static ov::AnyMap enable_callback(size_t num_threads) { - return ov::AnyMap({ov::intel_cpu::snippets_mode(ov::intel_cpu::SnippetsMode::ENABLE), - ov::inference_num_threads(num_threads)}); -} - static ov::AnyMap set_num_threads(size_t num_threads) { return ov::AnyMap({ov::inference_num_threads(num_threads)}); } -INSTANTIATE_TEST_SUITE_P( - smoke_Snippets_MHA_4D_SplitDimensionM_static, - MHAWithThreadCount, - ::testing::Combine(::testing::ValuesIn(SNIPPETS_TESTS_STATIC_SHAPES({{1, 128, 2, 64}, {1, 128, 2, 64}, {1, 1, 1, 1}, {1, 128, 2, 64}})), - ::testing::ValuesIn(precision_f32(4)), - ::testing::Values(ov::element::f32), - ::testing::Values(true), - ::testing::Values(7), // Subgraph + 4 Reshapes, Transpose1 on inputs and 1 Reshape on output - ::testing::Values(2), - ::testing::Values(ov::test::utils::DEVICE_CPU), - ::testing::Values(enable_callback(4))), // 4 Threads - MHAWithThreadCount::getTestCaseName); - -INSTANTIATE_TEST_SUITE_P( - smoke_Snippets_MHA_3D_SplitDimensionM_static, - MHAWithThreadCount, - ::testing::Combine(::testing::ValuesIn(SNIPPETS_TESTS_STATIC_SHAPES({{384, 2, 64}, {384, 2, 64}, {1, 384, 384}, {384, 2, 64}})), - ::testing::ValuesIn(precision_f32(4)), - ::testing::Values(ov::element::f32), - ::testing::Values(true), - ::testing::Values(10), // Subgraph + 4 Reshapes on inputs and 1 Reshape on output + 4 Transposes - ::testing::Values(1), // MHA - ::testing::Values(ov::test::utils::DEVICE_CPU), - ::testing::Values(enable_callback(4))), // 4 Threads - MHAWithThreadCount::getTestCaseName); - -std::vector> splitm_dynamic_shapes_4d = { +std::vector> splitm_shapes_4d = { + { + {PartialShape{}, {{1, 128, 2, 64}}}, + {PartialShape{}, {{1, 128, 2, 64}}}, + {PartialShape{}, {{1, 1, 1, 1}}}, + {PartialShape{}, {{1, 128, 2, 64}}}, + }, { {PartialShape{-1, -1, -1, -1}, {{1, 128, 2, 64}, {1, 17, 2, 64}, {1, 128, 2, 64}}}, {PartialShape{-1, -1, -1, -1}, {{1, 128, 2, 64}, {1, 17, 2, 64}, {1, 128, 2, 64}}}, @@ -77,19 +51,25 @@ std::vector> splitm_dynamic_shapes_4d = { static constexpr size_t expected_nodes_mha_splitm_4d_dyn = 2; INSTANTIATE_TEST_SUITE_P( - smoke_Snippets_MHA_4D_SplitDimensionM_dynamic, + smoke_Snippets_MHA_4D_SplitDimensionM, MHAWithThreadCount, - ::testing::Combine(::testing::ValuesIn(splitm_dynamic_shapes_4d), + ::testing::Combine(::testing::ValuesIn(splitm_shapes_4d), ::testing::ValuesIn(precision_f32(4)), ::testing::Values(ov::element::f32), ::testing::Values(false), ::testing::Values(expected_nodes_mha_splitm_4d_dyn), ::testing::Values(2), // Transpose1 + MHA ::testing::Values(ov::test::utils::DEVICE_CPU), - ::testing::Values(set_num_threads(4))), // 4 Threads + ::testing::Values(set_num_threads(4))), MHAWithThreadCount::getTestCaseName); -std::vector> splitm_dynamic_shapes_3d = { +std::vector> splitm_shapes_3d = { + { + {PartialShape{}, {{384, 2, 64}}}, + {PartialShape{}, {{384, 2, 64}}}, + {PartialShape{}, {{1, 384, 384}}}, + {PartialShape{}, {{384, 2, 64}}}, + }, { {PartialShape{-1, -1, -1}, {{128, 2, 64}, {17, 2, 64}, {128, 2, 64}}}, {PartialShape{-1, -1, -1}, {{128, 2, 64}, {17, 2, 64}, {128, 2, 64}}}, @@ -105,16 +85,16 @@ std::vector> splitm_dynamic_shapes_3d = { }; INSTANTIATE_TEST_SUITE_P( - smoke_Snippets_MHA_3D_SplitDimensionM_dynamic, + smoke_Snippets_MHA_3D_SplitDimensionM, MHAWithThreadCount, - ::testing::Combine(::testing::ValuesIn(splitm_dynamic_shapes_3d), + ::testing::Combine(::testing::ValuesIn(splitm_shapes_3d), ::testing::ValuesIn(precision_f32(4)), ::testing::Values(ov::element::f32), ::testing::Values(false), ::testing::Values(5), // Subgraph + 4 Transpose ::testing::Values(2), // MHA + one of the transposes is executed via Subgraph (because callback is disabled) ::testing::Values(ov::test::utils::DEVICE_CPU), - ::testing::Values(set_num_threads(4))), // 4 Threads + ::testing::Values(set_num_threads(4))), MHAWithThreadCount::getTestCaseName); } // namespace diff --git a/src/tests/ov_helpers/ov_snippets_models/include/subgraph_mha.hpp b/src/tests/ov_helpers/ov_snippets_models/include/subgraph_mha.hpp index aabcb532f364f0..5f674828edc8e8 100644 --- a/src/tests/ov_helpers/ov_snippets_models/include/subgraph_mha.hpp +++ b/src/tests/ov_helpers/ov_snippets_models/include/subgraph_mha.hpp @@ -107,19 +107,6 @@ class MHA2DFunction : public SnippetsFunctionBase { const std::vector precisions; }; -class MHASplitMFunction : public MHAFunction { -public: - explicit MHASplitMFunction(const std::vector& inputShapes, const std::vector& precisions, - const std::vector& reshapes, bool with_mul = true) - : MHAFunction(inputShapes, precisions, with_mul), reshapes(reshapes) { - OPENVINO_ASSERT(reshapes.size() == 5, "Got invalid number of Reshape shapes"); - } -protected: - std::shared_ptr initReference() const override; - - std::vector reshapes; -}; - /* Graph: * Transpose1[0,2,3,1] Parameter * \ / diff --git a/src/tests/ov_helpers/ov_snippets_models/src/subgraph_mha.cpp b/src/tests/ov_helpers/ov_snippets_models/src/subgraph_mha.cpp index 3dd7492720f648..5ef33b9f5f0cd4 100644 --- a/src/tests/ov_helpers/ov_snippets_models/src/subgraph_mha.cpp +++ b/src/tests/ov_helpers/ov_snippets_models/src/subgraph_mha.cpp @@ -291,69 +291,6 @@ std::shared_ptr MHA2DFunction::initReference() const { return std::make_shared(OutputVector{subgraph}, ngraphParams); } -std::shared_ptr MHASplitMFunction::initReference() const { - auto data0 = std::make_shared(precisions[0], input_shapes[0]); - auto data1 = std::make_shared(precisions[1], input_shapes[1]); - auto data2 = std::make_shared(precisions[2], input_shapes[2]); - auto data3 = std::make_shared(precisions[3], input_shapes[3]); - ov::ParameterVector ngraphParams = {data0, data1, data2, data3}; - - const auto rank_before = input_shapes[1].size(); - const auto transpose1Const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{rank_before}, get_decomposed_order(rank_before)); - const auto transpose1 = std::make_shared(data1, transpose1Const); - - std::shared_ptr subgraph_parent1 = transpose1; - if (with_mul) { - ov::Shape shape(rank_before, 1); - if (transpose1->get_output_partial_shape(0).is_static()) { - shape[rank_before - 3] = transpose1->get_output_shape(0)[rank_before - 3]; - } - const auto mulConst = ov::test::utils::make_constant(precisions[1], shape); - subgraph_parent1 = std::make_shared(transpose1, mulConst); - } - - auto make_reshape = [](const std::shared_ptr& node, const ov::Shape& new_shape) { - auto shape_const = ov::op::v0::Constant::create(ov::element::i32, {new_shape.size()}, new_shape); - return std::make_shared(node, shape_const, true); - }; - - auto reshape0 = make_reshape(data0, reshapes[0]); - auto reshape1 = make_reshape(subgraph_parent1, reshapes[1]); - auto reshape2 = make_reshape(data2, reshapes[2]); - auto reshape3 = make_reshape(data3, reshapes[3]); - OutputVector subgraph_inputs = {reshape0, reshape1, reshape2, reshape3}; - - auto transpose0Param = std::make_shared(precisions[0], reshape0->get_shape()); - auto brgemm1Param = std::make_shared(precisions[1], reshape1->get_shape()); - auto addParam = std::make_shared(precisions[2], reshape2->get_shape()); - auto transpose2Param = std::make_shared(precisions[3], reshape3->get_shape()); - ov::ParameterVector subgraph_params = {transpose0Param, brgemm1Param, addParam, transpose2Param}; - - const auto rank = input_shapes[0].size() + 1; - const auto transpose0Const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{rank}, get_fusion_order_after_split_m(rank, true)); - const auto transpose2Const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{rank}, get_fusion_order_after_split_m(rank, true)); - const auto transpose3Const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{rank}, get_fusion_order_after_split_m(rank, false)); - - const auto transpose0 = std::make_shared(transpose0Param, transpose0Const); - - const auto matMul0 = std::make_shared(transpose0, brgemm1Param); - const auto add = std::make_shared(matMul0, addParam); - const auto softMax = std::make_shared(add, rank - 1); - const auto transpose2 = std::make_shared(transpose2Param, transpose2Const); - const auto matMul1 = std::make_shared(softMax, transpose2); - const auto transpose3 = std::make_shared(matMul1, transpose3Const); - - const auto snippets_result = std::make_shared(transpose3); - - const auto subgraph = std::make_shared(subgraph_inputs, - std::make_shared(ov::OutputVector{snippets_result}, - subgraph_params)); - - auto reshape4 = make_reshape(subgraph, reshapes[4]); - ov::ResultVector results{std::make_shared(reshape4)}; - return std::make_shared(results, ngraphParams, "mha"); -} - std::shared_ptr MHAWithDynamicMulFunction::initOriginal() const { auto transpose0Param = std::make_shared(precisions[0], input_shapes[0]); auto transpose1Param = std::make_shared(precisions[1], input_shapes[1]); From d8128b026ba85b97dc78d54b093dd0cd9cab0bfb Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Thu, 5 Mar 2026 13:17:29 +0100 Subject: [PATCH 2/4] tidy+format --- .../pass/mha_parallel_wa_optimizer.hpp | 6 +- .../pass/mha_parallel_wa_optimizer.cpp | 62 ++++++++++++------- 2 files changed, 44 insertions(+), 24 deletions(-) diff --git a/src/common/snippets/include/snippets/lowered/pass/mha_parallel_wa_optimizer.hpp b/src/common/snippets/include/snippets/lowered/pass/mha_parallel_wa_optimizer.hpp index 4d5c4382c8d3db..f1eb567d4c7ebc 100644 --- a/src/common/snippets/include/snippets/lowered/pass/mha_parallel_wa_optimizer.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/mha_parallel_wa_optimizer.hpp @@ -49,9 +49,9 @@ class MHAParallelWAOptimizer : public lowered::pass::RuntimeOptimizer { * @return true if split was successfull, otherwise false */ static bool split(const ov::Shape& shape, - size_t optimal_parallelism_work_amount, - size_t& batch_m_dim, - size_t& new_m_dim); + size_t optimal_parallelism_work_amount, + size_t& batch_m_dim, + size_t& new_m_dim); private: /** diff --git a/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp b/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp index f33acf18984bdd..7e36cd42af18b4 100644 --- a/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp +++ b/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp @@ -7,10 +7,13 @@ #include #include #include +#include #include #include #include +#include #include +#include #include #include "openvino/core/except.hpp" @@ -55,22 +58,26 @@ ov::snippets::VectorDims reshape_m_dim(ov::snippets::VectorDims shape, size_t m_index, size_t batch_m_dim, size_t new_m_dim) { - if (shape[m_index] == 1) + if (shape[m_index] == 1) { return unsqueeze_m_dim(std::move(shape), m_index); + } shape[m_index] = new_m_dim; shape.insert(shape.begin() + m_index, batch_m_dim); return shape; } bool is_prime_number(size_t value) { - if (value == 2lu || value == 3lu) + if (value == 2LU || value == 3LU) { return true; - if (value == 1 || value % 2 == 0 || value % 3 == 0) + } + if (value == 1 || value % 2 == 0 || value % 3 == 0) { return false; + } const auto root = std::sqrt(value) + 1; for (size_t divisor = 5; divisor < root; divisor += 6) { - if ((value % divisor == 0) || (value % (divisor + 2) == 0)) + if ((value % divisor == 0) || (value % (divisor + 2) == 0)) { return false; + } } return true; } @@ -143,9 +150,10 @@ bool MHAParallelWAOptimizer::run(const lowered::LinearIR& linear_ir) { } for (size_t i = 0; i < m_configurator->get_io_num(); ++i) { - config->io_shapes[i] = m_unsqueezed_params.count(i) - ? unsqueeze_m_dim(config->io_shapes[i], m_dim_M_idces[i]) - : reshape_m_dim(config->io_shapes[i], m_dim_M_idces[i], new_batch_dim, new_kernel_dim); + config->io_shapes[i] = + m_unsqueezed_params.count(i) + ? unsqueeze_m_dim(config->io_shapes[i], m_dim_M_idces[i]) + : reshape_m_dim(config->io_shapes[i], m_dim_M_idces[i], new_batch_dim, new_kernel_dim); } config->io_layouts = m_optimized_layouts; return true; @@ -256,30 +264,36 @@ bool MHAParallelWAOptimizer::split(const ov::Shape& shape, size_t optimal_parallelism_work_amount, size_t& batch_m_dim, size_t& new_m_dim) { - const auto batch_dim = std::accumulate(shape.rbegin() + 2, shape.rend(), size_t(1), std::multiplies()); + const auto batch_dim = + std::accumulate(shape.rbegin() + 2, shape.rend(), static_cast(1), std::multiplies<>()); const auto m_dim = *(shape.rbegin() + 1); - if (is_prime_number(m_dim)) + if (is_prime_number(m_dim)) { return false; + } // We skip optimization if the current batch is optimal for concurrency - if (batch_dim % optimal_parallelism_work_amount == 0) + if (batch_dim % optimal_parallelism_work_amount == 0) { return false; + } auto split_is_done = [&batch_m_dim]() { return batch_m_dim != 1; }; std::tie(batch_m_dim, new_m_dim) = split_ideally(batch_dim, m_dim, optimal_parallelism_work_amount); - if (split_is_done()) + if (split_is_done()) { return true; + } std::tie(batch_m_dim, new_m_dim) = split_minimize_kernel_wa(batch_dim, m_dim, optimal_parallelism_work_amount); - if (split_is_done()) + if (split_is_done()) { return true; + } // If all the previous heuristics failed, fallback heuristic is used, which reflects the old splitting behavior - if (batch_dim < optimal_parallelism_work_amount) + if (batch_dim < optimal_parallelism_work_amount) { std::tie(batch_m_dim, new_m_dim) = split_fallback_increase_parallel_wa(batch_dim, m_dim, optimal_parallelism_work_amount); + } return split_is_done(); } @@ -289,15 +303,17 @@ std::pair MHAParallelWAOptimizer::split_ideally(size_t batch_dim // Ideal case #1: M can be split on the parts one of which complements the batch dimension to the optimal parallel // work amount In this case, each thread will execute the Snippets kernel once const size_t lower_bound = optimal_parallelism_work_amount / batch_dim; - if (lower_bound * batch_dim == optimal_parallelism_work_amount && m_dim % lower_bound == 0) + if (lower_bound * batch_dim == optimal_parallelism_work_amount && m_dim % lower_bound == 0) { return std::make_pair(lower_bound, m_dim / lower_bound); + } // Ideal case #2: M is divisible by optimal parallel work amount, and the new_m_dim is big enough // In this case, each thread will execute the Snippets kernel 'batch_dim' times if (m_dim % optimal_parallelism_work_amount == 0) { const auto new_m_dim = m_dim / optimal_parallelism_work_amount; - if (new_m_dim >= m_min_kernel_m) + if (new_m_dim >= m_min_kernel_m) { return std::make_pair(optimal_parallelism_work_amount, new_m_dim); + } } return std::make_pair(1, m_dim); @@ -311,9 +327,10 @@ std::pair MHAParallelWAOptimizer::split_fallback_increase_parall const size_t upper_bound = utils::div_up(2 * optimal_parallelism_work_amount, batch_dim); for (size_t divisor_0 = upper_bound - 1; divisor_0 > 1; divisor_0--) { size_t divisor_1 = m_dim / divisor_0; - if (divisor_1 * divisor_0 == m_dim) + if (divisor_1 * divisor_0 == m_dim) { return divisor_0 * batch_dim >= optimal_parallelism_work_amount ? std::make_pair(divisor_0, divisor_1) : splited; + } } return splited; } @@ -322,16 +339,18 @@ std::pair MHAParallelWAOptimizer::split_minimize_kernel_wa(size_ size_t m_dim, size_t optimal_parallelism_work_amount) { // This heuristic minimizes 'm_kernel' (=> maximizes 'm_batch') with a limitation that 'm_kernel >= min_kernel_m'. - // In other words, it tries to find 'm_kernel' bigger than 'm_min_kernel_m' and at the same time as close as possible - // to this value. + // In other words, it tries to find 'm_kernel' bigger than 'm_min_kernel_m' and at the same time as close as + // possible to this value. std::pair best_result = {1, m_dim}; for (size_t divisor = 2; divisor < std::sqrt(m_dim); ++divisor) { - if (m_dim % divisor != 0) + if (m_dim % divisor != 0) { continue; + } // If divisor is more than 'm_min_kernel_m', divisor becomes 'm_kernel', // guaranteeing the most optimal implementation from 'm_kernel' minimization perspective. - if (divisor >= m_min_kernel_m) + if (divisor >= m_min_kernel_m) { return std::make_pair(m_dim / divisor, divisor); + } // If divisor is less than 'm_min_kernel_m', divisor becomes m_batch. // However, it is not guaranteed that the current 'm_kernel = m_dim / divisor' is minimized, as one of the next @@ -342,8 +361,9 @@ std::pair MHAParallelWAOptimizer::split_minimize_kernel_wa(size_ best_result.second = m_kernel; } } - if (best_result.first * batch_dim >= optimal_parallelism_work_amount) + if (best_result.first * batch_dim >= optimal_parallelism_work_amount) { return best_result; + } return std::make_pair(1, m_dim); } From 365cce4a3dbcefaddf68cedb7c2889873c66365f Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Thu, 5 Mar 2026 14:55:43 +0100 Subject: [PATCH 3/4] tidy --- .../snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp b/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp index 7e36cd42af18b4..74fb031336daa8 100644 --- a/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp +++ b/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp @@ -28,6 +28,7 @@ #include "snippets/lowered/pass/runtime_optimizer.hpp" #include "snippets/op/brgemm.hpp" #include "snippets/runtime_configurator.hpp" +#include "snippets/shape_types.hpp" #include "snippets/utils/loop_utils.hpp" #include "snippets/utils/utils.hpp" From eee797e5c0ed7e97f17ed2d23696289d787c61bd Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Thu, 5 Mar 2026 21:02:58 +0100 Subject: [PATCH 4/4] tidy --- .../intel_cpu/src/transformations/transformation_pipeline.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index b79ca5880222ef..86801c33d670d5 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -181,9 +181,6 @@ #endif #if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) -# include -# include - # include "low_precision/convolution_backprop_data.hpp" # include "low_precision/fold_convert.hpp" # include "low_precision/fuse_convert.hpp"