Skip to content

Commit

Permalink
Extract IntervalSet -> basic Ternary code to common location
Browse files Browse the repository at this point in the history
This is/will be useful in other places. Put it in a dedicated interval_ops:: namespace.

PiperOrigin-RevId: 553612501
  • Loading branch information
allight authored and copybara-github committed Aug 3, 2023
1 parent f5621b3 commit 3214556
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 15 deletions.
32 changes: 32 additions & 0 deletions xls/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,38 @@ cc_test(
],
)

cc_library(
name = "interval_ops",
srcs = ["interval_ops.cc"],
hdrs = ["interval_ops.h"],
deps = [
":bits",
":bits_ops",
":interval_set",
":ir",
":ternary",
],
)

cc_test(
name = "interval_ops_test",
size = "small",
srcs = [
"interval_ops_test.cc",
],
deps = [
":bits",
":interval",
":interval_ops",
":interval_set",
":ternary",
"//xls/common:xls_gunit",
"//xls/common:xls_gunit_main",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/types:span",
],
)

cc_library(
name = "interval_set",
srcs = ["interval_set.cc"],
Expand Down
50 changes: 50 additions & 0 deletions xls/ir/interval_ops.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright 2023 The XLS Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "xls/ir/interval_ops.h"

#include <cstdint>
#include <optional>

#include "xls/ir/bits.h"
#include "xls/ir/bits_ops.h"
#include "xls/ir/node.h"
#include "xls/ir/ternary.h"

namespace xls::interval_ops {

TernaryVector ExtractTernaryVector(const IntervalSet& intervals,
std::optional<Node*> source) {
KnownBits bits = ExtractKnownBits(intervals, source);
return ternary_ops::FromKnownBits(bits.known_bits, bits.known_bit_values);
}

KnownBits ExtractKnownBits(const IntervalSet& intervals,
std::optional<Node*> source) {
XLS_CHECK(intervals.IsNormalized())
<< (source.has_value() ? source.value()->ToString() : "");
XLS_CHECK(!intervals.Intervals().empty())
<< (source.has_value() ? source.value()->ToString() : "");
Bits lcp = bits_ops::LongestCommonPrefixMSB(
{intervals.Intervals().front().LowerBound(),
intervals.Intervals().back().UpperBound()});
int64_t size = intervals.BitCount();
Bits remainder = Bits(size - lcp.bit_count());
return KnownBits{
.known_bits =
bits_ops::Concat({Bits::AllOnes(lcp.bit_count()), remainder}),
.known_bit_values = bits_ops::Concat({lcp, remainder}),
};
}
} // namespace xls::interval_ops
44 changes: 44 additions & 0 deletions xls/ir/interval_ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright 2023 The XLS Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef XLS_IR_INTERVAL_OPS_H_
#define XLS_IR_INTERVAL_OPS_H_

#include <optional>

#include "xls/ir/bits.h"
#include "xls/ir/interval_set.h"
#include "xls/ir/node.h"
#include "xls/ir/ternary.h"

namespace xls::interval_ops {

// Extract the ternary vector embedded in the interval-sets.
// TODO(allight): Currently this only searches for the longest common MSB
// prefix. More complex analysis is possible though of questionable usefulness
// given they can be extracted by other analyses.
TernaryVector ExtractTernaryVector(const IntervalSet& intervals,
std::optional<Node*> source = std::nullopt);

struct KnownBits {
Bits known_bits;
Bits known_bit_values;
};

KnownBits ExtractKnownBits(const IntervalSet& intervals,
std::optional<Node*> source = std::nullopt);

} // namespace xls::interval_ops

#endif // XLS_IR_INTERVAL_OPS_H_
84 changes: 84 additions & 0 deletions xls/ir/interval_ops_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright 2023 The XLS Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "xls/ir/interval_ops.h"

#include <vector>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/algorithm/container.h"
#include "absl/types/span.h"
#include "xls/ir/bits.h"
#include "xls/ir/interval.h"
#include "xls/ir/interval_set.h"
#include "xls/ir/ternary.h"

namespace xls::interval_ops {

namespace {

IntervalSet SetOf(absl::Span<const Interval> intervals) {
IntervalSet is(intervals.front().BitCount());
absl::c_for_each(intervals, [&](auto v) { is.AddInterval(v); });
is.Normalize();
return is;
}
TEST(IntervalOpsTest, BitsPrecise) {
IntervalSet is = SetOf({Interval::Precise(UBits(21, 8))});
auto known = ExtractKnownBits(is);
EXPECT_EQ(known.known_bits, Bits::AllOnes(8));
EXPECT_EQ(known.known_bit_values, UBits(21, 8));
}
TEST(IntervalOpsTest, BitsMaximal) {
IntervalSet is = SetOf({Interval::Maximal(8)});
auto known = ExtractKnownBits(is);
EXPECT_EQ(known.known_bits, Bits(8));
EXPECT_EQ(known.known_bit_values, Bits(8));
}
TEST(IntervalOpsTest, BitsHalfFull) {
IntervalSet is = SetOf({Interval::Maximal(4).ZeroExtend(8)});
auto known = ExtractKnownBits(is);
EXPECT_EQ(known.known_bits, UBits(0xf0, 8));
EXPECT_EQ(known.known_bit_values, Bits(8));
}
TEST(IntervalOpsTest, MiddleOut) {
IntervalSet is = SetOf({Interval(UBits(0, 8), UBits(0x4, 8)),
Interval(UBits(0x10, 8), UBits(0x14, 8))});
auto known = ExtractKnownBits(is);
EXPECT_EQ(known.known_bits, UBits(0xe0, 8));
EXPECT_EQ(known.known_bit_values, Bits(8));
}
TEST(IntervalOpsTest, MiddleOutHigh) {
IntervalSet is = SetOf({Interval(UBits(0xe0, 8), UBits(0xe4, 8)),
Interval(UBits(0xf0, 8), UBits(0xf4, 8))});
auto known = ExtractKnownBits(is);
EXPECT_EQ(known.known_bits, UBits(0xe0, 8));
EXPECT_EQ(known.known_bit_values, UBits(0xe0, 8));
}
TEST(IntervalOpsTest, MiddleOutTernary) {
IntervalSet is = SetOf({Interval(UBits(0, 8), UBits(0x4, 8)),
Interval(UBits(0x10, 8), UBits(0x14, 8))});
auto known = ExtractTernaryVector(is);
TernaryVector expected{
TernaryValue::kUnknown, TernaryValue::kUnknown,
TernaryValue::kUnknown, TernaryValue::kUnknown,
TernaryValue::kUnknown, TernaryValue::kKnownZero,
TernaryValue::kKnownZero, TernaryValue::kKnownZero,
};
EXPECT_EQ(known, expected);
}

} // namespace
} // namespace xls::interval_ops
1 change: 1 addition & 0 deletions xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ cc_library(
"//xls/ir:bits",
"//xls/ir:bits_ops",
"//xls/ir:interval",
"//xls/ir:interval_ops",
"//xls/ir:interval_set",
"//xls/ir:value_helpers",
],
Expand Down
25 changes: 10 additions & 15 deletions xls/passes/range_query_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include "xls/ir/abstract_node_evaluator.h"
#include "xls/ir/bits_ops.h"
#include "xls/ir/dfs_visitor.h"
#include "xls/ir/function_base.h"
#include "xls/ir/interval_ops.h"
#include "xls/ir/node_iterator.h"
#include "xls/ir/nodes.h"
#include "xls/ir/value_helpers.h"
Expand Down Expand Up @@ -159,7 +161,7 @@ class RangeQueryVisitor : public DfsVisitor {
// If the given interval sets are disjoint, returns `false`.
// In all other cases, returns `std::nullopt`.
static std::optional<bool> AnalyzeEq(const IntervalSet& lhs,
const IntervalSet& rhs);
const IntervalSet& rhs);

// Analyze whether elements of the two given interval sets must be less than,
// must not be less than, or may be either.
Expand All @@ -170,7 +172,7 @@ class RangeQueryVisitor : public DfsVisitor {
// returns `false`.
// In all other cases, returns `std::nullopt`.
static std::optional<bool> AnalyzeLt(const IntervalSet& lhs,
const IntervalSet& rhs);
const IntervalSet& rhs);

// An interval set covering exactly the binary representation of `false`.
static IntervalSet FalseIntervalSet();
Expand Down Expand Up @@ -401,18 +403,11 @@ void RangeQueryEngine::SetIntervalSetTree(
IntervalSetTree new_ist =
LeafTypeTree<IntervalSet>::Zip<IntervalSet, IntervalSet>(
IntervalSet::Intersect, old_ist, interval_sets);
int64_t size = node->GetType()->GetFlatBitCount();
if (node->GetType()->IsBits()) {
IntervalSet interval_set = new_ist.Get({});
XLS_CHECK(interval_set.IsNormalized());
XLS_CHECK(!interval_set.Intervals().empty()) << node->ToString();
Bits lcs = bits_ops::LongestCommonPrefixMSB(
{interval_set.Intervals().front().LowerBound(),
interval_set.Intervals().back().UpperBound()});
known_bits_[node] = bits_ops::Concat(
{Bits::AllOnes(lcs.bit_count()), Bits(size - lcs.bit_count())});
known_bit_values_[node] =
bits_ops::Concat({lcs, Bits(size - lcs.bit_count())});
interval_ops::KnownBits bits =
interval_ops::ExtractKnownBits(new_ist.Get({}), /*source=*/node);
known_bits_[node] = bits.known_bits;
known_bit_values_[node] = bits.known_bit_values;
}
interval_sets_[node] = new_ist;
}
Expand Down Expand Up @@ -561,7 +556,7 @@ absl::Status RangeQueryVisitor::HandleMonotoneAntitoneBinOp(
}

std::optional<bool> RangeQueryVisitor::AnalyzeEq(const IntervalSet& lhs,
const IntervalSet& rhs) {
const IntervalSet& rhs) {
XLS_CHECK(lhs.IsNormalized());
XLS_CHECK(rhs.IsNormalized());

Expand All @@ -586,7 +581,7 @@ std::optional<bool> RangeQueryVisitor::AnalyzeEq(const IntervalSet& lhs,
}

std::optional<bool> RangeQueryVisitor::AnalyzeLt(const IntervalSet& lhs,
const IntervalSet& rhs) {
const IntervalSet& rhs) {
if (std::optional<Interval> lhs_hull = lhs.ConvexHull()) {
if (std::optional<Interval> rhs_hull = rhs.ConvexHull()) {
if (Interval::Disjoint(*lhs_hull, *rhs_hull)) {
Expand Down

0 comments on commit 3214556

Please sign in to comment.