Skip to content

Commit 259d031

Browse files
richmckeevercopybara-github
authored andcommitted
Support IR conversion of proc aliases with proc-scoped channels flagged on.
PiperOrigin-RevId: 820795684
1 parent b64c970 commit 259d031

8 files changed

+180
-21
lines changed

xls/dslx/ir_convert/function_converter.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3586,10 +3586,6 @@ absl::Status FunctionConverter::HandleProcNextFunction(
35863586
ScopedTypeInfoSwap stis(this, config_type_info);
35873587

35883588
Function& config_fn = proc->config();
3589-
// This probably was checked already but might as well double-check it.
3590-
XLS_RET_CHECK(invocation != nullptr || !f->IsParametric())
3591-
<< "Cannot lower a parametric proc without an invocation";
3592-
35933589
proc_scoped_channel_scope = std::make_unique<ProcScopedChannelScope>(
35943590
package_data_.conversion_info, import_data, options_, builder_ptr);
35953591
proc_scoped_channel_scope->EnterFunctionContext(current_type_info_,

xls/dslx/ir_convert/get_conversion_records.cc

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,14 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
5555
public:
5656
ConversionRecordVisitor(Module* module, TypeInfo* type_info,
5757
bool include_tests, ProcIdFactory proc_id_factory,
58-
AstNode* top)
58+
AstNode* top,
59+
std::optional<ResolvedProcAlias> resolved_proc_alias)
5960
: module_(module),
6061
type_info_(type_info),
6162
include_tests_(include_tests),
6263
proc_id_factory_(proc_id_factory),
63-
top_(top) {}
64+
top_(top),
65+
resolved_proc_alias_(resolved_proc_alias) {}
6466

6567
absl::StatusOr<ConversionRecord> InvocationToConversionRecord(
6668
const Function* f, const Invocation* invocation,
@@ -235,6 +237,30 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
235237
VLOG(5) << "HandleProc " << p->ToString();
236238
const Function* next_fn = &p->next();
237239

240+
if (top_ == next_fn && resolved_proc_alias_.has_value()) {
241+
ProcId proc_id = proc_id_factory_.CreateProcId(
242+
/*parent=*/std::nullopt, const_cast<Proc*>(p),
243+
/*count_as_new_instance=*/false);
244+
proc_id.alias_name = resolved_proc_alias_->name;
245+
XLS_ASSIGN_OR_RETURN(
246+
ConversionRecord config_record,
247+
MakeConversionRecord(
248+
const_cast<Function*>(&p->config()), top_->owner(),
249+
resolved_proc_alias_->config_type_info, resolved_proc_alias_->env,
250+
proc_id, /*invocation=*/nullptr,
251+
/*is_top=*/false));
252+
XLS_ASSIGN_OR_RETURN(
253+
ConversionRecord next_record,
254+
MakeConversionRecord(
255+
const_cast<Function*>(&p->next()), top_->owner(),
256+
resolved_proc_alias_->next_type_info, resolved_proc_alias_->env,
257+
proc_id, /*invocation=*/nullptr,
258+
/*is_top=*/true,
259+
std::make_unique<ConversionRecord>(std::move(config_record))));
260+
records_.push_back(std::move(next_record));
261+
return absl::OkStatus();
262+
}
263+
238264
// This is required in order to process cross-module spawns; otherwise it
239265
// will never add procs from imported modules to the list of functions to
240266
// convert.
@@ -294,6 +320,9 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
294320
ProcIdFactory proc_id_factory_;
295321
AstNode* top_;
296322

323+
// The proc alias that was used to specify the top proc, if any.
324+
std::optional<ResolvedProcAlias> resolved_proc_alias_;
325+
297326
std::vector<ConversionRecord> records_;
298327
};
299328

@@ -319,23 +348,27 @@ absl::StatusOr<std::vector<ConversionRecord>> GetConversionRecords(
319348
// TODO: https://github.com/google/xls/issues/2078 - properly set
320349
// top instead of setting to nullptr.
321350
ConversionRecordVisitor visitor(module, type_info, include_tests,
322-
proc_id_factory, /*top=*/nullptr);
351+
proc_id_factory, /*top=*/nullptr,
352+
/*resolved_proc_alias=*/std::nullopt);
323353
XLS_RETURN_IF_ERROR(module->Accept(&visitor));
324354

325355
std::vector<ConversionRecord> records = visitor.records();
326356
return RemoveFunctionDuplicates(records);
327357
}
328358

329359
absl::StatusOr<std::vector<ConversionRecord>> GetConversionRecordsForEntry(
330-
std::variant<Proc*, Function*> entry, TypeInfo* type_info) {
360+
std::variant<Proc*, Function*> entry, TypeInfo* type_info,
361+
std::optional<ResolvedProcAlias> resolved_proc_alias) {
331362
ProcIdFactory proc_id_factory;
332363
if (std::holds_alternative<Function*>(entry)) {
364+
XLS_RET_CHECK(!resolved_proc_alias.has_value());
333365
Function* f = std::get<Function*>(entry);
334366
Module* m = f->owner();
335367
// We are only ever called for tests, so we set include_tests to
336368
// true, and make sure that this function is top.
337369
ConversionRecordVisitor visitor(m, type_info, /*include_tests=*/true,
338-
proc_id_factory, f);
370+
proc_id_factory, f,
371+
/*resolved_proc_alias=*/std::nullopt);
339372
XLS_RETURN_IF_ERROR(m->Accept(&visitor));
340373

341374
std::vector<ConversionRecord> records = visitor.records();
@@ -349,7 +382,8 @@ absl::StatusOr<std::vector<ConversionRecord>> GetConversionRecordsForEntry(
349382
// We are only ever called for tests, so we set include_tests to true,
350383
// and make sure that this proc's next function is top.
351384
ConversionRecordVisitor visitor(m, new_ti, /*include_tests=*/true,
352-
proc_id_factory, &p->next());
385+
proc_id_factory, &p->next(),
386+
resolved_proc_alias);
353387
XLS_RETURN_IF_ERROR(m->Accept(&visitor));
354388

355389
std::vector<ConversionRecord> records = visitor.records();

xls/dslx/ir_convert/get_conversion_records.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef XLS_DSLX_IR_CONVERT_GET_CONVERSION_RECORDS_H_
1616
#define XLS_DSLX_IR_CONVERT_GET_CONVERSION_RECORDS_H_
1717

18+
#include <optional>
1819
#include <variant>
1920
#include <vector>
2021

@@ -39,7 +40,8 @@ absl::StatusOr<std::vector<ConversionRecord>> GetConversionRecords(
3940
// entry: Proc or Function to start from (the top)
4041
// type_info: Mapping from node to type.
4142
absl::StatusOr<std::vector<ConversionRecord>> GetConversionRecordsForEntry(
42-
std::variant<Proc*, Function*> entry, TypeInfo* type_info);
43+
std::variant<Proc*, Function*> entry, TypeInfo* type_info,
44+
std::optional<ResolvedProcAlias> resolved_proc_alias);
4345

4446
} // namespace xls::dslx
4547

xls/dslx/ir_convert/ir_converter.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ absl::StatusOr<std::vector<ConversionRecord>> GetConversionRecords(
403403
// lower_to_proc_scoped_channels is turned on everywhere, and call
404404
// GetConversionRecordsForEntry unconditionally.
405405
if (options.lower_to_proc_scoped_channels) {
406-
return GetConversionRecordsForEntry(block, type_info);
406+
return GetConversionRecordsForEntry(block, type_info, resolved_proc_alias);
407407
}
408408
return GetOrderForEntry(block, type_info, resolved_proc_alias);
409409
}

xls/dslx/ir_convert/ir_converter_test.cc

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5877,10 +5877,7 @@ pub proc main {
58775877
XLS_ASSERT_OK_AND_ASSIGN(
58785878
std::string converted,
58795879
ConvertOneFunctionForTest(kProgram, "main", import_data,
5880-
ConvertOptions{
5881-
.emit_positions = false,
5882-
.lower_to_proc_scoped_channels = true,
5883-
}));
5880+
kProcScopedChannelOptions));
58845881
ExpectIr(converted);
58855882
}
58865883

@@ -5911,11 +5908,88 @@ fn f() -> u32 {
59115908

59125909
XLS_ASSERT_OK_AND_ASSIGN(
59135910
std::string converted,
5914-
ConvertModuleForTest(
5915-
importer_program,
5916-
ConvertOptions{.emit_positions = false,
5917-
.lower_to_proc_scoped_channels = true},
5918-
&import_data));
5911+
ConvertModuleForTest(importer_program, kProcScopedChannelOptions,
5912+
&import_data));
5913+
ExpectIr(converted);
5914+
}
5915+
5916+
TEST_P(ProcScopedChannelsIrConverterTest, ProcScopedBasicProcAlias) {
5917+
constexpr std::string_view program = R"(
5918+
proc Foo {
5919+
c: chan<u32> out;
5920+
init { u32:1 }
5921+
config(output_c: chan<u32> out) {
5922+
(output_c,)
5923+
}
5924+
next(i: u32) {
5925+
let tok = send(join(), c, i);
5926+
i + u32:2
5927+
}
5928+
}
5929+
5930+
pub proc FooAlias = Foo;
5931+
)";
5932+
5933+
auto import_data = CreateImportDataForTest();
5934+
XLS_ASSERT_OK_AND_ASSIGN(
5935+
std::string converted,
5936+
ConvertOneFunctionForTest(program, "FooAlias", import_data,
5937+
kProcScopedChannelOptions));
5938+
ExpectIr(converted);
5939+
}
5940+
5941+
TEST_P(ProcScopedChannelsIrConverterTest, ProcScopedParametricProcAlias) {
5942+
constexpr std::string_view program = R"(
5943+
proc Foo<N: u32> {
5944+
c: chan<uN[N]> out;
5945+
init { uN[N]:1 }
5946+
config(output_c: chan<uN[N]> out) {
5947+
(output_c,)
5948+
}
5949+
next(i: uN[N]) {
5950+
let tok = send(join(), c, i);
5951+
i + uN[N]:2
5952+
}
5953+
}
5954+
5955+
pub proc FooAlias = Foo<16>;
5956+
)";
5957+
5958+
auto import_data = CreateImportDataForTest();
5959+
XLS_ASSERT_OK_AND_ASSIGN(
5960+
std::string converted,
5961+
ConvertOneFunctionForTest(program, "FooAlias", import_data,
5962+
kProcScopedChannelOptions));
5963+
ExpectIr(converted);
5964+
}
5965+
5966+
TEST_P(ProcScopedChannelsIrConverterTest, ProcScopedProcAliasToImportedProc) {
5967+
ImportData import_data = CreateImportDataForTest();
5968+
5969+
constexpr std::string_view imported = R"(
5970+
pub proc Foo<N: u32> {
5971+
c: chan<uN[N]> out;
5972+
init { uN[N]:1 }
5973+
config(output_c: chan<uN[N]> out) {
5974+
(output_c,)
5975+
}
5976+
next(i: uN[N]) {
5977+
let tok = send(join(), c, i);
5978+
i + uN[N]:2
5979+
}
5980+
}
5981+
)";
5982+
XLS_EXPECT_OK(
5983+
ParseAndTypecheck(imported, "imported.x", "imported", &import_data));
5984+
5985+
constexpr std::string_view program = R"(
5986+
import imported;
5987+
pub proc FooAlias = imported::Foo<16>;
5988+
)";
5989+
XLS_ASSERT_OK_AND_ASSIGN(
5990+
std::string converted,
5991+
ConvertOneFunctionForTest(program, "FooAlias", import_data,
5992+
kProcScopedChannelOptions));
59195993
ExpectIr(converted);
59205994
}
59215995

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package test_module
2+
3+
file_number 0 "test_module.x"
4+
5+
top proc __test_module__FooAlias_next<output_c: bits[32] out>(__state: bits[32], init={1}) {
6+
chan_interface output_c(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
7+
__state: bits[32] = state_read(state_element=__state, id=2)
8+
literal.8: bits[32] = literal(value=2, id=8)
9+
output_c: bits[32] = send_channel_end(id=4)
10+
after_all.6: token = after_all(id=6)
11+
literal.3: bits[1] = literal(value=1, id=3)
12+
add.9: bits[32] = add(__state, literal.8, id=9)
13+
__token: token = literal(value=token, id=1)
14+
tuple.5: (bits[32]) = tuple(output_c, id=5)
15+
tok: token = send(after_all.6, __state, predicate=literal.3, channel=output_c, id=7)
16+
next_value.10: () = next_value(param=__state, value=add.9, id=10)
17+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package test_module
2+
3+
file_number 0 "test_module.x"
4+
5+
top proc __test_module__FooAlias_next<output_c: bits[16] out>(__state: bits[16], init={1}) {
6+
chan_interface output_c(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
7+
__state: bits[16] = state_read(state_element=__state, id=2)
8+
literal.9: bits[16] = literal(value=2, id=9)
9+
output_c: bits[16] = send_channel_end(id=5)
10+
after_all.7: token = after_all(id=7)
11+
literal.3: bits[1] = literal(value=1, id=3)
12+
add.10: bits[16] = add(__state, literal.9, id=10)
13+
__token: token = literal(value=token, id=1)
14+
N: bits[32] = literal(value=16, id=4)
15+
tuple.6: (bits[16]) = tuple(output_c, id=6)
16+
tok: token = send(after_all.7, __state, predicate=literal.3, channel=output_c, id=8)
17+
next_value.11: () = next_value(param=__state, value=add.10, id=11)
18+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package test_module
2+
3+
file_number 0 "imported.x"
4+
5+
top proc __imported__FooAlias_next<output_c: bits[16] out>(__state: bits[16], init={1}) {
6+
chan_interface output_c(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
7+
__state: bits[16] = state_read(state_element=__state, id=2)
8+
literal.9: bits[16] = literal(value=2, id=9)
9+
output_c: bits[16] = send_channel_end(id=5)
10+
after_all.7: token = after_all(id=7)
11+
literal.3: bits[1] = literal(value=1, id=3)
12+
add.10: bits[16] = add(__state, literal.9, id=10)
13+
__token: token = literal(value=token, id=1)
14+
N: bits[32] = literal(value=16, id=4)
15+
tuple.6: (bits[16]) = tuple(output_c, id=6)
16+
tok: token = send(after_all.7, __state, predicate=literal.3, channel=output_c, id=8)
17+
next_value.11: () = next_value(param=__state, value=add.10, id=11)
18+
}

0 commit comments

Comments
 (0)