Skip to content

Commit 75bedc2

Browse files
NL02copybara-github
authored andcommitted
[Explicit State Access] Allow multiple StateRead nodes to map to a single state element.
Previously State Elements and State Reads had a 1:1 mapping. Explicit state access now permits multiple reads from the same state element across different code paths, such as within If/Else blocks. PiperOrigin-RevId: 895959816
1 parent 668c4fd commit 75bedc2

File tree

3 files changed

+113
-43
lines changed

3 files changed

+113
-43
lines changed

xls/ir/proc.cc

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,15 @@ absl::Status Proc::RemoveStateElement(int64_t index) {
214214
StateElement* old_state_element = GetStateElement(index);
215215
auto old_state_read_it = state_reads_.find(old_state_element);
216216
XLS_RET_CHECK(old_state_read_it != state_reads_.end());
217-
if (!old_state_read_it->second->users().empty()) {
218-
return absl::InvalidArgumentError(absl::StrFormat(
219-
"Cannot remove state element %d of proc %s, existing "
220-
"state read %s has uses",
221-
index, name(), old_state_read_it->second->GetNameView()));
217+
for (StateRead* read : old_state_read_it->second) {
218+
if (!read->users().empty()) {
219+
return absl::InvalidArgumentError(
220+
absl::StrFormat("Cannot remove state element %d of proc %s, existing "
221+
"state read %s has uses",
222+
index, name(), read->GetNameView()));
223+
}
224+
XLS_RETURN_IF_ERROR(RemoveNode(read));
222225
}
223-
XLS_RETURN_IF_ERROR(RemoveNode(old_state_read_it->second));
224226
// TODO(allight): This should ideally not need to be done manually.
225227
state_reads_.erase(old_state_read_it);
226228

@@ -232,11 +234,14 @@ absl::Status Proc::RemoveStateElement(int64_t index) {
232234
absl::Status Proc::RemoveAllStateElements() {
233235
// TODO(allight): This relies on side tables being valid. For now just let it
234236
// go.
235-
for (const auto& [elem, read] : state_reads_) {
236-
if (read != nullptr) {
237-
XLS_RETURN_IF_ERROR(RemoveNode(read))
238-
<< "Cannot remove " << elem->ToString() << " of proc " << name()
239-
<< " because read '" << read->ToString() << "' could not be removed.";
237+
for (const auto& [elem, reads] : state_reads_) {
238+
for (StateRead* read : reads) {
239+
if (read != nullptr) {
240+
XLS_RETURN_IF_ERROR(RemoveNode(read))
241+
<< "Cannot remove " << elem->ToString() << " of proc " << name()
242+
<< " because read '" << read->ToString()
243+
<< "' could not be removed.";
244+
}
240245
}
241246
XLS_RETURN_IF_ERROR(state_name_uniquer_.ReleaseIdentifier(elem->name()))
242247
<< "Cannot release name of " << elem->ToString();
@@ -278,7 +283,7 @@ absl::StatusOr<StateRead*> Proc::InsertStateElement(
278283
MakeNodeWithName<StateRead>(
279284
loc, state_element, read_predicate,
280285
/*label=*/std::nullopt, state_element->name()));
281-
state_reads_[state_element] = state_read;
286+
state_reads_[state_element].push_back(state_read);
282287

283288
if (next_state.has_value()) {
284289
if (!ValueConformsToType(init_value, next_state.value()->GetType())) {
@@ -351,14 +356,13 @@ absl::StatusOr<Proc*> Proc::Clone(
351356
return mapping.at(orig);
352357
};
353358
for (StateElement* state_element : StateElements()) {
354-
StateRead* state_read = state_reads_.at(state_element);
355-
XLS_ASSIGN_OR_RETURN(
356-
StateRead * cloned_state_read,
357-
cloned_proc->AppendStateElement(
358-
remap_name(state_name_remapping, state_element->name()),
359-
state_element->initial_value(), state_read->predicate(),
360-
/*next_state=*/std::nullopt));
361-
original_to_clone[state_read] = cloned_state_read;
359+
XLS_RETURN_IF_ERROR(
360+
cloned_proc
361+
->InsertUnreadStateElement(
362+
cloned_proc->GetStateElementCount(),
363+
remap_name(state_name_remapping, state_element->name()),
364+
state_element->initial_value())
365+
.status());
362366
}
363367
if (is_new_style_proc()) {
364368
absl::flat_hash_map<ChannelInterface*, ChannelInterface*> channel_map;
@@ -445,7 +449,23 @@ absl::StatusOr<Proc*> Proc::Clone(
445449

446450
switch (node->op()) {
447451
case Op::kStateRead: {
448-
continue;
452+
StateRead* src = node->As<StateRead>();
453+
StateElement* src_elem = src->state_element();
454+
XLS_ASSIGN_OR_RETURN(int64_t idx, GetStateElementIndex(src_elem));
455+
StateElement* cloned_elem = cloned_proc->GetStateElement(idx);
456+
457+
std::optional<Node*> cloned_predicate;
458+
if (src->predicate().has_value()) {
459+
cloned_predicate = original_to_clone.at(src->predicate().value());
460+
}
461+
462+
XLS_ASSIGN_OR_RETURN(StateRead * cloned_state_read,
463+
cloned_proc->MakeNodeWithName<StateRead>(
464+
src->loc(), cloned_elem, cloned_predicate,
465+
/*label=*/std::nullopt, cloned_elem->name()));
466+
cloned_proc->state_reads_[cloned_elem].push_back(cloned_state_read);
467+
original_to_clone[node] = cloned_state_read;
468+
break;
449469
}
450470
case Op::kReceive: {
451471
Receive* src = node->As<Receive>();
@@ -1000,10 +1020,8 @@ absl::Status Proc::InternalRebuildSideTables() {
10001020
state_reads_.clear();
10011021
for (Node* n : nodes()) {
10021022
if (n->Is<StateRead>()) {
1003-
XLS_RET_CHECK(!state_reads_.contains(n->As<StateRead>()->state_element()))
1004-
<< "Duplicate state element read: "
1005-
<< n->As<StateRead>()->state_element();
1006-
state_reads_[n->As<StateRead>()->state_element()] = n->As<StateRead>();
1023+
state_reads_[n->As<StateRead>()->state_element()].push_back(
1024+
n->As<StateRead>());
10071025
} else if (n->Is<Next>()) {
10081026
next_values_.push_back(n->As<Next>());
10091027
next_values_by_state_element_[n->As<Next>()->state_element()].insert(

xls/ir/proc.h

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,23 @@ class Proc : public FunctionBase {
9595
return state_elements_.contains(name);
9696
}
9797

98+
// Remove legacy getters after all downstream passes migrate logic.
9899
StateRead* GetStateRead(int64_t index) const {
99-
return state_reads_.at(GetStateElement(index));
100+
return GetStateReads(index).front();
100101
}
102+
101103
StateRead* GetStateReadByStateElement(StateElement* state_element) const {
104+
return GetStateReadsByStateElement(state_element).front();
105+
}
106+
107+
// Get state reads for a state element at the given index.
108+
absl::Span<StateRead* const> GetStateReads(int64_t index) const {
109+
return state_reads_.at(GetStateElement(index));
110+
}
111+
112+
// Get state reads for a state element.
113+
absl::Span<StateRead* const> GetStateReadsByStateElement(
114+
StateElement* state_element) const {
102115
return state_reads_.at(state_element);
103116
}
104117

@@ -403,8 +416,8 @@ class Proc : public FunctionBase {
403416
absl::flat_hash_map<std::string, std::unique_ptr<StateElement>>
404417
state_elements_;
405418

406-
// Map of the unique StateRead node for each state element.
407-
absl::flat_hash_map<StateElement*, StateRead*> state_reads_;
419+
// Map of StateRead nodes for each state element.
420+
absl::flat_hash_map<StateElement*, std::vector<StateRead*>> state_reads_;
408421

409422
// Vector of state element pointers. Kept in sync with the state_elements_
410423
// map. Enables easy, stable iteration over state elements. With this vector,

xls/ir/proc_test.cc

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,43 @@ TEST_F(ProcTest, StatelessProc) {
205205
EXPECT_EQ(proc->DumpIr(), "proc p() {\n}\n");
206206
}
207207

208+
TEST_F(ProcTest, MultipleStateReads) {
209+
auto p = CreatePackage();
210+
ProcBuilder pb("p", p.get());
211+
BValue tkn = pb.StateElement("tkn", Value::Token());
212+
BValue state = pb.StateElement("x", Value(UBits(42, 32)));
213+
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build({tkn, state}));
214+
215+
StateElement* state_elem = proc->GetStateElement(1);
216+
217+
EXPECT_EQ(proc->GetStateReads(1).size(), 1);
218+
StateRead* read1 = proc->GetStateReads(1).front();
219+
220+
// Second Read
221+
XLS_ASSERT_OK_AND_ASSIGN(
222+
StateRead * read2,
223+
proc->MakeNodeWithName<StateRead>(SourceInfo(), state_elem,
224+
/*predicate=*/std::nullopt,
225+
/*label=*/std::nullopt, "x_read2"));
226+
XLS_ASSERT_OK(proc->RebuildSideTables());
227+
228+
EXPECT_EQ(proc->GetStateReads(1).size(), 2);
229+
EXPECT_THAT(proc->GetStateReads(1), ElementsAre(read1, read2));
230+
231+
EXPECT_EQ(proc->GetStateReadsByStateElement(state_elem).size(), 2);
232+
EXPECT_THAT(proc->GetStateReadsByStateElement(state_elem),
233+
ElementsAre(read1, read2));
234+
235+
// Remove the second read.
236+
std::string read2_name = read2->GetName();
237+
XLS_ASSERT_OK(proc->RemoveNode(read2));
238+
XLS_ASSERT_OK(proc->RebuildSideTables());
239+
240+
// Now we should have 1 read again.
241+
EXPECT_EQ(proc->GetStateReads(1).size(), 1);
242+
EXPECT_EQ(proc->GetStateReads(1).front(), read1);
243+
}
244+
208245
TEST_F(ProcTest, RemoveStateThatStillHasUse) {
209246
// Don't call CreatePackage which creates a VerifiedPackage because we
210247
// intentionally create a malformed proc.
@@ -254,10 +291,10 @@ TEST_F(ProcTest, Clone) {
254291
EXPECT_EQ(clone->DumpIr(),
255292
R"(proc cloned(tkn: token, state: bits[32], init={token, 42}) {
256293
tkn: token = state_read(state_element=tkn, id=12)
257-
literal.14: bits[32] = literal(value=1, id=14)
258-
state: bits[32] = state_read(state_element=state, id=13)
294+
literal.13: bits[32] = literal(value=1, id=13)
295+
state: bits[32] = state_read(state_element=state, id=14)
259296
receive_3: (token, bits[32]) = receive(tkn, channel=cloned_chan, id=15)
260-
add.16: bits[32] = add(literal.14, state, id=16)
297+
add.16: bits[32] = add(literal.13, state, id=16)
261298
tuple_index.17: bits[32] = tuple_index(receive_3, index=1, id=17)
262299
tuple_index.18: token = tuple_index(receive_3, index=0, id=18)
263300
add.19: bits[32] = add(add.16, tuple_index.17, id=19)
@@ -304,10 +341,10 @@ proc cloned<input_chan: bits[32] in, chan: bits[32] out>(tkn: token, state: bits
304341
chan_interface input_chan(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
305342
chan_interface chan(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
306343
tkn: token = state_read(state_element=tkn, id=1)
307-
literal.3: bits[32] = literal(value=1, id=3)
308-
state: bits[32] = state_read(state_element=state, id=2)
344+
literal.2: bits[32] = literal(value=1, id=2)
345+
state: bits[32] = state_read(state_element=state, id=3)
309346
receive_3: (token, bits[32]) = receive(tkn, channel=input_chan, id=4)
310-
add.5: bits[32] = add(literal.3, state, id=5)
347+
add.5: bits[32] = add(literal.2, state, id=5)
311348
tuple_index.6: bits[32] = tuple_index(receive_3, index=1, id=6)
312349
tuple_index.7: token = tuple_index(receive_3, index=0, id=7)
313350
add.8: bits[32] = add(add.5, tuple_index.6, id=8)
@@ -355,15 +392,15 @@ TEST_F(ProcTest, CloneNewStyle) {
355392
chan baz(bits[32], id=0, kind=streaming, ops=send_receive, flow_control=ready_valid, strictness=proven_mutually_exclusive)
356393
chan_interface baz(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=none, flop_kind=none)
357394
chan_interface baz(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=none, flop_kind=none)
358-
tkn: token = literal(value=token, id=14)
359-
receive_3: (token, bits[32]) = receive(tkn, channel=foo, id=15)
360-
tuple_index.16: token = tuple_index(receive_3, index=0, id=16)
361-
receive_6: (token, bits[32]) = receive(tuple_index.16, channel=baz, id=17)
362-
tuple_index.18: token = tuple_index(receive_6, index=0, id=18)
363-
state: bits[32] = state_read(state_element=state, id=13)
395+
tkn: token = literal(value=token, id=13)
396+
receive_3: (token, bits[32]) = receive(tkn, channel=foo, id=14)
397+
tuple_index.15: token = tuple_index(receive_3, index=0, id=15)
398+
receive_6: (token, bits[32]) = receive(tuple_index.15, channel=baz, id=16)
399+
tuple_index.17: token = tuple_index(receive_6, index=0, id=17)
400+
state: bits[32] = state_read(state_element=state, id=18)
364401
tuple_index.19: bits[32] = tuple_index(receive_3, index=1, id=19)
365402
tuple_index.20: bits[32] = tuple_index(receive_6, index=1, id=20)
366-
send_9: token = send(tuple_index.18, state, channel=bar, id=21)
403+
send_9: token = send(tuple_index.17, state, channel=bar, id=21)
367404
add.22: bits[32] = add(tuple_index.19, tuple_index.20, id=22)
368405
send_10: token = send(send_9, state, channel=baz, id=23)
369406
next_value.24: () = next_value(param=state, value=add.22, id=24)
@@ -556,7 +593,8 @@ TEST_F(ScheduledProcTest, StageAddAndClear) {
556593
proc->ClearStages();
557594
EXPECT_TRUE(proc->stages().empty());
558595
// Re-stage the state element to satisfy the verifier.
559-
XLS_ASSERT_OK(proc->AddNodeToStage(0, proc->GetStateRead(0)).status());
596+
XLS_ASSERT_OK(
597+
proc->AddNodeToStage(0, proc->GetStateReads(0).front()).status());
560598
}
561599

562600
TEST_F(ScheduledProcTest, AddEmptyStages) {
@@ -596,7 +634,8 @@ TEST_F(ScheduledProcTest, GetStageIndex) {
596634
EXPECT_THAT(proc->GetStageIndex(x), IsOkAndHolds(1));
597635
EXPECT_THAT(proc->GetStageIndex(y), IsOkAndHolds(2));
598636
EXPECT_THAT(proc->GetStageIndex(add), StatusIs(absl::StatusCode::kNotFound));
599-
EXPECT_THAT(proc->GetStageIndex(proc->GetStateRead(0)), IsOkAndHolds(0));
637+
EXPECT_THAT(proc->GetStageIndex(proc->GetStateReads(0).front()),
638+
IsOkAndHolds(0));
600639

601640
// The verifier requires that every node be in a stage before we finish.
602641
ASSERT_THAT(proc->AddNodeToStage(2, add), IsOkAndHolds(true));

0 commit comments

Comments
 (0)