Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions xls/ir/interval_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,31 @@ TernaryVector ExtractTernaryInterval(const Interval& interval) {
}
return result;
}
// Returns the maximum popcount of any value less than or equal to `limit`.
int64_t MaxPopCountLessThanOrEqualTo(const Bits& limit) {
if (limit.IsZero()) {
return 0;
}
if (limit.IsAllOnes()) {
return limit.bit_count();
}

// Find the highest set bit to determine the "active" range seen by the limit.
int64_t highest_set_bit = limit.bit_count() - 1 - limit.CountLeadingZeros();
int64_t significant_bits = highest_set_bit + 1;

// If the limit is of the form 2^significant_bits - 1 (i.e. all bits up to
// the highest set bit are 1), then the limit itself is the maximum.
if (limit.Slice(0, significant_bits).IsAllOnes()) {
return significant_bits;
}

// Otherwise, we can flip the highest set bit to 0 and set all lower bits
// to 1. This gives a number strictly less than 'limit' with popcount =
// (significant_bits - 1).
return significant_bits - 1;
}

} // namespace

TernaryVector ExtractTernaryVector(const IntervalSet& intervals,
Expand Down Expand Up @@ -259,6 +284,92 @@ bool CoversTernary(const IntervalSet& intervals, TernarySpan ternary) {
});
}

int64_t MaxPopCount(const Interval& interval) {
if (interval.IsPrecise()) {
return interval.LowerBound().PopCount();
}
Bits lower = interval.LowerBound();
Bits upper = interval.UpperBound();

// Find the longest common prefix shared by both bounds (from MSB).
Bits prefix = bits_ops::LongestCommonPrefixMSB({lower, upper});
int64_t prefix_pop = prefix.PopCount();

// The first bit where they differ is just below the common prefix.
int64_t divergence_bit_index = lower.bit_count() - 1 - prefix.bit_count();

// We split the interval into two branches at divergence_bit_index:
// Branch 1: The bit at divergence_bit_index is 0.
// Branch 2: The bit at divergence_bit_index is 1.

// Branch 1: The divergence bit is 0.
// We can set all bits strictly below the divergence bit to 1 without
// exceeding `upper`.
int64_t candidate_lower_branch_pop = prefix_pop + divergence_bit_index;

// Branch 2: The divergence bit is 1.
// We must ensure the number remains <= upper. This reduces to finding the
// maximum popcount of a number <= upper_remainder (the lower
// 'divergence_bit_index' bits of 'upper').
int64_t candidate_upper_branch_pop = 0;
if (divergence_bit_index == 0) {
candidate_upper_branch_pop =
prefix_pop + 1; // Divergence is at bit 0, no remainder.
} else {
Bits upper_remainder = upper.Slice(0, divergence_bit_index);
candidate_upper_branch_pop =
prefix_pop + 1 + MaxPopCountLessThanOrEqualTo(upper_remainder);
}

return std::max(candidate_lower_branch_pop, candidate_upper_branch_pop);
}

int64_t MaxPopCount(const IntervalSet& intervals) {
int64_t max_pop_count = 0;
for (const Interval& interval : intervals.Intervals()) {
max_pop_count = std::max(max_pop_count, MaxPopCount(interval));
}
return max_pop_count;
}

int64_t MinPopCount(const Interval& interval) {
if (interval.IsPrecise()) {
return interval.LowerBound().PopCount();
}
Bits lower = interval.LowerBound();
Bits upper = interval.UpperBound();

// Find the longest common prefix shared by both bounds (from MSB).
Bits prefix = bits_ops::LongestCommonPrefixMSB({lower, upper});

// The first bit where they differ is just below the common prefix.
int64_t divergence_bit_index = lower.bit_count() - 1 - prefix.bit_count();

// If there are no bits below the divergence bit, or if the lower bits of
// `lower` below the divergence bit are all zero, we can just use the prefix
// padded with zeros.
if (divergence_bit_index == 0 ||
lower.Slice(0, divergence_bit_index).IsZero()) {
return prefix.PopCount();
}

// Otherwise, we can achieve `prefix.PopCount() + 1` by setting the divergence
// bit to 1 and all lower bits to 0. This value is guaranteed to be in the
// interval.
return prefix.PopCount() + 1;
}

int64_t MinPopCount(const IntervalSet& intervals) {
if (intervals.IsEmpty()) {
return 0;
}
int64_t min_pop_count = intervals.BitCount();
for (const Interval& interval : intervals.Intervals()) {
min_pop_count = std::min(min_pop_count, MinPopCount(interval));
}
return min_pop_count;
}

namespace {

enum class Tonicity : bool { Monotone, Antitone };
Expand Down
8 changes: 8 additions & 0 deletions xls/ir/interval_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ TernaryVector ExtractTernaryVector(const IntervalSet& intervals,
bool CoversTernary(const Interval& interval, TernarySpan ternary);
bool CoversTernary(const IntervalSet& intervals, TernarySpan ternary);

// Returns the maximum popcount of any value in the given `intervals`.
int64_t MaxPopCount(const Interval& interval);
int64_t MaxPopCount(const IntervalSet& intervals);

// Returns the minimum popcount of any value in the given `intervals`.
int64_t MinPopCount(const Interval& interval);
int64_t MinPopCount(const IntervalSet& intervals);

struct KnownBits {
Bits known_bits;
Bits known_bit_values;
Expand Down
14 changes: 14 additions & 0 deletions xls/ir/interval_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,20 @@ TEST(IntervalOpsTest, MinimumSignedBitCount) {
7);
}

TEST(IntervalOpsTest, MaxPopCount) {
EXPECT_EQ(MaxPopCount(Interval::Precise(UBits(0, 8))), 0);
EXPECT_EQ(MaxPopCount(Interval::Precise(UBits(255, 8))), 8);
EXPECT_EQ(MaxPopCount(Interval::Closed(UBits(2, 8), UBits(4, 8))), 2);
EXPECT_EQ(MaxPopCount(Interval::Closed(UBits(7, 8), UBits(11, 8))), 3);
}

TEST(IntervalOpsTest, MinPopCount) {
EXPECT_EQ(MinPopCount(Interval::Precise(UBits(0, 8))), 0);
EXPECT_EQ(MinPopCount(Interval::Precise(UBits(255, 8))), 8);
EXPECT_EQ(MinPopCount(Interval::Closed(UBits(2, 8), UBits(4, 8))), 1);
EXPECT_EQ(MinPopCount(Interval::Closed(UBits(7, 8), UBits(11, 8))), 1);
}

TEST(MinimizeIntervalsTest, PrefersEarlyIntervals) {
// All 32 6-bit [0, 63] even numbers.
IntervalSet even_numbers =
Expand Down
39 changes: 39 additions & 0 deletions xls/ir/partial_information.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,45 @@ int64_t PartialInformation::KnownLeadingSignBits() const {
return 1 + bit_count_ - interval_ops::MinimumSignedBitCount(*range_);
}

int64_t PartialInformation::MaxPopCount() const {
if (IsImpossible()) {
return 0;
}
if (IsUnconstrained()) {
return bit_count_;
}
int64_t tern_count = 0;
if (ternary_) {
tern_count = absl::c_count_if(*ternary_, [](TernaryValue v) {
return v != TernaryValue::kKnownZero;
});
}
int64_t range_count = 0;
if (range_) {
range_count = interval_ops::MaxPopCount(*range_);
}
return std::min(range_count, tern_count);
}

int64_t PartialInformation::MinPopCount() const {
if (IsImpossible()) {
return 0;
}
if (IsUnconstrained()) {
return 0;
}
int64_t tern_count = 0;
if (ternary_) {
tern_count = absl::c_count_if(
*ternary_, [](TernaryValue v) { return v != TernaryValue::kKnownOne; });
}
int64_t range_count = 0;
if (range_) {
range_count = interval_ops::MinPopCount(*range_);
}
return std::max(range_count, tern_count);
}

std::string PartialInformation::ToDebugString() const {
if (IsUnconstrained()) {
return "unconstrained";
Expand Down
8 changes: 8 additions & 0 deletions xls/ir/partial_information.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ class PartialInformation {
// Gets the number of leading sign bits known.
int64_t KnownLeadingSignBits() const;

// Returns an upper bound on the popcount of any value that can satisfy this
// PartialInformation.
int64_t MaxPopCount() const;

// Returns a lower bound on the popcount of any value that can satisfy this
// PartialInformation.
int64_t MinPopCount() const;

std::string ToString() const;
std::string ToDebugString() const;

Expand Down
32 changes: 32 additions & 0 deletions xls/passes/context_sensitive_range_query_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,16 @@ class ProxyContextQueryEngine final : public QueryEngine {
return a_value != b_value;
}

bool AtMostOneBitTrue(Node* node) const override {
return MostSpecific(node).AtMostOneBitTrue(node);
}
bool AtLeastOneBitTrue(Node* node) const override {
return MostSpecific(node).AtLeastOneBitTrue(node);
}
bool ExactlyOneBitTrue(Node* node) const override {
return MostSpecific(node).ExactlyOneBitTrue(node);
}

bool Covers(Node* node, const Bits& value) const override {
return MostSpecific(node).Covers(node, value);
}
Expand Down Expand Up @@ -657,6 +667,28 @@ ContextSensitiveRangeQueryEngine::GetTernary(Node* node) const {
node->GetType(), *select_ranges_.at(node).ternary)
.AsShared();
}
bool ContextSensitiveRangeQueryEngine::AtMostOneBitTrue(Node* node) const {
if (!node->OpIn({Op::kSel}) || !select_ranges_.contains(node)) {
return base_case_ranges_.AtMostOneBitTrue(node);
}
return interval_ops::MaxPopCount(
select_ranges_.at(node).interval_set.Get({})) <= 1;
}
bool ContextSensitiveRangeQueryEngine::AtLeastOneBitTrue(Node* node) const {
if (!node->OpIn({Op::kSel}) || !select_ranges_.contains(node)) {
return base_case_ranges_.AtLeastOneBitTrue(node);
}
return interval_ops::MinPopCount(
select_ranges_.at(node).interval_set.Get({})) >= 1;
}
bool ContextSensitiveRangeQueryEngine::ExactlyOneBitTrue(Node* node) const {
if (!node->OpIn({Op::kSel}) || !select_ranges_.contains(node)) {
return base_case_ranges_.ExactlyOneBitTrue(node);
}
IntervalSet interval_set = select_ranges_.at(node).interval_set.Get({});
return interval_ops::MinPopCount(interval_set) == 1 &&
interval_ops::MaxPopCount(interval_set) == 1;
}
bool ContextSensitiveRangeQueryEngine::Covers(Node* node,
const Bits& value) const {
if (!node->OpIn({Op::kSel}) || !select_ranges_.contains(node)) {
Expand Down
4 changes: 4 additions & 0 deletions xls/passes/context_sensitive_range_query_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ class ContextSensitiveRangeQueryEngine final : public QueryEngine {
std::unique_ptr<QueryEngine> SpecializeGivenPredicate(
const absl::btree_set<PredicateState>& state) const override;

bool AtMostOneBitTrue(Node* node) const override;
bool AtLeastOneBitTrue(Node* node) const override;
bool ExactlyOneBitTrue(Node* node) const override;

bool Covers(Node* node, const Bits& value) const override;

Bits MaxUnsignedValue(Node* node) const override;
Expand Down
Loading
Loading