@@ -4560,6 +4560,80 @@ static SDValue signExtendBitcastSrcVector(SelectionDAG &DAG, EVT SExtVT,
4560
4560
llvm_unreachable (" Unexpected node type for vXi1 sign extension" );
4561
4561
}
4562
4562
4563
+ static SDValue
4564
+ performSETCC_BITCASTCombine (SDNode *N, SelectionDAG &DAG,
4565
+ TargetLowering::DAGCombinerInfo &DCI,
4566
+ const LoongArchSubtarget &Subtarget) {
4567
+ SDLoc DL (N);
4568
+ EVT VT = N->getValueType (0 );
4569
+ SDValue Src = N->getOperand (0 );
4570
+ EVT SrcVT = Src.getValueType ();
4571
+
4572
+ if (Src.getOpcode () != ISD::SETCC || !Src.hasOneUse ())
4573
+ return SDValue ();
4574
+
4575
+ bool UseLASX;
4576
+ unsigned Opc = ISD::DELETED_NODE;
4577
+ EVT CmpVT = Src.getOperand (0 ).getValueType ();
4578
+ EVT EltVT = CmpVT.getVectorElementType ();
4579
+
4580
+ if (Subtarget.hasExtLSX () && CmpVT.getSizeInBits () == 128 )
4581
+ UseLASX = false ;
4582
+ else if (Subtarget.has32S () && Subtarget.hasExtLASX () &&
4583
+ CmpVT.getSizeInBits () == 256 )
4584
+ UseLASX = true ;
4585
+ else
4586
+ return SDValue ();
4587
+
4588
+ SDValue SrcN1 = Src.getOperand (1 );
4589
+ switch (cast<CondCodeSDNode>(Src.getOperand (2 ))->get ()) {
4590
+ default :
4591
+ break ;
4592
+ case ISD::SETEQ:
4593
+ // x == 0 => not (vmsknez.b x)
4594
+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4595
+ Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
4596
+ break ;
4597
+ case ISD::SETGT:
4598
+ // x > -1 => vmskgez.b x
4599
+ if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) && EltVT == MVT::i8 )
4600
+ Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4601
+ break ;
4602
+ case ISD::SETGE:
4603
+ // x >= 0 => vmskgez.b x
4604
+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4605
+ Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4606
+ break ;
4607
+ case ISD::SETLT:
4608
+ // x < 0 => vmskltz.{b,h,w,d} x
4609
+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) &&
4610
+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4611
+ EltVT == MVT::i64 ))
4612
+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4613
+ break ;
4614
+ case ISD::SETLE:
4615
+ // x <= -1 => vmskltz.{b,h,w,d} x
4616
+ if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) &&
4617
+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4618
+ EltVT == MVT::i64 ))
4619
+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4620
+ break ;
4621
+ case ISD::SETNE:
4622
+ // x != 0 => vmsknez.b x
4623
+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4624
+ Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
4625
+ break ;
4626
+ }
4627
+
4628
+ if (Opc == ISD::DELETED_NODE)
4629
+ return SDValue ();
4630
+
4631
+ SDValue V = DAG.getNode (Opc, DL, MVT::i64 , Src.getOperand (0 ));
4632
+ EVT T = EVT::getIntegerVT (*DAG.getContext (), SrcVT.getVectorNumElements ());
4633
+ V = DAG.getZExtOrTrunc (V, DL, T);
4634
+ return DAG.getBitcast (VT, V);
4635
+ }
4636
+
4563
4637
static SDValue performBITCASTCombine (SDNode *N, SelectionDAG &DAG,
4564
4638
TargetLowering::DAGCombinerInfo &DCI,
4565
4639
const LoongArchSubtarget &Subtarget) {
@@ -4574,110 +4648,63 @@ static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
4574
4648
if (!SrcVT.isSimple () || SrcVT.getScalarType () != MVT::i1)
4575
4649
return SDValue ();
4576
4650
4577
- unsigned Opc = ISD::DELETED_NODE;
4578
4651
// Combine SETCC and BITCAST into [X]VMSK{LT,GE,NE} when possible
4652
+ SDValue Res = performSETCC_BITCASTCombine (N, DAG, DCI, Subtarget);
4653
+ if (Res)
4654
+ return Res;
4655
+
4656
+ // Generate vXi1 using [X]VMSKLTZ
4657
+ MVT SExtVT;
4658
+ unsigned Opc;
4659
+ bool UseLASX = false ;
4660
+ bool PropagateSExt = false ;
4661
+
4579
4662
if (Src.getOpcode () == ISD::SETCC && Src.hasOneUse ()) {
4580
- bool UseLASX;
4581
4663
EVT CmpVT = Src.getOperand (0 ).getValueType ();
4582
- EVT EltVT = CmpVT.getVectorElementType ();
4583
-
4584
- if (Subtarget.hasExtLSX () && CmpVT.getSizeInBits () <= 128 )
4585
- UseLASX = false ;
4586
- else if (Subtarget.has32S () && Subtarget.hasExtLASX () &&
4587
- CmpVT.getSizeInBits () <= 256 )
4588
- UseLASX = true ;
4589
- else
4664
+ if (CmpVT.getSizeInBits () > 256 )
4590
4665
return SDValue ();
4591
-
4592
- SDValue SrcN1 = Src.getOperand (1 );
4593
- switch (cast<CondCodeSDNode>(Src.getOperand (2 ))->get ()) {
4594
- default :
4595
- break ;
4596
- case ISD::SETEQ:
4597
- // x == 0 => not (vmsknez.b x)
4598
- if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4599
- Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
4600
- break ;
4601
- case ISD::SETGT:
4602
- // x > -1 => vmskgez.b x
4603
- if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) && EltVT == MVT::i8 )
4604
- Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4605
- break ;
4606
- case ISD::SETGE:
4607
- // x >= 0 => vmskgez.b x
4608
- if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4609
- Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4610
- break ;
4611
- case ISD::SETLT:
4612
- // x < 0 => vmskltz.{b,h,w,d} x
4613
- if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) &&
4614
- (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4615
- EltVT == MVT::i64 ))
4616
- Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4617
- break ;
4618
- case ISD::SETLE:
4619
- // x <= -1 => vmskltz.{b,h,w,d} x
4620
- if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) &&
4621
- (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4622
- EltVT == MVT::i64 ))
4623
- Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4624
- break ;
4625
- case ISD::SETNE:
4626
- // x != 0 => vmsknez.b x
4627
- if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4628
- Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
4629
- break ;
4630
- }
4631
4666
}
4632
4667
4633
- // Generate vXi1 using [X]VMSKLTZ
4634
- if (Opc == ISD::DELETED_NODE) {
4635
- MVT SExtVT;
4636
- bool UseLASX = false ;
4637
- bool PropagateSExt = false ;
4638
- switch (SrcVT.getSimpleVT ().SimpleTy ) {
4639
- default :
4640
- return SDValue ();
4641
- case MVT::v2i1:
4642
- SExtVT = MVT::v2i64;
4643
- break ;
4644
- case MVT::v4i1:
4645
- SExtVT = MVT::v4i32;
4646
- if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4647
- SExtVT = MVT::v4i64;
4648
- UseLASX = true ;
4649
- PropagateSExt = true ;
4650
- }
4651
- break ;
4652
- case MVT::v8i1:
4653
- SExtVT = MVT::v8i16;
4654
- if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4655
- SExtVT = MVT::v8i32;
4656
- UseLASX = true ;
4657
- PropagateSExt = true ;
4658
- }
4659
- break ;
4660
- case MVT::v16i1:
4661
- SExtVT = MVT::v16i8;
4662
- if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4663
- SExtVT = MVT::v16i16;
4664
- UseLASX = true ;
4665
- PropagateSExt = true ;
4666
- }
4667
- break ;
4668
- case MVT::v32i1:
4669
- SExtVT = MVT::v32i8;
4668
+ switch (SrcVT.getSimpleVT ().SimpleTy ) {
4669
+ default :
4670
+ return SDValue ();
4671
+ case MVT::v2i1:
4672
+ SExtVT = MVT::v2i64;
4673
+ break ;
4674
+ case MVT::v4i1:
4675
+ SExtVT = MVT::v4i32;
4676
+ if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4677
+ SExtVT = MVT::v4i64;
4670
4678
UseLASX = true ;
4671
- break ;
4672
- };
4673
- if (UseLASX && !Subtarget.has32S () && !Subtarget.hasExtLASX ())
4674
- return SDValue ();
4675
- Src = PropagateSExt ? signExtendBitcastSrcVector (DAG, SExtVT, Src, DL)
4676
- : DAG.getNode (ISD::SIGN_EXTEND, DL, SExtVT, Src);
4677
- Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4678
- } else {
4679
- Src = Src.getOperand (0 );
4680
- }
4679
+ PropagateSExt = true ;
4680
+ }
4681
+ break ;
4682
+ case MVT::v8i1:
4683
+ SExtVT = MVT::v8i16;
4684
+ if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4685
+ SExtVT = MVT::v8i32;
4686
+ UseLASX = true ;
4687
+ PropagateSExt = true ;
4688
+ }
4689
+ break ;
4690
+ case MVT::v16i1:
4691
+ SExtVT = MVT::v16i8;
4692
+ if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4693
+ SExtVT = MVT::v16i16;
4694
+ UseLASX = true ;
4695
+ PropagateSExt = true ;
4696
+ }
4697
+ break ;
4698
+ case MVT::v32i1:
4699
+ SExtVT = MVT::v32i8;
4700
+ UseLASX = true ;
4701
+ break ;
4702
+ };
4703
+ if (UseLASX && !(Subtarget.has32S () && Subtarget.hasExtLASX ()))
4704
+ return SDValue ();
4705
+ Src = PropagateSExt ? signExtendBitcastSrcVector (DAG, SExtVT, Src, DL)
4706
+ : DAG.getNode (ISD::SIGN_EXTEND, DL, SExtVT, Src);
4707
+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4681
4708
4682
4709
SDValue V = DAG.getNode (Opc, DL, MVT::i64 , Src);
4683
4710
EVT T = EVT::getIntegerVT (*DAG.getContext (), SrcVT.getVectorNumElements ());
0 commit comments