Skip to content

Commit 6a97ff2

Browse files
committed
[OpenMP][OMPIRBuilder] Use device shared memory for arg structures
Argument structures are created when sections of the LLVM IR corresponding to an OpenMP construct are outlined into their own function. For this, stack allocations are used. This patch modifies this behavior when compiling for a target device and outlining `parallel`-related IR, so that it uses device shared memory instead of private stack space. This is needed in order for threads to have access to these arguments.
1 parent 0586e88 commit 6a97ff2

File tree

5 files changed

+190
-36
lines changed

5 files changed

+190
-36
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2159,7 +2159,13 @@ class OpenMPIRBuilder {
21592159
/// during finalization.
21602160
struct OutlineInfo {
21612161
using PostOutlineCBTy = std::function<void(Function &)>;
2162+
using CustomArgAllocatorCBTy = std::function<Instruction *(
2163+
BasicBlock *, BasicBlock::iterator, Type *, const Twine &)>;
2164+
using CustomArgDeallocatorCBTy = std::function<Instruction *(
2165+
BasicBlock *, BasicBlock::iterator, Value *, Type *)>;
21622166
PostOutlineCBTy PostOutlineCB;
2167+
CustomArgAllocatorCBTy CustomArgAllocatorCB;
2168+
CustomArgDeallocatorCBTy CustomArgDeallocatorCB;
21632169
BasicBlock *EntryBB, *ExitBB, *OuterAllocaBB;
21642170
SmallVector<Value *, 2> ExcludeArgsFromAggregate;
21652171

llvm/include/llvm/Transforms/Utils/CodeExtractor.h

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
#include "llvm/ADT/ArrayRef.h"
1818
#include "llvm/ADT/DenseMap.h"
1919
#include "llvm/ADT/SetVector.h"
20+
#include "llvm/IR/BasicBlock.h"
2021
#include "llvm/Support/Compiler.h"
2122
#include <limits>
2223

2324
namespace llvm {
2425

2526
template <typename PtrType> class SmallPtrSetImpl;
2627
class AllocaInst;
27-
class BasicBlock;
2828
class BlockFrequency;
2929
class BlockFrequencyInfo;
3030
class BranchProbabilityInfo;
@@ -85,6 +85,10 @@ class CodeExtractorAnalysisCache {
8585
/// 3) Add allocas for any scalar outputs, adding all of the outputs' allocas
8686
/// as arguments, and inserting stores to the arguments for any scalars.
8787
class CodeExtractor {
88+
using CustomArgAllocatorCBTy = std::function<Instruction *(
89+
BasicBlock *, BasicBlock::iterator, Type *, const Twine &)>;
90+
using CustomArgDeallocatorCBTy = std::function<Instruction *(
91+
BasicBlock *, BasicBlock::iterator, Value *, Type *)>;
8892
using ValueSet = SetVector<Value *>;
8993

9094
// Various bits of state computed on construction.
@@ -133,6 +137,25 @@ class CodeExtractorAnalysisCache {
133137
// space.
134138
bool ArgsInZeroAddressSpace;
135139

140+
// If set, this callback will be used to allocate the arguments in the
141+
// caller before passing it to the outlined function holding the extracted
142+
// piece of code.
143+
CustomArgAllocatorCBTy *CustomArgAllocatorCB;
144+
145+
// A block outside of the extraction set where previously introduced
146+
// intermediate allocations can be deallocated. This is only used when an
147+
// custom deallocator is specified.
148+
BasicBlock *DeallocationBlock;
149+
150+
// If set, this callback will be used to deallocate the arguments in the
151+
// caller after running the outlined function holding the extracted piece of
152+
// code. It will not be called if a custom allocator isn't also present.
153+
//
154+
// By default, this will be done at the end of the basic block containing
155+
// the call to the outlined function, except if a deallocation block is
156+
// specified. In that case, that will take precedence.
157+
CustomArgDeallocatorCBTy *CustomArgDeallocatorCB;
158+
136159
public:
137160
/// Create a code extractor for a sequence of blocks.
138161
///
@@ -149,15 +172,20 @@ class CodeExtractorAnalysisCache {
149172
/// the function from which the code is being extracted.
150173
/// If ArgsInZeroAddressSpace param is set to true, then the aggregate
151174
/// param pointer of the outlined function is declared in zero address
152-
/// space.
175+
/// space. If a CustomArgAllocatorCB callback is specified, it will be used
176+
/// to allocate any structures or variable copies needed to pass arguments
177+
/// to the outlined function, rather than using regular allocas.
153178
LLVM_ABI
154179
CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
155180
bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
156181
BranchProbabilityInfo *BPI = nullptr,
157182
AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
158183
bool AllowAlloca = false,
159184
BasicBlock *AllocationBlock = nullptr,
160-
std::string Suffix = "", bool ArgsInZeroAddressSpace = false);
185+
std::string Suffix = "", bool ArgsInZeroAddressSpace = false,
186+
CustomArgAllocatorCBTy *CustomArgAllocatorCB = nullptr,
187+
BasicBlock *DeallocationBlock = nullptr,
188+
CustomArgDeallocatorCBTy *CustomArgDeallocatorCB = nullptr);
161189

162190
/// Perform the extraction, returning the new function.
163191
///
@@ -177,8 +205,9 @@ class CodeExtractorAnalysisCache {
177205
/// newly outlined function.
178206
/// \returns zero when called on a CodeExtractor instance where isEligible
179207
/// returns false.
180-
LLVM_ABI Function *extractCodeRegion(const CodeExtractorAnalysisCache &CEAC,
181-
ValueSet &Inputs, ValueSet &Outputs);
208+
LLVM_ABI Function *
209+
extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, ValueSet &Inputs,
210+
ValueSet &Outputs);
182211

183212
/// Verify that assumption cache isn't stale after a region is extracted.
184213
/// Returns true when verifier finds errors. AssumptionCache is passed as

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 89 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,38 @@ computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
268268
return Result;
269269
}
270270

271+
/// Given a function, if it represents the entry point of a target kernel, this
272+
/// returns the execution mode flags associated to that kernel.
273+
static std::optional<omp::OMPTgtExecModeFlags>
274+
getTargetKernelExecMode(Function &Kernel) {
275+
CallInst *TargetInitCall = nullptr;
276+
for (Instruction &Inst : Kernel.getEntryBlock()) {
277+
if (auto *Call = dyn_cast<CallInst>(&Inst)) {
278+
if (Call->getCalledFunction()->getName() == "__kmpc_target_init") {
279+
TargetInitCall = Call;
280+
break;
281+
}
282+
}
283+
}
284+
285+
if (!TargetInitCall)
286+
return std::nullopt;
287+
288+
// Get the kernel mode information from the global variable associated to the
289+
// first argument to the call to __kmpc_target_init. Refer to
290+
// createTargetInit() to see how this is initialized.
291+
Value *InitOperand = TargetInitCall->getArgOperand(0);
292+
GlobalVariable *KernelEnv = nullptr;
293+
if (auto *Cast = dyn_cast<ConstantExpr>(InitOperand))
294+
KernelEnv = cast<GlobalVariable>(Cast->getOperand(0));
295+
else
296+
KernelEnv = cast<GlobalVariable>(InitOperand);
297+
auto *KernelEnvInit = cast<ConstantStruct>(KernelEnv->getInitializer());
298+
auto *ConfigEnv = cast<ConstantStruct>(KernelEnvInit->getOperand(0));
299+
auto *KernelMode = cast<ConstantInt>(ConfigEnv->getOperand(2));
300+
return static_cast<OMPTgtExecModeFlags>(KernelMode->getZExtValue());
301+
}
302+
271303
/// Make \p Source branch to \p Target.
272304
///
273305
/// Handles two situations:
@@ -702,15 +734,19 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
702734
// CodeExtractor generates correct code for extracted functions
703735
// which are used by OpenMP runtime.
704736
bool ArgsInZeroAddressSpace = Config.isTargetDevice();
705-
CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
706-
/* AggregateArgs */ true,
707-
/* BlockFrequencyInfo */ nullptr,
708-
/* BranchProbabilityInfo */ nullptr,
709-
/* AssumptionCache */ nullptr,
710-
/* AllowVarArgs */ true,
711-
/* AllowAlloca */ true,
712-
/* AllocaBlock*/ OI.OuterAllocaBB,
713-
/* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
737+
CodeExtractor Extractor(
738+
Blocks, /* DominatorTree */ nullptr,
739+
/* AggregateArgs */ true,
740+
/* BlockFrequencyInfo */ nullptr,
741+
/* BranchProbabilityInfo */ nullptr,
742+
/* AssumptionCache */ nullptr,
743+
/* AllowVarArgs */ true,
744+
/* AllowAlloca */ true,
745+
/* AllocaBlock*/ OI.OuterAllocaBB,
746+
/* Suffix */ ".omp_par", ArgsInZeroAddressSpace,
747+
OI.CustomArgAllocatorCB ? &OI.CustomArgAllocatorCB : nullptr,
748+
/* DeallocationBlock */ OI.ExitBB,
749+
OI.CustomArgDeallocatorCB ? &OI.CustomArgDeallocatorCB : nullptr);
714750

715751
LLVM_DEBUG(dbgs() << "Before outlining: " << *OuterFn << "\n");
716752
LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
@@ -1614,6 +1650,50 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
16141650
IfCondition, NumThreads, PrivTID, PrivTIDAddr,
16151651
ThreadID, ToBeDeletedVec);
16161652
};
1653+
1654+
std::optional<omp::OMPTgtExecModeFlags> ExecMode =
1655+
getTargetKernelExecMode(*OuterFn);
1656+
1657+
// If OuterFn is not a Generic kernel, skip custom allocation. This causes
1658+
// the CodeExtractor to follow its default behavior. Otherwise, we need to
1659+
// use device shared memory to allocate argument structures.
1660+
if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC) {
1661+
OI.CustomArgAllocatorCB = [this,
1662+
EntryBB](BasicBlock *, BasicBlock::iterator,
1663+
Type *ArgTy, const Twine &Name) {
1664+
// Instead of using the insertion point provided by the CodeExtractor,
1665+
// here we need to use the block that eventually calls the outlined
1666+
// function for the `parallel` construct.
1667+
//
1668+
// The reason is that the explicit deallocation call will be inserted
1669+
// within the outlined function, whereas the alloca insertion point
1670+
// might actually be located somewhere else in the caller. This becomes
1671+
// a problem when e.g. `parallel` is inside of a `distribute` construct,
1672+
// because the deallocation would be executed multiple times and the
1673+
// allocation just once (outside of the loop).
1674+
//
1675+
// TODO: Ideally, we'd want to do the allocation and deallocation
1676+
// outside of the `parallel` outlined function, hence using here the
1677+
// insertion point provided by the CodeExtractor. We can't do this at
1678+
// the moment because there is currently no way of passing an eligible
1679+
// insertion point for the explicit deallocation to the CodeExtractor,
1680+
// as that block is created (at least when nested inside of
1681+
// `distribute`) sometime after createParallel() completed, so it can't
1682+
// be stored in the OutlineInfo structure here.
1683+
//
1684+
// The current approach results in an explicit allocation and
1685+
// deallocation pair for each `distribute` loop iteration in that case,
1686+
// which is suboptimal.
1687+
return createOMPAllocShared(
1688+
InsertPointTy(EntryBB, EntryBB->getFirstInsertionPt()), ArgTy,
1689+
Name);
1690+
};
1691+
OI.CustomArgDeallocatorCB =
1692+
[this](BasicBlock *BB, BasicBlock::iterator AllocIP, Value *Arg,
1693+
Type *ArgTy) -> Instruction * {
1694+
return createOMPFreeShared(InsertPointTy(BB, AllocIP), Arg, ArgTy);
1695+
};
1696+
}
16171697
} else {
16181698
// Generate OpenMP host runtime call
16191699
OI.PostOutlineCB = [=, ToBeDeletedVec =

llvm/lib/Transforms/Utils/CodeExtractor.cpp

Lines changed: 56 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
#include "llvm/Analysis/BranchProbabilityInfo.h"
2626
#include "llvm/IR/Argument.h"
2727
#include "llvm/IR/Attributes.h"
28-
#include "llvm/IR/BasicBlock.h"
2928
#include "llvm/IR/CFG.h"
3029
#include "llvm/IR/Constant.h"
3130
#include "llvm/IR/Constants.h"
@@ -265,12 +264,18 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
265264
BranchProbabilityInfo *BPI, AssumptionCache *AC,
266265
bool AllowVarArgs, bool AllowAlloca,
267266
BasicBlock *AllocationBlock, std::string Suffix,
268-
bool ArgsInZeroAddressSpace)
267+
bool ArgsInZeroAddressSpace,
268+
CustomArgAllocatorCBTy *CustomArgAllocatorCB,
269+
BasicBlock *DeallocationBlock,
270+
CustomArgDeallocatorCBTy *CustomArgDeallocatorCB)
269271
: DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
270272
BPI(BPI), AC(AC), AllocationBlock(AllocationBlock),
271273
AllowVarArgs(AllowVarArgs),
272274
Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
273-
Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {}
275+
Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace),
276+
CustomArgAllocatorCB(CustomArgAllocatorCB),
277+
DeallocationBlock(DeallocationBlock),
278+
CustomArgDeallocatorCB(CustomArgDeallocatorCB) {}
274279

275280
/// definedInRegion - Return true if the specified value is defined in the
276281
/// extracted region.
@@ -1852,24 +1857,38 @@ CallInst *CodeExtractor::emitReplacerCall(
18521857
if (StructValues.contains(output))
18531858
continue;
18541859

1855-
AllocaInst *alloca = new AllocaInst(
1856-
output->getType(), DL.getAllocaAddrSpace(), nullptr,
1857-
output->getName() + ".loc", AllocaBlock->getFirstInsertionPt());
1858-
params.push_back(alloca);
1859-
ReloadOutputs.push_back(alloca);
1860+
Value *OutAlloc;
1861+
if (CustomArgAllocatorCB)
1862+
OutAlloc = (*CustomArgAllocatorCB)(
1863+
AllocaBlock, AllocaBlock->getFirstInsertionPt(), output->getType(),
1864+
output->getName() + ".loc");
1865+
else
1866+
OutAlloc = new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
1867+
nullptr, output->getName() + ".loc",
1868+
AllocaBlock->getFirstInsertionPt());
1869+
1870+
params.push_back(OutAlloc);
1871+
ReloadOutputs.push_back(OutAlloc);
18601872
}
18611873

1862-
AllocaInst *Struct = nullptr;
1874+
Instruction *Struct = nullptr;
18631875
if (!StructValues.empty()) {
1864-
Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
1865-
"structArg", AllocaBlock->getFirstInsertionPt());
1866-
if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
1867-
auto *StructSpaceCast = new AddrSpaceCastInst(
1868-
Struct, PointerType ::get(Context, 0), "structArg.ascast");
1869-
StructSpaceCast->insertAfter(Struct->getIterator());
1870-
params.push_back(StructSpaceCast);
1871-
} else {
1876+
BasicBlock::iterator StructArgIP = AllocaBlock->getFirstInsertionPt();
1877+
if (CustomArgAllocatorCB) {
1878+
Struct = (*CustomArgAllocatorCB)(AllocaBlock, StructArgIP, StructArgTy,
1879+
"structArg");
18721880
params.push_back(Struct);
1881+
} else {
1882+
Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
1883+
"structArg", StructArgIP);
1884+
if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
1885+
auto *StructSpaceCast = new AddrSpaceCastInst(
1886+
Struct, PointerType ::get(Context, 0), "structArg.ascast");
1887+
StructSpaceCast->insertAfter(Struct->getIterator());
1888+
params.push_back(StructSpaceCast);
1889+
} else {
1890+
params.push_back(Struct);
1891+
}
18731892
}
18741893

18751894
unsigned AggIdx = 0;
@@ -2013,6 +2032,26 @@ CallInst *CodeExtractor::emitReplacerCall(
20132032
insertLifetimeMarkersSurroundingCall(oldFunction->getParent(), LifetimesStart,
20142033
{}, call);
20152034

2035+
// Deallocate variables that used a custom allocator.
2036+
if (CustomArgAllocatorCB && CustomArgDeallocatorCB) {
2037+
BasicBlock *DeallocBlock = codeReplacer;
2038+
BasicBlock::iterator DeallocIP = codeReplacer->end();
2039+
if (DeallocationBlock) {
2040+
DeallocBlock = DeallocationBlock;
2041+
DeallocIP = DeallocationBlock->getFirstInsertionPt();
2042+
}
2043+
2044+
int Index = 0;
2045+
for (Value *Output : outputs) {
2046+
if (!StructValues.contains(Output))
2047+
(*CustomArgDeallocatorCB)(DeallocBlock, DeallocIP,
2048+
ReloadOutputs[Index++], Output->getType());
2049+
}
2050+
2051+
if (Struct)
2052+
(*CustomArgDeallocatorCB)(DeallocBlock, DeallocIP, Struct, StructArgTy);
2053+
}
2054+
20162055
return call;
20172056
}
20182057

mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,21 +56,21 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
5656
// CHECK-SAME: ptr %[[TMP:.*]], ptr %[[TMP0:.*]]) #{{[0-9]+}} {
5757
// CHECK: %[[TMP1:.*]] = alloca [1 x ptr], align 8, addrspace(5)
5858
// CHECK: %[[TMP2:.*]] = addrspacecast ptr addrspace(5) %[[TMP1]] to ptr
59-
// CHECK: %[[STRUCTARG:.*]] = alloca { ptr }, align 8, addrspace(5)
60-
// CHECK: %[[STRUCTARG_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[STRUCTARG]] to ptr
6159
// CHECK: %[[TMP3:.*]] = alloca ptr, align 8, addrspace(5)
6260
// CHECK: %[[TMP4:.*]] = addrspacecast ptr addrspace(5) %[[TMP3]] to ptr
6361
// CHECK: store ptr %[[TMP0]], ptr %[[TMP4]], align 8
6462
// CHECK: %[[TMP5:.*]] = call i32 @__kmpc_target_init(ptr addrspacecast (ptr addrspace(1) @{{.*}} to ptr), ptr %[[TMP]])
6563
// CHECK: %[[EXEC_USER_CODE:.*]] = icmp eq i32 %[[TMP5]], -1
6664
// CHECK: br i1 %[[EXEC_USER_CODE]], label %[[USER_CODE_ENTRY:.*]], label %[[WORKER_EXIT:.*]]
6765
// CHECK: %[[TMP6:.*]] = load ptr, ptr %[[TMP4]], align 8
66+
// CHECK: %[[STRUCTARG:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 8)
6867
// CHECK: %[[OMP_GLOBAL_THREAD_NUM:.*]] = call i32 @__kmpc_global_thread_num(ptr addrspacecast (ptr addrspace(1) @[[GLOB1:[0-9]+]] to ptr))
69-
// CHECK: %[[GEP_:.*]] = getelementptr { ptr }, ptr addrspace(5) %[[STRUCTARG]], i32 0, i32 0
70-
// CHECK: store ptr %[[TMP6]], ptr addrspace(5) %[[GEP_]], align 8
68+
// CHECK: %[[GEP_:.*]] = getelementptr { ptr }, ptr %[[STRUCTARG]], i32 0, i32 0
69+
// CHECK: store ptr %[[TMP6]], ptr %[[GEP_]], align 8
7170
// CHECK: %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[TMP2]], i64 0, i64 0
72-
// CHECK: store ptr %[[STRUCTARG_ASCAST]], ptr %[[TMP7]], align 8
71+
// CHECK: store ptr %[[STRUCTARG]], ptr %[[TMP7]], align 8
7372
// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1)
73+
// CHECK: call void @__kmpc_free_shared(ptr %[[STRUCTARG]], i64 8)
7474
// CHECK: call void @__kmpc_target_deinit()
7575

7676
// CHECK: define internal void @[[FUNC1]](

0 commit comments

Comments
 (0)