Skip to content

Commit bb524c5

Browse files
committed
Use ConstantFoldLoadFromConst, and get i16 tables working as a result
1 parent bc33a14 commit bb524c5

File tree

2 files changed

+102
-38
lines changed

2 files changed

+102
-38
lines changed

llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp

Lines changed: 22 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -457,30 +457,20 @@ static bool foldSqrt(CallInst *Call, LibFunc Func, TargetTransformInfo &TTI,
457457

458458
// Check if this array of constants represents a cttz table.
459459
// Iterate over the elements from \p Table by trying to find/match all
460-
// the numbers from 0 to \p InputBits that should represent cttz results.
461-
static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul,
462-
uint64_t Shift, uint64_t InputBits) {
463-
unsigned Length = Table.getNumElements();
464-
if (Length < InputBits || Length > InputBits * 2)
465-
return false;
466-
467-
APInt Mask = APInt::getBitsSetFrom(InputBits, Shift);
468-
unsigned Matched = 0;
469-
470-
for (unsigned i = 0; i < Length; i++) {
471-
uint64_t Element = Table.getElementAsInteger(i);
472-
if (Element >= InputBits)
473-
continue;
474-
475-
// Check if \p Element matches a concrete answer. It could fail for some
476-
// elements that are never accessed, so we keep iterating over each element
477-
// from the table. The number of matched elements should be equal to the
478-
// number of potential right answers which is \p InputBits actually.
479-
if ((((Mul << Element) & Mask.getZExtValue()) >> Shift) == i)
480-
Matched++;
460+
// the numbers from 0 to \p InputTy->getSizeInBits() that should represent cttz
461+
// results.
462+
static bool isCTTZTable(Constant *Table, uint64_t Mul, uint64_t Shift,
463+
Type *AccessTy, unsigned InputBits,
464+
unsigned GEPIdxFactor, const DataLayout &DL) {
465+
for (unsigned Idx = 0; Idx < InputBits; Idx++) {
466+
APInt Index = (APInt(InputBits, 1ull << Idx) * Mul).lshr(Shift);
467+
ConstantInt *C = dyn_cast_or_null<ConstantInt>(
468+
ConstantFoldLoadFromConst(Table, AccessTy, Index * GEPIdxFactor, DL));
469+
if (!C || C->getZExtValue() != Idx)
470+
return false;
481471
}
482472

483-
return Matched == InputBits;
473+
return true;
484474
}
485475

486476
// Try to recognize table-based ctz implementation.
@@ -537,7 +527,7 @@ static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul,
537527
// %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
538528
//
539529
// All this can be lowered to @llvm.cttz.i32/64 intrinsic.
540-
static bool tryToRecognizeTableBasedCttz(Instruction &I) {
530+
static bool tryToRecognizeTableBasedCttz(Instruction &I, const DataLayout &DL) {
541531
LoadInst *LI = dyn_cast<LoadInst>(&I);
542532
if (!LI)
543533
return false;
@@ -567,11 +557,6 @@ static bool tryToRecognizeTableBasedCttz(Instruction &I) {
567557
if (!GVTable || !GVTable->hasInitializer() || !GVTable->isConstant())
568558
return false;
569559

570-
ConstantDataArray *ConstData =
571-
dyn_cast<ConstantDataArray>(GVTable->getInitializer());
572-
if (!ConstData || ConstData->getElementType() != GEPSrcEltTy)
573-
return false;
574-
575560
Value *X1;
576561
uint64_t MulConst, ShiftConst;
577562
// FIXME: 64-bit targets have `i64` type for the GEP index, so this match will
@@ -583,19 +568,21 @@ static bool tryToRecognizeTableBasedCttz(Instruction &I) {
583568
return false;
584569

585570
unsigned InputBits = X1->getType()->getScalarSizeInBits();
586-
if (InputBits != 32 && InputBits != 64)
571+
if (InputBits != 16 && InputBits != 32 && InputBits != 64)
587572
return false;
588573

589-
// Shift should extract top 5..7 bits.
574+
// Shift should extract top 4..7 bits.
590575
if (InputBits - Log2_32(InputBits) != ShiftConst &&
591576
InputBits - Log2_32(InputBits) - 1 != ShiftConst)
592577
return false;
593578

594-
if (!isCTTZTable(*ConstData, MulConst, ShiftConst, InputBits))
579+
if (!isCTTZTable(GVTable->getInitializer(), MulConst, ShiftConst, AccessType,
580+
InputBits, GEPSrcEltTy->getScalarSizeInBits() / 8, DL))
595581
return false;
596582

597-
auto ZeroTableElem = ConstData->getElementAsInteger(0);
598-
bool DefinedForZero = ZeroTableElem == InputBits;
583+
ConstantInt *ZeroTableElem = cast<ConstantInt>(
584+
ConstantFoldLoadFromConst(GVTable->getInitializer(), AccessType, DL));
585+
bool DefinedForZero = ZeroTableElem->getZExtValue() == InputBits;
599586

600587
IRBuilder<> B(LI);
601588
ConstantInt *BoolConst = B.getInt1(!DefinedForZero);
@@ -609,8 +596,7 @@ static bool tryToRecognizeTableBasedCttz(Instruction &I) {
609596
// If the value in elem 0 isn't the same as InputBits, we still want to
610597
// produce the value from the table.
611598
auto Cmp = B.CreateICmpEQ(X1, ConstantInt::get(XType, 0));
612-
auto Select =
613-
B.CreateSelect(Cmp, ConstantInt::get(XType, ZeroTableElem), Cttz);
599+
auto Select = B.CreateSelect(Cmp, B.CreateZExt(ZeroTableElem, XType), Cttz);
614600

615601
// NOTE: If the table[0] is 0, but the cttz(0) is defined by the Target
616602
// it should be handled as: `cttz(x) & (typeSize - 1)`.
@@ -1479,7 +1465,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
14791465
MadeChange |= foldGuardedFunnelShift(I, DT);
14801466
MadeChange |= tryToRecognizePopCount(I);
14811467
MadeChange |= tryToFPToSat(I, TTI);
1482-
MadeChange |= tryToRecognizeTableBasedCttz(I);
1468+
MadeChange |= tryToRecognizeTableBasedCttz(I, DL);
14831469
MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT);
14841470
MadeChange |= foldPatternedLoads(I, DL);
14851471
MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT);

llvm/test/Transforms/AggressiveInstCombine/lower-table-based-cttz-basics.ll

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ entry:
299299
ret i32 %conv
300300
}
301301

302+
; This is the same a ctz2 (i16 table) with an i8 gep making the indices invalid
302303
define i32 @ctz2_with_i8_gep(i32 %x) {
303304
; CHECK-LABEL: @ctz2_with_i8_gep(
304305
; CHECK-NEXT: entry:
@@ -308,7 +309,7 @@ define i32 @ctz2_with_i8_gep(i32 %x) {
308309
; CHECK-NEXT: [[SHR:%.*]] = lshr i32 [[MUL]], 26
309310
; CHECK-NEXT: [[IDXPROM:%.*]] = zext i32 [[SHR]] to i64
310311
; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds [64 x i8], ptr @ctz2.table, i64 0, i64 [[IDXPROM]]
311-
; CHECK-NEXT: [[TMP0:%.*]] = load i16, ptr [[ARRAYIDX]], align 2
312+
; CHECK-NEXT: [[TMP0:%.*]] = load i16, ptr [[ARRAYIDX]], align 1
312313
; CHECK-NEXT: [[CONV:%.*]] = sext i16 [[TMP0]] to i32
313314
; CHECK-NEXT: ret i32 [[CONV]]
314315
;
@@ -319,7 +320,84 @@ entry:
319320
%shr = lshr i32 %mul, 26
320321
%idxprom = zext i32 %shr to i64
321322
%arrayidx = getelementptr inbounds [64 x i8], ptr @ctz2.table, i64 0, i64 %idxprom
322-
%0 = load i16, ptr %arrayidx, align 2
323+
%0 = load i16, ptr %arrayidx, align 1
323324
%conv = sext i16 %0 to i32
324325
ret i32 %conv
325326
}
327+
328+
; This is the same a ctz2_with_i8_gep but with the gep index multiplied by 2.
329+
define i32 @ctz2_with_i8_gep_fixed(i32 %x) {
330+
; CHECK-LABEL: @ctz2_with_i8_gep_fixed(
331+
; CHECK-NEXT: [[SUB:%.*]] = sub i32 0, [[X:%.*]]
332+
; CHECK-NEXT: [[AND:%.*]] = and i32 [[X]], [[SUB]]
333+
; CHECK-NEXT: [[MUL:%.*]] = mul i32 [[AND]], 72416175
334+
; CHECK-NEXT: [[SHR:%.*]] = lshr i32 [[MUL]], 25
335+
; CHECK-NEXT: [[SHR2:%.*]] = and i32 [[SHR]], 126
336+
; CHECK-NEXT: [[TMP1:%.*]] = zext nneg i32 [[SHR2]] to i64
337+
; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds nuw i8, ptr @ctz2.table, i64 [[TMP1]]
338+
; CHECK-NEXT: [[TMP2:%.*]] = load i16, ptr [[ARRAYIDX]], align 2
339+
; CHECK-NEXT: [[CONV:%.*]] = sext i16 [[TMP2]] to i32
340+
; CHECK-NEXT: ret i32 [[CONV]]
341+
;
342+
%sub = sub i32 0, %x
343+
%and = and i32 %x, %sub
344+
%mul = mul i32 %and, 72416175
345+
%shr = lshr i32 %mul, 25
346+
%shr2 = and i32 %shr, 126
347+
%1 = zext nneg i32 %shr2 to i64
348+
%arrayidx = getelementptr inbounds nuw i8, ptr @ctz2.table, i64 %1
349+
%2 = load i16, ptr %arrayidx, align 2
350+
%conv = sext i16 %2 to i32
351+
ret i32 %conv
352+
}
353+
354+
; This is a i16 input with the debruijn table stored in a single i128.
355+
@tablei128 = internal unnamed_addr constant i128 16018378897745984667142067713738932480, align 16
356+
define i32 @cttz_i16_via_i128(i16 noundef %x) {
357+
; CHECK-LABEL: @cttz_i16_via_i128(
358+
; CHECK-NEXT: entry:
359+
; CHECK-NEXT: [[TMP0:%.*]] = call i16 @llvm.cttz.i16(i16 [[X:%.*]], i1 true)
360+
; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i16 [[X]], 0
361+
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP3]], i16 0, i16 [[TMP0]]
362+
; CHECK-NEXT: [[TMP1:%.*]] = trunc i16 [[TMP2]] to i8
363+
; CHECK-NEXT: [[CONV6:%.*]] = zext i8 [[TMP1]] to i32
364+
; CHECK-NEXT: ret i32 [[CONV6]]
365+
;
366+
entry:
367+
%sub = sub i16 0, %x
368+
%and = and i16 %x, %sub
369+
%mul = mul i16 %and, 2479
370+
%0 = lshr i16 %mul, 12
371+
%idxprom = zext nneg i16 %0 to i64
372+
%arrayidx = getelementptr inbounds nuw i8, ptr @tablei128, i64 %idxprom
373+
%1 = load i8, ptr %arrayidx, align 1
374+
%conv6 = zext i8 %1 to i32
375+
ret i32 %conv6
376+
}
377+
378+
; Same as above but the table is a little off
379+
@tablei128b = internal unnamed_addr constant i128 16018378897745984667142068813250560256, align 16
380+
define i32 @cttz_i16_via_i128_incorrecttable(i16 noundef %x) {
381+
; CHECK-LABEL: @cttz_i16_via_i128_incorrecttable(
382+
; CHECK-NEXT: entry:
383+
; CHECK-NEXT: [[SUB:%.*]] = sub i16 0, [[X:%.*]]
384+
; CHECK-NEXT: [[AND:%.*]] = and i16 [[X]], [[SUB]]
385+
; CHECK-NEXT: [[MUL:%.*]] = mul i16 [[AND]], 2479
386+
; CHECK-NEXT: [[TMP0:%.*]] = lshr i16 [[MUL]], 12
387+
; CHECK-NEXT: [[IDXPROM:%.*]] = zext nneg i16 [[TMP0]] to i64
388+
; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds nuw i8, ptr @tablei128b, i64 [[IDXPROM]]
389+
; CHECK-NEXT: [[TMP3:%.*]] = load i8, ptr [[ARRAYIDX]], align 1
390+
; CHECK-NEXT: [[CONV6:%.*]] = zext i8 [[TMP3]] to i32
391+
; CHECK-NEXT: ret i32 [[CONV6]]
392+
;
393+
entry:
394+
%sub = sub i16 0, %x
395+
%and = and i16 %x, %sub
396+
%mul = mul i16 %and, 2479
397+
%0 = lshr i16 %mul, 12
398+
%idxprom = zext nneg i16 %0 to i64
399+
%arrayidx = getelementptr inbounds nuw i8, ptr @tablei128b, i64 %idxprom
400+
%1 = load i8, ptr %arrayidx, align 1
401+
%conv6 = zext i8 %1 to i32
402+
ret i32 %conv6
403+
}

0 commit comments

Comments
 (0)