Skip to content

[LoongArch] Fix failure to widen operand for [X]VMSK{LT,GE,NE}Z #149442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

heiher
Copy link
Member

@heiher heiher commented Jul 18, 2025

Reported-by: tangyan [email protected]

@llvmbot
Copy link
Member

llvmbot commented Jul 18, 2025

@llvm/pr-subscribers-backend-loongarch

Author: hev (heiher)

Changes

Reported-by: tangyan <[email protected]>


Full diff: https://github.com/llvm/llvm-project/pull/149442.diff

2 Files Affected:

  • (modified) llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp (+120-100)
  • (modified) llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll (+15)
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index 2378664ca8155..c870271213dc6 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -4560,6 +4560,80 @@ static SDValue signExtendBitcastSrcVector(SelectionDAG &DAG, EVT SExtVT,
   llvm_unreachable("Unexpected node type for vXi1 sign extension");
 }
 
+static SDValue
+performSETCC_BITCASTCombine(SDNode *N, SelectionDAG &DAG,
+                            TargetLowering::DAGCombinerInfo &DCI,
+                            const LoongArchSubtarget &Subtarget) {
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+  SDValue Src = N->getOperand(0);
+  EVT SrcVT = Src.getValueType();
+
+  if (Src.getOpcode() != ISD::SETCC || !Src.hasOneUse())
+    return SDValue();
+
+  bool UseLASX;
+  unsigned Opc = ISD::DELETED_NODE;
+  EVT CmpVT = Src.getOperand(0).getValueType();
+  EVT EltVT = CmpVT.getVectorElementType();
+
+  if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() == 128)
+    UseLASX = false;
+  else if (Subtarget.has32S() && Subtarget.hasExtLASX() &&
+           CmpVT.getSizeInBits() == 256)
+    UseLASX = true;
+  else
+    return SDValue();
+
+  SDValue SrcN1 = Src.getOperand(1);
+  switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) {
+  default:
+    break;
+  case ISD::SETEQ:
+    // x == 0 => not (vmsknez.b x)
+    if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+      Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
+    break;
+  case ISD::SETGT:
+    // x > -1 => vmskgez.b x
+    if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8)
+      Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
+    break;
+  case ISD::SETGE:
+    // x >= 0 => vmskgez.b x
+    if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+      Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
+    break;
+  case ISD::SETLT:
+    // x < 0 => vmskltz.{b,h,w,d} x
+    if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) &&
+        (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
+         EltVT == MVT::i64))
+      Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
+    break;
+  case ISD::SETLE:
+    // x <= -1 => vmskltz.{b,h,w,d} x
+    if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) &&
+        (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
+         EltVT == MVT::i64))
+      Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
+    break;
+  case ISD::SETNE:
+    // x != 0 => vmsknez.b x
+    if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+      Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
+    break;
+  }
+
+  if (Opc == ISD::DELETED_NODE)
+    return SDValue();
+
+  SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src.getOperand(0));
+  EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());
+  V = DAG.getZExtOrTrunc(V, DL, T);
+  return DAG.getBitcast(VT, V);
+}
+
 static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
                                      TargetLowering::DAGCombinerInfo &DCI,
                                      const LoongArchSubtarget &Subtarget) {
@@ -4574,110 +4648,56 @@ static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
   if (!SrcVT.isSimple() || SrcVT.getScalarType() != MVT::i1)
     return SDValue();
 
-  unsigned Opc = ISD::DELETED_NODE;
   // Combine SETCC and BITCAST into [X]VMSK{LT,GE,NE} when possible
-  if (Src.getOpcode() == ISD::SETCC && Src.hasOneUse()) {
-    bool UseLASX;
-    EVT CmpVT = Src.getOperand(0).getValueType();
-    EVT EltVT = CmpVT.getVectorElementType();
-
-    if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() <= 128)
-      UseLASX = false;
-    else if (Subtarget.has32S() && Subtarget.hasExtLASX() &&
-             CmpVT.getSizeInBits() <= 256)
-      UseLASX = true;
-    else
-      return SDValue();
-
-    SDValue SrcN1 = Src.getOperand(1);
-    switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) {
-    default:
-      break;
-    case ISD::SETEQ:
-      // x == 0 => not (vmsknez.b x)
-      if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
-        Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
-      break;
-    case ISD::SETGT:
-      // x > -1 => vmskgez.b x
-      if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8)
-        Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
-      break;
-    case ISD::SETGE:
-      // x >= 0 => vmskgez.b x
-      if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
-        Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
-      break;
-    case ISD::SETLT:
-      // x < 0 => vmskltz.{b,h,w,d} x
-      if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) &&
-          (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
-           EltVT == MVT::i64))
-        Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
-      break;
-    case ISD::SETLE:
-      // x <= -1 => vmskltz.{b,h,w,d} x
-      if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) &&
-          (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
-           EltVT == MVT::i64))
-        Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
-      break;
-    case ISD::SETNE:
-      // x != 0 => vmsknez.b x
-      if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
-        Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
-      break;
-    }
-  }
+  SDValue Res = performSETCC_BITCASTCombine(N, DAG, DCI, Subtarget);
+  if (Res)
+    return Res;
 
   // Generate vXi1 using [X]VMSKLTZ
-  if (Opc == ISD::DELETED_NODE) {
-    MVT SExtVT;
-    bool UseLASX = false;
-    bool PropagateSExt = false;
-    switch (SrcVT.getSimpleVT().SimpleTy) {
-    default:
-      return SDValue();
-    case MVT::v2i1:
-      SExtVT = MVT::v2i64;
-      break;
-    case MVT::v4i1:
-      SExtVT = MVT::v4i32;
-      if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
-        SExtVT = MVT::v4i64;
-        UseLASX = true;
-        PropagateSExt = true;
-      }
-      break;
-    case MVT::v8i1:
-      SExtVT = MVT::v8i16;
-      if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
-        SExtVT = MVT::v8i32;
-        UseLASX = true;
-        PropagateSExt = true;
-      }
-      break;
-    case MVT::v16i1:
-      SExtVT = MVT::v16i8;
-      if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
-        SExtVT = MVT::v16i16;
-        UseLASX = true;
-        PropagateSExt = true;
-      }
-      break;
-    case MVT::v32i1:
-      SExtVT = MVT::v32i8;
+  MVT SExtVT;
+  unsigned Opc;
+  bool UseLASX = false;
+  bool PropagateSExt = false;
+  switch (SrcVT.getSimpleVT().SimpleTy) {
+  default:
+    return SDValue();
+  case MVT::v2i1:
+    SExtVT = MVT::v2i64;
+    break;
+  case MVT::v4i1:
+    SExtVT = MVT::v4i32;
+    if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
+      SExtVT = MVT::v4i64;
       UseLASX = true;
-      break;
-    };
-    if (UseLASX && !Subtarget.has32S() && !Subtarget.hasExtLASX())
-      return SDValue();
-    Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL)
-                        : DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
-    Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
-  } else {
-    Src = Src.getOperand(0);
-  }
+      PropagateSExt = true;
+    }
+    break;
+  case MVT::v8i1:
+    SExtVT = MVT::v8i16;
+    if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
+      SExtVT = MVT::v8i32;
+      UseLASX = true;
+      PropagateSExt = true;
+    }
+    break;
+  case MVT::v16i1:
+    SExtVT = MVT::v16i8;
+    if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
+      SExtVT = MVT::v16i16;
+      UseLASX = true;
+      PropagateSExt = true;
+    }
+    break;
+  case MVT::v32i1:
+    SExtVT = MVT::v32i8;
+    UseLASX = true;
+    break;
+  };
+  if (UseLASX && !Subtarget.has32S() && !Subtarget.hasExtLASX())
+    return SDValue();
+  Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL)
+                      : DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
+  Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
 
   SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src);
   EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());
diff --git a/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll b/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll
index 0ee30120f77a6..ad57bbf9ee5c0 100644
--- a/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll
+++ b/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll
@@ -588,3 +588,18 @@ define i2 @vmsk_trunc_i64(<2 x i64> %a) {
   %res = bitcast <2 x i1> %y to i2
   ret i2 %res
 }
+
+define i4 @vmsk_eq_allzeros_v4i8(<4 x i8> %a) {
+; CHECK-LABEL: vmsk_eq_allzeros_v4i8:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vseqi.b $vr0, $vr0, 0
+; CHECK-NEXT:    vilvl.b $vr0, $vr0, $vr0
+; CHECK-NEXT:    vilvl.h $vr0, $vr0, $vr0
+; CHECK-NEXT:    vslli.w $vr0, $vr0, 24
+; CHECK-NEXT:    vmskltz.w $vr0, $vr0
+; CHECK-NEXT:    vpickve2gr.hu $a0, $vr0, 0
+; CHECK-NEXT:    ret
+  %1 = icmp eq <4 x i8> %a, zeroinitializer
+  %2 = bitcast <4 x i1> %1 to i4
+  ret i4 %2
+}

EVT CmpVT = Src.getOperand(0).getValueType();
EVT EltVT = CmpVT.getVectorElementType();

if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() == 128)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: performSETCC_BITCASTCombine was split out from performBITCASTCombine. Additionally, the condition was changed from CmpVT.getSizeInBits() <= 128 to CmpVT.getSizeInBits() == 128. The same applies to the 256-bit case.

@heiher heiher marked this pull request as draft July 18, 2025 06:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants