Skip to content

Commit e3a67f6

Browse files
committed
Update TargetRegionFlags to mirror OMPTgtExecModeFlags
1 parent ad2a460 commit e3a67f6

File tree

5 files changed

+73
-61
lines changed

5 files changed

+73
-61
lines changed

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6775,7 +6775,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
67756775
Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
67766776
Constant *IsSPMDVal = ConstantInt::getSigned(Int8, Attrs.ExecFlags);
67776777
Constant *UseGenericStateMachineVal = ConstantInt::getSigned(
6778-
Int8, Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD);
6778+
Int8, Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD &&
6779+
Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP);
67796780
Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true);
67806781
Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0);
67816782

mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -223,28 +223,21 @@ def ScheduleModifier : OpenMP_I32EnumAttr<
223223
def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;
224224

225225
//===----------------------------------------------------------------------===//
226-
// target_region_flags enum.
226+
// target_exec_mode enum.
227227
//===----------------------------------------------------------------------===//
228228

229-
def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
230-
def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 0>;
231-
def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 1>;
232-
def TargetRegionFlagsNoLoop : I32BitEnumAttrCaseBit<"no_loop", 2>;
233-
234-
def TargetRegionFlags : OpenMP_BitEnumAttr<
235-
"TargetRegionFlags",
236-
"These flags describe properties of the target kernel. "
237-
"TargetRegionFlagsSpmd - denotes SPMD kernel. "
238-
"TargetRegionFlagsNoLoop - denotes kernel where "
239-
"num_teams * num_threads >= loop_trip_count. It allows the conversion "
240-
"of loops into sequential code by ensuring that each team/thread "
241-
"executes at most one iteration. "
242-
"TargetRegionFlagsTripCount - checks if a singular loop trip count should "
243-
"be calculated for the target region.", [
244-
TargetRegionFlagsNone,
245-
TargetRegionFlagsSpmd,
246-
TargetRegionFlagsTripCount,
247-
TargetRegionFlagsNoLoop
229+
def TargetExecModeBare : I32EnumAttrCase<"bare", 0>;
230+
def TargetExecModeGeneric : I32EnumAttrCase<"generic", 1>;
231+
def TargetExecModeSpmd : I32EnumAttrCase<"spmd", 2>;
232+
def TargetExecModeSpmdNoLoop : I32EnumAttrCase<"no_loop", 3>;
233+
234+
def TargetExecMode : OpenMP_I32EnumAttr<
235+
"TargetExecMode",
236+
"target execution mode, mirroring the `OMPTgtExecModeFlags` LLVM enum", [
237+
TargetExecModeBare,
238+
TargetExecModeGeneric,
239+
TargetExecModeSpmd,
240+
TargetExecModeSpmdNoLoop,
248241
]>;
249242

250243
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,13 +1522,17 @@ def TargetOp : OpenMP_Op<"target", traits = [
15221522
/// operations, the top level one will be the one captured.
15231523
Operation *getInnermostCapturedOmpOp();
15241524

1525-
/// Infers the kernel type (Generic, SPMD or Generic-SPMD) based on the
1526-
/// contents of the target region.
1525+
/// Infers the kernel type (Bare, Generic or SPMD) based on the contents of
1526+
/// the target region.
15271527
///
15281528
/// \param capturedOp result of a still valid (no modifications made to any
15291529
/// nested operations) previous call to `getInnermostCapturedOmpOp()`.
1530-
static ::mlir::omp::TargetRegionFlags
1531-
getKernelExecFlags(Operation *capturedOp);
1530+
/// \param hostEvalTripCount output argument to store whether this kernel
1531+
/// wraps a loop whose bounds must be evaluated on the host prior to
1532+
/// launching it.
1533+
static ::mlir::omp::TargetExecMode
1534+
getKernelExecFlags(Operation *capturedOp,
1535+
bool *hostEvalTripCount = nullptr);
15321536
}] # clausesExtraClassDeclaration;
15331537

15341538
let assemblyFormat = clausesAssemblyFormat # [{

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2205,8 +2205,9 @@ LogicalResult TargetOp::verifyRegions() {
22052205
return emitError("target containing multiple 'omp.teams' nested ops");
22062206

22072207
// Check that host_eval values are only used in legal ways.
2208+
bool hostEvalTripCount;
22082209
Operation *capturedOp = getInnermostCapturedOmpOp();
2209-
TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
2210+
TargetExecMode execMode = getKernelExecFlags(capturedOp, &hostEvalTripCount);
22102211
for (Value hostEvalArg :
22112212
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
22122213
for (Operation *user : hostEvalArg.getUsers()) {
@@ -2221,7 +2222,7 @@ LogicalResult TargetOp::verifyRegions() {
22212222
"and 'thread_limit' in 'omp.teams'";
22222223
}
22232224
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
2224-
if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
2225+
if (execMode == TargetExecMode::spmd &&
22252226
parallelOp->isAncestor(capturedOp) &&
22262227
hostEvalArg == parallelOp.getNumThreads())
22272228
continue;
@@ -2231,8 +2232,7 @@ LogicalResult TargetOp::verifyRegions() {
22312232
"'omp.parallel' when representing target SPMD";
22322233
}
22332234
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2234-
if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2235-
loopNestOp.getOperation() == capturedOp &&
2235+
if (hostEvalTripCount && loopNestOp.getOperation() == capturedOp &&
22362236
(llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
22372237
llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
22382238
llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
@@ -2362,7 +2362,9 @@ static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp,
23622362
ompFlags.getAssumeThreadsOversubscription();
23632363
}
23642364

2365-
TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2365+
TargetExecMode TargetOp::getKernelExecFlags(Operation *capturedOp,
2366+
bool *hostEvalTripCount) {
2367+
// TODO: Support detection of bare kernel mode.
23662368
// A non-null captured op is only valid if it resides inside of a TargetOp
23672369
// and is the result of calling getInnermostCapturedOmpOp() on it.
23682370
TargetOp targetOp =
@@ -2371,9 +2373,12 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
23712373
(targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
23722374
"unexpected captured op");
23732375

2376+
if (hostEvalTripCount)
2377+
*hostEvalTripCount = false;
2378+
23742379
// If it's not capturing a loop, it's a default target region.
23752380
if (!isa_and_present<LoopNestOp>(capturedOp))
2376-
return TargetRegionFlags::none;
2381+
return TargetExecMode::generic;
23772382

23782383
// Get the innermost non-simd loop wrapper.
23792384
SmallVector<LoopWrapperInterface> loopWrappers;
@@ -2386,59 +2391,63 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
23862391

23872392
auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
23882393
if (numWrappers != 1 && numWrappers != 2)
2389-
return TargetRegionFlags::none;
2394+
return TargetExecMode::generic;
23902395

23912396
// Detect target-teams-distribute-parallel-wsloop[-simd].
23922397
if (numWrappers == 2) {
23932398
WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
23942399
if (!wsloopOp)
2395-
return TargetRegionFlags::none;
2400+
return TargetExecMode::generic;
23962401

23972402
innermostWrapper = std::next(innermostWrapper);
23982403
if (!isa<DistributeOp>(innermostWrapper))
2399-
return TargetRegionFlags::none;
2404+
return TargetExecMode::generic;
24002405

24012406
Operation *parallelOp = (*innermostWrapper)->getParentOp();
24022407
if (!isa_and_present<ParallelOp>(parallelOp))
2403-
return TargetRegionFlags::none;
2408+
return TargetExecMode::generic;
24042409

24052410
TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->getParentOp());
24062411
if (!teamsOp)
2407-
return TargetRegionFlags::none;
2412+
return TargetExecMode::generic;
24082413

24092414
if (teamsOp->getParentOp() == targetOp.getOperation()) {
2410-
TargetRegionFlags result =
2411-
TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2415+
TargetExecMode result = TargetExecMode::spmd;
24122416
if (canPromoteToNoLoop(capturedOp, teamsOp, wsloopOp))
2413-
result = result | TargetRegionFlags::no_loop;
2417+
result = TargetExecMode::no_loop;
2418+
if (hostEvalTripCount)
2419+
*hostEvalTripCount = true;
24142420
return result;
24152421
}
24162422
}
24172423
// Detect target-teams-distribute[-simd] and target-teams-loop.
24182424
else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
24192425
Operation *teamsOp = (*innermostWrapper)->getParentOp();
24202426
if (!isa_and_present<TeamsOp>(teamsOp))
2421-
return TargetRegionFlags::none;
2427+
return TargetExecMode::generic;
24222428

24232429
if (teamsOp->getParentOp() != targetOp.getOperation())
2424-
return TargetRegionFlags::none;
2430+
return TargetExecMode::generic;
2431+
2432+
if (hostEvalTripCount)
2433+
*hostEvalTripCount = true;
24252434

24262435
if (isa<LoopOp>(innermostWrapper))
2427-
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2436+
return TargetExecMode::spmd;
24282437

2429-
return TargetRegionFlags::trip_count;
2438+
return TargetExecMode::generic;
24302439
}
24312440
// Detect target-parallel-wsloop[-simd].
24322441
else if (isa<WsloopOp>(innermostWrapper)) {
24332442
Operation *parallelOp = (*innermostWrapper)->getParentOp();
24342443
if (!isa_and_present<ParallelOp>(parallelOp))
2435-
return TargetRegionFlags::none;
2444+
return TargetExecMode::generic;
24362445

24372446
if (parallelOp->getParentOp() == targetOp.getOperation())
2438-
return TargetRegionFlags::spmd;
2447+
return TargetExecMode::spmd;
24392448
}
24402449

2441-
return TargetRegionFlags::none;
2450+
return TargetExecMode::generic;
24422451
}
24432452

24442453
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2601,11 +2601,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
26012601
// for every omp.wsloop nested inside a no-loop SPMD target region, even if
26022602
// that loop is not the top-level SPMD one.
26032603
if (loopOp == targetCapturedOp) {
2604-
omp::TargetRegionFlags kernelFlags =
2605-
targetOp.getKernelExecFlags(targetCapturedOp);
2606-
if (omp::bitEnumContainsAll(kernelFlags,
2607-
omp::TargetRegionFlags::spmd |
2608-
omp::TargetRegionFlags::no_loop))
2604+
if (targetOp.getKernelExecFlags(targetCapturedOp) ==
2605+
omp::TargetExecMode::no_loop)
26092606
noLoopMode = true;
26102607
}
26112608
}
@@ -5435,14 +5432,21 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
54355432
}
54365433

54375434
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
5438-
omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
5439-
attrs.ExecFlags =
5440-
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
5441-
? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::no_loop)
5442-
? llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP
5443-
: llvm::omp::OMP_TGT_EXEC_MODE_SPMD
5444-
: llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
5445-
5435+
omp::TargetExecMode execMode = targetOp.getKernelExecFlags(capturedOp);
5436+
switch (execMode) {
5437+
case omp::TargetExecMode::bare:
5438+
attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_BARE;
5439+
break;
5440+
case omp::TargetExecMode::generic:
5441+
attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
5442+
break;
5443+
case omp::TargetExecMode::spmd:
5444+
attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
5445+
break;
5446+
case omp::TargetExecMode::no_loop:
5447+
attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
5448+
break;
5449+
}
54465450
attrs.MinTeams = minTeamsVal;
54475451
attrs.MaxTeams.front() = maxTeamsVal;
54485452
attrs.MinThreads = 1;
@@ -5492,8 +5496,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
54925496
if (numThreads)
54935497
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
54945498

5495-
if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
5496-
omp::TargetRegionFlags::trip_count)) {
5499+
bool hostEvalTripCount;
5500+
targetOp.getKernelExecFlags(capturedOp, &hostEvalTripCount);
5501+
if (hostEvalTripCount) {
54975502
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
54985503
attrs.LoopTripCount = nullptr;
54995504

0 commit comments

Comments
 (0)