Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions xls/codegen/vast/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,24 @@ cc_library(
hdrs = ["dslx_type_fixer.h"],
deps = [
"//xls/common:casts",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"//xls/dslx:errors",
"//xls/dslx:import_data",
"//xls/dslx/frontend:ast",
"//xls/dslx/frontend:ast_cloner",
"//xls/dslx/frontend:ast_node",
"//xls/dslx/frontend:module",
"//xls/dslx/frontend:pos",
"//xls/dslx/type_system:type",
"//xls/dslx/type_system:type_info",
"//xls/dslx/type_system_v2:inference_table",
"//xls/dslx/type_system_v2:type_annotation_utils",
"//xls/dslx/type_system_v2:type_inference_error_handler",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
],
)

Expand All @@ -187,7 +197,6 @@ cc_library(
"//xls/dslx:create_import_data",
"//xls/dslx:import_data",
"//xls/dslx:import_routines",
"//xls/dslx:interp_bindings",
"//xls/dslx:interp_value",
"//xls/dslx:parse_and_typecheck",
"//xls/dslx:virtualizable_file_system",
Expand All @@ -205,13 +214,10 @@ cc_library(
"//xls/dslx/frontend:parser",
"//xls/dslx/frontend:pos",
"//xls/dslx/frontend:scanner",
"//xls/dslx/type_system:deduce",
"//xls/dslx/type_system:deduce_ctx",
"//xls/dslx/type_system:type",
"//xls/dslx/ir_convert:convert_options",
"//xls/dslx/type_system:type_info",
"//xls/dslx/type_system:typecheck_function",
"//xls/dslx/type_system:typecheck_invocation",
"//xls/dslx/type_system:typecheck_module",
"//xls/dslx/type_system_v2:type_inference_error_handler",
"//xls/dslx/type_system_v2:typecheck_module_v2",
"//xls/ir:bits",
"//xls/ir:bits_ops",
"//xls/ir:format_preference",
Expand Down Expand Up @@ -247,10 +253,7 @@ cc_library(
"//xls/dslx/frontend:module",
"//xls/dslx/frontend:pos",
"//xls/dslx/frontend:token",
"//xls/dslx/type_system:deduce_ctx",
"//xls/dslx/type_system:type",
"//xls/dslx/type_system:type_info",
"//xls/dslx/type_system:unwrap_meta_type",
"//xls/ir:source_location",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
Expand Down
138 changes: 58 additions & 80 deletions xls/codegen/vast/dslx_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,11 @@
#include "xls/dslx/import_data.h"
#include "xls/dslx/import_routines.h"
#include "xls/dslx/interp_value.h"
#include "xls/dslx/ir_convert/convert_options.h"
#include "xls/dslx/parse_and_typecheck.h"
#include "xls/dslx/type_system/deduce.h"
#include "xls/dslx/type_system/deduce_ctx.h"
#include "xls/dslx/type_system/type.h"
#include "xls/dslx/type_system/type_info.h"
#include "xls/dslx/type_system/typecheck_function.h"
#include "xls/dslx/type_system/typecheck_invocation.h"
#include "xls/dslx/type_system/typecheck_module.h"
#include "xls/dslx/type_system_v2/type_inference_error_handler.h"
#include "xls/dslx/type_system_v2/typecheck_module_v2.h"
#include "xls/dslx/virtualizable_file_system.h"
#include "xls/dslx/warning_collector.h"
#include "xls/dslx/warning_kind.h"
Expand Down Expand Up @@ -293,13 +290,7 @@ DslxBuilder::DslxBuilder(
resolver_(resolver),
warnings_(warnings),
type_info_(GetTypeInfoOrDie(import_data_, &module_)),
deduce_ctx_(type_info_, &module_, dslx::Deduce, &dslx::TypecheckFunction,
/*typecheck_module=*/nullptr, dslx::TypecheckInvocation,
&import_data_, &warnings_,
/*parent=*/nullptr),
vast_type_map_(vast_type_map) {
deduce_ctx_.fn_stack().push_back(dslx::FnStackEntry::MakeTop(&module_));
}
vast_type_map_(vast_type_map) {}

absl::StatusOr<dslx::Expr*> DslxBuilder::MakeNameRefAndCast(
verilog::Expression* expr, const dslx::Span& span, std::string_view name,
Expand Down Expand Up @@ -343,31 +334,23 @@ absl::StatusOr<dslx::ConstantDef*> DslxBuilder::HandleConstantDecl(
verilog::Parameter* parameter, std::string_view name, dslx::Expr* expr) {
auto* name_def =
resolver_->MakeNameDef(*this, span, name, parameter, vast_module);
auto* constant_def = module().Make<dslx::ConstantDef>(
span, name_def, /*type_annotation=*/nullptr, expr,
/*is_public=*/true);

dslx::TypeAnnotation* type_annotation = nullptr;
if (parameter->def() && parameter->def()->data_type() &&
parameter->def()->data_type()->IsUserDefined()) {
XLS_ASSIGN_OR_RETURN(
type_annotation,
VastTypeToDslxTypeForCast(dslx::Span(), parameter->def()->data_type()));
}

auto* constant_def =
module().Make<dslx::ConstantDef>(span, name_def, type_annotation, expr,
/*is_public=*/true);
name_def->set_definer(constant_def);

XLS_RETURN_IF_ERROR(
module().AddTop(constant_def, /*make_collision_error=*/nullptr));

// Note: historically, the propagation of errors was incorrectly dropped by
// the caller, hiding the fact that the logic from here down is a nice-to-have
// and historically fails e.g. in ConcatOfEnums.
auto deduced_type = deduce_ctx().Deduce(expr);
absl::StatusOr<dslx::InterpValue> value =
InterpretExpr(import_data(), type_info(), expr);
if (deduced_type.ok() && value.ok()) {
bindings_.AddValue(name_def->identifier(), *value);
import_data_.GetOrCreateTopLevelBindings(&module()).AddValue(
name_def->identifier(), *value);
}

auto def_deduced_type = deduce_ctx().Deduce(constant_def);
if (!def_deduced_type.ok()) {
VLOG(2) << "Failed to deduce constant def type: "
<< def_deduced_type.status();
}
// Add a comment with the value, if it is not obvious and can be folded.
if (parameter->rhs() != nullptr && !parameter->rhs()->IsLiteral()) {
absl::StatusOr<int64_t> folded_value =
Expand Down Expand Up @@ -503,9 +486,9 @@ absl::StatusOr<dslx::Import*> DslxBuilder::GetOrImportModule(
dslx::DoImport(
[this](std::unique_ptr<dslx::Module> module,
std::filesystem::path path) {
return dslx::TypecheckModule(std::move(module),
path, &import_data_,
&warnings_);
return dslx::TypecheckModuleV2(
std::move(module), path, &import_data_,
&warnings_, nullptr, nullptr);
},
import_tokens, &import_data_, dslx::Span::Fake(),
import_data_.vfs()));
Expand All @@ -518,11 +501,9 @@ absl::StatusOr<dslx::Import*> DslxBuilder::GetOrImportModule(
name_def->set_definer(import);
XLS_RETURN_IF_ERROR(
module().AddTop(import, /*make_collision_error=*/nullptr));
deduce_ctx().type_info()->AddImport(import, &mod_info->module(),
mod_info->type_info());
type_info_->AddImport(import, &mod_info->module(), mod_info->type_info());
import_data_.GetOrCreateTopLevelBindings(&module()).AddModule(
tail, &mod_info->module());
bindings_.AddModule(tail, &mod_info->module());
}

std::optional<dslx::ModuleMember*> std_or = module().FindMemberWithName(tail);
Expand All @@ -541,53 +522,35 @@ absl::StatusOr<dslx::Import*> DslxBuilder::GetOrImportModule(

absl::StatusOr<dslx::Expr*> DslxBuilder::CastToInferredVastType(
verilog::Expression* vast_expr, dslx::Expr* expr,
bool cast_enum_to_builtin) {
bool force_cast_user_defined) {
XLS_ASSIGN_OR_RETURN(verilog::DataType * vast_type,
GetVastDataType(vast_expr));
return Cast(vast_type, expr, cast_enum_to_builtin);
return Cast(vast_type, expr, force_cast_user_defined);
}

absl::StatusOr<dslx::Expr*> DslxBuilder::Cast(verilog::DataType* vast_type,
dslx::Expr* expr,
bool cast_enum_to_builtin) {
bool force_cast_user_defined) {
if (!vast_type->FlatBitCountAsInt64().ok()) {
VLOG(2) << "Warning: cannot insert a cast of expr: " << expr->ToString()
<< " to type: " << vast_type->Emit(nullptr)
<< " because the VAST type does not have a bit count; this is "
"generally OK if the width is not statically computed.";
return expr;
}
absl::StatusOr<std::unique_ptr<dslx::Type>> deduced_dslx_type =
deduce_ctx().Deduce(expr);
if (!deduced_dslx_type.ok()) {
VLOG(2) << "Warning: Pessimistically inserting a cast of expr: "
<< expr->ToString() << " to type: " << vast_type->Emit(nullptr)
<< " because the DSLX type cannot be deduced. This may happen if a "
"parameter value has a system function call.";
return CastInternal(vast_type, expr);
}
if ((*deduced_dslx_type)->HasEnum()) {
if (cast_enum_to_builtin || !vast_type->IsUserDefined()) {
// DSLX considers enum and values to mismatch operands of an equivalent
// built-in type. VAST type inference will say in that case that they are
// both the generic type, and we need the cast to make DSLX comply.
return CastInternal(vast_type, expr, cast_enum_to_builtin);
}
XLS_ASSIGN_OR_RETURN(dslx::TypeDim deduced_dslx_dim,
(*deduced_dslx_type)->GetTotalBitCount());
XLS_ASSIGN_OR_RETURN(int64_t deduced_dslx_bit_count,
deduced_dslx_dim.GetAsInt64());
XLS_ASSIGN_OR_RETURN(int64_t verilog_bit_count,
vast_type->FlatBitCountAsInt64());
if (deduced_dslx_bit_count != verilog_bit_count) {
return Cast(vast_type, expr);
if (vast_type->IsUserDefined()) {
const auto* typedef_type =
dynamic_cast<const verilog::TypedefType*>(vast_type);
// Integer type aliases should really be casted, but enums should not.
if (typedef_type &&
!typedef_type->type_def()->data_type()->IsUserDefined()) {
return CastInternal(vast_type, expr);
}
absl::StatusOr<bool> deduced_dslx_signed =
dslx::IsSigned(**deduced_dslx_type);
if (!deduced_dslx_signed.ok() ||
*deduced_dslx_signed != vast_type->is_signed()) {
return Cast(vast_type, expr);

if (force_cast_user_defined) {
return CastInternal(vast_type, expr);
}

return expr;
}

Expand Down Expand Up @@ -687,14 +650,18 @@ absl::StatusOr<verilog::DataType*> DslxBuilder::GetVastDataType(

absl::StatusOr<dslx::TypecheckedModule> DslxBuilder::RoundTrip(
const dslx::Module& module, std::string_view path,
dslx::ImportData& import_data) {
dslx::ImportData& import_data,
dslx::TypeInferenceErrorHandler error_handler) {
const std::string text = module.ToString();
dslx::Fileno fileno = import_data.file_table().GetOrCreate(path);
dslx::Scanner scanner(import_data.file_table(), fileno, text);
dslx::Parser parser(module_.name(), &scanner);
XLS_ASSIGN_OR_RETURN(
dslx::TypecheckedModule parsed_module,
ParseAndTypecheck(text, path, module_.name(), &import_data),
ParseAndTypecheck(text, path, module_.name(), &import_data,
/*comments=*/nullptr,
/*force_version=*/dslx::TypeInferenceVersion::kVersion2,
/*options=*/dslx::ConvertOptions{}, error_handler),
_ << "Failed to parse and typecheck module:\n"
<< text);
return parsed_module;
Expand Down Expand Up @@ -726,22 +693,33 @@ absl::StatusOr<std::string> DslxBuilder::FormatModule() {
// inference, like removal of dead casts.
dslx::ImportData initial_import_data = CreateImportData();
const std::string file_name = module_.name() + ".x";
std::unique_ptr<DslxTypeFixer> fixer =
CreateDslxTypeFixer(module_, import_data_);
XLS_ASSIGN_OR_RETURN(dslx::TypecheckedModule initial_module,
RoundTrip(module_, file_name, initial_import_data));
std::unique_ptr<DslxTypeFixer> fixer = CreateDslxTypeFixer();
RoundTrip(module_, file_name, initial_import_data,
fixer->GetErrorHandler()));
XLS_ASSIGN_OR_RETURN(
std::unique_ptr<dslx::Module> module_with_stripped_casts,
std::unique_ptr<dslx::Module> module_with_errors_fixed,
CloneModule(*initial_module.module,
fixer->GetReplacer(initial_module.type_info)));
fixer->GetErrorFixReplacer(initial_module.type_info)));

dslx::ImportData fix_import_data = CreateImportData();
XLS_ASSIGN_OR_RETURN(
dslx::TypecheckedModule fixed_and_typechecked_module,
RoundTrip(*module_with_errors_fixed, file_name, fix_import_data));
XLS_ASSIGN_OR_RETURN(
std::unique_ptr<dslx::Module> simplified_module,
CloneModule(
*fixed_and_typechecked_module.module,
fixer->GetSimplifyReplacer(fixed_and_typechecked_module.type_info)));

// We now need to round-trip the module to text and back to AST, without the
// comments, in order for the nodes to get spans accurately representing the
// DSLX as opposed to the source Verilog. We then position the comments
// relative to the appropriate spans.
dslx::ImportData import_data = CreateImportData();
XLS_ASSIGN_OR_RETURN(
dslx::TypecheckedModule parsed_module,
RoundTrip(*module_with_stripped_casts, file_name, import_data));
XLS_ASSIGN_OR_RETURN(dslx::TypecheckedModule parsed_module,
RoundTrip(*simplified_module, file_name, import_data));

std::vector<dslx::CommentData> comment_data;
for (const auto& [type_name, comment] : type_def_comments_) {
Expand Down
15 changes: 5 additions & 10 deletions xls/codegen/vast/dslx_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@
#include "xls/dslx/frontend/module.h"
#include "xls/dslx/frontend/pos.h"
#include "xls/dslx/import_data.h"
#include "xls/dslx/interp_bindings.h"
#include "xls/dslx/interp_value.h"
#include "xls/dslx/parse_and_typecheck.h"
#include "xls/dslx/type_system/deduce_ctx.h"
#include "xls/dslx/type_system/type_info.h"
#include "xls/dslx/type_system_v2/type_inference_error_handler.h"
#include "xls/dslx/warning_collector.h"
#include "xls/ir/bits.h"
#include "xls/ir/format_preference.h"
Expand Down Expand Up @@ -143,14 +142,14 @@ class DslxBuilder {
// operands).
absl::StatusOr<dslx::Expr*> CastToInferredVastType(
verilog::Expression* vast_expr, dslx::Expr* expr,
bool cast_enum_to_builtin = false);
bool force_cast_user_defined = false);

// Returns `expr` casted to the equivalent of the specified `vast_type`. If
// `cast_enum_to_builtin` is true, then the corresponding DSLX built-in type
// will be used for any VAST enum type.
absl::StatusOr<dslx::Expr*> Cast(verilog::DataType* vast_type,
dslx::Expr* expr,
bool cast_enum_to_builtin = false);
bool force_cast_user_defined = false);

dslx::Unop* HandleUnaryOperator(const dslx::Span& span,
dslx::UnopKind unop_kind, dslx::Expr* arg);
Expand Down Expand Up @@ -183,8 +182,6 @@ class DslxBuilder {
absl::StatusOr<std::string> FormatModule();

dslx::ImportData& import_data() { return import_data_; }
dslx::DeduceCtx& deduce_ctx() { return deduce_ctx_; }
dslx::TypeInfo& type_info() { return *type_info_; }
dslx::Module& module() { return module_; }
dslx::FileTable& file_table() { return import_data_.file_table(); }

Expand All @@ -206,7 +203,8 @@ class DslxBuilder {

absl::StatusOr<dslx::TypecheckedModule> RoundTrip(
const dslx::Module& module, std::string_view path,
dslx::ImportData& import_data);
dslx::ImportData& import_data,
dslx::TypeInferenceErrorHandler error_handler = nullptr);

dslx::ImportData CreateImportData();

Expand All @@ -219,9 +217,6 @@ class DslxBuilder {
dslx::WarningCollector warnings_;
dslx::TypeInfo* const type_info_;

dslx::DeduceCtx deduce_ctx_;
dslx::InterpBindings bindings_;

const absl::flat_hash_map<verilog::Expression*, verilog::DataType*>&
vast_type_map_;

Expand Down
Loading
Loading