Skip to content

Commit 688b614

Browse files
committed
[OpenMP][OMPIRBuilder] Use device shared memory for arg structures
Argument structures are created when sections of the LLVM IR corresponding to an OpenMP construct are outlined into their own function. For this, stack allocations are used. This patch modifies this behavior when compiling for a target device and outlining `parallel`-related IR, so that it uses device shared memory instead of private stack space. This is needed in order for threads to have access to these arguments.
1 parent 1b7dd6c commit 688b614

File tree

5 files changed

+187
-34
lines changed

5 files changed

+187
-34
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2244,7 +2244,13 @@ class OpenMPIRBuilder {
22442244
/// during finalization.
22452245
struct OutlineInfo {
22462246
using PostOutlineCBTy = std::function<void(Function &)>;
2247+
using CustomArgAllocatorCBTy = std::function<Instruction *(
2248+
BasicBlock *, BasicBlock::iterator, Type *, const Twine &)>;
2249+
using CustomArgDeallocatorCBTy = std::function<Instruction *(
2250+
BasicBlock *, BasicBlock::iterator, Value *, Type *)>;
22472251
PostOutlineCBTy PostOutlineCB;
2252+
CustomArgAllocatorCBTy CustomArgAllocatorCB;
2253+
CustomArgDeallocatorCBTy CustomArgDeallocatorCB;
22482254
BasicBlock *EntryBB, *ExitBB, *OuterAllocaBB;
22492255
SmallVector<Value *, 2> ExcludeArgsFromAggregate;
22502256

llvm/include/llvm/Transforms/Utils/CodeExtractor.h

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
#include "llvm/ADT/ArrayRef.h"
1818
#include "llvm/ADT/DenseMap.h"
1919
#include "llvm/ADT/SetVector.h"
20+
#include "llvm/IR/BasicBlock.h"
2021
#include "llvm/Support/Compiler.h"
2122
#include <limits>
2223

2324
namespace llvm {
2425

2526
template <typename PtrType> class SmallPtrSetImpl;
2627
class AllocaInst;
27-
class BasicBlock;
2828
class BlockFrequency;
2929
class BlockFrequencyInfo;
3030
class BranchProbabilityInfo;
@@ -85,6 +85,10 @@ class CodeExtractorAnalysisCache {
8585
/// 3) Add allocas for any scalar outputs, adding all of the outputs' allocas
8686
/// as arguments, and inserting stores to the arguments for any scalars.
8787
class CodeExtractor {
88+
using CustomArgAllocatorCBTy = std::function<Instruction *(
89+
BasicBlock *, BasicBlock::iterator, Type *, const Twine &)>;
90+
using CustomArgDeallocatorCBTy = std::function<Instruction *(
91+
BasicBlock *, BasicBlock::iterator, Value *, Type *)>;
8892
using ValueSet = SetVector<Value *>;
8993

9094
// Various bits of state computed on construction.
@@ -133,6 +137,25 @@ class CodeExtractorAnalysisCache {
133137
// space.
134138
bool ArgsInZeroAddressSpace;
135139

140+
// If set, this callback will be used to allocate the arguments in the
141+
// caller before passing it to the outlined function holding the extracted
142+
// piece of code.
143+
CustomArgAllocatorCBTy *CustomArgAllocatorCB;
144+
145+
// A block outside of the extraction set where previously introduced
146+
// intermediate allocations can be deallocated. This is only used when an
147+
// custom deallocator is specified.
148+
BasicBlock *DeallocationBlock;
149+
150+
// If set, this callback will be used to deallocate the arguments in the
151+
// caller after running the outlined function holding the extracted piece of
152+
// code. It will not be called if a custom allocator isn't also present.
153+
//
154+
// By default, this will be done at the end of the basic block containing
155+
// the call to the outlined function, except if a deallocation block is
156+
// specified. In that case, that will take precedence.
157+
CustomArgDeallocatorCBTy *CustomArgDeallocatorCB;
158+
136159
public:
137160
/// Create a code extractor for a sequence of blocks.
138161
///
@@ -149,15 +172,20 @@ class CodeExtractorAnalysisCache {
149172
/// the function from which the code is being extracted.
150173
/// If ArgsInZeroAddressSpace param is set to true, then the aggregate
151174
/// param pointer of the outlined function is declared in zero address
152-
/// space.
175+
/// space. If a CustomArgAllocatorCB callback is specified, it will be used
176+
/// to allocate any structures or variable copies needed to pass arguments
177+
/// to the outlined function, rather than using regular allocas.
153178
LLVM_ABI
154179
CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
155180
bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
156181
BranchProbabilityInfo *BPI = nullptr,
157182
AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
158183
bool AllowAlloca = false,
159184
BasicBlock *AllocationBlock = nullptr,
160-
std::string Suffix = "", bool ArgsInZeroAddressSpace = false);
185+
std::string Suffix = "", bool ArgsInZeroAddressSpace = false,
186+
CustomArgAllocatorCBTy *CustomArgAllocatorCB = nullptr,
187+
BasicBlock *DeallocationBlock = nullptr,
188+
CustomArgDeallocatorCBTy *CustomArgDeallocatorCB = nullptr);
161189

162190
/// Perform the extraction, returning the new function.
163191
///

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 89 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,38 @@ computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
280280
return Result;
281281
}
282282

283+
/// Given a function, if it represents the entry point of a target kernel, this
284+
/// returns the execution mode flags associated to that kernel.
285+
static std::optional<omp::OMPTgtExecModeFlags>
286+
getTargetKernelExecMode(Function &Kernel) {
287+
CallInst *TargetInitCall = nullptr;
288+
for (Instruction &Inst : Kernel.getEntryBlock()) {
289+
if (auto *Call = dyn_cast<CallInst>(&Inst)) {
290+
if (Call->getCalledFunction()->getName() == "__kmpc_target_init") {
291+
TargetInitCall = Call;
292+
break;
293+
}
294+
}
295+
}
296+
297+
if (!TargetInitCall)
298+
return std::nullopt;
299+
300+
// Get the kernel mode information from the global variable associated to the
301+
// first argument to the call to __kmpc_target_init. Refer to
302+
// createTargetInit() to see how this is initialized.
303+
Value *InitOperand = TargetInitCall->getArgOperand(0);
304+
GlobalVariable *KernelEnv = nullptr;
305+
if (auto *Cast = dyn_cast<ConstantExpr>(InitOperand))
306+
KernelEnv = cast<GlobalVariable>(Cast->getOperand(0));
307+
else
308+
KernelEnv = cast<GlobalVariable>(InitOperand);
309+
auto *KernelEnvInit = cast<ConstantStruct>(KernelEnv->getInitializer());
310+
auto *ConfigEnv = cast<ConstantStruct>(KernelEnvInit->getOperand(0));
311+
auto *KernelMode = cast<ConstantInt>(ConfigEnv->getOperand(2));
312+
return static_cast<OMPTgtExecModeFlags>(KernelMode->getZExtValue());
313+
}
314+
283315
/// Make \p Source branch to \p Target.
284316
///
285317
/// Handles two situations:
@@ -714,15 +746,19 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
714746
// CodeExtractor generates correct code for extracted functions
715747
// which are used by OpenMP runtime.
716748
bool ArgsInZeroAddressSpace = Config.isTargetDevice();
717-
CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
718-
/* AggregateArgs */ true,
719-
/* BlockFrequencyInfo */ nullptr,
720-
/* BranchProbabilityInfo */ nullptr,
721-
/* AssumptionCache */ nullptr,
722-
/* AllowVarArgs */ true,
723-
/* AllowAlloca */ true,
724-
/* AllocaBlock*/ OI.OuterAllocaBB,
725-
/* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
749+
CodeExtractor Extractor(
750+
Blocks, /* DominatorTree */ nullptr,
751+
/* AggregateArgs */ true,
752+
/* BlockFrequencyInfo */ nullptr,
753+
/* BranchProbabilityInfo */ nullptr,
754+
/* AssumptionCache */ nullptr,
755+
/* AllowVarArgs */ true,
756+
/* AllowAlloca */ true,
757+
/* AllocaBlock*/ OI.OuterAllocaBB,
758+
/* Suffix */ ".omp_par", ArgsInZeroAddressSpace,
759+
OI.CustomArgAllocatorCB ? &OI.CustomArgAllocatorCB : nullptr,
760+
/* DeallocationBlock */ OI.ExitBB,
761+
OI.CustomArgDeallocatorCB ? &OI.CustomArgDeallocatorCB : nullptr);
726762

727763
LLVM_DEBUG(dbgs() << "Before outlining: " << *OuterFn << "\n");
728764
LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
@@ -1625,6 +1661,50 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
16251661
IfCondition, NumThreads, PrivTID, PrivTIDAddr,
16261662
ThreadID, ToBeDeletedVec);
16271663
};
1664+
1665+
std::optional<omp::OMPTgtExecModeFlags> ExecMode =
1666+
getTargetKernelExecMode(*OuterFn);
1667+
1668+
// If OuterFn is not a Generic kernel, skip custom allocation. This causes
1669+
// the CodeExtractor to follow its default behavior. Otherwise, we need to
1670+
// use device shared memory to allocate argument structures.
1671+
if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC) {
1672+
OI.CustomArgAllocatorCB = [this,
1673+
EntryBB](BasicBlock *, BasicBlock::iterator,
1674+
Type *ArgTy, const Twine &Name) {
1675+
// Instead of using the insertion point provided by the CodeExtractor,
1676+
// here we need to use the block that eventually calls the outlined
1677+
// function for the `parallel` construct.
1678+
//
1679+
// The reason is that the explicit deallocation call will be inserted
1680+
// within the outlined function, whereas the alloca insertion point
1681+
// might actually be located somewhere else in the caller. This becomes
1682+
// a problem when e.g. `parallel` is inside of a `distribute` construct,
1683+
// because the deallocation would be executed multiple times and the
1684+
// allocation just once (outside of the loop).
1685+
//
1686+
// TODO: Ideally, we'd want to do the allocation and deallocation
1687+
// outside of the `parallel` outlined function, hence using here the
1688+
// insertion point provided by the CodeExtractor. We can't do this at
1689+
// the moment because there is currently no way of passing an eligible
1690+
// insertion point for the explicit deallocation to the CodeExtractor,
1691+
// as that block is created (at least when nested inside of
1692+
// `distribute`) sometime after createParallel() completed, so it can't
1693+
// be stored in the OutlineInfo structure here.
1694+
//
1695+
// The current approach results in an explicit allocation and
1696+
// deallocation pair for each `distribute` loop iteration in that case,
1697+
// which is suboptimal.
1698+
return createOMPAllocShared(
1699+
InsertPointTy(EntryBB, EntryBB->getFirstInsertionPt()), ArgTy,
1700+
Name);
1701+
};
1702+
OI.CustomArgDeallocatorCB =
1703+
[this](BasicBlock *BB, BasicBlock::iterator AllocIP, Value *Arg,
1704+
Type *ArgTy) -> Instruction * {
1705+
return createOMPFreeShared(InsertPointTy(BB, AllocIP), Arg, ArgTy);
1706+
};
1707+
}
16281708
} else {
16291709
// Generate OpenMP host runtime call
16301710
OI.PostOutlineCB = [=, ToBeDeletedVec =

llvm/lib/Transforms/Utils/CodeExtractor.cpp

Lines changed: 56 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
#include "llvm/Analysis/BranchProbabilityInfo.h"
2626
#include "llvm/IR/Argument.h"
2727
#include "llvm/IR/Attributes.h"
28-
#include "llvm/IR/BasicBlock.h"
2928
#include "llvm/IR/CFG.h"
3029
#include "llvm/IR/Constant.h"
3130
#include "llvm/IR/Constants.h"
@@ -265,12 +264,18 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
265264
BranchProbabilityInfo *BPI, AssumptionCache *AC,
266265
bool AllowVarArgs, bool AllowAlloca,
267266
BasicBlock *AllocationBlock, std::string Suffix,
268-
bool ArgsInZeroAddressSpace)
267+
bool ArgsInZeroAddressSpace,
268+
CustomArgAllocatorCBTy *CustomArgAllocatorCB,
269+
BasicBlock *DeallocationBlock,
270+
CustomArgDeallocatorCBTy *CustomArgDeallocatorCB)
269271
: DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
270272
BPI(BPI), AC(AC), AllocationBlock(AllocationBlock),
271273
AllowVarArgs(AllowVarArgs),
272274
Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
273-
Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {}
275+
Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace),
276+
CustomArgAllocatorCB(CustomArgAllocatorCB),
277+
DeallocationBlock(DeallocationBlock),
278+
CustomArgDeallocatorCB(CustomArgDeallocatorCB) {}
274279

275280
/// definedInRegion - Return true if the specified value is defined in the
276281
/// extracted region.
@@ -1850,24 +1855,38 @@ CallInst *CodeExtractor::emitReplacerCall(
18501855
if (StructValues.contains(output))
18511856
continue;
18521857

1853-
AllocaInst *alloca = new AllocaInst(
1854-
output->getType(), DL.getAllocaAddrSpace(), nullptr,
1855-
output->getName() + ".loc", AllocaBlock->getFirstInsertionPt());
1856-
params.push_back(alloca);
1857-
ReloadOutputs.push_back(alloca);
1858+
Value *OutAlloc;
1859+
if (CustomArgAllocatorCB)
1860+
OutAlloc = (*CustomArgAllocatorCB)(
1861+
AllocaBlock, AllocaBlock->getFirstInsertionPt(), output->getType(),
1862+
output->getName() + ".loc");
1863+
else
1864+
OutAlloc = new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
1865+
nullptr, output->getName() + ".loc",
1866+
AllocaBlock->getFirstInsertionPt());
1867+
1868+
params.push_back(OutAlloc);
1869+
ReloadOutputs.push_back(OutAlloc);
18581870
}
18591871

1860-
AllocaInst *Struct = nullptr;
1872+
Instruction *Struct = nullptr;
18611873
if (!StructValues.empty()) {
1862-
Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
1863-
"structArg", AllocaBlock->getFirstInsertionPt());
1864-
if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
1865-
auto *StructSpaceCast = new AddrSpaceCastInst(
1866-
Struct, PointerType ::get(Context, 0), "structArg.ascast");
1867-
StructSpaceCast->insertAfter(Struct->getIterator());
1868-
params.push_back(StructSpaceCast);
1869-
} else {
1874+
BasicBlock::iterator StructArgIP = AllocaBlock->getFirstInsertionPt();
1875+
if (CustomArgAllocatorCB) {
1876+
Struct = (*CustomArgAllocatorCB)(AllocaBlock, StructArgIP, StructArgTy,
1877+
"structArg");
18701878
params.push_back(Struct);
1879+
} else {
1880+
Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
1881+
"structArg", StructArgIP);
1882+
if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
1883+
auto *StructSpaceCast = new AddrSpaceCastInst(
1884+
Struct, PointerType ::get(Context, 0), "structArg.ascast");
1885+
StructSpaceCast->insertAfter(Struct->getIterator());
1886+
params.push_back(StructSpaceCast);
1887+
} else {
1888+
params.push_back(Struct);
1889+
}
18711890
}
18721891

18731892
unsigned AggIdx = 0;
@@ -2011,6 +2030,26 @@ CallInst *CodeExtractor::emitReplacerCall(
20112030
insertLifetimeMarkersSurroundingCall(oldFunction->getParent(), LifetimesStart,
20122031
{}, call);
20132032

2033+
// Deallocate variables that used a custom allocator.
2034+
if (CustomArgAllocatorCB && CustomArgDeallocatorCB) {
2035+
BasicBlock *DeallocBlock = codeReplacer;
2036+
BasicBlock::iterator DeallocIP = codeReplacer->end();
2037+
if (DeallocationBlock) {
2038+
DeallocBlock = DeallocationBlock;
2039+
DeallocIP = DeallocationBlock->getFirstInsertionPt();
2040+
}
2041+
2042+
int Index = 0;
2043+
for (Value *Output : outputs) {
2044+
if (!StructValues.contains(Output))
2045+
(*CustomArgDeallocatorCB)(DeallocBlock, DeallocIP,
2046+
ReloadOutputs[Index++], Output->getType());
2047+
}
2048+
2049+
if (Struct)
2050+
(*CustomArgDeallocatorCB)(DeallocBlock, DeallocIP, Struct, StructArgTy);
2051+
}
2052+
20142053
return call;
20152054
}
20162055

mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,21 +56,21 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
5656
// CHECK-SAME: ptr %[[TMP:.*]], ptr %[[TMP0:.*]]) #{{[0-9]+}} {
5757
// CHECK: %[[TMP1:.*]] = alloca [1 x ptr], align 8, addrspace(5)
5858
// CHECK: %[[TMP2:.*]] = addrspacecast ptr addrspace(5) %[[TMP1]] to ptr
59-
// CHECK: %[[STRUCTARG:.*]] = alloca { ptr }, align 8, addrspace(5)
60-
// CHECK: %[[STRUCTARG_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[STRUCTARG]] to ptr
6159
// CHECK: %[[TMP3:.*]] = alloca ptr, align 8, addrspace(5)
6260
// CHECK: %[[TMP4:.*]] = addrspacecast ptr addrspace(5) %[[TMP3]] to ptr
6361
// CHECK: store ptr %[[TMP0]], ptr %[[TMP4]], align 8
6462
// CHECK: %[[TMP5:.*]] = call i32 @__kmpc_target_init(ptr addrspacecast (ptr addrspace(1) @{{.*}} to ptr), ptr %[[TMP]])
6563
// CHECK: %[[EXEC_USER_CODE:.*]] = icmp eq i32 %[[TMP5]], -1
6664
// CHECK: br i1 %[[EXEC_USER_CODE]], label %[[USER_CODE_ENTRY:.*]], label %[[WORKER_EXIT:.*]]
6765
// CHECK: %[[TMP6:.*]] = load ptr, ptr %[[TMP4]], align 8
66+
// CHECK: %[[STRUCTARG:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 8)
6867
// CHECK: %[[OMP_GLOBAL_THREAD_NUM:.*]] = call i32 @__kmpc_global_thread_num(ptr addrspacecast (ptr addrspace(1) @[[GLOB1:[0-9]+]] to ptr))
69-
// CHECK: %[[GEP_:.*]] = getelementptr { ptr }, ptr addrspace(5) %[[STRUCTARG]], i32 0, i32 0
70-
// CHECK: store ptr %[[TMP6]], ptr addrspace(5) %[[GEP_]], align 8
68+
// CHECK: %[[GEP_:.*]] = getelementptr { ptr }, ptr %[[STRUCTARG]], i32 0, i32 0
69+
// CHECK: store ptr %[[TMP6]], ptr %[[GEP_]], align 8
7170
// CHECK: %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[TMP2]], i64 0, i64 0
72-
// CHECK: store ptr %[[STRUCTARG_ASCAST]], ptr %[[TMP7]], align 8
71+
// CHECK: store ptr %[[STRUCTARG]], ptr %[[TMP7]], align 8
7372
// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1)
73+
// CHECK: call void @__kmpc_free_shared(ptr %[[STRUCTARG]], i64 8)
7474
// CHECK: call void @__kmpc_target_deinit()
7575

7676
// CHECK: define internal void @[[FUNC1]](

0 commit comments

Comments
 (0)