diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index b231c04319106..563cc25b5463a 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -19,6 +19,8 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" #include "llvm/Transforms/Utils/Local.h" +#include +#include using namespace llvm; using namespace PatternMatch; @@ -47,6 +49,202 @@ static Value *getFCmpValue(unsigned Code, Value *LHS, Value *RHS, return Builder.CreateFCmpFMF(NewPred, LHS, RHS, FMF); } +/// This is to create optimal 3-variable boolean logic from truth tables. +/// currently it supports the cases pertaining to the issue 97044. More cases +/// can be added based on real-world justification for specific 3 input cases +/// or with reviewer approval all 256 cases can be added (choose the +/// canonicalizations found +/// in x86InstCombine.cpp?) +static Value *createLogicFromTable3Var(const std::bitset<8> &Table, Value *Op0, + Value *Op1, Value *Op2, Value *Root, + IRBuilderBase &Builder, bool HasOneUse) { + uint8_t TruthValue = Table.to_ulong(); + + // Skip transformation if expression is already simple (at most 2 levels + // deep). + if (Root->hasOneUse() && isa(Root)) { + if (auto *BO = dyn_cast(Root)) { + bool IsSimple = !isa(BO->getOperand(0)) || + !isa(BO->getOperand(1)); + if (IsSimple) + return nullptr; + } + } + + auto FoldConstant = [&](bool Val) { + Constant *Res = Val ? Builder.getTrue() : Builder.getFalse(); + if (Op0->getType()->isVectorTy()) + Res = ConstantVector::getSplat( + cast(Op0->getType())->getElementCount(), Res); + return Res; + }; + + Value *Result = nullptr; + switch (TruthValue) { + default: + return nullptr; + + case 0x00: // Always FALSE + Result = FoldConstant(false); + break; + + case 0xFF: // Always TRUE + Result = FoldConstant(true); + break; + + case 0xE1: // ~((Op1 | Op2) ^ Op0) + if (!HasOneUse) + return nullptr; + { + Value *Or = Builder.CreateOr(Op1, Op2); + Value *Xor = Builder.CreateXor(Or, Op0); + Result = Builder.CreateNot(Xor); + } + break; + + case 0x60: // Op0 & (Op1 ^ Op2) + if (!HasOneUse) + return nullptr; + { + Value *Xor = Builder.CreateXor(Op1, Op2); + Result = Builder.CreateAnd(Op0, Xor); + } + break; + + case 0xD2: // ((Op1 | Op2) ^ Op0) ^ Op1 + if (!HasOneUse) + return nullptr; + { + Value *Or = Builder.CreateOr(Op1, Op2); + Value *Xor1 = Builder.CreateXor(Or, Op0); + Result = Builder.CreateXor(Xor1, Op1); + } + break; + } + + return Result; +} + +static std::tuple +extractThreeVariables(Value *Root) { + std::set Variables; + unsigned NodeCount = 0; + const unsigned MaxNodes = + 50; // To prevent exponential blowup (see bitwise-hang.ll) + + std::function Collect = [&](Value *V) { + if (++NodeCount > MaxNodes) + return; + + Value *NotV; + if (match(V, m_Not(m_Value(NotV)))) { + Collect(NotV); + return; + } + if (auto *BO = dyn_cast(V)) { + Collect(BO->getOperand(0)); + Collect(BO->getOperand(1)); + } else if (isa(V) || isa(V)) { + if (!isa(V) && V != Root) { + Variables.insert(V); + } + } + }; + + Collect(Root); + + // Bail if we hit the node limit + if (NodeCount > MaxNodes) + return {nullptr, nullptr, nullptr}; + + if (Variables.size() == 3) { + auto It = Variables.begin(); + Value *Op0 = *It++; + Value *Op1 = *It++; + Value *Op2 = *It; + return {Op0, Op1, Op2}; + } + return {nullptr, nullptr, nullptr}; +} + +/// Evaluate a boolean expression with concrete variable values. +static std::optional +evaluateBooleanExpression(Value *Expr, const std::map &Values) { + if (auto It = Values.find(Expr); It != Values.end()) { + return It->second; + } + Value *NotExpr; + if (match(Expr, m_Not(m_Value(NotExpr)))) { + auto Operand = evaluateBooleanExpression(NotExpr, Values); + if (Operand) + return !*Operand; + return std::nullopt; + } + if (auto *BO = dyn_cast(Expr)) { + auto LHS = evaluateBooleanExpression(BO->getOperand(0), Values); + auto RHS = evaluateBooleanExpression(BO->getOperand(1), Values); + if (!LHS || !RHS) + return std::nullopt; + + switch (BO->getOpcode()) { + case Instruction::And: + return *LHS && *RHS; + case Instruction::Or: + return *LHS || *RHS; + case Instruction::Xor: + return *LHS != *RHS; + default: + return std::nullopt; + } + } + return std::nullopt; +} + +/// Extracts the truth table from a 3-variable boolean expression. +/// The truth table is a 8-bit integer where each bit corresponds to a possible +/// combination of the three variables. +/// The bits are ordered as follows: +/// 000, 001, 010, 011, 100, 101, 110, 111 +/// The result is a bitset where the i-th bit is set if the expression is true +/// for the i-th combination of the variables. +static std::optional> +extractThreeBitTruthTable(Value *Expr, Value *Op0, Value *Op1, Value *Op2) { + std::bitset<8> Table; + for (int I = 0; I < 8; I++) { + bool Val0 = (I >> 2) & 1; + bool Val1 = (I >> 1) & 1; + bool Val2 = I & 1; + std::map Values = {{Op0, Val0}, {Op1, Val1}, {Op2, Val2}}; + auto Result = evaluateBooleanExpression(Expr, Values); + if (!Result) + return std::nullopt; + Table[I] = *Result; + } + return Table; +} + +/// Try to canonicalize 3-variable boolean expressions using truth table lookup. +static Value *foldThreeVarBoolExpr(Value *Root, + InstCombiner::BuilderTy &Builder) { + // Only proceed if this is a "complex" expression. + if (!isa(Root)) + return nullptr; + + auto [Op0, Op1, Op2] = extractThreeVariables(Root); + if (!Op0 || !Op1 || !Op2) + return nullptr; + + auto Table = extractThreeBitTruthTable(Root, Op0, Op1, Op2); + if (!Table) + return nullptr; + + // Only transform expressions with single use to avoid code growth. + if (!Root->hasOneUse()) + return nullptr; + + return createLogicFromTable3Var(*Table, Op0, Op1, Op2, Root, Builder, true); +} + /// Emit a computation of: (V >= Lo && V < Hi) if Inside is true, otherwise /// (V < Lo || V >= Hi). This method expects that Lo < Hi. IsSigned indicates /// whether to treat V, Lo, and Hi as signed or not. @@ -3776,6 +3974,10 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { return replaceInstUsesWith(I, V); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + if (Value *Canonical = foldThreeVarBoolExpr(&I, Builder)) + return replaceInstUsesWith(I, Canonical); + Type *Ty = I.getType(); if (Ty->isIntOrIntVectorTy(1)) { if (auto *SI0 = dyn_cast(Op0)) { @@ -5182,6 +5384,9 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { } } + if (Value *Canonical = foldThreeVarBoolExpr(&I, Builder)) + return replaceInstUsesWith(I, Canonical); + if (auto *LHS = dyn_cast(I.getOperand(0))) if (auto *RHS = dyn_cast(I.getOperand(1))) if (Value *V = foldXorOfICmps(LHS, RHS, I)) diff --git a/llvm/test/Transforms/InstCombine/pr97044.ll b/llvm/test/Transforms/InstCombine/pr97044.ll new file mode 100644 index 0000000000000..9c9bf9aface25 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/pr97044.ll @@ -0,0 +1,86 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=instcombine -S | FileCheck %s +; Tests for GitHub issue #97044 - Boolean expression canonicalization +define i32 @test0_4way_or(i32 %x, i32 %y, i32 %z) { +; CHECK-LABEL: @test0_4way_or( +; CHECK-NEXT: [[TMP1:%.*]] = or i32 [[Y:%.*]], [[Z:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = xor i32 [[TMP1]], [[X:%.*]] +; CHECK-NEXT: [[OR13:%.*]] = xor i32 [[TMP2]], -1 +; CHECK-NEXT: ret i32 [[OR13]] +; + %not = xor i32 %z, -1 + %and = and i32 %y, %not + %and1 = and i32 %and, %x + %not2 = xor i32 %y, -1 + %and3 = and i32 %x, %not2 + %and4 = and i32 %and3, %z + %or = or i32 %and1, %and4 + %not5 = xor i32 %x, -1 + %not6 = xor i32 %y, -1 + %and7 = and i32 %not5, %not6 + %not8 = xor i32 %z, -1 + %and9 = and i32 %and7, %not8 + %or10 = or i32 %or, %and9 + %and11 = and i32 %x, %y + %and12 = and i32 %and11, %z + %or13 = or i32 %or10, %and12 + ret i32 %or13 +} +define i32 @test1_xor_pattern(i32 %x, i32 %y, i32 %z) { +; CHECK-LABEL: @test1_xor_pattern( +; CHECK-NEXT: [[TMP1:%.*]] = or i32 [[Y:%.*]], [[Z:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = xor i32 [[TMP1]], [[X:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[TMP2]], -1 +; CHECK-NEXT: ret i32 [[XOR]] +; + %not = xor i32 %z, -1 + %and = and i32 %x, %y + %not1 = xor i32 %x, -1 + %not2 = xor i32 %y, -1 + %and3 = and i32 %not1, %not2 + %or = or i32 %and, %and3 + %and4 = and i32 %not, %or + %and5 = and i32 %x, %y + %and6 = and i32 %x, %not2 + %or7 = or i32 %and5, %and6 + %and8 = and i32 %z, %or7 + %xor = xor i32 %and4, %and8 + ret i32 %xor +} +define i32 @test2_nested_xor(i32 %x, i32 %y, i32 %z) { +; CHECK-LABEL: @test2_nested_xor( +; CHECK-NEXT: [[TMP1:%.*]] = or i32 [[Y:%.*]], [[Z:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = xor i32 [[TMP1]], [[X:%.*]] +; CHECK-NEXT: [[TMP3:%.*]] = xor i32 [[TMP2]], [[Y]] +; CHECK-NEXT: ret i32 [[TMP3]] +; + %and = and i32 %x, %y + %not = xor i32 %x, -1 + %not1 = xor i32 %y, -1 + %and2 = and i32 %not, %not1 + %or = or i32 %and, %and2 + %and3 = and i32 %x, %y + %not4 = xor i32 %y, -1 + %and5 = and i32 %x, %not4 + %or6 = or i32 %and3, %and5 + %xor = xor i32 %or, %or6 + %not7 = xor i32 %y, -1 + %and8 = and i32 %z, %not7 + %and9 = and i32 %xor, %and8 + %xor10 = xor i32 %or, %and9 + %xor11 = xor i32 %xor10, %y + %xor12 = xor i32 %xor11, -1 + ret i32 %xor12 +} +define i32 @test3_already_optimal(i32 %x, i32 %y, i32 %z) { +; CHECK-LABEL: @test3_already_optimal( +; CHECK-NEXT: [[OR:%.*]] = or i32 [[Y:%.*]], [[Z:%.*]] +; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[OR]], [[X:%.*]] +; CHECK-NEXT: [[NOT:%.*]] = xor i32 [[XOR]], -1 +; CHECK-NEXT: ret i32 [[NOT]] +; + %or = or i32 %y, %z + %xor = xor i32 %or, %x + %not = xor i32 %xor, -1 + ret i32 %not +}