Skip to content

Commit

Permalink
Record register-read/writes to the node-to-stage map
Browse files Browse the repository at this point in the history
This will allow us to identify which stages standard pipeline registers belong to potentially allowing us to remove them if they are redundant.

PiperOrigin-RevId: 609541610
  • Loading branch information
allight authored and copybara-github committed Feb 23, 2024
1 parent 2f5b2e1 commit 863aae4
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 30 deletions.
82 changes: 52 additions & 30 deletions xls/codegen/block_conversion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,8 @@ static absl::StatusOr<BubbleFlowControl> UpdatePipelineWithBubbleFlowControl(
absl::Span<std::optional<Node*> const> pipeline_valid_nodes,
absl::Span<std::optional<Node*> const> pipeline_done_nodes,
absl::Span<PipelineStageRegisters> pipeline_data_registers,
absl::Span<std::optional<StateRegister>> state_registers, Block* block) {
absl::Span<std::optional<StateRegister>> state_registers,
absl::flat_hash_map<Node*, Stage>& node_to_stage_map, Block* block) {
// Create enable signals for each pipeline stage.
// - The enable signal for stage N is true either
// a. The next stage is empty/not valid
Expand Down Expand Up @@ -518,8 +519,12 @@ static absl::StatusOr<BubbleFlowControl> UpdatePipelineWithBubbleFlowControl(
/*loc=*/SourceInfo(), pipeline_reg.reg_write->data(),
/*load_enable=*/result.data_load_enable.at(stage),
/*reset=*/pipeline_reg.reg_write->reset(), pipeline_reg.reg));
XLS_RET_CHECK(node_to_stage_map.contains(pipeline_reg.reg_write));
Stage s = node_to_stage_map[pipeline_reg.reg_write];
node_to_stage_map.erase(pipeline_reg.reg_write);
XLS_RETURN_IF_ERROR(block->RemoveNode(pipeline_reg.reg_write));
pipeline_reg.reg_write = new_reg_write;
node_to_stage_map[new_reg_write] = s;
}
}
}
Expand Down Expand Up @@ -595,6 +600,8 @@ static absl::StatusOr<BubbleFlowControl> UpdatePipelineWithBubbleFlowControl(
/*load_enable=*/load_enable,
/*reset=*/state_register->reg_write->reset(),
/*reg=*/state_register->reg_write->GetRegister()));
XLS_RET_CHECK(node_to_stage_map.contains(state_register->reg_write));
node_to_stage_map.erase(state_register->reg_write);
XLS_RETURN_IF_ERROR(block->RemoveNode(state_register->reg_write));
state_register->reg_write = new_reg_write;

Expand Down Expand Up @@ -667,7 +674,8 @@ static absl::StatusOr<BubbleFlowControl> UpdatePipelineWithBubbleFlowControl(
//
static absl::StatusOr<Node*> UpdateSingleStagePipelineWithFlowControl(
Node* all_active_outputs_ready, Node* all_active_inputs_valid,
absl::Span<std::optional<StateRegister>> state_registers, Block* block) {
absl::Span<std::optional<StateRegister>> state_registers,
absl::flat_hash_map<Node*, Stage>& node_to_stage_map, Block* block) {
std::vector<Node*> operands = {all_active_outputs_ready,
all_active_inputs_valid};
XLS_ASSIGN_OR_RETURN(
Expand Down Expand Up @@ -705,6 +713,7 @@ static absl::StatusOr<Node*> UpdateSingleStagePipelineWithFlowControl(
/*reset=*/state_register->reg_write->reset(),
/*reg=*/state_register->reg_write->GetRegister()));
XLS_RETURN_IF_ERROR(block->RemoveNode(state_register->reg_write));
XLS_CHECK(!node_to_stage_map.contains(state_register->reg_write));
state_register->reg_write = new_reg_write;
}
}
Expand All @@ -715,7 +724,8 @@ static absl::StatusOr<Node*> UpdateSingleStagePipelineWithFlowControl(
static absl::StatusOr<ValidPorts> AddValidSignal(
absl::Span<PipelineStageRegisters> pipeline_registers,
const CodegenOptions& options, Block* block,
std::vector<std::optional<Node*>>& pipelined_valids) {
std::vector<std::optional<Node*>>& pipelined_valids,
absl::flat_hash_map<Node*, Stage>& node_to_stage_map) {
// Add valid input port.
XLS_RET_CHECK(options.valid_control().has_value());
if (options.valid_control()->input_name().empty()) {
Expand Down Expand Up @@ -757,6 +767,11 @@ static absl::StatusOr<ValidPorts> AddValidSignal(
/*loc=*/SourceInfo(), pipeline_reg.reg_write->data(),
/*load_enable=*/load_enable,
/*reset=*/std::nullopt, pipeline_reg.reg));
if (node_to_stage_map.contains(pipeline_reg.reg_write)) {
Stage s = node_to_stage_map[pipeline_reg.reg_write];
node_to_stage_map.erase(pipeline_reg.reg_write);
node_to_stage_map[new_write] = s;
}
XLS_RETURN_IF_ERROR(block->RemoveNode(pipeline_reg.reg_write));
pipeline_reg.reg_write = new_write;
}
Expand Down Expand Up @@ -1878,14 +1893,14 @@ static absl::StatusOr<std::vector<std::optional<Node*>>> AddBubbleFlowControl(
XLS_VLOG_LINES(3, block->DumpIr());
}

XLS_ASSIGN_OR_RETURN(
BubbleFlowControl bubble_flow_control,
UpdatePipelineWithBubbleFlowControl(
absl::MakeSpan(all_active_outputs_ready), options,
absl::MakeSpan(streaming_io.pipeline_valid),
absl::MakeSpan(streaming_io.stage_done),
absl::MakeSpan(streaming_io.pipeline_registers),
absl::MakeSpan(streaming_io.state_registers), block));
XLS_ASSIGN_OR_RETURN(BubbleFlowControl bubble_flow_control,
UpdatePipelineWithBubbleFlowControl(
absl::MakeSpan(all_active_outputs_ready), options,
absl::MakeSpan(streaming_io.pipeline_valid),
absl::MakeSpan(streaming_io.stage_done),
absl::MakeSpan(streaming_io.pipeline_registers),
absl::MakeSpan(streaming_io.state_registers),
streaming_io.node_to_stage_map, block));

XLS_VLOG(3) << "After Bubble Flow Control (pipeline)";
XLS_VLOG_LINES(3, block->DumpIr());
Expand All @@ -1909,12 +1924,12 @@ static absl::StatusOr<std::vector<std::optional<Node*>>> AddBubbleFlowControl(

// Handle flow control for the single pipeline stage case.
if (streaming_io.pipeline_registers.empty()) {
XLS_ASSIGN_OR_RETURN(
Node * input_stage_enable,
UpdateSingleStagePipelineWithFlowControl(
bubble_flow_control.data_load_enable.front(),
*streaming_io.stage_done.at(0),
absl::MakeSpan(streaming_io.state_registers), block));
XLS_ASSIGN_OR_RETURN(Node * input_stage_enable,
UpdateSingleStagePipelineWithFlowControl(
bubble_flow_control.data_load_enable.front(),
*streaming_io.stage_done.at(0),
absl::MakeSpan(streaming_io.state_registers),
streaming_io.node_to_stage_map, block));
bubble_flow_control.data_load_enable = {input_stage_enable};
}

Expand Down Expand Up @@ -2230,10 +2245,11 @@ class CloneNodesIntoBlockHandler {
if (schedule.IsLiveOutOfCycle(function_base_node, stage)) {
Node* node = node_map_.at(function_base_node);

XLS_ASSIGN_OR_RETURN(Node * node_after_stage,
CreatePipelineRegistersForNode(
PipelineSignalName(node->GetName(), stage),
node, result_.pipeline_registers.at(stage)));
XLS_ASSIGN_OR_RETURN(
Node * node_after_stage,
CreatePipelineRegistersForNode(
PipelineSignalName(node->GetName(), stage), node, stage,
result_.pipeline_registers.at(stage)));

node_map_[function_base_node] = node_after_stage;
}
Expand Down Expand Up @@ -2312,9 +2328,10 @@ class CloneNodesIntoBlockHandler {

XLS_ASSIGN_OR_RETURN(
RegisterRead * reg_read,

block()->MakeNodeWithName<RegisterRead>(node->loc(), reg,
/*name=*/reg->name()));

result_.node_to_stage_map[reg_read] = stage;
// The register write will be created later in HandleNextValue.
result_.state_registers[index] =
StateRegister{.name = std::string(param->name()),
Expand Down Expand Up @@ -2387,6 +2404,7 @@ class CloneNodesIntoBlockHandler {
/*load_enable=*/std::nullopt,
/*reset=*/std::nullopt, state_register.reg));
result_.output_states[stage].push_back(index);
result_.node_to_stage_map[state_register.reg_write] = stage;

// If the next state can be determined in a later cycle than the param
// access, we have a non-trivial backedge between initiations (II>1); use a
Expand Down Expand Up @@ -2743,7 +2761,8 @@ class CloneNodesIntoBlockHandler {
// Returns a PipelineRegister whose reg_read field can be used
// to chain dependent ops to.
absl::StatusOr<PipelineRegister> CreatePipelineRegister(std::string_view name,
Node* node) {
Node* node,
Stage stage_write) {
XLS_ASSIGN_OR_RETURN(Register * reg,
block()->AddRegister(name, node->GetType()));
XLS_ASSIGN_OR_RETURN(
Expand All @@ -2755,6 +2774,8 @@ class CloneNodesIntoBlockHandler {
RegisterRead * reg_read,
block()->MakeNodeWithName<RegisterRead>(node->loc(), reg,
/*name=*/reg->name()));
result_.node_to_stage_map[reg_write] = stage_write;
result_.node_to_stage_map[reg_read] = stage_write + 1;
return PipelineRegister{reg, reg_write, reg_read};
}

Expand All @@ -2781,7 +2802,7 @@ class CloneNodesIntoBlockHandler {
// the same type as the input node is returned.
//
absl::StatusOr<Node*> CreatePipelineRegistersForNode(
std::string_view base_name, Node* node,
std::string_view base_name, Node* node, Stage stage,
std::vector<PipelineRegister>& pipeline_registers_list) {
// As a special case, check if the node is a tuple
// containing types that are of zero-width. If so, separate them out so
Expand All @@ -2801,10 +2822,10 @@ class CloneNodesIntoBlockHandler {
XLS_ASSIGN_OR_RETURN(Node * split_node, block()->MakeNode<TupleIndex>(
node->loc(), node, i));

XLS_ASSIGN_OR_RETURN(
PipelineRegister pipe_reg,
CreatePipelineRegister(
absl::StrFormat("%s_index%d", base_name, i), split_node));
XLS_ASSIGN_OR_RETURN(PipelineRegister pipe_reg,
CreatePipelineRegister(
absl::StrFormat("%s_index%d", base_name, i),
split_node, stage));

split_registers.at(i) = pipe_reg.reg_read;
pipeline_registers_list.push_back(pipe_reg);
Expand All @@ -2821,7 +2842,7 @@ class CloneNodesIntoBlockHandler {

// Create a single register to store the node
XLS_ASSIGN_OR_RETURN(PipelineRegister pipe_reg,
CreatePipelineRegister(base_name, node));
CreatePipelineRegister(base_name, node, stage));

pipeline_registers_list.push_back(pipe_reg);
return pipe_reg.reg_read;
Expand Down Expand Up @@ -3107,7 +3128,8 @@ absl::StatusOr<CodegenPassUnit> FunctionToPipelinedBlock(
AddValidSignal(
absl::MakeSpan(unit.streaming_io_and_pipeline.pipeline_registers),
options, unit.block,
unit.streaming_io_and_pipeline.pipeline_valid));
unit.streaming_io_and_pipeline.pipeline_valid,
unit.streaming_io_and_pipeline.node_to_stage_map));
}

// Reorder the ports of the block to the following:
Expand Down
47 changes: 47 additions & 0 deletions xls/codegen/block_conversion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5417,6 +5417,53 @@ TEST_F(BlockConversionTest, CoveringRegions) {
EXPECT_TRUE(unit.concurrent_stages->IsConcurrent(3, 4));
}

TEST_F(BlockConversionTest, PipelineRegisterStagesKnown) {
auto p = CreatePackage();
XLS_ASSERT_OK_AND_ASSIGN(
Channel * x_out, p->CreateStreamingChannel("x_out", ChannelOps::kSendOnly,
p->GetBitsType(2)));
TokenlessProcBuilder pb(TestName(), "tok", p.get());
auto a = pb.StateElement("a_val", UBits(0, 2));
auto na = pb.Not(a, SourceInfo(), "not_a");
auto lit_one = pb.Literal(UBits(1, 2));
auto na_plus_one = pb.Add(na, lit_one, SourceInfo(), "na_plus_one");
auto send = pb.Send(x_out, na_plus_one);
auto next = pb.Next(a, na);
XLS_ASSERT_OK_AND_ASSIGN(auto proc, pb.Build());
PipelineSchedule ps(proc, {{pb.GetTokenParam().node(), 0},
{a.node(), 0},
{na.node(), 1},
{lit_one.node(), 2},
{na_plus_one.node(), 2},
{next.node(), 5},
{send.node(), 6}});
XLS_ASSERT_OK_AND_ASSIGN(
CodegenPassUnit unit,
ProcToPipelinedBlock(
ps, CodegenOptions().reset("foo", false, false, false), proc));

RecordProperty("blk", unit.block->DumpIr());
RecordProperty("map", testing::PrintToString(
unit.streaming_io_and_pipeline.node_to_stage_map));
auto read_at = [](BValue inst, int64_t stage) -> auto {
return testing::Contains(testing::Pair(
m::RegisterRead(testing::ContainsRegex(inst.GetName())), stage));
};
auto write_at = [](BValue inst, int64_t stage) -> auto {
return testing::Contains(testing::Pair(
m::RegisterWrite(testing::ContainsRegex(inst.GetName())), stage));
};
EXPECT_THAT(
unit.streaming_io_and_pipeline.node_to_stage_map,
testing::AllOf(read_at(na_plus_one, 3), read_at(na_plus_one, 4),
read_at(na_plus_one, 5), read_at(na_plus_one, 6),
read_at(a, 1), read_at(na, 2), read_at(na, 3),
read_at(na, 4), read_at(na, 5), write_at(na_plus_one, 2),
write_at(na_plus_one, 3), write_at(na_plus_one, 4),
write_at(na_plus_one, 5), write_at(a, 0), write_at(na, 1),
write_at(na, 2), write_at(na, 3), write_at(na, 4)));
}

} // namespace
} // namespace verilog
} // namespace xls

0 comments on commit 863aae4

Please sign in to comment.