Skip to content

Commit 950dea9

Browse files
committed
cpu: aarch64: style: remove redundant ZRegS(IDX(...)) wrappers
This is a non-functional commit which removes redundant `ZRegS(IDX(...))` wrappers in the eltwise injector.
1 parent e487fd8 commit 950dea9

File tree

1 file changed

+68
-93
lines changed

1 file changed

+68
-93
lines changed

src/cpu/aarch64/injectors/jit_uni_eltwise_injector.cpp

Lines changed: 68 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -468,29 +468,28 @@ template <cpu_isa_t isa>
468468
void jit_uni_eltwise_injector_t<isa>::exp_compute_vector_fwd(
469469
const TRegS &vmm_src, float min_input, float max_input) {
470470

471-
const auto &t0 = ZRegS(IDX(vmm_src));
472-
const auto &t1 = ZRegS(IDX(vmm_aux0));
473-
const auto &t2 = ZRegS(IDX(vmm_aux1));
471+
const auto &t0 = vmm_src;
472+
const auto &t1 = vmm_aux0;
473+
const auto &t2 = vmm_aux1;
474474
const float ln_flt_max = logf(FLT_MAX);
475475
const float ln_flt_min = logf(FLT_MIN);
476476
if (max_input > ln_flt_max)
477-
h->fmin(t0, p_all, ZRegS(IDX(table_val(exp_ln_flt_max_f, z_tmp))));
477+
h->fmin(t0, p_all, table_val(exp_ln_flt_max_f, z_tmp));
478478
if (min_input < ln_flt_min)
479-
h->fmax(t0, p_all, ZRegS(IDX(table_val(exp_ln_flt_min_f, z_tmp))));
480-
h->fmul(t0, t0, ZRegS(IDX(table_val(exp_log2ef, t1))));
479+
h->fmax(t0, p_all, table_val(exp_ln_flt_min_f, z_tmp));
480+
h->fmul(t0, t0, table_val(exp_log2ef, t1));
481481
h->frintm(t1, p_all, t0);
482482
h->fcvtzs(t2, p_all, t1);
483483
h->fsub(t1, t0, t1);
484-
h->fadd(t0, t1, ZRegS(IDX(table_val(one, z_tmp))));
484+
h->fadd(t0, t1, table_val(one, z_tmp));
485485
h->lsr(t1, t0, 17);
486486
h->fexpa(t1, t1);
487487
h->fscale(t1, p_all, t2);
488488
h->and_(ZRegD(t2.getIdx()), ZRegD(t0.getIdx()),
489489
ZRegD(IDX(table_val(exp_not_mask17, z_tmp))));
490490
h->fsub(t2, t0, t2);
491-
h->fmad(ZRegS(IDX(table_val(exp_coeff2, t0))), p_all, t2,
492-
ZRegS(IDX(table_val(exp_coeff1, z_tmp))));
493-
h->fmad(t0, p_all, t2, ZRegS(IDX(table_val(one, z_tmp))));
491+
h->fmad(table_val(exp_coeff2, t0), p_all, t2, table_val(exp_coeff1, z_tmp));
492+
h->fmad(t0, p_all, t2, table_val(one, z_tmp));
494493
h->fmul(t0, t1, t0);
495494
}
496495

@@ -556,25 +555,23 @@ void jit_uni_eltwise_injector_t<isa>::tanh_polynomial_approx_compute_vector_fwd(
556555
h->add_imm(h->X_TMP_1, x_table,
557556
table_off(tanh_pol_table, coeff_idx * tanh_n_polynomials),
558557
h->X_TMP_0);
559-
h->ld1w(ZRegS(IDX(vmm_coeff)), p_all,
560-
ptr(h->X_TMP_1, ZRegS(IDX(vmm_pol_idx)), SXTW));
558+
h->ld1w(vmm_coeff, p_all, ptr(h->X_TMP_1, vmm_pol_idx, SXTW));
561559
};
562560

563561
// because tanh(x) = -tanh(-x), we extract sign to make x postive
564562
// and reapply sign at the end
565563
h->fabs(vmm_src_pos, p_all / T_z, vmm_src);
566564

567565
// Compute indices for the table lookup
568-
h->sub(ZRegS(IDX(vmm_indices)), ZRegS(IDX(vmm_src_pos)),
569-
ZRegS(IDX(table_val(tanh_idx_bias, z_tmp))));
566+
h->sub(vmm_indices, vmm_src_pos, table_val(tanh_idx_bias, z_tmp));
570567
h->and_(ZRegD(IDX(vmm_indices)), ZRegD(IDX(vmm_indices)),
571568
ZRegD(IDX(table_val(tanh_idx_mask, z_tmp))));
572569
h->lsr(ZRegD(IDX(vmm_indices)), ZRegD(IDX(vmm_indices)), 20);
573570

574571
// Argument reduction
575572
h->and_(ZRegD(IDX(vmm_src_shift)), ZRegD(IDX(vmm_src_pos)),
576573
ZRegD(IDX(table_val(tanh_idx_mask, z_tmp))));
577-
h->fsub(vmm_src_pos, vmm_src_pos, ZRegS(IDX(vmm_src_shift)));
574+
h->fsub(vmm_src_pos, vmm_src_pos, vmm_src_shift);
578575

579576
gather_coefficient(vmm_pol, 6, vmm_indices);
580577
for (int deg = 5; deg >= 0; --deg) {
@@ -799,7 +796,7 @@ void jit_uni_eltwise_injector_t<isa>::soft_relu_compute_vector_fwd(
799796

800797
// calculate exp(x)
801798
// fx = x * log2ef + 0.5
802-
h->fmul(vmm_src, vmm_src, ZRegS(IDX(table_val(exp_log2ef, z_tmp))));
799+
h->fmul(vmm_src, vmm_src, table_val(exp_log2ef, z_tmp));
803800
h->fadd(vmm_src, p_all / T_m, 0.5f);
804801

805802
// tmp = floorf(fx)
@@ -809,19 +806,15 @@ void jit_uni_eltwise_injector_t<isa>::soft_relu_compute_vector_fwd(
809806
h->mov(ZRegD(IDX(vmm_src)), ZRegD(IDX(vmm_aux0)));
810807

811808
// x = x - fx * ln2
812-
h->fmul(vmm_aux0, vmm_aux0, ZRegS(IDX(table_val(ln2f, z_tmp))));
809+
h->fmul(vmm_aux0, vmm_aux0, table_val(ln2f, z_tmp));
813810
h->fsub(vmm_aux1, vmm_aux1, vmm_aux0);
814811
// compute exponent polynomial
815812
table_val(exp_pol, vmm_aux3, 4);
816-
h->fmad(vmm_aux3, p_all / T_m, vmm_aux1,
817-
ZRegS(IDX(table_val(exp_pol, z_tmp, 3))));
818-
h->fmad(vmm_aux3, p_all / T_m, vmm_aux1,
819-
ZRegS(IDX(table_val(exp_pol, z_tmp, 2))));
820-
h->fmad(vmm_aux3, p_all / T_m, vmm_aux1,
821-
ZRegS(IDX(table_val(exp_pol, z_tmp, 1))));
822-
h->fmad(vmm_aux3, p_all / T_m, vmm_aux1,
823-
ZRegS(IDX(table_val(exp_pol, z_tmp, 0))));
824-
h->fmad(vmm_aux3, p_all / T_m, vmm_aux1, ZRegS(IDX(table_val(one, z_tmp))));
813+
h->fmad(vmm_aux3, p_all / T_m, vmm_aux1, table_val(exp_pol, z_tmp, 3));
814+
h->fmad(vmm_aux3, p_all / T_m, vmm_aux1, table_val(exp_pol, z_tmp, 2));
815+
h->fmad(vmm_aux3, p_all / T_m, vmm_aux1, table_val(exp_pol, z_tmp, 1));
816+
h->fmad(vmm_aux3, p_all / T_m, vmm_aux1, table_val(exp_pol, z_tmp, 0));
817+
h->fmad(vmm_aux3, p_all / T_m, vmm_aux1, table_val(one, z_tmp));
825818

826819
// We do not count 2^-n here, because n can reach 128 and 2^(-128) is not
827820
// representable by fp32, so to get around this problem, instead of computing
@@ -849,8 +842,7 @@ void jit_uni_eltwise_injector_t<isa>::soft_relu_compute_vector_fwd(
849842
h->lsr(vmm_src, vmm_aux3, n_mantissa_bits);
850843
h->scvtf(vmm_src, p_all / T_m, vmm_src);
851844
// got n. where n is x = 2^n * y. y = 0.5 .. 1
852-
h->fsub(vmm_src, vmm_src,
853-
ZRegS(IDX(table_val(soft_relu_one_twenty_six, z_tmp))));
845+
h->fsub(vmm_src, vmm_src, table_val(soft_relu_one_twenty_six, z_tmp));
854846

855847
// and with mask (to get 0.5 * mantissa)
856848
h->and_(ZRegD(IDX(vmm_aux3)), ZRegD(IDX(vmm_aux3)),
@@ -865,23 +857,23 @@ void jit_uni_eltwise_injector_t<isa>::soft_relu_compute_vector_fwd(
865857
h->mov(ZRegD(IDX(vmm_aux1)),
866858
ZRegD(IDX(table_val(soft_relu_pol, z_tmp, 8))));
867859
h->fmad(vmm_aux1, p_all / T_m, vmm_aux3,
868-
ZRegS(IDX(table_val(soft_relu_pol, z_tmp, 7))));
860+
table_val(soft_relu_pol, z_tmp, 7));
869861
h->fmad(vmm_aux1, p_all / T_m, vmm_aux3,
870-
ZRegS(IDX(table_val(soft_relu_pol, z_tmp, 6))));
862+
table_val(soft_relu_pol, z_tmp, 6));
871863
h->fmad(vmm_aux1, p_all / T_m, vmm_aux3,
872-
ZRegS(IDX(table_val(soft_relu_pol, z_tmp, 5))));
864+
table_val(soft_relu_pol, z_tmp, 5));
873865
h->fmad(vmm_aux1, p_all / T_m, vmm_aux3,
874-
ZRegS(IDX(table_val(soft_relu_pol, z_tmp, 4))));
866+
table_val(soft_relu_pol, z_tmp, 4));
875867
h->fmad(vmm_aux1, p_all / T_m, vmm_aux3,
876-
ZRegS(IDX(table_val(soft_relu_pol, z_tmp, 3))));
868+
table_val(soft_relu_pol, z_tmp, 3));
877869
h->fmad(vmm_aux1, p_all / T_m, vmm_aux3,
878-
ZRegS(IDX(table_val(soft_relu_pol, z_tmp, 2))));
870+
table_val(soft_relu_pol, z_tmp, 2));
879871
h->fmad(vmm_aux1, p_all / T_m, vmm_aux3,
880-
ZRegS(IDX(table_val(soft_relu_pol, z_tmp, 1))));
872+
table_val(soft_relu_pol, z_tmp, 1));
881873
h->fmad(vmm_aux1, p_all / T_m, vmm_aux3,
882-
ZRegS(IDX(table_val(soft_relu_pol, z_tmp, 0))));
874+
table_val(soft_relu_pol, z_tmp, 0));
883875
//calculate ln(2) * n
884-
h->fmul(vmm_src, vmm_src, ZRegS(IDX(table_val(ln2f, z_tmp))));
876+
h->fmul(vmm_src, vmm_src, table_val(ln2f, z_tmp));
885877
h->fadd(vmm_src, vmm_src, vmm_aux1);
886878
h->fadd(vmm_src, vmm_src, vmm_aux0);
887879

@@ -944,11 +936,11 @@ template <cpu_isa_t isa>
944936
void jit_uni_eltwise_injector_t<isa>::log_compute_vector_fwd(
945937
const TRegS &vmm_src) {
946938

947-
const auto &t0 = ZRegS(IDX(vmm_src));
948-
const auto &t1 = ZRegS(IDX(vmm_aux1));
949-
const auto &t2 = ZRegS(IDX(vmm_aux2));
950-
const auto &t3 = ZRegS(IDX(vmm_aux3));
951-
const auto &t4 = ZRegS(IDX(vmm_aux4));
939+
const auto &t0 = vmm_src;
940+
const auto &t1 = vmm_aux1;
941+
const auto &t2 = vmm_aux2;
942+
const auto &t3 = vmm_aux3;
943+
const auto &t4 = vmm_aux4;
952944
const auto &mask = p_tmp0.s;
953945
const auto &wt0 = h->W_TMP_0;
954946
const auto &xt0 = h->X_TMP_0;
@@ -1067,39 +1059,35 @@ void jit_uni_eltwise_injector_t<
10671059
table_off(gelu_erf_minimax_pol,
10681060
coeff_idx * gelu_erf_n_polynomials),
10691061
h->X_TMP_0);
1070-
h->ld1w(ZRegS(IDX(vmm_coeff)), p_all / T_z,
1071-
ptr(h->X_TMP_1, ZRegS(IDX(vmm_pol_idx)), SXTW));
1062+
h->ld1w(vmm_coeff, p_all / T_z, ptr(h->X_TMP_1, vmm_pol_idx, SXTW));
10721063
};
10731064

10741065
// we use the erf function symmetry erf(-x) = -erf(x)
10751066
// So we make x positive, we will reapply the sign after erf evaluation
10761067
h->fabs(vmm_src_pos, p_all / T_z, vmm_src);
10771068

10781069
// Compute indices for table lookup
1079-
h->add(vmm_indices, vmm_src_pos,
1080-
ZRegS(IDX(table_val(gelu_erf_idx_bias, z_tmp, 0))));
1070+
h->add(vmm_indices, vmm_src_pos, table_val(gelu_erf_idx_bias, z_tmp, 0));
10811071

10821072
// An arithmetic shift is needed to properly map denormals to
10831073
// their polynomial. we shift by 21 as we use 2 bits of mantissa
10841074
// for indexing.
1085-
h->asr(ZRegS(IDX(vmm_indices)), ZRegS(IDX(vmm_indices)), 21);
1075+
h->asr(vmm_indices, vmm_indices, 21);
10861076

10871077
// Apply special rules
1088-
h->smax(vmm_indices, p_all / T_z,
1089-
ZRegS(IDX(table_val(gelu_erf_one, z_tmp))));
1090-
h->smin(vmm_indices, p_all / T_z,
1091-
ZRegS(IDX(table_val(gelu_erf_twenty_four, z_tmp))));
1078+
h->smax(vmm_indices, p_all / T_z, table_val(gelu_erf_one, z_tmp));
1079+
h->smin(vmm_indices, p_all / T_z, table_val(gelu_erf_twenty_four, z_tmp));
10921080

10931081
// We have to check
10941082
// index = x_pos > rbound ? 23 : index;
10951083
// for erf to return -1/1 when we should.
10961084
h->fcmlt(p_mask.s, p_all / T_z, vmm_src_pos,
1097-
ZRegS(IDX(table_val(gelu_erf_rbound, z_tmp))));
1085+
table_val(gelu_erf_rbound, z_tmp));
10981086
h->sel(vmm_indices, p_mask, vmm_indices,
1099-
ZRegS(IDX(table_val(gelu_erf_twenty_three, z_tmp))));
1087+
table_val(gelu_erf_twenty_three, z_tmp));
11001088

11011089
// Adjusting indices
1102-
h->mul(ZRegS(IDX(vmm_indices)), sizeof(float));
1090+
h->mul(vmm_indices, sizeof(float));
11031091

11041092
// Evaluate the polynomial
11051093
gather_coefficient(vmm_pol, 5, vmm_indices);
@@ -1115,9 +1103,9 @@ void jit_uni_eltwise_injector_t<
11151103
h->eor(ZRegD(IDX(vmm_pol)), p_all / T_z, ZRegD(IDX(vmm_tmp)));
11161104

11171105
// Compute the final output
1118-
h->fadd(vmm_pol, vmm_pol, ZRegS(IDX(table_val(one, z_tmp))));
1106+
h->fadd(vmm_pol, vmm_pol, table_val(one, z_tmp));
11191107
h->fmul(vmm_src, p_all / T_z, vmm_pol);
1120-
h->fmul(vmm_src, vmm_src, ZRegS(IDX(table_val(half, z_tmp))));
1108+
h->fmul(vmm_src, vmm_src, table_val(half, z_tmp));
11211109
}
11221110
template <cpu_isa_t isa>
11231111
void jit_uni_eltwise_injector_t<isa>::gelu_erf_compute_vector_fwd(
@@ -1142,8 +1130,7 @@ void jit_uni_eltwise_injector_t<isa>::gelu_erf_compute_vector_fwd(
11421130
h->mov(ZRegD(IDX(vmm_aux3)), ZRegD(IDX(vmm_src)));
11431131

11441132
// x = s / sqrt(2)
1145-
h->fmul(vmm_src, vmm_src,
1146-
ZRegS(IDX(table_val(gelu_erf_one_over_sqrt_two, z_tmp))));
1133+
h->fmul(vmm_src, vmm_src, table_val(gelu_erf_one_over_sqrt_two, z_tmp));
11471134

11481135
// abs(x)
11491136
h->fabs(vmm_aux1, p_all / T_m, vmm_src);
@@ -1169,17 +1156,13 @@ void jit_uni_eltwise_injector_t<isa>::gelu_erf_compute_vector_fwd(
11691156

11701157
// compute polynomialial r
11711158
table_val(gelu_erf_pol, vmm_aux1, 4);
1172-
h->fmad(vmm_aux1, p_all / T_m, vmm_aux4,
1173-
ZRegS(IDX(table_val(gelu_erf_pol, z_tmp, 3))));
1174-
h->fmad(vmm_aux1, p_all / T_m, vmm_aux4,
1175-
ZRegS(IDX(table_val(gelu_erf_pol, z_tmp, 2))));
1176-
h->fmad(vmm_aux1, p_all / T_m, vmm_aux4,
1177-
ZRegS(IDX(table_val(gelu_erf_pol, z_tmp, 1))));
1178-
h->fmad(vmm_aux1, p_all / T_m, vmm_aux4,
1179-
ZRegS(IDX(table_val(gelu_erf_pol, z_tmp, 0))));
1159+
h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, table_val(gelu_erf_pol, z_tmp, 3));
1160+
h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, table_val(gelu_erf_pol, z_tmp, 2));
1161+
h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, table_val(gelu_erf_pol, z_tmp, 1));
1162+
h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, table_val(gelu_erf_pol, z_tmp, 0));
11801163

11811164
// erf = sign * (1 - r * t * exp(-x*x))
1182-
h->fmad(vmm_src, p_all / T_m, vmm_aux1, ZRegS(IDX(table_val(one, z_tmp))));
1165+
h->fmad(vmm_src, p_all / T_m, vmm_aux1, table_val(one, z_tmp));
11831166
h->eor(ZRegD(IDX(vmm_src)), ZRegD(IDX(vmm_src)), ZRegD(IDX(vmm_aux0)));
11841167

11851168
// S = 0.5 * s
@@ -1207,12 +1190,12 @@ void jit_uni_eltwise_injector_t<isa>::elu_compute_vector_bwd(
12071190
// after exponentiation, get mask by comparing with exp(0)=1.f, not 0.f
12081191
compute_cmp_mask(vmm_src, table_val(one, z_tmp), _cmp_gt_os);
12091192
// R * alpha, then blend with 1.f
1210-
h->fmul(vmm_src, vmm_src, ZRegS(IDX(table_val(alpha, z_tmp))));
1193+
h->fmul(vmm_src, vmm_src, table_val(alpha, z_tmp));
12111194
} else {
12121195
// get mask of `d` > 0
12131196
compute_cmp_mask(vmm_src, table_val(zero, z_tmp), _cmp_gt_os);
12141197
// R = `d` + alpha, then blend with 1.f
1215-
h->fadd(vmm_src, vmm_src, ZRegS(IDX(table_val(alpha, z_tmp))));
1198+
h->fadd(vmm_src, vmm_src, table_val(alpha, z_tmp));
12161199
}
12171200
blend_with_mask(vmm_src, table_val(one, z_tmp));
12181201
}
@@ -1241,13 +1224,12 @@ void jit_uni_eltwise_injector_t<isa>::gelu_tanh_compute_vector_bwd(
12411224
// keep G2 in a separate register
12421225
h->mov(ZRegD(IDX(vmm_aux2)),
12431226
ZRegD(IDX(table_val(gelu_tanh_fitting_const_times_three, z_tmp))));
1244-
h->fmad(vmm_aux2, p_all / T_m, vmm_src, ZRegS(IDX(table_val(one, z_tmp))));
1227+
h->fmad(vmm_aux2, p_all / T_m, vmm_src, table_val(one, z_tmp));
12451228

12461229
h->mov(ZRegD(IDX(vmm_aux1)),
12471230
ZRegD(IDX(table_val(gelu_tanh_fitting_const, z_tmp))));
1248-
h->fmad(vmm_src, p_all / T_m, vmm_aux1, ZRegS(IDX(table_val(one, z_tmp))));
1249-
h->fmul(vmm_aux0, vmm_aux0,
1250-
ZRegS(IDX(table_val(gelu_tanh_sqrt_two_over_pi, z_tmp))));
1231+
h->fmad(vmm_src, p_all / T_m, vmm_aux1, table_val(one, z_tmp));
1232+
h->fmul(vmm_aux0, vmm_aux0, table_val(gelu_tanh_sqrt_two_over_pi, z_tmp));
12511233
h->fmul(vmm_src, vmm_src, vmm_aux0);
12521234
h->fmul(vmm_aux2, vmm_aux2, vmm_aux0);
12531235

@@ -1267,11 +1249,11 @@ void jit_uni_eltwise_injector_t<isa>::gelu_tanh_compute_vector_bwd(
12671249
// 1) R = G2 * (1 - T) = G2 - G2 * T
12681250
h->fmls(vmm_aux2, p_all / T_m, vmm_aux2, vmm_src);
12691251
// 2) Q = 1 + T
1270-
h->fadd(vmm_src, vmm_src, ZRegS(IDX(table_val(one, z_tmp))));
1252+
h->fadd(vmm_src, vmm_src, table_val(one, z_tmp));
12711253
// 3) res = Q * (1 + R) = Q + Q * R
12721254
h->fmla(vmm_src, p_all / T_m, vmm_src, vmm_aux2);
12731255

1274-
h->fmul(vmm_src, vmm_src, ZRegS(IDX(table_val(half, z_tmp))));
1256+
h->fmul(vmm_src, vmm_src, table_val(half, z_tmp));
12751257
}
12761258

12771259
template <cpu_isa_t isa>
@@ -1387,7 +1369,7 @@ template <cpu_isa_t isa>
13871369
void jit_uni_eltwise_injector_t<isa>::swish_compute_vector_bwd(
13881370
const TRegS &vmm_src) {
13891371
// R = alpha * s
1390-
h->fmul(vmm_src, vmm_src, ZRegS(IDX(table_val(alpha, z_tmp))));
1372+
h->fmul(vmm_src, vmm_src, table_val(alpha, z_tmp));
13911373

13921374
// Save R on stack for later usage
13931375
h->sub_imm(h->X_SP, h->X_SP, vlen, h->X_TMP_0);
@@ -1444,8 +1426,7 @@ template <cpu_isa_t isa>
14441426
void jit_uni_eltwise_injector_t<isa>::gelu_erf_compute_vector_bwd(
14451427
const TRegS &vmm_src) {
14461428
// R = s / sqrt(2)
1447-
h->fmul(vmm_src, vmm_src,
1448-
ZRegS(IDX(table_val(gelu_erf_one_over_sqrt_two, z_tmp))));
1429+
h->fmul(vmm_src, vmm_src, table_val(gelu_erf_one_over_sqrt_two, z_tmp));
14491430

14501431
// Save R on stack for later usage
14511432
h->sub_imm(h->X_SP, h->X_SP, vlen, h->X_TMP_0);
@@ -1461,8 +1442,7 @@ void jit_uni_eltwise_injector_t<isa>::gelu_erf_compute_vector_bwd(
14611442
// T = R / sqrt(pi) * Q
14621443
h->add_imm(h->X_TMP_0, h->X_SP, 0, h->X_TMP_1);
14631444
h->ldr(ZReg(IDX(vmm_aux2)), ptr(h->X_TMP_0));
1464-
h->fmul(vmm_aux2, vmm_aux2,
1465-
ZRegS(IDX(table_val(gelu_erf_one_over_sqrt_pi, z_tmp))));
1445+
h->fmul(vmm_aux2, vmm_aux2, table_val(gelu_erf_one_over_sqrt_pi, z_tmp));
14661446
h->fmul(vmm_aux2, vmm_aux2, vmm_src);
14671447

14681448
// -Q
@@ -1494,23 +1474,19 @@ void jit_uni_eltwise_injector_t<isa>::gelu_erf_compute_vector_bwd(
14941474

14951475
// compute polynomial r
14961476
h->mov(ZRegD(IDX(vmm_aux1)), ZRegD(IDX(table_val(gelu_erf_pol, z_tmp, 4))));
1497-
h->fmad(vmm_aux1, p_all / T_m, vmm_aux4,
1498-
ZRegS(IDX(table_val(gelu_erf_pol, z_tmp, 3))));
1499-
h->fmad(vmm_aux1, p_all / T_m, vmm_aux4,
1500-
ZRegS(IDX(table_val(gelu_erf_pol, z_tmp, 2))));
1501-
h->fmad(vmm_aux1, p_all / T_m, vmm_aux4,
1502-
ZRegS(IDX(table_val(gelu_erf_pol, z_tmp, 1))));
1503-
h->fmad(vmm_aux1, p_all / T_m, vmm_aux4,
1504-
ZRegS(IDX(table_val(gelu_erf_pol, z_tmp, 0))));
1477+
h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, table_val(gelu_erf_pol, z_tmp, 3));
1478+
h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, table_val(gelu_erf_pol, z_tmp, 2));
1479+
h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, table_val(gelu_erf_pol, z_tmp, 1));
1480+
h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, table_val(gelu_erf_pol, z_tmp, 0));
15051481

15061482
// erf = sign * (1 - r * t * exp(-x*x))
1507-
h->fmad(vmm_src, p_all / T_m, vmm_aux1, ZRegS(IDX(table_val(one, z_tmp))));
1483+
h->fmad(vmm_src, p_all / T_m, vmm_aux1, table_val(one, z_tmp));
15081484
h->eor(ZRegD(IDX(vmm_src)), ZRegD(IDX(vmm_src)), ZRegD(IDX(vmm_aux0)));
15091485

15101486
// P = T + 0.5
1511-
h->fadd(vmm_aux2, vmm_aux2, ZRegS(IDX(table_val(half, z_tmp))));
1487+
h->fadd(vmm_aux2, vmm_aux2, table_val(half, z_tmp));
15121488
// res = P + 0.5 * erf
1513-
h->fmla(vmm_aux2, p_all / T_m, vmm_src, ZRegS(IDX(table_val(half, z_tmp))));
1489+
h->fmla(vmm_aux2, p_all / T_m, vmm_src, table_val(half, z_tmp));
15141490
h->mov(ZRegD(IDX(vmm_src)), ZRegD(IDX(vmm_aux2)));
15151491
}
15161492

@@ -1748,8 +1724,7 @@ void jit_uni_eltwise_injector_t<isa>::compute_body(
17481724
}
17491725
}
17501726
if (scale_ != 1.f) {
1751-
h->fmul(ZRegS(IDX(TRegS(idx))), ZRegS(IDX(TRegS(idx))),
1752-
ZRegS(IDX(table_val(scale, vmm_mask))));
1727+
h->fmul(TRegS(idx), TRegS(idx), table_val(scale, vmm_mask));
17531728
}
17541729
});
17551730
}

0 commit comments

Comments
 (0)