Skip to content

[Branch Hinting] Add branch hint handling in RemoveUnusedBrs #7706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions src/ir/branch-hints.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Copyright 2025 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef wasm_ir_branch_hint_h
#define wasm_ir_branch_hint_h

#include "wasm.h"

//
// Branch hint utilities to get them, set, flip, etc.
//

namespace wasm::BranchHints {

// Get the branch hint for an expression.
inline std::optional<bool> get(Expression* expr, Function* func) {
auto iter = func->codeAnnotations.find(expr);
if (iter == func->codeAnnotations.end()) {
// No annotations at all.
return {};
}
return iter->second.branchLikely;
}

// Set the branch hint for an expression, trampling anything existing before.
inline void set(Expression* expr, std::optional<bool> likely, Function* func) {
// When we are writing an empty hint, do not create an empty annotation if one
// did not exist.
if (!likely && !func->codeAnnotations.count(expr)) {
return;
}
func->codeAnnotations[expr].branchLikely = likely;
}

// Clear the branch hint for an expression.
inline void clear(Expression* expr, Function* func) {
// Do not create an empty annotation if one did not exist.
auto iter = func->codeAnnotations.find(expr);
if (iter == func->codeAnnotations.end()) {
return;
}
iter->second.branchLikely = {};
}

// Copy the branch hint for an expression to another, trampling anything
// existing before for the latter.
inline void copyTo(Expression* from, Expression* to, Function* func) {
auto fromLikely = get(from, func);
set(to, fromLikely, func);
}

// Flip the branch hint for an expression (if it exists).
inline void flip(Expression* expr, Function* func) {
if (auto likely = get(expr, func)) {
set(expr, !*likely, func);
}
}

// Copy the branch hint for an expression to another, flipping it while we do
// so.
inline void copyFlippedTo(Expression* from, Expression* to, Function* func) {
copyTo(from, to, func);
flip(to, func);
}

// Given two expressions to read from, apply the AND hint to a target. That is,
// the target will be true when both inputs are true. |to| may be equal to
// |from1| or |from2|. The hint of |to| is trampled.
inline void applyAndTo(Expression* from1,
Expression* from2,
Expression* to,
Function* func) {
// If from1 and from2 are both likely, then from1 && from2 is slightly less
// likely, but we assume our hints are nearly certain, so we apply it. And,
// converse, if from1 and from2 and both unlikely, then from1 && from2 is even
// less likely, so we can once more apply a hint. If the hints differ, then
// one is unlikely or unknown, and we can't say anything about from1 && from2.
auto from1Hint = BranchHints::get(from1, func);
auto from2Hint = BranchHints::get(from2, func);
if (from1Hint == from2Hint) {
set(to, from1Hint, func);
} else {
// The hints do not even match.
BranchHints::clear(to, func);
}
}

// As |applyAndTo|, but now the condition on |to| the OR of |from1| and |from2|.
inline void applyOrTo(Expression* from1,
Expression* from2,
Expression* to,
Function* func) {
Comment on lines +101 to +105
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the name and doc comment, I would have expected this to set to as likely if from1 is likely or from2 is likely. I would not have expected it to reason about branch probabilities at all. Can we move this reasoning up into the callers so we don't have such a misleading method name?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving the logic into the callers would require duplicating the large internal comment - it is not trivial how to handle this, so a helper seems natural?

How about applyCombinedOrConditionTo?

// If one is likely then so is the from1 || from2. If both are unlikely then
// from1 || from2 is slightly more likely, but we assume our hints are nearly
// certain, so we apply it.
auto from1Hint = BranchHints::get(from1, func);
auto from2Hint = BranchHints::get(from2, func);
if ((from1Hint && *from1Hint) || (from2Hint && *from2Hint)) {
set(to, true, func);
} else if (from1Hint && from2Hint) {
// We ruled out that either one is present and true, so if both are present,
// both must be false.
assert(!*from1Hint && !*from2Hint);
set(to, false, func);
} else {
// We don't know.
BranchHints::clear(to, func);
}
}

} // namespace wasm::BranchHints

#endif // wasm_ir_branch_hint_h
36 changes: 30 additions & 6 deletions src/passes/RemoveUnusedBrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
// Removes branches for which we go to where they go anyhow
//

#include "ir/branch-hints.h"
#include "ir/branch-utils.h"
#include "ir/cost.h"
#include "ir/drop.h"
Expand Down Expand Up @@ -396,6 +397,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
curr->condition, br->value, getPassOptions(), *getModule())) {
if (!br->condition) {
br->condition = curr->condition;
BranchHints::copyTo(curr, br, getFunction());
} else {
// In this case we can replace
// if (condition1) br_if (condition2)
Expand Down Expand Up @@ -427,6 +429,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
// That keeps the order of the two conditions as it was originally.
br->condition =
builder.makeSelect(br->condition, curr->condition, zero);
BranchHints::applyAndTo(curr, br, br, getFunction());
}
br->finalize();
replaceCurrent(Builder(*getModule()).dropIfConcretelyTyped(br));
Expand Down Expand Up @@ -459,6 +462,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
Builder builder(*getModule());
curr->condition = builder.makeSelect(
child->condition, curr->condition, builder.makeConst(int32_t(0)));
BranchHints::applyAndTo(curr, child, curr, getFunction());
curr->ifTrue = child->ifTrue;
}
}
Expand Down Expand Up @@ -689,6 +693,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
brIf->condition = builder.makeUnary(EqZInt32, brIf->condition);
last->name = brIf->name;
brIf->name = loop->name;
BranchHints::flip(brIf, getFunction());
return true;
} else {
// there are elements in the middle,
Expand All @@ -709,6 +714,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
builder.makeIf(brIf->condition,
builder.makeBreak(brIf->name),
stealSlice(builder, block, i + 1, list.size()));
BranchHints::copyTo(brIf, list[i], getFunction());
block->finalize();
return true;
}
Expand Down Expand Up @@ -1210,6 +1216,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
// we are an if-else where the ifTrue is a break without a
// condition, so we can do this
ifTrueBreak->condition = iff->condition;
BranchHints::copyTo(iff, ifTrueBreak, getFunction());
ifTrueBreak->finalize();
list[i] = Builder(*getModule()).dropIfConcretelyTyped(ifTrueBreak);
ExpressionManipulator::spliceIntoBlock(curr, i + 1, iff->ifFalse);
Expand All @@ -1224,6 +1231,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
*getModule())) {
ifFalseBreak->condition =
Builder(*getModule()).makeUnary(EqZInt32, iff->condition);
BranchHints::copyFlippedTo(iff, ifFalseBreak, getFunction());
ifFalseBreak->finalize();
list[i] = Builder(*getModule()).dropIfConcretelyTyped(ifFalseBreak);
ExpressionManipulator::spliceIntoBlock(curr, i + 1, iff->ifTrue);
Expand Down Expand Up @@ -1256,7 +1264,9 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
Builder builder(*getModule());
br1->condition =
builder.makeBinary(OrInt32, br1->condition, br2->condition);
BranchHints::applyOrTo(br1, br2, br1, getFunction());
ExpressionManipulator::nop(br2);
BranchHints::clear(br2, getFunction());
}
}
} else {
Expand Down Expand Up @@ -1396,9 +1406,12 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
// no other breaks to that name, so we can do this
if (!drop) {
assert(!br->value);
replaceCurrent(builder.makeIf(
builder.makeUnary(EqZInt32, br->condition), curr));
auto* iff = builder.makeIf(
builder.makeUnary(EqZInt32, br->condition), curr);
replaceCurrent(iff);
BranchHints::copyFlippedTo(br, iff, getFunction());
ExpressionManipulator::nop(br);
BranchHints::clear(br, getFunction());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if we are left with a branch hint on a nop?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The binary reader and writer ignore such things, but the text printer is simpler and will print it. This just avoids it being printed out.

curr->finalize(curr->type);
} else {
// To use an if, the value must have no side effects, as in the
Expand All @@ -1409,8 +1422,9 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
if (EffectAnalyzer::canReorder(
passOptions, *getModule(), br->condition, br->value)) {
ExpressionManipulator::nop(list[0]);
replaceCurrent(
builder.makeIf(br->condition, br->value, curr));
auto* iff = builder.makeIf(br->condition, br->value, curr);
BranchHints::copyTo(br, iff, getFunction());
replaceCurrent(iff);
}
} else {
// The value has side effects, so it must always execute. We
Expand Down Expand Up @@ -1529,6 +1543,14 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
optimizeSetIf(getCurrentPointer());
}

// Flip an if's condition with an eqz, and flip its arms.
void flip(If* iff) {
std::swap(iff->ifTrue, iff->ifFalse);
iff->condition =
Builder(*getModule()).makeUnary(EqZInt32, iff->condition);
BranchHints::flip(iff, getFunction());
}

void optimizeSetIf(Expression** currp) {
if (optimizeSetIfWithBrArm(currp)) {
return;
Expand Down Expand Up @@ -1570,9 +1592,10 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
// Wonderful, do it!
Builder builder(*getModule());
if (flipCondition) {
builder.flip(iff);
flip(iff);
}
br->condition = iff->condition;
BranchHints::copyTo(iff, br, getFunction());
br->finalize();
set->value = two;
auto* block = builder.makeSequence(br, set);
Expand Down Expand Up @@ -1640,7 +1663,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
Builder builder(*getModule());
LocalGet* get = iff->ifTrue->dynCast<LocalGet>();
if (get && get->index == set->index) {
builder.flip(iff);
flip(iff);
} else {
get = iff->ifFalse->dynCast<LocalGet>();
if (get && get->index != set->index) {
Expand Down Expand Up @@ -1901,6 +1924,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
curr->type = Type::unreachable;
block->list.push_back(curr);
block->finalize();
BranchHints::clear(curr, getFunction());
// The type changed, so refinalize.
refinalize = true;
} else {
Expand Down
5 changes: 0 additions & 5 deletions src/wasm-builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1481,11 +1481,6 @@ class Builder {
return makeDrop(curr);
}

void flip(If* iff) {
std::swap(iff->ifTrue, iff->ifFalse);
iff->condition = makeUnary(EqZInt32, iff->condition);
}

// Returns a replacement with the precise same type, and with minimal contents
// as best we can. As a replacement, this may reuse the input node.
template<typename T> Expression* replaceWithIdenticalType(T* curr) {
Expand Down
Loading
Loading