-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[SimplifyCFG]: Switch on umin replaces default #164097
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-llvm-transforms Author: None (kper) ChangesA switch on Proof: https://alive2.llvm.org/ce/z/_N6nfs Full diff: https://github.com/llvm/llvm-project/pull/164097.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index d831c2737e5f8..8d3c91e69ad48 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -7540,6 +7540,62 @@ static bool reduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder,
return true;
}
+/// Tries to transform the switch when the condition is umin and a constant.
+/// In that case, the default branch can be replaced by the constant's branch.
+/// For example:
+/// switch(umin(a, 3)) {
+/// case 0:
+/// case 1:
+/// case 2:
+/// case 3:
+/// // ...
+/// default:
+/// unreachable
+/// }
+///
+/// Transforms into:
+///
+/// switch(umin(a, 3)) {
+/// case 0:
+/// case 1:
+/// case 2:
+/// default:
+/// // This is case 3
+/// }
+static bool simplifySwitchWhenUMin(SwitchInst *SI, IRBuilder<> &Builder) {
+ auto *Call = dyn_cast<IntrinsicInst>(SI->getCondition());
+
+ if (!Call)
+ return false;
+
+ if (Call->getIntrinsicID() != Intrinsic::umin)
+ return false;
+
+ if (!SI->defaultDestUnreachable())
+ return false;
+
+ // Extract the constant operand from the intrinsic.
+ auto *Constant = dyn_cast<ConstantInt>(Call->getArgOperand(1));
+
+ if (!Constant) {
+ return false;
+ }
+
+ for (auto Case = SI->case_begin(), e = SI->case_end(); Case != e; Case++) {
+ uint64_t CaseValue = Case->getCaseValue()->getValue().getZExtValue();
+
+ // We found the case which is equal to the case's umin argument.
+ // We can make the case the default case.
+ if (Constant->equalsInt(CaseValue)) {
+ SI->setDefaultDest(Case->getCaseSuccessor());
+ SI->removeCase(Case);
+ return true;
+ }
+ }
+
+ return false;
+}
+
/// Tries to transform switch of powers of two to reduce switch range.
/// For example, switch like:
/// switch (C) { case 1: case 2: case 64: case 128: }
@@ -7966,6 +8022,9 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
if (simplifyDuplicateSwitchArms(SI, DTU))
return requestResimplify();
+ if (simplifySwitchWhenUMin(SI, Builder))
+ return requestResimplify();
+
return false;
}
diff --git a/llvm/test/Transforms/SimplifyCFG/switch-umin.ll b/llvm/test/Transforms/SimplifyCFG/switch-umin.ll
new file mode 100644
index 0000000000000..69f78bf377dae
--- /dev/null
+++ b/llvm/test/Transforms/SimplifyCFG/switch-umin.ll
@@ -0,0 +1,105 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
+; RUN: opt -S -passes=simplifycfg < %s | FileCheck %s
+
+declare void @a()
+declare void @b()
+declare void @c()
+
+define void @switch_replace_default(i32 %x) {
+; CHECK-LABEL: define void @switch_replace_default(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 3)
+; CHECK-NEXT: switch i32 [[MIN]], label %[[COMMON_RET:.*]] [
+; CHECK-NEXT: i32 0, label %[[CASE0:.*]]
+; CHECK-NEXT: i32 1, label %[[CASE1:.*]]
+; CHECK-NEXT: i32 2, label %[[CASE2:.*]]
+; CHECK-NEXT: ]
+; CHECK: [[COMMON_RET]]:
+; CHECK-NEXT: ret void
+; CHECK: [[CASE0]]:
+; CHECK-NEXT: call void @a()
+; CHECK-NEXT: br label %[[COMMON_RET]]
+; CHECK: [[CASE1]]:
+; CHECK-NEXT: call void @b()
+; CHECK-NEXT: br label %[[COMMON_RET]]
+; CHECK: [[CASE2]]:
+; CHECK-NEXT: call void @c()
+; CHECK-NEXT: br label %[[COMMON_RET]]
+;
+ %min = call i32 @llvm.umin.i32(i32 %x, i32 3)
+ switch i32 %min, label %unreachable [
+ i32 0, label %case0
+ i32 1, label %case1
+ i32 2, label %case2
+ i32 3, label %case3
+ ]
+
+case0:
+ call void @a()
+ ret void
+
+case1:
+ call void @b()
+ ret void
+
+case2:
+ call void @c()
+ ret void
+
+case3:
+ ret void
+
+unreachable:
+ unreachable
+}
+
+define void @do_not_switch_replace_default(i32 %x, i32 %y) {
+; CHECK-LABEL: define void @do_not_switch_replace_default(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: switch i32 [[MIN]], label %[[UNREACHABLE:.*]] [
+; CHECK-NEXT: i32 0, label %[[CASE0:.*]]
+; CHECK-NEXT: i32 1, label %[[CASE1:.*]]
+; CHECK-NEXT: i32 2, label %[[CASE2:.*]]
+; CHECK-NEXT: i32 3, label %[[COMMON_RET:.*]]
+; CHECK-NEXT: ]
+; CHECK: [[COMMON_RET]]:
+; CHECK-NEXT: ret void
+; CHECK: [[CASE0]]:
+; CHECK-NEXT: call void @a()
+; CHECK-NEXT: br label %[[COMMON_RET]]
+; CHECK: [[CASE1]]:
+; CHECK-NEXT: call void @b()
+; CHECK-NEXT: br label %[[COMMON_RET]]
+; CHECK: [[CASE2]]:
+; CHECK-NEXT: call void @c()
+; CHECK-NEXT: br label %[[COMMON_RET]]
+; CHECK: [[UNREACHABLE]]:
+; CHECK-NEXT: unreachable
+;
+ %min = call i32 @llvm.umin.i32(i32 %x, i32 %y)
+ switch i32 %min, label %unreachable [
+ i32 0, label %case0
+ i32 1, label %case1
+ i32 2, label %case2
+ i32 3, label %case3
+ ]
+
+case0:
+ call void @a()
+ ret void
+
+case1:
+ call void @b()
+ ret void
+
+case2:
+ call void @c()
+ ret void
+
+case3:
+ ret void
+
+unreachable:
+ unreachable
+}
|
/// | ||
/// Transforms into: | ||
/// | ||
/// switch(umin(a, 3)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// switch(umin(a, 3)) { | |
/// switch(a) { |
/// In that case, the default branch can be replaced by the constant's branch. | ||
/// For example: | ||
/// switch(umin(a, 3)) { | ||
/// case 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about [4,1,2,3]
or [1,2,3]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dead edges should be removed. Otherwise it will cause miscompilation: https://alive2.llvm.org/ce/z/Faeck4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I mean we can remove dead edges: https://alive2.llvm.org/ce/z/hC3Dbm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this needs to be explicitly checked in the transform, we shouldn't rely on eliminateDeadSwitchCases having removed such cases (looking at the implementation, it uses known bits rather than ranges, so I think it may not eliminate all dead cases if the umin is not at a power of two boundary).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks :)
I have added a new commit which deletes cases where the value is higher than the constant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing test cases? I think you need add both tests of [1,2,3] that has holes and [4,1,2,3]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks.
It should be free to handle smax/smin/umax as well, though these patterns are rare.
Output of https://github.com/dtcxzyw/llvm-tools/blob/main/minmax-switch.cpp on llvm-opt-benchmark:
Count: 4
llvm.smax 2
llvm.smin 4
llvm.umax 9
llvm.umin 10395
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test coverage is needed for the new change.
|
||
// A case is dead when its value is higher than the Constant. | ||
SmallVector<ConstantInt *, 4> DeadCases; | ||
for (auto Case = SI->case_begin(), e = SI->case_end(); Case != e; Case++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (auto Case = SI->case_begin(), e = SI->case_end(); Case != e; Case++) { | |
for (auto Case : SI->cases()) { |
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would expect the dead cases to be removed even if no optimization opportunities in the motivating example.
for (auto Case : SI->cases()) { | ||
if (Case.getCaseValue()->getValue().ugt(Constant->getValue())) { | ||
DeadCases.push_back(Case.getCaseValue()); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (auto Case : SI->cases()) { | |
if (Case.getCaseValue()->getValue().ugt(Constant->getValue())) { | |
DeadCases.push_back(Case.getCaseValue()); | |
} | |
} | |
for (auto Case : SI->cases()) | |
if (Case.getCaseValue()->getValue().ugt(Constant->getValue())) | |
DeadCases.push_back(Case.getCaseValue()); |
llvm/include/llvm/IR/Instructions.h
Outdated
|
||
/// Delegate the call to the underlying SwitchInst::setDefaultCase and | ||
/// remove correspondent branch weight. | ||
LLVM_ABI void setDefaultDest(SwitchInst::CaseIt I); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replaceDefaultDest?
llvm/lib/IR/Instructions.cpp
Outdated
auto DestBlock = I->getCaseSuccessor(); | ||
if (Weights) { | ||
auto Weight = getSuccessorWeight(I->getCaseIndex() + 1); | ||
if (Weight.has_value()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weight
should be available if Weights
is available. Remove the check or use an assertion here.
I have included the recent suggestions. Also, I refactored the code s.t the dead case elimination happens first. I added a test which does not trigger the simplification because the dest is reachable but the dead cases were eliminated. |
BasicBlock *DeadCaseBB = DeadCase->getCaseSuccessor(); | ||
DeadCaseBB->removePredecessor(SIW->getParent()); | ||
SIW.removeCase(DeadCase); | ||
Updates.push_back({DominatorTree::Delete, BB, DeadCaseBB}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updates.push_back({DominatorTree::Delete, BB, DeadCaseBB}); | |
Updates.push_back({DominatorTree::Delete, BB, DeadCaseBB}); | |
Changed = true; |
if (!SI->defaultDestUnreachable() || Case == SI->case_default()) { | ||
if (DTU) | ||
DTU->applyUpdates(Updates); | ||
return false; | ||
} | ||
|
||
BasicBlock *Unreachable = SI->getDefaultDest(); | ||
SIW.replaceDefaultDest(Case); | ||
SIW.removeCase(Case); | ||
SIW->setCondition(A); | ||
|
||
Updates.push_back({DominatorTree::Delete, BB, Unreachable}); | ||
|
||
if (DTU) | ||
DTU->applyUpdates(Updates); | ||
|
||
return true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does it return false after removing the dead cases?
if (!SI->defaultDestUnreachable() || Case == SI->case_default()) { | |
if (DTU) | |
DTU->applyUpdates(Updates); | |
return false; | |
} | |
BasicBlock *Unreachable = SI->getDefaultDest(); | |
SIW.replaceDefaultDest(Case); | |
SIW.removeCase(Case); | |
SIW->setCondition(A); | |
Updates.push_back({DominatorTree::Delete, BB, Unreachable}); | |
if (DTU) | |
DTU->applyUpdates(Updates); | |
return true; | |
if (SI->defaultDestUnreachable() && Case != SI->case_default()) { | |
BasicBlock *Unreachable = SI->getDefaultDest(); | |
SIW.replaceDefaultDest(Case); | |
SIW.removeCase(Case); | |
SIW->setCondition(A); | |
Updates.push_back({DominatorTree::Delete, BB, Unreachable}); | |
Changed = true; | |
} | |
if (DTU) | |
DTU->applyUpdates(Updates); | |
return Changed; |
A switch on
umin
can eliminate the default case by making theumin
's constant the default case.Proof: https://alive2.llvm.org/ce/z/_N6nfs
Fixes: #162111