Skip to content

Commit

Permalink
Add the ability to replace nodes while cloning an AST.
Browse files Browse the repository at this point in the history
This will be useful for implementing unroll_for! and replacing the loop variable refs in the body.

Retrofit the preservation of type definitions to use this mechanism.

PiperOrigin-RevId: 663935817
  • Loading branch information
richmckeever authored and copybara-github committed Aug 17, 2024
1 parent b507caf commit 6ca2c4c
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 78 deletions.
3 changes: 2 additions & 1 deletion xls/dslx/frontend/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ cc_library(
":proc",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",
"//xls/common:casts",
"//xls/common:visitor",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"//xls/ir:format_strings",
],
Expand All @@ -56,6 +56,7 @@ cc_test(
":ast",
":ast_cloner",
":module",
":pos",
"@com_google_absl//absl/status:statusor",
"//xls/common:casts",
"//xls/common:xls_gunit_main",
Expand Down
2 changes: 1 addition & 1 deletion xls/dslx/frontend/ast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ std::string_view PrecedenceToString(Precedence p) {
constexpr int64_t kTargetLineChars = 80;

ExprOrType ToExprOrType(AstNode* n) {
if (Expr* e = down_cast<Expr*>(n)) {
if (Expr* e = dynamic_cast<Expr*>(n)) {
return e;
}
auto* type = down_cast<TypeAnnotation*>(n);
Expand Down
100 changes: 52 additions & 48 deletions xls/dslx/frontend/ast_cloner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
#include "absl/types/variant.h"
#include "xls/common/casts.h"
#include "xls/common/status/ret_check.h"
#include "xls/common/status/status_macros.h"
#include "xls/common/visitor.h"
#include "xls/dslx/frontend/ast.h"
Expand All @@ -43,8 +43,8 @@ namespace {

class AstCloner : public AstNodeVisitor {
public:
explicit AstCloner(Module* module, bool clone_type_definitions = true)
: module_(module), clone_type_definitions_(clone_type_definitions) {}
explicit AstCloner(Module* module, CloneReplacer replacer)
: module_(module), replacer_(std::move(replacer)) {}

absl::Status HandleArray(const Array* n) override {
XLS_RETURN_IF_ERROR(VisitChildren(n));
Expand Down Expand Up @@ -757,36 +757,34 @@ class AstCloner : public AstNodeVisitor {
absl::Status HandleTypeRef(const TypeRef* n) override {
TypeDefinition new_type_definition = n->type_definition();

if (clone_type_definitions_) {
// A TypeRef doesn't own its referenced type definition, so we have to
// explicitly visit it.
XLS_RETURN_IF_ERROR(absl::visit(
Visitor{[&](ColonRef* colon_ref) -> absl::Status {
XLS_RETURN_IF_ERROR(colon_ref->Accept(this));
new_type_definition =
down_cast<ColonRef*>(old_to_new_.at(colon_ref));
return absl::OkStatus();
},
[&](EnumDef* enum_def) -> absl::Status {
XLS_RETURN_IF_ERROR(enum_def->Accept(this));
new_type_definition =
down_cast<EnumDef*>(old_to_new_.at(enum_def));
return absl::OkStatus();
},
[&](StructDef* struct_def) -> absl::Status {
XLS_RETURN_IF_ERROR(struct_def->Accept(this));
new_type_definition =
down_cast<StructDef*>(old_to_new_.at(struct_def));
return absl::OkStatus();
},
[&](TypeAlias* type_alias) -> absl::Status {
XLS_RETURN_IF_ERROR(type_alias->Accept(this));
new_type_definition =
down_cast<TypeAlias*>(old_to_new_.at(type_alias));
return absl::OkStatus();
}},
n->type_definition()));
}
// A TypeRef doesn't own its referenced type definition, so we have to
// explicitly visit it.
XLS_RETURN_IF_ERROR(absl::visit(
Visitor{[&](ColonRef* colon_ref) -> absl::Status {
XLS_RETURN_IF_ERROR(ReplaceOrVisit(colon_ref));
new_type_definition =
down_cast<ColonRef*>(old_to_new_.at(colon_ref));
return absl::OkStatus();
},
[&](EnumDef* enum_def) -> absl::Status {
XLS_RETURN_IF_ERROR(ReplaceOrVisit(enum_def));
new_type_definition =
down_cast<EnumDef*>(old_to_new_.at(enum_def));
return absl::OkStatus();
},
[&](StructDef* struct_def) -> absl::Status {
XLS_RETURN_IF_ERROR(ReplaceOrVisit(struct_def));
new_type_definition =
down_cast<StructDef*>(old_to_new_.at(struct_def));
return absl::OkStatus();
},
[&](TypeAlias* type_alias) -> absl::Status {
XLS_RETURN_IF_ERROR(ReplaceOrVisit(type_alias));
new_type_definition =
down_cast<TypeAlias*>(old_to_new_.at(type_alias));
return absl::OkStatus();
}},
n->type_definition()));

old_to_new_[n] = module_->Make<TypeRef>(n->span(), new_type_definition);
return absl::OkStatus();
Expand Down Expand Up @@ -873,43 +871,49 @@ class AstCloner : public AstNodeVisitor {
absl::Status VisitChildren(const AstNode* node) {
for (const auto& child : node->GetChildren(/*want_types=*/true)) {
if (!old_to_new_.contains(child)) {
XLS_RETURN_IF_ERROR(child->Accept(this));
XLS_RETURN_IF_ERROR(ReplaceOrVisit(child));
}
}
return absl::OkStatus();
}

absl::Status ReplaceOrVisit(const AstNode* node) {
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> replacement, replacer_(node));
if (replacement.has_value()) {
old_to_new_[node] = *replacement;
return absl::OkStatus();
}
return node->Accept(this);
}

Module* const module_;
const bool clone_type_definitions_;
CloneReplacer replacer_;
absl::flat_hash_map<const AstNode*, AstNode*> old_to_new_;
};

} // namespace

absl::StatusOr<AstNode*> CloneAst(AstNode* root) {
if (dynamic_cast<Module*>(root) != nullptr) {
return absl::InvalidArgumentError("Clone a module via 'CloneModule'.");
std::optional<AstNode*> PreserveTypeDefinitionsReplacer(const AstNode* node) {
if (const auto* type_ref = dynamic_cast<const TypeRef*>(node); type_ref) {
return node->owner()->Make<TypeRef>(type_ref->span(),
type_ref->type_definition());
}
AstCloner cloner(root->owner());
XLS_RETURN_IF_ERROR(root->Accept(&cloner));
return cloner.old_to_new().at(root);
return std::nullopt;
}

absl::StatusOr<AstNode*> CloneAstSansTypeDefinitions(AstNode* root) {
XLS_RET_CHECK(root != nullptr);
absl::StatusOr<AstNode*> CloneAst(AstNode* root, CloneReplacer replacer) {
if (dynamic_cast<Module*>(root) != nullptr) {
return absl::InvalidArgumentError("Clone a module via 'CloneModule'.");
}

XLS_RET_CHECK(root->owner() != nullptr);
AstCloner cloner(root->owner(), /*clone_type_definitions=*/false);
AstCloner cloner(root->owner(), std::move(replacer));
XLS_RETURN_IF_ERROR(root->Accept(&cloner));
return cloner.old_to_new().at(root);
}

absl::StatusOr<std::unique_ptr<Module>> CloneModule(Module* module) {
absl::StatusOr<std::unique_ptr<Module>> CloneModule(Module* module,
CloneReplacer replacer) {
auto new_module = std::make_unique<Module>(module->name(), module->fs_path());
AstCloner cloner(new_module.get());
AstCloner cloner(new_module.get(), std::move(replacer));
for (const ModuleMember member : module->top()) {
ModuleMember new_member;
XLS_RETURN_IF_ERROR(absl::visit(
Expand Down
57 changes: 38 additions & 19 deletions xls/dslx/frontend/ast_cloner.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
#define XLS_DSLX_FRONTEND_AST_CLONER_H_

#include <memory>
#include <optional>
#include <utility>
#include <vector>

#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
Expand All @@ -25,24 +28,38 @@

namespace xls::dslx {

// Creates a deep copy of the given AST node (inside the same module). All nodes
// in the tree are duplicated.
absl::StatusOr<AstNode*> CloneAst(AstNode* root);
// A function that can be used to override the cloning behavior for certain
// nodes during a `CloneAst` operations. A replacer can be used to replace
// targeted nodes with something else entirely, or it can just "clone" those
// nodes differently than the default logic.
using CloneReplacer =
absl::AnyInvocable<absl::StatusOr<std::optional<AstNode*>>(const AstNode*)>;

// Variant of CloneAst that won't clone type definitions -- this is useful for
// e.g. cloning return types without recursing into cloned definitions which
// would change nominal types.
absl::StatusOr<AstNode*> CloneAstSansTypeDefinitions(AstNode* root);

// Helper wrapper for the above that downcasts the result to the given
// apparent type (derived type of AstNode).
template <typename T>
inline absl::StatusOr<T*> CloneNodeSansTypeDefinitions(T* root) {
XLS_ASSIGN_OR_RETURN(AstNode * cloned, CloneAstSansTypeDefinitions(root));
return down_cast<T*>(cloned);
// This function is directly usable as the `replacer` argument for `CloneAst`
// when a direct clone with no replacements is desired.
inline std::optional<AstNode*> NoopCloneReplacer(const AstNode* original_node) {
return std::nullopt;
}

absl::StatusOr<std::unique_ptr<Module>> CloneModule(Module* module);
// A replacer function that performs shallow clones of `TypeRef` nodes, pointing
// the clone to the original `TypeDefinition` object. This is useful for e.g.
// cloning return types without recursing into cloned definitions which would
// change nominal types.
std::optional<AstNode*> PreserveTypeDefinitionsReplacer(
const AstNode* original_node);

// Creates a deep copy of the given AST node (inside the same module), generally
// duplicating all nodes. The given `replacer` may override whether and how a
// given node gets duplicated. The `replacer` is invoked for each original node
// about to be cloned. If it returns `nullopt` (which is the default), then
// cloning proceeds as normal. If it returns an `AstNode*`, then that pointer is
// used as a wholesale replacement subtree, and cloning does not delve into the
// children of the original node.
absl::StatusOr<AstNode*> CloneAst(AstNode* root,
CloneReplacer replacer = &NoopCloneReplacer);

absl::StatusOr<std::unique_ptr<Module>> CloneModule(
Module* module, CloneReplacer replacer = &NoopCloneReplacer);

// Verifies that the AST node tree rooted at `new_root` does not contain any of
// the AST nodes in the tree rooted at `old_root`. In practice, this will verify
Expand All @@ -52,18 +69,20 @@ absl::Status VerifyClone(const AstNode* old_root, const AstNode* new_root);
// Helper for CloneAst that uses the apparent (derived) type given by the
// parameter as the return type. (This helps encapsulate casts to be safer.)
template <typename T>
inline absl::StatusOr<T*> CloneNode(T* node) {
XLS_ASSIGN_OR_RETURN(AstNode * cloned, CloneAst(node));
inline absl::StatusOr<T*> CloneNode(
T* node, CloneReplacer replacer = &NoopCloneReplacer) {
XLS_ASSIGN_OR_RETURN(AstNode * cloned, CloneAst(node, std::move(replacer)));
return down_cast<T*>(cloned);
}

// Helper that vectorizes the CloneNode routine.
template <typename T>
inline absl::StatusOr<std::vector<T*>> CloneNodes(absl::Span<T* const> nodes) {
inline absl::StatusOr<std::vector<T*>> CloneNodes(
absl::Span<T* const> nodes, CloneReplacer replacer = &NoopCloneReplacer) {
std::vector<T*> results;
results.reserve(nodes.size());
for (T* n : nodes) {
XLS_ASSIGN_OR_RETURN(T * cloned, CloneNode<T>(n));
XLS_ASSIGN_OR_RETURN(T * cloned, CloneNode<T>(n, replacer));
results.push_back(cloned);
}
return results;
Expand Down
75 changes: 75 additions & 0 deletions xls/dslx/frontend/ast_cloner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "xls/dslx/frontend/ast_cloner.h"

#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <variant>
Expand All @@ -26,11 +27,25 @@
#include "xls/dslx/command_line_utils.h"
#include "xls/dslx/frontend/ast.h"
#include "xls/dslx/frontend/module.h"
#include "xls/dslx/frontend/pos.h"
#include "xls/dslx/parse_and_typecheck.h"

namespace xls::dslx {
namespace {

std::optional<TypeRef*> FindFirstTypeRef(AstNode* node) {
if (auto type_ref = dynamic_cast<TypeRef*>(node); type_ref) {
return type_ref;
}
for (AstNode* child : node->GetChildren(true)) {
std::optional type_ref = FindFirstTypeRef(child);
if (type_ref.has_value()) {
return type_ref;
}
}
return std::nullopt;
}

TEST(AstClonerTest, BasicOperation) {
constexpr std::string_view kProgram = R"(
fn main() -> u32 {
Expand Down Expand Up @@ -77,6 +92,43 @@ fn main() -> u32 {
XLS_ASSERT_OK(VerifyClone(body_expr, clone));
}

TEST(AstClonerTest, ReplaceOneOfTwoNameRefs) {
constexpr std::string_view kProgram = R"(
fn main() -> u32 {
let a = u32:0;
let b = a + 2;
b
})";

constexpr std::string_view kExpected = R"({
let a = u32:0;
let b = 3 + 2;
b
})";

XLS_ASSERT_OK_AND_ASSIGN(auto module,
ParseModule(kProgram, "fake_path.x", "the_module"));
XLS_ASSERT_OK_AND_ASSIGN(Function * f,
module->GetMemberOrError<Function>("main"));
Number* a_replacement =
module->Make<Number>(Span::Fake(), "3", NumberKind::kOther,
/*type=*/nullptr);
StatementBlock* body_expr = f->body();
XLS_ASSERT_OK_AND_ASSIGN(
AstNode * clone,
CloneAst(body_expr,
[&](const AstNode* original_node) -> std::optional<AstNode*> {
if (const auto* name_ref =
dynamic_cast<const NameRef*>(original_node);
name_ref && name_ref->identifier() == "a") {
return a_replacement;
}
return std::nullopt;
}));
EXPECT_EQ(kExpected, clone->ToString());
XLS_ASSERT_OK(VerifyClone(body_expr, clone));
}

TEST(AstClonerTest, XlsTuple) {
constexpr std::string_view kProgram = R"(
fn main() -> (u32, u32) {
Expand Down Expand Up @@ -336,6 +388,29 @@ TEST(AstClonerTest, TypeAlias) {
XLS_ASSERT_OK(VerifyClone(type_alias, clone));
}

TEST(AstClonerTest, PreserveTypeDefinitionsReplacer) {
constexpr std::string_view kProgram =
R"(
type my_type = u32;
fn foo() -> u32 {
zero!<my_type>()
}
)";

XLS_ASSERT_OK_AND_ASSIGN(auto module,
ParseModule(kProgram, "fake_path.x", "the_module"));
XLS_ASSERT_OK_AND_ASSIGN(Function * foo,
module->GetMemberOrError<Function>("foo"));
XLS_ASSERT_OK_AND_ASSIGN(AstNode * clone,
CloneAst(foo, &PreserveTypeDefinitionsReplacer));
std::optional<TypeRef*> type_ref = FindFirstTypeRef(foo);
ASSERT_TRUE(type_ref.has_value());
std::optional<TypeRef*> cloned_type_ref = FindFirstTypeRef(clone);
ASSERT_TRUE(cloned_type_ref.has_value());
EXPECT_EQ((*cloned_type_ref)->type_definition(),
(*type_ref)->type_definition());
}

TEST(AstClonerTest, QuickCheck) {
constexpr std::string_view kProgram = R"(#[quickcheck(test_count=1000)]
fn my_quickcheck(a: u32, b: u64, c: sN[128]) {
Expand Down
Loading

0 comments on commit 6ca2c4c

Please sign in to comment.