Skip to content

Commit a9de1ab

Browse files
authored
[NVPTX] Disable v2f32 registers when no operations supported, or via cl::opt (#154476)
The addition of v2f32 as a legal type, supported by the B64 register class, has caused performance regressions, broken inline assembly, and resulted in a couple (now fixed) mis-compilations. In order to mitigate these issues, only mark this as a legal type when there exist operations that support it, since for targets where this is not the case it serves no purpose. To enable further debugging, add an option to disable v2f32. In order to allow for a target-dependent set of legal types, ComputePTXValueVTs has been fully re-written to take advantage of TargetLowering call-lowering APIs.
1 parent 0319a79 commit a9de1ab

20 files changed

+2564
-1720
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 50 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ static bool IsPTXVectorType(MVT VT) {
196196
// - unsigned int NumElts - The number of elements in the final vector
197197
// - EVT EltVT - The type of the elements in the final vector
198198
static std::optional<std::pair<unsigned int, MVT>>
199-
getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
199+
getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
200+
unsigned AddressSpace) {
200201
if (!VectorEVT.isSimple())
201202
return std::nullopt;
202203
const MVT VectorVT = VectorEVT.getSimpleVT();
@@ -213,6 +214,8 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
213214
// The size of the PTX virtual register that holds a packed type.
214215
unsigned PackRegSize;
215216

217+
bool CanLowerTo256Bit = STI.has256BitVectorLoadStore(AddressSpace);
218+
216219
// We only handle "native" vector sizes for now, e.g. <4 x double> is not
217220
// legal. We can (and should) split that into 2 stores of <2 x double> here
218221
// but I'm leaving that as a TODO for now.
@@ -263,6 +266,8 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
263266
LLVM_FALLTHROUGH;
264267
case MVT::v2f32: // <1 x f32x2>
265268
case MVT::v4f32: // <2 x f32x2>
269+
if (!STI.hasF32x2Instructions())
270+
return std::pair(NumElts, EltVT);
266271
PackRegSize = 64;
267272
break;
268273
}
@@ -278,97 +283,44 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
278283
}
279284

280285
/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
281-
/// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors
282-
/// into their primitive components.
286+
/// legal-ish MVTs that compose it. Unlike ComputeValueVTs, this will legalize
287+
/// the types as required by the calling convention (with special handling for
288+
/// i8s).
283289
/// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
284290
/// same number of types as the Ins/Outs arrays in LowerFormalArguments,
285291
/// LowerCall, and LowerReturn.
286292
static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
293+
LLVMContext &Ctx, CallingConv::ID CallConv,
287294
Type *Ty, SmallVectorImpl<EVT> &ValueVTs,
288-
SmallVectorImpl<uint64_t> *Offsets = nullptr,
295+
SmallVectorImpl<uint64_t> &Offsets,
289296
uint64_t StartingOffset = 0) {
290297
SmallVector<EVT, 16> TempVTs;
291298
SmallVector<uint64_t, 16> TempOffsets;
292-
293-
// Special case for i128 - decompose to (i64, i64)
294-
if (Ty->isIntegerTy(128) || Ty->isFP128Ty()) {
295-
ValueVTs.append({MVT::i64, MVT::i64});
296-
297-
if (Offsets)
298-
Offsets->append({StartingOffset + 0, StartingOffset + 8});
299-
300-
return;
301-
}
302-
303-
// Given a struct type, recursively traverse the elements with custom ComputePTXValueVTs.
304-
if (StructType *STy = dyn_cast<StructType>(Ty)) {
305-
auto const *SL = DL.getStructLayout(STy);
306-
auto ElementNum = 0;
307-
for(auto *EI : STy->elements()) {
308-
ComputePTXValueVTs(TLI, DL, EI, ValueVTs, Offsets,
309-
StartingOffset + SL->getElementOffset(ElementNum));
310-
++ElementNum;
311-
}
312-
return;
313-
}
314-
315-
// Given an array type, recursively traverse the elements with custom ComputePTXValueVTs.
316-
if (ArrayType *ATy = dyn_cast<ArrayType>(Ty)) {
317-
Type *EltTy = ATy->getElementType();
318-
uint64_t EltSize = DL.getTypeAllocSize(EltTy);
319-
for (int I : llvm::seq<int>(ATy->getNumElements()))
320-
ComputePTXValueVTs(TLI, DL, EltTy, ValueVTs, Offsets, StartingOffset + I * EltSize);
321-
return;
322-
}
323-
324-
// Will split structs and arrays into member types, but will not split vector
325-
// types. We do that manually below.
326299
ComputeValueVTs(TLI, DL, Ty, TempVTs, &TempOffsets, StartingOffset);
327300

328-
for (auto [VT, Off] : zip(TempVTs, TempOffsets)) {
329-
// Split vectors into individual elements that fit into registers.
330-
if (VT.isVector()) {
331-
unsigned NumElts = VT.getVectorNumElements();
332-
EVT EltVT = VT.getVectorElementType();
333-
// Below we must maintain power-of-2 sized vectors because
334-
// TargetLoweringBase::getVectorTypeBreakdown() which is invoked in
335-
// ComputePTXValueVTs() cannot currently break down non-power-of-2 sized
336-
// vectors.
337-
338-
// If the element type belongs to one of the supported packed vector types
339-
// then we can pack multiples of this element into a single register.
340-
if (VT == MVT::v2i8) {
341-
// We can pack 2 i8s into a single 16-bit register. We only do this for
342-
// loads and stores, which is why we have a separate case for it.
343-
EltVT = MVT::v2i8;
344-
NumElts = 1;
345-
} else if (VT == MVT::v3i8) {
346-
// We can also pack 3 i8s into 32-bit register, leaving the 4th
347-
// element undefined.
348-
EltVT = MVT::v4i8;
349-
NumElts = 1;
350-
} else if (NumElts > 1 && isPowerOf2_32(NumElts)) {
351-
// Handle default packed types.
352-
for (MVT PackedVT : NVPTX::packed_types()) {
353-
const auto NumEltsPerReg = PackedVT.getVectorNumElements();
354-
if (NumElts % NumEltsPerReg == 0 &&
355-
EltVT == PackedVT.getVectorElementType()) {
356-
EltVT = PackedVT;
357-
NumElts /= NumEltsPerReg;
358-
break;
359-
}
360-
}
361-
}
301+
for (const auto [VT, Off] : zip(TempVTs, TempOffsets)) {
302+
MVT RegisterVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, VT);
303+
unsigned NumRegs = TLI.getNumRegistersForCallingConv(Ctx, CallConv, VT);
304+
305+
// Since we actually can load/store b8, we need to ensure that we'll use
306+
// the original sized type for any i8s or i8 vectors.
307+
if (VT.getScalarType() == MVT::i8) {
308+
if (RegisterVT == MVT::i16)
309+
RegisterVT = MVT::i8;
310+
else if (RegisterVT == MVT::v2i16)
311+
RegisterVT = MVT::v2i8;
312+
else
313+
assert(RegisterVT == MVT::v4i8 &&
314+
"Expected v4i8, v2i16, or i16 for i8 RegisterVT");
315+
}
362316

363-
for (unsigned J : seq(NumElts)) {
364-
ValueVTs.push_back(EltVT);
365-
if (Offsets)
366-
Offsets->push_back(Off + J * EltVT.getStoreSize());
367-
}
368-
} else {
369-
ValueVTs.push_back(VT);
370-
if (Offsets)
371-
Offsets->push_back(Off);
317+
// TODO: This is horribly incorrect for cases where the vector elements are
318+
// not a multiple of bytes (ex i1) and legal or i8. However, this problem
319+
// has existed for as long as NVPTX has and no one has complained, so we'll
320+
// leave it for now.
321+
for (unsigned I : seq(NumRegs)) {
322+
ValueVTs.push_back(RegisterVT);
323+
Offsets.push_back(Off + I * RegisterVT.getStoreSize());
372324
}
373325
}
374326
}
@@ -631,7 +583,9 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
631583
addRegisterClass(MVT::v2f16, &NVPTX::B32RegClass);
632584
addRegisterClass(MVT::bf16, &NVPTX::B16RegClass);
633585
addRegisterClass(MVT::v2bf16, &NVPTX::B32RegClass);
634-
addRegisterClass(MVT::v2f32, &NVPTX::B64RegClass);
586+
587+
if (STI.hasF32x2Instructions())
588+
addRegisterClass(MVT::v2f32, &NVPTX::B64RegClass);
635589

636590
// Conversion to/from FP16/FP16x2 is always legal.
637591
setOperationAction(ISD::BUILD_VECTOR, MVT::v2f16, Custom);
@@ -672,7 +626,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
672626
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2f32, Expand);
673627
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2f32, Expand);
674628
// Need custom lowering in case the index is dynamic.
675-
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
629+
if (STI.hasF32x2Instructions())
630+
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
676631

677632
// Custom conversions to/from v2i8.
678633
setOperationAction(ISD::BITCAST, MVT::v2i8, Custom);
@@ -1606,7 +1561,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
16061561
} else {
16071562
SmallVector<EVT, 16> VTs;
16081563
SmallVector<uint64_t, 16> Offsets;
1609-
ComputePTXValueVTs(*this, DL, Arg.Ty, VTs, &Offsets, VAOffset);
1564+
ComputePTXValueVTs(*this, DL, Ctx, CLI.CallConv, Arg.Ty, VTs, Offsets,
1565+
VAOffset);
16101566
assert(VTs.size() == Offsets.size() && "Size mismatch");
16111567
assert(VTs.size() == ArgOuts.size() && "Size mismatch");
16121568

@@ -1756,7 +1712,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
17561712
if (!Ins.empty()) {
17571713
SmallVector<EVT, 16> VTs;
17581714
SmallVector<uint64_t, 16> Offsets;
1759-
ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
1715+
ComputePTXValueVTs(*this, DL, Ctx, CLI.CallConv, RetTy, VTs, Offsets);
17601716
assert(VTs.size() == Ins.size() && "Bad value decomposition");
17611717

17621718
const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
@@ -3217,8 +3173,8 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
32173173
if (ValVT != MemVT)
32183174
return SDValue();
32193175

3220-
const auto NumEltsAndEltVT = getVectorLoweringShape(
3221-
ValVT, STI.has256BitVectorLoadStore(N->getAddressSpace()));
3176+
const auto NumEltsAndEltVT =
3177+
getVectorLoweringShape(ValVT, STI, N->getAddressSpace());
32223178
if (!NumEltsAndEltVT)
32233179
return SDValue();
32243180
const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
@@ -3386,6 +3342,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
33863342
const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
33873343
SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
33883344
const DataLayout &DL = DAG.getDataLayout();
3345+
LLVMContext &Ctx = *DAG.getContext();
33893346
auto PtrVT = getPointerTy(DAG.getDataLayout());
33903347

33913348
const Function &F = DAG.getMachineFunction().getFunction();
@@ -3457,7 +3414,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34573414
} else {
34583415
SmallVector<EVT, 16> VTs;
34593416
SmallVector<uint64_t, 16> Offsets;
3460-
ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0);
3417+
ComputePTXValueVTs(*this, DL, Ctx, CallConv, Ty, VTs, Offsets);
34613418
assert(VTs.size() == ArgIns.size() && "Size mismatch");
34623419
assert(VTs.size() == Offsets.size() && "Size mismatch");
34633420

@@ -3469,7 +3426,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34693426
for (const unsigned NumElts : VI) {
34703427
// i1 is loaded/stored as i8
34713428
const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
3472-
const EVT VecVT = getVectorizedVT(LoadVT, NumElts, *DAG.getContext());
3429+
const EVT VecVT = getVectorizedVT(LoadVT, NumElts, Ctx);
34733430

34743431
SDValue VecAddr = DAG.getObjectPtrOffset(
34753432
dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
@@ -3514,6 +3471,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
35143471
}
35153472

35163473
const DataLayout &DL = DAG.getDataLayout();
3474+
LLVMContext &Ctx = *DAG.getContext();
35173475

35183476
const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32);
35193477
const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL);
@@ -3526,7 +3484,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
35263484

35273485
SmallVector<EVT, 16> VTs;
35283486
SmallVector<uint64_t, 16> Offsets;
3529-
ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
3487+
ComputePTXValueVTs(*this, DL, Ctx, CallConv, RetTy, VTs, Offsets);
35303488
assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
35313489

35323490
const auto GetRetVal = [&](unsigned I) -> SDValue {
@@ -5985,8 +5943,8 @@ static void replaceLoadVector(SDNode *N, SelectionDAG &DAG,
59855943
if (ResVT != MemVT)
59865944
return;
59875945

5988-
const auto NumEltsAndEltVT = getVectorLoweringShape(
5989-
ResVT, STI.has256BitVectorLoadStore(LD->getAddressSpace()));
5946+
const auto NumEltsAndEltVT =
5947+
getVectorLoweringShape(ResVT, STI, LD->getAddressSpace());
59905948
if (!NumEltsAndEltVT)
59915949
return;
59925950
const auto [NumElts, EltVT] = NumEltsAndEltVT.value();

llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ static cl::opt<bool>
2929
NoF16Math("nvptx-no-f16-math", cl::Hidden,
3030
cl::desc("NVPTX Specific: Disable generation of f16 math ops."),
3131
cl::init(false));
32+
33+
static cl::opt<bool> NoF32x2("nvptx-no-f32x2", cl::Hidden,
34+
cl::desc("NVPTX Specific: Disable generation of "
35+
"f32x2 instructions and registers."),
36+
cl::init(false));
37+
3238
// Pin the vtable to this file.
3339
void NVPTXSubtarget::anchor() {}
3440

@@ -70,6 +76,10 @@ bool NVPTXSubtarget::allowFP16Math() const {
7076
return hasFP16Math() && NoF16Math == false;
7177
}
7278

79+
bool NVPTXSubtarget::hasF32x2Instructions() const {
80+
return SmVersion >= 100 && PTXVersion >= 86 && !NoF32x2;
81+
}
82+
7383
bool NVPTXSubtarget::hasNativeBF16Support(int Opcode) const {
7484
if (!hasBF16Math())
7585
return false;

llvm/lib/Target/NVPTX/NVPTXSubtarget.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,7 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
117117
return HasTcgen05 && PTXVersion >= 86;
118118
}
119119
// f32x2 instructions in Blackwell family
120-
bool hasF32x2Instructions() const {
121-
return SmVersion >= 100 && PTXVersion >= 86;
122-
}
120+
bool hasF32x2Instructions() const;
123121

124122
// TMA G2S copy with cta_group::1/2 support
125123
bool hasCpAsyncBulkTensorCTAGroupSupport() const {

llvm/test/CodeGen/NVPTX/aggregate-return.ll

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,20 @@ declare {float, float} @bars({float, float} %input)
1010
define void @test_v2f32(<2 x float> %input, ptr %output) {
1111
; CHECK-LABEL: test_v2f32(
1212
; CHECK: {
13-
; CHECK-NEXT: .reg .b64 %rd<4>;
13+
; CHECK-NEXT: .reg .b32 %r<5>;
14+
; CHECK-NEXT: .reg .b64 %rd<2>;
1415
; CHECK-EMPTY:
1516
; CHECK-NEXT: // %bb.0:
16-
; CHECK-NEXT: ld.param.b64 %rd1, [test_v2f32_param_0];
17+
; CHECK-NEXT: ld.param.v2.b32 {%r1, %r2}, [test_v2f32_param_0];
1718
; CHECK-NEXT: { // callseq 0, 0
1819
; CHECK-NEXT: .param .align 8 .b8 param0[8];
1920
; CHECK-NEXT: .param .align 8 .b8 retval0[8];
20-
; CHECK-NEXT: st.param.b64 [param0], %rd1;
21+
; CHECK-NEXT: st.param.v2.b32 [param0], {%r1, %r2};
2122
; CHECK-NEXT: call.uni (retval0), barv, (param0);
22-
; CHECK-NEXT: ld.param.b64 %rd2, [retval0];
23+
; CHECK-NEXT: ld.param.v2.b32 {%r3, %r4}, [retval0];
2324
; CHECK-NEXT: } // callseq 0
24-
; CHECK-NEXT: ld.param.b64 %rd3, [test_v2f32_param_1];
25-
; CHECK-NEXT: st.b64 [%rd3], %rd2;
25+
; CHECK-NEXT: ld.param.b64 %rd1, [test_v2f32_param_1];
26+
; CHECK-NEXT: st.v2.b32 [%rd1], {%r3, %r4};
2627
; CHECK-NEXT: ret;
2728
%call = tail call <2 x float> @barv(<2 x float> %input)
2829
store <2 x float> %call, ptr %output, align 8
@@ -32,24 +33,28 @@ define void @test_v2f32(<2 x float> %input, ptr %output) {
3233
define void @test_v3f32(<3 x float> %input, ptr %output) {
3334
; CHECK-LABEL: test_v3f32(
3435
; CHECK: {
35-
; CHECK-NEXT: .reg .b32 %r<3>;
36-
; CHECK-NEXT: .reg .b64 %rd<4>;
36+
; CHECK-NEXT: .reg .b32 %r<7>;
37+
; CHECK-NEXT: .reg .b64 %rd<6>;
3738
; CHECK-EMPTY:
3839
; CHECK-NEXT: // %bb.0:
39-
; CHECK-NEXT: ld.param.b64 %rd1, [test_v3f32_param_0];
40-
; CHECK-NEXT: ld.param.b32 %r1, [test_v3f32_param_0+8];
40+
; CHECK-NEXT: ld.param.v2.b32 {%r1, %r2}, [test_v3f32_param_0];
41+
; CHECK-NEXT: ld.param.b32 %r3, [test_v3f32_param_0+8];
4142
; CHECK-NEXT: { // callseq 1, 0
4243
; CHECK-NEXT: .param .align 16 .b8 param0[16];
4344
; CHECK-NEXT: .param .align 16 .b8 retval0[16];
44-
; CHECK-NEXT: st.param.b32 [param0+8], %r1;
45-
; CHECK-NEXT: st.param.b64 [param0], %rd1;
45+
; CHECK-NEXT: st.param.b32 [param0+8], %r3;
46+
; CHECK-NEXT: st.param.v2.b32 [param0], {%r1, %r2};
4647
; CHECK-NEXT: call.uni (retval0), barv3, (param0);
47-
; CHECK-NEXT: ld.param.b32 %r2, [retval0+8];
48-
; CHECK-NEXT: ld.param.b64 %rd2, [retval0];
48+
; CHECK-NEXT: ld.param.b32 %r4, [retval0+8];
49+
; CHECK-NEXT: ld.param.v2.b32 {%r5, %r6}, [retval0];
4950
; CHECK-NEXT: } // callseq 1
50-
; CHECK-NEXT: ld.param.b64 %rd3, [test_v3f32_param_1];
51-
; CHECK-NEXT: st.b32 [%rd3+8], %r2;
52-
; CHECK-NEXT: st.b64 [%rd3], %rd2;
51+
; CHECK-NEXT: cvt.u64.u32 %rd1, %r5;
52+
; CHECK-NEXT: cvt.u64.u32 %rd2, %r6;
53+
; CHECK-NEXT: shl.b64 %rd3, %rd2, 32;
54+
; CHECK-NEXT: or.b64 %rd4, %rd1, %rd3;
55+
; CHECK-NEXT: ld.param.b64 %rd5, [test_v3f32_param_1];
56+
; CHECK-NEXT: st.b32 [%rd5+8], %r4;
57+
; CHECK-NEXT: st.b64 [%rd5], %rd4;
5358
; CHECK-NEXT: ret;
5459
%call = tail call <3 x float> @barv3(<3 x float> %input)
5560
; Make sure we don't load more values than than we need to.

0 commit comments

Comments
 (0)