Skip to content

Commit 39800fa

Browse files
committed
[MLIR][OpenMP] Remove Generic-SPMD early detection
This patch removes logic from MLIR to attempt identifying Generic kernels that could be executed in SPMD mode. This optimization is done by the OpenMPOpt pass for Clang and is only required here to circumvent missing support for the new DeviceRTL APIs used in MLIR to LLVM IR translation that Clang doesn't currently use (e.g. `kmpc_distribute_static_loop` ). Removing checks in MLIR avoids duplicating the logic that should be centralized in the OpenMPOpt pass. Additionally, offloading kernels currently compiled through the OpenMP dialect fail to run parallel regions properly when in Generic mode. By disabling early detection, this issue becomes apparent for a range of kernels where this was masked by having them run in SPMD mode.
1 parent b09b05a commit 39800fa

File tree

4 files changed

+17
-51
lines changed

4 files changed

+17
-51
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,15 +227,13 @@ def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;
227227
//===----------------------------------------------------------------------===//
228228

229229
def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
230-
def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>;
231-
def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>;
232-
def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>;
230+
def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 0>;
231+
def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 1>;
233232

234233
def TargetRegionFlags : OpenMP_BitEnumAttr<
235234
"TargetRegionFlags",
236235
"target region property flags", [
237236
TargetRegionFlagsNone,
238-
TargetRegionFlagsGeneric,
239237
TargetRegionFlagsSpmd,
240238
TargetRegionFlagsTripCount
241239
]>;

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 11 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2117,7 +2117,7 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
21172117

21182118
// If it's not capturing a loop, it's a default target region.
21192119
if (!isa_and_present<LoopNestOp>(capturedOp))
2120-
return TargetRegionFlags::generic;
2120+
return TargetRegionFlags::none;
21212121

21222122
// Get the innermost non-simd loop wrapper.
21232123
SmallVector<LoopWrapperInterface> loopWrappers;
@@ -2130,24 +2130,24 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
21302130

21312131
auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
21322132
if (numWrappers != 1 && numWrappers != 2)
2133-
return TargetRegionFlags::generic;
2133+
return TargetRegionFlags::none;
21342134

21352135
// Detect target-teams-distribute-parallel-wsloop[-simd].
21362136
if (numWrappers == 2) {
21372137
if (!isa<WsloopOp>(innermostWrapper))
2138-
return TargetRegionFlags::generic;
2138+
return TargetRegionFlags::none;
21392139

21402140
innermostWrapper = std::next(innermostWrapper);
21412141
if (!isa<DistributeOp>(innermostWrapper))
2142-
return TargetRegionFlags::generic;
2142+
return TargetRegionFlags::none;
21432143

21442144
Operation *parallelOp = (*innermostWrapper)->getParentOp();
21452145
if (!isa_and_present<ParallelOp>(parallelOp))
2146-
return TargetRegionFlags::generic;
2146+
return TargetRegionFlags::none;
21472147

21482148
Operation *teamsOp = parallelOp->getParentOp();
21492149
if (!isa_and_present<TeamsOp>(teamsOp))
2150-
return TargetRegionFlags::generic;
2150+
return TargetRegionFlags::none;
21512151

21522152
if (teamsOp->getParentOp() == targetOp.getOperation())
21532153
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
@@ -2156,53 +2156,27 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
21562156
else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
21572157
Operation *teamsOp = (*innermostWrapper)->getParentOp();
21582158
if (!isa_and_present<TeamsOp>(teamsOp))
2159-
return TargetRegionFlags::generic;
2159+
return TargetRegionFlags::none;
21602160

21612161
if (teamsOp->getParentOp() != targetOp.getOperation())
2162-
return TargetRegionFlags::generic;
2162+
return TargetRegionFlags::none;
21632163

21642164
if (isa<LoopOp>(innermostWrapper))
21652165
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
21662166

2167-
// Find single immediately nested captured omp.parallel and add spmd flag
2168-
// (generic-spmd case).
2169-
//
2170-
// TODO: This shouldn't have to be done here, as it is too easy to break.
2171-
// The openmp-opt pass should be updated to be able to promote kernels like
2172-
// this from "Generic" to "Generic-SPMD". However, the use of the
2173-
// `kmpc_distribute_static_loop` family of functions produced by the
2174-
// OMPIRBuilder for these kernels prevents that from working.
2175-
Dialect *ompDialect = targetOp->getDialect();
2176-
Operation *nestedCapture = findCapturedOmpOp(
2177-
capturedOp, /*checkSingleMandatoryExec=*/false,
2178-
[&](Operation *sibling) {
2179-
return sibling && (ompDialect != sibling->getDialect() ||
2180-
sibling->hasTrait<OpTrait::IsTerminator>());
2181-
});
2182-
2183-
TargetRegionFlags result =
2184-
TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2185-
2186-
if (!nestedCapture)
2187-
return result;
2188-
2189-
while (nestedCapture->getParentOp() != capturedOp)
2190-
nestedCapture = nestedCapture->getParentOp();
2191-
2192-
return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2193-
: result;
2167+
return TargetRegionFlags::trip_count;
21942168
}
21952169
// Detect target-parallel-wsloop[-simd].
21962170
else if (isa<WsloopOp>(innermostWrapper)) {
21972171
Operation *parallelOp = (*innermostWrapper)->getParentOp();
21982172
if (!isa_and_present<ParallelOp>(parallelOp))
2199-
return TargetRegionFlags::generic;
2173+
return TargetRegionFlags::none;
22002174

22012175
if (parallelOp->getParentOp() == targetOp.getOperation())
22022176
return TargetRegionFlags::spmd;
22032177
}
22042178

2205-
return TargetRegionFlags::generic;
2179+
return TargetRegionFlags::none;
22062180
}
22072181

22082182
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5355,16 +5355,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
53555355

53565356
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
53575357
omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
5358-
assert(
5359-
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
5360-
omp::TargetRegionFlags::spmd) &&
5361-
"invalid kernel flags");
53625358
attrs.ExecFlags =
5363-
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
5364-
? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
5365-
? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
5366-
: llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
5367-
: llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
5359+
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
5360+
? llvm::omp::OMP_TGT_EXEC_MODE_SPMD
5361+
: llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
53685362
attrs.MinTeams = minTeamsVal;
53695363
attrs.MaxTeams.front() = maxTeamsVal;
53705364
attrs.MinThreads = 1;

mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
8787
}
8888
}
8989

90-
// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]]
90+
// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:1]]
9191
// DEVICE: @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
9292
// DEVICE: @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
9393
// DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},

0 commit comments

Comments
 (0)