@@ -196,7 +196,8 @@ static bool IsPTXVectorType(MVT VT) {
196
196
// - unsigned int NumElts - The number of elements in the final vector
197
197
// - EVT EltVT - The type of the elements in the final vector
198
198
static std::optional<std::pair<unsigned int , MVT>>
199
- getVectorLoweringShape (EVT VectorEVT, bool CanLowerTo256Bit) {
199
+ getVectorLoweringShape (EVT VectorEVT, const NVPTXSubtarget &STI,
200
+ unsigned AddressSpace) {
200
201
if (!VectorEVT.isSimple ())
201
202
return std::nullopt;
202
203
const MVT VectorVT = VectorEVT.getSimpleVT ();
@@ -213,6 +214,8 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
213
214
// The size of the PTX virtual register that holds a packed type.
214
215
unsigned PackRegSize;
215
216
217
+ bool CanLowerTo256Bit = STI.has256BitVectorLoadStore (AddressSpace);
218
+
216
219
// We only handle "native" vector sizes for now, e.g. <4 x double> is not
217
220
// legal. We can (and should) split that into 2 stores of <2 x double> here
218
221
// but I'm leaving that as a TODO for now.
@@ -263,6 +266,8 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
263
266
LLVM_FALLTHROUGH;
264
267
case MVT::v2f32: // <1 x f32x2>
265
268
case MVT::v4f32: // <2 x f32x2>
269
+ if (!STI.hasF32x2Instructions ())
270
+ return std::pair (NumElts, EltVT);
266
271
PackRegSize = 64 ;
267
272
break ;
268
273
}
@@ -278,97 +283,44 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
278
283
}
279
284
280
285
// / ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
281
- // / EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors
282
- // / into their primitive components.
286
+ // / legal-ish MVTs that compose it. Unlike ComputeValueVTs, this will legalize
287
+ // / the types as required by the calling convention (with special handling for
288
+ // / i8s).
283
289
// / NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
284
290
// / same number of types as the Ins/Outs arrays in LowerFormalArguments,
285
291
// / LowerCall, and LowerReturn.
286
292
static void ComputePTXValueVTs (const TargetLowering &TLI, const DataLayout &DL,
293
+ LLVMContext &Ctx, CallingConv::ID CallConv,
287
294
Type *Ty, SmallVectorImpl<EVT> &ValueVTs,
288
- SmallVectorImpl<uint64_t > * Offsets = nullptr ,
295
+ SmallVectorImpl<uint64_t > & Offsets,
289
296
uint64_t StartingOffset = 0 ) {
290
297
SmallVector<EVT, 16 > TempVTs;
291
298
SmallVector<uint64_t , 16 > TempOffsets;
292
-
293
- // Special case for i128 - decompose to (i64, i64)
294
- if (Ty->isIntegerTy (128 ) || Ty->isFP128Ty ()) {
295
- ValueVTs.append ({MVT::i64 , MVT::i64 });
296
-
297
- if (Offsets)
298
- Offsets->append ({StartingOffset + 0 , StartingOffset + 8 });
299
-
300
- return ;
301
- }
302
-
303
- // Given a struct type, recursively traverse the elements with custom ComputePTXValueVTs.
304
- if (StructType *STy = dyn_cast<StructType>(Ty)) {
305
- auto const *SL = DL.getStructLayout (STy);
306
- auto ElementNum = 0 ;
307
- for (auto *EI : STy->elements ()) {
308
- ComputePTXValueVTs (TLI, DL, EI, ValueVTs, Offsets,
309
- StartingOffset + SL->getElementOffset (ElementNum));
310
- ++ElementNum;
311
- }
312
- return ;
313
- }
314
-
315
- // Given an array type, recursively traverse the elements with custom ComputePTXValueVTs.
316
- if (ArrayType *ATy = dyn_cast<ArrayType>(Ty)) {
317
- Type *EltTy = ATy->getElementType ();
318
- uint64_t EltSize = DL.getTypeAllocSize (EltTy);
319
- for (int I : llvm::seq<int >(ATy->getNumElements ()))
320
- ComputePTXValueVTs (TLI, DL, EltTy, ValueVTs, Offsets, StartingOffset + I * EltSize);
321
- return ;
322
- }
323
-
324
- // Will split structs and arrays into member types, but will not split vector
325
- // types. We do that manually below.
326
299
ComputeValueVTs (TLI, DL, Ty, TempVTs, &TempOffsets, StartingOffset);
327
300
328
- for (auto [VT, Off] : zip (TempVTs, TempOffsets)) {
329
- // Split vectors into individual elements that fit into registers.
330
- if (VT.isVector ()) {
331
- unsigned NumElts = VT.getVectorNumElements ();
332
- EVT EltVT = VT.getVectorElementType ();
333
- // Below we must maintain power-of-2 sized vectors because
334
- // TargetLoweringBase::getVectorTypeBreakdown() which is invoked in
335
- // ComputePTXValueVTs() cannot currently break down non-power-of-2 sized
336
- // vectors.
337
-
338
- // If the element type belongs to one of the supported packed vector types
339
- // then we can pack multiples of this element into a single register.
340
- if (VT == MVT::v2i8) {
341
- // We can pack 2 i8s into a single 16-bit register. We only do this for
342
- // loads and stores, which is why we have a separate case for it.
343
- EltVT = MVT::v2i8;
344
- NumElts = 1 ;
345
- } else if (VT == MVT::v3i8) {
346
- // We can also pack 3 i8s into 32-bit register, leaving the 4th
347
- // element undefined.
348
- EltVT = MVT::v4i8;
349
- NumElts = 1 ;
350
- } else if (NumElts > 1 && isPowerOf2_32 (NumElts)) {
351
- // Handle default packed types.
352
- for (MVT PackedVT : NVPTX::packed_types ()) {
353
- const auto NumEltsPerReg = PackedVT.getVectorNumElements ();
354
- if (NumElts % NumEltsPerReg == 0 &&
355
- EltVT == PackedVT.getVectorElementType ()) {
356
- EltVT = PackedVT;
357
- NumElts /= NumEltsPerReg;
358
- break ;
359
- }
360
- }
361
- }
301
+ for (const auto [VT, Off] : zip (TempVTs, TempOffsets)) {
302
+ MVT RegisterVT = TLI.getRegisterTypeForCallingConv (Ctx, CallConv, VT);
303
+ unsigned NumRegs = TLI.getNumRegistersForCallingConv (Ctx, CallConv, VT);
304
+
305
+ // Since we actually can load/store b8, we need to ensure that we'll use
306
+ // the original sized type for any i8s or i8 vectors.
307
+ if (VT.getScalarType () == MVT::i8 ) {
308
+ if (RegisterVT == MVT::i16 )
309
+ RegisterVT = MVT::i8 ;
310
+ else if (RegisterVT == MVT::v2i16)
311
+ RegisterVT = MVT::v2i8;
312
+ else
313
+ assert (RegisterVT == MVT::v4i8 &&
314
+ " Expected v4i8, v2i16, or i16 for i8 RegisterVT" );
315
+ }
362
316
363
- for (unsigned J : seq (NumElts)) {
364
- ValueVTs.push_back (EltVT);
365
- if (Offsets)
366
- Offsets->push_back (Off + J * EltVT.getStoreSize ());
367
- }
368
- } else {
369
- ValueVTs.push_back (VT);
370
- if (Offsets)
371
- Offsets->push_back (Off);
317
+ // TODO: This is horribly incorrect for cases where the vector elements are
318
+ // not a multiple of bytes (ex i1) and legal or i8. However, this problem
319
+ // has existed for as long as NVPTX has and no one has complained, so we'll
320
+ // leave it for now.
321
+ for (unsigned I : seq (NumRegs)) {
322
+ ValueVTs.push_back (RegisterVT);
323
+ Offsets.push_back (Off + I * RegisterVT.getStoreSize ());
372
324
}
373
325
}
374
326
}
@@ -631,7 +583,9 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
631
583
addRegisterClass (MVT::v2f16, &NVPTX::B32RegClass);
632
584
addRegisterClass (MVT::bf16 , &NVPTX::B16RegClass);
633
585
addRegisterClass (MVT::v2bf16, &NVPTX::B32RegClass);
634
- addRegisterClass (MVT::v2f32, &NVPTX::B64RegClass);
586
+
587
+ if (STI.hasF32x2Instructions ())
588
+ addRegisterClass (MVT::v2f32, &NVPTX::B64RegClass);
635
589
636
590
// Conversion to/from FP16/FP16x2 is always legal.
637
591
setOperationAction (ISD::BUILD_VECTOR, MVT::v2f16, Custom);
@@ -672,7 +626,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
672
626
setOperationAction (ISD::INSERT_VECTOR_ELT, MVT::v2f32, Expand);
673
627
setOperationAction (ISD::VECTOR_SHUFFLE, MVT::v2f32, Expand);
674
628
// Need custom lowering in case the index is dynamic.
675
- setOperationAction (ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
629
+ if (STI.hasF32x2Instructions ())
630
+ setOperationAction (ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
676
631
677
632
// Custom conversions to/from v2i8.
678
633
setOperationAction (ISD::BITCAST, MVT::v2i8, Custom);
@@ -1606,7 +1561,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
1606
1561
} else {
1607
1562
SmallVector<EVT, 16 > VTs;
1608
1563
SmallVector<uint64_t , 16 > Offsets;
1609
- ComputePTXValueVTs (*this , DL, Arg.Ty , VTs, &Offsets, VAOffset);
1564
+ ComputePTXValueVTs (*this , DL, Ctx, CLI.CallConv , Arg.Ty , VTs, Offsets,
1565
+ VAOffset);
1610
1566
assert (VTs.size () == Offsets.size () && " Size mismatch" );
1611
1567
assert (VTs.size () == ArgOuts.size () && " Size mismatch" );
1612
1568
@@ -1756,7 +1712,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
1756
1712
if (!Ins.empty ()) {
1757
1713
SmallVector<EVT, 16 > VTs;
1758
1714
SmallVector<uint64_t , 16 > Offsets;
1759
- ComputePTXValueVTs (*this , DL, RetTy, VTs, & Offsets);
1715
+ ComputePTXValueVTs (*this , DL, Ctx, CLI. CallConv , RetTy, VTs, Offsets);
1760
1716
assert (VTs.size () == Ins.size () && " Bad value decomposition" );
1761
1717
1762
1718
const Align RetAlign = getArgumentAlignment (CB, RetTy, 0 , DL);
@@ -3217,8 +3173,8 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
3217
3173
if (ValVT != MemVT)
3218
3174
return SDValue ();
3219
3175
3220
- const auto NumEltsAndEltVT = getVectorLoweringShape (
3221
- ValVT, STI. has256BitVectorLoadStore ( N->getAddressSpace () ));
3176
+ const auto NumEltsAndEltVT =
3177
+ getVectorLoweringShape ( ValVT, STI, N->getAddressSpace ());
3222
3178
if (!NumEltsAndEltVT)
3223
3179
return SDValue ();
3224
3180
const auto [NumElts, EltVT] = NumEltsAndEltVT.value ();
@@ -3386,6 +3342,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
3386
3342
const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
3387
3343
SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
3388
3344
const DataLayout &DL = DAG.getDataLayout ();
3345
+ LLVMContext &Ctx = *DAG.getContext ();
3389
3346
auto PtrVT = getPointerTy (DAG.getDataLayout ());
3390
3347
3391
3348
const Function &F = DAG.getMachineFunction ().getFunction ();
@@ -3457,7 +3414,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
3457
3414
} else {
3458
3415
SmallVector<EVT, 16 > VTs;
3459
3416
SmallVector<uint64_t , 16 > Offsets;
3460
- ComputePTXValueVTs (*this , DL, Ty, VTs, & Offsets, 0 );
3417
+ ComputePTXValueVTs (*this , DL, Ctx, CallConv, Ty, VTs, Offsets);
3461
3418
assert (VTs.size () == ArgIns.size () && " Size mismatch" );
3462
3419
assert (VTs.size () == Offsets.size () && " Size mismatch" );
3463
3420
@@ -3469,7 +3426,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
3469
3426
for (const unsigned NumElts : VI) {
3470
3427
// i1 is loaded/stored as i8
3471
3428
const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
3472
- const EVT VecVT = getVectorizedVT (LoadVT, NumElts, *DAG. getContext () );
3429
+ const EVT VecVT = getVectorizedVT (LoadVT, NumElts, Ctx );
3473
3430
3474
3431
SDValue VecAddr = DAG.getObjectPtrOffset (
3475
3432
dl, ArgSymbol, TypeSize::getFixed (Offsets[I]));
@@ -3514,6 +3471,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
3514
3471
}
3515
3472
3516
3473
const DataLayout &DL = DAG.getDataLayout ();
3474
+ LLVMContext &Ctx = *DAG.getContext ();
3517
3475
3518
3476
const SDValue RetSymbol = DAG.getExternalSymbol (" func_retval0" , MVT::i32 );
3519
3477
const auto RetAlign = getFunctionParamOptimizedAlign (&F, RetTy, DL);
@@ -3526,7 +3484,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
3526
3484
3527
3485
SmallVector<EVT, 16 > VTs;
3528
3486
SmallVector<uint64_t , 16 > Offsets;
3529
- ComputePTXValueVTs (*this , DL, RetTy, VTs, & Offsets);
3487
+ ComputePTXValueVTs (*this , DL, Ctx, CallConv, RetTy, VTs, Offsets);
3530
3488
assert (VTs.size () == OutVals.size () && " Bad return value decomposition" );
3531
3489
3532
3490
const auto GetRetVal = [&](unsigned I) -> SDValue {
@@ -5985,8 +5943,8 @@ static void replaceLoadVector(SDNode *N, SelectionDAG &DAG,
5985
5943
if (ResVT != MemVT)
5986
5944
return ;
5987
5945
5988
- const auto NumEltsAndEltVT = getVectorLoweringShape (
5989
- ResVT, STI. has256BitVectorLoadStore ( LD->getAddressSpace () ));
5946
+ const auto NumEltsAndEltVT =
5947
+ getVectorLoweringShape ( ResVT, STI, LD->getAddressSpace ());
5990
5948
if (!NumEltsAndEltVT)
5991
5949
return ;
5992
5950
const auto [NumElts, EltVT] = NumEltsAndEltVT.value ();
0 commit comments