Skip to content

[MLIR][OpenMP] Remove Generic-SPMD early detection #150922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

skatrak
Copy link
Member

@skatrak skatrak commented Jul 28, 2025

This patch removes logic from MLIR to attempt identifying Generic kernels that could be executed in SPMD mode.

This optimization is done by the OpenMPOpt pass for Clang and is only required here to circumvent missing support for the new DeviceRTL APIs used in MLIR to LLVM IR translation that Clang doesn't currently use (e.g. kmpc_distribute_static_loop ). Removing checks in MLIR avoids duplicating the logic that should be centralized in the OpenMPOpt pass.

Additionally, offloading kernels currently compiled through the OpenMP dialect fail to run parallel regions properly when in Generic mode. By disabling early detection, this issue becomes apparent for a range of kernels where this was masked by having them run in SPMD mode.

This patch removes logic from MLIR to attempt identifying Generic kernels that
could be executed in SPMD mode.

This optimization is done by the OpenMPOpt pass for Clang and is only required
here to circumvent missing support for the new DeviceRTL APIs used in MLIR to
LLVM IR translation that Clang doesn't currently use (e.g.
`kmpc_distribute_static_loop` ). Removing checks in MLIR avoids duplicating the
logic that should be centralized in the OpenMPOpt pass.

Additionally, offloading kernels currently compiled through the OpenMP dialect
fail to run parallel regions properly when in Generic mode. By disabling early
detection, this issue becomes apparent for a range of kernels where this was
masked by having them run in SPMD mode.
@llvmbot
Copy link
Member

llvmbot commented Jul 28, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-openmp

@llvm/pr-subscribers-flang-openmp

Author: Sergio Afonso (skatrak)

Changes

This patch removes logic from MLIR to attempt identifying Generic kernels that could be executed in SPMD mode.

This optimization is done by the OpenMPOpt pass for Clang and is only required here to circumvent missing support for the new DeviceRTL APIs used in MLIR to LLVM IR translation that Clang doesn't currently use (e.g. kmpc_distribute_static_loop ). Removing checks in MLIR avoids duplicating the logic that should be centralized in the OpenMPOpt pass.

Additionally, offloading kernels currently compiled through the OpenMP dialect fail to run parallel regions properly when in Generic mode. By disabling early detection, this issue becomes apparent for a range of kernels where this was masked by having them run in SPMD mode.


Full diff: https://github.com/llvm/llvm-project/pull/150922.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td (+2-4)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+11-37)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+3-9)
  • (modified) mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir (+1-1)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index c080c3fac87d4..ce0ebabd58125 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -227,15 +227,13 @@ def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;
 //===----------------------------------------------------------------------===//
 
 def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
-def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>;
-def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>;
-def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>;
+def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 0>;
+def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 1>;
 
 def TargetRegionFlags : OpenMP_BitEnumAttr<
     "TargetRegionFlags",
     "target region property flags", [
       TargetRegionFlagsNone,
-      TargetRegionFlagsGeneric,
       TargetRegionFlagsSpmd,
       TargetRegionFlagsTripCount
     ]>;
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index c1c1767ef90b0..8854e908c71f3 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2117,7 +2117,7 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
 
   // If it's not capturing a loop, it's a default target region.
   if (!isa_and_present<LoopNestOp>(capturedOp))
-    return TargetRegionFlags::generic;
+    return TargetRegionFlags::none;
 
   // Get the innermost non-simd loop wrapper.
   SmallVector<LoopWrapperInterface> loopWrappers;
@@ -2130,24 +2130,24 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
 
   auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
   if (numWrappers != 1 && numWrappers != 2)
-    return TargetRegionFlags::generic;
+    return TargetRegionFlags::none;
 
   // Detect target-teams-distribute-parallel-wsloop[-simd].
   if (numWrappers == 2) {
     if (!isa<WsloopOp>(innermostWrapper))
-      return TargetRegionFlags::generic;
+      return TargetRegionFlags::none;
 
     innermostWrapper = std::next(innermostWrapper);
     if (!isa<DistributeOp>(innermostWrapper))
-      return TargetRegionFlags::generic;
+      return TargetRegionFlags::none;
 
     Operation *parallelOp = (*innermostWrapper)->getParentOp();
     if (!isa_and_present<ParallelOp>(parallelOp))
-      return TargetRegionFlags::generic;
+      return TargetRegionFlags::none;
 
     Operation *teamsOp = parallelOp->getParentOp();
     if (!isa_and_present<TeamsOp>(teamsOp))
-      return TargetRegionFlags::generic;
+      return TargetRegionFlags::none;
 
     if (teamsOp->getParentOp() == targetOp.getOperation())
       return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
@@ -2156,53 +2156,27 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
   else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
     Operation *teamsOp = (*innermostWrapper)->getParentOp();
     if (!isa_and_present<TeamsOp>(teamsOp))
-      return TargetRegionFlags::generic;
+      return TargetRegionFlags::none;
 
     if (teamsOp->getParentOp() != targetOp.getOperation())
-      return TargetRegionFlags::generic;
+      return TargetRegionFlags::none;
 
     if (isa<LoopOp>(innermostWrapper))
       return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
 
-    // Find single immediately nested captured omp.parallel and add spmd flag
-    // (generic-spmd case).
-    //
-    // TODO: This shouldn't have to be done here, as it is too easy to break.
-    // The openmp-opt pass should be updated to be able to promote kernels like
-    // this from "Generic" to "Generic-SPMD". However, the use of the
-    // `kmpc_distribute_static_loop` family of functions produced by the
-    // OMPIRBuilder for these kernels prevents that from working.
-    Dialect *ompDialect = targetOp->getDialect();
-    Operation *nestedCapture = findCapturedOmpOp(
-        capturedOp, /*checkSingleMandatoryExec=*/false,
-        [&](Operation *sibling) {
-          return sibling && (ompDialect != sibling->getDialect() ||
-                             sibling->hasTrait<OpTrait::IsTerminator>());
-        });
-
-    TargetRegionFlags result =
-        TargetRegionFlags::generic | TargetRegionFlags::trip_count;
-
-    if (!nestedCapture)
-      return result;
-
-    while (nestedCapture->getParentOp() != capturedOp)
-      nestedCapture = nestedCapture->getParentOp();
-
-    return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
-                                          : result;
+    return TargetRegionFlags::trip_count;
   }
   // Detect target-parallel-wsloop[-simd].
   else if (isa<WsloopOp>(innermostWrapper)) {
     Operation *parallelOp = (*innermostWrapper)->getParentOp();
     if (!isa_and_present<ParallelOp>(parallelOp))
-      return TargetRegionFlags::generic;
+      return TargetRegionFlags::none;
 
     if (parallelOp->getParentOp() == targetOp.getOperation())
       return TargetRegionFlags::spmd;
   }
 
-  return TargetRegionFlags::generic;
+  return TargetRegionFlags::none;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9f18199c75b4b..34358cdcece3c 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -5324,16 +5324,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
 
   // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
   omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
-  assert(
-      omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
-                                               omp::TargetRegionFlags::spmd) &&
-      "invalid kernel flags");
   attrs.ExecFlags =
-      omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
-          ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
-                ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
-                : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
-          : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
+      omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
+          ? llvm::omp::OMP_TGT_EXEC_MODE_SPMD
+          : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
   attrs.MinTeams = minTeamsVal;
   attrs.MaxTeams.front() = maxTeamsVal;
   attrs.MinThreads = 1;
diff --git a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
index 9bb2b40a43def..fd190a7b95f66 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
@@ -87,7 +87,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
   }
 }
 
-// DEVICE:      @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]]
+// DEVICE:      @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:1]]
 // DEVICE:      @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
 // DEVICE:      @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
 // DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},

def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>;
def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>;
def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>;
def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 0>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't this intended to correspond to the OMP_TGT_EXEC_MODE_SPMD = 1 << 1 flag? Is there some comments/doxygen explaining what those flags mean?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants