Skip to content

[SelectionDAG] Verify SDTCisVT and SDTCVecEltisVT constraints #150125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions llvm/include/llvm/CodeGen/SDNodeInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,21 @@ enum SDNF {
SDNFIsStrictFP,
};

struct VTByHwModePair {
uint8_t Mode;
MVT::SimpleValueType VT;
};

struct SDTypeConstraint {
SDTC Kind;
uint8_t OpNo;
uint8_t OtherOpNo;
MVT::SimpleValueType VT;
/// For Kind == SDTCisVT or SDTCVecEltisVT:
/// - if not using HwMode, NumHwModes == 0 and VT is MVT::SimpleValueType;
/// - otherwise, VT is offset into VTByHwModeTable and NumHwModes specifies
/// the number of entries.
uint8_t NumHwModes;
uint16_t VT;
};

using SDNodeTSFlags = uint32_t;
Expand All @@ -76,13 +86,15 @@ class SDNodeInfo final {
unsigned NumOpcodes;
const SDNodeDesc *Descs;
StringTable Names;
const VTByHwModePair *VTByHwModeTable;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of all targets, only RISCV and LoongArch got a non-empty (two element) table.

const SDTypeConstraint *Constraints;

public:
constexpr SDNodeInfo(unsigned NumOpcodes, const SDNodeDesc *Descs,
StringTable Names, const SDTypeConstraint *Constraints)
StringTable Names, const VTByHwModePair *VTByHwModeTable,
const SDTypeConstraint *Constraints)
: NumOpcodes(NumOpcodes), Descs(Descs), Names(Names),
Constraints(Constraints) {}
VTByHwModeTable(VTByHwModeTable), Constraints(Constraints) {}

/// Returns true if there is a generated description for a node with the given
/// target-specific opcode.
Expand Down
105 changes: 105 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SDNodeInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
//===----------------------------------------------------------------------===//

#include "llvm/CodeGen/SDNodeInfo.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"

using namespace llvm;

Expand Down Expand Up @@ -40,6 +43,26 @@ static void checkOperandType(const SelectionDAG &DAG, const SDNode *N,
ExpectedVT.getEVTString() + ", got " + ActualVT.getEVTString());
}

namespace {

struct ConstraintOp {
const SDNode *N;
unsigned Idx;
bool IsRes;

SDValue getValue() const {
return IsRes ? SDValue(const_cast<SDNode *>(N), Idx) : N->getOperand(Idx);
}

EVT getValueType() const { return getValue().getValueType(); }
};

raw_ostream &operator<<(raw_ostream &OS, const ConstraintOp &Op) {
return OS << (Op.IsRes ? "result" : "operand") << " #" << Op.Idx;
}

} // namespace

void SDNodeInfo::verifyNode(const SelectionDAG &DAG, const SDNode *N) const {
const SDNodeDesc &Desc = getDesc(N->getOpcode());
bool HasChain = Desc.hasProperty(SDNPHasChain);
Expand Down Expand Up @@ -125,4 +148,86 @@ void SDNodeInfo::verifyNode(const SelectionDAG &DAG, const SDNode *N) const {
" must be Register or RegisterMask");
}
}

unsigned VTHwMode =
DAG.getSubtarget().getHwMode(MCSubtargetInfo::HwMode_ValueType);

auto GetConstraintOp = [&](unsigned Idx) {
if (Idx < Desc.NumResults)
return ConstraintOp{N, Idx, /*IsRes=*/true};
return ConstraintOp{N, HasChain + (Idx - Desc.NumResults), /*IsRes=*/false};
};
Comment on lines +155 to +159
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to static member ConstraintOp::get?


auto GetConstraintVT = [&](const SDTypeConstraint &C) {
if (!C.NumHwModes)
return static_cast<MVT::SimpleValueType>(C.VT);
for (auto [Mode, VT] : ArrayRef(&VTByHwModeTable[C.VT], C.NumHwModes))
if (Mode == VTHwMode)
return VT;
llvm_unreachable("No value type for this HW mode");
};

SmallString<128> ES;
raw_svector_ostream SS(ES);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Currently, it only diagnoses the first constraint violation. Would it be helpful if it reported all problems?
  2. Would it be helpful if the diagnostics included the constraint as written in *.td files (e.g. "SDTCisVT<2, i32>")?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1: preferably, yes, but this could come later?

2: I'm not sure the syntax has to be 1:1, but something to reference back to the constraint as written would be good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, will do that after implementing more constraint checks.


for (const SDTypeConstraint &C : getConstraints(N->getOpcode())) {
ConstraintOp Op = GetConstraintOp(C.OpNo);
EVT OpVT = Op.getValueType();

switch (C.Kind) {
case SDTCisVT: {
EVT ExpectedVT = GetConstraintVT(C);

bool IsPtr = ExpectedVT == MVT::iPTR;
if (IsPtr)
ExpectedVT =
DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout());

if (OpVT != ExpectedVT) {
SS << Op << " must have type " << ExpectedVT;
if (IsPtr)
SS << " (iPTR)";
SS << ", but has type " << OpVT;
reportNodeError(DAG, N, SS.str());
}
break;
}
case SDTCisPtrTy:
break;
case SDTCisInt:
break;
case SDTCisFP:
break;
case SDTCisVec:
break;
case SDTCisSameAs:
break;
case SDTCisVTSmallerThanOp:
break;
case SDTCisOpSmallerThanOp:
break;
case SDTCisEltOfVec:
break;
case SDTCisSubVecOfVec:
break;
case SDTCVecEltisVT: {
EVT ExpectedVT = GetConstraintVT(C);

if (!OpVT.isVector()) {
SS << Op << " must have vector type";
reportNodeError(DAG, N, SS.str());
}
if (OpVT.getVectorElementType() != ExpectedVT) {
SS << Op << " must have " << ExpectedVT << " element type, but has "
<< OpVT.getVectorElementType() << " element type";
reportNodeError(DAG, N, SS.str());
}
break;
}
case SDTCisSameNumEltsAs:
break;
case SDTCisSameSizeAs:
break;
}
}
}
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1144,7 +1144,7 @@ def AArch64msrr : SDNode<"AArch64ISD::MSRR",
SDTCisVT<2, i64>]>,
[SDNPHasChain]>;

def SD_AArch64rshrnb : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisVec<1>, SDTCisInt<2>]>;
def SD_AArch64rshrnb : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisVec<1>, SDTCisVT<2, i32>]>;
// Vector narrowing shift by immediate (bottom)
def AArch64rshrnb : SDNode<"AArch64ISD::RSHRNB_I", SD_AArch64rshrnb>;
def AArch64rshrnb_pf : PatFrags<(ops node:$rs, node:$i),
Expand Down
44 changes: 31 additions & 13 deletions llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,41 @@ AArch64SelectionDAGInfo::AArch64SelectionDAGInfo()

void AArch64SelectionDAGInfo::verifyTargetNode(const SelectionDAG &DAG,
const SDNode *N) const {
switch (N->getOpcode()) {
case AArch64ISD::ADC:
case AArch64ISD::SBC:
case AArch64ISD::ADCS:
case AArch64ISD::SBCS:
// operand #2 must have type i32, but has type glue
return;
case AArch64ISD::SUBS:
// result #1 must have type i32, but has type glue
return;
case AArch64ISD::CSEL:
case AArch64ISD::CSINC:
case AArch64ISD::BRCOND:
// operand #3 must have type i32, but has type glue
return;
case AArch64ISD::WrapperLarge:
// operand #0 must have type i32, but has type i64
return;
case AArch64ISD::LDNP:
// result #0 must have type v4i32, but has type v2f64
return;
case AArch64ISD::STNP:
// operand #1 must have type v4i32, but has type v2i64
return;
}

SelectionDAGGenTargetInfo::verifyTargetNode(DAG, N);

#ifndef NDEBUG
// Some additional checks not yet implemented by verifyTargetNode.
switch (N->getOpcode()) {
default:
return SelectionDAGGenTargetInfo::verifyTargetNode(DAG, N);
case AArch64ISD::SADDWT:
case AArch64ISD::SADDWB:
case AArch64ISD::UADDWT:
case AArch64ISD::UADDWB: {
assert(N->getNumValues() == 1 && "Expected one result!");
assert(N->getNumOperands() == 2 && "Expected two operands!");
EVT VT = N->getValueType(0);
EVT Op0VT = N->getOperand(0).getValueType();
EVT Op1VT = N->getOperand(1).getValueType();
Expand All @@ -61,8 +86,6 @@ void AArch64SelectionDAGInfo::verifyTargetNode(const SelectionDAG &DAG,
case AArch64ISD::SUNPKHI:
case AArch64ISD::UUNPKLO:
case AArch64ISD::UUNPKHI: {
assert(N->getNumValues() == 1 && "Expected one result!");
assert(N->getNumOperands() == 1 && "Expected one operand!");
EVT VT = N->getValueType(0);
EVT OpVT = N->getOperand(0).getValueType();
assert(OpVT.isVector() && VT.isVector() && OpVT.isInteger() &&
Expand All @@ -79,8 +102,6 @@ void AArch64SelectionDAGInfo::verifyTargetNode(const SelectionDAG &DAG,
case AArch64ISD::UZP2:
case AArch64ISD::ZIP1:
case AArch64ISD::ZIP2: {
assert(N->getNumValues() == 1 && "Expected one result!");
assert(N->getNumOperands() == 2 && "Expected two operands!");
EVT VT = N->getValueType(0);
EVT Op0VT = N->getOperand(0).getValueType();
EVT Op1VT = N->getOperand(1).getValueType();
Expand All @@ -90,11 +111,8 @@ void AArch64SelectionDAGInfo::verifyTargetNode(const SelectionDAG &DAG,
break;
}
case AArch64ISD::RSHRNB_I: {
assert(N->getNumValues() == 1 && "Expected one result!");
assert(N->getNumOperands() == 2 && "Expected two operands!");
EVT VT = N->getValueType(0);
EVT Op0VT = N->getOperand(0).getValueType();
EVT Op1VT = N->getOperand(1).getValueType();
assert(VT.isVector() && VT.isInteger() &&
"Expected integer vector result type!");
assert(Op0VT.isVector() && Op0VT.isInteger() &&
Expand All @@ -103,8 +121,8 @@ void AArch64SelectionDAGInfo::verifyTargetNode(const SelectionDAG &DAG,
"Expected vectors of equal size!");
assert(VT.getVectorElementCount() == Op0VT.getVectorElementCount() * 2 &&
"Expected input vector with half the lanes of its result!");
assert(Op1VT == MVT::i32 && isa<ConstantSDNode>(N->getOperand(1)) &&
"Expected second operand to be a constant i32!");
assert(isa<ConstantSDNode>(N->getOperand(1)) &&
"Expected second operand to be a constant!");
break;
}
}
Expand Down
15 changes: 15 additions & 0 deletions llvm/lib/Target/M68k/M68kSelectionDAGInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,19 @@ using namespace llvm;
M68kSelectionDAGInfo::M68kSelectionDAGInfo()
: SelectionDAGGenTargetInfo(M68kGenSDNodeInfo) {}

void M68kSelectionDAGInfo::verifyTargetNode(const SelectionDAG &DAG,
const SDNode *N) const {
switch (N->getOpcode()) {
case M68kISD::ADD:
case M68kISD::SUBX:
// result #1 must have type i8, but has type i32
return;
case M68kISD::SETCC:
// operand #1 must have type i8, but has type i32
return;
}

SelectionDAGGenTargetInfo::verifyTargetNode(DAG, N);
}

M68kSelectionDAGInfo::~M68kSelectionDAGInfo() = default;
3 changes: 3 additions & 0 deletions llvm/lib/Target/M68k/M68kSelectionDAGInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class M68kSelectionDAGInfo : public SelectionDAGGenTargetInfo {
M68kSelectionDAGInfo();

~M68kSelectionDAGInfo() override;

void verifyTargetNode(const SelectionDAG &DAG,
const SDNode *N) const override;
};

} // namespace llvm
Expand Down
19 changes: 5 additions & 14 deletions llvm/lib/Target/RISCV/RISCVSelectionDAGInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,22 @@ RISCVSelectionDAGInfo::~RISCVSelectionDAGInfo() = default;

void RISCVSelectionDAGInfo::verifyTargetNode(const SelectionDAG &DAG,
const SDNode *N) const {
SelectionDAGGenTargetInfo::verifyTargetNode(DAG, N);

#ifndef NDEBUG
// Some additional checks not yet implemented by verifyTargetNode.
switch (N->getOpcode()) {
default:
return SelectionDAGGenTargetInfo::verifyTargetNode(DAG, N);
case RISCVISD::TUPLE_EXTRACT:
assert(N->getNumOperands() == 2 && "Expected three operands!");
assert(N->getOperand(1).getOpcode() == ISD::TargetConstant &&
N->getOperand(1).getValueType() == MVT::i32 &&
"Expected index to be an i32 target constant!");
"Expected index to be a target constant!");
break;
case RISCVISD::TUPLE_INSERT:
assert(N->getNumOperands() == 3 && "Expected three operands!");
assert(N->getOperand(2).getOpcode() == ISD::TargetConstant &&
N->getOperand(2).getValueType() == MVT::i32 &&
"Expected index to be an i32 target constant!");
"Expected index to be a target constant!");
break;
case RISCVISD::VQDOT_VL:
case RISCVISD::VQDOTU_VL:
case RISCVISD::VQDOTSU_VL: {
assert(N->getNumValues() == 1 && "Expected one result!");
assert(N->getNumOperands() == 5 && "Expected five operands!");
EVT VT = N->getValueType(0);
assert(VT.isScalableVector() && VT.getVectorElementType() == MVT::i32 &&
"Expected result to be an i32 scalable vector");
Expand All @@ -50,13 +45,9 @@ void RISCVSelectionDAGInfo::verifyTargetNode(const SelectionDAG &DAG,
"Expected result and first 3 operands to have the same type!");
EVT MaskVT = N->getOperand(3).getValueType();
assert(MaskVT.isScalableVector() &&
MaskVT.getVectorElementType() == MVT::i1 &&
MaskVT.getVectorElementCount() == VT.getVectorElementCount() &&
"Expected mask VT to be an i1 scalable vector with same number of "
"elements as the result");
assert((N->getOperand(4).getValueType() == MVT::i32 ||
N->getOperand(4).getValueType() == MVT::i64) &&
"Expect VL operand to be i32 or i64");
break;
}
}
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/Sparc/SparcInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def callseq_start : SDNode<"ISD::CALLSEQ_START", SDT_SPCallSeqStart,
def callseq_end : SDNode<"ISD::CALLSEQ_END", SDT_SPCallSeqEnd,
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue]>;

def SDT_SPCall : SDTypeProfile<0, -1, [SDTCisVT<0, i32>]>;
def SDT_SPCall : SDTypeProfile<0, -1, [SDTCisVT<0, iPTR>]>;
def call : SDNode<"SPISD::CALL", SDT_SPCall,
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
SDNPVariadic]>;
Expand Down
40 changes: 22 additions & 18 deletions llvm/test/TableGen/SDNodeInfoEmitter/advanced.td
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,26 @@ def my_node_3 : SDNode<
// CHECK-NEXT: "MyTargetISD::NODE_3\0"
// CHECK-NEXT: ;

// CHECK: static const SDTypeConstraint MyTargetSDTypeConstraints[] = {
// CHECK-NEXT: /* 0 */ {SDTCisVT, 1, 0, MVT::i2},
// CHECK-SAME: {SDTCisVT, 0, 0, MVT::i1},
// CHECK-NEXT: /* 2 */ {SDTCisSameSizeAs, 19, 18, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCisSameNumEltsAs, 17, 16, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCVecEltisVT, 15, 0, MVT::i32},
// CHECK-SAME: {SDTCisSubVecOfVec, 14, 13, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCisEltOfVec, 12, 11, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCisOpSmallerThanOp, 10, 9, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCisVTSmallerThanOp, 8, 7, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCisSameAs, 6, 5, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCisVec, 4, 0, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCisFP, 3, 0, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCisInt, 2, 0, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCisPtrTy, 1, 0, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCisVT, 0, 0, MVT::i1},
// CHECK: static const VTByHwModePair MyTargetVTByHwModeTable[] = {
// CHECK-NEXT: /* dummy */ {0, MVT::INVALID_SIMPLE_VALUE_TYPE}
// CHECK-NEXT: };
// CHECK-EMPTY:
// CHECK-NEXT: static const SDTypeConstraint MyTargetSDTypeConstraints[] = {
// CHECK-NEXT: /* 0 */ {SDTCisVT, 1, 0, 0, MVT::i2},
// CHECK-SAME: {SDTCisVT, 0, 0, 0, MVT::i1},
// CHECK-NEXT: /* 2 */ {SDTCisSameSizeAs, 19, 18, 0, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCisSameNumEltsAs, 17, 16, 0, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCVecEltisVT, 15, 0, 0, MVT::i32},
// CHECK-SAME: {SDTCisSubVecOfVec, 14, 13, 0, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCisEltOfVec, 12, 11, 0, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCisOpSmallerThanOp, 10, 9, 0, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCisVTSmallerThanOp, 8, 7, 0, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCisSameAs, 6, 5, 0, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCisVec, 4, 0, 0, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCisFP, 3, 0, 0, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCisInt, 2, 0, 0, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCisPtrTy, 1, 0, 0, MVT::INVALID_SIMPLE_VALUE_TYPE},
// CHECK-SAME: {SDTCisVT, 0, 0, 0, MVT::i1},
// CHECK-NEXT: };
// CHECK-EMPTY:
// CHECK-NEXT: static const SDNodeDesc MyTargetSDNodeDescs[] = {
Expand All @@ -90,5 +94,5 @@ def my_node_3 : SDNode<
// CHECK-NEXT: };
// CHECK-EMPTY:
// CHECK-NEXT: static const SDNodeInfo MyTargetGenSDNodeInfo(
// CHECK-NEXT: /*NumOpcodes=*/3, MyTargetSDNodeDescs,
// CHECK-NEXT: MyTargetSDNodeNames, MyTargetSDTypeConstraints);
// CHECK-NEXT: /*NumOpcodes=*/3, MyTargetSDNodeDescs, MyTargetSDNodeNames,
// CHECK-NEXT: MyTargetVTByHwModeTable, MyTargetSDTypeConstraints);
Loading