Skip to content

Commit 8a307ae

Browse files
authored
[LoongArch] Fix failure to widen operand for [X]VMSK{LT,GE,NE}Z (#149442)
Reported-by: tangyan <[email protected]>
1 parent 2320cdd commit 8a307ae

File tree

2 files changed

+139
-97
lines changed

2 files changed

+139
-97
lines changed

llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp

Lines changed: 124 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -4560,6 +4560,80 @@ static SDValue signExtendBitcastSrcVector(SelectionDAG &DAG, EVT SExtVT,
45604560
llvm_unreachable("Unexpected node type for vXi1 sign extension");
45614561
}
45624562

4563+
static SDValue
4564+
performSETCC_BITCASTCombine(SDNode *N, SelectionDAG &DAG,
4565+
TargetLowering::DAGCombinerInfo &DCI,
4566+
const LoongArchSubtarget &Subtarget) {
4567+
SDLoc DL(N);
4568+
EVT VT = N->getValueType(0);
4569+
SDValue Src = N->getOperand(0);
4570+
EVT SrcVT = Src.getValueType();
4571+
4572+
if (Src.getOpcode() != ISD::SETCC || !Src.hasOneUse())
4573+
return SDValue();
4574+
4575+
bool UseLASX;
4576+
unsigned Opc = ISD::DELETED_NODE;
4577+
EVT CmpVT = Src.getOperand(0).getValueType();
4578+
EVT EltVT = CmpVT.getVectorElementType();
4579+
4580+
if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() == 128)
4581+
UseLASX = false;
4582+
else if (Subtarget.has32S() && Subtarget.hasExtLASX() &&
4583+
CmpVT.getSizeInBits() == 256)
4584+
UseLASX = true;
4585+
else
4586+
return SDValue();
4587+
4588+
SDValue SrcN1 = Src.getOperand(1);
4589+
switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) {
4590+
default:
4591+
break;
4592+
case ISD::SETEQ:
4593+
// x == 0 => not (vmsknez.b x)
4594+
if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
4595+
Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
4596+
break;
4597+
case ISD::SETGT:
4598+
// x > -1 => vmskgez.b x
4599+
if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8)
4600+
Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4601+
break;
4602+
case ISD::SETGE:
4603+
// x >= 0 => vmskgez.b x
4604+
if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
4605+
Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4606+
break;
4607+
case ISD::SETLT:
4608+
// x < 0 => vmskltz.{b,h,w,d} x
4609+
if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) &&
4610+
(EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4611+
EltVT == MVT::i64))
4612+
Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4613+
break;
4614+
case ISD::SETLE:
4615+
// x <= -1 => vmskltz.{b,h,w,d} x
4616+
if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) &&
4617+
(EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4618+
EltVT == MVT::i64))
4619+
Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4620+
break;
4621+
case ISD::SETNE:
4622+
// x != 0 => vmsknez.b x
4623+
if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
4624+
Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
4625+
break;
4626+
}
4627+
4628+
if (Opc == ISD::DELETED_NODE)
4629+
return SDValue();
4630+
4631+
SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src.getOperand(0));
4632+
EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());
4633+
V = DAG.getZExtOrTrunc(V, DL, T);
4634+
return DAG.getBitcast(VT, V);
4635+
}
4636+
45634637
static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
45644638
TargetLowering::DAGCombinerInfo &DCI,
45654639
const LoongArchSubtarget &Subtarget) {
@@ -4574,110 +4648,63 @@ static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
45744648
if (!SrcVT.isSimple() || SrcVT.getScalarType() != MVT::i1)
45754649
return SDValue();
45764650

4577-
unsigned Opc = ISD::DELETED_NODE;
45784651
// Combine SETCC and BITCAST into [X]VMSK{LT,GE,NE} when possible
4652+
SDValue Res = performSETCC_BITCASTCombine(N, DAG, DCI, Subtarget);
4653+
if (Res)
4654+
return Res;
4655+
4656+
// Generate vXi1 using [X]VMSKLTZ
4657+
MVT SExtVT;
4658+
unsigned Opc;
4659+
bool UseLASX = false;
4660+
bool PropagateSExt = false;
4661+
45794662
if (Src.getOpcode() == ISD::SETCC && Src.hasOneUse()) {
4580-
bool UseLASX;
45814663
EVT CmpVT = Src.getOperand(0).getValueType();
4582-
EVT EltVT = CmpVT.getVectorElementType();
4583-
4584-
if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() <= 128)
4585-
UseLASX = false;
4586-
else if (Subtarget.has32S() && Subtarget.hasExtLASX() &&
4587-
CmpVT.getSizeInBits() <= 256)
4588-
UseLASX = true;
4589-
else
4664+
if (CmpVT.getSizeInBits() > 256)
45904665
return SDValue();
4591-
4592-
SDValue SrcN1 = Src.getOperand(1);
4593-
switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) {
4594-
default:
4595-
break;
4596-
case ISD::SETEQ:
4597-
// x == 0 => not (vmsknez.b x)
4598-
if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
4599-
Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
4600-
break;
4601-
case ISD::SETGT:
4602-
// x > -1 => vmskgez.b x
4603-
if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8)
4604-
Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4605-
break;
4606-
case ISD::SETGE:
4607-
// x >= 0 => vmskgez.b x
4608-
if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
4609-
Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4610-
break;
4611-
case ISD::SETLT:
4612-
// x < 0 => vmskltz.{b,h,w,d} x
4613-
if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) &&
4614-
(EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4615-
EltVT == MVT::i64))
4616-
Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4617-
break;
4618-
case ISD::SETLE:
4619-
// x <= -1 => vmskltz.{b,h,w,d} x
4620-
if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) &&
4621-
(EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4622-
EltVT == MVT::i64))
4623-
Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4624-
break;
4625-
case ISD::SETNE:
4626-
// x != 0 => vmsknez.b x
4627-
if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
4628-
Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
4629-
break;
4630-
}
46314666
}
46324667

4633-
// Generate vXi1 using [X]VMSKLTZ
4634-
if (Opc == ISD::DELETED_NODE) {
4635-
MVT SExtVT;
4636-
bool UseLASX = false;
4637-
bool PropagateSExt = false;
4638-
switch (SrcVT.getSimpleVT().SimpleTy) {
4639-
default:
4640-
return SDValue();
4641-
case MVT::v2i1:
4642-
SExtVT = MVT::v2i64;
4643-
break;
4644-
case MVT::v4i1:
4645-
SExtVT = MVT::v4i32;
4646-
if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
4647-
SExtVT = MVT::v4i64;
4648-
UseLASX = true;
4649-
PropagateSExt = true;
4650-
}
4651-
break;
4652-
case MVT::v8i1:
4653-
SExtVT = MVT::v8i16;
4654-
if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
4655-
SExtVT = MVT::v8i32;
4656-
UseLASX = true;
4657-
PropagateSExt = true;
4658-
}
4659-
break;
4660-
case MVT::v16i1:
4661-
SExtVT = MVT::v16i8;
4662-
if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
4663-
SExtVT = MVT::v16i16;
4664-
UseLASX = true;
4665-
PropagateSExt = true;
4666-
}
4667-
break;
4668-
case MVT::v32i1:
4669-
SExtVT = MVT::v32i8;
4668+
switch (SrcVT.getSimpleVT().SimpleTy) {
4669+
default:
4670+
return SDValue();
4671+
case MVT::v2i1:
4672+
SExtVT = MVT::v2i64;
4673+
break;
4674+
case MVT::v4i1:
4675+
SExtVT = MVT::v4i32;
4676+
if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
4677+
SExtVT = MVT::v4i64;
46704678
UseLASX = true;
4671-
break;
4672-
};
4673-
if (UseLASX && !Subtarget.has32S() && !Subtarget.hasExtLASX())
4674-
return SDValue();
4675-
Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL)
4676-
: DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
4677-
Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4678-
} else {
4679-
Src = Src.getOperand(0);
4680-
}
4679+
PropagateSExt = true;
4680+
}
4681+
break;
4682+
case MVT::v8i1:
4683+
SExtVT = MVT::v8i16;
4684+
if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
4685+
SExtVT = MVT::v8i32;
4686+
UseLASX = true;
4687+
PropagateSExt = true;
4688+
}
4689+
break;
4690+
case MVT::v16i1:
4691+
SExtVT = MVT::v16i8;
4692+
if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
4693+
SExtVT = MVT::v16i16;
4694+
UseLASX = true;
4695+
PropagateSExt = true;
4696+
}
4697+
break;
4698+
case MVT::v32i1:
4699+
SExtVT = MVT::v32i8;
4700+
UseLASX = true;
4701+
break;
4702+
};
4703+
if (UseLASX && !(Subtarget.has32S() && Subtarget.hasExtLASX()))
4704+
return SDValue();
4705+
Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL)
4706+
: DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
4707+
Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
46814708

46824709
SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src);
46834710
EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());

llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,3 +588,18 @@ define i2 @vmsk_trunc_i64(<2 x i64> %a) {
588588
%res = bitcast <2 x i1> %y to i2
589589
ret i2 %res
590590
}
591+
592+
define i4 @vmsk_eq_allzeros_v4i8(<4 x i8> %a) {
593+
; CHECK-LABEL: vmsk_eq_allzeros_v4i8:
594+
; CHECK: # %bb.0:
595+
; CHECK-NEXT: vseqi.b $vr0, $vr0, 0
596+
; CHECK-NEXT: vilvl.b $vr0, $vr0, $vr0
597+
; CHECK-NEXT: vilvl.h $vr0, $vr0, $vr0
598+
; CHECK-NEXT: vslli.w $vr0, $vr0, 24
599+
; CHECK-NEXT: vmskltz.w $vr0, $vr0
600+
; CHECK-NEXT: vpickve2gr.hu $a0, $vr0, 0
601+
; CHECK-NEXT: ret
602+
%1 = icmp eq <4 x i8> %a, zeroinitializer
603+
%2 = bitcast <4 x i1> %1 to i4
604+
ret i4 %2
605+
}

0 commit comments

Comments
 (0)