@@ -2373,7 +2373,7 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2373
2373
2374
2374
// If it's not capturing a loop, it's a default target region.
2375
2375
if (!isa_and_present<LoopNestOp>(capturedOp))
2376
- return TargetRegionFlags::generic ;
2376
+ return TargetRegionFlags::none ;
2377
2377
2378
2378
// Get the innermost non-simd loop wrapper.
2379
2379
SmallVector<LoopWrapperInterface> loopWrappers;
@@ -2386,25 +2386,25 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2386
2386
2387
2387
auto numWrappers = std::distance (innermostWrapper, loopWrappers.end ());
2388
2388
if (numWrappers != 1 && numWrappers != 2 )
2389
- return TargetRegionFlags::generic ;
2389
+ return TargetRegionFlags::none ;
2390
2390
2391
2391
// Detect target-teams-distribute-parallel-wsloop[-simd].
2392
2392
if (numWrappers == 2 ) {
2393
2393
WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2394
2394
if (!wsloopOp)
2395
- return TargetRegionFlags::generic ;
2395
+ return TargetRegionFlags::none ;
2396
2396
2397
2397
innermostWrapper = std::next (innermostWrapper);
2398
2398
if (!isa<DistributeOp>(innermostWrapper))
2399
- return TargetRegionFlags::generic ;
2399
+ return TargetRegionFlags::none ;
2400
2400
2401
2401
Operation *parallelOp = (*innermostWrapper)->getParentOp ();
2402
2402
if (!isa_and_present<ParallelOp>(parallelOp))
2403
- return TargetRegionFlags::generic ;
2403
+ return TargetRegionFlags::none ;
2404
2404
2405
2405
TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->getParentOp ());
2406
2406
if (!teamsOp)
2407
- return TargetRegionFlags::generic ;
2407
+ return TargetRegionFlags::none ;
2408
2408
2409
2409
if (teamsOp->getParentOp () == targetOp.getOperation ()) {
2410
2410
TargetRegionFlags result =
@@ -2418,53 +2418,27 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2418
2418
else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2419
2419
Operation *teamsOp = (*innermostWrapper)->getParentOp ();
2420
2420
if (!isa_and_present<TeamsOp>(teamsOp))
2421
- return TargetRegionFlags::generic ;
2421
+ return TargetRegionFlags::none ;
2422
2422
2423
2423
if (teamsOp->getParentOp () != targetOp.getOperation ())
2424
- return TargetRegionFlags::generic ;
2424
+ return TargetRegionFlags::none ;
2425
2425
2426
2426
if (isa<LoopOp>(innermostWrapper))
2427
2427
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2428
2428
2429
- // Find single immediately nested captured omp.parallel and add spmd flag
2430
- // (generic-spmd case).
2431
- //
2432
- // TODO: This shouldn't have to be done here, as it is too easy to break.
2433
- // The openmp-opt pass should be updated to be able to promote kernels like
2434
- // this from "Generic" to "Generic-SPMD". However, the use of the
2435
- // `kmpc_distribute_static_loop` family of functions produced by the
2436
- // OMPIRBuilder for these kernels prevents that from working.
2437
- Dialect *ompDialect = targetOp->getDialect ();
2438
- Operation *nestedCapture = findCapturedOmpOp (
2439
- capturedOp, /* checkSingleMandatoryExec=*/ false ,
2440
- [&](Operation *sibling) {
2441
- return sibling && (ompDialect != sibling->getDialect () ||
2442
- sibling->hasTrait <OpTrait::IsTerminator>());
2443
- });
2444
-
2445
- TargetRegionFlags result =
2446
- TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2447
-
2448
- if (!nestedCapture)
2449
- return result;
2450
-
2451
- while (nestedCapture->getParentOp () != capturedOp)
2452
- nestedCapture = nestedCapture->getParentOp ();
2453
-
2454
- return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2455
- : result;
2429
+ return TargetRegionFlags::trip_count;
2456
2430
}
2457
2431
// Detect target-parallel-wsloop[-simd].
2458
2432
else if (isa<WsloopOp>(innermostWrapper)) {
2459
2433
Operation *parallelOp = (*innermostWrapper)->getParentOp ();
2460
2434
if (!isa_and_present<ParallelOp>(parallelOp))
2461
- return TargetRegionFlags::generic ;
2435
+ return TargetRegionFlags::none ;
2462
2436
2463
2437
if (parallelOp->getParentOp () == targetOp.getOperation ())
2464
2438
return TargetRegionFlags::spmd;
2465
2439
}
2466
2440
2467
- return TargetRegionFlags::generic ;
2441
+ return TargetRegionFlags::none ;
2468
2442
}
2469
2443
2470
2444
// ===----------------------------------------------------------------------===//
0 commit comments