Skip to content

Commit 3b62a2b

Browse files
committed
[NVPTX] Fixup v2i8 call lowering, use generic load/store nodes for call params
1 parent 0f2484a commit 3b62a2b

39 files changed

+1645
-1655
lines changed

clang/test/CodeGenCUDA/bf16.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ __device__ __bf16 external_func( __bf16 in);
3535
// CHECK: .param .align 2 .b8 _Z9test_callDF16b_param_0[2]
3636
__device__ __bf16 test_call( __bf16 in) {
3737
// CHECK: ld.param.b16 %[[R:rs[0-9]+]], [_Z9test_callDF16b_param_0];
38-
// CHECK: st.param.b16 [param0], %[[R]];
3938
// CHECK: .param .align 2 .b8 retval0[2];
39+
// CHECK: st.param.b16 [param0], %[[R]];
4040
// CHECK: call.uni (retval0), _Z13external_funcDF16b, (param0);
4141
// CHECK: ld.param.b16 %[[RET:rs[0-9]+]], [retval0];
4242
return external_func(in);

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 0 additions & 273 deletions
Original file line numberDiff line numberDiff line change
@@ -145,18 +145,6 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
145145
if (tryStoreVector(N))
146146
return;
147147
break;
148-
case NVPTXISD::LoadParam:
149-
case NVPTXISD::LoadParamV2:
150-
case NVPTXISD::LoadParamV4:
151-
if (tryLoadParam(N))
152-
return;
153-
break;
154-
case NVPTXISD::StoreParam:
155-
case NVPTXISD::StoreParamV2:
156-
case NVPTXISD::StoreParamV4:
157-
if (tryStoreParam(N))
158-
return;
159-
break;
160148
case ISD::INTRINSIC_W_CHAIN:
161149
if (tryIntrinsicChain(N))
162150
return;
@@ -1462,267 +1450,6 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
14621450
return true;
14631451
}
14641452

1465-
bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
1466-
SDValue Chain = Node->getOperand(0);
1467-
SDValue Offset = Node->getOperand(2);
1468-
SDValue Glue = Node->getOperand(3);
1469-
SDLoc DL(Node);
1470-
MemSDNode *Mem = cast<MemSDNode>(Node);
1471-
1472-
unsigned VecSize;
1473-
switch (Node->getOpcode()) {
1474-
default:
1475-
return false;
1476-
case NVPTXISD::LoadParam:
1477-
VecSize = 1;
1478-
break;
1479-
case NVPTXISD::LoadParamV2:
1480-
VecSize = 2;
1481-
break;
1482-
case NVPTXISD::LoadParamV4:
1483-
VecSize = 4;
1484-
break;
1485-
}
1486-
1487-
EVT EltVT = Node->getValueType(0);
1488-
EVT MemVT = Mem->getMemoryVT();
1489-
1490-
std::optional<unsigned> Opcode;
1491-
1492-
switch (VecSize) {
1493-
default:
1494-
return false;
1495-
case 1:
1496-
Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy,
1497-
NVPTX::LoadParamMemI8, NVPTX::LoadParamMemI16,
1498-
NVPTX::LoadParamMemI32, NVPTX::LoadParamMemI64);
1499-
break;
1500-
case 2:
1501-
Opcode =
1502-
pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy, NVPTX::LoadParamMemV2I8,
1503-
NVPTX::LoadParamMemV2I16, NVPTX::LoadParamMemV2I32,
1504-
NVPTX::LoadParamMemV2I64);
1505-
break;
1506-
case 4:
1507-
Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy,
1508-
NVPTX::LoadParamMemV4I8, NVPTX::LoadParamMemV4I16,
1509-
NVPTX::LoadParamMemV4I32, {/* no v4i64 */});
1510-
break;
1511-
}
1512-
if (!Opcode)
1513-
return false;
1514-
1515-
SDVTList VTs;
1516-
if (VecSize == 1) {
1517-
VTs = CurDAG->getVTList(EltVT, MVT::Other, MVT::Glue);
1518-
} else if (VecSize == 2) {
1519-
VTs = CurDAG->getVTList(EltVT, EltVT, MVT::Other, MVT::Glue);
1520-
} else {
1521-
EVT EVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other, MVT::Glue };
1522-
VTs = CurDAG->getVTList(EVTs);
1523-
}
1524-
1525-
unsigned OffsetVal = Offset->getAsZExtVal();
1526-
1527-
SmallVector<SDValue, 2> Ops(
1528-
{CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue});
1529-
1530-
ReplaceNode(Node, CurDAG->getMachineNode(*Opcode, DL, VTs, Ops));
1531-
return true;
1532-
}
1533-
1534-
// Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri)
1535-
#define getOpcV2H(ty, opKind0, opKind1) \
1536-
NVPTX::StoreParamV2##ty##_##opKind0##opKind1
1537-
1538-
#define getOpcV2H1(ty, opKind0, isImm1) \
1539-
(isImm1) ? getOpcV2H(ty, opKind0, i) : getOpcV2H(ty, opKind0, r)
1540-
1541-
#define getOpcodeForVectorStParamV2(ty, isimm) \
1542-
(isimm[0]) ? getOpcV2H1(ty, i, isimm[1]) : getOpcV2H1(ty, r, isimm[1])
1543-
1544-
#define getOpcV4H(ty, opKind0, opKind1, opKind2, opKind3) \
1545-
NVPTX::StoreParamV4##ty##_##opKind0##opKind1##opKind2##opKind3
1546-
1547-
#define getOpcV4H3(ty, opKind0, opKind1, opKind2, isImm3) \
1548-
(isImm3) ? getOpcV4H(ty, opKind0, opKind1, opKind2, i) \
1549-
: getOpcV4H(ty, opKind0, opKind1, opKind2, r)
1550-
1551-
#define getOpcV4H2(ty, opKind0, opKind1, isImm2, isImm3) \
1552-
(isImm2) ? getOpcV4H3(ty, opKind0, opKind1, i, isImm3) \
1553-
: getOpcV4H3(ty, opKind0, opKind1, r, isImm3)
1554-
1555-
#define getOpcV4H1(ty, opKind0, isImm1, isImm2, isImm3) \
1556-
(isImm1) ? getOpcV4H2(ty, opKind0, i, isImm2, isImm3) \
1557-
: getOpcV4H2(ty, opKind0, r, isImm2, isImm3)
1558-
1559-
#define getOpcodeForVectorStParamV4(ty, isimm) \
1560-
(isimm[0]) ? getOpcV4H1(ty, i, isimm[1], isimm[2], isimm[3]) \
1561-
: getOpcV4H1(ty, r, isimm[1], isimm[2], isimm[3])
1562-
1563-
#define getOpcodeForVectorStParam(n, ty, isimm) \
1564-
(n == 2) ? getOpcodeForVectorStParamV2(ty, isimm) \
1565-
: getOpcodeForVectorStParamV4(ty, isimm)
1566-
1567-
static unsigned pickOpcodeForVectorStParam(SmallVector<SDValue, 8> &Ops,
1568-
unsigned NumElts,
1569-
MVT::SimpleValueType MemTy,
1570-
SelectionDAG *CurDAG, SDLoc DL) {
1571-
// Determine which inputs are registers and immediates make new operators
1572-
// with constant values
1573-
SmallVector<bool, 4> IsImm(NumElts, false);
1574-
for (unsigned i = 0; i < NumElts; i++) {
1575-
IsImm[i] = (isa<ConstantSDNode>(Ops[i]) || isa<ConstantFPSDNode>(Ops[i]));
1576-
if (IsImm[i]) {
1577-
SDValue Imm = Ops[i];
1578-
if (MemTy == MVT::f32 || MemTy == MVT::f64) {
1579-
const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
1580-
const ConstantFP *CF = ConstImm->getConstantFPValue();
1581-
Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
1582-
} else {
1583-
const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
1584-
const ConstantInt *CI = ConstImm->getConstantIntValue();
1585-
Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
1586-
}
1587-
Ops[i] = Imm;
1588-
}
1589-
}
1590-
1591-
// Get opcode for MemTy, size, and register/immediate operand ordering
1592-
switch (MemTy) {
1593-
case MVT::i8:
1594-
return getOpcodeForVectorStParam(NumElts, I8, IsImm);
1595-
case MVT::i16:
1596-
return getOpcodeForVectorStParam(NumElts, I16, IsImm);
1597-
case MVT::i32:
1598-
return getOpcodeForVectorStParam(NumElts, I32, IsImm);
1599-
case MVT::i64:
1600-
assert(NumElts == 2 && "MVT too large for NumElts > 2");
1601-
return getOpcodeForVectorStParamV2(I64, IsImm);
1602-
case MVT::f32:
1603-
return getOpcodeForVectorStParam(NumElts, F32, IsImm);
1604-
case MVT::f64:
1605-
assert(NumElts == 2 && "MVT too large for NumElts > 2");
1606-
return getOpcodeForVectorStParamV2(F64, IsImm);
1607-
1608-
// These cases don't support immediates, just use the all register version
1609-
// and generate moves.
1610-
case MVT::i1:
1611-
return (NumElts == 2) ? NVPTX::StoreParamV2I8_rr
1612-
: NVPTX::StoreParamV4I8_rrrr;
1613-
case MVT::f16:
1614-
case MVT::bf16:
1615-
return (NumElts == 2) ? NVPTX::StoreParamV2I16_rr
1616-
: NVPTX::StoreParamV4I16_rrrr;
1617-
case MVT::v2f16:
1618-
case MVT::v2bf16:
1619-
case MVT::v2i16:
1620-
case MVT::v4i8:
1621-
return (NumElts == 2) ? NVPTX::StoreParamV2I32_rr
1622-
: NVPTX::StoreParamV4I32_rrrr;
1623-
default:
1624-
llvm_unreachable("Cannot select st.param for unknown MemTy");
1625-
}
1626-
}
1627-
1628-
bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
1629-
SDLoc DL(N);
1630-
SDValue Chain = N->getOperand(0);
1631-
SDValue Param = N->getOperand(1);
1632-
unsigned ParamVal = Param->getAsZExtVal();
1633-
SDValue Offset = N->getOperand(2);
1634-
unsigned OffsetVal = Offset->getAsZExtVal();
1635-
MemSDNode *Mem = cast<MemSDNode>(N);
1636-
SDValue Glue = N->getOperand(N->getNumOperands() - 1);
1637-
1638-
// How many elements do we have?
1639-
unsigned NumElts;
1640-
switch (N->getOpcode()) {
1641-
default:
1642-
llvm_unreachable("Unexpected opcode");
1643-
case NVPTXISD::StoreParam:
1644-
NumElts = 1;
1645-
break;
1646-
case NVPTXISD::StoreParamV2:
1647-
NumElts = 2;
1648-
break;
1649-
case NVPTXISD::StoreParamV4:
1650-
NumElts = 4;
1651-
break;
1652-
}
1653-
1654-
// Build vector of operands
1655-
SmallVector<SDValue, 8> Ops;
1656-
for (unsigned i = 0; i < NumElts; ++i)
1657-
Ops.push_back(N->getOperand(i + 3));
1658-
Ops.append({CurDAG->getTargetConstant(ParamVal, DL, MVT::i32),
1659-
CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue});
1660-
1661-
// Determine target opcode
1662-
// If we have an i1, use an 8-bit store. The lowering code in
1663-
// NVPTXISelLowering will have already emitted an upcast.
1664-
std::optional<unsigned> Opcode;
1665-
switch (NumElts) {
1666-
default:
1667-
llvm_unreachable("Unexpected NumElts");
1668-
case 1: {
1669-
MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
1670-
SDValue Imm = Ops[0];
1671-
if (MemTy != MVT::f16 && MemTy != MVT::bf16 &&
1672-
(isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
1673-
// Convert immediate to target constant
1674-
if (MemTy == MVT::f32 || MemTy == MVT::f64) {
1675-
const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
1676-
const ConstantFP *CF = ConstImm->getConstantFPValue();
1677-
Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
1678-
} else {
1679-
const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
1680-
const ConstantInt *CI = ConstImm->getConstantIntValue();
1681-
Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
1682-
}
1683-
Ops[0] = Imm;
1684-
// Use immediate version of store param
1685-
Opcode =
1686-
pickOpcodeForVT(MemTy, NVPTX::StoreParamI8_i, NVPTX::StoreParamI16_i,
1687-
NVPTX::StoreParamI32_i, NVPTX::StoreParamI64_i);
1688-
} else
1689-
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
1690-
NVPTX::StoreParamI8_r, NVPTX::StoreParamI16_r,
1691-
NVPTX::StoreParamI32_r, NVPTX::StoreParamI64_r);
1692-
if (Opcode == NVPTX::StoreParamI8_r) {
1693-
// Fine tune the opcode depending on the size of the operand.
1694-
// This helps to avoid creating redundant COPY instructions in
1695-
// InstrEmitter::AddRegisterOperand().
1696-
switch (Ops[0].getSimpleValueType().SimpleTy) {
1697-
default:
1698-
break;
1699-
case MVT::i32:
1700-
Opcode = NVPTX::StoreParamI8TruncI32_r;
1701-
break;
1702-
case MVT::i64:
1703-
Opcode = NVPTX::StoreParamI8TruncI64_r;
1704-
break;
1705-
}
1706-
}
1707-
break;
1708-
}
1709-
case 2:
1710-
case 4: {
1711-
MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
1712-
Opcode = pickOpcodeForVectorStParam(Ops, NumElts, MemTy, CurDAG, DL);
1713-
break;
1714-
}
1715-
}
1716-
1717-
SDVTList RetVTs = CurDAG->getVTList(MVT::Other, MVT::Glue);
1718-
SDNode *Ret = CurDAG->getMachineNode(*Opcode, DL, RetVTs, Ops);
1719-
MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
1720-
CurDAG->setNodeMemRefs(cast<MachineSDNode>(Ret), {MemRef});
1721-
1722-
ReplaceNode(N, Ret);
1723-
return true;
1724-
}
1725-
17261453
/// SelectBFE - Look for instruction sequences that can be made more efficient
17271454
/// by using the 'bfe' (bit-field extract) PTX instruction
17281455
bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) {

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
7878
bool tryLDG(MemSDNode *N);
7979
bool tryStore(SDNode *N);
8080
bool tryStoreVector(SDNode *N);
81-
bool tryLoadParam(SDNode *N);
82-
bool tryStoreParam(SDNode *N);
8381
bool tryFence(SDNode *N);
8482
void SelectAddrSpaceCast(SDNode *N);
8583
bool tryBFE(SDNode *N);

0 commit comments

Comments
 (0)