Skip to content

[MLIR][OpenMP] Remove Generic-SPMD early detection #150922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -227,15 +227,13 @@ def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;
//===----------------------------------------------------------------------===//

def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>;
def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>;
def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>;
def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 0>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't this intended to correspond to the OMP_TGT_EXEC_MODE_SPMD = 1 << 1 flag? Is there some comments/doxygen explaining what those flags mean?

def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 1>;

def TargetRegionFlags : OpenMP_BitEnumAttr<
"TargetRegionFlags",
"target region property flags", [
TargetRegionFlagsNone,
TargetRegionFlagsGeneric,
TargetRegionFlagsSpmd,
TargetRegionFlagsTripCount
]>;
Expand Down
48 changes: 11 additions & 37 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2117,7 +2117,7 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {

// If it's not capturing a loop, it's a default target region.
if (!isa_and_present<LoopNestOp>(capturedOp))
return TargetRegionFlags::generic;
return TargetRegionFlags::none;

// Get the innermost non-simd loop wrapper.
SmallVector<LoopWrapperInterface> loopWrappers;
Expand All @@ -2130,24 +2130,24 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {

auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
if (numWrappers != 1 && numWrappers != 2)
return TargetRegionFlags::generic;
return TargetRegionFlags::none;

// Detect target-teams-distribute-parallel-wsloop[-simd].
if (numWrappers == 2) {
if (!isa<WsloopOp>(innermostWrapper))
return TargetRegionFlags::generic;
return TargetRegionFlags::none;

innermostWrapper = std::next(innermostWrapper);
if (!isa<DistributeOp>(innermostWrapper))
return TargetRegionFlags::generic;
return TargetRegionFlags::none;

Operation *parallelOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
return TargetRegionFlags::generic;
return TargetRegionFlags::none;

Operation *teamsOp = parallelOp->getParentOp();
if (!isa_and_present<TeamsOp>(teamsOp))
return TargetRegionFlags::generic;
return TargetRegionFlags::none;

if (teamsOp->getParentOp() == targetOp.getOperation())
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
Expand All @@ -2156,53 +2156,27 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
Operation *teamsOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<TeamsOp>(teamsOp))
return TargetRegionFlags::generic;
return TargetRegionFlags::none;

if (teamsOp->getParentOp() != targetOp.getOperation())
return TargetRegionFlags::generic;
return TargetRegionFlags::none;

if (isa<LoopOp>(innermostWrapper))
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;

// Find single immediately nested captured omp.parallel and add spmd flag
// (generic-spmd case).
//
// TODO: This shouldn't have to be done here, as it is too easy to break.
// The openmp-opt pass should be updated to be able to promote kernels like
// this from "Generic" to "Generic-SPMD". However, the use of the
// `kmpc_distribute_static_loop` family of functions produced by the
// OMPIRBuilder for these kernels prevents that from working.
Dialect *ompDialect = targetOp->getDialect();
Operation *nestedCapture = findCapturedOmpOp(
capturedOp, /*checkSingleMandatoryExec=*/false,
[&](Operation *sibling) {
return sibling && (ompDialect != sibling->getDialect() ||
sibling->hasTrait<OpTrait::IsTerminator>());
});

TargetRegionFlags result =
TargetRegionFlags::generic | TargetRegionFlags::trip_count;

if (!nestedCapture)
return result;

while (nestedCapture->getParentOp() != capturedOp)
nestedCapture = nestedCapture->getParentOp();

return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
: result;
return TargetRegionFlags::trip_count;
}
// Detect target-parallel-wsloop[-simd].
else if (isa<WsloopOp>(innermostWrapper)) {
Operation *parallelOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
return TargetRegionFlags::generic;
return TargetRegionFlags::none;

if (parallelOp->getParentOp() == targetOp.getOperation())
return TargetRegionFlags::spmd;
}

return TargetRegionFlags::generic;
return TargetRegionFlags::none;
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5324,16 +5324,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,

// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
assert(
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
omp::TargetRegionFlags::spmd) &&
"invalid kernel flags");
attrs.ExecFlags =
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
: llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
: llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
? llvm::omp::OMP_TGT_EXEC_MODE_SPMD
: llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
attrs.MinTeams = minTeamsVal;
attrs.MaxTeams.front() = maxTeamsVal;
attrs.MinThreads = 1;
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
}
}

// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]]
// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:1]]
// DEVICE: @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
// DEVICE: @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
// DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},
Expand Down