@@ -33,11 +33,6 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3333#include <arm_neon.h>
3434#include "common.h"
3535
36- static inline float bf16_to_fp32 (bfloat16 bf16 ) {
37- uint32_t fp32 = (uint32_t )bf16 << 16 ;
38- return * ((float * )& fp32 );
39- }
40-
4136int CNAME (BLASLONG m , BLASLONG n , float alpha , bfloat16 * a , BLASLONG lda , bfloat16 * x , BLASLONG incx , float beta , float * y , BLASLONG incy )
4237{
4338 if (m < 1 || n < 1 ) return (0 );
@@ -132,10 +127,10 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat
132127 }
133128
134129 for (; i < m ; ++ i ) {
135- y0_ptr [iy ] += alpha * a0_ptr [i ] * x_ptr [i ];
136- y1_ptr [iy ] += alpha * a1_ptr [i ] * x_ptr [i ];
137- y2_ptr [iy ] += alpha * a2_ptr [i ] * x_ptr [i ];
138- y3_ptr [iy ] += alpha * a3_ptr [i ] * x_ptr [i ];
130+ y0_ptr [iy ] += alpha * vcvtah_f32_bf16 ( a0_ptr [i ]) * vcvtah_f32_bf16 ( x_ptr [i ]) ;
131+ y1_ptr [iy ] += alpha * vcvtah_f32_bf16 ( a1_ptr [i ]) * vcvtah_f32_bf16 ( x_ptr [i ]) ;
132+ y2_ptr [iy ] += alpha * vcvtah_f32_bf16 ( a2_ptr [i ]) * vcvtah_f32_bf16 ( x_ptr [i ]) ;
133+ y3_ptr [iy ] += alpha * vcvtah_f32_bf16 ( a3_ptr [i ]) * vcvtah_f32_bf16 ( x_ptr [i ]) ;
139134 }
140135
141136 iy += incy ;
@@ -177,7 +172,7 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat
177172 }
178173
179174 for (; i < m ; ++ i ) {
180- y_ptr [iy ] += alpha * a_ptr [i ] * x_ptr [i ];
175+ y_ptr [iy ] += alpha * vcvtah_f32_bf16 ( a_ptr [i ]) * vcvtah_f32_bf16 ( x_ptr [i ]) ;
181176 }
182177
183178 iy += incy ;
@@ -191,7 +186,7 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat
191186 temp = 0.0 ;
192187 ix = 0 ;
193188 for (i = 0 ; i < m ; i ++ ) {
194- temp += bf16_to_fp32 ( a [i ]) * bf16_to_fp32 ( x [ix ]);
189+ temp += vcvtah_f32_bf16 ( a_ptr [i ]) * vcvtah_f32_bf16 ( x_ptr [ix ]);
195190 ix += incx ;
196191 }
197192 if (beta == 0.0f ) {
0 commit comments