diff --git a/xls/codegen/vast/BUILD b/xls/codegen/vast/BUILD index c7687be1c7..a3e477c509 100644 --- a/xls/codegen/vast/BUILD +++ b/xls/codegen/vast/BUILD @@ -161,14 +161,24 @@ cc_library( hdrs = ["dslx_type_fixer.h"], deps = [ "//xls/common:casts", + "//xls/common/status:ret_check", "//xls/common/status:status_macros", + "//xls/dslx:errors", + "//xls/dslx:import_data", "//xls/dslx/frontend:ast", "//xls/dslx/frontend:ast_cloner", "//xls/dslx/frontend:ast_node", + "//xls/dslx/frontend:module", + "//xls/dslx/frontend:pos", "//xls/dslx/type_system:type", "//xls/dslx/type_system:type_info", + "//xls/dslx/type_system_v2:inference_table", + "//xls/dslx/type_system_v2:type_annotation_utils", + "//xls/dslx/type_system_v2:type_inference_error_handler", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", ], ) @@ -187,7 +197,6 @@ cc_library( "//xls/dslx:create_import_data", "//xls/dslx:import_data", "//xls/dslx:import_routines", - "//xls/dslx:interp_bindings", "//xls/dslx:interp_value", "//xls/dslx:parse_and_typecheck", "//xls/dslx:virtualizable_file_system", @@ -205,13 +214,10 @@ cc_library( "//xls/dslx/frontend:parser", "//xls/dslx/frontend:pos", "//xls/dslx/frontend:scanner", - "//xls/dslx/type_system:deduce", - "//xls/dslx/type_system:deduce_ctx", - "//xls/dslx/type_system:type", + "//xls/dslx/ir_convert:convert_options", "//xls/dslx/type_system:type_info", - "//xls/dslx/type_system:typecheck_function", - "//xls/dslx/type_system:typecheck_invocation", - "//xls/dslx/type_system:typecheck_module", + "//xls/dslx/type_system_v2:type_inference_error_handler", + "//xls/dslx/type_system_v2:typecheck_module_v2", "//xls/ir:bits", "//xls/ir:bits_ops", "//xls/ir:format_preference", @@ -247,10 +253,7 @@ cc_library( "//xls/dslx/frontend:module", "//xls/dslx/frontend:pos", "//xls/dslx/frontend:token", - "//xls/dslx/type_system:deduce_ctx", "//xls/dslx/type_system:type", - "//xls/dslx/type_system:type_info", - "//xls/dslx/type_system:unwrap_meta_type", "//xls/ir:source_location", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/xls/codegen/vast/dslx_builder.cc b/xls/codegen/vast/dslx_builder.cc index 280419ce8a..adb5a8574c 100644 --- a/xls/codegen/vast/dslx_builder.cc +++ b/xls/codegen/vast/dslx_builder.cc @@ -55,14 +55,11 @@ #include "xls/dslx/import_data.h" #include "xls/dslx/import_routines.h" #include "xls/dslx/interp_value.h" +#include "xls/dslx/ir_convert/convert_options.h" #include "xls/dslx/parse_and_typecheck.h" -#include "xls/dslx/type_system/deduce.h" -#include "xls/dslx/type_system/deduce_ctx.h" -#include "xls/dslx/type_system/type.h" #include "xls/dslx/type_system/type_info.h" -#include "xls/dslx/type_system/typecheck_function.h" -#include "xls/dslx/type_system/typecheck_invocation.h" -#include "xls/dslx/type_system/typecheck_module.h" +#include "xls/dslx/type_system_v2/type_inference_error_handler.h" +#include "xls/dslx/type_system_v2/typecheck_module_v2.h" #include "xls/dslx/virtualizable_file_system.h" #include "xls/dslx/warning_collector.h" #include "xls/dslx/warning_kind.h" @@ -293,13 +290,7 @@ DslxBuilder::DslxBuilder( resolver_(resolver), warnings_(warnings), type_info_(GetTypeInfoOrDie(import_data_, &module_)), - deduce_ctx_(type_info_, &module_, dslx::Deduce, &dslx::TypecheckFunction, - /*typecheck_module=*/nullptr, dslx::TypecheckInvocation, - &import_data_, &warnings_, - /*parent=*/nullptr), - vast_type_map_(vast_type_map) { - deduce_ctx_.fn_stack().push_back(dslx::FnStackEntry::MakeTop(&module_)); -} + vast_type_map_(vast_type_map) {} absl::StatusOr DslxBuilder::MakeNameRefAndCast( verilog::Expression* expr, const dslx::Span& span, std::string_view name, @@ -343,31 +334,23 @@ absl::StatusOr DslxBuilder::HandleConstantDecl( verilog::Parameter* parameter, std::string_view name, dslx::Expr* expr) { auto* name_def = resolver_->MakeNameDef(*this, span, name, parameter, vast_module); - auto* constant_def = module().Make( - span, name_def, /*type_annotation=*/nullptr, expr, - /*is_public=*/true); + + dslx::TypeAnnotation* type_annotation = nullptr; + if (parameter->def() && parameter->def()->data_type() && + parameter->def()->data_type()->IsUserDefined()) { + XLS_ASSIGN_OR_RETURN( + type_annotation, + VastTypeToDslxTypeForCast(dslx::Span(), parameter->def()->data_type())); + } + + auto* constant_def = + module().Make(span, name_def, type_annotation, expr, + /*is_public=*/true); name_def->set_definer(constant_def); XLS_RETURN_IF_ERROR( module().AddTop(constant_def, /*make_collision_error=*/nullptr)); - // Note: historically, the propagation of errors was incorrectly dropped by - // the caller, hiding the fact that the logic from here down is a nice-to-have - // and historically fails e.g. in ConcatOfEnums. - auto deduced_type = deduce_ctx().Deduce(expr); - absl::StatusOr value = - InterpretExpr(import_data(), type_info(), expr); - if (deduced_type.ok() && value.ok()) { - bindings_.AddValue(name_def->identifier(), *value); - import_data_.GetOrCreateTopLevelBindings(&module()).AddValue( - name_def->identifier(), *value); - } - - auto def_deduced_type = deduce_ctx().Deduce(constant_def); - if (!def_deduced_type.ok()) { - VLOG(2) << "Failed to deduce constant def type: " - << def_deduced_type.status(); - } // Add a comment with the value, if it is not obvious and can be folded. if (parameter->rhs() != nullptr && !parameter->rhs()->IsLiteral()) { absl::StatusOr folded_value = @@ -503,9 +486,9 @@ absl::StatusOr DslxBuilder::GetOrImportModule( dslx::DoImport( [this](std::unique_ptr module, std::filesystem::path path) { - return dslx::TypecheckModule(std::move(module), - path, &import_data_, - &warnings_); + return dslx::TypecheckModuleV2( + std::move(module), path, &import_data_, + &warnings_, nullptr, nullptr); }, import_tokens, &import_data_, dslx::Span::Fake(), import_data_.vfs())); @@ -518,11 +501,9 @@ absl::StatusOr DslxBuilder::GetOrImportModule( name_def->set_definer(import); XLS_RETURN_IF_ERROR( module().AddTop(import, /*make_collision_error=*/nullptr)); - deduce_ctx().type_info()->AddImport(import, &mod_info->module(), - mod_info->type_info()); + type_info_->AddImport(import, &mod_info->module(), mod_info->type_info()); import_data_.GetOrCreateTopLevelBindings(&module()).AddModule( tail, &mod_info->module()); - bindings_.AddModule(tail, &mod_info->module()); } std::optional std_or = module().FindMemberWithName(tail); @@ -541,15 +522,15 @@ absl::StatusOr DslxBuilder::GetOrImportModule( absl::StatusOr DslxBuilder::CastToInferredVastType( verilog::Expression* vast_expr, dslx::Expr* expr, - bool cast_enum_to_builtin) { + bool force_cast_user_defined) { XLS_ASSIGN_OR_RETURN(verilog::DataType * vast_type, GetVastDataType(vast_expr)); - return Cast(vast_type, expr, cast_enum_to_builtin); + return Cast(vast_type, expr, force_cast_user_defined); } absl::StatusOr DslxBuilder::Cast(verilog::DataType* vast_type, dslx::Expr* expr, - bool cast_enum_to_builtin) { + bool force_cast_user_defined) { if (!vast_type->FlatBitCountAsInt64().ok()) { VLOG(2) << "Warning: cannot insert a cast of expr: " << expr->ToString() << " to type: " << vast_type->Emit(nullptr) @@ -557,37 +538,19 @@ absl::StatusOr DslxBuilder::Cast(verilog::DataType* vast_type, "generally OK if the width is not statically computed."; return expr; } - absl::StatusOr> deduced_dslx_type = - deduce_ctx().Deduce(expr); - if (!deduced_dslx_type.ok()) { - VLOG(2) << "Warning: Pessimistically inserting a cast of expr: " - << expr->ToString() << " to type: " << vast_type->Emit(nullptr) - << " because the DSLX type cannot be deduced. This may happen if a " - "parameter value has a system function call."; - return CastInternal(vast_type, expr); - } - if ((*deduced_dslx_type)->HasEnum()) { - if (cast_enum_to_builtin || !vast_type->IsUserDefined()) { - // DSLX considers enum and values to mismatch operands of an equivalent - // built-in type. VAST type inference will say in that case that they are - // both the generic type, and we need the cast to make DSLX comply. - return CastInternal(vast_type, expr, cast_enum_to_builtin); - } - XLS_ASSIGN_OR_RETURN(dslx::TypeDim deduced_dslx_dim, - (*deduced_dslx_type)->GetTotalBitCount()); - XLS_ASSIGN_OR_RETURN(int64_t deduced_dslx_bit_count, - deduced_dslx_dim.GetAsInt64()); - XLS_ASSIGN_OR_RETURN(int64_t verilog_bit_count, - vast_type->FlatBitCountAsInt64()); - if (deduced_dslx_bit_count != verilog_bit_count) { - return Cast(vast_type, expr); + if (vast_type->IsUserDefined()) { + const auto* typedef_type = + dynamic_cast(vast_type); + // Integer type aliases should really be casted, but enums should not. + if (typedef_type && + !typedef_type->type_def()->data_type()->IsUserDefined()) { + return CastInternal(vast_type, expr); } - absl::StatusOr deduced_dslx_signed = - dslx::IsSigned(**deduced_dslx_type); - if (!deduced_dslx_signed.ok() || - *deduced_dslx_signed != vast_type->is_signed()) { - return Cast(vast_type, expr); + + if (force_cast_user_defined) { + return CastInternal(vast_type, expr); } + return expr; } @@ -687,14 +650,18 @@ absl::StatusOr DslxBuilder::GetVastDataType( absl::StatusOr DslxBuilder::RoundTrip( const dslx::Module& module, std::string_view path, - dslx::ImportData& import_data) { + dslx::ImportData& import_data, + dslx::TypeInferenceErrorHandler error_handler) { const std::string text = module.ToString(); dslx::Fileno fileno = import_data.file_table().GetOrCreate(path); dslx::Scanner scanner(import_data.file_table(), fileno, text); dslx::Parser parser(module_.name(), &scanner); XLS_ASSIGN_OR_RETURN( dslx::TypecheckedModule parsed_module, - ParseAndTypecheck(text, path, module_.name(), &import_data), + ParseAndTypecheck(text, path, module_.name(), &import_data, + /*comments=*/nullptr, + /*force_version=*/dslx::TypeInferenceVersion::kVersion2, + /*options=*/dslx::ConvertOptions{}, error_handler), _ << "Failed to parse and typecheck module:\n" << text); return parsed_module; @@ -726,22 +693,33 @@ absl::StatusOr DslxBuilder::FormatModule() { // inference, like removal of dead casts. dslx::ImportData initial_import_data = CreateImportData(); const std::string file_name = module_.name() + ".x"; + std::unique_ptr fixer = + CreateDslxTypeFixer(module_, import_data_); XLS_ASSIGN_OR_RETURN(dslx::TypecheckedModule initial_module, - RoundTrip(module_, file_name, initial_import_data)); - std::unique_ptr fixer = CreateDslxTypeFixer(); + RoundTrip(module_, file_name, initial_import_data, + fixer->GetErrorHandler())); XLS_ASSIGN_OR_RETURN( - std::unique_ptr module_with_stripped_casts, + std::unique_ptr module_with_errors_fixed, CloneModule(*initial_module.module, - fixer->GetReplacer(initial_module.type_info))); + fixer->GetErrorFixReplacer(initial_module.type_info))); + + dslx::ImportData fix_import_data = CreateImportData(); + XLS_ASSIGN_OR_RETURN( + dslx::TypecheckedModule fixed_and_typechecked_module, + RoundTrip(*module_with_errors_fixed, file_name, fix_import_data)); + XLS_ASSIGN_OR_RETURN( + std::unique_ptr simplified_module, + CloneModule( + *fixed_and_typechecked_module.module, + fixer->GetSimplifyReplacer(fixed_and_typechecked_module.type_info))); // We now need to round-trip the module to text and back to AST, without the // comments, in order for the nodes to get spans accurately representing the // DSLX as opposed to the source Verilog. We then position the comments // relative to the appropriate spans. dslx::ImportData import_data = CreateImportData(); - XLS_ASSIGN_OR_RETURN( - dslx::TypecheckedModule parsed_module, - RoundTrip(*module_with_stripped_casts, file_name, import_data)); + XLS_ASSIGN_OR_RETURN(dslx::TypecheckedModule parsed_module, + RoundTrip(*simplified_module, file_name, import_data)); std::vector comment_data; for (const auto& [type_name, comment] : type_def_comments_) { diff --git a/xls/codegen/vast/dslx_builder.h b/xls/codegen/vast/dslx_builder.h index e1613969c5..9f74b95bbf 100644 --- a/xls/codegen/vast/dslx_builder.h +++ b/xls/codegen/vast/dslx_builder.h @@ -28,11 +28,10 @@ #include "xls/dslx/frontend/module.h" #include "xls/dslx/frontend/pos.h" #include "xls/dslx/import_data.h" -#include "xls/dslx/interp_bindings.h" #include "xls/dslx/interp_value.h" #include "xls/dslx/parse_and_typecheck.h" -#include "xls/dslx/type_system/deduce_ctx.h" #include "xls/dslx/type_system/type_info.h" +#include "xls/dslx/type_system_v2/type_inference_error_handler.h" #include "xls/dslx/warning_collector.h" #include "xls/ir/bits.h" #include "xls/ir/format_preference.h" @@ -143,14 +142,14 @@ class DslxBuilder { // operands). absl::StatusOr CastToInferredVastType( verilog::Expression* vast_expr, dslx::Expr* expr, - bool cast_enum_to_builtin = false); + bool force_cast_user_defined = false); // Returns `expr` casted to the equivalent of the specified `vast_type`. If // `cast_enum_to_builtin` is true, then the corresponding DSLX built-in type // will be used for any VAST enum type. absl::StatusOr Cast(verilog::DataType* vast_type, dslx::Expr* expr, - bool cast_enum_to_builtin = false); + bool force_cast_user_defined = false); dslx::Unop* HandleUnaryOperator(const dslx::Span& span, dslx::UnopKind unop_kind, dslx::Expr* arg); @@ -183,8 +182,6 @@ class DslxBuilder { absl::StatusOr FormatModule(); dslx::ImportData& import_data() { return import_data_; } - dslx::DeduceCtx& deduce_ctx() { return deduce_ctx_; } - dslx::TypeInfo& type_info() { return *type_info_; } dslx::Module& module() { return module_; } dslx::FileTable& file_table() { return import_data_.file_table(); } @@ -206,7 +203,8 @@ class DslxBuilder { absl::StatusOr RoundTrip( const dslx::Module& module, std::string_view path, - dslx::ImportData& import_data); + dslx::ImportData& import_data, + dslx::TypeInferenceErrorHandler error_handler = nullptr); dslx::ImportData CreateImportData(); @@ -219,9 +217,6 @@ class DslxBuilder { dslx::WarningCollector warnings_; dslx::TypeInfo* const type_info_; - dslx::DeduceCtx deduce_ctx_; - dslx::InterpBindings bindings_; - const absl::flat_hash_map& vast_type_map_; diff --git a/xls/codegen/vast/dslx_type_fixer.cc b/xls/codegen/vast/dslx_type_fixer.cc index 21d8e0ea38..75fd4095a7 100644 --- a/xls/codegen/vast/dslx_type_fixer.cc +++ b/xls/codegen/vast/dslx_type_fixer.cc @@ -14,41 +14,217 @@ #include "xls/codegen/vast/dslx_type_fixer.h" +#include #include #include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" #include "xls/common/casts.h" +#include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" +#include "xls/dslx/errors.h" #include "xls/dslx/frontend/ast.h" #include "xls/dslx/frontend/ast_cloner.h" #include "xls/dslx/frontend/ast_node.h" +#include "xls/dslx/frontend/pos.h" +#include "xls/dslx/import_data.h" #include "xls/dslx/type_system/type.h" #include "xls/dslx/type_system/type_info.h" +#include "xls/dslx/type_system_v2/inference_table.h" +#include "xls/dslx/type_system_v2/type_annotation_utils.h" +#include "xls/dslx/type_system_v2/type_inference_error_handler.h" namespace xls { namespace dslx { class DslxTypeFixerImpl : public DslxTypeFixer { public: - CloneReplacer GetReplacer(const TypeInfo* ti) final { - return [this, ti](const AstNode* node, Module* target_module, - const absl::flat_hash_map&) + DslxTypeFixerImpl(Module& original_module) + : original_module_(original_module) {} + + CloneReplacer GetErrorFixReplacer(const TypeInfo* ti) final { + return GetErrorFixReplacerInternal(ti, /*root_node_to_preserve=*/nullptr); + } + + CloneReplacer GetSimplifyReplacer(const TypeInfo* ti) final { + return GetSimplifyReplacerInternal(ti, /*root_node_to_preserve=*/nullptr); + } + + TypeInferenceErrorHandler GetErrorHandler() { + return [this](const AstNode* node, absl::Span types) { + return HandleError(node, types); + }; + } + + private: + absl::StatusOr HandleError( + const AstNode* node, absl::Span types) { + // Verilog allows a signed index value but DSLX does not. We drop the + // signed annotation to fix this. + if (node->kind() == AstNodeKind::kNumber) { + const auto* number = down_cast(node); + if (number->parent() && number->type_annotation() && + dynamic_cast(node->parent())) { + literals_with_dropped_annotations_.insert(number); + return CreateU32Annotation(original_module_, number->span()); + } + } + + // Being asked to fix a function type means we are calling a builtin with + // the wrong type, in which case the callee type is the first thing we hit. + for (const CandidateType& candidate : types) { + if (candidate.type->IsFunction() && + candidate.flags.HasFlag(TypeInferenceFlag::kFormalFunctionType)) { + XLS_RET_CHECK(node->parent() != nullptr); + XLS_RET_CHECK(node->parent()->kind() == AstNodeKind::kInvocation); + XLS_RET_CHECK(candidate.annotation->annotation_kind() == + TypeAnnotationKind::kFunction); + invocations_with_fixed_callees_.emplace( + node->parent(), + down_cast(candidate.annotation)); + return candidate.annotation; + } + } + + // An argument motivating the fixing of a function type above will be + // encountered in a later invocation of this fixer. We detect such arguments + // here and insert casts for them. + if (node->parent() != nullptr && + node->parent()->kind() == AstNodeKind::kInvocation) { + const auto fixed_callee = + invocations_with_fixed_callees_.find(node->parent()); + if (fixed_callee != invocations_with_fixed_callees_.end()) { + absl::Span args = + down_cast(node->parent())->args(); + for (int i = 0; i < args.size(); i++) { + if (args[i] == node) { + XLS_RET_CHECK(i < fixed_callee->second->param_types().size()); + const TypeAnnotation* annotation = + fixed_callee->second->param_types()[i]; + added_casts_.emplace(node, annotation); + return annotation; + } + } + } + } + + // Verilog allows you to mix bits and enums, or different enums of + // equivalent size, in an expr. In DSLX this requires casting. What follows + // is a custom version of bits-like unification that takes this into + // account. + std::optional result; + int64_t result_size = 0; + bool result_signedness = false; + for (const CandidateType& candidate : types) { + XLS_ASSIGN_OR_RETURN(TypeDim next_size_dim, + candidate.type->GetTotalBitCount()); + XLS_ASSIGN_OR_RETURN(int64_t next_size, next_size_dim.GetAsInt64()); + XLS_ASSIGN_OR_RETURN(bool next_signedness, IsSigned(*candidate.type)); + bool formal_member_type = + candidate.flags.HasFlag(TypeInferenceFlag::kFormalMemberType); + if (result && result_signedness != next_signedness) { + if (formal_member_type) { + result = candidate; + result_size = next_size; + } else { + return absl::InvalidArgumentError( + absl::Substitute("Signed vs. unsigned mismatch: $0 vs. $1", + types[0].annotation->ToString(), + candidate.annotation->ToString())); + } + } else if (!result || next_size > result_size) { + result = candidate; + result_size = next_size; + result_signedness = next_signedness; + } + if (formal_member_type) { + break; + } + } + if (dynamic_cast(node)) { + added_casts_.emplace(node, result->annotation); + } + + return result->annotation; + } + + CloneReplacer GetErrorFixReplacerInternal( + const TypeInfo* ti, const AstNode* root_node_to_preserve) { + return [this, ti, root_node_to_preserve]( + const AstNode* node, Module* target_module, + const absl::flat_hash_map&) + -> absl::StatusOr> { + if (node == root_node_to_preserve) { + return std::nullopt; + } + + // Drop literal annotations that the error handler said we don't want. + if (literals_with_dropped_annotations_.contains(node)) { + const auto* number = down_cast(node); + return target_module->Make(number->span(), number->text(), + NumberKind::kOther, + /*type_annotation=*/nullptr); + } + + // Insert casts that the error handler said we needed. + const auto cast = added_casts_.find(node); + if (cast != added_casts_.end()) { + XLS_ASSIGN_OR_RETURN( + AstNode * clone, + CloneAst(node, GetErrorFixReplacerInternal(ti, node))); + return target_module->Make( + Span::None(), + const_cast(down_cast(clone)), + const_cast(cast->second)); + } + + return std::nullopt; + }; + } + + CloneReplacer GetSimplifyReplacerInternal( + const TypeInfo* ti, const AstNode* root_node_to_preserve) { + return [this, ti, root_node_to_preserve]( + const AstNode* node, Module* target_module, + const absl::flat_hash_map&) -> absl::StatusOr> { + if (node == root_node_to_preserve) { + return std::nullopt; + } + + // Drop the LHS annotation on a constant def if unnecessary. + if (node->kind() == AstNodeKind::kConstantDef) { + const std::optional lhs_type = + ti->GetItem(down_cast(node)->name_def()); + const std::optional rhs_type = + ti->GetItem(down_cast(node)->value()); + if (TypeEq(**lhs_type, **rhs_type)) { + XLS_ASSIGN_OR_RETURN( + AstNode * clone, + CloneAst(node, GetSimplifyReplacerInternal(ti, node))); + down_cast(clone)->set_type_annotation(nullptr); + return clone; + } + } + + // Drop dead casts. std::optional unwrapped = UnwrapDeadCast(ti, node); if (unwrapped.has_value()) { - XLS_ASSIGN_OR_RETURN(AstNode * clone_of_unwrapped, - CloneAst(*unwrapped, GetReplacer(ti))); + XLS_ASSIGN_OR_RETURN( + AstNode * clone_of_unwrapped, + CloneAst(*unwrapped, GetSimplifyReplacerInternal(ti, node))); Expr* result = down_cast(clone_of_unwrapped); result->set_in_parens(false); return result; } + return std::nullopt; }; } - private: std::optional UnwrapDeadCast(const TypeInfo* ti, const AstNode* node) { const std::optional casted = ti->GetItem(node); @@ -61,16 +237,8 @@ class DslxTypeFixerImpl : public DslxTypeFixer { const std::optional uncasted = ti->GetItem(expr); // See if we are OK to drop all layers of casts up to here. - if (**casted == **uncasted) { + if (TypeEq(**casted, **uncasted)) { unwrapped = expr; - } else { - std::optional casted_bits = GetBitsLike(**casted); - std::optional uncasted_bits = - GetBitsLike(**uncasted); - if (casted_bits.has_value() && uncasted_bits.has_value() && - *casted_bits == *uncasted_bits) { - unwrapped = expr; - } } node = expr; @@ -78,12 +246,28 @@ class DslxTypeFixerImpl : public DslxTypeFixer { return unwrapped; } + + bool TypeEq(const Type& a, const Type& b) { + if (a == b) { + return true; + } + std::optional a_bits = GetBitsLike(a); + std::optional b_bits = GetBitsLike(b); + return a_bits.has_value() && b_bits.has_value() && *a_bits == *b_bits; + } + + Module& original_module_; + absl::flat_hash_map added_casts_; + absl::flat_hash_set literals_with_dropped_annotations_; + absl::flat_hash_map + invocations_with_fixed_callees_; }; } // namespace dslx -std::unique_ptr CreateDslxTypeFixer() { - return std::make_unique(); +std::unique_ptr CreateDslxTypeFixer(dslx::Module& module, + const dslx::ImportData&) { + return std::make_unique(module); } } // namespace xls diff --git a/xls/codegen/vast/dslx_type_fixer.h b/xls/codegen/vast/dslx_type_fixer.h index 72096ec127..d05635a4f4 100644 --- a/xls/codegen/vast/dslx_type_fixer.h +++ b/xls/codegen/vast/dslx_type_fixer.h @@ -18,7 +18,10 @@ #include #include "xls/dslx/frontend/ast_cloner.h" +#include "xls/dslx/frontend/module.h" +#include "xls/dslx/import_data.h" #include "xls/dslx/type_system/type_info.h" +#include "xls/dslx/type_system_v2/type_inference_error_handler.h" namespace xls { @@ -31,13 +34,24 @@ class DslxTypeFixer { public: virtual ~DslxTypeFixer() = default; + // Returns an error handler that can fix type errors in the draft of the + // generated DSLX. After running type inference with this handler, the + // replacer returned by GetReplacer() injects any needed fixes that were + // discovered by the error handler. + virtual dslx::TypeInferenceErrorHandler GetErrorHandler() = 0; + + // Deals out a `CloneReplacer` that can be used to transform the draft AST + // into a fixed AST. + virtual dslx::CloneReplacer GetErrorFixReplacer(const dslx::TypeInfo* ti) = 0; + // Deals out a `CloneReplacer` that can be used to transform the draft AST // into a fixed AST. - virtual dslx::CloneReplacer GetReplacer(const dslx::TypeInfo* ti) = 0; + virtual dslx::CloneReplacer GetSimplifyReplacer(const dslx::TypeInfo* ti) = 0; }; // Deals out a type fixer. -std::unique_ptr CreateDslxTypeFixer(); +std::unique_ptr CreateDslxTypeFixer( + dslx::Module& original_module, const dslx::ImportData& import_data); } // namespace xls diff --git a/xls/codegen/vast/translate_vast_to_dslx.cc b/xls/codegen/vast/translate_vast_to_dslx.cc index b70a2f5b38..e86d8c0446 100644 --- a/xls/codegen/vast/translate_vast_to_dslx.cc +++ b/xls/codegen/vast/translate_vast_to_dslx.cc @@ -47,10 +47,7 @@ #include "xls/dslx/frontend/token.h" #include "xls/dslx/import_data.h" #include "xls/dslx/trait_visitor.h" -#include "xls/dslx/type_system/deduce_ctx.h" #include "xls/dslx/type_system/type.h" -#include "xls/dslx/type_system/type_info.h" -#include "xls/dslx/type_system/unwrap_meta_type.h" #include "xls/dslx/warning_collector.h" #include "xls/dslx/warning_kind.h" #include "xls/ir/source_location.h" @@ -145,9 +142,6 @@ class VastToDslxTranslator { TranslateType(def->data_type())); auto* param = module().Make(name_def, annot); name_def->set_definer(param); - XLS_ASSIGN_OR_RETURN(std::unique_ptr type, - deduce_ctx().Deduce(param)); - deduce_ctx().type_info()->SetItem(param->name_def(), *type); return param; } @@ -199,20 +193,6 @@ class VastToDslxTranslator { } if (!result) { - XLS_ASSIGN_OR_RETURN(auto lhs_type, deduce_ctx().Deduce(lhs)); - XLS_ASSIGN_OR_RETURN(auto rhs_type, deduce_ctx().Deduce(rhs)); - - if (*lhs_type != *rhs_type) { - // Generally, compatible but not identical types (like an enum value vs. - // generic value) will be coerced to be identical via VAST type - // inference and the insertion of casts based on that during translation - // of the LHS and RHS. In the event that this somehow gets it wrong or - // fails to appease DSLX deduction, we can't keep the expr. - return absl::InvalidArgumentError(absl::StrFormat( - "Cannot translate binop \"%s\": arguments have different types: " - "%s vs %s", - op->Emit(nullptr), lhs_type->ToString(), rhs_type->ToString())); - } // Note it uses the same span for the whole node and the operand; // the only time the operand span is used is for formatting, and // this node won't be used for formatting. @@ -227,48 +207,10 @@ class VastToDslxTranslator { } dslx::Expr* result = nullptr; for (verilog::Expression* next : concat->args()) { - // DSLX wants concat operands to be both arrays or both bits. To satisfy - // DSLX: - // - We must coerce enums to bits here. - // - Typedefs to bits will work fine with normal VAST-DSLX casting rules. - // - Structs will fail because you can't even cast them to bits. - // See https://github.com/google/xls/issues/1498. XLS_ASSIGN_OR_RETURN(dslx::Expr * dslx_expr, TranslateExpression(next)); XLS_ASSIGN_OR_RETURN(dslx_expr, dslx_builder_->CastToInferredVastType( - next, dslx_expr, /*cast_enum_to_builtin=*/true)); - XLS_ASSIGN_OR_RETURN(auto expr_type, deduce_ctx().Deduce(dslx_expr)); - bool is_signed = false; - int64_t size = 0; - auto* enum_type = dynamic_cast(expr_type.get()); - if (enum_type) { - is_signed = enum_type->is_signed(); - XLS_ASSIGN_OR_RETURN(size, enum_type->GetTotalBitCount()->GetAsInt64()); - } else { - dslx::BitsType* bits_type = - dynamic_cast(expr_type.get()); - if (bits_type == nullptr) { - return absl::InvalidArgumentError( - absl::StrCat("Cannot translate concat \"%s\": all arguments must " - "be bits-typed.", - concat->Emit(nullptr))); - } - is_signed = bits_type->is_signed(); - XLS_ASSIGN_OR_RETURN(size, bits_type->size().GetAsInt64()); - } - if (is_signed) { - dslx::Span span = dslx_expr->span(); - dslx::TypeAnnotation* annot = - module().Make( - span, dslx::BuiltinType::kBits, - module().GetOrCreateBuiltinNameDef(dslx::BuiltinType::kBits)); - auto* dim = module().Make( - span, absl::StrCat(size), dslx::NumberKind::kOther, nullptr); - annot = module().Make(dslx_expr->span(), - annot, dim); - dslx_expr = - module().Make(dslx_expr->span(), dslx_expr, annot); - } + next, dslx_expr, /*force_cast_enum=*/true)); dslx::Span span = CreateNodeSpan(concat); // Note it uses the same span for the whole node and the operand; // the only time the operand span is used is for formatting, and @@ -325,7 +267,6 @@ class VastToDslxTranslator { enum_def->set_extern_type_name(vast_name); XLS_RETURN_IF_ERROR( module().AddTop(enum_def, /*make_collision_error=*/nullptr)); - XLS_RETURN_IF_ERROR(deduce_ctx().Deduce(enum_def).status()); return enum_def; } @@ -483,25 +424,6 @@ class VastToDslxTranslator { number_value->SetTypeAnnotation(enum_type_annotation); } auto* expr = down_cast(constant_value); - XLS_ASSIGN_OR_RETURN(std::unique_ptr member_expr_type, - deduce_ctx().Deduce(expr)); - XLS_ASSIGN_OR_RETURN(std::unique_ptr enum_type, - deduce_ctx().Deduce(enum_type_annotation)); - XLS_ASSIGN_OR_RETURN( - enum_type, - UnwrapMetaType(std::move(enum_type), enum_type_annotation->span(), - "enum type", file_table())); - - XLS_ASSIGN_OR_RETURN(bool member_signedness, - IsSigned(*member_expr_type)); - XLS_ASSIGN_OR_RETURN(bool enum_signedness, IsSigned(*enum_type)); - - // We need a cast here, for example, if the member is an unsigned - // expression (e.g. a concatenation) but the enum type is signed. - if (member_signedness != enum_signedness) { - expr = module().Make(CreateNodeSpan(member), expr, - enum_type_annotation); - } members.push_back({name_def, expr}); } else { members.push_back( @@ -614,23 +536,7 @@ class VastToDslxTranslator { dslx::Import * std, dslx_builder_->GetOrImportModule(dslx::ImportTokens({"std"}))); - XLS_ASSIGN_OR_RETURN(std::unique_ptr ct, - deduce_ctx().Deduce(args[0])); - dslx::Span span = CreateNodeSpan(vast_call); - dslx::BitsType* bt = down_cast(ct.get()); - auto* ubits_type = module().Make( - span, dslx::BuiltinType::kUN, - module().GetOrCreateBuiltinNameDef(dslx::BuiltinType::kUN)); - - XLS_ASSIGN_OR_RETURN(int64_t bit_width, bt->size().GetAsInt64()); - auto* bits_size = module().Make( - span, absl::StrCat(bit_width), dslx::NumberKind::kOther, nullptr); - if (bt->is_signed()) { - auto* unsigned_type = module().Make( - span, ubits_type, bits_size); - args[0] = module().Make(span, args[0], unsigned_type); - } auto* name_ref = module().Make(span, "std", &std->name_def()); auto* fn_ref = module().Make(span, name_ref, "clog2"); @@ -638,7 +544,6 @@ class VastToDslxTranslator { XLS_ASSIGN_OR_RETURN( result, dslx_builder_->CastToInferredVastType(vast_call, result)); - XLS_RETURN_IF_ERROR(deduce_ctx().Deduce(result).status()); return result; } return absl::InvalidArgumentError( @@ -771,8 +676,6 @@ class VastToDslxTranslator { // need to use in any depth. dslx::Module& module() { return dslx_builder_->module(); } - dslx::DeduceCtx& deduce_ctx() { return dslx_builder_->deduce_ctx(); } - dslx::TypeInfo& type_info() { return dslx_builder_->type_info(); } dslx::ImportData& import_data() { return dslx_builder_->import_data(); } dslx::FileTable& file_table() { return import_data().file_table(); } diff --git a/xls/codegen/vast/translate_vast_to_dslx_test.cc b/xls/codegen/vast/translate_vast_to_dslx_test.cc index e902a9d84b..605c023fe2 100644 --- a/xls/codegen/vast/translate_vast_to_dslx_test.cc +++ b/xls/codegen/vast/translate_vast_to_dslx_test.cc @@ -863,9 +863,9 @@ pub const var_2 = u16:0xbeef; import std; -pub const var_3 = std::clog2(s32:1024 as uN[32]) as s32; // s32:10 -pub const var_4 = std::clog2(s32:1025 as uN[32]) as s32; // s32:11 -pub const var_5 = std::clog2(s32:2048 as uN[32]) as s32; // s32:11 +pub const var_3 = std::clog2(s32:1024 as bits[32]) as s32; // s32:10 +pub const var_4 = std::clog2(s32:1025 as bits[32]) as s32; // s32:11 +pub const var_5 = std::clog2(s32:2048 as bits[32]) as s32; // s32:11 pub const var_6 = u16:0xbeef; pub const var_7 = -s32:0xbeef as u16; )"; diff --git a/xls/dslx/frontend/ast.h b/xls/dslx/frontend/ast.h index 039a886f8a..9b7ce1fc4c 100644 --- a/xls/dslx/frontend/ast.h +++ b/xls/dslx/frontend/ast.h @@ -3910,6 +3910,10 @@ class ConstantDef : public AstNode { const std::string& identifier() const { return name_def_->identifier(); } NameDef* name_def() const { return name_def_; } TypeAnnotation* type_annotation() const { return type_annotation_; } + void set_type_annotation(TypeAnnotation* annotation) { + type_annotation_ = annotation; + } + Expr* value() const { return value_; } const Span& span() const { return span_; } std::optional GetSpan() const override { return span_; } diff --git a/xls/dslx/type_system_v2/inference_table.cc b/xls/dslx/type_system_v2/inference_table.cc index b612a0a648..6a295f5d1e 100644 --- a/xls/dslx/type_system_v2/inference_table.cc +++ b/xls/dslx/type_system_v2/inference_table.cc @@ -68,6 +68,8 @@ const TypeInferenceFlag TypeInferenceFlag::kBitsLikeType(1 << 4, "bits-like-type"); const TypeInferenceFlag TypeInferenceFlag::kFormalMemberType( 1 << 5, "formal-member-type"); +const TypeInferenceFlag TypeInferenceFlag::kFormalFunctionType( + 1 << 6, "formal-function-type"); namespace { diff --git a/xls/dslx/type_system_v2/inference_table.h b/xls/dslx/type_system_v2/inference_table.h index c112ac7d8a..f79a204789 100644 --- a/xls/dslx/type_system_v2/inference_table.h +++ b/xls/dslx/type_system_v2/inference_table.h @@ -96,6 +96,9 @@ class TypeInferenceFlag { // want to solve for). static const TypeInferenceFlag kFormalMemberType; + // Indicates the formal type of a function. + static const TypeInferenceFlag kFormalFunctionType; + bool HasFlag(const TypeInferenceFlag& value) const { if (value.flags_ == kNone.flags_) { return flags_ == kNone.flags_; @@ -110,8 +113,9 @@ class TypeInferenceFlag { // 1. Zero or one flag is set (except those we specifically allow combining // with others). // 2. Both kMinSize and kHasPrefix are set. - const uint8_t combo_allowed_flags = - kSliceContainerSize.flags_ | kFormalMemberType.flags_; + const uint8_t combo_allowed_flags = kSliceContainerSize.flags_ | + kFormalMemberType.flags_ | + kFormalFunctionType.flags_; CHECK((flags_ & (flags_ - 1) & ~combo_allowed_flags) == 0 || flags_ == (kMinSize.flags_ | kHasPrefix.flags_)); } diff --git a/xls/dslx/type_system_v2/inference_table_converter_impl.cc b/xls/dslx/type_system_v2/inference_table_converter_impl.cc index 2a2e9c1f46..5a8598cf96 100644 --- a/xls/dslx/type_system_v2/inference_table_converter_impl.cc +++ b/xls/dslx/type_system_v2/inference_table_converter_impl.cc @@ -529,6 +529,8 @@ class InferenceTableConverterImpl : public InferenceTableConverter, ft_annotation = ExpandVarargs(module_, ft_annotation, actual_args.size()); } + table_.SetAnnotationFlag(ft_annotation, + TypeInferenceFlag::kFormalFunctionType); XLS_RETURN_IF_ERROR( table_.SetTypeAnnotation(invocation->callee(), ft_annotation)); } @@ -715,6 +717,8 @@ class InferenceTableConverterImpl : public InferenceTableConverter, CreateFunctionTypeAnnotation(module_, *function), value_exprs, invocation_context->self_type(), /*clone_if_no_parametrics=*/true)); + table_.SetAnnotationFlag(parametric_free_type, + TypeInferenceFlag::kFormalFunctionType); XLS_RETURN_IF_ERROR( ConvertSubtree(parametric_free_type, caller, caller_context)); @@ -739,7 +743,8 @@ class InferenceTableConverterImpl : public InferenceTableConverter, parametric_free_function_type = down_cast(parametric_free_type); - + table_.SetAnnotationFlag(parametric_free_function_type, + TypeInferenceFlag::kFormalFunctionType); invocation_context->SetParametricFreeFunctionType( parametric_free_function_type); } diff --git a/xls/dslx/type_system_v2/populate_table_visitor.cc b/xls/dslx/type_system_v2/populate_table_visitor.cc index e9da5755ba..a724dc3c12 100644 --- a/xls/dslx/type_system_v2/populate_table_visitor.cc +++ b/xls/dslx/type_system_v2/populate_table_visitor.cc @@ -1588,6 +1588,8 @@ class PopulateInferenceTableVisitor : public PopulateTableVisitor, std::get(node->explicit_parametrics().front()); auto* fn_type = module_.Make( /*param_types=*/std::vector{}, ret_type); + table_.SetAnnotationFlag(fn_type, + TypeInferenceFlag::kFormalFunctionType); XLS_RETURN_IF_ERROR(table_.SetTypeAnnotation(node, fn_type)); return DefaultHandler(node); } @@ -1654,6 +1656,8 @@ class PopulateInferenceTableVisitor : public PopulateTableVisitor, if (node->type_annotation()) { XLS_RETURN_IF_ERROR(table_.SetTypeAnnotation(node->values()[0].value, node->type_annotation())); + table_.SetAnnotationFlag(node->type_annotation(), + TypeInferenceFlag::kFormalMemberType); } }