@@ -945,6 +945,53 @@ static bool upgradeArmOrAarch64IntrinsicFunction(bool IsArm, Function *F,
945
945
return false ; // No other 'arm.*', 'aarch64.*'.
946
946
}
947
947
948
+ static Intrinsic::ID shouldUpgradeNVPTXTMAG2SIntrinsics (Function *F,
949
+ StringRef Name) {
950
+ if (Name.consume_front (" cp.async.bulk.tensor.g2s." )) {
951
+ Intrinsic::ID ID =
952
+ StringSwitch<Intrinsic::ID>(Name)
953
+ .Case (" im2col.3d" ,
954
+ Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d)
955
+ .Case (" im2col.4d" ,
956
+ Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d)
957
+ .Case (" im2col.5d" ,
958
+ Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d)
959
+ .Case (" tile.1d" , Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d)
960
+ .Case (" tile.2d" , Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d)
961
+ .Case (" tile.3d" , Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d)
962
+ .Case (" tile.4d" , Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d)
963
+ .Case (" tile.5d" , Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d)
964
+ .Default (Intrinsic::not_intrinsic);
965
+
966
+ if (ID == Intrinsic::not_intrinsic)
967
+ return ID;
968
+
969
+ // These intrinsics may need upgrade for two reasons:
970
+ // (1) When the address-space of the first argument is shared[AS=3]
971
+ // (and we upgrade it to use shared_cluster address-space[AS=7])
972
+ if (F->getArg (0 )->getType ()->getPointerAddressSpace () ==
973
+ NVPTXAS::ADDRESS_SPACE_SHARED)
974
+ return ID;
975
+
976
+ // (2) When there are only two boolean flag arguments at the end:
977
+ //
978
+ // The last three parameters of the older version of these
979
+ // intrinsics are: arg1, arg2, .. i64 ch, i1 mc_flag, i1 ch_flag
980
+ //
981
+ // The newer version reads as:
982
+ // arg1, arg2, .. i64 ch, i1 mc_flag, i1 ch_flag, i32 cta_group_flag
983
+ //
984
+ // So, when the type of the [N-3]rd argument is "not i1", then
985
+ // it is the older version and we need to upgrade.
986
+ size_t FlagStartIndex = F->getFunctionType ()->getNumParams () - 3 ;
987
+ Type *ArgType = F->getFunctionType ()->getParamType (FlagStartIndex);
988
+ if (!ArgType->isIntegerTy (1 ))
989
+ return ID;
990
+ }
991
+
992
+ return Intrinsic::not_intrinsic;
993
+ }
994
+
948
995
static Intrinsic::ID shouldUpgradeNVPTXSharedClusterIntrinsic (Function *F,
949
996
StringRef Name) {
950
997
if (Name.consume_front (" mapa.shared.cluster" ))
@@ -959,22 +1006,6 @@ static Intrinsic::ID shouldUpgradeNVPTXSharedClusterIntrinsic(Function *F,
959
1006
Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster)
960
1007
.Case (" shared.cta.to.cluster" ,
961
1008
Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster)
962
- .Case (" tensor.g2s.im2col.3d" ,
963
- Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d)
964
- .Case (" tensor.g2s.im2col.4d" ,
965
- Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d)
966
- .Case (" tensor.g2s.im2col.5d" ,
967
- Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d)
968
- .Case (" tensor.g2s.tile.1d" ,
969
- Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d)
970
- .Case (" tensor.g2s.tile.2d" ,
971
- Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d)
972
- .Case (" tensor.g2s.tile.3d" ,
973
- Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d)
974
- .Case (" tensor.g2s.tile.4d" ,
975
- Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d)
976
- .Case (" tensor.g2s.tile.5d" ,
977
- Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d)
978
1009
.Default (Intrinsic::not_intrinsic);
979
1010
980
1011
if (ID != Intrinsic::not_intrinsic)
@@ -1339,6 +1370,14 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
1339
1370
return true ;
1340
1371
}
1341
1372
1373
+ // Upgrade TMA copy G2S Intrinsics
1374
+ IID = shouldUpgradeNVPTXTMAG2SIntrinsics (F, Name);
1375
+ if (IID != Intrinsic::not_intrinsic) {
1376
+ rename (F);
1377
+ NewFn = Intrinsic::getOrInsertDeclaration (F->getParent (), IID);
1378
+ return true ;
1379
+ }
1380
+
1342
1381
// The following nvvm intrinsics correspond exactly to an LLVM idiom, but
1343
1382
// not to an intrinsic alone. We expand them in UpgradeIntrinsicCall.
1344
1383
//
@@ -4831,7 +4870,18 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) {
4831
4870
return ;
4832
4871
}
4833
4872
case Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster:
4834
- case Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster:
4873
+ case Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster: {
4874
+ // Create a new call with the correct address space.
4875
+ SmallVector<Value *, 4 > Args (CI->args ());
4876
+ Args[0 ] = Builder.CreateAddrSpaceCast (
4877
+ Args[0 ], Builder.getPtrTy (NVPTXAS::ADDRESS_SPACE_SHARED_CLUSTER));
4878
+
4879
+ NewCall = Builder.CreateCall (NewFn, Args);
4880
+ NewCall->takeName (CI);
4881
+ CI->replaceAllUsesWith (NewCall);
4882
+ CI->eraseFromParent ();
4883
+ return ;
4884
+ }
4835
4885
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d:
4836
4886
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d:
4837
4887
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d:
@@ -4840,10 +4890,22 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) {
4840
4890
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d:
4841
4891
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d:
4842
4892
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d: {
4843
- // Create a new call with the correct address space.
4844
- SmallVector<Value *, 4 > Args (CI->args ());
4845
- Args[0 ] = Builder.CreateAddrSpaceCast (
4846
- Args[0 ], Builder.getPtrTy (NVPTXAS::ADDRESS_SPACE_SHARED_CLUSTER));
4893
+ SmallVector<Value *, 16 > Args (CI->args ());
4894
+
4895
+ // Create AddrSpaceCast to shared_cluster if needed.
4896
+ // This handles case (1) in shouldUpgradeNVPTXTMAG2SIntrinsics().
4897
+ unsigned AS = CI->getArgOperand (0 )->getType ()->getPointerAddressSpace ();
4898
+ if (AS == NVPTXAS::ADDRESS_SPACE_SHARED)
4899
+ Args[0 ] = Builder.CreateAddrSpaceCast (
4900
+ Args[0 ], Builder.getPtrTy (NVPTXAS::ADDRESS_SPACE_SHARED_CLUSTER));
4901
+
4902
+ // Attach the flag argument for cta_group, with a
4903
+ // default value of 0. This handles case (2) in
4904
+ // shouldUpgradeNVPTXTMAG2SIntrinsics().
4905
+ size_t NumArgs = CI->arg_size ();
4906
+ Value *FlagArg = CI->getArgOperand (NumArgs - 3 );
4907
+ if (!FlagArg->getType ()->isIntegerTy (1 ))
4908
+ Args.push_back (ConstantInt::get (Builder.getInt32Ty (), 0 ));
4847
4909
4848
4910
NewCall = Builder.CreateCall (NewFn, Args);
4849
4911
NewCall->takeName (CI);
0 commit comments