diff --git a/kernel/riscv64/sbgemm_kernel_16x8_zvl256b.c b/kernel/riscv64/sbgemm_kernel_16x8_zvl256b.c index bded873b8e..6e7b06884d 100644 --- a/kernel/riscv64/sbgemm_kernel_16x8_zvl256b.c +++ b/kernel/riscv64/sbgemm_kernel_16x8_zvl256b.c @@ -6,6 +6,8 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B, BLASLONG gvl = 0; BLASLONG m_top = 0; BLASLONG n_top = 0; + __bf16 *BB = (__bf16 *)(B); + __bf16 *AA = (__bf16 *)(A); // -- MAIN PASS for (BLASLONG j=0; j 0; i -= vl) { vl = VSETVL(i); vy = VLEV_FLOAT(y_ptr, vl); @@ -88,7 +99,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * y_ptr += vl; a_ptr += vl; } - x += inc_x; + x_ptr += inc_x; a += lda; } } else { @@ -110,9 +121,14 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * } } for (j = 0; j < n; j++) { - temp = (IFLOAT)(alpha * (FLOAT)(x[0])); +#if defined(HFLOAT16) + temp = (_Float16)(alpha * (FLOAT)(x_ptr[0])); + a_ptr = (_Float16 *)(a); +#else + temp = (__bf16)(alpha * (FLOAT)(x_ptr[0])); + a_ptr = (__bf16 *)(a); +#endif y_ptr = y; - a_ptr = a; for (i = m; i > 0; i -= vl) { vl = VSETVL(i); vy = VLSEV_FLOAT(y_ptr, stride_y, vl); @@ -122,7 +138,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * y_ptr += vl * inc_y; a_ptr += vl; } - x += inc_x; + x_ptr += inc_x; a += lda; } } diff --git a/kernel/riscv64/sbgemv_t_vector.c b/kernel/riscv64/sbgemv_t_vector.c index f537ca4ead..136a1f7c1f 100644 --- a/kernel/riscv64/sbgemv_t_vector.c +++ b/kernel/riscv64/sbgemv_t_vector.c @@ -58,7 +58,13 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * { BLASLONG i = 0, j = 0, k = 0; BLASLONG ix = 0, iy = 0; - IFLOAT *a_ptr = a; +#if defined(HFLOAT16) + _Float16 *a_ptr = (_Float16 *)(a); + _Float16 *x_ptr = (_Float16 *)(x); +#else + __bf16 *a_ptr = (__bf16 *)(a); + __bf16 *x_ptr = (__bf16 *)(x); +#endif FLOAT temp; IFLOAT_V_T va, vx; @@ -79,7 +85,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * #endif for (k = 0; k < m/gvl; k++) { va = VLEV_IFLOAT(&a_ptr[j], gvl); - vx = VLEV_IFLOAT(&x[j], gvl); + vx = VLEV_IFLOAT(&x_ptr[j], gvl); vr = VFMACCVV_FLOAT(vz, va, vx, gvl); // could vfmacc here and reduce outside loop v_res = VFREDSUM_FLOAT(vr, v_res, gvl); // but that reordering diverges far enough from scalar path to make tests fail j += gvl; @@ -87,7 +93,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * if (j < m) { gvl = VSETVL(m-j); va = VLEV_IFLOAT(&a_ptr[j], gvl); - vx = VLEV_IFLOAT(&x[j], gvl); + vx = VLEV_IFLOAT(&x_ptr[j], gvl); vr = VFMACCVV_FLOAT(vz, va, vx, gvl); v_res = VFREDSUM_FLOAT(vr, v_res, gvl); } @@ -109,7 +115,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * #endif for (k = 0; k < m/gvl; k++) { va = VLEV_IFLOAT(&a_ptr[j], gvl); - vx = VLSEV_IFLOAT(&x[ix], stride_x, gvl); + vx = VLSEV_IFLOAT(&x_ptr[ix], stride_x, gvl); vr = VFMACCVV_FLOAT(vz, va, vx, gvl); v_res = VFREDSUM_FLOAT(vr, v_res, gvl); j += gvl; @@ -118,7 +124,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * if (j < m) { gvl = VSETVL(m-j); va = VLEV_IFLOAT(&a_ptr[j], gvl); - vx = VLSEV_IFLOAT(&x[ix], stride_x, gvl); + vx = VLSEV_IFLOAT(&x_ptr[ix], stride_x, gvl); vr = VFMACCVV_FLOAT(vz, va, vx, gvl); v_res = VFREDSUM_FLOAT(vr, v_res, gvl); }