Skip to content

Commit dd3d30a

Browse files
committed
[ARM] Have custom lowering for ucmp and scmp
Limited to non-thumb at the moment, but we can do this for i32 in 3 steps, using subs to set the flags initially.
1 parent b9adc4a commit dd3d30a

File tree

4 files changed

+183
-48
lines changed

4 files changed

+183
-48
lines changed

llvm/lib/Target/ARM/ARMISelLowering.cpp

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,11 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM_,
802802
setOperationAction(ISD::BSWAP, VT, Expand);
803803
}
804804

805+
if (!Subtarget->isThumb()) {
806+
setOperationAction(ISD::SCMP, MVT::i32, Custom);
807+
setOperationAction(ISD::UCMP, MVT::i32, Custom);
808+
}
809+
805810
setOperationAction(ISD::ConstantFP, MVT::f32, Custom);
806811
setOperationAction(ISD::ConstantFP, MVT::f64, Custom);
807812

@@ -10614,6 +10619,142 @@ SDValue ARMTargetLowering::LowerFP_TO_BF16(SDValue Op,
1061410619
return DAG.getBitcast(MVT::i32, Res);
1061510620
}
1061610621

10622+
SDValue ARMTargetLowering::LowerSCMP(SDValue Op, SelectionDAG &DAG) const {
10623+
SDLoc dl(Op);
10624+
SDValue LHS = Op.getOperand(0);
10625+
SDValue RHS = Op.getOperand(1);
10626+
10627+
// For the ARM assembly pattern:
10628+
// subs r0, r0, r1 ; subtract RHS from LHS and set flags
10629+
// movgt r0, #1 ; if LHS > RHS, set result to 1
10630+
// mvnlt r0, #0 ; if LHS < RHS, set result to -1 (mvn #0 = -1)
10631+
// ; if LHS == RHS, result remains 0 from the subs
10632+
10633+
// Optimization: if RHS is a subtraction against 0, use ADDC instead of SUBC
10634+
// Check if RHS is (0 - something), and if so use ADDC with LHS + something
10635+
SDValue SubResult, Flags;
10636+
bool CanUseAdd = false;
10637+
SDValue AddOperand;
10638+
10639+
// Check if RHS is a subtraction against 0: (0 - X)
10640+
if (RHS.getOpcode() == ISD::SUB) {
10641+
SDValue SubLHS = RHS.getOperand(0);
10642+
SDValue SubRHS = RHS.getOperand(1);
10643+
10644+
// Check if it's 0 - X
10645+
if (isNullConstant(SubLHS)) {
10646+
// For SCMP: only if X is known to never be INT_MIN (to avoid overflow)
10647+
if (RHS->getFlags().hasNoSignedWrap() || !DAG.computeKnownBits(SubRHS)
10648+
.getSignedMinValue()
10649+
.isMinSignedValue()) {
10650+
CanUseAdd = true;
10651+
AddOperand = SubRHS; // Replace RHS with X, so we do LHS + X instead of
10652+
// LHS - (0 - X)
10653+
}
10654+
}
10655+
}
10656+
10657+
if (CanUseAdd) {
10658+
// Use ADDC: LHS + AddOperand (where RHS was 0 - AddOperand)
10659+
SDValue AddWithFlags = DAG.getNode(
10660+
ARMISD::ADDC, dl, DAG.getVTList(MVT::i32, FlagsVT), LHS, AddOperand);
10661+
SubResult = AddWithFlags.getValue(0); // The addition result
10662+
Flags = AddWithFlags.getValue(1); // The flags from ADDS
10663+
} else {
10664+
// Use ARMISD::SUBC to generate SUBS instruction (subtract with flags)
10665+
SDValue SubWithFlags = DAG.getNode(
10666+
ARMISD::SUBC, dl, DAG.getVTList(MVT::i32, FlagsVT), LHS, RHS);
10667+
SubResult = SubWithFlags.getValue(0); // The subtraction result
10668+
Flags = SubWithFlags.getValue(1); // The flags from SUBS
10669+
}
10670+
10671+
// Constants for conditional moves
10672+
SDValue One = DAG.getConstant(1, dl, MVT::i32);
10673+
SDValue MinusOne = DAG.getConstant(0xFFFFFFFF, dl, MVT::i32);
10674+
10675+
// movgt: if greater than, set to 1
10676+
SDValue GTCond = DAG.getConstant(ARMCC::GT, dl, MVT::i32);
10677+
SDValue Result1 =
10678+
DAG.getNode(ARMISD::CMOV, dl, MVT::i32, SubResult, One, GTCond, Flags);
10679+
10680+
// mvnlt: if less than, set to -1 (equivalent to mvn #0)
10681+
SDValue LTCond = DAG.getConstant(ARMCC::LT, dl, MVT::i32);
10682+
SDValue Result2 =
10683+
DAG.getNode(ARMISD::CMOV, dl, MVT::i32, Result1, MinusOne, LTCond, Flags);
10684+
10685+
if (Op.getValueType() != MVT::i32)
10686+
Result2 = DAG.getSExtOrTrunc(Result2, dl, Op.getValueType());
10687+
10688+
return Result2;
10689+
}
10690+
10691+
SDValue ARMTargetLowering::LowerUCMP(SDValue Op, SelectionDAG &DAG) const {
10692+
SDLoc dl(Op);
10693+
SDValue LHS = Op.getOperand(0);
10694+
SDValue RHS = Op.getOperand(1);
10695+
10696+
// For the ARM assembly pattern (unsigned version):
10697+
// subs r0, r0, r1 ; subtract RHS from LHS and set flags
10698+
// movhi r0, #1 ; if LHS > RHS (unsigned), set result to 1
10699+
// mvnlo r0, #0 ; if LHS < RHS (unsigned), set result to -1
10700+
// ; if LHS == RHS, result remains 0 from the subs
10701+
10702+
// Optimization: if RHS is a subtraction against 0, use ADDC instead of SUBC
10703+
// Check if RHS is (0 - something), and if so use ADDC with LHS + something
10704+
SDValue SubResult, Flags;
10705+
bool CanUseAdd = false;
10706+
SDValue AddOperand;
10707+
10708+
// Check if RHS is a subtraction against 0: (0 - X)
10709+
if (RHS.getOpcode() == ISD::SUB) {
10710+
SDValue SubLHS = RHS.getOperand(0);
10711+
SDValue SubRHS = RHS.getOperand(1);
10712+
10713+
// Check if it's 0 - X
10714+
if (isNullConstant(SubLHS)) {
10715+
// For UCMP: only if X is known to never be zero
10716+
if (DAG.isKnownNeverZero(SubRHS)) {
10717+
CanUseAdd = true;
10718+
AddOperand = SubRHS; // Replace RHS with X, so we do LHS + X instead of
10719+
// LHS - (0 - X)
10720+
}
10721+
}
10722+
}
10723+
10724+
if (CanUseAdd) {
10725+
// Use ADDC: LHS + AddOperand (where RHS was 0 - AddOperand)
10726+
SDValue AddWithFlags = DAG.getNode(
10727+
ARMISD::ADDC, dl, DAG.getVTList(MVT::i32, FlagsVT), LHS, AddOperand);
10728+
SubResult = AddWithFlags.getValue(0); // The addition result
10729+
Flags = AddWithFlags.getValue(1); // The flags from ADDS
10730+
} else {
10731+
// Use ARMISD::SUBC to generate SUBS instruction (subtract with flags)
10732+
SDValue SubWithFlags = DAG.getNode(
10733+
ARMISD::SUBC, dl, DAG.getVTList(MVT::i32, FlagsVT), LHS, RHS);
10734+
SubResult = SubWithFlags.getValue(0); // The subtraction result
10735+
Flags = SubWithFlags.getValue(1); // The flags from SUBS
10736+
}
10737+
10738+
// Constants for conditional moves
10739+
SDValue One = DAG.getConstant(1, dl, MVT::i32);
10740+
SDValue MinusOne = DAG.getConstant(0xFFFFFFFF, dl, MVT::i32);
10741+
10742+
// movhi: if higher (unsigned greater than), set to 1
10743+
SDValue HICond = DAG.getConstant(ARMCC::HI, dl, MVT::i32);
10744+
SDValue Result1 =
10745+
DAG.getNode(ARMISD::CMOV, dl, MVT::i32, SubResult, One, HICond, Flags);
10746+
10747+
// mvnlo: if lower (unsigned less than), set to -1
10748+
SDValue LOCond = DAG.getConstant(ARMCC::LO, dl, MVT::i32);
10749+
SDValue Result2 =
10750+
DAG.getNode(ARMISD::CMOV, dl, MVT::i32, Result1, MinusOne, LOCond, Flags);
10751+
10752+
if (Op.getValueType() != MVT::i32)
10753+
Result2 = DAG.getSExtOrTrunc(Result2, dl, Op.getValueType());
10754+
10755+
return Result2;
10756+
}
10757+
1061710758
SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
1061810759
LLVM_DEBUG(dbgs() << "Lowering node: "; Op.dump());
1061910760
switch (Op.getOpcode()) {
@@ -10742,6 +10883,10 @@ SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
1074210883
case ISD::FP_TO_BF16:
1074310884
return LowerFP_TO_BF16(Op, DAG);
1074410885
case ARMISD::WIN__DBZCHK: return SDValue();
10886+
case ISD::SCMP:
10887+
return LowerSCMP(Op, DAG);
10888+
case ISD::UCMP:
10889+
return LowerUCMP(Op, DAG);
1074510890
}
1074610891
}
1074710892

llvm/lib/Target/ARM/ARMISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,8 @@ class VectorType;
903903
void LowerLOAD(SDNode *N, SmallVectorImpl<SDValue> &Results,
904904
SelectionDAG &DAG) const;
905905
SDValue LowerFP_TO_BF16(SDValue Op, SelectionDAG &DAG) const;
906+
SDValue LowerSCMP(SDValue Op, SelectionDAG &DAG) const;
907+
SDValue LowerUCMP(SDValue Op, SelectionDAG &DAG) const;
906908

907909
Register getRegisterByName(const char* RegName, LLT VT,
908910
const MachineFunction &MF) const override;

llvm/test/CodeGen/ARM/scmp.ll

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,9 @@
44
define i8 @scmp_8_8(i8 signext %x, i8 signext %y) nounwind {
55
; CHECK-LABEL: scmp_8_8:
66
; CHECK: @ %bb.0:
7-
; CHECK-NEXT: cmp r0, r1
8-
; CHECK-NEXT: mov r0, #0
9-
; CHECK-NEXT: mov r2, #0
10-
; CHECK-NEXT: movwlt r0, #1
11-
; CHECK-NEXT: movwgt r2, #1
12-
; CHECK-NEXT: sub r0, r2, r0
7+
; CHECK-NEXT: subs r0, r0, r1
8+
; CHECK-NEXT: movwgt r0, #1
9+
; CHECK-NEXT: mvnlt r0, #0
1310
; CHECK-NEXT: bx lr
1411
%1 = call i8 @llvm.scmp(i8 %x, i8 %y)
1512
ret i8 %1
@@ -18,12 +15,9 @@ define i8 @scmp_8_8(i8 signext %x, i8 signext %y) nounwind {
1815
define i8 @scmp_8_16(i16 signext %x, i16 signext %y) nounwind {
1916
; CHECK-LABEL: scmp_8_16:
2017
; CHECK: @ %bb.0:
21-
; CHECK-NEXT: cmp r0, r1
22-
; CHECK-NEXT: mov r0, #0
23-
; CHECK-NEXT: mov r2, #0
24-
; CHECK-NEXT: movwlt r0, #1
25-
; CHECK-NEXT: movwgt r2, #1
26-
; CHECK-NEXT: sub r0, r2, r0
18+
; CHECK-NEXT: subs r0, r0, r1
19+
; CHECK-NEXT: movwgt r0, #1
20+
; CHECK-NEXT: mvnlt r0, #0
2721
; CHECK-NEXT: bx lr
2822
%1 = call i8 @llvm.scmp(i16 %x, i16 %y)
2923
ret i8 %1
@@ -32,12 +26,9 @@ define i8 @scmp_8_16(i16 signext %x, i16 signext %y) nounwind {
3226
define i8 @scmp_8_32(i32 %x, i32 %y) nounwind {
3327
; CHECK-LABEL: scmp_8_32:
3428
; CHECK: @ %bb.0:
35-
; CHECK-NEXT: cmp r0, r1
36-
; CHECK-NEXT: mov r0, #0
37-
; CHECK-NEXT: mov r2, #0
38-
; CHECK-NEXT: movwlt r0, #1
39-
; CHECK-NEXT: movwgt r2, #1
40-
; CHECK-NEXT: sub r0, r2, r0
29+
; CHECK-NEXT: subs r0, r0, r1
30+
; CHECK-NEXT: movwgt r0, #1
31+
; CHECK-NEXT: mvnlt r0, #0
4132
; CHECK-NEXT: bx lr
4233
%1 = call i8 @llvm.scmp(i32 %x, i32 %y)
4334
ret i8 %1
@@ -92,17 +83,26 @@ define i8 @scmp_8_128(i128 %x, i128 %y) nounwind {
9283
define i32 @scmp_32_32(i32 %x, i32 %y) nounwind {
9384
; CHECK-LABEL: scmp_32_32:
9485
; CHECK: @ %bb.0:
95-
; CHECK-NEXT: cmp r0, r1
96-
; CHECK-NEXT: mov r0, #0
97-
; CHECK-NEXT: mov r2, #0
98-
; CHECK-NEXT: movwlt r0, #1
99-
; CHECK-NEXT: movwgt r2, #1
100-
; CHECK-NEXT: sub r0, r2, r0
86+
; CHECK-NEXT: subs r0, r0, r1
87+
; CHECK-NEXT: movwgt r0, #1
88+
; CHECK-NEXT: mvnlt r0, #0
10189
; CHECK-NEXT: bx lr
10290
%1 = call i32 @llvm.scmp(i32 %x, i32 %y)
10391
ret i32 %1
10492
}
10593

94+
define i32 @scmp_neg(i32 %x, i32 %y) nounwind {
95+
; CHECK-LABEL: scmp_neg:
96+
; CHECK: @ %bb.0:
97+
; CHECK-NEXT: adds r0, r0, r1
98+
; CHECK-NEXT: movwgt r0, #1
99+
; CHECK-NEXT: mvnlt r0, #0
100+
; CHECK-NEXT: bx lr
101+
%yy = sub nsw i32 0, %y
102+
%1 = call i32 @llvm.scmp(i32 %x, i32 %yy)
103+
ret i32 %1
104+
}
105+
106106
define i32 @scmp_32_64(i64 %x, i64 %y) nounwind {
107107
; CHECK-LABEL: scmp_32_64:
108108
; CHECK: @ %bb.0:

llvm/test/CodeGen/ARM/ucmp.ll

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,9 @@
44
define i8 @ucmp_8_8(i8 zeroext %x, i8 zeroext %y) nounwind {
55
; CHECK-LABEL: ucmp_8_8:
66
; CHECK: @ %bb.0:
7-
; CHECK-NEXT: cmp r0, r1
8-
; CHECK-NEXT: mov r0, #0
9-
; CHECK-NEXT: mov r2, #0
10-
; CHECK-NEXT: movwlo r0, #1
11-
; CHECK-NEXT: movwhi r2, #1
12-
; CHECK-NEXT: sub r0, r2, r0
7+
; CHECK-NEXT: subs r0, r0, r1
8+
; CHECK-NEXT: movwhi r0, #1
9+
; CHECK-NEXT: mvnlo r0, #0
1310
; CHECK-NEXT: bx lr
1411
%1 = call i8 @llvm.ucmp(i8 %x, i8 %y)
1512
ret i8 %1
@@ -18,12 +15,9 @@ define i8 @ucmp_8_8(i8 zeroext %x, i8 zeroext %y) nounwind {
1815
define i8 @ucmp_8_16(i16 zeroext %x, i16 zeroext %y) nounwind {
1916
; CHECK-LABEL: ucmp_8_16:
2017
; CHECK: @ %bb.0:
21-
; CHECK-NEXT: cmp r0, r1
22-
; CHECK-NEXT: mov r0, #0
23-
; CHECK-NEXT: mov r2, #0
24-
; CHECK-NEXT: movwlo r0, #1
25-
; CHECK-NEXT: movwhi r2, #1
26-
; CHECK-NEXT: sub r0, r2, r0
18+
; CHECK-NEXT: subs r0, r0, r1
19+
; CHECK-NEXT: movwhi r0, #1
20+
; CHECK-NEXT: mvnlo r0, #0
2721
; CHECK-NEXT: bx lr
2822
%1 = call i8 @llvm.ucmp(i16 %x, i16 %y)
2923
ret i8 %1
@@ -32,12 +26,9 @@ define i8 @ucmp_8_16(i16 zeroext %x, i16 zeroext %y) nounwind {
3226
define i8 @ucmp_8_32(i32 %x, i32 %y) nounwind {
3327
; CHECK-LABEL: ucmp_8_32:
3428
; CHECK: @ %bb.0:
35-
; CHECK-NEXT: cmp r0, r1
36-
; CHECK-NEXT: mov r0, #0
37-
; CHECK-NEXT: mov r2, #0
38-
; CHECK-NEXT: movwlo r0, #1
39-
; CHECK-NEXT: movwhi r2, #1
40-
; CHECK-NEXT: sub r0, r2, r0
29+
; CHECK-NEXT: subs r0, r0, r1
30+
; CHECK-NEXT: movwhi r0, #1
31+
; CHECK-NEXT: mvnlo r0, #0
4132
; CHECK-NEXT: bx lr
4233
%1 = call i8 @llvm.ucmp(i32 %x, i32 %y)
4334
ret i8 %1
@@ -92,12 +83,9 @@ define i8 @ucmp_8_128(i128 %x, i128 %y) nounwind {
9283
define i32 @ucmp_32_32(i32 %x, i32 %y) nounwind {
9384
; CHECK-LABEL: ucmp_32_32:
9485
; CHECK: @ %bb.0:
95-
; CHECK-NEXT: cmp r0, r1
96-
; CHECK-NEXT: mov r0, #0
97-
; CHECK-NEXT: mov r2, #0
98-
; CHECK-NEXT: movwlo r0, #1
99-
; CHECK-NEXT: movwhi r2, #1
100-
; CHECK-NEXT: sub r0, r2, r0
86+
; CHECK-NEXT: subs r0, r0, r1
87+
; CHECK-NEXT: movwhi r0, #1
88+
; CHECK-NEXT: mvnlo r0, #0
10189
; CHECK-NEXT: bx lr
10290
%1 = call i32 @llvm.ucmp(i32 %x, i32 %y)
10391
ret i32 %1

0 commit comments

Comments
 (0)