Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xls/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,7 @@ cc_test(
"@com_google_absl//absl/base",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@googletest//:gtest",
],
)
Expand Down
8 changes: 5 additions & 3 deletions xls/ir/ir_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1895,9 +1895,11 @@ absl::StatusOr<Parser::BodyResult> Parser::ParseBody(
Proc * source_proc,
ParseProc(package, /*outer_attributes=*/{}, &source));
for (StateElement* element : source_proc->StateElements()) {
name_to_value->emplace(
element->name(),
bb->SourceNode(source_proc->GetStateReadByStateElement(element)));
absl::Span<StateRead* const> reads =
source_proc->GetStateReadsByStateElement(element);
XLS_RET_CHECK_EQ(reads.size(), 1);
name_to_value->emplace(element->name(),
bb->SourceNode(reads.front()));
}
} else {
return absl::InvalidArgumentError(absl::StrFormat(
Expand Down
18 changes: 13 additions & 5 deletions xls/ir/ir_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/substitute.h"
#include "absl/types/span.h"
#include "xls/common/source_location.h"
#include "xls/common/status/matchers.h"
#include "xls/ir/bits.h"
Expand Down Expand Up @@ -607,7 +608,9 @@ proc foo( x: bits[32], y: (), z: bits[32], init={42, (), 123}) {
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, package->GetProc("foo"));
EXPECT_EQ(proc->GetStateElementCount(), 3);
XLS_ASSERT_OK_AND_ASSIGN(StateElement * x, proc->GetStateElementByName("x"));
EXPECT_THAT(proc->GetStateReadByStateElement(x)->predicate(), std::nullopt);
absl::Span<StateRead* const> reads = proc->GetStateReadsByStateElement(x);
ASSERT_EQ(reads.size(), 1);
EXPECT_THAT(reads.front()->predicate(), std::nullopt);
}

TEST(IrParserTest, ProcWithPredicatedStateRead) {
Expand All @@ -626,17 +629,22 @@ proc foo( x: bits[32], y: bits[1], z: bits[32], init={42, 1, 123}) {
EXPECT_EQ(proc->GetStateElementCount(), 3);

XLS_ASSERT_OK_AND_ASSIGN(StateElement * x, proc->GetStateElementByName("x"));
std::optional<Node*> x_predicate =
proc->GetStateReadByStateElement(x)->predicate();
absl::Span<StateRead* const> reads_x = proc->GetStateReadsByStateElement(x);
ASSERT_EQ(reads_x.size(), 1);
std::optional<Node*> x_predicate = reads_x.front()->predicate();
ASSERT_TRUE(x_predicate.has_value());
ASSERT_EQ((*x_predicate)->op(), Op::kStateRead);
EXPECT_EQ((*x_predicate)->As<StateRead>()->state_element()->name(), "y");

XLS_ASSERT_OK_AND_ASSIGN(StateElement * y, proc->GetStateElementByName("y"));
ASSERT_FALSE(proc->GetStateReadByStateElement(y)->predicate().has_value());
absl::Span<StateRead* const> reads_y = proc->GetStateReadsByStateElement(y);
ASSERT_EQ(reads_y.size(), 1);
ASSERT_FALSE(reads_y.front()->predicate().has_value());

XLS_ASSERT_OK_AND_ASSIGN(StateElement * z, proc->GetStateElementByName("z"));
ASSERT_FALSE(proc->GetStateReadByStateElement(z)->predicate().has_value());
absl::Span<StateRead* const> reads_z = proc->GetStateReadsByStateElement(z);
ASSERT_EQ(reads_z.size(), 1);
ASSERT_FALSE(reads_z.front()->predicate().has_value());
}

TEST(IrParserTest, ParseSendReceiveChannel) {
Expand Down
4 changes: 2 additions & 2 deletions xls/ir/node_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ TEST_F(NodeUtilTest, ChannelNodes) {

EXPECT_THAT(GetChannelUsedByNode(rcv.node()), IsOkAndHolds(ch0));
EXPECT_THAT(GetChannelUsedByNode(send.node()), IsOkAndHolds(ch1));
EXPECT_THAT(GetChannelUsedByNode(proc->GetStateRead(0)),
EXPECT_THAT(GetChannelUsedByNode(proc->GetStateReads(0).front()),
StatusIs(absl::StatusCode::kNotFound,
HasSubstr("No channel associated with node")));
}
Expand Down Expand Up @@ -435,7 +435,7 @@ TEST_F(NodeUtilTest, ReplaceTupleIndicesWorksWithToken) {
// works, we'd need to make an after_all and add the receive's output token to
// it after calling ReplaceTupleElementsWith().
XLS_EXPECT_OK(ReplaceTupleElementsWith(
receive_node, {{0, proc->GetStateRead(0)}, {1, lit0}}));
receive_node, {{0, proc->GetStateReads(0).front()}, {1, lit0}}));

ExpectIr(proc->DumpIr(), TestName());
}
Expand Down
68 changes: 43 additions & 25 deletions xls/ir/proc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,15 @@ absl::Status Proc::RemoveStateElement(int64_t index) {
StateElement* old_state_element = GetStateElement(index);
auto old_state_read_it = state_reads_.find(old_state_element);
XLS_RET_CHECK(old_state_read_it != state_reads_.end());
if (!old_state_read_it->second->users().empty()) {
return absl::InvalidArgumentError(absl::StrFormat(
"Cannot remove state element %d of proc %s, existing "
"state read %s has uses",
index, name(), old_state_read_it->second->GetNameView()));
for (StateRead* read : old_state_read_it->second) {
if (!read->users().empty()) {
return absl::InvalidArgumentError(
absl::StrFormat("Cannot remove state element %d of proc %s, existing "
"state read %s has uses",
index, name(), read->GetNameView()));
}
XLS_RETURN_IF_ERROR(RemoveNode(read));
}
XLS_RETURN_IF_ERROR(RemoveNode(old_state_read_it->second));
// TODO(allight): This should ideally not need to be done manually.
state_reads_.erase(old_state_read_it);

Expand All @@ -232,11 +234,14 @@ absl::Status Proc::RemoveStateElement(int64_t index) {
absl::Status Proc::RemoveAllStateElements() {
// TODO(allight): This relies on side tables being valid. For now just let it
// go.
for (const auto& [elem, read] : state_reads_) {
if (read != nullptr) {
XLS_RETURN_IF_ERROR(RemoveNode(read))
<< "Cannot remove " << elem->ToString() << " of proc " << name()
<< " because read '" << read->ToString() << "' could not be removed.";
for (const auto& [elem, reads] : state_reads_) {
for (StateRead* read : reads) {
if (read != nullptr) {
XLS_RETURN_IF_ERROR(RemoveNode(read))
<< "Cannot remove " << elem->ToString() << " of proc " << name()
<< " because read '" << read->ToString()
<< "' could not be removed.";
}
}
XLS_RETURN_IF_ERROR(state_name_uniquer_.ReleaseIdentifier(elem->name()))
<< "Cannot release name of " << elem->ToString();
Expand Down Expand Up @@ -278,7 +283,7 @@ absl::StatusOr<StateRead*> Proc::InsertStateElement(
MakeNodeWithName<StateRead>(
loc, state_element, read_predicate,
/*label=*/std::nullopt, state_element->name()));
state_reads_[state_element] = state_read;
state_reads_[state_element].push_back(state_read);

if (next_state.has_value()) {
if (!ValueConformsToType(init_value, next_state.value()->GetType())) {
Expand Down Expand Up @@ -351,14 +356,13 @@ absl::StatusOr<Proc*> Proc::Clone(
return mapping.at(orig);
};
for (StateElement* state_element : StateElements()) {
StateRead* state_read = state_reads_.at(state_element);
XLS_ASSIGN_OR_RETURN(
StateRead * cloned_state_read,
cloned_proc->AppendStateElement(
remap_name(state_name_remapping, state_element->name()),
state_element->initial_value(), state_read->predicate(),
/*next_state=*/std::nullopt));
original_to_clone[state_read] = cloned_state_read;
XLS_RETURN_IF_ERROR(
cloned_proc
->InsertUnreadStateElement(
cloned_proc->GetStateElementCount(),
remap_name(state_name_remapping, state_element->name()),
state_element->initial_value())
.status());
}
if (is_new_style_proc()) {
absl::flat_hash_map<ChannelInterface*, ChannelInterface*> channel_map;
Expand Down Expand Up @@ -445,7 +449,23 @@ absl::StatusOr<Proc*> Proc::Clone(

switch (node->op()) {
case Op::kStateRead: {
continue;
StateRead* src = node->As<StateRead>();
StateElement* src_elem = src->state_element();
XLS_ASSIGN_OR_RETURN(int64_t idx, GetStateElementIndex(src_elem));
StateElement* cloned_elem = cloned_proc->GetStateElement(idx);

std::optional<Node*> cloned_predicate;
if (src->predicate().has_value()) {
cloned_predicate = original_to_clone.at(src->predicate().value());
}

XLS_ASSIGN_OR_RETURN(StateRead * cloned_state_read,
cloned_proc->MakeNodeWithName<StateRead>(
src->loc(), cloned_elem, cloned_predicate,
/*label=*/std::nullopt, cloned_elem->name()));
cloned_proc->state_reads_[cloned_elem].push_back(cloned_state_read);
original_to_clone[node] = cloned_state_read;
break;
}
case Op::kReceive: {
Receive* src = node->As<Receive>();
Expand Down Expand Up @@ -1000,10 +1020,8 @@ absl::Status Proc::InternalRebuildSideTables() {
state_reads_.clear();
for (Node* n : nodes()) {
if (n->Is<StateRead>()) {
XLS_RET_CHECK(!state_reads_.contains(n->As<StateRead>()->state_element()))
<< "Duplicate state element read: "
<< n->As<StateRead>()->state_element();
state_reads_[n->As<StateRead>()->state_element()] = n->As<StateRead>();
state_reads_[n->As<StateRead>()->state_element()].push_back(
n->As<StateRead>());
} else if (n->Is<Next>()) {
next_values_.push_back(n->As<Next>());
next_values_by_state_element_[n->As<Next>()->state_element()].insert(
Expand Down
19 changes: 16 additions & 3 deletions xls/ir/proc.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,23 @@ class Proc : public FunctionBase {
return state_elements_.contains(name);
}

// Remove legacy getters after all downstream passes migrate logic.
StateRead* GetStateRead(int64_t index) const {
return state_reads_.at(GetStateElement(index));
return GetStateReads(index).front();
}

StateRead* GetStateReadByStateElement(StateElement* state_element) const {
return GetStateReadsByStateElement(state_element).front();
}

// Get state reads for a state element at the given index.
absl::Span<StateRead* const> GetStateReads(int64_t index) const {
return state_reads_.at(GetStateElement(index));
}

// Get state reads for a state element.
absl::Span<StateRead* const> GetStateReadsByStateElement(
StateElement* state_element) const {
return state_reads_.at(state_element);
}

Expand Down Expand Up @@ -403,8 +416,8 @@ class Proc : public FunctionBase {
absl::flat_hash_map<std::string, std::unique_ptr<StateElement>>
state_elements_;

// Map of the unique StateRead node for each state element.
absl::flat_hash_map<StateElement*, StateRead*> state_reads_;
// Map of StateRead nodes for each state element.
absl::flat_hash_map<StateElement*, std::vector<StateRead*>> state_reads_;

// Vector of state element pointers. Kept in sync with the state_elements_
// map. Enables easy, stable iteration over state elements. With this vector,
Expand Down
69 changes: 54 additions & 15 deletions xls/ir/proc_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,43 @@ TEST_F(ProcTest, StatelessProc) {
EXPECT_EQ(proc->DumpIr(), "proc p() {\n}\n");
}

TEST_F(ProcTest, MultipleStateReads) {
auto p = CreatePackage();
ProcBuilder pb("p", p.get());
BValue tkn = pb.StateElement("tkn", Value::Token());
BValue state = pb.StateElement("x", Value(UBits(42, 32)));
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build({tkn, state}));

StateElement* state_elem = proc->GetStateElement(1);

EXPECT_EQ(proc->GetStateReads(1).size(), 1);
StateRead* read1 = proc->GetStateReads(1).front();

// Second Read
XLS_ASSERT_OK_AND_ASSIGN(
StateRead * read2,
proc->MakeNodeWithName<StateRead>(SourceInfo(), state_elem,
/*predicate=*/std::nullopt,
/*label=*/std::nullopt, "x_read2"));
XLS_ASSERT_OK(proc->RebuildSideTables());

EXPECT_EQ(proc->GetStateReads(1).size(), 2);
EXPECT_THAT(proc->GetStateReads(1), ElementsAre(read1, read2));

EXPECT_EQ(proc->GetStateReadsByStateElement(state_elem).size(), 2);
EXPECT_THAT(proc->GetStateReadsByStateElement(state_elem),
ElementsAre(read1, read2));

// Remove the second read.
std::string read2_name = read2->GetName();
XLS_ASSERT_OK(proc->RemoveNode(read2));
XLS_ASSERT_OK(proc->RebuildSideTables());

// Now we should have 1 read again.
EXPECT_EQ(proc->GetStateReads(1).size(), 1);
EXPECT_EQ(proc->GetStateReads(1).front(), read1);
}

TEST_F(ProcTest, RemoveStateThatStillHasUse) {
// Don't call CreatePackage which creates a VerifiedPackage because we
// intentionally create a malformed proc.
Expand Down Expand Up @@ -254,10 +291,10 @@ TEST_F(ProcTest, Clone) {
EXPECT_EQ(clone->DumpIr(),
R"(proc cloned(tkn: token, state: bits[32], init={token, 42}) {
tkn: token = state_read(state_element=tkn, id=12)
literal.14: bits[32] = literal(value=1, id=14)
state: bits[32] = state_read(state_element=state, id=13)
literal.13: bits[32] = literal(value=1, id=13)
state: bits[32] = state_read(state_element=state, id=14)
receive_3: (token, bits[32]) = receive(tkn, channel=cloned_chan, id=15)
add.16: bits[32] = add(literal.14, state, id=16)
add.16: bits[32] = add(literal.13, state, id=16)
tuple_index.17: bits[32] = tuple_index(receive_3, index=1, id=17)
tuple_index.18: token = tuple_index(receive_3, index=0, id=18)
add.19: bits[32] = add(add.16, tuple_index.17, id=19)
Expand Down Expand Up @@ -304,10 +341,10 @@ proc cloned<input_chan: bits[32] in, chan: bits[32] out>(tkn: token, state: bits
chan_interface input_chan(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface chan(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
tkn: token = state_read(state_element=tkn, id=1)
literal.3: bits[32] = literal(value=1, id=3)
state: bits[32] = state_read(state_element=state, id=2)
literal.2: bits[32] = literal(value=1, id=2)
state: bits[32] = state_read(state_element=state, id=3)
receive_3: (token, bits[32]) = receive(tkn, channel=input_chan, id=4)
add.5: bits[32] = add(literal.3, state, id=5)
add.5: bits[32] = add(literal.2, state, id=5)
tuple_index.6: bits[32] = tuple_index(receive_3, index=1, id=6)
tuple_index.7: token = tuple_index(receive_3, index=0, id=7)
add.8: bits[32] = add(add.5, tuple_index.6, id=8)
Expand Down Expand Up @@ -355,15 +392,15 @@ TEST_F(ProcTest, CloneNewStyle) {
chan baz(bits[32], id=0, kind=streaming, ops=send_receive, flow_control=ready_valid, strictness=proven_mutually_exclusive)
chan_interface baz(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=none, flop_kind=none)
chan_interface baz(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=none, flop_kind=none)
tkn: token = literal(value=token, id=14)
receive_3: (token, bits[32]) = receive(tkn, channel=foo, id=15)
tuple_index.16: token = tuple_index(receive_3, index=0, id=16)
receive_6: (token, bits[32]) = receive(tuple_index.16, channel=baz, id=17)
tuple_index.18: token = tuple_index(receive_6, index=0, id=18)
state: bits[32] = state_read(state_element=state, id=13)
tkn: token = literal(value=token, id=13)
receive_3: (token, bits[32]) = receive(tkn, channel=foo, id=14)
tuple_index.15: token = tuple_index(receive_3, index=0, id=15)
receive_6: (token, bits[32]) = receive(tuple_index.15, channel=baz, id=16)
tuple_index.17: token = tuple_index(receive_6, index=0, id=17)
state: bits[32] = state_read(state_element=state, id=18)
tuple_index.19: bits[32] = tuple_index(receive_3, index=1, id=19)
tuple_index.20: bits[32] = tuple_index(receive_6, index=1, id=20)
send_9: token = send(tuple_index.18, state, channel=bar, id=21)
send_9: token = send(tuple_index.17, state, channel=bar, id=21)
add.22: bits[32] = add(tuple_index.19, tuple_index.20, id=22)
send_10: token = send(send_9, state, channel=baz, id=23)
next_value.24: () = next_value(param=state, value=add.22, id=24)
Expand Down Expand Up @@ -556,7 +593,8 @@ TEST_F(ScheduledProcTest, StageAddAndClear) {
proc->ClearStages();
EXPECT_TRUE(proc->stages().empty());
// Re-stage the state element to satisfy the verifier.
XLS_ASSERT_OK(proc->AddNodeToStage(0, proc->GetStateRead(0)).status());
XLS_ASSERT_OK(
proc->AddNodeToStage(0, proc->GetStateReads(0).front()).status());
}

TEST_F(ScheduledProcTest, AddEmptyStages) {
Expand Down Expand Up @@ -596,7 +634,8 @@ TEST_F(ScheduledProcTest, GetStageIndex) {
EXPECT_THAT(proc->GetStageIndex(x), IsOkAndHolds(1));
EXPECT_THAT(proc->GetStageIndex(y), IsOkAndHolds(2));
EXPECT_THAT(proc->GetStageIndex(add), StatusIs(absl::StatusCode::kNotFound));
EXPECT_THAT(proc->GetStageIndex(proc->GetStateRead(0)), IsOkAndHolds(0));
EXPECT_THAT(proc->GetStageIndex(proc->GetStateReads(0).front()),
IsOkAndHolds(0));

// The verifier requires that every node be in a stage before we finish.
ASSERT_THAT(proc->AddNodeToStage(2, add), IsOkAndHolds(true));
Expand Down
24 changes: 15 additions & 9 deletions xls/ir/proc_testutils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,14 @@ absl::StatusOr<std::vector<BValue>> GetStateValuesBeforeActivation(
absl::flat_hash_map<NodeActivation, BValue>& values) {
std::vector<BValue> states;
for (StateElement* state_element : p->StateElements()) {
StateRead* state_read = p->GetStateReadByStateElement(state_element);
absl::Span<StateRead* const> reads =
p->GetStateReadsByStateElement(state_element);
XLS_RET_CHECK(!reads.empty()) << "No reads for " << state_element;

BValue state_val;
if (activation == 0) {
values[{state_read, 0}] =
fb.Literal(state_element->initial_value(), SourceInfo(),
absl::StrFormat("%s_initial_value", p->name()));
state_val = fb.Literal(state_element->initial_value(), SourceInfo(),
absl::StrFormat("%s_initial_value", p->name()));
} else {
std::vector<BValue> cases;
std::vector<BValue> selectors;
Expand All @@ -261,23 +264,26 @@ absl::StatusOr<std::vector<BValue>> GetStateValuesBeforeActivation(
}
if (selectors.empty()) {
XLS_RET_CHECK_EQ(cases.size(), 1) << "no cases for " << state_element;
values[{state_read, activation}] = cases.front();
state_val = cases.front();
} else if (cases.front().GetType()->IsBits() &&
cases.front().GetType()->GetFlatBitCount() == 0) {
// Special case to avoid creating non-trivial uses of zero-len bit
// vectors.
values[{state_read, activation}] = fb.Literal(UBits(0, 0));
state_val = fb.Literal(UBits(0, 0));
} else {
XLS_RET_CHECK_EQ(cases.size(), selectors.size());
// materialize the next values into a select.
// Need to reverse to keep the LSB is case 0 etc.
absl::c_reverse(selectors);
values[{state_read, activation}] = fb.PrioritySelect(
state_val = fb.PrioritySelect(
fb.Concat(selectors), cases,
/*default_value=*/values[{state_read, activation - 1}]);
/*default_value=*/values[{reads.front(), activation - 1}]);
}
}
states.push_back(values[{state_read, activation}]);
for (StateRead* read : reads) {
values[{read, activation}] = state_val;
}
states.push_back(state_val);
}
return states;
}
Expand Down
Loading
Loading