Skip to content

Commit 0daff51

Browse files
committed
address comments
1 parent d31bc98 commit 0daff51

File tree

3 files changed

+63
-50
lines changed

3 files changed

+63
-50
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 48 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,26 +1386,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
13861386
return DAG.getConstant(I, dl, MVT::i32);
13871387
};
13881388

1389-
// Variadic arguments.
1390-
//
1391-
// Normally, for each argument, we declare a param scalar or a param
1392-
// byte array in the .param space, and store the argument value to that
1393-
// param scalar or array starting at offset 0.
1394-
//
1395-
// In the case of the first variadic argument, we declare a vararg byte array
1396-
// with size 0. The exact size of this array isn't known at this point, so
1397-
// it'll be patched later. All the variadic arguments will be stored to this
1398-
// array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1399-
// initially set to 0, so it can be used for non-variadic arguments (which use
1400-
// 0 offset) to simplify the code.
1401-
//
1402-
// After all vararg is processed, 'VAOffset' holds the size of the
1403-
// vararg byte array.
1404-
1405-
SDValue VADeclareParam = SDValue(); // vararg byte array
1406-
const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic
1407-
unsigned VAOffset = 0; // current offset in the param array
1408-
14091389
const unsigned UniqueCallSite = GlobalUniqueCallSite++;
14101390
const SDValue CallChain = CLI.Chain;
14111391
const SDValue StartChain =
@@ -1414,7 +1394,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
14141394

14151395
SmallVector<SDValue, 16> CallPrereqs{StartChain};
14161396

1417-
const auto DeclareScalarParam = [&](SDValue Symbol, unsigned Size) {
1397+
const auto MakeDeclareScalarParam = [&](SDValue Symbol, unsigned Size) {
14181398
// PTX ABI requires integral types to be at least 32 bits in size. FP16 is
14191399
// loaded/stored using i16, so it's handled here as well.
14201400
const unsigned SizeBits = promoteScalarArgumentSize(Size * 8);
@@ -1426,8 +1406,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
14261406
return Declare;
14271407
};
14281408

1429-
const auto DeclareArrayParam = [&](SDValue Symbol, Align Align,
1430-
unsigned Size) {
1409+
const auto MakeDeclareArrayParam = [&](SDValue Symbol, Align Align,
1410+
unsigned Size) {
14311411
SDValue Declare = DAG.getNode(
14321412
NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue},
14331413
{StartChain, Symbol, GetI32(Align.value()), GetI32(Size), DeclareGlue});
@@ -1436,6 +1416,33 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
14361416
return Declare;
14371417
};
14381418

1419+
// Variadic arguments.
1420+
//
1421+
// Normally, for each argument, we declare a param scalar or a param
1422+
// byte array in the .param space, and store the argument value to that
1423+
// param scalar or array starting at offset 0.
1424+
//
1425+
// In the case of the first variadic argument, we declare a vararg byte array
1426+
// with size 0. The exact size of this array isn't known at this point, so
1427+
// it'll be patched later. All the variadic arguments will be stored to this
1428+
// array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1429+
// initially set to 0, so it can be used for non-variadic arguments (which use
1430+
// 0 offset) to simplify the code.
1431+
//
1432+
// After all vararg is processed, 'VAOffset' holds the size of the
1433+
// vararg byte array.
1434+
assert((CLI.IsVarArg || CLI.Args.size() == CLI.NumFixedArgs) &&
1435+
"Non-VarArg function with extra arguments");
1436+
1437+
const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic
1438+
unsigned VAOffset = 0; // current offset in the param array
1439+
1440+
const SDValue VADeclareParam =
1441+
CLI.Args.size() > FirstVAArg
1442+
? MakeDeclareArrayParam(getCallParamSymbol(DAG, FirstVAArg, MVT::i32),
1443+
Align(STI.getMaxRequiredAlignment()), 0)
1444+
: SDValue();
1445+
14391446
// Args.size() and Outs.size() need not match.
14401447
// Outs.size() will be larger
14411448
// * if there is an aggregate argument with multiple fields (each field
@@ -1496,21 +1503,17 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
14961503
"type size mismatch");
14971504

14981505
const SDValue ArgDeclare = [&]() {
1499-
if (IsVAArg) {
1500-
if (ArgI == FirstVAArg)
1501-
VADeclareParam = DeclareArrayParam(
1502-
ParamSymbol, Align(STI.getMaxRequiredAlignment()), 0);
1506+
if (IsVAArg)
15031507
return VADeclareParam;
1504-
}
15051508

15061509
if (IsByVal || shouldPassAsArray(Arg.Ty))
1507-
return DeclareArrayParam(ParamSymbol, ArgAlign, TypeSize);
1510+
return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TypeSize);
15081511

15091512
assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
15101513
assert((ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) &&
15111514
"Only int and float types are supported as non-array arguments");
15121515

1513-
return DeclareScalarParam(ParamSymbol, TypeSize);
1516+
return MakeDeclareScalarParam(ParamSymbol, TypeSize);
15141517
}();
15151518

15161519
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter
@@ -1570,7 +1573,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
15701573
if (NumElts == 1) {
15711574
Val = GetStoredValue(J, EltVT, CurrentAlign);
15721575
} else {
1573-
SmallVector<SDValue, 6> StoreVals;
1576+
SmallVector<SDValue, 8> StoreVals;
15741577
for (const unsigned K : llvm::seq(NumElts)) {
15751578
SDValue ValJ = GetStoredValue(J + K, EltVT, CurrentAlign);
15761579
if (ValJ.getValueType().isVector())
@@ -1611,9 +1614,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
16111614
const unsigned ResultSize = DL.getTypeAllocSize(RetTy);
16121615
if (shouldPassAsArray(RetTy)) {
16131616
const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
1614-
DeclareArrayParam(RetSymbol, RetAlign, ResultSize);
1617+
MakeDeclareArrayParam(RetSymbol, RetAlign, ResultSize);
16151618
} else {
1616-
DeclareScalarParam(RetSymbol, ResultSize);
1619+
MakeDeclareScalarParam(RetSymbol, ResultSize);
16171620
}
16181621
}
16191622

@@ -1737,17 +1740,16 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
17371740

17381741
LoadChains.push_back(R.getValue(1));
17391742

1740-
if (NumElts == 1) {
1743+
if (NumElts == 1)
17411744
ProxyRegOps.push_back(R);
1742-
} else {
1745+
else
17431746
for (const unsigned J : llvm::seq(NumElts)) {
17441747
SDValue Elt = DAG.getNode(
17451748
LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
17461749
: ISD::EXTRACT_VECTOR_ELT,
17471750
dl, LoadVT, R, DAG.getVectorIdxConstant(J * PackingAmt, dl));
17481751
ProxyRegOps.push_back(Elt);
17491752
}
1750-
}
17511753
I += NumElts;
17521754
}
17531755
}
@@ -5767,7 +5769,7 @@ static SDValue sinkProxyReg(SDValue R, SDValue Chain,
57675769
{Chain, R});
57685770
}
57695771
case ISD::BUILD_VECTOR: {
5770-
if (DCI.isAfterLegalizeDAG())
5772+
if (DCI.isBeforeLegalize())
57715773
return SDValue();
57725774

57735775
SmallVector<SDValue, 16> Ops;
@@ -5779,6 +5781,15 @@ static SDValue sinkProxyReg(SDValue R, SDValue Chain,
57795781
}
57805782
return DCI.DAG.getNode(ISD::BUILD_VECTOR, SDLoc(R), R.getValueType(), Ops);
57815783
}
5784+
case ISD::EXTRACT_VECTOR_ELT: {
5785+
if (DCI.isBeforeLegalize())
5786+
return SDValue();
5787+
5788+
if (SDValue V = sinkProxyReg(R.getOperand(0), Chain, DCI))
5789+
return DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(R), R.getValueType(),
5790+
V, R.getOperand(1));
5791+
return SDValue();
5792+
}
57825793
default:
57835794
return SDValue();
57845795
}

llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,8 @@ define i64 @test_param_type_mismatch_variadic(ptr %p) {
173173
; CHECK-NEXT: // %bb.0:
174174
; CHECK-NEXT: ld.param.b64 %rd1, [test_param_type_mismatch_variadic_param_0];
175175
; CHECK-NEXT: { // callseq 4, 0
176-
; CHECK-NEXT: .param .b64 param0;
177176
; CHECK-NEXT: .param .align 8 .b8 param1[8];
177+
; CHECK-NEXT: .param .b64 param0;
178178
; CHECK-NEXT: .param .b64 retval0;
179179
; CHECK-NEXT: st.param.b64 [param0], %rd1;
180180
; CHECK-NEXT: st.param.b64 [param1], 7;
@@ -195,8 +195,8 @@ define i64 @test_param_count_mismatch_variadic(ptr %p) {
195195
; CHECK-NEXT: // %bb.0:
196196
; CHECK-NEXT: ld.param.b64 %rd1, [test_param_count_mismatch_variadic_param_0];
197197
; CHECK-NEXT: { // callseq 5, 0
198-
; CHECK-NEXT: .param .b64 param0;
199198
; CHECK-NEXT: .param .align 8 .b8 param1[8];
199+
; CHECK-NEXT: .param .b64 param0;
200200
; CHECK-NEXT: .param .b64 retval0;
201201
; CHECK-NEXT: st.param.b64 [param0], %rd1;
202202
; CHECK-NEXT: st.param.b64 [param1], 7;

llvm/test/CodeGen/NVPTX/param-load-store.ll

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ define <4 x i8> @test_v4i8(<4 x i8> %a) {
243243
; CHECK: call.uni (retval0), test_v5i8,
244244
; CHECK-DAG: ld.param.b32 [[RE0:%r[0-9]+]], [retval0];
245245
; CHECK-DAG: ld.param.b8 [[RE4:%rs[0-9]+]], [retval0+4];
246-
; CHECK-DAG: st.param.b32 [func_retval0], {{%r[0-9]+}};
246+
; CHECK-DAG: st.param.b32 [func_retval0], [[RE0]];
247247
; CHECK-DAG: st.param.b8 [func_retval0+4], [[RE4]];
248248
; CHECK-NEXT: ret;
249249
define <5 x i8> @test_v5i8(<5 x i8> %a) {
@@ -311,8 +311,9 @@ define signext i16 @test_i16s(i16 signext %a) {
311311
; CHECK-DAG: st.param.b32 [param0], [[E0]];
312312
; CHECK-DAG: st.param.b16 [param0+4], [[E2]];
313313
; CHECK: call.uni (retval0), test_v3i16,
314-
; CHECK: ld.param.v2.b16 {[[RE0:%rs[0-9]+]], [[RE1:%rs[0-9]+]]}, [retval0];
314+
; CHECK: ld.param.b32 [[RE:%r[0-9]+]], [retval0];
315315
; CHECK: ld.param.b16 [[RE2:%rs[0-9]+]], [retval0+4];
316+
; CHECK-DAG: mov.b32 {[[RE0:%rs[0-9]+]], [[RE1:%rs[0-9]+]]}, [[RE]];
316317
; CHECK-DAG: st.param.v2.b16 [func_retval0], {[[RE0]], [[RE1]]};
317318
; CHECK-DAG: st.param.b16 [func_retval0+4], [[RE2]];
318319
; CHECK-NEXT: ret;
@@ -347,9 +348,9 @@ define <4 x i16> @test_v4i16(<4 x i16> %a) {
347348
; CHECK-DAG: st.param.v2.b32 [param0], {[[E0]], [[E1]]};
348349
; CHECK-DAG: st.param.b16 [param0+8], [[E4]];
349350
; CHECK: call.uni (retval0), test_v5i16,
350-
; CHECK-DAG: ld.param.v4.b16 {[[RE0:%rs[0-9]+]], [[RE1:%rs[0-9]+]], [[RE2:%rs[0-9]+]], [[RE3:%rs[0-9]+]]}, [retval0];
351+
; CHECK-DAG: ld.param.v2.b32 {[[RE0:%r[0-9]+]], [[RE1:%r[0-9]+]]}, [retval0];
351352
; CHECK-DAG: ld.param.b16 [[RE4:%rs[0-9]+]], [retval0+8];
352-
; CHECK-DAG: st.param.v4.b16 [func_retval0], {[[RE0]], [[RE1]], [[RE2]], [[RE3]]}
353+
; CHECK-DAG: st.param.v2.b32 [func_retval0], {[[RE0]], [[RE1]]}
353354
; CHECK-DAG: st.param.b16 [func_retval0+8], [[RE4]];
354355
; CHECK-NEXT: ret;
355356
define <5 x i16> @test_v5i16(<5 x i16> %a) {
@@ -432,8 +433,9 @@ define <2 x bfloat> @test_v2bf16(<2 x bfloat> %a) {
432433
; CHECK-DAG: st.param.b32 [param0], [[E0]];
433434
; CHECK-DAG: st.param.b16 [param0+4], [[E2]];
434435
; CHECK: call.uni (retval0), test_v3f16,
435-
; CHECK-DAG: ld.param.v2.b16 {[[R0:%rs[0-9]+]], [[R1:%rs[0-9]+]]}, [retval0];
436+
; CHECK-DAG: ld.param.b32 [[R:%r[0-9]+]], [retval0];
436437
; CHECK-DAG: ld.param.b16 [[R2:%rs[0-9]+]], [retval0+4];
438+
; CHECK-DAG: mov.b32 {[[R0:%rs[0-9]+]], [[R1:%rs[0-9]+]]}, [[R]];
437439
; CHECK-DAG: st.param.v2.b16 [func_retval0], {[[R0]], [[R1]]};
438440
; CHECK-DAG: st.param.b16 [func_retval0+4], [[R2]];
439441
; CHECK: ret;
@@ -468,9 +470,9 @@ define <4 x half> @test_v4f16(<4 x half> %a) {
468470
; CHECK-DAG: st.param.v2.b32 [param0], {[[E0]], [[E1]]};
469471
; CHECK-DAG: st.param.b16 [param0+8], [[E4]];
470472
; CHECK: call.uni (retval0), test_v5f16,
471-
; CHECK-DAG: ld.param.v4.b16 {[[R0:%rs[0-9]+]], [[R1:%rs[0-9]+]], [[R2:%rs[0-9]+]], [[R3:%rs[0-9]+]]}, [retval0];
473+
; CHECK-DAG: ld.param.v2.b32 {[[R0:%r[0-9]+]], [[R1:%r[0-9]+]]}, [retval0];
472474
; CHECK-DAG: ld.param.b16 [[R4:%rs[0-9]+]], [retval0+8];
473-
; CHECK-DAG: st.param.v4.b16 [func_retval0], {[[R0]], [[R1]], [[R2]], [[R3]]};
475+
; CHECK-DAG: st.param.v2.b32 [func_retval0], {[[R0]], [[R1]]};
474476
; CHECK-DAG: st.param.b16 [func_retval0+8], [[R4]];
475477
; CHECK: ret;
476478
define <5 x half> @test_v5f16(<5 x half> %a) {
@@ -506,11 +508,11 @@ define <8 x half> @test_v8f16(<8 x half> %a) {
506508
; CHECK-DAG: st.param.v2.b32 [param0+8], {[[E2]], [[E3]]};
507509
; CHECK-DAG: st.param.b16 [param0+16], [[E8]];
508510
; CHECK: call.uni (retval0), test_v9f16,
509-
; CHECK-DAG: ld.param.v4.b16 {[[R0:%rs[0-9]+]], [[R1:%rs[0-9]+]], [[R2:%rs[0-9]+]], [[R3:%rs[0-9]+]]}, [retval0];
510-
; CHECK-DAG: ld.param.v4.b16 {[[R4:%rs[0-9]+]], [[R5:%rs[0-9]+]], [[R6:%rs[0-9]+]], [[R7:%rs[0-9]+]]}, [retval0+8];
511+
; CHECK-DAG: ld.param.v2.b32 {[[R0:%r[0-9]+]], [[R1:%r[0-9]+]]}, [retval0];
512+
; CHECK-DAG: ld.param.v2.b32 {[[R2:%r[0-9]+]], [[R3:%r[0-9]+]]}, [retval0+8];
511513
; CHECK-DAG: ld.param.b16 [[R8:%rs[0-9]+]], [retval0+16];
512-
; CHECK-DAG: st.param.v4.b16 [func_retval0], {[[R0]], [[R1]], [[R2]], [[R3]]};
513-
; CHECK-DAG: st.param.v4.b16 [func_retval0+8], {[[R4]], [[R5]], [[R6]], [[R7]]};
514+
; CHECK-DAG: st.param.v2.b32 [func_retval0], {[[R0]], [[R1]]};
515+
; CHECK-DAG: st.param.v2.b32 [func_retval0+8], {[[R2]], [[R3]]};
514516
; CHECK-DAG: st.param.b16 [func_retval0+16], [[R8]];
515517
; CHECK: ret;
516518
define <9 x half> @test_v9f16(<9 x half> %a) {

0 commit comments

Comments
 (0)