Skip to content

[RFC][LV] Add support for speculative loads in loops that may fault #151300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions llvm/docs/LangRef.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24243,6 +24243,63 @@ Examples:
%also.r = call <8 x i8> @llvm.masked.load.v8i8.p0(ptr %ptr, i32 2, <8 x i1> %mask, <8 x i8> poison)


.. _int_vp_ff_load:

'``llvm.vp.ff.load``' Intrinsic
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""
This is an overloaded intrinsic.

::

declare {<4 x float>, i32} @llvm.vp.load.ff.v4f32.p0(ptr %ptr, <4 x i1> %mask, i32 %evl)
declare {<vscale x 2 x i16>, i32} @llvm.vp.load.ff.nxv2i16.p0(ptr %ptr, <vscale x 2 x i1> %mask, i32 %evl)
declare {<8 x float>, i32} @llvm.vp.load.ff.v8f32.p1(ptr addrspace(1) %ptr, <8 x i1> %mask, i32 %evl)
declare {<vscale x 1 x i64>, i32} @llvm.vp.load.ff.nxv1i64.p6(ptr addrspace(6) %ptr, <vscale x 1 x i1> %mask, i32 %evl)

Overview:
"""""""""

The '``llvm.vp.load.ff.*``' intrinsic is similar to '``llvm.vp.load.*``', but
will not trap if there are not ``evl`` readable elements at the pointer.

Arguments:
""""""""""

The first argument is the base pointer for the load. The second argument is a
vector of boolean values with the same number of elements as the first return
type. The third is the explicit vector length of the operation. The first
return type and underlying type of the base pointer are the same vector types.

The :ref:`align <attr_align>` parameter attribute can be provided for the first
argument.

Semantics:
""""""""""

The '``llvm.vp.load.ff``' intrinsic reads a vector from memory similar to
'``llvm.vp.load``, but will only trap if the first lane is unreadable. If
any other lane is unreadable, the number of successfully read lanes will
be returned in the second return value. The result in the first return value
for the lanes that were not successfully read is
:ref:`poison value <poisonvalues>`. If ``evl`` is 0, no read occurs and thus no
trap can occur for the first lane. If ``mask`` is 0 for the first lane, no
trap occurs. This intrinsic is allowed to read fewer than ``evl`` lanes even
if no trap would occur. If ``evl`` is non-zero, the result in the second result
must be at least 1 even if the first lane is disabled by ``mask``.

The default alignment is taken as the ABI alignment of the first return
type as specified by the :ref:`datalayout string<langref_datalayout>`.

Examples:
"""""""""

.. code-block:: text

%r = call {<8 x i8>, i32} @llvm.vp.load.ff.v8i8.p0(ptr align 2 %ptr, <8 x i1> %mask, i32 %evl)

.. _int_vp_store:

'``llvm.vp.store``' Intrinsic
Expand Down
8 changes: 8 additions & 0 deletions llvm/include/llvm/Analysis/Loads.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ LLVM_ABI bool isDereferenceableReadOnlyLoop(
Loop *L, ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC,
SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr);

/// Return true if the loop \p L cannot fault on any iteration and only
/// contains read-only memory accesses. Also collect loads that are not
/// guaranteed to be dereferenceable.
LLVM_ABI bool isReadOnlyLoopWithSafeOrSpeculativeLoads(
Loop *L, ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC,
SmallVectorImpl<LoadInst *> *SpeculativeLoads,
SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr);

/// Return true if we know that executing a load from this value cannot trap.
///
/// If DT and ScanFrom are specified this method performs context-sensitive
Expand Down
3 changes: 3 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -1857,6 +1857,9 @@ class TargetTransformInfo {
/// \returns True if the target supports scalable vectors.
LLVM_ABI bool supportsScalableVectors() const;

/// \returns True if the target supports speculative loads.
LLVM_ABI bool supportsSpeculativeLoads() const;

/// \return true when scalable vectorization is preferred.
LLVM_ABI bool enableScalableVectorization() const;

Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,8 @@ class TargetTransformInfoImplBase {

virtual bool supportsScalableVectors() const { return false; }

virtual bool supportsSpeculativeLoads() const { return false; }

virtual bool enableScalableVectorization() const { return false; }

virtual bool hasActiveVectorLength() const { return false; }
Expand Down
3 changes: 3 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -1668,6 +1668,9 @@ class SelectionDAG {
ArrayRef<SDValue> Ops,
MachineMemOperand *MMO,
ISD::MemIndexType IndexType);
LLVM_ABI SDValue getLoadFFVP(EVT VT, const SDLoc &DL, SDValue Chain,
SDValue Ptr, SDValue Mask, SDValue EVL,
MachineMemOperand *MMO);

LLVM_ABI SDValue getGetFPEnv(SDValue Chain, const SDLoc &dl, SDValue Ptr,
EVT MemVT, MachineMemOperand *MMO);
Expand Down
17 changes: 17 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAGNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -3099,6 +3099,23 @@ class MaskedHistogramSDNode : public MaskedGatherScatterSDNode {
}
};

class VPLoadFFSDNode : public MemSDNode {
public:
friend class SelectionDAG;

VPLoadFFSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, EVT MemVT,
MachineMemOperand *MMO)
: MemSDNode(ISD::VP_LOAD_FF, Order, dl, VTs, MemVT, MMO) {}

const SDValue &getBasePtr() const { return getOperand(1); }
const SDValue &getMask() const { return getOperand(2); }
const SDValue &getVectorLength() const { return getOperand(3); }

static bool classof(const SDNode *N) {
return N->getOpcode() == ISD::VP_LOAD_FF;
}
};

class FPStateAccessSDNode : public MemSDNode {
public:
friend class SelectionDAG;
Expand Down
8 changes: 8 additions & 0 deletions llvm/include/llvm/IR/Intrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -1932,6 +1932,14 @@ def int_vp_load : DefaultAttrsIntrinsic<[ llvm_anyvector_ty],
llvm_i32_ty],
[ NoCapture<ArgIndex<0>>, IntrReadMem, IntrArgMemOnly ]>;

def int_vp_load_ff
: DefaultAttrsIntrinsic<[llvm_anyvector_ty, llvm_i32_ty],
[llvm_anyptr_ty,
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
llvm_i32_ty],
[NoCapture<ArgIndex<0>>, IntrNoSync, IntrReadMem,
IntrWillReturn, IntrArgMemOnly]>;

def int_vp_gather: DefaultAttrsIntrinsic<[ llvm_anyvector_ty],
[ LLVMVectorOfAnyPointersToElt<0>,
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/IR/VPIntrinsics.def
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,12 @@ VP_PROPERTY_FUNCTIONAL_OPC(Load)
VP_PROPERTY_FUNCTIONAL_INTRINSIC(masked_load)
END_REGISTER_VP(vp_load, VP_LOAD)

BEGIN_REGISTER_VP_INTRINSIC(vp_load_ff, 1, 2)
// val,chain = VP_LOAD_FF chain,base,mask,evl
BEGIN_REGISTER_VP_SDNODE(VP_LOAD_FF, -1, vp_load_ff, 2, 3)
HELPER_MAP_VPID_TO_VPSD(vp_load_ff, VP_LOAD_FF)
VP_PROPERTY_NO_FUNCTIONAL
END_REGISTER_VP(vp_load_ff, VP_LOAD_FF)
// llvm.experimental.vp.strided.load(ptr,stride,mask,vlen)
BEGIN_REGISTER_VP_INTRINSIC(experimental_vp_strided_load, 2, 3)
// chain = EXPERIMENTAL_VP_STRIDED_LOAD chain,base,offset,stride,mask,evl
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,10 @@ class LoopVectorizationLegality {
/// Returns a list of all known histogram operations in the loop.
bool hasHistograms() const { return !Histograms.empty(); }

const SmallPtrSetImpl<const Instruction *> &getSpeculativeLoads() const {
return SpeculativeLoads;
}

PredicatedScalarEvolution *getPredicatedScalarEvolution() const {
return &PSE;
}
Expand Down Expand Up @@ -645,6 +649,9 @@ class LoopVectorizationLegality {
/// may work on the same memory location.
SmallVector<HistogramInfo, 1> Histograms;

/// Hold all loads that need to be speculative.
SmallPtrSet<const Instruction *, 4> SpeculativeLoads;

/// BFI and PSI are used to check for profile guided size optimizations.
BlockFrequencyInfo *BFI;
ProfileSummaryInfo *PSI;
Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/Analysis/Loads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,3 +862,19 @@ bool llvm::isDereferenceableReadOnlyLoop(
}
return true;
}

bool llvm::isReadOnlyLoopWithSafeOrSpeculativeLoads(
Loop *L, ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC,
SmallVectorImpl<LoadInst *> *SpeculativeLoads,
SmallVectorImpl<const SCEVPredicate *> *Predicates) {
for (BasicBlock *BB : L->blocks()) {
for (Instruction &I : *BB) {
if (auto *LI = dyn_cast<LoadInst>(&I)) {
if (!isDereferenceableAndAlignedInLoop(LI, L, *SE, *DT, AC, Predicates))
SpeculativeLoads->push_back(LI);
} else if (I.mayReadFromMemory() || I.mayWriteToMemory() || I.mayThrow())
return false;
}
}
return true;
}
4 changes: 4 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,10 @@ bool TargetTransformInfo::supportsScalableVectors() const {
return TTIImpl->supportsScalableVectors();
}

bool TargetTransformInfo::supportsSpeculativeLoads() const {
return TTIImpl->supportsSpeculativeLoads();
}

bool TargetTransformInfo::enableScalableVectorization() const {
return TTIImpl->enableScalableVectorization();
}
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
void SplitVecRes_INSERT_VECTOR_ELT(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_LOAD(LoadSDNode *LD, SDValue &Lo, SDValue &Hi);
void SplitVecRes_VP_LOAD(VPLoadSDNode *LD, SDValue &Lo, SDValue &Hi);
void SplitVecRes_VP_LOAD_FF(VPLoadFFSDNode *LD, SDValue &Lo, SDValue &Hi);
void SplitVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *SLD, SDValue &Lo,
SDValue &Hi);
void SplitVecRes_MLOAD(MaskedLoadSDNode *MLD, SDValue &Lo, SDValue &Hi);
Expand Down Expand Up @@ -1075,6 +1076,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue WidenVecRes_INSERT_VECTOR_ELT(SDNode* N);
SDValue WidenVecRes_LOAD(SDNode* N);
SDValue WidenVecRes_VP_LOAD(VPLoadSDNode *N);
SDValue WidenVecRes_VP_LOAD_FF(VPLoadFFSDNode *N);
SDValue WidenVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *N);
SDValue WidenVecRes_VECTOR_COMPRESS(SDNode *N);
SDValue WidenVecRes_MLOAD(MaskedLoadSDNode* N);
Expand Down
74 changes: 74 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::VP_LOAD:
SplitVecRes_VP_LOAD(cast<VPLoadSDNode>(N), Lo, Hi);
break;
case ISD::VP_LOAD_FF:
SplitVecRes_VP_LOAD_FF(cast<VPLoadFFSDNode>(N), Lo, Hi);
break;
case ISD::EXPERIMENTAL_VP_STRIDED_LOAD:
SplitVecRes_VP_STRIDED_LOAD(cast<VPStridedLoadSDNode>(N), Lo, Hi);
break;
Expand Down Expand Up @@ -2227,6 +2230,51 @@ void DAGTypeLegalizer::SplitVecRes_VP_LOAD(VPLoadSDNode *LD, SDValue &Lo,
ReplaceValueWith(SDValue(LD, 1), Ch);
}

void DAGTypeLegalizer::SplitVecRes_VP_LOAD_FF(VPLoadFFSDNode *LD, SDValue &Lo,
SDValue &Hi) {
EVT LoVT, HiVT;
SDLoc dl(LD);
std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(LD->getValueType(0));

SDValue Ch = LD->getChain();
SDValue Ptr = LD->getBasePtr();
Align Alignment = LD->getBaseAlign();
SDValue Mask = LD->getMask();
SDValue EVL = LD->getVectorLength();
EVT MemoryVT = LD->getMemoryVT();

bool HiIsEmpty = false;
auto [LoMemVT, HiMemVT] =
DAG.GetDependentSplitDestVTs(MemoryVT, LoVT, &HiIsEmpty);

// Split Mask operand
SDValue MaskLo, MaskHi;
if (Mask.getOpcode() == ISD::SETCC) {
SplitVecRes_SETCC(Mask.getNode(), MaskLo, MaskHi);
} else {
if (getTypeAction(Mask.getValueType()) == TargetLowering::TypeSplitVector)
GetSplitVector(Mask, MaskLo, MaskHi);
else
std::tie(MaskLo, MaskHi) = DAG.SplitVector(Mask, dl);
}

// Split EVL operand
auto [EVLLo, EVLHi] = DAG.SplitEVL(EVL, LD->getValueType(0), dl);

MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
LD->getPointerInfo(), MachineMemOperand::MOLoad,
LocationSize::beforeOrAfterPointer(), Alignment, LD->getAAInfo(),
LD->getRanges());

Lo = DAG.getLoadFFVP(LoVT, dl, Ch, Ptr, MaskLo, EVLLo, MMO);

// Fill the upper half with poison.
Hi = DAG.getUNDEF(HiVT);

ReplaceValueWith(SDValue(LD, 1), Lo.getValue(1));
ReplaceValueWith(SDValue(LD, 2), Lo.getValue(2));
}

void DAGTypeLegalizer::SplitVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *SLD,
SDValue &Lo, SDValue &Hi) {
assert(SLD->isUnindexed() &&
Expand Down Expand Up @@ -4707,6 +4755,9 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
case ISD::VP_LOAD:
Res = WidenVecRes_VP_LOAD(cast<VPLoadSDNode>(N));
break;
case ISD::VP_LOAD_FF:
Res = WidenVecRes_VP_LOAD_FF(cast<VPLoadFFSDNode>(N));
break;
case ISD::EXPERIMENTAL_VP_STRIDED_LOAD:
Res = WidenVecRes_VP_STRIDED_LOAD(cast<VPStridedLoadSDNode>(N));
break;
Expand Down Expand Up @@ -6163,6 +6214,29 @@ SDValue DAGTypeLegalizer::WidenVecRes_VP_LOAD(VPLoadSDNode *N) {
return Res;
}

SDValue DAGTypeLegalizer::WidenVecRes_VP_LOAD_FF(VPLoadFFSDNode *N) {
EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
SDValue Mask = N->getMask();
SDValue EVL = N->getVectorLength();
SDLoc dl(N);

// The mask should be widened as well
assert(getTypeAction(Mask.getValueType()) ==
TargetLowering::TypeWidenVector &&
"Unable to widen binary VP op");
Mask = GetWidenedVector(Mask);
assert(Mask.getValueType().getVectorElementCount() ==
TLI.getTypeToTransformTo(*DAG.getContext(), Mask.getValueType())
.getVectorElementCount() &&
"Unable to widen vector load");

SDValue Res = DAG.getLoadFFVP(WidenVT, dl, N->getChain(), N->getBasePtr(),
Mask, EVL, N->getMemOperand());
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
ReplaceValueWith(SDValue(N, 2), Res.getValue(2));
return Res;
}

SDValue DAGTypeLegalizer::WidenVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *N) {
SDLoc DL(N);

Expand Down
36 changes: 36 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,14 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) {
ID.AddInteger(ELD->getMemOperand()->getFlags());
break;
}
case ISD::VP_LOAD_FF: {
const auto *LD = cast<VPLoadFFSDNode>(N);
ID.AddInteger(LD->getMemoryVT().getRawBits());
ID.AddInteger(LD->getRawSubclassData());
ID.AddInteger(LD->getPointerInfo().getAddrSpace());
ID.AddInteger(LD->getMemOperand()->getFlags());
break;
}
case ISD::VP_STORE: {
const VPStoreSDNode *EST = cast<VPStoreSDNode>(N);
ID.AddInteger(EST->getMemoryVT().getRawBits());
Expand Down Expand Up @@ -10393,6 +10401,34 @@ SDValue SelectionDAG::getMaskedHistogram(SDVTList VTs, EVT MemVT,
return V;
}

SDValue SelectionDAG::getLoadFFVP(EVT VT, const SDLoc &dl, SDValue Chain,
SDValue Ptr, SDValue Mask, SDValue EVL,
MachineMemOperand *MMO) {
SDVTList VTs = getVTList(VT, EVL.getValueType(), MVT::Other);
SDValue Ops[] = {Chain, Ptr, Mask, EVL};
FoldingSetNodeID ID;
AddNodeIDNode(ID, ISD::VP_LOAD_FF, VTs, Ops);
ID.AddInteger(VT.getRawBits());
ID.AddInteger(getSyntheticNodeSubclassData<VPLoadFFSDNode>(dl.getIROrder(),
VTs, VT, MMO));
ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
ID.AddInteger(MMO->getFlags());
void *IP = nullptr;
if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
cast<VPLoadFFSDNode>(E)->refineAlignment(MMO);
return SDValue(E, 0);
}
auto *N = newSDNode<VPLoadFFSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs,
VT, MMO);
createOperands(N, Ops);

CSEMap.InsertNode(N, IP);
InsertNode(N);
SDValue V(N, 0);
NewSDValueDbgMsg(V, "Creating new node: ", this);
return V;
}

SDValue SelectionDAG::getGetFPEnv(SDValue Chain, const SDLoc &dl, SDValue Ptr,
EVT MemVT, MachineMemOperand *MMO) {
assert(Chain.getValueType() == MVT::Other && "Invalid chain type");
Expand Down
Loading