Skip to content

Commit ff62894

Browse files
committed
[Snippets] SplitM pass removed
1 parent 222c9d4 commit ff62894

File tree

10 files changed

+211
-346
lines changed

10 files changed

+211
-346
lines changed

src/common/snippets/include/snippets/lowered/pass/mha_parallel_wa_optimizer.hpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,41 @@ class MHAParallelWAOptimizer : public lowered::pass::RuntimeOptimizer {
4040
return !m_loops_to_split.empty();
4141
}
4242

43+
/**
44+
* @brief Tries to split M dimension in "shape" in accordance to optimal parallel work amount
45+
* @param shape Original shape
46+
* @param optimal_parallelism_work_amount Optimal work amount
47+
* @param batch_m_dim reference on batch's part of the split M
48+
* @param new_m_dim reference on new M dim after the split
49+
* @return true if split was successfull, otherwise false
50+
*/
51+
static bool split(const ov::Shape& shape,
52+
size_t optimal_parallelism_work_amount,
53+
size_t& batch_m_dim,
54+
size_t& new_m_dim);
55+
4356
private:
57+
/**
58+
* @brief Contains splitM approaches allowing to get the batch ideally divisible by
59+
* optimal_parallelism_work_amount
60+
*/
61+
static std::pair<size_t, size_t> split_ideally(size_t batch_dim,
62+
size_t m_dim,
63+
size_t optimal_parallelism_work_amount);
64+
/**
65+
* @brief Splits m_dim to minimize kernel_m in order to reduce waiting time for idle threads at the last
66+
* parallel loop iteration.
67+
*/
68+
static std::pair<size_t, size_t> split_minimize_kernel_wa(size_t batch_dim,
69+
size_t m_dim,
70+
size_t optimal_parallelism_work_amount);
71+
/**
72+
* @brief Splits m_dim to get the batch in (optimal_parallelism_work_amount, 2 *
73+
* optimal_parallelism_work_amount) interval
74+
*/
75+
static std::pair<size_t, size_t> split_fallback_increase_parallel_wa(size_t batch_dim,
76+
size_t m_dim,
77+
size_t optimal_parallelism_work_amount);
4478
static std::unordered_set<lowered::ExpressionPtr> find_applicable_brgemms(const lowered::LinearIRCPtr& linear_ir,
4579
bool check_dynamic_wa = true);
4680

@@ -58,6 +92,7 @@ class MHAParallelWAOptimizer : public lowered::pass::RuntimeOptimizer {
5892
size_t m_concurrency = 0;
5993

6094
static const size_t m_dim_M_idx;
95+
static const size_t m_min_kernel_m;
6196
};
6297

6398
} // namespace ov::snippets::lowered::pass

src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp

Lines changed: 149 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
#include "snippets/lowered/pass/mha_parallel_wa_optimizer.hpp"
66

77
#include <algorithm>
8+
#include <cmath>
89
#include <cstddef>
910
#include <iterator>
11+
#include <numeric>
1012
#include <set>
1113
#include <unordered_set>
1214
#include <vector>
@@ -22,14 +24,60 @@
2224
#include "snippets/lowered/loop_port.hpp"
2325
#include "snippets/lowered/pass/runtime_optimizer.hpp"
2426
#include "snippets/op/brgemm.hpp"
25-
#include "snippets/pass/split_dimension_m.hpp"
2627
#include "snippets/runtime_configurator.hpp"
2728
#include "snippets/utils/loop_utils.hpp"
2829
#include "snippets/utils/utils.hpp"
2930

3031
namespace ov::snippets::lowered::pass {
32+
namespace {
33+
std::vector<size_t> get_updated_order(const std::vector<size_t>& order, size_t m_index) {
34+
std::vector<size_t> new_order(order.size() + 1, 0);
35+
size_t shift_idx = 0;
36+
for (size_t i = 0; i < order.size(); ++i) {
37+
if (order[i] < m_index) {
38+
new_order[i + shift_idx] = order[i];
39+
} else if (order[i] == m_index) {
40+
new_order[i + shift_idx++] = order[i];
41+
new_order[i + shift_idx] = order[i] + 1;
42+
} else {
43+
new_order[i + shift_idx] = order[i] + 1;
44+
}
45+
}
46+
return new_order;
47+
}
48+
49+
ov::snippets::VectorDims unsqueeze_m_dim(ov::snippets::VectorDims shape, size_t m_index) {
50+
shape.insert(shape.begin() + m_index, 1);
51+
return shape;
52+
}
53+
54+
ov::snippets::VectorDims reshape_m_dim(ov::snippets::VectorDims shape,
55+
size_t m_index,
56+
size_t batch_m_dim,
57+
size_t new_m_dim) {
58+
if (shape[m_index] == 1)
59+
return unsqueeze_m_dim(std::move(shape), m_index);
60+
shape[m_index] = new_m_dim;
61+
shape.insert(shape.begin() + m_index, batch_m_dim);
62+
return shape;
63+
}
64+
65+
bool is_prime_number(size_t value) {
66+
if (value == 2lu || value == 3lu)
67+
return true;
68+
if (value == 1 || value % 2 == 0 || value % 3 == 0)
69+
return false;
70+
const auto root = std::sqrt(value) + 1;
71+
for (size_t divisor = 5; divisor < root; divisor += 6) {
72+
if ((value % divisor == 0) || (value % (divisor + 2) == 0))
73+
return false;
74+
}
75+
return true;
76+
}
77+
} // namespace
3178

3279
const size_t MHAParallelWAOptimizer::m_dim_M_idx = 1;
80+
const size_t MHAParallelWAOptimizer::m_min_kernel_m = 32;
3381

3482
MHAParallelWAOptimizer::MHAParallelWAOptimizer(const lowered::LinearIRCPtr& linear_ir,
3583
const RuntimeConfigurator* configurator)
@@ -56,18 +104,15 @@ MHAParallelWAOptimizer::MHAParallelWAOptimizer(const lowered::LinearIRCPtr& line
56104
: utils::get_output_dim_idx(layout, m_dim_M_idx);
57105
m_dim_M_idces[i] = dim_idx;
58106
const auto m_idx = i < configurator->get_in_num() ? dim_idx : layout.size() - 2;
59-
m_optimized_layouts[i] = ov::snippets::pass::SplitDimensionM::get_updated_order(layout, m_idx);
107+
m_optimized_layouts[i] = get_updated_order(layout, m_idx);
60108
}
61109
}
62110

63111
bool MHAParallelWAOptimizer::run(const lowered::LinearIR& linear_ir) {
64112
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::MHAParallelWAOptimizer")
65113
const auto& config = m_configurator->get_config();
66114
size_t new_batch_dim = 0, new_kernel_dim = 0;
67-
if (!ov::snippets::pass::SplitDimensionM::split(config->master_shape,
68-
m_concurrency,
69-
new_batch_dim,
70-
new_kernel_dim)) {
115+
if (!split(config->master_shape, m_concurrency, new_batch_dim, new_kernel_dim)) {
71116
return false;
72117
}
73118
auto& master_shape = config->master_shape;
@@ -98,13 +143,9 @@ bool MHAParallelWAOptimizer::run(const lowered::LinearIR& linear_ir) {
98143
}
99144

100145
for (size_t i = 0; i < m_configurator->get_io_num(); ++i) {
101-
config->io_shapes[i] =
102-
m_unsqueezed_params.count(i)
103-
? ov::snippets::pass::SplitDimensionM::unsqueeze_m_dim(config->io_shapes[i], m_dim_M_idces[i])
104-
: ov::snippets::pass::SplitDimensionM::reshape_m_dim(config->io_shapes[i],
105-
m_dim_M_idces[i],
106-
new_batch_dim,
107-
new_kernel_dim);
146+
config->io_shapes[i] = m_unsqueezed_params.count(i)
147+
? unsqueeze_m_dim(config->io_shapes[i], m_dim_M_idces[i])
148+
: reshape_m_dim(config->io_shapes[i], m_dim_M_idces[i], new_batch_dim, new_kernel_dim);
108149
}
109150
config->io_layouts = m_optimized_layouts;
110151
return true;
@@ -211,4 +252,99 @@ std::vector<lowered::ExpandedLoopInfoPtr> MHAParallelWAOptimizer::find_loops_to_
211252
return loops_to_split;
212253
}
213254

255+
bool MHAParallelWAOptimizer::split(const ov::Shape& shape,
256+
size_t optimal_parallelism_work_amount,
257+
size_t& batch_m_dim,
258+
size_t& new_m_dim) {
259+
const auto batch_dim = std::accumulate(shape.rbegin() + 2, shape.rend(), size_t(1), std::multiplies<size_t>());
260+
const auto m_dim = *(shape.rbegin() + 1);
261+
if (is_prime_number(m_dim))
262+
return false;
263+
264+
// We skip optimization if the current batch is optimal for concurrency
265+
if (batch_dim % optimal_parallelism_work_amount == 0)
266+
return false;
267+
268+
auto split_is_done = [&batch_m_dim]() {
269+
return batch_m_dim != 1;
270+
};
271+
272+
std::tie(batch_m_dim, new_m_dim) = split_ideally(batch_dim, m_dim, optimal_parallelism_work_amount);
273+
if (split_is_done())
274+
return true;
275+
276+
std::tie(batch_m_dim, new_m_dim) = split_minimize_kernel_wa(batch_dim, m_dim, optimal_parallelism_work_amount);
277+
if (split_is_done())
278+
return true;
279+
// If all the previous heuristics failed, fallback heuristic is used, which reflects the old splitting behavior
280+
if (batch_dim < optimal_parallelism_work_amount)
281+
std::tie(batch_m_dim, new_m_dim) =
282+
split_fallback_increase_parallel_wa(batch_dim, m_dim, optimal_parallelism_work_amount);
283+
return split_is_done();
284+
}
285+
286+
std::pair<size_t, size_t> MHAParallelWAOptimizer::split_ideally(size_t batch_dim,
287+
size_t m_dim,
288+
size_t optimal_parallelism_work_amount) {
289+
// Ideal case #1: M can be split on the parts one of which complements the batch dimension to the optimal parallel
290+
// work amount In this case, each thread will execute the Snippets kernel once
291+
const size_t lower_bound = optimal_parallelism_work_amount / batch_dim;
292+
if (lower_bound * batch_dim == optimal_parallelism_work_amount && m_dim % lower_bound == 0)
293+
return std::make_pair(lower_bound, m_dim / lower_bound);
294+
295+
// Ideal case #2: M is divisible by optimal parallel work amount, and the new_m_dim is big enough
296+
// In this case, each thread will execute the Snippets kernel 'batch_dim' times
297+
if (m_dim % optimal_parallelism_work_amount == 0) {
298+
const auto new_m_dim = m_dim / optimal_parallelism_work_amount;
299+
if (new_m_dim >= m_min_kernel_m)
300+
return std::make_pair(optimal_parallelism_work_amount, new_m_dim);
301+
}
302+
303+
return std::make_pair(1, m_dim);
304+
}
305+
306+
std::pair<size_t, size_t> MHAParallelWAOptimizer::split_fallback_increase_parallel_wa(
307+
size_t batch_dim,
308+
size_t m_dim,
309+
size_t optimal_parallelism_work_amount) {
310+
std::pair<size_t, size_t> splited = {1, m_dim};
311+
const size_t upper_bound = utils::div_up(2 * optimal_parallelism_work_amount, batch_dim);
312+
for (size_t divisor_0 = upper_bound - 1; divisor_0 > 1; divisor_0--) {
313+
size_t divisor_1 = m_dim / divisor_0;
314+
if (divisor_1 * divisor_0 == m_dim)
315+
return divisor_0 * batch_dim >= optimal_parallelism_work_amount ? std::make_pair(divisor_0, divisor_1)
316+
: splited;
317+
}
318+
return splited;
319+
}
320+
321+
std::pair<size_t, size_t> MHAParallelWAOptimizer::split_minimize_kernel_wa(size_t batch_dim,
322+
size_t m_dim,
323+
size_t optimal_parallelism_work_amount) {
324+
// This heuristic minimizes 'm_kernel' (=> maximizes 'm_batch') with a limitation that 'm_kernel >= min_kernel_m'.
325+
// In other words, it tries to find 'm_kernel' bigger than 'm_min_kernel_m' and at the same time as close as possible
326+
// to this value.
327+
std::pair<size_t, size_t> best_result = {1, m_dim};
328+
for (size_t divisor = 2; divisor < std::sqrt(m_dim); ++divisor) {
329+
if (m_dim % divisor != 0)
330+
continue;
331+
// If divisor is more than 'm_min_kernel_m', divisor becomes 'm_kernel',
332+
// guaranteeing the most optimal implementation from 'm_kernel' minimization perspective.
333+
if (divisor >= m_min_kernel_m)
334+
return std::make_pair(m_dim / divisor, divisor);
335+
336+
// If divisor is less than 'm_min_kernel_m', divisor becomes m_batch.
337+
// However, it is not guaranteed that the current 'm_kernel = m_dim / divisor' is minimized, as one of the next
338+
// divisors can be more optimal. So in this case the best result is remembered
339+
const size_t m_kernel = m_dim / divisor;
340+
if (m_kernel >= m_min_kernel_m) {
341+
best_result.first = divisor;
342+
best_result.second = m_kernel;
343+
}
344+
}
345+
if (best_result.first * batch_dim >= optimal_parallelism_work_amount)
346+
return best_result;
347+
return std::make_pair(1, m_dim);
348+
}
349+
214350
} // namespace ov::snippets::lowered::pass

src/common/snippets/src/pass/common_optimizations.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include "snippets/pass/extract_unsupported_transposes.hpp"
1919
#include "snippets/pass/fq_decomposition.hpp"
2020
#include "snippets/pass/softmax_reshape_elimination.hpp"
21-
#include "snippets/pass/split_dimension_m.hpp"
2221
#include "snippets/pass/subgraph_manager.hpp"
2322
#include "snippets/pass/transform_convert.hpp"
2423
#include "snippets/pass/validate.hpp"
@@ -60,10 +59,6 @@ CommonOptimizations::CommonOptimizations(const CommonOptimizations::Config& conf
6059
ov::snippets::pass::ExtractUnsupportedTransposes,
6160
is_domain_sensitive,
6261
config.get_transpose_support_callback());
63-
REGISTER_SNIPPETS_PASS(subgraph_manager,
64-
ov::snippets::pass::SplitDimensionM,
65-
is_domain_sensitive && config.get_split_m_dimension(),
66-
config.get_concurrency());
6762
subgraph_manager.run_passes(subgraph);
6863

6964
// Validate the body after all common optimizations

0 commit comments

Comments
 (0)