-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[IA][RISCV] Recognize deinterleaved loads that could lower to strided segmented loads #151612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c722014
9d6ef18
95f772e
7bb4ec3
f5507fb
8e4b79e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -268,13 +268,19 @@ static Value *getMaskOperand(IntrinsicInst *II) { | |
} | ||
} | ||
|
||
// Return the corresponded deinterleaved mask, or nullptr if there is no valid | ||
// mask. | ||
static Value *getMask(Value *WideMask, unsigned Factor, | ||
ElementCount LeafValueEC); | ||
|
||
static Value *getMask(Value *WideMask, unsigned Factor, | ||
VectorType *LeafValueTy) { | ||
// Return a pair of | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We talked about this offline, but I'm more and more coming to the view we should have made these a set of utility routines (usable by each target), and simply passed the mask operand through (or maybe not even that.) More of an aside for longer term consideration than a comment on this review. |
||
// (1) The corresponded deinterleaved mask, or nullptr if there is no valid | ||
// mask. | ||
// (2) Some mask effectively skips a certain field, this element contains | ||
// the factor after taking such contraction into consideration. Note that | ||
// currently we only support skipping trailing fields. So if the "nominal" | ||
// factor was 5, you cannot only skip field 1 and 2, but you can skip field 3 | ||
// and 4. | ||
static std::pair<Value *, unsigned> getMask(Value *WideMask, unsigned Factor, | ||
ElementCount LeafValueEC); | ||
|
||
static std::pair<Value *, unsigned> getMask(Value *WideMask, unsigned Factor, | ||
VectorType *LeafValueTy) { | ||
return getMask(WideMask, Factor, LeafValueTy->getElementCount()); | ||
} | ||
|
||
|
@@ -379,22 +385,25 @@ bool InterleavedAccessImpl::lowerInterleavedLoad( | |
replaceBinOpShuffles(BinOpShuffles.getArrayRef(), Shuffles, Load); | ||
|
||
Value *Mask = nullptr; | ||
unsigned GapMaskFactor = Factor; | ||
if (LI) { | ||
LLVM_DEBUG(dbgs() << "IA: Found an interleaved load: " << *Load << "\n"); | ||
} else { | ||
// Check mask operand. Handle both all-true/false and interleaved mask. | ||
Mask = getMask(getMaskOperand(II), Factor, VecTy); | ||
std::tie(Mask, GapMaskFactor) = getMask(getMaskOperand(II), Factor, VecTy); | ||
if (!Mask) | ||
return false; | ||
|
||
LLVM_DEBUG(dbgs() << "IA: Found an interleaved vp.load or masked.load: " | ||
<< *Load << "\n"); | ||
LLVM_DEBUG(dbgs() << "IA: With nominal factor " << Factor | ||
<< " and mask factor " << GapMaskFactor << "\n"); | ||
} | ||
|
||
// Try to create target specific intrinsics to replace the load and | ||
// shuffles. | ||
if (!TLI->lowerInterleavedLoad(cast<Instruction>(Load), Mask, Shuffles, | ||
Indices, Factor)) | ||
Indices, Factor, GapMaskFactor)) | ||
// If Extracts is not empty, tryReplaceExtracts made changes earlier. | ||
return !Extracts.empty() || BinOpShuffleChanged; | ||
|
||
|
@@ -531,15 +540,20 @@ bool InterleavedAccessImpl::lowerInterleavedStore( | |
"number of stored element should be a multiple of Factor"); | ||
|
||
Value *Mask = nullptr; | ||
unsigned GapMaskFactor = Factor; | ||
if (SI) { | ||
LLVM_DEBUG(dbgs() << "IA: Found an interleaved store: " << *Store << "\n"); | ||
} else { | ||
// Check mask operand. Handle both all-true/false and interleaved mask. | ||
unsigned LaneMaskLen = NumStoredElements / Factor; | ||
Mask = getMask(getMaskOperand(II), Factor, | ||
ElementCount::getFixed(LaneMaskLen)); | ||
std::tie(Mask, GapMaskFactor) = getMask( | ||
getMaskOperand(II), Factor, ElementCount::getFixed(LaneMaskLen)); | ||
if (!Mask) | ||
return false; | ||
// We shouldn't transform stores even it has a gap mask. And since we might | ||
// already change the IR, we're returning true here. | ||
if (GapMaskFactor != Factor) | ||
return true; | ||
|
||
LLVM_DEBUG(dbgs() << "IA: Found an interleaved vp.store or masked.store: " | ||
<< *Store << "\n"); | ||
|
@@ -556,34 +570,87 @@ bool InterleavedAccessImpl::lowerInterleavedStore( | |
return true; | ||
} | ||
|
||
static Value *getMask(Value *WideMask, unsigned Factor, | ||
ElementCount LeafValueEC) { | ||
// A wide mask <1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0> could be used to skip the | ||
// last field in a factor-of-three interleaved store or deinterleaved load (in | ||
// which case LeafMaskLen is 4). Such (wide) mask is also known as gap mask. | ||
// This helper function tries to detect this pattern and return the actual | ||
// factor we're accessing, which is 2 in this example. | ||
static unsigned getGapMaskFactor(const Constant &MaskConst, unsigned Factor, | ||
unsigned LeafMaskLen) { | ||
APInt FactorMask(Factor, 0); | ||
FactorMask.setAllBits(); | ||
for (unsigned F = 0U; F < Factor; ++F) { | ||
bool AllZero = true; | ||
for (unsigned Idx = 0U; Idx < LeafMaskLen; ++Idx) { | ||
Constant *C = MaskConst.getAggregateElement(F + Idx * Factor); | ||
if (!C->isZeroValue()) { | ||
AllZero = false; | ||
break; | ||
} | ||
} | ||
// All mask bits on this field are zero, skipping it. | ||
if (AllZero) | ||
FactorMask.clearBit(F); | ||
} | ||
// We currently only allow gaps in the "trailing" factors / fields. So | ||
// given the original factor being 4, we can skip fields 2 and 3, but we | ||
// cannot only skip fields 1 and 2. If FactorMask does not match such | ||
// pattern, reset it. | ||
if (!FactorMask.isMask()) | ||
FactorMask.setAllBits(); | ||
|
||
return FactorMask.popcount(); | ||
} | ||
|
||
static std::pair<Value *, unsigned> getMask(Value *WideMask, unsigned Factor, | ||
ElementCount LeafValueEC) { | ||
using namespace PatternMatch; | ||
|
||
if (auto *IMI = dyn_cast<IntrinsicInst>(WideMask)) { | ||
if (unsigned F = getInterleaveIntrinsicFactor(IMI->getIntrinsicID()); | ||
F && F == Factor && llvm::all_equal(IMI->args())) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can handle the case where the tail elements in the interleave are zero. Might be easier to start with this one, as it's the minimum code change. (This combines with my macro comment.) |
||
return IMI->getArgOperand(0); | ||
return {IMI->getArgOperand(0), Factor}; | ||
} | ||
} | ||
|
||
// Try to match `and <interleaved mask>, <gap mask>`. The WideMask here is | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be dropped and done as a follow up commit. |
||
// expected to be a fixed vector and gap mask should be a constant mask. | ||
Value *AndMaskLHS; | ||
Constant *AndMaskRHS; | ||
if (match(WideMask, m_c_And(m_Value(AndMaskLHS), m_Constant(AndMaskRHS))) && | ||
LeafValueEC.isFixed()) { | ||
assert(!isa<Constant>(AndMaskLHS) && | ||
"expect constants to be folded already"); | ||
return {getMask(AndMaskLHS, Factor, LeafValueEC).first, | ||
getGapMaskFactor(*AndMaskRHS, Factor, LeafValueEC.getFixedValue())}; | ||
} | ||
|
||
if (auto *ConstMask = dyn_cast<Constant>(WideMask)) { | ||
if (auto *Splat = ConstMask->getSplatValue()) | ||
// All-ones or all-zeros mask. | ||
return ConstantVector::getSplat(LeafValueEC, Splat); | ||
return {ConstantVector::getSplat(LeafValueEC, Splat), Factor}; | ||
|
||
if (LeafValueEC.isFixed()) { | ||
unsigned LeafMaskLen = LeafValueEC.getFixedValue(); | ||
// First, check if we use a gap mask to skip some of the factors / fields. | ||
const unsigned GapMaskFactor = | ||
getGapMaskFactor(*ConstMask, Factor, LeafMaskLen); | ||
assert(GapMaskFactor <= Factor); | ||
|
||
SmallVector<Constant *, 8> LeafMask(LeafMaskLen, nullptr); | ||
// If this is a fixed-length constant mask, each lane / leaf has to | ||
// use the same mask. This is done by checking if every group with Factor | ||
// number of elements in the interleaved mask has homogeneous values. | ||
for (unsigned Idx = 0U; Idx < LeafMaskLen * Factor; ++Idx) { | ||
if (Idx % Factor >= GapMaskFactor) | ||
continue; | ||
Constant *C = ConstMask->getAggregateElement(Idx); | ||
if (LeafMask[Idx / Factor] && LeafMask[Idx / Factor] != C) | ||
return nullptr; | ||
return {nullptr, Factor}; | ||
LeafMask[Idx / Factor] = C; | ||
} | ||
|
||
return ConstantVector::get(LeafMask); | ||
return {ConstantVector::get(LeafMask), GapMaskFactor}; | ||
} | ||
} | ||
|
||
|
@@ -603,12 +670,13 @@ static Value *getMask(Value *WideMask, unsigned Factor, | |
auto *LeafMaskTy = | ||
VectorType::get(Type::getInt1Ty(SVI->getContext()), LeafValueEC); | ||
IRBuilder<> Builder(SVI); | ||
return Builder.CreateExtractVector(LeafMaskTy, SVI->getOperand(0), | ||
uint64_t(0)); | ||
return {Builder.CreateExtractVector(LeafMaskTy, SVI->getOperand(0), | ||
uint64_t(0)), | ||
Factor}; | ||
} | ||
} | ||
|
||
return nullptr; | ||
return {nullptr, Factor}; | ||
} | ||
|
||
bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic( | ||
|
@@ -639,9 +707,12 @@ bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic( | |
return false; | ||
|
||
// Check mask operand. Handle both all-true/false and interleaved mask. | ||
Mask = getMask(getMaskOperand(II), Factor, getDeinterleavedVectorType(DI)); | ||
unsigned GapMaskFactor; | ||
std::tie(Mask, GapMaskFactor) = | ||
getMask(getMaskOperand(II), Factor, getDeinterleavedVectorType(DI)); | ||
if (!Mask) | ||
return false; | ||
assert(GapMaskFactor == Factor); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Took me a sec to figure out why this assert held, add a && "why this is true" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Originally I added this assertion because there is no way one can synthesize a gap mask for scalable vector. But reading this again, I realized this part of the code (vp.load/masked.load + deinterleave intrinsic) could also handle fixed vectors. So I'm going to turn this into a check instead. |
||
|
||
LLVM_DEBUG(dbgs() << "IA: Found a vp.load or masked.load with deinterleave" | ||
<< " intrinsic " << *DI << " and factor = " | ||
|
@@ -680,10 +751,13 @@ bool InterleavedAccessImpl::lowerInterleaveIntrinsic( | |
II->getIntrinsicID() != Intrinsic::vp_store) | ||
return false; | ||
// Check mask operand. Handle both all-true/false and interleaved mask. | ||
Mask = getMask(getMaskOperand(II), Factor, | ||
cast<VectorType>(InterleaveValues[0]->getType())); | ||
unsigned GapMaskFactor; | ||
std::tie(Mask, GapMaskFactor) = | ||
getMask(getMaskOperand(II), Factor, | ||
cast<VectorType>(InterleaveValues[0]->getType())); | ||
if (!Mask) | ||
return false; | ||
assert(GapMaskFactor == Factor); | ||
|
||
LLVM_DEBUG(dbgs() << "IA: Found a vp.store or masked.store with interleave" | ||
<< " intrinsic " << *IntII << " and factor = " | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17254,7 +17254,7 @@ static Function *getStructuredStoreFunction(Module *M, unsigned Factor, | |
/// %vec1 = extractelement { <4 x i32>, <4 x i32> } %ld2, i32 1 | ||
bool AArch64TargetLowering::lowerInterleavedLoad( | ||
Instruction *Load, Value *Mask, ArrayRef<ShuffleVectorInst *> Shuffles, | ||
ArrayRef<unsigned> Indices, unsigned Factor) const { | ||
ArrayRef<unsigned> Indices, unsigned Factor, unsigned MaskFactor) const { | ||
assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() && | ||
"Invalid interleave factor"); | ||
assert(!Shuffles.empty() && "Empty shufflevector input"); | ||
|
@@ -17266,6 +17266,9 @@ bool AArch64TargetLowering::lowerInterleavedLoad( | |
return false; | ||
assert(!Mask && "Unexpected mask on a load"); | ||
|
||
if (Factor != MaskFactor) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be an assert (same for most targets), since LoadInst isn't masked by definition. |
||
return false; | ||
|
||
const DataLayout &DL = LI->getDataLayout(); | ||
|
||
VectorType *VTy = Shuffles[0]->getType(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be a bit easier for the targets if we instead passed the stride in bytes in? That way they wouldn't have to worry about the difference between the MaskFactor and Factor.
Targets that don't support strided interleaved loads would check that
Stride == DL.getTypeStoreSize(VTy->getElementType())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe stride is relative to the current start address, so in the case of skipping fields, the stride will always be
Factor * DL.getTypeStoreSize(VTy->getElementType())
regardless of how many fields you wanna skip.But I guess my more high-level question would be: for those targets that don't support strided interleaved loads, what is the benefit of replacing a check between Factor and MaskFactor with another check on Stride ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh whoops yes, that should be multiplied by factor.
To me MaskFactor feels like a concept internal to InterleavedAccessPass that's leaking through.
I'm not strongly opinionated about this though, just thought I'd throw the idea out there, happy to go with what you prefer.
I guess an alternative is that we could also add a separate "lowerStridedInterleaved" TTI hook. But maybe that will lead to hook explosion again
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I believe you meant TLI hooks) Yeah I'm also worried about the fact that it will double the number of hooks, as all four of the them could have a strided version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An alternate suggestion: pass in GapMask as an APInt, then have the target filter out which set of gaps it can handle.