Skip to content

[MLIR][OpenMP] Support allocations of device shared memory #150924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/skatrak/flang-generic-02-ompirbuilder-shared-mem
Choose a base branch
from

Conversation

skatrak
Copy link
Member

@skatrak skatrak commented Jul 28, 2025

This patch updates the allocation of some reduction and private variables within target regions to use device shared memory rather than private memory. This is a prerequisite to produce working Generic kernels containing parallel regions.

In particular, the following situations result in the usage of device shared memory (only when compiling for the target device if they are placed inside of a target region representing a Generic kernel):

  • Reduction variables on teams constructs.
  • Private variables on teams and distribute constructs that are reduced or used inside of a parallel region.

Currently, there is no support for delayed privatization on teams constructs, so private variables on these constructs won't currently be affected. When support is added, if it uses the existing allocatePrivateVars and cleanupPrivateVars functions, usage of device shared memory will be introduced automatically.

This patch updates the allocation of some reduction and private variables
within target regions to use device shared memory rather than private memory.
This is a prerequisite to produce working Generic kernels containing parallel
regions.

In particular, the following situations result in the usage of device shared
memory (only when compiling for the target device if they are placed inside of
a target region representing a Generic kernel):
- Reduction variables on `teams` constructs.
- Private variables on `teams` and `distribute` constructs that are reduced or
used inside of a `parallel` region.

Currently, there is no support for delayed privatization on `teams` constructs,
so private variables on these constructs won't currently be affected. When
support is added, if it uses the existing `allocatePrivateVars` and
`cleanupPrivateVars` functions, usage of device shared memory will be
introduced automatically.
@llvmbot
Copy link
Member

llvmbot commented Jul 28, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-openmp
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-flang-openmp

Author: Sergio Afonso (skatrak)

Changes

This patch updates the allocation of some reduction and private variables within target regions to use device shared memory rather than private memory. This is a prerequisite to produce working Generic kernels containing parallel regions.

In particular, the following situations result in the usage of device shared memory (only when compiling for the target device if they are placed inside of a target region representing a Generic kernel):

  • Reduction variables on teams constructs.
  • Private variables on teams and distribute constructs that are reduced or used inside of a parallel region.

Currently, there is no support for delayed privatization on teams constructs, so private variables on these constructs won't currently be affected. When support is added, if it uses the existing allocatePrivateVars and cleanupPrivateVars functions, usage of device shared memory will be introduced automatically.


Patch is 23.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/150924.diff

2 Files Affected:

  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+167-60)
  • (added) mlir/test/Target/LLVMIR/omptarget-device-shared-memory.mlir (+86)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 34358cdcece3c..c5a26cab553cf 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1102,12 +1102,64 @@ struct DeferredStore {
 };
 } // namespace
 
+/// Check whether allocations for the given operation might potentially have to
+/// be done in device shared memory. That means we're compiling for a offloading
+/// target, the operation is an `omp::TargetOp` or nested inside of one and that
+/// target region represents a Generic (non-SPMD) kernel.
+///
+/// This represents a necessary but not sufficient set of conditions to use
+/// device shared memory in place of regular allocas. Depending on the variable,
+/// its uses or the associated OpenMP construct might also need to be taken into
+/// account.
+static bool
+mightAllocInDeviceSharedMemory(Operation &op,
+                               const llvm::OpenMPIRBuilder &ompBuilder) {
+  if (!ompBuilder.Config.isTargetDevice())
+    return false;
+
+  auto targetOp = dyn_cast<omp::TargetOp>(op);
+  if (!targetOp)
+    targetOp = op.getParentOfType<omp::TargetOp>();
+
+  return targetOp &&
+         !bitEnumContainsAny(
+             targetOp.getKernelExecFlags(targetOp.getInnermostCapturedOmpOp()),
+             omp::TargetRegionFlags::spmd);
+}
+
+/// Check whether the entry block argument representing the private copy of a
+/// variable in an OpenMP construct must be allocated in device shared memory,
+/// based on what the uses of that copy are.
+///
+/// This must only be called if a previous call to
+/// \c mightAllocInDeviceSharedMemory has already returned \c true for the
+/// operation that owns the specified block argument.
+static bool mustAllocPrivateVarInDeviceSharedMemory(BlockArgument value) {
+  Operation *parentOp = value.getOwner()->getParentOp();
+  auto targetOp = dyn_cast<omp::TargetOp>(parentOp);
+  if (!targetOp)
+    targetOp = parentOp->getParentOfType<omp::TargetOp>();
+  assert(targetOp && "expected a parent omp.target operation");
+
+  for (auto *user : value.getUsers()) {
+    if (auto parallelOp = dyn_cast<omp::ParallelOp>(user)) {
+      if (llvm::is_contained(parallelOp.getReductionVars(), value))
+        return true;
+    } else if (auto parallelOp = user->getParentOfType<omp::ParallelOp>()) {
+      if (targetOp->isProperAncestor(parallelOp))
+        return true;
+    }
+  }
+
+  return false;
+}
+
 /// Allocate space for privatized reduction variables.
 /// `deferredStores` contains information to create store operations which needs
 /// to be inserted after all allocas
 template <typename T>
 static LogicalResult
-allocReductionVars(T loop, ArrayRef<BlockArgument> reductionArgs,
+allocReductionVars(T op, ArrayRef<BlockArgument> reductionArgs,
                    llvm::IRBuilderBase &builder,
                    LLVM::ModuleTranslation &moduleTranslation,
                    const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
@@ -1119,10 +1171,14 @@ allocReductionVars(T loop, ArrayRef<BlockArgument> reductionArgs,
   llvm::IRBuilderBase::InsertPointGuard guard(builder);
   builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
 
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+  bool useDeviceSharedMem =
+      isa<omp::TeamsOp>(op) && mightAllocInDeviceSharedMemory(*op, *ompBuilder);
+
   // delay creating stores until after all allocas
-  deferredStores.reserve(loop.getNumReductionVars());
+  deferredStores.reserve(op.getNumReductionVars());
 
-  for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
+  for (std::size_t i = 0; i < op.getNumReductionVars(); ++i) {
     Region &allocRegion = reductionDecls[i].getAllocRegion();
     if (isByRefs[i]) {
       if (allocRegion.empty())
@@ -1131,7 +1187,7 @@ allocReductionVars(T loop, ArrayRef<BlockArgument> reductionArgs,
       SmallVector<llvm::Value *, 1> phis;
       if (failed(inlineConvertOmpRegions(allocRegion, "omp.reduction.alloc",
                                          builder, moduleTranslation, &phis)))
-        return loop.emitError(
+        return op.emitError(
             "failed to inline `alloc` region of `omp.declare_reduction`");
 
       assert(phis.size() == 1 && "expected one allocation to be yielded");
@@ -1139,33 +1195,43 @@ allocReductionVars(T loop, ArrayRef<BlockArgument> reductionArgs,
 
       // Allocate reduction variable (which is a pointer to the real reduction
       // variable allocated in the inlined region)
-      llvm::Value *var = builder.CreateAlloca(
-          moduleTranslation.convertType(reductionDecls[i].getType()));
-
       llvm::Type *ptrTy = builder.getPtrTy();
-      llvm::Value *castVar =
-          builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
+      llvm::Type *varTy =
+          moduleTranslation.convertType(reductionDecls[i].getType());
+      llvm::Value *var;
+      if (useDeviceSharedMem) {
+        var = ompBuilder->createOMPAllocShared(builder, varTy);
+      } else {
+        var = builder.CreateAlloca(varTy);
+        var = builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
+      }
+
       llvm::Value *castPhi =
           builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
 
-      deferredStores.emplace_back(castPhi, castVar);
+      deferredStores.emplace_back(castPhi, var);
 
-      privateReductionVariables[i] = castVar;
+      privateReductionVariables[i] = var;
       moduleTranslation.mapValue(reductionArgs[i], castPhi);
-      reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
+      reductionVariableMap.try_emplace(op.getReductionVars()[i], castPhi);
     } else {
       assert(allocRegion.empty() &&
              "allocaction is implicit for by-val reduction");
-      llvm::Value *var = builder.CreateAlloca(
-          moduleTranslation.convertType(reductionDecls[i].getType()));
 
       llvm::Type *ptrTy = builder.getPtrTy();
-      llvm::Value *castVar =
-          builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
+      llvm::Type *varTy =
+          moduleTranslation.convertType(reductionDecls[i].getType());
+      llvm::Value *var;
+      if (useDeviceSharedMem) {
+        var = ompBuilder->createOMPAllocShared(builder, varTy);
+      } else {
+        var = builder.CreateAlloca(varTy);
+        var = builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
+      }
 
-      moduleTranslation.mapValue(reductionArgs[i], castVar);
-      privateReductionVariables[i] = castVar;
-      reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
+      moduleTranslation.mapValue(reductionArgs[i], var);
+      privateReductionVariables[i] = var;
+      reductionVariableMap.try_emplace(op.getReductionVars()[i], var);
     }
   }
 
@@ -1227,6 +1293,10 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
   if (op.getNumReductionVars() == 0)
     return success();
 
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+  bool useDeviceSharedMem =
+      isa<omp::TeamsOp>(op) && mightAllocInDeviceSharedMemory(*op, *ompBuilder);
+
   llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init");
   auto allocaIP = llvm::IRBuilderBase::InsertPoint(
       latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
@@ -1241,8 +1311,12 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
       // TODO: remove after all users of by-ref are updated to use the alloc
       // region: Allocate reduction variable (which is a pointer to the real
       // reduciton variable allocated in the inlined region)
-      byRefVars[i] = builder.CreateAlloca(
-          moduleTranslation.convertType(reductionDecls[i].getType()));
+      llvm::Type *varTy =
+          moduleTranslation.convertType(reductionDecls[i].getType());
+      if (useDeviceSharedMem)
+        byRefVars[i] = ompBuilder->createOMPAllocShared(builder, varTy);
+      else
+        byRefVars[i] = builder.CreateAlloca(varTy);
     }
   }
 
@@ -1438,10 +1512,20 @@ static LogicalResult createReductionsAndCleanup(
                   [](omp::DeclareReductionOp reductionDecl) {
                     return &reductionDecl.getCleanupRegion();
                   });
-  return inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
-                                moduleTranslation, builder,
-                                "omp.reduction.cleanup");
-  return success();
+  LogicalResult result = inlineOmpRegionCleanup(
+      reductionRegions, privateReductionVariables, moduleTranslation, builder,
+      "omp.reduction.cleanup");
+
+  bool useDeviceSharedMem =
+      isa<omp::TeamsOp>(op) && mightAllocInDeviceSharedMemory(*op, *ompBuilder);
+  if (useDeviceSharedMem) {
+    for (auto [var, reductionDecl] :
+         llvm::zip_equal(privateReductionVariables, reductionDecls))
+      ompBuilder->createOMPFreeShared(
+          builder, var, moduleTranslation.convertType(reductionDecl.getType()));
+  }
+
+  return result;
 }
 
 static ArrayRef<bool> getIsByRef(std::optional<ArrayRef<bool>> attr) {
@@ -1586,8 +1670,9 @@ initPrivateVars(llvm::IRBuilderBase &builder,
 /// Allocate and initialize delayed private variables. Returns the basic block
 /// which comes after all of these allocations. llvm::Value * for each of these
 /// private variables are populated in llvmPrivateVars.
+template <typename T>
 static llvm::Expected<llvm::BasicBlock *>
-allocatePrivateVars(llvm::IRBuilderBase &builder,
+allocatePrivateVars(T op, llvm::IRBuilderBase &builder,
                     LLVM::ModuleTranslation &moduleTranslation,
                     PrivateVarsInfo &privateVarsInfo,
                     const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
@@ -1610,6 +1695,10 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
   llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
   llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(0);
 
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+  bool mightUseDeviceSharedMem =
+      isa<omp::TeamsOp, omp::DistributeOp>(*op) &&
+      mightAllocInDeviceSharedMemory(*op, *ompBuilder);
   unsigned int allocaAS =
       moduleTranslation.getLLVMModule()->getDataLayout().getAllocaAddrSpace();
   unsigned int defaultAS = moduleTranslation.getLLVMModule()
@@ -1622,11 +1711,17 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
     llvm::Type *llvmAllocType =
         moduleTranslation.convertType(privDecl.getType());
     builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
-    llvm::Value *llvmPrivateVar = builder.CreateAlloca(
-        llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
-    if (allocaAS != defaultAS)
-      llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
-                                                   builder.getPtrTy(defaultAS));
+    llvm::Value *llvmPrivateVar = nullptr;
+    if (mightUseDeviceSharedMem &&
+        mustAllocPrivateVarInDeviceSharedMemory(blockArg)) {
+      llvmPrivateVar = ompBuilder->createOMPAllocShared(builder, llvmAllocType);
+    } else {
+      llvmPrivateVar = builder.CreateAlloca(
+          llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
+      if (allocaAS != defaultAS)
+        llvmPrivateVar = builder.CreateAddrSpaceCast(
+            llvmPrivateVar, builder.getPtrTy(defaultAS));
+    }
 
     privateVarsInfo.llvmVars.push_back(llvmPrivateVar);
   }
@@ -1698,24 +1793,41 @@ static LogicalResult copyFirstPrivateVars(
   return success();
 }
 
+template <typename T>
 static LogicalResult
-cleanupPrivateVars(llvm::IRBuilderBase &builder,
+cleanupPrivateVars(T op, llvm::IRBuilderBase &builder,
                    LLVM::ModuleTranslation &moduleTranslation, Location loc,
-                   SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
-                   SmallVectorImpl<omp::PrivateClauseOp> &privateDecls) {
+                   PrivateVarsInfo &privateVarsInfo) {
   // private variable deallocation
   SmallVector<Region *> privateCleanupRegions;
-  llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
+  llvm::transform(privateVarsInfo.privatizers,
+                  std::back_inserter(privateCleanupRegions),
                   [](omp::PrivateClauseOp privatizer) {
                     return &privatizer.getDeallocRegion();
                   });
 
-  if (failed(inlineOmpRegionCleanup(
-          privateCleanupRegions, llvmPrivateVars, moduleTranslation, builder,
-          "omp.private.dealloc", /*shouldLoadCleanupRegionArg=*/false)))
+  if (failed(inlineOmpRegionCleanup(privateCleanupRegions,
+                                    privateVarsInfo.llvmVars, moduleTranslation,
+                                    builder, "omp.private.dealloc",
+                                    /*shouldLoadCleanupRegionArg=*/false)))
     return mlir::emitError(loc, "failed to inline `dealloc` region of an "
                                 "`omp.private` op in");
 
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+  bool mightUseDeviceSharedMem =
+      isa<omp::TeamsOp, omp::DistributeOp>(*op) &&
+      mightAllocInDeviceSharedMemory(*op, *ompBuilder);
+  for (auto [privDecl, llvmPrivVar, blockArg] :
+       llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.llvmVars,
+                       privateVarsInfo.blockArgs)) {
+    if (mightUseDeviceSharedMem &&
+        mustAllocPrivateVarInDeviceSharedMemory(blockArg)) {
+      ompBuilder->createOMPFreeShared(
+          builder, llvmPrivVar,
+          moduleTranslation.convertType(privDecl.getType()));
+    }
+  }
+
   return success();
 }
 
@@ -2382,9 +2494,8 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
 
     builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
 
-    if (failed(cleanupPrivateVars(builder, moduleTranslation, taskOp.getLoc(),
-                                  privateVarsInfo.llvmVars,
-                                  privateVarsInfo.privatizers)))
+    if (failed(cleanupPrivateVars(taskOp, builder, moduleTranslation,
+                                  taskOp.getLoc(), privateVarsInfo)))
       return llvm::make_error<PreviouslyReportedError>();
 
     // Free heap allocated task context structure at the end of the task.
@@ -2501,7 +2612,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
       wsloopOp.getNumReductionVars());
 
   llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
-      builder, moduleTranslation, privateVarsInfo, allocaIP);
+      wsloopOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
   if (handleError(afterAllocas, opInst).failed())
     return failure();
 
@@ -2627,9 +2738,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
           /*isTeamsReduction=*/false)))
     return failure();
 
-  return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
-                            privateVarsInfo.llvmVars,
-                            privateVarsInfo.privatizers);
+  return cleanupPrivateVars(wsloopOp, builder, moduleTranslation,
+                            wsloopOp.getLoc(), privateVarsInfo);
 }
 
 /// Converts the OpenMP parallel operation to LLVM IR.
@@ -2656,7 +2766,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
   auto bodyGenCB = [&](InsertPointTy allocaIP,
                        InsertPointTy codeGenIP) -> llvm::Error {
     llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
-        builder, moduleTranslation, privateVarsInfo, allocaIP);
+        opInst, builder, moduleTranslation, privateVarsInfo, allocaIP);
     if (handleError(afterAllocas, *opInst).failed())
       return llvm::make_error<PreviouslyReportedError>();
 
@@ -2770,9 +2880,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
       return llvm::createStringError(
           "failed to inline `cleanup` region of `omp.declare_reduction`");
 
-    if (failed(cleanupPrivateVars(builder, moduleTranslation, opInst.getLoc(),
-                                  privateVarsInfo.llvmVars,
-                                  privateVarsInfo.privatizers)))
+    if (failed(cleanupPrivateVars(opInst, builder, moduleTranslation,
+                                  opInst.getLoc(), privateVarsInfo)))
       return llvm::make_error<PreviouslyReportedError>();
 
     builder.restoreIP(oldIP);
@@ -2844,7 +2953,7 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
       findAllocaInsertPoint(builder, moduleTranslation);
 
   llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
-      builder, moduleTranslation, privateVarsInfo, allocaIP);
+      simdOp, builder, moduleTranslation, privateVarsInfo, allocaIP);
   if (handleError(afterAllocas, opInst).failed())
     return failure();
 
@@ -2958,9 +3067,8 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
                                     "omp.reduction.cleanup")))
     return failure();
 
-  return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
-                            privateVarsInfo.llvmVars,
-                            privateVarsInfo.privatizers);
+  return cleanupPrivateVars(simdOp, builder, moduleTranslation, simdOp.getLoc(),
+                            privateVarsInfo);
 }
 
 /// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
@@ -4776,8 +4884,8 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
     builder.restoreIP(codeGenIP);
     PrivateVarsInfo privVarsInfo(distributeOp);
 
-    llvm::Expected<llvm::BasicBlock *> afterAllocas =
-        allocatePrivateVars(builder, moduleTranslation, privVarsInfo, allocaIP);
+    llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
+        distributeOp, builder, moduleTranslation, privVarsInfo, allocaIP);
     if (handleError(afterAllocas, opInst).failed())
       return llvm::make_error<PreviouslyReportedError>();
 
@@ -4830,9 +4938,8 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
         return wsloopIP.takeError();
     }
 
-    if (failed(cleanupPrivateVars(builder, moduleTranslation,
-                                  distributeOp.getLoc(), privVarsInfo.llvmVars,
-                                  privVarsInfo.privatizers)))
+    if (failed(cleanupPrivateVars(distributeOp, builder, moduleTranslation,
+                                  distributeOp.getLoc(), privVarsInfo)))
       return llvm::make_error<PreviouslyReportedError>();
 
     return llvm::Error::success();
@@ -5555,8 +5662,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
     PrivateVarsInfo privateVarsInfo(targetOp);
 
     llvm::Expected<llvm::BasicBlock *> afterAllocas =
-        allocatePrivateVars(builder, moduleTranslation, privateVarsInfo,
-                            allocaIP, &mappedPrivateVars);
+        allocatePrivateVars(targetOp, builder, moduleTranslation,
+                            privateVarsInfo, allocaIP, &mappedPrivateVars);
 
     if (failed(handleError(afterAllocas, *targetOp)))
       return llvm::make_error<PreviouslyReportedError>();
diff --git a/mlir/test/Target/LLVMIR/omptarget-device-shared-memory.mlir b/mlir/test/Target/LLVMIR/omptarget-device-shared-memory.mlir
new file mode 100644
index 0000000000000..0e08b771633c6
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-device-shared-memory.mlir
@@ -0,0 +1,86 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// This test checks that, when compiling for an offloading target, device shared
+// memory will be used in place of allocas for certain private variables.
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>>, llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8", llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
+  omp.private {type = private} @privatizer : i32
+  omp.declare_reduction @reduction : i32 init {
+  ^bb0...
[truncated]

Copy link
Member

@Meinersbur Meinersbur left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

if (!ompBuilder.Config.isTargetDevice())
return false;

auto targetOp = dyn_cast<omp::TargetOp>(op);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto targetOp = dyn_cast<omp::TargetOp>(op);
auto *targetOp = dyn_cast<omp::TargetOp>(op);

/// operation that owns the specified block argument.
static bool mustAllocPrivateVarInDeviceSharedMemory(BlockArgument value) {
Operation *parentOp = value.getOwner()->getParentOp();
auto targetOp = dyn_cast<omp::TargetOp>(parentOp);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto targetOp = dyn_cast<omp::TargetOp>(parentOp);
auto *targetOp = dyn_cast<omp::TargetOp>(parentOp);

assert(targetOp && "expected a parent omp.target operation");

for (auto *user : value.getUsers()) {
if (auto parallelOp = dyn_cast<omp::ParallelOp>(user)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (auto parallelOp = dyn_cast<omp::ParallelOp>(user)) {
if (auto *parallelOp = dyn_cast<omp::ParallelOp>(user)) {

Copy link
Contributor

@bhandarkar-pranav bhandarkar-pranav left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR, @skatrak. One nit and one clarifying question on my part.

///
/// This represents a necessary but not sufficient set of conditions to use
/// device shared memory in place of regular allocas. Depending on the variable,
/// its uses or the associated OpenMP construct might also need to be taken into
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

- The 'or' in this sentence seems to be off. Did you mean "Depending on the variable and its uses, the associated OpenMP construct might need to be taken.."?

if (llvm::is_contained(parallelOp.getReductionVars(), value))
return true;
} else if (auto parallelOp = user->getParentOfType<omp::ParallelOp>()) {
if (targetOp->isProperAncestor(parallelOp))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be a hole in my understanding of this. At this point in the code, we know that value iis a BlockArgument. We know that it has an an ancestor, targetOp that is an omp::TargetOp. Should't all the users of a BlockArgument be such that they are dominated by the BlockArgument. Therefore targetOp should trivially be an ancestor of all users, no? All this is to say that the use of a BlockArgument inside a omp::ParallelOp should be enough and this ancestor check is superfluous. Unless, of course, I am missing something that is obvious.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants