Skip to content

Commit

Permalink
[DSLX:parser] Refactor to consolidate on AST cloner.
Browse files Browse the repository at this point in the history
Types needed slightly shallower cloning to share the cloner infrastructure, if
we recursed all the way into the type definitions (beyond the references) then
we'd break the nominal type comparisons we do by AST pointer. This adds a
cloning option that doesn't traverse into type definitions, which is what we
need for return type clones as the parser currently does them.

PiperOrigin-RevId: 555515229
  • Loading branch information
cdleary authored and copybara-github committed Aug 10, 2023
1 parent 1943d63 commit eeb9b3e
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 83 deletions.
2 changes: 2 additions & 0 deletions xls/dslx/frontend/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ cc_library(
"@com_google_absl//absl/types:variant",
"//xls/common:casts",
"//xls/common:visitor",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
],
)
Expand Down Expand Up @@ -94,6 +95,7 @@ cc_library(
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",
"//xls/common:casts",
"//xls/common:strong_int",
"//xls/common:visitor",
"//xls/common/status:ret_check",
Expand Down
82 changes: 51 additions & 31 deletions xls/dslx/frontend/ast_cloner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,19 @@
#include "absl/types/span.h"
#include "absl/types/variant.h"
#include "xls/common/casts.h"
#include "xls/common/status/ret_check.h"
#include "xls/common/status/status_macros.h"
#include "xls/common/visitor.h"
#include "xls/dslx/frontend/ast.h"
#include "xls/dslx/frontend/ast_utils.h"

namespace xls::dslx {
namespace {

class AstCloner : public AstNodeVisitor {
public:
explicit AstCloner(Module* module) : module_(module) {}
explicit AstCloner(Module* module, bool clone_type_definitions = true)
: module_(module), clone_type_definitions_(clone_type_definitions) {}

absl::Status HandleArray(const Array* n) override {
XLS_RETURN_IF_ERROR(VisitChildren(n));
Expand Down Expand Up @@ -756,36 +759,38 @@ class AstCloner : public AstNodeVisitor {
}

absl::Status HandleTypeRef(const TypeRef* n) override {
TypeDefinition new_type_definition;
TypeDefinition new_type_definition = n->type_definition();

// A TypeRef doesn't own its referenced type definition, so we have to
// explicitly visit it.
XLS_RETURN_IF_ERROR(absl::visit(
Visitor{[&](ColonRef* colon_ref) -> absl::Status {
XLS_RETURN_IF_ERROR(colon_ref->Accept(this));
new_type_definition =
down_cast<ColonRef*>(old_to_new_.at(colon_ref));
return absl::OkStatus();
},
[&](EnumDef* enum_def) -> absl::Status {
XLS_RETURN_IF_ERROR(enum_def->Accept(this));
new_type_definition =
down_cast<EnumDef*>(old_to_new_.at(enum_def));
return absl::OkStatus();
},
[&](StructDef* struct_def) -> absl::Status {
XLS_RETURN_IF_ERROR(struct_def->Accept(this));
new_type_definition =
down_cast<StructDef*>(old_to_new_.at(struct_def));
return absl::OkStatus();
},
[&](TypeAlias* type_alias) -> absl::Status {
XLS_RETURN_IF_ERROR(type_alias->Accept(this));
new_type_definition =
down_cast<TypeAlias*>(old_to_new_.at(type_alias));
return absl::OkStatus();
}},
n->type_definition()));
if (clone_type_definitions_) {
// A TypeRef doesn't own its referenced type definition, so we have to
// explicitly visit it.
XLS_RETURN_IF_ERROR(absl::visit(
Visitor{[&](ColonRef* colon_ref) -> absl::Status {
XLS_RETURN_IF_ERROR(colon_ref->Accept(this));
new_type_definition =
down_cast<ColonRef*>(old_to_new_.at(colon_ref));
return absl::OkStatus();
},
[&](EnumDef* enum_def) -> absl::Status {
XLS_RETURN_IF_ERROR(enum_def->Accept(this));
new_type_definition =
down_cast<EnumDef*>(old_to_new_.at(enum_def));
return absl::OkStatus();
},
[&](StructDef* struct_def) -> absl::Status {
XLS_RETURN_IF_ERROR(struct_def->Accept(this));
new_type_definition =
down_cast<StructDef*>(old_to_new_.at(struct_def));
return absl::OkStatus();
},
[&](TypeAlias* type_alias) -> absl::Status {
XLS_RETURN_IF_ERROR(type_alias->Accept(this));
new_type_definition =
down_cast<TypeAlias*>(old_to_new_.at(type_alias));
return absl::OkStatus();
}},
n->type_definition()));
}

old_to_new_[n] = module_->Make<TypeRef>(n->span(), new_type_definition);
return absl::OkStatus();
Expand Down Expand Up @@ -873,10 +878,13 @@ class AstCloner : public AstNodeVisitor {
return absl::OkStatus();
}

Module* module_;
Module* const module_;
const bool clone_type_definitions_;
absl::flat_hash_map<const AstNode*, AstNode*> old_to_new_;
};

} // namespace

absl::StatusOr<AstNode*> CloneAst(AstNode* root) {
if (dynamic_cast<Module*>(root) != nullptr) {
return absl::InvalidArgumentError("Clone a module via 'CloneModule'.");
Expand All @@ -886,6 +894,18 @@ absl::StatusOr<AstNode*> CloneAst(AstNode* root) {
return cloner.old_to_new().at(root);
}

absl::StatusOr<AstNode*> CloneAstSansTypeDefinitions(AstNode* root) {
XLS_RET_CHECK(root != nullptr);
if (dynamic_cast<Module*>(root) != nullptr) {
return absl::InvalidArgumentError("Clone a module via 'CloneModule'.");
}

XLS_RET_CHECK(root->owner() != nullptr);
AstCloner cloner(root->owner(), /*clone_type_definitions=*/false);
XLS_RETURN_IF_ERROR(root->Accept(&cloner));
return cloner.old_to_new().at(root);
}

absl::StatusOr<std::unique_ptr<Module>> CloneModule(Module* module) {
auto new_module = std::make_unique<Module>(module->name(), module->fs_path());
AstCloner cloner(new_module.get());
Expand Down
6 changes: 6 additions & 0 deletions xls/dslx/frontend/ast_cloner.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ namespace xls::dslx {
// Creates a deep copy of the given AST node (inside the same module). All nodes
// in the tree are duplicated.
absl::StatusOr<AstNode*> CloneAst(AstNode* root);

// Variant of CloneAst that won't clone type definitions -- this is useful for
// e.g. cloning return types without recursing into cloned definitions which
// would change nominal types.
absl::StatusOr<AstNode*> CloneAstSansTypeDefinitions(AstNode* root);

absl::StatusOr<std::unique_ptr<Module>> CloneModule(Module* module);

// Verifies that the AST node tree rooted at `new_root` does not contain any of
Expand Down
49 changes: 3 additions & 46 deletions xls/dslx/frontend/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
#include "absl/types/variant.h"
#include "xls/common/casts.h"
#include "xls/common/status/ret_check.h"
#include "xls/common/status/status_builder.h"
#include "xls/common/status/status_macros.h"
Expand Down Expand Up @@ -2104,52 +2105,8 @@ absl::StatusOr<Function*> Parser::ParseProcInit(
}

absl::StatusOr<TypeAnnotation*> Parser::CloneReturnType(TypeAnnotation* input) {
if (auto* builtin_type = dynamic_cast<BuiltinTypeAnnotation*>(input);
builtin_type != nullptr) {
return module_->Make<BuiltinTypeAnnotation>(
input->span(), builtin_type->builtin_type(),
builtin_type->builtin_name_def());
}

if (auto* array_type = dynamic_cast<ArrayTypeAnnotation*>(input);
array_type != nullptr) {
XLS_ASSIGN_OR_RETURN(TypeAnnotation * clone_element,
CloneReturnType(array_type->element_type()));
XLS_ASSIGN_OR_RETURN(AstNode * new_dim, CloneAst(array_type->dim()));
return module_->Make<ArrayTypeAnnotation>(input->span(), clone_element,
down_cast<Expr*>(new_dim));
}

if (auto* tuple_type = dynamic_cast<TupleTypeAnnotation*>(input);
tuple_type != nullptr) {
std::vector<TypeAnnotation*> members;
members.reserve(tuple_type->members().size());
for (auto* member : tuple_type->members()) {
XLS_ASSIGN_OR_RETURN(TypeAnnotation * new_member,
CloneReturnType(member));
members.push_back(new_member);
}
return module_->Make<TupleTypeAnnotation>(tuple_type->span(), members);
}

if (auto* typeref_type = dynamic_cast<TypeRefTypeAnnotation*>(input);
typeref_type != nullptr) {
TypeRef* old_ref = typeref_type->type_ref();
TypeRef* new_ref =
module_->Make<TypeRef>(old_ref->span(), old_ref->type_definition());

std::vector<ExprOrType> new_parametrics;
for (const ExprOrType& parametric : typeref_type->parametrics()) {
XLS_ASSIGN_OR_RETURN(AstNode * new_parametric,
CloneAst(ToAstNode(parametric)));
new_parametrics.push_back(ToExprOrType(new_parametric));
}
return module_->Make<TypeRefTypeAnnotation>(typeref_type->span(), new_ref,
new_parametrics);
}

XLS_RET_CHECK_NE(dynamic_cast<ChannelTypeAnnotation*>(input), nullptr);
return absl::UnimplementedError("Cannot clone channel type annotations.");
XLS_ASSIGN_OR_RETURN(AstNode * cloned, CloneAstSansTypeDefinitions(input));
return down_cast<TypeAnnotation*>(cloned);
}

absl::StatusOr<Proc*> Parser::ParseProc(bool is_public,
Expand Down
15 changes: 9 additions & 6 deletions xls/dslx/type_system/typecheck.cc
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,8 @@ absl::Status CheckFunction(Function* f, DeduceCtx* ctx) {
"Types cannot be returned from functions");
}
if (*return_type != *body_type) {
XLS_VLOG(5) << "return type: " << return_type->ToString()
<< " body type: " << body_type->ToString();
if (f->tag() == Function::Tag::kProcInit) {
return ctx->TypeMismatchError(
f->body()->span(), f->body(), *body_type, f->return_type(),
Expand Down Expand Up @@ -1028,23 +1030,24 @@ absl::StatusOr<TypeAndParametricEnv> CheckInvocation(
if (annotated_return_type != *resolved_body_type) {
XLS_VLOG(5) << "annotated_return_type: " << annotated_return_type
<< " resolved_body_type: " << resolved_body_type->ToString();

if (callee_fn->tag() == Function::Tag::kProcInit) {
return ctx->TypeMismatchError(
callee_fn->body()->span(), callee_fn->body(), *body_type, nullptr,
*callee_tab.type,
callee_fn->body()->span(), callee_fn->body(), *resolved_body_type,
nullptr, annotated_return_type,
absl::StrFormat("'next' state param and 'init' types differ."));
}

if (callee_fn->tag() == Function::Tag::kProcNext) {
return ctx->TypeMismatchError(
callee_fn->body()->span(), callee_fn->body(), *body_type, nullptr,
*callee_tab.type,
callee_fn->body()->span(), callee_fn->body(), *resolved_body_type,
nullptr, annotated_return_type,
absl::StrFormat("'next' input and output state types differ."));
}

return ctx->TypeMismatchError(
callee_fn->body()->span(), callee_fn->body(), *body_type, nullptr,
*callee_tab.type,
callee_fn->body()->span(), callee_fn->body(), *resolved_body_type,
nullptr, annotated_return_type,
absl::StrFormat("Return type of function body for '%s' did not match "
"the annotated return type.",
callee_fn->identifier()));
Expand Down

0 comments on commit eeb9b3e

Please sign in to comment.