Skip to content

Commit 9d6ef18

Browse files
committed
[IA][RISCV] Recognize deinterleaved loads that could lower to strided segmented loads
1 parent c722014 commit 9d6ef18

File tree

11 files changed

+139
-486
lines changed

11 files changed

+139
-486
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3209,10 +3209,12 @@ class LLVM_ABI TargetLoweringBase {
32093209
/// \p Shuffles is the shufflevector list to DE-interleave the loaded vector.
32103210
/// \p Indices is the corresponding indices for each shufflevector.
32113211
/// \p Factor is the interleave factor.
3212+
/// \p MaskFactor is the interleave factor that considers mask, which can
3213+
/// reduce the original factor.
32123214
virtual bool lowerInterleavedLoad(Instruction *Load, Value *Mask,
32133215
ArrayRef<ShuffleVectorInst *> Shuffles,
3214-
ArrayRef<unsigned> Indices,
3215-
unsigned Factor) const {
3216+
ArrayRef<unsigned> Indices, unsigned Factor,
3217+
unsigned MaskFactor) const {
32163218
return false;
32173219
}
32183220

llvm/lib/CodeGen/InterleavedAccessPass.cpp

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -268,13 +268,19 @@ static Value *getMaskOperand(IntrinsicInst *II) {
268268
}
269269
}
270270

271-
// Return the corresponded deinterleaved mask, or nullptr if there is no valid
272-
// mask.
273-
static Value *getMask(Value *WideMask, unsigned Factor,
274-
ElementCount LeafValueEC);
275-
276-
static Value *getMask(Value *WideMask, unsigned Factor,
277-
VectorType *LeafValueTy) {
271+
// Return a pair of
272+
// (1) The corresponded deinterleaved mask, or nullptr if there is no valid
273+
// mask.
274+
// (2) Some mask effectively skips a certain field, this element contains
275+
// the factor after taking such contraction into consideration. Note that
276+
// currently we only support skipping trailing fields. So if the "nominal"
277+
// factor was 5, you cannot only skip field 1 and 2, but you can skip field 3
278+
// and 4.
279+
static std::pair<Value *, unsigned> getMask(Value *WideMask, unsigned Factor,
280+
ElementCount LeafValueEC);
281+
282+
static std::pair<Value *, unsigned> getMask(Value *WideMask, unsigned Factor,
283+
VectorType *LeafValueTy) {
278284
return getMask(WideMask, Factor, LeafValueTy->getElementCount());
279285
}
280286

@@ -379,22 +385,25 @@ bool InterleavedAccessImpl::lowerInterleavedLoad(
379385
replaceBinOpShuffles(BinOpShuffles.getArrayRef(), Shuffles, Load);
380386

381387
Value *Mask = nullptr;
388+
unsigned MaskFactor = Factor;
382389
if (LI) {
383390
LLVM_DEBUG(dbgs() << "IA: Found an interleaved load: " << *Load << "\n");
384391
} else {
385392
// Check mask operand. Handle both all-true/false and interleaved mask.
386-
Mask = getMask(getMaskOperand(II), Factor, VecTy);
393+
std::tie(Mask, MaskFactor) = getMask(getMaskOperand(II), Factor, VecTy);
387394
if (!Mask)
388395
return false;
389396

390397
LLVM_DEBUG(dbgs() << "IA: Found an interleaved vp.load or masked.load: "
391398
<< *Load << "\n");
399+
LLVM_DEBUG(dbgs() << "IA: With nominal factor " << Factor
400+
<< " and mask factor " << MaskFactor << "\n");
392401
}
393402

394403
// Try to create target specific intrinsics to replace the load and
395404
// shuffles.
396405
if (!TLI->lowerInterleavedLoad(cast<Instruction>(Load), Mask, Shuffles,
397-
Indices, Factor))
406+
Indices, Factor, MaskFactor))
398407
// If Extracts is not empty, tryReplaceExtracts made changes earlier.
399408
return !Extracts.empty() || BinOpShuffleChanged;
400409

@@ -536,8 +545,8 @@ bool InterleavedAccessImpl::lowerInterleavedStore(
536545
} else {
537546
// Check mask operand. Handle both all-true/false and interleaved mask.
538547
unsigned LaneMaskLen = NumStoredElements / Factor;
539-
Mask = getMask(getMaskOperand(II), Factor,
540-
ElementCount::getFixed(LaneMaskLen));
548+
std::tie(Mask, std::ignore) = getMask(getMaskOperand(II), Factor,
549+
ElementCount::getFixed(LaneMaskLen));
541550
if (!Mask)
542551
return false;
543552

@@ -556,34 +565,57 @@ bool InterleavedAccessImpl::lowerInterleavedStore(
556565
return true;
557566
}
558567

559-
static Value *getMask(Value *WideMask, unsigned Factor,
560-
ElementCount LeafValueEC) {
568+
static std::pair<Value *, unsigned> getMask(Value *WideMask, unsigned Factor,
569+
ElementCount LeafValueEC) {
561570
if (auto *IMI = dyn_cast<IntrinsicInst>(WideMask)) {
562571
if (unsigned F = getInterleaveIntrinsicFactor(IMI->getIntrinsicID());
563572
F && F == Factor && llvm::all_equal(IMI->args())) {
564-
return IMI->getArgOperand(0);
573+
return {IMI->getArgOperand(0), Factor};
565574
}
566575
}
567576

568577
if (auto *ConstMask = dyn_cast<Constant>(WideMask)) {
569578
if (auto *Splat = ConstMask->getSplatValue())
570579
// All-ones or all-zeros mask.
571-
return ConstantVector::getSplat(LeafValueEC, Splat);
580+
return {ConstantVector::getSplat(LeafValueEC, Splat), Factor};
572581

573582
if (LeafValueEC.isFixed()) {
574583
unsigned LeafMaskLen = LeafValueEC.getFixedValue();
584+
// First, check if the mask completely skips some of the factors / fields.
585+
APInt FactorMask(Factor, 0);
586+
FactorMask.setAllBits();
587+
for (unsigned F = 0U; F < Factor; ++F) {
588+
unsigned Idx;
589+
for (Idx = 0U; Idx < LeafMaskLen; ++Idx) {
590+
Constant *C = ConstMask->getAggregateElement(F + Idx * Factor);
591+
if (!C->isZeroValue())
592+
break;
593+
}
594+
// All mask bits on this field are zero, skipping it.
595+
if (Idx >= LeafMaskLen)
596+
FactorMask.clearBit(F);
597+
}
598+
// We currently only support skipping "trailing" factors / fields. So
599+
// given the original factor being 4, we can skip fields 2 and 3, but we
600+
// cannot only skip fields 1 and 2. If FactorMask does not match such
601+
// pattern, reset it.
602+
if (!FactorMask.isMask())
603+
FactorMask.setAllBits();
604+
575605
SmallVector<Constant *, 8> LeafMask(LeafMaskLen, nullptr);
576606
// If this is a fixed-length constant mask, each lane / leaf has to
577607
// use the same mask. This is done by checking if every group with Factor
578608
// number of elements in the interleaved mask has homogeneous values.
579609
for (unsigned Idx = 0U; Idx < LeafMaskLen * Factor; ++Idx) {
610+
if (!FactorMask[Idx % Factor])
611+
continue;
580612
Constant *C = ConstMask->getAggregateElement(Idx);
581613
if (LeafMask[Idx / Factor] && LeafMask[Idx / Factor] != C)
582-
return nullptr;
614+
return {nullptr, Factor};
583615
LeafMask[Idx / Factor] = C;
584616
}
585617

586-
return ConstantVector::get(LeafMask);
618+
return {ConstantVector::get(LeafMask), FactorMask.popcount()};
587619
}
588620
}
589621

@@ -603,12 +635,13 @@ static Value *getMask(Value *WideMask, unsigned Factor,
603635
auto *LeafMaskTy =
604636
VectorType::get(Type::getInt1Ty(SVI->getContext()), LeafValueEC);
605637
IRBuilder<> Builder(SVI);
606-
return Builder.CreateExtractVector(LeafMaskTy, SVI->getOperand(0),
607-
uint64_t(0));
638+
return {Builder.CreateExtractVector(LeafMaskTy, SVI->getOperand(0),
639+
uint64_t(0)),
640+
Factor};
608641
}
609642
}
610643

611-
return nullptr;
644+
return {nullptr, Factor};
612645
}
613646

614647
bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic(
@@ -639,7 +672,8 @@ bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic(
639672
return false;
640673

641674
// Check mask operand. Handle both all-true/false and interleaved mask.
642-
Mask = getMask(getMaskOperand(II), Factor, getDeinterleavedVectorType(DI));
675+
std::tie(Mask, std::ignore) =
676+
getMask(getMaskOperand(II), Factor, getDeinterleavedVectorType(DI));
643677
if (!Mask)
644678
return false;
645679

@@ -680,8 +714,9 @@ bool InterleavedAccessImpl::lowerInterleaveIntrinsic(
680714
II->getIntrinsicID() != Intrinsic::vp_store)
681715
return false;
682716
// Check mask operand. Handle both all-true/false and interleaved mask.
683-
Mask = getMask(getMaskOperand(II), Factor,
684-
cast<VectorType>(InterleaveValues[0]->getType()));
717+
std::tie(Mask, std::ignore) =
718+
getMask(getMaskOperand(II), Factor,
719+
cast<VectorType>(InterleaveValues[0]->getType()));
685720
if (!Mask)
686721
return false;
687722

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17254,7 +17254,7 @@ static Function *getStructuredStoreFunction(Module *M, unsigned Factor,
1725417254
/// %vec1 = extractelement { <4 x i32>, <4 x i32> } %ld2, i32 1
1725517255
bool AArch64TargetLowering::lowerInterleavedLoad(
1725617256
Instruction *Load, Value *Mask, ArrayRef<ShuffleVectorInst *> Shuffles,
17257-
ArrayRef<unsigned> Indices, unsigned Factor) const {
17257+
ArrayRef<unsigned> Indices, unsigned Factor, unsigned MaskFactor) const {
1725817258
assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() &&
1725917259
"Invalid interleave factor");
1726017260
assert(!Shuffles.empty() && "Empty shufflevector input");

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,8 @@ class AArch64TargetLowering : public TargetLowering {
220220

221221
bool lowerInterleavedLoad(Instruction *Load, Value *Mask,
222222
ArrayRef<ShuffleVectorInst *> Shuffles,
223-
ArrayRef<unsigned> Indices,
224-
unsigned Factor) const override;
223+
ArrayRef<unsigned> Indices, unsigned Factor,
224+
unsigned MaskFactor) const override;
225225
bool lowerInterleavedStore(Instruction *Store, Value *Mask,
226226
ShuffleVectorInst *SVI,
227227
unsigned Factor) const override;

llvm/lib/Target/ARM/ARMISelLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21599,7 +21599,7 @@ unsigned ARMTargetLowering::getMaxSupportedInterleaveFactor() const {
2159921599
/// %vec1 = extractelement { <4 x i32>, <4 x i32> } %vld2, i32 1
2160021600
bool ARMTargetLowering::lowerInterleavedLoad(
2160121601
Instruction *Load, Value *Mask, ArrayRef<ShuffleVectorInst *> Shuffles,
21602-
ArrayRef<unsigned> Indices, unsigned Factor) const {
21602+
ArrayRef<unsigned> Indices, unsigned Factor, unsigned MaskFactor) const {
2160321603
assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() &&
2160421604
"Invalid interleave factor");
2160521605
assert(!Shuffles.empty() && "Empty shufflevector input");

llvm/lib/Target/ARM/ARMISelLowering.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,8 +683,8 @@ class VectorType;
683683

684684
bool lowerInterleavedLoad(Instruction *Load, Value *Mask,
685685
ArrayRef<ShuffleVectorInst *> Shuffles,
686-
ArrayRef<unsigned> Indices,
687-
unsigned Factor) const override;
686+
ArrayRef<unsigned> Indices, unsigned Factor,
687+
unsigned MaskFactor) const override;
688688
bool lowerInterleavedStore(Instruction *Store, Value *Mask,
689689
ShuffleVectorInst *SVI,
690690
unsigned Factor) const override;

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,8 @@ class RISCVTargetLowering : public TargetLowering {
431431

432432
bool lowerInterleavedLoad(Instruction *Load, Value *Mask,
433433
ArrayRef<ShuffleVectorInst *> Shuffles,
434-
ArrayRef<unsigned> Indices,
435-
unsigned Factor) const override;
434+
ArrayRef<unsigned> Indices, unsigned Factor,
435+
unsigned MaskFactor) const override;
436436

437437
bool lowerInterleavedStore(Instruction *Store, Value *Mask,
438438
ShuffleVectorInst *SVI,

llvm/lib/Target/RISCV/RISCVInterleavedAccess.cpp

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ static const Intrinsic::ID FixedVlsegIntrIds[] = {
6363
Intrinsic::riscv_seg6_load_mask, Intrinsic::riscv_seg7_load_mask,
6464
Intrinsic::riscv_seg8_load_mask};
6565

66+
static const Intrinsic::ID FixedVlssegIntrIds[] = {
67+
Intrinsic::riscv_sseg2_load_mask, Intrinsic::riscv_sseg3_load_mask,
68+
Intrinsic::riscv_sseg4_load_mask, Intrinsic::riscv_sseg5_load_mask,
69+
Intrinsic::riscv_sseg6_load_mask, Intrinsic::riscv_sseg7_load_mask,
70+
Intrinsic::riscv_sseg8_load_mask};
71+
6672
static const Intrinsic::ID ScalableVlsegIntrIds[] = {
6773
Intrinsic::riscv_vlseg2_mask, Intrinsic::riscv_vlseg3_mask,
6874
Intrinsic::riscv_vlseg4_mask, Intrinsic::riscv_vlseg5_mask,
@@ -197,9 +203,13 @@ static bool getMemOperands(unsigned Factor, VectorType *VTy, Type *XLenTy,
197203
/// %vec1 = extractelement { <4 x i32>, <4 x i32> } %ld2, i32 1
198204
bool RISCVTargetLowering::lowerInterleavedLoad(
199205
Instruction *Load, Value *Mask, ArrayRef<ShuffleVectorInst *> Shuffles,
200-
ArrayRef<unsigned> Indices, unsigned Factor) const {
206+
ArrayRef<unsigned> Indices, unsigned Factor, unsigned MaskFactor) const {
201207
assert(Indices.size() == Shuffles.size());
208+
assert(MaskFactor <= Factor);
202209

210+
// TODO: Lower to strided load when MaskFactor = 1.
211+
if (MaskFactor < 2)
212+
return false;
203213
IRBuilder<> Builder(Load);
204214

205215
const DataLayout &DL = Load->getDataLayout();
@@ -208,20 +218,37 @@ bool RISCVTargetLowering::lowerInterleavedLoad(
208218

209219
Value *Ptr, *VL;
210220
Align Alignment;
211-
if (!getMemOperands(Factor, VTy, XLenTy, Load, Ptr, Mask, VL, Alignment))
221+
if (!getMemOperands(MaskFactor, VTy, XLenTy, Load, Ptr, Mask, VL, Alignment))
212222
return false;
213223

214224
Type *PtrTy = Ptr->getType();
215225
unsigned AS = PtrTy->getPointerAddressSpace();
216-
if (!isLegalInterleavedAccessType(VTy, Factor, Alignment, AS, DL))
226+
if (!isLegalInterleavedAccessType(VTy, MaskFactor, Alignment, AS, DL))
217227
return false;
218228

219-
CallInst *VlsegN = Builder.CreateIntrinsic(
220-
FixedVlsegIntrIds[Factor - 2], {VTy, PtrTy, XLenTy}, {Ptr, Mask, VL});
229+
CallInst *SegLoad = nullptr;
230+
if (MaskFactor < Factor) {
231+
// Lower to strided segmented load.
232+
unsigned ScalarSizeInBytes = DL.getTypeStoreSize(VTy->getElementType());
233+
Value *Stride = ConstantInt::get(XLenTy, Factor * ScalarSizeInBytes);
234+
SegLoad = Builder.CreateIntrinsic(FixedVlssegIntrIds[MaskFactor - 2],
235+
{VTy, PtrTy, XLenTy, XLenTy},
236+
{Ptr, Stride, Mask, VL});
237+
} else {
238+
// Lower to normal segmented load.
239+
SegLoad = Builder.CreateIntrinsic(FixedVlsegIntrIds[Factor - 2],
240+
{VTy, PtrTy, XLenTy}, {Ptr, Mask, VL});
241+
}
221242

222243
for (unsigned i = 0; i < Shuffles.size(); i++) {
223-
Value *SubVec = Builder.CreateExtractValue(VlsegN, Indices[i]);
224-
Shuffles[i]->replaceAllUsesWith(SubVec);
244+
unsigned FactorIdx = Indices[i];
245+
if (FactorIdx >= MaskFactor) {
246+
// Replace masked-off factors (that are still extracted) with poison.
247+
Shuffles[i]->replaceAllUsesWith(PoisonValue::get(VTy));
248+
} else {
249+
Value *SubVec = Builder.CreateExtractValue(SegLoad, FactorIdx);
250+
Shuffles[i]->replaceAllUsesWith(SubVec);
251+
}
225252
}
226253

227254
return true;

llvm/lib/Target/X86/X86ISelLowering.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,8 +1663,8 @@ namespace llvm {
16631663
/// instructions/intrinsics.
16641664
bool lowerInterleavedLoad(Instruction *Load, Value *Mask,
16651665
ArrayRef<ShuffleVectorInst *> Shuffles,
1666-
ArrayRef<unsigned> Indices,
1667-
unsigned Factor) const override;
1666+
ArrayRef<unsigned> Indices, unsigned Factor,
1667+
unsigned MaskFactor) const override;
16681668

16691669
/// Lower interleaved store(s) into target specific
16701670
/// instructions/intrinsics.

llvm/lib/Target/X86/X86InterleavedAccess.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,7 @@ bool X86InterleavedAccessGroup::lowerIntoOptimizedSequence() {
802802
// Currently, lowering is supported for 4x64 bits with Factor = 4 on AVX.
803803
bool X86TargetLowering::lowerInterleavedLoad(
804804
Instruction *Load, Value *Mask, ArrayRef<ShuffleVectorInst *> Shuffles,
805-
ArrayRef<unsigned> Indices, unsigned Factor) const {
805+
ArrayRef<unsigned> Indices, unsigned Factor, unsigned MaskFactor) const {
806806
assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() &&
807807
"Invalid interleave factor");
808808
assert(!Shuffles.empty() && "Empty shufflevector input");

0 commit comments

Comments
 (0)