diff --git a/src/cpu/aarch64/injectors/jit_uni_eltwise_injector.cpp b/src/cpu/aarch64/injectors/jit_uni_eltwise_injector.cpp index 108d81493c7..babbd365c27 100644 --- a/src/cpu/aarch64/injectors/jit_uni_eltwise_injector.cpp +++ b/src/cpu/aarch64/injectors/jit_uni_eltwise_injector.cpp @@ -55,9 +55,11 @@ bool is_supported(cpu_isa_t isa, alg_kind_t alg) { if (is_isa_supported(isa)) { if (isa == asimd) { using namespace alg_kind; - return utils::one_of(alg, eltwise_relu, eltwise_square, eltwise_abs, - eltwise_sqrt, eltwise_linear, eltwise_exp, - eltwise_hardsigmoid, eltwise_hardswish, eltwise_clip, + return utils::one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_soft_relu, eltwise_logistic, eltwise_mish, + eltwise_exp, eltwise_gelu_tanh, eltwise_hardsigmoid, + eltwise_hardswish, eltwise_swish, eltwise_clip, eltwise_clip_v2, eltwise_round, eltwise_clip_v2_use_dst_for_bwd); } else { @@ -232,7 +234,9 @@ void jit_uni_eltwise_injector_t::set_coef_to_regs() { if (alpha_ != 0.f) table_val(alpha, z_tmp); break; case eltwise_elu_use_dst_for_bwd: - case eltwise_elu: table_val(alpha, vmm_aux4); break; + case eltwise_elu: + table_val(alpha, isa == asimd ? vmm_aux4 : vmm_aux2); + break; case eltwise_tanh_use_dst_for_bwd: case eltwise_tanh: case eltwise_square: @@ -457,31 +461,35 @@ void jit_uni_eltwise_injector_t::blend_with_mask( template void jit_uni_eltwise_injector_t::exp_compute_vector_fwd( const TRegS &vmm_src) { + exp_compute_vector_fwd(vmm_src, min_input_, max_input_); +} - const auto &t0 = ZRegS(IDX(vmm_src)); - const auto &t1 = ZRegS(IDX(vmm_aux0)); - const auto &t2 = ZRegS(IDX(vmm_aux1)); +template +void jit_uni_eltwise_injector_t::exp_compute_vector_fwd( + const TRegS &vmm_src, float min_input, float max_input) { + + const auto &t0 = vmm_src; + const auto &t1 = vmm_aux0; + const auto &t2 = vmm_aux1; const float ln_flt_max = logf(FLT_MAX); const float ln_flt_min = logf(FLT_MIN); - - if (max_input_ > ln_flt_max) - h->fmin(t0, p_all, ZRegS(IDX(table_val(exp_ln_flt_max_f, z_tmp)))); - if (min_input_ < ln_flt_min) - h->fmax(t0, p_all, ZRegS(IDX(table_val(exp_ln_flt_min_f, z_tmp)))); - h->fmul(t0, t0, ZRegS(IDX(table_val(exp_log2ef, t1)))); + if (max_input > ln_flt_max) + h->fmin(t0, p_all, table_val(exp_ln_flt_max_f, z_tmp)); + if (min_input < ln_flt_min) + h->fmax(t0, p_all, table_val(exp_ln_flt_min_f, z_tmp)); + h->fmul(t0, t0, table_val(exp_log2ef, t1)); h->frintm(t1, p_all, t0); h->fcvtzs(t2, p_all, t1); h->fsub(t1, t0, t1); - h->fadd(t0, t1, ZRegS(IDX(table_val(one, z_tmp)))); + h->fadd(t0, t1, table_val(one, z_tmp)); h->lsr(t1, t0, 17); h->fexpa(t1, t1); h->fscale(t1, p_all, t2); h->and_(ZRegD(t2.getIdx()), ZRegD(t0.getIdx()), ZRegD(IDX(table_val(exp_not_mask17, z_tmp)))); h->fsub(t2, t0, t2); - h->fmad(ZRegS(IDX(table_val(exp_coeff2, t0))), p_all, t2, - ZRegS(IDX(table_val(exp_coeff1, z_tmp)))); - h->fmad(t0, p_all, t2, ZRegS(IDX(table_val(one, z_tmp)))); + h->fmad(table_val(exp_coeff2, t0), p_all, t2, table_val(exp_coeff1, z_tmp)); + h->fmad(t0, p_all, t2, table_val(one, z_tmp)); h->fmul(t0, t1, t0); } @@ -506,15 +514,18 @@ void jit_uni_eltwise_injector_t::relu_zero_ns_compute_vector_fwd( template void jit_uni_eltwise_injector_t::elu_compute_vector_fwd( const TRegS &vmm_src) { - // IMPORTANT: we use vmm_aux3 for the mask as exp_compute does not use it. + // IMPORTANT: we use vmm_aux3 to save src as exp does not use it. h->mov(ZRegD(vmm_aux3.getIdx()), ZRegD(vmm_src.getIdx())); // compute exponent - exp_compute_vector_fwd(vmm_src); + exp_compute_vector_fwd(vmm_src, min_input_, 0.f); - // alpha * (exp(x) - 1) - h->fsub(vmm_src, p_all / T_m, 1.); - h->fmul(vmm_src, vmm_src, vmm_aux4); + // alpha * (exp(x) - 1) = alpha * exp(x) - alpha + if (alpha_ != 1.f) { + h->fnmsb(vmm_src, p_all / T_m, vmm_aux2, vmm_aux2); + } else { + h->fsub(vmm_src, p_all / T_m, 1.); + } // combine with mask h->fcmgt(p_mask.s, p_all / T_z, vmm_aux3, 0.); @@ -532,9 +543,9 @@ void jit_uni_eltwise_injector_t::tanh_polynomial_approx_compute_vector_fwd( const int tanh_n_polynomials = 32; // Register mapping - TRegS vmm_dst = vmm_aux1, vmm_src_shift = vmm_aux1, vmm_coeff = vmm_aux1, - vmm_pol = vmm_aux2, vmm_indices = vmm_aux3, vmm_tmp = vmm_aux3, - vmm_src_pos = vmm_aux4, vmm_sign = vmm_aux4; + TRegS vmm_dst = vmm_aux0, vmm_src_shift = vmm_aux0, vmm_coeff = vmm_aux0, + vmm_pol = vmm_aux1, vmm_indices = vmm_aux2, vmm_tmp = vmm_aux2, + vmm_src_pos = vmm_aux3, vmm_sign = vmm_aux3; const auto &mask = PReg(6); // avoid pred regs used in *conv_kernel* @@ -544,8 +555,7 @@ void jit_uni_eltwise_injector_t::tanh_polynomial_approx_compute_vector_fwd( h->add_imm(h->X_TMP_1, x_table, table_off(tanh_pol_table, coeff_idx * tanh_n_polynomials), h->X_TMP_0); - h->ld1w(ZRegS(IDX(vmm_coeff)), p_all, - ptr(h->X_TMP_1, ZRegS(IDX(vmm_pol_idx)), SXTW)); + h->ld1w(vmm_coeff, p_all, ptr(h->X_TMP_1, vmm_pol_idx, SXTW)); }; // because tanh(x) = -tanh(-x), we extract sign to make x postive @@ -553,8 +563,7 @@ void jit_uni_eltwise_injector_t::tanh_polynomial_approx_compute_vector_fwd( h->fabs(vmm_src_pos, p_all / T_z, vmm_src); // Compute indices for the table lookup - h->sub(ZRegS(IDX(vmm_indices)), ZRegS(IDX(vmm_src_pos)), - ZRegS(IDX(table_val(tanh_idx_bias, z_tmp)))); + h->sub(vmm_indices, vmm_src_pos, table_val(tanh_idx_bias, z_tmp)); h->and_(ZRegD(IDX(vmm_indices)), ZRegD(IDX(vmm_indices)), ZRegD(IDX(table_val(tanh_idx_mask, z_tmp)))); h->lsr(ZRegD(IDX(vmm_indices)), ZRegD(IDX(vmm_indices)), 20); @@ -562,7 +571,7 @@ void jit_uni_eltwise_injector_t::tanh_polynomial_approx_compute_vector_fwd( // Argument reduction h->and_(ZRegD(IDX(vmm_src_shift)), ZRegD(IDX(vmm_src_pos)), ZRegD(IDX(table_val(tanh_idx_mask, z_tmp)))); - h->fsub(vmm_src_pos, vmm_src_pos, ZRegS(IDX(vmm_src_shift))); + h->fsub(vmm_src_pos, vmm_src_pos, vmm_src_shift); gather_coefficient(vmm_pol, 6, vmm_indices); for (int deg = 5; deg >= 0; --deg) { @@ -605,18 +614,17 @@ void jit_uni_eltwise_injector_t::tanh_compute_vector_fwd( // tanh(x) = x(1 + (-1/3)x^2) for |x| < tanh_range // tanh(x) = 1 - 2/(1 + exp(2 x)) for otherwise - const auto &t0 = ZRegS(IDX(vmm_src)); - const auto &t1 = ZRegS(IDX(vmm_aux1)); - const auto &t2 = ZRegS(IDX(vmm_aux2)); - const auto &t3 = ZRegS(IDX(vmm_aux3)); - const auto &oneS = ZRegS(IDX(vmm_aux4)); + const auto &t0 = vmm_src; + const auto &t1 = vmm_aux1; + const auto &t2 = vmm_aux2; + const auto &oneS = vmm_aux3; const auto &mask = PReg(6); // avoid pred regs used in *conv_kernel* h->fcpy(oneS, p_all, 1); // make mask for small x - h->mov(t3, p_all, t0); + h->mov(t2, p_all, t0); h->fabs(t1, p_all, t0); - h->cmplt(mask.s, p_all, t1, ZRegS(IDX(table_val(tanh_range, z_tmp)))); + h->cmplt(mask.s, p_all, t1, table_val(tanh_range, z_tmp)); // 2x h->fadd(t0, t0, t0); @@ -624,25 +632,15 @@ void jit_uni_eltwise_injector_t::tanh_compute_vector_fwd( exp_compute_vector_fwd(t0); // 1+exp(2x) h->fadd(t0, t0, oneS); - // 1/(1+exp(2x)) - // 1st aprox ; a = 1/x + e - h->frecpe(t1, t0); - // 2nd aprox ; a' = (2 - ax)a = 1/x - e^2 x - h->frecps(t2, t0, t1); - h->fmul(t2, t2, t1); - // 3rd aprox ; a'' = (2 - a'x)a' - h->frecps(t0, t0, t2); - h->fmul(t0, t0, t2); - // 2/(1+exp(2x)) - h->fadd(t0, t0, t0); + h->fdiv(table_val(two, t1), p_all, t0); // 1-2/(1+exp(2x)) - h->fsub(t0, oneS, t0); + h->fsub(t0, oneS, t1); // tanh(x) = x(1 - x^2/3) for |x| < tanh_range - h->fmul(t1, t3, t3); - h->fmad(t1, p_all, ZRegS(IDX(table_val(tanh_m1d3, z_tmp))), oneS); - h->fmul(t1, p_all, t3); + h->fmul(t1, t2, t2); + h->fmad(t1, p_all, table_val(tanh_m1d3, z_tmp), oneS); + h->fmul(t1, p_all, t2); // select the correct value according to mask h->mov(t0, mask, t1); } @@ -650,37 +648,22 @@ void jit_uni_eltwise_injector_t::tanh_compute_vector_fwd( template void jit_uni_eltwise_injector_t::gelu_tanh_compute_vector_fwd( const TRegS &vmm_src) { - h->mov(ZRegD(IDX(vmm_aux0)), ZRegD(IDX(vmm_src))); + // IMPORTANT: we use vmm_aux4 to save src as tanh does not use it. + h->mov(ZRegD(IDX(vmm_aux4)), ZRegD(IDX(vmm_src))); // compute G(x) = sqrt_root_two_over_pi * x * (1 + fitting_const * x * x) h->fmul(vmm_src, vmm_src, vmm_src); - h->mov(ZRegD(IDX(vmm_aux1)), - ZRegD(IDX(table_val(gelu_tanh_fitting_const, z_tmp)))); - /* Do not use 1.f, which is a float constant, - but 1., which is a double constant. */ - h->fmov(z_tmp, 1.); - h->fmad(vmm_src, p_all / T_m, vmm_aux1, ZRegS(IDX(table_val(one, z_tmp)))); - h->fmul(vmm_src, vmm_src, vmm_aux0); - h->fmul(vmm_src, vmm_src, - ZRegS(IDX(table_val(gelu_tanh_sqrt_two_over_pi, z_tmp)))); - - // save x on stack as tanh uses vmm_aux0 - h->sub_imm(h->X_SP, h->X_SP, vlen, h->X_TMP_0); - - h->add_imm(h->X_TMP_0, h->X_SP, 0, h->X_TMP_1); - h->str(ZReg(IDX(vmm_aux0)), ptr(h->X_TMP_0)); + h->fmad(vmm_src, p_all / T_m, table_val(gelu_tanh_fitting_const, vmm_aux0), + table_val(one, z_tmp)); + h->fmul(vmm_aux2, vmm_aux4, table_val(gelu_tanh_sqrt_two_over_pi, z_tmp)); + h->fmul(vmm_src, vmm_src, vmm_aux2); // compute tanh(G(x)) tanh_compute_vector_fwd(vmm_src); - h->add_imm(h->X_TMP_0, h->X_SP, 0, h->X_TMP_1); - h->ldr(ZReg(IDX(vmm_aux0)), ptr(h->X_TMP_0)); - h->add_imm(h->X_SP, h->X_SP, vlen, h->X_TMP_0); - // compute 0.5 * x * (1 + tanh(G(x))) - h->fadd(vmm_src, p_all / T_m, 1.); + h->fmad(vmm_src, p_all / T_m, vmm_aux4, vmm_aux4); h->fmul(vmm_src, p_all / T_m, 0.5f); - h->fmul(vmm_src, vmm_src, vmm_aux0); } template @@ -742,24 +725,24 @@ void jit_uni_eltwise_injector_t::mish_compute_vector_fwd( // than exp, and also requires more constants to be stored in memory, // making the algorithm slower. - // IMPORTANT: we use vmm_aux3 to save src as exp does not use it. - h->mov(ZRegD(vmm_aux3.getIdx()), ZRegD(vmm_src.getIdx())); // vmm_aux3 = x + // IMPORTANT: we use vmm_aux2 to save src as exp does not use it. + h->mov(ZRegD(vmm_aux2.getIdx()), ZRegD(vmm_src.getIdx())); + + // mish uses (exp(x)+1)^2; + // avoid overflow by x <= log(sqrt(FLT_MAX)) h->fminnm(vmm_src, p_all / T_m, table_val(fwd_mish_max_x_for_equation_f, z_tmp)); - exp_compute_vector_fwd(vmm_src); + exp_compute_vector_fwd(vmm_src, min_input_, 0.5f * logf(FLT_MAX)); // (e^x+1)^2 h->fadd(vmm_src, p_all / T_m, 1.); h->fmul(vmm_src, vmm_src, vmm_src); - // save (e^x+1)^2 as it appears in both the denominator and the numerator - h->mov(ZRegD(vmm_aux1.getIdx()), ZRegD(vmm_src.getIdx())); - // x * ((e^x + 1)^2 - 1) / ((e^x + 1)^2 + 1) + h->fadd(vmm_aux0, vmm_src, table_val(one, z_tmp)); h->fsub(vmm_src, p_all / T_m, 1.); - h->fadd(vmm_aux1, p_all / T_m, 1.); - h->fdiv(vmm_src, p_all / T_m, vmm_aux1); - h->fmul(vmm_src, vmm_src, vmm_aux3); + h->fdiv(vmm_src, p_all / T_m, vmm_aux0); + h->fmul(vmm_src, vmm_src, vmm_aux2); } template @@ -788,11 +771,13 @@ void jit_uni_eltwise_injector_t::soft_relu_compute_vector_fwd( // alpha scaling if (alpha_ == 1.f) { // Nothing to do - } - if (alpha_ == 0.5 || alpha_ == 2.0) + } else if (alpha_ == -1) { + h->fneg(vmm_src, p_all / T_m, vmm_src); + } else if (alpha_ == 0.5 || alpha_ == 2.0) { h->fmul(vmm_src, p_all / T_m, alpha_); - else + } else { h->fmul(vmm_src, vmm_src, table_val(alpha, z_tmp)); + } // ln(1 + exp(x)) = // = ln(1 + exp(n * ln(2) + r)) // divide x by ln(2) and get quot and rem @@ -805,16 +790,13 @@ void jit_uni_eltwise_injector_t::soft_relu_compute_vector_fwd( // keep src for further computations h->mov(ZRegD(IDX(vmm_aux2)), ZRegD(IDX(vmm_src))); - h->fminnm(ZRegS(IDX(table_val(exp_ln_flt_max_f, z_tmp))), p_all, vmm_src); - h->mov(ZRegD(IDX(vmm_src)), ZRegD(IDX(z_tmp))); - h->fmaxnm(ZRegS(IDX(table_val(exp_ln_flt_min_f, z_tmp))), p_all, vmm_src); - - h->mov(ZRegD(IDX(vmm_src)), ZRegD(IDX(z_tmp))); + h->fminnm(vmm_src, p_all / T_m, table_val(exp_ln_flt_max_f, z_tmp)); + h->fmaxnm(vmm_src, p_all / T_m, table_val(exp_ln_flt_min_f, z_tmp)); h->mov(ZRegD(IDX(vmm_aux1)), ZRegD(IDX(vmm_src))); // calculate exp(x) // fx = x * log2ef + 0.5 - h->fmul(vmm_src, vmm_src, ZRegS(IDX(table_val(exp_log2ef, z_tmp)))); + h->fmul(vmm_src, vmm_src, table_val(exp_log2ef, z_tmp)); h->fadd(vmm_src, p_all / T_m, 0.5f); // tmp = floorf(fx) @@ -824,19 +806,15 @@ void jit_uni_eltwise_injector_t::soft_relu_compute_vector_fwd( h->mov(ZRegD(IDX(vmm_src)), ZRegD(IDX(vmm_aux0))); // x = x - fx * ln2 - h->fmul(vmm_aux0, vmm_aux0, ZRegS(IDX(table_val(ln2f, z_tmp)))); + h->fmul(vmm_aux0, vmm_aux0, table_val(ln2f, z_tmp)); h->fsub(vmm_aux1, vmm_aux1, vmm_aux0); // compute exponent polynomial - h->mov(ZRegD(IDX(vmm_aux3)), ZRegD(IDX(table_val(exp_pol, z_tmp, 4)))); - h->fmad(vmm_aux3, p_all / T_m, vmm_aux1, - ZRegS(IDX(table_val(exp_pol, z_tmp, 3)))); - h->fmad(vmm_aux3, p_all / T_m, vmm_aux1, - ZRegS(IDX(table_val(exp_pol, z_tmp, 2)))); - h->fmad(vmm_aux3, p_all / T_m, vmm_aux1, - ZRegS(IDX(table_val(exp_pol, z_tmp, 1)))); - h->fmad(vmm_aux3, p_all / T_m, vmm_aux1, - ZRegS(IDX(table_val(exp_pol, z_tmp, 0)))); - h->fmad(vmm_aux3, p_all / T_m, vmm_aux1, ZRegS(IDX(table_val(one, z_tmp)))); + table_val(exp_pol, vmm_aux3, 4); + h->fmad(vmm_aux3, p_all / T_m, vmm_aux1, table_val(exp_pol, z_tmp, 3)); + h->fmad(vmm_aux3, p_all / T_m, vmm_aux1, table_val(exp_pol, z_tmp, 2)); + h->fmad(vmm_aux3, p_all / T_m, vmm_aux1, table_val(exp_pol, z_tmp, 1)); + h->fmad(vmm_aux3, p_all / T_m, vmm_aux1, table_val(exp_pol, z_tmp, 0)); + h->fmad(vmm_aux3, p_all / T_m, vmm_aux1, table_val(one, z_tmp)); // We do not count 2^-n here, because n can reach 128 and 2^(-128) is not // representable by fp32, so to get around this problem, instead of computing @@ -847,13 +825,12 @@ void jit_uni_eltwise_injector_t::soft_relu_compute_vector_fwd( // vmm_src now represents n-1 h->fsub(vmm_src, p_all / T_m, 1.); h->fneg(vmm_aux1, p_all / T_m, vmm_src); - - h->frinti(vmm_aux1, p_all / T_m, vmm_aux1); h->fcvtzs(vmm_aux1, p_all / T_m, vmm_aux1); // restore vmm_src to n h->fadd(vmm_src, p_all / T_m, 1.); - h->add(vmm_aux1, vmm_aux1, ZRegS(IDX(table_val(exponent_bias, z_tmp)))); + const auto exponent_bias = 0x0000007f; + h->add(vmm_aux1, exponent_bias); h->lsl(vmm_aux1, vmm_aux1, n_mantissa_bits); // calculate ln(1 + y) h->fmul(vmm_aux3, p_all / T_m, 2.); // 2*exp(r) @@ -865,8 +842,7 @@ void jit_uni_eltwise_injector_t::soft_relu_compute_vector_fwd( h->lsr(vmm_src, vmm_aux3, n_mantissa_bits); h->scvtf(vmm_src, p_all / T_m, vmm_src); // got n. where n is x = 2^n * y. y = 0.5 .. 1 - h->fsub(vmm_src, vmm_src, - ZRegS(IDX(table_val(soft_relu_one_twenty_six, z_tmp)))); + h->fsub(vmm_src, vmm_src, table_val(soft_relu_one_twenty_six, z_tmp)); // and with mask (to get 0.5 * mantissa) h->and_(ZRegD(IDX(vmm_aux3)), ZRegD(IDX(vmm_aux3)), @@ -881,23 +857,23 @@ void jit_uni_eltwise_injector_t::soft_relu_compute_vector_fwd( h->mov(ZRegD(IDX(vmm_aux1)), ZRegD(IDX(table_val(soft_relu_pol, z_tmp, 8)))); h->fmad(vmm_aux1, p_all / T_m, vmm_aux3, - ZRegS(IDX(table_val(soft_relu_pol, z_tmp, 7)))); + table_val(soft_relu_pol, z_tmp, 7)); h->fmad(vmm_aux1, p_all / T_m, vmm_aux3, - ZRegS(IDX(table_val(soft_relu_pol, z_tmp, 6)))); + table_val(soft_relu_pol, z_tmp, 6)); h->fmad(vmm_aux1, p_all / T_m, vmm_aux3, - ZRegS(IDX(table_val(soft_relu_pol, z_tmp, 5)))); + table_val(soft_relu_pol, z_tmp, 5)); h->fmad(vmm_aux1, p_all / T_m, vmm_aux3, - ZRegS(IDX(table_val(soft_relu_pol, z_tmp, 4)))); + table_val(soft_relu_pol, z_tmp, 4)); h->fmad(vmm_aux1, p_all / T_m, vmm_aux3, - ZRegS(IDX(table_val(soft_relu_pol, z_tmp, 3)))); + table_val(soft_relu_pol, z_tmp, 3)); h->fmad(vmm_aux1, p_all / T_m, vmm_aux3, - ZRegS(IDX(table_val(soft_relu_pol, z_tmp, 2)))); + table_val(soft_relu_pol, z_tmp, 2)); h->fmad(vmm_aux1, p_all / T_m, vmm_aux3, - ZRegS(IDX(table_val(soft_relu_pol, z_tmp, 1)))); + table_val(soft_relu_pol, z_tmp, 1)); h->fmad(vmm_aux1, p_all / T_m, vmm_aux3, - ZRegS(IDX(table_val(soft_relu_pol, z_tmp, 0)))); + table_val(soft_relu_pol, z_tmp, 0)); //calculate ln(2) * n - h->fmul(vmm_src, vmm_src, ZRegS(IDX(table_val(ln2f, z_tmp)))); + h->fmul(vmm_src, vmm_src, table_val(ln2f, z_tmp)); h->fadd(vmm_src, vmm_src, vmm_aux1); h->fadd(vmm_src, vmm_src, vmm_aux0); @@ -908,10 +884,9 @@ void jit_uni_eltwise_injector_t::soft_relu_compute_vector_fwd( if (alpha_ == 1.f) { // standard soft_relu case // Skip an instruction. } else if (alpha_ == -1) { // logsigmoid case - /* Do not use -1.f, which is a float constant, - but -1., which is a double constant. */ - h->fmov(z_tmp, -1.); - h->fmul(vmm_src, vmm_src, z_tmp); + h->fneg(vmm_src, p_all / T_m, vmm_src); + } else if (alpha_ == 0.5 || alpha_ == 2.0) { + h->fmul(vmm_src, p_all / T_m, 1.f / alpha_); } else { // General case. h->fdiv(vmm_src, p_all / T_m, table_val(alpha, z_tmp)); } @@ -922,64 +897,50 @@ void jit_uni_eltwise_injector_t::logistic_compute_vector_fwd( const TRegS &vmm_src) { // To avoid exp(x) overflow happened at x > logf(FLT_MAX), negate positive, // compute exp(x), where x <= 0 to get 0 <= exp(x) <= 1 and restore value - // sign at the end. This is possible due to logistic is symmetric function. - // IMPORTANT: we use vmm_aux3 for the mask as exp_compute does not use it. - h->mov(ZRegD(IDX(vmm_aux3)), ZRegD(IDX(vmm_src))); - // we store the original sign and make x negative - h->and_(ZRegD(IDX(vmm_aux3)), ZRegD(IDX(vmm_aux3)), - ZRegD(IDX(table_val(sign_mask, z_tmp)))); - h->orr(ZRegD(IDX(vmm_src)), ZRegD(IDX(vmm_src)), - ZRegD(IDX(table_val(sign_mask, z_tmp)))); + // sign at the end. This is possible due to logistic being a symmetric function. - exp_compute_vector_fwd(vmm_src); + // Store the original sign and make x negative + constexpr unsigned sign_mask = 0x80000000; + h->fcmgt(p_mask.s, p_all / T_z, vmm_src, 0.); + h->orr(vmm_src, sign_mask); + + // Compute exponent + exp_compute_vector_fwd(vmm_src, min_input_, 0.f); - // dup exp(x) - h->mov(ZRegD(IDX(vmm_aux1)), ZRegD(IDX(vmm_src))); // (exp(x) + 1) - h->fadd(vmm_aux1, vmm_aux1, ZRegS(IDX(table_val(one, z_tmp)))); + h->fadd(vmm_aux0, vmm_src, table_val(one, z_tmp)); // y = exp(x) / (exp(x) + 1) - h->fdiv(vmm_src, p_all, vmm_aux1); + h->fdiv(vmm_src, p_all, vmm_aux0); // Now we have to apply the "symmetry" based on original sign - h->mov(ZRegD(IDX(vmm_aux2)), ZRegD(IDX(table_val(one, z_tmp)))); - h->fsub(vmm_aux2, vmm_aux2, vmm_src); - - h->and_(ZRegD(IDX(z_tmp)), ZRegD(IDX(vmm_aux3)), ZRegD(IDX(vmm_aux3))); - h->cmpne(PRegS(IDX(p_mask)), p_all / T_z, z_tmp, 0); + h->fsub(vmm_aux0, z_tmp, vmm_src); - blend_with_mask(vmm_aux2, vmm_src); - - h->mov(ZRegD(IDX(vmm_src)), ZRegD(IDX(vmm_aux2))); + // Combine using mask + blend_with_mask(vmm_src, vmm_aux0); } template void jit_uni_eltwise_injector_t::swish_compute_vector_fwd( const TRegS &vmm_src) { - // Save src data on stack for later usage - h->sub_imm(h->X_SP, h->X_SP, vlen, h->X_TMP_0); - h->add_imm(h->X_TMP_0, h->X_SP, 0, h->X_TMP_1); - h->str(ZReg(IDX(vmm_src)), ptr(h->X_TMP_0)); + // IMPORTANT: we use vmm_aux2 to save src as logistic does not use it. + h->mov(ZRegD(vmm_aux2.getIdx()), ZRegD(IDX(vmm_src))); // x*alpha - h->fmul(vmm_src, vmm_src, ZRegS(IDX(table_val(alpha, z_tmp)))); + if (alpha_ != 1.f) { h->fmul(vmm_src, vmm_src, table_val(alpha, z_tmp)); } // sigmoid(x*alpha) logistic_compute_vector_fwd(vmm_src); // x*sigmoid(alpha*x) - h->add_imm(h->X_TMP_0, h->X_SP, 0, h->X_TMP_1); - h->ldr(ZReg(IDX(vmm_aux0)), ptr(h->X_TMP_0)); - h->add_imm(h->X_SP, h->X_SP, vlen, h->X_TMP_0); - - h->fmul(vmm_src, vmm_src, vmm_aux0); + h->fmul(vmm_src, vmm_src, vmm_aux2); } template void jit_uni_eltwise_injector_t::log_compute_vector_fwd( const TRegS &vmm_src) { - const auto &t0 = ZRegS(IDX(vmm_src)); - const auto &t1 = ZRegS(IDX(vmm_aux1)); - const auto &t2 = ZRegS(IDX(vmm_aux2)); - const auto &t3 = ZRegS(IDX(vmm_aux3)); - const auto &t4 = ZRegS(IDX(vmm_aux4)); + const auto &t0 = vmm_src; + const auto &t1 = vmm_aux1; + const auto &t2 = vmm_aux2; + const auto &t3 = vmm_aux3; + const auto &t4 = vmm_aux4; const auto &mask = p_tmp0.s; const auto &wt0 = h->W_TMP_0; const auto &xt0 = h->X_TMP_0; @@ -1098,8 +1059,7 @@ void jit_uni_eltwise_injector_t< table_off(gelu_erf_minimax_pol, coeff_idx * gelu_erf_n_polynomials), h->X_TMP_0); - h->ld1w(ZRegS(IDX(vmm_coeff)), p_all / T_z, - ptr(h->X_TMP_1, ZRegS(IDX(vmm_pol_idx)), SXTW)); + h->ld1w(vmm_coeff, p_all / T_z, ptr(h->X_TMP_1, vmm_pol_idx, SXTW)); }; // we use the erf function symmetry erf(-x) = -erf(x) @@ -1107,30 +1067,27 @@ void jit_uni_eltwise_injector_t< h->fabs(vmm_src_pos, p_all / T_z, vmm_src); // Compute indices for table lookup - h->add(vmm_indices, vmm_src_pos, - ZRegS(IDX(table_val(gelu_erf_idx_bias, z_tmp, 0)))); + h->add(vmm_indices, vmm_src_pos, table_val(gelu_erf_idx_bias, z_tmp, 0)); // An arithmetic shift is needed to properly map denormals to // their polynomial. we shift by 21 as we use 2 bits of mantissa // for indexing. - h->asr(ZRegS(IDX(vmm_indices)), ZRegS(IDX(vmm_indices)), 21); + h->asr(vmm_indices, vmm_indices, 21); // Apply special rules - h->smax(vmm_indices, p_all / T_z, - ZRegS(IDX(table_val(gelu_erf_one, z_tmp)))); - h->smin(vmm_indices, p_all / T_z, - ZRegS(IDX(table_val(gelu_erf_twenty_four, z_tmp)))); + h->smax(vmm_indices, p_all / T_z, table_val(gelu_erf_one, z_tmp)); + h->smin(vmm_indices, p_all / T_z, table_val(gelu_erf_twenty_four, z_tmp)); // We have to check // index = x_pos > rbound ? 23 : index; // for erf to return -1/1 when we should. h->fcmlt(p_mask.s, p_all / T_z, vmm_src_pos, - ZRegS(IDX(table_val(gelu_erf_rbound, z_tmp)))); + table_val(gelu_erf_rbound, z_tmp)); h->sel(vmm_indices, p_mask, vmm_indices, - ZRegS(IDX(table_val(gelu_erf_twenty_three, z_tmp)))); + table_val(gelu_erf_twenty_three, z_tmp)); // Adjusting indices - h->mul(ZRegS(IDX(vmm_indices)), sizeof(float)); + h->mul(vmm_indices, sizeof(float)); // Evaluate the polynomial gather_coefficient(vmm_pol, 5, vmm_indices); @@ -1146,9 +1103,9 @@ void jit_uni_eltwise_injector_t< h->eor(ZRegD(IDX(vmm_pol)), p_all / T_z, ZRegD(IDX(vmm_tmp))); // Compute the final output - h->fadd(vmm_pol, vmm_pol, ZRegS(IDX(table_val(one, z_tmp)))); + h->fadd(vmm_pol, vmm_pol, table_val(one, z_tmp)); h->fmul(vmm_src, p_all / T_z, vmm_pol); - h->fmul(vmm_src, vmm_src, ZRegS(IDX(table_val(half, z_tmp)))); + h->fmul(vmm_src, vmm_src, table_val(half, z_tmp)); } template void jit_uni_eltwise_injector_t::gelu_erf_compute_vector_fwd( @@ -1173,8 +1130,7 @@ void jit_uni_eltwise_injector_t::gelu_erf_compute_vector_fwd( h->mov(ZRegD(IDX(vmm_aux3)), ZRegD(IDX(vmm_src))); // x = s / sqrt(2) - h->fmul(vmm_src, vmm_src, - ZRegS(IDX(table_val(gelu_erf_one_over_sqrt_two, z_tmp)))); + h->fmul(vmm_src, vmm_src, table_val(gelu_erf_one_over_sqrt_two, z_tmp)); // abs(x) h->fabs(vmm_aux1, p_all / T_m, vmm_src); @@ -1200,17 +1156,13 @@ void jit_uni_eltwise_injector_t::gelu_erf_compute_vector_fwd( // compute polynomialial r table_val(gelu_erf_pol, vmm_aux1, 4); - h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, - ZRegS(IDX(table_val(gelu_erf_pol, z_tmp, 3)))); - h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, - ZRegS(IDX(table_val(gelu_erf_pol, z_tmp, 2)))); - h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, - ZRegS(IDX(table_val(gelu_erf_pol, z_tmp, 1)))); - h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, - ZRegS(IDX(table_val(gelu_erf_pol, z_tmp, 0)))); + h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, table_val(gelu_erf_pol, z_tmp, 3)); + h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, table_val(gelu_erf_pol, z_tmp, 2)); + h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, table_val(gelu_erf_pol, z_tmp, 1)); + h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, table_val(gelu_erf_pol, z_tmp, 0)); // erf = sign * (1 - r * t * exp(-x*x)) - h->fmad(vmm_src, p_all / T_m, vmm_aux1, ZRegS(IDX(table_val(one, z_tmp)))); + h->fmad(vmm_src, p_all / T_m, vmm_aux1, table_val(one, z_tmp)); h->eor(ZRegD(IDX(vmm_src)), ZRegD(IDX(vmm_src)), ZRegD(IDX(vmm_aux0))); // S = 0.5 * s @@ -1238,12 +1190,12 @@ void jit_uni_eltwise_injector_t::elu_compute_vector_bwd( // after exponentiation, get mask by comparing with exp(0)=1.f, not 0.f compute_cmp_mask(vmm_src, table_val(one, z_tmp), _cmp_gt_os); // R * alpha, then blend with 1.f - h->fmul(vmm_src, vmm_src, ZRegS(IDX(table_val(alpha, z_tmp)))); + h->fmul(vmm_src, vmm_src, table_val(alpha, z_tmp)); } else { // get mask of `d` > 0 compute_cmp_mask(vmm_src, table_val(zero, z_tmp), _cmp_gt_os); // R = `d` + alpha, then blend with 1.f - h->fadd(vmm_src, vmm_src, ZRegS(IDX(table_val(alpha, z_tmp)))); + h->fadd(vmm_src, vmm_src, table_val(alpha, z_tmp)); } blend_with_mask(vmm_src, table_val(one, z_tmp)); } @@ -1272,13 +1224,12 @@ void jit_uni_eltwise_injector_t::gelu_tanh_compute_vector_bwd( // keep G2 in a separate register h->mov(ZRegD(IDX(vmm_aux2)), ZRegD(IDX(table_val(gelu_tanh_fitting_const_times_three, z_tmp)))); - h->fmad(vmm_aux2, p_all / T_m, vmm_src, ZRegS(IDX(table_val(one, z_tmp)))); + h->fmad(vmm_aux2, p_all / T_m, vmm_src, table_val(one, z_tmp)); h->mov(ZRegD(IDX(vmm_aux1)), ZRegD(IDX(table_val(gelu_tanh_fitting_const, z_tmp)))); - h->fmad(vmm_src, p_all / T_m, vmm_aux1, ZRegS(IDX(table_val(one, z_tmp)))); - h->fmul(vmm_aux0, vmm_aux0, - ZRegS(IDX(table_val(gelu_tanh_sqrt_two_over_pi, z_tmp)))); + h->fmad(vmm_src, p_all / T_m, vmm_aux1, table_val(one, z_tmp)); + h->fmul(vmm_aux0, vmm_aux0, table_val(gelu_tanh_sqrt_two_over_pi, z_tmp)); h->fmul(vmm_src, vmm_src, vmm_aux0); h->fmul(vmm_aux2, vmm_aux2, vmm_aux0); @@ -1298,11 +1249,11 @@ void jit_uni_eltwise_injector_t::gelu_tanh_compute_vector_bwd( // 1) R = G2 * (1 - T) = G2 - G2 * T h->fmls(vmm_aux2, p_all / T_m, vmm_aux2, vmm_src); // 2) Q = 1 + T - h->fadd(vmm_src, vmm_src, ZRegS(IDX(table_val(one, z_tmp)))); + h->fadd(vmm_src, vmm_src, table_val(one, z_tmp)); // 3) res = Q * (1 + R) = Q + Q * R h->fmla(vmm_src, p_all / T_m, vmm_src, vmm_aux2); - h->fmul(vmm_src, vmm_src, ZRegS(IDX(table_val(half, z_tmp)))); + h->fmul(vmm_src, vmm_src, table_val(half, z_tmp)); } template @@ -1418,7 +1369,7 @@ template void jit_uni_eltwise_injector_t::swish_compute_vector_bwd( const TRegS &vmm_src) { // R = alpha * s - h->fmul(vmm_src, vmm_src, ZRegS(IDX(table_val(alpha, z_tmp)))); + h->fmul(vmm_src, vmm_src, table_val(alpha, z_tmp)); // Save R on stack for later usage h->sub_imm(h->X_SP, h->X_SP, vlen, h->X_TMP_0); @@ -1475,8 +1426,7 @@ template void jit_uni_eltwise_injector_t::gelu_erf_compute_vector_bwd( const TRegS &vmm_src) { // R = s / sqrt(2) - h->fmul(vmm_src, vmm_src, - ZRegS(IDX(table_val(gelu_erf_one_over_sqrt_two, z_tmp)))); + h->fmul(vmm_src, vmm_src, table_val(gelu_erf_one_over_sqrt_two, z_tmp)); // Save R on stack for later usage h->sub_imm(h->X_SP, h->X_SP, vlen, h->X_TMP_0); @@ -1492,8 +1442,7 @@ void jit_uni_eltwise_injector_t::gelu_erf_compute_vector_bwd( // T = R / sqrt(pi) * Q h->add_imm(h->X_TMP_0, h->X_SP, 0, h->X_TMP_1); h->ldr(ZReg(IDX(vmm_aux2)), ptr(h->X_TMP_0)); - h->fmul(vmm_aux2, vmm_aux2, - ZRegS(IDX(table_val(gelu_erf_one_over_sqrt_pi, z_tmp)))); + h->fmul(vmm_aux2, vmm_aux2, table_val(gelu_erf_one_over_sqrt_pi, z_tmp)); h->fmul(vmm_aux2, vmm_aux2, vmm_src); // -Q @@ -1525,23 +1474,19 @@ void jit_uni_eltwise_injector_t::gelu_erf_compute_vector_bwd( // compute polynomial r h->mov(ZRegD(IDX(vmm_aux1)), ZRegD(IDX(table_val(gelu_erf_pol, z_tmp, 4)))); - h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, - ZRegS(IDX(table_val(gelu_erf_pol, z_tmp, 3)))); - h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, - ZRegS(IDX(table_val(gelu_erf_pol, z_tmp, 2)))); - h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, - ZRegS(IDX(table_val(gelu_erf_pol, z_tmp, 1)))); - h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, - ZRegS(IDX(table_val(gelu_erf_pol, z_tmp, 0)))); + h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, table_val(gelu_erf_pol, z_tmp, 3)); + h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, table_val(gelu_erf_pol, z_tmp, 2)); + h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, table_val(gelu_erf_pol, z_tmp, 1)); + h->fmad(vmm_aux1, p_all / T_m, vmm_aux4, table_val(gelu_erf_pol, z_tmp, 0)); // erf = sign * (1 - r * t * exp(-x*x)) - h->fmad(vmm_src, p_all / T_m, vmm_aux1, ZRegS(IDX(table_val(one, z_tmp)))); + h->fmad(vmm_src, p_all / T_m, vmm_aux1, table_val(one, z_tmp)); h->eor(ZRegD(IDX(vmm_src)), ZRegD(IDX(vmm_src)), ZRegD(IDX(vmm_aux0))); // P = T + 0.5 - h->fadd(vmm_aux2, vmm_aux2, ZRegS(IDX(table_val(half, z_tmp)))); + h->fadd(vmm_aux2, vmm_aux2, table_val(half, z_tmp)); // res = P + 0.5 * erf - h->fmla(vmm_aux2, p_all / T_m, vmm_src, ZRegS(IDX(table_val(half, z_tmp)))); + h->fmla(vmm_aux2, p_all / T_m, vmm_src, table_val(half, z_tmp)); h->mov(ZRegD(IDX(vmm_src)), ZRegD(IDX(vmm_aux2))); } @@ -1607,22 +1552,25 @@ size_t jit_uni_eltwise_injector_t::aux_vecs_count() { case eltwise_relu_use_dst_for_bwd: case eltwise_relu: return (alpha_ == 0.f) ? 2 : 3; case eltwise_elu_use_dst_for_bwd: - case eltwise_elu: return 6; /* = exp + 2 */ + case eltwise_elu: return (isa == asimd) ? 7 : 5; /* = exp + 2 */ case eltwise_tanh_use_dst_for_bwd: - case eltwise_tanh: return 9; + case eltwise_tanh: return (isa == asimd) ? 7 : 5; /* = exp + 2 */ case eltwise_square: return 0; case eltwise_abs: return 2; case eltwise_sqrt_use_dst_for_bwd: case eltwise_sqrt: return 0; case eltwise_linear: return 2; - case eltwise_soft_relu: return 5; - case eltwise_mish: return 5; /* = exp + 1 */ + case eltwise_soft_relu: return (isa == asimd) ? 7 : 5; + case eltwise_mish: return (isa == asimd) ? 6 : 4; /* = exp + 1 */ case eltwise_logistic_use_dst_for_bwd: - case eltwise_logistic: return 5; /* = exp + 1 */ + case eltwise_logistic: + return (isa == asimd) ? 6 : 3; /* = exp (+ 1) */ case eltwise_exp_use_dst_for_bwd: case eltwise_exp: return (isa == asimd) ? 5 : 3; - case eltwise_gelu_tanh: return 9; /* = tanh */ - case eltwise_swish: return 6; /* = logistic */ + case eltwise_gelu_tanh: + return (isa == asimd) ? 8 : 6; /* = tanh + 1 */ + case eltwise_swish: + return (isa == asimd) ? 7 : 4; /* = logistic + 1 */ case eltwise_log: return 6; case eltwise_clip: case eltwise_clip_v2_use_dst_for_bwd: @@ -1776,8 +1724,7 @@ void jit_uni_eltwise_injector_t::compute_body( } } if (scale_ != 1.f) { - h->fmul(ZRegS(IDX(TRegS(idx))), ZRegS(IDX(TRegS(idx))), - ZRegS(IDX(table_val(scale, vmm_mask)))); + h->fmul(TRegS(idx), TRegS(idx), table_val(scale, vmm_mask)); } }); } @@ -2520,7 +2467,7 @@ size_t jit_uni_eltwise_injector_t::get_vec_len() { template <> void jit_uni_eltwise_injector_t::exp_compute_vector_fwd( - const TRegS &vmm_src) { + const TRegS &vmm_src, float min_input, float max_input) { /* * Based on the expf implementation from Arm Optimized Routines (AOR) * Accuracy: maxerr ≈ 1.95 ULP. @@ -2536,17 +2483,17 @@ void jit_uni_eltwise_injector_t::exp_compute_vector_fwd( */ Xbyak_aarch64::Label L_done, L_special; - const auto &t0 = VReg4S(vmm_src.getIdx()); - const auto &t1 = VReg4S(vmm_aux0.getIdx()); - const auto &t2 = VReg4S(vmm_aux1.getIdx()); - const auto &t3 = VReg4S(vmm_aux2.getIdx()); - const auto &t4 = VReg4S(vmm_aux3.getIdx()); - const auto &t_tmp = VReg4S(vmm_tmp.getIdx()); + const auto &t0 = vmm_src; + const auto &t1 = vmm_aux0; + const auto &t2 = vmm_aux1; + const auto &t3 = vmm_aux2; + const auto &t4 = vmm_aux3; + const auto &t_tmp = vmm_tmp; const float special_case_input_threshold = 126.5f * logf(2.0f); // ~87.6831f const float ln_flt_min = logf(FLT_MIN); // ~-87.3365f - bool need_clamp = min_input_ < ln_flt_min; - bool need_special_case = max_input_ >= special_case_input_threshold; + bool need_clamp = min_input < ln_flt_min; + bool need_special_case = max_input >= special_case_input_threshold; if (!need_special_case && need_clamp) { // Clamp x to avoid overflow of f32 exponent bits @@ -2695,6 +2642,86 @@ void jit_uni_eltwise_injector_t::relu_compute_vector_fwd( h->fmaxnm(vmm_src, vmm_src, vmm_aux0); } +template <> +void jit_uni_eltwise_injector_t::elu_compute_vector_fwd( + const TRegS &vmm_src) { + // IMPORTANT: keep original src in vmm_src for blending positive lanes. + h->mov(VReg16B(vmm_aux5.getIdx()), VReg16B(vmm_src.getIdx())); + + // Compute exponent + exp_compute_vector_fwd(vmm_aux5, min_input_, 0.f); + + // alpha * (exp(x) - 1) = -alpha - (-alpha * exp(x)) + if (alpha_ != 1.f) { + h->fneg(vmm_aux1, vmm_aux4); + h->fmls(vmm_aux1, vmm_aux1, vmm_aux5); + } else { + h->fsub(vmm_aux1, vmm_aux5, vmm_aux4); + } + + // Combine with mask + h->fcmle(vmm_aux0, vmm_src, 0.); + h->bit(VReg16B(vmm_src.getIdx()), VReg16B(vmm_aux1.getIdx()), + VReg16B(vmm_aux0.getIdx())); +} + +template <> +void jit_uni_eltwise_injector_t::tanh_compute_vector_fwd( + const TRegS &vmm_src) { + // tanh(x) = x(1 + (-1/3)x^2) for |x| < tanh_range + // tanh(x) = 1 - 2/(1 + exp(2 x)) for otherwise + + const auto &t0 = vmm_src; + const auto &t1 = vmm_aux0; + const auto &t2 = vmm_aux1; + const auto &t3 = vmm_aux4; + const auto &mask = vmm_aux5; + + // Make mask for small x + h->mov(VReg16B(t3.getIdx()), VReg16B(t0.getIdx())); + h->facge(mask, table_val(tanh_range, z_tmp), t0); + + // 2x + h->fadd(t0, t0, t0); + // exp(2x) + exp_compute_vector_fwd(t0); + + // 1+exp(2x) + h->fadd(t0, t0, table_val(one, t2)); + // 2/(1+exp(2x)) + h->fdiv(t0, table_val(two, t1), t0); + // 1-2/(1+exp(2x)) + h->fsub(t0, t2, t0); + + // tanh(x) = x(1 - x^2/3) for |x| < tanh_range + h->fmul(t1, t3, t3); + h->fmla(t2, table_val(tanh_m1d3, z_tmp), t1); + h->fmul(t1, t3, t2); + // Select the correct value according to mask + h->bit(VReg16B(t0.getIdx()), VReg16B(t1.getIdx()), VReg16B(mask.getIdx())); +} + +template <> +void jit_uni_eltwise_injector_t::gelu_tanh_compute_vector_fwd( + const TRegS &vmm_src) { + // IMPORTANT: we use vmm_aux6 to save src as tanh does not use it. + h->mov(VReg16B(IDX(vmm_aux6)), VReg16B(IDX(vmm_src))); + + // Compute G(x) = sqrt_root_two_over_pi * x * (1 + fitting_const * x * x) + h->fmul(vmm_src, vmm_src, vmm_src); + h->fmla(table_val(one, vmm_aux1), + table_val(gelu_tanh_fitting_const, vmm_aux0), vmm_src); + h->fmul(vmm_aux2, vmm_aux6, table_val(gelu_tanh_sqrt_two_over_pi, z_tmp)); + h->fmul(vmm_src, vmm_aux1, vmm_aux2); + + // Compute tanh(G(x)) + tanh_compute_vector_fwd(vmm_src); + + // Compute 0.5 * x * (1 + tanh(G(x))) + h->fmla(vmm_aux6, vmm_src, vmm_aux6); + h->fmul(vmm_src, vmm_aux6, table_val(half, z_tmp)); +} + template <> void jit_uni_eltwise_injector_t::abs_compute_vector_fwd( const TReg &vmm_src) { @@ -2733,6 +2760,40 @@ void jit_uni_eltwise_injector_t::clip_compute_vector_fwd( } } +template <> +void jit_uni_eltwise_injector_t::mish_compute_vector_fwd( + const TRegS &vmm_src) { + // An equation other than mish(x) = x*tanh(srelu(x)) was used + // to calculate mish, but it should be remembered that it is equivalent + // equation, it uses the following rule: + // tanh(x) = (e^x - e^-x) / (e^x + e^-x), + // hence the equation for mish can take the form: + // mish(x) = x * ((e^x + 1)^2 - 1)/((e^x + 1)^2 + 1). + // This option was chosen because computing tanh requires more registers + // than exp, and also requires more constants to be stored in memory, + // making the algorithm slower. + + // IMPORTANT: we use vmm_aux4 to save src as exp does not use it. + h->mov(VReg16B(vmm_aux4.getIdx()), + VReg16B(vmm_src.getIdx())); // vmm_aux4 = x + + // mish uses (exp(x)+1)^2; + // avoid overflow by x <= log(sqrt(FLT_MAX)) + h->fminnm( + vmm_src, vmm_src, table_val(fwd_mish_max_x_for_equation_f, z_tmp)); + exp_compute_vector_fwd(vmm_src, min_input_, 0.5f * logf(FLT_MAX)); + + // (e^x+1)^2 + h->fadd(vmm_src, vmm_src, table_val(one, z_tmp)); + h->fmul(vmm_src, vmm_src, vmm_src); + + // x * ((e^x + 1)^2 - 1) / ((e^x + 1)^2 + 1) + h->fsub(vmm_aux0, vmm_src, z_tmp); + h->fadd(vmm_aux1, vmm_src, z_tmp); + h->fdiv(vmm_src, vmm_aux0, vmm_aux1); + h->fmul(vmm_src, vmm_src, vmm_aux4); +} + template <> void jit_uni_eltwise_injector_t::hardsigmoid_compute_vector_fwd( const TRegS &vmm_src) { @@ -2753,6 +2814,168 @@ void jit_uni_eltwise_injector_t::hardswish_compute_vector_fwd( h->fmul(vmm_src, vmm_src, vmm_aux2); } +template <> +void jit_uni_eltwise_injector_t::soft_relu_compute_vector_fwd( + const TRegS &vmm_src) { + + const auto &oneS = table_val(one, vmm_aux4); + const auto &halfS = table_val(half, vmm_aux5); + + // alpha scaling + if (alpha_ == 1.f) { + // Nothing to do + } else if (alpha_ == -1) { + h->fneg(vmm_src, vmm_src); + } else { + h->fmul(vmm_src, vmm_src, table_val(alpha, z_tmp)); + } + + // ln(1 + exp(x)) = + // = ln(1 + exp(n * ln(2) + r)) // divide x by ln(2) and get quot and rem + // = ln(1 + 2^n * exp(r)) // simplify the exp(n*ln(2)) expression + // = ln(2 ^ 0 + 2^n * exp(r)) // note 1 = 2^0 + // = ln(2 ^ (n - n) + 2^n * exp(r)) // 2^0 = 2^(n-n) + // = ln(2 ^ n * (2^-n + exp(r))) // factorize with 2^n + // = n * ln(2) + ln(2^-n + exp(r)) // take the 2^n factor out of the ln + + // Keep src for further computations + h->mov(VReg16B(IDX(vmm_aux2)), VReg16B(IDX(vmm_src))); + + h->fminnm(vmm_src, table_val(exp_ln_flt_max_f, z_tmp), vmm_src); + h->fmaxnm(vmm_src, table_val(exp_ln_flt_min_f, z_tmp), vmm_src); + h->mov(VReg16B(IDX(vmm_aux1)), VReg16B(IDX(vmm_src))); + + // Calculate exp(x) + // fx = x * log2ef + 0.5 + h->fmul(vmm_src, vmm_src, table_val(exp_log2ef, z_tmp)); + h->fadd(vmm_src, vmm_src, halfS); + + // tmp = floorf(fx) + h->frintm(vmm_aux0, vmm_src); + + // Keep vmm_src = fx for further computations + h->mov(VReg16B(IDX(vmm_src)), VReg16B(IDX(vmm_aux0))); + + // x = x - fx * ln2 + h->fmul(vmm_aux0, vmm_aux0, table_val(ln2f, z_tmp)); + h->fsub(vmm_aux1, vmm_aux1, vmm_aux0); + // Compute exponent polynomial + table_val(exp_pol, z_tmp, 4); + h->fmla(table_val(exp_pol, vmm_aux3, 3), z_tmp, vmm_aux1); + h->fmla(table_val(exp_pol, z_tmp, 2), vmm_aux3, vmm_aux1); + h->fmla(table_val(exp_pol, vmm_aux3, 1), z_tmp, vmm_aux1); + h->fmla(table_val(exp_pol, z_tmp, 0), vmm_aux3, vmm_aux1); + h->mov(VReg16B(IDX(vmm_aux3)), VReg16B(IDX(oneS))); + h->fmla(vmm_aux3, z_tmp, vmm_aux1); + + // We do not count 2^-n here, because n can reach 128 and 2^(-128) is not + // representable by fp32, so to get around this problem, instead of computing + // 2^-n + exp(r) will be counted (2^-(n-1) + 2*exp(r))/2, because 2^(-127) + // and 2 are numbers representable in fp32. + + // Compute 2^-(n-1) + h->fsub(vmm_aux1, vmm_src, oneS); + h->fneg(vmm_aux1, vmm_aux1); + h->fcvtzs(vmm_aux1, vmm_aux1); + h->add(vmm_aux1, vmm_aux1, table_val(exponent_bias, z_tmp)); + h->shl(vmm_aux1, vmm_aux1, n_mantissa_bits); + + // Calculate ln(1 + y) + h->fadd(vmm_aux3, vmm_aux3, vmm_aux3); // 2*exp(r) + h->fadd(vmm_aux3, vmm_aux3, + vmm_aux1); // 2^-(n-1) + 2*exp(r) + h->fmul(vmm_aux3, vmm_aux3, halfS); // (2^-(n-1) + 2*exp(r))/2 + + // frexp() + h->ushr(vmm_src, vmm_aux3, n_mantissa_bits); + h->scvtf(vmm_src, vmm_src); + // Got n. where n is x = 2^n * y. y = 0.5 .. 1 + h->fsub(vmm_src, vmm_src, table_val(soft_relu_one_twenty_six, z_tmp)); + + // And with mask (to get 0.5 * mantissa) + h->and_(VReg16B(IDX(vmm_aux3)), VReg16B(IDX(vmm_aux3)), + VReg16B(IDX(table_val(soft_relu_mantissa_sign_mask, z_tmp)))); + // Got y. (mantisa) 0.5 < y < 1 (or with (to get 0.5 * mantissa)) + h->orr(VReg16B(IDX(vmm_aux3)), VReg16B(IDX(vmm_aux3)), VReg16B(IDX(halfS))); + // y = y - 1 + h->fsub(vmm_aux3, vmm_aux3, oneS); + + // Compute log1p polynomial + table_val(soft_relu_pol, vmm_aux1, 8); + h->fmla(table_val(soft_relu_pol, z_tmp, 7), vmm_aux1, vmm_aux3); + h->fmla(table_val(soft_relu_pol, vmm_aux1, 6), z_tmp, vmm_aux3); + h->fmla(table_val(soft_relu_pol, z_tmp, 5), vmm_aux1, vmm_aux3); + h->fmla(table_val(soft_relu_pol, vmm_aux1, 4), z_tmp, vmm_aux3); + h->fmla(table_val(soft_relu_pol, z_tmp, 3), vmm_aux1, vmm_aux3); + h->fmla(table_val(soft_relu_pol, vmm_aux1, 2), z_tmp, vmm_aux3); + h->fmla(table_val(soft_relu_pol, z_tmp, 1), vmm_aux1, vmm_aux3); + h->fmla(table_val(soft_relu_pol, vmm_aux1, 0), z_tmp, vmm_aux3); + + // Calculate ln(2) * n + h->fmul(vmm_src, vmm_src, table_val(ln2f, z_tmp)); + h->fadd(vmm_src, vmm_src, vmm_aux1); + h->fadd(vmm_src, vmm_src, vmm_aux0); + + // Get vmm_mask = src > max logf + // y = (x < max log f) ? soft_relu(x) : x + h->fcmgt(vmm_aux1, vmm_aux2, table_val(exp_ln_flt_max_f, z_tmp)); + h->bit(VReg16B(IDX(vmm_src)), VReg16B(IDX(vmm_aux2)), + VReg16B(IDX(vmm_aux1))); + if (alpha_ == 1.f) { // standard soft_relu case + // Skip an instruction. + } else if (alpha_ == -1) { // logsigmoid case + h->fneg(vmm_src, vmm_src); + } else if (alpha_ == 0.5) { + h->fadd(vmm_src, vmm_src, vmm_src); + } else if (alpha_ == 2.0) { + h->fmul(vmm_src, vmm_src, halfS); + } else { // General case. + h->fdiv(vmm_src, vmm_src, table_val(alpha, z_tmp)); + } +} + +template <> +void jit_uni_eltwise_injector_t::logistic_compute_vector_fwd( + const TRegS &vmm_src) { + // To avoid exp(x) overflow happening at x > logf(FLT_MAX), negate positive, + // compute exp(x), where x <= 0 to get 0 <= exp(x) <= 1 and restore value + // sign at the end. This is possible due to logistic being a symmetric function. + // IMPORTANT: we use vmm_aux4 for the mask as exp_compute does not use it. + h->fcmgt(vmm_aux4, vmm_src, 0.); + + // Force input negative + h->orr(VReg16B(IDX(vmm_src)), VReg16B(IDX(vmm_src)), + VReg16B(IDX(table_val(sign_mask, z_tmp)))); + + // Compute exponent + exp_compute_vector_fwd(vmm_src, min_input_, 0.f); + + // (exp(x) + 1) + h->fadd(vmm_aux0, vmm_src, table_val(one, z_tmp)); + // y = exp(x) / (exp(x) + 1) + h->fdiv(vmm_src, vmm_src, vmm_aux0); + + // Now we have to apply the "symmetry" based on original sign + h->fsub(vmm_aux0, z_tmp, vmm_src); + + // Combine with mask + h->bit(VReg16B(IDX(vmm_src)), VReg16B(IDX(vmm_aux0)), + VReg16B(IDX(vmm_aux4))); +} + +template <> +void jit_uni_eltwise_injector_t::swish_compute_vector_fwd( + const TRegS &vmm_src) { + // IMPORTANT: we use vmm_aux5 to save src as logistic does not use it. + h->mov(VReg16B(vmm_aux5.getIdx()), VReg16B(vmm_src.getIdx())); + // x*alpha + if (alpha_ != 1.f) { h->fmul(vmm_src, vmm_src, table_val(alpha, z_tmp)); } + // sigmoid(x*alpha) + logistic_compute_vector_fwd(vmm_src); + // x*sigmoid(alpha*x) + h->fmul(vmm_src, vmm_src, vmm_aux5); +} + template <> void jit_uni_eltwise_injector_t::round_compute_vector_fwd( const TRegS &vmm_src) { @@ -2784,13 +3007,6 @@ void jit_uni_eltwise_injector_t::load_vector( #define DEFINE_ASIMD_EMPTY_FUNC(func_name) \ template <> \ void jit_uni_eltwise_injector_t::func_name(const TRegS &) {} -DEFINE_ASIMD_EMPTY_FUNC(mish_compute_vector_fwd); -DEFINE_ASIMD_EMPTY_FUNC(elu_compute_vector_fwd); -DEFINE_ASIMD_EMPTY_FUNC(tanh_compute_vector_fwd); -DEFINE_ASIMD_EMPTY_FUNC(gelu_tanh_compute_vector_fwd); -DEFINE_ASIMD_EMPTY_FUNC(soft_relu_compute_vector_fwd); -DEFINE_ASIMD_EMPTY_FUNC(logistic_compute_vector_fwd); -DEFINE_ASIMD_EMPTY_FUNC(swish_compute_vector_fwd); DEFINE_ASIMD_EMPTY_FUNC(log_compute_vector_fwd); DEFINE_ASIMD_EMPTY_FUNC(gelu_erf_compute_vector_fwd); DEFINE_ASIMD_EMPTY_FUNC(tanh_polynomial_approx_compute_vector_fwd); diff --git a/src/cpu/aarch64/injectors/jit_uni_eltwise_injector.hpp b/src/cpu/aarch64/injectors/jit_uni_eltwise_injector.hpp index e6cfe8c8e44..82866305213 100644 --- a/src/cpu/aarch64/injectors/jit_uni_eltwise_injector.hpp +++ b/src/cpu/aarch64/injectors/jit_uni_eltwise_injector.hpp @@ -225,6 +225,8 @@ struct jit_uni_eltwise_injector_t { size_t get_vec_len(); void exp_compute_vector_fwd(const TRegS &vmm_src); + void exp_compute_vector_fwd( + const TRegS &vmm_src, float min_input, float max_input); void relu_compute_vector_fwd(const TRegS &vmm_src); void relu_zero_ns_compute_vector_fwd(const TReg &vmm_src); void elu_compute_vector_fwd(const TRegS &vmm_src);