Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions libs/api/test/test_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,114 @@ SCENARIO("Program handles Pipeline operators", "[program][pipeline]") {
}
}

SCENARIO("Program handles Pipeline operators and resets", "[program][pipeline]") {
GIVEN("A program with a Pipeline") {
std::string program_json = R"({
"operators": [
{"type": "Input", "id": "input", "portTypes": ["number"]},
{
"type": "Pipeline",
"id": "pipeline",
"input_port_types": ["number"],
"output_port_types": ["number"],
"operators": [
{"type": "Input", "id": "pinput", "portTypes": ["number"]},
{"type": "MovingAverage", "id": "ma", "window_size": 3}

],
"connections": [
{"from": "pinput", "to": "ma", "fromPort": "o1", "toPort": "i1"}
],
"entryOperator": "pinput",
"outputMappings": {
"ma": {"o1": "o1"}
}
},
{"type": "Output", "id": "output", "portTypes": ["number"]}
],
"connections": [
{"from": "input", "to": "pipeline", "fromPort": "o1", "toPort": "i1"},
{"from": "pipeline", "to": "output", "fromPort": "o1", "toPort": "i1"}
],
"entryOperator": "input",
"output": {
"output": ["o1"]
}
})";

Program program(program_json);

WHEN("Processing messages") {
// Need 5 messages:
// First 3 messages to fill MA(3) buffer
// Message 4 and 5 to emit MA values that will fill STD(3) buffer and produce output
std::vector<Message<NumberData>> messages = {
{1, NumberData{1.0}}, // MA collecting
{2, NumberData{2.0}}, // MA collecting
{3, NumberData{3.0}}, // MA emits value (3.0) -> pipeline resets
{4, NumberData{4.0}}, // MA collecting
{5, NumberData{5.0}}, // MA collecting
{6, NumberData{6.0}}, // MA emits value (5.0) -> pipeline resets
{7, NumberData{0.0}}, // MA collecting
{8, NumberData{0.0}}, // MA collecting
{9, NumberData{0.0}}, // MA emits value (0.0) -> pipeline resets
};

ProgramMsgBatch final_batch;

program.receive(messages.at(0));
program.receive(messages.at(1));
final_batch = program.receive(messages.at(2));

THEN("Pipeline processes messages correctly and resets") {
REQUIRE(final_batch.size() == 1);
REQUIRE(final_batch["output"].count("o1") == 1);
const auto* out_msg = dynamic_cast<const Message<NumberData>*>(final_batch["output"]["o1"].back().get());
REQUIRE(out_msg != nullptr);
REQUIRE(out_msg->time == 3);
REQUIRE(out_msg->data.value == 2.0);

final_batch = program.receive(messages.at(3));

AND_THEN("Pipeline start all over, ma collects") {
REQUIRE(final_batch.size() == 0);

final_batch = program.receive(messages.at(4));

AND_THEN("ma collects") {
REQUIRE(final_batch.size() == 0);

final_batch = program.receive(messages.at(5));

AND_THEN("pipeline emits") {
REQUIRE(final_batch.size() == 1);
REQUIRE(final_batch["output"].count("o1") == 1);
const auto* out_msg = dynamic_cast<const Message<NumberData>*>(final_batch["output"]["o1"].back().get());
REQUIRE(out_msg != nullptr);
REQUIRE(out_msg->time == 6);
REQUIRE(out_msg->data.value == 5.0);

program.receive(messages.at(6));
program.receive(messages.at(7));
final_batch = program.receive(messages.at(8));

AND_THEN("pipeline emits") {
REQUIRE(final_batch.size() == 1);
REQUIRE(final_batch["output"].count("o1") == 1);
const auto* out_msg =
dynamic_cast<const Message<NumberData>*>(final_batch["output"]["o1"].back().get());
REQUIRE(out_msg != nullptr);
REQUIRE(out_msg->time == 9);
REQUIRE(out_msg->data.value == 0.0);
}
}
}
}
}
}
}
}

SCENARIO("Program handles Pipeline serialization", "[program][pipeline]") {
GIVEN("A program with a stateful Pipeline") {
std::string program_json = R"({
Expand Down
61 changes: 27 additions & 34 deletions libs/core/include/rtbot/Pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,30 +98,13 @@ class Pipeline : public Operator {

void reset() override {
RTBOT_LOG_DEBUG("Resetting pipeline");
// First reset our own state
Operator::reset();

// Then reset all internal operators
for (auto& [_, op] : operators_) {
op->reset();
}
}

void clear_all_output_ports() override {
// Check if we produced any output
bool has_output = false;
for (size_t i = 0; i < num_output_ports(); ++i) {
if (!get_output_queue(i).empty()) {
has_output = true;
break;
}
}

// If we produced output, reset the pipeline for next iteration
if (has_output) {
reset();
}

Operator::clear_all_output_ports();
for (auto& [_, op] : operators_) {
op->clear_all_output_ports();
Expand All @@ -145,26 +128,36 @@ class Pipeline : public Operator {
entry_operator_->receive_data(msg->clone(), i);
entry_operator_->execute();
input_queue.pop_front();
}
}

// Process output mappings
for (const auto& [op_id, mappings] : output_mappings_) {
auto it = operators_.find(op_id);
if (it != operators_.end()) {
auto& op = it->second;
for (const auto& [operator_port, pipeline_port] : mappings) {
if (operator_port < op->num_output_ports() && pipeline_port < num_output_ports()) {
const auto& source_queue = op->get_output_queue(operator_port);
// Only forward if source operator has produced output on the mapped port
if (!source_queue.empty()) {
auto& target_queue = get_output_queue(pipeline_port);
for (const auto& msg : source_queue) {
RTBOT_LOG_DEBUG("Forwarding message ", msg->to_string(), " from ", op_id, " -> ", pipeline_port);
target_queue.push_back(msg->clone());
// Process output mappings
bool was_reseted = false;
for (const auto& [op_id, mappings] : output_mappings_) {
auto it = operators_.find(op_id);
if (it != operators_.end()) {
auto& op = it->second;
for (const auto& [operator_port, pipeline_port] : mappings) {
if (operator_port < op->num_output_ports() && pipeline_port < num_output_ports()) {
const auto& source_queue = op->get_output_queue(operator_port);
// Only forward if source operator has produced output on the mapped port
if (!source_queue.empty()) {
was_reseted = false;
auto& target_queue = get_output_queue(pipeline_port);
for (const auto& msg : source_queue) {
RTBOT_LOG_DEBUG("Forwarding message ", msg->to_string(), " from ", op_id, " -> ", pipeline_port);
target_queue.push_back(msg->clone());
reset();
was_reseted = true;
break;
}
}
}
if (was_reseted) {
break;
}
}
}
if (was_reseted) {
break;
}
}
}
}
Expand Down
28 changes: 28 additions & 0 deletions libs/std/include/rtbot/std/ResamplerConstant.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,34 @@ class ResamplerConstant : public Operator {

std::string type_name() const override { return "ResamplerConstant"; }

Bytes collect() override {
// First collect base state
Bytes bytes = Operator::collect();

// Serialize next emission time
bytes.insert(bytes.end(), reinterpret_cast<const uint8_t*>(&next_emit_),
reinterpret_cast<const uint8_t*>(&next_emit_) + sizeof(next_emit_));

// Serialize initialization state
bytes.insert(bytes.end(), reinterpret_cast<const uint8_t*>(&initialized_),
reinterpret_cast<const uint8_t*>(&initialized_) + sizeof(initialized_));

return bytes;
}

void restore(Bytes::const_iterator& it) override {
// First restore base state
Operator::restore(it);

// Restore next emission time
next_emit_ = *reinterpret_cast<const timestamp_t*>(&(*it));
it += sizeof(timestamp_t);

// Restore initialization state
initialized_ = *reinterpret_cast<const bool*>(&(*it));
it += sizeof(bool);
}

timestamp_t get_interval() const { return dt_; }
timestamp_t get_next_emission_time() const { return next_emit_; }
std::optional<timestamp_t> get_t0() const { return t0_; }
Expand Down
10 changes: 4 additions & 6 deletions libs/std/include/rtbot/std/ResamplerHermite.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class ResamplerHermite : public Buffer<NumberData, ResamplerFeatures> {
Buffer<NumberData, ResamplerFeatures>::reset();
initialized_ = false;
next_emit_ = 0;
pending_emissions_.clear();
}

std::string type_name() const override { return "ResamplerHermite"; }
Expand Down Expand Up @@ -125,11 +124,10 @@ class ResamplerHermite : public Buffer<NumberData, ResamplerFeatures> {
return h00 * y1 + h10 * m0 + h01 * y2 + h11 * m1;
}

timestamp_t dt_; // Resampling interval
std::optional<timestamp_t> t0_; // Optional start time
timestamp_t next_emit_; // Next time to emit a sample
bool initialized_; // Whether we've initialized next_emit_
std::vector<std::unique_ptr<Message<NumberData>>> pending_emissions_; // Queue of pending emissions
timestamp_t dt_; // Resampling interval
std::optional<timestamp_t> t0_; // Optional start time
timestamp_t next_emit_; // Next time to emit a sample
bool initialized_; // Whether we've initialized next_emit_
};

inline std::shared_ptr<ResamplerHermite> make_resampler_hermite(std::string id, timestamp_t interval,
Expand Down
31 changes: 31 additions & 0 deletions libs/std/test/test_filter_scalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,37 @@ SCENARIO("FilterScalarOp derived classes handle basic filtering", "[filter_scala
}
}

SECTION("GreaterThan operator small value") {
auto gt = make_greater_than("gt1", 0.5);

REQUIRE(gt->type_name() == "GreaterThan");
REQUIRE(dynamic_cast<GreaterThan*>(gt.get())->get_threshold() == 0.5);

std::vector<std::pair<timestamp_t, double>> inputs = {
{0, 0.3}, // Should be filtered
{1, 1.0}, // Should pass
{2, 4.0}, // Should pass
{4, 0.2}, // Should be filtered
{5, 0.5} // Should be filtered (not strictly greater than)
};

std::vector<std::pair<timestamp_t, double>> expected = {{1, 1.0}, {2, 4.0}};

for (const auto& input : inputs) {
gt->receive_data(create_message<NumberData>(input.first, NumberData{input.second}), 0);
}
gt->execute();

auto& output = gt->get_output_queue(0);
REQUIRE(output.size() == expected.size());

for (size_t i = 0; i < output.size(); ++i) {
auto* msg = dynamic_cast<const Message<NumberData>*>(output[i].get());
REQUIRE(msg->time == expected[i].first);
REQUIRE(msg->data.value == expected[i].second);
}
}

SECTION("EqualTo operator") {
auto eq = make_equal_to("eq1", 3.0, 0.1);

Expand Down
Loading