Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions src/cpu/x64/gemm/f32/jit_avx512_common_gemm_f32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ struct xbyak_gemm_t : public jit_generator_t {
auto ALPHA = qword[rsp + 48];
auto BETA = qword[rsp + 64];
auto ORIG_A = qword[rsp + 80];
auto WS_BUF = qword[rsp + 96];
auto ORIG_SP = qword[rsp + 120];

auto ZSTRIDE = zmm4;
Expand All @@ -147,7 +148,8 @@ struct xbyak_gemm_t : public jit_generator_t {
Label pack2, pack3, pack4, pack10;

mov(BO1, A);
lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]);
mov(AO1, WS_BUF);
lea(AO1, ptr[AO1 + OFFSET * SIZE]);
mov(LL, K);
sar(LL, 2);
jle(pack3, T_NEAR);
Expand Down Expand Up @@ -833,13 +835,15 @@ struct xbyak_gemm_t : public jit_generator_t {
auto kernel = [&](int unroll_m, int unroll_n, bool isDirect,
bool isCopy, bool isUnmasked = true) {
if (!isDirect) {
lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]);
mov(AO1, WS_BUF);
lea(AO1, ptr[AO1 + OFFSET * SIZE]);
} else {
mov(AO1, A);
}

if (isCopy) {
lea(LDA4, ptr[rsp + 128 + OFFSET * SIZE]);
mov(LDA4, WS_BUF);
lea(LDA4, ptr[LDA4 + OFFSET * SIZE]);
} else {
auto step = 2;
lea(LDA4, ptr[LDA * step + (16 - 1 - OFFSET) * SIZE]);
Expand Down Expand Up @@ -1454,18 +1458,22 @@ struct xbyak_gemm_t : public jit_generator_t {
cmp(K, STACK_K_CAPACITY);
jg(buffer_in_ws, T_NEAR);

// Create buffer and align to 4kB page
// Using 4kB aligned buffer on stack as workspace
lea(rax, ptr[K * SIZE]);
imul(rax, rax, 0x30);
add(rax, 256);
sub(rsp, rax);
and_(rsp, -PAGE_4K);
lea(rax, ptr[rsp + 128]);
jmp(buffer_allocated, T_NEAR);

L(buffer_in_ws);
mov(rsp, ARG_WS);
// Using buffer in heap as workspace
mov(rax, ARG_WS);
sub(rsp, 256);

L(buffer_allocated);
mov(WS_BUF, rax);

mov(ORIG_SP, rbp);
mov(M, ARG_M);
Expand Down
21 changes: 15 additions & 6 deletions src/cpu/x64/gemm/f32/jit_avx_gemm_f32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -811,13 +811,15 @@ struct xbyak_gemm_t : public jit_generator_t {
const Ymm &reg18, const Ymm &reg19, const Ymm &reg20,
const Ymm &reg21, const Ymm &reg22, const Ymm &reg23) {
if (!isDirect) {
lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
mov(AO1, WS_BUF);
lea(AO1, ptr[AO1 + OFFSET * SIZE]);
} else {
mov(AO1, A);
}

if (isCopy) {
lea(LDA4, ptr[rsp + 256 + OFFSET * SIZE]);
mov(LDA4, WS_BUF);
lea(LDA4, ptr[LDA4 + OFFSET * SIZE]);
} else {
lea(LDA4, ptr[LDA * 8 + (8 - 1 - OFFSET) * SIZE]);
}
Expand Down Expand Up @@ -1313,7 +1315,8 @@ struct xbyak_gemm_t : public jit_generator_t {
Reg64 reg;

mov(BO1, A);
lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
mov(AO1, WS_BUF);
lea(AO1, ptr[AO1 + OFFSET * SIZE]);

if (isTransA) {
lea(BO2, ptr[BO1 + LDA * 4]);
Expand Down Expand Up @@ -1983,18 +1986,22 @@ struct xbyak_gemm_t : public jit_generator_t {
cmp(K, STACK_K_CAPACITY);
jg(buffer_in_ws, T_NEAR);

// Create buffer and align to 4kB page
// Using 4kB aligned buffer on stack as workspace
lea(rax, ptr[K * SIZE]);
sal(rax, math::ilog2q(UNROLL_M));
add(rax, 256);
sub(rsp, rax);
and_(rsp, -PAGE_4K);
lea(rax, ptr[rsp + 256]);
jmp(buffer_allocated, T_NEAR);

L(buffer_in_ws);
mov(rsp, ARG_WS);
// Using buffer in heap as workspace
mov(rax, ARG_WS);
sub(rsp, 256);

L(buffer_allocated);
mov(WS_BUF, rax);

mov(ORIG_SP, rbp);
mov(M, ARG_M);
Expand Down Expand Up @@ -2162,8 +2169,10 @@ struct xbyak_gemm_t : public jit_generator_t {
const Address BETA = qword[rsp + 64];
const Address ORIG_A = qword[rsp + 80];
const Address MASK = dword[rsp + 88];
// STRIDE requires padding to 32 bytes to accomodate ymm loads/stores
const Address STRIDE = qword[rsp + 120];
const Address ORIG_SP = qword[rsp + 152];
const Address WS_BUF = qword[rsp + 152];
const Address ORIG_SP = qword[rsp + 160];

const Ymm VALPHA = ymm1;
const Ymm VBETA = ymm2;
Expand Down
Loading