Skip to content

Commit

Permalink
Add state index to selector for "select before loop" in merge states …
Browse files Browse the repository at this point in the history
…mode.

This gives the downstream compiler tools the information needed to
discard false data dependencies in later states.

Also debugging improvements for FSMs
- Better node naming
- Optional, and false positive prone, analysis for these cases with DebugIrTraceFlags_PrevStateIOReferences

PiperOrigin-RevId: 622461336
  • Loading branch information
Sean Purser-Haskell authored and copybara-github committed Apr 6, 2024
1 parent 69c01ae commit 541c53b
Show file tree
Hide file tree
Showing 9 changed files with 399 additions and 132 deletions.
288 changes: 213 additions & 75 deletions xls/contrib/xlscc/translate_block.cc

Large diffs are not rendered by default.

47 changes: 34 additions & 13 deletions xls/contrib/xlscc/translate_loops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,8 @@ absl::Status Translator::GenerateIR_PipelinedLoop(
}
}

lvalue_conditions_tuple = context().fb->Tuple(lvalue_conditions, loc);
lvalue_conditions_tuple = context().fb->Tuple(lvalue_conditions, loc,
/*name=*/"lvalue_conditions");
std::vector<std::shared_ptr<CType>> lvalue_conds_tuple_fields;
lvalue_conds_tuple_fields.resize(lvalue_conditions.size(),
std::make_shared<CBoolType>());
Expand Down Expand Up @@ -547,7 +548,8 @@ absl::Status Translator::GenerateIR_PipelinedLoop(
// Must match if(uses_on_reset) below
context_tuple_out = CValue(
context().fb->Tuple({outer_on_reset_value, context_struct_out.rvalue(),
lvalue_conditions_tuple}),
lvalue_conditions_tuple},
loc, /*name=*/"context_out_tuple_inner"),
context_tuple_type);
}

Expand Down Expand Up @@ -649,8 +651,13 @@ absl::Status Translator::GenerateIR_PipelinedLoop(
context_tuple_out = CValue(
context().fb->Tuple(
{on_reset_cval.rvalue(),
context().fb->TupleIndex(context_tuple_out.rvalue(), 1, loc),
context().fb->TupleIndex(context_tuple_out.rvalue(), 2, loc)}),
context().fb->TupleIndex(context_tuple_out.rvalue(), 1, loc,
/*name=*/"context_out_outer_struct"),
context().fb->TupleIndex(
context_tuple_out.rvalue(), 2, loc,
/*name=*/"context_out_outer_lvalue_conditions")},
loc,
/*name=*/"context_out_tuple_outer"),
context_tuple_out.type());
}

Expand All @@ -661,7 +668,8 @@ absl::Status Translator::GenerateIR_PipelinedLoop(
op.op = OpType::kSend;
std::vector<xls::BValue> sp = {context_tuple_out.rvalue(),
context().full_condition_bval(loc)};
op.ret_value = context().fb->Tuple(sp, loc);
op.ret_value =
context().fb->Tuple(sp, loc, /*name=*/"context_out_send_tup");
XLS_ASSIGN_OR_RETURN(ctx_out_op_ptr,
AddOpToChannel(op, context_out_channel, loc));
}
Expand Down Expand Up @@ -1055,9 +1063,14 @@ absl::Status Translator::GenerateIR_PipelinedLoopProc(

xls::BValue receive =
pb.ReceiveIf(context_out_channel->generated.value(), token,
/*pred=*/placeholder_cond, loc);
token = pb.TupleIndex(receive, 0);
xls::BValue received_context_tuple = pb.TupleIndex(receive, 1);
/*pred=*/placeholder_cond, loc,
/*name=*/absl::StrFormat("%s_receive_context", name_prefix));
token = pb.TupleIndex(
receive, 0, loc,
/*name=*/absl::StrFormat("%s_receive_context_token", name_prefix));
xls::BValue received_context_tuple = pb.TupleIndex(
receive, 1, loc,
/*name=*/absl::StrFormat("%s_receive_context_tup", name_prefix));

XLS_ASSIGN_OR_RETURN(
PipelinedLoopContentsReturn contents_ret,
Expand Down Expand Up @@ -1183,11 +1196,16 @@ Translator::GenerateIR_PipelinedLoopContents(

xls::BValue token = token_in;

xls::BValue received_on_reset = pb.TupleIndex(received_context_tuple, 0, loc);
xls::BValue received_context = pb.TupleIndex(received_context_tuple, 1, loc);
xls::BValue received_on_reset = pb.TupleIndex(
received_context_tuple, 0, loc,
/*name=*/absl::StrFormat("%s_receive_on_reset", name_prefix));
xls::BValue received_context = pb.TupleIndex(
received_context_tuple, 1, loc,
/*name=*/absl::StrFormat("%s_receive_context_data", name_prefix));

xls::BValue received_lvalue_conds =
pb.TupleIndex(received_context_tuple, 2, loc);
xls::BValue received_lvalue_conds = pb.TupleIndex(
received_context_tuple, 2, loc,
/*name=*/absl::StrFormat("%s_receive_context_lvalues", name_prefix));

xls::BValue use_context_in = last_iter_broke_in;

Expand Down Expand Up @@ -1401,7 +1419,10 @@ Translator::GenerateIR_PipelinedLoopContents(

xls::BValue ret_next =
pb.TupleIndex(fsm_ret.return_value,
prepared.return_index_for_static.at(namedecl), loc);
prepared.return_index_for_static.at(namedecl), loc,
/*name=*/
absl::StrFormat("%s_fsm_ret_static_%s", name_prefix,
namedecl->getNameAsString()));

xls::BValue state_elem_bval(
prepared.state_element_for_variable.at(namedecl), &pb);
Expand Down
58 changes: 58 additions & 0 deletions xls/contrib/xlscc/translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_replace.h"
#include "absl/types/span.h"
#include "clang/include/clang/AST/APValue.h"
Expand Down Expand Up @@ -5202,6 +5203,14 @@ std::string Debug_NodeToInfix(const xls::Node* node, int64_t& n_printed) {
return absl::StrFormat("%s(%i)", Debug_NodeToInfix(tup, n_printed),
ti->index());
}
if (node->Is<xls::Tuple>()) {
const xls::Tuple* tp = node->As<xls::Tuple>();
std::vector<std::string> operand_strings;
for (const xls::Node* op : tp->operands()) {
operand_strings.push_back(Debug_NodeToInfix(op, n_printed));
}
return std::string("(") + absl::StrJoin(operand_strings, ", ") + ")";
}
if (node->Is<xls::UnOp>()) {
const xls::UnOp* op = node->As<xls::UnOp>();
if (op->op() == xls::Op::kNot) {
Expand Down Expand Up @@ -5259,6 +5268,37 @@ std::string Debug_NodeToInfix(const xls::Node* node, int64_t& n_printed) {
typeid(*node).name());
}

std::string Debug_OpName(const IOOp& op) {
if (op.op == OpType::kTrace) {
return "trace";
}
if (op.channel != nullptr) {
std::string op_type_name;
switch (op.op) {
case OpType::kSend:
op_type_name = "send";
break;
case OpType::kRecv:
op_type_name = "recv";
break;
case OpType::kRead:
op_type_name = "read";
break;
case OpType::kWrite:
op_type_name = "write";
break;
default:
CHECK_EQ("Op type doesn't make sense here", nullptr);
}
return absl::StrFormat("%s_%s", op.channel->unique_name, op_type_name);
}
if (!op.final_param_name.empty()) {
return op.final_param_name;
}
CHECK_EQ("Unable to form name for op", nullptr);
return "TODO_OpName";
}

std::string Debug_VariablesChangedBetween(const TranslationContext& before,
const TranslationContext& after) {
std::ostringstream ostr;
Expand All @@ -5284,6 +5324,24 @@ std::string Debug_VariablesChangedBetween(const TranslationContext& before,
return ostr.str();
}

std::optional<std::list<const xls::Node*>> Debug_DeeplyCheckOperandsFromPrev(
const xls::Node* node,
const absl::flat_hash_set<const xls::Node*>& prev_state_io_nodes) {
for (const xls::Node* op : node->operands()) {
if (prev_state_io_nodes.contains(op)) {
return std::list<const xls::Node*>({op});
}
std::optional<std::list<const xls::Node*>> opt_path =
Debug_DeeplyCheckOperandsFromPrev(op, prev_state_io_nodes);
if (opt_path.has_value()) {
std::list<const xls::Node*> path = opt_path.value();
path.push_front(op);
return path;
}
}
return std::nullopt;
}

absl::StatusOr<Z3_lbool> Translator::CheckAssumptions(
absl::Span<xls::Node*> positive_nodes,
absl::Span<xls::Node*> negative_nodes, Z3_solver& solver,
Expand Down
60 changes: 32 additions & 28 deletions xls/contrib/xlscc/translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ class CVoidType : public CType {
absl::Status GetMetadata(Translator& translator, xlscc_metadata::Type* output,
absl::flat_hash_set<const clang::NamedDecl*>&
aliases_used) const override;
absl::Status GetMetadataValue(Translator& translator,
ConstValue const_value,
absl::Status GetMetadataValue(Translator& translator, ConstValue const_value,
xlscc_metadata::Value* output) const override;

bool operator==(const CType& o) const override;
Expand All @@ -129,8 +128,7 @@ class CBitsType : public CType {
absl::Status GetMetadata(Translator& translator, xlscc_metadata::Type* output,
absl::flat_hash_set<const clang::NamedDecl*>&
aliases_used) const override;
absl::Status GetMetadataValue(Translator& translator,
ConstValue const_value,
absl::Status GetMetadataValue(Translator& translator, ConstValue const_value,
xlscc_metadata::Value* output) const override;

bool operator==(const CType& o) const override;
Expand All @@ -151,8 +149,7 @@ class CIntType : public CType {
absl::Status GetMetadata(Translator& translator, xlscc_metadata::Type* output,
absl::flat_hash_set<const clang::NamedDecl*>&
aliases_used) const override;
absl::Status GetMetadataValue(Translator& translator,
ConstValue const_value,
absl::Status GetMetadataValue(Translator& translator, ConstValue const_value,
xlscc_metadata::Value* output) const override;

bool operator==(const CType& o) const override;
Expand Down Expand Up @@ -213,8 +210,7 @@ class CEnumType : public CIntType {
absl::Status GetMetadata(Translator& translator, xlscc_metadata::Type* output,
absl::flat_hash_set<const clang::NamedDecl*>&
aliases_used) const override;
absl::Status GetMetadataValue(Translator& translator,
ConstValue const_value,
absl::Status GetMetadataValue(Translator& translator, ConstValue const_value,
xlscc_metadata::Value* output) const override;

bool operator==(const CType& o) const override;
Expand All @@ -241,8 +237,7 @@ class CBoolType : public CType {
absl::Status GetMetadata(Translator& translator, xlscc_metadata::Type* output,
absl::flat_hash_set<const clang::NamedDecl*>&
aliases_used) const override;
absl::Status GetMetadataValue(Translator& translator,
ConstValue const_value,
absl::Status GetMetadataValue(Translator& translator, ConstValue const_value,
xlscc_metadata::Value* output) const override;
bool operator==(const CType& o) const override;
bool StoredAsXLSBits() const override;
Expand Down Expand Up @@ -279,8 +274,7 @@ class CStructType : public CType {
absl::Status GetMetadata(Translator& translator, xlscc_metadata::Type* output,
absl::flat_hash_set<const clang::NamedDecl*>&
aliases_used) const override;
absl::Status GetMetadataValue(Translator& translator,
ConstValue const_value,
absl::Status GetMetadataValue(Translator& translator, ConstValue const_value,
xlscc_metadata::Value* output) const override;
bool operator==(const CType& o) const override;
absl::StatusOr<bool> ContainsLValues(Translator& translator) const override;
Expand Down Expand Up @@ -314,8 +308,7 @@ class CInternalTuple : public CType {
absl::Status GetMetadata(Translator& translator, xlscc_metadata::Type* output,
absl::flat_hash_set<const clang::NamedDecl*>&
aliases_used) const override;
absl::Status GetMetadataValue(Translator& translator,
ConstValue const_value,
absl::Status GetMetadataValue(Translator& translator, ConstValue const_value,
xlscc_metadata::Value* output) const override;
bool operator==(const CType& o) const override;

Expand Down Expand Up @@ -349,8 +342,7 @@ class CInstantiableTypeAlias : public CType {
absl::Status GetMetadata(Translator& translator, xlscc_metadata::Type* output,
absl::flat_hash_set<const clang::NamedDecl*>&
aliases_used) const override;
absl::Status GetMetadataValue(Translator& translator,
ConstValue const_value,
absl::Status GetMetadataValue(Translator& translator, ConstValue const_value,
xlscc_metadata::Value* output) const override;
explicit operator std::string() const override;
int GetBitWidth() const override;
Expand All @@ -370,8 +362,7 @@ class CArrayType : public CType {
absl::Status GetMetadata(Translator& translator, xlscc_metadata::Type* output,
absl::flat_hash_set<const clang::NamedDecl*>&
aliases_used) const override;
absl::Status GetMetadataValue(Translator& translator,
ConstValue const_value,
absl::Status GetMetadataValue(Translator& translator, ConstValue const_value,
xlscc_metadata::Value* output) const override;
absl::StatusOr<bool> ContainsLValues(Translator& translator) const override;

Expand All @@ -393,8 +384,7 @@ class CPointerType : public CType {
absl::Status GetMetadata(Translator& translator, xlscc_metadata::Type* output,
absl::flat_hash_set<const clang::NamedDecl*>&
aliases_used) const override;
absl::Status GetMetadataValue(Translator& translator,
ConstValue const_value,
absl::Status GetMetadataValue(Translator& translator, ConstValue const_value,
xlscc_metadata::Value* output) const override;
absl::StatusOr<bool> ContainsLValues(Translator& translator) const override;

Expand All @@ -414,8 +404,7 @@ class CReferenceType : public CType {
absl::Status GetMetadata(Translator& translator, xlscc_metadata::Type* output,
absl::flat_hash_set<const clang::NamedDecl*>&
aliases_used) const override;
absl::Status GetMetadataValue(Translator& translator,
ConstValue const_value,
absl::Status GetMetadataValue(Translator& translator, ConstValue const_value,
xlscc_metadata::Value* output) const override;
absl::StatusOr<bool> ContainsLValues(Translator& translator) const override;

Expand All @@ -442,8 +431,7 @@ class CChannelType : public CType {
absl::Status GetMetadata(Translator& translator, xlscc_metadata::Type* output,
absl::flat_hash_set<const clang::NamedDecl*>&
aliases_used) const override;
absl::Status GetMetadataValue(Translator& translator,
ConstValue const_value,
absl::Status GetMetadataValue(Translator& translator, ConstValue const_value,
xlscc_metadata::Value* output) const override;

std::shared_ptr<CType> GetItemType() const;
Expand Down Expand Up @@ -923,6 +911,7 @@ int Debug_CountNodes(const xls::Node* node,
std::set<const xls::Node*>& visited);
std::string Debug_NodeToInfix(xls::BValue bval);
std::string Debug_NodeToInfix(const xls::Node* node, int64_t& n_printed);
std::string Debug_OpName(const IOOp& op);

// Encapsulates a context for translating Clang AST to XLS IR.
// This is roughly equivalent to a "scope" in C++. There will typically
Expand Down Expand Up @@ -1086,6 +1075,10 @@ struct TranslationContext {
std::string Debug_VariablesChangedBetween(const TranslationContext& before,
const TranslationContext& after);

std::optional<std::list<const xls::Node*>> Debug_DeeplyCheckOperandsFromPrev(
const xls::Node* node,
const absl::flat_hash_set<const xls::Node*>& prev_state_io_nodes);

enum IOOpOrdering {
kNone = 0,
kChannelWise = 1,
Expand All @@ -1102,7 +1095,8 @@ enum DebugIrTraceFlags {
DebugIrTraceFlags_None = 0,
DebugIrTraceFlags_LoopContext = 1,
DebugIrTraceFlags_LoopControl = 2,
DebugIrTraceFlags_FSMStates = 4
DebugIrTraceFlags_FSMStates = 4,
DebugIrTraceFlags_PrevStateIOReferences = 8
};

class Translator {
Expand Down Expand Up @@ -1706,6 +1700,16 @@ class Translator {
PreparedBlock& prepared, xls::ProcBuilder& pb, int nesting_level,
const xls::SourceInfo& body_loc);

struct LayoutFSMStatesReturn {
absl::flat_hash_map<const IOOp*, const State*> state_by_io_op;
std::vector<std::unique_ptr<State>> states;
bool has_pipelined_loop = false;
};

absl::StatusOr<LayoutFSMStatesReturn> LayoutFSMStates(
PreparedBlock& prepared, xls::ProcBuilder& pb,
const xls::SourceInfo& body_loc);

std::set<ChannelBundle> GetChannelsUsedByOp(
const IOOp& op, const PipelinedLoopSubProc* sub_procp,
const xls::SourceInfo& loc);
Expand Down Expand Up @@ -1760,9 +1764,9 @@ class Translator {
};
// Checks if an expression is an IO op, and if so, generates the value
// to replace it in IR generation.
absl::StatusOr<IOOpReturn> InterceptIOOp(
const clang::Expr* expr, const xls::SourceInfo& loc,
CValue assignment_value = CValue());
absl::StatusOr<IOOpReturn> InterceptIOOp(const clang::Expr* expr,
const xls::SourceInfo& loc,
CValue assignment_value = CValue());

// IOOp must have io_call, and op members filled in
// This will add a parameter for IO input if needed,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ module foo_proc(
wire [31:0] in2_select;
wire [31:0] in1_select;
wire p0_all_active_inputs_valid;
wire [31:0] out_send_value;
assign in1_op0_ret_io_value = dir == 32'h0000_0000;
assign in2_op0_ret_io_value = ~in1_op0_ret_io_value;
assign in2_select = in2_op0_ret_io_value ? in2 : 32'h0000_0000;
assign in1_select = in1_op0_ret_io_value ? in1 : 32'h0000_0000;
assign p0_all_active_inputs_valid = (~in1_op0_ret_io_value | in1_vld) & (~in2_op0_ret_io_value | in2_vld);
assign out = in1_op0_ret_io_value ? in1_select : in2_select;
assign out_send_value = in1_op0_ret_io_value ? in1_select : in2_select;
assign out = out_send_value;
assign out_vld = p0_all_active_inputs_valid & 1'h1 & 1'h1;
assign in1_rdy = in1_op0_ret_io_value & out_rdy;
assign in2_rdy = in2_op0_ret_io_value & out_rdy;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ module foo_proc(
wire [31:0] in2_select;
wire [31:0] in1_select;
wire p0_all_active_inputs_valid;
wire [31:0] out_send_value;
assign in1_op0_ret_io_value = dir == 32'h0000_0000;
assign in2_op0_ret_io_value = ~in1_op0_ret_io_value;
assign in2_select = in2_op0_ret_io_value ? in2 : 32'h0000_0000;
assign in1_select = in1_op0_ret_io_value ? in1 : 32'h0000_0000;
assign p0_all_active_inputs_valid = (~in1_op0_ret_io_value | in1_vld) & (~in2_op0_ret_io_value | in2_vld);
assign out = in1_op0_ret_io_value ? in1_select : in2_select;
assign out_send_value = in1_op0_ret_io_value ? in1_select : in2_select;
assign out = out_send_value;
assign out_vld = p0_all_active_inputs_valid & 1'h1 & 1'h1;
assign in1_rdy = in1_op0_ret_io_value & out_rdy;
assign in2_rdy = in2_op0_ret_io_value & out_rdy;
Expand Down
Loading

0 comments on commit 541c53b

Please sign in to comment.