|
19 | 19 | #include "llvm/IR/PatternMatch.h"
|
20 | 20 | #include "llvm/Transforms/InstCombine/InstCombiner.h"
|
21 | 21 | #include "llvm/Transforms/Utils/Local.h"
|
| 22 | +#include <bitset> |
| 23 | +#include <map> |
22 | 24 |
|
23 | 25 | using namespace llvm;
|
24 | 26 | using namespace PatternMatch;
|
@@ -47,6 +49,200 @@ static Value *getFCmpValue(unsigned Code, Value *LHS, Value *RHS,
|
47 | 49 | return Builder.CreateFCmpFMF(NewPred, LHS, RHS, FMF);
|
48 | 50 | }
|
49 | 51 |
|
| 52 | +/// This is to create optimal 3-variable boolean logic from truth tables. |
| 53 | +/// currently it supports the cases pertaining to the issue 97044. More cases can be added |
| 54 | +/// based on real-world justification for specific 3 input cases |
| 55 | +/// or with reviewer approval all 256 cases can be added (choose the canonicalizations found |
| 56 | +/// in x86InstCombine.cpp?) |
| 57 | +static Value *createLogicFromTable3Var(const std::bitset<8> &Table, Value *Op0, |
| 58 | + Value *Op1, Value *Op2, Value *Root, |
| 59 | + IRBuilderBase &Builder, bool HasOneUse) { |
| 60 | + uint8_t TruthValue = Table.to_ulong(); |
| 61 | + |
| 62 | + // Skip transformation if expression is already simple (at most 2 levels |
| 63 | + // deep). |
| 64 | + if (Root->hasOneUse() && isa<BinaryOperator>(Root)) { |
| 65 | + if (auto *BO = dyn_cast<BinaryOperator>(Root)) { |
| 66 | + bool IsSimple = !isa<BinaryOperator>(BO->getOperand(0)) || |
| 67 | + !isa<BinaryOperator>(BO->getOperand(1)); |
| 68 | + if (IsSimple) |
| 69 | + return nullptr; |
| 70 | + } |
| 71 | + } |
| 72 | + |
| 73 | + auto FoldConstant = [&](bool Val) { |
| 74 | + Constant *Res = Val ? Builder.getTrue() : Builder.getFalse(); |
| 75 | + if (Op0->getType()->isVectorTy()) |
| 76 | + Res = ConstantVector::getSplat( |
| 77 | + cast<VectorType>(Op0->getType())->getElementCount(), Res); |
| 78 | + return Res; |
| 79 | + }; |
| 80 | + |
| 81 | + Value *Result = nullptr; |
| 82 | + switch (TruthValue) { |
| 83 | + default: |
| 84 | + return nullptr; |
| 85 | + |
| 86 | + case 0x00: // Always FALSE |
| 87 | + Result = FoldConstant(false); |
| 88 | + break; |
| 89 | + |
| 90 | + case 0xFF: // Always TRUE |
| 91 | + Result = FoldConstant(true); |
| 92 | + break; |
| 93 | + |
| 94 | + case 0xE1: // ~((Op1 | Op2) ^ Op0) |
| 95 | + if (!HasOneUse) |
| 96 | + return nullptr; |
| 97 | + { |
| 98 | + Value *Or = Builder.CreateOr(Op1, Op2); |
| 99 | + Value *Xor = Builder.CreateXor(Or, Op0); |
| 100 | + Result = Builder.CreateNot(Xor); |
| 101 | + } |
| 102 | + break; |
| 103 | + |
| 104 | + case 0x60: // Op0 & (Op1 ^ Op2) |
| 105 | + if (!HasOneUse) |
| 106 | + return nullptr; |
| 107 | + { |
| 108 | + Value *Xor = Builder.CreateXor(Op1, Op2); |
| 109 | + Result = Builder.CreateAnd(Op0, Xor); |
| 110 | + } |
| 111 | + break; |
| 112 | + |
| 113 | + case 0xD2: // ((Op1 | Op2) ^ Op0) ^ Op1 |
| 114 | + if (!HasOneUse) |
| 115 | + return nullptr; |
| 116 | + { |
| 117 | + Value *Or = Builder.CreateOr(Op1, Op2); |
| 118 | + Value *Xor1 = Builder.CreateXor(Or, Op0); |
| 119 | + Result = Builder.CreateXor(Xor1, Op1); |
| 120 | + } |
| 121 | + break; |
| 122 | + } |
| 123 | + |
| 124 | + return Result; |
| 125 | +} |
| 126 | + |
| 127 | +static std::tuple<Value *, Value *, Value *> |
| 128 | +extractThreeVariables(Value *Root) { |
| 129 | + std::set<Value *> Variables; |
| 130 | + unsigned NodeCount = 0; |
| 131 | + const unsigned MaxNodes = 50; // To prevent exponential blowup (see bitwise-hang.ll) |
| 132 | + |
| 133 | + std::function<void(Value *)> Collect = [&](Value *V) { |
| 134 | + if (++NodeCount > MaxNodes) |
| 135 | + return; |
| 136 | + |
| 137 | + Value *NotV; |
| 138 | + if (match(V, m_Not(m_Value(NotV)))) { |
| 139 | + Collect(NotV); |
| 140 | + return; |
| 141 | + } |
| 142 | + if (auto *BO = dyn_cast<BinaryOperator>(V)) { |
| 143 | + Collect(BO->getOperand(0)); |
| 144 | + Collect(BO->getOperand(1)); |
| 145 | + } else if (isa<Argument>(V) || isa<Instruction>(V)) { |
| 146 | + if (!isa<Constant>(V) && V != Root) { |
| 147 | + Variables.insert(V); |
| 148 | + } |
| 149 | + } |
| 150 | + }; |
| 151 | + |
| 152 | + Collect(Root); |
| 153 | + |
| 154 | + // Bail if we hit the node limit |
| 155 | + if (NodeCount > MaxNodes) |
| 156 | + return {nullptr, nullptr, nullptr}; |
| 157 | + |
| 158 | + if (Variables.size() == 3) { |
| 159 | + auto It = Variables.begin(); |
| 160 | + Value *Op0 = *It++; |
| 161 | + Value *Op1 = *It++; |
| 162 | + Value *Op2 = *It; |
| 163 | + return {Op0, Op1, Op2}; |
| 164 | + } |
| 165 | + return {nullptr, nullptr, nullptr}; |
| 166 | +} |
| 167 | + |
| 168 | +/// Evaluate a boolean expression with concrete variable values. |
| 169 | +static std::optional<bool> |
| 170 | +evaluateBooleanExpression(Value *Expr, const std::map<Value *, bool> &Values) { |
| 171 | + if (auto It = Values.find(Expr); It != Values.end()) { |
| 172 | + return It->second; |
| 173 | + } |
| 174 | + Value *NotExpr; |
| 175 | + if (match(Expr, m_Not(m_Value(NotExpr)))) { |
| 176 | + auto Operand = evaluateBooleanExpression(NotExpr, Values); |
| 177 | + if (Operand) |
| 178 | + return !*Operand; |
| 179 | + return std::nullopt; |
| 180 | + } |
| 181 | + if (auto *BO = dyn_cast<BinaryOperator>(Expr)) { |
| 182 | + auto LHS = evaluateBooleanExpression(BO->getOperand(0), Values); |
| 183 | + auto RHS = evaluateBooleanExpression(BO->getOperand(1), Values); |
| 184 | + if (!LHS || !RHS) |
| 185 | + return std::nullopt; |
| 186 | + |
| 187 | + switch (BO->getOpcode()) { |
| 188 | + case Instruction::And: |
| 189 | + return *LHS && *RHS; |
| 190 | + case Instruction::Or: |
| 191 | + return *LHS || *RHS; |
| 192 | + case Instruction::Xor: |
| 193 | + return *LHS != *RHS; |
| 194 | + default: |
| 195 | + return std::nullopt; |
| 196 | + } |
| 197 | + } |
| 198 | + return std::nullopt; |
| 199 | +} |
| 200 | + |
| 201 | +/// Extracts the truth table from a 3-variable boolean expression. |
| 202 | +/// The truth table is a 8-bit integer where each bit corresponds to a possible |
| 203 | +/// combination of the three variables. |
| 204 | +/// The bits are ordered as follows: |
| 205 | +/// 000, 001, 010, 011, 100, 101, 110, 111 |
| 206 | +/// The result is a bitset where the i-th bit is set if the expression is true |
| 207 | +/// for the i-th combination of the variables. |
| 208 | +static std::optional<std::bitset<8>> extractThreeBitTruthTable(Value *Expr, Value *Op0, |
| 209 | + Value *Op1, Value *Op2) { |
| 210 | + std::bitset<8> Table; |
| 211 | + for (int I = 0; I < 8; I++) { |
| 212 | + bool Val0 = (I >> 2) & 1; |
| 213 | + bool Val1 = (I >> 1) & 1; |
| 214 | + bool Val2 = I & 1; |
| 215 | + std::map<Value *, bool> Values = {{Op0, Val0}, {Op1, Val1}, {Op2, Val2}}; |
| 216 | + auto Result = evaluateBooleanExpression(Expr, Values); |
| 217 | + if (!Result) |
| 218 | + return std::nullopt; |
| 219 | + Table[I] = *Result; |
| 220 | + } |
| 221 | + return Table; |
| 222 | +} |
| 223 | + |
| 224 | +/// Try to canonicalize 3-variable boolean expressions using truth table lookup. |
| 225 | +static Value *foldThreeVarBoolExpr(Value *Root, |
| 226 | + InstCombiner::BuilderTy &Builder) { |
| 227 | + // Only proceed if this is a "complex" expression. |
| 228 | + if (!isa<BinaryOperator>(Root)) |
| 229 | + return nullptr; |
| 230 | + |
| 231 | + auto [Op0, Op1, Op2] = extractThreeVariables(Root); |
| 232 | + if (!Op0 || !Op1 || !Op2) |
| 233 | + return nullptr; |
| 234 | + |
| 235 | + auto Table = extractThreeBitTruthTable(Root, Op0, Op1, Op2); |
| 236 | + if (!Table) |
| 237 | + return nullptr; |
| 238 | + |
| 239 | + // Only transform expressions with single use to avoid code growth. |
| 240 | + if (!Root->hasOneUse()) |
| 241 | + return nullptr; |
| 242 | + |
| 243 | + return createLogicFromTable3Var(*Table, Op0, Op1, Op2, Root, Builder, true); |
| 244 | +} |
| 245 | + |
50 | 246 | /// Emit a computation of: (V >= Lo && V < Hi) if Inside is true, otherwise
|
51 | 247 | /// (V < Lo || V >= Hi). This method expects that Lo < Hi. IsSigned indicates
|
52 | 248 | /// whether to treat V, Lo, and Hi as signed or not.
|
@@ -3777,41 +3973,8 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
|
3777 | 3973 |
|
3778 | 3974 | Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
|
3779 | 3975 |
|
3780 |
| - // ((X & Y & ~Z) | (X & ~Y & Z) | (~X & ~Y &~Z) | (X & Y &Z)) -> ~((Y | Z) ^ |
3781 |
| - // X) |
3782 |
| - { |
3783 |
| - Value *X, *Y, *Z; |
3784 |
| - Value *Term1, *Term2, *XAndYAndZ; |
3785 |
| - if (match(&I, |
3786 |
| - m_Or(m_Or(m_Value(Term1), m_Value(Term2)), m_Value(XAndYAndZ))) && |
3787 |
| - match(XAndYAndZ, m_And(m_And(m_Value(X), m_Value(Y)), m_Value(Z)))) { |
3788 |
| - Value *YOrZ = Builder.CreateOr(Y, Z); |
3789 |
| - Value *YOrZXorX = Builder.CreateXor(YOrZ, X); |
3790 |
| - return BinaryOperator::CreateNot(YOrZXorX); |
3791 |
| - } |
3792 |
| - } |
3793 |
| - |
3794 |
| - // (Z & X) | ~((Y ^ X) | Z) -> ~((Y | Z) ^ X) |
3795 |
| - { |
3796 |
| - Value *X, *Y, *Z; |
3797 |
| - Value *ZAndX, *NotPattern; |
3798 |
| - |
3799 |
| - if (match(&I, m_c_Or(m_Value(ZAndX), m_Value(NotPattern))) && |
3800 |
| - match(ZAndX, m_c_And(m_Value(Z), m_Value(X)))) { |
3801 |
| - |
3802 |
| - Value *YXorXOrZ; |
3803 |
| - if (match(NotPattern, m_Not(m_Value(YXorXOrZ)))) { |
3804 |
| - Value *YXorX; |
3805 |
| - if (match(YXorXOrZ, m_c_Or(m_Value(YXorX), m_Specific(Z))) && |
3806 |
| - match(YXorX, m_c_Xor(m_Value(Y), m_Specific(X)))) { |
3807 |
| - |
3808 |
| - Value *YOrZ = Builder.CreateOr(Y, Z); |
3809 |
| - Value *YOrZXorX = Builder.CreateXor(YOrZ, X); |
3810 |
| - return BinaryOperator::CreateNot(YOrZXorX); |
3811 |
| - } |
3812 |
| - } |
3813 |
| - } |
3814 |
| - } |
| 3976 | + if (Value *Canonical = foldThreeVarBoolExpr(&I, Builder)) |
| 3977 | + return replaceInstUsesWith(I, Canonical); |
3815 | 3978 |
|
3816 | 3979 | Type *Ty = I.getType();
|
3817 | 3980 | if (Ty->isIntOrIntVectorTy(1)) {
|
@@ -5218,25 +5381,9 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
|
5218 | 5381 | return SelectInst::Create(A, NotB, C);
|
5219 | 5382 | }
|
5220 | 5383 | }
|
5221 |
| - |
5222 |
| - // ((X & Y) | (~X & ~Y)) ^ (Z & (((X & Y) | (~X & ~Y)) ^ ((X & Y) | (X & |
5223 |
| - // ~Y)))) -> ~((Y | Z) ^ X) |
5224 |
| - if (match(Op1, m_AllOnes())) { |
5225 |
| - Value *X, *Y, *Z; |
5226 |
| - Value *XorWithY; |
5227 |
| - if (match(Op0, m_Xor(m_Value(XorWithY), m_Value(Y)))) { |
5228 |
| - Value *ZAndNotY; |
5229 |
| - if (match(XorWithY, m_Xor(m_Value(X), m_Value(ZAndNotY)))) { |
5230 |
| - Value *NotY; |
5231 |
| - if (match(ZAndNotY, m_And(m_Value(Z), m_Value(NotY))) && |
5232 |
| - match(NotY, m_Not(m_Specific(Y)))) { |
5233 |
| - Value *YOrZ = Builder.CreateOr(Y, Z); |
5234 |
| - Value *YOrZXorX = Builder.CreateXor(YOrZ, X); |
5235 |
| - return BinaryOperator::CreateNot(YOrZXorX); |
5236 |
| - } |
5237 |
| - } |
5238 |
| - } |
5239 |
| - } |
| 5384 | + |
| 5385 | + if (Value *Canonical = foldThreeVarBoolExpr(&I, Builder)) |
| 5386 | + return replaceInstUsesWith(I, Canonical); |
5240 | 5387 |
|
5241 | 5388 | if (auto *LHS = dyn_cast<ICmpInst>(I.getOperand(0)))
|
5242 | 5389 | if (auto *RHS = dyn_cast<ICmpInst>(I.getOperand(1)))
|
|
0 commit comments