@@ -2117,7 +2117,7 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2117
2117
2118
2118
// If it's not capturing a loop, it's a default target region.
2119
2119
if (!isa_and_present<LoopNestOp>(capturedOp))
2120
- return TargetRegionFlags::generic ;
2120
+ return TargetRegionFlags::none ;
2121
2121
2122
2122
// Get the innermost non-simd loop wrapper.
2123
2123
SmallVector<LoopWrapperInterface> loopWrappers;
@@ -2130,24 +2130,24 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2130
2130
2131
2131
auto numWrappers = std::distance (innermostWrapper, loopWrappers.end ());
2132
2132
if (numWrappers != 1 && numWrappers != 2 )
2133
- return TargetRegionFlags::generic ;
2133
+ return TargetRegionFlags::none ;
2134
2134
2135
2135
// Detect target-teams-distribute-parallel-wsloop[-simd].
2136
2136
if (numWrappers == 2 ) {
2137
2137
if (!isa<WsloopOp>(innermostWrapper))
2138
- return TargetRegionFlags::generic ;
2138
+ return TargetRegionFlags::none ;
2139
2139
2140
2140
innermostWrapper = std::next (innermostWrapper);
2141
2141
if (!isa<DistributeOp>(innermostWrapper))
2142
- return TargetRegionFlags::generic ;
2142
+ return TargetRegionFlags::none ;
2143
2143
2144
2144
Operation *parallelOp = (*innermostWrapper)->getParentOp ();
2145
2145
if (!isa_and_present<ParallelOp>(parallelOp))
2146
- return TargetRegionFlags::generic ;
2146
+ return TargetRegionFlags::none ;
2147
2147
2148
2148
Operation *teamsOp = parallelOp->getParentOp ();
2149
2149
if (!isa_and_present<TeamsOp>(teamsOp))
2150
- return TargetRegionFlags::generic ;
2150
+ return TargetRegionFlags::none ;
2151
2151
2152
2152
if (teamsOp->getParentOp () == targetOp.getOperation ())
2153
2153
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
@@ -2156,53 +2156,27 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2156
2156
else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2157
2157
Operation *teamsOp = (*innermostWrapper)->getParentOp ();
2158
2158
if (!isa_and_present<TeamsOp>(teamsOp))
2159
- return TargetRegionFlags::generic ;
2159
+ return TargetRegionFlags::none ;
2160
2160
2161
2161
if (teamsOp->getParentOp () != targetOp.getOperation ())
2162
- return TargetRegionFlags::generic ;
2162
+ return TargetRegionFlags::none ;
2163
2163
2164
2164
if (isa<LoopOp>(innermostWrapper))
2165
2165
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2166
2166
2167
- // Find single immediately nested captured omp.parallel and add spmd flag
2168
- // (generic-spmd case).
2169
- //
2170
- // TODO: This shouldn't have to be done here, as it is too easy to break.
2171
- // The openmp-opt pass should be updated to be able to promote kernels like
2172
- // this from "Generic" to "Generic-SPMD". However, the use of the
2173
- // `kmpc_distribute_static_loop` family of functions produced by the
2174
- // OMPIRBuilder for these kernels prevents that from working.
2175
- Dialect *ompDialect = targetOp->getDialect ();
2176
- Operation *nestedCapture = findCapturedOmpOp (
2177
- capturedOp, /* checkSingleMandatoryExec=*/ false ,
2178
- [&](Operation *sibling) {
2179
- return sibling && (ompDialect != sibling->getDialect () ||
2180
- sibling->hasTrait <OpTrait::IsTerminator>());
2181
- });
2182
-
2183
- TargetRegionFlags result =
2184
- TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2185
-
2186
- if (!nestedCapture)
2187
- return result;
2188
-
2189
- while (nestedCapture->getParentOp () != capturedOp)
2190
- nestedCapture = nestedCapture->getParentOp ();
2191
-
2192
- return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2193
- : result;
2167
+ return TargetRegionFlags::trip_count;
2194
2168
}
2195
2169
// Detect target-parallel-wsloop[-simd].
2196
2170
else if (isa<WsloopOp>(innermostWrapper)) {
2197
2171
Operation *parallelOp = (*innermostWrapper)->getParentOp ();
2198
2172
if (!isa_and_present<ParallelOp>(parallelOp))
2199
- return TargetRegionFlags::generic ;
2173
+ return TargetRegionFlags::none ;
2200
2174
2201
2175
if (parallelOp->getParentOp () == targetOp.getOperation ())
2202
2176
return TargetRegionFlags::spmd;
2203
2177
}
2204
2178
2205
- return TargetRegionFlags::generic ;
2179
+ return TargetRegionFlags::none ;
2206
2180
}
2207
2181
2208
2182
// ===----------------------------------------------------------------------===//
0 commit comments