@@ -335,6 +335,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
335
335
getActionDefinitionsBuilder ({G_SMULH, G_UMULH}).alwaysLegal ();
336
336
}
337
337
338
+ getActionDefinitionsBuilder (G_IS_FPCLASS).custom ();
339
+
338
340
getLegacyLegalizerInfo ().computeTables ();
339
341
verify (*ST.getInstrInfo ());
340
342
}
@@ -355,9 +357,14 @@ static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
355
357
bool SPIRVLegalizerInfo::legalizeCustom (
356
358
LegalizerHelper &Helper, MachineInstr &MI,
357
359
LostDebugLocObserver &LocObserver) const {
358
- auto Opc = MI.getOpcode ();
359
360
MachineRegisterInfo &MRI = MI.getMF ()->getRegInfo ();
360
- if (Opc == TargetOpcode::G_ICMP) {
361
+ switch (MI.getOpcode ()) {
362
+ default :
363
+ // TODO: implement legalization for other opcodes.
364
+ return true ;
365
+ case TargetOpcode::G_IS_FPCLASS:
366
+ return legalizeIsFPClass (Helper, MI, LocObserver);
367
+ case TargetOpcode::G_ICMP: {
361
368
assert (GR->getSPIRVTypeForVReg (MI.getOperand (0 ).getReg ()));
362
369
auto &Op0 = MI.getOperand (2 );
363
370
auto &Op1 = MI.getOperand (3 );
@@ -378,6 +385,238 @@ bool SPIRVLegalizerInfo::legalizeCustom(
378
385
}
379
386
return true ;
380
387
}
381
- // TODO: implement legalization for other opcodes.
388
+ }
389
+ }
390
+
391
+ // Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted
392
+ // to ensure that all instructions created during the lowering have SPIR-V types
393
+ // assigned to them.
394
+ bool SPIRVLegalizerInfo::legalizeIsFPClass (
395
+ LegalizerHelper &Helper, MachineInstr &MI,
396
+ LostDebugLocObserver &LocObserver) const {
397
+ auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs ();
398
+ FPClassTest Mask = static_cast <FPClassTest>(MI.getOperand (2 ).getImm ());
399
+
400
+ auto &MIRBuilder = Helper.MIRBuilder ;
401
+ auto &MF = MIRBuilder.getMF ();
402
+ MachineRegisterInfo &MRI = MF.getRegInfo ();
403
+
404
+ Type *LLVMDstTy =
405
+ IntegerType::get (MIRBuilder.getContext (), DstTy.getScalarSizeInBits ());
406
+ if (DstTy.isVector ())
407
+ LLVMDstTy = VectorType::get (LLVMDstTy, DstTy.getElementCount ());
408
+ SPIRVType *SPIRVDstTy = GR->getOrCreateSPIRVType (
409
+ LLVMDstTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
410
+ /* EmitIR*/ true );
411
+
412
+ unsigned BitSize = SrcTy.getScalarSizeInBits ();
413
+ const fltSemantics &Semantics = getFltSemanticForLLT (SrcTy.getScalarType ());
414
+
415
+ LLT IntTy = LLT::scalar (BitSize);
416
+ Type *LLVMIntTy = IntegerType::get (MIRBuilder.getContext (), BitSize);
417
+ if (SrcTy.isVector ()) {
418
+ IntTy = LLT::vector (SrcTy.getElementCount (), IntTy);
419
+ LLVMIntTy = VectorType::get (LLVMIntTy, SrcTy.getElementCount ());
420
+ }
421
+ SPIRVType *SPIRVIntTy = GR->getOrCreateSPIRVType (
422
+ LLVMIntTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
423
+ /* EmitIR*/ true );
424
+
425
+ // Clang doesn't support capture of structured bindings:
426
+ LLT DstTyCopy = DstTy;
427
+ const auto assignSPIRVTy = [&](MachineInstrBuilder &&MI) {
428
+ // Assign this MI's (assumed only) destination to one of the two types we
429
+ // expect: either the G_IS_FPCLASS's destination type, or the integer type
430
+ // bitcast from the source type.
431
+ LLT MITy = MRI.getType (MI.getReg (0 ));
432
+ assert ((MITy == IntTy || MITy == DstTyCopy) &&
433
+ " Unexpected LLT type while lowering G_IS_FPCLASS" );
434
+ auto *SPVTy = MITy == IntTy ? SPIRVIntTy : SPIRVDstTy;
435
+ GR->assignSPIRVTypeToVReg (SPVTy, MI.getReg (0 ), MF);
436
+ return MI;
437
+ };
438
+
439
+ // Helper to build and assign a constant in one go
440
+ const auto buildSPIRVConstant = [&](LLT Ty, auto &&C) -> MachineInstrBuilder {
441
+ if (!Ty.isFixedVector ())
442
+ return assignSPIRVTy (MIRBuilder.buildConstant (Ty, C));
443
+ auto ScalarC = MIRBuilder.buildConstant (Ty.getScalarType (), C);
444
+ assert ((Ty == IntTy || Ty == DstTyCopy) &&
445
+ " Unexpected LLT type while lowering constant for G_IS_FPCLASS" );
446
+ SPIRVType *VecEltTy = GR->getOrCreateSPIRVType (
447
+ (Ty == IntTy ? LLVMIntTy : LLVMDstTy)->getScalarType (), MIRBuilder,
448
+ SPIRV::AccessQualifier::ReadWrite,
449
+ /* EmitIR*/ true );
450
+ GR->assignSPIRVTypeToVReg (VecEltTy, ScalarC.getReg (0 ), MF);
451
+ return assignSPIRVTy (MIRBuilder.buildSplatBuildVector (Ty, ScalarC));
452
+ };
453
+
454
+ if (Mask == fcNone) {
455
+ MIRBuilder.buildCopy (DstReg, buildSPIRVConstant (DstTy, 0 ));
456
+ MI.eraseFromParent ();
457
+ return true ;
458
+ }
459
+ if (Mask == fcAllFlags) {
460
+ MIRBuilder.buildCopy (DstReg, buildSPIRVConstant (DstTy, 1 ));
461
+ MI.eraseFromParent ();
462
+ return true ;
463
+ }
464
+
465
+ // Note that rather than creating a COPY here (between a floating-point and
466
+ // integer type of the same size) we create a SPIR-V bitcast immediately. We
467
+ // can't create a G_BITCAST because the LLTs are the same, and we can't seem
468
+ // to correctly lower COPYs to SPIR-V bitcasts at this moment.
469
+ Register ResVReg = MRI.createGenericVirtualRegister (IntTy);
470
+ MRI.setRegClass (ResVReg, GR->getRegClass (SPIRVIntTy));
471
+ GR->assignSPIRVTypeToVReg (SPIRVIntTy, ResVReg, Helper.MIRBuilder .getMF ());
472
+ auto AsInt = MIRBuilder.buildInstr (SPIRV::OpBitcast)
473
+ .addDef (ResVReg)
474
+ .addUse (GR->getSPIRVTypeID (SPIRVIntTy))
475
+ .addUse (SrcReg);
476
+ AsInt = assignSPIRVTy (std::move (AsInt));
477
+
478
+ // Various masks.
479
+ APInt SignBit = APInt::getSignMask (BitSize);
480
+ APInt ValueMask = APInt::getSignedMaxValue (BitSize); // All bits but sign.
481
+ APInt Inf = APFloat::getInf (Semantics).bitcastToAPInt (); // Exp and int bit.
482
+ APInt ExpMask = Inf;
483
+ APInt AllOneMantissa = APFloat::getLargest (Semantics).bitcastToAPInt () & ~Inf;
484
+ APInt QNaNBitMask =
485
+ APInt::getOneBitSet (BitSize, AllOneMantissa.getActiveBits () - 1 );
486
+ APInt InversionMask = APInt::getAllOnes (DstTy.getScalarSizeInBits ());
487
+
488
+ auto SignBitC = buildSPIRVConstant (IntTy, SignBit);
489
+ auto ValueMaskC = buildSPIRVConstant (IntTy, ValueMask);
490
+ auto InfC = buildSPIRVConstant (IntTy, Inf);
491
+ auto ExpMaskC = buildSPIRVConstant (IntTy, ExpMask);
492
+ auto ZeroC = buildSPIRVConstant (IntTy, 0 );
493
+
494
+ auto Abs = assignSPIRVTy (MIRBuilder.buildAnd (IntTy, AsInt, ValueMaskC));
495
+ auto Sign = assignSPIRVTy (
496
+ MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_NE, DstTy, AsInt, Abs));
497
+
498
+ auto Res = buildSPIRVConstant (DstTy, 0 );
499
+
500
+ const auto appendToRes = [&](MachineInstrBuilder &&ToAppend) {
501
+ Res = assignSPIRVTy (
502
+ MIRBuilder.buildOr (DstTyCopy, Res, assignSPIRVTy (std::move (ToAppend))));
503
+ };
504
+
505
+ // Tests that involve more than one class should be processed first.
506
+ if ((Mask & fcFinite) == fcFinite) {
507
+ // finite(V) ==> abs(V) u< exp_mask
508
+ appendToRes (MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_ULT, DstTy, Abs,
509
+ ExpMaskC));
510
+ Mask &= ~fcFinite;
511
+ } else if ((Mask & fcFinite) == fcPosFinite) {
512
+ // finite(V) && V > 0 ==> V u< exp_mask
513
+ appendToRes (MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_ULT, DstTy, AsInt,
514
+ ExpMaskC));
515
+ Mask &= ~fcPosFinite;
516
+ } else if ((Mask & fcFinite) == fcNegFinite) {
517
+ // finite(V) && V < 0 ==> abs(V) u< exp_mask && signbit == 1
518
+ auto Cmp = assignSPIRVTy (MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_ULT,
519
+ DstTy, Abs, ExpMaskC));
520
+ appendToRes (MIRBuilder.buildAnd (DstTy, Cmp, Sign));
521
+ Mask &= ~fcNegFinite;
522
+ }
523
+
524
+ if (FPClassTest PartialCheck = Mask & (fcZero | fcSubnormal)) {
525
+ // fcZero | fcSubnormal => test all exponent bits are 0
526
+ // TODO: Handle sign bit specific cases
527
+ // TODO: Handle inverted case
528
+ if (PartialCheck == (fcZero | fcSubnormal)) {
529
+ auto ExpBits = assignSPIRVTy (MIRBuilder.buildAnd (IntTy, AsInt, ExpMaskC));
530
+ appendToRes (MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_EQ, DstTy,
531
+ ExpBits, ZeroC));
532
+ Mask &= ~PartialCheck;
533
+ }
534
+ }
535
+
536
+ // Check for individual classes.
537
+ if (FPClassTest PartialCheck = Mask & fcZero) {
538
+ if (PartialCheck == fcPosZero)
539
+ appendToRes (MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_EQ, DstTy,
540
+ AsInt, ZeroC));
541
+ else if (PartialCheck == fcZero)
542
+ appendToRes (
543
+ MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_EQ, DstTy, Abs, ZeroC));
544
+ else // fcNegZero
545
+ appendToRes (MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_EQ, DstTy,
546
+ AsInt, SignBitC));
547
+ }
548
+
549
+ if (FPClassTest PartialCheck = Mask & fcSubnormal) {
550
+ // issubnormal(V) ==> unsigned(abs(V) - 1) u< (all mantissa bits set)
551
+ // issubnormal(V) && V>0 ==> unsigned(V - 1) u< (all mantissa bits set)
552
+ auto V = (PartialCheck == fcPosSubnormal) ? AsInt : Abs;
553
+ auto OneC = buildSPIRVConstant (IntTy, 1 );
554
+ auto VMinusOne = MIRBuilder.buildSub (IntTy, V, OneC);
555
+ auto SubnormalRes = assignSPIRVTy (
556
+ MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_ULT, DstTy, VMinusOne,
557
+ buildSPIRVConstant (IntTy, AllOneMantissa)));
558
+ if (PartialCheck == fcNegSubnormal)
559
+ SubnormalRes = MIRBuilder.buildAnd (DstTy, SubnormalRes, Sign);
560
+ appendToRes (std::move (SubnormalRes));
561
+ }
562
+
563
+ if (FPClassTest PartialCheck = Mask & fcInf) {
564
+ if (PartialCheck == fcPosInf)
565
+ appendToRes (MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_EQ, DstTy,
566
+ AsInt, InfC));
567
+ else if (PartialCheck == fcInf)
568
+ appendToRes (
569
+ MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_EQ, DstTy, Abs, InfC));
570
+ else { // fcNegInf
571
+ APInt NegInf = APFloat::getInf (Semantics, true ).bitcastToAPInt ();
572
+ auto NegInfC = buildSPIRVConstant (IntTy, NegInf);
573
+ appendToRes (MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_EQ, DstTy,
574
+ AsInt, NegInfC));
575
+ }
576
+ }
577
+
578
+ if (FPClassTest PartialCheck = Mask & fcNan) {
579
+ auto InfWithQnanBitC = buildSPIRVConstant (IntTy, Inf | QNaNBitMask);
580
+ if (PartialCheck == fcNan) {
581
+ // isnan(V) ==> abs(V) u> int(inf)
582
+ appendToRes (
583
+ MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC));
584
+ } else if (PartialCheck == fcQNan) {
585
+ // isquiet(V) ==> abs(V) u>= (unsigned(Inf) | quiet_bit)
586
+ appendToRes (MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_UGE, DstTy, Abs,
587
+ InfWithQnanBitC));
588
+ } else { // fcSNan
589
+ // issignaling(V) ==> abs(V) u> unsigned(Inf) &&
590
+ // abs(V) u< (unsigned(Inf) | quiet_bit)
591
+ auto IsNan = assignSPIRVTy (
592
+ MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC));
593
+ auto IsNotQnan = assignSPIRVTy (MIRBuilder.buildICmp (
594
+ CmpInst::Predicate::ICMP_ULT, DstTy, Abs, InfWithQnanBitC));
595
+ appendToRes (MIRBuilder.buildAnd (DstTy, IsNan, IsNotQnan));
596
+ }
597
+ }
598
+
599
+ if (FPClassTest PartialCheck = Mask & fcNormal) {
600
+ // isnormal(V) ==> (0 u< exp u< max_exp) ==> (unsigned(exp-1) u<
601
+ // (max_exp-1))
602
+ APInt ExpLSB = ExpMask & ~(ExpMask.shl (1 ));
603
+ auto ExpMinusOne = assignSPIRVTy (
604
+ MIRBuilder.buildSub (IntTy, Abs, buildSPIRVConstant (IntTy, ExpLSB)));
605
+ APInt MaxExpMinusOne = ExpMask - ExpLSB;
606
+ auto NormalRes = assignSPIRVTy (
607
+ MIRBuilder.buildICmp (CmpInst::Predicate::ICMP_ULT, DstTy, ExpMinusOne,
608
+ buildSPIRVConstant (IntTy, MaxExpMinusOne)));
609
+ if (PartialCheck == fcNegNormal)
610
+ NormalRes = MIRBuilder.buildAnd (DstTy, NormalRes, Sign);
611
+ else if (PartialCheck == fcPosNormal) {
612
+ auto PosSign = assignSPIRVTy (MIRBuilder.buildXor (
613
+ DstTy, Sign, buildSPIRVConstant (DstTy, InversionMask)));
614
+ NormalRes = MIRBuilder.buildAnd (DstTy, NormalRes, PosSign);
615
+ }
616
+ appendToRes (std::move (NormalRes));
617
+ }
618
+
619
+ MIRBuilder.buildCopy (DstReg, Res);
620
+ MI.eraseFromParent ();
382
621
return true ;
383
622
}
0 commit comments