From 32145566fbe79bda13c0a849750813e0303bc97f Mon Sep 17 00:00:00 2001 From: Alex Light Date: Thu, 3 Aug 2023 15:46:15 -0700 Subject: [PATCH] Extract IntervalSet -> basic Ternary code to common location This is/will be useful in other places. Put it in a dedicated interval_ops:: namespace. PiperOrigin-RevId: 553612501 --- xls/ir/BUILD | 32 ++++++++++++ xls/ir/interval_ops.cc | 50 +++++++++++++++++++ xls/ir/interval_ops.h | 44 +++++++++++++++++ xls/ir/interval_ops_test.cc | 84 ++++++++++++++++++++++++++++++++ xls/passes/BUILD | 1 + xls/passes/range_query_engine.cc | 25 ++++------ 6 files changed, 221 insertions(+), 15 deletions(-) create mode 100644 xls/ir/interval_ops.cc create mode 100644 xls/ir/interval_ops.h create mode 100644 xls/ir/interval_ops_test.cc diff --git a/xls/ir/BUILD b/xls/ir/BUILD index e544817bb9..65b5e70660 100644 --- a/xls/ir/BUILD +++ b/xls/ir/BUILD @@ -176,6 +176,38 @@ cc_test( ], ) +cc_library( + name = "interval_ops", + srcs = ["interval_ops.cc"], + hdrs = ["interval_ops.h"], + deps = [ + ":bits", + ":bits_ops", + ":interval_set", + ":ir", + ":ternary", + ], +) + +cc_test( + name = "interval_ops_test", + size = "small", + srcs = [ + "interval_ops_test.cc", + ], + deps = [ + ":bits", + ":interval", + ":interval_ops", + ":interval_set", + ":ternary", + "//xls/common:xls_gunit", + "//xls/common:xls_gunit_main", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "interval_set", srcs = ["interval_set.cc"], diff --git a/xls/ir/interval_ops.cc b/xls/ir/interval_ops.cc new file mode 100644 index 0000000000..cd6a0247b4 --- /dev/null +++ b/xls/ir/interval_ops.cc @@ -0,0 +1,50 @@ +// Copyright 2023 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xls/ir/interval_ops.h" + +#include +#include + +#include "xls/ir/bits.h" +#include "xls/ir/bits_ops.h" +#include "xls/ir/node.h" +#include "xls/ir/ternary.h" + +namespace xls::interval_ops { + +TernaryVector ExtractTernaryVector(const IntervalSet& intervals, + std::optional source) { + KnownBits bits = ExtractKnownBits(intervals, source); + return ternary_ops::FromKnownBits(bits.known_bits, bits.known_bit_values); +} + +KnownBits ExtractKnownBits(const IntervalSet& intervals, + std::optional source) { + XLS_CHECK(intervals.IsNormalized()) + << (source.has_value() ? source.value()->ToString() : ""); + XLS_CHECK(!intervals.Intervals().empty()) + << (source.has_value() ? source.value()->ToString() : ""); + Bits lcp = bits_ops::LongestCommonPrefixMSB( + {intervals.Intervals().front().LowerBound(), + intervals.Intervals().back().UpperBound()}); + int64_t size = intervals.BitCount(); + Bits remainder = Bits(size - lcp.bit_count()); + return KnownBits{ + .known_bits = + bits_ops::Concat({Bits::AllOnes(lcp.bit_count()), remainder}), + .known_bit_values = bits_ops::Concat({lcp, remainder}), + }; +} +} // namespace xls::interval_ops diff --git a/xls/ir/interval_ops.h b/xls/ir/interval_ops.h new file mode 100644 index 0000000000..d6cc7c65be --- /dev/null +++ b/xls/ir/interval_ops.h @@ -0,0 +1,44 @@ +// Copyright 2023 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef XLS_IR_INTERVAL_OPS_H_ +#define XLS_IR_INTERVAL_OPS_H_ + +#include + +#include "xls/ir/bits.h" +#include "xls/ir/interval_set.h" +#include "xls/ir/node.h" +#include "xls/ir/ternary.h" + +namespace xls::interval_ops { + +// Extract the ternary vector embedded in the interval-sets. +// TODO(allight): Currently this only searches for the longest common MSB +// prefix. More complex analysis is possible though of questionable usefulness +// given they can be extracted by other analyses. +TernaryVector ExtractTernaryVector(const IntervalSet& intervals, + std::optional source = std::nullopt); + +struct KnownBits { + Bits known_bits; + Bits known_bit_values; +}; + +KnownBits ExtractKnownBits(const IntervalSet& intervals, + std::optional source = std::nullopt); + +} // namespace xls::interval_ops + +#endif // XLS_IR_INTERVAL_OPS_H_ diff --git a/xls/ir/interval_ops_test.cc b/xls/ir/interval_ops_test.cc new file mode 100644 index 0000000000..42af830dc9 --- /dev/null +++ b/xls/ir/interval_ops_test.cc @@ -0,0 +1,84 @@ +// Copyright 2023 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xls/ir/interval_ops.h" + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/algorithm/container.h" +#include "absl/types/span.h" +#include "xls/ir/bits.h" +#include "xls/ir/interval.h" +#include "xls/ir/interval_set.h" +#include "xls/ir/ternary.h" + +namespace xls::interval_ops { + +namespace { + +IntervalSet SetOf(absl::Span intervals) { + IntervalSet is(intervals.front().BitCount()); + absl::c_for_each(intervals, [&](auto v) { is.AddInterval(v); }); + is.Normalize(); + return is; +} +TEST(IntervalOpsTest, BitsPrecise) { + IntervalSet is = SetOf({Interval::Precise(UBits(21, 8))}); + auto known = ExtractKnownBits(is); + EXPECT_EQ(known.known_bits, Bits::AllOnes(8)); + EXPECT_EQ(known.known_bit_values, UBits(21, 8)); +} +TEST(IntervalOpsTest, BitsMaximal) { + IntervalSet is = SetOf({Interval::Maximal(8)}); + auto known = ExtractKnownBits(is); + EXPECT_EQ(known.known_bits, Bits(8)); + EXPECT_EQ(known.known_bit_values, Bits(8)); +} +TEST(IntervalOpsTest, BitsHalfFull) { + IntervalSet is = SetOf({Interval::Maximal(4).ZeroExtend(8)}); + auto known = ExtractKnownBits(is); + EXPECT_EQ(known.known_bits, UBits(0xf0, 8)); + EXPECT_EQ(known.known_bit_values, Bits(8)); +} +TEST(IntervalOpsTest, MiddleOut) { + IntervalSet is = SetOf({Interval(UBits(0, 8), UBits(0x4, 8)), + Interval(UBits(0x10, 8), UBits(0x14, 8))}); + auto known = ExtractKnownBits(is); + EXPECT_EQ(known.known_bits, UBits(0xe0, 8)); + EXPECT_EQ(known.known_bit_values, Bits(8)); +} +TEST(IntervalOpsTest, MiddleOutHigh) { + IntervalSet is = SetOf({Interval(UBits(0xe0, 8), UBits(0xe4, 8)), + Interval(UBits(0xf0, 8), UBits(0xf4, 8))}); + auto known = ExtractKnownBits(is); + EXPECT_EQ(known.known_bits, UBits(0xe0, 8)); + EXPECT_EQ(known.known_bit_values, UBits(0xe0, 8)); +} +TEST(IntervalOpsTest, MiddleOutTernary) { + IntervalSet is = SetOf({Interval(UBits(0, 8), UBits(0x4, 8)), + Interval(UBits(0x10, 8), UBits(0x14, 8))}); + auto known = ExtractTernaryVector(is); + TernaryVector expected{ + TernaryValue::kUnknown, TernaryValue::kUnknown, + TernaryValue::kUnknown, TernaryValue::kUnknown, + TernaryValue::kUnknown, TernaryValue::kKnownZero, + TernaryValue::kKnownZero, TernaryValue::kKnownZero, + }; + EXPECT_EQ(known, expected); +} + +} // namespace +} // namespace xls::interval_ops diff --git a/xls/passes/BUILD b/xls/passes/BUILD index 829f5a0f49..5c473f4b08 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -533,6 +533,7 @@ cc_library( "//xls/ir:bits", "//xls/ir:bits_ops", "//xls/ir:interval", + "//xls/ir:interval_ops", "//xls/ir:interval_set", "//xls/ir:value_helpers", ], diff --git a/xls/passes/range_query_engine.cc b/xls/passes/range_query_engine.cc index 20b3f09c81..55daa45b5a 100644 --- a/xls/passes/range_query_engine.cc +++ b/xls/passes/range_query_engine.cc @@ -25,6 +25,8 @@ #include "xls/ir/abstract_node_evaluator.h" #include "xls/ir/bits_ops.h" #include "xls/ir/dfs_visitor.h" +#include "xls/ir/function_base.h" +#include "xls/ir/interval_ops.h" #include "xls/ir/node_iterator.h" #include "xls/ir/nodes.h" #include "xls/ir/value_helpers.h" @@ -159,7 +161,7 @@ class RangeQueryVisitor : public DfsVisitor { // If the given interval sets are disjoint, returns `false`. // In all other cases, returns `std::nullopt`. static std::optional AnalyzeEq(const IntervalSet& lhs, - const IntervalSet& rhs); + const IntervalSet& rhs); // Analyze whether elements of the two given interval sets must be less than, // must not be less than, or may be either. @@ -170,7 +172,7 @@ class RangeQueryVisitor : public DfsVisitor { // returns `false`. // In all other cases, returns `std::nullopt`. static std::optional AnalyzeLt(const IntervalSet& lhs, - const IntervalSet& rhs); + const IntervalSet& rhs); // An interval set covering exactly the binary representation of `false`. static IntervalSet FalseIntervalSet(); @@ -401,18 +403,11 @@ void RangeQueryEngine::SetIntervalSetTree( IntervalSetTree new_ist = LeafTypeTree::Zip( IntervalSet::Intersect, old_ist, interval_sets); - int64_t size = node->GetType()->GetFlatBitCount(); if (node->GetType()->IsBits()) { - IntervalSet interval_set = new_ist.Get({}); - XLS_CHECK(interval_set.IsNormalized()); - XLS_CHECK(!interval_set.Intervals().empty()) << node->ToString(); - Bits lcs = bits_ops::LongestCommonPrefixMSB( - {interval_set.Intervals().front().LowerBound(), - interval_set.Intervals().back().UpperBound()}); - known_bits_[node] = bits_ops::Concat( - {Bits::AllOnes(lcs.bit_count()), Bits(size - lcs.bit_count())}); - known_bit_values_[node] = - bits_ops::Concat({lcs, Bits(size - lcs.bit_count())}); + interval_ops::KnownBits bits = + interval_ops::ExtractKnownBits(new_ist.Get({}), /*source=*/node); + known_bits_[node] = bits.known_bits; + known_bit_values_[node] = bits.known_bit_values; } interval_sets_[node] = new_ist; } @@ -561,7 +556,7 @@ absl::Status RangeQueryVisitor::HandleMonotoneAntitoneBinOp( } std::optional RangeQueryVisitor::AnalyzeEq(const IntervalSet& lhs, - const IntervalSet& rhs) { + const IntervalSet& rhs) { XLS_CHECK(lhs.IsNormalized()); XLS_CHECK(rhs.IsNormalized()); @@ -586,7 +581,7 @@ std::optional RangeQueryVisitor::AnalyzeEq(const IntervalSet& lhs, } std::optional RangeQueryVisitor::AnalyzeLt(const IntervalSet& lhs, - const IntervalSet& rhs) { + const IntervalSet& rhs) { if (std::optional lhs_hull = lhs.ConvexHull()) { if (std::optional rhs_hull = rhs.ConvexHull()) { if (Interval::Disjoint(*lhs_hull, *rhs_hull)) {