Skip to content

Commit

Permalink
Enable referencing subarrays of an N-D channel array, i.e. regions wi…
Browse files Browse the repository at this point in the history
…th up to N-1 dims fixed.

The primary use case for this is dealing out subarrays in proc config, as in `spawn Foo(arr[0])` where `arr` is a `chan<u32>[M][N]` and `Foo` requires a 1D channel array.

Doing this in next() will continue to be impossible, but with a clear error message.

PiperOrigin-RevId: 687365179
  • Loading branch information
richmckeever authored and copybara-github committed Oct 18, 2024
1 parent b03f2f4 commit aa6cca3
Show file tree
Hide file tree
Showing 9 changed files with 498 additions and 78 deletions.
3 changes: 3 additions & 0 deletions xls/dslx/ir_convert/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ cc_library(
"//xls/dslx:interp_value",
"//xls/dslx:interp_value_utils",
"//xls/dslx/frontend:ast",
"//xls/dslx/frontend:proc_id",
"//xls/dslx/type_system:parametric_env",
"//xls/dslx/type_system:type",
"//xls/dslx/type_system:type_info",
Expand Down Expand Up @@ -79,6 +80,8 @@ cc_test(
"//xls/dslx:import_data",
"//xls/dslx/frontend:ast",
"//xls/dslx/frontend:pos",
"//xls/dslx/frontend:proc_id",
"//xls/dslx/frontend:proc_test_utils",
"//xls/dslx/type_system:parametric_env",
"//xls/dslx/type_system:type",
"//xls/dslx/type_system:type_info",
Expand Down
110 changes: 82 additions & 28 deletions xls/dslx/ir_convert/channel_scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
Expand All @@ -36,6 +37,7 @@
#include "xls/dslx/channel_direction.h"
#include "xls/dslx/constexpr_evaluator.h"
#include "xls/dslx/frontend/ast.h"
#include "xls/dslx/frontend/proc_id.h"
#include "xls/dslx/import_data.h"
#include "xls/dslx/interp_value.h"
#include "xls/dslx/interp_value_utils.h"
Expand All @@ -52,13 +54,19 @@
#include "xls/ir/xls_ir_interface.pb.h"

namespace xls::dslx {
namespace {

constexpr std::string_view kNameAndDimsSeparator = "__";
constexpr std::string_view kBetweenDimsSeparator = "_";

} // namespace

ChannelScope::ChannelScope(PackageConversionData* conversion_info,
ImportData* import_data,
std::optional<FifoConfig> default_fifo_config)
: conversion_info_(conversion_info),
import_data_(import_data),
channel_name_uniquer_(/*separator=*/"__"),
channel_name_uniquer_(kNameAndDimsSeparator),
default_fifo_config_(default_fifo_config) {
// Populate channel name uniquer with pre-existing channel names.
for (Channel* channel : conversion_info_->package->channels()) {
Expand Down Expand Up @@ -106,7 +114,8 @@ absl::StatusOr<ChannelOrArray> ChannelScope::DefineChannelOrArrayInternal(
XLS_ASSIGN_OR_RETURN(std::vector<std::string> suffixes,
CreateAllArrayElementSuffixes(*dims));
for (const std::string& suffix : suffixes) {
std::string channel_name = absl::StrCat(base_channel_name, "__", suffix);
std::string channel_name =
absl::StrCat(base_channel_name, kNameAndDimsSeparator, suffix);
XLS_ASSIGN_OR_RETURN(Channel * channel,
CreateChannel(channel_name, ops, type, fifo_config));
array->AddChannel(channel_name, channel);
Expand Down Expand Up @@ -170,7 +179,8 @@ absl::Status ChannelScope::DefineProtoChannelOrArray(
}

absl::StatusOr<ChannelOrArray>
ChannelScope::AssociateWithExistingChannelOrArray(const NameDef* name_def,
ChannelScope::AssociateWithExistingChannelOrArray(const ProcId& proc_id,
const NameDef* name_def,
const ChannelDecl* decl) {
VLOG(4) << "ChannelScope::AssociateWithExistingChannelOrArray : "
<< name_def->ToString() << " -> " << decl->ToString();
Expand All @@ -180,22 +190,38 @@ ChannelScope::AssociateWithExistingChannelOrArray(const NameDef* name_def,
}
ChannelOrArray channel_or_array = decl_to_channel_or_array_.at(decl);
XLS_RETURN_IF_ERROR(
AssociateWithExistingChannelOrArray(name_def, channel_or_array));
AssociateWithExistingChannelOrArray(proc_id, name_def, channel_or_array));
return channel_or_array;
}

absl::Status ChannelScope::AssociateWithExistingChannelOrArray(
const NameDef* name_def, ChannelOrArray channel_or_array) {
const ProcId& proc_id, const NameDef* name_def,
ChannelOrArray channel_or_array) {
VLOG(4) << "ChannelScope::AssociateWithExistingChannelOrArray : "
<< name_def->ToString() << " -> "
<< GetBaseNameForChannelOrArray(channel_or_array) << " (array: "
<< std::holds_alternative<ChannelArray*>(channel_or_array) << ")";
name_def_to_channel_or_array_[name_def] = channel_or_array;
name_def_to_channel_or_array_[std::make_pair(proc_id, name_def)] =
channel_or_array;
return absl::OkStatus();
}

absl::StatusOr<Channel*> ChannelScope::GetChannelForArrayIndex(
const Index* index) {
const ProcId& proc_id, const Index* index) {
XLS_ASSIGN_OR_RETURN(
ChannelOrArray result,
EvaluateIndex(proc_id, index, /*allow_subarray_reference=*/false));
CHECK(std::holds_alternative<Channel*>(result));
return std::get<Channel*>(result);
}

absl::StatusOr<ChannelOrArray> ChannelScope::GetChannelOrArrayForArrayIndex(
const ProcId& proc_id, const Index* index) {
return EvaluateIndex(proc_id, index, /*allow_subarray_reference=*/true);
}

absl::StatusOr<ChannelOrArray> ChannelScope::EvaluateIndex(
const ProcId& proc_id, const Index* index, bool allow_subarray_reference) {
VLOG(4) << "ChannelScope::GetChannelForArrayIndex : " << index->ToString();
CHECK(function_context_.has_value());
std::string suffix;
Expand All @@ -212,11 +238,13 @@ absl::StatusOr<Channel*> ChannelScope::GetChannelForArrayIndex(
std::get<Expr*>(index->rhs())));
XLS_ASSIGN_OR_RETURN(int64_t dim_value,
dim_interp_value.GetBitValueUnsigned());
suffix = suffix.empty() ? absl::StrCat(dim_value)
: absl::StrCat(dim_value, "_", suffix);
suffix = suffix.empty()
? absl::StrCat(dim_value)
: absl::StrCat(dim_value, kBetweenDimsSeparator, suffix);
if (const NameRef* name_ref = dynamic_cast<NameRef*>(index->lhs());
name_ref) {
return GetChannelArrayElement(name_ref, suffix);
return GetChannelArrayElement(proc_id, name_ref, suffix,
allow_subarray_reference);
}
Index* new_index = dynamic_cast<Index*>(index->lhs());
if (!new_index) {
Expand All @@ -228,16 +256,6 @@ absl::StatusOr<Channel*> ChannelScope::GetChannelForArrayIndex(
}
}

absl::StatusOr<std::string_view> ChannelScope::GetBaseNameForNameDef(
const NameDef* name_def) {
const auto it = name_def_to_channel_or_array_.find(name_def);
if (it == name_def_to_channel_or_array_.end()) {
return absl::NotFoundError(absl::StrCat(
"No channel or array associated with NameDef: ", name_def->ToString()));
}
return GetBaseNameForChannelOrArray(it->second);
}

std::string_view ChannelScope::GetBaseNameForChannelOrArray(
ChannelOrArray channel_or_array) {
return absl::visit(Visitor{[](Channel* channel) { return channel->name(); },
Expand Down Expand Up @@ -270,7 +288,8 @@ ChannelScope::CreateAllArrayElementSuffixes(const std::vector<Expr*>& dims) {
continue;
}
for (const std::string& next : strings) {
new_strings.push_back(absl::StrCat(next, "_", element_index));
new_strings.push_back(
absl::StrCat(next, kBetweenDimsSeparator, element_index));
}
}
strings = std::move(new_strings);
Expand All @@ -280,8 +299,8 @@ ChannelScope::CreateAllArrayElementSuffixes(const std::vector<Expr*>& dims) {

absl::StatusOr<std::string> ChannelScope::CreateBaseChannelName(
std::string_view short_name) {
return channel_name_uniquer_.GetSanitizedUniqueName(
absl::StrCat(conversion_info_->package->name(), "__", short_name));
return channel_name_uniquer_.GetSanitizedUniqueName(absl::StrCat(
conversion_info_->package->name(), kNameAndDimsSeparator, short_name));
}

absl::StatusOr<xls::Type*> ChannelScope::GetChannelType(
Expand Down Expand Up @@ -339,10 +358,12 @@ absl::StatusOr<Channel*> ChannelScope::CreateChannel(
/*fifo_config=*/fifo_config);
}

absl::StatusOr<Channel*> ChannelScope::GetChannelArrayElement(
const NameRef* name_ref, std::string_view flattened_name_suffix) {
absl::StatusOr<ChannelOrArray> ChannelScope::GetChannelArrayElement(
const ProcId& proc_id, const NameRef* name_ref,
std::string_view flattened_name_suffix, bool allow_subarray_reference) {
const auto* name_def = std::get<const NameDef*>(name_ref->name_def());
const auto it = name_def_to_channel_or_array_.find(name_def);
const auto it =
name_def_to_channel_or_array_.find(std::make_pair(proc_id, name_def));
if (it == name_def_to_channel_or_array_.end()) {
return absl::NotFoundError(
absl::StrCat("Not a channel or channel array: ", name_def->ToString()));
Expand All @@ -355,15 +376,48 @@ absl::StatusOr<Channel*> ChannelScope::GetChannelArrayElement(
std::get<Channel*>(channel_or_array)->name()));
}
ChannelArray* array = std::get<ChannelArray*>(channel_or_array);
std::string flattened_channel_name =
absl::StrCat(array->base_channel_name(), "__", flattened_name_suffix);
std::string flattened_channel_name = absl::StrCat(
array->base_channel_name(),
array->is_subarray() ? kBetweenDimsSeparator : kNameAndDimsSeparator,
flattened_name_suffix);
std::optional<Channel*> channel = array->FindChannel(flattened_channel_name);
if (channel.has_value()) {
VLOG(4) << "Found channel array element: " << (*channel)->name();
return *channel;
}
if (allow_subarray_reference) {
return GetOrDefineSubarray(array, flattened_channel_name);
}
return absl::NotFoundError(absl::StrCat(
"No array element with flattened name: ", flattened_channel_name));
}

absl::StatusOr<ChannelArray*> ChannelScope::GetOrDefineSubarray(
ChannelArray* array, std::string_view subarray_name) {
const auto it = subarrays_.find(subarray_name);
if (it != subarrays_.end()) {
VLOG(5) << "Found subarray " << subarray_name;
return it->second;
}
ChannelArray* subarray =
&arrays_.emplace_back(ChannelArray(subarray_name, /*subarray=*/true));
subarrays_.emplace_hint(it, subarray_name, subarray);
std::string subarray_prefix =
absl::StrCat(subarray_name, kBetweenDimsSeparator);
VLOG(5) << "Searching for subarray elements with prefix " << subarray_prefix;
for (const std::string& name : array->flattened_names_in_order()) {
if (absl::StartsWith(name, subarray_prefix)) {
Channel* channel = *array->FindChannel(name);
subarray->AddChannel(channel->name(), channel);
}
}
// If type checking has been done right etc., there should never be a request
// for a subarray prefix that matches zero channels, even when compiling
// erroneous DSLX code.
CHECK(!subarray->flattened_names_in_order().empty());
VLOG(5) << "Defined subarray " << subarray_name << " with "
<< subarray->flattened_names_in_order().size() << " elements.";
return subarray;
}

} // namespace xls::dslx
48 changes: 37 additions & 11 deletions xls/dslx/ir_convert/channel_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xls/dslx/frontend/ast.h"
#include "xls/dslx/frontend/proc_id.h"
#include "xls/dslx/import_data.h"
#include "xls/dslx/ir_convert/conversion_info.h"
#include "xls/dslx/type_system/parametric_env.h"
Expand All @@ -49,10 +50,12 @@ class ChannelArray {
std::string ToString() const { return base_channel_name_; }

private:
explicit ChannelArray(std::string_view base_channel_name)
: base_channel_name_(base_channel_name) {}
explicit ChannelArray(std::string_view base_channel_name,
bool subarray = false)
: base_channel_name_(base_channel_name), subarray_(subarray) {}

std::string_view base_channel_name() const { return base_channel_name_; }
bool is_subarray() const { return subarray_; }

absl::Span<const std::string> flattened_names_in_order() const {
return flattened_names_in_order_;
Expand All @@ -75,6 +78,14 @@ class ChannelArray {
// in the DSLX source code.
const std::string base_channel_name_;

// Whether this array represents part of a larger N-D array, with up to N-1
// dims fixed. In that case, it will contain some of the same channel pointers
// that are in the `ChannelArray` object representing the overall array.
// `ChannelArray` objects for subarrays are fabricated by a `ChannelScope` on
// an as-needed basis, when references to them are encountered (in the form of
// `Index` ops).
const bool subarray_;

// The flattened names in order of addition. The scope adds channels in
// ascending index order, and in some situations wants to enumerate them in
// that order.
Expand Down Expand Up @@ -118,13 +129,14 @@ class ChannelScope {
// should be used, for example, when a channel is passed into `spawn` and the
// receiving proc associates it with a local argument name.
absl::StatusOr<ChannelOrArray> AssociateWithExistingChannelOrArray(
const NameDef* name_def, const ChannelDecl* decl);
const ProcId& proc_id, const NameDef* name_def, const ChannelDecl* decl);

// Variant of `AssociateWithExistingChannelOrArray`, to be used when the
// caller has the channel or array returned by `DefineChannelOrArray` on hand,
// rather than the `decl` it was made from.
absl::Status AssociateWithExistingChannelOrArray(
const NameDef* name_def, ChannelOrArray channel_or_array);
const ProcId& proc_id, const NameDef* name_def,
ChannelOrArray channel_or_array);

// Retrieves the individual `Channel` that is referred to by the given `index`
// operation. In order for this to succeed, `index` must meet the following
Expand All @@ -138,7 +150,15 @@ class ChannelScope {
// be constexpr evaluatable.
// A not-found error is the guaranteed result in cases where `index` is not
// a channel array index operation at all.
absl::StatusOr<Channel*> GetChannelForArrayIndex(const Index* index);
absl::StatusOr<Channel*> GetChannelForArrayIndex(const ProcId& proc_id,
const Index* index);

// Retrieves the subarray or individual `Channel` that is referred to by the
// given `index operation. The `index` must conform to the criteria described
// for `GetChannelForArrayIndex()`, but it may lead part way into a
// multidimensional channel array.
absl::StatusOr<ChannelOrArray> GetChannelOrArrayForArrayIndex(
const ProcId& proc_id, const Index* index);

private:
absl::StatusOr<ChannelOrArray> DefineChannelOrArrayInternal(
Expand All @@ -150,9 +170,6 @@ class ChannelScope {
ChannelOrArray array, dslx::ChannelTypeAnnotation* type_annot,
xls::Type* ir_type, TypeInfo* type_info);

absl::StatusOr<std::string_view> GetBaseNameForNameDef(
const NameDef* name_def);

std::string_view GetBaseNameForChannelOrArray(
ChannelOrArray channel_or_array);

Expand All @@ -171,8 +188,16 @@ class ChannelScope {
xls::Type* type,
std::optional<FifoConfig> fifo_config);

absl::StatusOr<Channel*> GetChannelArrayElement(
const NameRef* name_ref, std::string_view flattened_name_suffix);
absl::StatusOr<ChannelOrArray> EvaluateIndex(const ProcId& proc_id,
const Index* index,
bool allow_subarray_reference);

absl::StatusOr<ChannelOrArray> GetChannelArrayElement(
const ProcId& proc_id, const NameRef* name_ref,
std::string_view flattened_name_suffix, bool allow_subarray_reference);

absl::StatusOr<ChannelArray*> GetOrDefineSubarray(
ChannelArray* array, std::string_view subarray_name);

PackageConversionData* const conversion_info_;
ImportData* const import_data_;
Expand Down Expand Up @@ -200,8 +225,9 @@ class ChannelScope {

absl::flat_hash_map<const ChannelDecl*, ChannelOrArray>
decl_to_channel_or_array_;
absl::flat_hash_map<const NameDef*, ChannelOrArray>
absl::flat_hash_map<std::pair<ProcId, const NameDef*>, ChannelOrArray>
name_def_to_channel_or_array_;
absl::flat_hash_map<std::string, ChannelArray*> subarrays_;
};

} // namespace xls::dslx
Expand Down
Loading

0 comments on commit aa6cca3

Please sign in to comment.