diff --git a/xls/dslx/BUILD b/xls/dslx/BUILD index 52077436a3..9680a08be6 100644 --- a/xls/dslx/BUILD +++ b/xls/dslx/BUILD @@ -432,6 +432,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_googlesource_code_re2//:re2", ], ) diff --git a/xls/dslx/error_printer.cc b/xls/dslx/error_printer.cc index 188c971cee..a543116443 100644 --- a/xls/dslx/error_printer.cc +++ b/xls/dslx/error_printer.cc @@ -35,6 +35,7 @@ #include "xls/common/status/status_macros.h" #include "xls/dslx/frontend/pos.h" #include "xls/dslx/warning_collector.h" +#include "re2/re2.h" namespace xls::dslx { @@ -78,6 +79,17 @@ absl::Status PrintPositionalError( int64_t last_line_printed = std::min(limit.lineno() + line_count_each_side, file_limit.lineno()); + // Strip ANSI escape codes from the error message opportunistically if we know + // we shouldn't be emitting colors. + // + // This is a bit of a layering violation but makes life easier as it allows + // "downstream" error message creation to insert ANSI codes without needing to + // be told explicitly if color error messages are ok. + std::string msg{error_message}; + if (!isatty(fileno(stderr)) || color == PositionalErrorColor::kNoColor) { + RE2::GlobalReplace(&msg, "\33\\[\\d+m", ""); + } + std::string_view pos_color_leader; std::string_view msg_color_leader; std::string_view color_reset; @@ -134,14 +146,14 @@ absl::Status PrintPositionalError( dashes_and_arrow = std::string(width - 1, '-') + "^"; } os << absl::StreamFormat("%s%s^%s %s%s\n", msg_color_leader, squiggles, - dashes_and_arrow, error_message, color_reset); + dashes_and_arrow, msg, color_reset); } } else if (i == limit.lineno()) { // Emit arrow pointing to the end of the multi-line error. std::string spaces(std::string_view("0000: ").size(), ' '); std::string underscores(std::max(int64_t{0}, limit.colno()), '_'); os << absl::StreamFormat("%s%s|%s^ %s%s\n", msg_color_leader, spaces, - underscores, error_message, color_reset); + underscores, msg, color_reset); // We're done drawing the multiline arrows; put down the crayon. bar = bar_off; } diff --git a/xls/dslx/interpreter_test.py b/xls/dslx/interpreter_test.py index 9f857829ce..389e54552a 100644 --- a/xls/dslx/interpreter_test.py +++ b/xls/dslx/interpreter_test.py @@ -388,6 +388,7 @@ def test_cast_array_to_wrong_bit_count(self): 0003: let x = u2[2]:[2, 3]; 0004: assert_eq(u3:0, x as u3) ~~~~~~~~~~~~~~~~~~~~~~~~^-----^ XlsTypeError: Cannot cast from expression type uN[2][2] to uN[3]. + Type mismatch: uN[2][2] vs uN[3] 0005: } diff --git a/xls/dslx/tests/errors/error_modules_test.py b/xls/dslx/tests/errors/error_modules_test.py index 90c64f2b53..8e9815ee89 100644 --- a/xls/dslx/tests/errors/error_modules_test.py +++ b/xls/dslx/tests/errors/error_modules_test.py @@ -1097,7 +1097,12 @@ def test_deeply_nested_type_mismatch(self): ) self.assertIn('XlsTypeError:', stderr) self.assertIn( - '(uN[8], uN[16], uN[32], uN[64])\nvs (uN[8], uN[16], uN[33], uN[64])', + 'Mismatched elements within type:\n' + ' uN[32]\n' + 'vs uN[33]\n' + 'Overall type mismatch:\n' + ' (uN[8], uN[16], uN[32], uN[64])\n' + 'vs (uN[8], uN[16], uN[33], uN[64])', stderr, ) self.assertIn( diff --git a/xls/dslx/type_system/BUILD b/xls/dslx/type_system/BUILD index 1184f5912a..d110d274a6 100644 --- a/xls/dslx/type_system/BUILD +++ b/xls/dslx/type_system/BUILD @@ -261,12 +261,13 @@ cc_library( srcs = ["maybe_explain_error.cc"], hdrs = ["maybe_explain_error.h"], deps = [ + ":format_type_mismatch", ":type", ":type_mismatch_error_data", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", - "//xls/dslx:errors", + "//xls/common/status:status_macros", "//xls/dslx/frontend:ast", "//xls/dslx/frontend:ast_node", "//xls/dslx/frontend:pos", @@ -865,3 +866,33 @@ cc_test( "//xls/common/status:matchers", ], ) + +cc_library( + name = "format_type_mismatch", + srcs = ["format_type_mismatch.cc"], + hdrs = ["format_type_mismatch.h"], + deps = [ + ":type", + ":zip_types", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:variant", + "//xls/common:visitor", + "//xls/common/status:status_macros", + ], +) + +cc_test( + name = "format_type_mismatch_test", + srcs = ["format_type_mismatch_test.cc"], + deps = [ + ":format_type_mismatch", + ":type", + "//xls/common:xls_gunit", + "//xls/common:xls_gunit_main", + "//xls/common/status:matchers", + "//xls/common/status:status_macros", + ], +) diff --git a/xls/dslx/type_system/format_type_mismatch.cc b/xls/dslx/type_system/format_type_mismatch.cc new file mode 100644 index 0000000000..f71953da58 --- /dev/null +++ b/xls/dslx/type_system/format_type_mismatch.cc @@ -0,0 +1,229 @@ +// Copyright 2024 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xls/dslx/type_system/format_type_mismatch.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/variant.h" +#include "xls/common/status/status_macros.h" +#include "xls/common/visitor.h" +#include "xls/dslx/type_system/type.h" +#include "xls/dslx/type_system/zip_types.h" + +namespace xls::dslx { +namespace { + +constexpr std::string_view kAnsiReset = "\33[0m"; +constexpr std::string_view kAnsiRed = "\33[31m"; +constexpr std::string_view kAnsiBoldOn = "\33[1m"; +constexpr std::string_view kAnsiBoldOff = "\33[22m"; + +// Populates the ref given as `mismatches` with the mismatches. +// +// Note: we could have this use the auto-formatting pretty printer to get more +// readable line wrapping for very long types, but we hope that highlighting the +// subtype mismatches inside the broader type might suffice for now. +class Callbacks : public ZipTypesCallbacks { + public: + explicit Callbacks( + std::vector>& mismatches) + : mismatches_(mismatches) {} + + absl::Status NoteAggregateStart(const AggregatePair& aggregates) override { + return absl::visit( + Visitor{ + [&](std::pair) { + AddMatchedBoth("("); + return absl::OkStatus(); + }, + [&](std::pair p) { + AddMatchedBoth( + absl::StrCat(p.first->nominal_type().identifier(), "{")); + return absl::OkStatus(); + }, + [&](std::pair p) { + /* goes at the end */ + return absl::OkStatus(); + }, + [&](std::pair p) { + AddMatchedBoth("chan("); + return absl::OkStatus(); + }, + [&](std::pair p) { + return absl::UnimplementedError( + "Cannot print diffs of function types."); + }, + [&](std::pair p) { + AddMatchedBoth("typeof("); + return absl::OkStatus(); + }, + }, + aggregates); + } + + absl::Status NoteAggregateEnd(const AggregatePair& aggregates) override { + return absl::visit( + Visitor{ + [&](std::pair) { + AddMatchedBoth(")"); + return absl::OkStatus(); + }, + [&](std::pair) { + AddMatchedBoth("}"); + return absl::OkStatus(); + }, + [&](std::pair p) { + AddMatched(absl::StrCat("[", p.first->size().ToString(), "]"), + &colorized_lhs_); + AddMatched(absl::StrCat("[", p.second->size().ToString(), "]"), + &colorized_rhs_); + return absl::OkStatus(); + }, + [&](std::pair p) { + AddMatchedBoth(")"); + return absl::OkStatus(); + }, + [&](std::pair p) { + return absl::UnimplementedError( + "Cannot print diffs of function types."); + }, + [&](std::pair p) { + AddMatchedBoth(")"); + return absl::OkStatus(); + }, + }, + aggregates); + } + + absl::Status NoteMatchedLeafType(const Type& lhs, const Type* lhs_parent, + const Type& rhs, + const Type* rhs_parent) override { + match_count_++; + BeforeType(lhs, lhs_parent, rhs, rhs_parent); + AddMatched(lhs.ToString(), &colorized_lhs_); + AddMatched(rhs.ToString(), &colorized_rhs_); + AfterType(lhs, lhs_parent, rhs, rhs_parent); + return absl::OkStatus(); + } + + absl::Status NoteTypeMismatch(const Type& lhs, const Type* lhs_parent, + const Type& rhs, + const Type* rhs_parent) override { + mismatches_.push_back({lhs.ToString(), rhs.ToString()}); + BeforeType(lhs, lhs_parent, rhs, rhs_parent); + AddMismatched(lhs.ToString(), rhs.ToString()); + AfterType(lhs, lhs_parent, rhs, rhs_parent); + return absl::OkStatus(); + } + + std::string_view colorized_lhs() const { return colorized_lhs_; } + std::string_view colorized_rhs() const { return colorized_rhs_; } + + int64_t match_count() const { return match_count_; } + + private: + // Adds a struct field before the RHS. + void BeforeType(const Type& lhs, const Type* lhs_parent, const Type& rhs, + const Type* rhs_parent) { + if (lhs_parent == nullptr) { + return; + } + if (auto* parent_struct = dynamic_cast(lhs_parent); + parent_struct != nullptr) { + int64_t index = parent_struct->IndexOf(lhs).value(); + AddMatchedBoth(absl::StrCat(parent_struct->GetMemberName(index), ": ")); + } + } + + void AfterType(const Type& lhs, const Type* lhs_parent, const Type& rhs, + const Type* rhs_parent) { + if (lhs_parent == nullptr) { + return; + } + if (auto* parent_struct = dynamic_cast(lhs_parent); + parent_struct != nullptr && + parent_struct->IndexOf(lhs).value() + 1 != parent_struct->size()) { + AddMatchedBoth(", "); + } + if (auto* parent_tuple = dynamic_cast(lhs_parent); + parent_tuple != nullptr && + parent_tuple->IndexOf(lhs).value() + 1 != parent_tuple->size()) { + AddMatchedBoth(", "); + } + } + + void AddMismatched(std::string_view lhs, std::string_view rhs) { + absl::StrAppend(&colorized_lhs_, kAnsiRed, lhs, kAnsiReset); + absl::StrAppend(&colorized_rhs_, kAnsiRed, rhs, kAnsiReset); + } + + void AddMatched(std::string_view matched_text, std::string* out) { + absl::StrAppend(out, matched_text); + } + // Helper that adds the matched text to both the LHS and RHS. + void AddMatchedBoth(std::string_view matched_text) { + AddMatched(matched_text, &colorized_lhs_); + AddMatched(matched_text, &colorized_rhs_); + } + + // We start the string off with an ANSI reset since we have our own coloring + // we do inside. + std::string colorized_lhs_; + std::string colorized_rhs_; + std::vector>& mismatches_; + int64_t match_count_ = 0; +}; + +} // namespace + +absl::StatusOr FormatTypeMismatch(const Type& lhs, + const Type& rhs) { + std::vector> mismatches; + + Callbacks callbacks(mismatches); + + XLS_RETURN_IF_ERROR(ZipTypes(lhs, rhs, callbacks)); + + std::vector lines; + if (callbacks.match_count() == 0) { + lines.push_back("Type mismatch:"); + lines.push_back(absl::StrFormat(" %s", lhs.ToString())); + lines.push_back(absl::StrFormat("vs %s", rhs.ToString())); + } else { + lines.push_back(absl::StrFormat("%sMismatched elements %swithin%s type:", + kAnsiReset, kAnsiBoldOn, kAnsiBoldOff)); + for (const auto& [lhs_mismatch, rhs_mismatch] : mismatches) { + lines.push_back(absl::StrFormat(" %s", lhs_mismatch)); + lines.push_back(absl::StrFormat("vs %s", rhs_mismatch)); + } + lines.push_back(absl::StrFormat("%sOverall%s type mismatch:", kAnsiBoldOn, + kAnsiBoldOff)); + lines.push_back( + absl::StrFormat("%s %s", kAnsiReset, callbacks.colorized_lhs())); + lines.push_back(absl::StrFormat("vs %s", callbacks.colorized_rhs())); + } + return absl::StrJoin(lines, "\n"); +} + +} // namespace xls::dslx diff --git a/xls/dslx/type_system/format_type_mismatch.h b/xls/dslx/type_system/format_type_mismatch.h new file mode 100644 index 0000000000..0b3f1a8cc5 --- /dev/null +++ b/xls/dslx/type_system/format_type_mismatch.h @@ -0,0 +1,44 @@ +// Copyright 2024 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef XLS_DSLX_TYPE_SYSTEM_FORMAT_TYPE_MISMATCH_H_ +#define XLS_DSLX_TYPE_SYSTEM_FORMAT_TYPE_MISMATCH_H_ + +#include + +#include "absl/status/statusor.h" +#include "xls/dslx/type_system/type.h" + +namespace xls::dslx { + +// Returns a string that displays the mismatch between the "lhs" and "rhs" types +// in more detail in an attempt to be helpful to a DSLX user. +// +// This may include lines that highlight specific differences between the types +// where the structure was the same between lhs and rhs; e.g. for +// +// lhs: (u32, u64) +// rhs: (u32, s64) +// +// The returned string will highlight that the discrepancy is between the u64 / +// s64 within the tuples. +// +// The returned string should be assumed to be multi-line but not +// newline-terminated. +absl::StatusOr FormatTypeMismatch(const Type& lhs, + const Type& rhs); + +} // namespace xls::dslx + +#endif // XLS_DSLX_TYPE_SYSTEM_FORMAT_TYPE_MISMATCH_H_ diff --git a/xls/dslx/type_system/format_type_mismatch_test.cc b/xls/dslx/type_system/format_type_mismatch_test.cc new file mode 100644 index 0000000000..dc1bc096ac --- /dev/null +++ b/xls/dslx/type_system/format_type_mismatch_test.cc @@ -0,0 +1,89 @@ +// Copyright 2024 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xls/dslx/type_system/format_type_mismatch.h" + +#include +#include + +#include "gtest/gtest.h" +#include "xls/common/status/matchers.h" +#include "xls/dslx/type_system/type.h" + +namespace xls::dslx { +namespace { + +// Macro definitions so we can use C style string concatenation where we can +// write/see the escapes more easily. +#define ANSI_RESET "\33[0m" +#define ANSI_RED "\33[31m" +#define ANSI_BOLD "\33[1m" +#define ANSI_UNBOLD "\33[22m" + +TEST(FormatTypeMismatchTest, ElementInTuple) { + auto t0 = TupleType::Create3(BitsType::MakeU8(), + std::make_unique(false, 16), + BitsType::MakeU32()); + auto t1 = TupleType::Create3(BitsType::MakeU8(), + std::make_unique(true, 16), + BitsType::MakeU32()); + + XLS_ASSERT_OK_AND_ASSIGN(std::string got, FormatTypeMismatch(*t0, *t1)); + + EXPECT_EQ(got, + ANSI_RESET "Mismatched elements " ANSI_BOLD "within" ANSI_UNBOLD + " type:\n" + " uN[16]\n" + "vs sN[16]\n" ANSI_BOLD "Overall" ANSI_UNBOLD + " type mismatch:\n" ANSI_RESET " (uN[8], " ANSI_RED + "uN[16]" ANSI_RESET + ", uN[32])\n" + "vs (uN[8], " ANSI_RED "sN[16]" ANSI_RESET ", uN[32])"); +} + +TEST(FormatTypeMismatchTest, ElementTypeInArrayInTuple) { + auto t0 = TupleType::Create2( + BitsType::MakeU1(), + std::make_unique(BitsType::MakeU32(), TypeDim::CreateU32(4))); + auto t1 = TupleType::Create2( + BitsType::MakeU1(), + std::make_unique(BitsType::MakeS32(), TypeDim::CreateU32(4))); + + XLS_ASSERT_OK_AND_ASSIGN(std::string got, FormatTypeMismatch(*t0, *t1)); + + EXPECT_EQ(got, + ANSI_RESET "Mismatched elements " ANSI_BOLD "within" ANSI_UNBOLD + " type:\n" + " uN[32]\n" + "vs sN[32]\n" ANSI_BOLD "Overall" ANSI_UNBOLD + " type mismatch:\n" ANSI_RESET " (uN[1], " ANSI_RED + "uN[32]" ANSI_RESET + "[4])\n" + "vs (uN[1], " ANSI_RED "sN[32]" ANSI_RESET "[4])"); +} + +TEST(FormatTypeMismatchTest, TotallyDifferentTuples) { + auto t0 = TupleType::Create2(BitsType::MakeU8(), BitsType::MakeU32()); + auto t1 = TupleType::Create2(BitsType::MakeU1(), BitsType::MakeU64()); + + XLS_ASSERT_OK_AND_ASSIGN(std::string got, FormatTypeMismatch(*t0, *t1)); + + EXPECT_EQ(got, + "Type mismatch:\n" + " (uN[8], uN[32])\n" + "vs (uN[1], uN[64])"); +} + +} // namespace +} // namespace xls::dslx diff --git a/xls/dslx/type_system/maybe_explain_error.cc b/xls/dslx/type_system/maybe_explain_error.cc index 7f3657b735..16d4705fab 100644 --- a/xls/dslx/type_system/maybe_explain_error.cc +++ b/xls/dslx/type_system/maybe_explain_error.cc @@ -21,9 +21,11 @@ #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" +#include "xls/common/status/status_macros.h" #include "xls/dslx/frontend/ast.h" #include "xls/dslx/frontend/ast_node.h" #include "xls/dslx/frontend/pos.h" +#include "xls/dslx/type_system/format_type_mismatch.h" #include "xls/dslx/type_system/type.h" #include "xls/dslx/type_system/type_mismatch_error_data.h" @@ -34,17 +36,18 @@ namespace { absl::Status XlsTypeErrorStatus(const Span& span, const Type& lhs, const Type& rhs, std::string_view message) { if (lhs.IsAggregate() || rhs.IsAggregate()) { - return absl::InvalidArgumentError(absl::StrFormat( - "XlsTypeError: %s %s\n" - " %s\n" - "vs %s", - span.ToString(), message, lhs.ToErrorString(), rhs.ToErrorString())); + XLS_ASSIGN_OR_RETURN(std::string type_diff, FormatTypeMismatch(lhs, rhs)); + return absl::InvalidArgumentError( + absl::StrFormat("XlsTypeError: %s %s\n" + "%s", + span.ToString(), message, type_diff)); } return absl::InvalidArgumentError( absl::StrFormat("XlsTypeError: %s %s vs %s: %s", span.ToString(), lhs.ToErrorString(), rhs.ToErrorString(), message)); } +// Creates an XlsTypeErrorStatus using the data within the type mismatch struct. absl::Status MakeTypeError(const TypeMismatchErrorData& data) { return XlsTypeErrorStatus(data.error_span, *data.lhs, *data.rhs, data.message); diff --git a/xls/dslx/type_system/type.h b/xls/dslx/type_system/type.h index fc291551b5..0c0a6e6687 100644 --- a/xls/dslx/type_system/type.h +++ b/xls/dslx/type_system/type.h @@ -378,6 +378,15 @@ class StructType : public Type { return std::make_unique(CloneSpan(members_), struct_def_); } + std::optional IndexOf(const Type& e) const { + for (int64_t i = 0; i < size(); ++i) { + if (&GetMemberType(i) == &e) { + return i; + } + } + return std::nullopt; + } + // For user-level error reporting, we also note the name of the struct // definition if one is available. std::string ToErrorString() const override; @@ -453,6 +462,15 @@ class TupleType : public Type { // where we want to say "void" similar to a scalar. bool IsAggregate() const override { return !empty(); } + std::optional IndexOf(const Type& e) const { + for (int64_t i = 0; i < size(); ++i) { + if (&GetMemberType(i) == &e) { + return i; + } + } + return std::nullopt; + } + std::unique_ptr CloneToUnique() const override; bool empty() const; diff --git a/xls/dslx/type_system/typecheck_module_test.cc b/xls/dslx/type_system/typecheck_module_test.cc index 26b3f8c18b..baced58a5e 100644 --- a/xls/dslx/type_system/typecheck_module_test.cc +++ b/xls/dslx/type_system/typecheck_module_test.cc @@ -595,9 +595,11 @@ fn p(x: (u32, u64)[N]) -> u32 { x[0].0 } fn main() -> u32 { p(u32[1][1]:[[u32:0]]) })"), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("expected argument kind 'array' to match parameter " - "kind 'tuple'\n (uN[32], uN[64])\nvs uN[32][1]"))); + StatusIs( + absl::StatusCode::kInvalidArgument, + AllOf(HasSubstr("expected argument kind 'array' to match parameter " + "kind 'tuple'"), + HasSubstr("(uN[32], uN[64])\nvs uN[32][1]")))); } TEST(TypecheckTest, ForBuiltinInBody) { @@ -629,7 +631,7 @@ fn f(x: u32) -> (u32, u8) { })"), StatusIs( absl::StatusCode::kInvalidArgument, - AllOf(HasSubstr("(uN[32], uN[8])\nvs (uN[32], (uN[32], uN[8])"), + AllOf(HasSubstr("uN[8]\nvs (uN[32], uN[8])"), HasSubstr( "For-loop annotated type did not match inferred type.")))); } @@ -1489,9 +1491,9 @@ fn f() -> MyEnum { MyEnum::C } HasSubstr("Name 'C' is not defined by the enum MyEnum"))); } +// Nominal typing not structural, e.g. OtherPoint cannot be passed where we want +// a Point, even though their members are the same. TEST(TypecheckTest, NominalTyping) { - // Nominal typing not structural, e.g. OtherPoint cannot be passed where we - // want a Point, even though their members are the same. EXPECT_THAT(Typecheck(R"( struct Point { x: s8, y: u32 } struct OtherPoint { x: s8, y: u32 } @@ -1502,8 +1504,8 @@ fn g() -> Point { } )"), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("struct 'Point' structure: Point { x: sN[8], " - "y: uN[32] }\nvs struct 'OtherPoint'"))); + HasSubstr("Point { x: sN[8], y: uN[32] }\nvs OtherPoint " + "{ x: sN[8], y: uN[32] }"))); } TEST(TypecheckTest, ParametricWithConstantArrayEllipsis) { @@ -1688,7 +1690,7 @@ fn g() -> (s8, u32) { } )"), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("(sN[8], uN[32])\nvs struct 'Point' structure"))); + HasSubstr("(sN[8], uN[32])\nvs Point { x: sN[8], y: uN[32] }"))); } TEST(TypecheckStructInstanceTest, SplatWithDuplicate) { @@ -1783,14 +1785,11 @@ fn main() { } TEST(TypecheckParametricStructInstanceTest, BadReturnType) { - EXPECT_THAT( - TypecheckParametricStructInstance( - "fn f() -> Point<5, 10> { Point { x: u32:5, y: u64:255 } }"), - StatusIs( - absl::StatusCode::kInvalidArgument, - HasSubstr( - " struct 'Point' structure: Point { x: uN[32], y: uN[64] }\nvs " - "struct 'Point' structure: Point { x: uN[5], y: uN[10] }"))); + EXPECT_THAT(TypecheckParametricStructInstance( + "fn f() -> Point<5, 10> { Point { x: u32:5, y: u64:255 } }"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Point { x: uN[32], y: uN[64] }\nvs Point { " + "x: uN[5], y: uN[10] }"))); } // Bad struct type-parametric instantiation in parametric function. diff --git a/xls/dslx/type_system/zip_types.cc b/xls/dslx/type_system/zip_types.cc index 9669225ebe..5e16876e57 100644 --- a/xls/dslx/type_system/zip_types.cc +++ b/xls/dslx/type_system/zip_types.cc @@ -24,14 +24,23 @@ namespace xls::dslx { namespace { +// Forward decl. +absl::Status ZipTypesWithParents(const Type& lhs, const Type& rhs, + const Type* lhs_parent, const Type* rhs_parent, + ZipTypesCallbacks& callbacks); + // This is an implementation detail in traversing types and then recursively // calling ZipTypes -- we inherit TypeVisitor because we need to learn the // actual type of the generic `Type` on the left hand side and then compare that // to what we see on the right hand side at each step. class ZipTypeVisitor : public TypeVisitor { public: - explicit ZipTypeVisitor(const Type& rhs, ZipTypesCallbacks& callbacks) - : rhs_(rhs), callbacks_(callbacks) {} + explicit ZipTypeVisitor(const Type& rhs, const Type* lhs_parent, + const Type* rhs_parent, ZipTypesCallbacks& callbacks) + : rhs_(rhs), + lhs_parent_(lhs_parent), + rhs_parent_(rhs_parent), + callbacks_(callbacks) {} ~ZipTypeVisitor() override = default; @@ -56,13 +65,15 @@ class ZipTypeVisitor : public TypeVisitor { if (auto* rhs = dynamic_cast(&rhs_)) { return HandleTupleLike(lhs, *rhs); } - return callbacks_.NoteTypeMismatch(lhs, rhs_); + return callbacks_.NoteTypeMismatch(lhs, lhs_parent_, rhs_, rhs_parent_); } absl::Status HandleStruct(const StructType& lhs) override { if (auto* rhs = dynamic_cast(&rhs_)) { - return HandleTupleLike(lhs, *rhs); + if (&lhs.nominal_type() == &rhs->nominal_type()) { + return HandleTupleLike(lhs, *rhs); + } } - return callbacks_.NoteTypeMismatch(lhs, rhs_); + return callbacks_.NoteTypeMismatch(lhs, lhs_parent_, rhs_, rhs_parent_); } absl::Status HandleArray(const ArrayType& lhs) override { if (auto* rhs = dynamic_cast(&rhs_)) { @@ -73,7 +84,7 @@ class ZipTypeVisitor : public TypeVisitor { XLS_RETURN_IF_ERROR(ZipTypes(lhs_elem, rhs_elem, callbacks_)); return callbacks_.NoteAggregateEnd(aggregates); } - return callbacks_.NoteTypeMismatch(lhs, rhs_); + return callbacks_.NoteTypeMismatch(lhs, lhs_parent_, rhs_, rhs_parent_); } absl::Status HandleChannel(const ChannelType& lhs) override { if (auto* rhs = dynamic_cast(&rhs_)) { @@ -83,7 +94,7 @@ class ZipTypeVisitor : public TypeVisitor { ZipTypes(lhs.payload_type(), rhs->payload_type(), callbacks_)); return callbacks_.NoteAggregateEnd(aggregates); } - return callbacks_.NoteTypeMismatch(lhs, rhs_); + return callbacks_.NoteTypeMismatch(lhs, lhs_parent_, rhs_, rhs_parent_); } absl::Status HandleFunction(const FunctionType& lhs) override { if (auto* rhs = dynamic_cast(&rhs_)) { @@ -97,7 +108,7 @@ class ZipTypeVisitor : public TypeVisitor { ZipTypes(lhs.return_type(), rhs->return_type(), callbacks_)); return callbacks_.NoteAggregateEnd(aggregates); } - return callbacks_.NoteTypeMismatch(lhs, rhs_); + return callbacks_.NoteTypeMismatch(lhs, lhs_parent_, rhs_, rhs_parent_); } absl::Status HandleMeta(const MetaType& lhs) override { if (auto* rhs = dynamic_cast(&rhs_)) { @@ -107,7 +118,7 @@ class ZipTypeVisitor : public TypeVisitor { ZipTypes(*lhs.wrapped(), *rhs->wrapped(), callbacks_)); return callbacks_.NoteAggregateEnd(aggregates); } - return callbacks_.NoteTypeMismatch(lhs, rhs_); + return callbacks_.NoteTypeMismatch(lhs, lhs_parent_, rhs_, rhs_parent_); } private: @@ -116,14 +127,15 @@ class ZipTypeVisitor : public TypeVisitor { absl::Status HandleTupleLike(const T& lhs, const T& rhs) { bool structurally_compatible = lhs.size() == rhs.size(); if (!structurally_compatible) { - return callbacks_.NoteTypeMismatch(lhs, rhs); + return callbacks_.NoteTypeMismatch(lhs, lhs_parent_, rhs, rhs_parent_); } AggregatePair aggregates = std::make_pair(&lhs, &rhs); XLS_RETURN_IF_ERROR(callbacks_.NoteAggregateStart(aggregates)); for (int64_t i = 0; i < lhs.size(); ++i) { const Type& lhs_elem = lhs.GetMemberType(i); const Type& rhs_elem = rhs.GetMemberType(i); - XLS_RETURN_IF_ERROR(ZipTypes(lhs_elem, rhs_elem, callbacks_)); + XLS_RETURN_IF_ERROR( + ZipTypesWithParents(lhs_elem, rhs_elem, &lhs, &rhs, callbacks_)); } XLS_RETURN_IF_ERROR(callbacks_.NoteAggregateEnd(aggregates)); return absl::OkStatus(); @@ -131,21 +143,30 @@ class ZipTypeVisitor : public TypeVisitor { absl::Status HandleNonAggregate(const Type& lhs) { if (lhs.CompatibleWith(rhs_)) { - return callbacks_.NoteMatchedLeafType(lhs, rhs_); + return callbacks_.NoteMatchedLeafType(lhs, lhs_parent_, rhs_, + rhs_parent_); } - return callbacks_.NoteTypeMismatch(lhs, rhs_); + return callbacks_.NoteTypeMismatch(lhs, lhs_parent_, rhs_, rhs_parent_); } const Type& rhs_; + const Type* lhs_parent_; + const Type* rhs_parent_; ZipTypesCallbacks& callbacks_; }; +absl::Status ZipTypesWithParents(const Type& lhs, const Type& rhs, + const Type* lhs_parent, const Type* rhs_parent, + ZipTypesCallbacks& callbacks) { + ZipTypeVisitor visitor(rhs, lhs_parent, rhs_parent, callbacks); + return lhs.Accept(visitor); +} + } // namespace absl::Status ZipTypes(const Type& lhs, const Type& rhs, ZipTypesCallbacks& callbacks) { - ZipTypeVisitor visitor(rhs, callbacks); - return lhs.Accept(visitor); + return ZipTypesWithParents(lhs, rhs, nullptr, nullptr, callbacks); } } // namespace xls::dslx diff --git a/xls/dslx/type_system/zip_types.h b/xls/dslx/type_system/zip_types.h index b0eb3d3f05..cbe40ed202 100644 --- a/xls/dslx/type_system/zip_types.h +++ b/xls/dslx/type_system/zip_types.h @@ -48,12 +48,16 @@ class ZipTypesCallbacks { // Called when there is a leaf type (non aggregate) where the types are // type-compatible. virtual absl::Status NoteMatchedLeafType(const Type& lhs, - const Type& rhs) = 0; + const Type* lhs_parent, + const Type& rhs, + const Type* rhs_parent) = 0; // Called when there is a type (could be leaf or aggregate) where the types // are not type-compatible -- we do not recurse into these as they likely do // not have a common internal structure given that they mismatch. - virtual absl::Status NoteTypeMismatch(const Type& lhs, const Type& rhs) = 0; + virtual absl::Status NoteTypeMismatch(const Type& lhs, const Type* lhs_parent, + const Type& rhs, + const Type* rhs_parent) = 0; }; // Zips the /common structure/ of "lhs" type and "rhs" type, invoking "f" at all diff --git a/xls/dslx/type_system/zip_types_test.cc b/xls/dslx/type_system/zip_types_test.cc index 65d71a76e2..d92a93c917 100644 --- a/xls/dslx/type_system/zip_types_test.cc +++ b/xls/dslx/type_system/zip_types_test.cc @@ -46,7 +46,9 @@ enum class CallbackKind : uint8_t { struct CallbackData { CallbackKind kind; const Type* lhs = nullptr; + const Type* lhs_parent = nullptr; const Type* rhs = nullptr; + const Type* rhs_parent = nullptr; std::optional aggregates; }; @@ -101,14 +103,24 @@ class ZipTypesCallbacksCollector : public ZipTypesCallbacks { CallbackData{.kind = CallbackKind::kAggregateEnd, .aggregates = pair}); return absl::OkStatus(); } - absl::Status NoteMatchedLeafType(const Type& lhs, const Type& rhs) override { - data_.push_back(CallbackData{ - .kind = CallbackKind::kMatchedLeaf, .lhs = &lhs, .rhs = &rhs}); + absl::Status NoteMatchedLeafType(const Type& lhs, const Type* lhs_parent, + const Type& rhs, + const Type* rhs_parent) override { + data_.push_back(CallbackData{.kind = CallbackKind::kMatchedLeaf, + .lhs = &lhs, + .lhs_parent = lhs_parent, + .rhs = &rhs, + .rhs_parent = rhs_parent}); return absl::OkStatus(); } - absl::Status NoteTypeMismatch(const Type& lhs, const Type& rhs) override { - data_.push_back(CallbackData{ - .kind = CallbackKind::kMismatch, .lhs = &lhs, .rhs = &rhs}); + absl::Status NoteTypeMismatch(const Type& lhs, const Type* lhs_parent, + const Type& rhs, + const Type* rhs_parent) override { + data_.push_back(CallbackData{.kind = CallbackKind::kMismatch, + .lhs = &lhs, + .lhs_parent = lhs_parent, + .rhs = &rhs, + .rhs_parent = rhs_parent}); return absl::OkStatus(); } @@ -147,9 +159,10 @@ TEST(ZipTypesTest, BitsConstructorVsBitsType) { ZipTypesCallbacksCollector collector; XLS_ASSERT_OK(ZipTypes(*lhs, *rhs, collector)); - EXPECT_THAT(collector.data(), - ElementsAre(FieldsAre(CallbackKind::kMatchedLeaf, lhs.get(), - rhs.get(), std::nullopt))); + EXPECT_THAT( + collector.data(), + ElementsAre(FieldsAre(CallbackKind::kMatchedLeaf, lhs.get(), nullptr, + rhs.get(), nullptr, std::nullopt))); } TEST(ZipTypesTest, TupleWithOneDifferingElement) { @@ -167,16 +180,18 @@ TEST(ZipTypesTest, TupleWithOneDifferingElement) { std::make_pair(lhs.get(), rhs.get()); EXPECT_THAT(collector.data()[0], FieldsAre(CallbackKind::kAggregateStart, nullptr, nullptr, - AggregatePair{aggregates})); - EXPECT_THAT(collector.data()[1], - FieldsAre(CallbackKind::kMatchedLeaf, &lhs->GetMemberType(0), - &rhs->GetMemberType(0), std::nullopt)); - EXPECT_THAT(collector.data()[2], - FieldsAre(CallbackKind::kMismatch, &lhs->GetMemberType(1), - &rhs->GetMemberType(1), std::nullopt)); + nullptr, nullptr, AggregatePair{aggregates})); + EXPECT_THAT( + collector.data()[1], + FieldsAre(CallbackKind::kMatchedLeaf, &lhs->GetMemberType(0), lhs.get(), + &rhs->GetMemberType(0), rhs.get(), std::nullopt)); + EXPECT_THAT( + collector.data()[2], + FieldsAre(CallbackKind::kMismatch, &lhs->GetMemberType(1), lhs.get(), + &rhs->GetMemberType(1), rhs.get(), std::nullopt)); EXPECT_THAT(collector.data()[3], - FieldsAre(CallbackKind::kAggregateEnd, nullptr, nullptr, - AggregatePair{aggregates})); + FieldsAre(CallbackKind::kAggregateEnd, nullptr, nullptr, nullptr, + nullptr, AggregatePair{aggregates})); } } // namespace