-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[InstCombine] Canonicalize complex boolean expressions into ~((y | z) ^ x) via 3-input truth table #149530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[InstCombine] Canonicalize complex boolean expressions into ~((y | z) ^ x) via 3-input truth table #149530
Changes from all commits
cf2f9db
3a55b19
02807e3
af90743
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -19,6 +19,8 @@ | |||||||||||||||
#include "llvm/IR/PatternMatch.h" | ||||||||||||||||
#include "llvm/Transforms/InstCombine/InstCombiner.h" | ||||||||||||||||
#include "llvm/Transforms/Utils/Local.h" | ||||||||||||||||
#include <bitset> | ||||||||||||||||
#include <map> | ||||||||||||||||
|
||||||||||||||||
using namespace llvm; | ||||||||||||||||
using namespace PatternMatch; | ||||||||||||||||
|
@@ -47,6 +49,202 @@ static Value *getFCmpValue(unsigned Code, Value *LHS, Value *RHS, | |||||||||||||||
return Builder.CreateFCmpFMF(NewPred, LHS, RHS, FMF); | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
/// This is to create optimal 3-variable boolean logic from truth tables. | ||||||||||||||||
/// currently it supports the cases pertaining to the issue 97044. More cases | ||||||||||||||||
/// can be added based on real-world justification for specific 3 input cases | ||||||||||||||||
/// or with reviewer approval all 256 cases can be added (choose the | ||||||||||||||||
/// canonicalizations found | ||||||||||||||||
/// in x86InstCombine.cpp?) | ||||||||||||||||
static Value *createLogicFromTable3Var(const std::bitset<8> &Table, Value *Op0, | ||||||||||||||||
Value *Op1, Value *Op2, Value *Root, | ||||||||||||||||
IRBuilderBase &Builder, bool HasOneUse) { | ||||||||||||||||
uint8_t TruthValue = Table.to_ulong(); | ||||||||||||||||
|
||||||||||||||||
// Skip transformation if expression is already simple (at most 2 levels | ||||||||||||||||
// deep). | ||||||||||||||||
if (Root->hasOneUse() && isa<BinaryOperator>(Root)) { | ||||||||||||||||
if (auto *BO = dyn_cast<BinaryOperator>(Root)) { | ||||||||||||||||
bool IsSimple = !isa<BinaryOperator>(BO->getOperand(0)) || | ||||||||||||||||
!isa<BinaryOperator>(BO->getOperand(1)); | ||||||||||||||||
if (IsSimple) | ||||||||||||||||
return nullptr; | ||||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
auto FoldConstant = [&](bool Val) { | ||||||||||||||||
Constant *Res = Val ? Builder.getTrue() : Builder.getFalse(); | ||||||||||||||||
if (Op0->getType()->isVectorTy()) | ||||||||||||||||
Res = ConstantVector::getSplat( | ||||||||||||||||
cast<VectorType>(Op0->getType())->getElementCount(), Res); | ||||||||||||||||
return Res; | ||||||||||||||||
Comment on lines
+75
to
+79
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
}; | ||||||||||||||||
|
||||||||||||||||
Value *Result = nullptr; | ||||||||||||||||
switch (TruthValue) { | ||||||||||||||||
default: | ||||||||||||||||
return nullptr; | ||||||||||||||||
|
||||||||||||||||
case 0x00: // Always FALSE | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am fine to just add a small number of cases to cover the motivating issue. |
||||||||||||||||
Result = FoldConstant(false); | ||||||||||||||||
break; | ||||||||||||||||
|
||||||||||||||||
case 0xFF: // Always TRUE | ||||||||||||||||
Result = FoldConstant(true); | ||||||||||||||||
break; | ||||||||||||||||
|
||||||||||||||||
case 0xE1: // ~((Op1 | Op2) ^ Op0) | ||||||||||||||||
if (!HasOneUse) | ||||||||||||||||
return nullptr; | ||||||||||||||||
{ | ||||||||||||||||
Value *Or = Builder.CreateOr(Op1, Op2); | ||||||||||||||||
Value *Xor = Builder.CreateXor(Or, Op0); | ||||||||||||||||
Result = Builder.CreateNot(Xor); | ||||||||||||||||
} | ||||||||||||||||
break; | ||||||||||||||||
|
||||||||||||||||
case 0x60: // Op0 & (Op1 ^ Op2) | ||||||||||||||||
if (!HasOneUse) | ||||||||||||||||
return nullptr; | ||||||||||||||||
{ | ||||||||||||||||
Value *Xor = Builder.CreateXor(Op1, Op2); | ||||||||||||||||
Result = Builder.CreateAnd(Op0, Xor); | ||||||||||||||||
} | ||||||||||||||||
break; | ||||||||||||||||
|
||||||||||||||||
case 0xD2: // ((Op1 | Op2) ^ Op0) ^ Op1 | ||||||||||||||||
if (!HasOneUse) | ||||||||||||||||
return nullptr; | ||||||||||||||||
{ | ||||||||||||||||
Value *Or = Builder.CreateOr(Op1, Op2); | ||||||||||||||||
Value *Xor1 = Builder.CreateXor(Or, Op0); | ||||||||||||||||
Result = Builder.CreateXor(Xor1, Op1); | ||||||||||||||||
} | ||||||||||||||||
break; | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
return Result; | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
static std::tuple<Value *, Value *, Value *> | ||||||||||||||||
extractThreeVariables(Value *Root) { | ||||||||||||||||
std::set<Value *> Variables; | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for the suggestion. This has also been implemented now in the latest commits |
||||||||||||||||
unsigned NodeCount = 0; | ||||||||||||||||
const unsigned MaxNodes = | ||||||||||||||||
50; // To prevent exponential blowup (see bitwise-hang.ll) | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, this was a typo. Was meaning to reference |
||||||||||||||||
|
||||||||||||||||
std::function<void(Value *)> Collect = [&](Value *V) { | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid recursion by using a worklist-based traversal. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the suggestion, this has been implemented now |
||||||||||||||||
if (++NodeCount > MaxNodes) | ||||||||||||||||
return; | ||||||||||||||||
|
||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As we don't have a cost-based heuristic for now, we can ensure that all non-root nodes are not used by some nodes that do not belong to the expression tree. |
||||||||||||||||
Value *NotV; | ||||||||||||||||
if (match(V, m_Not(m_Value(NotV)))) { | ||||||||||||||||
Collect(NotV); | ||||||||||||||||
return; | ||||||||||||||||
} | ||||||||||||||||
if (auto *BO = dyn_cast<BinaryOperator>(V)) { | ||||||||||||||||
Collect(BO->getOperand(0)); | ||||||||||||||||
Collect(BO->getOperand(1)); | ||||||||||||||||
} else if (isa<Argument>(V) || isa<Instruction>(V)) { | ||||||||||||||||
if (!isa<Constant>(V) && V != Root) { | ||||||||||||||||
Variables.insert(V); | ||||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
}; | ||||||||||||||||
|
||||||||||||||||
Collect(Root); | ||||||||||||||||
|
||||||||||||||||
// Bail if we hit the node limit | ||||||||||||||||
if (NodeCount > MaxNodes) | ||||||||||||||||
return {nullptr, nullptr, nullptr}; | ||||||||||||||||
|
||||||||||||||||
if (Variables.size() == 3) { | ||||||||||||||||
auto It = Variables.begin(); | ||||||||||||||||
Value *Op0 = *It++; | ||||||||||||||||
Value *Op1 = *It++; | ||||||||||||||||
Value *Op2 = *It; | ||||||||||||||||
return {Op0, Op1, Op2}; | ||||||||||||||||
} | ||||||||||||||||
return {nullptr, nullptr, nullptr}; | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
/// Evaluate a boolean expression with concrete variable values. | ||||||||||||||||
static std::optional<bool> | ||||||||||||||||
evaluateBooleanExpression(Value *Expr, const std::map<Value *, bool> &Values) { | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're correct, this would be more efficient for us here. This has been done, now along with the other llvm ADT changes |
||||||||||||||||
if (auto It = Values.find(Expr); It != Values.end()) { | ||||||||||||||||
return It->second; | ||||||||||||||||
} | ||||||||||||||||
Value *NotExpr; | ||||||||||||||||
if (match(Expr, m_Not(m_Value(NotExpr)))) { | ||||||||||||||||
auto Operand = evaluateBooleanExpression(NotExpr, Values); | ||||||||||||||||
if (Operand) | ||||||||||||||||
return !*Operand; | ||||||||||||||||
return std::nullopt; | ||||||||||||||||
} | ||||||||||||||||
if (auto *BO = dyn_cast<BinaryOperator>(Expr)) { | ||||||||||||||||
auto LHS = evaluateBooleanExpression(BO->getOperand(0), Values); | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid recursion by evaluating the values of subexpressions in the topological order. You can compute the topological order of instructions by sorting them with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the suggestion. I have implemented a version of this that removes recursion and sorts them topologically with However, although I initially considered using the dominator tree I went with the approach in commit 4c86e54 since afaiu, the I am happy to go with the dominator tree approach though if my current understanding of the matter is incorrect |
||||||||||||||||
auto RHS = evaluateBooleanExpression(BO->getOperand(1), Values); | ||||||||||||||||
if (!LHS || !RHS) | ||||||||||||||||
return std::nullopt; | ||||||||||||||||
|
||||||||||||||||
switch (BO->getOpcode()) { | ||||||||||||||||
case Instruction::And: | ||||||||||||||||
return *LHS && *RHS; | ||||||||||||||||
case Instruction::Or: | ||||||||||||||||
return *LHS || *RHS; | ||||||||||||||||
case Instruction::Xor: | ||||||||||||||||
return *LHS != *RHS; | ||||||||||||||||
default: | ||||||||||||||||
return std::nullopt; | ||||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
return std::nullopt; | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
/// Extracts the truth table from a 3-variable boolean expression. | ||||||||||||||||
/// The truth table is a 8-bit integer where each bit corresponds to a possible | ||||||||||||||||
/// combination of the three variables. | ||||||||||||||||
/// The bits are ordered as follows: | ||||||||||||||||
/// 000, 001, 010, 011, 100, 101, 110, 111 | ||||||||||||||||
/// The result is a bitset where the i-th bit is set if the expression is true | ||||||||||||||||
/// for the i-th combination of the variables. | ||||||||||||||||
static std::optional<std::bitset<8>> | ||||||||||||||||
extractThreeBitTruthTable(Value *Expr, Value *Op0, Value *Op1, Value *Op2) { | ||||||||||||||||
std::bitset<8> Table; | ||||||||||||||||
for (int I = 0; I < 8; I++) { | ||||||||||||||||
bool Val0 = (I >> 2) & 1; | ||||||||||||||||
bool Val1 = (I >> 1) & 1; | ||||||||||||||||
bool Val2 = I & 1; | ||||||||||||||||
std::map<Value *, bool> Values = {{Op0, Val0}, {Op1, Val1}, {Op2, Val2}}; | ||||||||||||||||
auto Result = evaluateBooleanExpression(Expr, Values); | ||||||||||||||||
if (!Result) | ||||||||||||||||
return std::nullopt; | ||||||||||||||||
Table[I] = *Result; | ||||||||||||||||
} | ||||||||||||||||
return Table; | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
/// Try to canonicalize 3-variable boolean expressions using truth table lookup. | ||||||||||||||||
static Value *foldThreeVarBoolExpr(Value *Root, | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should also be called by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the suggestion. I have made the change. I mentioned my original intention in leaving this out initially in my response to the comment on the number of Root Instructions comment. tldr: left it out as visitAnd isn't contributing to issue #97044. I have added it now though in a new commit as it seems reasonable to add it here anyways as future cases will be using it anyways |
||||||||||||||||
InstCombiner::BuilderTy &Builder) { | ||||||||||||||||
// Only proceed if this is a "complex" expression. | ||||||||||||||||
if (!isa<BinaryOperator>(Root)) | ||||||||||||||||
return nullptr; | ||||||||||||||||
|
||||||||||||||||
auto [Op0, Op1, Op2] = extractThreeVariables(Root); | ||||||||||||||||
if (!Op0 || !Op1 || !Op2) | ||||||||||||||||
return nullptr; | ||||||||||||||||
|
||||||||||||||||
auto Table = extractThreeBitTruthTable(Root, Op0, Op1, Op2); | ||||||||||||||||
if (!Table) | ||||||||||||||||
return nullptr; | ||||||||||||||||
|
||||||||||||||||
// Only transform expressions with single use to avoid code growth. | ||||||||||||||||
if (!Root->hasOneUse()) | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general, the number of the root instruction's users doesn't matter. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Generally, yes you are correct this won't matter (ie in the |
||||||||||||||||
return nullptr; | ||||||||||||||||
|
||||||||||||||||
return createLogicFromTable3Var(*Table, Op0, Op1, Op2, Root, Builder, true); | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
/// Emit a computation of: (V >= Lo && V < Hi) if Inside is true, otherwise | ||||||||||||||||
/// (V < Lo || V >= Hi). This method expects that Lo < Hi. IsSigned indicates | ||||||||||||||||
/// whether to treat V, Lo, and Hi as signed or not. | ||||||||||||||||
|
@@ -3776,6 +3974,10 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { | |||||||||||||||
return replaceInstUsesWith(I, V); | ||||||||||||||||
|
||||||||||||||||
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); | ||||||||||||||||
|
||||||||||||||||
if (Value *Canonical = foldThreeVarBoolExpr(&I, Builder)) | ||||||||||||||||
return replaceInstUsesWith(I, Canonical); | ||||||||||||||||
|
||||||||||||||||
Type *Ty = I.getType(); | ||||||||||||||||
if (Ty->isIntOrIntVectorTy(1)) { | ||||||||||||||||
if (auto *SI0 = dyn_cast<SelectInst>(Op0)) { | ||||||||||||||||
|
@@ -5182,6 +5384,9 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { | |||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
if (Value *Canonical = foldThreeVarBoolExpr(&I, Builder)) | ||||||||||||||||
return replaceInstUsesWith(I, Canonical); | ||||||||||||||||
|
||||||||||||||||
if (auto *LHS = dyn_cast<ICmpInst>(I.getOperand(0))) | ||||||||||||||||
if (auto *RHS = dyn_cast<ICmpInst>(I.getOperand(1))) | ||||||||||||||||
if (Value *V = foldXorOfICmps(LHS, RHS, I)) | ||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py | ||
; RUN: opt < %s -passes=instcombine -S | FileCheck %s | ||
; Tests for GitHub issue #97044 - Boolean expression canonicalization | ||
define i32 @test0_4way_or(i32 %x, i32 %y, i32 %z) { | ||
; CHECK-LABEL: @test0_4way_or( | ||
; CHECK-NEXT: [[TMP1:%.*]] = or i32 [[Y:%.*]], [[Z:%.*]] | ||
; CHECK-NEXT: [[TMP2:%.*]] = xor i32 [[TMP1]], [[X:%.*]] | ||
; CHECK-NEXT: [[OR13:%.*]] = xor i32 [[TMP2]], -1 | ||
; CHECK-NEXT: ret i32 [[OR13]] | ||
; | ||
%not = xor i32 %z, -1 | ||
%and = and i32 %y, %not | ||
%and1 = and i32 %and, %x | ||
%not2 = xor i32 %y, -1 | ||
%and3 = and i32 %x, %not2 | ||
%and4 = and i32 %and3, %z | ||
%or = or i32 %and1, %and4 | ||
%not5 = xor i32 %x, -1 | ||
%not6 = xor i32 %y, -1 | ||
%and7 = and i32 %not5, %not6 | ||
%not8 = xor i32 %z, -1 | ||
%and9 = and i32 %and7, %not8 | ||
%or10 = or i32 %or, %and9 | ||
%and11 = and i32 %x, %y | ||
%and12 = and i32 %and11, %z | ||
%or13 = or i32 %or10, %and12 | ||
ret i32 %or13 | ||
} | ||
define i32 @test1_xor_pattern(i32 %x, i32 %y, i32 %z) { | ||
; CHECK-LABEL: @test1_xor_pattern( | ||
; CHECK-NEXT: [[TMP1:%.*]] = or i32 [[Y:%.*]], [[Z:%.*]] | ||
; CHECK-NEXT: [[TMP2:%.*]] = xor i32 [[TMP1]], [[X:%.*]] | ||
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[TMP2]], -1 | ||
; CHECK-NEXT: ret i32 [[XOR]] | ||
; | ||
%not = xor i32 %z, -1 | ||
%and = and i32 %x, %y | ||
%not1 = xor i32 %x, -1 | ||
%not2 = xor i32 %y, -1 | ||
%and3 = and i32 %not1, %not2 | ||
%or = or i32 %and, %and3 | ||
%and4 = and i32 %not, %or | ||
%and5 = and i32 %x, %y | ||
%and6 = and i32 %x, %not2 | ||
%or7 = or i32 %and5, %and6 | ||
%and8 = and i32 %z, %or7 | ||
%xor = xor i32 %and4, %and8 | ||
ret i32 %xor | ||
} | ||
define i32 @test2_nested_xor(i32 %x, i32 %y, i32 %z) { | ||
; CHECK-LABEL: @test2_nested_xor( | ||
; CHECK-NEXT: [[TMP1:%.*]] = or i32 [[Y:%.*]], [[Z:%.*]] | ||
; CHECK-NEXT: [[TMP2:%.*]] = xor i32 [[TMP1]], [[X:%.*]] | ||
; CHECK-NEXT: [[TMP3:%.*]] = xor i32 [[TMP2]], [[Y]] | ||
; CHECK-NEXT: ret i32 [[TMP3]] | ||
; | ||
%and = and i32 %x, %y | ||
%not = xor i32 %x, -1 | ||
%not1 = xor i32 %y, -1 | ||
%and2 = and i32 %not, %not1 | ||
%or = or i32 %and, %and2 | ||
%and3 = and i32 %x, %y | ||
%not4 = xor i32 %y, -1 | ||
%and5 = and i32 %x, %not4 | ||
%or6 = or i32 %and3, %and5 | ||
%xor = xor i32 %or, %or6 | ||
%not7 = xor i32 %y, -1 | ||
%and8 = and i32 %z, %not7 | ||
%and9 = and i32 %xor, %and8 | ||
%xor10 = xor i32 %or, %and9 | ||
%xor11 = xor i32 %xor10, %y | ||
%xor12 = xor i32 %xor11, -1 | ||
ret i32 %xor12 | ||
} | ||
define i32 @test3_already_optimal(i32 %x, i32 %y, i32 %z) { | ||
; CHECK-LABEL: @test3_already_optimal( | ||
; CHECK-NEXT: [[OR:%.*]] = or i32 [[Y:%.*]], [[Z:%.*]] | ||
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[OR]], [[X:%.*]] | ||
; CHECK-NEXT: [[NOT:%.*]] = xor i32 [[XOR]], -1 | ||
; CHECK-NEXT: ret i32 [[NOT]] | ||
; | ||
%or = or i32 %y, %z | ||
%xor = xor i32 %or, %x | ||
%not = xor i32 %xor, -1 | ||
ret i32 %not | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic should be moved into the caller.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the suggestion, this has been done now. You are correct, the caller would be a better place to have this