Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 104 additions & 102 deletions kernel/riscv64/sbgemm_kernel_16x8_zvl256b.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
BLASLONG gvl = 0;
BLASLONG m_top = 0;
BLASLONG n_top = 0;
__bf16 *BB = (__bf16 *)(B);
__bf16 *AA = (__bf16 *)(A);

// -- MAIN PASS
for (BLASLONG j=0; j<N/8; j+=1) {
Expand All @@ -26,17 +28,17 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
vfloat32m2_t result7 = __riscv_vfmv_v_f_f32m2(0.0f, gvl);

for (BLASLONG k=0; k<K; k++) {
__bf16 B0 = B[bi+0];
__bf16 B1 = B[bi+1];
__bf16 B2 = B[bi+2];
__bf16 B3 = B[bi+3];
__bf16 B4 = B[bi+4];
__bf16 B5 = B[bi+5];
__bf16 B6 = B[bi+6];
__bf16 B7 = B[bi+7];
__bf16 B0 = BB[bi+0];
__bf16 B1 = BB[bi+1];
__bf16 B2 = BB[bi+2];
__bf16 B3 = BB[bi+3];
__bf16 B4 = BB[bi+4];
__bf16 B5 = BB[bi+5];
__bf16 B6 = BB[bi+6];
__bf16 B7 = BB[bi+7];
bi += 8;

vbfloat16m1_t A0 = __riscv_vle16_v_bf16m1( &A[ai+0*gvl], gvl );
vbfloat16m1_t A0 = __riscv_vle16_v_bf16m1( &AA[ai+0*gvl], gvl );
ai += 16;

result0 = __riscv_vfwmaccbf16_vf_f32m2(result0, B0, A0, gvl);
Expand Down Expand Up @@ -100,17 +102,17 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
vfloat32m1_t result7 = __riscv_vfmv_v_f_f32m1(0.0f, gvl);

for (BLASLONG k=0; k<K; k++) {
__bf16 B0 = B[bi+0];
__bf16 B1 = B[bi+1];
__bf16 B2 = B[bi+2];
__bf16 B3 = B[bi+3];
__bf16 B4 = B[bi+4];
__bf16 B5 = B[bi+5];
__bf16 B6 = B[bi+6];
__bf16 B7 = B[bi+7];
__bf16 B0 = BB[bi+0];
__bf16 B1 = BB[bi+1];
__bf16 B2 = BB[bi+2];
__bf16 B3 = BB[bi+3];
__bf16 B4 = BB[bi+4];
__bf16 B5 = BB[bi+5];
__bf16 B6 = BB[bi+6];
__bf16 B7 = BB[bi+7];
bi += 8;

vbfloat16mf2_t A0 = __riscv_vle16_v_bf16mf2( &A[ai+0*gvl], gvl );
vbfloat16mf2_t A0 = __riscv_vle16_v_bf16mf2( &AA[ai+0*gvl], gvl );
ai += 8;

result0 = __riscv_vfwmaccbf16_vf_f32m1(result0, B0, A0, gvl);
Expand Down Expand Up @@ -172,17 +174,17 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
vfloat32m1_t result7 = __riscv_vfmv_v_f_f32m1(0.0f, gvl);

for (BLASLONG k=0; k < K; ++k) {
__bf16 B0 = B[bi+0];
__bf16 B1 = B[bi+1];
__bf16 B2 = B[bi+2];
__bf16 B3 = B[bi+3];
__bf16 B4 = B[bi+4];
__bf16 B5 = B[bi+5];
__bf16 B6 = B[bi+6];
__bf16 B7 = B[bi+7];
__bf16 B0 = BB[bi+0];
__bf16 B1 = BB[bi+1];
__bf16 B2 = BB[bi+2];
__bf16 B3 = BB[bi+3];
__bf16 B4 = BB[bi+4];
__bf16 B5 = BB[bi+5];
__bf16 B6 = BB[bi+6];
__bf16 B7 = BB[bi+7];
bi += 8;

vbfloat16mf2_t A0 = __riscv_vle16_v_bf16mf2( &A[ai+0*gvl], gvl );
vbfloat16mf2_t A0 = __riscv_vle16_v_bf16mf2( &AA[ai+0*gvl], gvl );
ai += 4;

result0 = __riscv_vfwmaccbf16_vf_f32m1(result0, B0, A0, gvl);
Expand Down Expand Up @@ -256,22 +258,22 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
BLASLONG bi = n_top * K;

for (BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0])*(float)(B[bi+0]);
result1+=(float)(A[ai+1])*(float)(B[bi+0]);
result2+=(float)(A[ai+0])*(float)(B[bi+1]);
result3+=(float)(A[ai+1])*(float)(B[bi+1]);
result4+=(float)(A[ai+0])*(float)(B[bi+2]);
result5+=(float)(A[ai+1])*(float)(B[bi+2]);
result6+=(float)(A[ai+0])*(float)(B[bi+3]);
result7+=(float)(A[ai+1])*(float)(B[bi+3]);
result8+=(float)(A[ai+0])*(float)(B[bi+4]);
result9+=(float)(A[ai+1])*(float)(B[bi+4]);
result10+=(float)(A[ai+0])*(float)(B[bi+5]);
result11+=(float)(A[ai+1])*(float)(B[bi+5]);
result12+=(float)(A[ai+0])*(float)(B[bi+6]);
result13+=(float)(A[ai+1])*(float)(B[bi+6]);
result14+=(float)(A[ai+0])*(float)(B[bi+7]);
result15+=(float)(A[ai+1])*(float)(B[bi+7]);
result0+=(float)(AA[ai+0])*(float)(BB[bi+0]);
result1+=(float)(AA[ai+1])*(float)(BB[bi+0]);
result2+=(float)(AA[ai+0])*(float)(BB[bi+1]);
result3+=(float)(AA[ai+1])*(float)(BB[bi+1]);
result4+=(float)(AA[ai+0])*(float)(BB[bi+2]);
result5+=(float)(AA[ai+1])*(float)(BB[bi+2]);
result6+=(float)(AA[ai+0])*(float)(BB[bi+3]);
result7+=(float)(AA[ai+1])*(float)(BB[bi+3]);
result8+=(float)(AA[ai+0])*(float)(BB[bi+4]);
result9+=(float)(AA[ai+1])*(float)(BB[bi+4]);
result10+=(float)(AA[ai+0])*(float)(BB[bi+5]);
result11+=(float)(AA[ai+1])*(float)(BB[bi+5]);
result12+=(float)(AA[ai+0])*(float)(BB[bi+6]);
result13+=(float)(AA[ai+1])*(float)(BB[bi+6]);
result14+=(float)(AA[ai+0])*(float)(BB[bi+7]);
result15+=(float)(AA[ai+1])*(float)(BB[bi+7]);
ai+=2;
bi+=8;
}
Expand Down Expand Up @@ -314,14 +316,14 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
BLASLONG bi = n_top * K;

for (BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0])*(float)(B[bi+0]);
result1+=(float)(A[ai+0])*(float)(B[bi+1]);
result2+=(float)(A[ai+0])*(float)(B[bi+2]);
result3+=(float)(A[ai+0])*(float)(B[bi+3]);
result4+=(float)(A[ai+0])*(float)(B[bi+4]);
result5+=(float)(A[ai+0])*(float)(B[bi+5]);
result6+=(float)(A[ai+0])*(float)(B[bi+6]);
result7+=(float)(A[ai+0])*(float)(B[bi+7]);
result0+=(float)(AA[ai+0])*(float)(BB[bi+0]);
result1+=(float)(AA[ai+0])*(float)(BB[bi+1]);
result2+=(float)(AA[ai+0])*(float)(BB[bi+2]);
result3+=(float)(AA[ai+0])*(float)(BB[bi+3]);
result4+=(float)(AA[ai+0])*(float)(BB[bi+4]);
result5+=(float)(AA[ai+0])*(float)(BB[bi+5]);
result6+=(float)(AA[ai+0])*(float)(BB[bi+6]);
result7+=(float)(AA[ai+0])*(float)(BB[bi+7]);
ai+=1;
bi+=8;
}
Expand Down Expand Up @@ -354,13 +356,13 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
vfloat32m2_t result3 = __riscv_vfmv_v_f_f32m2(0.0f, gvl);

for (BLASLONG k=0; k<K; k++) {
__bf16 B0 = B[bi+0];
__bf16 B1 = B[bi+1];
__bf16 B2 = B[bi+2];
__bf16 B3 = B[bi+3];
__bf16 B0 = BB[bi+0];
__bf16 B1 = BB[bi+1];
__bf16 B2 = BB[bi+2];
__bf16 B3 = BB[bi+3];
bi += 4;

vbfloat16m1_t A0 = __riscv_vle16_v_bf16m1( &A[ai+0*gvl], gvl );
vbfloat16m1_t A0 = __riscv_vle16_v_bf16m1( &AA[ai+0*gvl], gvl );
ai += 16;

result0 = __riscv_vfwmaccbf16_vf_f32m2(result0, B0, A0, gvl);
Expand Down Expand Up @@ -401,13 +403,13 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
vfloat32m1_t result3 = __riscv_vfmv_v_f_f32m1(0.0f, gvl);

for (BLASLONG k=0; k<K; k++) {
__bf16 B0 = B[bi+0];
__bf16 B1 = B[bi+1];
__bf16 B2 = B[bi+2];
__bf16 B3 = B[bi+3];
__bf16 B0 = BB[bi+0];
__bf16 B1 = BB[bi+1];
__bf16 B2 = BB[bi+2];
__bf16 B3 = BB[bi+3];
bi += 4;

vbfloat16mf2_t A0 = __riscv_vle16_v_bf16mf2( &A[ai+0*gvl], gvl );
vbfloat16mf2_t A0 = __riscv_vle16_v_bf16mf2( &AA[ai+0*gvl], gvl );
ai += 8;

result0 = __riscv_vfwmaccbf16_vf_f32m1(result0, B0, A0, gvl);
Expand Down Expand Up @@ -449,13 +451,13 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
vfloat32m1_t result3 = __riscv_vfmv_v_f_f32m1(0.0f, gvl);

for (BLASLONG k=0; k < K; ++k) {
__bf16 B0 = B[bi+0];
__bf16 B1 = B[bi+1];
__bf16 B2 = B[bi+2];
__bf16 B3 = B[bi+3];
__bf16 B0 = BB[bi+0];
__bf16 B1 = BB[bi+1];
__bf16 B2 = BB[bi+2];
__bf16 B3 = BB[bi+3];
bi += 4;

vbfloat16mf2_t A0 = __riscv_vle16_v_bf16mf2( &A[ai+0*gvl], gvl );
vbfloat16mf2_t A0 = __riscv_vle16_v_bf16mf2( &AA[ai+0*gvl], gvl );
ai += 4;

result0 = __riscv_vfwmaccbf16_vf_f32m1(result0, B0, A0, gvl);
Expand Down Expand Up @@ -501,14 +503,14 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
BLASLONG bi = n_top * K;

for (BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0])*(float)(B[bi+0]);
result1+=(float)(A[ai+1])*(float)(B[bi+0]);
result2+=(float)(A[ai+0])*(float)(B[bi+1]);
result3+=(float)(A[ai+1])*(float)(B[bi+1]);
result4+=(float)(A[ai+0])*(float)(B[bi+2]);
result5+=(float)(A[ai+1])*(float)(B[bi+2]);
result6+=(float)(A[ai+0])*(float)(B[bi+3]);
result7+=(float)(A[ai+1])*(float)(B[bi+3]);
result0+=(float)(AA[ai+0])*(float)(BB[bi+0]);
result1+=(float)(AA[ai+1])*(float)(BB[bi+0]);
result2+=(float)(AA[ai+0])*(float)(BB[bi+1]);
result3+=(float)(AA[ai+1])*(float)(BB[bi+1]);
result4+=(float)(AA[ai+0])*(float)(BB[bi+2]);
result5+=(float)(AA[ai+1])*(float)(BB[bi+2]);
result6+=(float)(AA[ai+0])*(float)(BB[bi+3]);
result7+=(float)(AA[ai+1])*(float)(BB[bi+3]);
ai+=2;
bi+=4;
}
Expand Down Expand Up @@ -537,10 +539,10 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
BLASLONG bi = n_top * K;

for (BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0])*(float)(B[bi+0]);
result1+=(float)(A[ai+0])*(float)(B[bi+1]);
result2+=(float)(A[ai+0])*(float)(B[bi+2]);
result3+=(float)(A[ai+0])*(float)(B[bi+3]);
result0+=(float)(AA[ai+0])*(float)(BB[bi+0]);
result1+=(float)(AA[ai+0])*(float)(BB[bi+1]);
result2+=(float)(AA[ai+0])*(float)(BB[bi+2]);
result3+=(float)(AA[ai+0])*(float)(BB[bi+3]);
ai+=1;
bi+=4;
}
Expand Down Expand Up @@ -569,11 +571,11 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
vfloat32m2_t result1 = __riscv_vfmv_v_f_f32m2(0.0f, gvl);

for (BLASLONG k=0; k<K; k++) {
__bf16 B0 = B[bi+0];
__bf16 B1 = B[bi+1];
__bf16 B0 = BB[bi+0];
__bf16 B1 = BB[bi+1];
bi += 2;

vbfloat16m1_t A0 = __riscv_vle16_v_bf16m1( &A[ai+0*gvl], gvl );
vbfloat16m1_t A0 = __riscv_vle16_v_bf16m1( &AA[ai+0*gvl], gvl );
ai += 16;

result0 = __riscv_vfwmaccbf16_vf_f32m2(result0, B0, A0, gvl);
Expand Down Expand Up @@ -603,11 +605,11 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
vfloat32m1_t result1 = __riscv_vfmv_v_f_f32m1(0.0f, gvl);

for (BLASLONG k=0; k<K; k++) {
__bf16 B0 = B[bi+0];
__bf16 B1 = B[bi+1];
__bf16 B0 = BB[bi+0];
__bf16 B1 = BB[bi+1];
bi += 2;

vbfloat16mf2_t A0 = __riscv_vle16_v_bf16mf2( &A[ai+0*gvl], gvl );
vbfloat16mf2_t A0 = __riscv_vle16_v_bf16mf2( &AA[ai+0*gvl], gvl );
ai += 8;

result0 = __riscv_vfwmaccbf16_vf_f32m1(result0, B0, A0, gvl);
Expand Down Expand Up @@ -639,11 +641,11 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
vfloat32m1_t result1 = __riscv_vfmv_v_f_f32m1(0.0f, gvl);

for (BLASLONG k=0; k < K; ++k) {
__bf16 B0 = B[bi+0];
__bf16 B1 = B[bi+1];
__bf16 B0 = BB[bi+0];
__bf16 B1 = BB[bi+1];
bi += 2;

vbfloat16mf2_t A0 = __riscv_vle16_v_bf16mf2( &A[ai+0*gvl], gvl );
vbfloat16mf2_t A0 = __riscv_vle16_v_bf16mf2( &AA[ai+0*gvl], gvl );
ai += 4;

result0 = __riscv_vfwmaccbf16_vf_f32m1(result0, B0, A0, gvl);
Expand Down Expand Up @@ -675,10 +677,10 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
BLASLONG bi = n_top * K;

for (BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0])*(float)(B[bi+0]);
result1+=(float)(A[ai+1])*(float)(B[bi+0]);
result2+=(float)(A[ai+0])*(float)(B[bi+1]);
result3+=(float)(A[ai+1])*(float)(B[bi+1]);
result0+=(float)(AA[ai+0])*(float)(BB[bi+0]);
result1+=(float)(AA[ai+1])*(float)(BB[bi+0]);
result2+=(float)(AA[ai+0])*(float)(BB[bi+1]);
result3+=(float)(AA[ai+1])*(float)(BB[bi+1]);
ai+=2;
bi+=2;
}
Expand All @@ -701,8 +703,8 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
BLASLONG bi = n_top * K;

for (BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0])*(float)(B[bi+0]);
result1+=(float)(A[ai+0])*(float)(B[bi+1]);
result0+=(float)(AA[ai+0])*(float)(BB[bi+0]);
result1+=(float)(AA[ai+0])*(float)(BB[bi+1]);
ai+=1;
bi+=2;
}
Expand All @@ -728,10 +730,10 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
vfloat32m2_t result0 = __riscv_vfmv_v_f_f32m2(0.0f, gvl);

for (BLASLONG k=0; k<K; k++) {
__bf16 B0 = B[bi+0];
__bf16 B0 = BB[bi+0];
bi += 1;

vbfloat16m1_t A0 = __riscv_vle16_v_bf16m1( &A[ai+0*gvl], gvl );
vbfloat16m1_t A0 = __riscv_vle16_v_bf16m1( &AA[ai+0*gvl], gvl );
ai += 16;

result0 = __riscv_vfwmaccbf16_vf_f32m2(result0, B0, A0, gvl);
Expand All @@ -757,10 +759,10 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
vfloat32m1_t result0 = __riscv_vfmv_v_f_f32m1(0.0f, gvl);

for (BLASLONG k=0; k<K; k++) {
__bf16 B0 = B[bi+0];
__bf16 B0 = BB[bi+0];
bi += 1;

vbfloat16mf2_t A0 = __riscv_vle16_v_bf16mf2( &A[ai+0*gvl], gvl );
vbfloat16mf2_t A0 = __riscv_vle16_v_bf16mf2( &AA[ai+0*gvl], gvl );
ai += 8;

result0 = __riscv_vfwmaccbf16_vf_f32m1(result0, B0, A0, gvl);
Expand All @@ -787,10 +789,10 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
vfloat32m1_t result0 = __riscv_vfmv_v_f_f32m1(0.0f, gvl);

for (BLASLONG k=0; k < K; ++k) {
__bf16 B0 = B[bi+0];
__bf16 B0 = BB[bi+0];
bi += 1;

vbfloat16mf2_t A0 = __riscv_vle16_v_bf16mf2( &A[ai+0*gvl], gvl );
vbfloat16mf2_t A0 = __riscv_vle16_v_bf16mf2( &AA[ai+0*gvl], gvl );
ai += 4;

result0 = __riscv_vfwmaccbf16_vf_f32m1(result0, B0, A0, gvl);
Expand All @@ -814,8 +816,8 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
BLASLONG bi = n_top * K;

for (BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0])*(float)(B[bi+0]);
result1+=(float)(A[ai+1])*(float)(B[bi+0]);
result0+=(float)(AA[ai+0])*(float)(BB[bi+0]);
result1+=(float)(AA[ai+1])*(float)(BB[bi+0]);
ai+=2;
bi+=1;
}
Expand All @@ -835,7 +837,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
BLASLONG bi = n_top * K;

for (BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0])*(float)(B[bi+0]);
result0+=(float)(AA[ai+0])*(float)(BB[bi+0]);
ai+=1;
bi+=1;
}
Expand Down
Loading
Loading