@@ -1386,26 +1386,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
1386
1386
return DAG.getConstant (I, dl, MVT::i32 );
1387
1387
};
1388
1388
1389
- // Variadic arguments.
1390
- //
1391
- // Normally, for each argument, we declare a param scalar or a param
1392
- // byte array in the .param space, and store the argument value to that
1393
- // param scalar or array starting at offset 0.
1394
- //
1395
- // In the case of the first variadic argument, we declare a vararg byte array
1396
- // with size 0. The exact size of this array isn't known at this point, so
1397
- // it'll be patched later. All the variadic arguments will be stored to this
1398
- // array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1399
- // initially set to 0, so it can be used for non-variadic arguments (which use
1400
- // 0 offset) to simplify the code.
1401
- //
1402
- // After all vararg is processed, 'VAOffset' holds the size of the
1403
- // vararg byte array.
1404
-
1405
- SDValue VADeclareParam = SDValue (); // vararg byte array
1406
- const unsigned FirstVAArg = CLI.NumFixedArgs ; // position of first variadic
1407
- unsigned VAOffset = 0 ; // current offset in the param array
1408
-
1409
1389
const unsigned UniqueCallSite = GlobalUniqueCallSite++;
1410
1390
const SDValue CallChain = CLI.Chain ;
1411
1391
const SDValue StartChain =
@@ -1414,7 +1394,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
1414
1394
1415
1395
SmallVector<SDValue, 16 > CallPrereqs{StartChain};
1416
1396
1417
- const auto DeclareScalarParam = [&](SDValue Symbol, unsigned Size) {
1397
+ const auto MakeDeclareScalarParam = [&](SDValue Symbol, unsigned Size) {
1418
1398
// PTX ABI requires integral types to be at least 32 bits in size. FP16 is
1419
1399
// loaded/stored using i16, so it's handled here as well.
1420
1400
const unsigned SizeBits = promoteScalarArgumentSize (Size * 8 );
@@ -1426,8 +1406,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
1426
1406
return Declare;
1427
1407
};
1428
1408
1429
- const auto DeclareArrayParam = [&](SDValue Symbol, Align Align,
1430
- unsigned Size) {
1409
+ const auto MakeDeclareArrayParam = [&](SDValue Symbol, Align Align,
1410
+ unsigned Size) {
1431
1411
SDValue Declare = DAG.getNode (
1432
1412
NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue},
1433
1413
{StartChain, Symbol, GetI32 (Align.value ()), GetI32 (Size), DeclareGlue});
@@ -1436,6 +1416,33 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
1436
1416
return Declare;
1437
1417
};
1438
1418
1419
+ // Variadic arguments.
1420
+ //
1421
+ // Normally, for each argument, we declare a param scalar or a param
1422
+ // byte array in the .param space, and store the argument value to that
1423
+ // param scalar or array starting at offset 0.
1424
+ //
1425
+ // In the case of the first variadic argument, we declare a vararg byte array
1426
+ // with size 0. The exact size of this array isn't known at this point, so
1427
+ // it'll be patched later. All the variadic arguments will be stored to this
1428
+ // array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1429
+ // initially set to 0, so it can be used for non-variadic arguments (which use
1430
+ // 0 offset) to simplify the code.
1431
+ //
1432
+ // After all vararg is processed, 'VAOffset' holds the size of the
1433
+ // vararg byte array.
1434
+ assert ((CLI.IsVarArg || CLI.Args .size () == CLI.NumFixedArgs ) &&
1435
+ " Non-VarArg function with extra arguments" );
1436
+
1437
+ const unsigned FirstVAArg = CLI.NumFixedArgs ; // position of first variadic
1438
+ unsigned VAOffset = 0 ; // current offset in the param array
1439
+
1440
+ const SDValue VADeclareParam =
1441
+ CLI.Args .size () > FirstVAArg
1442
+ ? MakeDeclareArrayParam (getCallParamSymbol (DAG, FirstVAArg, MVT::i32 ),
1443
+ Align (STI.getMaxRequiredAlignment ()), 0 )
1444
+ : SDValue ();
1445
+
1439
1446
// Args.size() and Outs.size() need not match.
1440
1447
// Outs.size() will be larger
1441
1448
// * if there is an aggregate argument with multiple fields (each field
@@ -1496,21 +1503,17 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
1496
1503
" type size mismatch" );
1497
1504
1498
1505
const SDValue ArgDeclare = [&]() {
1499
- if (IsVAArg) {
1500
- if (ArgI == FirstVAArg)
1501
- VADeclareParam = DeclareArrayParam (
1502
- ParamSymbol, Align (STI.getMaxRequiredAlignment ()), 0 );
1506
+ if (IsVAArg)
1503
1507
return VADeclareParam;
1504
- }
1505
1508
1506
1509
if (IsByVal || shouldPassAsArray (Arg.Ty ))
1507
- return DeclareArrayParam (ParamSymbol, ArgAlign, TypeSize);
1510
+ return MakeDeclareArrayParam (ParamSymbol, ArgAlign, TypeSize);
1508
1511
1509
1512
assert (ArgOuts.size () == 1 && " We must pass only one value as non-array" );
1510
1513
assert ((ArgOuts[0 ].VT .isInteger () || ArgOuts[0 ].VT .isFloatingPoint ()) &&
1511
1514
" Only int and float types are supported as non-array arguments" );
1512
1515
1513
- return DeclareScalarParam (ParamSymbol, TypeSize);
1516
+ return MakeDeclareScalarParam (ParamSymbol, TypeSize);
1514
1517
}();
1515
1518
1516
1519
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter
@@ -1570,7 +1573,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
1570
1573
if (NumElts == 1 ) {
1571
1574
Val = GetStoredValue (J, EltVT, CurrentAlign);
1572
1575
} else {
1573
- SmallVector<SDValue, 6 > StoreVals;
1576
+ SmallVector<SDValue, 8 > StoreVals;
1574
1577
for (const unsigned K : llvm::seq (NumElts)) {
1575
1578
SDValue ValJ = GetStoredValue (J + K, EltVT, CurrentAlign);
1576
1579
if (ValJ.getValueType ().isVector ())
@@ -1611,9 +1614,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
1611
1614
const unsigned ResultSize = DL.getTypeAllocSize (RetTy);
1612
1615
if (shouldPassAsArray (RetTy)) {
1613
1616
const Align RetAlign = getArgumentAlignment (CB, RetTy, 0 , DL);
1614
- DeclareArrayParam (RetSymbol, RetAlign, ResultSize);
1617
+ MakeDeclareArrayParam (RetSymbol, RetAlign, ResultSize);
1615
1618
} else {
1616
- DeclareScalarParam (RetSymbol, ResultSize);
1619
+ MakeDeclareScalarParam (RetSymbol, ResultSize);
1617
1620
}
1618
1621
}
1619
1622
@@ -1737,17 +1740,16 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
1737
1740
1738
1741
LoadChains.push_back (R.getValue (1 ));
1739
1742
1740
- if (NumElts == 1 ) {
1743
+ if (NumElts == 1 )
1741
1744
ProxyRegOps.push_back (R);
1742
- } else {
1745
+ else
1743
1746
for (const unsigned J : llvm::seq (NumElts)) {
1744
1747
SDValue Elt = DAG.getNode (
1745
1748
LoadVT.isVector () ? ISD::EXTRACT_SUBVECTOR
1746
1749
: ISD::EXTRACT_VECTOR_ELT,
1747
1750
dl, LoadVT, R, DAG.getVectorIdxConstant (J * PackingAmt, dl));
1748
1751
ProxyRegOps.push_back (Elt);
1749
1752
}
1750
- }
1751
1753
I += NumElts;
1752
1754
}
1753
1755
}
@@ -5767,7 +5769,7 @@ static SDValue sinkProxyReg(SDValue R, SDValue Chain,
5767
5769
{Chain, R});
5768
5770
}
5769
5771
case ISD::BUILD_VECTOR: {
5770
- if (DCI.isAfterLegalizeDAG ())
5772
+ if (DCI.isBeforeLegalize ())
5771
5773
return SDValue ();
5772
5774
5773
5775
SmallVector<SDValue, 16 > Ops;
@@ -5779,6 +5781,15 @@ static SDValue sinkProxyReg(SDValue R, SDValue Chain,
5779
5781
}
5780
5782
return DCI.DAG .getNode (ISD::BUILD_VECTOR, SDLoc (R), R.getValueType (), Ops);
5781
5783
}
5784
+ case ISD::EXTRACT_VECTOR_ELT: {
5785
+ if (DCI.isBeforeLegalize ())
5786
+ return SDValue ();
5787
+
5788
+ if (SDValue V = sinkProxyReg (R.getOperand (0 ), Chain, DCI))
5789
+ return DCI.DAG .getNode (ISD::EXTRACT_VECTOR_ELT, SDLoc (R), R.getValueType (),
5790
+ V, R.getOperand (1 ));
5791
+ return SDValue ();
5792
+ }
5782
5793
default :
5783
5794
return SDValue ();
5784
5795
}
0 commit comments