Skip to content

Commit 9e948a5

Browse files
committed
Update TargetRegionFlags to mirror OMPTgtExecModeFlags
1 parent 39800fa commit 9e948a5

File tree

4 files changed

+64
-41
lines changed

4 files changed

+64
-41
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -223,19 +223,19 @@ def ScheduleModifier : OpenMP_I32EnumAttr<
223223
def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;
224224

225225
//===----------------------------------------------------------------------===//
226-
// target_region_flags enum.
226+
// target_exec_mode enum.
227227
//===----------------------------------------------------------------------===//
228228

229-
def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
230-
def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 0>;
231-
def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 1>;
229+
def TargetExecModeBare : I32EnumAttrCase<"bare", 0>;
230+
def TargetExecModeGeneric : I32EnumAttrCase<"generic", 1>;
231+
def TargetExecModeSpmd : I32EnumAttrCase<"spmd", 2>;
232232

233-
def TargetRegionFlags : OpenMP_BitEnumAttr<
234-
"TargetRegionFlags",
235-
"target region property flags", [
236-
TargetRegionFlagsNone,
237-
TargetRegionFlagsSpmd,
238-
TargetRegionFlagsTripCount
233+
def TargetExecMode : OpenMP_I32EnumAttr<
234+
"TargetExecMode",
235+
"target execution mode, mirroring the `OMPTgtExecModeFlags` LLVM enum", [
236+
TargetExecModeBare,
237+
TargetExecModeGeneric,
238+
TargetExecModeSpmd,
239239
]>;
240240

241241
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,13 +1517,17 @@ def TargetOp : OpenMP_Op<"target", traits = [
15171517
/// operations, the top level one will be the one captured.
15181518
Operation *getInnermostCapturedOmpOp();
15191519

1520-
/// Infers the kernel type (Generic, SPMD or Generic-SPMD) based on the
1521-
/// contents of the target region.
1520+
/// Infers the kernel type (Bare, Generic or SPMD) based on the contents of
1521+
/// the target region.
15221522
///
15231523
/// \param capturedOp result of a still valid (no modifications made to any
15241524
/// nested operations) previous call to `getInnermostCapturedOmpOp()`.
1525-
static ::mlir::omp::TargetRegionFlags
1526-
getKernelExecFlags(Operation *capturedOp);
1525+
/// \param hostEvalTripCount output argument to store whether this kernel
1526+
/// wraps a loop whose bounds must be evaluated on the host prior to
1527+
/// launching it.
1528+
static ::mlir::omp::TargetExecMode
1529+
getKernelExecFlags(Operation *capturedOp,
1530+
bool *hostEvalTripCount = nullptr);
15271531
}] # clausesExtraClassDeclaration;
15281532

15291533
let assemblyFormat = clausesAssemblyFormat # [{

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1974,8 +1974,9 @@ LogicalResult TargetOp::verifyRegions() {
19741974
return emitError("target containing multiple 'omp.teams' nested ops");
19751975

19761976
// Check that host_eval values are only used in legal ways.
1977+
bool hostEvalTripCount;
19771978
Operation *capturedOp = getInnermostCapturedOmpOp();
1978-
TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
1979+
TargetExecMode execMode = getKernelExecFlags(capturedOp, &hostEvalTripCount);
19791980
for (Value hostEvalArg :
19801981
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
19811982
for (Operation *user : hostEvalArg.getUsers()) {
@@ -1990,7 +1991,7 @@ LogicalResult TargetOp::verifyRegions() {
19901991
"and 'thread_limit' in 'omp.teams'";
19911992
}
19921993
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1993-
if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
1994+
if (execMode == TargetExecMode::spmd &&
19941995
parallelOp->isAncestor(capturedOp) &&
19951996
hostEvalArg == parallelOp.getNumThreads())
19961997
continue;
@@ -2000,8 +2001,7 @@ LogicalResult TargetOp::verifyRegions() {
20002001
"'omp.parallel' when representing target SPMD";
20012002
}
20022003
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2003-
if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2004-
loopNestOp.getOperation() == capturedOp &&
2004+
if (hostEvalTripCount && loopNestOp.getOperation() == capturedOp &&
20052005
(llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
20062006
llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
20072007
llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
@@ -2106,7 +2106,9 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
21062106
});
21072107
}
21082108

2109-
TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2109+
TargetExecMode TargetOp::getKernelExecFlags(Operation *capturedOp,
2110+
bool *hostEvalTripCount) {
2111+
// TODO: Support detection of bare kernel mode.
21102112
// A non-null captured op is only valid if it resides inside of a TargetOp
21112113
// and is the result of calling getInnermostCapturedOmpOp() on it.
21122114
TargetOp targetOp =
@@ -2115,9 +2117,12 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
21152117
(targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
21162118
"unexpected captured op");
21172119

2120+
if (hostEvalTripCount)
2121+
*hostEvalTripCount = false;
2122+
21182123
// If it's not capturing a loop, it's a default target region.
21192124
if (!isa_and_present<LoopNestOp>(capturedOp))
2120-
return TargetRegionFlags::none;
2125+
return TargetExecMode::generic;
21212126

21222127
// Get the innermost non-simd loop wrapper.
21232128
SmallVector<LoopWrapperInterface> loopWrappers;
@@ -2130,53 +2135,59 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
21302135

21312136
auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
21322137
if (numWrappers != 1 && numWrappers != 2)
2133-
return TargetRegionFlags::none;
2138+
return TargetExecMode::generic;
21342139

21352140
// Detect target-teams-distribute-parallel-wsloop[-simd].
21362141
if (numWrappers == 2) {
21372142
if (!isa<WsloopOp>(innermostWrapper))
2138-
return TargetRegionFlags::none;
2143+
return TargetExecMode::generic;
21392144

21402145
innermostWrapper = std::next(innermostWrapper);
21412146
if (!isa<DistributeOp>(innermostWrapper))
2142-
return TargetRegionFlags::none;
2147+
return TargetExecMode::generic;
21432148

21442149
Operation *parallelOp = (*innermostWrapper)->getParentOp();
21452150
if (!isa_and_present<ParallelOp>(parallelOp))
2146-
return TargetRegionFlags::none;
2151+
return TargetExecMode::generic;
21472152

21482153
Operation *teamsOp = parallelOp->getParentOp();
21492154
if (!isa_and_present<TeamsOp>(teamsOp))
2150-
return TargetRegionFlags::none;
2155+
return TargetExecMode::generic;
21512156

2152-
if (teamsOp->getParentOp() == targetOp.getOperation())
2153-
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2157+
if (teamsOp->getParentOp() == targetOp.getOperation()) {
2158+
if (hostEvalTripCount)
2159+
*hostEvalTripCount = true;
2160+
return TargetExecMode::spmd;
2161+
}
21542162
}
21552163
// Detect target-teams-distribute[-simd] and target-teams-loop.
21562164
else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
21572165
Operation *teamsOp = (*innermostWrapper)->getParentOp();
21582166
if (!isa_and_present<TeamsOp>(teamsOp))
2159-
return TargetRegionFlags::none;
2167+
return TargetExecMode::generic;
21602168

21612169
if (teamsOp->getParentOp() != targetOp.getOperation())
2162-
return TargetRegionFlags::none;
2170+
return TargetExecMode::generic;
2171+
2172+
if (hostEvalTripCount)
2173+
*hostEvalTripCount = true;
21632174

21642175
if (isa<LoopOp>(innermostWrapper))
2165-
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2176+
return TargetExecMode::spmd;
21662177

2167-
return TargetRegionFlags::trip_count;
2178+
return TargetExecMode::generic;
21682179
}
21692180
// Detect target-parallel-wsloop[-simd].
21702181
else if (isa<WsloopOp>(innermostWrapper)) {
21712182
Operation *parallelOp = (*innermostWrapper)->getParentOp();
21722183
if (!isa_and_present<ParallelOp>(parallelOp))
2173-
return TargetRegionFlags::none;
2184+
return TargetExecMode::generic;
21742185

21752186
if (parallelOp->getParentOp() == targetOp.getOperation())
2176-
return TargetRegionFlags::spmd;
2187+
return TargetExecMode::spmd;
21772188
}
21782189

2179-
return TargetRegionFlags::none;
2190+
return TargetExecMode::generic;
21802191
}
21812192

21822193
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5354,11 +5354,18 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
53545354
}
53555355

53565356
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
5357-
omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
5358-
attrs.ExecFlags =
5359-
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
5360-
? llvm::omp::OMP_TGT_EXEC_MODE_SPMD
5361-
: llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
5357+
omp::TargetExecMode execMode = targetOp.getKernelExecFlags(capturedOp);
5358+
switch (execMode) {
5359+
case omp::TargetExecMode::bare:
5360+
attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_BARE;
5361+
break;
5362+
case omp::TargetExecMode::generic:
5363+
attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
5364+
break;
5365+
case omp::TargetExecMode::spmd:
5366+
attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
5367+
break;
5368+
}
53625369
attrs.MinTeams = minTeamsVal;
53635370
attrs.MaxTeams.front() = maxTeamsVal;
53645371
attrs.MinThreads = 1;
@@ -5408,8 +5415,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
54085415
if (numThreads)
54095416
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
54105417

5411-
if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
5412-
omp::TargetRegionFlags::trip_count)) {
5418+
bool hostEvalTripCount;
5419+
targetOp.getKernelExecFlags(capturedOp, &hostEvalTripCount);
5420+
if (hostEvalTripCount) {
54135421
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
54145422
attrs.LoopTripCount = nullptr;
54155423

0 commit comments

Comments
 (0)