Skip to content

Commit ad2a460

Browse files
committed
[MLIR][OpenMP] Remove Generic-SPMD early detection
This patch removes logic from MLIR to attempt identifying Generic kernels that could be executed in SPMD mode. This optimization is done by the OpenMPOpt pass for Clang and is only required here to circumvent missing support for the new DeviceRTL APIs used in MLIR to LLVM IR translation that Clang doesn't currently use (e.g. `kmpc_distribute_static_loop`). Removing checks in MLIR avoids duplicating the logic that should be centralized in the OpenMPOpt pass. Additionally, offloading kernels currently compiled through the OpenMP dialect fail to run parallel regions properly when in Generic mode. By disabling early detection, this issue becomes apparent for a range of kernels where this was masked by having them run in SPMD mode.
1 parent b92ff6b commit ad2a460

File tree

4 files changed

+23
-63
lines changed

4 files changed

+23
-63
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -227,24 +227,21 @@ def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;
227227
//===----------------------------------------------------------------------===//
228228

229229
def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
230-
def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>;
231-
def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>;
232-
def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>;
233-
def TargetRegionFlagsNoLoop : I32BitEnumAttrCaseBit<"no_loop", 3>;
230+
def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 0>;
231+
def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 1>;
232+
def TargetRegionFlagsNoLoop : I32BitEnumAttrCaseBit<"no_loop", 2>;
234233

235234
def TargetRegionFlags : OpenMP_BitEnumAttr<
236235
"TargetRegionFlags",
237236
"These flags describe properties of the target kernel. "
238-
"TargetRegionFlagsGeneric - denotes generic kernel. "
239237
"TargetRegionFlagsSpmd - denotes SPMD kernel. "
240238
"TargetRegionFlagsNoLoop - denotes kernel where "
241239
"num_teams * num_threads >= loop_trip_count. It allows the conversion "
242240
"of loops into sequential code by ensuring that each team/thread "
243241
"executes at most one iteration. "
244-
"TargetRegionFlagsTripCount - checks if the loop trip count should be "
245-
"calculated.", [
242+
"TargetRegionFlagsTripCount - checks if a singular loop trip count should "
243+
"be calculated for the target region.", [
246244
TargetRegionFlagsNone,
247-
TargetRegionFlagsGeneric,
248245
TargetRegionFlagsSpmd,
249246
TargetRegionFlagsTripCount,
250247
TargetRegionFlagsNoLoop

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 11 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2373,7 +2373,7 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
23732373

23742374
// If it's not capturing a loop, it's a default target region.
23752375
if (!isa_and_present<LoopNestOp>(capturedOp))
2376-
return TargetRegionFlags::generic;
2376+
return TargetRegionFlags::none;
23772377

23782378
// Get the innermost non-simd loop wrapper.
23792379
SmallVector<LoopWrapperInterface> loopWrappers;
@@ -2386,25 +2386,25 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
23862386

23872387
auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
23882388
if (numWrappers != 1 && numWrappers != 2)
2389-
return TargetRegionFlags::generic;
2389+
return TargetRegionFlags::none;
23902390

23912391
// Detect target-teams-distribute-parallel-wsloop[-simd].
23922392
if (numWrappers == 2) {
23932393
WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
23942394
if (!wsloopOp)
2395-
return TargetRegionFlags::generic;
2395+
return TargetRegionFlags::none;
23962396

23972397
innermostWrapper = std::next(innermostWrapper);
23982398
if (!isa<DistributeOp>(innermostWrapper))
2399-
return TargetRegionFlags::generic;
2399+
return TargetRegionFlags::none;
24002400

24012401
Operation *parallelOp = (*innermostWrapper)->getParentOp();
24022402
if (!isa_and_present<ParallelOp>(parallelOp))
2403-
return TargetRegionFlags::generic;
2403+
return TargetRegionFlags::none;
24042404

24052405
TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->getParentOp());
24062406
if (!teamsOp)
2407-
return TargetRegionFlags::generic;
2407+
return TargetRegionFlags::none;
24082408

24092409
if (teamsOp->getParentOp() == targetOp.getOperation()) {
24102410
TargetRegionFlags result =
@@ -2418,53 +2418,27 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
24182418
else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
24192419
Operation *teamsOp = (*innermostWrapper)->getParentOp();
24202420
if (!isa_and_present<TeamsOp>(teamsOp))
2421-
return TargetRegionFlags::generic;
2421+
return TargetRegionFlags::none;
24222422

24232423
if (teamsOp->getParentOp() != targetOp.getOperation())
2424-
return TargetRegionFlags::generic;
2424+
return TargetRegionFlags::none;
24252425

24262426
if (isa<LoopOp>(innermostWrapper))
24272427
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
24282428

2429-
// Find single immediately nested captured omp.parallel and add spmd flag
2430-
// (generic-spmd case).
2431-
//
2432-
// TODO: This shouldn't have to be done here, as it is too easy to break.
2433-
// The openmp-opt pass should be updated to be able to promote kernels like
2434-
// this from "Generic" to "Generic-SPMD". However, the use of the
2435-
// `kmpc_distribute_static_loop` family of functions produced by the
2436-
// OMPIRBuilder for these kernels prevents that from working.
2437-
Dialect *ompDialect = targetOp->getDialect();
2438-
Operation *nestedCapture = findCapturedOmpOp(
2439-
capturedOp, /*checkSingleMandatoryExec=*/false,
2440-
[&](Operation *sibling) {
2441-
return sibling && (ompDialect != sibling->getDialect() ||
2442-
sibling->hasTrait<OpTrait::IsTerminator>());
2443-
});
2444-
2445-
TargetRegionFlags result =
2446-
TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2447-
2448-
if (!nestedCapture)
2449-
return result;
2450-
2451-
while (nestedCapture->getParentOp() != capturedOp)
2452-
nestedCapture = nestedCapture->getParentOp();
2453-
2454-
return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2455-
: result;
2429+
return TargetRegionFlags::trip_count;
24562430
}
24572431
// Detect target-parallel-wsloop[-simd].
24582432
else if (isa<WsloopOp>(innermostWrapper)) {
24592433
Operation *parallelOp = (*innermostWrapper)->getParentOp();
24602434
if (!isa_and_present<ParallelOp>(parallelOp))
2461-
return TargetRegionFlags::generic;
2435+
return TargetRegionFlags::none;
24622436

24632437
if (parallelOp->getParentOp() == targetOp.getOperation())
24642438
return TargetRegionFlags::spmd;
24652439
}
24662440

2467-
return TargetRegionFlags::generic;
2441+
return TargetRegionFlags::none;
24682442
}
24692443

24702444
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2605,9 +2605,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
26052605
targetOp.getKernelExecFlags(targetCapturedOp);
26062606
if (omp::bitEnumContainsAll(kernelFlags,
26072607
omp::TargetRegionFlags::spmd |
2608-
omp::TargetRegionFlags::no_loop) &&
2609-
!omp::bitEnumContainsAny(kernelFlags,
2610-
omp::TargetRegionFlags::generic))
2608+
omp::TargetRegionFlags::no_loop))
26112609
noLoopMode = true;
26122610
}
26132611
}
@@ -5438,21 +5436,12 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
54385436

54395437
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
54405438
omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
5441-
assert(
5442-
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
5443-
omp::TargetRegionFlags::spmd) &&
5444-
"invalid kernel flags");
54455439
attrs.ExecFlags =
5446-
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
5447-
? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
5448-
? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
5449-
: llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
5450-
: llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
5451-
if (omp::bitEnumContainsAll(kernelFlags,
5452-
omp::TargetRegionFlags::spmd |
5453-
omp::TargetRegionFlags::no_loop) &&
5454-
!omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
5455-
attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
5440+
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
5441+
? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::no_loop)
5442+
? llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP
5443+
: llvm::omp::OMP_TGT_EXEC_MODE_SPMD
5444+
: llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
54565445

54575446
attrs.MinTeams = minTeamsVal;
54585447
attrs.MaxTeams.front() = maxTeamsVal;

mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
8484
}
8585
}
8686

87-
// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]]
87+
// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:1]]
8888
// DEVICE: @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
8989
// DEVICE: @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
9090
// DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},

0 commit comments

Comments
 (0)