diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td index c080c3fac87d4..ce0ebabd58125 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td @@ -227,15 +227,13 @@ def ScheduleModifierAttr : OpenMP_EnumAttr; //===----------------------------------------------------------------------===// def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">; -def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>; -def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>; -def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>; +def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 0>; +def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 1>; def TargetRegionFlags : OpenMP_BitEnumAttr< "TargetRegionFlags", "target region property flags", [ TargetRegionFlagsNone, - TargetRegionFlagsGeneric, TargetRegionFlagsSpmd, TargetRegionFlagsTripCount ]>; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index c1c1767ef90b0..8854e908c71f3 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2117,7 +2117,7 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { // If it's not capturing a loop, it's a default target region. if (!isa_and_present(capturedOp)) - return TargetRegionFlags::generic; + return TargetRegionFlags::none; // Get the innermost non-simd loop wrapper. SmallVector loopWrappers; @@ -2130,24 +2130,24 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { auto numWrappers = std::distance(innermostWrapper, loopWrappers.end()); if (numWrappers != 1 && numWrappers != 2) - return TargetRegionFlags::generic; + return TargetRegionFlags::none; // Detect target-teams-distribute-parallel-wsloop[-simd]. if (numWrappers == 2) { if (!isa(innermostWrapper)) - return TargetRegionFlags::generic; + return TargetRegionFlags::none; innermostWrapper = std::next(innermostWrapper); if (!isa(innermostWrapper)) - return TargetRegionFlags::generic; + return TargetRegionFlags::none; Operation *parallelOp = (*innermostWrapper)->getParentOp(); if (!isa_and_present(parallelOp)) - return TargetRegionFlags::generic; + return TargetRegionFlags::none; Operation *teamsOp = parallelOp->getParentOp(); if (!isa_and_present(teamsOp)) - return TargetRegionFlags::generic; + return TargetRegionFlags::none; if (teamsOp->getParentOp() == targetOp.getOperation()) return TargetRegionFlags::spmd | TargetRegionFlags::trip_count; @@ -2156,53 +2156,27 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { else if (isa(innermostWrapper)) { Operation *teamsOp = (*innermostWrapper)->getParentOp(); if (!isa_and_present(teamsOp)) - return TargetRegionFlags::generic; + return TargetRegionFlags::none; if (teamsOp->getParentOp() != targetOp.getOperation()) - return TargetRegionFlags::generic; + return TargetRegionFlags::none; if (isa(innermostWrapper)) return TargetRegionFlags::spmd | TargetRegionFlags::trip_count; - // Find single immediately nested captured omp.parallel and add spmd flag - // (generic-spmd case). - // - // TODO: This shouldn't have to be done here, as it is too easy to break. - // The openmp-opt pass should be updated to be able to promote kernels like - // this from "Generic" to "Generic-SPMD". However, the use of the - // `kmpc_distribute_static_loop` family of functions produced by the - // OMPIRBuilder for these kernels prevents that from working. - Dialect *ompDialect = targetOp->getDialect(); - Operation *nestedCapture = findCapturedOmpOp( - capturedOp, /*checkSingleMandatoryExec=*/false, - [&](Operation *sibling) { - return sibling && (ompDialect != sibling->getDialect() || - sibling->hasTrait()); - }); - - TargetRegionFlags result = - TargetRegionFlags::generic | TargetRegionFlags::trip_count; - - if (!nestedCapture) - return result; - - while (nestedCapture->getParentOp() != capturedOp) - nestedCapture = nestedCapture->getParentOp(); - - return isa(nestedCapture) ? result | TargetRegionFlags::spmd - : result; + return TargetRegionFlags::trip_count; } // Detect target-parallel-wsloop[-simd]. else if (isa(innermostWrapper)) { Operation *parallelOp = (*innermostWrapper)->getParentOp(); if (!isa_and_present(parallelOp)) - return TargetRegionFlags::generic; + return TargetRegionFlags::none; if (parallelOp->getParentOp() == targetOp.getOperation()) return TargetRegionFlags::spmd; } - return TargetRegionFlags::generic; + return TargetRegionFlags::none; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 9f18199c75b4b..34358cdcece3c 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -5324,16 +5324,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, // Update kernel bounds structure for the `OpenMPIRBuilder` to use. omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp); - assert( - omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic | - omp::TargetRegionFlags::spmd) && - "invalid kernel flags"); attrs.ExecFlags = - omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic) - ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd) - ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD - : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC - : llvm::omp::OMP_TGT_EXEC_MODE_SPMD; + omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd) + ? llvm::omp::OMP_TGT_EXEC_MODE_SPMD + : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC; attrs.MinTeams = minTeamsVal; attrs.MaxTeams.front() = maxTeamsVal; attrs.MinThreads = 1; diff --git a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir index 9bb2b40a43def..fd190a7b95f66 100644 --- a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir +++ b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir @@ -87,7 +87,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo } } -// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]] +// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:1]] // DEVICE: @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata" // DEVICE: @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy { // DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},