Skip to content

Commit 02807e3

Browse files
committed
3 input handled via truth table
1 parent 3a55b19 commit 02807e3

File tree

1 file changed

+202
-53
lines changed

1 file changed

+202
-53
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 202 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#include "llvm/IR/PatternMatch.h"
2020
#include "llvm/Transforms/InstCombine/InstCombiner.h"
2121
#include "llvm/Transforms/Utils/Local.h"
22+
#include <bitset>
23+
#include <map>
2224

2325
using namespace llvm;
2426
using namespace PatternMatch;
@@ -47,6 +49,202 @@ static Value *getFCmpValue(unsigned Code, Value *LHS, Value *RHS,
4749
return Builder.CreateFCmpFMF(NewPred, LHS, RHS, FMF);
4850
}
4951

52+
/// This is to create optimal 3-variable boolean logic from truth tables.
53+
/// currently it supports the cases pertaining to the issue 97044. More cases
54+
/// can be added based on real-world justification for specific 3 input cases
55+
/// or with reviewer approval all 256 cases can be added (choose the
56+
/// canonicalizations found
57+
/// in x86InstCombine.cpp?)
58+
static Value *createLogicFromTable3Var(const std::bitset<8> &Table, Value *Op0,
59+
Value *Op1, Value *Op2, Value *Root,
60+
IRBuilderBase &Builder, bool HasOneUse) {
61+
uint8_t TruthValue = Table.to_ulong();
62+
63+
// Skip transformation if expression is already simple (at most 2 levels
64+
// deep).
65+
if (Root->hasOneUse() && isa<BinaryOperator>(Root)) {
66+
if (auto *BO = dyn_cast<BinaryOperator>(Root)) {
67+
bool IsSimple = !isa<BinaryOperator>(BO->getOperand(0)) ||
68+
!isa<BinaryOperator>(BO->getOperand(1));
69+
if (IsSimple)
70+
return nullptr;
71+
}
72+
}
73+
74+
auto FoldConstant = [&](bool Val) {
75+
Constant *Res = Val ? Builder.getTrue() : Builder.getFalse();
76+
if (Op0->getType()->isVectorTy())
77+
Res = ConstantVector::getSplat(
78+
cast<VectorType>(Op0->getType())->getElementCount(), Res);
79+
return Res;
80+
};
81+
82+
Value *Result = nullptr;
83+
switch (TruthValue) {
84+
default:
85+
return nullptr;
86+
87+
case 0x00: // Always FALSE
88+
Result = FoldConstant(false);
89+
break;
90+
91+
case 0xFF: // Always TRUE
92+
Result = FoldConstant(true);
93+
break;
94+
95+
case 0xE1: // ~((Op1 | Op2) ^ Op0)
96+
if (!HasOneUse)
97+
return nullptr;
98+
{
99+
Value *Or = Builder.CreateOr(Op1, Op2);
100+
Value *Xor = Builder.CreateXor(Or, Op0);
101+
Result = Builder.CreateNot(Xor);
102+
}
103+
break;
104+
105+
case 0x60: // Op0 & (Op1 ^ Op2)
106+
if (!HasOneUse)
107+
return nullptr;
108+
{
109+
Value *Xor = Builder.CreateXor(Op1, Op2);
110+
Result = Builder.CreateAnd(Op0, Xor);
111+
}
112+
break;
113+
114+
case 0xD2: // ((Op1 | Op2) ^ Op0) ^ Op1
115+
if (!HasOneUse)
116+
return nullptr;
117+
{
118+
Value *Or = Builder.CreateOr(Op1, Op2);
119+
Value *Xor1 = Builder.CreateXor(Or, Op0);
120+
Result = Builder.CreateXor(Xor1, Op1);
121+
}
122+
break;
123+
}
124+
125+
return Result;
126+
}
127+
128+
static std::tuple<Value *, Value *, Value *>
129+
extractThreeVariables(Value *Root) {
130+
std::set<Value *> Variables;
131+
unsigned NodeCount = 0;
132+
const unsigned MaxNodes =
133+
50; // To prevent exponential blowup (see bitwise-hang.ll)
134+
135+
std::function<void(Value *)> Collect = [&](Value *V) {
136+
if (++NodeCount > MaxNodes)
137+
return;
138+
139+
Value *NotV;
140+
if (match(V, m_Not(m_Value(NotV)))) {
141+
Collect(NotV);
142+
return;
143+
}
144+
if (auto *BO = dyn_cast<BinaryOperator>(V)) {
145+
Collect(BO->getOperand(0));
146+
Collect(BO->getOperand(1));
147+
} else if (isa<Argument>(V) || isa<Instruction>(V)) {
148+
if (!isa<Constant>(V) && V != Root) {
149+
Variables.insert(V);
150+
}
151+
}
152+
};
153+
154+
Collect(Root);
155+
156+
// Bail if we hit the node limit
157+
if (NodeCount > MaxNodes)
158+
return {nullptr, nullptr, nullptr};
159+
160+
if (Variables.size() == 3) {
161+
auto It = Variables.begin();
162+
Value *Op0 = *It++;
163+
Value *Op1 = *It++;
164+
Value *Op2 = *It;
165+
return {Op0, Op1, Op2};
166+
}
167+
return {nullptr, nullptr, nullptr};
168+
}
169+
170+
/// Evaluate a boolean expression with concrete variable values.
171+
static std::optional<bool>
172+
evaluateBooleanExpression(Value *Expr, const std::map<Value *, bool> &Values) {
173+
if (auto It = Values.find(Expr); It != Values.end()) {
174+
return It->second;
175+
}
176+
Value *NotExpr;
177+
if (match(Expr, m_Not(m_Value(NotExpr)))) {
178+
auto Operand = evaluateBooleanExpression(NotExpr, Values);
179+
if (Operand)
180+
return !*Operand;
181+
return std::nullopt;
182+
}
183+
if (auto *BO = dyn_cast<BinaryOperator>(Expr)) {
184+
auto LHS = evaluateBooleanExpression(BO->getOperand(0), Values);
185+
auto RHS = evaluateBooleanExpression(BO->getOperand(1), Values);
186+
if (!LHS || !RHS)
187+
return std::nullopt;
188+
189+
switch (BO->getOpcode()) {
190+
case Instruction::And:
191+
return *LHS && *RHS;
192+
case Instruction::Or:
193+
return *LHS || *RHS;
194+
case Instruction::Xor:
195+
return *LHS != *RHS;
196+
default:
197+
return std::nullopt;
198+
}
199+
}
200+
return std::nullopt;
201+
}
202+
203+
/// Extracts the truth table from a 3-variable boolean expression.
204+
/// The truth table is a 8-bit integer where each bit corresponds to a possible
205+
/// combination of the three variables.
206+
/// The bits are ordered as follows:
207+
/// 000, 001, 010, 011, 100, 101, 110, 111
208+
/// The result is a bitset where the i-th bit is set if the expression is true
209+
/// for the i-th combination of the variables.
210+
static std::optional<std::bitset<8>>
211+
extractThreeBitTruthTable(Value *Expr, Value *Op0, Value *Op1, Value *Op2) {
212+
std::bitset<8> Table;
213+
for (int I = 0; I < 8; I++) {
214+
bool Val0 = (I >> 2) & 1;
215+
bool Val1 = (I >> 1) & 1;
216+
bool Val2 = I & 1;
217+
std::map<Value *, bool> Values = {{Op0, Val0}, {Op1, Val1}, {Op2, Val2}};
218+
auto Result = evaluateBooleanExpression(Expr, Values);
219+
if (!Result)
220+
return std::nullopt;
221+
Table[I] = *Result;
222+
}
223+
return Table;
224+
}
225+
226+
/// Try to canonicalize 3-variable boolean expressions using truth table lookup.
227+
static Value *foldThreeVarBoolExpr(Value *Root,
228+
InstCombiner::BuilderTy &Builder) {
229+
// Only proceed if this is a "complex" expression.
230+
if (!isa<BinaryOperator>(Root))
231+
return nullptr;
232+
233+
auto [Op0, Op1, Op2] = extractThreeVariables(Root);
234+
if (!Op0 || !Op1 || !Op2)
235+
return nullptr;
236+
237+
auto Table = extractThreeBitTruthTable(Root, Op0, Op1, Op2);
238+
if (!Table)
239+
return nullptr;
240+
241+
// Only transform expressions with single use to avoid code growth.
242+
if (!Root->hasOneUse())
243+
return nullptr;
244+
245+
return createLogicFromTable3Var(*Table, Op0, Op1, Op2, Root, Builder, true);
246+
}
247+
50248
/// Emit a computation of: (V >= Lo && V < Hi) if Inside is true, otherwise
51249
/// (V < Lo || V >= Hi). This method expects that Lo < Hi. IsSigned indicates
52250
/// whether to treat V, Lo, and Hi as signed or not.
@@ -3777,41 +3975,8 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
37773975

37783976
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
37793977

3780-
// ((X & Y & ~Z) | (X & ~Y & Z) | (~X & ~Y &~Z) | (X & Y &Z)) -> ~((Y | Z) ^
3781-
// X)
3782-
{
3783-
Value *X, *Y, *Z;
3784-
Value *Term1, *Term2, *XAndYAndZ;
3785-
if (match(&I,
3786-
m_Or(m_Or(m_Value(Term1), m_Value(Term2)), m_Value(XAndYAndZ))) &&
3787-
match(XAndYAndZ, m_And(m_And(m_Value(X), m_Value(Y)), m_Value(Z)))) {
3788-
Value *YOrZ = Builder.CreateOr(Y, Z);
3789-
Value *YOrZXorX = Builder.CreateXor(YOrZ, X);
3790-
return BinaryOperator::CreateNot(YOrZXorX);
3791-
}
3792-
}
3793-
3794-
// (Z & X) | ~((Y ^ X) | Z) -> ~((Y | Z) ^ X)
3795-
{
3796-
Value *X, *Y, *Z;
3797-
Value *ZAndX, *NotPattern;
3798-
3799-
if (match(&I, m_c_Or(m_Value(ZAndX), m_Value(NotPattern))) &&
3800-
match(ZAndX, m_c_And(m_Value(Z), m_Value(X)))) {
3801-
3802-
Value *YXorXOrZ;
3803-
if (match(NotPattern, m_Not(m_Value(YXorXOrZ)))) {
3804-
Value *YXorX;
3805-
if (match(YXorXOrZ, m_c_Or(m_Value(YXorX), m_Specific(Z))) &&
3806-
match(YXorX, m_c_Xor(m_Value(Y), m_Specific(X)))) {
3807-
3808-
Value *YOrZ = Builder.CreateOr(Y, Z);
3809-
Value *YOrZXorX = Builder.CreateXor(YOrZ, X);
3810-
return BinaryOperator::CreateNot(YOrZXorX);
3811-
}
3812-
}
3813-
}
3814-
}
3978+
if (Value *Canonical = foldThreeVarBoolExpr(&I, Builder))
3979+
return replaceInstUsesWith(I, Canonical);
38153980

38163981
Type *Ty = I.getType();
38173982
if (Ty->isIntOrIntVectorTy(1)) {
@@ -5219,24 +5384,8 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
52195384
}
52205385
}
52215386

5222-
// ((X & Y) | (~X & ~Y)) ^ (Z & (((X & Y) | (~X & ~Y)) ^ ((X & Y) | (X &
5223-
// ~Y)))) -> ~((Y | Z) ^ X)
5224-
if (match(Op1, m_AllOnes())) {
5225-
Value *X, *Y, *Z;
5226-
Value *XorWithY;
5227-
if (match(Op0, m_Xor(m_Value(XorWithY), m_Value(Y)))) {
5228-
Value *ZAndNotY;
5229-
if (match(XorWithY, m_Xor(m_Value(X), m_Value(ZAndNotY)))) {
5230-
Value *NotY;
5231-
if (match(ZAndNotY, m_And(m_Value(Z), m_Value(NotY))) &&
5232-
match(NotY, m_Not(m_Specific(Y)))) {
5233-
Value *YOrZ = Builder.CreateOr(Y, Z);
5234-
Value *YOrZXorX = Builder.CreateXor(YOrZ, X);
5235-
return BinaryOperator::CreateNot(YOrZXorX);
5236-
}
5237-
}
5238-
}
5239-
}
5387+
if (Value *Canonical = foldThreeVarBoolExpr(&I, Builder))
5388+
return replaceInstUsesWith(I, Canonical);
52405389

52415390
if (auto *LHS = dyn_cast<ICmpInst>(I.getOperand(0)))
52425391
if (auto *RHS = dyn_cast<ICmpInst>(I.getOperand(1)))

0 commit comments

Comments
 (0)