@@ -205,6 +205,43 @@ TEST_F(ProcTest, StatelessProc) {
205205 EXPECT_EQ (proc->DumpIr (), " proc p() {\n }\n " );
206206}
207207
208+ TEST_F (ProcTest, MultipleStateReads) {
209+ auto p = CreatePackage ();
210+ ProcBuilder pb (" p" , p.get ());
211+ BValue tkn = pb.StateElement (" tkn" , Value::Token ());
212+ BValue state = pb.StateElement (" x" , Value (UBits (42 , 32 )));
213+ XLS_ASSERT_OK_AND_ASSIGN (Proc * proc, pb.Build ({tkn, state}));
214+
215+ StateElement* state_elem = proc->GetStateElement (1 );
216+
217+ EXPECT_EQ (proc->GetStateReads (1 ).size (), 1 );
218+ StateRead* read1 = proc->GetStateReads (1 ).front ();
219+
220+ // Second Read
221+ XLS_ASSERT_OK_AND_ASSIGN (
222+ StateRead * read2,
223+ proc->MakeNodeWithName <StateRead>(SourceInfo (), state_elem,
224+ /* predicate=*/ std::nullopt ,
225+ /* label=*/ std::nullopt , " x_read2" ));
226+ XLS_ASSERT_OK (proc->RebuildSideTables ());
227+
228+ EXPECT_EQ (proc->GetStateReads (1 ).size (), 2 );
229+ EXPECT_THAT (proc->GetStateReads (1 ), ElementsAre (read1, read2));
230+
231+ EXPECT_EQ (proc->GetStateReadsByStateElement (state_elem).size (), 2 );
232+ EXPECT_THAT (proc->GetStateReadsByStateElement (state_elem),
233+ ElementsAre (read1, read2));
234+
235+ // Remove the second read.
236+ std::string read2_name = read2->GetName ();
237+ XLS_ASSERT_OK (proc->RemoveNode (read2));
238+ XLS_ASSERT_OK (proc->RebuildSideTables ());
239+
240+ // Now we should have 1 read again.
241+ EXPECT_EQ (proc->GetStateReads (1 ).size (), 1 );
242+ EXPECT_EQ (proc->GetStateReads (1 ).front (), read1);
243+ }
244+
208245TEST_F (ProcTest, RemoveStateThatStillHasUse) {
209246 // Don't call CreatePackage which creates a VerifiedPackage because we
210247 // intentionally create a malformed proc.
@@ -254,10 +291,10 @@ TEST_F(ProcTest, Clone) {
254291 EXPECT_EQ (clone->DumpIr (),
255292 R"( proc cloned(tkn: token, state: bits[32], init={token, 42}) {
256293 tkn: token = state_read(state_element=tkn, id=12)
257- literal.14 : bits[32] = literal(value=1, id=14 )
258- state: bits[32] = state_read(state_element=state, id=13 )
294+ literal.13 : bits[32] = literal(value=1, id=13 )
295+ state: bits[32] = state_read(state_element=state, id=14 )
259296 receive_3: (token, bits[32]) = receive(tkn, channel=cloned_chan, id=15)
260- add.16: bits[32] = add(literal.14 , state, id=16)
297+ add.16: bits[32] = add(literal.13 , state, id=16)
261298 tuple_index.17: bits[32] = tuple_index(receive_3, index=1, id=17)
262299 tuple_index.18: token = tuple_index(receive_3, index=0, id=18)
263300 add.19: bits[32] = add(add.16, tuple_index.17, id=19)
@@ -304,10 +341,10 @@ proc cloned<input_chan: bits[32] in, chan: bits[32] out>(tkn: token, state: bits
304341 chan_interface input_chan(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
305342 chan_interface chan(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
306343 tkn: token = state_read(state_element=tkn, id=1)
307- literal.3 : bits[32] = literal(value=1, id=3 )
308- state: bits[32] = state_read(state_element=state, id=2 )
344+ literal.2 : bits[32] = literal(value=1, id=2 )
345+ state: bits[32] = state_read(state_element=state, id=3 )
309346 receive_3: (token, bits[32]) = receive(tkn, channel=input_chan, id=4)
310- add.5: bits[32] = add(literal.3 , state, id=5)
347+ add.5: bits[32] = add(literal.2 , state, id=5)
311348 tuple_index.6: bits[32] = tuple_index(receive_3, index=1, id=6)
312349 tuple_index.7: token = tuple_index(receive_3, index=0, id=7)
313350 add.8: bits[32] = add(add.5, tuple_index.6, id=8)
@@ -355,15 +392,15 @@ TEST_F(ProcTest, CloneNewStyle) {
355392 chan baz(bits[32], id=0, kind=streaming, ops=send_receive, flow_control=ready_valid, strictness=proven_mutually_exclusive)
356393 chan_interface baz(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=none, flop_kind=none)
357394 chan_interface baz(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=none, flop_kind=none)
358- tkn: token = literal(value=token, id=14 )
359- receive_3: (token, bits[32]) = receive(tkn, channel=foo, id=15 )
360- tuple_index.16 : token = tuple_index(receive_3, index=0, id=16 )
361- receive_6: (token, bits[32]) = receive(tuple_index.16 , channel=baz, id=17 )
362- tuple_index.18 : token = tuple_index(receive_6, index=0, id=18 )
363- state: bits[32] = state_read(state_element=state, id=13 )
395+ tkn: token = literal(value=token, id=13 )
396+ receive_3: (token, bits[32]) = receive(tkn, channel=foo, id=14 )
397+ tuple_index.15 : token = tuple_index(receive_3, index=0, id=15 )
398+ receive_6: (token, bits[32]) = receive(tuple_index.15 , channel=baz, id=16 )
399+ tuple_index.17 : token = tuple_index(receive_6, index=0, id=17 )
400+ state: bits[32] = state_read(state_element=state, id=18 )
364401 tuple_index.19: bits[32] = tuple_index(receive_3, index=1, id=19)
365402 tuple_index.20: bits[32] = tuple_index(receive_6, index=1, id=20)
366- send_9: token = send(tuple_index.18 , state, channel=bar, id=21)
403+ send_9: token = send(tuple_index.17 , state, channel=bar, id=21)
367404 add.22: bits[32] = add(tuple_index.19, tuple_index.20, id=22)
368405 send_10: token = send(send_9, state, channel=baz, id=23)
369406 next_value.24: () = next_value(param=state, value=add.22, id=24)
@@ -556,7 +593,8 @@ TEST_F(ScheduledProcTest, StageAddAndClear) {
556593 proc->ClearStages ();
557594 EXPECT_TRUE (proc->stages ().empty ());
558595 // Re-stage the state element to satisfy the verifier.
559- XLS_ASSERT_OK (proc->AddNodeToStage (0 , proc->GetStateRead (0 )).status ());
596+ XLS_ASSERT_OK (
597+ proc->AddNodeToStage (0 , proc->GetStateReads (0 ).front ()).status ());
560598}
561599
562600TEST_F (ScheduledProcTest, AddEmptyStages) {
@@ -596,7 +634,8 @@ TEST_F(ScheduledProcTest, GetStageIndex) {
596634 EXPECT_THAT (proc->GetStageIndex (x), IsOkAndHolds (1 ));
597635 EXPECT_THAT (proc->GetStageIndex (y), IsOkAndHolds (2 ));
598636 EXPECT_THAT (proc->GetStageIndex (add), StatusIs (absl::StatusCode::kNotFound ));
599- EXPECT_THAT (proc->GetStageIndex (proc->GetStateRead (0 )), IsOkAndHolds (0 ));
637+ EXPECT_THAT (proc->GetStageIndex (proc->GetStateReads (0 ).front ()),
638+ IsOkAndHolds (0 ));
600639
601640 // The verifier requires that every node be in a stage before we finish.
602641 ASSERT_THAT (proc->AddNodeToStage (2 , add), IsOkAndHolds (true ));
0 commit comments