Skip to content

[OpenMP][OMPIRBuilder] Support parallel in Generic kernels #150926

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/skatrak/flang-generic-04-parallel-args
Choose a base branch
from

Conversation

skatrak
Copy link
Member

@skatrak skatrak commented Jul 28, 2025

This patch introduces codegen logic to produce a wrapper function argument for the __kmpc_parallel_51 DeviceRTL function needed to handle arguments passed using device shared memory in Generic mode.

@llvmbot
Copy link
Member

llvmbot commented Jul 28, 2025

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-flang-openmp

Author: Sergio Afonso (skatrak)

Changes

This patch introduces codegen logic to produce a wrapper function argument for the __kmpc_parallel_51 DeviceRTL function needed to handle arguments passed using device shared memory in Generic mode.


Full diff: https://github.com/llvm/llvm-project/pull/150926.diff

2 Files Affected:

  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+94-6)
  • (modified) mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir (+22-3)
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index a913958c0de9a..0005a72e86324 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -1323,6 +1323,86 @@ Error OpenMPIRBuilder::emitCancelationCheckImpl(
   return Error::success();
 }
 
+// Create wrapper function used to gather the outlined function's argument
+// structure from a shared buffer and to forward them to it when running in
+// Generic mode.
+//
+// The outlined function is expected to receive 2 integer arguments followed by
+// an optional pointer argument to an argument structure holding the rest.
+static Function *createTargetParallelWrapper(OpenMPIRBuilder *OMPIRBuilder,
+                                             Function &OutlinedFn) {
+  size_t NumArgs = OutlinedFn.arg_size();
+  assert((NumArgs == 2 || NumArgs == 3) &&
+         "expected a 2-3 argument parallel outlined function");
+  bool UseArgStruct = NumArgs == 3;
+
+  IRBuilder<> &Builder = OMPIRBuilder->Builder;
+  IRBuilder<>::InsertPointGuard IPG(Builder);
+  auto *FnTy = FunctionType::get(Builder.getVoidTy(),
+                                 {Builder.getInt16Ty(), Builder.getInt32Ty()},
+                                 /*isVarArg=*/false);
+  auto *WrapperFn =
+      Function::Create(FnTy, GlobalValue::InternalLinkage,
+                       OutlinedFn.getName() + ".wrapper", OMPIRBuilder->M);
+
+  WrapperFn->addParamAttr(0, Attribute::NoUndef);
+  WrapperFn->addParamAttr(0, Attribute::ZExt);
+  WrapperFn->addParamAttr(1, Attribute::NoUndef);
+
+  BasicBlock *EntryBB =
+      BasicBlock::Create(OMPIRBuilder->M.getContext(), "entry", WrapperFn);
+  Builder.SetInsertPoint(EntryBB);
+
+  // Allocation.
+  Value *AddrAlloca = Builder.CreateAlloca(Builder.getInt32Ty(),
+                                           /*ArraySize=*/nullptr, "addr");
+  AddrAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      AddrAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
+      AddrAlloca->getName() + ".ascast");
+
+  Value *ZeroAlloca = Builder.CreateAlloca(Builder.getInt32Ty(),
+                                           /*ArraySize=*/nullptr, "zero");
+  ZeroAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      ZeroAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
+      ZeroAlloca->getName() + ".ascast");
+
+  Value *ArgsAlloca = nullptr;
+  if (UseArgStruct) {
+    ArgsAlloca = Builder.CreateAlloca(Builder.getPtrTy(),
+                                      /*ArraySize=*/nullptr, "global_args");
+    ArgsAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        ArgsAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
+        ArgsAlloca->getName() + ".ascast");
+  }
+
+  // Initialization.
+  Builder.CreateStore(WrapperFn->getArg(1), AddrAlloca);
+  Builder.CreateStore(Builder.getInt32(0), ZeroAlloca);
+  if (UseArgStruct) {
+    Builder.CreateCall(
+        OMPIRBuilder->getOrCreateRuntimeFunctionPtr(
+            llvm::omp::RuntimeFunction::OMPRTL___kmpc_get_shared_variables),
+        {ArgsAlloca});
+  }
+
+  SmallVector<Value *, 3> Args{AddrAlloca, ZeroAlloca};
+
+  // Load structArg from global_args.
+  if (UseArgStruct) {
+    Value *StructArg = Builder.CreateLoad(Builder.getPtrTy(), ArgsAlloca);
+    StructArg = Builder.CreateInBoundsGEP(Builder.getPtrTy(), StructArg,
+                                          {Builder.getInt64(0)});
+    StructArg = Builder.CreateLoad(Builder.getPtrTy(), StructArg, "structArg");
+    Args.push_back(StructArg);
+  }
+
+  // Call the outlined function holding the parallel body.
+  Builder.CreateCall(&OutlinedFn, Args);
+  Builder.CreateRetVoid();
+
+  return WrapperFn;
+}
+
 // Callback used to create OpenMP runtime calls to support
 // omp parallel clause for the device.
 // We need to use this callback to replace call to the OutlinedFn in OuterFn
@@ -1332,6 +1412,10 @@ static void targetParallelCallback(
     BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
     Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
     Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
+  assert(OutlinedFn.arg_size() >= 2 &&
+         "Expected at least tid and bounded tid as arguments");
+  unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
+
   // Add some known attributes.
   IRBuilder<> &Builder = OMPIRBuilder->Builder;
   OutlinedFn.addParamAttr(0, Attribute::NoAlias);
@@ -1340,17 +1424,12 @@ static void targetParallelCallback(
   OutlinedFn.addParamAttr(1, Attribute::NoUndef);
   OutlinedFn.addFnAttr(Attribute::NoUnwind);
 
-  assert(OutlinedFn.arg_size() >= 2 &&
-         "Expected at least tid and bounded tid as arguments");
-  unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
-
   CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
   assert(CI && "Expected call instruction to outlined function");
   CI->getParent()->setName("omp_parallel");
 
   Builder.SetInsertPoint(CI);
   Type *PtrTy = OMPIRBuilder->VoidPtr;
-  Value *NullPtrValue = Constant::getNullValue(PtrTy);
 
   // Add alloca for kernel args
   OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
@@ -1376,6 +1455,15 @@ static void targetParallelCallback(
       IfCondition ? Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32)
                   : Builder.getInt32(1);
 
+  // If this is not a Generic kernel, we can skip generating the wrapper.
+  std::optional<omp::OMPTgtExecModeFlags> ExecMode =
+      getTargetKernelExecMode(*OuterFn);
+  Value *WrapperFn;
+  if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC)
+    WrapperFn = createTargetParallelWrapper(OMPIRBuilder, OutlinedFn);
+  else
+    WrapperFn = Constant::getNullValue(PtrTy);
+
   // Build kmpc_parallel_51 call
   Value *Parallel51CallArgs[] = {
       /* identifier*/ Ident,
@@ -1384,7 +1472,7 @@ static void targetParallelCallback(
       /* number of threads */ NumThreads ? NumThreads : Builder.getInt32(-1),
       /* Proc bind */ Builder.getInt32(-1),
       /* outlined function */ &OutlinedFn,
-      /* wrapper function */ NullPtrValue,
+      /* wrapper function */ WrapperFn,
       /* arguments of the outlined funciton*/ Args,
       /* number of arguments */ Builder.getInt64(NumCapturedVars)};
 
diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
index 504e39c96f008..ca998b4672ba0 100644
--- a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
@@ -69,7 +69,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK:         store ptr %[[TMP6]], ptr %[[GEP_]], align 8
 // CHECK:         %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[TMP2]], i64 0, i64 0
 // CHECK:         store ptr %[[STRUCTARG]], ptr %[[TMP7]], align 8
-// CHECK:         call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1)
+// CHECK:         call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr @[[FUNC1_WRAPPER:.*]], ptr %[[TMP2]], i64 1)
 // CHECK:         call void @__kmpc_free_shared(ptr %[[STRUCTARG]], i64 8)
 // CHECK:         call void @__kmpc_target_deinit()
 
@@ -84,7 +84,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK:         call void @__kmpc_parallel_51(ptr addrspacecast (
 // CHECK-SAME:  ptr addrspace(1) @[[NUM_THREADS_GLOB:[0-9]+]] to ptr),
 // CHECK-SAME:  i32 [[NUM_THREADS_TMP0:%.*]], i32 1, i32 156,
-// CHECK-SAME:  i32 -1,  ptr [[FUNC_NUM_THREADS1:@.*]], ptr null, ptr [[NUM_THREADS_TMP1:%.*]], i64 1)
+// CHECK-SAME:  i32 -1, ptr @[[FUNC_NUM_THREADS1:.*]], ptr @[[FUNC2_WRAPPER:.*]], ptr [[NUM_THREADS_TMP1:%.*]], i64 1)
 
 // One of the arguments of  kmpc_parallel_51 function is responsible for handling if clause
 // of omp parallel construct for target region. If this  argument is nonzero,
@@ -105,4 +105,23 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK:         call void @__kmpc_parallel_51(ptr addrspacecast (
 // CHECK-SAME:  ptr addrspace(1) {{.*}} to ptr),
 // CHECK-SAME:  i32 {{.*}}, i32 %[[IFCOND_TMP4]], i32 -1,
-// CHECK-SAME:  i32 -1,  ptr {{.*}}, ptr null, ptr {{.*}}, i64 1)
+// CHECK-SAME:  i32 -1,  ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i64 1)
+
+// CHECK: define internal void @[[FUNC1_WRAPPER]](i16 noundef zeroext %{{.*}}, i32 noundef %[[ADDR:.*]])
+// CHECK: %[[ADDR_ALLOCA:.*]] = alloca i32, align 4, addrspace(5)
+// CHECK: %[[ADDR_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ADDR_ALLOCA]] to ptr
+// CHECK: %[[ZERO_ALLOCA:.*]] = alloca i32, align 4, addrspace(5)
+// CHECK: %[[ZERO_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ZERO_ALLOCA]] to ptr
+// CHECK: %[[ARGS_ALLOCA:.*]] = alloca ptr, align 8, addrspace(5)
+// CHECK: %[[ARGS_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ARGS_ALLOCA]] to ptr
+// CHECK: store i32 %[[ADDR]], ptr %[[ADDR_ASCAST]]
+// CHECK: store i32 0, ptr %[[ZERO_ASCAST]]
+// CHECK: call void @__kmpc_get_shared_variables(ptr %[[ARGS_ASCAST]])
+// CHECK: %[[LOAD_ARGS:.*]] = load ptr, ptr %[[ARGS_ASCAST]], align 8
+// CHECK: %[[FIRST_ARG:.*]] = getelementptr inbounds ptr, ptr %[[LOAD_ARGS]], i64 0
+// CHECK: %[[STRUCTARG:.*]] = load ptr, ptr %[[FIRST_ARG]], align 8
+// CHECK: call void @[[FUNC1]](ptr %[[ADDR_ASCAST]], ptr %[[ZERO_ASCAST]], ptr %[[STRUCTARG]])
+
+// CHECK: define internal void @[[FUNC2_WRAPPER]](i16 noundef zeroext %{{.*}}, i32 noundef %{{.*}})
+// CHECK-NOT: define
+// CHECK: call void @[[FUNC_NUM_THREADS1]]({{.*}})

@skatrak skatrak force-pushed the users/skatrak/flang-generic-04-parallel-args branch from 6a97ff2 to 8b34402 Compare July 28, 2025 11:22
@skatrak skatrak force-pushed the users/skatrak/flang-generic-05-parallel-wrapper branch from 8771c0f to b6e9849 Compare July 28, 2025 11:23
This patch introduces codegen logic to produce a wrapper function argument for
the `__kmpc_parallel_51` DeviceRTL function needed to handle arguments passed
using device shared memory in Generic mode.
@skatrak skatrak force-pushed the users/skatrak/flang-generic-05-parallel-wrapper branch from b6e9849 to 6a81001 Compare August 6, 2025 12:54
Copy link
Member

@Meinersbur Meinersbur left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +1326 to +1331
// Create wrapper function used to gather the outlined function's argument
// structure from a shared buffer and to forward them to it when running in
// Generic mode.
//
// The outlined function is expected to receive 2 integer arguments followed by
// an optional pointer argument to an argument structure holding the rest.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider doxygen /// coments

std::optional<omp::OMPTgtExecModeFlags> ExecMode =
getTargetKernelExecMode(*OuterFn);
Value *WrapperFn;
if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC)
if (ExecMode && (*ExecMode & OMP_TGT_EXEC_MODE_GENERIC))

[suggestion]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants