Skip to content

Commit b6e9849

Browse files
committed
[OpenMP][OMPIRBuilder] Support parallel in Generic kernels
This patch introduces codegen logic to produce a wrapper function argument for the `__kmpc_parallel_51` DeviceRTL function needed to handle arguments passed using device shared memory in Generic mode.
1 parent 8b34402 commit b6e9849

File tree

2 files changed

+116
-9
lines changed

2 files changed

+116
-9
lines changed

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,6 +1323,86 @@ Error OpenMPIRBuilder::emitCancelationCheckImpl(
13231323
return Error::success();
13241324
}
13251325

1326+
// Create wrapper function used to gather the outlined function's argument
1327+
// structure from a shared buffer and to forward them to it when running in
1328+
// Generic mode.
1329+
//
1330+
// The outlined function is expected to receive 2 integer arguments followed by
1331+
// an optional pointer argument to an argument structure holding the rest.
1332+
static Function *createTargetParallelWrapper(OpenMPIRBuilder *OMPIRBuilder,
1333+
Function &OutlinedFn) {
1334+
size_t NumArgs = OutlinedFn.arg_size();
1335+
assert((NumArgs == 2 || NumArgs == 3) &&
1336+
"expected a 2-3 argument parallel outlined function");
1337+
bool UseArgStruct = NumArgs == 3;
1338+
1339+
IRBuilder<> &Builder = OMPIRBuilder->Builder;
1340+
IRBuilder<>::InsertPointGuard IPG(Builder);
1341+
auto *FnTy = FunctionType::get(Builder.getVoidTy(),
1342+
{Builder.getInt16Ty(), Builder.getInt32Ty()},
1343+
/*isVarArg=*/false);
1344+
auto *WrapperFn =
1345+
Function::Create(FnTy, GlobalValue::InternalLinkage,
1346+
OutlinedFn.getName() + ".wrapper", OMPIRBuilder->M);
1347+
1348+
WrapperFn->addParamAttr(0, Attribute::NoUndef);
1349+
WrapperFn->addParamAttr(0, Attribute::ZExt);
1350+
WrapperFn->addParamAttr(1, Attribute::NoUndef);
1351+
1352+
BasicBlock *EntryBB =
1353+
BasicBlock::Create(OMPIRBuilder->M.getContext(), "entry", WrapperFn);
1354+
Builder.SetInsertPoint(EntryBB);
1355+
1356+
// Allocation.
1357+
Value *AddrAlloca = Builder.CreateAlloca(Builder.getInt32Ty(),
1358+
/*ArraySize=*/nullptr, "addr");
1359+
AddrAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
1360+
AddrAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
1361+
AddrAlloca->getName() + ".ascast");
1362+
1363+
Value *ZeroAlloca = Builder.CreateAlloca(Builder.getInt32Ty(),
1364+
/*ArraySize=*/nullptr, "zero");
1365+
ZeroAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
1366+
ZeroAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
1367+
ZeroAlloca->getName() + ".ascast");
1368+
1369+
Value *ArgsAlloca = nullptr;
1370+
if (UseArgStruct) {
1371+
ArgsAlloca = Builder.CreateAlloca(Builder.getPtrTy(),
1372+
/*ArraySize=*/nullptr, "global_args");
1373+
ArgsAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
1374+
ArgsAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
1375+
ArgsAlloca->getName() + ".ascast");
1376+
}
1377+
1378+
// Initialization.
1379+
Builder.CreateStore(WrapperFn->getArg(1), AddrAlloca);
1380+
Builder.CreateStore(Builder.getInt32(0), ZeroAlloca);
1381+
if (UseArgStruct) {
1382+
Builder.CreateCall(
1383+
OMPIRBuilder->getOrCreateRuntimeFunctionPtr(
1384+
llvm::omp::RuntimeFunction::OMPRTL___kmpc_get_shared_variables),
1385+
{ArgsAlloca});
1386+
}
1387+
1388+
SmallVector<Value *, 3> Args{AddrAlloca, ZeroAlloca};
1389+
1390+
// Load structArg from global_args.
1391+
if (UseArgStruct) {
1392+
Value *StructArg = Builder.CreateLoad(Builder.getPtrTy(), ArgsAlloca);
1393+
StructArg = Builder.CreateInBoundsGEP(Builder.getPtrTy(), StructArg,
1394+
{Builder.getInt64(0)});
1395+
StructArg = Builder.CreateLoad(Builder.getPtrTy(), StructArg, "structArg");
1396+
Args.push_back(StructArg);
1397+
}
1398+
1399+
// Call the outlined function holding the parallel body.
1400+
Builder.CreateCall(&OutlinedFn, Args);
1401+
Builder.CreateRetVoid();
1402+
1403+
return WrapperFn;
1404+
}
1405+
13261406
// Callback used to create OpenMP runtime calls to support
13271407
// omp parallel clause for the device.
13281408
// We need to use this callback to replace call to the OutlinedFn in OuterFn
@@ -1332,6 +1412,10 @@ static void targetParallelCallback(
13321412
BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
13331413
Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
13341414
Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
1415+
assert(OutlinedFn.arg_size() >= 2 &&
1416+
"Expected at least tid and bounded tid as arguments");
1417+
unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1418+
13351419
// Add some known attributes.
13361420
IRBuilder<> &Builder = OMPIRBuilder->Builder;
13371421
OutlinedFn.addParamAttr(0, Attribute::NoAlias);
@@ -1340,17 +1424,12 @@ static void targetParallelCallback(
13401424
OutlinedFn.addParamAttr(1, Attribute::NoUndef);
13411425
OutlinedFn.addFnAttr(Attribute::NoUnwind);
13421426

1343-
assert(OutlinedFn.arg_size() >= 2 &&
1344-
"Expected at least tid and bounded tid as arguments");
1345-
unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1346-
13471427
CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
13481428
assert(CI && "Expected call instruction to outlined function");
13491429
CI->getParent()->setName("omp_parallel");
13501430

13511431
Builder.SetInsertPoint(CI);
13521432
Type *PtrTy = OMPIRBuilder->VoidPtr;
1353-
Value *NullPtrValue = Constant::getNullValue(PtrTy);
13541433

13551434
// Add alloca for kernel args
13561435
OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
@@ -1376,6 +1455,15 @@ static void targetParallelCallback(
13761455
IfCondition ? Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32)
13771456
: Builder.getInt32(1);
13781457

1458+
// If this is not a Generic kernel, we can skip generating the wrapper.
1459+
std::optional<omp::OMPTgtExecModeFlags> ExecMode =
1460+
getTargetKernelExecMode(*OuterFn);
1461+
Value *WrapperFn;
1462+
if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC)
1463+
WrapperFn = createTargetParallelWrapper(OMPIRBuilder, OutlinedFn);
1464+
else
1465+
WrapperFn = Constant::getNullValue(PtrTy);
1466+
13791467
// Build kmpc_parallel_51 call
13801468
Value *Parallel51CallArgs[] = {
13811469
/* identifier*/ Ident,
@@ -1384,7 +1472,7 @@ static void targetParallelCallback(
13841472
/* number of threads */ NumThreads ? NumThreads : Builder.getInt32(-1),
13851473
/* Proc bind */ Builder.getInt32(-1),
13861474
/* outlined function */ &OutlinedFn,
1387-
/* wrapper function */ NullPtrValue,
1475+
/* wrapper function */ WrapperFn,
13881476
/* arguments of the outlined funciton*/ Args,
13891477
/* number of arguments */ Builder.getInt64(NumCapturedVars)};
13901478

mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
6969
// CHECK: store ptr %[[TMP6]], ptr %[[GEP_]], align 8
7070
// CHECK: %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[TMP2]], i64 0, i64 0
7171
// CHECK: store ptr %[[STRUCTARG]], ptr %[[TMP7]], align 8
72-
// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1)
72+
// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr @[[FUNC1_WRAPPER:.*]], ptr %[[TMP2]], i64 1)
7373
// CHECK: call void @__kmpc_free_shared(ptr %[[STRUCTARG]], i64 8)
7474
// CHECK: call void @__kmpc_target_deinit()
7575

@@ -84,7 +84,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
8484
// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (
8585
// CHECK-SAME: ptr addrspace(1) @[[NUM_THREADS_GLOB:[0-9]+]] to ptr),
8686
// CHECK-SAME: i32 [[NUM_THREADS_TMP0:%.*]], i32 1, i32 156,
87-
// CHECK-SAME: i32 -1, ptr [[FUNC_NUM_THREADS1:@.*]], ptr null, ptr [[NUM_THREADS_TMP1:%.*]], i64 1)
87+
// CHECK-SAME: i32 -1, ptr @[[FUNC_NUM_THREADS1:.*]], ptr @[[FUNC2_WRAPPER:.*]], ptr [[NUM_THREADS_TMP1:%.*]], i64 1)
8888

8989
// One of the arguments of kmpc_parallel_51 function is responsible for handling if clause
9090
// of omp parallel construct for target region. If this argument is nonzero,
@@ -105,4 +105,23 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
105105
// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (
106106
// CHECK-SAME: ptr addrspace(1) {{.*}} to ptr),
107107
// CHECK-SAME: i32 {{.*}}, i32 %[[IFCOND_TMP4]], i32 -1,
108-
// CHECK-SAME: i32 -1, ptr {{.*}}, ptr null, ptr {{.*}}, i64 1)
108+
// CHECK-SAME: i32 -1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i64 1)
109+
110+
// CHECK: define internal void @[[FUNC1_WRAPPER]](i16 noundef zeroext %{{.*}}, i32 noundef %[[ADDR:.*]])
111+
// CHECK: %[[ADDR_ALLOCA:.*]] = alloca i32, align 4, addrspace(5)
112+
// CHECK: %[[ADDR_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ADDR_ALLOCA]] to ptr
113+
// CHECK: %[[ZERO_ALLOCA:.*]] = alloca i32, align 4, addrspace(5)
114+
// CHECK: %[[ZERO_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ZERO_ALLOCA]] to ptr
115+
// CHECK: %[[ARGS_ALLOCA:.*]] = alloca ptr, align 8, addrspace(5)
116+
// CHECK: %[[ARGS_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ARGS_ALLOCA]] to ptr
117+
// CHECK: store i32 %[[ADDR]], ptr %[[ADDR_ASCAST]]
118+
// CHECK: store i32 0, ptr %[[ZERO_ASCAST]]
119+
// CHECK: call void @__kmpc_get_shared_variables(ptr %[[ARGS_ASCAST]])
120+
// CHECK: %[[LOAD_ARGS:.*]] = load ptr, ptr %[[ARGS_ASCAST]], align 8
121+
// CHECK: %[[FIRST_ARG:.*]] = getelementptr inbounds ptr, ptr %[[LOAD_ARGS]], i64 0
122+
// CHECK: %[[STRUCTARG:.*]] = load ptr, ptr %[[FIRST_ARG]], align 8
123+
// CHECK: call void @[[FUNC1]](ptr %[[ADDR_ASCAST]], ptr %[[ZERO_ASCAST]], ptr %[[STRUCTARG]])
124+
125+
// CHECK: define internal void @[[FUNC2_WRAPPER]](i16 noundef zeroext %{{.*}}, i32 noundef %{{.*}})
126+
// CHECK-NOT: define
127+
// CHECK: call void @[[FUNC_NUM_THREADS1]]({{.*}})

0 commit comments

Comments
 (0)