diff --git a/libs/api/test/test_program.cpp b/libs/api/test/test_program.cpp index 86557843..3f0df777 100644 --- a/libs/api/test/test_program.cpp +++ b/libs/api/test/test_program.cpp @@ -333,6 +333,114 @@ SCENARIO("Program handles Pipeline operators", "[program][pipeline]") { } } +SCENARIO("Program handles Pipeline operators and resets", "[program][pipeline]") { + GIVEN("A program with a Pipeline") { + std::string program_json = R"({ + "operators": [ + {"type": "Input", "id": "input", "portTypes": ["number"]}, + { + "type": "Pipeline", + "id": "pipeline", + "input_port_types": ["number"], + "output_port_types": ["number"], + "operators": [ + {"type": "Input", "id": "pinput", "portTypes": ["number"]}, + {"type": "MovingAverage", "id": "ma", "window_size": 3} + + ], + "connections": [ + {"from": "pinput", "to": "ma", "fromPort": "o1", "toPort": "i1"} + ], + "entryOperator": "pinput", + "outputMappings": { + "ma": {"o1": "o1"} + } + }, + {"type": "Output", "id": "output", "portTypes": ["number"]} + ], + "connections": [ + {"from": "input", "to": "pipeline", "fromPort": "o1", "toPort": "i1"}, + {"from": "pipeline", "to": "output", "fromPort": "o1", "toPort": "i1"} + ], + "entryOperator": "input", + "output": { + "output": ["o1"] + } + })"; + + Program program(program_json); + + WHEN("Processing messages") { + // Need 5 messages: + // First 3 messages to fill MA(3) buffer + // Message 4 and 5 to emit MA values that will fill STD(3) buffer and produce output + std::vector> messages = { + {1, NumberData{1.0}}, // MA collecting + {2, NumberData{2.0}}, // MA collecting + {3, NumberData{3.0}}, // MA emits value (3.0) -> pipeline resets + {4, NumberData{4.0}}, // MA collecting + {5, NumberData{5.0}}, // MA collecting + {6, NumberData{6.0}}, // MA emits value (5.0) -> pipeline resets + {7, NumberData{0.0}}, // MA collecting + {8, NumberData{0.0}}, // MA collecting + {9, NumberData{0.0}}, // MA emits value (0.0) -> pipeline resets + }; + + ProgramMsgBatch final_batch; + + program.receive(messages.at(0)); + program.receive(messages.at(1)); + final_batch = program.receive(messages.at(2)); + + THEN("Pipeline processes messages correctly and resets") { + REQUIRE(final_batch.size() == 1); + REQUIRE(final_batch["output"].count("o1") == 1); + const auto* out_msg = dynamic_cast*>(final_batch["output"]["o1"].back().get()); + REQUIRE(out_msg != nullptr); + REQUIRE(out_msg->time == 3); + REQUIRE(out_msg->data.value == 2.0); + + final_batch = program.receive(messages.at(3)); + + AND_THEN("Pipeline start all over, ma collects") { + REQUIRE(final_batch.size() == 0); + + final_batch = program.receive(messages.at(4)); + + AND_THEN("ma collects") { + REQUIRE(final_batch.size() == 0); + + final_batch = program.receive(messages.at(5)); + + AND_THEN("pipeline emits") { + REQUIRE(final_batch.size() == 1); + REQUIRE(final_batch["output"].count("o1") == 1); + const auto* out_msg = dynamic_cast*>(final_batch["output"]["o1"].back().get()); + REQUIRE(out_msg != nullptr); + REQUIRE(out_msg->time == 6); + REQUIRE(out_msg->data.value == 5.0); + + program.receive(messages.at(6)); + program.receive(messages.at(7)); + final_batch = program.receive(messages.at(8)); + + AND_THEN("pipeline emits") { + REQUIRE(final_batch.size() == 1); + REQUIRE(final_batch["output"].count("o1") == 1); + const auto* out_msg = + dynamic_cast*>(final_batch["output"]["o1"].back().get()); + REQUIRE(out_msg != nullptr); + REQUIRE(out_msg->time == 9); + REQUIRE(out_msg->data.value == 0.0); + } + } + } + } + } + } + } +} + SCENARIO("Program handles Pipeline serialization", "[program][pipeline]") { GIVEN("A program with a stateful Pipeline") { std::string program_json = R"({ diff --git a/libs/core/include/rtbot/Pipeline.h b/libs/core/include/rtbot/Pipeline.h index 934cd195..404a6418 100644 --- a/libs/core/include/rtbot/Pipeline.h +++ b/libs/core/include/rtbot/Pipeline.h @@ -98,9 +98,6 @@ class Pipeline : public Operator { void reset() override { RTBOT_LOG_DEBUG("Resetting pipeline"); - // First reset our own state - Operator::reset(); - // Then reset all internal operators for (auto& [_, op] : operators_) { op->reset(); @@ -108,20 +105,6 @@ class Pipeline : public Operator { } void clear_all_output_ports() override { - // Check if we produced any output - bool has_output = false; - for (size_t i = 0; i < num_output_ports(); ++i) { - if (!get_output_queue(i).empty()) { - has_output = true; - break; - } - } - - // If we produced output, reset the pipeline for next iteration - if (has_output) { - reset(); - } - Operator::clear_all_output_ports(); for (auto& [_, op] : operators_) { op->clear_all_output_ports(); @@ -145,26 +128,36 @@ class Pipeline : public Operator { entry_operator_->receive_data(msg->clone(), i); entry_operator_->execute(); input_queue.pop_front(); - } - } - - // Process output mappings - for (const auto& [op_id, mappings] : output_mappings_) { - auto it = operators_.find(op_id); - if (it != operators_.end()) { - auto& op = it->second; - for (const auto& [operator_port, pipeline_port] : mappings) { - if (operator_port < op->num_output_ports() && pipeline_port < num_output_ports()) { - const auto& source_queue = op->get_output_queue(operator_port); - // Only forward if source operator has produced output on the mapped port - if (!source_queue.empty()) { - auto& target_queue = get_output_queue(pipeline_port); - for (const auto& msg : source_queue) { - RTBOT_LOG_DEBUG("Forwarding message ", msg->to_string(), " from ", op_id, " -> ", pipeline_port); - target_queue.push_back(msg->clone()); + // Process output mappings + bool was_reseted = false; + for (const auto& [op_id, mappings] : output_mappings_) { + auto it = operators_.find(op_id); + if (it != operators_.end()) { + auto& op = it->second; + for (const auto& [operator_port, pipeline_port] : mappings) { + if (operator_port < op->num_output_ports() && pipeline_port < num_output_ports()) { + const auto& source_queue = op->get_output_queue(operator_port); + // Only forward if source operator has produced output on the mapped port + if (!source_queue.empty()) { + was_reseted = false; + auto& target_queue = get_output_queue(pipeline_port); + for (const auto& msg : source_queue) { + RTBOT_LOG_DEBUG("Forwarding message ", msg->to_string(), " from ", op_id, " -> ", pipeline_port); + target_queue.push_back(msg->clone()); + reset(); + was_reseted = true; + break; + } + } + } + if (was_reseted) { + break; } } } + if (was_reseted) { + break; + } } } } diff --git a/libs/std/include/rtbot/std/ResamplerConstant.h b/libs/std/include/rtbot/std/ResamplerConstant.h index db6e81d4..4812ce48 100644 --- a/libs/std/include/rtbot/std/ResamplerConstant.h +++ b/libs/std/include/rtbot/std/ResamplerConstant.h @@ -29,6 +29,34 @@ class ResamplerConstant : public Operator { std::string type_name() const override { return "ResamplerConstant"; } + Bytes collect() override { + // First collect base state + Bytes bytes = Operator::collect(); + + // Serialize next emission time + bytes.insert(bytes.end(), reinterpret_cast(&next_emit_), + reinterpret_cast(&next_emit_) + sizeof(next_emit_)); + + // Serialize initialization state + bytes.insert(bytes.end(), reinterpret_cast(&initialized_), + reinterpret_cast(&initialized_) + sizeof(initialized_)); + + return bytes; + } + + void restore(Bytes::const_iterator& it) override { + // First restore base state + Operator::restore(it); + + // Restore next emission time + next_emit_ = *reinterpret_cast(&(*it)); + it += sizeof(timestamp_t); + + // Restore initialization state + initialized_ = *reinterpret_cast(&(*it)); + it += sizeof(bool); + } + timestamp_t get_interval() const { return dt_; } timestamp_t get_next_emission_time() const { return next_emit_; } std::optional get_t0() const { return t0_; } diff --git a/libs/std/include/rtbot/std/ResamplerHermite.h b/libs/std/include/rtbot/std/ResamplerHermite.h index 598a1783..d1fe2a8c 100644 --- a/libs/std/include/rtbot/std/ResamplerHermite.h +++ b/libs/std/include/rtbot/std/ResamplerHermite.h @@ -30,7 +30,6 @@ class ResamplerHermite : public Buffer { Buffer::reset(); initialized_ = false; next_emit_ = 0; - pending_emissions_.clear(); } std::string type_name() const override { return "ResamplerHermite"; } @@ -125,11 +124,10 @@ class ResamplerHermite : public Buffer { return h00 * y1 + h10 * m0 + h01 * y2 + h11 * m1; } - timestamp_t dt_; // Resampling interval - std::optional t0_; // Optional start time - timestamp_t next_emit_; // Next time to emit a sample - bool initialized_; // Whether we've initialized next_emit_ - std::vector>> pending_emissions_; // Queue of pending emissions + timestamp_t dt_; // Resampling interval + std::optional t0_; // Optional start time + timestamp_t next_emit_; // Next time to emit a sample + bool initialized_; // Whether we've initialized next_emit_ }; inline std::shared_ptr make_resampler_hermite(std::string id, timestamp_t interval, diff --git a/libs/std/test/test_filter_scalar.cpp b/libs/std/test/test_filter_scalar.cpp index f565c7ae..0e74b5e6 100644 --- a/libs/std/test/test_filter_scalar.cpp +++ b/libs/std/test/test_filter_scalar.cpp @@ -68,6 +68,37 @@ SCENARIO("FilterScalarOp derived classes handle basic filtering", "[filter_scala } } + SECTION("GreaterThan operator small value") { + auto gt = make_greater_than("gt1", 0.5); + + REQUIRE(gt->type_name() == "GreaterThan"); + REQUIRE(dynamic_cast(gt.get())->get_threshold() == 0.5); + + std::vector> inputs = { + {0, 0.3}, // Should be filtered + {1, 1.0}, // Should pass + {2, 4.0}, // Should pass + {4, 0.2}, // Should be filtered + {5, 0.5} // Should be filtered (not strictly greater than) + }; + + std::vector> expected = {{1, 1.0}, {2, 4.0}}; + + for (const auto& input : inputs) { + gt->receive_data(create_message(input.first, NumberData{input.second}), 0); + } + gt->execute(); + + auto& output = gt->get_output_queue(0); + REQUIRE(output.size() == expected.size()); + + for (size_t i = 0; i < output.size(); ++i) { + auto* msg = dynamic_cast*>(output[i].get()); + REQUIRE(msg->time == expected[i].first); + REQUIRE(msg->data.value == expected[i].second); + } + } + SECTION("EqualTo operator") { auto eq = make_equal_to("eq1", 3.0, 0.1);