@@ -66,6 +66,7 @@ void jit_gemm_emitter::validate_arguments(const std::vector<size_t>& in, const s
66
66
OV_CPU_JIT_EMITTER_ASSERT (in.size () == 2 , " Expects 2 input regs, got" , in.size ());
67
67
OV_CPU_JIT_EMITTER_ASSERT (out.size () == 1 , " Expects 1 output reg, got" , out.size ());
68
68
OV_CPU_JIT_EMITTER_ASSERT (m_memory_offsets.size () == 3 , " Expected 3 memory offsets for A, B, C" );
69
+ OV_CPU_JIT_EMITTER_ASSERT (m_buffer_ids.size () == 3 , " Expected 3 buffer IDs for A, B, C" );
69
70
}
70
71
71
72
void jit_gemm_emitter::emit_impl (const std::vector<size_t >& in, const std::vector<size_t >& out) const {
@@ -104,9 +105,12 @@ void jit_gemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vecto
104
105
}
105
106
106
107
// Load back the adjusted pointers for function call
108
+ h->ldr (x1, Xbyak_aarch64::ptr (h->sp )); // matrix A (in0)
109
+ h->ldr (x2, Xbyak_aarch64::ptr (h->sp , get_vec_length ())); // matrix B (in1)
107
110
h->ldr (x3, Xbyak_aarch64::ptr (h->sp , 2 * get_vec_length ())); // matrix C (out)
108
- h->ldr (x2, Xbyak_aarch64::ptr (h->sp , 1 * get_vec_length ())); // matrix B (in1)
109
- h->ldr (x1, Xbyak_aarch64::ptr (h->sp , 0 * get_vec_length ())); // matrix A (in0)
111
+
112
+ // Restore stack pointer
113
+ h->add (h->sp , h->sp , 3 * get_vec_length ());
110
114
111
115
// Set up executor pointer as first argument
112
116
const auto & compiled_kernel = get_compiled_kernel_ptr ();
@@ -116,9 +120,6 @@ void jit_gemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vecto
116
120
h->mov (func_reg, get_execute_function_ptr ());
117
121
h->blr (func_reg);
118
122
119
- // Restore stack pointer
120
- h->add (h->sp , h->sp , 3 * get_vec_length ());
121
-
122
123
restore_context (exclude);
123
124
}
124
125
0 commit comments