@@ -1974,8 +1974,9 @@ LogicalResult TargetOp::verifyRegions() {
1974
1974
return emitError (" target containing multiple 'omp.teams' nested ops" );
1975
1975
1976
1976
// Check that host_eval values are only used in legal ways.
1977
+ bool hostEvalTripCount;
1977
1978
Operation *capturedOp = getInnermostCapturedOmpOp ();
1978
- TargetRegionFlags execFlags = getKernelExecFlags (capturedOp);
1979
+ TargetExecMode execMode = getKernelExecFlags (capturedOp, &hostEvalTripCount );
1979
1980
for (Value hostEvalArg :
1980
1981
cast<BlockArgOpenMPOpInterface>(getOperation ()).getHostEvalBlockArgs ()) {
1981
1982
for (Operation *user : hostEvalArg.getUsers ()) {
@@ -1990,7 +1991,7 @@ LogicalResult TargetOp::verifyRegions() {
1990
1991
" and 'thread_limit' in 'omp.teams'" ;
1991
1992
}
1992
1993
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1993
- if (bitEnumContainsAny (execFlags, TargetRegionFlags ::spmd) &&
1994
+ if (execMode == TargetExecMode ::spmd &&
1994
1995
parallelOp->isAncestor (capturedOp) &&
1995
1996
hostEvalArg == parallelOp.getNumThreads ())
1996
1997
continue ;
@@ -2000,8 +2001,7 @@ LogicalResult TargetOp::verifyRegions() {
2000
2001
" 'omp.parallel' when representing target SPMD" ;
2001
2002
}
2002
2003
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2003
- if (bitEnumContainsAny (execFlags, TargetRegionFlags::trip_count) &&
2004
- loopNestOp.getOperation () == capturedOp &&
2004
+ if (hostEvalTripCount && loopNestOp.getOperation () == capturedOp &&
2005
2005
(llvm::is_contained (loopNestOp.getLoopLowerBounds (), hostEvalArg) ||
2006
2006
llvm::is_contained (loopNestOp.getLoopUpperBounds (), hostEvalArg) ||
2007
2007
llvm::is_contained (loopNestOp.getLoopSteps (), hostEvalArg)))
@@ -2106,7 +2106,9 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
2106
2106
});
2107
2107
}
2108
2108
2109
- TargetRegionFlags TargetOp::getKernelExecFlags (Operation *capturedOp) {
2109
+ TargetExecMode TargetOp::getKernelExecFlags (Operation *capturedOp,
2110
+ bool *hostEvalTripCount) {
2111
+ // TODO: Support detection of bare kernel mode.
2110
2112
// A non-null captured op is only valid if it resides inside of a TargetOp
2111
2113
// and is the result of calling getInnermostCapturedOmpOp() on it.
2112
2114
TargetOp targetOp =
@@ -2115,9 +2117,12 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2115
2117
(targetOp && targetOp.getInnermostCapturedOmpOp () == capturedOp)) &&
2116
2118
" unexpected captured op" );
2117
2119
2120
+ if (hostEvalTripCount)
2121
+ *hostEvalTripCount = false ;
2122
+
2118
2123
// If it's not capturing a loop, it's a default target region.
2119
2124
if (!isa_and_present<LoopNestOp>(capturedOp))
2120
- return TargetRegionFlags::none ;
2125
+ return TargetExecMode::generic ;
2121
2126
2122
2127
// Get the innermost non-simd loop wrapper.
2123
2128
SmallVector<LoopWrapperInterface> loopWrappers;
@@ -2130,53 +2135,59 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2130
2135
2131
2136
auto numWrappers = std::distance (innermostWrapper, loopWrappers.end ());
2132
2137
if (numWrappers != 1 && numWrappers != 2 )
2133
- return TargetRegionFlags::none ;
2138
+ return TargetExecMode::generic ;
2134
2139
2135
2140
// Detect target-teams-distribute-parallel-wsloop[-simd].
2136
2141
if (numWrappers == 2 ) {
2137
2142
if (!isa<WsloopOp>(innermostWrapper))
2138
- return TargetRegionFlags::none ;
2143
+ return TargetExecMode::generic ;
2139
2144
2140
2145
innermostWrapper = std::next (innermostWrapper);
2141
2146
if (!isa<DistributeOp>(innermostWrapper))
2142
- return TargetRegionFlags::none ;
2147
+ return TargetExecMode::generic ;
2143
2148
2144
2149
Operation *parallelOp = (*innermostWrapper)->getParentOp ();
2145
2150
if (!isa_and_present<ParallelOp>(parallelOp))
2146
- return TargetRegionFlags::none ;
2151
+ return TargetExecMode::generic ;
2147
2152
2148
2153
Operation *teamsOp = parallelOp->getParentOp ();
2149
2154
if (!isa_and_present<TeamsOp>(teamsOp))
2150
- return TargetRegionFlags::none ;
2155
+ return TargetExecMode::generic ;
2151
2156
2152
- if (teamsOp->getParentOp () == targetOp.getOperation ())
2153
- return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2157
+ if (teamsOp->getParentOp () == targetOp.getOperation ()) {
2158
+ if (hostEvalTripCount)
2159
+ *hostEvalTripCount = true ;
2160
+ return TargetExecMode::spmd;
2161
+ }
2154
2162
}
2155
2163
// Detect target-teams-distribute[-simd] and target-teams-loop.
2156
2164
else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2157
2165
Operation *teamsOp = (*innermostWrapper)->getParentOp ();
2158
2166
if (!isa_and_present<TeamsOp>(teamsOp))
2159
- return TargetRegionFlags::none ;
2167
+ return TargetExecMode::generic ;
2160
2168
2161
2169
if (teamsOp->getParentOp () != targetOp.getOperation ())
2162
- return TargetRegionFlags::none;
2170
+ return TargetExecMode::generic;
2171
+
2172
+ if (hostEvalTripCount)
2173
+ *hostEvalTripCount = true ;
2163
2174
2164
2175
if (isa<LoopOp>(innermostWrapper))
2165
- return TargetRegionFlags ::spmd | TargetRegionFlags::trip_count ;
2176
+ return TargetExecMode ::spmd;
2166
2177
2167
- return TargetRegionFlags::trip_count ;
2178
+ return TargetExecMode::generic ;
2168
2179
}
2169
2180
// Detect target-parallel-wsloop[-simd].
2170
2181
else if (isa<WsloopOp>(innermostWrapper)) {
2171
2182
Operation *parallelOp = (*innermostWrapper)->getParentOp ();
2172
2183
if (!isa_and_present<ParallelOp>(parallelOp))
2173
- return TargetRegionFlags::none ;
2184
+ return TargetExecMode::generic ;
2174
2185
2175
2186
if (parallelOp->getParentOp () == targetOp.getOperation ())
2176
- return TargetRegionFlags ::spmd;
2187
+ return TargetExecMode ::spmd;
2177
2188
}
2178
2189
2179
- return TargetRegionFlags::none ;
2190
+ return TargetExecMode::generic ;
2180
2191
}
2181
2192
2182
2193
// ===----------------------------------------------------------------------===//
0 commit comments