Skip to content

Commit db43e0a

Browse files
committed
3 input handled via truth table
1 parent 3a55b19 commit db43e0a

File tree

1 file changed

+201
-54
lines changed

1 file changed

+201
-54
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 201 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#include "llvm/IR/PatternMatch.h"
2020
#include "llvm/Transforms/InstCombine/InstCombiner.h"
2121
#include "llvm/Transforms/Utils/Local.h"
22+
#include <bitset>
23+
#include <map>
2224

2325
using namespace llvm;
2426
using namespace PatternMatch;
@@ -47,6 +49,200 @@ static Value *getFCmpValue(unsigned Code, Value *LHS, Value *RHS,
4749
return Builder.CreateFCmpFMF(NewPred, LHS, RHS, FMF);
4850
}
4951

52+
/// This is to create optimal 3-variable boolean logic from truth tables.
53+
/// currently it supports the cases pertaining to the issue 97044. More cases can be added
54+
/// based on real-world justification for specific 3 input cases
55+
/// or with reviewer approval all 256 cases can be added (choose the canonicalizations found
56+
/// in x86InstCombine.cpp?)
57+
static Value *createLogicFromTable3Var(const std::bitset<8> &Table, Value *Op0,
58+
Value *Op1, Value *Op2, Value *Root,
59+
IRBuilderBase &Builder, bool HasOneUse) {
60+
uint8_t TruthValue = Table.to_ulong();
61+
62+
// Skip transformation if expression is already simple (at most 2 levels
63+
// deep).
64+
if (Root->hasOneUse() && isa<BinaryOperator>(Root)) {
65+
if (auto *BO = dyn_cast<BinaryOperator>(Root)) {
66+
bool IsSimple = !isa<BinaryOperator>(BO->getOperand(0)) ||
67+
!isa<BinaryOperator>(BO->getOperand(1));
68+
if (IsSimple)
69+
return nullptr;
70+
}
71+
}
72+
73+
auto FoldConstant = [&](bool Val) {
74+
Constant *Res = Val ? Builder.getTrue() : Builder.getFalse();
75+
if (Op0->getType()->isVectorTy())
76+
Res = ConstantVector::getSplat(
77+
cast<VectorType>(Op0->getType())->getElementCount(), Res);
78+
return Res;
79+
};
80+
81+
Value *Result = nullptr;
82+
switch (TruthValue) {
83+
default:
84+
return nullptr;
85+
86+
case 0x00: // Always FALSE
87+
Result = FoldConstant(false);
88+
break;
89+
90+
case 0xFF: // Always TRUE
91+
Result = FoldConstant(true);
92+
break;
93+
94+
case 0xE1: // ~((Op1 | Op2) ^ Op0)
95+
if (!HasOneUse)
96+
return nullptr;
97+
{
98+
Value *Or = Builder.CreateOr(Op1, Op2);
99+
Value *Xor = Builder.CreateXor(Or, Op0);
100+
Result = Builder.CreateNot(Xor);
101+
}
102+
break;
103+
104+
case 0x60: // Op0 & (Op1 ^ Op2)
105+
if (!HasOneUse)
106+
return nullptr;
107+
{
108+
Value *Xor = Builder.CreateXor(Op1, Op2);
109+
Result = Builder.CreateAnd(Op0, Xor);
110+
}
111+
break;
112+
113+
case 0xD2: // ((Op1 | Op2) ^ Op0) ^ Op1
114+
if (!HasOneUse)
115+
return nullptr;
116+
{
117+
Value *Or = Builder.CreateOr(Op1, Op2);
118+
Value *Xor1 = Builder.CreateXor(Or, Op0);
119+
Result = Builder.CreateXor(Xor1, Op1);
120+
}
121+
break;
122+
}
123+
124+
return Result;
125+
}
126+
127+
static std::tuple<Value *, Value *, Value *>
128+
extractThreeVariables(Value *Root) {
129+
std::set<Value *> Variables;
130+
unsigned NodeCount = 0;
131+
const unsigned MaxNodes = 50; // To prevent exponential blowup (see bitwise-hang.ll)
132+
133+
std::function<void(Value *)> Collect = [&](Value *V) {
134+
if (++NodeCount > MaxNodes)
135+
return;
136+
137+
Value *NotV;
138+
if (match(V, m_Not(m_Value(NotV)))) {
139+
Collect(NotV);
140+
return;
141+
}
142+
if (auto *BO = dyn_cast<BinaryOperator>(V)) {
143+
Collect(BO->getOperand(0));
144+
Collect(BO->getOperand(1));
145+
} else if (isa<Argument>(V) || isa<Instruction>(V)) {
146+
if (!isa<Constant>(V) && V != Root) {
147+
Variables.insert(V);
148+
}
149+
}
150+
};
151+
152+
Collect(Root);
153+
154+
// Bail if we hit the node limit
155+
if (NodeCount > MaxNodes)
156+
return {nullptr, nullptr, nullptr};
157+
158+
if (Variables.size() == 3) {
159+
auto It = Variables.begin();
160+
Value *Op0 = *It++;
161+
Value *Op1 = *It++;
162+
Value *Op2 = *It;
163+
return {Op0, Op1, Op2};
164+
}
165+
return {nullptr, nullptr, nullptr};
166+
}
167+
168+
/// Evaluate a boolean expression with concrete variable values.
169+
static std::optional<bool>
170+
evaluateBooleanExpression(Value *Expr, const std::map<Value *, bool> &Values) {
171+
if (auto It = Values.find(Expr); It != Values.end()) {
172+
return It->second;
173+
}
174+
Value *NotExpr;
175+
if (match(Expr, m_Not(m_Value(NotExpr)))) {
176+
auto Operand = evaluateBooleanExpression(NotExpr, Values);
177+
if (Operand)
178+
return !*Operand;
179+
return std::nullopt;
180+
}
181+
if (auto *BO = dyn_cast<BinaryOperator>(Expr)) {
182+
auto LHS = evaluateBooleanExpression(BO->getOperand(0), Values);
183+
auto RHS = evaluateBooleanExpression(BO->getOperand(1), Values);
184+
if (!LHS || !RHS)
185+
return std::nullopt;
186+
187+
switch (BO->getOpcode()) {
188+
case Instruction::And:
189+
return *LHS && *RHS;
190+
case Instruction::Or:
191+
return *LHS || *RHS;
192+
case Instruction::Xor:
193+
return *LHS != *RHS;
194+
default:
195+
return std::nullopt;
196+
}
197+
}
198+
return std::nullopt;
199+
}
200+
201+
/// Extracts the truth table from a 3-variable boolean expression.
202+
/// The truth table is a 8-bit integer where each bit corresponds to a possible
203+
/// combination of the three variables.
204+
/// The bits are ordered as follows:
205+
/// 000, 001, 010, 011, 100, 101, 110, 111
206+
/// The result is a bitset where the i-th bit is set if the expression is true
207+
/// for the i-th combination of the variables.
208+
static std::optional<std::bitset<8>> extractThreeBitTruthTable(Value *Expr, Value *Op0,
209+
Value *Op1, Value *Op2) {
210+
std::bitset<8> Table;
211+
for (int I = 0; I < 8; I++) {
212+
bool Val0 = (I >> 2) & 1;
213+
bool Val1 = (I >> 1) & 1;
214+
bool Val2 = I & 1;
215+
std::map<Value *, bool> Values = {{Op0, Val0}, {Op1, Val1}, {Op2, Val2}};
216+
auto Result = evaluateBooleanExpression(Expr, Values);
217+
if (!Result)
218+
return std::nullopt;
219+
Table[I] = *Result;
220+
}
221+
return Table;
222+
}
223+
224+
/// Try to canonicalize 3-variable boolean expressions using truth table lookup.
225+
static Value *foldThreeVarBoolExpr(Value *Root,
226+
InstCombiner::BuilderTy &Builder) {
227+
// Only proceed if this is a "complex" expression.
228+
if (!isa<BinaryOperator>(Root))
229+
return nullptr;
230+
231+
auto [Op0, Op1, Op2] = extractThreeVariables(Root);
232+
if (!Op0 || !Op1 || !Op2)
233+
return nullptr;
234+
235+
auto Table = extractThreeBitTruthTable(Root, Op0, Op1, Op2);
236+
if (!Table)
237+
return nullptr;
238+
239+
// Only transform expressions with single use to avoid code growth.
240+
if (!Root->hasOneUse())
241+
return nullptr;
242+
243+
return createLogicFromTable3Var(*Table, Op0, Op1, Op2, Root, Builder, true);
244+
}
245+
50246
/// Emit a computation of: (V >= Lo && V < Hi) if Inside is true, otherwise
51247
/// (V < Lo || V >= Hi). This method expects that Lo < Hi. IsSigned indicates
52248
/// whether to treat V, Lo, and Hi as signed or not.
@@ -3777,41 +3973,8 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
37773973

37783974
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
37793975

3780-
// ((X & Y & ~Z) | (X & ~Y & Z) | (~X & ~Y &~Z) | (X & Y &Z)) -> ~((Y | Z) ^
3781-
// X)
3782-
{
3783-
Value *X, *Y, *Z;
3784-
Value *Term1, *Term2, *XAndYAndZ;
3785-
if (match(&I,
3786-
m_Or(m_Or(m_Value(Term1), m_Value(Term2)), m_Value(XAndYAndZ))) &&
3787-
match(XAndYAndZ, m_And(m_And(m_Value(X), m_Value(Y)), m_Value(Z)))) {
3788-
Value *YOrZ = Builder.CreateOr(Y, Z);
3789-
Value *YOrZXorX = Builder.CreateXor(YOrZ, X);
3790-
return BinaryOperator::CreateNot(YOrZXorX);
3791-
}
3792-
}
3793-
3794-
// (Z & X) | ~((Y ^ X) | Z) -> ~((Y | Z) ^ X)
3795-
{
3796-
Value *X, *Y, *Z;
3797-
Value *ZAndX, *NotPattern;
3798-
3799-
if (match(&I, m_c_Or(m_Value(ZAndX), m_Value(NotPattern))) &&
3800-
match(ZAndX, m_c_And(m_Value(Z), m_Value(X)))) {
3801-
3802-
Value *YXorXOrZ;
3803-
if (match(NotPattern, m_Not(m_Value(YXorXOrZ)))) {
3804-
Value *YXorX;
3805-
if (match(YXorXOrZ, m_c_Or(m_Value(YXorX), m_Specific(Z))) &&
3806-
match(YXorX, m_c_Xor(m_Value(Y), m_Specific(X)))) {
3807-
3808-
Value *YOrZ = Builder.CreateOr(Y, Z);
3809-
Value *YOrZXorX = Builder.CreateXor(YOrZ, X);
3810-
return BinaryOperator::CreateNot(YOrZXorX);
3811-
}
3812-
}
3813-
}
3814-
}
3976+
if (Value *Canonical = foldThreeVarBoolExpr(&I, Builder))
3977+
return replaceInstUsesWith(I, Canonical);
38153978

38163979
Type *Ty = I.getType();
38173980
if (Ty->isIntOrIntVectorTy(1)) {
@@ -5218,25 +5381,9 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
52185381
return SelectInst::Create(A, NotB, C);
52195382
}
52205383
}
5221-
5222-
// ((X & Y) | (~X & ~Y)) ^ (Z & (((X & Y) | (~X & ~Y)) ^ ((X & Y) | (X &
5223-
// ~Y)))) -> ~((Y | Z) ^ X)
5224-
if (match(Op1, m_AllOnes())) {
5225-
Value *X, *Y, *Z;
5226-
Value *XorWithY;
5227-
if (match(Op0, m_Xor(m_Value(XorWithY), m_Value(Y)))) {
5228-
Value *ZAndNotY;
5229-
if (match(XorWithY, m_Xor(m_Value(X), m_Value(ZAndNotY)))) {
5230-
Value *NotY;
5231-
if (match(ZAndNotY, m_And(m_Value(Z), m_Value(NotY))) &&
5232-
match(NotY, m_Not(m_Specific(Y)))) {
5233-
Value *YOrZ = Builder.CreateOr(Y, Z);
5234-
Value *YOrZXorX = Builder.CreateXor(YOrZ, X);
5235-
return BinaryOperator::CreateNot(YOrZXorX);
5236-
}
5237-
}
5238-
}
5239-
}
5384+
5385+
if (Value *Canonical = foldThreeVarBoolExpr(&I, Builder))
5386+
return replaceInstUsesWith(I, Canonical);
52405387

52415388
if (auto *LHS = dyn_cast<ICmpInst>(I.getOperand(0)))
52425389
if (auto *RHS = dyn_cast<ICmpInst>(I.getOperand(1)))

0 commit comments

Comments
 (0)