diff --git a/xls/dslx/ir_convert/BUILD b/xls/dslx/ir_convert/BUILD index 7e65b19c30..68de3a6439 100644 --- a/xls/dslx/ir_convert/BUILD +++ b/xls/dslx/ir_convert/BUILD @@ -43,6 +43,7 @@ cc_library( "//xls/dslx:interp_value", "//xls/dslx:interp_value_utils", "//xls/dslx/frontend:ast", + "//xls/dslx/frontend:proc_id", "//xls/dslx/type_system:parametric_env", "//xls/dslx/type_system:type", "//xls/dslx/type_system:type_info", @@ -79,6 +80,8 @@ cc_test( "//xls/dslx:import_data", "//xls/dslx/frontend:ast", "//xls/dslx/frontend:pos", + "//xls/dslx/frontend:proc_id", + "//xls/dslx/frontend:proc_test_utils", "//xls/dslx/type_system:parametric_env", "//xls/dslx/type_system:type", "//xls/dslx/type_system:type_info", diff --git a/xls/dslx/ir_convert/channel_scope.cc b/xls/dslx/ir_convert/channel_scope.cc index 6dcec957a8..1db474f52b 100644 --- a/xls/dslx/ir_convert/channel_scope.cc +++ b/xls/dslx/ir_convert/channel_scope.cc @@ -26,6 +26,7 @@ #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" @@ -36,6 +37,7 @@ #include "xls/dslx/channel_direction.h" #include "xls/dslx/constexpr_evaluator.h" #include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/proc_id.h" #include "xls/dslx/import_data.h" #include "xls/dslx/interp_value.h" #include "xls/dslx/interp_value_utils.h" @@ -52,13 +54,19 @@ #include "xls/ir/xls_ir_interface.pb.h" namespace xls::dslx { +namespace { + +constexpr std::string_view kNameAndDimsSeparator = "__"; +constexpr std::string_view kBetweenDimsSeparator = "_"; + +} // namespace ChannelScope::ChannelScope(PackageConversionData* conversion_info, ImportData* import_data, std::optional default_fifo_config) : conversion_info_(conversion_info), import_data_(import_data), - channel_name_uniquer_(/*separator=*/"__"), + channel_name_uniquer_(kNameAndDimsSeparator), default_fifo_config_(default_fifo_config) { // Populate channel name uniquer with pre-existing channel names. for (Channel* channel : conversion_info_->package->channels()) { @@ -106,7 +114,8 @@ absl::StatusOr ChannelScope::DefineChannelOrArrayInternal( XLS_ASSIGN_OR_RETURN(std::vector suffixes, CreateAllArrayElementSuffixes(*dims)); for (const std::string& suffix : suffixes) { - std::string channel_name = absl::StrCat(base_channel_name, "__", suffix); + std::string channel_name = + absl::StrCat(base_channel_name, kNameAndDimsSeparator, suffix); XLS_ASSIGN_OR_RETURN(Channel * channel, CreateChannel(channel_name, ops, type, fifo_config)); array->AddChannel(channel_name, channel); @@ -170,7 +179,8 @@ absl::Status ChannelScope::DefineProtoChannelOrArray( } absl::StatusOr -ChannelScope::AssociateWithExistingChannelOrArray(const NameDef* name_def, +ChannelScope::AssociateWithExistingChannelOrArray(const ProcId& proc_id, + const NameDef* name_def, const ChannelDecl* decl) { VLOG(4) << "ChannelScope::AssociateWithExistingChannelOrArray : " << name_def->ToString() << " -> " << decl->ToString(); @@ -180,22 +190,38 @@ ChannelScope::AssociateWithExistingChannelOrArray(const NameDef* name_def, } ChannelOrArray channel_or_array = decl_to_channel_or_array_.at(decl); XLS_RETURN_IF_ERROR( - AssociateWithExistingChannelOrArray(name_def, channel_or_array)); + AssociateWithExistingChannelOrArray(proc_id, name_def, channel_or_array)); return channel_or_array; } absl::Status ChannelScope::AssociateWithExistingChannelOrArray( - const NameDef* name_def, ChannelOrArray channel_or_array) { + const ProcId& proc_id, const NameDef* name_def, + ChannelOrArray channel_or_array) { VLOG(4) << "ChannelScope::AssociateWithExistingChannelOrArray : " << name_def->ToString() << " -> " << GetBaseNameForChannelOrArray(channel_or_array) << " (array: " << std::holds_alternative(channel_or_array) << ")"; - name_def_to_channel_or_array_[name_def] = channel_or_array; + name_def_to_channel_or_array_[std::make_pair(proc_id, name_def)] = + channel_or_array; return absl::OkStatus(); } absl::StatusOr ChannelScope::GetChannelForArrayIndex( - const Index* index) { + const ProcId& proc_id, const Index* index) { + XLS_ASSIGN_OR_RETURN( + ChannelOrArray result, + EvaluateIndex(proc_id, index, /*allow_subarray_reference=*/false)); + CHECK(std::holds_alternative(result)); + return std::get(result); +} + +absl::StatusOr ChannelScope::GetChannelOrArrayForArrayIndex( + const ProcId& proc_id, const Index* index) { + return EvaluateIndex(proc_id, index, /*allow_subarray_reference=*/true); +} + +absl::StatusOr ChannelScope::EvaluateIndex( + const ProcId& proc_id, const Index* index, bool allow_subarray_reference) { VLOG(4) << "ChannelScope::GetChannelForArrayIndex : " << index->ToString(); CHECK(function_context_.has_value()); std::string suffix; @@ -212,11 +238,13 @@ absl::StatusOr ChannelScope::GetChannelForArrayIndex( std::get(index->rhs()))); XLS_ASSIGN_OR_RETURN(int64_t dim_value, dim_interp_value.GetBitValueUnsigned()); - suffix = suffix.empty() ? absl::StrCat(dim_value) - : absl::StrCat(dim_value, "_", suffix); + suffix = suffix.empty() + ? absl::StrCat(dim_value) + : absl::StrCat(dim_value, kBetweenDimsSeparator, suffix); if (const NameRef* name_ref = dynamic_cast(index->lhs()); name_ref) { - return GetChannelArrayElement(name_ref, suffix); + return GetChannelArrayElement(proc_id, name_ref, suffix, + allow_subarray_reference); } Index* new_index = dynamic_cast(index->lhs()); if (!new_index) { @@ -228,16 +256,6 @@ absl::StatusOr ChannelScope::GetChannelForArrayIndex( } } -absl::StatusOr ChannelScope::GetBaseNameForNameDef( - const NameDef* name_def) { - const auto it = name_def_to_channel_or_array_.find(name_def); - if (it == name_def_to_channel_or_array_.end()) { - return absl::NotFoundError(absl::StrCat( - "No channel or array associated with NameDef: ", name_def->ToString())); - } - return GetBaseNameForChannelOrArray(it->second); -} - std::string_view ChannelScope::GetBaseNameForChannelOrArray( ChannelOrArray channel_or_array) { return absl::visit(Visitor{[](Channel* channel) { return channel->name(); }, @@ -270,7 +288,8 @@ ChannelScope::CreateAllArrayElementSuffixes(const std::vector& dims) { continue; } for (const std::string& next : strings) { - new_strings.push_back(absl::StrCat(next, "_", element_index)); + new_strings.push_back( + absl::StrCat(next, kBetweenDimsSeparator, element_index)); } } strings = std::move(new_strings); @@ -280,8 +299,8 @@ ChannelScope::CreateAllArrayElementSuffixes(const std::vector& dims) { absl::StatusOr ChannelScope::CreateBaseChannelName( std::string_view short_name) { - return channel_name_uniquer_.GetSanitizedUniqueName( - absl::StrCat(conversion_info_->package->name(), "__", short_name)); + return channel_name_uniquer_.GetSanitizedUniqueName(absl::StrCat( + conversion_info_->package->name(), kNameAndDimsSeparator, short_name)); } absl::StatusOr ChannelScope::GetChannelType( @@ -339,10 +358,12 @@ absl::StatusOr ChannelScope::CreateChannel( /*fifo_config=*/fifo_config); } -absl::StatusOr ChannelScope::GetChannelArrayElement( - const NameRef* name_ref, std::string_view flattened_name_suffix) { +absl::StatusOr ChannelScope::GetChannelArrayElement( + const ProcId& proc_id, const NameRef* name_ref, + std::string_view flattened_name_suffix, bool allow_subarray_reference) { const auto* name_def = std::get(name_ref->name_def()); - const auto it = name_def_to_channel_or_array_.find(name_def); + const auto it = + name_def_to_channel_or_array_.find(std::make_pair(proc_id, name_def)); if (it == name_def_to_channel_or_array_.end()) { return absl::NotFoundError( absl::StrCat("Not a channel or channel array: ", name_def->ToString())); @@ -355,15 +376,48 @@ absl::StatusOr ChannelScope::GetChannelArrayElement( std::get(channel_or_array)->name())); } ChannelArray* array = std::get(channel_or_array); - std::string flattened_channel_name = - absl::StrCat(array->base_channel_name(), "__", flattened_name_suffix); + std::string flattened_channel_name = absl::StrCat( + array->base_channel_name(), + array->is_subarray() ? kBetweenDimsSeparator : kNameAndDimsSeparator, + flattened_name_suffix); std::optional channel = array->FindChannel(flattened_channel_name); if (channel.has_value()) { VLOG(4) << "Found channel array element: " << (*channel)->name(); return *channel; } + if (allow_subarray_reference) { + return GetOrDefineSubarray(array, flattened_channel_name); + } return absl::NotFoundError(absl::StrCat( "No array element with flattened name: ", flattened_channel_name)); } +absl::StatusOr ChannelScope::GetOrDefineSubarray( + ChannelArray* array, std::string_view subarray_name) { + const auto it = subarrays_.find(subarray_name); + if (it != subarrays_.end()) { + VLOG(5) << "Found subarray " << subarray_name; + return it->second; + } + ChannelArray* subarray = + &arrays_.emplace_back(ChannelArray(subarray_name, /*subarray=*/true)); + subarrays_.emplace_hint(it, subarray_name, subarray); + std::string subarray_prefix = + absl::StrCat(subarray_name, kBetweenDimsSeparator); + VLOG(5) << "Searching for subarray elements with prefix " << subarray_prefix; + for (const std::string& name : array->flattened_names_in_order()) { + if (absl::StartsWith(name, subarray_prefix)) { + Channel* channel = *array->FindChannel(name); + subarray->AddChannel(channel->name(), channel); + } + } + // If type checking has been done right etc., there should never be a request + // for a subarray prefix that matches zero channels, even when compiling + // erroneous DSLX code. + CHECK(!subarray->flattened_names_in_order().empty()); + VLOG(5) << "Defined subarray " << subarray_name << " with " + << subarray->flattened_names_in_order().size() << " elements."; + return subarray; +} + } // namespace xls::dslx diff --git a/xls/dslx/ir_convert/channel_scope.h b/xls/dslx/ir_convert/channel_scope.h index 105085e08e..4f4d26d9d9 100644 --- a/xls/dslx/ir_convert/channel_scope.h +++ b/xls/dslx/ir_convert/channel_scope.h @@ -26,6 +26,7 @@ #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xls/dslx/frontend/ast.h" +#include "xls/dslx/frontend/proc_id.h" #include "xls/dslx/import_data.h" #include "xls/dslx/ir_convert/conversion_info.h" #include "xls/dslx/type_system/parametric_env.h" @@ -49,10 +50,12 @@ class ChannelArray { std::string ToString() const { return base_channel_name_; } private: - explicit ChannelArray(std::string_view base_channel_name) - : base_channel_name_(base_channel_name) {} + explicit ChannelArray(std::string_view base_channel_name, + bool subarray = false) + : base_channel_name_(base_channel_name), subarray_(subarray) {} std::string_view base_channel_name() const { return base_channel_name_; } + bool is_subarray() const { return subarray_; } absl::Span flattened_names_in_order() const { return flattened_names_in_order_; @@ -75,6 +78,14 @@ class ChannelArray { // in the DSLX source code. const std::string base_channel_name_; + // Whether this array represents part of a larger N-D array, with up to N-1 + // dims fixed. In that case, it will contain some of the same channel pointers + // that are in the `ChannelArray` object representing the overall array. + // `ChannelArray` objects for subarrays are fabricated by a `ChannelScope` on + // an as-needed basis, when references to them are encountered (in the form of + // `Index` ops). + const bool subarray_; + // The flattened names in order of addition. The scope adds channels in // ascending index order, and in some situations wants to enumerate them in // that order. @@ -118,13 +129,14 @@ class ChannelScope { // should be used, for example, when a channel is passed into `spawn` and the // receiving proc associates it with a local argument name. absl::StatusOr AssociateWithExistingChannelOrArray( - const NameDef* name_def, const ChannelDecl* decl); + const ProcId& proc_id, const NameDef* name_def, const ChannelDecl* decl); // Variant of `AssociateWithExistingChannelOrArray`, to be used when the // caller has the channel or array returned by `DefineChannelOrArray` on hand, // rather than the `decl` it was made from. absl::Status AssociateWithExistingChannelOrArray( - const NameDef* name_def, ChannelOrArray channel_or_array); + const ProcId& proc_id, const NameDef* name_def, + ChannelOrArray channel_or_array); // Retrieves the individual `Channel` that is referred to by the given `index` // operation. In order for this to succeed, `index` must meet the following @@ -138,7 +150,15 @@ class ChannelScope { // be constexpr evaluatable. // A not-found error is the guaranteed result in cases where `index` is not // a channel array index operation at all. - absl::StatusOr GetChannelForArrayIndex(const Index* index); + absl::StatusOr GetChannelForArrayIndex(const ProcId& proc_id, + const Index* index); + + // Retrieves the subarray or individual `Channel` that is referred to by the + // given `index operation. The `index` must conform to the criteria described + // for `GetChannelForArrayIndex()`, but it may lead part way into a + // multidimensional channel array. + absl::StatusOr GetChannelOrArrayForArrayIndex( + const ProcId& proc_id, const Index* index); private: absl::StatusOr DefineChannelOrArrayInternal( @@ -150,9 +170,6 @@ class ChannelScope { ChannelOrArray array, dslx::ChannelTypeAnnotation* type_annot, xls::Type* ir_type, TypeInfo* type_info); - absl::StatusOr GetBaseNameForNameDef( - const NameDef* name_def); - std::string_view GetBaseNameForChannelOrArray( ChannelOrArray channel_or_array); @@ -171,8 +188,16 @@ class ChannelScope { xls::Type* type, std::optional fifo_config); - absl::StatusOr GetChannelArrayElement( - const NameRef* name_ref, std::string_view flattened_name_suffix); + absl::StatusOr EvaluateIndex(const ProcId& proc_id, + const Index* index, + bool allow_subarray_reference); + + absl::StatusOr GetChannelArrayElement( + const ProcId& proc_id, const NameRef* name_ref, + std::string_view flattened_name_suffix, bool allow_subarray_reference); + + absl::StatusOr GetOrDefineSubarray( + ChannelArray* array, std::string_view subarray_name); PackageConversionData* const conversion_info_; ImportData* const import_data_; @@ -200,8 +225,9 @@ class ChannelScope { absl::flat_hash_map decl_to_channel_or_array_; - absl::flat_hash_map + absl::flat_hash_map, ChannelOrArray> name_def_to_channel_or_array_; + absl::flat_hash_map subarrays_; }; } // namespace xls::dslx diff --git a/xls/dslx/ir_convert/channel_scope_test.cc b/xls/dslx/ir_convert/channel_scope_test.cc index 14c82ce613..b2d3bffaf2 100644 --- a/xls/dslx/ir_convert/channel_scope_test.cc +++ b/xls/dslx/ir_convert/channel_scope_test.cc @@ -32,6 +32,8 @@ #include "xls/dslx/create_import_data.h" #include "xls/dslx/frontend/ast.h" #include "xls/dslx/frontend/pos.h" +#include "xls/dslx/frontend/proc_id.h" +#include "xls/dslx/frontend/proc_test_utils.h" #include "xls/dslx/import_data.h" #include "xls/dslx/ir_convert/conversion_info.h" #include "xls/dslx/type_system/parametric_env.h" @@ -99,22 +101,26 @@ class ChannelScopeTest : public ::testing::Test { NumberKind::kOther, GetU32TypeAnnotation()); } - Index* CreateIndexOp(ChannelDecl* decl, + Index* CreateIndexOp(NameRef* name_ref, const std::vector& indices) { - NameDef* fake_array = module_->Make(Span::Fake(), "arr", nullptr); - absl::StatusOr channel_or_array = - scope_->AssociateWithExistingChannelOrArray(fake_array, decl); - XLS_EXPECT_OK(channel_or_array); - NameRef* fake_array_ref = - module_->Make(Span::Fake(), "arr", fake_array); Index* index = - module_->Make(Span::Fake(), fake_array_ref, MakeU32(indices[0])); + module_->Make(Span::Fake(), name_ref, MakeU32(indices[0])); for (int i = 1; i < indices.size(); i++) { index = module_->Make(Span::Fake(), index, MakeU32(indices[i])); } return index; } + Index* CreateIndexOp(ChannelDecl* decl, + const std::vector& indices) { + NameDef* fake_array = module_->Make(Span::Fake(), "arr", nullptr); + absl::StatusOr channel_or_array = + scope_->AssociateWithExistingChannelOrArray(ProcId{}, fake_array, decl); + XLS_EXPECT_OK(channel_or_array); + return CreateIndexOp( + module_->Make(Span::Fake(), "arr", fake_array), indices); + } + std::unique_ptr import_data_; ParametricEnv bindings_; PackageConversionData conv_; @@ -204,7 +210,7 @@ TEST_F(ChannelScopeTest, AssociateWithExistingChannelDecl) { EXPECT_TRUE(std::holds_alternative(result)); NameDef* name_def = module_->Make(Span::Fake(), "ch", nullptr); XLS_EXPECT_OK_AND_EQ( - scope_->AssociateWithExistingChannelOrArray(name_def, decl), + scope_->AssociateWithExistingChannelOrArray(ProcId{}, name_def, decl), std::get(result)); } @@ -216,15 +222,16 @@ TEST_F(ChannelScopeTest, AssociateWithExistingChannelArrayDecl) { EXPECT_TRUE(std::holds_alternative(result)); NameDef* name_def = module_->Make(Span::Fake(), "ch", nullptr); XLS_EXPECT_OK_AND_EQ( - scope_->AssociateWithExistingChannelOrArray(name_def, decl), + scope_->AssociateWithExistingChannelOrArray(ProcId{}, name_def, decl), std::get(result)); } TEST_F(ChannelScopeTest, AssociateWithExistingChannelOrArrayNonexistent) { ChannelDecl* decl = MakeU32ChannelDecl("the_channel"); NameDef* name_def = module_->Make(Span::Fake(), "ch", nullptr); - EXPECT_THAT(scope_->AssociateWithExistingChannelOrArray(name_def, decl), - StatusIs(absl::StatusCode::kNotFound)); + EXPECT_THAT( + scope_->AssociateWithExistingChannelOrArray(ProcId{}, name_def, decl), + StatusIs(absl::StatusCode::kNotFound)); } TEST_F(ChannelScopeTest, AssociateWithExistingChannel) { @@ -233,7 +240,8 @@ TEST_F(ChannelScopeTest, AssociateWithExistingChannel) { scope_->DefineChannelOrArray(decl)); EXPECT_TRUE(std::holds_alternative(result)); NameDef* name_def = module_->Make(Span::Fake(), "ch", nullptr); - XLS_EXPECT_OK(scope_->AssociateWithExistingChannelOrArray(name_def, result)); + XLS_EXPECT_OK( + scope_->AssociateWithExistingChannelOrArray(ProcId{}, name_def, result)); } TEST_F(ChannelScopeTest, AssociateWithExistingChannelArray) { @@ -243,7 +251,45 @@ TEST_F(ChannelScopeTest, AssociateWithExistingChannelArray) { scope_->DefineChannelOrArray(decl)); EXPECT_TRUE(std::holds_alternative(result)); NameDef* name_def = module_->Make(Span::Fake(), "ch", nullptr); - XLS_EXPECT_OK(scope_->AssociateWithExistingChannelOrArray(name_def, result)); + XLS_EXPECT_OK( + scope_->AssociateWithExistingChannelOrArray(ProcId{}, name_def, result)); +} + +TEST_F(ChannelScopeTest, AssociateWithExistingChannelArrayDifferentProcIds) { + std::vector dims = {MakeU32("5")}; + ChannelDecl* arr1_decl = MakeU32ChannelDecl("arr1", dims); + XLS_ASSERT_OK_AND_ASSIGN(ChannelOrArray arr1, + scope_->DefineChannelOrArray(arr1_decl)); + ChannelDecl* arr2_decl = MakeU32ChannelDecl("arr2", dims); + XLS_ASSERT_OK_AND_ASSIGN(ChannelOrArray arr2, + scope_->DefineChannelOrArray(arr2_decl)); + EXPECT_TRUE(std::holds_alternative(arr1)); + EXPECT_TRUE(std::holds_alternative(arr2)); + + NameDef* ch_def = module_->Make(Span::Fake(), "ch", nullptr); + NameRef* ch_ref = module_->Make(Span::Fake(), "ch", ch_def); + FileTable file_table; + auto [proc_a_module, proc_a] = CreateEmptyProc(file_table, "A"); + auto [proc_b_module, proc_b] = CreateEmptyProc(file_table, "B"); + // Simulate two spawns of B from A, the first passing `arr1` for `ch` and the + // second passing `arr2` for `ch`. + ProcId proc_id1{.proc_instance_stack = {{proc_a, 0}, {proc_b, 0}}}; + ProcId proc_id2{.proc_instance_stack = {{proc_a, 0}, {proc_b, 1}}}; + XLS_EXPECT_OK( + scope_->AssociateWithExistingChannelOrArray(proc_id1, ch_def, arr1_decl)); + XLS_EXPECT_OK( + scope_->AssociateWithExistingChannelOrArray(proc_id2, ch_def, arr2_decl)); + + // Trying to evaluate `ch[some_index]` now should give us a different object + // depending on the proc ID. + XLS_ASSERT_OK_AND_ASSIGN( + Channel * test_channel1, + scope_->GetChannelForArrayIndex(proc_id1, CreateIndexOp(ch_ref, {"2"}))); + EXPECT_EQ(test_channel1->name(), "the_package__arr1__2"); + XLS_ASSERT_OK_AND_ASSIGN( + Channel * test_channel2, + scope_->GetChannelForArrayIndex(proc_id2, CreateIndexOp(ch_ref, {"2"}))); + EXPECT_EQ(test_channel2->name(), "the_package__arr2__2"); } TEST_F(ChannelScopeTest, HandleChannelIndex1DValid) { @@ -252,8 +298,9 @@ TEST_F(ChannelScopeTest, HandleChannelIndex1DValid) { XLS_ASSERT_OK_AND_ASSIGN(ChannelOrArray result, scope_->DefineChannelOrArray(decl)); EXPECT_TRUE(std::holds_alternative(result)); - XLS_ASSERT_OK_AND_ASSIGN(Channel * channel, scope_->GetChannelForArrayIndex( - CreateIndexOp(decl, {"2"}))); + XLS_ASSERT_OK_AND_ASSIGN( + Channel * channel, + scope_->GetChannelForArrayIndex(ProcId{}, CreateIndexOp(decl, {"2"}))); EXPECT_EQ(channel->name(), "the_package__the_channel__2"); } @@ -263,9 +310,9 @@ TEST_F(ChannelScopeTest, HandleChannelIndex2DValid) { XLS_ASSERT_OK_AND_ASSIGN(ChannelOrArray result, scope_->DefineChannelOrArray(decl)); EXPECT_TRUE(std::holds_alternative(result)); - XLS_ASSERT_OK_AND_ASSIGN( - Channel * channel, - scope_->GetChannelForArrayIndex(CreateIndexOp(decl, {"4", "1"}))); + XLS_ASSERT_OK_AND_ASSIGN(Channel * channel, + scope_->GetChannelForArrayIndex( + ProcId{}, CreateIndexOp(decl, {"4", "1"}))); EXPECT_EQ(channel->name(), "the_package__the_channel__4_1"); } @@ -274,29 +321,56 @@ TEST_F(ChannelScopeTest, HandleChannelIndexWithNonArray) { XLS_ASSERT_OK_AND_ASSIGN(ChannelOrArray result, scope_->DefineChannelOrArray(decl)); EXPECT_TRUE(std::holds_alternative(result)); - EXPECT_THAT(scope_->GetChannelForArrayIndex(CreateIndexOp(decl, {"4"})), - StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT( + scope_->GetChannelForArrayIndex(ProcId{}, CreateIndexOp(decl, {"4"})), + StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_F(ChannelScopeTest, HandleChannelIndexWithTooManyIndices) { + std::vector dims = {MakeU32("2"), MakeU32("5")}; + ChannelDecl* decl = MakeU32ChannelDecl("the_channel", dims); + XLS_ASSERT_OK_AND_ASSIGN(ChannelOrArray result, + scope_->DefineChannelOrArray(decl)); + EXPECT_TRUE(std::holds_alternative(result)); + EXPECT_THAT(scope_->GetChannelForArrayIndex( + ProcId{}, CreateIndexOp(decl, {"4", "1", "1"})), + StatusIs(absl::StatusCode::kNotFound)); +} + +TEST_F(ChannelScopeTest, HandleChannelIndexWithInsufficientIndices) { std::vector dims = {MakeU32("2"), MakeU32("5")}; ChannelDecl* decl = MakeU32ChannelDecl("the_channel", dims); XLS_ASSERT_OK_AND_ASSIGN(ChannelOrArray result, scope_->DefineChannelOrArray(decl)); EXPECT_TRUE(std::holds_alternative(result)); EXPECT_THAT( - scope_->GetChannelForArrayIndex(CreateIndexOp(decl, {"4", "1", "1"})), + scope_->GetChannelForArrayIndex(ProcId{}, CreateIndexOp(decl, {"4"})), StatusIs(absl::StatusCode::kNotFound)); } -TEST_F(ChannelScopeTest, HandleChannelIndexWithInsufficientIndices) { +TEST_F(ChannelScopeTest, HandleSubarrayIndex) { std::vector dims = {MakeU32("2"), MakeU32("5")}; ChannelDecl* decl = MakeU32ChannelDecl("the_channel", dims); XLS_ASSERT_OK_AND_ASSIGN(ChannelOrArray result, scope_->DefineChannelOrArray(decl)); EXPECT_TRUE(std::holds_alternative(result)); - EXPECT_THAT(scope_->GetChannelForArrayIndex(CreateIndexOp(decl, {"4"})), - StatusIs(absl::StatusCode::kNotFound)); + + // Get a subarray of "the_channel" and assign a `NameDef` to that. + XLS_ASSERT_OK_AND_ASSIGN(ChannelOrArray subarray, + scope_->GetChannelOrArrayForArrayIndex( + ProcId{}, CreateIndexOp(decl, {"4"}))); + ASSERT_TRUE(std::holds_alternative(subarray)); + NameDef* subarray_def = module_->Make(Span::Fake(), "ch", nullptr); + NameRef* subarray_ref = + module_->Make(Span::Fake(), "ch", subarray_def); + XLS_EXPECT_OK(scope_->AssociateWithExistingChannelOrArray( + ProcId{}, subarray_def, subarray)); + + // Now index into the subarray. + XLS_ASSERT_OK_AND_ASSIGN(Channel * channel, + scope_->GetChannelForArrayIndex( + ProcId{}, CreateIndexOp(subarray_ref, {"1"}))); + EXPECT_EQ(channel->name(), "the_package__the_channel__4_1"); } TEST_F(ChannelScopeTest, HandleChannelIndexWithOutOfRangeIndices) { @@ -305,7 +379,8 @@ TEST_F(ChannelScopeTest, HandleChannelIndexWithOutOfRangeIndices) { XLS_ASSERT_OK_AND_ASSIGN(ChannelOrArray result, scope_->DefineChannelOrArray(decl)); EXPECT_TRUE(std::holds_alternative(result)); - EXPECT_THAT(scope_->GetChannelForArrayIndex(CreateIndexOp(decl, {"5", "0"})), + EXPECT_THAT(scope_->GetChannelForArrayIndex(ProcId{}, + CreateIndexOp(decl, {"5", "0"})), StatusIs(absl::StatusCode::kNotFound)); } diff --git a/xls/dslx/ir_convert/function_converter.cc b/xls/dslx/ir_convert/function_converter.cc index e06b71e5ca..a6aad9e280 100644 --- a/xls/dslx/ir_convert/function_converter.cc +++ b/xls/dslx/ir_convert/function_converter.cc @@ -1679,11 +1679,24 @@ absl::StatusOr FunctionConverter::HandleMap(const Invocation* node) { } absl::Status FunctionConverter::HandleIndex(const Index* node) { - absl::StatusOr channel = - channel_scope_->GetChannelForArrayIndex(node); - if (channel.ok()) { - node_to_ir_[node] = *channel; - return absl::OkStatus(); + if (proc_id_.has_value()) { + absl::StatusOr channel_or_array = + channel_scope_->GetChannelOrArrayForArrayIndex(*proc_id_, node); + if (channel_or_array.ok()) { + if (std::holds_alternative(*channel_or_array)) { + // We don't allow referencing subarrays outside of config(), and the + // ones that occur in config() are dealt with in `ProcConfigIrConverter` + // rather than ending up here. The reason for disallowing them in next() + // is low utility and higher difficulty of implementation against + // non-lowered arrays, which can't be an `IrValue`, `BValue`, etc. + return absl::InvalidArgumentError(absl::StrFormat( + "Invalid channel subarray use: `%s` at %s; channel subarrays can " + "only be used in proc config functions.", + node->ToString(), node->span().ToString(file_table()))); + } + node_to_ir_[node] = std::get(*channel_or_array); + return absl::OkStatus(); + } } XLS_RETURN_IF_ERROR(Visit(node->lhs())); XLS_ASSIGN_OR_RETURN(BValue lhs, Use(node->lhs())); diff --git a/xls/dslx/ir_convert/ir_converter_test.cc b/xls/dslx/ir_convert/ir_converter_test.cc index 8494294dc4..43f9d53259 100644 --- a/xls/dslx/ir_convert/ir_converter_test.cc +++ b/xls/dslx/ir_convert/ir_converter_test.cc @@ -2540,6 +2540,138 @@ TEST(IrConverterTest, ReceiveFromBoundaryChannelArrayElement) { ExpectIr(converted, TestName()); } +TEST(IrConverterTest, DealOutChannelSubarray) { + constexpr std::string_view kProgram = R"( + proc B { + outs: chan[2] out; + ins: chan[2] in; + + init {} + + config(outs: chan[2] out, ins: chan[2] in) { + (outs, ins) + } + + next(state: ()) { + unroll_for!(j, tok) : (u32, token) in u32:0..u32:2 { + let tok = send(tok, outs[j], j); + let(tok, _) = recv(tok, ins[j]); + tok + }(join()); + } + } + + proc A { + init {} + + config() { + let (outs, ins) = chan[2][2]("the_channel"); + unroll_for!(i, _) : (u32, ()) in u32:0..u32:2 { + spawn B(outs[i], ins[i]); + }(()); + } + + next(state: ()) { state } + } + )"; + + ConvertOptions options; + options.emit_fail_as_assert = false; + options.emit_positions = false; + options.verify_ir = false; + auto import_data = CreateImportDataForTest(); + XLS_ASSERT_OK_AND_ASSIGN( + std::string converted, + ConvertOneFunctionForTest(kProgram, "A", import_data, options)); + ExpectIr(converted, TestName()); +} + +TEST(IrConverterTest, LetChannelSubarrayInConfig) { + constexpr std::string_view kProgram = R"( + proc B { + outs: chan[2] out; + ins: chan[2] in; + + init {} + + config(outs: chan[2] out, ins: chan[2] in) { + (outs, ins) + } + + next(state: ()) { + unroll_for!(j, tok) : (u32, token) in u32:0..u32:2 { + let tok = send(tok, outs[j], j); + let(tok, _) = recv(tok, ins[j]); + tok + }(join()); + } + } + + proc A { + init {} + + config() { + let (outs, ins) = chan[2][2]("the_channel"); + let outs0 = outs[0]; + let ins0 = ins[0]; + let outs1 = outs[1]; + let ins1 = ins[1]; + spawn B(outs0, ins0); + spawn B(outs1, ins1); + } + + next(state: ()) { state } + } + )"; + + ConvertOptions options; + options.emit_fail_as_assert = false; + options.emit_positions = false; + options.verify_ir = false; + auto import_data = CreateImportDataForTest(); + XLS_ASSERT_OK_AND_ASSIGN( + std::string converted, + ConvertOneFunctionForTest(kProgram, "A", import_data, options)); + ExpectIr(converted, TestName()); +} + +TEST(IrConverterTest, LetChannelSubarrayInNext) { + constexpr std::string_view kProgram = R"( + proc A { + outs: chan[2][2] out; + ins: chan[2][2] in; + + init {} + + config() { + let (outs, ins) = chan[2][2]("the_channel"); + (outs, ins) + } + + next(state: ()) { + let (outs0, ins0) = (outs[0], ins[0]); + let (outs1, ins1) = (outs[1], ins[1]); + unroll_for!(j, tok) : (u32, token) in u32:0..u32:2 { + let tok = send(tok, outs0[j], j); + let(tok, _) = recv(tok, ins0[j]); + let tok = send(tok, outs1[j], j); + let(tok, _) = recv(tok, ins1[j]); + tok + }(join()); + } + } + )"; + + ConvertOptions options; + options.emit_fail_as_assert = false; + options.emit_positions = false; + options.verify_ir = false; + auto import_data = CreateImportDataForTest(); + EXPECT_THAT(ConvertOneFunctionForTest(kProgram, "A", import_data, options), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid channel subarray use"))); +} + TEST(IrConverterTest, TopProcWithState) { constexpr std::string_view kProgram = R"( proc main { diff --git a/xls/dslx/ir_convert/proc_config_ir_converter.cc b/xls/dslx/ir_convert/proc_config_ir_converter.cc index 8318817bfa..614c39a555 100644 --- a/xls/dslx/ir_convert/proc_config_ir_converter.cc +++ b/xls/dslx/ir_convert/proc_config_ir_converter.cc @@ -90,7 +90,7 @@ absl::Status ProcConfigIrConverter::Finalize() { ProcConfigValueToChannelOrArray(value); if (channel_or_array.has_value()) { XLS_RETURN_IF_ERROR(channel_scope_->AssociateWithExistingChannelOrArray( - member->name_def(), *channel_or_array)); + proc_id_, member->name_def(), *channel_or_array)); } } @@ -138,9 +138,10 @@ absl::Status ProcConfigIrConverter::HandleFunction(const Function* node) { absl::Status ProcConfigIrConverter::HandleIndex(const Index* node) { VLOG(4) << "ProcConfigIrConverter::HandleIndex: " << node->ToString(); - XLS_ASSIGN_OR_RETURN(Channel * channel, - channel_scope_->GetChannelForArrayIndex(node)); - node_to_ir_[node] = channel; + XLS_ASSIGN_OR_RETURN( + ChannelOrArray channel_or_array, + channel_scope_->GetChannelOrArrayForArrayIndex(proc_id_, node)); + node_to_ir_[node] = ChannelOrArrayToProcConfigValue(channel_or_array); return absl::OkStatus(); } @@ -165,8 +166,8 @@ absl::Status ProcConfigIrConverter::HandleLet(const Let* node) { NameDef* name_def = std::get(leaves[i]); XLS_ASSIGN_OR_RETURN( ChannelOrArray target, - channel_scope_->AssociateWithExistingChannelOrArray(name_def, - decl)); + channel_scope_->AssociateWithExistingChannelOrArray( + proc_id_, name_def, decl)); node_to_ir_[name_def] = ChannelOrArrayToProcConfigValue(target); continue; } @@ -192,7 +193,7 @@ absl::Status ProcConfigIrConverter::HandleLet(const Let* node) { ProcConfigValueToChannelOrArray(value); if (channel_or_array.has_value()) { XLS_RETURN_IF_ERROR(channel_scope_->AssociateWithExistingChannelOrArray( - def, *channel_or_array)); + proc_id_, def, *channel_or_array)); } node_to_ir_[def] = value; } @@ -244,7 +245,7 @@ absl::Status ProcConfigIrConverter::HandleParam(const Param* node) { ProcConfigValueToChannelOrArray(value); if (channel_or_array.has_value()) { XLS_RETURN_IF_ERROR(channel_scope_->AssociateWithExistingChannelOrArray( - node->name_def(), *channel_or_array)); + proc_id_, node->name_def(), *channel_or_array)); } node_to_ir_[node->name_def()] = value; return absl::OkStatus(); diff --git a/xls/dslx/ir_convert/testdata/ir_converter_test_DealOutChannelSubarray.ir b/xls/dslx/ir_convert/testdata/ir_converter_test_DealOutChannelSubarray.ir new file mode 100644 index 0000000000..406366ec3d --- /dev/null +++ b/xls/dslx/ir_convert/testdata/ir_converter_test_DealOutChannelSubarray.ir @@ -0,0 +1,58 @@ +package test_module + +file_number 0 "test_module.x" + +chan test_module__the_channel__0_0(bits[32], id=0, kind=streaming, ops=send_receive, flow_control=ready_valid, strictness=proven_mutually_exclusive, metadata="""""") +chan test_module__the_channel__1_0(bits[32], id=1, kind=streaming, ops=send_receive, flow_control=ready_valid, strictness=proven_mutually_exclusive, metadata="""""") +chan test_module__the_channel__0_1(bits[32], id=2, kind=streaming, ops=send_receive, flow_control=ready_valid, strictness=proven_mutually_exclusive, metadata="""""") +chan test_module__the_channel__1_1(bits[32], id=3, kind=streaming, ops=send_receive, flow_control=ready_valid, strictness=proven_mutually_exclusive, metadata="""""") + +fn __test_module__B.init() -> () { + ret tuple.1: () = tuple(id=1) +} + +top proc __test_module__A_0_next(__state: (), init={()}) { + __token: token = literal(value=token, id=2) + literal.4: bits[1] = literal(value=1, id=4) + next (__state) +} + +proc __test_module__A__B_0_next(__state: (), init={()}) { + tok: token = after_all(id=8) + literal.9: bits[32] = literal(value=0, id=9) + literal.7: bits[1] = literal(value=1, id=7) + tok__1: token = send(tok, literal.9, predicate=literal.7, channel=test_module__the_channel__0_0, id=10) + receive.11: (token, bits[32]) = receive(tok__1, predicate=literal.7, channel=test_module__the_channel__0_0, id=11) + tok__3: token = tuple_index(receive.11, index=0, id=13) + literal.15: bits[32] = literal(value=1, id=15) + tok__4: token = send(tok__3, literal.15, predicate=literal.7, channel=test_module__the_channel__0_1, id=16) + receive.17: (token, bits[32]) = receive(tok__4, predicate=literal.7, channel=test_module__the_channel__0_1, id=17) + __token: token = literal(value=token, id=5) + tuple_index.12: token = tuple_index(receive.11, index=0, id=12) + tuple_index.14: bits[32] = tuple_index(receive.11, index=1, id=14) + tuple_index.18: token = tuple_index(receive.17, index=0, id=18) + tok__5: token = tuple_index(receive.17, index=0, id=19) + tuple_index.20: bits[32] = tuple_index(receive.17, index=1, id=20) + tuple.21: () = tuple(id=21) + next (tuple.21) +} + +proc __test_module__A__B_1_next(__state: (), init={()}) { + tok: token = after_all(id=25) + literal.26: bits[32] = literal(value=0, id=26) + literal.24: bits[1] = literal(value=1, id=24) + tok__1: token = send(tok, literal.26, predicate=literal.24, channel=test_module__the_channel__1_0, id=27) + receive.28: (token, bits[32]) = receive(tok__1, predicate=literal.24, channel=test_module__the_channel__1_0, id=28) + tok__3: token = tuple_index(receive.28, index=0, id=30) + literal.32: bits[32] = literal(value=1, id=32) + tok__4: token = send(tok__3, literal.32, predicate=literal.24, channel=test_module__the_channel__1_1, id=33) + receive.34: (token, bits[32]) = receive(tok__4, predicate=literal.24, channel=test_module__the_channel__1_1, id=34) + __token: token = literal(value=token, id=22) + tuple_index.29: token = tuple_index(receive.28, index=0, id=29) + tuple_index.31: bits[32] = tuple_index(receive.28, index=1, id=31) + tuple_index.35: token = tuple_index(receive.34, index=0, id=35) + tok__5: token = tuple_index(receive.34, index=0, id=36) + tuple_index.37: bits[32] = tuple_index(receive.34, index=1, id=37) + tuple.38: () = tuple(id=38) + next (tuple.38) +} diff --git a/xls/dslx/ir_convert/testdata/ir_converter_test_LetChannelSubarrayInConfig.ir b/xls/dslx/ir_convert/testdata/ir_converter_test_LetChannelSubarrayInConfig.ir new file mode 100644 index 0000000000..406366ec3d --- /dev/null +++ b/xls/dslx/ir_convert/testdata/ir_converter_test_LetChannelSubarrayInConfig.ir @@ -0,0 +1,58 @@ +package test_module + +file_number 0 "test_module.x" + +chan test_module__the_channel__0_0(bits[32], id=0, kind=streaming, ops=send_receive, flow_control=ready_valid, strictness=proven_mutually_exclusive, metadata="""""") +chan test_module__the_channel__1_0(bits[32], id=1, kind=streaming, ops=send_receive, flow_control=ready_valid, strictness=proven_mutually_exclusive, metadata="""""") +chan test_module__the_channel__0_1(bits[32], id=2, kind=streaming, ops=send_receive, flow_control=ready_valid, strictness=proven_mutually_exclusive, metadata="""""") +chan test_module__the_channel__1_1(bits[32], id=3, kind=streaming, ops=send_receive, flow_control=ready_valid, strictness=proven_mutually_exclusive, metadata="""""") + +fn __test_module__B.init() -> () { + ret tuple.1: () = tuple(id=1) +} + +top proc __test_module__A_0_next(__state: (), init={()}) { + __token: token = literal(value=token, id=2) + literal.4: bits[1] = literal(value=1, id=4) + next (__state) +} + +proc __test_module__A__B_0_next(__state: (), init={()}) { + tok: token = after_all(id=8) + literal.9: bits[32] = literal(value=0, id=9) + literal.7: bits[1] = literal(value=1, id=7) + tok__1: token = send(tok, literal.9, predicate=literal.7, channel=test_module__the_channel__0_0, id=10) + receive.11: (token, bits[32]) = receive(tok__1, predicate=literal.7, channel=test_module__the_channel__0_0, id=11) + tok__3: token = tuple_index(receive.11, index=0, id=13) + literal.15: bits[32] = literal(value=1, id=15) + tok__4: token = send(tok__3, literal.15, predicate=literal.7, channel=test_module__the_channel__0_1, id=16) + receive.17: (token, bits[32]) = receive(tok__4, predicate=literal.7, channel=test_module__the_channel__0_1, id=17) + __token: token = literal(value=token, id=5) + tuple_index.12: token = tuple_index(receive.11, index=0, id=12) + tuple_index.14: bits[32] = tuple_index(receive.11, index=1, id=14) + tuple_index.18: token = tuple_index(receive.17, index=0, id=18) + tok__5: token = tuple_index(receive.17, index=0, id=19) + tuple_index.20: bits[32] = tuple_index(receive.17, index=1, id=20) + tuple.21: () = tuple(id=21) + next (tuple.21) +} + +proc __test_module__A__B_1_next(__state: (), init={()}) { + tok: token = after_all(id=25) + literal.26: bits[32] = literal(value=0, id=26) + literal.24: bits[1] = literal(value=1, id=24) + tok__1: token = send(tok, literal.26, predicate=literal.24, channel=test_module__the_channel__1_0, id=27) + receive.28: (token, bits[32]) = receive(tok__1, predicate=literal.24, channel=test_module__the_channel__1_0, id=28) + tok__3: token = tuple_index(receive.28, index=0, id=30) + literal.32: bits[32] = literal(value=1, id=32) + tok__4: token = send(tok__3, literal.32, predicate=literal.24, channel=test_module__the_channel__1_1, id=33) + receive.34: (token, bits[32]) = receive(tok__4, predicate=literal.24, channel=test_module__the_channel__1_1, id=34) + __token: token = literal(value=token, id=22) + tuple_index.29: token = tuple_index(receive.28, index=0, id=29) + tuple_index.31: bits[32] = tuple_index(receive.28, index=1, id=31) + tuple_index.35: token = tuple_index(receive.34, index=0, id=35) + tok__5: token = tuple_index(receive.34, index=0, id=36) + tuple_index.37: bits[32] = tuple_index(receive.34, index=1, id=37) + tuple.38: () = tuple(id=38) + next (tuple.38) +}