diff --git a/xls/codegen/block_conversion_test.cc b/xls/codegen/block_conversion_test.cc index 494fb168d2..eac6aec70d 100644 --- a/xls/codegen/block_conversion_test.cc +++ b/xls/codegen/block_conversion_test.cc @@ -5629,16 +5629,18 @@ proc alternating_counter(counter0: bits[32], counter1: bits[32], index: bits[1], package->GetProc("alternating_counter")); AddPredicate only_on_0(*proc->GetNode("index_is_0")); AddPredicate only_on_1(*proc->GetNode("index_is_1")); - XLS_ASSERT_OK( - proc->TransformStateElement(proc->GetStateReadByStateElement( - *proc->GetStateElementByName("counter0")), - Value(UBits(0, 32)), only_on_0) - .status()); - XLS_ASSERT_OK( - proc->TransformStateElement(proc->GetStateReadByStateElement( - *proc->GetStateElementByName("counter1")), - Value(UBits(5, 32)), only_on_1) - .status()); + XLS_ASSERT_OK(proc->TransformStateElement( + proc->GetStateReadsByStateElement( + *proc->GetStateElementByName("counter0")) + .front(), + Value(UBits(0, 32)), only_on_0) + .status()); + XLS_ASSERT_OK(proc->TransformStateElement( + proc->GetStateReadsByStateElement( + *proc->GetStateElementByName("counter1")) + .front(), + Value(UBits(5, 32)), only_on_1) + .status()); ASSERT_THAT(proc->next_values(*proc->GetStateElementByName("counter0")), SizeIs(1)); @@ -5746,16 +5748,18 @@ proc alternating_counter(counter0: bits[32], counter1: bits[32], index: bits[1], package->GetProc("alternating_counter")); AddPredicate only_on_0(*proc->GetNode("index_is_0")); AddPredicate only_on_1(*proc->GetNode("index_is_1")); - XLS_ASSERT_OK( - proc->TransformStateElement(proc->GetStateReadByStateElement( - *proc->GetStateElementByName("counter0")), - Value(UBits(0, 32)), only_on_0) - .status()); - XLS_ASSERT_OK( - proc->TransformStateElement(proc->GetStateReadByStateElement( - *proc->GetStateElementByName("counter1")), - Value(UBits(5, 32)), only_on_1) - .status()); + XLS_ASSERT_OK(proc->TransformStateElement( + proc->GetStateReadsByStateElement( + *proc->GetStateElementByName("counter0")) + .front(), + Value(UBits(0, 32)), only_on_0) + .status()); + XLS_ASSERT_OK(proc->TransformStateElement( + proc->GetStateReadsByStateElement( + *proc->GetStateElementByName("counter1")) + .front(), + Value(UBits(5, 32)), only_on_1) + .status()); SchedulingOptions scheduling_options = SchedulingOptions() diff --git a/xls/codegen_v_1_5/block_conversion_pass_pipeline_test.cc b/xls/codegen_v_1_5/block_conversion_pass_pipeline_test.cc index b16c5efd76..eaf5a2f90b 100644 --- a/xls/codegen_v_1_5/block_conversion_pass_pipeline_test.cc +++ b/xls/codegen_v_1_5/block_conversion_pass_pipeline_test.cc @@ -5857,16 +5857,18 @@ proc alternating_counter(counter0: bits[32], counter1: bits[32], index: bits[1], package->GetProc("alternating_counter")); AddPredicate only_on_0(*proc->GetNode("index_is_0")); AddPredicate only_on_1(*proc->GetNode("index_is_1")); - XLS_ASSERT_OK( - proc->TransformStateElement(proc->GetStateReadByStateElement( - *proc->GetStateElementByName("counter0")), - Value(UBits(0, 32)), only_on_0) - .status()); - XLS_ASSERT_OK( - proc->TransformStateElement(proc->GetStateReadByStateElement( - *proc->GetStateElementByName("counter1")), - Value(UBits(5, 32)), only_on_1) - .status()); + XLS_ASSERT_OK(proc->TransformStateElement( + proc->GetStateReadsByStateElement( + *proc->GetStateElementByName("counter0")) + .front(), + Value(UBits(0, 32)), only_on_0) + .status()); + XLS_ASSERT_OK(proc->TransformStateElement( + proc->GetStateReadsByStateElement( + *proc->GetStateElementByName("counter1")) + .front(), + Value(UBits(5, 32)), only_on_1) + .status()); ASSERT_THAT(proc->next_values(*proc->GetStateElementByName("counter0")), SizeIs(1)); @@ -5970,16 +5972,18 @@ proc alternating_counter(counter0: bits[32], counter1: bits[32], index: bits[1], package->GetProc("alternating_counter")); AddPredicate only_on_0(*proc->GetNode("index_is_0")); AddPredicate only_on_1(*proc->GetNode("index_is_1")); - XLS_ASSERT_OK( - proc->TransformStateElement(proc->GetStateReadByStateElement( - *proc->GetStateElementByName("counter0")), - Value(UBits(0, 32)), only_on_0) - .status()); - XLS_ASSERT_OK( - proc->TransformStateElement(proc->GetStateReadByStateElement( - *proc->GetStateElementByName("counter1")), - Value(UBits(5, 32)), only_on_1) - .status()); + XLS_ASSERT_OK(proc->TransformStateElement( + proc->GetStateReadsByStateElement( + *proc->GetStateElementByName("counter0")) + .front(), + Value(UBits(0, 32)), only_on_0) + .status()); + XLS_ASSERT_OK(proc->TransformStateElement( + proc->GetStateReadsByStateElement( + *proc->GetStateElementByName("counter1")) + .front(), + Value(UBits(5, 32)), only_on_1) + .status()); SchedulingOptions scheduling_options = SchedulingOptions() diff --git a/xls/contrib/xlscc/generate_fsm.cc b/xls/contrib/xlscc/generate_fsm.cc index 626d366d87..440b2a4968 100644 --- a/xls/contrib/xlscc/generate_fsm.cc +++ b/xls/contrib/xlscc/generate_fsm.cc @@ -837,8 +837,12 @@ NewFSMGenerator::GenerateNewFSMInvocation( xls_state_element = pb.StateElement( state_element.name, xls::ZeroOfType(state_element.type), body_loc); } else { - xls::StateRead* state_read = pb.proc()->GetStateReadByStateElement( - state_element.existing_state_element); + absl::Span reads = + pb.proc()->GetStateReadsByStateElement( + state_element.existing_state_element); + XLSCC_CHECK_LE(reads.size(), 1, body_loc); + XLSCC_CHECK(!reads.empty(), body_loc); + xls::StateRead* state_read = reads.front(); xls_state_element = TrackedBValue(state_read, &pb); } diff --git a/xls/contrib/xlscc/translate_block.cc b/xls/contrib/xlscc/translate_block.cc index af70666b4d..6529d71b81 100644 --- a/xls/contrib/xlscc/translate_block.cc +++ b/xls/contrib/xlscc/translate_block.cc @@ -35,6 +35,7 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "clang/include/clang/AST/Attr.h" #include "clang/include/clang/AST/Attrs.inc" #include "clang/include/clang/AST/Decl.h" @@ -258,8 +259,11 @@ absl::StatusOr ComposeStaticValueInput( if (!generate_new_fsm || !TypeIsDecomposable(xls_type)) { xls::StateElement* state_element = state_element_for_static.at( DeclLeaf{.decl = namedecl, .leaf_index = -1}); - return TrackedBValue(pb.proc()->GetStateReadByStateElement(state_element), - &pb); + absl::Span reads = + pb.proc()->GetStateReadsByStateElement(state_element); + CHECK(!reads.empty()); + CHECK_LE(reads.size(), 1); + return TrackedBValue(reads.front(), &pb); } absl::InlinedVector decomposed_types = DecomposeTupleTypes(xls_type); @@ -268,7 +272,11 @@ absl::StatusOr ComposeStaticValueInput( for (int64_t i = 0; i < decomposed_types.size(); ++i) { xls::StateElement* decomposed_element = state_element_for_static.at( DeclLeaf{.decl = namedecl, .leaf_index = i}); - nodes.push_back(pb.proc()->GetStateReadByStateElement(decomposed_element)); + absl::Span reads = + pb.proc()->GetStateReadsByStateElement(decomposed_element); + CHECK(!reads.empty()); + CHECK_LE(reads.size(), 1); + nodes.push_back(reads.front()); } XLS_ASSIGN_OR_RETURN(xls::Node * node, @@ -662,8 +670,11 @@ absl::StatusOr Translator::GenerateIR_Block( next_state_value.value = TrackedBValue(decomposed_next_val, &pb); } else { XLSCC_CHECK_EQ(decomposed_elems.size(), 1, body_loc); - xls::StateRead* state_read = - pb.proc()->GetStateReadByStateElement(decomposed_elem); + absl::Span reads = + pb.proc()->GetStateReadsByStateElement(decomposed_elem); + XLSCC_CHECK(!reads.empty(), body_loc); + XLSCC_CHECK_LE(reads.size(), 1, body_loc); + xls::StateRead* state_read = reads.front(); TrackedBValue prev_val(state_read, &pb); next_state_value.value = pb.And(prev_val, @@ -2268,7 +2279,11 @@ absl::StatusOr Translator::BuildWithNextStateValueMap( return absl::InternalError( absl::StrFormat("No next values for state element %s", elem->name())); } - xls::StateRead* state_read = pb.proc()->GetStateReadByStateElement(elem); + absl::Span reads = + pb.proc()->GetStateReadsByStateElement(elem); + XLSCC_CHECK(!reads.empty(), loc); + XLSCC_CHECK_LE(reads.size(), 1, loc); + xls::StateRead* state_read = reads.front(); TrackedBValue read_bval(state_read, &pb); if (values_for_elem == 1) { const NextStateValue& next_state_value = diff --git a/xls/contrib/xlscc/translate_loops.cc b/xls/contrib/xlscc/translate_loops.cc index da099e01b6..3b3a8e53ea 100644 --- a/xls/contrib/xlscc/translate_loops.cc +++ b/xls/contrib/xlscc/translate_loops.cc @@ -32,6 +32,7 @@ #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/time/time.h" +#include "absl/types/span.h" #include "clang/include/clang/AST/Attr.h" #include "clang/include/clang/AST/Decl.h" #include "clang/include/clang/AST/Expr.h" @@ -1374,8 +1375,11 @@ Translator::GenerateIR_PipelinedLoopContents( } else { xls::StateElement* state_elem = prepared.state_element_for_variable.at(DeclLeaf{.decl = decl}); - state_reads_by_decl[decl] = - TrackedBValue(pb.proc()->GetStateReadByStateElement(state_elem), &pb); + absl::Span reads = + pb.proc()->GetStateReadsByStateElement(state_elem); + XLSCC_CHECK(!reads.empty(), loc); + XLSCC_CHECK_LE(reads.size(), 1); + state_reads_by_decl[decl] = TrackedBValue(reads.front(), &pb); } } diff --git a/xls/contrib/xlscc/unit_tests/unit_test.cc b/xls/contrib/xlscc/unit_tests/unit_test.cc index 89461b4ec0..47bea5dbe4 100644 --- a/xls/contrib/xlscc/unit_tests/unit_test.cc +++ b/xls/contrib/xlscc/unit_tests/unit_test.cc @@ -938,7 +938,11 @@ XlsccTestBase::GetStatesByIONodeForFSMProc(std::string_view func_name) { CHECK_EQ(found_proc_with_fsm, nullptr); found_proc_with_fsm = proc.get(); - fsm_state_read = proc->GetStateReadByStateElement(state_element); + absl::Span reads = + proc->GetStateReadsByStateElement(state_element); + CHECK(!reads.empty()); + CHECK_LE(reads.size(), 1); + fsm_state_read = reads.front(); CHECK_NE(found_proc_with_fsm, nullptr); CHECK_NE(fsm_state_read, nullptr); diff --git a/xls/dev_tools/ir_minimizer_main.cc b/xls/dev_tools/ir_minimizer_main.cc index 5b8551336c..8e346104eb 100644 --- a/xls/dev_tools/ir_minimizer_main.cc +++ b/xls/dev_tools/ir_minimizer_main.cc @@ -450,14 +450,16 @@ absl::StatusOr RemoveDeadParameters(FunctionBase* f) { // Replace all uses of invariant state elements (i.e.: ones where // next[i] == param[i]) with a literal of the initial value. Value init_value = invariant->initial_value(); - Node* state_read = p->GetStateReadByStateElement(invariant); absl::btree_set next_values = p->next_values(invariant); for (Next* next : next_values) { XLS_RETURN_IF_ERROR(p->RemoveNode(next)); } - XLS_RETURN_IF_ERROR( - state_read->ReplaceUsesWithNew(init_value).status()); + for (xls::StateRead* state_read : + p->GetStateReadsByStateElement(invariant)) { + XLS_RETURN_IF_ERROR( + state_read->ReplaceUsesWithNew(init_value).status()); + } dead_state_elements.insert(invariant); } diff --git a/xls/ir/BUILD b/xls/ir/BUILD index 8f33c0c9a4..a24296d4cb 100644 --- a/xls/ir/BUILD +++ b/xls/ir/BUILD @@ -931,6 +931,7 @@ cc_test( "@com_google_absl//absl/base", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@googletest//:gtest", ], ) diff --git a/xls/ir/ir_parser.cc b/xls/ir/ir_parser.cc index eb3f3b3a8e..c67b2917bf 100644 --- a/xls/ir/ir_parser.cc +++ b/xls/ir/ir_parser.cc @@ -1895,9 +1895,11 @@ absl::StatusOr Parser::ParseBody( Proc * source_proc, ParseProc(package, /*outer_attributes=*/{}, &source)); for (StateElement* element : source_proc->StateElements()) { - name_to_value->emplace( - element->name(), - bb->SourceNode(source_proc->GetStateReadByStateElement(element))); + absl::Span reads = + source_proc->GetStateReadsByStateElement(element); + XLS_RET_CHECK_EQ(reads.size(), 1); + name_to_value->emplace(element->name(), + bb->SourceNode(reads.front())); } } else { return absl::InvalidArgumentError(absl::StrFormat( diff --git a/xls/ir/ir_parser_test.cc b/xls/ir/ir_parser_test.cc index 3b4af676b3..9049b833df 100644 --- a/xls/ir/ir_parser_test.cc +++ b/xls/ir/ir_parser_test.cc @@ -27,6 +27,7 @@ #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/substitute.h" +#include "absl/types/span.h" #include "xls/common/source_location.h" #include "xls/common/status/matchers.h" #include "xls/ir/bits.h" @@ -607,7 +608,9 @@ proc foo( x: bits[32], y: (), z: bits[32], init={42, (), 123}) { XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, package->GetProc("foo")); EXPECT_EQ(proc->GetStateElementCount(), 3); XLS_ASSERT_OK_AND_ASSIGN(StateElement * x, proc->GetStateElementByName("x")); - EXPECT_THAT(proc->GetStateReadByStateElement(x)->predicate(), std::nullopt); + absl::Span reads = proc->GetStateReadsByStateElement(x); + ASSERT_EQ(reads.size(), 1); + EXPECT_THAT(reads.front()->predicate(), std::nullopt); } TEST(IrParserTest, ProcWithPredicatedStateRead) { @@ -626,17 +629,22 @@ proc foo( x: bits[32], y: bits[1], z: bits[32], init={42, 1, 123}) { EXPECT_EQ(proc->GetStateElementCount(), 3); XLS_ASSERT_OK_AND_ASSIGN(StateElement * x, proc->GetStateElementByName("x")); - std::optional x_predicate = - proc->GetStateReadByStateElement(x)->predicate(); + absl::Span reads_x = proc->GetStateReadsByStateElement(x); + ASSERT_EQ(reads_x.size(), 1); + std::optional x_predicate = reads_x.front()->predicate(); ASSERT_TRUE(x_predicate.has_value()); ASSERT_EQ((*x_predicate)->op(), Op::kStateRead); EXPECT_EQ((*x_predicate)->As()->state_element()->name(), "y"); XLS_ASSERT_OK_AND_ASSIGN(StateElement * y, proc->GetStateElementByName("y")); - ASSERT_FALSE(proc->GetStateReadByStateElement(y)->predicate().has_value()); + absl::Span reads_y = proc->GetStateReadsByStateElement(y); + ASSERT_EQ(reads_y.size(), 1); + ASSERT_FALSE(reads_y.front()->predicate().has_value()); XLS_ASSERT_OK_AND_ASSIGN(StateElement * z, proc->GetStateElementByName("z")); - ASSERT_FALSE(proc->GetStateReadByStateElement(z)->predicate().has_value()); + absl::Span reads_z = proc->GetStateReadsByStateElement(z); + ASSERT_EQ(reads_z.size(), 1); + ASSERT_FALSE(reads_z.front()->predicate().has_value()); } TEST(IrParserTest, ParseSendReceiveChannel) { diff --git a/xls/ir/node_util_test.cc b/xls/ir/node_util_test.cc index d6f8e9466f..36628188a6 100644 --- a/xls/ir/node_util_test.cc +++ b/xls/ir/node_util_test.cc @@ -381,7 +381,7 @@ TEST_F(NodeUtilTest, ChannelNodes) { EXPECT_THAT(GetChannelUsedByNode(rcv.node()), IsOkAndHolds(ch0)); EXPECT_THAT(GetChannelUsedByNode(send.node()), IsOkAndHolds(ch1)); - EXPECT_THAT(GetChannelUsedByNode(proc->GetStateRead(0)), + EXPECT_THAT(GetChannelUsedByNode(proc->GetStateReads(0).front()), StatusIs(absl::StatusCode::kNotFound, HasSubstr("No channel associated with node"))); } @@ -435,7 +435,7 @@ TEST_F(NodeUtilTest, ReplaceTupleIndicesWorksWithToken) { // works, we'd need to make an after_all and add the receive's output token to // it after calling ReplaceTupleElementsWith(). XLS_EXPECT_OK(ReplaceTupleElementsWith( - receive_node, {{0, proc->GetStateRead(0)}, {1, lit0}})); + receive_node, {{0, proc->GetStateReads(0).front()}, {1, lit0}})); ExpectIr(proc->DumpIr(), TestName()); } diff --git a/xls/ir/proc.cc b/xls/ir/proc.cc index 3c9eeba358..5f76831b2c 100644 --- a/xls/ir/proc.cc +++ b/xls/ir/proc.cc @@ -214,13 +214,15 @@ absl::Status Proc::RemoveStateElement(int64_t index) { StateElement* old_state_element = GetStateElement(index); auto old_state_read_it = state_reads_.find(old_state_element); XLS_RET_CHECK(old_state_read_it != state_reads_.end()); - if (!old_state_read_it->second->users().empty()) { - return absl::InvalidArgumentError(absl::StrFormat( - "Cannot remove state element %d of proc %s, existing " - "state read %s has uses", - index, name(), old_state_read_it->second->GetNameView())); + for (StateRead* read : old_state_read_it->second) { + if (!read->users().empty()) { + return absl::InvalidArgumentError( + absl::StrFormat("Cannot remove state element %d of proc %s, existing " + "state read %s has uses", + index, name(), read->GetNameView())); + } + XLS_RETURN_IF_ERROR(RemoveNode(read)); } - XLS_RETURN_IF_ERROR(RemoveNode(old_state_read_it->second)); // TODO(allight): This should ideally not need to be done manually. state_reads_.erase(old_state_read_it); @@ -232,11 +234,14 @@ absl::Status Proc::RemoveStateElement(int64_t index) { absl::Status Proc::RemoveAllStateElements() { // TODO(allight): This relies on side tables being valid. For now just let it // go. - for (const auto& [elem, read] : state_reads_) { - if (read != nullptr) { - XLS_RETURN_IF_ERROR(RemoveNode(read)) - << "Cannot remove " << elem->ToString() << " of proc " << name() - << " because read '" << read->ToString() << "' could not be removed."; + for (const auto& [elem, reads] : state_reads_) { + for (StateRead* read : reads) { + if (read != nullptr) { + XLS_RETURN_IF_ERROR(RemoveNode(read)) + << "Cannot remove " << elem->ToString() << " of proc " << name() + << " because read '" << read->ToString() + << "' could not be removed."; + } } XLS_RETURN_IF_ERROR(state_name_uniquer_.ReleaseIdentifier(elem->name())) << "Cannot release name of " << elem->ToString(); @@ -278,7 +283,7 @@ absl::StatusOr Proc::InsertStateElement( MakeNodeWithName( loc, state_element, read_predicate, /*label=*/std::nullopt, state_element->name())); - state_reads_[state_element] = state_read; + state_reads_[state_element].push_back(state_read); if (next_state.has_value()) { if (!ValueConformsToType(init_value, next_state.value()->GetType())) { @@ -351,14 +356,13 @@ absl::StatusOr Proc::Clone( return mapping.at(orig); }; for (StateElement* state_element : StateElements()) { - StateRead* state_read = state_reads_.at(state_element); - XLS_ASSIGN_OR_RETURN( - StateRead * cloned_state_read, - cloned_proc->AppendStateElement( - remap_name(state_name_remapping, state_element->name()), - state_element->initial_value(), state_read->predicate(), - /*next_state=*/std::nullopt)); - original_to_clone[state_read] = cloned_state_read; + XLS_RETURN_IF_ERROR( + cloned_proc + ->InsertUnreadStateElement( + cloned_proc->GetStateElementCount(), + remap_name(state_name_remapping, state_element->name()), + state_element->initial_value()) + .status()); } if (is_new_style_proc()) { absl::flat_hash_map channel_map; @@ -445,7 +449,23 @@ absl::StatusOr Proc::Clone( switch (node->op()) { case Op::kStateRead: { - continue; + StateRead* src = node->As(); + StateElement* src_elem = src->state_element(); + XLS_ASSIGN_OR_RETURN(int64_t idx, GetStateElementIndex(src_elem)); + StateElement* cloned_elem = cloned_proc->GetStateElement(idx); + + std::optional cloned_predicate; + if (src->predicate().has_value()) { + cloned_predicate = original_to_clone.at(src->predicate().value()); + } + + XLS_ASSIGN_OR_RETURN(StateRead * cloned_state_read, + cloned_proc->MakeNodeWithName( + src->loc(), cloned_elem, cloned_predicate, + /*label=*/std::nullopt, cloned_elem->name())); + cloned_proc->state_reads_[cloned_elem].push_back(cloned_state_read); + original_to_clone[node] = cloned_state_read; + break; } case Op::kReceive: { Receive* src = node->As(); @@ -1000,10 +1020,8 @@ absl::Status Proc::InternalRebuildSideTables() { state_reads_.clear(); for (Node* n : nodes()) { if (n->Is()) { - XLS_RET_CHECK(!state_reads_.contains(n->As()->state_element())) - << "Duplicate state element read: " - << n->As()->state_element(); - state_reads_[n->As()->state_element()] = n->As(); + state_reads_[n->As()->state_element()].push_back( + n->As()); } else if (n->Is()) { next_values_.push_back(n->As()); next_values_by_state_element_[n->As()->state_element()].insert( diff --git a/xls/ir/proc.h b/xls/ir/proc.h index 6a3384bf63..371ddb72c2 100644 --- a/xls/ir/proc.h +++ b/xls/ir/proc.h @@ -95,10 +95,23 @@ class Proc : public FunctionBase { return state_elements_.contains(name); } + // Remove legacy getters after all downstream passes migrate logic. StateRead* GetStateRead(int64_t index) const { - return state_reads_.at(GetStateElement(index)); + return GetStateReads(index).front(); } + StateRead* GetStateReadByStateElement(StateElement* state_element) const { + return GetStateReadsByStateElement(state_element).front(); + } + + // Get state reads for a state element at the given index. + absl::Span GetStateReads(int64_t index) const { + return state_reads_.at(GetStateElement(index)); + } + + // Get state reads for a state element. + absl::Span GetStateReadsByStateElement( + StateElement* state_element) const { return state_reads_.at(state_element); } @@ -403,8 +416,8 @@ class Proc : public FunctionBase { absl::flat_hash_map> state_elements_; - // Map of the unique StateRead node for each state element. - absl::flat_hash_map state_reads_; + // Map of StateRead nodes for each state element. + absl::flat_hash_map> state_reads_; // Vector of state element pointers. Kept in sync with the state_elements_ // map. Enables easy, stable iteration over state elements. With this vector, diff --git a/xls/ir/proc_test.cc b/xls/ir/proc_test.cc index dd1e401dff..c9ae241ccd 100644 --- a/xls/ir/proc_test.cc +++ b/xls/ir/proc_test.cc @@ -205,6 +205,43 @@ TEST_F(ProcTest, StatelessProc) { EXPECT_EQ(proc->DumpIr(), "proc p() {\n}\n"); } +TEST_F(ProcTest, MultipleStateReads) { + auto p = CreatePackage(); + ProcBuilder pb("p", p.get()); + BValue tkn = pb.StateElement("tkn", Value::Token()); + BValue state = pb.StateElement("x", Value(UBits(42, 32))); + XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build({tkn, state})); + + StateElement* state_elem = proc->GetStateElement(1); + + EXPECT_EQ(proc->GetStateReads(1).size(), 1); + StateRead* read1 = proc->GetStateReads(1).front(); + + // Second Read + XLS_ASSERT_OK_AND_ASSIGN( + StateRead * read2, + proc->MakeNodeWithName(SourceInfo(), state_elem, + /*predicate=*/std::nullopt, + /*label=*/std::nullopt, "x_read2")); + XLS_ASSERT_OK(proc->RebuildSideTables()); + + EXPECT_EQ(proc->GetStateReads(1).size(), 2); + EXPECT_THAT(proc->GetStateReads(1), ElementsAre(read1, read2)); + + EXPECT_EQ(proc->GetStateReadsByStateElement(state_elem).size(), 2); + EXPECT_THAT(proc->GetStateReadsByStateElement(state_elem), + ElementsAre(read1, read2)); + + // Remove the second read. + std::string read2_name = read2->GetName(); + XLS_ASSERT_OK(proc->RemoveNode(read2)); + XLS_ASSERT_OK(proc->RebuildSideTables()); + + // Now we should have 1 read again. + EXPECT_EQ(proc->GetStateReads(1).size(), 1); + EXPECT_EQ(proc->GetStateReads(1).front(), read1); +} + TEST_F(ProcTest, RemoveStateThatStillHasUse) { // Don't call CreatePackage which creates a VerifiedPackage because we // intentionally create a malformed proc. @@ -254,10 +291,10 @@ TEST_F(ProcTest, Clone) { EXPECT_EQ(clone->DumpIr(), R"(proc cloned(tkn: token, state: bits[32], init={token, 42}) { tkn: token = state_read(state_element=tkn, id=12) - literal.14: bits[32] = literal(value=1, id=14) - state: bits[32] = state_read(state_element=state, id=13) + literal.13: bits[32] = literal(value=1, id=13) + state: bits[32] = state_read(state_element=state, id=14) receive_3: (token, bits[32]) = receive(tkn, channel=cloned_chan, id=15) - add.16: bits[32] = add(literal.14, state, id=16) + add.16: bits[32] = add(literal.13, state, id=16) tuple_index.17: bits[32] = tuple_index(receive_3, index=1, id=17) tuple_index.18: token = tuple_index(receive_3, index=0, id=18) add.19: bits[32] = add(add.16, tuple_index.17, id=19) @@ -304,10 +341,10 @@ proc cloned(tkn: token, state: bits chan_interface input_chan(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) chan_interface chan(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) tkn: token = state_read(state_element=tkn, id=1) - literal.3: bits[32] = literal(value=1, id=3) - state: bits[32] = state_read(state_element=state, id=2) + literal.2: bits[32] = literal(value=1, id=2) + state: bits[32] = state_read(state_element=state, id=3) receive_3: (token, bits[32]) = receive(tkn, channel=input_chan, id=4) - add.5: bits[32] = add(literal.3, state, id=5) + add.5: bits[32] = add(literal.2, state, id=5) tuple_index.6: bits[32] = tuple_index(receive_3, index=1, id=6) tuple_index.7: token = tuple_index(receive_3, index=0, id=7) add.8: bits[32] = add(add.5, tuple_index.6, id=8) @@ -355,15 +392,15 @@ TEST_F(ProcTest, CloneNewStyle) { chan baz(bits[32], id=0, kind=streaming, ops=send_receive, flow_control=ready_valid, strictness=proven_mutually_exclusive) chan_interface baz(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=none, flop_kind=none) chan_interface baz(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=none, flop_kind=none) - tkn: token = literal(value=token, id=14) - receive_3: (token, bits[32]) = receive(tkn, channel=foo, id=15) - tuple_index.16: token = tuple_index(receive_3, index=0, id=16) - receive_6: (token, bits[32]) = receive(tuple_index.16, channel=baz, id=17) - tuple_index.18: token = tuple_index(receive_6, index=0, id=18) - state: bits[32] = state_read(state_element=state, id=13) + tkn: token = literal(value=token, id=13) + receive_3: (token, bits[32]) = receive(tkn, channel=foo, id=14) + tuple_index.15: token = tuple_index(receive_3, index=0, id=15) + receive_6: (token, bits[32]) = receive(tuple_index.15, channel=baz, id=16) + tuple_index.17: token = tuple_index(receive_6, index=0, id=17) + state: bits[32] = state_read(state_element=state, id=18) tuple_index.19: bits[32] = tuple_index(receive_3, index=1, id=19) tuple_index.20: bits[32] = tuple_index(receive_6, index=1, id=20) - send_9: token = send(tuple_index.18, state, channel=bar, id=21) + send_9: token = send(tuple_index.17, state, channel=bar, id=21) add.22: bits[32] = add(tuple_index.19, tuple_index.20, id=22) send_10: token = send(send_9, state, channel=baz, id=23) next_value.24: () = next_value(param=state, value=add.22, id=24) @@ -556,7 +593,8 @@ TEST_F(ScheduledProcTest, StageAddAndClear) { proc->ClearStages(); EXPECT_TRUE(proc->stages().empty()); // Re-stage the state element to satisfy the verifier. - XLS_ASSERT_OK(proc->AddNodeToStage(0, proc->GetStateRead(0)).status()); + XLS_ASSERT_OK( + proc->AddNodeToStage(0, proc->GetStateReads(0).front()).status()); } TEST_F(ScheduledProcTest, AddEmptyStages) { @@ -596,7 +634,8 @@ TEST_F(ScheduledProcTest, GetStageIndex) { EXPECT_THAT(proc->GetStageIndex(x), IsOkAndHolds(1)); EXPECT_THAT(proc->GetStageIndex(y), IsOkAndHolds(2)); EXPECT_THAT(proc->GetStageIndex(add), StatusIs(absl::StatusCode::kNotFound)); - EXPECT_THAT(proc->GetStageIndex(proc->GetStateRead(0)), IsOkAndHolds(0)); + EXPECT_THAT(proc->GetStageIndex(proc->GetStateReads(0).front()), + IsOkAndHolds(0)); // The verifier requires that every node be in a stage before we finish. ASSERT_THAT(proc->AddNodeToStage(2, add), IsOkAndHolds(true)); diff --git a/xls/ir/proc_testutils.cc b/xls/ir/proc_testutils.cc index 1227d27849..71f6a8f3eb 100644 --- a/xls/ir/proc_testutils.cc +++ b/xls/ir/proc_testutils.cc @@ -244,11 +244,14 @@ absl::StatusOr> GetStateValuesBeforeActivation( absl::flat_hash_map& values) { std::vector states; for (StateElement* state_element : p->StateElements()) { - StateRead* state_read = p->GetStateReadByStateElement(state_element); + absl::Span reads = + p->GetStateReadsByStateElement(state_element); + XLS_RET_CHECK(!reads.empty()) << "No reads for " << state_element; + + BValue state_val; if (activation == 0) { - values[{state_read, 0}] = - fb.Literal(state_element->initial_value(), SourceInfo(), - absl::StrFormat("%s_initial_value", p->name())); + state_val = fb.Literal(state_element->initial_value(), SourceInfo(), + absl::StrFormat("%s_initial_value", p->name())); } else { std::vector cases; std::vector selectors; @@ -261,23 +264,26 @@ absl::StatusOr> GetStateValuesBeforeActivation( } if (selectors.empty()) { XLS_RET_CHECK_EQ(cases.size(), 1) << "no cases for " << state_element; - values[{state_read, activation}] = cases.front(); + state_val = cases.front(); } else if (cases.front().GetType()->IsBits() && cases.front().GetType()->GetFlatBitCount() == 0) { // Special case to avoid creating non-trivial uses of zero-len bit // vectors. - values[{state_read, activation}] = fb.Literal(UBits(0, 0)); + state_val = fb.Literal(UBits(0, 0)); } else { XLS_RET_CHECK_EQ(cases.size(), selectors.size()); // materialize the next values into a select. // Need to reverse to keep the LSB is case 0 etc. absl::c_reverse(selectors); - values[{state_read, activation}] = fb.PrioritySelect( + state_val = fb.PrioritySelect( fb.Concat(selectors), cases, - /*default_value=*/values[{state_read, activation - 1}]); + /*default_value=*/values[{reads.front(), activation - 1}]); } } - states.push_back(values[{state_read, activation}]); + for (StateRead* read : reads) { + values[{read, activation}] = state_val; + } + states.push_back(state_val); } return states; } diff --git a/xls/jit/function_base_jit.cc b/xls/jit/function_base_jit.cc index 4a37e41879..d1f677be57 100644 --- a/xls/jit/function_base_jit.cc +++ b/xls/jit/function_base_jit.cc @@ -118,6 +118,12 @@ llvm::Value* LoadPointerFromPointerArray(int64_t index, // functions, procs, etc, as well as for the partition functions called from // within the jitted functions. // +// Information about an input/output slot in the JIT function signature. +struct JitInput { + std::vector nodes; // All nodes mapping to this slot. + Type* type; // The type of the slot. +}; + // `input_args` are the Nodes whose values are passed in the `inputs` function // argument. `output_args` are Nodes whose values are written out to buffers // indicated by the `outputs` function argument. @@ -132,8 +138,8 @@ class LlvmFunctionWrapper final : public JitCompilationMetadata { // specified will be added to the function signature after all other // args. static LlvmFunctionWrapper Create( - std::string_view name, absl::Span input_args, - absl::Span output_args, llvm::Type* return_type, + std::string_view name, absl::Span input_args, + absl::Span output_args, llvm::Type* return_type, const JitBuilderContext& jit_context, std::optional extra_arg = std::nullopt) { llvm::Type* ptr_type = llvm::PointerType::get(jit_context.context(), 0); @@ -262,16 +268,18 @@ class LlvmFunctionWrapper final : public JitCompilationMetadata { } private: - LlvmFunctionWrapper(absl::Span input_args, - absl::Span output_args) - : input_args_(input_args.begin(), input_args.end()), - output_args_(output_args.begin(), output_args.end()) { + LlvmFunctionWrapper(absl::Span input_args, + absl::Span output_args) { for (int64_t i = 0; i < input_args.size(); ++i) { - CHECK(!input_indices_.contains(input_args[i])); - input_indices_[input_args[i]] = i; + for (Node* node : input_args[i].nodes) { + CHECK(!input_indices_.contains(node)); + input_indices_[node] = i; + } } for (int64_t i = 0; i < output_args.size(); ++i) { - output_indices_[output_args[i]].push_back(i); + for (Node* node : output_args[i].nodes) { + output_indices_[node].push_back(i); + } } } @@ -279,8 +287,6 @@ class LlvmFunctionWrapper final : public JitCompilationMetadata { llvm::FunctionType* fn_type_; std::unique_ptr> entry_builder_; - std::vector input_args_; - std::vector output_args_; absl::flat_hash_map input_indices_; absl::flat_hash_map> output_indices_; }; @@ -554,8 +560,8 @@ Type* OutputType(const Node* node) { // buffers are passed in via the `input`/`output` arguments of the function. absl::StatusOr BuildPartitionFunction( std::string_view name, const Partition& partition, - absl::Span global_input_nodes, - absl::Span global_output_nodes, + absl::Span global_input_nodes, + absl::Span global_output_nodes, const BufferAllocator& allocator, JitBuilderContext& jit_context) { LlvmFunctionWrapper wrapper = LlvmFunctionWrapper::Create( name, global_input_nodes, global_output_nodes, @@ -723,37 +729,48 @@ absl::Status AllocateBuffers(absl::Span partitions, // Returns the nodes which comprise the inputs to a jitted function implementing // `function_base`. These nodes are passed in via the `inputs` argument. -std::vector GetJittedFunctionInputs(FunctionBase* function_base) { +std::vector GetJittedFunctionInputs(FunctionBase* function_base) { if (function_base->IsBlock()) { Block* block = function_base->AsBlockOrDie(); - std::vector out; + std::vector out; out.reserve(block->GetInputPorts().size() + block->GetRegisters().size()); - absl::c_copy(block->GetInputPorts(), std::back_inserter(out)); - absl::c_transform( - block->GetRegisters(), std::back_inserter(out), - [&](Register* r) -> Node* { return *block->GetRegisterRead(r); }); + for (Node* p : block->GetInputPorts()) { + out.push_back(JitInput{{p}, p->GetType()}); + } + for (Register* r : block->GetRegisters()) { + Node* rr = *block->GetRegisterRead(r); + out.push_back(JitInput{{rr}, rr->GetType()}); + } return out; } if (function_base->IsProc()) { Proc* proc = function_base->AsProcOrDie(); - std::vector out; - absl::c_transform( - proc->StateElements(), std::back_inserter(out), - [&](StateElement* st) { return proc->GetStateReadByStateElement(st); }); + std::vector out; + out.reserve(proc->StateElements().size()); + for (StateElement* st : proc->StateElements()) { + std::vector reads; + for (StateRead* sr : proc->GetStateReadsByStateElement(st)) { + reads.push_back(sr); + } + out.push_back(JitInput{reads, st->type()}); + } return out; } - std::vector inputs(function_base->params().begin(), - function_base->params().end()); - return inputs; + std::vector out; + out.reserve(function_base->params().size()); + for (Node* param : function_base->params()) { + out.push_back(JitInput{{param}, param->GetType()}); + } + return out; } // Returns the nodes whose values are passed out of a jitted function. Buffers // to hold these node values are passed in via the `outputs` argument. -std::vector GetJittedFunctionOutputs(FunctionBase* function_base) { +std::vector GetJittedFunctionOutputs(FunctionBase* function_base) { if (function_base->IsFunction()) { // The output of a function is its return value. Function* f = function_base->AsFunctionOrDie(); - return {f->return_value()}; + return {JitInput{{f->return_value()}, f->return_value()->GetType()}}; } if (function_base->IsBlock()) { // Order of block outputs is: @@ -762,15 +779,18 @@ std::vector GetJittedFunctionOutputs(FunctionBase* function_base) { // (3) Second, and later RegisterWrites of each register (if any). // Multiple RegisterWrites are reconciled at the end of each cycle. Block* block = function_base->AsBlockOrDie(); - std::vector out; + std::vector out; out.reserve(block->GetOutputPorts().size() + block->GetRegisters().size()); - absl::c_copy(block->GetOutputPorts(), std::back_inserter(out)); + for (Node* p : block->GetOutputPorts()) { + out.push_back(JitInput{{p}, p->GetType()}); + } for (Register* reg : block->GetRegisters()) { - out.push_back(block->GetRegisterWrites(reg)->front()); + out.push_back( + JitInput{{block->GetRegisterWrites(reg)->front()}, reg->type()}); } for (Register* reg : block->GetRegisters()) { for (RegisterWrite* rw : block->GetRegisterWrites(reg)->subspan(1)) { - out.push_back(rw); + out.push_back(JitInput{{rw}, reg->type()}); } } return out; @@ -778,11 +798,15 @@ std::vector GetJittedFunctionOutputs(FunctionBase* function_base) { // The outputs of a proc are the next state values - which will be stored in // the memory locations for the state reads. Proc* proc = function_base->AsProcOrDie(); - std::vector outputs; + std::vector outputs; outputs.reserve(proc->StateElements().size()); - absl::c_transform( - proc->StateElements(), std::back_inserter(outputs), - [&](StateElement* st) { return proc->GetStateReadByStateElement(st); }); + for (StateElement* st : proc->StateElements()) { + std::vector reads; + for (StateRead* sr : proc->GetStateReadsByStateElement(st)) { + reads.push_back(sr); + } + outputs.push_back(JitInput{reads, st->type()}); + } return outputs; } @@ -846,8 +870,8 @@ absl::StatusOr BuildFunctionInternal( // have each function assign its own tmp buffer starting from 0 and make the // overall tmp-buffer the topo sort. std::string base_name = jit_context.MangleFunctionName(xls_function); - std::vector inputs = GetJittedFunctionInputs(xls_function); - std::vector outputs = GetJittedFunctionOutputs(xls_function); + std::vector inputs = GetJittedFunctionInputs(xls_function); + std::vector outputs = GetJittedFunctionOutputs(xls_function); LlvmFunctionWrapper wrapper = LlvmFunctionWrapper::Create( base_name, inputs, outputs, llvm::Type::getInt64Ty(jit_context.context()), jit_context, @@ -1174,14 +1198,12 @@ absl::StatusOr BuildPackedWrapper( FunctionBase* xls_function, llvm::Function* callee, JitBuilderContext& jit_context) { llvm::LLVMContext* context = &jit_context.context(); - std::vector inputs = GetJittedFunctionInputs(xls_function); - std::vector outputs = GetJittedFunctionOutputs(xls_function); + std::vector inputs = GetJittedFunctionInputs(xls_function); + std::vector outputs = GetJittedFunctionOutputs(xls_function); LlvmFunctionWrapper wrapper = LlvmFunctionWrapper::Create( absl::StrFormat("%s_packed", jit_context.MangleFunctionName(xls_function)), - GetJittedFunctionInputs(xls_function), - GetJittedFunctionOutputs(xls_function), llvm::Type::getInt64Ty(*context), - jit_context, + inputs, outputs, llvm::Type::getInt64Ty(*context), jit_context, LlvmFunctionWrapper::FunctionArg{ .name = "continuation_point", .type = llvm::Type::getInt64Ty(*context)}); @@ -1203,9 +1225,9 @@ absl::StatusOr BuildPackedWrapper( llvm::Type* pointer_array_type = llvm::ArrayType::get(llvm::PointerType::getUnqual(*context), 0); for (int64_t i = 0; i < inputs.size(); ++i) { - Node* input = inputs[i]; + const JitInput& input = inputs[i]; llvm::Value* input_buffer = wrapper.entry_builder().CreateAlloca( - jit_context.type_converter().ConvertToLlvmType(input->GetType())); + jit_context.type_converter().ConvertToLlvmType(input.type)); llvm::Value* gep = wrapper.entry_builder().CreateGEP( pointer_array_type, input_arg_array, { @@ -1214,12 +1236,12 @@ absl::StatusOr BuildPackedWrapper( }); wrapper.entry_builder().CreateStore(input_buffer, gep); - if (input->GetType()->GetFlatBitCount() > 0) { + if (input.type->GetFlatBitCount() > 0) { llvm::Value* packed_buffer = LoadPointerFromPointerArray( i, wrapper.GetInputsArg(), &wrapper.entry_builder()); - XLS_RETURN_IF_ERROR(UnpackValue( - packed_buffer, input_buffer, input->GetType(), /*bit_offset=*/0, - jit_context.type_converter(), &wrapper.entry_builder())); + XLS_RETURN_IF_ERROR( + UnpackValue(packed_buffer, input_buffer, input.type, /*bit_offset=*/0, + jit_context.type_converter(), &wrapper.entry_builder())); } } @@ -1227,9 +1249,9 @@ absl::StatusOr BuildPackedWrapper( wrapper.entry_builder().CreateAlloca(llvm::ArrayType::get( llvm::PointerType::get(*context, 0), outputs.size())); for (int64_t i = 0; i < outputs.size(); ++i) { - Node* output = outputs[i]; + const JitInput& output = outputs[i]; llvm::Value* output_buffer = wrapper.entry_builder().CreateAlloca( - jit_context.type_converter().ConvertToLlvmType(OutputType(output))); + jit_context.type_converter().ConvertToLlvmType(output.type)); llvm::Value* gep = wrapper.entry_builder().CreateGEP( pointer_array_type, output_arg_array, { @@ -1253,7 +1275,7 @@ absl::StatusOr BuildPackedWrapper( // After returning, pack the value into the return value buffer. for (int64_t i = 0; i < outputs.size(); ++i) { - Node* output = outputs[i]; + const JitInput& output = outputs[i]; // Declare the return argument as an iX, and pack the actual data as such // an integer. @@ -1262,10 +1284,10 @@ absl::StatusOr BuildPackedWrapper( llvm::Value* packed_output_buffer = LoadPointerFromPointerArray( i, wrapper.GetOutputsArg(), &wrapper.entry_builder()); - XLS_RETURN_IF_ERROR(PackValue( - unpacked_output_buffer, packed_output_buffer, OutputType(output), - /*bit_offset=*/0, jit_context.type_converter(), - &wrapper.entry_builder())); + XLS_RETURN_IF_ERROR( + PackValue(unpacked_output_buffer, packed_output_buffer, output.type, + /*bit_offset=*/0, jit_context.type_converter(), + &wrapper.entry_builder())); } // Return value of zero means that the FunctionBase completed execution. @@ -1382,15 +1404,13 @@ absl::StatusOr JittedFunctionBase::BuildInternal( } } - for (const Node* input : GetJittedFunctionInputs(xls_function)) { - Type* input_type = InputType(input); + for (const JitInput& input : GetJittedFunctionInputs(xls_function)) { jitted_function.input_buffer_metadata_.push_back( - jit_context.type_converter().GetTypeBufferMetadata(input_type)); + jit_context.type_converter().GetTypeBufferMetadata(input.type)); } - for (const Node* output : GetJittedFunctionOutputs(xls_function)) { - Type* output_type = OutputType(output); + for (const JitInput& output : GetJittedFunctionOutputs(xls_function)) { jitted_function.output_buffer_metadata_.push_back( - jit_context.type_converter().GetTypeBufferMetadata(output_type)); + jit_context.type_converter().GetTypeBufferMetadata(output.type)); } jitted_function.temp_buffer_size_ = allocator.size(); jitted_function.temp_buffer_alignment_ = allocator.alignment(); diff --git a/xls/passes/BUILD b/xls/passes/BUILD index b2654fa65c..42dc97e2fa 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -2555,6 +2555,7 @@ xls_pass( "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) diff --git a/xls/passes/array_untuple_pass.cc b/xls/passes/array_untuple_pass.cc index b07c45cfda..76e095c707 100644 --- a/xls/passes/array_untuple_pass.cc +++ b/xls/passes/array_untuple_pass.cc @@ -118,17 +118,26 @@ absl::StatusOr> FindExternalGroups( // Don't mess with params that are only used in identity updates. Would // infinite loop otherwise since we don't remove these very often. for (StateElement* state_element : f->AsProcOrDie()->StateElements()) { - StateRead* state_read = - f->AsProcOrDie()->GetStateReadByStateElement(state_element); - if (absl::c_all_of(state_read->users(), [&](Node* n) -> bool { - if (n->Is()) { - Next* nxt = n->As(); - return nxt->state_read() == nxt->value() && - nxt->state_read() == state_read; - } - return false; - })) { - excluded.insert(groups.Find(state_read)); + absl::Span state_reads = + f->AsProcOrDie()->GetStateReadsByStateElement(state_element); + bool all_reads_identity = true; + for (StateRead* state_read : state_reads) { + if (!absl::c_all_of(state_read->users(), [&](Node* n) -> bool { + if (n->Is()) { + Next* nxt = n->As(); + return nxt->state_read() == nxt->value() && + nxt->state_read() == state_read; + } + return false; + })) { + all_reads_identity = false; + break; + } + } + if (all_reads_identity) { + for (StateRead* state_read : state_reads) { + excluded.insert(groups.Find(state_read)); + } } } } diff --git a/xls/passes/canonicalization_pass_test.cc b/xls/passes/canonicalization_pass_test.cc index bc6ce425eb..a1544ba1f8 100644 --- a/xls/passes/canonicalization_pass_test.cc +++ b/xls/passes/canonicalization_pass_test.cc @@ -348,10 +348,11 @@ TEST_F(CanonicalizePassTest, StateReadWithAlwaysTruePredicate) { /*read_predicate=*/pb.Literal(UBits(1, 1))); pb.Next(x, pb.Literal(UBits(1, 32))); XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); - EXPECT_THAT(proc->GetStateRead(0)->predicate(), Optional(m::Literal(1))); + EXPECT_THAT(proc->GetStateReads(0).front()->predicate(), + Optional(m::Literal(1))); EXPECT_THAT(Run(p.get()), IsOkAndHolds(true)); - EXPECT_EQ(proc->GetStateRead(0)->predicate(), std::nullopt); + EXPECT_EQ(proc->GetStateReads(0).front()->predicate(), std::nullopt); } void IrFuzzCanonicalization(FuzzPackageWithArgs fuzz_package_with_args) { diff --git a/xls/passes/conditional_specialization_pass_test.cc b/xls/passes/conditional_specialization_pass_test.cc index 73ac38a809..7c9f7b19a2 100644 --- a/xls/passes/conditional_specialization_pass_test.cc +++ b/xls/passes/conditional_specialization_pass_test.cc @@ -1266,14 +1266,16 @@ TEST_F(ConditionalSpecializationPassTest, StateReadSpecialization) { Run(proc, /*use_bdd=*/true, /*optimize_for_best_case_throughput=*/true), IsOkAndHolds(true)); - EXPECT_THAT( - proc->GetStateReadByStateElement(*proc->GetStateElementByName("counter0")) - ->predicate(), - Optional(m::Not(m::StateRead("index")))); - EXPECT_THAT( - proc->GetStateReadByStateElement(*proc->GetStateElementByName("counter1")) - ->predicate(), - Optional(m::StateRead("index"))); + EXPECT_THAT(proc->GetStateReadsByStateElement( + *proc->GetStateElementByName("counter0")) + .front() + ->predicate(), + Optional(m::Not(m::StateRead("index")))); + EXPECT_THAT(proc->GetStateReadsByStateElement( + *proc->GetStateElementByName("counter1")) + .front() + ->predicate(), + Optional(m::StateRead("index"))); } TEST_F(ConditionalSpecializationPassTest, HarderStateReadSpecialization) { @@ -1303,16 +1305,20 @@ TEST_F(ConditionalSpecializationPassTest, HarderStateReadSpecialization) { Run(proc, /*use_bdd=*/true, /*optimize_for_best_case_throughput=*/true), IsOkAndHolds(true)); + EXPECT_THAT(proc->GetStateReadsByStateElement( + *proc->GetStateElementByName("counter0")) + .front() + ->predicate(), + Optional(m::Eq(m::StateRead("index"), m::Literal(0)))); + EXPECT_THAT(proc->GetStateReadsByStateElement( + *proc->GetStateElementByName("counter1")) + .front() + ->predicate(), + Optional(m::Eq(m::StateRead("index"), m::Literal(1)))); EXPECT_THAT( - proc->GetStateReadByStateElement(*proc->GetStateElementByName("counter0")) - ->predicate(), - Optional(m::Eq(m::StateRead("index"), m::Literal(0)))); - EXPECT_THAT( - proc->GetStateReadByStateElement(*proc->GetStateElementByName("counter1")) - ->predicate(), - Optional(m::Eq(m::StateRead("index"), m::Literal(1)))); - EXPECT_THAT( - proc->GetStateReadByStateElement(*proc->GetStateElementByName("counter2")) + proc->GetStateReadsByStateElement( + *proc->GetStateElementByName("counter2")) + .front() ->predicate(), Optional(m::And( // High bit of index is set diff --git a/xls/passes/proc_state_array_flattening_pass.cc b/xls/passes/proc_state_array_flattening_pass.cc index 01af175c31..f68f2272e1 100644 --- a/xls/passes/proc_state_array_flattening_pass.cc +++ b/xls/passes/proc_state_array_flattening_pass.cc @@ -20,6 +20,7 @@ #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "xls/common/math_util.h" #include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" @@ -131,10 +132,19 @@ absl::StatusOr SimplifyProcState(Proc* proc, Value new_init_value = Value::Tuple(old_init_value.elements()); ArrayToTupleStateTransformer transformer; - XLS_RETURN_IF_ERROR(proc->TransformStateElement( - proc->GetStateReadByStateElement(state_element), - new_init_value, transformer) - .status()); + absl::Span state_reads = + proc->GetStateReadsByStateElement(state_element); + XLS_RET_CHECK_LE(state_reads.size(), 1) + << "ProcStateArrayFlatteningPass only supports at most one StateRead per " + "StateElement for now."; + if (state_reads.empty()) { + VLOG(3) << "Not flattening proc state with no reads: " + << state_element->ToString(); + return false; + } + XLS_RETURN_IF_ERROR( + proc->TransformStateElement(state_reads[0], new_init_value, transformer) + .status()); std::vector old_next_values(proc->next_values(state_element).begin(), proc->next_values(state_element).end()); diff --git a/xls/passes/proc_state_bits_shattering_pass.cc b/xls/passes/proc_state_bits_shattering_pass.cc index 3e0405aed5..d691a299a2 100644 --- a/xls/passes/proc_state_bits_shattering_pass.cc +++ b/xls/passes/proc_state_bits_shattering_pass.cc @@ -116,7 +116,15 @@ absl::StatusOr MaybeSplitStateElements( // to use STL set intersection algorithms. std::vector split_ends; bool could_benefit_from_splitting = false; - StateRead* state_read = proc->GetStateReadByStateElement(state_element); + absl::Span state_reads = + proc->GetStateReadsByStateElement(state_element); + XLS_RET_CHECK_LE(state_reads.size(), 1) + << "ProcStateBitsShatteringPass only supports at most one StateRead " + "per StateElement for now."; + if (state_reads.empty()) { + continue; + } + StateRead* state_read = state_reads[0]; for (Next* next : proc->next_values(state_element)) { if (next->value() == state_read) { // This is a no-op next-value; it doesn't affect whether or not it's diff --git a/xls/passes/proc_state_narrowing_pass.cc b/xls/passes/proc_state_narrowing_pass.cc index 931b413725..5196685996 100644 --- a/xls/passes/proc_state_narrowing_pass.cc +++ b/xls/passes/proc_state_narrowing_pass.cc @@ -161,7 +161,15 @@ absl::StatusOr ProcStateNarrowingPass::RunOnProcInternal( << state_element->type()->ToString(); continue; } - StateRead* state_read = proc->GetStateReadByStateElement(state_element); + absl::Span state_reads = + proc->GetStateReadsByStateElement(state_element); + XLS_RET_CHECK_LE(state_reads.size(), 1) + << "ProcStateNarrowingPass only supports at most one StateRead per " + "StateElement for now."; + if (state_reads.empty()) { + continue; + } + StateRead* state_read = state_reads[0]; std::optional> ternary = qe.GetTernary(state_read); if (!ternary) { diff --git a/xls/passes/proc_state_optimization_pass.cc b/xls/passes/proc_state_optimization_pass.cc index 75666a6ece..9bd352e056 100644 --- a/xls/passes/proc_state_optimization_pass.cc +++ b/xls/passes/proc_state_optimization_pass.cc @@ -79,7 +79,6 @@ absl::StatusOr RemoveZeroWidthStateElements(Proc* proc) { StateElement* state_element = proc->GetStateElement(i); VLOG(2) << "Removing zero-width state element: " << proc->GetStateElement(i)->name(); - StateRead* state_read = proc->GetStateReadByStateElement(state_element); std::vector next_values(proc->next_values(state_element).begin(), proc->next_values(state_element).end()); for (Next* next : next_values) { @@ -87,9 +86,13 @@ absl::StatusOr RemoveZeroWidthStateElements(Proc* proc) { next->ReplaceUsesWithNew(Value::Tuple({})).status()); XLS_RETURN_IF_ERROR(proc->RemoveNode(next)); } - XLS_RETURN_IF_ERROR( - state_read->ReplaceUsesWithNew(state_element->initial_value()) - .status()); + for (StateRead* state_read : + proc->GetStateReadsByStateElement(state_element)) { + XLS_RETURN_IF_ERROR( + state_read + ->ReplaceUsesWithNew(state_element->initial_value()) + .status()); + } VLOG(4) << "Removing state element " << proc->StateElements()[i] << " for being zero width."; XLS_RETURN_IF_ERROR(proc->RemoveStateElement(i)); @@ -102,12 +105,20 @@ absl::StatusOr RemoveConstantStateElements(Proc* proc, std::vector to_remove; for (int64_t i = proc->GetStateElementCount() - 1; i >= 0; --i) { StateElement* state_element = proc->GetStateElement(i); - StateRead* state_read = proc->GetStateReadByStateElement(state_element); + absl::Span reads = + proc->GetStateReadsByStateElement(state_element); const Value& initial_value = state_element->initial_value(); bool never_changes = true; for (Next* next : proc->next_values(state_element)) { - if (next->value() == state_read) { + bool is_read = false; + for (StateRead* read : reads) { + if (next->value() == read) { + is_read = true; + break; + } + } + if (is_read) { continue; } std::optional next_value = query_engine.KnownValue(next->value()); @@ -128,7 +139,6 @@ absl::StatusOr RemoveConstantStateElements(Proc* proc, Value value = state_element->initial_value(); VLOG(2) << "Removing constant state element: " << state_element->name() << " (value: " << value.ToString() << ")"; - StateRead* state_read = proc->GetStateReadByStateElement(state_element); std::vector next_values(proc->next_values(state_element).begin(), proc->next_values(state_element).end()); for (Next* next : next_values) { @@ -136,8 +146,11 @@ absl::StatusOr RemoveConstantStateElements(Proc* proc, next->ReplaceUsesWithNew(Value::Tuple({})).status()); XLS_RETURN_IF_ERROR(proc->RemoveNode(next)); } - XLS_RETURN_IF_ERROR( - state_read->ReplaceUsesWithNew(value).status()); + for (StateRead* state_read : + proc->GetStateReadsByStateElement(state_element)) { + XLS_RETURN_IF_ERROR( + state_read->ReplaceUsesWithNew(value).status()); + } VLOG(4) << "Removing state element " << proc->StateElements()[i] << " for being constant."; XLS_RETURN_IF_ERROR(proc->RemoveStateElement(i)); @@ -250,7 +263,7 @@ ComputeStateDependencies(Proc* proc, OptimizationContext& context) { std::vector dependent_elements; for (int64_t i = 0; i < proc->GetStateElementCount(); ++i) { if (state_dependencies.at(node).Get(i)) { - dependent_elements.push_back(proc->GetStateRead(i)->GetName()); + dependent_elements.push_back(proc->GetStateElement(i)->name()); } } VLOG(5) << absl::StrFormat(" %s : {%s}%s", node->GetName(), @@ -327,7 +340,6 @@ absl::StatusOr RemoveUnobservableStateElements( // Replace uses of to-be-removed state elements with a zero-valued literal, // and remove their next_value nodes. for (int64_t i : to_remove) { - StateRead* state_read = proc->GetStateRead(i); absl::btree_set next_values = proc->next_values(proc->GetStateElement(i)); for (Next* next : next_values) { @@ -335,11 +347,13 @@ absl::StatusOr RemoveUnobservableStateElements( next->ReplaceUsesWithNew(Value::Tuple({})).status()); XLS_RETURN_IF_ERROR(proc->RemoveNode(next)); } - if (!state_read->IsDead()) { - XLS_RETURN_IF_ERROR( - state_read - ->ReplaceUsesWithNew(ZeroOfType(state_read->GetType())) - .status()); + for (StateRead* state_read : proc->GetStateReads(i)) { + if (!state_read->IsDead()) { + XLS_RETURN_IF_ERROR( + state_read + ->ReplaceUsesWithNew(ZeroOfType(state_read->GetType())) + .status()); + } } } @@ -448,10 +462,12 @@ absl::Status ConstantChainToStateMachine(Proc* proc, next->ReplaceUsesWithNew(Value::Tuple({})).status()); XLS_RETURN_IF_ERROR(proc->RemoveNode(next)); } - XLS_RETURN_IF_ERROR(proc->GetStateRead(state_index) - ->ReplaceUsesWithNew(state_machine_read, + cases, chain_literal) + .status()); + } indices_to_remove.insert(state_index); } for (int64_t state_index : indices_to_remove) { diff --git a/xls/passes/proc_state_provenance_narrowing_pass.cc b/xls/passes/proc_state_provenance_narrowing_pass.cc index c342030193..9f88528586 100644 --- a/xls/passes/proc_state_provenance_narrowing_pass.cc +++ b/xls/passes/proc_state_provenance_narrowing_pass.cc @@ -193,12 +193,12 @@ class NarrowTransform final : public Proc::StateElementTransformer { std::vector segments_; }; -absl::StatusOr UnchangedBits(Proc* proc, StateElement* state_element, +absl::StatusOr UnchangedBits(Proc* proc, StateRead* state_read, const Bits& initial_bits, const QueryEngine& query_engine, BitProvenanceAnalysis& provenance) { + StateElement* state_element = state_read->state_element(); Bits unchanged_bits = Bits::AllOnes(initial_bits.bit_count()); - StateRead* state_read = proc->GetStateReadByStateElement(state_element); for (Next* next : proc->next_values(state_element)) { if (next->value() == state_read) { // Pass-through nexts are trivially unaffecting. @@ -256,7 +256,7 @@ absl::StatusOr ProcStateProvenanceNarrowingPass::RunOnProcInternal( BitProvenanceAnalysis provenance; bool made_changes = false; - std::vector> transforms; + std::vector> transforms; for (StateElement* state_element : proc->StateElements()) { if (!state_element->type()->IsBits()) { @@ -264,12 +264,23 @@ absl::StatusOr ProcStateProvenanceNarrowingPass::RunOnProcInternal( // worthwhile. continue; } + absl::Span state_reads = + proc->GetStateReadsByStateElement(state_element); + XLS_RET_CHECK_LE(state_reads.size(), 1) + << "ProcStateProvenanceNarrowingPass only supports at most one " + "StateRead per StateElement for now."; + if (state_reads.empty()) { + VLOG(3) << "Skipping state element with no reads: " + << state_element->name(); + continue; + } + StateRead* state_read = state_reads[0]; Value init = state_element->initial_value(); XLS_RET_CHECK(init.IsBits()); const Bits& initial_bits = init.bits(); XLS_ASSIGN_OR_RETURN( Bits unchanged_bits, - UnchangedBits(proc, state_element, initial_bits, *qe, provenance)); + UnchangedBits(proc, state_read, initial_bits, *qe, provenance)); // Do the actual splitting if (unchanged_bits.IsZero()) { VLOG(3) << "Unable to narrow " << state_element->name() @@ -285,15 +296,14 @@ absl::StatusOr ProcStateProvenanceNarrowingPass::RunOnProcInternal( << (unchanged_bits.bit_count() - unchanged_bits.PopCount()); Bits narrowed_init = NarrowValue(initial_bits, segments); transforms.push_back( - {state_element, NarrowTransform(std::move(segments)), narrowed_init}); + {state_read, NarrowTransform(std::move(segments)), narrowed_init}); } - for (auto& [state_element, transform, narrowed_init] : transforms) { + for (auto& [state_read, transform, narrowed_init] : transforms) { made_changes = true; - XLS_RETURN_IF_ERROR(proc->TransformStateElement( - proc->GetStateReadByStateElement(state_element), - Value(narrowed_init), transform) - .status()); + XLS_RETURN_IF_ERROR( + proc->TransformStateElement(state_read, Value(narrowed_init), transform) + .status()); } return made_changes; diff --git a/xls/passes/proc_state_range_query_engine.cc b/xls/passes/proc_state_range_query_engine.cc index 60e8efbf54..1001b2ff5d 100644 --- a/xls/passes/proc_state_range_query_engine.cc +++ b/xls/passes/proc_state_range_query_engine.cc @@ -587,7 +587,17 @@ absl::StatusOr> NarrowUsingSegments( CHECK(remaining_intervals.contains(Interval::Precise(init_value.bits()))) << "Initial value not included in constant values."; remaining_intervals.erase(Interval::Precise(init_value.bits())); - StateRead* state_read = proc->GetStateReadByStateElement(state_element); + absl::Span state_reads = + proc->GetStateReadsByStateElement(state_element); + XLS_RET_CHECK_LE(state_reads.size(), 1) + << "ProcStateRangeQueryEngine only supports at most one StateRead per " + "StateElement for now."; + if (state_reads.empty()) { + VLOG(3) << "Cannot narrow proc state with no reads: " + << state_element->ToString(); + return std::nullopt; + } + StateRead* state_read = state_reads[0]; XLS_ASSIGN_OR_RETURN( SegmentRangeData limiter, SegmentRangeData::Create(nda, ground_truth, state_read, topo_sort)); @@ -854,8 +864,12 @@ absl::StatusOr ProcStateRangeQueryEngine ::Populate( absl::flat_hash_map state_read_intervals; state_read_intervals.reserve(final_range_data.size()); for (const auto& [state_element, range] : final_range_data) { - state_read_intervals[proc->GetStateReadByStateElement(state_element)] = - range.interval_set.Get({}); + absl::Span state_reads = + proc->GetStateReadsByStateElement(state_element); + XLS_RET_CHECK_LE(state_reads.size(), 1); + if (!state_reads.empty()) { + state_read_intervals[state_reads[0]] = range.interval_set.Get({}); + } } ProcStateGivens givens(proc, std::move(state_read_intervals)); XLS_RETURN_IF_ERROR(spec_ternary.PopulateWithGivens(proc, givens).status()); diff --git a/xls/passes/proc_state_tuple_flattening_pass.cc b/xls/passes/proc_state_tuple_flattening_pass.cc index 82df10745e..7dddecb7c9 100644 --- a/xls/passes/proc_state_tuple_flattening_pass.cc +++ b/xls/passes/proc_state_tuple_flattening_pass.cc @@ -196,7 +196,13 @@ absl::Status FlattenState(Proc* proc) { for (int64_t state_index = 0; state_index < proc->GetStateElementCount(); ++state_index) { StateElement* state_element = proc->GetStateElement(state_index); - StateRead* state_read = proc->GetStateReadByStateElement(state_element); + absl::Span state_reads = + proc->GetStateReadsByStateElement(state_element); + XLS_RET_CHECK_LE(state_reads.size(), 1) + << "ProcStateTupleFlatteningPass only supports at most one StateRead " + "per " + "StateElement for now."; + StateRead* state_read = state_reads.empty() ? nullptr : state_reads[0]; // Gather the flattened initial values and next state elements. std::vector init_values = @@ -220,8 +226,10 @@ absl::Status FlattenState(Proc* proc) { element.initial_value = init_values[i]; XLS_ASSIGN_OR_RETURN( element.placeholder, - proc->MakeNode(state_read->loc(), init_values[i])); - element.read_predicate = state_read->predicate(); + proc->MakeNode(state_read ? state_read->loc() : SourceInfo(), + init_values[i])); + element.read_predicate = + state_read ? state_read->predicate() : std::nullopt; placeholders.push_back(element.placeholder); elements.push_back(std::move(element)); @@ -281,29 +289,33 @@ absl::Status FlattenState(Proc* proc) { // ensure that any users that are also being replaced have an identity node // put in place; this prevents us from accidentally referencing the // placeholder after it's intended to be replaced with the new state param. - Node* state_read_identity = nullptr; - // Copy the users of the state read to avoid invalidating the iterator. - std::vector users(state_read->users().begin(), - state_read->users().end()); - for (Node* user : users) { - if (!user->OpIn({Op::kStateRead, Op::kNext})) { - continue; - } - if (state_read_identity == nullptr) { - XLS_ASSIGN_OR_RETURN( - state_read_identity, - proc->MakeNode(state_read->loc(), state_read, Op::kIdentity)); - identities.push_back(state_read_identity); + if (state_read != nullptr) { + Node* state_read_identity = nullptr; + // Copy the users of the state read to avoid invalidating the iterator. + std::vector users(state_read->users().begin(), + state_read->users().end()); + for (Node* user : users) { + if (!user->OpIn({Op::kStateRead, Op::kNext})) { + continue; + } + if (state_read_identity == nullptr) { + XLS_ASSIGN_OR_RETURN(state_read_identity, + proc->MakeNode(state_read->loc(), + state_read, Op::kIdentity)); + identities.push_back(state_read_identity); + } + user->ReplaceOperand(state_read, state_read_identity); } - user->ReplaceOperand(state_read, state_read_identity); } // Create a node of the same type as the old state param but constructed // from the new (decomposed) state params placeholders. XLS_ASSIGN_OR_RETURN( Node * old_param_replacement, - ComposeNode(state_read->GetType(), placeholders, proc)); - XLS_RETURN_IF_ERROR(state_read->ReplaceUsesWith(old_param_replacement)); + ComposeNode(state_element->type(), placeholders, proc)); + if (state_read != nullptr) { + XLS_RETURN_IF_ERROR(state_read->ReplaceUsesWith(old_param_replacement)); + } } XLS_RETURN_IF_ERROR(ReplaceProcState(proc, elements)); diff --git a/xls/passes/proc_state_tuple_flattening_pass_test.cc b/xls/passes/proc_state_tuple_flattening_pass_test.cc index 66fb892a2b..5120005252 100644 --- a/xls/passes/proc_state_tuple_flattening_pass_test.cc +++ b/xls/passes/proc_state_tuple_flattening_pass_test.cc @@ -195,12 +195,12 @@ TEST_P(ProcStateFlatteningPassTest, EmptyTupleAndBitsState) { // The name uniquer is told the names "y" and "q" are already released. So the // new state params get the same name. - EXPECT_EQ(proc->GetStateRead(0)->GetName(), "y"); + EXPECT_EQ(proc->GetStateReads(0).front()->GetName(), "y"); EXPECT_EQ(proc->GetStateElement(0)->initial_value(), Value(UBits(0, 32))); EXPECT_THAT(proc->next_values(proc->GetStateElement(0)), ElementsAre(m::Next(m::StateRead("y"), m::StateRead("y")))); - EXPECT_EQ(proc->GetStateRead(1)->GetName(), "q"); + EXPECT_EQ(proc->GetStateReads(1).front()->GetName(), "q"); EXPECT_EQ(proc->GetStateElement(1)->initial_value(), Value(UBits(0, 64))); EXPECT_THAT( proc->next_values(proc->GetStateElement(1)), @@ -220,7 +220,7 @@ TEST_P(ProcStateFlatteningPassTest, TrivialTupleState) { EXPECT_EQ(proc->GetStateElementCount(), 1); - EXPECT_EQ(proc->GetStateRead(0)->GetName(), "x"); + EXPECT_EQ(proc->GetStateReads(0).front()->GetName(), "x"); EXPECT_EQ(proc->GetStateElement(0)->initial_value(), Value(UBits(42, 32))); EXPECT_THAT( proc->next_values(proc->GetStateElement(0)), @@ -240,7 +240,7 @@ TEST_P(ProcStateFlatteningPassTest, TrivialTupleStateWithNextExpression) { EXPECT_EQ(proc->GetStateElementCount(), 1); - EXPECT_EQ(proc->GetStateRead(0)->GetName(), "x"); + EXPECT_EQ(proc->GetStateReads(0).front()->GetName(), "x"); EXPECT_EQ(proc->GetStateElement(0)->initial_value(), Value(UBits(42, 32))); EXPECT_THAT(proc->next_values(proc->GetStateElement(0)), UnorderedElementsAre( @@ -270,32 +270,32 @@ TEST_P(ProcStateFlatteningPassTest, ComplicatedState) { EXPECT_EQ(proc->GetStateElementCount(), 6); - EXPECT_EQ(proc->GetStateRead(0)->GetName(), "a_0"); + EXPECT_EQ(proc->GetStateReads(0).front()->GetName(), "a_0"); EXPECT_THAT( proc->next_values(proc->GetStateElement(0)), UnorderedElementsAre(m::Next(m::StateRead("a_0"), m::StateRead("b")))); - EXPECT_EQ(proc->GetStateRead(1)->GetName(), "a_1"); + EXPECT_EQ(proc->GetStateReads(1).front()->GetName(), "a_1"); EXPECT_THAT( proc->next_values(proc->GetStateElement(1)), UnorderedElementsAre(m::Next(m::StateRead("a_1"), m::StateRead("c_0")))); - EXPECT_EQ(proc->GetStateRead(2)->GetName(), "a_2"); + EXPECT_EQ(proc->GetStateReads(2).front()->GetName(), "a_2"); EXPECT_THAT( proc->next_values(proc->GetStateElement(2)), UnorderedElementsAre(m::Next(m::StateRead("a_2"), m::StateRead("c_1")))); - EXPECT_EQ(proc->GetStateRead(3)->GetName(), "b"); + EXPECT_EQ(proc->GetStateReads(3).front()->GetName(), "b"); EXPECT_THAT( proc->next_values(proc->GetStateElement(3)), UnorderedElementsAre(m::Next(m::StateRead("b"), m::StateRead("a_0")))); - EXPECT_EQ(proc->GetStateRead(4)->GetName(), "c_0"); + EXPECT_EQ(proc->GetStateReads(4).front()->GetName(), "c_0"); EXPECT_THAT( proc->next_values(proc->GetStateElement(4)), UnorderedElementsAre(m::Next(m::StateRead("c_0"), m::StateRead("a_1")))); - EXPECT_EQ(proc->GetStateRead(5)->GetName(), "c_1"); + EXPECT_EQ(proc->GetStateReads(5).front()->GetName(), "c_1"); EXPECT_THAT( proc->next_values(proc->GetStateElement(5)), UnorderedElementsAre(m::Next(m::StateRead("c_1"), m::StateRead("a_2")))); diff --git a/xls/passes/receive_default_value_simplification_pass_test.cc b/xls/passes/receive_default_value_simplification_pass_test.cc index 758884a196..becee32b74 100644 --- a/xls/passes/receive_default_value_simplification_pass_test.cc +++ b/xls/passes/receive_default_value_simplification_pass_test.cc @@ -73,14 +73,14 @@ TEST_F(ReceiveDefaultValueSimplificationPassTest, EXPECT_THAT(proc->next_values(proc->GetStateElement(1)), ElementsAre(m::Next( - proc->GetStateRead(1), + proc->GetStateReads(1).front(), m::Select(m::StateRead("pred"), {m::Literal(0), m::TupleIndex(m::Receive(), 1)})))); EXPECT_THAT(Run(proc), IsOkAndHolds(true)); EXPECT_THAT(proc->next_values(proc->GetStateElement(1)), - ElementsAre(m::Next(proc->GetStateRead(1), + ElementsAre(m::Next(proc->GetStateReads(1).front(), m::TupleIndex(m::Receive(), 1)))); } @@ -102,14 +102,14 @@ TEST_F(ReceiveDefaultValueSimplificationPassTest, EXPECT_THAT( proc->next_values(proc->GetStateElement(1)), ElementsAre(m::Next( - proc->GetStateRead(1), + proc->GetStateReads(1).front(), m::PrioritySelect(m::StateRead("pred"), {m::TupleIndex(m::Receive(), 1)}, m::Literal(0))))); EXPECT_THAT(Run(proc), IsOkAndHolds(true)); EXPECT_THAT(proc->next_values(proc->GetStateElement(1)), - ElementsAre(m::Next(proc->GetStateRead(1), + ElementsAre(m::Next(proc->GetStateReads(1).front(), m::TupleIndex(m::Receive(), 1)))); } @@ -132,14 +132,14 @@ TEST_F(ReceiveDefaultValueSimplificationPassTest, EXPECT_THAT(proc->next_values(proc->GetStateElement(1)), ElementsAre(m::Next( - proc->GetStateRead(1), + proc->GetStateReads(1).front(), m::Select(m::StateRead("pred"), {m::Literal(), m::TupleIndex(m::Receive(), 1)})))); EXPECT_THAT(Run(proc), IsOkAndHolds(true)); EXPECT_THAT(proc->next_values(proc->GetStateElement(1)), - ElementsAre(m::Next(proc->GetStateRead(1), + ElementsAre(m::Next(proc->GetStateReads(1).front(), m::TupleIndex(m::Receive(), 1)))); } @@ -159,14 +159,14 @@ TEST_F(ReceiveDefaultValueSimplificationPassTest, EXPECT_THAT(proc->next_values(proc->GetStateElement(0)), ElementsAre(m::Next( - proc->GetStateRead(0), + proc->GetStateReads(0).front(), m::Select(m::TupleIndex(), {m::Literal(0), m::TupleIndex(m::Receive(), 1)})))); EXPECT_THAT(Run(proc), IsOkAndHolds(true)); EXPECT_THAT(proc->next_values(proc->GetStateElement(0)), - ElementsAre(m::Next(proc->GetStateRead(0), + ElementsAre(m::Next(proc->GetStateReads(0).front(), m::TupleIndex(m::Receive(), 1)))); } @@ -187,14 +187,14 @@ TEST_F(ReceiveDefaultValueSimplificationPassTest, EXPECT_THAT(proc->next_values(proc->GetStateElement(1)), ElementsAre(m::Next( - proc->GetStateRead(1), + proc->GetStateReads(1).front(), m::Select(m::TupleIndex(), {m::Literal(0), m::TupleIndex(m::Receive(), 1)})))); EXPECT_THAT(Run(proc), IsOkAndHolds(true)); EXPECT_THAT(proc->next_values(proc->GetStateElement(1)), - ElementsAre(m::Next(proc->GetStateRead(1), + ElementsAre(m::Next(proc->GetStateReads(1).front(), m::TupleIndex(m::Receive(), 1)))); } diff --git a/xls/passes/token_simplification_pass_test.cc b/xls/passes/token_simplification_pass_test.cc index 74c6c868c4..8185b3ef2c 100644 --- a/xls/passes/token_simplification_pass_test.cc +++ b/xls/passes/token_simplification_pass_test.cc @@ -14,7 +14,6 @@ #include "xls/passes/token_simplification_pass.h" -#include #include #include @@ -68,7 +67,8 @@ TEST_F(TokenSimplificationPassTest, SingleArgument) { XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, p->GetTopAsProc()); EXPECT_THAT(Run(proc), IsOkAndHolds(true)); EXPECT_THAT(proc->next_values(proc->GetStateElement(0)), - ElementsAre(m::Next(proc->GetStateRead(0), m::StateRead("tok")))); + ElementsAre(m::Next(proc->GetStateReads(0).front(), + m::StateRead("tok")))); } TEST_F(TokenSimplificationPassTest, DuplicatedArgument) { @@ -84,7 +84,8 @@ TEST_F(TokenSimplificationPassTest, DuplicatedArgument) { XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, p->GetTopAsProc()); EXPECT_THAT(Run(proc), IsOkAndHolds(true)); EXPECT_THAT(proc->next_values(proc->GetStateElement(0)), - ElementsAre(m::Next(proc->GetStateRead(0), m::StateRead("tok")))); + ElementsAre(m::Next(proc->GetStateReads(0).front(), + m::StateRead("tok")))); } TEST_F(TokenSimplificationPassTest, NestedAfterAll) { @@ -101,7 +102,8 @@ TEST_F(TokenSimplificationPassTest, NestedAfterAll) { XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, p->GetTopAsProc()); EXPECT_THAT(Run(proc), IsOkAndHolds(true)); EXPECT_THAT(proc->next_values(proc->GetStateElement(0)), - ElementsAre(m::Next(proc->GetStateRead(0), m::StateRead("tok")))); + ElementsAre(m::Next(proc->GetStateReads(0).front(), + m::StateRead("tok")))); } TEST_F(TokenSimplificationPassTest, DelayZero) { @@ -117,7 +119,8 @@ TEST_F(TokenSimplificationPassTest, DelayZero) { XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, p->GetTopAsProc()); EXPECT_THAT(Run(proc), IsOkAndHolds(true)); EXPECT_THAT(proc->next_values(proc->GetStateElement(0)), - ElementsAre(m::Next(proc->GetStateRead(0), m::StateRead("tok")))); + ElementsAre(m::Next(proc->GetStateReads(0).front(), + m::StateRead("tok")))); } TEST_F(TokenSimplificationPassTest, NestedDelay) { @@ -135,7 +138,7 @@ TEST_F(TokenSimplificationPassTest, NestedDelay) { EXPECT_THAT(Run(proc), IsOkAndHolds(true)); EXPECT_THAT( proc->next_values(proc->GetStateElement(0)), - ElementsAre(m::Next(proc->GetStateRead(0), + ElementsAre(m::Next(proc->GetStateReads(0).front(), m::MinDelay(m::StateRead("tok"), /*delay=*/3)))); } @@ -156,7 +159,7 @@ TEST_F(TokenSimplificationPassTest, AfterAllWithCommonDelay) { EXPECT_THAT(Run(proc), IsOkAndHolds(true)); EXPECT_THAT( proc->next_values(proc->GetStateElement(0)), - ElementsAre(m::Next(proc->GetStateRead(0), + ElementsAre(m::Next(proc->GetStateReads(0).front(), m::MinDelay(m::StateRead("tok"), /*delay=*/6)))); } @@ -183,7 +186,7 @@ TEST_F(TokenSimplificationPassTest, DuplicatedArgument2) { EXPECT_THAT( proc->next_values(proc->GetStateElement(0)), ElementsAre(m::Next( - proc->GetStateRead(0), + proc->GetStateReads(0).front(), m::AfterAll( m::Send(m::Send(m::StateRead("tok"), m::Literal()), m::Literal()), m::Send(m::StateRead("tok"), m::Literal()))))); @@ -211,7 +214,7 @@ TEST_F(TokenSimplificationPassTest, UnrelatedArguments) { EXPECT_THAT(Run(proc), IsOkAndHolds(false)); EXPECT_THAT(proc->next_values(proc->GetStateElement(0)), ElementsAre(m::Next( - proc->GetStateRead(0), + proc->GetStateReads(0).front(), m::AfterAll(m::Send(m::StateRead("tok"), m::Literal()), m::Send(m::StateRead("tok"), m::Literal()), m::Send(m::StateRead("tok"), m::Literal()))))); @@ -237,7 +240,7 @@ TEST_F(TokenSimplificationPassTest, ArgumentsWithDependencies) { XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, p->GetTopAsProc()); EXPECT_THAT(Run(proc), IsOkAndHolds(true)); EXPECT_THAT(proc->next_values(proc->GetStateElement(0)), - ElementsAre(m::Next(proc->GetStateRead(0), m::Send()))); + ElementsAre(m::Next(proc->GetStateReads(0).front(), m::Send()))); } TEST_F(TokenSimplificationPassTest, DoNotRelyOnInvokeForDependencies) { @@ -266,7 +269,7 @@ TEST_F(TokenSimplificationPassTest, DoNotRelyOnInvokeForDependencies) { EXPECT_THAT(Run(proc), IsOkAndHolds(true)); EXPECT_THAT( proc->next_values(proc->GetStateElement(0)), - ElementsAre(m::Next(proc->GetStateRead(0), + ElementsAre(m::Next(proc->GetStateReads(0).front(), m::AfterAll(m::Send(), m::Send(), m::Invoke())))); } diff --git a/xls/passes/useless_io_removal_pass_test.cc b/xls/passes/useless_io_removal_pass_test.cc index ea6b597a9e..302854cc36 100644 --- a/xls/passes/useless_io_removal_pass_test.cc +++ b/xls/passes/useless_io_removal_pass_test.cc @@ -111,10 +111,10 @@ TEST_F(UselessIORemovalPassTest, RemoveSendIfLiteralFalse) { int64_t original_node_count = proc->node_count(); EXPECT_THAT(Run(p.get()), IsOkAndHolds(true)); EXPECT_EQ(proc->node_count(), original_node_count - 3); - EXPECT_THAT( - proc->next_values(proc->GetStateElement(0)), - ElementsAre(m::Next(proc->GetStateRead(0), - m::Send(proc->GetStateRead(0), m::Literal(1))))); + EXPECT_THAT(proc->next_values(proc->GetStateElement(0)), + ElementsAre(m::Next( + proc->GetStateReads(0).front(), + m::Send(proc->GetStateReads(0).front(), m::Literal(1))))); } TEST_F(UselessIORemovalPassTest, RemoveSendIfLiteralFalseNewStyle) { @@ -179,12 +179,12 @@ TEST_F(UselessIORemovalPassTest, RemoveReceiveNonBlockingIfLiteralFalse) { m::TupleIndex(m::Receive(m::StateRead("tkn"), m::Channel("test_channel")), 0), m::Literal(0), m::Literal(0)); - EXPECT_THAT( - proc->next_values(proc->GetStateElement(0)), - ElementsAre(m::Next(proc->GetStateRead(0), m::TupleIndex(tuple, 0)))); - EXPECT_THAT( - proc->next_values(proc->GetStateElement(1)), - ElementsAre(m::Next(proc->GetStateRead(1), m::TupleIndex(tuple, 1)))); + EXPECT_THAT(proc->next_values(proc->GetStateElement(0)), + ElementsAre(m::Next(proc->GetStateReads(0).front(), + m::TupleIndex(tuple, 0)))); + EXPECT_THAT(proc->next_values(proc->GetStateElement(1)), + ElementsAre(m::Next(proc->GetStateReads(1).front(), + m::TupleIndex(tuple, 1)))); } TEST_F(UselessIORemovalPassTest, RemoveReceiveIfLiteralFalse) { @@ -209,12 +209,12 @@ TEST_F(UselessIORemovalPassTest, RemoveReceiveIfLiteralFalse) { m::TupleIndex(m::Receive(m::StateRead("tkn"), m::Channel("test_channel")), 0), m::Literal(0)); - EXPECT_THAT( - proc->next_values(proc->GetStateElement(0)), - ElementsAre(m::Next(proc->GetStateRead(0), m::TupleIndex(tuple, 0)))); - EXPECT_THAT( - proc->next_values(proc->GetStateElement(1)), - ElementsAre(m::Next(proc->GetStateRead(1), m::TupleIndex(tuple, 1)))); + EXPECT_THAT(proc->next_values(proc->GetStateElement(0)), + ElementsAre(m::Next(proc->GetStateReads(0).front(), + m::TupleIndex(tuple, 0)))); + EXPECT_THAT(proc->next_values(proc->GetStateElement(1)), + ElementsAre(m::Next(proc->GetStateReads(1).front(), + m::TupleIndex(tuple, 1)))); } TEST_F(UselessIORemovalPassTest, RemoveSendPredIfLiteralTrue) { @@ -235,10 +235,11 @@ TEST_F(UselessIORemovalPassTest, RemoveSendPredIfLiteralTrue) { EXPECT_EQ(proc->node_count(), original_node_count - 1); EXPECT_THAT( proc->next_values(proc->GetStateElement(0)), - ElementsAre(m::Next(proc->GetStateRead(0), + ElementsAre(m::Next(proc->GetStateReads(0).front(), m::Send(m::StateRead("tkn"), m::Literal(1))))); - EXPECT_THAT(proc->next_values(proc->GetStateElement(1)), - ElementsAre(m::Next(proc->GetStateRead(1), m::Literal(0)))); + EXPECT_THAT( + proc->next_values(proc->GetStateElement(1)), + ElementsAre(m::Next(proc->GetStateReads(1).front(), m::Literal(0)))); } TEST_F(UselessIORemovalPassTest, RemoveReceivePredIfLiteralTrue) { @@ -259,12 +260,12 @@ TEST_F(UselessIORemovalPassTest, RemoveReceivePredIfLiteralTrue) { EXPECT_THAT(Run(p.get()), IsOkAndHolds(true)); EXPECT_EQ(proc->node_count(), original_node_count - 1); auto tuple = m::Receive(m::StateRead("tkn"), m::Channel("test_channel")); - EXPECT_THAT( - proc->next_values(proc->GetStateElement(0)), - ElementsAre(m::Next(proc->GetStateRead(0), m::TupleIndex(tuple, 0)))); - EXPECT_THAT( - proc->next_values(proc->GetStateElement(1)), - ElementsAre(m::Next(proc->GetStateRead(1), m::TupleIndex(tuple, 1)))); + EXPECT_THAT(proc->next_values(proc->GetStateElement(0)), + ElementsAre(m::Next(proc->GetStateReads(0).front(), + m::TupleIndex(tuple, 0)))); + EXPECT_THAT(proc->next_values(proc->GetStateElement(1)), + ElementsAre(m::Next(proc->GetStateReads(1).front(), + m::TupleIndex(tuple, 1)))); } TEST_F(UselessIORemovalPassTest, DontRemoveLastSendIfOnSendOnlyChannel) { diff --git a/xls/scheduling/proc_state_legalization_pass.cc b/xls/scheduling/proc_state_legalization_pass.cc index 857249f155..db4bf34abd 100644 --- a/xls/scheduling/proc_state_legalization_pass.cc +++ b/xls/scheduling/proc_state_legalization_pass.cc @@ -53,7 +53,13 @@ namespace { absl::StatusOr LegalizeStateReadPredicate( Proc* proc, StateElement* state_element, const SchedulingPassOptions& options) { - StateRead* state_read = proc->GetStateReadByStateElement(state_element); + absl::Span reads = + proc->GetStateReadsByStateElement(state_element); + XLS_RET_CHECK_LE(reads.size(), 1); + if (reads.empty()) { + return false; + } + StateRead* state_read = reads[0]; const absl::btree_set& next_values = proc->next_values(state_element); if (!state_read->predicate().has_value() || next_values.empty()) { @@ -215,7 +221,13 @@ absl::StatusOr AddMutualExclusionAsserts( absl::StatusOr AddWriteWithoutReadAsserts( Proc* proc, StateElement* state_element, const SchedulingPassOptions& options) { - StateRead* state_read = proc->GetStateReadByStateElement(state_element); + absl::Span reads = + proc->GetStateReadsByStateElement(state_element); + XLS_RET_CHECK_LE(reads.size(), 1); + if (reads.empty()) { + return false; + } + StateRead* state_read = reads[0]; if (!state_read->predicate().has_value()) { return false; } @@ -291,7 +303,13 @@ absl::StatusOr AddDefaultNextValue(Proc* proc, StateElement* state_element, const SchedulingPassOptions& options) { absl::btree_set predicates; - StateRead* state_read = proc->GetStateReadByStateElement(state_element); + absl::Span reads = + proc->GetStateReadsByStateElement(state_element); + XLS_RET_CHECK_LE(reads.size(), 1); + if (reads.empty()) { + return false; + } + StateRead* state_read = reads[0]; for (Next* next : proc->next_values(state_element)) { if (next->predicate().has_value()) { predicates.insert(*next->predicate()); diff --git a/xls/scheduling/proc_state_legalization_pass_test.cc b/xls/scheduling/proc_state_legalization_pass_test.cc index 668a585827..b47776b363 100644 --- a/xls/scheduling/proc_state_legalization_pass_test.cc +++ b/xls/scheduling/proc_state_legalization_pass_test.cc @@ -384,11 +384,13 @@ TEST_P(ProcStateLegalizationPassTest, ProcWithPredicatedStateRead) { ScopedRecordIr sri(p.get()); ASSERT_THAT(Run(proc), IsOkAndHolds(true)); - EXPECT_EQ(proc->GetStateReadByStateElement(*proc->GetStateElementByName("x")) + EXPECT_EQ(proc->GetStateReadsByStateElement(*proc->GetStateElementByName("x")) + .front() ->predicate(), std::nullopt); EXPECT_THAT( - proc->GetStateReadByStateElement(*proc->GetStateElementByName("y")) + proc->GetStateReadsByStateElement(*proc->GetStateElementByName("y")) + .front() ->predicate(), Optional(m::Or( m::Eq(m::UMod(m::StateRead("x"), m::Literal(2)), m::Literal(0)), @@ -443,11 +445,13 @@ TEST_P(ProcStateLegalizationPassTest, ScopedRecordIr sri(p.get()); ASSERT_THAT(Run(proc), IsOkAndHolds(true)); - EXPECT_EQ(proc->GetStateReadByStateElement(*proc->GetStateElementByName("x")) + EXPECT_EQ(proc->GetStateReadsByStateElement(*proc->GetStateElementByName("x")) + .front() ->predicate(), std::nullopt); EXPECT_THAT( - proc->GetStateReadByStateElement(*proc->GetStateElementByName("y")) + proc->GetStateReadsByStateElement(*proc->GetStateElementByName("y")) + .front() ->predicate(), Optional(m::Or( m::Eq(m::UMod(m::StateRead("x"), m::Literal(2)), m::Literal(0)), @@ -517,7 +521,8 @@ TEST_P(ProcStateLegalizationPassTest, m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), m::Literal(0)), m::Not(m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), m::Literal(0)))); EXPECT_THAT( - proc->GetStateReadByStateElement(*proc->GetStateElementByName("y")) + proc->GetStateReadsByStateElement(*proc->GetStateElementByName("y")) + .front() ->predicate(), Optional(expected_read_predicate)); EXPECT_THAT( diff --git a/xls/scheduling/schedule_util.cc b/xls/scheduling/schedule_util.cc index 21155d5ef4..80aa70e201 100644 --- a/xls/scheduling/schedule_util.cc +++ b/xls/scheduling/schedule_util.cc @@ -83,11 +83,12 @@ absl::StatusOr> GetDeadAfterSynthesisNodes( Proc* proc = f->AsProcOrDie(); for (StateElement* state_element : proc->StateElements()) { VLOG(2) << "Considering state element: " << state_element->name(); - if (live_after_synthesis.contains( - proc->GetStateReadByStateElement(state_element))) { - for (Next* next : - proc->GetStateReadByStateElement(state_element)->GetNextValues()) { - mark_live(next); + for (StateRead* state_read : + proc->GetStateReadsByStateElement(state_element)) { + if (live_after_synthesis.contains(state_read)) { + for (Next* next : state_read->GetNextValues()) { + mark_live(next); + } } } } diff --git a/xls/tools/delay_info_printer.cc b/xls/tools/delay_info_printer.cc index c9e72d0b7c..aec75d5c22 100644 --- a/xls/tools/delay_info_printer.cc +++ b/xls/tools/delay_info_printer.cc @@ -29,6 +29,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/types/span.h" #include "xls/common/file/filesystem.h" #include "xls/common/status/status_macros.h" #include "xls/estimators/delay_model/analyze_critical_path.h" @@ -255,8 +256,14 @@ class DelayInfoPrinterImpl : public DelayInfoPrinter { return true; } if (node->Is()) { - return node->As()->state_read() == - proc->GetStateReadByStateElement(state_element); + Node* next_sr = node->As()->state_read(); + for (StateRead* sr : + proc->GetStateReadsByStateElement(state_element)) { + if (next_sr == sr) { + return true; + } + } + return false; } return false; }));