@@ -2205,8 +2205,9 @@ LogicalResult TargetOp::verifyRegions() {
2205
2205
return emitError (" target containing multiple 'omp.teams' nested ops" );
2206
2206
2207
2207
// Check that host_eval values are only used in legal ways.
2208
+ bool hostEvalTripCount;
2208
2209
Operation *capturedOp = getInnermostCapturedOmpOp ();
2209
- TargetRegionFlags execFlags = getKernelExecFlags (capturedOp);
2210
+ TargetExecMode execMode = getKernelExecFlags (capturedOp, &hostEvalTripCount );
2210
2211
for (Value hostEvalArg :
2211
2212
cast<BlockArgOpenMPOpInterface>(getOperation ()).getHostEvalBlockArgs ()) {
2212
2213
for (Operation *user : hostEvalArg.getUsers ()) {
@@ -2221,7 +2222,7 @@ LogicalResult TargetOp::verifyRegions() {
2221
2222
" and 'thread_limit' in 'omp.teams'" ;
2222
2223
}
2223
2224
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
2224
- if (bitEnumContainsAny (execFlags, TargetRegionFlags ::spmd) &&
2225
+ if (execMode == TargetExecMode ::spmd &&
2225
2226
parallelOp->isAncestor (capturedOp) &&
2226
2227
hostEvalArg == parallelOp.getNumThreads ())
2227
2228
continue ;
@@ -2231,8 +2232,7 @@ LogicalResult TargetOp::verifyRegions() {
2231
2232
" 'omp.parallel' when representing target SPMD" ;
2232
2233
}
2233
2234
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2234
- if (bitEnumContainsAny (execFlags, TargetRegionFlags::trip_count) &&
2235
- loopNestOp.getOperation () == capturedOp &&
2235
+ if (hostEvalTripCount && loopNestOp.getOperation () == capturedOp &&
2236
2236
(llvm::is_contained (loopNestOp.getLoopLowerBounds (), hostEvalArg) ||
2237
2237
llvm::is_contained (loopNestOp.getLoopUpperBounds (), hostEvalArg) ||
2238
2238
llvm::is_contained (loopNestOp.getLoopSteps (), hostEvalArg)))
@@ -2362,7 +2362,9 @@ static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp,
2362
2362
ompFlags.getAssumeThreadsOversubscription ();
2363
2363
}
2364
2364
2365
- TargetRegionFlags TargetOp::getKernelExecFlags (Operation *capturedOp) {
2365
+ TargetExecMode TargetOp::getKernelExecFlags (Operation *capturedOp,
2366
+ bool *hostEvalTripCount) {
2367
+ // TODO: Support detection of bare kernel mode.
2366
2368
// A non-null captured op is only valid if it resides inside of a TargetOp
2367
2369
// and is the result of calling getInnermostCapturedOmpOp() on it.
2368
2370
TargetOp targetOp =
@@ -2371,9 +2373,12 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2371
2373
(targetOp && targetOp.getInnermostCapturedOmpOp () == capturedOp)) &&
2372
2374
" unexpected captured op" );
2373
2375
2376
+ if (hostEvalTripCount)
2377
+ *hostEvalTripCount = false ;
2378
+
2374
2379
// If it's not capturing a loop, it's a default target region.
2375
2380
if (!isa_and_present<LoopNestOp>(capturedOp))
2376
- return TargetRegionFlags::none ;
2381
+ return TargetExecMode::generic ;
2377
2382
2378
2383
// Get the innermost non-simd loop wrapper.
2379
2384
SmallVector<LoopWrapperInterface> loopWrappers;
@@ -2386,59 +2391,63 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2386
2391
2387
2392
auto numWrappers = std::distance (innermostWrapper, loopWrappers.end ());
2388
2393
if (numWrappers != 1 && numWrappers != 2 )
2389
- return TargetRegionFlags::none ;
2394
+ return TargetExecMode::generic ;
2390
2395
2391
2396
// Detect target-teams-distribute-parallel-wsloop[-simd].
2392
2397
if (numWrappers == 2 ) {
2393
2398
WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2394
2399
if (!wsloopOp)
2395
- return TargetRegionFlags::none ;
2400
+ return TargetExecMode::generic ;
2396
2401
2397
2402
innermostWrapper = std::next (innermostWrapper);
2398
2403
if (!isa<DistributeOp>(innermostWrapper))
2399
- return TargetRegionFlags::none ;
2404
+ return TargetExecMode::generic ;
2400
2405
2401
2406
Operation *parallelOp = (*innermostWrapper)->getParentOp ();
2402
2407
if (!isa_and_present<ParallelOp>(parallelOp))
2403
- return TargetRegionFlags::none ;
2408
+ return TargetExecMode::generic ;
2404
2409
2405
2410
TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->getParentOp ());
2406
2411
if (!teamsOp)
2407
- return TargetRegionFlags::none ;
2412
+ return TargetExecMode::generic ;
2408
2413
2409
2414
if (teamsOp->getParentOp () == targetOp.getOperation ()) {
2410
- TargetRegionFlags result =
2411
- TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2415
+ TargetExecMode result = TargetExecMode::spmd;
2412
2416
if (canPromoteToNoLoop (capturedOp, teamsOp, wsloopOp))
2413
- result = result | TargetRegionFlags::no_loop;
2417
+ result = TargetExecMode::no_loop;
2418
+ if (hostEvalTripCount)
2419
+ *hostEvalTripCount = true ;
2414
2420
return result;
2415
2421
}
2416
2422
}
2417
2423
// Detect target-teams-distribute[-simd] and target-teams-loop.
2418
2424
else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2419
2425
Operation *teamsOp = (*innermostWrapper)->getParentOp ();
2420
2426
if (!isa_and_present<TeamsOp>(teamsOp))
2421
- return TargetRegionFlags::none ;
2427
+ return TargetExecMode::generic ;
2422
2428
2423
2429
if (teamsOp->getParentOp () != targetOp.getOperation ())
2424
- return TargetRegionFlags::none;
2430
+ return TargetExecMode::generic;
2431
+
2432
+ if (hostEvalTripCount)
2433
+ *hostEvalTripCount = true ;
2425
2434
2426
2435
if (isa<LoopOp>(innermostWrapper))
2427
- return TargetRegionFlags ::spmd | TargetRegionFlags::trip_count ;
2436
+ return TargetExecMode ::spmd;
2428
2437
2429
- return TargetRegionFlags::trip_count ;
2438
+ return TargetExecMode::generic ;
2430
2439
}
2431
2440
// Detect target-parallel-wsloop[-simd].
2432
2441
else if (isa<WsloopOp>(innermostWrapper)) {
2433
2442
Operation *parallelOp = (*innermostWrapper)->getParentOp ();
2434
2443
if (!isa_and_present<ParallelOp>(parallelOp))
2435
- return TargetRegionFlags::none ;
2444
+ return TargetExecMode::generic ;
2436
2445
2437
2446
if (parallelOp->getParentOp () == targetOp.getOperation ())
2438
- return TargetRegionFlags ::spmd;
2447
+ return TargetExecMode ::spmd;
2439
2448
}
2440
2449
2441
- return TargetRegionFlags::none ;
2450
+ return TargetExecMode::generic ;
2442
2451
}
2443
2452
2444
2453
// ===----------------------------------------------------------------------===//
0 commit comments