Skip to content

Commit 4c86e54

Browse files
committed
removed recursion + smallptrset used
1 parent d066a85 commit 4c86e54

File tree

1 file changed

+79
-45
lines changed

1 file changed

+79
-45
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 79 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
#include "llvm/Transforms/InstCombine/InstCombiner.h"
2121
#include "llvm/Transforms/Utils/Local.h"
2222
#include <bitset>
23-
#include <map>
2423

2524
using namespace llvm;
2625
using namespace PatternMatch;
@@ -115,77 +114,109 @@ static Value *createLogicFromTable3Var(const std::bitset<8> &Table, Value *Op0,
115114

116115
static std::tuple<Value *, Value *, Value *>
117116
extractThreeVariables(Value *Root) {
118-
std::set<Value *> Variables;
117+
SmallPtrSet<Value *, 3> Variables;
119118
unsigned NodeCount = 0;
120-
const unsigned MaxNodes =
121-
50; // To prevent exponential blowup (see bitwise-hang.ll)
119+
const unsigned MaxNodes = 50; // To prevent exponential blowup with loop
120+
// unrolling(see bitreverse-hang.ll)
122121

123-
std::function<void(Value *)> Collect = [&](Value *V) {
124-
if (++NodeCount > MaxNodes)
125-
return;
122+
SmallVector<Value *> Worklist;
123+
Worklist.push_back(Root);
124+
125+
while (!Worklist.empty() && NodeCount <= MaxNodes) {
126+
Value *V = Worklist.pop_back_val();
127+
++NodeCount;
128+
129+
if (NodeCount > MaxNodes)
130+
break;
126131

127132
Value *NotV;
128133
if (match(V, m_Not(m_Value(NotV)))) {
129-
Collect(NotV);
130-
return;
134+
Worklist.push_back(NotV);
135+
continue;
131136
}
132137
if (auto *BO = dyn_cast<BinaryOperator>(V)) {
133-
Collect(BO->getOperand(0));
134-
Collect(BO->getOperand(1));
138+
Worklist.push_back(BO->getOperand(0));
139+
Worklist.push_back(BO->getOperand(1));
135140
} else if (isa<Argument>(V) || isa<Instruction>(V)) {
136141
if (!isa<Constant>(V) && V != Root) {
137142
Variables.insert(V);
138143
}
139144
}
140-
};
141-
142-
Collect(Root);
145+
}
143146

144147
// Bail if we hit the node limit
145148
if (NodeCount > MaxNodes)
146149
return {nullptr, nullptr, nullptr};
147150

148151
if (Variables.size() == 3) {
149-
auto It = Variables.begin();
150-
Value *Op0 = *It++;
151-
Value *Op1 = *It++;
152-
Value *Op2 = *It;
153-
return {Op0, Op1, Op2};
152+
// Sort variables by pointer value to ensure deterministic ordering
153+
SmallVector<Value *, 3> SortedVars(Variables.begin(), Variables.end());
154+
llvm::sort(SortedVars, [](Value *A, Value *B) { return A < B; });
155+
return {SortedVars[0], SortedVars[1], SortedVars[2]};
154156
}
155157
return {nullptr, nullptr, nullptr};
156158
}
157159

158160
/// Evaluate a boolean expression with concrete variable values.
159161
static std::optional<bool>
160-
evaluateBooleanExpression(Value *Expr, const std::map<Value *, bool> &Values) {
161-
if (auto It = Values.find(Expr); It != Values.end()) {
162-
return It->second;
163-
}
164-
Value *NotExpr;
165-
if (match(Expr, m_Not(m_Value(NotExpr)))) {
166-
auto Operand = evaluateBooleanExpression(NotExpr, Values);
167-
if (Operand)
168-
return !*Operand;
169-
return std::nullopt;
162+
evaluateBooleanExpression(Value *Expr,
163+
const SmallMapVector<Value *, bool, 4> &Values) {
164+
165+
// Post-order traversal of the expression tree
166+
SmallVector<Instruction *> Instructions;
167+
SmallVector<Value *> ToVisit;
168+
SmallPtrSet<Instruction *, 8> Seen;
169+
170+
ToVisit.push_back(Expr);
171+
while (!ToVisit.empty()) {
172+
Value *V = ToVisit.pop_back_val();
173+
if (auto *I = dyn_cast<Instruction>(V)) {
174+
if (Seen.insert(I).second) {
175+
Instructions.push_back(I);
176+
for (Value *Op : I->operands()) {
177+
ToVisit.push_back(Op);
178+
}
179+
}
180+
}
170181
}
171-
if (auto *BO = dyn_cast<BinaryOperator>(Expr)) {
172-
auto LHS = evaluateBooleanExpression(BO->getOperand(0), Values);
173-
auto RHS = evaluateBooleanExpression(BO->getOperand(1), Values);
174-
if (!LHS || !RHS)
175-
return std::nullopt;
176182

177-
switch (BO->getOpcode()) {
178-
case Instruction::And:
179-
return *LHS && *RHS;
180-
case Instruction::Or:
181-
return *LHS || *RHS;
182-
case Instruction::Xor:
183-
return *LHS != *RHS;
184-
default:
185-
return std::nullopt;
183+
llvm::sort(Instructions,
184+
[](Instruction *A, Instruction *B) { return A->comesBefore(B); });
185+
186+
// Now in topological order we can evaluate the expression
187+
SmallDenseMap<Value *, bool> Computed(Values.begin(), Values.end());
188+
189+
for (Instruction *I : Instructions) {
190+
Value *NotV;
191+
if (match(I, m_Not(m_Value(NotV)))) {
192+
auto It = Computed.find(NotV);
193+
if (It == Computed.end())
194+
return std::nullopt;
195+
Computed[I] = !It->second;
196+
} else if (auto *BO = dyn_cast<BinaryOperator>(I)) {
197+
auto LHSIt = Computed.find(BO->getOperand(0));
198+
auto RHSIt = Computed.find(BO->getOperand(1));
199+
if (LHSIt == Computed.end() || RHSIt == Computed.end())
200+
return std::nullopt;
201+
202+
switch (BO->getOpcode()) {
203+
case Instruction::And:
204+
Computed[I] = LHSIt->second && RHSIt->second;
205+
break;
206+
case Instruction::Or:
207+
Computed[I] = LHSIt->second || RHSIt->second;
208+
break;
209+
case Instruction::Xor:
210+
Computed[I] = LHSIt->second != RHSIt->second;
211+
break;
212+
default:
213+
return std::nullopt;
214+
}
186215
}
187216
}
188-
return std::nullopt;
217+
218+
auto It = Computed.find(Expr);
219+
return It != Computed.end() ? std::optional<bool>(It->second) : std::nullopt;
189220
}
190221

191222
/// Extracts the truth table from a 3-variable boolean expression.
@@ -202,7 +233,10 @@ extractThreeBitTruthTable(Value *Expr, Value *Op0, Value *Op1, Value *Op2) {
202233
bool Val0 = (I >> 2) & 1;
203234
bool Val1 = (I >> 1) & 1;
204235
bool Val2 = I & 1;
205-
std::map<Value *, bool> Values = {{Op0, Val0}, {Op1, Val1}, {Op2, Val2}};
236+
SmallMapVector<Value *, bool, 4> Values;
237+
Values[Op0] = Val0;
238+
Values[Op1] = Val1;
239+
Values[Op2] = Val2;
206240
auto Result = evaluateBooleanExpression(Expr, Values);
207241
if (!Result)
208242
return std::nullopt;

0 commit comments

Comments
 (0)