Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class jit_emitter : public ov::snippets::Emitter {
static std::set<std::vector<element::Type>> get_supported_precisions(
const std::shared_ptr<ov::Node>& node = nullptr);

static constexpr int sp_alignment = 16;

protected:
static size_t get_max_vecs_count();
static int32_t get_vec_length();
Expand Down Expand Up @@ -155,6 +157,10 @@ class jit_emitter : public ov::snippets::Emitter {
}
}

static int32_t get_gpr_length() {
return 8;
}

private:
mutable std::vector<size_t> preserved_vec_idxs;
mutable std::vector<size_t> preserved_gpr_idxs;
Expand All @@ -179,10 +185,6 @@ class jit_emitter : public ov::snippets::Emitter {
return 32;
}

int32_t get_gpr_length() const {
return h->x0.getBit() / 8;
}

void store_context(const std::vector<size_t>& gpr_regs,
const std::vector<size_t>& vec_regs,
const std::unordered_set<size_t>& ignore_vec_regs = {}) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include "jit_gemm_copy_b_emitter.hpp"

#include <xbyak_aarch64/xbyak_aarch64/xbyak_aarch64_adr.h>
#include <xbyak_aarch64/xbyak_aarch64/xbyak_aarch64_reg.h>

#include <cpu/aarch64/cpu_isa_traits.hpp>
Expand All @@ -17,6 +16,8 @@
#include <vector>

#include "emitters/snippets/aarch64/kernel_executors/gemm_copy_b.hpp"
#include "emitters/snippets/aarch64/utils.hpp"
#include "emitters/snippets/utils/utils.hpp"
#include "emitters/utils.hpp"
#include "openvino/core/node.hpp"
#include "openvino/core/type.hpp"
Expand Down Expand Up @@ -53,6 +54,13 @@ jit_gemm_copy_b_emitter::jit_gemm_copy_b_emitter(jit_generator* h,
OV_CPU_JIT_EMITTER_ASSERT(n_blk_size > 0, "n_blk_size of gemm_repack is expected to be greater than 0.");
GemmCopyBKernelKaiConfig kernel_config(n_blk_size);
m_kernel_executor = kernel_table->register_kernel<GemmCopyBKaiKernelExecutor>(expr, kernel_config);

// Initialize memory offsets similar to x64 brgemm_copy_b implementation
m_memory_offsets = {gemm_repack->get_offset_in(), gemm_repack->get_offset_out()};

// Initialize buffer IDs using the utils function
m_buffer_ids = {ov::intel_cpu::utils::get_buffer_cluster_id(expr->get_input_port(0)),
ov::intel_cpu::utils::get_buffer_cluster_id(expr->get_output_port(0))};
}

std::set<std::vector<element::Type>> jit_gemm_copy_b_emitter::get_supported_precisions(
Expand All @@ -64,6 +72,8 @@ std::set<std::vector<element::Type>> jit_gemm_copy_b_emitter::get_supported_prec
void jit_gemm_copy_b_emitter::validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 1, "Expects 1 input reg, got", in.size());
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got", out.size());
OV_CPU_JIT_EMITTER_ASSERT(m_memory_offsets.size() == 2, "Expected 2 memory offsets for input and output");
OV_CPU_JIT_EMITTER_ASSERT(m_buffer_ids.size() == 2, "Expected 2 buffer IDs for input and output");
}

void jit_gemm_copy_b_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
Expand All @@ -75,10 +85,29 @@ void jit_gemm_copy_b_emitter::emit_impl(const std::vector<size_t>& in, const std
Xbyak_aarch64::XReg x0(0);
Xbyak_aarch64::XReg x1(1);
Xbyak_aarch64::XReg x2(2);
h->str(Xbyak_aarch64::XReg(in[0]), pre_ptr(h->sp, -get_vec_length()));
h->str(Xbyak_aarch64::XReg(out[0]), pre_ptr(h->sp, -get_vec_length()));
h->ldr(x2, post_ptr(h->sp, get_vec_length()));
h->ldr(x1, post_ptr(h->sp, get_vec_length()));
Xbyak_aarch64::XReg aux_reg(3);

// Prepare memory pointers with offsets
std::vector<size_t> mem_ptrs_idxs{in[0], out[0]};
const auto& mem_ptrs = utils::transform_idxs_to_regs(mem_ptrs_idxs);

// Apply memory offsets and load adjusted pointers
std::vector<Xbyak_aarch64::XReg> load_regs{x1, x2};

// Dynamically choose safe auxiliary registers that don't conflict with mem_ptrs or load_regs
std::vector<size_t> used_indices;
used_indices.reserve(mem_ptrs.size());
for (const auto& reg : mem_ptrs) {
used_indices.push_back(reg.getIdx());
}
for (const auto& reg : load_regs) {
used_indices.push_back(reg.getIdx());
}
std::vector<Xbyak_aarch64::XReg> aux_regs = utils::get_aux_gprs(used_indices);

utils::push_and_load_ptrs_with_offsets(h, mem_ptrs, m_memory_offsets, m_buffer_ids, aux_regs, load_regs);

// Set up executor pointer as first argument
const auto& compiled_kernel = get_compiled_kernel_ptr();
h->mov(x0, compiled_kernel);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class jit_gemm_copy_b_emitter : public jit_emitter {
uintptr_t get_compiled_kernel_ptr() const;

std::shared_ptr<GemmCopyBKaiKernelExecutor> m_kernel_executor = nullptr;
std::vector<size_t> m_memory_offsets;
std::vector<size_t> m_buffer_ids;
};

} // namespace ov::intel_cpu::aarch64
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include "jit_gemm_emitter.hpp"

#include <xbyak_aarch64/xbyak_aarch64/xbyak_aarch64_adr.h>
#include <xbyak_aarch64/xbyak_aarch64/xbyak_aarch64_reg.h>

#include <cpu/aarch64/cpu_isa_traits.hpp>
Expand All @@ -17,11 +16,15 @@
#include <vector>

#include "emitters/snippets/aarch64/kernel_executors/gemm.hpp"
#include "emitters/snippets/aarch64/utils.hpp"
#include "emitters/snippets/utils/utils.hpp"
#include "emitters/utils.hpp"
#include "openvino/core/node.hpp"
#include "openvino/core/type.hpp"
#include "openvino/core/type/element_type.hpp"
#include "snippets/kernel_executor_table.hpp"
#include "snippets/lowered/expression.hpp"
#include "transformations/snippets/aarch64/op/gemm_cpu.hpp"

using namespace Xbyak_aarch64;

Expand All @@ -39,6 +42,17 @@ jit_gemm_emitter::jit_gemm_emitter(jit_generator* h,
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
GemmKernelKaiConfig kernel_config;
m_kernel_executor_kai = kernel_table->register_kernel<GemmKaiKernelExecutor>(expr, kernel_config);

const auto gemm_node = as_type_ptr<GemmCPU>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(gemm_node, "Expected GemmCPU node");

// Initialize memory offsets similar to x64 brgemm implementation
m_memory_offsets = {gemm_node->get_offset_a(), gemm_node->get_offset_b(), gemm_node->get_offset_c()};

// Initialize buffer IDs using the utils function
m_buffer_ids = {ov::intel_cpu::utils::get_buffer_cluster_id(expr->get_input_port(0)),
ov::intel_cpu::utils::get_buffer_cluster_id(expr->get_input_port(1)),
ov::intel_cpu::utils::get_buffer_cluster_id(expr->get_output_port(0))};
}

std::set<std::vector<element::Type>> jit_gemm_emitter::get_supported_precisions(
Expand All @@ -50,6 +64,8 @@ std::set<std::vector<element::Type>> jit_gemm_emitter::get_supported_precisions(
void jit_gemm_emitter::validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got", in.size());
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got", out.size());
OV_CPU_JIT_EMITTER_ASSERT(m_memory_offsets.size() == 3, "Expected 3 memory offsets for A, B, C");
OV_CPU_JIT_EMITTER_ASSERT(m_buffer_ids.size() == 3, "Expected 3 buffer IDs for A, B, C");
}

void jit_gemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
Expand All @@ -62,12 +78,29 @@ void jit_gemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vecto
Xbyak_aarch64::XReg x1(1);
Xbyak_aarch64::XReg x2(2);
Xbyak_aarch64::XReg x3(3);
h->str(Xbyak_aarch64::XReg(in[0]), pre_ptr(h->sp, -get_vec_length()));
h->str(Xbyak_aarch64::XReg(in[1]), pre_ptr(h->sp, -get_vec_length()));
h->str(Xbyak_aarch64::XReg(out[0]), pre_ptr(h->sp, -get_vec_length()));
h->ldr(x3, post_ptr(h->sp, get_vec_length()));
h->ldr(x2, post_ptr(h->sp, get_vec_length()));
h->ldr(x1, post_ptr(h->sp, get_vec_length()));
Xbyak_aarch64::XReg aux_reg(5);

// Prepare memory pointers with offsets
std::vector<size_t> mem_ptrs_idxs{in[0], in[1], out[0]};
const auto& mem_ptrs = utils::transform_idxs_to_regs(mem_ptrs_idxs);

// Apply memory offsets and load adjusted pointers
std::vector<Xbyak_aarch64::XReg> load_regs{x1, x2, x3};

// Dynamically choose safe auxiliary registers that don't conflict with mem_ptrs or load_regs
std::vector<size_t> used_indices;
used_indices.reserve(mem_ptrs.size());
for (const auto& reg : mem_ptrs) {
used_indices.push_back(reg.getIdx());
}
for (const auto& reg : load_regs) {
used_indices.push_back(reg.getIdx());
}
std::vector<Xbyak_aarch64::XReg> aux_regs = utils::get_aux_gprs(used_indices);

utils::push_and_load_ptrs_with_offsets(h, mem_ptrs, m_memory_offsets, m_buffer_ids, aux_regs, load_regs);

// Set up executor pointer as first argument
const auto& compiled_kernel = get_compiled_kernel_ptr();
h->mov(x0, compiled_kernel);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class jit_gemm_emitter : public jit_emitter {
uintptr_t get_compiled_kernel_ptr() const;

std::shared_ptr<GemmKaiKernelExecutor> m_kernel_executor_kai = nullptr;
std::vector<size_t> m_memory_offsets;
std::vector<size_t> m_buffer_ids;
};

} // namespace ov::intel_cpu::aarch64
181 changes: 181 additions & 0 deletions src/plugins/intel_cpu/src/emitters/snippets/aarch64/utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "utils.hpp"

#include <xbyak_aarch64/xbyak_aarch64/xbyak_aarch64_adr.h>
#include <xbyak_aarch64/xbyak_aarch64/xbyak_aarch64_reg.h>

#include <cpu/aarch64/jit_generator.hpp>
#include <cstddef>
#include <cstdint>
#include <set>
#include <unordered_set>
#include <vector>

#include "emitters/snippets/jit_snippets_call_args.hpp"
#include "emitters/utils.hpp"
#include "openvino/core/except.hpp"
#include "snippets/emitter.hpp"
#include "snippets/utils/utils.hpp"
#include "utils/general_utils.h"

using namespace dnnl::impl::cpu::aarch64;

namespace ov::intel_cpu::aarch64::utils {

std::vector<Xbyak_aarch64::XReg> get_aux_gprs(const std::vector<size_t>& used_gpr_idxs, size_t count) {
// X0 and X1 - runtime parameter registers in the kernel
// X18 - platform register should not be used
// SP - stack pointer should be preserved
static const std::unordered_set<size_t> blacklist_gpr_idxs = {
0, // abi_param1 (X0)
1, // abi_param2 (X1)
18, // Platform register (X18)
31, // Stack pointer (SP)
};

OPENVINO_ASSERT(count <= 32 - blacklist_gpr_idxs.size(),
"Cannot allocate more than ",
32 - blacklist_gpr_idxs.size(),
" auxiliary registers");

// Convert used_gpr_idxs to unordered_set for O(1) lookups
const std::unordered_set<size_t> used_set(used_gpr_idxs.begin(), used_gpr_idxs.end());

std::vector<Xbyak_aarch64::XReg> aux_regs;
aux_regs.reserve(count);

// Iterate from X30 down to X0 (allocate from the end)
for (size_t idx = 30; idx != SIZE_MAX; --idx) {
if (used_set.count(idx) || blacklist_gpr_idxs.count(idx)) {
continue;
}
aux_regs.emplace_back(idx);
if (aux_regs.size() == count) {
break;
}
}

OPENVINO_ASSERT(aux_regs.size() == count, "Expected ", count, " auxiliary registers, but got ", aux_regs.size());
return aux_regs;
}

Xbyak_aarch64::XReg get_aux_gpr(const std::vector<size_t>& used_gpr_idxs) {
return get_aux_gprs(used_gpr_idxs, 1)[0];
}

Xbyak_aarch64::XReg init_memory_access_aux_gpr(const std::vector<size_t>& used_gpr_reg_idxs,
const std::vector<size_t>& aux_gpr_idxs,
std::set<snippets::Reg>& regs_to_spill) {
if (!aux_gpr_idxs.empty()) {
return Xbyak_aarch64::XReg(static_cast<int>(aux_gpr_idxs[0]));
}
const auto aux_reg = get_aux_gpr(used_gpr_reg_idxs);
regs_to_spill.emplace(snippets::RegType::gpr, aux_reg.getIdx());
return aux_reg;
}

void push_ptr_with_runtime_offset_on_stack(dnnl::impl::cpu::aarch64::jit_generator* h,
int32_t stack_offset,
const Xbyak_aarch64::XReg& ptr_reg,
const std::vector<Xbyak_aarch64::XReg>& aux_regs,
size_t runtime_offset) {
// Safety assertions as suggested
OV_CPU_JIT_EMITTER_ASSERT(aux_regs.size() >= 3, "aux_regs must contain at least 3 registers");

// Assert that ptr_reg is not in aux_regs
for (const auto& reg : aux_regs) {
OV_CPU_JIT_EMITTER_ASSERT(reg.getIdx() != ptr_reg.getIdx(), "ptr_reg must not be in aux_regs");
}

// Use safe auxiliary registers from the provided set
const Xbyak_aarch64::XReg aux_reg = aux_regs[0]; // For storing adjusted pointer
const Xbyak_aarch64::XReg temp_reg = aux_regs[1]; // For temporary calculations
const Xbyak_aarch64::XReg addr_reg = aux_regs[2]; // For address calculations in add_imm

// Copy pointer to aux register
h->mov(aux_reg, ptr_reg);

// Load the runtime offset from abi_param1 (X0) and add it to the pointer
Xbyak_aarch64::XReg abi_param1(0);

// Load the offset value from the runtime parameter location
h->add_imm(temp_reg, abi_param1, runtime_offset, addr_reg);
h->ldr(temp_reg, Xbyak_aarch64::ptr(temp_reg));

h->add(aux_reg, aux_reg, temp_reg);

// Store the adjusted pointer on stack
h->str(aux_reg, Xbyak_aarch64::ptr(h->sp, stack_offset));
}

void push_ptr_with_static_offset_on_stack(dnnl::impl::cpu::aarch64::jit_generator* h,
int32_t stack_offset,
const Xbyak_aarch64::XReg& ptr_reg,
const std::vector<Xbyak_aarch64::XReg>& aux_regs,
size_t ptr_offset) {
// If there's no static offset, just store the pointer
if (ptr_offset == 0) {
h->str(ptr_reg, Xbyak_aarch64::ptr(h->sp, stack_offset));
return;
}

// Safety assertions as suggested
OV_CPU_JIT_EMITTER_ASSERT(aux_regs.size() >= 2, "aux_regs must contain at least 2 registers");

// Assert that ptr_reg is not in aux_regs
for (const auto& reg : aux_regs) {
OV_CPU_JIT_EMITTER_ASSERT(reg.getIdx() != ptr_reg.getIdx(), "ptr_reg must not be in aux_regs");
}

// Use safe auxiliary registers from the provided vector
const Xbyak_aarch64::XReg temp_reg = aux_regs[0]; // For storing adjusted pointer
const Xbyak_aarch64::XReg addr_reg = aux_regs[1]; // For address calculations in add_imm

// For non-zero offsets, apply the offset and then store
h->add_imm(temp_reg, ptr_reg, ptr_offset, addr_reg);

// Store the adjusted pointer on stack
h->str(temp_reg, Xbyak_aarch64::ptr(h->sp, stack_offset));
}

void push_and_load_ptrs_with_offsets(dnnl::impl::cpu::aarch64::jit_generator* h,
const std::vector<Xbyak_aarch64::XReg>& mem_ptrs,
const std::vector<size_t>& memory_offsets,
const std::vector<size_t>& buffer_ids,
const std::vector<Xbyak_aarch64::XReg>& aux_regs,
const std::vector<Xbyak_aarch64::XReg>& load_regs) {
const size_t gpr_length = 8; // 64-bit register length
const size_t sp_alignment = 16; // AArch64 stack alignment requirement

// Allocate stack space for all pointers
const auto sp_size = rnd_up(mem_ptrs.size() * gpr_length, sp_alignment);
h->sub(h->sp, h->sp, sp_size);

// Push all pointers with offsets onto stack
for (size_t i = 0; i < mem_ptrs.size(); i++) {
const auto& ptr_reg = mem_ptrs[i];
int32_t stack_offset = i * gpr_length;

if (ov::snippets::utils::is_dynamic_value(memory_offsets[i])) {
// Dynamic offset: read from runtime parameters
size_t runtime_offset = GET_OFF(buffer_offsets) + buffer_ids[i] * sizeof(size_t);
push_ptr_with_runtime_offset_on_stack(h, stack_offset, ptr_reg, aux_regs, runtime_offset);
} else {
// Static offset: add compile-time constant
push_ptr_with_static_offset_on_stack(h, stack_offset, ptr_reg, aux_regs, memory_offsets[i]);
}
}

// Load back the adjusted pointers to specified registers
for (size_t i = 0; i < load_regs.size() && i < mem_ptrs.size(); i++) {
h->ldr(load_regs[i], Xbyak_aarch64::ptr(h->sp, static_cast<int32_t>(i * gpr_length)));
}

// Restore stack pointer
h->add(h->sp, h->sp, sp_size);
}

} // namespace ov::intel_cpu::aarch64::utils
Loading
Loading