Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
cf2f9db
[InstCombine] Add pre-commit tests for boolean canonicalization (NFC)
yafet-a Jul 19, 2025
3a55b19
[InstCombine] Optimised expressions in issue #97044
yafet-a Jul 21, 2025
02807e3
3 input handled via truth table
yafet-a Jul 29, 2025
af90743
Merge branch 'main' into users/yafet-a/boolean-optimisation
yafet-a Jul 29, 2025
d066a85
Move simple expression check to caller
yafet-a Aug 7, 2025
4c86e54
removed recursion + smallptrset used
yafet-a Aug 7, 2025
7a2dc67
moved calls to consistent location in each visit function + added cal…
yafet-a Aug 7, 2025
6772db5
traverse only if node belongs to expr tree
yafet-a Aug 11, 2025
5405486
Refactor
yafet-a Aug 11, 2025
cb1e164
Merge branch 'main' into users/yafet-a/boolean-optimisation
yafet-a Aug 11, 2025
650a7ab
check for instructions being in the same bb to avoid comesBefore cros…
yafet-a Aug 13, 2025
18e576e
review (batch evaluation, refactors)
yafet-a Aug 14, 2025
28d4a0f
Add negative tests
yafet-a Aug 14, 2025
1fee55f
correctly checking for vars in same bb in extractThreeVariables()
yafet-a Aug 14, 2025
a39a3b4
reuse visited set in extractThreeVariables
yafet-a Aug 15, 2025
23feb15
multi-use tests + negative tests with and/or Var, Const nodes
yafet-a Aug 15, 2025
1a94bba
Computed Map validation in extractThreeVar
yafet-a Aug 15, 2025
d19190d
Pass instructions by reference instead of returning vectors
yafet-a Aug 19, 2025
9296d9b
early check for invalid num of variables
yafet-a Aug 20, 2025
48bd1ca
Improved sorting
yafet-a Aug 20, 2025
464d95e
treat non-bitwise ops as leaf nodes with use-count heuristic
yafet-a Aug 20, 2025
2a905fb
format
yafet-a Aug 20, 2025
fc2aac4
format-2
yafet-a Aug 20, 2025
d557827
validate no cross-BB instruction order comparison for computation ins…
yafet-a Aug 26, 2025
5c40046
(NFC: Styling) + Structural Similarity Check for loop
yafet-a Aug 29, 2025
7bf8caf
Traverse root operands to avoid treating them as leaf variables
yafet-a Aug 29, 2025
abd628d
NFC: negative test for treating root operands as leaf variables
yafet-a Aug 29, 2025
64fbafd
style: header comments
yafet-a Sep 1, 2025
ecd9669
[Tests] vector tests
yafet-a Sep 1, 2025
2b18e01
[NIT] improving header comments
yafet-a Sep 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include "llvm/IR/PatternMatch.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include "llvm/Transforms/Utils/Local.h"
#include <bitset>
#include <map>

using namespace llvm;
using namespace PatternMatch;
Expand Down Expand Up @@ -47,6 +49,202 @@ static Value *getFCmpValue(unsigned Code, Value *LHS, Value *RHS,
return Builder.CreateFCmpFMF(NewPred, LHS, RHS, FMF);
}

/// This is to create optimal 3-variable boolean logic from truth tables.
/// currently it supports the cases pertaining to the issue 97044. More cases
/// can be added based on real-world justification for specific 3 input cases
/// or with reviewer approval all 256 cases can be added (choose the
/// canonicalizations found
/// in x86InstCombine.cpp?)
static Value *createLogicFromTable3Var(const std::bitset<8> &Table, Value *Op0,
Value *Op1, Value *Op2, Value *Root,
IRBuilderBase &Builder, bool HasOneUse) {
uint8_t TruthValue = Table.to_ulong();

// Skip transformation if expression is already simple (at most 2 levels
// deep).
if (Root->hasOneUse() && isa<BinaryOperator>(Root)) {
if (auto *BO = dyn_cast<BinaryOperator>(Root)) {
bool IsSimple = !isa<BinaryOperator>(BO->getOperand(0)) ||
!isa<BinaryOperator>(BO->getOperand(1));
if (IsSimple)
return nullptr;
}
}

auto FoldConstant = [&](bool Val) {
Constant *Res = Val ? Builder.getTrue() : Builder.getFalse();
if (Op0->getType()->isVectorTy())
Res = ConstantVector::getSplat(
cast<VectorType>(Op0->getType())->getElementCount(), Res);
return Res;
};

Value *Result = nullptr;
switch (TruthValue) {
default:
return nullptr;

case 0x00: // Always FALSE
Result = FoldConstant(false);
break;

case 0xFF: // Always TRUE
Result = FoldConstant(true);
break;

case 0xE1: // ~((Op1 | Op2) ^ Op0)
if (!HasOneUse)
return nullptr;
{
Value *Or = Builder.CreateOr(Op1, Op2);
Value *Xor = Builder.CreateXor(Or, Op0);
Result = Builder.CreateNot(Xor);
}
break;

case 0x60: // Op0 & (Op1 ^ Op2)
if (!HasOneUse)
return nullptr;
{
Value *Xor = Builder.CreateXor(Op1, Op2);
Result = Builder.CreateAnd(Op0, Xor);
}
break;

case 0xD2: // ((Op1 | Op2) ^ Op0) ^ Op1
if (!HasOneUse)
return nullptr;
{
Value *Or = Builder.CreateOr(Op1, Op2);
Value *Xor1 = Builder.CreateXor(Or, Op0);
Result = Builder.CreateXor(Xor1, Op1);
}
break;
}

return Result;
}

static std::tuple<Value *, Value *, Value *>
extractThreeVariables(Value *Root) {
std::set<Value *> Variables;
unsigned NodeCount = 0;
const unsigned MaxNodes =
50; // To prevent exponential blowup (see bitwise-hang.ll)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bitwise-hang.ll is missing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this was a typo. Was meaning to reference bitreverse-hang.ll. Comment has been updated now


std::function<void(Value *)> Collect = [&](Value *V) {
if (++NodeCount > MaxNodes)
return;

Value *NotV;
if (match(V, m_Not(m_Value(NotV)))) {
Collect(NotV);
return;
}
if (auto *BO = dyn_cast<BinaryOperator>(V)) {
Collect(BO->getOperand(0));
Collect(BO->getOperand(1));
} else if (isa<Argument>(V) || isa<Instruction>(V)) {
if (!isa<Constant>(V) && V != Root) {
Variables.insert(V);
}
}
};

Collect(Root);

// Bail if we hit the node limit
if (NodeCount > MaxNodes)
return {nullptr, nullptr, nullptr};

if (Variables.size() == 3) {
auto It = Variables.begin();
Value *Op0 = *It++;
Value *Op1 = *It++;
Value *Op2 = *It;
return {Op0, Op1, Op2};
}
return {nullptr, nullptr, nullptr};
}

/// Evaluate a boolean expression with concrete variable values.
static std::optional<bool>
evaluateBooleanExpression(Value *Expr, const std::map<Value *, bool> &Values) {
if (auto It = Values.find(Expr); It != Values.end()) {
return It->second;
}
Value *NotExpr;
if (match(Expr, m_Not(m_Value(NotExpr)))) {
auto Operand = evaluateBooleanExpression(NotExpr, Values);
if (Operand)
return !*Operand;
return std::nullopt;
}
if (auto *BO = dyn_cast<BinaryOperator>(Expr)) {
auto LHS = evaluateBooleanExpression(BO->getOperand(0), Values);
auto RHS = evaluateBooleanExpression(BO->getOperand(1), Values);
if (!LHS || !RHS)
return std::nullopt;

switch (BO->getOpcode()) {
case Instruction::And:
return *LHS && *RHS;
case Instruction::Or:
return *LHS || *RHS;
case Instruction::Xor:
return *LHS != *RHS;
default:
return std::nullopt;
}
}
return std::nullopt;
}

/// Extracts the truth table from a 3-variable boolean expression.
/// The truth table is a 8-bit integer where each bit corresponds to a possible
/// combination of the three variables.
/// The bits are ordered as follows:
/// 000, 001, 010, 011, 100, 101, 110, 111
/// The result is a bitset where the i-th bit is set if the expression is true
/// for the i-th combination of the variables.
static std::optional<std::bitset<8>>
extractThreeBitTruthTable(Value *Expr, Value *Op0, Value *Op1, Value *Op2) {
std::bitset<8> Table;
for (int I = 0; I < 8; I++) {
bool Val0 = (I >> 2) & 1;
bool Val1 = (I >> 1) & 1;
bool Val2 = I & 1;
std::map<Value *, bool> Values = {{Op0, Val0}, {Op1, Val1}, {Op2, Val2}};
auto Result = evaluateBooleanExpression(Expr, Values);
if (!Result)
return std::nullopt;
Table[I] = *Result;
}
return Table;
}

/// Try to canonicalize 3-variable boolean expressions using truth table lookup.
static Value *foldThreeVarBoolExpr(Value *Root,
InstCombiner::BuilderTy &Builder) {
// Only proceed if this is a "complex" expression.
if (!isa<BinaryOperator>(Root))
return nullptr;

auto [Op0, Op1, Op2] = extractThreeVariables(Root);
if (!Op0 || !Op1 || !Op2)
return nullptr;

auto Table = extractThreeBitTruthTable(Root, Op0, Op1, Op2);
if (!Table)
return nullptr;

// Only transform expressions with single use to avoid code growth.
if (!Root->hasOneUse())
return nullptr;

return createLogicFromTable3Var(*Table, Op0, Op1, Op2, Root, Builder, true);
}

/// Emit a computation of: (V >= Lo && V < Hi) if Inside is true, otherwise
/// (V < Lo || V >= Hi). This method expects that Lo < Hi. IsSigned indicates
/// whether to treat V, Lo, and Hi as signed or not.
Expand Down Expand Up @@ -3776,6 +3974,10 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
return replaceInstUsesWith(I, V);

Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);

if (Value *Canonical = foldThreeVarBoolExpr(&I, Builder))
return replaceInstUsesWith(I, Canonical);

Type *Ty = I.getType();
if (Ty->isIntOrIntVectorTy(1)) {
if (auto *SI0 = dyn_cast<SelectInst>(Op0)) {
Expand Down Expand Up @@ -5182,6 +5384,9 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
}
}

if (Value *Canonical = foldThreeVarBoolExpr(&I, Builder))
return replaceInstUsesWith(I, Canonical);

if (auto *LHS = dyn_cast<ICmpInst>(I.getOperand(0)))
if (auto *RHS = dyn_cast<ICmpInst>(I.getOperand(1)))
if (Value *V = foldXorOfICmps(LHS, RHS, I))
Expand Down
86 changes: 86 additions & 0 deletions llvm/test/Transforms/InstCombine/pr97044.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
; Tests for GitHub issue #97044 - Boolean expression canonicalization
define i32 @test0_4way_or(i32 %x, i32 %y, i32 %z) {
; CHECK-LABEL: @test0_4way_or(
; CHECK-NEXT: [[TMP1:%.*]] = or i32 [[Y:%.*]], [[Z:%.*]]
; CHECK-NEXT: [[TMP2:%.*]] = xor i32 [[TMP1]], [[X:%.*]]
; CHECK-NEXT: [[OR13:%.*]] = xor i32 [[TMP2]], -1
; CHECK-NEXT: ret i32 [[OR13]]
;
%not = xor i32 %z, -1
%and = and i32 %y, %not
%and1 = and i32 %and, %x
%not2 = xor i32 %y, -1
%and3 = and i32 %x, %not2
%and4 = and i32 %and3, %z
%or = or i32 %and1, %and4
%not5 = xor i32 %x, -1
%not6 = xor i32 %y, -1
%and7 = and i32 %not5, %not6
%not8 = xor i32 %z, -1
%and9 = and i32 %and7, %not8
%or10 = or i32 %or, %and9
%and11 = and i32 %x, %y
%and12 = and i32 %and11, %z
%or13 = or i32 %or10, %and12
ret i32 %or13
}
define i32 @test1_xor_pattern(i32 %x, i32 %y, i32 %z) {
; CHECK-LABEL: @test1_xor_pattern(
; CHECK-NEXT: [[TMP1:%.*]] = or i32 [[Y:%.*]], [[Z:%.*]]
; CHECK-NEXT: [[TMP2:%.*]] = xor i32 [[TMP1]], [[X:%.*]]
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[TMP2]], -1
; CHECK-NEXT: ret i32 [[XOR]]
;
%not = xor i32 %z, -1
%and = and i32 %x, %y
%not1 = xor i32 %x, -1
%not2 = xor i32 %y, -1
%and3 = and i32 %not1, %not2
%or = or i32 %and, %and3
%and4 = and i32 %not, %or
%and5 = and i32 %x, %y
%and6 = and i32 %x, %not2
%or7 = or i32 %and5, %and6
%and8 = and i32 %z, %or7
%xor = xor i32 %and4, %and8
ret i32 %xor
}
define i32 @test2_nested_xor(i32 %x, i32 %y, i32 %z) {
; CHECK-LABEL: @test2_nested_xor(
; CHECK-NEXT: [[TMP1:%.*]] = or i32 [[Y:%.*]], [[Z:%.*]]
; CHECK-NEXT: [[TMP2:%.*]] = xor i32 [[TMP1]], [[X:%.*]]
; CHECK-NEXT: [[TMP3:%.*]] = xor i32 [[TMP2]], [[Y]]
; CHECK-NEXT: ret i32 [[TMP3]]
;
%and = and i32 %x, %y
%not = xor i32 %x, -1
%not1 = xor i32 %y, -1
%and2 = and i32 %not, %not1
%or = or i32 %and, %and2
%and3 = and i32 %x, %y
%not4 = xor i32 %y, -1
%and5 = and i32 %x, %not4
%or6 = or i32 %and3, %and5
%xor = xor i32 %or, %or6
%not7 = xor i32 %y, -1
%and8 = and i32 %z, %not7
%and9 = and i32 %xor, %and8
%xor10 = xor i32 %or, %and9
%xor11 = xor i32 %xor10, %y
%xor12 = xor i32 %xor11, -1
ret i32 %xor12
}
define i32 @test3_already_optimal(i32 %x, i32 %y, i32 %z) {
; CHECK-LABEL: @test3_already_optimal(
; CHECK-NEXT: [[OR:%.*]] = or i32 [[Y:%.*]], [[Z:%.*]]
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[OR]], [[X:%.*]]
; CHECK-NEXT: [[NOT:%.*]] = xor i32 [[XOR]], -1
; CHECK-NEXT: ret i32 [[NOT]]
;
%or = or i32 %y, %z
%xor = xor i32 %or, %x
%not = xor i32 %xor, -1
ret i32 %not
}