Skip to content

Commit cc4e8fb

Browse files
committed
[Snippets][CPU] Support static and dynamic offsets in JIT Gemm and GemmCopyB emitters
1 parent d9a8e93 commit cc4e8fb

File tree

7 files changed

+331
-11
lines changed

7 files changed

+331
-11
lines changed

src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_gemm_copy_b_emitter.cpp

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#include <vector>
1818

1919
#include "emitters/snippets/aarch64/kernel_executors/gemm_copy_b.hpp"
20+
#include "emitters/snippets/aarch64/utils.hpp"
21+
#include "emitters/snippets/jit_snippets_call_args.hpp"
2022
#include "emitters/utils.hpp"
2123
#include "openvino/core/node.hpp"
2224
#include "openvino/core/type.hpp"
@@ -53,6 +55,13 @@ jit_gemm_copy_b_emitter::jit_gemm_copy_b_emitter(jit_generator* h,
5355
OV_CPU_JIT_EMITTER_ASSERT(n_blk_size > 0, "n_blk_size of gemm_repack is expected to be greater than 0.");
5456
GemmCopyBKernelKaiConfig kernel_config(n_blk_size);
5557
m_kernel_executor = kernel_table->register_kernel<GemmCopyBKaiKernelExecutor>(expr, kernel_config);
58+
59+
// Initialize memory offsets similar to x64 brgemm_copy_b implementation
60+
m_memory_offsets = {gemm_repack->get_offset_in(), gemm_repack->get_offset_out()};
61+
62+
// Initialize buffer IDs using the utils function
63+
m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)),
64+
utils::get_buffer_cluster_id(expr->get_output_port(0))};
5665
}
5766

5867
std::set<std::vector<element::Type>> jit_gemm_copy_b_emitter::get_supported_precisions(
@@ -64,6 +73,7 @@ std::set<std::vector<element::Type>> jit_gemm_copy_b_emitter::get_supported_prec
6473
void jit_gemm_copy_b_emitter::validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
6574
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 1, "Expects 1 input reg, got", in.size());
6675
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got", out.size());
76+
OV_CPU_JIT_EMITTER_ASSERT(m_memory_offsets.size() == 2, "Expected 2 memory offsets for input and output");
6777
}
6878

6979
void jit_gemm_copy_b_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
@@ -75,17 +85,46 @@ void jit_gemm_copy_b_emitter::emit_impl(const std::vector<size_t>& in, const std
7585
Xbyak_aarch64::XReg x0(0);
7686
Xbyak_aarch64::XReg x1(1);
7787
Xbyak_aarch64::XReg x2(2);
78-
h->str(Xbyak_aarch64::XReg(in[0]), pre_ptr(h->sp, -get_vec_length()));
79-
h->str(Xbyak_aarch64::XReg(out[0]), pre_ptr(h->sp, -get_vec_length()));
80-
h->ldr(x2, post_ptr(h->sp, get_vec_length()));
81-
h->ldr(x1, post_ptr(h->sp, get_vec_length()));
88+
Xbyak_aarch64::XReg aux_reg(3);
89+
90+
// Prepare memory pointers with offsets
91+
std::vector<size_t> mem_ptrs_idxs{in[0], out[0]};
92+
const auto& mem_ptrs = utils::transform_idxs_to_regs(mem_ptrs_idxs);
93+
94+
// Apply memory offsets to input/output pointers
95+
// Reserve space on stack for input/output pointers
96+
h->sub(h->sp, h->sp, 2 * get_vec_length());
97+
98+
// Apply offsets and store pointers on stack
99+
for (size_t i = 0; i < mem_ptrs.size(); i++) {
100+
const auto& ptr_reg = mem_ptrs[i];
101+
int32_t stack_offset = i * get_vec_length();
102+
103+
if (ov::snippets::utils::is_dynamic_value(m_memory_offsets[i])) {
104+
// Dynamic offset: read from runtime parameters
105+
size_t runtime_offset = GET_OFF(buffer_offsets) + m_buffer_ids[i] * sizeof(size_t);
106+
utils::push_ptr_with_runtime_offset_on_stack(h, stack_offset, ptr_reg, aux_reg, runtime_offset);
107+
} else {
108+
// Static offset: add compile-time constant
109+
utils::push_ptr_with_static_offset_on_stack(h, stack_offset, ptr_reg, m_memory_offsets[i]);
110+
}
111+
}
112+
113+
// Load back the adjusted pointers for function call
114+
h->ldr(x2, Xbyak_aarch64::ptr(h->sp, 1 * get_vec_length())); // output pointer
115+
h->ldr(x1, Xbyak_aarch64::ptr(h->sp, 0 * get_vec_length())); // input pointer
116+
117+
// Set up executor pointer as first argument
82118
const auto& compiled_kernel = get_compiled_kernel_ptr();
83119
h->mov(x0, compiled_kernel);
84120

85121
Xbyak_aarch64::XReg func_reg(9);
86122
h->mov(func_reg, get_execute_function_ptr());
87123
h->blr(func_reg);
88124

125+
// Restore stack pointer
126+
h->add(h->sp, h->sp, 2 * get_vec_length());
127+
89128
restore_context(exclude);
90129
}
91130

src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_gemm_copy_b_emitter.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class jit_gemm_copy_b_emitter : public jit_emitter {
3030
const uintptr_t get_compiled_kernel_ptr() const;
3131

3232
std::shared_ptr<GemmCopyBKaiKernelExecutor> m_kernel_executor = nullptr;
33+
std::vector<size_t> m_memory_offsets;
34+
std::vector<size_t> m_buffer_ids;
3335
};
3436

3537
} // namespace ov::intel_cpu::aarch64

src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_gemm_emitter.cpp

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717
#include <vector>
1818

1919
#include "emitters/snippets/aarch64/kernel_executors/gemm.hpp"
20+
#include "emitters/snippets/aarch64/utils.hpp"
21+
#include "emitters/snippets/jit_snippets_call_args.hpp"
2022
#include "emitters/utils.hpp"
2123
#include "openvino/core/node.hpp"
2224
#include "openvino/core/type/element_type.hpp"
2325
#include "snippets/kernel_executor_table.hpp"
2426
#include "snippets/lowered/expression.hpp"
27+
#include "snippets/utils/utils.hpp"
28+
#include "transformations/snippets/aarch64/op/gemm_cpu.hpp"
2529

2630
using namespace Xbyak_aarch64;
2731

@@ -39,6 +43,17 @@ jit_gemm_emitter::jit_gemm_emitter(jit_generator* h,
3943
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
4044
GemmKernelKaiConfig kernel_config;
4145
m_kernel_executor_kai = kernel_table->register_kernel<GemmKaiKernelExecutor>(expr, kernel_config);
46+
47+
const auto gemm_node = as_type_ptr<GemmCPU>(expr->get_node());
48+
OV_CPU_JIT_EMITTER_ASSERT(gemm_node, "Expected GemmCPU node");
49+
50+
// Initialize memory offsets similar to x64 brgemm implementation
51+
m_memory_offsets = {gemm_node->get_offset_a(), gemm_node->get_offset_b(), gemm_node->get_offset_c()};
52+
53+
// Initialize buffer IDs using the utils function
54+
m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)),
55+
utils::get_buffer_cluster_id(expr->get_input_port(1)),
56+
utils::get_buffer_cluster_id(expr->get_output_port(0))};
4257
}
4358

4459
std::set<std::vector<element::Type>> jit_gemm_emitter::get_supported_precisions(
@@ -50,6 +65,7 @@ std::set<std::vector<element::Type>> jit_gemm_emitter::get_supported_precisions(
5065
void jit_gemm_emitter::validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
5166
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got", in.size());
5267
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got", out.size());
68+
OV_CPU_JIT_EMITTER_ASSERT(m_memory_offsets.size() == 3, "Expected 3 memory offsets for A, B, C");
5369
}
5470

5571
void jit_gemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
@@ -62,19 +78,47 @@ void jit_gemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vecto
6278
Xbyak_aarch64::XReg x1(1);
6379
Xbyak_aarch64::XReg x2(2);
6480
Xbyak_aarch64::XReg x3(3);
65-
h->str(Xbyak_aarch64::XReg(in[0]), pre_ptr(h->sp, -get_vec_length()));
66-
h->str(Xbyak_aarch64::XReg(in[1]), pre_ptr(h->sp, -get_vec_length()));
67-
h->str(Xbyak_aarch64::XReg(out[0]), pre_ptr(h->sp, -get_vec_length()));
68-
h->ldr(x3, post_ptr(h->sp, get_vec_length()));
69-
h->ldr(x2, post_ptr(h->sp, get_vec_length()));
70-
h->ldr(x1, post_ptr(h->sp, get_vec_length()));
81+
Xbyak_aarch64::XReg aux_reg(5);
82+
83+
// Prepare memory pointers with offsets
84+
std::vector<size_t> mem_ptrs_idxs{in[0], in[1], out[0]};
85+
const auto& mem_ptrs = utils::transform_idxs_to_regs(mem_ptrs_idxs);
86+
87+
// Apply memory offsets to input/output pointers
88+
// Reserve space on stack for matrix pointers - use pre-decrement addressing
89+
h->sub(h->sp, h->sp, 3 * get_vec_length());
90+
91+
// Apply offsets and store pointers on stack
92+
for (size_t i = 0; i < mem_ptrs.size(); i++) {
93+
const auto& ptr_reg = mem_ptrs[i];
94+
int32_t stack_offset = i * get_vec_length();
95+
96+
if (ov::snippets::utils::is_dynamic_value(m_memory_offsets[i])) {
97+
// Dynamic offset: read from runtime parameters
98+
size_t runtime_offset = GET_OFF(buffer_offsets) + m_buffer_ids[i] * sizeof(size_t);
99+
utils::push_ptr_with_runtime_offset_on_stack(h, stack_offset, ptr_reg, aux_reg, runtime_offset);
100+
} else {
101+
// Static offset: add compile-time constant
102+
utils::push_ptr_with_static_offset_on_stack(h, stack_offset, ptr_reg, m_memory_offsets[i]);
103+
}
104+
}
105+
106+
// Load back the adjusted pointers for function call
107+
h->ldr(x3, Xbyak_aarch64::ptr(h->sp, 2 * get_vec_length())); // matrix C (out)
108+
h->ldr(x2, Xbyak_aarch64::ptr(h->sp, 1 * get_vec_length())); // matrix B (in1)
109+
h->ldr(x1, Xbyak_aarch64::ptr(h->sp, 0 * get_vec_length())); // matrix A (in0)
110+
111+
// Set up executor pointer as first argument
71112
const auto& compiled_kernel = get_compiled_kernel_ptr();
72113
h->mov(x0, compiled_kernel);
73114

74115
Xbyak_aarch64::XReg func_reg(9);
75116
h->mov(func_reg, get_execute_function_ptr());
76117
h->blr(func_reg);
77118

119+
// Restore stack pointer
120+
h->add(h->sp, h->sp, 3 * get_vec_length());
121+
78122
restore_context(exclude);
79123
}
80124

src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_gemm_emitter.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ class jit_gemm_emitter : public jit_emitter {
3232
const uintptr_t get_compiled_kernel_ptr() const;
3333

3434
std::shared_ptr<GemmKaiKernelExecutor> m_kernel_executor_kai = nullptr;
35+
std::vector<size_t> m_memory_offsets;
36+
std::vector<size_t> m_buffer_ids;
3537
};
3638

3739
} // namespace ov::intel_cpu::aarch64
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
// Copyright (C) 2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "utils.hpp"
6+
7+
#include <algorithm>
8+
#include <common/utils.hpp>
9+
#include <cstddef>
10+
#include <cstdint>
11+
#include <memory>
12+
#include <set>
13+
#include <unordered_set>
14+
#include <vector>
15+
16+
#include "emitters/utils.hpp"
17+
#include "openvino/core/except.hpp"
18+
#include "openvino/core/type.hpp"
19+
#include "snippets/emitter.hpp"
20+
#include "snippets/lowered/expression_port.hpp"
21+
#include "snippets/lowered/expressions/buffer_expression.hpp"
22+
#include "snippets/op/loop.hpp"
23+
#include "snippets/op/memory_access.hpp"
24+
#include "snippets/utils/utils.hpp"
25+
26+
using namespace dnnl::impl::cpu::aarch64;
27+
28+
namespace ov::intel_cpu::aarch64::utils {
29+
30+
size_t get_buffer_cluster_id(const ov::snippets::lowered::ExpressionPort& port) {
31+
auto get_cluster_id = [](const snippets::lowered::ExpressionPort& p) {
32+
const auto buffer = ov::as_type_ptr<ov::snippets::lowered::BufferExpression>(p.get_expr());
33+
return buffer ? buffer->get_cluster_id() : SIZE_MAX;
34+
};
35+
const auto& ma_op = std::dynamic_pointer_cast<ov::snippets::modifier::MemoryAccess>(port.get_expr()->get_node());
36+
OPENVINO_ASSERT(ma_op, "Expected MemoryAccess op!");
37+
auto offset = ov::snippets::utils::get_dynamic_value<size_t>();
38+
size_t id = SIZE_MAX;
39+
switch (port.get_type()) {
40+
case ov::snippets::lowered::ExpressionPort::Type::Input:
41+
offset = ma_op->get_input_offset(port.get_index());
42+
id = get_cluster_id(port.get_port_connector_ptr()->get_source());
43+
break;
44+
case ov::snippets::lowered::ExpressionPort::Type::Output:
45+
offset = ma_op->get_output_offset(port.get_index());
46+
for (const auto& child : port.get_connected_ports()) {
47+
if (!ov::is_type<snippets::op::LoopEnd>(child.get_expr()->get_node())) {
48+
id = get_cluster_id(child);
49+
}
50+
}
51+
break;
52+
default:
53+
OV_CPU_JIT_EMITTER_THROW("Uknown type of expression port!");
54+
}
55+
OV_CPU_JIT_EMITTER_ASSERT(IMPLICATION(ov::snippets::utils::is_dynamic_value(offset), id != SIZE_MAX),
56+
"In dynamic case Buffer Cluster ID must be known!");
57+
return id;
58+
}
59+
60+
Xbyak_aarch64::XReg get_aux_gpr(const std::vector<size_t>& used_gpr_idxs) {
61+
// SP - stack pointer should be preserved, X0 and X1 - runtime parameter registers in the kernel
62+
// X18 - platform register should not be used
63+
static std::unordered_set<size_t> blacklist_gpr_idxs = {
64+
31, // Stack pointer (SP)
65+
0, // abi_param1 (X0)
66+
1, // abi_param2 (X1)
67+
18 // Platform register (X18)
68+
};
69+
70+
// Iterate through available GPR registers (X0-X30, excluding X31 which is SP)
71+
for (size_t gpr_idx = 0; gpr_idx <= 30; ++gpr_idx) {
72+
size_t _idx = 30 - gpr_idx; // we allocate from the end
73+
if (std::find(used_gpr_idxs.cbegin(), used_gpr_idxs.cend(), _idx) != used_gpr_idxs.cend()) {
74+
continue;
75+
}
76+
if (blacklist_gpr_idxs.count(_idx) > 0) {
77+
continue;
78+
}
79+
return Xbyak_aarch64::XReg(_idx);
80+
}
81+
OV_CPU_JIT_EMITTER_THROW("Failed to allocate aux GPR");
82+
}
83+
84+
Xbyak_aarch64::XReg init_memory_access_aux_gpr(const std::vector<size_t>& used_gpr_reg_idxs,
85+
const std::vector<size_t>& aux_gpr_idxs,
86+
std::set<snippets::Reg>& regs_to_spill) {
87+
if (!aux_gpr_idxs.empty()) {
88+
return Xbyak_aarch64::XReg(static_cast<int>(aux_gpr_idxs[0]));
89+
}
90+
const auto aux_reg = ov::intel_cpu::aarch64::utils::get_aux_gpr(used_gpr_reg_idxs);
91+
regs_to_spill.emplace(snippets::RegType::gpr, aux_reg.getIdx());
92+
return aux_reg;
93+
}
94+
95+
void push_ptr_with_runtime_offset_on_stack(dnnl::impl::cpu::aarch64::jit_generator* h,
96+
int32_t stack_offset,
97+
const Xbyak_aarch64::XReg& ptr_reg,
98+
const Xbyak_aarch64::XReg& aux_reg,
99+
size_t runtime_offset) {
100+
// Copy pointer to aux register
101+
h->mov(aux_reg, ptr_reg);
102+
103+
// Load the runtime offset from abi_param1 (X0) and add it to the pointer
104+
Xbyak_aarch64::XReg abi_param1(0);
105+
Xbyak_aarch64::XReg offset_reg(4);
106+
107+
// Handle large runtime offsets by using a temporary register
108+
if (runtime_offset > 4095) {
109+
Xbyak_aarch64::XReg temp_offset_reg(6);
110+
h->mov(temp_offset_reg, static_cast<uint64_t>(runtime_offset));
111+
h->add(temp_offset_reg, abi_param1, temp_offset_reg);
112+
h->ldr(offset_reg, Xbyak_aarch64::ptr(temp_offset_reg));
113+
} else {
114+
h->ldr(offset_reg, Xbyak_aarch64::ptr(abi_param1, static_cast<int32_t>(runtime_offset)));
115+
}
116+
117+
h->add(aux_reg, aux_reg, offset_reg);
118+
119+
// Store the adjusted pointer on stack
120+
h->str(aux_reg, Xbyak_aarch64::ptr(h->sp, stack_offset));
121+
}
122+
123+
void push_ptr_with_static_offset_on_stack(dnnl::impl::cpu::aarch64::jit_generator* h,
124+
int32_t stack_offset,
125+
const Xbyak_aarch64::XReg& ptr_reg,
126+
size_t ptr_offset) {
127+
// If there's no static offset, just store the pointer
128+
if (ptr_offset == 0) {
129+
h->str(ptr_reg, Xbyak_aarch64::ptr(h->sp, stack_offset));
130+
return;
131+
}
132+
133+
// For non-zero offsets, apply the offset and then store
134+
Xbyak_aarch64::XReg temp_reg(4);
135+
h->mov(temp_reg, ptr_reg);
136+
137+
// For large offsets, use a register to hold the offset value
138+
if (ptr_offset > 4095) { // 12-bit immediate limit for add instruction
139+
Xbyak_aarch64::XReg offset_reg(6);
140+
h->mov(offset_reg, static_cast<uint64_t>(ptr_offset));
141+
h->add(temp_reg, temp_reg, offset_reg);
142+
} else {
143+
h->add(temp_reg, temp_reg, static_cast<int32_t>(ptr_offset));
144+
}
145+
146+
// Store the adjusted pointer on stack
147+
h->str(temp_reg, Xbyak_aarch64::ptr(h->sp, stack_offset));
148+
}
149+
150+
} // namespace ov::intel_cpu::aarch64::utils

0 commit comments

Comments
 (0)