Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xls/dslx/ir_convert/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ cc_library(
"//xls/dslx/frontend:module",
"//xls/dslx/frontend:proc_id",
"//xls/dslx/type_system:parametric_env",
"//xls/dslx/type_system:type",
"//xls/dslx/type_system:type_info",
"//xls/public:status_macros",
"@com_google_absl//absl/log",
Expand Down
9 changes: 5 additions & 4 deletions xls/dslx/ir_convert/conversion_record.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,11 @@ std::string ConversionRecord::ToString() const {
config = config_record_->ToString();
}
return absl::StrFormat(
"ConversionRecord{m=%s, f=%s, top=%s, pid=%s, parametric_env=%s, "
"type_info=%p, config=%s}",
module_->name(), f_->identifier(), is_top_ ? "true" : "false", proc_id,
parametric_env_.ToString(), type_info_, config);
"ConversionRecord{m=%s, invocation=%p, f=%s, top=%s, pid=%s, "
"parametric_env=%s, type_info=%p, config=%s}",
module_->name(), invocation_, f_->identifier(),
is_top_ ? "true" : "false", proc_id, parametric_env_.ToString(),
type_info_, config);
}

} // namespace xls::dslx
114 changes: 100 additions & 14 deletions xls/dslx/ir_convert/get_conversion_records.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,6 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
} else {
VLOG(5) << "Processing fn " << f->ToString();
}
// TODO: davidplass - change this to gather invocations from *spawns*
// instead of *functions*. This will allow function_converter to emit
// multiple IR procs with the same parametrics but different config
// parameters. Then, HandleFunction would only have to handle procs
// that are not spawned explicitly, like test or top procs.

// Note, it's possible there is no config invocation if it's a
// top proc or some other reason.
std::unique_ptr<ConversionRecord> config_record;
Expand Down Expand Up @@ -119,21 +113,60 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
return cr;
}

// TODO: davidplass - whenever a parametric proc spawns another parametric
// proc, for every unique invocation of the parent proc we have to add the
// child proc to the list as well, recursively.
absl::Status HandleSpawn(const Spawn* spawn) override {
Invocation* invocation = spawn->config();
auto root_invocation_data = type_info_->GetRootInvocationData(invocation);
XLS_RET_CHECK(root_invocation_data.has_value());
const InvocationData* invocation_data = *root_invocation_data;
const Function* config_fn = invocation_data->callee();
XLS_RET_CHECK(config_fn->proc().has_value());
Proc* proc = config_fn->proc().value();
const Function* next_fn = &proc->next();

std::optional<ProcId> proc_id = proc_id_factory_.CreateProcId(
/*parent=*/std::nullopt, proc,
/*count_as_new_instance=*/false);

std::vector<InvocationCalleeData> calls =
type_info_->GetUniqueInvocationCalleeData(next_fn);
// Look at these calls and find the one with the
// (caller) parametric env that matches the invocation_datum.
for (auto& callee_data : calls) {
for (auto& [caller_bindings, invocation_datum] :
invocation_data->env_to_callee_data()) {
if (callee_data.caller_bindings == caller_bindings) {
XLS_ASSIGN_OR_RETURN(
ConversionRecord cr,
InvocationToConversionRecord(
next_fn, callee_data.invocation,
callee_data.derived_type_info,
invocation_datum.callee_bindings,
invocation_datum.caller_bindings,
// Since this proc is being spawned, it's certainly not top.
/* is_top= */ false, proc_id));
records_.push_back(std::move(cr));
break;
}
}
}
return absl::OkStatus();
}

absl::Status AddFunction(const Function* f) {
std::optional<ProcId> proc_id;
if (f->proc().has_value()) {
proc_id = proc_id_factory_.CreateProcId(
/*parent=*/std::nullopt, f->proc().value(),
// TODO: davidplass - For parametric procs we have to decide if this
// is a new instance if it has been called with the same parametrics
// before. Otherwise it needs a new procid.
/*count_as_new_instance=*/false);
}
std::vector<InvocationCalleeData> calls =
type_info_->GetUniqueInvocationCalleeData(f);
if (f->IsParametric() && calls.empty()) {
VLOG(5) << "No calls to parametric proc " << f->name_def()->ToString();
return absl::OkStatus();
return DefaultHandler(f);
}
for (auto& callee_data : calls) {
XLS_ASSIGN_OR_RETURN(
Expand Down Expand Up @@ -184,7 +217,19 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
}

absl::Status HandleProc(const Proc* p) override {
return AddFunction(&p->next());
// Process config so it can use the spawns to identify the dependent procs
// to convert.
XLS_RETURN_IF_ERROR(DefaultHandler(&p->config()));

if (top_ == &p->next() || !p->IsParametric()) {
// "top" procs won't have spawns referencing them so they won't
// otherwise be added to the list, so we have to manually do it here.

// Similarly, if a proc is not parametric, while it might not have any
// spawns, we still want to convert it.
return AddFunction(&p->next());
}
return absl::OkStatus();
}

absl::Status HandleTestProc(const TestProc* tp) override {
Expand Down Expand Up @@ -215,6 +260,39 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {

} // namespace

// This function removes duplicate conversion records from a list.
// The input list is modified.
void RemoveFunctionDuplicates(std::vector<ConversionRecord>& ready) {
for (auto iter_func = ready.begin(); iter_func != ready.end(); iter_func++) {
const ConversionRecord& function_cr = *iter_func;
for (auto iter_subject = iter_func + 1; iter_subject != ready.end();) {
const ConversionRecord& subject_cr = *iter_subject;

if (function_cr.f() == subject_cr.f()) {
bool either_is_parametric =
function_cr.f()->IsParametric() || subject_cr.f()->IsParametric();
// If neither are parametric, then function identity comparison is
// a sufficient test to eliminate detected duplicates.
if (!either_is_parametric) {
iter_subject = ready.erase(iter_subject);
continue;
}

// If the functions are the same and they have the same parametric
// environment, eliminate any duplicates.
bool both_are_parametric =
function_cr.f()->IsParametric() && subject_cr.f()->IsParametric();
if (both_are_parametric &&
function_cr.parametric_env() == subject_cr.parametric_env()) {
iter_subject = ready.erase(iter_subject);
continue;
}
}
iter_subject++;
}
}
}

absl::StatusOr<std::vector<ConversionRecord>> GetConversionRecords(
Module* module, TypeInfo* type_info, bool include_tests) {
ProcIdFactory proc_id_factory;
Expand All @@ -224,7 +302,9 @@ absl::StatusOr<std::vector<ConversionRecord>> GetConversionRecords(
proc_id_factory, /*top=*/nullptr);
XLS_RETURN_IF_ERROR(module->Accept(&visitor));

return visitor.records();
std::vector<ConversionRecord> records = visitor.records();
RemoveFunctionDuplicates(records);
return records;
}

absl::StatusOr<std::vector<ConversionRecord>> GetConversionRecordsForEntry(
Expand All @@ -238,7 +318,10 @@ absl::StatusOr<std::vector<ConversionRecord>> GetConversionRecordsForEntry(
ConversionRecordVisitor visitor(m, type_info, /*include_tests=*/true,
proc_id_factory, f);
XLS_RETURN_IF_ERROR(m->Accept(&visitor));
return visitor.records();

std::vector<ConversionRecord> records = visitor.records();
RemoveFunctionDuplicates(records);
return records;
}

Proc* p = std::get<Proc*>(entry);
Expand All @@ -250,6 +333,9 @@ absl::StatusOr<std::vector<ConversionRecord>> GetConversionRecordsForEntry(
ConversionRecordVisitor visitor(m, new_ti, /*include_tests=*/true,
proc_id_factory, &p->next());
XLS_RETURN_IF_ERROR(m->Accept(&visitor));
return visitor.records();

std::vector<ConversionRecord> records = visitor.records();
RemoveFunctionDuplicates(records);
return records;
}
} // namespace xls::dslx
26 changes: 10 additions & 16 deletions xls/dslx/ir_convert/get_conversion_records_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,17 +164,13 @@ proc top {
GetConversionRecords(tm.module, tm.type_info, false));
ASSERT_EQ(3, order.size());
EXPECT_EQ(order[0].f()->identifier(), "P.next");
EXPECT_EQ(order[0].invocation()->ToString(),
"P.next<u32:2>(P.init<u32:2>())");
const ConversionRecord* config_record = order[0].config_record();
EXPECT_NE(config_record, nullptr);
EXPECT_EQ(config_record->invocation()->ToString(), "P.config<u32:2>(u2:1)");
EXPECT_EQ(order[0].parametric_env(),
ParametricEnv(absl::flat_hash_map<std::string, InterpValue>{
{"N", InterpValue::MakeUBits(/*bit_count=*/32, /*value=*/2)}}));
EXPECT_EQ(order[1].f()->identifier(), "P.next");
EXPECT_EQ(order[1].invocation()->ToString(),
"P.next<u32:4>(P.init<u32:4>())");
config_record = order[1].config_record();
EXPECT_NE(config_record, nullptr);
EXPECT_EQ(config_record->invocation()->ToString(), "P.config<u32:4>(u4:2)");
Expand Down Expand Up @@ -388,8 +384,9 @@ fn my_test() -> bool { f<u32:8>(u8:1) == u32:8 }
XLS_ASSERT_OK_AND_ASSIGN(
TypecheckedModule tm,
ParseAndTypecheck(kProgram, "test.x", "test", &import_data));
XLS_ASSERT_OK_AND_ASSIGN(std::vector<ConversionRecord> order,
GetConversionRecords(tm.module, tm.type_info, true));
XLS_ASSERT_OK_AND_ASSIGN(
std::vector<ConversionRecord> order,
GetConversionRecords(tm.module, tm.type_info, /*include_tests=*/true));
ASSERT_EQ(2, order.size());
EXPECT_EQ(order[0].f()->identifier(), "f");
EXPECT_EQ(order[0].parametric_env(),
Expand Down Expand Up @@ -440,13 +437,15 @@ proc test {
next(x: ()) { () }
}
)";

TEST(GetConversionRecordsTest, TestProc) {
auto import_data = CreateImportDataForTest();
XLS_ASSERT_OK_AND_ASSIGN(
TypecheckedModule tm,
ParseAndTypecheck(kTestProc, "test.x", "test", &import_data));
XLS_ASSERT_OK_AND_ASSIGN(std::vector<ConversionRecord> order,
GetConversionRecords(tm.module, tm.type_info, true));
XLS_ASSERT_OK_AND_ASSIGN(
std::vector<ConversionRecord> order,
GetConversionRecords(tm.module, tm.type_info, /*include_tests=*/true));
ASSERT_EQ(2, order.size());
EXPECT_EQ(order[0].f()->identifier(), "P.next");
EXPECT_EQ(order[0].parametric_env(),
Expand All @@ -462,14 +461,9 @@ TEST(GetConversionRecordsTest, TestProcSkipped) {
ParseAndTypecheck(kTestProc, "test.x", "test", &import_data));
XLS_ASSERT_OK_AND_ASSIGN(
std::vector<ConversionRecord> order,
GetConversionRecords(tm.module, tm.type_info, false));
// It still converts the parametric proc because there is still a spawn,
// in the test proc.
ASSERT_EQ(1, order.size());
EXPECT_EQ(order[0].f()->identifier(), "P.next");
EXPECT_EQ(order[0].parametric_env(),
ParametricEnv(absl::flat_hash_map<std::string, InterpValue>{
{"N", InterpValue::MakeUBits(/*bit_count=*/32, /*value=*/4)}}));
GetConversionRecords(tm.module, tm.type_info, /*include_tests=*/false));
// Without processing the test proc, the subject proc isn't found.
ASSERT_EQ(0, order.size());
}

TEST(GetConversionRecordsTest, Quickcheck) {
Expand Down
Loading