Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,41 @@ class MHAParallelWAOptimizer : public lowered::pass::RuntimeOptimizer {
return !m_loops_to_split.empty();
}

/**
* @brief Tries to split M dimension in "shape" in accordance to optimal parallel work amount
* @param shape Original shape
* @param optimal_parallelism_work_amount Optimal work amount
* @param batch_m_dim reference on batch's part of the split M
* @param new_m_dim reference on new M dim after the split
* @return true if split was successfull, otherwise false
*/
static bool split(const ov::Shape& shape,
size_t optimal_parallelism_work_amount,
size_t& batch_m_dim,
size_t& new_m_dim);

private:
/**
* @brief Contains splitM approaches allowing to get the batch ideally divisible by
* optimal_parallelism_work_amount
*/
static std::pair<size_t, size_t> split_ideally(size_t batch_dim,
size_t m_dim,
size_t optimal_parallelism_work_amount);
/**
* @brief Splits m_dim to minimize kernel_m in order to reduce waiting time for idle threads at the last
* parallel loop iteration.
*/
static std::pair<size_t, size_t> split_minimize_kernel_wa(size_t batch_dim,
size_t m_dim,
size_t optimal_parallelism_work_amount);
/**
* @brief Splits m_dim to get the batch in (optimal_parallelism_work_amount, 2 *
* optimal_parallelism_work_amount) interval
*/
static std::pair<size_t, size_t> split_fallback_increase_parallel_wa(size_t batch_dim,
size_t m_dim,
size_t optimal_parallelism_work_amount);
static std::unordered_set<lowered::ExpressionPtr> find_applicable_brgemms(const lowered::LinearIRCPtr& linear_ir,
bool check_dynamic_wa = true);

Expand All @@ -58,6 +92,7 @@ class MHAParallelWAOptimizer : public lowered::pass::RuntimeOptimizer {
size_t m_concurrency = 0;

static const size_t m_dim_M_idx;
static const size_t m_min_kernel_m;
};

} // namespace ov::snippets::lowered::pass
179 changes: 168 additions & 11 deletions src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@
#include "snippets/lowered/pass/mha_parallel_wa_optimizer.hpp"

#include <algorithm>
#include <cmath>
#include <cstddef>
#include <functional>
#include <iterator>
#include <numeric>
#include <set>
#include <tuple>
#include <unordered_set>
#include <utility>
#include <vector>

#include "openvino/core/except.hpp"
Expand All @@ -22,14 +27,65 @@
#include "snippets/lowered/loop_port.hpp"
#include "snippets/lowered/pass/runtime_optimizer.hpp"
#include "snippets/op/brgemm.hpp"
#include "snippets/pass/split_dimension_m.hpp"
#include "snippets/runtime_configurator.hpp"
#include "snippets/shape_types.hpp"
#include "snippets/utils/loop_utils.hpp"
#include "snippets/utils/utils.hpp"

namespace ov::snippets::lowered::pass {
namespace {
std::vector<size_t> get_updated_order(const std::vector<size_t>& order, size_t m_index) {
std::vector<size_t> new_order(order.size() + 1, 0);
size_t shift_idx = 0;
for (size_t i = 0; i < order.size(); ++i) {
if (order[i] < m_index) {
new_order[i + shift_idx] = order[i];
} else if (order[i] == m_index) {
new_order[i + shift_idx++] = order[i];
new_order[i + shift_idx] = order[i] + 1;
} else {
new_order[i + shift_idx] = order[i] + 1;
}
}
return new_order;
}

ov::snippets::VectorDims unsqueeze_m_dim(ov::snippets::VectorDims shape, size_t m_index) {
shape.insert(shape.begin() + m_index, 1);
return shape;
}

ov::snippets::VectorDims reshape_m_dim(ov::snippets::VectorDims shape,
size_t m_index,
size_t batch_m_dim,
size_t new_m_dim) {
if (shape[m_index] == 1) {
return unsqueeze_m_dim(std::move(shape), m_index);
}
shape[m_index] = new_m_dim;
shape.insert(shape.begin() + m_index, batch_m_dim);
return shape;
}

bool is_prime_number(size_t value) {
if (value == 2LU || value == 3LU) {
return true;
}
if (value == 1 || value % 2 == 0 || value % 3 == 0) {
return false;
}
const auto root = std::sqrt(value) + 1;
for (size_t divisor = 5; divisor < root; divisor += 6) {
if ((value % divisor == 0) || (value % (divisor + 2) == 0)) {
return false;
}
}
return true;
}
} // namespace

const size_t MHAParallelWAOptimizer::m_dim_M_idx = 1;
const size_t MHAParallelWAOptimizer::m_min_kernel_m = 32;

MHAParallelWAOptimizer::MHAParallelWAOptimizer(const lowered::LinearIRCPtr& linear_ir,
const RuntimeConfigurator* configurator)
Expand All @@ -56,18 +112,15 @@ MHAParallelWAOptimizer::MHAParallelWAOptimizer(const lowered::LinearIRCPtr& line
: utils::get_output_dim_idx(layout, m_dim_M_idx);
m_dim_M_idces[i] = dim_idx;
const auto m_idx = i < configurator->get_in_num() ? dim_idx : layout.size() - 2;
m_optimized_layouts[i] = ov::snippets::pass::SplitDimensionM::get_updated_order(layout, m_idx);
m_optimized_layouts[i] = get_updated_order(layout, m_idx);
}
}

bool MHAParallelWAOptimizer::run(const lowered::LinearIR& linear_ir) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::MHAParallelWAOptimizer")
const auto& config = m_configurator->get_config();
size_t new_batch_dim = 0, new_kernel_dim = 0;
if (!ov::snippets::pass::SplitDimensionM::split(config->master_shape,
m_concurrency,
new_batch_dim,
new_kernel_dim)) {
if (!split(config->master_shape, m_concurrency, new_batch_dim, new_kernel_dim)) {
return false;
}
auto& master_shape = config->master_shape;
Expand Down Expand Up @@ -100,11 +153,8 @@ bool MHAParallelWAOptimizer::run(const lowered::LinearIR& linear_ir) {
for (size_t i = 0; i < m_configurator->get_io_num(); ++i) {
config->io_shapes[i] =
m_unsqueezed_params.count(i)
? ov::snippets::pass::SplitDimensionM::unsqueeze_m_dim(config->io_shapes[i], m_dim_M_idces[i])
: ov::snippets::pass::SplitDimensionM::reshape_m_dim(config->io_shapes[i],
m_dim_M_idces[i],
new_batch_dim,
new_kernel_dim);
? unsqueeze_m_dim(config->io_shapes[i], m_dim_M_idces[i])
: reshape_m_dim(config->io_shapes[i], m_dim_M_idces[i], new_batch_dim, new_kernel_dim);
}
config->io_layouts = m_optimized_layouts;
return true;
Expand Down Expand Up @@ -211,4 +261,111 @@ std::vector<lowered::ExpandedLoopInfoPtr> MHAParallelWAOptimizer::find_loops_to_
return loops_to_split;
}

bool MHAParallelWAOptimizer::split(const ov::Shape& shape,
size_t optimal_parallelism_work_amount,
size_t& batch_m_dim,
size_t& new_m_dim) {
const auto batch_dim =
std::accumulate(shape.rbegin() + 2, shape.rend(), static_cast<size_t>(1), std::multiplies<>());
const auto m_dim = *(shape.rbegin() + 1);
if (is_prime_number(m_dim)) {
return false;
}

// We skip optimization if the current batch is optimal for concurrency
if (batch_dim % optimal_parallelism_work_amount == 0) {
return false;
}

auto split_is_done = [&batch_m_dim]() {
return batch_m_dim != 1;
};

std::tie(batch_m_dim, new_m_dim) = split_ideally(batch_dim, m_dim, optimal_parallelism_work_amount);
if (split_is_done()) {
return true;
}

std::tie(batch_m_dim, new_m_dim) = split_minimize_kernel_wa(batch_dim, m_dim, optimal_parallelism_work_amount);
if (split_is_done()) {
return true;
}
// If all the previous heuristics failed, fallback heuristic is used, which reflects the old splitting behavior
if (batch_dim < optimal_parallelism_work_amount) {
std::tie(batch_m_dim, new_m_dim) =
split_fallback_increase_parallel_wa(batch_dim, m_dim, optimal_parallelism_work_amount);
}
return split_is_done();
}

std::pair<size_t, size_t> MHAParallelWAOptimizer::split_ideally(size_t batch_dim,
size_t m_dim,
size_t optimal_parallelism_work_amount) {
// Ideal case #1: M can be split on the parts one of which complements the batch dimension to the optimal parallel
// work amount In this case, each thread will execute the Snippets kernel once
const size_t lower_bound = optimal_parallelism_work_amount / batch_dim;
if (lower_bound * batch_dim == optimal_parallelism_work_amount && m_dim % lower_bound == 0) {
return std::make_pair(lower_bound, m_dim / lower_bound);
}

// Ideal case #2: M is divisible by optimal parallel work amount, and the new_m_dim is big enough
// In this case, each thread will execute the Snippets kernel 'batch_dim' times
if (m_dim % optimal_parallelism_work_amount == 0) {
const auto new_m_dim = m_dim / optimal_parallelism_work_amount;
if (new_m_dim >= m_min_kernel_m) {
return std::make_pair(optimal_parallelism_work_amount, new_m_dim);
}
}

return std::make_pair(1, m_dim);
}

std::pair<size_t, size_t> MHAParallelWAOptimizer::split_fallback_increase_parallel_wa(
size_t batch_dim,
size_t m_dim,
size_t optimal_parallelism_work_amount) {
std::pair<size_t, size_t> splited = {1, m_dim};
const size_t upper_bound = utils::div_up(2 * optimal_parallelism_work_amount, batch_dim);
for (size_t divisor_0 = upper_bound - 1; divisor_0 > 1; divisor_0--) {
size_t divisor_1 = m_dim / divisor_0;
if (divisor_1 * divisor_0 == m_dim) {
return divisor_0 * batch_dim >= optimal_parallelism_work_amount ? std::make_pair(divisor_0, divisor_1)
: splited;
}
}
return splited;
}

std::pair<size_t, size_t> MHAParallelWAOptimizer::split_minimize_kernel_wa(size_t batch_dim,
size_t m_dim,
size_t optimal_parallelism_work_amount) {
// This heuristic minimizes 'm_kernel' (=> maximizes 'm_batch') with a limitation that 'm_kernel >= min_kernel_m'.
// In other words, it tries to find 'm_kernel' bigger than 'm_min_kernel_m' and at the same time as close as
// possible to this value.
std::pair<size_t, size_t> best_result = {1, m_dim};
for (size_t divisor = 2; divisor < std::sqrt(m_dim); ++divisor) {
if (m_dim % divisor != 0) {
continue;
}
// If divisor is more than 'm_min_kernel_m', divisor becomes 'm_kernel',
// guaranteeing the most optimal implementation from 'm_kernel' minimization perspective.
if (divisor >= m_min_kernel_m) {
return std::make_pair(m_dim / divisor, divisor);
}

// If divisor is less than 'm_min_kernel_m', divisor becomes m_batch.
// However, it is not guaranteed that the current 'm_kernel = m_dim / divisor' is minimized, as one of the next
// divisors can be more optimal. So in this case the best result is remembered
const size_t m_kernel = m_dim / divisor;
if (m_kernel >= m_min_kernel_m) {
best_result.first = divisor;
best_result.second = m_kernel;
}
}
if (best_result.first * batch_dim >= optimal_parallelism_work_amount) {
return best_result;
}
return std::make_pair(1, m_dim);
}

} // namespace ov::snippets::lowered::pass
5 changes: 0 additions & 5 deletions src/common/snippets/src/pass/common_optimizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "snippets/pass/extract_unsupported_transposes.hpp"
#include "snippets/pass/fq_decomposition.hpp"
#include "snippets/pass/softmax_reshape_elimination.hpp"
#include "snippets/pass/split_dimension_m.hpp"
#include "snippets/pass/subgraph_manager.hpp"
#include "snippets/pass/transform_convert.hpp"
#include "snippets/pass/validate.hpp"
Expand Down Expand Up @@ -60,10 +59,6 @@ CommonOptimizations::CommonOptimizations(const CommonOptimizations::Config& conf
ov::snippets::pass::ExtractUnsupportedTransposes,
is_domain_sensitive,
config.get_transpose_support_callback());
REGISTER_SNIPPETS_PASS(subgraph_manager,
ov::snippets::pass::SplitDimensionM,
is_domain_sensitive && config.get_split_m_dimension(),
config.get_concurrency());
subgraph_manager.run_passes(subgraph);

// Validate the body after all common optimizations
Expand Down
Loading
Loading