-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[OpenMP][OMPIRBuilder] Support parallel in Generic kernels #150926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/skatrak/flang-generic-04-parallel-args
Are you sure you want to change the base?
[OpenMP][OMPIRBuilder] Support parallel in Generic kernels #150926
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-flang-openmp Author: Sergio Afonso (skatrak) ChangesThis patch introduces codegen logic to produce a wrapper function argument for the Full diff: https://github.com/llvm/llvm-project/pull/150926.diff 2 Files Affected:
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index a913958c0de9a..0005a72e86324 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -1323,6 +1323,86 @@ Error OpenMPIRBuilder::emitCancelationCheckImpl(
return Error::success();
}
+// Create wrapper function used to gather the outlined function's argument
+// structure from a shared buffer and to forward them to it when running in
+// Generic mode.
+//
+// The outlined function is expected to receive 2 integer arguments followed by
+// an optional pointer argument to an argument structure holding the rest.
+static Function *createTargetParallelWrapper(OpenMPIRBuilder *OMPIRBuilder,
+ Function &OutlinedFn) {
+ size_t NumArgs = OutlinedFn.arg_size();
+ assert((NumArgs == 2 || NumArgs == 3) &&
+ "expected a 2-3 argument parallel outlined function");
+ bool UseArgStruct = NumArgs == 3;
+
+ IRBuilder<> &Builder = OMPIRBuilder->Builder;
+ IRBuilder<>::InsertPointGuard IPG(Builder);
+ auto *FnTy = FunctionType::get(Builder.getVoidTy(),
+ {Builder.getInt16Ty(), Builder.getInt32Ty()},
+ /*isVarArg=*/false);
+ auto *WrapperFn =
+ Function::Create(FnTy, GlobalValue::InternalLinkage,
+ OutlinedFn.getName() + ".wrapper", OMPIRBuilder->M);
+
+ WrapperFn->addParamAttr(0, Attribute::NoUndef);
+ WrapperFn->addParamAttr(0, Attribute::ZExt);
+ WrapperFn->addParamAttr(1, Attribute::NoUndef);
+
+ BasicBlock *EntryBB =
+ BasicBlock::Create(OMPIRBuilder->M.getContext(), "entry", WrapperFn);
+ Builder.SetInsertPoint(EntryBB);
+
+ // Allocation.
+ Value *AddrAlloca = Builder.CreateAlloca(Builder.getInt32Ty(),
+ /*ArraySize=*/nullptr, "addr");
+ AddrAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ AddrAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
+ AddrAlloca->getName() + ".ascast");
+
+ Value *ZeroAlloca = Builder.CreateAlloca(Builder.getInt32Ty(),
+ /*ArraySize=*/nullptr, "zero");
+ ZeroAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ZeroAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
+ ZeroAlloca->getName() + ".ascast");
+
+ Value *ArgsAlloca = nullptr;
+ if (UseArgStruct) {
+ ArgsAlloca = Builder.CreateAlloca(Builder.getPtrTy(),
+ /*ArraySize=*/nullptr, "global_args");
+ ArgsAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ArgsAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
+ ArgsAlloca->getName() + ".ascast");
+ }
+
+ // Initialization.
+ Builder.CreateStore(WrapperFn->getArg(1), AddrAlloca);
+ Builder.CreateStore(Builder.getInt32(0), ZeroAlloca);
+ if (UseArgStruct) {
+ Builder.CreateCall(
+ OMPIRBuilder->getOrCreateRuntimeFunctionPtr(
+ llvm::omp::RuntimeFunction::OMPRTL___kmpc_get_shared_variables),
+ {ArgsAlloca});
+ }
+
+ SmallVector<Value *, 3> Args{AddrAlloca, ZeroAlloca};
+
+ // Load structArg from global_args.
+ if (UseArgStruct) {
+ Value *StructArg = Builder.CreateLoad(Builder.getPtrTy(), ArgsAlloca);
+ StructArg = Builder.CreateInBoundsGEP(Builder.getPtrTy(), StructArg,
+ {Builder.getInt64(0)});
+ StructArg = Builder.CreateLoad(Builder.getPtrTy(), StructArg, "structArg");
+ Args.push_back(StructArg);
+ }
+
+ // Call the outlined function holding the parallel body.
+ Builder.CreateCall(&OutlinedFn, Args);
+ Builder.CreateRetVoid();
+
+ return WrapperFn;
+}
+
// Callback used to create OpenMP runtime calls to support
// omp parallel clause for the device.
// We need to use this callback to replace call to the OutlinedFn in OuterFn
@@ -1332,6 +1412,10 @@ static void targetParallelCallback(
BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
+ assert(OutlinedFn.arg_size() >= 2 &&
+ "Expected at least tid and bounded tid as arguments");
+ unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
+
// Add some known attributes.
IRBuilder<> &Builder = OMPIRBuilder->Builder;
OutlinedFn.addParamAttr(0, Attribute::NoAlias);
@@ -1340,17 +1424,12 @@ static void targetParallelCallback(
OutlinedFn.addParamAttr(1, Attribute::NoUndef);
OutlinedFn.addFnAttr(Attribute::NoUnwind);
- assert(OutlinedFn.arg_size() >= 2 &&
- "Expected at least tid and bounded tid as arguments");
- unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
-
CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
assert(CI && "Expected call instruction to outlined function");
CI->getParent()->setName("omp_parallel");
Builder.SetInsertPoint(CI);
Type *PtrTy = OMPIRBuilder->VoidPtr;
- Value *NullPtrValue = Constant::getNullValue(PtrTy);
// Add alloca for kernel args
OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
@@ -1376,6 +1455,15 @@ static void targetParallelCallback(
IfCondition ? Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32)
: Builder.getInt32(1);
+ // If this is not a Generic kernel, we can skip generating the wrapper.
+ std::optional<omp::OMPTgtExecModeFlags> ExecMode =
+ getTargetKernelExecMode(*OuterFn);
+ Value *WrapperFn;
+ if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC)
+ WrapperFn = createTargetParallelWrapper(OMPIRBuilder, OutlinedFn);
+ else
+ WrapperFn = Constant::getNullValue(PtrTy);
+
// Build kmpc_parallel_51 call
Value *Parallel51CallArgs[] = {
/* identifier*/ Ident,
@@ -1384,7 +1472,7 @@ static void targetParallelCallback(
/* number of threads */ NumThreads ? NumThreads : Builder.getInt32(-1),
/* Proc bind */ Builder.getInt32(-1),
/* outlined function */ &OutlinedFn,
- /* wrapper function */ NullPtrValue,
+ /* wrapper function */ WrapperFn,
/* arguments of the outlined funciton*/ Args,
/* number of arguments */ Builder.getInt64(NumCapturedVars)};
diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
index 504e39c96f008..ca998b4672ba0 100644
--- a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
@@ -69,7 +69,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// CHECK: store ptr %[[TMP6]], ptr %[[GEP_]], align 8
// CHECK: %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[TMP2]], i64 0, i64 0
// CHECK: store ptr %[[STRUCTARG]], ptr %[[TMP7]], align 8
-// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1)
+// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr @[[FUNC1_WRAPPER:.*]], ptr %[[TMP2]], i64 1)
// CHECK: call void @__kmpc_free_shared(ptr %[[STRUCTARG]], i64 8)
// CHECK: call void @__kmpc_target_deinit()
@@ -84,7 +84,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (
// CHECK-SAME: ptr addrspace(1) @[[NUM_THREADS_GLOB:[0-9]+]] to ptr),
// CHECK-SAME: i32 [[NUM_THREADS_TMP0:%.*]], i32 1, i32 156,
-// CHECK-SAME: i32 -1, ptr [[FUNC_NUM_THREADS1:@.*]], ptr null, ptr [[NUM_THREADS_TMP1:%.*]], i64 1)
+// CHECK-SAME: i32 -1, ptr @[[FUNC_NUM_THREADS1:.*]], ptr @[[FUNC2_WRAPPER:.*]], ptr [[NUM_THREADS_TMP1:%.*]], i64 1)
// One of the arguments of kmpc_parallel_51 function is responsible for handling if clause
// of omp parallel construct for target region. If this argument is nonzero,
@@ -105,4 +105,23 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (
// CHECK-SAME: ptr addrspace(1) {{.*}} to ptr),
// CHECK-SAME: i32 {{.*}}, i32 %[[IFCOND_TMP4]], i32 -1,
-// CHECK-SAME: i32 -1, ptr {{.*}}, ptr null, ptr {{.*}}, i64 1)
+// CHECK-SAME: i32 -1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i64 1)
+
+// CHECK: define internal void @[[FUNC1_WRAPPER]](i16 noundef zeroext %{{.*}}, i32 noundef %[[ADDR:.*]])
+// CHECK: %[[ADDR_ALLOCA:.*]] = alloca i32, align 4, addrspace(5)
+// CHECK: %[[ADDR_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ADDR_ALLOCA]] to ptr
+// CHECK: %[[ZERO_ALLOCA:.*]] = alloca i32, align 4, addrspace(5)
+// CHECK: %[[ZERO_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ZERO_ALLOCA]] to ptr
+// CHECK: %[[ARGS_ALLOCA:.*]] = alloca ptr, align 8, addrspace(5)
+// CHECK: %[[ARGS_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ARGS_ALLOCA]] to ptr
+// CHECK: store i32 %[[ADDR]], ptr %[[ADDR_ASCAST]]
+// CHECK: store i32 0, ptr %[[ZERO_ASCAST]]
+// CHECK: call void @__kmpc_get_shared_variables(ptr %[[ARGS_ASCAST]])
+// CHECK: %[[LOAD_ARGS:.*]] = load ptr, ptr %[[ARGS_ASCAST]], align 8
+// CHECK: %[[FIRST_ARG:.*]] = getelementptr inbounds ptr, ptr %[[LOAD_ARGS]], i64 0
+// CHECK: %[[STRUCTARG:.*]] = load ptr, ptr %[[FIRST_ARG]], align 8
+// CHECK: call void @[[FUNC1]](ptr %[[ADDR_ASCAST]], ptr %[[ZERO_ASCAST]], ptr %[[STRUCTARG]])
+
+// CHECK: define internal void @[[FUNC2_WRAPPER]](i16 noundef zeroext %{{.*}}, i32 noundef %{{.*}})
+// CHECK-NOT: define
+// CHECK: call void @[[FUNC_NUM_THREADS1]]({{.*}})
|
6a97ff2
to
8b34402
Compare
8771c0f
to
b6e9849
Compare
This patch introduces codegen logic to produce a wrapper function argument for the `__kmpc_parallel_51` DeviceRTL function needed to handle arguments passed using device shared memory in Generic mode.
b6e9849
to
6a81001
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
// Create wrapper function used to gather the outlined function's argument | ||
// structure from a shared buffer and to forward them to it when running in | ||
// Generic mode. | ||
// | ||
// The outlined function is expected to receive 2 integer arguments followed by | ||
// an optional pointer argument to an argument structure holding the rest. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider doxygen ///
coments
std::optional<omp::OMPTgtExecModeFlags> ExecMode = | ||
getTargetKernelExecMode(*OuterFn); | ||
Value *WrapperFn; | ||
if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC) | |
if (ExecMode && (*ExecMode & OMP_TGT_EXEC_MODE_GENERIC)) |
[suggestion]
This patch introduces codegen logic to produce a wrapper function argument for the
__kmpc_parallel_51
DeviceRTL function needed to handle arguments passed using device shared memory in Generic mode.