Skip to content

8365772: RISC-V: correctly prereserve NaN payload when converting from float to float16 in vector way #26883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/hotspot/cpu/riscv/assembler_riscv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1988,6 +1988,7 @@ enum VectorMask {

// Vector Narrowing Integer Right Shift Instructions
INSN(vnsra_wi, 0b1010111, 0b011, 0b101101);
INSN(vnsrl_wi, 0b1010111, 0b011, 0b101100);

#undef INSN

Expand Down
66 changes: 51 additions & 15 deletions src/hotspot/cpu/riscv/c2_MacroAssembler_riscv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2494,36 +2494,72 @@ static void float_to_float16_v_slow_path(C2_MacroAssembler& masm,

// mul is already set to mf2 in float_to_float16_v.

// preserve the payloads of non-canonical NaNs.
__ vnsra_wi(dst, src, 13, Assembler::v0_t);

// preserve the sign bit.
__ vnsra_wi(tmp, src, 26, Assembler::v0_t);
__ vsll_vi(tmp, tmp, 10, Assembler::v0_t);
__ mv(t0, 0x3ff);
__ vor_vx(tmp, tmp, t0, Assembler::v0_t);

// get the result by merging sign bit and payloads of preserved non-canonical NaNs.
__ vand_vv(dst, dst, tmp, Assembler::v0_t);
// Float (32 bits)
// Bit: 31 30 to 23 22 to 0
// +---+------------------+-----------------------------+
// | S | Exponent | Mantissa (Fraction) |
// +---+------------------+-----------------------------+
// 1 bit 8 bits 23 bits
//
// Float (16 bits)
// Bit: 15 14 to 10 9 to 0
// +---+----------------+------------------+
// | S | Exponent | Mantissa |
// +---+----------------+------------------+
// 1 bit 5 bits 10 bits
const int fp_sign_bits = 1;
const int fp32_bits = 32;
const int fp32_mantissa_2nd_part_bits = 9;
const int fp32_mantissa_3rd_part_bits = 4;
const int fp16_exponent_bits = 5;
const int fp16_mantissa_bits = 10;

// preserve the sign bit and exponent.
__ vnsra_wi(dst, src, fp32_bits - fp_sign_bits - fp16_exponent_bits, Assembler::v0_t);
__ vsll_vi(dst, dst, fp16_mantissa_bits, Assembler::v0_t);

// Preserve high order bit of float NaN in the
// binary16 result NaN (tenth bit); OR in remaining
// bits into lower 9 bits of binary 16 significand.
// | (doppel & 0x007f_e000) >> 13 // 10 bits
// | (doppel & 0x0000_1ff0) >> 4 // 9 bits
// | (doppel & 0x0000_000f)); // 4 bits
//
// Check j.l.Float.floatToFloat16 for more information.
// 10 bits
__ vnsrl_wi(tmp, src, fp32_mantissa_2nd_part_bits + fp32_mantissa_3rd_part_bits, Assembler::v0_t);
__ mv(t0, 0x3ff); // retain first part of mantissa in a float 32
__ vand_vx(tmp, tmp, t0, Assembler::v0_t);
__ vor_vv(dst, dst, tmp, Assembler::v0_t);
// 9 bits
__ vnsrl_wi(tmp, src, fp32_mantissa_3rd_part_bits, Assembler::v0_t);
__ mv(t0, 0x1ff); // retain second part of mantissa in a float 32
__ vand_vx(tmp, tmp, t0, Assembler::v0_t);
__ vor_vv(dst, dst, tmp, Assembler::v0_t);
// 4 bits
// Narrow shift is necessary to move data from 32 bits element to 16 bits element in vector register.
__ vnsrl_wi(tmp, src, 0, Assembler::v0_t);
__ vand_vi(tmp, tmp, 0xf, Assembler::v0_t);
__ vor_vv(dst, dst, tmp, Assembler::v0_t);

__ j(stub.continuation());
#undef __
}

// j.l.Float.float16ToFloat
void C2_MacroAssembler::float_to_float16_v(VectorRegister dst, VectorRegister src, VectorRegister vtmp,
Register tmp, uint vector_length) {
void C2_MacroAssembler::float_to_float16_v(VectorRegister dst, VectorRegister src,
VectorRegister vtmp, Register tmp, uint vector_length) {
assert_different_registers(dst, src, vtmp);

auto stub = C2CodeStub::make<VectorRegister, VectorRegister, VectorRegister>
(dst, src, vtmp, 28, float_to_float16_v_slow_path);
(dst, src, vtmp, 56, float_to_float16_v_slow_path);

// On riscv, NaN needs a special process as vfncvt_f_f_w does not work in that case.

vsetvli_helper(BasicType::T_FLOAT, vector_length, Assembler::m1);

// check whether there is a NaN.
// replace v_fclass with vmseq_vv as performance optimization.
// replace v_fclass with vmfne_vv as performance optimization.
vmfne_vv(v0, src, src);
vcpop_m(t0, v0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

/**
* @test
* @key randomness
* @bug 8320646
* @summary Auto-vectorize Float.floatToFloat16, Float.float16ToFloat APIs, with NaN
* @requires vm.compiler2.enabled
Expand All @@ -37,9 +38,11 @@
package compiler.vectorization;

import java.util.HexFormat;
import java.util.Random;

import compiler.lib.ir_framework.*;
import jdk.test.lib.Asserts;
import jdk.test.lib.Utils;

public class TestFloatConversionsVectorNaN {
private static final int ARRLEN = 1024;
Expand Down Expand Up @@ -79,14 +82,16 @@ public void test_float_float16(short[] sout, float[] finp) {

@Run(test = {"test_float_float16"}, mode = RunMode.STANDALONE)
public void kernel_test_float_float16() {
Random rand = Utils.getRandomInstance();
int errno = 0;
finp = new float[ARRLEN];
sout = new short[ARRLEN];

// Setup
for (int i = 0; i < ARRLEN; i++) {
if (i%39 == 0) {
int x = 0x7f800000 + ((i/39) << 13);
if (i%3 == 0) {
int shift = rand.nextInt(13+1);
int x = 0x7f800000 + ((i/39) << shift);
x = (i%2 == 0) ? x : (x | 0x80000000);
finp[i] = Float.intBitsToFloat(x);
} else {
Expand Down Expand Up @@ -128,7 +133,8 @@ public void kernel_test_float_float16() {

static int assertEquals(int idx, float f, short expected, short actual) {
HexFormat hf = HexFormat.of();
String msg = "floatToFloat16 wrong result: idx: " + idx + ", \t" + f +
String msg = "floatToFloat16 wrong result: idx: " + idx +
", \t" + f + ", hex: " + Integer.toHexString(Float.floatToRawIntBits(f)) +
",\t expected: " + hf.toHexDigits(expected) +
",\t actual: " + hf.toHexDigits(actual);
if ((expected & 0x7c00) != 0x7c00) {
Expand Down Expand Up @@ -167,7 +173,7 @@ public void kernel_test_float16_float() {

// Setup
for (int i = 0; i < ARRLEN; i++) {
if (i%39 == 0) {
if (i%3 == 0) {
int x = 0x7c00 + i;
x = (i%2 == 0) ? x : (x | 0x8000);
sinp[i] = (short)x;
Expand Down