Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llvm/include/llvm/CodeGen/ValueTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,9 @@ def aarch64mfp8 : ValueType<8, 253>; // 8-bit value in FPR (AArch64)
def c64 : VTCheriCapability<64, 254>; // 64-bit CHERI capability value
def c128 : VTCheriCapability<128, 255>; // 128-bit CHERI capability value

// Pseudo valuetype mapped to the current CHERI capability pointer size.
def cPTR : VTAny<503>;

let isNormalValueType = false in {
def token : ValueType<0, 504>; // TokenTy
def MetadataVT : ValueType<0, 505> { // Metadata
Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/CodeGenTypes/MachineValueType.h
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,12 @@ namespace llvm {
MVT::LAST_FP_SCALABLE_VECTOR_VALUETYPE,
force_iteration_on_noniterable_enum);
}

static auto cheri_capability_valuetypes() {
return enum_seq_inclusive(MVT::FIRST_CHERI_CAPABILITY_VALUETYPE,
MVT::LAST_CHERI_CAPABILITY_VALUETYPE,
force_iteration_on_noniterable_enum);
}
/// @}
};

Expand Down
46 changes: 43 additions & 3 deletions llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,8 @@ bool TypeSetByHwMode::intersect(SetType &Out, const SetType &In) {
using WildPartT = std::pair<MVT, std::function<bool(MVT)>>;
static const WildPartT WildParts[] = {
{MVT::iPTR, [](MVT T) { return T.isScalarInteger() || T == MVT::iPTR; }},
{MVT::cPTR,
[](MVT T) { return T.isCheriCapability() || T == MVT::cPTR; }},
};

bool Changed = false;
Expand Down Expand Up @@ -816,6 +818,11 @@ void TypeInfer::expandOverloads(TypeSetByHwMode::SetType &Out,
if (Out.count(MVT::pAny)) {
Out.erase(MVT::pAny);
Out.insert(MVT::iPTR);
for (MVT T : MVT::cheri_capability_valuetypes()) {
if (Legal.count(T)) {
Out.insert(MVT::cPTR);
}
}
} else if (Out.count(MVT::iAny)) {
Out.erase(MVT::iAny);
for (MVT T : MVT::integer_valuetypes())
Expand Down Expand Up @@ -1647,9 +1654,11 @@ bool SDTypeConstraint::ApplyTypeConstraint(TreePatternNode &N,
case SDTCisVT:
// Operand must be a particular type.
return NodeToApply.UpdateNodeType(ResNo, VVT, TP);
case SDTCisPtrTy:
// Operand must be same as target pointer type.
return NodeToApply.UpdateNodeType(ResNo, MVT::iPTR, TP);
case SDTCisPtrTy: {
// Operand must be a legal pointer (iPTR, or possibly cPTR) type.
const auto &PtrTys = TP.getDAGPatterns().getLegalPtrTypes();
return NodeToApply.UpdateNodeType(ResNo, PtrTys, TP);
}
case SDTCisInt:
// Require it to be one of the legal integer VTs.
return TI.EnforceInteger(NodeToApply.getExtType(ResNo));
Expand Down Expand Up @@ -3260,6 +3269,7 @@ CodeGenDAGPatterns::CodeGenDAGPatterns(const RecordKeeper &R,
PatternRewriterFn PatternRewriter)
: Records(R), Target(R), Intrinsics(R),
LegalVTS(Target.getLegalValueTypes()),
LegalPtrVTS(ComputeLegalPtrTypes()),
PatternRewriter(std::move(PatternRewriter)) {
ParseNodeInfo();
ParseNodeTransforms();
Expand Down Expand Up @@ -3295,6 +3305,36 @@ const Record *CodeGenDAGPatterns::getSDNodeNamed(StringRef Name) const {
return N;
}

// Compute the subset of iPTR and cPTR legal for each mode, coalescing into the
// default mode where possible to avoid predicate explosion.
TypeSetByHwMode CodeGenDAGPatterns::ComputeLegalPtrTypes() const {
auto LegalPtrsForSet = [](const MachineValueTypeSet &In) {
MachineValueTypeSet Out;
Out.insert(MVT::iPTR);
for (MVT T : MVT::cheri_capability_valuetypes()) {
if (In.count(T)) {
Out.insert(MVT::cPTR);
break;
}
}
return Out;
};

const auto &LegalTypes = getLegalTypes();
MachineValueTypeSet LegalPtrsDefault =
LegalPtrsForSet(LegalTypes.get(DefaultMode));

TypeSetByHwMode LegalPtrTypes;
for (const auto &I : LegalTypes) {
MachineValueTypeSet S = LegalPtrsForSet(I.second);
if (I.first != DefaultMode && S == LegalPtrsDefault)
continue;
LegalPtrTypes.getOrCreate(I.first).insert(S);
}

return LegalPtrTypes;
}

// Parse all of the SDNode definitions for the target, populating SDNodes.
void CodeGenDAGPatterns::ParseNodeInfo() {
const CodeGenHwModes &CGH = getTargetInfo().getHwModes();
Expand Down
3 changes: 3 additions & 0 deletions llvm/utils/TableGen/Common/CodeGenDAGPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,7 @@ class CodeGenDAGPatterns {
std::vector<PatternToMatch> PatternsToMatch;

TypeSetByHwMode LegalVTS;
TypeSetByHwMode LegalPtrVTS;

using PatternRewriterFn = std::function<void(TreePattern *)>;
PatternRewriterFn PatternRewriter;
Expand All @@ -1148,6 +1149,7 @@ class CodeGenDAGPatterns {
CodeGenTarget &getTargetInfo() { return Target; }
const CodeGenTarget &getTargetInfo() const { return Target; }
const TypeSetByHwMode &getLegalTypes() const { return LegalVTS; }
const TypeSetByHwMode &getLegalPtrTypes() const { return LegalPtrVTS; }

const Record *getSDNodeNamed(StringRef Name) const;

Expand Down Expand Up @@ -1249,6 +1251,7 @@ class CodeGenDAGPatterns {
}

private:
TypeSetByHwMode ComputeLegalPtrTypes() const;
void ParseNodeInfo();
void ParseNodeTransforms();
void ParseComplexPatterns();
Expand Down
16 changes: 15 additions & 1 deletion llvm/utils/TableGen/Common/DAGISelMatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,14 @@ static bool TypesAreContradictory(MVT::SimpleValueType T1,
if (T1 == T2)
return false;

if (T1 == MVT::pAny)
return TypesAreContradictory(MVT::iPTR, T2) &&
TypesAreContradictory(MVT::cPTR, T2);

if (T2 == MVT::pAny)
return TypesAreContradictory(T1, MVT::iPTR) &&
TypesAreContradictory(T1, MVT::cPTR);

// If either type is about iPtr, then they don't conflict unless the other
// one is not a scalar integer type.
if (T1 == MVT::iPTR)
Expand All @@ -336,7 +344,13 @@ static bool TypesAreContradictory(MVT::SimpleValueType T1,
if (T2 == MVT::iPTR)
return !MVT(T1).isInteger() || MVT(T1).isVector();

// Otherwise, they are two different non-iPTR types, they conflict.
if (T1 == MVT::cPTR)
return !MVT(T2).isCheriCapability() || MVT(T2).isVector();

if (T2 == MVT::cPTR)
return !MVT(T1).isCheriCapability() || MVT(T1).isVector();

// Otherwise, they are two different non-iPTR/cPTR types, they conflict.
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1426,7 +1426,9 @@ Error OperandMatcher::addTypeCheckPredicate(const TypeSetByHwMode &VTy,
if (!VTy.isMachineValueType())
return failUnsupported("unsupported typeset");

if (VTy.getMachineValueType() == MVT::iPTR && OperandIsAPointer) {
if ((VTy.getMachineValueType() == MVT::iPTR ||
VTy.getMachineValueType() == MVT::cPTR) &&
OperandIsAPointer) {
addPredicate<PointerToAnyOperandMatcher>(0);
return Error::success();
}
Expand Down
6 changes: 3 additions & 3 deletions llvm/utils/TableGen/DAGISelMatcherOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,9 +519,9 @@ static void FactorScope(std::unique_ptr<Matcher> &MatcherPtr) {
CheckTypeMatcher *CTM = cast_or_null<CheckTypeMatcher>(
FindNodeWithKind(Optn, Matcher::CheckType));
if (!CTM ||
// iPTR checks could alias any other case without us knowing, don't
// bother with them.
CTM->getType() == MVT::iPTR ||
// iPTR/cPTR checks could alias any other case without us knowing,
// don't bother with them.
CTM->getType() == MVT::iPTR || CTM->getType() == MVT::cPTR ||
// SwitchType only works for result #0.
CTM->getResNo() != 0 ||
// If the CheckType isn't at the start of the list, see if we can move
Expand Down