@@ -1323,6 +1323,86 @@ Error OpenMPIRBuilder::emitCancelationCheckImpl(
1323
1323
return Error::success ();
1324
1324
}
1325
1325
1326
+ // Create wrapper function used to gather the outlined function's argument
1327
+ // structure from a shared buffer and to forward them to it when running in
1328
+ // Generic mode.
1329
+ //
1330
+ // The outlined function is expected to receive 2 integer arguments followed by
1331
+ // an optional pointer argument to an argument structure holding the rest.
1332
+ static Function *createTargetParallelWrapper (OpenMPIRBuilder *OMPIRBuilder,
1333
+ Function &OutlinedFn) {
1334
+ size_t NumArgs = OutlinedFn.arg_size ();
1335
+ assert ((NumArgs == 2 || NumArgs == 3 ) &&
1336
+ " expected a 2-3 argument parallel outlined function" );
1337
+ bool UseArgStruct = NumArgs == 3 ;
1338
+
1339
+ IRBuilder<> &Builder = OMPIRBuilder->Builder ;
1340
+ IRBuilder<>::InsertPointGuard IPG (Builder);
1341
+ auto *FnTy = FunctionType::get (Builder.getVoidTy (),
1342
+ {Builder.getInt16Ty (), Builder.getInt32Ty ()},
1343
+ /* isVarArg=*/ false );
1344
+ auto *WrapperFn =
1345
+ Function::Create (FnTy, GlobalValue::InternalLinkage,
1346
+ OutlinedFn.getName () + " .wrapper" , OMPIRBuilder->M );
1347
+
1348
+ WrapperFn->addParamAttr (0 , Attribute::NoUndef);
1349
+ WrapperFn->addParamAttr (0 , Attribute::ZExt);
1350
+ WrapperFn->addParamAttr (1 , Attribute::NoUndef);
1351
+
1352
+ BasicBlock *EntryBB =
1353
+ BasicBlock::Create (OMPIRBuilder->M .getContext (), " entry" , WrapperFn);
1354
+ Builder.SetInsertPoint (EntryBB);
1355
+
1356
+ // Allocation.
1357
+ Value *AddrAlloca = Builder.CreateAlloca (Builder.getInt32Ty (),
1358
+ /* ArraySize=*/ nullptr , " addr" );
1359
+ AddrAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast (
1360
+ AddrAlloca, Builder.getPtrTy (/* AddrSpace=*/ 0 ),
1361
+ AddrAlloca->getName () + " .ascast" );
1362
+
1363
+ Value *ZeroAlloca = Builder.CreateAlloca (Builder.getInt32Ty (),
1364
+ /* ArraySize=*/ nullptr , " zero" );
1365
+ ZeroAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast (
1366
+ ZeroAlloca, Builder.getPtrTy (/* AddrSpace=*/ 0 ),
1367
+ ZeroAlloca->getName () + " .ascast" );
1368
+
1369
+ Value *ArgsAlloca = nullptr ;
1370
+ if (UseArgStruct) {
1371
+ ArgsAlloca = Builder.CreateAlloca (Builder.getPtrTy (),
1372
+ /* ArraySize=*/ nullptr , " global_args" );
1373
+ ArgsAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast (
1374
+ ArgsAlloca, Builder.getPtrTy (/* AddrSpace=*/ 0 ),
1375
+ ArgsAlloca->getName () + " .ascast" );
1376
+ }
1377
+
1378
+ // Initialization.
1379
+ Builder.CreateStore (WrapperFn->getArg (1 ), AddrAlloca);
1380
+ Builder.CreateStore (Builder.getInt32 (0 ), ZeroAlloca);
1381
+ if (UseArgStruct) {
1382
+ Builder.CreateCall (
1383
+ OMPIRBuilder->getOrCreateRuntimeFunctionPtr (
1384
+ llvm::omp::RuntimeFunction::OMPRTL___kmpc_get_shared_variables),
1385
+ {ArgsAlloca});
1386
+ }
1387
+
1388
+ SmallVector<Value *, 3 > Args{AddrAlloca, ZeroAlloca};
1389
+
1390
+ // Load structArg from global_args.
1391
+ if (UseArgStruct) {
1392
+ Value *StructArg = Builder.CreateLoad (Builder.getPtrTy (), ArgsAlloca);
1393
+ StructArg = Builder.CreateInBoundsGEP (Builder.getPtrTy (), StructArg,
1394
+ {Builder.getInt64 (0 )});
1395
+ StructArg = Builder.CreateLoad (Builder.getPtrTy (), StructArg, " structArg" );
1396
+ Args.push_back (StructArg);
1397
+ }
1398
+
1399
+ // Call the outlined function holding the parallel body.
1400
+ Builder.CreateCall (&OutlinedFn, Args);
1401
+ Builder.CreateRetVoid ();
1402
+
1403
+ return WrapperFn;
1404
+ }
1405
+
1326
1406
// Callback used to create OpenMP runtime calls to support
1327
1407
// omp parallel clause for the device.
1328
1408
// We need to use this callback to replace call to the OutlinedFn in OuterFn
@@ -1332,6 +1412,10 @@ static void targetParallelCallback(
1332
1412
BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
1333
1413
Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1334
1414
Value *ThreadID, const SmallVector<Instruction *, 4 > &ToBeDeleted) {
1415
+ assert (OutlinedFn.arg_size () >= 2 &&
1416
+ " Expected at least tid and bounded tid as arguments" );
1417
+ unsigned NumCapturedVars = OutlinedFn.arg_size () - /* tid & bounded tid */ 2 ;
1418
+
1335
1419
// Add some known attributes.
1336
1420
IRBuilder<> &Builder = OMPIRBuilder->Builder ;
1337
1421
OutlinedFn.addParamAttr (0 , Attribute::NoAlias);
@@ -1340,17 +1424,12 @@ static void targetParallelCallback(
1340
1424
OutlinedFn.addParamAttr (1 , Attribute::NoUndef);
1341
1425
OutlinedFn.addFnAttr (Attribute::NoUnwind);
1342
1426
1343
- assert (OutlinedFn.arg_size () >= 2 &&
1344
- " Expected at least tid and bounded tid as arguments" );
1345
- unsigned NumCapturedVars = OutlinedFn.arg_size () - /* tid & bounded tid */ 2 ;
1346
-
1347
1427
CallInst *CI = cast<CallInst>(OutlinedFn.user_back ());
1348
1428
assert (CI && " Expected call instruction to outlined function" );
1349
1429
CI->getParent ()->setName (" omp_parallel" );
1350
1430
1351
1431
Builder.SetInsertPoint (CI);
1352
1432
Type *PtrTy = OMPIRBuilder->VoidPtr ;
1353
- Value *NullPtrValue = Constant::getNullValue (PtrTy);
1354
1433
1355
1434
// Add alloca for kernel args
1356
1435
OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP ();
@@ -1376,6 +1455,15 @@ static void targetParallelCallback(
1376
1455
IfCondition ? Builder.CreateSExtOrTrunc (IfCondition, OMPIRBuilder->Int32 )
1377
1456
: Builder.getInt32 (1 );
1378
1457
1458
+ // If this is not a Generic kernel, we can skip generating the wrapper.
1459
+ std::optional<omp::OMPTgtExecModeFlags> ExecMode =
1460
+ getTargetKernelExecMode (*OuterFn);
1461
+ Value *WrapperFn;
1462
+ if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC)
1463
+ WrapperFn = createTargetParallelWrapper (OMPIRBuilder, OutlinedFn);
1464
+ else
1465
+ WrapperFn = Constant::getNullValue (PtrTy);
1466
+
1379
1467
// Build kmpc_parallel_51 call
1380
1468
Value *Parallel51CallArgs[] = {
1381
1469
/* identifier*/ Ident,
@@ -1384,7 +1472,7 @@ static void targetParallelCallback(
1384
1472
/* number of threads */ NumThreads ? NumThreads : Builder.getInt32 (-1 ),
1385
1473
/* Proc bind */ Builder.getInt32 (-1 ),
1386
1474
/* outlined function */ &OutlinedFn,
1387
- /* wrapper function */ NullPtrValue ,
1475
+ /* wrapper function */ WrapperFn ,
1388
1476
/* arguments of the outlined funciton*/ Args,
1389
1477
/* number of arguments */ Builder.getInt64 (NumCapturedVars)};
1390
1478
0 commit comments