Skip to content

Commit f129a3c

Browse files
committed
[SelectionDAG][X86] Split <2 x T> vector types for atomic load
Vector types of 2 elements that aren't widened are split so that they can be vectorized within SelectionDAG. This change utilizes the load vectorization infrastructure in order to regroup the split elements. This enables SelectionDAG to translate vectors with type bfloat,half. commit-id:3a045357
1 parent e07e225 commit f129a3c

File tree

8 files changed

+147
-35
lines changed

8 files changed

+147
-35
lines changed

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1837,7 +1837,7 @@ class SelectionDAG {
18371837
/// chain to the token factor. This ensures that the new memory node will have
18381838
/// the same relative memory dependency position as the old load. Returns the
18391839
/// new merged load chain.
1840-
SDValue makeEquivalentMemoryOrdering(LoadSDNode *OldLoad, SDValue NewMemOp);
1840+
SDValue makeEquivalentMemoryOrdering(MemSDNode *OldLoad, SDValue NewMemOp);
18411841

18421842
/// Topological-sort the AllNodes list and a
18431843
/// assign a unique node id for each node in the DAG based on their
@@ -2263,6 +2263,8 @@ class SelectionDAG {
22632263
/// location that the 'Base' load is loading from.
22642264
bool areNonVolatileConsecutiveLoads(LoadSDNode *LD, LoadSDNode *Base,
22652265
unsigned Bytes, int Dist) const;
2266+
bool areNonVolatileConsecutiveLoads(AtomicSDNode *LD, AtomicSDNode *Base,
2267+
unsigned Bytes, int Dist) const;
22662268

22672269
/// Infer alignment of a load / store address. Return std::nullopt if it
22682270
/// cannot be inferred.

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -946,6 +946,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
946946
void SplitVecRes_FPOp_MultiType(SDNode *N, SDValue &Lo, SDValue &Hi);
947947
void SplitVecRes_IS_FPCLASS(SDNode *N, SDValue &Lo, SDValue &Hi);
948948
void SplitVecRes_INSERT_VECTOR_ELT(SDNode *N, SDValue &Lo, SDValue &Hi);
949+
void SplitVecRes_ATOMIC_LOAD(AtomicSDNode *LD);
949950
void SplitVecRes_LOAD(LoadSDNode *LD, SDValue &Lo, SDValue &Hi);
950951
void SplitVecRes_VP_LOAD(VPLoadSDNode *LD, SDValue &Lo, SDValue &Hi);
951952
void SplitVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *SLD, SDValue &Lo,

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,6 +1152,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
11521152
SplitVecRes_STEP_VECTOR(N, Lo, Hi);
11531153
break;
11541154
case ISD::SIGN_EXTEND_INREG: SplitVecRes_InregOp(N, Lo, Hi); break;
1155+
case ISD::ATOMIC_LOAD:
1156+
SplitVecRes_ATOMIC_LOAD(cast<AtomicSDNode>(N));
1157+
break;
11551158
case ISD::LOAD:
11561159
SplitVecRes_LOAD(cast<LoadSDNode>(N), Lo, Hi);
11571160
break;
@@ -1395,6 +1398,34 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
13951398
SetSplitVector(SDValue(N, ResNo), Lo, Hi);
13961399
}
13971400

1401+
void DAGTypeLegalizer::SplitVecRes_ATOMIC_LOAD(AtomicSDNode *LD) {
1402+
SDLoc dl(LD);
1403+
1404+
EVT MemoryVT = LD->getMemoryVT();
1405+
unsigned NumElts = MemoryVT.getVectorMinNumElements();
1406+
1407+
EVT IntMemoryVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, NumElts);
1408+
EVT ElemVT = EVT::getVectorVT(*DAG.getContext(),
1409+
MemoryVT.getVectorElementType(), 1);
1410+
1411+
// Create a single atomic to load all the elements at once.
1412+
SDValue Atomic = DAG.getAtomic(ISD::ATOMIC_LOAD, dl, IntMemoryVT, IntMemoryVT,
1413+
LD->getChain(), LD->getBasePtr(),
1414+
LD->getMemOperand());
1415+
1416+
// Instead of splitting, put all the elements back into a vector.
1417+
SmallVector<SDValue, 4> Ops;
1418+
for (unsigned i = 0; i < NumElts; ++i) {
1419+
SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i16, Atomic,
1420+
DAG.getVectorIdxConstant(i, dl));
1421+
Elt = DAG.getBitcast(ElemVT, Elt);
1422+
Ops.push_back(Elt);
1423+
}
1424+
SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, dl, MemoryVT, Ops);
1425+
1426+
ReplaceValueWith(SDValue(LD, 0), Concat);
1427+
}
1428+
13981429
void DAGTypeLegalizer::IncrementPointer(MemSDNode *N, EVT MemVT,
13991430
MachinePointerInfo &MPI, SDValue &Ptr,
14001431
uint64_t *ScaledOffset) {

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12161,7 +12161,7 @@ SDValue SelectionDAG::makeEquivalentMemoryOrdering(SDValue OldChain,
1216112161
return TokenFactor;
1216212162
}
1216312163

12164-
SDValue SelectionDAG::makeEquivalentMemoryOrdering(LoadSDNode *OldLoad,
12164+
SDValue SelectionDAG::makeEquivalentMemoryOrdering(MemSDNode *OldLoad,
1216512165
SDValue NewMemOp) {
1216612166
assert(isa<MemSDNode>(NewMemOp.getNode()) && "Expected a memop node");
1216712167
SDValue OldChain = SDValue(OldLoad, 1);
@@ -12873,13 +12873,33 @@ std::pair<SDValue, SDValue> SelectionDAG::UnrollVectorOverflowOp(
1287312873
getBuildVector(NewOvVT, dl, OvScalars));
1287412874
}
1287512875

12876+
bool SelectionDAG::areNonVolatileConsecutiveLoads(AtomicSDNode *LD,
12877+
AtomicSDNode *Base,
12878+
unsigned Bytes,
12879+
int Dist) const {
12880+
if (LD->isVolatile() || Base->isVolatile())
12881+
return false;
12882+
if (LD->getChain() != Base->getChain())
12883+
return false;
12884+
EVT VT = LD->getMemoryVT();
12885+
if (VT.getSizeInBits() / 8 != Bytes)
12886+
return false;
12887+
12888+
auto BaseLocDecomp = BaseIndexOffset::match(Base, *this);
12889+
auto LocDecomp = BaseIndexOffset::match(LD, *this);
12890+
12891+
int64_t Offset = 0;
12892+
if (BaseLocDecomp.equalBaseIndex(LocDecomp, *this, Offset))
12893+
return (Dist * (int64_t)Bytes == Offset);
12894+
return false;
12895+
}
12896+
1287612897
bool SelectionDAG::areNonVolatileConsecutiveLoads(LoadSDNode *LD,
1287712898
LoadSDNode *Base,
1287812899
unsigned Bytes,
1287912900
int Dist) const {
1288012901
if (LD->isVolatile() || Base->isVolatile())
1288112902
return false;
12882-
// TODO: probably too restrictive for atomics, revisit
1288312903
if (!LD->isSimple())
1288412904
return false;
1288512905
if (LD->isIndexed() || Base->isIndexed())

llvm/lib/CodeGen/SelectionDAG/SelectionDAGAddressAnalysis.cpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ bool BaseIndexOffset::contains(const SelectionDAG &DAG, int64_t BitSize,
194194
return false;
195195
}
196196

197-
/// Parses tree in Ptr for base, index, offset addresses.
198-
static BaseIndexOffset matchLSNode(const LSBaseSDNode *N,
197+
template <typename T>
198+
static BaseIndexOffset matchSDNode(const T *N,
199199
const SelectionDAG &DAG) {
200200
SDValue Ptr = N->getBasePtr();
201201

@@ -206,16 +206,18 @@ static BaseIndexOffset matchLSNode(const LSBaseSDNode *N,
206206
bool IsIndexSignExt = false;
207207

208208
// pre-inc/pre-dec ops are components of EA.
209-
if (N->getAddressingMode() == ISD::PRE_INC) {
210-
if (auto *C = dyn_cast<ConstantSDNode>(N->getOffset()))
211-
Offset += C->getSExtValue();
212-
else // If unknown, give up now.
213-
return BaseIndexOffset(SDValue(), SDValue(), 0, false);
214-
} else if (N->getAddressingMode() == ISD::PRE_DEC) {
215-
if (auto *C = dyn_cast<ConstantSDNode>(N->getOffset()))
216-
Offset -= C->getSExtValue();
217-
else // If unknown, give up now.
218-
return BaseIndexOffset(SDValue(), SDValue(), 0, false);
209+
if constexpr (std::is_same_v<T, LSBaseSDNode>) {
210+
if (N->getAddressingMode() == ISD::PRE_INC) {
211+
if (auto *C = dyn_cast<ConstantSDNode>(N->getOffset()))
212+
Offset += C->getSExtValue();
213+
else // If unknown, give up now.
214+
return BaseIndexOffset(SDValue(), SDValue(), 0, false);
215+
} else if (N->getAddressingMode() == ISD::PRE_DEC) {
216+
if (auto *C = dyn_cast<ConstantSDNode>(N->getOffset()))
217+
Offset -= C->getSExtValue();
218+
else // If unknown, give up now.
219+
return BaseIndexOffset(SDValue(), SDValue(), 0, false);
220+
}
219221
}
220222

221223
// Consume constant adds & ors with appropriate masking.
@@ -300,8 +302,10 @@ static BaseIndexOffset matchLSNode(const LSBaseSDNode *N,
300302

301303
BaseIndexOffset BaseIndexOffset::match(const SDNode *N,
302304
const SelectionDAG &DAG) {
305+
if (const auto *AN = dyn_cast<AtomicSDNode>(N))
306+
return matchSDNode(AN, DAG);
303307
if (const auto *LS0 = dyn_cast<LSBaseSDNode>(N))
304-
return matchLSNode(LS0, DAG);
308+
return matchSDNode(LS0, DAG);
305309
if (const auto *LN = dyn_cast<LifetimeSDNode>(N)) {
306310
if (LN->hasOffset())
307311
return BaseIndexOffset(LN->getOperand(1), SDValue(), LN->getOffset(),

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5218,7 +5218,11 @@ void SelectionDAGBuilder::visitAtomicLoad(const LoadInst &I) {
52185218
L = DAG.getPtrExtOrTrunc(L, dl, VT);
52195219

52205220
setValue(&I, L);
5221-
DAG.setRoot(OutChain);
5221+
5222+
if (VT.isVector())
5223+
DAG.setRoot(InChain);
5224+
else
5225+
DAG.setRoot(OutChain);
52225226
}
52235227

52245228
void SelectionDAGBuilder::visitAtomicStore(const StoreInst &I) {

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7061,14 +7061,23 @@ static SDValue LowerAsSplatVectorLoad(SDValue SrcOp, MVT VT, const SDLoc &dl,
70617061
}
70627062

70637063
// Recurse to find a LoadSDNode source and the accumulated ByteOffest.
7064-
static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
7065-
if (ISD::isNON_EXTLoad(Elt.getNode())) {
7066-
auto *BaseLd = cast<LoadSDNode>(Elt);
7067-
if (!BaseLd->isSimple())
7068-
return false;
7069-
Ld = BaseLd;
7070-
ByteOffset = 0;
7071-
return true;
7064+
template <typename T>
7065+
static bool findEltLoadSrc(SDValue Elt, T *&Ld, int64_t &ByteOffset) {
7066+
if constexpr (std::is_same_v<T, AtomicSDNode>) {
7067+
if (auto *BaseLd = dyn_cast<AtomicSDNode>(Elt)) {
7068+
Ld = BaseLd;
7069+
ByteOffset = 0;
7070+
return true;
7071+
}
7072+
} else if constexpr (std::is_same_v<T, LoadSDNode>) {
7073+
if (ISD::isNON_EXTLoad(Elt.getNode())) {
7074+
auto *BaseLd = cast<LoadSDNode>(Elt);
7075+
if (!BaseLd->isSimple())
7076+
return false;
7077+
Ld = BaseLd;
7078+
ByteOffset = 0;
7079+
return true;
7080+
}
70727081
}
70737082

70747083
switch (Elt.getOpcode()) {
@@ -7108,6 +7117,7 @@ static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
71087117
/// a build_vector or insert_subvector whose loaded operands are 'Elts'.
71097118
///
71107119
/// Example: <load i32 *a, load i32 *a+4, zero, undef> -> zextload a
7120+
template <typename T>
71117121
static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71127122
const SDLoc &DL, SelectionDAG &DAG,
71137123
const X86Subtarget &Subtarget,
@@ -7122,7 +7132,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71227132
APInt ZeroMask = APInt::getZero(NumElems);
71237133
APInt UndefMask = APInt::getZero(NumElems);
71247134

7125-
SmallVector<LoadSDNode*, 8> Loads(NumElems, nullptr);
7135+
SmallVector<T*, 8> Loads(NumElems, nullptr);
71267136
SmallVector<int64_t, 8> ByteOffsets(NumElems, 0);
71277137

71287138
// For each element in the initializer, see if we've found a load, zero or an
@@ -7172,7 +7182,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71727182
EVT EltBaseVT = EltBase.getValueType();
71737183
assert(EltBaseVT.getSizeInBits() == EltBaseVT.getStoreSizeInBits() &&
71747184
"Register/Memory size mismatch");
7175-
LoadSDNode *LDBase = Loads[FirstLoadedElt];
7185+
T *LDBase = Loads[FirstLoadedElt];
71767186
assert(LDBase && "Did not find base load for merging consecutive loads");
71777187
unsigned BaseSizeInBits = EltBaseVT.getStoreSizeInBits();
71787188
unsigned BaseSizeInBytes = BaseSizeInBits / 8;
@@ -7186,8 +7196,8 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71867196

71877197
// Check to see if the element's load is consecutive to the base load
71887198
// or offset from a previous (already checked) load.
7189-
auto CheckConsecutiveLoad = [&](LoadSDNode *Base, int EltIdx) {
7190-
LoadSDNode *Ld = Loads[EltIdx];
7199+
auto CheckConsecutiveLoad = [&](T *Base, int EltIdx) {
7200+
T *Ld = Loads[EltIdx];
71917201
int64_t ByteOffset = ByteOffsets[EltIdx];
71927202
if (ByteOffset && (ByteOffset % BaseSizeInBytes) == 0) {
71937203
int64_t BaseIdx = EltIdx - (ByteOffset / BaseSizeInBytes);
@@ -7215,7 +7225,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
72157225
}
72167226
}
72177227

7218-
auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, LoadSDNode *LDBase) {
7228+
auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, T *LDBase) {
72197229
auto MMOFlags = LDBase->getMemOperand()->getFlags();
72207230
assert(LDBase->isSimple() &&
72217231
"Cannot merge volatile or atomic loads.");
@@ -7285,7 +7295,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
72857295
EVT HalfVT =
72867296
EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(), HalfNumElems);
72877297
SDValue HalfLD =
7288-
EltsFromConsecutiveLoads(HalfVT, Elts.drop_back(HalfNumElems), DL,
7298+
EltsFromConsecutiveLoads<T>(HalfVT, Elts.drop_back(HalfNumElems), DL,
72897299
DAG, Subtarget, IsAfterLegalize);
72907300
if (HalfLD)
72917301
return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT),
@@ -7362,7 +7372,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
73627372
EVT::getVectorVT(*DAG.getContext(), RepeatVT.getScalarType(),
73637373
VT.getSizeInBits() / ScalarSize);
73647374
if (TLI.isTypeLegal(BroadcastVT)) {
7365-
if (SDValue RepeatLoad = EltsFromConsecutiveLoads(
7375+
if (SDValue RepeatLoad = EltsFromConsecutiveLoads<T>(
73667376
RepeatVT, RepeatedLoads, DL, DAG, Subtarget, IsAfterLegalize)) {
73677377
SDValue Broadcast = RepeatLoad;
73687378
if (RepeatSize > ScalarSize) {
@@ -7403,7 +7413,7 @@ static SDValue combineToConsecutiveLoads(EVT VT, SDValue Op, const SDLoc &DL,
74037413
return SDValue();
74047414
}
74057415
assert(Elts.size() == VT.getVectorNumElements());
7406-
return EltsFromConsecutiveLoads(VT, Elts, DL, DAG, Subtarget,
7416+
return EltsFromConsecutiveLoads<LoadSDNode>(VT, Elts, DL, DAG, Subtarget,
74077417
IsAfterLegalize);
74087418
}
74097419

@@ -9258,8 +9268,12 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
92589268
{
92599269
SmallVector<SDValue, 64> Ops(Op->op_begin(), Op->op_begin() + NumElems);
92609270
if (SDValue LD =
9261-
EltsFromConsecutiveLoads(VT, Ops, dl, DAG, Subtarget, false))
9271+
EltsFromConsecutiveLoads<LoadSDNode>(VT, Ops, dl, DAG, Subtarget, false)) {
92629272
return LD;
9273+
} else if (SDValue LD =
9274+
EltsFromConsecutiveLoads<AtomicSDNode>(VT, Ops, dl, DAG, Subtarget, false)) {
9275+
return LD;
9276+
}
92639277
}
92649278

92659279
// If this is a splat of pairs of 32-bit elements, we can use a narrower
@@ -57979,7 +57993,7 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
5797957993
*FirstLd->getMemOperand(), &Fast) &&
5798057994
Fast) {
5798157995
if (SDValue Ld =
57982-
EltsFromConsecutiveLoads(VT, Ops, DL, DAG, Subtarget, false))
57996+
EltsFromConsecutiveLoads<LoadSDNode>(VT, Ops, DL, DAG, Subtarget, false))
5798357997
return Ld;
5798457998
}
5798557999
}

llvm/test/CodeGen/X86/atomic-load-store.ll

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,24 @@ define <2 x float> @atomic_vec2_float_align(ptr %x) {
195195
ret <2 x float> %ret
196196
}
197197

198+
define <2 x half> @atomic_vec2_half(ptr %x) {
199+
; CHECK-LABEL: atomic_vec2_half:
200+
; CHECK: ## %bb.0:
201+
; CHECK-NEXT: movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
202+
; CHECK-NEXT: retq
203+
%ret = load atomic <2 x half>, ptr %x acquire, align 4
204+
ret <2 x half> %ret
205+
}
206+
207+
define <2 x bfloat> @atomic_vec2_bfloat(ptr %x) {
208+
; CHECK-LABEL: atomic_vec2_bfloat:
209+
; CHECK: ## %bb.0:
210+
; CHECK-NEXT: movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
211+
; CHECK-NEXT: retq
212+
%ret = load atomic <2 x bfloat>, ptr %x acquire, align 4
213+
ret <2 x bfloat> %ret
214+
}
215+
198216
define <1 x ptr> @atomic_vec1_ptr(ptr %x) nounwind {
199217
; CHECK3-LABEL: atomic_vec1_ptr:
200218
; CHECK3: ## %bb.0:
@@ -367,6 +385,24 @@ define <4 x i16> @atomic_vec4_i16(ptr %x) nounwind {
367385
ret <4 x i16> %ret
368386
}
369387

388+
define <4 x half> @atomic_vec4_half(ptr %x) nounwind {
389+
; CHECK-LABEL: atomic_vec4_half:
390+
; CHECK: ## %bb.0:
391+
; CHECK-NEXT: movsd {{.*#+}} xmm0 = mem[0],zero
392+
; CHECK-NEXT: retq
393+
%ret = load atomic <4 x half>, ptr %x acquire, align 8
394+
ret <4 x half> %ret
395+
}
396+
397+
define <4 x bfloat> @atomic_vec4_bfloat(ptr %x) nounwind {
398+
; CHECK-LABEL: atomic_vec4_bfloat:
399+
; CHECK: ## %bb.0:
400+
; CHECK-NEXT: movsd {{.*#+}} xmm0 = mem[0],zero
401+
; CHECK-NEXT: retq
402+
%ret = load atomic <4 x bfloat>, ptr %x acquire, align 8
403+
ret <4 x bfloat> %ret
404+
}
405+
370406
define <4 x float> @atomic_vec4_float_align(ptr %x) nounwind {
371407
; CHECK-LABEL: atomic_vec4_float_align:
372408
; CHECK: ## %bb.0:

0 commit comments

Comments
 (0)