@@ -468,29 +468,28 @@ template <cpu_isa_t isa>
468468void jit_uni_eltwise_injector_t <isa>::exp_compute_vector_fwd(
469469 const TRegS &vmm_src, float min_input, float max_input) {
470470
471- const auto &t0 = ZRegS ( IDX ( vmm_src)) ;
472- const auto &t1 = ZRegS ( IDX ( vmm_aux0)) ;
473- const auto &t2 = ZRegS ( IDX ( vmm_aux1)) ;
471+ const auto &t0 = vmm_src;
472+ const auto &t1 = vmm_aux0;
473+ const auto &t2 = vmm_aux1;
474474 const float ln_flt_max = logf (FLT_MAX);
475475 const float ln_flt_min = logf (FLT_MIN);
476476 if (max_input > ln_flt_max)
477- h->fmin (t0, p_all, ZRegS ( IDX ( table_val (exp_ln_flt_max_f, z_tmp)) ));
477+ h->fmin (t0, p_all, table_val (exp_ln_flt_max_f, z_tmp));
478478 if (min_input < ln_flt_min)
479- h->fmax (t0, p_all, ZRegS ( IDX ( table_val (exp_ln_flt_min_f, z_tmp)) ));
480- h->fmul (t0, t0, ZRegS ( IDX ( table_val (exp_log2ef, t1)) ));
479+ h->fmax (t0, p_all, table_val (exp_ln_flt_min_f, z_tmp));
480+ h->fmul (t0, t0, table_val (exp_log2ef, t1));
481481 h->frintm (t1, p_all, t0);
482482 h->fcvtzs (t2, p_all, t1);
483483 h->fsub (t1, t0, t1);
484- h->fadd (t0, t1, ZRegS ( IDX ( table_val (one, z_tmp)) ));
484+ h->fadd (t0, t1, table_val (one, z_tmp));
485485 h->lsr (t1, t0, 17 );
486486 h->fexpa (t1, t1);
487487 h->fscale (t1, p_all, t2);
488488 h->and_ (ZRegD (t2.getIdx ()), ZRegD (t0.getIdx ()),
489489 ZRegD (IDX (table_val (exp_not_mask17, z_tmp))));
490490 h->fsub (t2, t0, t2);
491- h->fmad (ZRegS (IDX (table_val (exp_coeff2, t0))), p_all, t2,
492- ZRegS (IDX (table_val (exp_coeff1, z_tmp))));
493- h->fmad (t0, p_all, t2, ZRegS (IDX (table_val (one, z_tmp))));
491+ h->fmad (table_val (exp_coeff2, t0), p_all, t2, table_val (exp_coeff1, z_tmp));
492+ h->fmad (t0, p_all, t2, table_val (one, z_tmp));
494493 h->fmul (t0, t1, t0);
495494}
496495
@@ -556,25 +555,23 @@ void jit_uni_eltwise_injector_t<isa>::tanh_polynomial_approx_compute_vector_fwd(
556555 h->add_imm (h->X_TMP_1 , x_table,
557556 table_off (tanh_pol_table, coeff_idx * tanh_n_polynomials),
558557 h->X_TMP_0 );
559- h->ld1w (ZRegS (IDX (vmm_coeff)), p_all,
560- ptr (h->X_TMP_1 , ZRegS (IDX (vmm_pol_idx)), SXTW));
558+ h->ld1w (vmm_coeff, p_all, ptr (h->X_TMP_1 , vmm_pol_idx, SXTW));
561559 };
562560
563561 // because tanh(x) = -tanh(-x), we extract sign to make x postive
564562 // and reapply sign at the end
565563 h->fabs (vmm_src_pos, p_all / T_z, vmm_src);
566564
567565 // Compute indices for the table lookup
568- h->sub (ZRegS (IDX (vmm_indices)), ZRegS (IDX (vmm_src_pos)),
569- ZRegS (IDX (table_val (tanh_idx_bias, z_tmp))));
566+ h->sub (vmm_indices, vmm_src_pos, table_val (tanh_idx_bias, z_tmp));
570567 h->and_ (ZRegD (IDX (vmm_indices)), ZRegD (IDX (vmm_indices)),
571568 ZRegD (IDX (table_val (tanh_idx_mask, z_tmp))));
572569 h->lsr (ZRegD (IDX (vmm_indices)), ZRegD (IDX (vmm_indices)), 20 );
573570
574571 // Argument reduction
575572 h->and_ (ZRegD (IDX (vmm_src_shift)), ZRegD (IDX (vmm_src_pos)),
576573 ZRegD (IDX (table_val (tanh_idx_mask, z_tmp))));
577- h->fsub (vmm_src_pos, vmm_src_pos, ZRegS ( IDX ( vmm_src_shift)) );
574+ h->fsub (vmm_src_pos, vmm_src_pos, vmm_src_shift);
578575
579576 gather_coefficient (vmm_pol, 6 , vmm_indices);
580577 for (int deg = 5 ; deg >= 0 ; --deg) {
@@ -799,7 +796,7 @@ void jit_uni_eltwise_injector_t<isa>::soft_relu_compute_vector_fwd(
799796
800797 // calculate exp(x)
801798 // fx = x * log2ef + 0.5
802- h->fmul (vmm_src, vmm_src, ZRegS ( IDX ( table_val (exp_log2ef, z_tmp)) ));
799+ h->fmul (vmm_src, vmm_src, table_val (exp_log2ef, z_tmp));
803800 h->fadd (vmm_src, p_all / T_m, 0 .5f );
804801
805802 // tmp = floorf(fx)
@@ -809,19 +806,15 @@ void jit_uni_eltwise_injector_t<isa>::soft_relu_compute_vector_fwd(
809806 h->mov (ZRegD (IDX (vmm_src)), ZRegD (IDX (vmm_aux0)));
810807
811808 // x = x - fx * ln2
812- h->fmul (vmm_aux0, vmm_aux0, ZRegS ( IDX ( table_val (ln2f, z_tmp)) ));
809+ h->fmul (vmm_aux0, vmm_aux0, table_val (ln2f, z_tmp));
813810 h->fsub (vmm_aux1, vmm_aux1, vmm_aux0);
814811 // compute exponent polynomial
815812 table_val (exp_pol, vmm_aux3, 4 );
816- h->fmad (vmm_aux3, p_all / T_m, vmm_aux1,
817- ZRegS (IDX (table_val (exp_pol, z_tmp, 3 ))));
818- h->fmad (vmm_aux3, p_all / T_m, vmm_aux1,
819- ZRegS (IDX (table_val (exp_pol, z_tmp, 2 ))));
820- h->fmad (vmm_aux3, p_all / T_m, vmm_aux1,
821- ZRegS (IDX (table_val (exp_pol, z_tmp, 1 ))));
822- h->fmad (vmm_aux3, p_all / T_m, vmm_aux1,
823- ZRegS (IDX (table_val (exp_pol, z_tmp, 0 ))));
824- h->fmad (vmm_aux3, p_all / T_m, vmm_aux1, ZRegS (IDX (table_val (one, z_tmp))));
813+ h->fmad (vmm_aux3, p_all / T_m, vmm_aux1, table_val (exp_pol, z_tmp, 3 ));
814+ h->fmad (vmm_aux3, p_all / T_m, vmm_aux1, table_val (exp_pol, z_tmp, 2 ));
815+ h->fmad (vmm_aux3, p_all / T_m, vmm_aux1, table_val (exp_pol, z_tmp, 1 ));
816+ h->fmad (vmm_aux3, p_all / T_m, vmm_aux1, table_val (exp_pol, z_tmp, 0 ));
817+ h->fmad (vmm_aux3, p_all / T_m, vmm_aux1, table_val (one, z_tmp));
825818
826819 // We do not count 2^-n here, because n can reach 128 and 2^(-128) is not
827820 // representable by fp32, so to get around this problem, instead of computing
@@ -849,8 +842,7 @@ void jit_uni_eltwise_injector_t<isa>::soft_relu_compute_vector_fwd(
849842 h->lsr (vmm_src, vmm_aux3, n_mantissa_bits);
850843 h->scvtf (vmm_src, p_all / T_m, vmm_src);
851844 // got n. where n is x = 2^n * y. y = 0.5 .. 1
852- h->fsub (vmm_src, vmm_src,
853- ZRegS (IDX (table_val (soft_relu_one_twenty_six, z_tmp))));
845+ h->fsub (vmm_src, vmm_src, table_val (soft_relu_one_twenty_six, z_tmp));
854846
855847 // and with mask (to get 0.5 * mantissa)
856848 h->and_ (ZRegD (IDX (vmm_aux3)), ZRegD (IDX (vmm_aux3)),
@@ -865,23 +857,23 @@ void jit_uni_eltwise_injector_t<isa>::soft_relu_compute_vector_fwd(
865857 h->mov (ZRegD (IDX (vmm_aux1)),
866858 ZRegD (IDX (table_val (soft_relu_pol, z_tmp, 8 ))));
867859 h->fmad (vmm_aux1, p_all / T_m, vmm_aux3,
868- ZRegS ( IDX ( table_val (soft_relu_pol, z_tmp, 7 )) ));
860+ table_val (soft_relu_pol, z_tmp, 7 ));
869861 h->fmad (vmm_aux1, p_all / T_m, vmm_aux3,
870- ZRegS ( IDX ( table_val (soft_relu_pol, z_tmp, 6 )) ));
862+ table_val (soft_relu_pol, z_tmp, 6 ));
871863 h->fmad (vmm_aux1, p_all / T_m, vmm_aux3,
872- ZRegS ( IDX ( table_val (soft_relu_pol, z_tmp, 5 )) ));
864+ table_val (soft_relu_pol, z_tmp, 5 ));
873865 h->fmad (vmm_aux1, p_all / T_m, vmm_aux3,
874- ZRegS ( IDX ( table_val (soft_relu_pol, z_tmp, 4 )) ));
866+ table_val (soft_relu_pol, z_tmp, 4 ));
875867 h->fmad (vmm_aux1, p_all / T_m, vmm_aux3,
876- ZRegS ( IDX ( table_val (soft_relu_pol, z_tmp, 3 )) ));
868+ table_val (soft_relu_pol, z_tmp, 3 ));
877869 h->fmad (vmm_aux1, p_all / T_m, vmm_aux3,
878- ZRegS ( IDX ( table_val (soft_relu_pol, z_tmp, 2 )) ));
870+ table_val (soft_relu_pol, z_tmp, 2 ));
879871 h->fmad (vmm_aux1, p_all / T_m, vmm_aux3,
880- ZRegS ( IDX ( table_val (soft_relu_pol, z_tmp, 1 )) ));
872+ table_val (soft_relu_pol, z_tmp, 1 ));
881873 h->fmad (vmm_aux1, p_all / T_m, vmm_aux3,
882- ZRegS ( IDX ( table_val (soft_relu_pol, z_tmp, 0 )) ));
874+ table_val (soft_relu_pol, z_tmp, 0 ));
883875 // calculate ln(2) * n
884- h->fmul (vmm_src, vmm_src, ZRegS ( IDX ( table_val (ln2f, z_tmp)) ));
876+ h->fmul (vmm_src, vmm_src, table_val (ln2f, z_tmp));
885877 h->fadd (vmm_src, vmm_src, vmm_aux1);
886878 h->fadd (vmm_src, vmm_src, vmm_aux0);
887879
@@ -944,11 +936,11 @@ template <cpu_isa_t isa>
944936void jit_uni_eltwise_injector_t <isa>::log_compute_vector_fwd(
945937 const TRegS &vmm_src) {
946938
947- const auto &t0 = ZRegS ( IDX ( vmm_src)) ;
948- const auto &t1 = ZRegS ( IDX ( vmm_aux1)) ;
949- const auto &t2 = ZRegS ( IDX ( vmm_aux2)) ;
950- const auto &t3 = ZRegS ( IDX ( vmm_aux3)) ;
951- const auto &t4 = ZRegS ( IDX ( vmm_aux4)) ;
939+ const auto &t0 = vmm_src;
940+ const auto &t1 = vmm_aux1;
941+ const auto &t2 = vmm_aux2;
942+ const auto &t3 = vmm_aux3;
943+ const auto &t4 = vmm_aux4;
952944 const auto &mask = p_tmp0.s ;
953945 const auto &wt0 = h->W_TMP_0 ;
954946 const auto &xt0 = h->X_TMP_0 ;
@@ -1067,39 +1059,35 @@ void jit_uni_eltwise_injector_t<
10671059 table_off (gelu_erf_minimax_pol,
10681060 coeff_idx * gelu_erf_n_polynomials),
10691061 h->X_TMP_0 );
1070- h->ld1w (ZRegS (IDX (vmm_coeff)), p_all / T_z,
1071- ptr (h->X_TMP_1 , ZRegS (IDX (vmm_pol_idx)), SXTW));
1062+ h->ld1w (vmm_coeff, p_all / T_z, ptr (h->X_TMP_1 , vmm_pol_idx, SXTW));
10721063 };
10731064
10741065 // we use the erf function symmetry erf(-x) = -erf(x)
10751066 // So we make x positive, we will reapply the sign after erf evaluation
10761067 h->fabs (vmm_src_pos, p_all / T_z, vmm_src);
10771068
10781069 // Compute indices for table lookup
1079- h->add (vmm_indices, vmm_src_pos,
1080- ZRegS (IDX (table_val (gelu_erf_idx_bias, z_tmp, 0 ))));
1070+ h->add (vmm_indices, vmm_src_pos, table_val (gelu_erf_idx_bias, z_tmp, 0 ));
10811071
10821072 // An arithmetic shift is needed to properly map denormals to
10831073 // their polynomial. we shift by 21 as we use 2 bits of mantissa
10841074 // for indexing.
1085- h->asr (ZRegS ( IDX ( vmm_indices)), ZRegS ( IDX ( vmm_indices)) , 21 );
1075+ h->asr (vmm_indices, vmm_indices, 21 );
10861076
10871077 // Apply special rules
1088- h->smax (vmm_indices, p_all / T_z,
1089- ZRegS (IDX (table_val (gelu_erf_one, z_tmp))));
1090- h->smin (vmm_indices, p_all / T_z,
1091- ZRegS (IDX (table_val (gelu_erf_twenty_four, z_tmp))));
1078+ h->smax (vmm_indices, p_all / T_z, table_val (gelu_erf_one, z_tmp));
1079+ h->smin (vmm_indices, p_all / T_z, table_val (gelu_erf_twenty_four, z_tmp));
10921080
10931081 // We have to check
10941082 // index = x_pos > rbound ? 23 : index;
10951083 // for erf to return -1/1 when we should.
10961084 h->fcmlt (p_mask.s , p_all / T_z, vmm_src_pos,
1097- ZRegS ( IDX ( table_val (gelu_erf_rbound, z_tmp)) ));
1085+ table_val (gelu_erf_rbound, z_tmp));
10981086 h->sel (vmm_indices, p_mask, vmm_indices,
1099- ZRegS ( IDX ( table_val (gelu_erf_twenty_three, z_tmp)) ));
1087+ table_val (gelu_erf_twenty_three, z_tmp));
11001088
11011089 // Adjusting indices
1102- h->mul (ZRegS ( IDX ( vmm_indices)) , sizeof (float ));
1090+ h->mul (vmm_indices, sizeof (float ));
11031091
11041092 // Evaluate the polynomial
11051093 gather_coefficient (vmm_pol, 5 , vmm_indices);
@@ -1115,9 +1103,9 @@ void jit_uni_eltwise_injector_t<
11151103 h->eor (ZRegD (IDX (vmm_pol)), p_all / T_z, ZRegD (IDX (vmm_tmp)));
11161104
11171105 // Compute the final output
1118- h->fadd (vmm_pol, vmm_pol, ZRegS ( IDX ( table_val (one, z_tmp)) ));
1106+ h->fadd (vmm_pol, vmm_pol, table_val (one, z_tmp));
11191107 h->fmul (vmm_src, p_all / T_z, vmm_pol);
1120- h->fmul (vmm_src, vmm_src, ZRegS ( IDX ( table_val (half, z_tmp)) ));
1108+ h->fmul (vmm_src, vmm_src, table_val (half, z_tmp));
11211109}
11221110template <cpu_isa_t isa>
11231111void jit_uni_eltwise_injector_t <isa>::gelu_erf_compute_vector_fwd(
@@ -1142,8 +1130,7 @@ void jit_uni_eltwise_injector_t<isa>::gelu_erf_compute_vector_fwd(
11421130 h->mov (ZRegD (IDX (vmm_aux3)), ZRegD (IDX (vmm_src)));
11431131
11441132 // x = s / sqrt(2)
1145- h->fmul (vmm_src, vmm_src,
1146- ZRegS (IDX (table_val (gelu_erf_one_over_sqrt_two, z_tmp))));
1133+ h->fmul (vmm_src, vmm_src, table_val (gelu_erf_one_over_sqrt_two, z_tmp));
11471134
11481135 // abs(x)
11491136 h->fabs (vmm_aux1, p_all / T_m, vmm_src);
@@ -1169,17 +1156,13 @@ void jit_uni_eltwise_injector_t<isa>::gelu_erf_compute_vector_fwd(
11691156
11701157 // compute polynomialial r
11711158 table_val (gelu_erf_pol, vmm_aux1, 4 );
1172- h->fmad (vmm_aux1, p_all / T_m, vmm_aux4,
1173- ZRegS (IDX (table_val (gelu_erf_pol, z_tmp, 3 ))));
1174- h->fmad (vmm_aux1, p_all / T_m, vmm_aux4,
1175- ZRegS (IDX (table_val (gelu_erf_pol, z_tmp, 2 ))));
1176- h->fmad (vmm_aux1, p_all / T_m, vmm_aux4,
1177- ZRegS (IDX (table_val (gelu_erf_pol, z_tmp, 1 ))));
1178- h->fmad (vmm_aux1, p_all / T_m, vmm_aux4,
1179- ZRegS (IDX (table_val (gelu_erf_pol, z_tmp, 0 ))));
1159+ h->fmad (vmm_aux1, p_all / T_m, vmm_aux4, table_val (gelu_erf_pol, z_tmp, 3 ));
1160+ h->fmad (vmm_aux1, p_all / T_m, vmm_aux4, table_val (gelu_erf_pol, z_tmp, 2 ));
1161+ h->fmad (vmm_aux1, p_all / T_m, vmm_aux4, table_val (gelu_erf_pol, z_tmp, 1 ));
1162+ h->fmad (vmm_aux1, p_all / T_m, vmm_aux4, table_val (gelu_erf_pol, z_tmp, 0 ));
11801163
11811164 // erf = sign * (1 - r * t * exp(-x*x))
1182- h->fmad (vmm_src, p_all / T_m, vmm_aux1, ZRegS ( IDX ( table_val (one, z_tmp)) ));
1165+ h->fmad (vmm_src, p_all / T_m, vmm_aux1, table_val (one, z_tmp));
11831166 h->eor (ZRegD (IDX (vmm_src)), ZRegD (IDX (vmm_src)), ZRegD (IDX (vmm_aux0)));
11841167
11851168 // S = 0.5 * s
@@ -1207,12 +1190,12 @@ void jit_uni_eltwise_injector_t<isa>::elu_compute_vector_bwd(
12071190 // after exponentiation, get mask by comparing with exp(0)=1.f, not 0.f
12081191 compute_cmp_mask (vmm_src, table_val (one, z_tmp), _cmp_gt_os);
12091192 // R * alpha, then blend with 1.f
1210- h->fmul (vmm_src, vmm_src, ZRegS ( IDX ( table_val (alpha, z_tmp)) ));
1193+ h->fmul (vmm_src, vmm_src, table_val (alpha, z_tmp));
12111194 } else {
12121195 // get mask of `d` > 0
12131196 compute_cmp_mask (vmm_src, table_val (zero, z_tmp), _cmp_gt_os);
12141197 // R = `d` + alpha, then blend with 1.f
1215- h->fadd (vmm_src, vmm_src, ZRegS ( IDX ( table_val (alpha, z_tmp)) ));
1198+ h->fadd (vmm_src, vmm_src, table_val (alpha, z_tmp));
12161199 }
12171200 blend_with_mask (vmm_src, table_val (one, z_tmp));
12181201}
@@ -1241,13 +1224,12 @@ void jit_uni_eltwise_injector_t<isa>::gelu_tanh_compute_vector_bwd(
12411224 // keep G2 in a separate register
12421225 h->mov (ZRegD (IDX (vmm_aux2)),
12431226 ZRegD (IDX (table_val (gelu_tanh_fitting_const_times_three, z_tmp))));
1244- h->fmad (vmm_aux2, p_all / T_m, vmm_src, ZRegS ( IDX ( table_val (one, z_tmp)) ));
1227+ h->fmad (vmm_aux2, p_all / T_m, vmm_src, table_val (one, z_tmp));
12451228
12461229 h->mov (ZRegD (IDX (vmm_aux1)),
12471230 ZRegD (IDX (table_val (gelu_tanh_fitting_const, z_tmp))));
1248- h->fmad (vmm_src, p_all / T_m, vmm_aux1, ZRegS (IDX (table_val (one, z_tmp))));
1249- h->fmul (vmm_aux0, vmm_aux0,
1250- ZRegS (IDX (table_val (gelu_tanh_sqrt_two_over_pi, z_tmp))));
1231+ h->fmad (vmm_src, p_all / T_m, vmm_aux1, table_val (one, z_tmp));
1232+ h->fmul (vmm_aux0, vmm_aux0, table_val (gelu_tanh_sqrt_two_over_pi, z_tmp));
12511233 h->fmul (vmm_src, vmm_src, vmm_aux0);
12521234 h->fmul (vmm_aux2, vmm_aux2, vmm_aux0);
12531235
@@ -1267,11 +1249,11 @@ void jit_uni_eltwise_injector_t<isa>::gelu_tanh_compute_vector_bwd(
12671249 // 1) R = G2 * (1 - T) = G2 - G2 * T
12681250 h->fmls (vmm_aux2, p_all / T_m, vmm_aux2, vmm_src);
12691251 // 2) Q = 1 + T
1270- h->fadd (vmm_src, vmm_src, ZRegS ( IDX ( table_val (one, z_tmp)) ));
1252+ h->fadd (vmm_src, vmm_src, table_val (one, z_tmp));
12711253 // 3) res = Q * (1 + R) = Q + Q * R
12721254 h->fmla (vmm_src, p_all / T_m, vmm_src, vmm_aux2);
12731255
1274- h->fmul (vmm_src, vmm_src, ZRegS ( IDX ( table_val (half, z_tmp)) ));
1256+ h->fmul (vmm_src, vmm_src, table_val (half, z_tmp));
12751257}
12761258
12771259template <cpu_isa_t isa>
@@ -1387,7 +1369,7 @@ template <cpu_isa_t isa>
13871369void jit_uni_eltwise_injector_t <isa>::swish_compute_vector_bwd(
13881370 const TRegS &vmm_src) {
13891371 // R = alpha * s
1390- h->fmul (vmm_src, vmm_src, ZRegS ( IDX ( table_val (alpha, z_tmp)) ));
1372+ h->fmul (vmm_src, vmm_src, table_val (alpha, z_tmp));
13911373
13921374 // Save R on stack for later usage
13931375 h->sub_imm (h->X_SP , h->X_SP , vlen, h->X_TMP_0 );
@@ -1444,8 +1426,7 @@ template <cpu_isa_t isa>
14441426void jit_uni_eltwise_injector_t <isa>::gelu_erf_compute_vector_bwd(
14451427 const TRegS &vmm_src) {
14461428 // R = s / sqrt(2)
1447- h->fmul (vmm_src, vmm_src,
1448- ZRegS (IDX (table_val (gelu_erf_one_over_sqrt_two, z_tmp))));
1429+ h->fmul (vmm_src, vmm_src, table_val (gelu_erf_one_over_sqrt_two, z_tmp));
14491430
14501431 // Save R on stack for later usage
14511432 h->sub_imm (h->X_SP , h->X_SP , vlen, h->X_TMP_0 );
@@ -1461,8 +1442,7 @@ void jit_uni_eltwise_injector_t<isa>::gelu_erf_compute_vector_bwd(
14611442 // T = R / sqrt(pi) * Q
14621443 h->add_imm (h->X_TMP_0 , h->X_SP , 0 , h->X_TMP_1 );
14631444 h->ldr (ZReg (IDX (vmm_aux2)), ptr (h->X_TMP_0 ));
1464- h->fmul (vmm_aux2, vmm_aux2,
1465- ZRegS (IDX (table_val (gelu_erf_one_over_sqrt_pi, z_tmp))));
1445+ h->fmul (vmm_aux2, vmm_aux2, table_val (gelu_erf_one_over_sqrt_pi, z_tmp));
14661446 h->fmul (vmm_aux2, vmm_aux2, vmm_src);
14671447
14681448 // -Q
@@ -1494,23 +1474,19 @@ void jit_uni_eltwise_injector_t<isa>::gelu_erf_compute_vector_bwd(
14941474
14951475 // compute polynomial r
14961476 h->mov (ZRegD (IDX (vmm_aux1)), ZRegD (IDX (table_val (gelu_erf_pol, z_tmp, 4 ))));
1497- h->fmad (vmm_aux1, p_all / T_m, vmm_aux4,
1498- ZRegS (IDX (table_val (gelu_erf_pol, z_tmp, 3 ))));
1499- h->fmad (vmm_aux1, p_all / T_m, vmm_aux4,
1500- ZRegS (IDX (table_val (gelu_erf_pol, z_tmp, 2 ))));
1501- h->fmad (vmm_aux1, p_all / T_m, vmm_aux4,
1502- ZRegS (IDX (table_val (gelu_erf_pol, z_tmp, 1 ))));
1503- h->fmad (vmm_aux1, p_all / T_m, vmm_aux4,
1504- ZRegS (IDX (table_val (gelu_erf_pol, z_tmp, 0 ))));
1477+ h->fmad (vmm_aux1, p_all / T_m, vmm_aux4, table_val (gelu_erf_pol, z_tmp, 3 ));
1478+ h->fmad (vmm_aux1, p_all / T_m, vmm_aux4, table_val (gelu_erf_pol, z_tmp, 2 ));
1479+ h->fmad (vmm_aux1, p_all / T_m, vmm_aux4, table_val (gelu_erf_pol, z_tmp, 1 ));
1480+ h->fmad (vmm_aux1, p_all / T_m, vmm_aux4, table_val (gelu_erf_pol, z_tmp, 0 ));
15051481
15061482 // erf = sign * (1 - r * t * exp(-x*x))
1507- h->fmad (vmm_src, p_all / T_m, vmm_aux1, ZRegS ( IDX ( table_val (one, z_tmp)) ));
1483+ h->fmad (vmm_src, p_all / T_m, vmm_aux1, table_val (one, z_tmp));
15081484 h->eor (ZRegD (IDX (vmm_src)), ZRegD (IDX (vmm_src)), ZRegD (IDX (vmm_aux0)));
15091485
15101486 // P = T + 0.5
1511- h->fadd (vmm_aux2, vmm_aux2, ZRegS ( IDX ( table_val (half, z_tmp)) ));
1487+ h->fadd (vmm_aux2, vmm_aux2, table_val (half, z_tmp));
15121488 // res = P + 0.5 * erf
1513- h->fmla (vmm_aux2, p_all / T_m, vmm_src, ZRegS ( IDX ( table_val (half, z_tmp)) ));
1489+ h->fmla (vmm_aux2, p_all / T_m, vmm_src, table_val (half, z_tmp));
15141490 h->mov (ZRegD (IDX (vmm_src)), ZRegD (IDX (vmm_aux2)));
15151491}
15161492
@@ -1748,8 +1724,7 @@ void jit_uni_eltwise_injector_t<isa>::compute_body(
17481724 }
17491725 }
17501726 if (scale_ != 1 .f ) {
1751- h->fmul (ZRegS (IDX (TRegS (idx))), ZRegS (IDX (TRegS (idx))),
1752- ZRegS (IDX (table_val (scale, vmm_mask))));
1727+ h->fmul (TRegS (idx), TRegS (idx), table_val (scale, vmm_mask));
17531728 }
17541729 });
17551730}
0 commit comments