55#include " snippets/lowered/pass/mha_parallel_wa_optimizer.hpp"
66
77#include < algorithm>
8+ #include < cmath>
89#include < cstddef>
910#include < iterator>
11+ #include < numeric>
1012#include < set>
1113#include < unordered_set>
1214#include < vector>
2224#include " snippets/lowered/loop_port.hpp"
2325#include " snippets/lowered/pass/runtime_optimizer.hpp"
2426#include " snippets/op/brgemm.hpp"
25- #include " snippets/pass/split_dimension_m.hpp"
2627#include " snippets/runtime_configurator.hpp"
2728#include " snippets/utils/loop_utils.hpp"
2829#include " snippets/utils/utils.hpp"
2930
3031namespace ov ::snippets::lowered::pass {
32+ namespace {
33+ std::vector<size_t > get_updated_order (const std::vector<size_t >& order, size_t m_index) {
34+ std::vector<size_t > new_order (order.size () + 1 , 0 );
35+ size_t shift_idx = 0 ;
36+ for (size_t i = 0 ; i < order.size (); ++i) {
37+ if (order[i] < m_index) {
38+ new_order[i + shift_idx] = order[i];
39+ } else if (order[i] == m_index) {
40+ new_order[i + shift_idx++] = order[i];
41+ new_order[i + shift_idx] = order[i] + 1 ;
42+ } else {
43+ new_order[i + shift_idx] = order[i] + 1 ;
44+ }
45+ }
46+ return new_order;
47+ }
48+
49+ ov::snippets::VectorDims unsqueeze_m_dim (ov::snippets::VectorDims shape, size_t m_index) {
50+ shape.insert (shape.begin () + m_index, 1 );
51+ return shape;
52+ }
53+
54+ ov::snippets::VectorDims reshape_m_dim (ov::snippets::VectorDims shape,
55+ size_t m_index,
56+ size_t batch_m_dim,
57+ size_t new_m_dim) {
58+ if (shape[m_index] == 1 )
59+ return unsqueeze_m_dim (std::move (shape), m_index);
60+ shape[m_index] = new_m_dim;
61+ shape.insert (shape.begin () + m_index, batch_m_dim);
62+ return shape;
63+ }
64+
65+ bool is_prime_number (size_t value) {
66+ if (value == 2lu || value == 3lu)
67+ return true ;
68+ if (value == 1 || value % 2 == 0 || value % 3 == 0 )
69+ return false ;
70+ const auto root = std::sqrt (value) + 1 ;
71+ for (size_t divisor = 5 ; divisor < root; divisor += 6 ) {
72+ if ((value % divisor == 0 ) || (value % (divisor + 2 ) == 0 ))
73+ return false ;
74+ }
75+ return true ;
76+ }
77+ } // namespace
3178
3279const size_t MHAParallelWAOptimizer::m_dim_M_idx = 1 ;
80+ const size_t MHAParallelWAOptimizer::m_min_kernel_m = 32 ;
3381
3482MHAParallelWAOptimizer::MHAParallelWAOptimizer (const lowered::LinearIRCPtr& linear_ir,
3583 const RuntimeConfigurator* configurator)
@@ -56,18 +104,15 @@ MHAParallelWAOptimizer::MHAParallelWAOptimizer(const lowered::LinearIRCPtr& line
56104 : utils::get_output_dim_idx (layout, m_dim_M_idx);
57105 m_dim_M_idces[i] = dim_idx;
58106 const auto m_idx = i < configurator->get_in_num () ? dim_idx : layout.size () - 2 ;
59- m_optimized_layouts[i] = ov::snippets::pass::SplitDimensionM:: get_updated_order (layout, m_idx);
107+ m_optimized_layouts[i] = get_updated_order (layout, m_idx);
60108 }
61109}
62110
63111bool MHAParallelWAOptimizer::run (const lowered::LinearIR& linear_ir) {
64112 OV_ITT_SCOPED_TASK (ov::pass::itt::domains::SnippetsTransform, " Snippets::MHAParallelWAOptimizer" )
65113 const auto & config = m_configurator->get_config ();
66114 size_t new_batch_dim = 0 , new_kernel_dim = 0 ;
67- if (!ov::snippets::pass::SplitDimensionM::split (config->master_shape ,
68- m_concurrency,
69- new_batch_dim,
70- new_kernel_dim)) {
115+ if (!split (config->master_shape , m_concurrency, new_batch_dim, new_kernel_dim)) {
71116 return false ;
72117 }
73118 auto & master_shape = config->master_shape ;
@@ -98,13 +143,9 @@ bool MHAParallelWAOptimizer::run(const lowered::LinearIR& linear_ir) {
98143 }
99144
100145 for (size_t i = 0 ; i < m_configurator->get_io_num (); ++i) {
101- config->io_shapes [i] =
102- m_unsqueezed_params.count (i)
103- ? ov::snippets::pass::SplitDimensionM::unsqueeze_m_dim (config->io_shapes [i], m_dim_M_idces[i])
104- : ov::snippets::pass::SplitDimensionM::reshape_m_dim (config->io_shapes [i],
105- m_dim_M_idces[i],
106- new_batch_dim,
107- new_kernel_dim);
146+ config->io_shapes [i] = m_unsqueezed_params.count (i)
147+ ? unsqueeze_m_dim (config->io_shapes [i], m_dim_M_idces[i])
148+ : reshape_m_dim (config->io_shapes [i], m_dim_M_idces[i], new_batch_dim, new_kernel_dim);
108149 }
109150 config->io_layouts = m_optimized_layouts;
110151 return true ;
@@ -211,4 +252,99 @@ std::vector<lowered::ExpandedLoopInfoPtr> MHAParallelWAOptimizer::find_loops_to_
211252 return loops_to_split;
212253}
213254
255+ bool MHAParallelWAOptimizer::split (const ov::Shape& shape,
256+ size_t optimal_parallelism_work_amount,
257+ size_t & batch_m_dim,
258+ size_t & new_m_dim) {
259+ const auto batch_dim = std::accumulate (shape.rbegin () + 2 , shape.rend (), size_t (1 ), std::multiplies<size_t >());
260+ const auto m_dim = *(shape.rbegin () + 1 );
261+ if (is_prime_number (m_dim))
262+ return false ;
263+
264+ // We skip optimization if the current batch is optimal for concurrency
265+ if (batch_dim % optimal_parallelism_work_amount == 0 )
266+ return false ;
267+
268+ auto split_is_done = [&batch_m_dim]() {
269+ return batch_m_dim != 1 ;
270+ };
271+
272+ std::tie (batch_m_dim, new_m_dim) = split_ideally (batch_dim, m_dim, optimal_parallelism_work_amount);
273+ if (split_is_done ())
274+ return true ;
275+
276+ std::tie (batch_m_dim, new_m_dim) = split_minimize_kernel_wa (batch_dim, m_dim, optimal_parallelism_work_amount);
277+ if (split_is_done ())
278+ return true ;
279+ // If all the previous heuristics failed, fallback heuristic is used, which reflects the old splitting behavior
280+ if (batch_dim < optimal_parallelism_work_amount)
281+ std::tie (batch_m_dim, new_m_dim) =
282+ split_fallback_increase_parallel_wa (batch_dim, m_dim, optimal_parallelism_work_amount);
283+ return split_is_done ();
284+ }
285+
286+ std::pair<size_t , size_t > MHAParallelWAOptimizer::split_ideally (size_t batch_dim,
287+ size_t m_dim,
288+ size_t optimal_parallelism_work_amount) {
289+ // Ideal case #1: M can be split on the parts one of which complements the batch dimension to the optimal parallel
290+ // work amount In this case, each thread will execute the Snippets kernel once
291+ const size_t lower_bound = optimal_parallelism_work_amount / batch_dim;
292+ if (lower_bound * batch_dim == optimal_parallelism_work_amount && m_dim % lower_bound == 0 )
293+ return std::make_pair (lower_bound, m_dim / lower_bound);
294+
295+ // Ideal case #2: M is divisible by optimal parallel work amount, and the new_m_dim is big enough
296+ // In this case, each thread will execute the Snippets kernel 'batch_dim' times
297+ if (m_dim % optimal_parallelism_work_amount == 0 ) {
298+ const auto new_m_dim = m_dim / optimal_parallelism_work_amount;
299+ if (new_m_dim >= m_min_kernel_m)
300+ return std::make_pair (optimal_parallelism_work_amount, new_m_dim);
301+ }
302+
303+ return std::make_pair (1 , m_dim);
304+ }
305+
306+ std::pair<size_t , size_t > MHAParallelWAOptimizer::split_fallback_increase_parallel_wa (
307+ size_t batch_dim,
308+ size_t m_dim,
309+ size_t optimal_parallelism_work_amount) {
310+ std::pair<size_t , size_t > splited = {1 , m_dim};
311+ const size_t upper_bound = utils::div_up (2 * optimal_parallelism_work_amount, batch_dim);
312+ for (size_t divisor_0 = upper_bound - 1 ; divisor_0 > 1 ; divisor_0--) {
313+ size_t divisor_1 = m_dim / divisor_0;
314+ if (divisor_1 * divisor_0 == m_dim)
315+ return divisor_0 * batch_dim >= optimal_parallelism_work_amount ? std::make_pair (divisor_0, divisor_1)
316+ : splited;
317+ }
318+ return splited;
319+ }
320+
321+ std::pair<size_t , size_t > MHAParallelWAOptimizer::split_minimize_kernel_wa (size_t batch_dim,
322+ size_t m_dim,
323+ size_t optimal_parallelism_work_amount) {
324+ // This heuristic minimizes 'm_kernel' (=> maximizes 'm_batch') with a limitation that 'm_kernel >= min_kernel_m'.
325+ // In other words, it tries to find 'm_kernel' bigger than 'm_min_kernel_m' and at the same time as close as possible
326+ // to this value.
327+ std::pair<size_t , size_t > best_result = {1 , m_dim};
328+ for (size_t divisor = 2 ; divisor < std::sqrt (m_dim); ++divisor) {
329+ if (m_dim % divisor != 0 )
330+ continue ;
331+ // If divisor is more than 'm_min_kernel_m', divisor becomes 'm_kernel',
332+ // guaranteeing the most optimal implementation from 'm_kernel' minimization perspective.
333+ if (divisor >= m_min_kernel_m)
334+ return std::make_pair (m_dim / divisor, divisor);
335+
336+ // If divisor is less than 'm_min_kernel_m', divisor becomes m_batch.
337+ // However, it is not guaranteed that the current 'm_kernel = m_dim / divisor' is minimized, as one of the next
338+ // divisors can be more optimal. So in this case the best result is remembered
339+ const size_t m_kernel = m_dim / divisor;
340+ if (m_kernel >= m_min_kernel_m) {
341+ best_result.first = divisor;
342+ best_result.second = m_kernel;
343+ }
344+ }
345+ if (best_result.first * batch_dim >= optimal_parallelism_work_amount)
346+ return best_result;
347+ return std::make_pair (1 , m_dim);
348+ }
349+
214350} // namespace ov::snippets::lowered::pass
0 commit comments