Skip to content

Commit

Permalink
Fix bugs in Interval::IsTrueWhenAndWith
Browse files Browse the repository at this point in the history
This function's definition didn't quite meet its specification; in particular, it used an invalid mask. The result was incorrect behavior of all types (false positives & false negatives).

This could have caused some incorrect optimization analyses for `PrioritySelect` in the RangeQueryEngine.

Making an efficient implementation of this required the new `interval_ops::CoversTernary` function, which tests whether a ternary pattern intersects an `Interval` in O(bits) time. Since this function turned out to be quite complicated, I ended up writing a corresponding proof in the inline comments, and added fuzz testing (which quickly caught several bugs in my first implementation). I also exhaustively tested it for all combinations of 7-bit `Interval`s and 7-bit ternary patterns. (The exhaustive test took several minutes to run on my local machine, so I haven't checked that part in.)

PiperOrigin-RevId: 627946453
  • Loading branch information
ericastor authored and copybara-github committed Apr 25, 2024
1 parent dc8f6b0 commit d8b09c6
Show file tree
Hide file tree
Showing 12 changed files with 250 additions and 38 deletions.
2 changes: 2 additions & 0 deletions xls/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ cc_test(
deps = [
":bits",
":bits_ops",
":bits_test_utils",
":interval",
":interval_test_utils",
"//xls/common:xls_gunit",
Expand Down Expand Up @@ -249,6 +250,7 @@ cc_test(
":interval",
":interval_ops",
":interval_set",
":interval_test_utils",
":ir",
":ir_test_base",
":ternary",
Expand Down
13 changes: 7 additions & 6 deletions xls/ir/interval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,12 +292,13 @@ bool Interval::IsMaximal() const {

bool Interval::IsTrueWhenAndWith(const Bits& value) const {
CHECK_EQ(value.bit_count(), BitCount());
int64_t right_index = std::min(LowerBound().CountTrailingZeros(),
UpperBound().CountTrailingZeros());
int64_t left_index = BitCount() - UpperBound().CountLeadingZeros();
Bits interval_mask_value(BitCount());
interval_mask_value.SetRange(right_index, left_index);
return !bits_ops::And(interval_mask_value, value).IsZero();
BitsRope interval_mask_value(BitCount());
Bits common_prefix =
bits_ops::LongestCommonPrefixMSB({LowerBound(), UpperBound()});
interval_mask_value.push_back(
Bits::AllOnes(BitCount() - common_prefix.bit_count()));
interval_mask_value.push_back(common_prefix);
return !bits_ops::And(interval_mask_value.Build(), value).IsZero();
}

bool Interval::Covers(const Bits& point) const {
Expand Down
5 changes: 5 additions & 0 deletions xls/ir/interval.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,11 @@ class Interval {
interval.upper_bound_);
}

template <typename Sink>
friend void AbslStringify(Sink& sink, const Interval& interval) {
absl::Format(&sink, "%s", interval.ToString());
}

private:
void EnsureValid() const { CHECK(is_valid_); }

Expand Down
126 changes: 124 additions & 2 deletions xls/ir/interval_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include <memory>
#include <optional>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/log/check.h"
Expand All @@ -30,7 +29,6 @@
#include "xls/ir/bits_ops.h"
#include "xls/ir/interval.h"
#include "xls/ir/interval_set.h"
#include "xls/ir/lsb_or_msb.h"
#include "xls/ir/node.h"
#include "xls/ir/ternary.h"
#include "xls/passes/ternary_evaluator.h"
Expand Down Expand Up @@ -119,6 +117,130 @@ IntervalSet FromTernary(TernarySpan tern, int64_t max_interval_bits) {
return is;
}

bool CoversTernary(const Interval& interval, TernarySpan ternary) {
if (interval.BitCount() != ternary.size()) {
return false;
}
if (ternary_ops::IsFullyKnown(ternary)) {
return interval.Covers(ternary_ops::ToKnownBitsValues(ternary));
}
if (interval.IsPrecise()) {
return ternary_ops::IsCompatible(ternary, interval.LowerBound());
}

Bits lcp = bits_ops::LongestCommonPrefixMSB(
{interval.LowerBound(), interval.UpperBound()});

// We know the next bit of the bounds of `interval` differs, and the interval
// is proper iff the upper bound has a 1 there.
const bool proper = interval.UpperBound().GetFromMsb(lcp.bit_count());

TernarySpan prefix = ternary.subspan(ternary.size() - lcp.bit_count());

// If the interval is proper, then the interval only contains things with
// this least-common prefix.
if (proper && !ternary_ops::IsCompatible(prefix, lcp)) {
return false;
}

// If the interval is improper, then it contains everything that doesn't share
// this prefix. Therefore, unless `prefix` is fully-known and matches the
// least-common prefix, `ternary` can definitely represent something in the
// interval.
if (!proper && !(ternary_ops::IsFullyKnown(prefix) &&
ternary_ops::ToKnownBitsValues(prefix) == lcp)) {
return true;
}

// Take the leading value in `ternary`.
TernaryValue x = ternary[ternary.size() - lcp.bit_count() - 1];

// Drop all the bits we've already confirmed match, plus one more.
Bits L = interval.LowerBound().Slice(0, ternary.size() - lcp.bit_count() - 1);
Bits U = interval.UpperBound().Slice(0, ternary.size() - lcp.bit_count() - 1);
TernarySpan t = ternary.subspan(0, ternary.size() - lcp.bit_count() - 1);

auto could_be_le = [](TernarySpan t, const Bits& L) {
for (int64_t i = t.size() - 1; i >= 0; --i) {
if (L.Get(i)) {
if (t[i] != TernaryValue::kKnownOne) {
// If this bit is zero, it will make t < L.
return true;
}
} else if (t[i] == TernaryValue::kKnownOne) {
// We know t > L.
return false;
}
}
return true;
};
auto could_be_ge = [](TernarySpan t, const Bits& U) {
for (int64_t i = t.size() - 1; i >= 0; --i) {
if (U.Get(i)) {
if (t[i] == TernaryValue::kKnownZero) {
// We know t < U.
return false;
}
} else if (t[i] != TernaryValue::kKnownZero) {
// If this bit is one, it will make t > L.
return true;
}
}
return true;
};

// NOTE: At this point, we want to know:
//
// if improper, whether it's possible to have:
// xt <= 0U || 1L <= xt, which is true iff
// (x == 0 && t <= U) || (x == 1 && L <= t).
//
// if proper, whether it's possible to have:
// 0L <= xt && xt <= 1U, which is true iff
// (x == 1 || L <= t) && (x == 0 || t <= U).
//
// If x is known, then this is easy:
// if x == 0 && proper: check if it's possible to have L <= t.
// if x == 1 && improper: check if it's possible to have L <= t.
// if x == 0 && improper: check if it's possible to have t <= U.
// if x == 1 && proper: check if it's possible to have t <= U.
// In other words:
// if (x == 0) == proper, check if it's possible to have L <= t.
// Otherwise, check if it's possible to have t <= U.
if (ternary_ops::IsKnown(x)) {
if ((x == TernaryValue::kKnownZero) == proper) {
return could_be_ge(t, L);
}
return could_be_le(t, U);
}

// If x is unknown, then we can choose whichever value we want. Therefore, we
// just need to know:
// if improper, whether it's possible to have... well.
// if we take x == 0, then we just need to check if we can have t <= U.
// if we take x == 1, then we just need to check if we can have L <= t.
// Therefore, we just need to check whether it's possible to have:
// t <= U || L <= t.
// if proper, whether it's possible to have... well.
// If we take x == 1, then we just need to check if we can have t <= U.
// If we take x == 0, then we just need to check if we can have L <= t.
// Therefore, we just need to check whether it's possible to have:
// t <= U || L <= t.
// The conclusion is the same whether the interval is proper or improper, so
// we check this and we're done.
return could_be_le(t, U) || could_be_ge(t, L);
}

bool CoversTernary(const IntervalSet& intervals, TernarySpan ternary) {
if (intervals.BitCount() != ternary.size()) {
return false;
}
return absl::c_any_of(intervals.Intervals(),
[&ternary](const Interval& interval) {
return CoversTernary(interval, ternary);
});
}

namespace {
enum class Tonicity : bool { Monotone, Antitone };
// What sort of behavior the argument exhibits
Expand Down
8 changes: 8 additions & 0 deletions xls/ir/interval_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ IntervalSet FromTernary(TernarySpan ternary, int64_t max_interval_bits = 4);
TernaryVector ExtractTernaryVector(const IntervalSet& intervals,
std::optional<Node*> source = std::nullopt);

// Determine whether the given `intervals` include any element that matches the
// given `ternary` span.
bool CoversTernary(const Interval& interval, TernarySpan ternary);
bool CoversTernary(const IntervalSet& intervals, TernarySpan ternary);

struct KnownBits {
Bits known_bits;
Bits known_bit_values;
Expand Down Expand Up @@ -82,6 +87,9 @@ IntervalSet UMul(const IntervalSet& a, const IntervalSet& b,
int64_t output_bitwidth);
IntervalSet UDiv(const IntervalSet& a, const IntervalSet& b);

// Encode/decode
IntervalSet Decode(const IntervalSet& a, int64_t width);

// Bit ops.
IntervalSet Not(const IntervalSet& a);
IntervalSet And(const IntervalSet& a, const IntervalSet& b);
Expand Down
23 changes: 23 additions & 0 deletions xls/ir/interval_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "xls/ir/function_builder.h"
#include "xls/ir/interval.h"
#include "xls/ir/interval_set.h"
#include "xls/ir/interval_test_utils.h"
#include "xls/ir/ir_test_base.h"
#include "xls/ir/nodes.h"
#include "xls/ir/package.h"
Expand Down Expand Up @@ -868,5 +869,27 @@ FUZZ_TEST(MinimizeIntervalsTest, MinimizeIntervalsGeneratesSuperset)
fuzztest::NonNegative<int64_t>())),
fuzztest::InRange<int64_t>(1, 256));

void CoversTernaryWorksForIntervals(const Interval& interval,
TernarySpan ternary) {
EXPECT_EQ(interval_ops::CoversTernary(interval, ternary),
interval.ForEachElement([&](const Bits& element) {
return ternary ==
ternary_ops::Intersection(
ternary_ops::BitsToTernary(element), ternary);
}))
<< "interval: "
<< absl::StrFormat("[%s, %s]", interval.LowerBound().ToDebugString(),
interval.UpperBound().ToDebugString())
<< ", ternary: " << ToString(ternary);
}
FUZZ_TEST(IntervalOpsFuzzTest, CoversTernaryWorksForIntervals)
.WithDomains(ArbitraryInterval(8),
fuzztest::VectorOf(fuzztest::ElementOf({
TernaryValue::kKnownZero,
TernaryValue::kKnownOne,
TernaryValue::kUnknown,
}))
.WithSize(8));

} // namespace
} // namespace xls::interval_ops
21 changes: 12 additions & 9 deletions xls/ir/interval_set_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,16 +272,19 @@ TEST(IntervalTest, Size) {
}

TEST(IntervalTest, IsTrueWhenMaskWith) {
IntervalSet example(3);
example.AddInterval(MakeInterval(0, 0, 3));
for (int64_t value = 0; value < 8; ++value) {
EXPECT_FALSE(example.IsTrueWhenMaskWith(UBits(value, 3)));
IntervalSet example(4);
example.AddInterval(MakeInterval(0, 0, 4));
for (int64_t value = 0; value < 16; ++value) {
EXPECT_FALSE(example.IsTrueWhenMaskWith(UBits(value, 4)));
}
example.AddInterval(MakeInterval(2, 4, 4));
EXPECT_FALSE(example.IsTrueWhenMaskWith(UBits(0, 4)));
for (int64_t value = 1; value < 8; ++value) {
EXPECT_TRUE(example.IsTrueWhenMaskWith(UBits(value, 4)));
}
example.AddInterval(MakeInterval(2, 4, 3));
EXPECT_FALSE(example.IsTrueWhenMaskWith(UBits(0, 3)));
EXPECT_FALSE(example.IsTrueWhenMaskWith(UBits(1, 3)));
for (int64_t value = 2; value < 8; ++value) {
EXPECT_TRUE(example.IsTrueWhenMaskWith(UBits(value, 3)));
EXPECT_FALSE(example.IsTrueWhenMaskWith(UBits(8, 4)));
for (int64_t value = 9; value < 16; ++value) {
EXPECT_TRUE(example.IsTrueWhenMaskWith(UBits(value, 4)));
}
}

Expand Down
47 changes: 30 additions & 17 deletions xls/ir/interval_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "xls/common/status/matchers.h"
#include "xls/ir/bits.h"
#include "xls/ir/bits_ops.h"
#include "xls/ir/bits_test_utils.h"
#include "xls/ir/interval_test_utils.h"

using ::testing::ElementsAre;
Expand Down Expand Up @@ -373,24 +374,15 @@ TEST(IntervalTest, NonZeroStartingValueIsTrueWhenMaskWith) {

// Test the IsTrueWhenMaskWith with an interval starting at zero.
TEST(IntervalTest, ZeroStartingValueIsTrueWhenMaskWith) {
Interval interval(UBits(0, 3), UBits(4, 3));
Interval interval(UBits(0, 4), UBits(4, 4));

EXPECT_FALSE(interval.IsTrueWhenAndWith(UBits(0, 3)));
EXPECT_FALSE(interval.IsTrueWhenAndWith(UBits(1, 3)));
EXPECT_FALSE(interval.IsTrueWhenAndWith(UBits(2, 3)));
EXPECT_FALSE(interval.IsTrueWhenAndWith(UBits(3, 3)));
EXPECT_TRUE(interval.IsTrueWhenAndWith(UBits(4, 3)));
EXPECT_TRUE(interval.IsTrueWhenAndWith(UBits(5, 3)));
EXPECT_TRUE(interval.IsTrueWhenAndWith(UBits(6, 3)));
EXPECT_TRUE(interval.IsTrueWhenAndWith(UBits(7, 3)));
}

// Test the IsTrueWhenMaskWith with an interval that does not overlap.
TEST(IntervalTest, NoOverlappingIntervalIsTrueWhenMaskWith) {
Interval interval(UBits(4, 3), UBits(0, 3));

for (int64_t value = 0; value < 8; ++value) {
EXPECT_FALSE(interval.IsTrueWhenAndWith(UBits(value, 3)));
EXPECT_FALSE(interval.IsTrueWhenAndWith(UBits(0, 4)));
for (int64_t value = 1; value < 7; ++value) {
EXPECT_TRUE(interval.IsTrueWhenAndWith(UBits(value, 4)));
}
EXPECT_FALSE(interval.IsTrueWhenAndWith(UBits(8, 4)));
for (int64_t value = 9; value < 16; ++value) {
EXPECT_TRUE(interval.IsTrueWhenAndWith(UBits(value, 4)));
}
}

Expand All @@ -405,6 +397,27 @@ TEST(IntervalTest, OverlappingBitsIsTrueWhenMaskWith) {
}
}

// Test the IsTrueWhenMaskWith with an interval containing overlapping bits, but
// not overlapping for every bit at the ends of the interval.
TEST(IntervalTest, OverlappingBitsIsTrueWhenMaskWith2) {
Interval interval(UBits(2, 3), UBits(6, 3));

EXPECT_FALSE(interval.IsTrueWhenAndWith(UBits(0, 3)));
for (int64_t value = 1; value < 8; ++value) {
EXPECT_TRUE(interval.IsTrueWhenAndWith(UBits(value, 3)))
<< "didn't match with value: " << value;
}
}

void IsTrueWhenAndWith(const Interval& interval, const Bits& value) {
EXPECT_EQ(interval.IsTrueWhenAndWith(value),
interval.ForEachElement([&](const Bits& bits) -> bool {
return bits_ops::OrReduce(bits_ops::And(bits, value)).IsAllOnes();
}));
}
FUZZ_TEST(IntervalFuzzTest, IsTrueWhenAndWith)
.WithDomains(ProperInterval(12), ArbitraryBits(12));

TEST(IntervalTest, Covers) {
Bits thirty_two = Bits::PowerOfTwo(5, 12);
Bits sixty_four = Bits::PowerOfTwo(6, 12);
Expand Down
16 changes: 16 additions & 0 deletions xls/ir/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,22 @@ TernaryVector Intersection(TernarySpan lhs, TernarySpan rhs) {
return result;
}

bool IsCompatible(TernarySpan pattern, const Bits& bits) {
if (pattern.size() != bits.bit_count()) {
return false;
}

for (int64_t i = 0; i < pattern.size(); ++i) {
if (pattern[i] == TernaryValue::kUnknown) {
continue;
}
if (bits.Get(i) != (pattern[i] == TernaryValue::kKnownOne)) {
return false;
}
}
return true;
}

void UpdateWithIntersection(TernaryVector& lhs, TernarySpan rhs) {
CHECK_EQ(lhs.size(), rhs.size());

Expand Down
3 changes: 3 additions & 0 deletions xls/ir/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ absl::Status UpdateWithUnion(TernaryVector& lhs, TernarySpan rhs);
// lengths.
TernaryVector Intersection(TernarySpan lhs, TernarySpan rhs);

// Returns true if `bits` is a possible value for `pattern`.
bool IsCompatible(TernarySpan pattern, const Bits& bits);

// Updates `lhs`, turning it into a vector of bits known to have the same value
// in both `lhs` and `rhs`. CHECK fails if `lhs` and `rhs` have different
// lengths.
Expand Down
Loading

0 comments on commit d8b09c6

Please sign in to comment.