From 9d96ece855b02464d8d0aac5f36af955ff939892 Mon Sep 17 00:00:00 2001 From: Eric Astor Date: Mon, 10 Jul 2023 11:43:44 -0700 Subject: [PATCH] Add utility methods to create N-ary ops when needed, and simpler nodes when not Supports AND and NOR to start, special-casing the empty and single-operand cases. PiperOrigin-RevId: 546936766 --- xls/codegen/block_conversion.cc | 95 ++++--------------- xls/ir/BUILD | 6 ++ xls/ir/node_util.cc | 47 +++++++++ xls/ir/node_util.h | 20 ++++ xls/ir/node_util_test.cc | 78 +++++++++++++++ ..._main_test__test_add_idle_output_proc.vtxt | 2 +- 6 files changed, 169 insertions(+), 79 deletions(-) diff --git a/xls/codegen/block_conversion.cc b/xls/codegen/block_conversion.cc index 0764f81df4..6fbde97115 100644 --- a/xls/codegen/block_conversion.cc +++ b/xls/codegen/block_conversion.cc @@ -889,25 +889,10 @@ static absl::StatusOr> MakeInputReadyPortsForOutputChannels( // And reduce all the active ready signals. This signal is true iff all // active outputs are ready. - std::string all_active_outputs_ready_name = - PipelineSignalName("all_active_outputs_ready", stage); - Node* all_active_outputs_ready; - if (active_readys.empty()) { - XLS_ASSIGN_OR_RETURN( - all_active_outputs_ready, - block->MakeNodeWithName( - SourceInfo(), Value(UBits(1, 1)), all_active_outputs_ready_name)); - } else if (active_readys.size() == 1) { - // Don't make a new named node if there is only one active_valid signal, - // just use the signal directly. - all_active_outputs_ready = active_readys[0]; - } else { - XLS_ASSIGN_OR_RETURN( - all_active_outputs_ready, - block->MakeNodeWithName(SourceInfo(), active_readys, Op::kAnd, - all_active_outputs_ready_name)); - } - + XLS_ASSIGN_OR_RETURN( + Node * all_active_outputs_ready, + NaryAndIfNeeded(block, active_readys, + PipelineSignalName("all_active_outputs_ready", stage))); result.push_back(all_active_outputs_ready); } @@ -973,25 +958,10 @@ static absl::StatusOr> MakeInputValidPortsForInputChannels( // And reduce all the active valid signals. This signal is true iff all // active inputs are valid. - Node* all_active_inputs_valid; - std::string all_active_inputs_valid_name = - PipelineSignalName("all_active_inputs_valid", stage); - if (active_valids.empty()) { - XLS_ASSIGN_OR_RETURN( - all_active_inputs_valid, - block->MakeNodeWithName( - SourceInfo(), Value(UBits(1, 1)), all_active_inputs_valid_name)); - } else if (active_valids.size() == 1) { - // Don't make a new named node if there is only one active_valid signal, - // just use the signal directly. - all_active_inputs_valid = active_valids[0]; - } else { - XLS_ASSIGN_OR_RETURN( - all_active_inputs_valid, - block->MakeNodeWithName(SourceInfo(), active_valids, Op::kAnd, - all_active_inputs_valid_name)); - } - + XLS_ASSIGN_OR_RETURN( + Node * all_active_inputs_valid, + NaryAndIfNeeded(block, active_valids, + PipelineSignalName("all_active_inputs_valid", stage))); result.push_back(all_active_inputs_valid); } @@ -1088,24 +1058,10 @@ static absl::StatusOr> MakeValidNodesForInputStates( // And reduce all the active valid signals. This signal is true iff all // active states are valid. - Node* all_active_states_valid; - std::string all_active_states_valid_name = - PipelineSignalName("all_active_states_valid", stage); - if (active_valids.empty()) { - XLS_ASSIGN_OR_RETURN( - all_active_states_valid, - block->MakeNodeWithName( - SourceInfo(), Value(UBits(1, 1)), all_active_states_valid_name)); - } else if (active_valids.size() == 1) { - // Don't make a new named node if there is only one active_valid signal, - // just use the signal directly. - all_active_states_valid = active_valids[0]; - } else { - XLS_ASSIGN_OR_RETURN( - all_active_states_valid, - block->MakeNodeWithName(SourceInfo(), active_valids, Op::kAnd, - all_active_states_valid_name)); - } + XLS_ASSIGN_OR_RETURN( + Node * all_active_states_valid, + NaryAndIfNeeded(block, active_valids, + PipelineSignalName("all_active_states_valid", stage))); result.push_back(all_active_states_valid); } @@ -1147,25 +1103,10 @@ static absl::StatusOr> MakeReadyNodesForOutputStates( // And reduce all the active ready signals. This signal is true iff all // active states are ready. - std::string all_active_states_ready_name = - PipelineSignalName("all_active_states_ready", stage); - Node* all_active_states_ready; - if (active_readys.empty()) { - XLS_ASSIGN_OR_RETURN( - all_active_states_ready, - block->MakeNodeWithName( - SourceInfo(), Value(UBits(1, 1)), all_active_states_ready_name)); - } else if (active_readys.size() == 1) { - // Don't make a new named node if there is only one active_valid signal, - // just use the signal directly. - all_active_states_ready = active_readys[0]; - } else { - XLS_ASSIGN_OR_RETURN( - all_active_states_ready, - block->MakeNodeWithName(SourceInfo(), active_readys, Op::kAnd, - all_active_states_ready_name)); - } - + XLS_ASSIGN_OR_RETURN( + Node * all_active_states_ready, + NaryAndIfNeeded(block, active_readys, + PipelineSignalName("all_active_states_ready", stage))); result.push_back(all_active_states_ready); } @@ -1794,9 +1735,7 @@ static absl::Status AddIdleOutput(std::vector valid_nodes, } } - XLS_ASSIGN_OR_RETURN( - Node * idle_signal, - block->MakeNode(SourceInfo(), valid_nodes, Op::kNor)); + XLS_ASSIGN_OR_RETURN(Node * idle_signal, NaryNorIfNeeded(block, valid_nodes)); XLS_ASSIGN_OR_RETURN(streaming_io.idle_port, block->AddOutputPort("idle", idle_signal)); diff --git a/xls/ir/BUILD b/xls/ir/BUILD index 03baff43ce..3ddfafa41d 100644 --- a/xls/ir/BUILD +++ b/xls/ir/BUILD @@ -545,6 +545,7 @@ cc_test( srcs = ["node_util_test.cc"], data = glob(["testdata/*.ir"]), deps = [ + ":bits", ":function_builder", ":ir", ":ir_matcher", @@ -938,13 +939,18 @@ cc_library( srcs = ["node_util.cc"], hdrs = ["node_util.h"], deps = [ + ":bits", ":channel", ":ir", + ":op", + ":source_location", + ":value", ":value_helpers", "//xls/common/logging", "//xls/common/status:ret_check", "//xls/common/status:status_macros", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", diff --git a/xls/ir/node_util.cc b/xls/ir/node_util.cc index 23dd6aab0c..25e4b93b20 100644 --- a/xls/ir/node_util.cc +++ b/xls/ir/node_util.cc @@ -13,16 +13,28 @@ // limitations under the License. #include "xls/ir/node_util.h" + #include #include +#include +#include +#include #include "absl/algorithm/container.h" +#include "absl/container/btree_set.h" +#include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" #include "xls/common/logging/logging.h" #include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" +#include "xls/ir/bits.h" #include "xls/ir/function_base.h" +#include "xls/ir/node.h" +#include "xls/ir/op.h" +#include "xls/ir/source_location.h" +#include "xls/ir/value.h" #include "xls/ir/value_helpers.h" namespace xls { @@ -123,6 +135,41 @@ absl::StatusOr OrReduceLeading(Node* node, int64_t bit_count) { return f->MakeNode(node->loc(), bits, Op::kOr); } +absl::StatusOr NaryAndIfNeeded(FunctionBase* f, + absl::Span operands, + std::string_view name, + const SourceInfo& source_info) { + if (operands.empty()) { + return f->MakeNodeWithName(source_info, Value(UBits(1, 1)), name); + } + + absl::btree_set unique_operands(operands.begin(), + operands.end()); + if (unique_operands.size() == 1) { + return operands[0]; + } + return f->MakeNodeWithName( + source_info, + std::vector(unique_operands.begin(), unique_operands.end()), + Op::kAnd, name); +} + +absl::StatusOr NaryNorIfNeeded(FunctionBase* f, + absl::Span operands, + std::string_view name, + const SourceInfo& source_info) { + XLS_RET_CHECK(!operands.empty()); + absl::btree_set unique_operands(operands.begin(), + operands.end()); + if (unique_operands.size() == 1) { + return f->MakeNodeWithName(source_info, operands[0], Op::kNot, name); + } + return f->MakeNodeWithName( + source_info, + std::vector(unique_operands.begin(), unique_operands.end()), + Op::kNor, name); +} + bool IsUnsignedCompare(Node* node) { switch (node->op()) { case Op::kULe: diff --git a/xls/ir/node_util.h b/xls/ir/node_util.h index 00640e90ea..6576782a96 100644 --- a/xls/ir/node_util.h +++ b/xls/ir/node_util.h @@ -17,12 +17,16 @@ #ifndef XLS_IR_NODE_UTIL_H_ #define XLS_IR_NODE_UTIL_H_ +#include + #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xls/ir/channel.h" #include "xls/ir/nodes.h" +#include "xls/ir/source_location.h" +#include "xls/ir/value.h" namespace xls { @@ -130,6 +134,22 @@ absl::StatusOr AndReduceTrailing(Node* node, int64_t bit_count); // values. absl::StatusOr OrReduceLeading(Node* node, int64_t bit_count); +// And-reduce the given operands if needed. If there are 2+ operands, returns an +// N-ary AND of them; if there is 1 operand, returns that operand; and if there +// are no operands, returns a literal 1. +absl::StatusOr NaryAndIfNeeded(FunctionBase* f, + absl::Span operands, + std::string_view name = "", + const SourceInfo& source_info = {}); + +// Nor-reduce the given operands if needed. If there are 2+ operands, returns an +// N-ary NOR of them; if there is 1 operand, returns a negation of it; and if +// there are no operands, fails. +absl::StatusOr NaryNorIfNeeded(FunctionBase* f, + absl::Span operands, + std::string_view name = "", + const SourceInfo& source_info = {}); + // Returns whether the given node is a signed/unsigned comparison operation (for // example, ULe or SGt). bool IsUnsignedCompare(Node* node); diff --git a/xls/ir/node_util_test.cc b/xls/ir/node_util_test.cc index 80b1b812b4..a10955d428 100644 --- a/xls/ir/node_util_test.cc +++ b/xls/ir/node_util_test.cc @@ -16,6 +16,7 @@ #include #include +#include #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -25,6 +26,7 @@ #include "xls/common/status/matchers.h" #include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" +#include "xls/ir/bits.h" #include "xls/ir/function.h" #include "xls/ir/function_builder.h" #include "xls/ir/ir_matcher.h" @@ -37,9 +39,11 @@ namespace m = ::xls::op_matchers; namespace xls { namespace { +using status_testing::IsOk; using status_testing::IsOkAndHolds; using status_testing::StatusIs; using ::testing::HasSubstr; +using ::testing::Not; class Result { public: @@ -299,5 +303,79 @@ TEST_F(NodeUtilTest, ReplaceTupleIndicesFailsWithDependentReplacement) { HasSubstr("Replacement index 1 (lhs) depends on"))); } +TEST_F(NodeUtilTest, NaryAndWithNoInputs) { + Package p("my_package"); + FunctionBuilder b(TestName(), &p); + b.Param("a", p.GetBitsType(1)); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, b.Build()); + + EXPECT_THAT(NaryAndIfNeeded(f, {}), IsOkAndHolds(m::Literal(UBits(1, 1)))); +} + +TEST_F(NodeUtilTest, NaryAndWithOneInput) { + Package p("my_package"); + FunctionBuilder b(TestName(), &p); + BValue a = b.Param("a", p.GetBitsType(1)); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, b.Build()); + + EXPECT_THAT(NaryAndIfNeeded(f, std::vector{a.node(), a.node()}), + IsOkAndHolds(m::Param("a"))); +} + +TEST_F(NodeUtilTest, NaryAndWithMultipleInputs) { + Package p("my_package"); + FunctionBuilder b(TestName(), &p); + BValue a0 = b.Param("a0", p.GetBitsType(1)); + BValue a1 = b.Param("a1", p.GetBitsType(1)); + BValue a2 = b.Param("a2", p.GetBitsType(1)); + BValue a3 = b.Param("a3", p.GetBitsType(1)); + BValue a4 = b.Param("a4", p.GetBitsType(1)); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, b.Build()); + + EXPECT_THAT( + NaryAndIfNeeded( + f, std::vector{a0.node(), a1.node(), a2.node(), a3.node(), + a4.node(), a1.node(), a3.node(), a1.node()}), + IsOkAndHolds(m::And(m::Param("a0"), m::Param("a1"), m::Param("a2"), + m::Param("a3"), m::Param("a4")))); +} + +TEST_F(NodeUtilTest, NaryNorWithNoInputs) { + Package p("my_package"); + FunctionBuilder b(TestName(), &p); + b.Param("a", p.GetBitsType(1)); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, b.Build()); + + EXPECT_THAT(NaryNorIfNeeded(f, {}), Not(IsOk())); +} + +TEST_F(NodeUtilTest, NaryNorWithOneInput) { + Package p("my_package"); + FunctionBuilder b(TestName(), &p); + BValue a = b.Param("a", p.GetBitsType(1)); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, b.Build()); + + EXPECT_THAT(NaryNorIfNeeded(f, std::vector{a.node(), a.node()}), + IsOkAndHolds(m::Not(m::Param("a")))); +} + +TEST_F(NodeUtilTest, NaryNorWithMultipleInputs) { + Package p("my_package"); + FunctionBuilder b(TestName(), &p); + BValue a0 = b.Param("a0", p.GetBitsType(1)); + BValue a1 = b.Param("a1", p.GetBitsType(1)); + BValue a2 = b.Param("a2", p.GetBitsType(1)); + BValue a3 = b.Param("a3", p.GetBitsType(1)); + BValue a4 = b.Param("a4", p.GetBitsType(1)); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, b.Build()); + + EXPECT_THAT( + NaryNorIfNeeded( + f, std::vector{a0.node(), a3.node(), a2.node(), a1.node(), + a4.node(), a1.node(), a3.node(), a1.node()}), + IsOkAndHolds(m::Nor(m::Param("a0"), m::Param("a1"), m::Param("a2"), + m::Param("a3"), m::Param("a4")))); +} + } // namespace } // namespace xls diff --git a/xls/tools/testdata/codegen_main_test__test_add_idle_output_proc.vtxt b/xls/tools/testdata/codegen_main_test__test_add_idle_output_proc.vtxt index 5fb45088be..022f18b850 100644 --- a/xls/tools/testdata/codegen_main_test__test_add_idle_output_proc.vtxt +++ b/xls/tools/testdata/codegen_main_test__test_add_idle_output_proc.vtxt @@ -53,5 +53,5 @@ module neg_proc( assign out = __out_reg; assign out_vld = __out_valid_reg; assign in_rdy = in_load_en; - assign idle = ~(__in_valid_reg | __out_valid_reg | in_vld | __out_valid_reg); + assign idle = ~(in_vld | __in_valid_reg | __out_valid_reg); endmodule