Skip to content

[AMDGPU][FixIrreducible][UnifyLoopExits] Support callbr with inline-asm #149308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/include/llvm/Support/GenericLoopInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ void LoopBase<BlockT, LoopT>::verifyLoop() const {
if (BB == getHeader()) {
assert(!OutsideLoopPreds.empty() && "Loop is unreachable!");
} else if (!OutsideLoopPreds.empty()) {
// A non-header loop shouldn't be reachable from outside the loop,
// A non-header loop block shouldn't be reachable from outside the loop,
// though it is permitted if the predecessor is not itself actually
// reachable.
BlockT *EntryBB = &BB->getParent()->front();
Expand Down
12 changes: 11 additions & 1 deletion llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Dominators.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/Printable.h"
#include <cassert>

namespace llvm {
Expand Down Expand Up @@ -607,10 +608,19 @@ LLVM_ABI bool SplitIndirectBrCriticalEdges(Function &F,
// successors
LLVM_ABI void InvertBranch(BranchInst *PBI, IRBuilderBase &Builder);

// Check whether the function only has simple terminator:
template <typename... TermInst>
LLVM_ABI bool hasOnlyGivenTerminators(const Function &F);

// Check whether the function only has blocks with simple terminators:
// br/brcond/unreachable/ret
LLVM_ABI bool hasOnlySimpleTerminator(const Function &F);

// Check whether the function only has blocks with simple terminators
// (br/brcond/unreachable/ret) or callbr.
LLVM_ABI bool hasOnlySimpleTerminatorOrCallBr(const Function &F);

LLVM_ABI Printable printBBPtr(const BasicBlock *BB);

} // end namespace llvm

#endif // LLVM_TRANSFORMS_UTILS_BASICBLOCKUTILS_H
36 changes: 35 additions & 1 deletion llvm/include/llvm/Transforms/Utils/ControlFlowUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@

#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/CycleInfo.h"

namespace llvm {

class BasicBlock;
class CallBrInst;
class LoopInfo;
class DomTreeUpdater;

/// Given a set of branch descriptors [BB, Succ0, Succ1], create a "hub" such
Expand Down Expand Up @@ -104,7 +107,8 @@ struct ControlFlowHub {
: BB(BB), Succ0(Succ0), Succ1(Succ1) {}
};

void addBranch(BasicBlock *BB, BasicBlock *Succ0, BasicBlock *Succ1) {
void addBranch(BasicBlock *BB, BasicBlock *Succ0,
BasicBlock *Succ1 = nullptr) {
assert(BB);
assert(Succ0 || Succ1);
Branches.emplace_back(BB, Succ0, Succ1);
Expand All @@ -118,6 +122,36 @@ struct ControlFlowHub {
std::optional<unsigned> MaxControlFlowBooleans = std::nullopt);

SmallVector<BranchDescriptor> Branches;

/// \brief Create a new intermediate target block for a callbr edge.
///
/// This function creates a new basic block (the "target block") that sits
/// between a callbr instruction and one of its successors. The callbr's
/// successor is rewired to this new block, and the new block unconditionally
/// branches to the original successor. This is useful for normalizing control
/// flow, e.g., when transforming irreducible loops.
///
/// \param CallBr The callbr instruction whose edge is to be split.
/// \param Succ The original successor basic block to be reached.
/// \param SuccIdx The index of the successor in the callbr
/// instruction.
/// \param AttachToCallBr If true, the new block is associated with the
/// callbr's parent for loop/cycle info.
/// If false, the new block is associated with the
/// callbr's successor for loop/cycle info.
/// \param CI Optional CycleInfo for updating cycle membership.
/// \param DTU Optional DomTreeUpdater for updating the dominator
/// tree.
/// \param LI Optional LoopInfo for updating loop membership.
///
/// \returns The newly created intermediate target block.
///
/// \note This function updates PHI nodes, dominator tree, loop info, and
/// cycle info as needed.
static BasicBlock *
createCallBrTarget(CallBrInst *CallBr, BasicBlock *Succ, unsigned SuccIdx,
bool AttachToCallBr = true, CycleInfo *CI = nullptr,
DomTreeUpdater *DTU = nullptr, LoopInfo *LI = nullptr);
};

} // end namespace llvm
Expand Down
22 changes: 19 additions & 3 deletions llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1766,12 +1766,28 @@ void llvm::InvertBranch(BranchInst *PBI, IRBuilderBase &Builder) {
PBI->swapSuccessors();
}

bool llvm::hasOnlySimpleTerminator(const Function &F) {
template <typename... TermInst>
bool llvm::hasOnlyGivenTerminators(const Function &F) {
for (auto &BB : F) {
auto *Term = BB.getTerminator();
if (!(isa<ReturnInst>(Term) || isa<UnreachableInst>(Term) ||
isa<BranchInst>(Term)))
if (!(isa<TermInst>(Term) || ...))
return false;
}
return true;
}

bool llvm::hasOnlySimpleTerminator(const Function &F) {
return hasOnlyGivenTerminators<ReturnInst, UnreachableInst, BranchInst>(F);
}

bool llvm::hasOnlySimpleTerminatorOrCallBr(const Function &F) {
return hasOnlyGivenTerminators<ReturnInst, UnreachableInst, BranchInst,
CallBrInst>(F);
}

Printable llvm::printBBPtr(const BasicBlock *BB) {
return Printable([BB](raw_ostream &OS) {
if (BB)
return BB->printAsOperand(OS);
});
}
57 changes: 56 additions & 1 deletion llvm/lib/Transforms/Utils/ControlFlowUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/ValueHandle.h"
Expand Down Expand Up @@ -282,7 +283,9 @@ std::pair<BasicBlock *, bool> ControlFlowHub::finalize(

for (auto [BB, Succ0, Succ1] : Branches) {
#ifndef NDEBUG
assert(Incoming.insert(BB).second && "Duplicate entry for incoming block.");
assert(
(Incoming.insert(BB).second || isa<CallBrInst>(BB->getTerminator())) &&
"Duplicate entry for incoming block.");
#endif
if (Succ0)
Outgoing.insert(Succ0);
Expand Down Expand Up @@ -342,3 +345,55 @@ std::pair<BasicBlock *, bool> ControlFlowHub::finalize(

return {FirstGuardBlock, true};
}

BasicBlock *ControlFlowHub::createCallBrTarget(
CallBrInst *CallBr, BasicBlock *Succ, unsigned SuccIdx, bool AttachToCallBr,
CycleInfo *CI, DomTreeUpdater *DTU, LoopInfo *LI) {
BasicBlock *CallBrBlock = CallBr->getParent();
BasicBlock *CallBrTarget =
BasicBlock::Create(CallBrBlock->getContext(),
CallBrBlock->getName() + ".target." + Succ->getName(),
CallBrBlock->getParent());
// Rewire control flow from callbr to the new target block.
Succ->replacePhiUsesWith(CallBrBlock, CallBrTarget);
CallBr->setSuccessor(SuccIdx, CallBrTarget);
// Jump from the new target block to the original successor.
BranchInst::Create(Succ, CallBrTarget);
if (LI) {
if (Loop *L = LI->getLoopFor(AttachToCallBr ? CallBrBlock : Succ)) {
bool AddToLoop = true;
if (AttachToCallBr) {
// Check if the loops are disjoint. In that case, we do not add the
// intermediate target to any loop.
if (auto *LL = LI->getLoopFor(Succ);
LL && !L->contains(LL) && !LL->contains(L))
AddToLoop = false;
}
if (AddToLoop)
L->addBasicBlockToLoop(CallBrTarget, *LI);
}
}
if (CI) {
if (auto *C = CI->getCycle(AttachToCallBr ? CallBrBlock : Succ); C) {
bool AddToCycle = true;
if (AttachToCallBr) {
// Check if the cycles are disjoint. In that case, we do not add the
// intermediate target to any cycle.
if (auto *CC = CI->getCycle(Succ); CC) {
auto *CommonC = CI->getSmallestCommonCycle(C, CC);
if (CommonC != C && CommonC != CC)
AddToCycle = false;
}
}
if (AddToCycle)
CI->addBlockToCycle(CallBrTarget, C);
}
}
if (DTU) {
DTU->applyUpdates({{DominatorTree::Insert, CallBrBlock, CallBrTarget}});
if (DTU->getDomTree().dominates(CallBrBlock, Succ))
DTU->applyUpdates({{DominatorTree::Delete, CallBrBlock, Succ},
{DominatorTree::Insert, CallBrTarget, Succ}});
}
return CallBrTarget;
}
125 changes: 100 additions & 25 deletions llvm/lib/Transforms/Utils/FixIrreducible.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,53 @@
// Limitation: The pass cannot handle switch statements and indirect
// branches. Both must be lowered to plain branches first.
//
// CallBr support: CallBr is handled as a more general branch instruction which
// can have multiple successors. The pass redirects the edges to intermediate
// target blocks that unconditionally branch to the original callbr target
// blocks. This allows the control flow hub to know to which of the original
// target blocks to jump to.
// Example input CFG:
// Entry (callbr)
// / \
// v v
// H ----> B
// ^ /|
// `----' |
// v
// Exit
//
// becomes:
// Entry (callbr)
// / \
// v v
// target.H target.B
// | |
// v v
// H ----> B
// ^ /|
// `----' |
// v
// Exit
//
// Note
// OUTPUT CFG: Converted to a natural loop with a new header N.
//
// Entry (callbr)
// / \
// v v
// target.H target.B
// \ /
// \ /
// v v
// N <---.
// / \ \
// / \ |
// v v /
// H --> B --'
// |
// v
// Exit
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Utils/FixIrreducible.h"
Expand Down Expand Up @@ -231,6 +278,7 @@ static bool fixIrreducible(Cycle &C, CycleInfo &CI, DominatorTree &DT,
return false;
LLVM_DEBUG(dbgs() << "Processing cycle:\n" << CI.print(&C) << "\n";);

DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
ControlFlowHub CHub;
SetVector<BasicBlock *> Predecessors;

Expand All @@ -242,18 +290,32 @@ static bool fixIrreducible(Cycle &C, CycleInfo &CI, DominatorTree &DT,
}

for (BasicBlock *P : Predecessors) {
auto *Branch = cast<BranchInst>(P->getTerminator());
// Exactly one of the two successors is the header.
BasicBlock *Succ0 = Branch->getSuccessor(0) == Header ? Header : nullptr;
BasicBlock *Succ1 = Succ0 ? nullptr : Header;
if (!Succ0)
assert(Branch->getSuccessor(1) == Header);
assert(Succ0 || Succ1);
CHub.addBranch(P, Succ0, Succ1);

LLVM_DEBUG(dbgs() << "Added internal branch: " << P->getName() << " -> "
<< (Succ0 ? Succ0->getName() : "") << " "
<< (Succ1 ? Succ1->getName() : "") << "\n");
if (BranchInst *Branch = dyn_cast<BranchInst>(P->getTerminator())) {
// Exactly one of the two successors is the header.
BasicBlock *Succ0 = Branch->getSuccessor(0) == Header ? Header : nullptr;
BasicBlock *Succ1 = Succ0 ? nullptr : Header;
if (!Succ0)
assert(Branch->getSuccessor(1) == Header);
assert(Succ0 || Succ1);
CHub.addBranch(P, Succ0, Succ1);

LLVM_DEBUG(dbgs() << "Added internal branch: " << printBBPtr(P) << " -> "
<< printBBPtr(Succ0) << (Succ0 && Succ1 ? " " : "")
<< printBBPtr(Succ1) << "\n");
} else if (CallBrInst *CallBr = dyn_cast<CallBrInst>(P->getTerminator())) {
for (unsigned I = 0; I < CallBr->getNumSuccessors(); ++I) {
BasicBlock *Succ = CallBr->getSuccessor(I);
if (Succ != Header)
continue;
BasicBlock *NewSucc = ControlFlowHub::createCallBrTarget(
CallBr, Succ, I, false, &CI, &DTU, LI);
CHub.addBranch(NewSucc, Succ);
LLVM_DEBUG(dbgs() << "Added internal branch: " << printBBPtr(NewSucc)
<< " -> " << printBBPtr(Succ) << "\n");
}
} else {
llvm_unreachable("unsupported block terminator");
}
}

// Redirect external incoming edges. This includes the edges on the header.
Expand All @@ -266,17 +328,31 @@ static bool fixIrreducible(Cycle &C, CycleInfo &CI, DominatorTree &DT,
}

for (BasicBlock *P : Predecessors) {
auto *Branch = cast<BranchInst>(P->getTerminator());
BasicBlock *Succ0 = Branch->getSuccessor(0);
Succ0 = C.contains(Succ0) ? Succ0 : nullptr;
BasicBlock *Succ1 =
Branch->isUnconditional() ? nullptr : Branch->getSuccessor(1);
Succ1 = Succ1 && C.contains(Succ1) ? Succ1 : nullptr;
CHub.addBranch(P, Succ0, Succ1);

LLVM_DEBUG(dbgs() << "Added external branch: " << P->getName() << " -> "
<< (Succ0 ? Succ0->getName() : "") << " "
<< (Succ1 ? Succ1->getName() : "") << "\n");
if (BranchInst *Branch = dyn_cast<BranchInst>(P->getTerminator()); Branch) {
BasicBlock *Succ0 = Branch->getSuccessor(0);
Succ0 = C.contains(Succ0) ? Succ0 : nullptr;
BasicBlock *Succ1 =
Branch->isUnconditional() ? nullptr : Branch->getSuccessor(1);
Succ1 = Succ1 && C.contains(Succ1) ? Succ1 : nullptr;
CHub.addBranch(P, Succ0, Succ1);

LLVM_DEBUG(dbgs() << "Added external branch: " << printBBPtr(P) << " -> "
<< printBBPtr(Succ0) << (Succ0 && Succ1 ? " " : "")
<< printBBPtr(Succ1) << "\n");
} else if (CallBrInst *CallBr = dyn_cast<CallBrInst>(P->getTerminator())) {
for (unsigned I = 0; I < CallBr->getNumSuccessors(); ++I) {
BasicBlock *Succ = CallBr->getSuccessor(I);
if (!C.contains(Succ))
continue;
BasicBlock *NewSucc = ControlFlowHub::createCallBrTarget(
CallBr, Succ, I, true, &CI, &DTU, LI);
CHub.addBranch(NewSucc, Succ);
LLVM_DEBUG(dbgs() << "Added external branch: " << printBBPtr(NewSucc)
<< " -> " << printBBPtr(Succ) << "\n");
}
} else {
llvm_unreachable("unsupported block terminator");
}
}

// Redirect all the backedges through a "hub" consisting of a series
Expand All @@ -292,7 +368,6 @@ static bool fixIrreducible(Cycle &C, CycleInfo &CI, DominatorTree &DT,
SetVector<BasicBlock *> Entries;
Entries.insert(C.entry_rbegin(), C.entry_rend());

DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
CHub.finalize(&DTU, GuardBlocks, "irr");
#if defined(EXPENSIVE_CHECKS)
assert(DT.verify(DominatorTree::VerificationLevel::Full));
Expand Down Expand Up @@ -325,7 +400,7 @@ static bool FixIrreducibleImpl(Function &F, CycleInfo &CI, DominatorTree &DT,
LLVM_DEBUG(dbgs() << "===== Fix irreducible control-flow in function: "
<< F.getName() << "\n");

assert(hasOnlySimpleTerminator(F) && "Unsupported block terminator.");
assert(hasOnlySimpleTerminatorOrCallBr(F) && "Unsupported block terminator.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this assert really necessary, or will it hit one of the later asserts when a particular block is processed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only if those blocks are processed while dealing with a loop/cycle.
However, they wouldn't actually be a problem if used completely unrelated to any loop/cycle.
Ig that it's more for documenting the restriction of the pass inside the code as well.


bool Changed = false;
for (Cycle *TopCycle : CI.toplevel_cycles()) {
Expand Down
Loading