From 39800face19f2966a4456d2a4c583cd87a693c7e Mon Sep 17 00:00:00 2001 From: Sergio Afonso Date: Wed, 21 May 2025 16:29:23 +0100 Subject: [PATCH 1/2] [MLIR][OpenMP] Remove Generic-SPMD early detection This patch removes logic from MLIR to attempt identifying Generic kernels that could be executed in SPMD mode. This optimization is done by the OpenMPOpt pass for Clang and is only required here to circumvent missing support for the new DeviceRTL APIs used in MLIR to LLVM IR translation that Clang doesn't currently use (e.g. `kmpc_distribute_static_loop` ). Removing checks in MLIR avoids duplicating the logic that should be centralized in the OpenMPOpt pass. Additionally, offloading kernels currently compiled through the OpenMP dialect fail to run parallel regions properly when in Generic mode. By disabling early detection, this issue becomes apparent for a range of kernels where this was masked by having them run in SPMD mode. --- .../mlir/Dialect/OpenMP/OpenMPEnums.td | 6 +-- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 48 +++++-------------- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 12 ++--- .../LLVMIR/openmp-target-generic-spmd.mlir | 2 +- 4 files changed, 17 insertions(+), 51 deletions(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td index c080c3fac87d4..ce0ebabd58125 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td @@ -227,15 +227,13 @@ def ScheduleModifierAttr : OpenMP_EnumAttr; //===----------------------------------------------------------------------===// def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">; -def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>; -def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>; -def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>; +def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 0>; +def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 1>; def TargetRegionFlags : OpenMP_BitEnumAttr< "TargetRegionFlags", "target region property flags", [ TargetRegionFlagsNone, - TargetRegionFlagsGeneric, TargetRegionFlagsSpmd, TargetRegionFlagsTripCount ]>; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index c1c1767ef90b0..8854e908c71f3 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2117,7 +2117,7 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { // If it's not capturing a loop, it's a default target region. if (!isa_and_present(capturedOp)) - return TargetRegionFlags::generic; + return TargetRegionFlags::none; // Get the innermost non-simd loop wrapper. SmallVector loopWrappers; @@ -2130,24 +2130,24 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { auto numWrappers = std::distance(innermostWrapper, loopWrappers.end()); if (numWrappers != 1 && numWrappers != 2) - return TargetRegionFlags::generic; + return TargetRegionFlags::none; // Detect target-teams-distribute-parallel-wsloop[-simd]. if (numWrappers == 2) { if (!isa(innermostWrapper)) - return TargetRegionFlags::generic; + return TargetRegionFlags::none; innermostWrapper = std::next(innermostWrapper); if (!isa(innermostWrapper)) - return TargetRegionFlags::generic; + return TargetRegionFlags::none; Operation *parallelOp = (*innermostWrapper)->getParentOp(); if (!isa_and_present(parallelOp)) - return TargetRegionFlags::generic; + return TargetRegionFlags::none; Operation *teamsOp = parallelOp->getParentOp(); if (!isa_and_present(teamsOp)) - return TargetRegionFlags::generic; + return TargetRegionFlags::none; if (teamsOp->getParentOp() == targetOp.getOperation()) return TargetRegionFlags::spmd | TargetRegionFlags::trip_count; @@ -2156,53 +2156,27 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { else if (isa(innermostWrapper)) { Operation *teamsOp = (*innermostWrapper)->getParentOp(); if (!isa_and_present(teamsOp)) - return TargetRegionFlags::generic; + return TargetRegionFlags::none; if (teamsOp->getParentOp() != targetOp.getOperation()) - return TargetRegionFlags::generic; + return TargetRegionFlags::none; if (isa(innermostWrapper)) return TargetRegionFlags::spmd | TargetRegionFlags::trip_count; - // Find single immediately nested captured omp.parallel and add spmd flag - // (generic-spmd case). - // - // TODO: This shouldn't have to be done here, as it is too easy to break. - // The openmp-opt pass should be updated to be able to promote kernels like - // this from "Generic" to "Generic-SPMD". However, the use of the - // `kmpc_distribute_static_loop` family of functions produced by the - // OMPIRBuilder for these kernels prevents that from working. - Dialect *ompDialect = targetOp->getDialect(); - Operation *nestedCapture = findCapturedOmpOp( - capturedOp, /*checkSingleMandatoryExec=*/false, - [&](Operation *sibling) { - return sibling && (ompDialect != sibling->getDialect() || - sibling->hasTrait()); - }); - - TargetRegionFlags result = - TargetRegionFlags::generic | TargetRegionFlags::trip_count; - - if (!nestedCapture) - return result; - - while (nestedCapture->getParentOp() != capturedOp) - nestedCapture = nestedCapture->getParentOp(); - - return isa(nestedCapture) ? result | TargetRegionFlags::spmd - : result; + return TargetRegionFlags::trip_count; } // Detect target-parallel-wsloop[-simd]. else if (isa(innermostWrapper)) { Operation *parallelOp = (*innermostWrapper)->getParentOp(); if (!isa_and_present(parallelOp)) - return TargetRegionFlags::generic; + return TargetRegionFlags::none; if (parallelOp->getParentOp() == targetOp.getOperation()) return TargetRegionFlags::spmd; } - return TargetRegionFlags::generic; + return TargetRegionFlags::none; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index eb96cb211fdd5..88601ef45911e 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -5355,16 +5355,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, // Update kernel bounds structure for the `OpenMPIRBuilder` to use. omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp); - assert( - omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic | - omp::TargetRegionFlags::spmd) && - "invalid kernel flags"); attrs.ExecFlags = - omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic) - ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd) - ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD - : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC - : llvm::omp::OMP_TGT_EXEC_MODE_SPMD; + omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd) + ? llvm::omp::OMP_TGT_EXEC_MODE_SPMD + : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC; attrs.MinTeams = minTeamsVal; attrs.MaxTeams.front() = maxTeamsVal; attrs.MinThreads = 1; diff --git a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir index 9bb2b40a43def..fd190a7b95f66 100644 --- a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir +++ b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir @@ -87,7 +87,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo } } -// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]] +// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:1]] // DEVICE: @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata" // DEVICE: @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy { // DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}}, From 9e948a58af729d8c142c6e1c4a252a01fd2e6dbd Mon Sep 17 00:00:00 2001 From: Sergio Afonso Date: Wed, 13 Aug 2025 12:53:34 +0100 Subject: [PATCH 2/2] Update TargetRegionFlags to mirror OMPTgtExecModeFlags --- .../mlir/Dialect/OpenMP/OpenMPEnums.td | 20 ++++---- mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 12 +++-- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 51 +++++++++++-------- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 22 +++++--- 4 files changed, 64 insertions(+), 41 deletions(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td index ce0ebabd58125..deb2fba1cd796 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td @@ -223,19 +223,19 @@ def ScheduleModifier : OpenMP_I32EnumAttr< def ScheduleModifierAttr : OpenMP_EnumAttr; //===----------------------------------------------------------------------===// -// target_region_flags enum. +// target_exec_mode enum. //===----------------------------------------------------------------------===// -def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">; -def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 0>; -def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 1>; +def TargetExecModeBare : I32EnumAttrCase<"bare", 0>; +def TargetExecModeGeneric : I32EnumAttrCase<"generic", 1>; +def TargetExecModeSpmd : I32EnumAttrCase<"spmd", 2>; -def TargetRegionFlags : OpenMP_BitEnumAttr< - "TargetRegionFlags", - "target region property flags", [ - TargetRegionFlagsNone, - TargetRegionFlagsSpmd, - TargetRegionFlagsTripCount +def TargetExecMode : OpenMP_I32EnumAttr< + "TargetExecMode", + "target execution mode, mirroring the `OMPTgtExecModeFlags` LLVM enum", [ + TargetExecModeBare, + TargetExecModeGeneric, + TargetExecModeSpmd, ]>; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index be114ea4fb631..6569905c5fae4 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1517,13 +1517,17 @@ def TargetOp : OpenMP_Op<"target", traits = [ /// operations, the top level one will be the one captured. Operation *getInnermostCapturedOmpOp(); - /// Infers the kernel type (Generic, SPMD or Generic-SPMD) based on the - /// contents of the target region. + /// Infers the kernel type (Bare, Generic or SPMD) based on the contents of + /// the target region. /// /// \param capturedOp result of a still valid (no modifications made to any /// nested operations) previous call to `getInnermostCapturedOmpOp()`. - static ::mlir::omp::TargetRegionFlags - getKernelExecFlags(Operation *capturedOp); + /// \param hostEvalTripCount output argument to store whether this kernel + /// wraps a loop whose bounds must be evaluated on the host prior to + /// launching it. + static ::mlir::omp::TargetExecMode + getKernelExecFlags(Operation *capturedOp, + bool *hostEvalTripCount = nullptr); }] # clausesExtraClassDeclaration; let assemblyFormat = clausesAssemblyFormat # [{ diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 8854e908c71f3..c3c17006fe571 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1974,8 +1974,9 @@ LogicalResult TargetOp::verifyRegions() { return emitError("target containing multiple 'omp.teams' nested ops"); // Check that host_eval values are only used in legal ways. + bool hostEvalTripCount; Operation *capturedOp = getInnermostCapturedOmpOp(); - TargetRegionFlags execFlags = getKernelExecFlags(capturedOp); + TargetExecMode execMode = getKernelExecFlags(capturedOp, &hostEvalTripCount); for (Value hostEvalArg : cast(getOperation()).getHostEvalBlockArgs()) { for (Operation *user : hostEvalArg.getUsers()) { @@ -1990,7 +1991,7 @@ LogicalResult TargetOp::verifyRegions() { "and 'thread_limit' in 'omp.teams'"; } if (auto parallelOp = dyn_cast(user)) { - if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) && + if (execMode == TargetExecMode::spmd && parallelOp->isAncestor(capturedOp) && hostEvalArg == parallelOp.getNumThreads()) continue; @@ -2000,8 +2001,7 @@ LogicalResult TargetOp::verifyRegions() { "'omp.parallel' when representing target SPMD"; } if (auto loopNestOp = dyn_cast(user)) { - if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) && - loopNestOp.getOperation() == capturedOp && + if (hostEvalTripCount && loopNestOp.getOperation() == capturedOp && (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) || llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) || llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg))) @@ -2106,7 +2106,9 @@ Operation *TargetOp::getInnermostCapturedOmpOp() { }); } -TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { +TargetExecMode TargetOp::getKernelExecFlags(Operation *capturedOp, + bool *hostEvalTripCount) { + // TODO: Support detection of bare kernel mode. // A non-null captured op is only valid if it resides inside of a TargetOp // and is the result of calling getInnermostCapturedOmpOp() on it. TargetOp targetOp = @@ -2115,9 +2117,12 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) && "unexpected captured op"); + if (hostEvalTripCount) + *hostEvalTripCount = false; + // If it's not capturing a loop, it's a default target region. if (!isa_and_present(capturedOp)) - return TargetRegionFlags::none; + return TargetExecMode::generic; // Get the innermost non-simd loop wrapper. SmallVector loopWrappers; @@ -2130,53 +2135,59 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { auto numWrappers = std::distance(innermostWrapper, loopWrappers.end()); if (numWrappers != 1 && numWrappers != 2) - return TargetRegionFlags::none; + return TargetExecMode::generic; // Detect target-teams-distribute-parallel-wsloop[-simd]. if (numWrappers == 2) { if (!isa(innermostWrapper)) - return TargetRegionFlags::none; + return TargetExecMode::generic; innermostWrapper = std::next(innermostWrapper); if (!isa(innermostWrapper)) - return TargetRegionFlags::none; + return TargetExecMode::generic; Operation *parallelOp = (*innermostWrapper)->getParentOp(); if (!isa_and_present(parallelOp)) - return TargetRegionFlags::none; + return TargetExecMode::generic; Operation *teamsOp = parallelOp->getParentOp(); if (!isa_and_present(teamsOp)) - return TargetRegionFlags::none; + return TargetExecMode::generic; - if (teamsOp->getParentOp() == targetOp.getOperation()) - return TargetRegionFlags::spmd | TargetRegionFlags::trip_count; + if (teamsOp->getParentOp() == targetOp.getOperation()) { + if (hostEvalTripCount) + *hostEvalTripCount = true; + return TargetExecMode::spmd; + } } // Detect target-teams-distribute[-simd] and target-teams-loop. else if (isa(innermostWrapper)) { Operation *teamsOp = (*innermostWrapper)->getParentOp(); if (!isa_and_present(teamsOp)) - return TargetRegionFlags::none; + return TargetExecMode::generic; if (teamsOp->getParentOp() != targetOp.getOperation()) - return TargetRegionFlags::none; + return TargetExecMode::generic; + + if (hostEvalTripCount) + *hostEvalTripCount = true; if (isa(innermostWrapper)) - return TargetRegionFlags::spmd | TargetRegionFlags::trip_count; + return TargetExecMode::spmd; - return TargetRegionFlags::trip_count; + return TargetExecMode::generic; } // Detect target-parallel-wsloop[-simd]. else if (isa(innermostWrapper)) { Operation *parallelOp = (*innermostWrapper)->getParentOp(); if (!isa_and_present(parallelOp)) - return TargetRegionFlags::none; + return TargetExecMode::generic; if (parallelOp->getParentOp() == targetOp.getOperation()) - return TargetRegionFlags::spmd; + return TargetExecMode::spmd; } - return TargetRegionFlags::none; + return TargetExecMode::generic; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 88601ef45911e..d49cc38cd7925 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -5354,11 +5354,18 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, } // Update kernel bounds structure for the `OpenMPIRBuilder` to use. - omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp); - attrs.ExecFlags = - omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd) - ? llvm::omp::OMP_TGT_EXEC_MODE_SPMD - : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC; + omp::TargetExecMode execMode = targetOp.getKernelExecFlags(capturedOp); + switch (execMode) { + case omp::TargetExecMode::bare: + attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_BARE; + break; + case omp::TargetExecMode::generic: + attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_GENERIC; + break; + case omp::TargetExecMode::spmd: + attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD; + break; + } attrs.MinTeams = minTeamsVal; attrs.MaxTeams.front() = maxTeamsVal; attrs.MinThreads = 1; @@ -5408,8 +5415,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, if (numThreads) attrs.MaxThreads = moduleTranslation.lookupValue(numThreads); - if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp), - omp::TargetRegionFlags::trip_count)) { + bool hostEvalTripCount; + targetOp.getKernelExecFlags(capturedOp, &hostEvalTripCount); + if (hostEvalTripCount) { llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); attrs.LoopTripCount = nullptr;