Skip to content
5 changes: 5 additions & 0 deletions llvm/include/llvm/IR/Instructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -3556,6 +3556,11 @@ class SwitchInstProfUpdateWrapper {
/// correspondent branch weight.
LLVM_ABI SwitchInst::CaseIt removeCase(SwitchInst::CaseIt I);

/// Replace the default destination by given case. Delegate the call to
/// the underlying SwitchInst::setDefaultDest and remove correspondent branch
/// weight.
LLVM_ABI void replaceDefaultDest(SwitchInst::CaseIt I);

/// Delegate the call to the underlying SwitchInst::addCase() and set the
/// specified branch weight for the added case.
LLVM_ABI void addCase(ConstantInt *OnVal, BasicBlock *Dest, CaseWeightOpt W);
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/IR/Instructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4171,6 +4171,16 @@ SwitchInstProfUpdateWrapper::removeCase(SwitchInst::CaseIt I) {
return SI.removeCase(I);
}

void SwitchInstProfUpdateWrapper::replaceDefaultDest(SwitchInst::CaseIt I) {
auto *DestBlock = I->getCaseSuccessor();
if (Weights) {
auto Weight = getSuccessorWeight(I->getCaseIndex() + 1);
(*Weights)[0] = Weight.value();
}

SI.setDefaultDest(DestBlock);
}

void SwitchInstProfUpdateWrapper::addCase(
ConstantInt *OnVal, BasicBlock *Dest,
SwitchInstProfUpdateWrapper::CaseWeightOpt W) {
Expand Down
79 changes: 79 additions & 0 deletions llvm/lib/Transforms/Utils/SimplifyCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7540,6 +7540,82 @@ static bool reduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder,
return true;
}

/// Tries to transform the switch when the condition is umin with a constant.
/// In that case, the default branch can be replaced by the constant's branch.
/// This method also removes dead cases when the simplification cannot replace
/// the default branch.
///
/// For example:
/// switch(umin(a, 3)) {
/// case 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about [4,1,2,3] or [1,2,3]?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dead edges should be removed. Otherwise it will cause miscompilation: https://alive2.llvm.org/ce/z/Faeck4

Copy link
Member

@dianqk dianqk Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I mean we can remove dead edges: https://alive2.llvm.org/ce/z/hC3Dbm.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to be explicitly checked in the transform, we shouldn't rely on eliminateDeadSwitchCases having removed such cases (looking at the implementation, it uses known bits rather than ranges, so I think it may not eliminate all dead cases if the umin is not at a power of two boundary).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks :)
I have added a new commit which deletes cases where the value is higher than the constant

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test cases? I think you need add both tests of [1,2,3] that has holes and [4,1,2,3]?

/// case 1:
/// case 2:
/// case 3:
/// case 4:
/// // ...
/// default:
/// unreachable
/// }
///
/// Transforms into:
///
/// switch(a) {
/// case 0:
/// case 1:
/// case 2:
/// default:
/// // This is case 3
/// }
static bool simplifySwitchWhenUMin(SwitchInst *SI, DomTreeUpdater *DTU) {
Value *A;
ConstantInt *Constant;

if (!match(SI->getCondition(), m_UMin(m_Value(A), m_ConstantInt(Constant))))
return false;

SmallVector<DominatorTree::UpdateType> Updates;
SwitchInstProfUpdateWrapper SIW(*SI);
BasicBlock *BB = SIW->getParent();

// Dead cases are removed even when the simplification fails.
// A case is dead when its value is higher than the Constant.
SmallVector<ConstantInt *, 4> DeadCases;
for (auto Case : SI->cases())
if (Case.getCaseValue()->getValue().ugt(Constant->getValue()))
DeadCases.push_back(Case.getCaseValue());

for (ConstantInt *DeadCaseVal : DeadCases) {
SwitchInst::CaseIt DeadCase = SIW->findCaseValue(DeadCaseVal);
BasicBlock *DeadCaseBB = DeadCase->getCaseSuccessor();
DeadCaseBB->removePredecessor(SIW->getParent());
SIW.removeCase(DeadCase);
Updates.push_back({DominatorTree::Delete, BB, DeadCaseBB});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Updates.push_back({DominatorTree::Delete, BB, DeadCaseBB});
Updates.push_back({DominatorTree::Delete, BB, DeadCaseBB});
Changed = true;

}

auto Case = SI->findCaseValue(Constant);
// If the case value is not found, `findCaseValue` returns the default case.
// In this scenario, since there is no explicit `case 3:`, the simplification
// fails. The simplification also fails when the switch’s default destination
// is reachable.
if (!SI->defaultDestUnreachable() || Case == SI->case_default()) {
if (DTU)
DTU->applyUpdates(Updates);
return !Updates.empty();
}

BasicBlock *Unreachable = SI->getDefaultDest();
SIW.replaceDefaultDest(Case);
SIW.removeCase(Case);
SIW->setCondition(A);

Updates.push_back({DominatorTree::Delete, BB, Unreachable});

if (DTU)
DTU->applyUpdates(Updates);

return true;
}

/// Tries to transform switch of powers of two to reduce switch range.
/// For example, switch like:
/// switch (C) { case 1: case 2: case 64: case 128: }
Expand Down Expand Up @@ -7966,6 +8042,9 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
if (simplifyDuplicateSwitchArms(SI, DTU))
return requestResimplify();

if (simplifySwitchWhenUMin(SI, DTU))
return requestResimplify();

return false;
}

Expand Down
246 changes: 246 additions & 0 deletions llvm/test/Transforms/SimplifyCFG/switch-umin.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
; RUN: opt -S -passes=simplifycfg < %s | FileCheck %s

declare void @a()
declare void @b()
declare void @c()
declare void @d()

define void @switch_replace_default(i32 %x) {
; CHECK-LABEL: define void @switch_replace_default(
; CHECK-SAME: i32 [[X:%.*]]) {
; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 3)
; CHECK-NEXT: switch i32 [[X]], label %[[COMMON_RET:.*]] [
; CHECK-NEXT: i32 0, label %[[CASE0:.*]]
; CHECK-NEXT: i32 1, label %[[CASE1:.*]]
; CHECK-NEXT: i32 2, label %[[CASE2:.*]]
; CHECK-NEXT: ], !prof [[PROF0:![0-9]+]]
; CHECK: [[COMMON_RET]]:
; CHECK-NEXT: ret void
; CHECK: [[CASE0]]:
; CHECK-NEXT: call void @a()
; CHECK-NEXT: br label %[[COMMON_RET]]
; CHECK: [[CASE1]]:
; CHECK-NEXT: call void @b()
; CHECK-NEXT: br label %[[COMMON_RET]]
; CHECK: [[CASE2]]:
; CHECK-NEXT: call void @c()
; CHECK-NEXT: br label %[[COMMON_RET]]
;
%min = call i32 @llvm.umin.i32(i32 %x, i32 3)
switch i32 %min, label %unreachable [
i32 0, label %case0
i32 1, label %case1
i32 2, label %case2
i32 3, label %case3
], !prof !0

case0:
call void @a()
ret void

case1:
call void @b()
ret void

case2:
call void @c()
ret void

case3:
ret void

unreachable:
unreachable
}

define void @switch_replace_default_and_remove_dead_cases(i32 %x) {
; CHECK-LABEL: define void @switch_replace_default_and_remove_dead_cases(
; CHECK-SAME: i32 [[X:%.*]]) {
; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 3)
; CHECK-NEXT: switch i32 [[X]], label %[[COMMON_RET:.*]] [
; CHECK-NEXT: i32 2, label %[[CASE2:.*]]
; CHECK-NEXT: i32 1, label %[[CASE1:.*]]
; CHECK-NEXT: ]
; CHECK: [[COMMON_RET]]:
; CHECK-NEXT: ret void
; CHECK: [[CASE1]]:
; CHECK-NEXT: call void @b()
; CHECK-NEXT: br label %[[COMMON_RET]]
; CHECK: [[CASE2]]:
; CHECK-NEXT: call void @c()
; CHECK-NEXT: br label %[[COMMON_RET]]
;
%min = call i32 @llvm.umin.i32(i32 %x, i32 3)
switch i32 %min, label %unreachable [
i32 4, label %case4
i32 1, label %case1
i32 2, label %case2
i32 3, label %case3
]

case4:
call void @a()
ret void

case1:
call void @b()
ret void

case2:
call void @c()
ret void

case3:
ret void

unreachable:
unreachable
}

define void @switch_replace_default_when_holes(i32 %x) {
; CHECK-LABEL: define void @switch_replace_default_when_holes(
; CHECK-SAME: i32 [[X:%.*]]) {
; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 3)
; CHECK-NEXT: switch i32 [[X]], label %[[COMMON_RET:.*]] [
; CHECK-NEXT: i32 1, label %[[CASE1:.*]]
; CHECK-NEXT: i32 2, label %[[CASE2:.*]]
; CHECK-NEXT: ]
; CHECK: [[COMMON_RET]]:
; CHECK-NEXT: ret void
; CHECK: [[CASE1]]:
; CHECK-NEXT: call void @b()
; CHECK-NEXT: br label %[[COMMON_RET]]
; CHECK: [[CASE2]]:
; CHECK-NEXT: call void @c()
; CHECK-NEXT: br label %[[COMMON_RET]]
;
%min = call i32 @llvm.umin.i32(i32 %x, i32 3)
switch i32 %min, label %unreachable [
i32 1, label %case1
i32 2, label %case2
i32 3, label %case3
]

case1:
call void @b()
ret void

case2:
call void @c()
ret void

case3:
ret void

unreachable:
unreachable
}

define void @do_not_switch_replace_default(i32 %x, i32 %y) {
; CHECK-LABEL: define void @do_not_switch_replace_default(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 [[Y]])
; CHECK-NEXT: switch i32 [[MIN]], label %[[UNREACHABLE:.*]] [
; CHECK-NEXT: i32 0, label %[[CASE0:.*]]
; CHECK-NEXT: i32 1, label %[[CASE1:.*]]
; CHECK-NEXT: i32 2, label %[[CASE2:.*]]
; CHECK-NEXT: i32 3, label %[[COMMON_RET:.*]]
; CHECK-NEXT: ]
; CHECK: [[COMMON_RET]]:
; CHECK-NEXT: ret void
; CHECK: [[CASE0]]:
; CHECK-NEXT: call void @a()
; CHECK-NEXT: br label %[[COMMON_RET]]
; CHECK: [[CASE1]]:
; CHECK-NEXT: call void @b()
; CHECK-NEXT: br label %[[COMMON_RET]]
; CHECK: [[CASE2]]:
; CHECK-NEXT: call void @c()
; CHECK-NEXT: br label %[[COMMON_RET]]
; CHECK: [[UNREACHABLE]]:
; CHECK-NEXT: unreachable
;
%min = call i32 @llvm.umin.i32(i32 %x, i32 %y)
switch i32 %min, label %unreachable [
i32 0, label %case0
i32 1, label %case1
i32 2, label %case2
i32 3, label %case3
]

case0:
call void @a()
ret void

case1:
call void @b()
ret void

case2:
call void @c()
ret void

case3:
ret void

unreachable:
unreachable
}

define void @do_not_replace_switch_default_but_remove_dead_cases(i32 %x) {
; CHECK-LABEL: define void @do_not_replace_switch_default_but_remove_dead_cases(
; CHECK-SAME: i32 [[X:%.*]]) {
; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 3)
; CHECK-NEXT: switch i32 [[MIN]], label %[[CASE0:.*]] [
; CHECK-NEXT: i32 3, label %[[COMMON_RET:.*]]
; CHECK-NEXT: i32 1, label %[[CASE1:.*]]
; CHECK-NEXT: i32 2, label %[[CASE2:.*]]
; CHECK-NEXT: ]
; CHECK: [[COMMON_RET]]:
; CHECK-NEXT: ret void
; CHECK: [[CASE0]]:
; CHECK-NEXT: call void @a()
; CHECK-NEXT: br label %[[COMMON_RET]]
; CHECK: [[CASE1]]:
; CHECK-NEXT: call void @b()
; CHECK-NEXT: br label %[[COMMON_RET]]
; CHECK: [[CASE2]]:
; CHECK-NEXT: call void @c()
; CHECK-NEXT: br label %[[COMMON_RET]]
;
%min = call i32 @llvm.umin.i32(i32 %x, i32 3)
switch i32 %min, label %case0 [ ; default is reachable, therefore simplification not triggered
i32 0, label %case0
i32 1, label %case1
i32 2, label %case2
i32 3, label %case3
i32 4, label %case4
]

case0:
call void @a()
ret void

case1:
call void @b()
ret void

case2:
call void @c()
ret void

case3:
ret void

case4:
call void @d()
ret void

}


!0 = !{!"branch_weights", i32 1, i32 2, i32 3, i32 99, i32 5}
;.
; CHECK: [[PROF0]] = !{!"branch_weights", i32 5, i32 2, i32 3, i32 99}
;.