Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 79 additions & 138 deletions src/sgemm_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,177 +325,118 @@ unsafe fn kernel_x86_avx<MA>(k: usize, alpha: T, a: *const T, b: *const T,
let (mut a, mut b) = if prefer_row_major_c { (a, b) } else { (b, a) };
let (rsc, csc) = if prefer_row_major_c { (rsc, csc) } else { (csc, rsc) };

macro_rules! shuffle_mask {
($z:expr, $y:expr, $x:expr, $w:expr) => {
($z << 6) | ($y << 4) | ($x << 2) | $w
}
}
macro_rules! permute_mask {
($z:expr, $y:expr, $x:expr, $w:expr) => {
($z << 6) | ($y << 4) | ($x << 2) | $w
}
}

macro_rules! permute2f128_mask {
($y:expr, $x:expr) => {
(($y << 4) | $x)
}
}

// Start data load before each iteration
let mut av = _mm256_load_ps(a);
let mut bv = _mm256_load_ps(b);
let mut bvl = _mm256_broadcast_ps(&*(b.add(0) as *const _));
let mut bvh = _mm256_broadcast_ps(&*(b.add(4) as *const _));

// Compute A B
unroll_by_with_last!(4 => k, is_last, {
// We compute abij = ai bj
//
// Load b as one contiguous vector
// Load a as striped vectors
//
// Shuffle the abij elements in order after the loop.
//
// Note this scheme copied and transposed from the BLIS 8x8 sgemm
// microkernel.
//
// Our a indices are striped and our b indices are linear. In
// the variable names below, we always have doubled indices so
// for example a0246 corresponds to a vector of a0 a0 a2 a2 a4 a4 a6 a6.
//
// ab0246: ab2064: ab4602: ab6420:
// ( ab00 ( ab20 ( ab40 ( ab60
// ab01 ab21 ab41 ab61
// ab22 ab02 ab62 ab42
// ab23 ab03 ab63 ab43
// ab44 ab64 ab04 ab24
// ab45 ab65 ab05 ab25
// ab66 ab46 ab26 ab06
// ab67 ) ab47 ) ab27 ) ab07 )
// ab0: ab1: ab2: ab3:
// ( ab00 ( ab10 ( ab20 ( ab30
// ab11 ab21 ab31 ab01
// ab22 ab32 ab02 ab12
// ab33 ab03 ab13 ab23
// ab40 ab50 ab60 ab70
// ab51 ab61 ab71 ab41
// ab62 ab72 ab42 ab52
// ab73 ) ab43 ) ab53 ) ab63 )
//
// ab1357: ab3175: ab5713: ab7531:
// ( ab10 ( ab30 ( ab50 ( ab70
// ab11 ab31 ab51 ab71
// ab32 ab12 ab72 ab52
// ab33 ab13 ab73 ab53
// ab54 ab74 ab14 ab34
// ab55 ab75 ab15 ab35
// ab76 ab56 ab36 ab16
// ab77 ) ab57 ) ab37 ) ab17 )

const PERM32_2301: i32 = permute_mask!(1, 0, 3, 2);
const PERM128_30: i32 = permute2f128_mask!(0, 3);

// _mm256_moveldup_ps(av):
// vmovsldup ymm2, ymmword ptr [rax]
//
// Load and duplicate each even word:
// ymm2 ← [a0 a0 a2 a2 a4 a4 a6 a6]
//
// _mm256_movehdup_ps(av):
// vmovshdup ymm2, ymmword ptr [rax]
//
// Load and duplicate each odd word:
// ymm2 ← [a1 a1 a3 a3 a5 a5 a7 a7]
//
// ( ab04 ( ab14 ( ab24 ( ab34
// ab15 ab25 ab35 ab05
// ab26 ab36 ab06 ab16
// ab37 ab07 ab17 ab27
// ab44 ab54 ab64 ab74
// ab55 ab65 ab75 ab45
// ab66 ab76 ab46 ab56
// ab77 ) ab47 ) ab57 ) ab67 )

let a0246 = _mm256_moveldup_ps(av); // Load: a0 a0 a2 a2 a4 a4 a6 a6
let a2064 = _mm256_permute_ps(a0246, PERM32_2301);
let a01234567 = av;
let a12305674 = _mm256_permute_ps(av, permute_mask!(0, 3, 2, 1));
let a23016745 = _mm256_permute_ps(av, permute_mask!(1, 0, 3, 2));
let a30127456 = _mm256_permute_ps(av, permute_mask!(2, 1, 0, 3));

let a1357 = _mm256_movehdup_ps(av); // Load: a1 a1 a3 a3 a5 a5 a7 a7
let a3175 = _mm256_permute_ps(a1357, PERM32_2301);
ab[0] = MA::multiply_add(a01234567, bvl, ab[0]);
ab[4] = MA::multiply_add(a01234567, bvh, ab[4]);

let bv_lh = _mm256_permute2f128_ps(bv, bv, PERM128_30);
ab[1] = MA::multiply_add(a12305674, bvl, ab[1]);
ab[5] = MA::multiply_add(a12305674, bvh, ab[5]);

ab[0] = MA::multiply_add(a0246, bv, ab[0]);
ab[1] = MA::multiply_add(a2064, bv, ab[1]);
ab[2] = MA::multiply_add(a0246, bv_lh, ab[2]);
ab[3] = MA::multiply_add(a2064, bv_lh, ab[3]);
ab[2] = MA::multiply_add(a23016745, bvl, ab[2]);
ab[6] = MA::multiply_add(a23016745, bvh, ab[6]);

ab[4] = MA::multiply_add(a1357, bv, ab[4]);
ab[5] = MA::multiply_add(a3175, bv, ab[5]);
ab[6] = MA::multiply_add(a1357, bv_lh, ab[6]);
ab[7] = MA::multiply_add(a3175, bv_lh, ab[7]);
ab[3] = MA::multiply_add(a30127456, bvl, ab[3]);
ab[7] = MA::multiply_add(a30127456, bvh, ab[7]);

if !is_last {
a = a.add(MR);
b = b.add(NR);

bv = _mm256_load_ps(b);
bvl = _mm256_broadcast_ps(&*(b.add(0) as *const _));
bvh = _mm256_broadcast_ps(&*(b.add(4) as *const _));
av = _mm256_load_ps(a);
}
});

let alphav = _mm256_set1_ps(alpha);

// Permute to put the abij elements in order
//
// shufps 0xe4: 22006644 00224466 -> 22226666
//
// vperm2 0x30: 00004444 44440000 -> 00000000
// vperm2 0x12: 00004444 44440000 -> 44444444
//

let ab0246 = ab[0];
let ab2064 = ab[1];
let ab4602 = ab[2]; // reverse order
let ab6420 = ab[3]; // reverse order

let ab1357 = ab[4];
let ab3175 = ab[5];
let ab5713 = ab[6]; // reverse order
let ab7531 = ab[7]; // reverse order

const SHUF_0123: i32 = shuffle_mask!(3, 2, 1, 0);
debug_assert_eq!(SHUF_0123, 0xE4);

const PERM128_02: i32 = permute2f128_mask!(2, 0);
const PERM128_31: i32 = permute2f128_mask!(1, 3);

// No elements are "shuffled" in truth, they all stay at their index
// but we combine vectors to de-stripe them.
//
// For example, the first shuffle below uses 0 1 2 3 which
// corresponds to the X0 X1 Y2 Y3 sequence etc:
//
// variable
// X ab00 ab01 ab22 ab23 ab44 ab45 ab66 ab67 ab0246
// Y ab20 ab21 ab02 ab03 ab64 ab65 ab46 ab47 ab2064
//
// X0 X1 Y2 Y3 X4 X5 Y6 Y7
// = ab00 ab01 ab02 ab03 ab44 ab45 ab46 ab47 ab0044

let ab0044 = _mm256_shuffle_ps(ab0246, ab2064, SHUF_0123);
let ab2266 = _mm256_shuffle_ps(ab2064, ab0246, SHUF_0123);

let ab4400 = _mm256_shuffle_ps(ab4602, ab6420, SHUF_0123);
let ab6622 = _mm256_shuffle_ps(ab6420, ab4602, SHUF_0123);

let ab1155 = _mm256_shuffle_ps(ab1357, ab3175, SHUF_0123);
let ab3377 = _mm256_shuffle_ps(ab3175, ab1357, SHUF_0123);

let ab5511 = _mm256_shuffle_ps(ab5713, ab7531, SHUF_0123);
let ab7733 = _mm256_shuffle_ps(ab7531, ab5713, SHUF_0123);

let ab0000 = _mm256_permute2f128_ps(ab0044, ab4400, PERM128_02);
let ab4444 = _mm256_permute2f128_ps(ab0044, ab4400, PERM128_31);

let ab2222 = _mm256_permute2f128_ps(ab2266, ab6622, PERM128_02);
let ab6666 = _mm256_permute2f128_ps(ab2266, ab6622, PERM128_31);

let ab1111 = _mm256_permute2f128_ps(ab1155, ab5511, PERM128_02);
let ab5555 = _mm256_permute2f128_ps(ab1155, ab5511, PERM128_31);

let ab3333 = _mm256_permute2f128_ps(ab3377, ab7733, PERM128_02);
let ab7777 = _mm256_permute2f128_ps(ab3377, ab7733, PERM128_31);

ab[0] = ab0000;
ab[1] = ab1111;
ab[2] = ab2222;
ab[3] = ab3333;
ab[4] = ab4444;
ab[5] = ab5555;
ab[6] = ab6666;
ab[7] = ab7777;
let t0 = ab[0];
let t1 = ab[1];
let t2 = ab[2];
let t3 = ab[3];

let (t0, t1, t2, t3) = (
_mm256_blend_ps(t0, t3, 0b10101010),
_mm256_blend_ps(t1, t0, 0b10101010),
_mm256_blend_ps(t2, t1, 0b10101010),
_mm256_blend_ps(t3, t2, 0b10101010),
);

let (t0, t1, t2, t3) = (
_mm256_blend_ps(t0, t2, 0b11001100),
_mm256_blend_ps(t1, t3, 0b11001100),
_mm256_blend_ps(t2, t0, 0b11001100),
_mm256_blend_ps(t3, t1, 0b11001100),
);

let t4 = ab[4];
let t5 = ab[5];
let t6 = ab[6];
let t7 = ab[7];

let (t4, t5, t6, t7) = (
_mm256_blend_ps(t4, t7, 0b10101010),
_mm256_blend_ps(t5, t4, 0b10101010),
_mm256_blend_ps(t6, t5, 0b10101010),
_mm256_blend_ps(t7, t6, 0b10101010),
);

let (t4, t5, t6, t7) = (
_mm256_blend_ps(t4, t6, 0b11001100),
_mm256_blend_ps(t5, t7, 0b11001100),
_mm256_blend_ps(t6, t4, 0b11001100),
_mm256_blend_ps(t7, t5, 0b11001100),
);

ab[0] = _mm256_permute2f128_ps(t0, t4, 0x20);
ab[1] = _mm256_permute2f128_ps(t1, t5, 0x20);
ab[2] = _mm256_permute2f128_ps(t2, t6, 0x20);
ab[3] = _mm256_permute2f128_ps(t3, t7, 0x20);
ab[4] = _mm256_permute2f128_ps(t0, t4, 0x31);
ab[5] = _mm256_permute2f128_ps(t1, t5, 0x31);
ab[6] = _mm256_permute2f128_ps(t2, t6, 0x31);
ab[7] = _mm256_permute2f128_ps(t3, t7, 0x31);

// Compute α (A B)
// Compute here if we don't have fma, else pick up α further down
Expand Down
Loading