Skip to content

[Snippets][CPU] Support static and dynamic offsets in JIT Gemm and GemmCopyB emitters #31375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ class jit_emitter : public ov::snippets::Emitter {
}
}

int32_t get_gpr_length() const {
return h->x0.getBit() / 8;
}

private:
mutable std::vector<size_t> preserved_vec_idxs;
mutable std::vector<size_t> preserved_gpr_idxs;
Expand All @@ -179,10 +183,6 @@ class jit_emitter : public ov::snippets::Emitter {
return 32;
}

int32_t get_gpr_length() const {
return h->x0.getBit() / 8;
}

void store_context(const std::vector<size_t>& gpr_regs,
const std::vector<size_t>& vec_regs,
const std::unordered_set<size_t>& ignore_vec_regs = {}) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <vector>

#include "emitters/snippets/aarch64/kernel_executors/gemm_copy_b.hpp"
#include "emitters/snippets/aarch64/utils.hpp"
#include "emitters/snippets/jit_snippets_call_args.hpp"
#include "emitters/utils.hpp"
#include "openvino/core/node.hpp"
#include "openvino/core/type.hpp"
Expand Down Expand Up @@ -53,6 +55,13 @@ jit_gemm_copy_b_emitter::jit_gemm_copy_b_emitter(jit_generator* h,
OV_CPU_JIT_EMITTER_ASSERT(n_blk_size > 0, "n_blk_size of gemm_repack is expected to be greater than 0.");
GemmCopyBKernelKaiConfig kernel_config(n_blk_size);
m_kernel_executor = kernel_table->register_kernel<GemmCopyBKaiKernelExecutor>(expr, kernel_config);

// Initialize memory offsets similar to x64 brgemm_copy_b implementation
m_memory_offsets = {gemm_repack->get_offset_in(), gemm_repack->get_offset_out()};

// Initialize buffer IDs using the utils function
m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)),
utils::get_buffer_cluster_id(expr->get_output_port(0))};
}

std::set<std::vector<element::Type>> jit_gemm_copy_b_emitter::get_supported_precisions(
Expand All @@ -64,6 +73,8 @@ std::set<std::vector<element::Type>> jit_gemm_copy_b_emitter::get_supported_prec
void jit_gemm_copy_b_emitter::validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 1, "Expects 1 input reg, got", in.size());
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got", out.size());
OV_CPU_JIT_EMITTER_ASSERT(m_memory_offsets.size() == 2, "Expected 2 memory offsets for input and output");
OV_CPU_JIT_EMITTER_ASSERT(m_buffer_ids.size() == 2, "Expected 2 buffer IDs for input and output");
}

void jit_gemm_copy_b_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
Expand All @@ -75,10 +86,39 @@ void jit_gemm_copy_b_emitter::emit_impl(const std::vector<size_t>& in, const std
Xbyak_aarch64::XReg x0(0);
Xbyak_aarch64::XReg x1(1);
Xbyak_aarch64::XReg x2(2);
h->str(Xbyak_aarch64::XReg(in[0]), pre_ptr(h->sp, -get_vec_length()));
h->str(Xbyak_aarch64::XReg(out[0]), pre_ptr(h->sp, -get_vec_length()));
h->ldr(x2, post_ptr(h->sp, get_vec_length()));
h->ldr(x1, post_ptr(h->sp, get_vec_length()));
Xbyak_aarch64::XReg aux_reg(3);

// Prepare memory pointers with offsets
std::vector<size_t> mem_ptrs_idxs{in[0], out[0]};
const auto& mem_ptrs = utils::transform_idxs_to_regs(mem_ptrs_idxs);

// Apply memory offsets to input/output pointers
// Reserve space on stack for input/output pointers
h->sub(h->sp, h->sp, 2 * get_gpr_length());

// Apply offsets and store pointers on stack
for (size_t i = 0; i < mem_ptrs.size(); i++) {
const auto& ptr_reg = mem_ptrs[i];
int32_t stack_offset = i * get_gpr_length();

if (ov::snippets::utils::is_dynamic_value(m_memory_offsets[i])) {
// Dynamic offset: read from runtime parameters
size_t runtime_offset = GET_OFF(buffer_offsets) + m_buffer_ids[i] * sizeof(size_t);
utils::push_ptr_with_runtime_offset_on_stack(h, stack_offset, ptr_reg, aux_reg, runtime_offset);
} else {
// Static offset: add compile-time constant
utils::push_ptr_with_static_offset_on_stack(h, stack_offset, ptr_reg, m_memory_offsets[i]);
}
}

// Load back the adjusted pointers for function call
h->ldr(x1, Xbyak_aarch64::ptr(h->sp)); // input pointer
h->ldr(x2, Xbyak_aarch64::ptr(h->sp, get_gpr_length())); // output pointer

// Restore stack pointer
h->add(h->sp, h->sp, 2 * get_gpr_length());

// Set up executor pointer as first argument
const auto& compiled_kernel = get_compiled_kernel_ptr();
h->mov(x0, compiled_kernel);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class jit_gemm_copy_b_emitter : public jit_emitter {
const uintptr_t get_compiled_kernel_ptr() const;

std::shared_ptr<GemmCopyBKaiKernelExecutor> m_kernel_executor = nullptr;
std::vector<size_t> m_memory_offsets;
std::vector<size_t> m_buffer_ids;
};

} // namespace ov::intel_cpu::aarch64
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@
#include <vector>

#include "emitters/snippets/aarch64/kernel_executors/gemm.hpp"
#include "emitters/snippets/aarch64/utils.hpp"
#include "emitters/snippets/jit_snippets_call_args.hpp"
#include "emitters/utils.hpp"
#include "openvino/core/node.hpp"
#include "openvino/core/type/element_type.hpp"
#include "snippets/kernel_executor_table.hpp"
#include "snippets/lowered/expression.hpp"
#include "snippets/utils/utils.hpp"
#include "transformations/snippets/aarch64/op/gemm_cpu.hpp"

using namespace Xbyak_aarch64;

Expand All @@ -39,6 +43,17 @@ jit_gemm_emitter::jit_gemm_emitter(jit_generator* h,
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
GemmKernelKaiConfig kernel_config;
m_kernel_executor_kai = kernel_table->register_kernel<GemmKaiKernelExecutor>(expr, kernel_config);

const auto gemm_node = as_type_ptr<GemmCPU>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(gemm_node, "Expected GemmCPU node");

// Initialize memory offsets similar to x64 brgemm implementation
m_memory_offsets = {gemm_node->get_offset_a(), gemm_node->get_offset_b(), gemm_node->get_offset_c()};

// Initialize buffer IDs using the utils function
m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)),
utils::get_buffer_cluster_id(expr->get_input_port(1)),
utils::get_buffer_cluster_id(expr->get_output_port(0))};
}

std::set<std::vector<element::Type>> jit_gemm_emitter::get_supported_precisions(
Expand All @@ -50,6 +65,8 @@ std::set<std::vector<element::Type>> jit_gemm_emitter::get_supported_precisions(
void jit_gemm_emitter::validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got", in.size());
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got", out.size());
OV_CPU_JIT_EMITTER_ASSERT(m_memory_offsets.size() == 3, "Expected 3 memory offsets for A, B, C");
OV_CPU_JIT_EMITTER_ASSERT(m_buffer_ids.size() == 3, "Expected 3 buffer IDs for A, B, C");
}

void jit_gemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
Expand All @@ -62,12 +79,40 @@ void jit_gemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vecto
Xbyak_aarch64::XReg x1(1);
Xbyak_aarch64::XReg x2(2);
Xbyak_aarch64::XReg x3(3);
h->str(Xbyak_aarch64::XReg(in[0]), pre_ptr(h->sp, -get_vec_length()));
h->str(Xbyak_aarch64::XReg(in[1]), pre_ptr(h->sp, -get_vec_length()));
h->str(Xbyak_aarch64::XReg(out[0]), pre_ptr(h->sp, -get_vec_length()));
h->ldr(x3, post_ptr(h->sp, get_vec_length()));
h->ldr(x2, post_ptr(h->sp, get_vec_length()));
h->ldr(x1, post_ptr(h->sp, get_vec_length()));
Xbyak_aarch64::XReg aux_reg(5);

// Prepare memory pointers with offsets
std::vector<size_t> mem_ptrs_idxs{in[0], in[1], out[0]};
const auto& mem_ptrs = utils::transform_idxs_to_regs(mem_ptrs_idxs);

// Apply memory offsets to input/output pointers
// Reserve space on stack for matrix pointers - use pre-decrement addressing
h->sub(h->sp, h->sp, 3 * get_gpr_length());

// Apply offsets and store pointers on stack
for (size_t i = 0; i < mem_ptrs.size(); i++) {
const auto& ptr_reg = mem_ptrs[i];
int32_t stack_offset = i * get_gpr_length();

if (ov::snippets::utils::is_dynamic_value(m_memory_offsets[i])) {
// Dynamic offset: read from runtime parameters
size_t runtime_offset = GET_OFF(buffer_offsets) + m_buffer_ids[i] * sizeof(size_t);
utils::push_ptr_with_runtime_offset_on_stack(h, stack_offset, ptr_reg, aux_reg, runtime_offset);
} else {
// Static offset: add compile-time constant
utils::push_ptr_with_static_offset_on_stack(h, stack_offset, ptr_reg, m_memory_offsets[i]);
}
}

// Load back the adjusted pointers for function call
h->ldr(x1, Xbyak_aarch64::ptr(h->sp)); // matrix A (in0)
h->ldr(x2, Xbyak_aarch64::ptr(h->sp, get_gpr_length())); // matrix B (in1)
h->ldr(x3, Xbyak_aarch64::ptr(h->sp, 2 * get_gpr_length())); // matrix C (out)

// Restore stack pointer
h->add(h->sp, h->sp, 3 * get_gpr_length());

// Set up executor pointer as first argument
const auto& compiled_kernel = get_compiled_kernel_ptr();
h->mov(x0, compiled_kernel);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class jit_gemm_emitter : public jit_emitter {
const uintptr_t get_compiled_kernel_ptr() const;

std::shared_ptr<GemmKaiKernelExecutor> m_kernel_executor_kai = nullptr;
std::vector<size_t> m_memory_offsets;
std::vector<size_t> m_buffer_ids;
};

} // namespace ov::intel_cpu::aarch64
150 changes: 150 additions & 0 deletions src/plugins/intel_cpu/src/emitters/snippets/aarch64/utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "utils.hpp"

#include <algorithm>
#include <common/utils.hpp>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <set>
#include <unordered_set>
#include <vector>

#include "emitters/utils.hpp"
#include "openvino/core/except.hpp"
#include "openvino/core/type.hpp"
#include "snippets/emitter.hpp"
#include "snippets/lowered/expression_port.hpp"
#include "snippets/lowered/expressions/buffer_expression.hpp"
#include "snippets/op/loop.hpp"
#include "snippets/op/memory_access.hpp"
#include "snippets/utils/utils.hpp"

using namespace dnnl::impl::cpu::aarch64;

namespace ov::intel_cpu::aarch64::utils {

size_t get_buffer_cluster_id(const ov::snippets::lowered::ExpressionPort& port) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this helper for aarch64 and helper for x64 are the same, what's about to move them into one common file in emitters/snippets/utils/util.* for example?
These helpers aren't depended on arch so can be reused on all platforms

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

auto get_cluster_id = [](const snippets::lowered::ExpressionPort& p) {
const auto buffer = ov::as_type_ptr<ov::snippets::lowered::BufferExpression>(p.get_expr());
return buffer ? buffer->get_cluster_id() : SIZE_MAX;
};
const auto& ma_op = std::dynamic_pointer_cast<ov::snippets::modifier::MemoryAccess>(port.get_expr()->get_node());
OPENVINO_ASSERT(ma_op, "Expected MemoryAccess op!");
auto offset = ov::snippets::utils::get_dynamic_value<size_t>();
size_t id = SIZE_MAX;
switch (port.get_type()) {
case ov::snippets::lowered::ExpressionPort::Type::Input:
offset = ma_op->get_input_offset(port.get_index());
id = get_cluster_id(port.get_port_connector_ptr()->get_source());
break;
case ov::snippets::lowered::ExpressionPort::Type::Output:
offset = ma_op->get_output_offset(port.get_index());
for (const auto& child : port.get_connected_ports()) {
if (!ov::is_type<snippets::op::LoopEnd>(child.get_expr()->get_node())) {
id = get_cluster_id(child);
}
}
break;
default:
OV_CPU_JIT_EMITTER_THROW("Uknown type of expression port!");
}
OV_CPU_JIT_EMITTER_ASSERT(IMPLICATION(ov::snippets::utils::is_dynamic_value(offset), id != SIZE_MAX),
"In dynamic case Buffer Cluster ID must be known!");
return id;
}

Xbyak_aarch64::XReg get_aux_gpr(const std::vector<size_t>& used_gpr_idxs) {
// SP - stack pointer should be preserved, X0 and X1 - runtime parameter registers in the kernel
// X18 - platform register should not be used
static std::unordered_set<size_t> blacklist_gpr_idxs = {
31, // Stack pointer (SP)
0, // abi_param1 (X0)
1, // abi_param2 (X1)
18 // Platform register (X18)
};

// Iterate through available GPR registers (X0-X30, excluding X31 which is SP)
for (size_t gpr_idx = 0; gpr_idx <= 30; ++gpr_idx) {
size_t _idx = 30 - gpr_idx; // we allocate from the end
if (std::find(used_gpr_idxs.cbegin(), used_gpr_idxs.cend(), _idx) != used_gpr_idxs.cend()) {
continue;
}
if (blacklist_gpr_idxs.count(_idx) > 0) {
continue;
}
return Xbyak_aarch64::XReg(_idx);
}
OV_CPU_JIT_EMITTER_THROW("Failed to allocate aux GPR");
}

Xbyak_aarch64::XReg init_memory_access_aux_gpr(const std::vector<size_t>& used_gpr_reg_idxs,
const std::vector<size_t>& aux_gpr_idxs,
std::set<snippets::Reg>& regs_to_spill) {
if (!aux_gpr_idxs.empty()) {
return Xbyak_aarch64::XReg(static_cast<int>(aux_gpr_idxs[0]));
}
const auto aux_reg = ov::intel_cpu::aarch64::utils::get_aux_gpr(used_gpr_reg_idxs);
regs_to_spill.emplace(snippets::RegType::gpr, aux_reg.getIdx());
return aux_reg;
}

void push_ptr_with_runtime_offset_on_stack(dnnl::impl::cpu::aarch64::jit_generator* h,
int32_t stack_offset,
const Xbyak_aarch64::XReg& ptr_reg,
const Xbyak_aarch64::XReg& aux_reg,
size_t runtime_offset) {
// Copy pointer to aux register
h->mov(aux_reg, ptr_reg);

// Load the runtime offset from abi_param1 (X0) and add it to the pointer
Xbyak_aarch64::XReg abi_param1(0);
Xbyak_aarch64::XReg offset_reg(4);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still see magic XReg(4) here. Please avoid any initializations of registers by const numbers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


// Handle large runtime offsets by using a temporary register
if (runtime_offset > 4095) {
Xbyak_aarch64::XReg temp_offset_reg(6);
h->mov(temp_offset_reg, static_cast<uint64_t>(runtime_offset));
h->add(temp_offset_reg, abi_param1, temp_offset_reg);
h->ldr(offset_reg, Xbyak_aarch64::ptr(temp_offset_reg));
} else {
h->ldr(offset_reg, Xbyak_aarch64::ptr(abi_param1, static_cast<int32_t>(runtime_offset)));
}

h->add(aux_reg, aux_reg, offset_reg);

// Store the adjusted pointer on stack
h->str(aux_reg, Xbyak_aarch64::ptr(h->sp, stack_offset));
}

void push_ptr_with_static_offset_on_stack(dnnl::impl::cpu::aarch64::jit_generator* h,
int32_t stack_offset,
const Xbyak_aarch64::XReg& ptr_reg,
size_t ptr_offset) {
// If there's no static offset, just store the pointer
if (ptr_offset == 0) {
h->str(ptr_reg, Xbyak_aarch64::ptr(h->sp, stack_offset));
return;
}

// For non-zero offsets, apply the offset and then store
Xbyak_aarch64::XReg temp_reg(4);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's is the magic number 4? What's happened if ptr_reg is XReg(4) too? We can corrupt initial value in XReg(4).

To solve this problem, can we pass temp_reg as the argument of the helper as it's done in push_ptr_with_runtime_offset_on_stack?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

h->mov(temp_reg, ptr_reg);

// For large offsets, use a register to hold the offset value
if (ptr_offset > 4095) { // 12-bit immediate limit for add instruction
Xbyak_aarch64::XReg offset_reg(6);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same comment: we should avoid initialization of register by explicit numbers in any helpers since they might be called in any code parts. This is unsafe way if somewhere XReg(6) is already live register during push_ptr_with_static_offset_on_stack call

Please avoid any initialization of the register by explicit number in helpers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or can we just use TMP_X_0 for example here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

h->mov(offset_reg, static_cast<uint64_t>(ptr_offset));
h->add(temp_reg, temp_reg, offset_reg);
} else {
h->add(temp_reg, temp_reg, static_cast<int32_t>(ptr_offset));
}

// Store the adjusted pointer on stack
h->str(temp_reg, Xbyak_aarch64::ptr(h->sp, stack_offset));
}

} // namespace ov::intel_cpu::aarch64::utils
Loading
Loading