Skip to content

Commit 638383c

Browse files
authored
[SPIRV] Support G_IS_FPCLASS (#148637)
This commit adds custom legalization for G_IS_FPCLASS, corresponding to the @llvm.is.fpclass intrinsic. The lowering strategy is essentially copied and adjusted from the target-agnostic LegalizeHelper::lowerISFPCLASS legalization. The reason we can't just use that directly is that the series of instruction it expands to aren't logged in the SPIR-V backend's register/type book-keeping, leading to issues later on in the compilation process. As such the code introduced here was copied from the aforementioned helper method, with some notable changes: * Each new instruction's destination register must have a SPIR-V type registered to it. * Instead of a COPY from the floating-point type to integer, we issue a SPIR-V OpBitcast directly. The backend doesn't currently appear to handle bitcast-like COPYs. Fixes #72862
1 parent ffcee26 commit 638383c

File tree

3 files changed

+654
-3
lines changed

3 files changed

+654
-3
lines changed

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

Lines changed: 242 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
335335
getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
336336
}
337337

338+
getActionDefinitionsBuilder(G_IS_FPCLASS).custom();
339+
338340
getLegacyLegalizerInfo().computeTables();
339341
verify(*ST.getInstrInfo());
340342
}
@@ -355,9 +357,14 @@ static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
355357
bool SPIRVLegalizerInfo::legalizeCustom(
356358
LegalizerHelper &Helper, MachineInstr &MI,
357359
LostDebugLocObserver &LocObserver) const {
358-
auto Opc = MI.getOpcode();
359360
MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
360-
if (Opc == TargetOpcode::G_ICMP) {
361+
switch (MI.getOpcode()) {
362+
default:
363+
// TODO: implement legalization for other opcodes.
364+
return true;
365+
case TargetOpcode::G_IS_FPCLASS:
366+
return legalizeIsFPClass(Helper, MI, LocObserver);
367+
case TargetOpcode::G_ICMP: {
361368
assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
362369
auto &Op0 = MI.getOperand(2);
363370
auto &Op1 = MI.getOperand(3);
@@ -378,6 +385,238 @@ bool SPIRVLegalizerInfo::legalizeCustom(
378385
}
379386
return true;
380387
}
381-
// TODO: implement legalization for other opcodes.
388+
}
389+
}
390+
391+
// Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted
392+
// to ensure that all instructions created during the lowering have SPIR-V types
393+
// assigned to them.
394+
bool SPIRVLegalizerInfo::legalizeIsFPClass(
395+
LegalizerHelper &Helper, MachineInstr &MI,
396+
LostDebugLocObserver &LocObserver) const {
397+
auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
398+
FPClassTest Mask = static_cast<FPClassTest>(MI.getOperand(2).getImm());
399+
400+
auto &MIRBuilder = Helper.MIRBuilder;
401+
auto &MF = MIRBuilder.getMF();
402+
MachineRegisterInfo &MRI = MF.getRegInfo();
403+
404+
Type *LLVMDstTy =
405+
IntegerType::get(MIRBuilder.getContext(), DstTy.getScalarSizeInBits());
406+
if (DstTy.isVector())
407+
LLVMDstTy = VectorType::get(LLVMDstTy, DstTy.getElementCount());
408+
SPIRVType *SPIRVDstTy = GR->getOrCreateSPIRVType(
409+
LLVMDstTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
410+
/*EmitIR*/ true);
411+
412+
unsigned BitSize = SrcTy.getScalarSizeInBits();
413+
const fltSemantics &Semantics = getFltSemanticForLLT(SrcTy.getScalarType());
414+
415+
LLT IntTy = LLT::scalar(BitSize);
416+
Type *LLVMIntTy = IntegerType::get(MIRBuilder.getContext(), BitSize);
417+
if (SrcTy.isVector()) {
418+
IntTy = LLT::vector(SrcTy.getElementCount(), IntTy);
419+
LLVMIntTy = VectorType::get(LLVMIntTy, SrcTy.getElementCount());
420+
}
421+
SPIRVType *SPIRVIntTy = GR->getOrCreateSPIRVType(
422+
LLVMIntTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
423+
/*EmitIR*/ true);
424+
425+
// Clang doesn't support capture of structured bindings:
426+
LLT DstTyCopy = DstTy;
427+
const auto assignSPIRVTy = [&](MachineInstrBuilder &&MI) {
428+
// Assign this MI's (assumed only) destination to one of the two types we
429+
// expect: either the G_IS_FPCLASS's destination type, or the integer type
430+
// bitcast from the source type.
431+
LLT MITy = MRI.getType(MI.getReg(0));
432+
assert((MITy == IntTy || MITy == DstTyCopy) &&
433+
"Unexpected LLT type while lowering G_IS_FPCLASS");
434+
auto *SPVTy = MITy == IntTy ? SPIRVIntTy : SPIRVDstTy;
435+
GR->assignSPIRVTypeToVReg(SPVTy, MI.getReg(0), MF);
436+
return MI;
437+
};
438+
439+
// Helper to build and assign a constant in one go
440+
const auto buildSPIRVConstant = [&](LLT Ty, auto &&C) -> MachineInstrBuilder {
441+
if (!Ty.isFixedVector())
442+
return assignSPIRVTy(MIRBuilder.buildConstant(Ty, C));
443+
auto ScalarC = MIRBuilder.buildConstant(Ty.getScalarType(), C);
444+
assert((Ty == IntTy || Ty == DstTyCopy) &&
445+
"Unexpected LLT type while lowering constant for G_IS_FPCLASS");
446+
SPIRVType *VecEltTy = GR->getOrCreateSPIRVType(
447+
(Ty == IntTy ? LLVMIntTy : LLVMDstTy)->getScalarType(), MIRBuilder,
448+
SPIRV::AccessQualifier::ReadWrite,
449+
/*EmitIR*/ true);
450+
GR->assignSPIRVTypeToVReg(VecEltTy, ScalarC.getReg(0), MF);
451+
return assignSPIRVTy(MIRBuilder.buildSplatBuildVector(Ty, ScalarC));
452+
};
453+
454+
if (Mask == fcNone) {
455+
MIRBuilder.buildCopy(DstReg, buildSPIRVConstant(DstTy, 0));
456+
MI.eraseFromParent();
457+
return true;
458+
}
459+
if (Mask == fcAllFlags) {
460+
MIRBuilder.buildCopy(DstReg, buildSPIRVConstant(DstTy, 1));
461+
MI.eraseFromParent();
462+
return true;
463+
}
464+
465+
// Note that rather than creating a COPY here (between a floating-point and
466+
// integer type of the same size) we create a SPIR-V bitcast immediately. We
467+
// can't create a G_BITCAST because the LLTs are the same, and we can't seem
468+
// to correctly lower COPYs to SPIR-V bitcasts at this moment.
469+
Register ResVReg = MRI.createGenericVirtualRegister(IntTy);
470+
MRI.setRegClass(ResVReg, GR->getRegClass(SPIRVIntTy));
471+
GR->assignSPIRVTypeToVReg(SPIRVIntTy, ResVReg, Helper.MIRBuilder.getMF());
472+
auto AsInt = MIRBuilder.buildInstr(SPIRV::OpBitcast)
473+
.addDef(ResVReg)
474+
.addUse(GR->getSPIRVTypeID(SPIRVIntTy))
475+
.addUse(SrcReg);
476+
AsInt = assignSPIRVTy(std::move(AsInt));
477+
478+
// Various masks.
479+
APInt SignBit = APInt::getSignMask(BitSize);
480+
APInt ValueMask = APInt::getSignedMaxValue(BitSize); // All bits but sign.
481+
APInt Inf = APFloat::getInf(Semantics).bitcastToAPInt(); // Exp and int bit.
482+
APInt ExpMask = Inf;
483+
APInt AllOneMantissa = APFloat::getLargest(Semantics).bitcastToAPInt() & ~Inf;
484+
APInt QNaNBitMask =
485+
APInt::getOneBitSet(BitSize, AllOneMantissa.getActiveBits() - 1);
486+
APInt InversionMask = APInt::getAllOnes(DstTy.getScalarSizeInBits());
487+
488+
auto SignBitC = buildSPIRVConstant(IntTy, SignBit);
489+
auto ValueMaskC = buildSPIRVConstant(IntTy, ValueMask);
490+
auto InfC = buildSPIRVConstant(IntTy, Inf);
491+
auto ExpMaskC = buildSPIRVConstant(IntTy, ExpMask);
492+
auto ZeroC = buildSPIRVConstant(IntTy, 0);
493+
494+
auto Abs = assignSPIRVTy(MIRBuilder.buildAnd(IntTy, AsInt, ValueMaskC));
495+
auto Sign = assignSPIRVTy(
496+
MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_NE, DstTy, AsInt, Abs));
497+
498+
auto Res = buildSPIRVConstant(DstTy, 0);
499+
500+
const auto appendToRes = [&](MachineInstrBuilder &&ToAppend) {
501+
Res = assignSPIRVTy(
502+
MIRBuilder.buildOr(DstTyCopy, Res, assignSPIRVTy(std::move(ToAppend))));
503+
};
504+
505+
// Tests that involve more than one class should be processed first.
506+
if ((Mask & fcFinite) == fcFinite) {
507+
// finite(V) ==> abs(V) u< exp_mask
508+
appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, Abs,
509+
ExpMaskC));
510+
Mask &= ~fcFinite;
511+
} else if ((Mask & fcFinite) == fcPosFinite) {
512+
// finite(V) && V > 0 ==> V u< exp_mask
513+
appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, AsInt,
514+
ExpMaskC));
515+
Mask &= ~fcPosFinite;
516+
} else if ((Mask & fcFinite) == fcNegFinite) {
517+
// finite(V) && V < 0 ==> abs(V) u< exp_mask && signbit == 1
518+
auto Cmp = assignSPIRVTy(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT,
519+
DstTy, Abs, ExpMaskC));
520+
appendToRes(MIRBuilder.buildAnd(DstTy, Cmp, Sign));
521+
Mask &= ~fcNegFinite;
522+
}
523+
524+
if (FPClassTest PartialCheck = Mask & (fcZero | fcSubnormal)) {
525+
// fcZero | fcSubnormal => test all exponent bits are 0
526+
// TODO: Handle sign bit specific cases
527+
// TODO: Handle inverted case
528+
if (PartialCheck == (fcZero | fcSubnormal)) {
529+
auto ExpBits = assignSPIRVTy(MIRBuilder.buildAnd(IntTy, AsInt, ExpMaskC));
530+
appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
531+
ExpBits, ZeroC));
532+
Mask &= ~PartialCheck;
533+
}
534+
}
535+
536+
// Check for individual classes.
537+
if (FPClassTest PartialCheck = Mask & fcZero) {
538+
if (PartialCheck == fcPosZero)
539+
appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
540+
AsInt, ZeroC));
541+
else if (PartialCheck == fcZero)
542+
appendToRes(
543+
MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, ZeroC));
544+
else // fcNegZero
545+
appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
546+
AsInt, SignBitC));
547+
}
548+
549+
if (FPClassTest PartialCheck = Mask & fcSubnormal) {
550+
// issubnormal(V) ==> unsigned(abs(V) - 1) u< (all mantissa bits set)
551+
// issubnormal(V) && V>0 ==> unsigned(V - 1) u< (all mantissa bits set)
552+
auto V = (PartialCheck == fcPosSubnormal) ? AsInt : Abs;
553+
auto OneC = buildSPIRVConstant(IntTy, 1);
554+
auto VMinusOne = MIRBuilder.buildSub(IntTy, V, OneC);
555+
auto SubnormalRes = assignSPIRVTy(
556+
MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, VMinusOne,
557+
buildSPIRVConstant(IntTy, AllOneMantissa)));
558+
if (PartialCheck == fcNegSubnormal)
559+
SubnormalRes = MIRBuilder.buildAnd(DstTy, SubnormalRes, Sign);
560+
appendToRes(std::move(SubnormalRes));
561+
}
562+
563+
if (FPClassTest PartialCheck = Mask & fcInf) {
564+
if (PartialCheck == fcPosInf)
565+
appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
566+
AsInt, InfC));
567+
else if (PartialCheck == fcInf)
568+
appendToRes(
569+
MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, InfC));
570+
else { // fcNegInf
571+
APInt NegInf = APFloat::getInf(Semantics, true).bitcastToAPInt();
572+
auto NegInfC = buildSPIRVConstant(IntTy, NegInf);
573+
appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
574+
AsInt, NegInfC));
575+
}
576+
}
577+
578+
if (FPClassTest PartialCheck = Mask & fcNan) {
579+
auto InfWithQnanBitC = buildSPIRVConstant(IntTy, Inf | QNaNBitMask);
580+
if (PartialCheck == fcNan) {
581+
// isnan(V) ==> abs(V) u> int(inf)
582+
appendToRes(
583+
MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC));
584+
} else if (PartialCheck == fcQNan) {
585+
// isquiet(V) ==> abs(V) u>= (unsigned(Inf) | quiet_bit)
586+
appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGE, DstTy, Abs,
587+
InfWithQnanBitC));
588+
} else { // fcSNan
589+
// issignaling(V) ==> abs(V) u> unsigned(Inf) &&
590+
// abs(V) u< (unsigned(Inf) | quiet_bit)
591+
auto IsNan = assignSPIRVTy(
592+
MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC));
593+
auto IsNotQnan = assignSPIRVTy(MIRBuilder.buildICmp(
594+
CmpInst::Predicate::ICMP_ULT, DstTy, Abs, InfWithQnanBitC));
595+
appendToRes(MIRBuilder.buildAnd(DstTy, IsNan, IsNotQnan));
596+
}
597+
}
598+
599+
if (FPClassTest PartialCheck = Mask & fcNormal) {
600+
// isnormal(V) ==> (0 u< exp u< max_exp) ==> (unsigned(exp-1) u<
601+
// (max_exp-1))
602+
APInt ExpLSB = ExpMask & ~(ExpMask.shl(1));
603+
auto ExpMinusOne = assignSPIRVTy(
604+
MIRBuilder.buildSub(IntTy, Abs, buildSPIRVConstant(IntTy, ExpLSB)));
605+
APInt MaxExpMinusOne = ExpMask - ExpLSB;
606+
auto NormalRes = assignSPIRVTy(
607+
MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, ExpMinusOne,
608+
buildSPIRVConstant(IntTy, MaxExpMinusOne)));
609+
if (PartialCheck == fcNegNormal)
610+
NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, Sign);
611+
else if (PartialCheck == fcPosNormal) {
612+
auto PosSign = assignSPIRVTy(MIRBuilder.buildXor(
613+
DstTy, Sign, buildSPIRVConstant(DstTy, InversionMask)));
614+
NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, PosSign);
615+
}
616+
appendToRes(std::move(NormalRes));
617+
}
618+
619+
MIRBuilder.buildCopy(DstReg, Res);
620+
MI.eraseFromParent();
382621
return true;
383622
}

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class SPIRVLegalizerInfo : public LegalizerInfo {
3030
bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI,
3131
LostDebugLocObserver &LocObserver) const override;
3232
SPIRVLegalizerInfo(const SPIRVSubtarget &ST);
33+
34+
private:
35+
bool legalizeIsFPClass(LegalizerHelper &Helper, MachineInstr &MI,
36+
LostDebugLocObserver &LocObserver) const;
3337
};
3438
} // namespace llvm
3539
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H

0 commit comments

Comments
 (0)