diff --git a/libs/api/include/rtbot/Program.h b/libs/api/include/rtbot/Program.h index 7dd39b19..85e5a783 100644 --- a/libs/api/include/rtbot/Program.h +++ b/libs/api/include/rtbot/Program.h @@ -114,14 +114,14 @@ class Program { // Message processing ProgramMsgBatch receive(const Message& msg, const std::string& port_id = "i1") { - send_to_entry(msg, port_id); + send_to_entry(msg, port_id, false); ProgramMsgBatch result = collect_outputs(false); clear_all_outputs(); return result; } ProgramMsgBatch receive_debug(const Message& msg, const std::string& port_id = "i1") { - send_to_entry(msg, port_id); + send_to_entry(msg, port_id, true); ProgramMsgBatch result = collect_outputs(true); clear_all_outputs(); return result; @@ -229,10 +229,10 @@ class Program { throw runtime_error("Could not resolve operator ID: " + id); } - void send_to_entry(const Message& msg, const std::string& port_id) { + void send_to_entry(const Message& msg, const std::string& port_id, bool debug=false) { auto port_info = OperatorJson::parse_port_name(port_id); operators_[entry_operator_id_]->receive_data(create_message(msg.time, msg.data), port_info.index); - operators_[entry_operator_id_]->execute(); + operators_[entry_operator_id_]->execute(debug); } ProgramMsgBatch collect_outputs(bool debug_mode = false) { @@ -255,7 +255,7 @@ class Program { // In debug mode, collect all ports if (debug_mode) { for (size_t i = 0; i < op->num_output_ports(); i++) { - const auto& queue = op->get_output_queue(i); + const auto& queue = op->get_debug_output_queue(i); if (!queue.empty()) { PortMsgBatch port_msgs; for (const auto& msg : queue) { diff --git a/libs/core/include/rtbot/Buffer.h b/libs/core/include/rtbot/Buffer.h index a08d9d76..7f3c3d8b 100644 --- a/libs/core/include/rtbot/Buffer.h +++ b/libs/core/include/rtbot/Buffer.h @@ -72,11 +72,38 @@ class Buffer : public Operator { return std::sqrt(variance()); } - Bytes collect() override { - Bytes bytes = Operator::collect(); + bool equals(const Buffer& other) const { + + if (window_size_ != other.window_size_) return false; + + if (buffer_.size() != other.buffer_.size()) return false; + + auto it1 = buffer_.begin(); + auto it2 = other.buffer_.begin(); + + for (; it1 != buffer_.end() && it2 != other.buffer_.end(); ++it1, ++it2) { + const auto& msg1 = *it1; + const auto& msg2 = *it2; - bytes.insert(bytes.end(), reinterpret_cast(&window_size_), - reinterpret_cast(&window_size_) + sizeof(window_size_)); + if (msg1 && msg2) { + if (msg1->time != msg2->time) return false; + if (msg1->hash() != msg2->hash()) return false; + } else return false; + } + + if constexpr (Features::TRACK_SUM) { + if (StateSerializer::hash_double(sum_) != StateSerializer::hash_double(other.sum_)) return false; + } + + if constexpr (Features::TRACK_VARIANCE) { + if (StateSerializer::hash_double(M2_) != StateSerializer::hash_double(other.M2_)) return false; + } + + return Operator::equals(other); + } + + Bytes collect() override { + Bytes bytes = Operator::collect(); size_t buffer_size = buffer_.size(); bytes.insert(bytes.end(), reinterpret_cast(&buffer_size), @@ -90,65 +117,68 @@ class Buffer : public Operator { bytes.insert(bytes.end(), msg_bytes.begin(), msg_bytes.end()); } - if constexpr (Features::TRACK_SUM) { - bytes.insert(bytes.end(), reinterpret_cast(&sum_), - reinterpret_cast(&sum_) + sizeof(sum_)); - } - - if constexpr (Features::TRACK_VARIANCE) { - bytes.insert(bytes.end(), reinterpret_cast(&M2_), - reinterpret_cast(&M2_) + sizeof(M2_)); - } - return bytes; } void restore(Bytes::const_iterator& it) override { - Operator::restore(it); - - window_size_ = *reinterpret_cast(&(*it)); - it += sizeof(size_t); + // Call base restore first + Operator::restore(it); - size_t buffer_size = *reinterpret_cast(&(*it)); - it += sizeof(size_t); + // ---- Read buffer_size safely ---- + size_t buffer_size; + std::memcpy(&buffer_size, &(*it), sizeof(buffer_size)); + it += sizeof(buffer_size); + // ---- Deserialize buffer ---- buffer_.clear(); for (size_t i = 0; i < buffer_size; ++i) { - size_t msg_size = *reinterpret_cast(&(*it)); - it += sizeof(size_t); - - Bytes msg_bytes(it, it + msg_size); - buffer_.push_back( - std::unique_ptr>(dynamic_cast*>(BaseMessage::deserialize(msg_bytes).release()))); - it += msg_size; + // Read size of each message + size_t msg_size; + std::memcpy(&msg_size, &(*it), sizeof(msg_size)); + it += sizeof(msg_size); + + // Extract message bytes + Bytes msg_bytes(it, it + msg_size); + + // Deserialize message and cast to derived type + buffer_.push_back( + std::unique_ptr>( + dynamic_cast*>(BaseMessage::deserialize(msg_bytes).release()) + ) + ); + + it += msg_size; } + // ---- Optional statistics ---- if constexpr (Features::TRACK_SUM) { - sum_ = *reinterpret_cast(&(*it)); - it += sizeof(double); + sum_ = 0.0; + if (!buffer_.empty()) { + // First pass: compute sum + for (const auto& msg : buffer_) { + sum_ += msg->data.value; + } + } } if constexpr (Features::TRACK_VARIANCE) { - M2_ = *reinterpret_cast(&(*it)); - it += sizeof(double); - - // Recompute statistics from buffer to ensure consistency - sum_ = 0.0; - M2_ = 0.0; - - if (!buffer_.empty()) { - // First pass: compute mean - for (const auto& msg : buffer_) { - sum_ += msg->data.value; + // Recompute statistics from buffer to ensure consistency + sum_ = 0.0; + M2_ = 0.0; + + if (!buffer_.empty()) { + // First pass: compute sum + for (const auto& msg : buffer_) { + sum_ += msg->data.value; + } + + // Second pass: compute M2 + double mean = sum_ / buffer_.size(); + for (const auto& msg : buffer_) { + double delta = msg->data.value - mean; + M2_ += delta * delta; + } } - - // Second pass: compute M2 - double mean = sum_ / buffer_.size(); - for (const auto& msg : buffer_) { - double delta = msg->data.value - mean; - M2_ += delta * delta; - } - } } } diff --git a/libs/core/include/rtbot/Demultiplexer.h b/libs/core/include/rtbot/Demultiplexer.h index 51bd6116..c008daa5 100644 --- a/libs/core/include/rtbot/Demultiplexer.h +++ b/libs/core/include/rtbot/Demultiplexer.h @@ -20,13 +20,11 @@ class Demultiplexer : public Operator { } // Add single data input port with type T - add_data_port(); - data_time_tracker_ = std::set(); + add_data_port(); // Add corresponding control ports (always boolean) for (size_t i = 0; i < num_ports; ++i) { - add_control_port(); - control_time_tracker_[i] = std::map(); + add_control_port(); } // Add output ports (same type as input) @@ -37,143 +35,62 @@ class Demultiplexer : public Operator { std::string type_name() const override { return "Demultiplexer"; } - size_t get_num_ports() const { return control_time_tracker_.size(); } + size_t get_num_ports() const { return num_control_ports(); } - Bytes collect() override { - Bytes bytes = Operator::collect(); // First collect base state - - // Serialize data time tracker - StateSerializer::serialize_timestamp_set(bytes, data_time_tracker_); - - // Serialize control time tracker - StateSerializer::serialize_port_control_map(bytes, control_time_tracker_); - - return bytes; + bool equals(const Demultiplexer& other) const { + return Operator::equals(other); } - - void restore(Bytes::const_iterator& it) override { - // First restore base state - Operator::restore(it); - - // Clear current state - data_time_tracker_.clear(); - control_time_tracker_.clear(); - - // Restore data time tracker - StateSerializer::deserialize_timestamp_set(it, data_time_tracker_); - - // Restore control time tracker - StateSerializer::deserialize_port_control_map(it, control_time_tracker_); - - // Validate control port count - StateSerializer::validate_port_count(control_time_tracker_.size(), num_control_ports(), "Control"); + + bool operator==(const Demultiplexer& other) const { + return equals(other); } - void reset() override { - Operator::reset(); - data_time_tracker_.clear(); - control_time_tracker_.clear(); + bool operator!=(const Demultiplexer& other) const { + return !(*this == other); } - void receive_data(std::unique_ptr msg, size_t port_index) override { - auto time = msg->time; - Operator::receive_data(std::move(msg), port_index); - - data_time_tracker_.insert(time); - } + protected: - void receive_control(std::unique_ptr msg, size_t port_index) override { - if (port_index >= num_control_ports()) { - throw std::runtime_error("Invalid control port index"); - } - - auto* ctrl_msg = dynamic_cast*>(msg.get()); - if (!ctrl_msg) { - throw std::runtime_error("Invalid control message type"); - } - - // Update control tracker - control_time_tracker_[port_index][ctrl_msg->time] = ctrl_msg->data.value; - - // Add message to queue - get_control_queue(port_index).push_back(std::move(msg)); - control_ports_with_new_data_.insert(port_index); - } - - protected: void process_data() override { - while (true) { - // Find oldest common control timestamp - auto common_control_time = TimestampTracker::find_oldest_common_time(control_time_tracker_); - if (!common_control_time) { - break; - } - - // Clean up any old input data messages - auto& data_queue = get_data_queue(0); - while (!data_queue.empty()) { - auto* msg = dynamic_cast*>(data_queue.front().get()); - if (msg && msg->time < *common_control_time) { - data_time_tracker_.erase(msg->time); - data_queue.pop_front(); - } else { - break; - } - } - - // Look for matching data message - bool message_found = false; - if (!data_queue.empty()) { - auto* msg = dynamic_cast*>(data_queue.front().get()); - if (msg && msg->time == *common_control_time) { - // Get active control ports - std::vector active_ports; - for (size_t i = 0; i < num_control_ports(); ++i) { - if (control_time_tracker_[i].at(*common_control_time)) { - active_ports.push_back(i); - } + while(true) { + + bool is_any_control_empty; + bool are_controls_sync; + do { + is_any_control_empty = false; + are_controls_sync = sync_control_inputs(); + for (int i=0; i < num_control_ports(); i++) { + if (get_control_queue(i).empty()) { + is_any_control_empty = true; + break; } + } + } while (!are_controls_sync && !is_any_control_empty ); - // Route message to all active ports - for (size_t port : active_ports) { - get_output_queue(port).push_back(data_queue.front()->clone()); - } - - data_time_tracker_.erase(msg->time); - data_queue.pop_front(); - message_found = true; - } - } - - clean_up_control_messages(*common_control_time); + if (!are_controls_sync) return; - if (!message_found) { - break; - } - } - } - - private: - void clean_up_control_messages(timestamp_t time) { - for (auto& [port, tracker] : control_time_tracker_) { - tracker.erase(time); - } - - for (size_t port = 0; port < num_control_ports(); ++port) { - auto& queue = get_control_queue(port); - while (!queue.empty()) { - auto* msg = dynamic_cast*>(queue.front().get()); - if (msg && msg->time <= time) { - queue.pop_front(); - } else { - break; + auto& data_queue = get_data_queue(0); + if (data_queue.empty()) return; + auto* msg = dynamic_cast*>(data_queue.front().get()); + auto* ctrl_msg = dynamic_cast*>(get_control_queue(0).front().get()); + if (msg && ctrl_msg && msg->time == ctrl_msg->time) { + for (int i = 0; i < num_control_ports(); i++) { + ctrl_msg = dynamic_cast*>(get_control_queue(i).front().get()); + if (ctrl_msg->data.value) { + get_output_queue(i).push_back(data_queue.front()->clone()); + } + get_control_queue(i).pop_front(); } + data_queue.pop_front(); + } else if (msg && ctrl_msg && msg->time < ctrl_msg->time) { + data_queue.pop_front(); + } else if (msg && ctrl_msg && ctrl_msg->time < msg->time) { + for (int i = 0; i < num_control_ports(); i++) + get_control_queue(i).pop_front(); + } } } - - std::set data_time_tracker_; - std::map> control_time_tracker_; }; // Factory functions for common configurations using PortType diff --git a/libs/core/include/rtbot/FilterByValue.h b/libs/core/include/rtbot/FilterByValue.h index f693c77a..25df3e70 100644 --- a/libs/core/include/rtbot/FilterByValue.h +++ b/libs/core/include/rtbot/FilterByValue.h @@ -21,6 +21,10 @@ class FilterByValue : public Operator { add_output_port(); } + bool equals(const FilterByValue& other) const { + return Operator::equals(other); + } + protected: void process_data() override { auto& input_queue = get_data_queue(0); diff --git a/libs/core/include/rtbot/Input.h b/libs/core/include/rtbot/Input.h index 8607cf01..c75d666c 100644 --- a/libs/core/include/rtbot/Input.h +++ b/libs/core/include/rtbot/Input.h @@ -20,8 +20,7 @@ class Input : public Operator { if (!PortType::is_valid_port_type(type)) { throw std::runtime_error("Unknown port type: " + type); } - PortType::add_port(*this, type, true, true); - last_sent_times_.push_back(0); + PortType::add_port(*this, type, true, false ,true); port_type_names_.push_back(type); } } @@ -31,68 +30,17 @@ class Input : public Operator { // Get port configuration const std::vector& get_port_types() const { return port_type_names_; } - // Query port state - bool has_sent(size_t port_index) const { - validate_port_index(port_index); - return last_sent_times_[port_index] > 0; + bool equals(const Input& other) const { + if (port_type_names_ != other.port_type_names_) return false; + return Operator::equals(other); } - - timestamp_t get_last_sent_time(size_t port_index) const { - validate_port_index(port_index); - return last_sent_times_[port_index]; - } - - // State serialization - Bytes collect() override { - // First collect base state - Bytes bytes = Operator::collect(); - - // Serialize last sent times - size_t num_ports = last_sent_times_.size(); - bytes.insert(bytes.end(), reinterpret_cast(&num_ports), - reinterpret_cast(&num_ports) + sizeof(num_ports)); - - for (const auto& time : last_sent_times_) { - bytes.insert(bytes.end(), reinterpret_cast(&time), - reinterpret_cast(&time) + sizeof(time)); - } - - // Serialize port type names - StateSerializer::serialize_string_vector(bytes, port_type_names_); - - return bytes; - } - - void restore(Bytes::const_iterator& it) override { - // First restore base state - Operator::restore(it); - - // Restore last sent times - size_t num_ports = *reinterpret_cast(&(*it)); - it += sizeof(size_t); - - StateSerializer::validate_port_count(num_ports, num_data_ports(), "Data"); - - last_sent_times_.clear(); - last_sent_times_.reserve(num_ports); - for (size_t i = 0; i < num_ports; ++i) { - timestamp_t time = *reinterpret_cast(&(*it)); - it += sizeof(timestamp_t); - last_sent_times_.push_back(time); - } - - // Restore port type names - StateSerializer::deserialize_string_vector(it, port_type_names_); - - // Validate port types match - if (port_type_names_.size() != num_data_ports()) { - throw std::runtime_error("Port type count mismatch during restore"); - } + + bool operator==(const Input& other) const { + return equals(other); } - void reset() override { - Operator::reset(); - last_sent_times_.assign(last_sent_times_.size(), 0); + bool operator!=(const Input& other) const { + return !(*this == other); } // Do not throw exceptions in receive_data @@ -107,19 +55,15 @@ class Input : public Operator { protected: void process_data() override { // Process each port independently to allow concurrent timestamps - for (const auto& port_index : data_ports_with_new_data_) { + for (int port_index = 0; port_index < num_data_ports(); port_index++) { const auto& input_queue = get_data_queue(port_index); if (input_queue.empty()) continue; auto& output_queue = get_output_queue(port_index); // Process all messages in input queue - for (const auto& msg : input_queue) { - // Only forward if timestamp is increasing for this specific port - if (!has_sent(port_index) || msg->time > last_sent_times_[port_index]) { - output_queue.push_back(std::move(msg->clone())); - last_sent_times_[port_index] = msg->time; - } + for (const auto& msg : input_queue) { + output_queue.push_back(std::move(msg->clone())); } // Clear processed messages @@ -127,7 +71,6 @@ class Input : public Operator { } } - void process_control() override {} // No control processing needed private: void validate_port_index(size_t port_index) const { @@ -135,8 +78,6 @@ class Input : public Operator { throw std::runtime_error("Invalid port index: " + std::to_string(port_index)); } } - - std::vector last_sent_times_; std::vector port_type_names_; }; diff --git a/libs/core/include/rtbot/Join.h b/libs/core/include/rtbot/Join.h index 39d90f13..53282a2e 100644 --- a/libs/core/include/rtbot/Join.h +++ b/libs/core/include/rtbot/Join.h @@ -30,8 +30,7 @@ class Join : public Operator { throw std::runtime_error("Unknown port type: " + type); } - PortType::add_port(*this, type, true, false); // input only - data_time_tracker_[num_data_ports() - 1] = std::set(); + PortType::add_port(*this, type, true, false, false); // input only port_type_names_.push_back(type); } @@ -40,7 +39,7 @@ class Join : public Operator { if (!PortType::is_valid_port_type(type)) { throw std::runtime_error("Unknown port type: " + type); } - PortType::add_port(*this, type, false, true); // output only + PortType::add_port(*this, type,false ,false, true); // output only } } @@ -55,8 +54,7 @@ class Join : public Operator { throw std::runtime_error("Unknown port type: " + type); } - PortType::add_port(*this, type, true, true); - data_time_tracker_[num_data_ports() - 1] = std::set(); + PortType::add_port(*this, type, true, false ,true); port_type_names_.push_back(type); } } @@ -70,8 +68,7 @@ class Join : public Operator { std::string port_type = PortType::get_port_type(); for (size_t i = 0; i < num_ports; ++i) { - PortType::add_port(*this, port_type, true, true); - data_time_tracker_[i] = std::set(); + PortType::add_port(*this, port_type, true, false ,true); port_type_names_.push_back(port_type); } } @@ -81,177 +78,51 @@ class Join : public Operator { // Get port configuration const std::vector& get_port_types() const { return port_type_names_; } - Bytes collect() override { - // First collect base state - Bytes bytes = Operator::collect(); - - // Serialize data time tracker - StateSerializer::serialize_port_timestamp_set_map(bytes, data_time_tracker_); - - // Serialize port type names - StateSerializer::serialize_string_vector(bytes, port_type_names_); - - // Serialize synchronized_data - size_t sync_size = synchronized_data.size(); - bytes.insert(bytes.end(), reinterpret_cast(&sync_size), - reinterpret_cast(&sync_size) + sizeof(sync_size)); - - for (const auto& [time, messages] : synchronized_data) { - bytes.insert(bytes.end(), reinterpret_cast(&time), - reinterpret_cast(&time) + sizeof(time)); - - size_t msg_count = messages.size(); - bytes.insert(bytes.end(), reinterpret_cast(&msg_count), - reinterpret_cast(&msg_count) + sizeof(msg_count)); - - for (const auto& msg : messages) { - Bytes msg_bytes = msg->serialize(); - size_t msg_size = msg_bytes.size(); - bytes.insert(bytes.end(), reinterpret_cast(&msg_size), - reinterpret_cast(&msg_size) + sizeof(msg_size)); - bytes.insert(bytes.end(), msg_bytes.begin(), msg_bytes.end()); - } - } - - return bytes; - } - - void restore(Bytes::const_iterator& it) override { - // First restore base state - Operator::restore(it); - - // Clear current state - data_time_tracker_.clear(); - - // Restore data time tracker - StateSerializer::deserialize_port_timestamp_set_map(it, data_time_tracker_); - - // Validate port count - StateSerializer::validate_port_count(data_time_tracker_.size(), num_data_ports(), "Data"); - - // Restore port type names - StateSerializer::deserialize_string_vector(it, port_type_names_); - - // Restore synchronized_data - synchronized_data.clear(); - size_t sync_size = *reinterpret_cast(&(*it)); - it += sizeof(size_t); - - for (size_t i = 0; i < sync_size; ++i) { - timestamp_t time = *reinterpret_cast(&(*it)); - it += sizeof(timestamp_t); - - size_t msg_count = *reinterpret_cast(&(*it)); - it += sizeof(size_t); - - std::vector> messages; - for (size_t j = 0; j < msg_count; ++j) { - size_t msg_size = *reinterpret_cast(&(*it)); - it += sizeof(size_t); - - Bytes msg_bytes(it, it + msg_size); - messages.push_back(BaseMessage::deserialize(msg_bytes)); - it += msg_size; - } - synchronized_data[time] = std::move(messages); - } - - // Validate port types match - if (port_type_names_.size() != num_data_ports()) { - throw std::runtime_error("Port type count mismatch during restore"); - } + bool equals(const Join& other) const { + return (port_type_names_ == other.port_type_names_ && Operator::equals(other)); } - void reset() override { - Operator::reset(); - for (auto& [_, tracker] : data_time_tracker_) { - tracker.clear(); - } - synchronized_data.clear(); + bool operator==(const Join& other) const { + return equals(other); } - void receive_data(std::unique_ptr msg, size_t port_index) override { - auto time = msg->time; - Operator::receive_data(std::move(msg), port_index); - - // Track timestamp - data_time_tracker_[port_index].insert(time); + bool operator!=(const Join& other) const { + return !(*this == other); } protected: // Performs synchronization of input messages - void sync() { - while (true) { - // Find oldest common timestamp across all ports - auto common_time = TimestampTracker::find_oldest_common_time(data_time_tracker_); - if (!common_time) { - break; - } - // Process messages with common timestamp - std::vector> synced_messages; - - for (size_t port = 0; port < num_data_ports(); ++port) { - auto& queue = get_data_queue(port); - bool message_found = false; - - // Find message with matching timestamp - while (!queue.empty()) { - const auto& front_msg = queue.front(); - if (front_msg->time < *common_time) { - data_time_tracker_[port].erase(front_msg->time); - queue.pop_front(); - } else if (front_msg->time == *common_time) { - synced_messages.push_back(front_msg->clone()); - data_time_tracker_[port].erase(front_msg->time); - queue.pop_front(); - message_found = true; - break; - } else { + void process_data() override { + + while(true) { + + bool is_any_empty; + bool is_sync; + do { + is_any_empty = false; + is_sync = sync_data_inputs(); + for (int i=0; i < num_data_ports(); i++) { + if (get_data_queue(i).empty()) { + is_any_empty = true; break; } - } + } + } while (!is_sync && !is_any_empty ); - if (!message_found) { - throw std::runtime_error("Implementation error in Join: TimestampTracker found common timestamp " + - std::to_string(*common_time) + " but message not found in queue for port " + - std::to_string(port)); - } - } + if (!is_sync) return; - // Store synchronized messages - synchronized_data[*common_time] = std::move(synced_messages); - } - } - - void process_data() override { - // First perform synchronization - sync(); - - // Then process synchronized data - for (const auto& [time, messages] : synchronized_data) { - for (size_t i = 0; i < messages.size(); ++i) { - get_output_queue(i).push_back(messages[i]->clone()); + for (int i=0; i < num_data_ports(); i++) { + get_output_queue(i).push_back(get_data_queue(i).front()->clone()); + get_data_queue(i).pop_front(); } - } - - // Clear synchronized data after processing - synchronized_data.clear(); - } - void process_control() override {} // Join has no control ports - - // Access to synchronized data for derived classes - const std::map>>& get_synchronized_data() const { - return synchronized_data; + } } - void clear_synchronized_data() { synchronized_data.clear(); } private: std::vector port_type_names_; - std::map> data_time_tracker_; - std::map>> synchronized_data; }; // Factory functions remain unchanged diff --git a/libs/core/include/rtbot/Message.h b/libs/core/include/rtbot/Message.h index 4923e68a..cdfdacb9 100644 --- a/libs/core/include/rtbot/Message.h +++ b/libs/core/include/rtbot/Message.h @@ -37,10 +37,31 @@ struct NumberData { static NumberData deserialize(Bytes::const_iterator& it) { NumberData data; - data.value = *reinterpret_cast(&(*it)); + + // Read double safely (avoid unaligned reinterpret_cast) + double value; + std::memcpy(&value, &(*it), sizeof(double)); + data.value = value; + it += sizeof(double); return data; } + + bool operator==(const NumberData& other) const noexcept { + return hash() == other.hash(); + } + + bool operator!=(const NumberData& other) const noexcept { + return !(*this == other); + } + + + uint64_t hash() const { + uint64_t u; + uint64_t quantize = static_cast(value * 1e9); + std::memcpy(&u, &quantize, sizeof(uint64_t)); + return u; + } }; struct BooleanData { @@ -55,10 +76,28 @@ struct BooleanData { static BooleanData deserialize(Bytes::const_iterator& it) { BooleanData data; - data.value = *reinterpret_cast(&(*it)); + + bool value; + std::memcpy(&value, &(*it), sizeof(bool)); + data.value = value; + it += sizeof(bool); return data; } + + bool operator==(const BooleanData& other) const noexcept { + return hash() == other.hash(); + } + + bool operator!=(const BooleanData& other) const noexcept { + return !(*this == other); + } + + uint64_t hash() const { + if (value) return 1; + else return 0; + } + }; struct VectorNumberData { @@ -79,17 +118,43 @@ struct VectorNumberData { static VectorNumberData deserialize(Bytes::const_iterator& it) { VectorNumberData data; - size_t size = *reinterpret_cast(&(*it)); + + // ---- read vector size ---- + size_t size; + std::memcpy(&size, &(*it), sizeof(size)); it += sizeof(size_t); data.values.reserve(size); + + // ---- read each double ---- for (size_t i = 0; i < size; ++i) { - double value = *reinterpret_cast(&(*it)); - it += sizeof(double); - data.values.push_back(value); + double value; + std::memcpy(&value, &(*it), sizeof(double)); + it += sizeof(double); + data.values.push_back(value); } return data; } + + bool operator==(const BooleanData& other) const noexcept { + return hash() == other.hash(); + } + + bool operator!=(const BooleanData& other) const noexcept { + return !(*this == other); + } + + uint64_t hash() const { + uint64_t result = 0; + for (int i=0; i < values.size(); i++) { + uint64_t u; + double value = values.at(i); + uint64_t quantize = static_cast(value * 1e9); + std::memcpy(&u, &quantize, sizeof(uint64_t)); + result = result + u; + } + return result; + } }; struct VectorBooleanData { @@ -117,22 +182,38 @@ struct VectorBooleanData { static VectorBooleanData deserialize(Bytes::const_iterator& it) { VectorBooleanData data; - size_t size = *reinterpret_cast(&(*it)); + + // ---- read vector size ---- + size_t size; + std::memcpy(&size, &(*it), sizeof(size)); it += sizeof(size_t); - // Read packed boolean values + // Number of bytes in the packed bitfield size_t byte_count = (size + 7) / 8; data.values.reserve(size); + // ---- decode packed bits ---- for (size_t i = 0; i < size; ++i) { - uint8_t byte = *(it + (i / 8)); - bool value = (byte & (1 << (i % 8))) != 0; - data.values.push_back(value); + uint8_t byte = *(it + (i / 8)); + bool value = (byte & (uint8_t(1) << (i % 8))) != 0; + data.values.push_back(value); } + // Advance iterator past the packed bytes it += byte_count; return data; } + + uint64_t hash() const { + uint64_t result = 0; + for (int i=0; i < values.size(); i++) { + uint64_t u; + if (values.at(i)) u = 1; + else u = 0; + result = result + u; + } + return result; + } }; // Base message class @@ -145,6 +226,7 @@ class BaseMessage { virtual std::type_index type() const = 0; virtual std::unique_ptr clone() const = 0; virtual std::string to_string() const = 0; + virtual std::uint64_t hash() const = 0; virtual Bytes serialize() const { Bytes bytes; @@ -224,6 +306,10 @@ class Message : public BaseMessage { return std::make_unique>(time, data); } + std::uint64_t hash() const override { + return data.hash(); + } + T data; }; @@ -231,28 +317,35 @@ class Message : public BaseMessage { inline std::unique_ptr BaseMessage::deserialize(const Bytes& bytes) { auto it = bytes.begin(); - // Read timestamp - timestamp_t time = *reinterpret_cast(&(*it)); + // ---- Read timestamp ---- + timestamp_t time; + std::memcpy(&time, &(*it), sizeof(time)); it += sizeof(timestamp_t); - // Read type information - size_t type_length = *reinterpret_cast(&(*it)); + // ---- Read type string length ---- + size_t type_length; + std::memcpy(&type_length, &(*it), sizeof(type_length)); it += sizeof(size_t); + + // ---- Read type name ---- std::string type_name(it, it + type_length); it += type_length; - // Create appropriate message type based on type_name + // ---- Dispatch based on type name ---- if (type_name == typeid(NumberData).name()) { - return Message::deserialize_data(time, it, bytes.end()); - } else if (type_name == typeid(BooleanData).name()) { - return Message::deserialize_data(time, it, bytes.end()); - } else if (type_name == typeid(VectorNumberData).name()) { - return Message::deserialize_data(time, it, bytes.end()); - } else if (type_name == typeid(VectorBooleanData).name()) { - return Message::deserialize_data(time, it, bytes.end()); - } else { - throw std::runtime_error("Unknown message type: " + type_name); + return Message::deserialize_data(time, it, bytes.end()); } + else if (type_name == typeid(BooleanData).name()) { + return Message::deserialize_data(time, it, bytes.end()); + } + else if (type_name == typeid(VectorNumberData).name()) { + return Message::deserialize_data(time, it, bytes.end()); + } + else if (type_name == typeid(VectorBooleanData).name()) { + return Message::deserialize_data(time, it, bytes.end()); + } + + throw std::runtime_error("Unknown message type: " + type_name); } template diff --git a/libs/core/include/rtbot/Multiplexer.h b/libs/core/include/rtbot/Multiplexer.h index 9785cd3c..ff16cc3e 100644 --- a/libs/core/include/rtbot/Multiplexer.h +++ b/libs/core/include/rtbot/Multiplexer.h @@ -25,66 +25,31 @@ class Multiplexer : public Operator { // Add data ports for (size_t i = 0; i < num_ports; ++i) { add_data_port(); - data_time_tracker_[i] = std::set(); } // Add corresponding control ports for (size_t i = 0; i < num_ports; ++i) { add_control_port(); - control_time_tracker_[i] = std::map(); } // Single output port add_output_port(); - } - - void reset() override { - Operator::reset(); - data_time_tracker_.clear(); - control_time_tracker_.clear(); - } + } - size_t get_num_ports() const { return data_time_tracker_.size(); } + size_t get_num_ports() const { return data_ports_.size(); } std::string type_name() const override { return "Multiplexer"; } - // State serialization - Bytes collect() override { - // First collect base state - Bytes bytes = Operator::collect(); - - // Serialize data time tracker - StateSerializer::serialize_port_timestamp_set_map(bytes, data_time_tracker_); - - // Serialize control time tracker - StateSerializer::serialize_port_control_map(bytes, control_time_tracker_); - - return bytes; + bool equals(const Multiplexer& other) const { + return Operator::equals(other); } - - void restore(Bytes::const_iterator& it) override { - // First restore base state - Operator::restore(it); - - // Clear current state - data_time_tracker_.clear(); - control_time_tracker_.clear(); - - // Restore data time tracker - StateSerializer::deserialize_port_timestamp_set_map(it, data_time_tracker_); - StateSerializer::validate_port_count(data_time_tracker_.size(), num_data_ports(), "Data"); - - // Restore control time tracker - StateSerializer::deserialize_port_control_map(it, control_time_tracker_); - StateSerializer::validate_port_count(control_time_tracker_.size(), num_control_ports(), "Control"); + + bool operator==(const Multiplexer& other) const { + return equals(other); } - void receive_data(std::unique_ptr msg, size_t port_index) override { - auto time = msg->time; - Operator::receive_data(std::move(msg), port_index); - - // Add timestamp to the tracker - data_time_tracker_[port_index].insert(time); + bool operator!=(const Multiplexer& other) const { + return !(*this == other); } void receive_control(std::unique_ptr msg, size_t port_index) override { @@ -96,138 +61,107 @@ class Multiplexer : public Operator { if (!ctrl_msg) { throw std::runtime_error("Invalid control message type"); } + + // Update last timestamp + control_ports_[port_index].last_timestamp = msg->time; - // Update control time tracker - control_time_tracker_[port_index][ctrl_msg->time] = ctrl_msg->data.value; + if (get_control_queue(port_index).size() == max_size_per_port_) { + get_control_queue(port_index).pop_front(); + } + // Add message to queue get_control_queue(port_index).push_back(std::move(msg)); - data_ports_with_new_data_.insert(port_index); + } - protected: + protected: + void process_data() override { while (true) { - // Find the next timestamp that exists in all data queues - auto common_data_time = TimestampTracker::find_oldest_common_time(data_time_tracker_); - - // Clean up old control data if we have a common data time - if (common_data_time) { - clean_old_control_timestamps(*common_data_time); - } - - // Look for matching control messages - auto common_control_time = TimestampTracker::find_oldest_common_time( - control_time_tracker_, common_data_time.value_or(std::numeric_limits::min())); - if (!common_control_time) { - break; - } - - auto port_to_emit = find_port_to_emit(*common_control_time); - if (!port_to_emit) { - clean_up_control_messages(*common_control_time); - continue; + + int num_empty_data_ports = 0; + for (int i=0; i < num_data_ports(); i++) { + if (get_data_queue(i).empty()) { + num_empty_data_ports++; + } } - - auto& input_queue = get_data_queue(*port_to_emit); - bool message_found = false; - - if (!input_queue.empty()) { - message_found = false; - for (const auto& msg_ptr : input_queue) { - if (auto* msg = dynamic_cast*>(msg_ptr.get())) { - if (msg->time == *common_control_time) { - auto& output = get_output_queue(0); - output.push_back(create_message(msg->time, msg->data)); + + if (num_empty_data_ports == num_data_ports()) return; + + bool is_any_control_empty; + bool are_control_inputs_sync; + do { + is_any_control_empty = false; + are_control_inputs_sync = sync_control_inputs(); + for (int i=0; i < num_control_ports(); i++) { + if (get_control_queue(i).empty()) { + is_any_control_empty = true; + break; + } + } + } while (!are_control_inputs_sync && !is_any_control_empty ); + + if (!are_control_inputs_sync) return; + + auto* ctrl_msg = dynamic_cast*>(get_control_queue(0).front().get()); + + int64_t port_to_emit = find_port_to_emit(ctrl_msg->time); + if (port_to_emit >= 0) { + bool message_found = false; + for (int i = 0; i < num_data_ports(); i++) { + if (!get_data_queue(i).empty()) { + auto* msg = dynamic_cast*>(get_data_queue(i).front().get()); + if (i == port_to_emit && msg->time == ctrl_msg->time) { + get_output_queue(0).push_back(create_message(msg->time, msg->data)); + get_data_queue(i).pop_front(); message_found = true; - break; // Found our message for this timestamp - } else if (msg->time > *common_control_time) { - break; // No need to check further as messages are ordered by time + } else if (i == port_to_emit && ctrl_msg->time < msg->time) { + message_found = true; + } else if (msg->time <= ctrl_msg->time) { + get_data_queue(i).pop_front(); } } } - } - - clean_up_control_messages(*common_control_time); - - if (message_found) { - clean_up_data_messages(*common_control_time); + if (message_found) { + for (int i = 0; i < num_control_ports(); i++) { + get_control_queue(i).pop_front(); + } + } } else { - break; + clean_data_input_queue_fronts(ctrl_msg->time); + for (int i = 0; i < num_control_ports(); i++) { + get_control_queue(i).pop_front(); + } } } } private: - void clean_old_control_timestamps(timestamp_t current_time) { - // Clean up control trackers - for (auto& [port, times] : control_time_tracker_) { - auto it = times.begin(); - while (it != times.end() && it->first < current_time) { - it = times.erase(it); + + void clean_data_input_queue_fronts(timestamp_t time) { + for (int i = 0; i < num_data_ports(); i++) { + if (!get_data_queue(i).empty()) { + auto* msg = dynamic_cast*>(get_data_queue(i).front().get()); + if (msg && msg->time <= time) get_data_queue(i).pop_front(); } } } - std::optional find_port_to_emit(timestamp_t time) { + int64_t find_port_to_emit(timestamp_t time) { size_t active_count = 0; - std::optional selected_port; + int64_t selected_port = -1; - for (size_t port = 0; port < num_control_ports(); ++port) { - auto it = control_time_tracker_[port].find(time); - if (it != control_time_tracker_[port].end() && it->second) { + for (size_t i = 0; i < num_control_ports(); i++) { + auto* ctrl_msg = dynamic_cast*>(get_control_queue(i).front().get()); + if ((ctrl_msg->time == time) && ctrl_msg->data.value) { active_count++; - selected_port = port; + selected_port = static_cast(i); } } - return (active_count == 1) ? selected_port : std::nullopt; - } - - void clean_up_data_messages(timestamp_t time) { - for (size_t port = 0; port < num_data_ports(); ++port) { - auto& data_queue = get_data_queue(port); - while (!data_queue.empty()) { - auto* msg = dynamic_cast*>(data_queue.front().get()); - if (msg && msg->time <= time) { - data_time_tracker_[port].erase(msg->time); - data_queue.pop_front(); - } else { - break; - } - } - } + return (active_count == 1) ? selected_port : -1; } - - void clean_up_control_messages(timestamp_t time) { - // Clean up control queues and trackers - for (size_t port = 0; port < num_control_ports(); ++port) { - // Clear all control messages up to and including current time - auto& control_queue = get_control_queue(port); - while (!control_queue.empty()) { - if (auto* msg = dynamic_cast*>(control_queue.front().get())) { - if (msg->time <= time) { - control_queue.pop_front(); - } else { - break; - } - } - } - - // Clean up the tracker for this timestamp - auto& port_tracker = control_time_tracker_[port]; - auto it = port_tracker.begin(); - while (it != port_tracker.end() && it->first <= time) { - it = port_tracker.erase(it); - } - } - } - - // Track all available timestamps for each data port - std::map> data_time_tracker_; - - // Track control values for each port and timestamp - std::map> control_time_tracker_; }; // Factory function for creating a Multiplexer operator diff --git a/libs/core/include/rtbot/Operator.h b/libs/core/include/rtbot/Operator.h index 83b17734..ac146773 100644 --- a/libs/core/include/rtbot/Operator.h +++ b/libs/core/include/rtbot/Operator.h @@ -1,13 +1,15 @@ #ifndef OPERATOR_H #define OPERATOR_H +#define MAX_SIZE_PER_PORT 17280 + #include #include #include #include #include #include -#include +#include #include #include #include @@ -35,7 +37,7 @@ enum class PortKind { DATA, CONTROL }; // Base operator class class Operator { public: - Operator(std::string id) : id_(std::move(id)) {} + Operator(std::string id) : id_(std::move(id)), max_size_per_port_(MAX_SIZE_PER_PORT) {} virtual ~Operator() = default; virtual std::string type_name() const = 0; @@ -55,86 +57,58 @@ class Operator { reinterpret_cast(&control_ports_count) + sizeof(control_ports_count)); bytes.insert(bytes.end(), reinterpret_cast(&output_ports_count), reinterpret_cast(&output_ports_count) + sizeof(output_ports_count)); - - // Serialize port types - for (const auto& port : data_ports_) { - StateSerializer::serialize_type_index(bytes, port.type); - } - for (const auto& port : control_ports_) { - StateSerializer::serialize_type_index(bytes, port.type); - } - for (const auto& port : output_ports_) { - StateSerializer::serialize_type_index(bytes, port.type); - } - // Serialize message queues for (const auto& port : data_ports_) { StateSerializer::serialize_message_queue(bytes, port.queue); + bytes.insert(bytes.end(), reinterpret_cast(&port.last_timestamp), + reinterpret_cast(&port.last_timestamp) + sizeof(port.last_timestamp)); } for (const auto& port : control_ports_) { StateSerializer::serialize_message_queue(bytes, port.queue); + bytes.insert(bytes.end(), reinterpret_cast(&port.last_timestamp), + reinterpret_cast(&port.last_timestamp) + sizeof(port.last_timestamp)); } for (const auto& port : output_ports_) { StateSerializer::serialize_message_queue(bytes, port.queue); - } - // Serialize connection stateful data - for (const auto& conn : connections_) { - bytes.insert(bytes.end(), reinterpret_cast(&conn.last_propagated_index), - reinterpret_cast(&conn.last_propagated_index) + sizeof(conn.last_propagated_index)); - } - - // Serialize ports with new data sets - StateSerializer::serialize_index_set(bytes, data_ports_with_new_data_); - StateSerializer::serialize_index_set(bytes, control_ports_with_new_data_); + } return bytes; } virtual void restore(Bytes::const_iterator& it) { - // Read port counts - size_t data_ports_count = *reinterpret_cast(&(*it)); + // ---- Read port counts safely ---- + size_t data_ports_count; + std::memcpy(&data_ports_count, &(*it), sizeof(data_ports_count)); it += sizeof(size_t); - size_t control_ports_count = *reinterpret_cast(&(*it)); + + size_t control_ports_count; + std::memcpy(&control_ports_count, &(*it), sizeof(control_ports_count)); it += sizeof(size_t); - size_t output_ports_count = *reinterpret_cast(&(*it)); + + size_t output_ports_count; + std::memcpy(&output_ports_count, &(*it), sizeof(output_ports_count)); it += sizeof(size_t); - // Validate port counts match current configuration + // ---- Validate counts ---- StateSerializer::validate_port_count(data_ports_count, data_ports_.size(), "Data"); StateSerializer::validate_port_count(control_ports_count, control_ports_.size(), "Control"); StateSerializer::validate_port_count(output_ports_count, output_ports_.size(), "Output"); - // Validate and restore port types + // ---- Restore message queues ---- for (auto& port : data_ports_) { - StateSerializer::validate_and_restore_type(it, port.type); + StateSerializer::deserialize_message_queue(it, port.queue); + std::memcpy(&port.last_timestamp, &(*it), sizeof(port.last_timestamp)); + it += sizeof(timestamp_t); } for (auto& port : control_ports_) { - StateSerializer::validate_and_restore_type(it, port.type); + StateSerializer::deserialize_message_queue(it, port.queue); + std::memcpy(&port.last_timestamp, &(*it), sizeof(port.last_timestamp)); + it += sizeof(timestamp_t); } for (auto& port : output_ports_) { - StateSerializer::validate_and_restore_type(it, port.type); - } - - // Restore message queues - for (auto& port : data_ports_) { - StateSerializer::deserialize_message_queue(it, port.queue); - } - for (auto& port : control_ports_) { - StateSerializer::deserialize_message_queue(it, port.queue); - } - for (auto& port : output_ports_) { - StateSerializer::deserialize_message_queue(it, port.queue); - } - - // Restore connections (excluding child pointers) - for (size_t i = 0; i < connections_.size(); ++i) { - connections_[i].last_propagated_index = *reinterpret_cast(&(*it)); - it += sizeof(size_t); + StateSerializer::deserialize_message_queue(it, port.queue); } - // Restore ports with new data sets - StateSerializer::deserialize_index_set(it, data_ports_with_new_data_); - StateSerializer::deserialize_index_set(it, control_ports_with_new_data_); } // Dynamic port management with type information @@ -156,6 +130,7 @@ class Operator { size_t num_data_ports() const { return data_ports_.size(); } size_t num_control_ports() const { return control_ports_.size(); } size_t num_output_ports() const { return output_ports_.size(); } + size_t max_size_per_port() const { return max_size_per_port_; } // Runtime port access for data with type checking virtual void receive_data(std::unique_ptr msg, size_t port_index) { @@ -182,9 +157,13 @@ class Operator { #ifdef RTBOT_INSTRUMENTATION RTBOT_RECORD_MESSAGE(id_, type_name(), std::move(msg->clone())); #endif + + + if (data_ports_[port_index].queue.size() == max_size_per_port_) { + data_ports_[port_index].queue.pop_front(); + } data_ports_[port_index].queue.push_back(std::move(msg)); - data_ports_with_new_data_.insert(port_index); } virtual void reset() { @@ -200,8 +179,9 @@ class Operator { for (auto& port : output_ports_) { port.queue.clear(); } - data_ports_with_new_data_.clear(); - control_ports_with_new_data_.clear(); + for (auto& queue : debug_output_queues_) { + queue.clear(); + } } // This should be called by the runtime to clear all output ports before executing @@ -210,30 +190,26 @@ class Operator { for (auto& port : output_ports_) { port.queue.clear(); } - // also clear all connections - for (auto& conn : connections_) { - conn.last_propagated_index = 0; + for (auto& queue : debug_output_queues_) { + queue.clear(); } } - void execute() { + void execute(bool debug=false) { SpanScope span_scope{"operator_execute"}; - RTBOT_ADD_ATTRIBUTE("operator.id", id_); - - if (data_ports_with_new_data_.empty() && control_ports_with_new_data_.empty()) { - return; - } + RTBOT_ADD_ATTRIBUTE("operator.id", id_); // Process control messages first - if (!control_ports_with_new_data_.empty()) { + if (num_control_ports() > 0) { SpanScope control_scope{"process_control"}; - process_control(); - control_ports_with_new_data_.clear(); + process_control(); } // Then process data - process_data(); - data_ports_with_new_data_.clear(); + if (num_data_ports() > 0) { + SpanScope data_scope{"process_data"}; + process_data(); + } #ifdef RTBOT_INSTRUMENTATION for (size_t i = 0; i < output_ports_.size(); i++) { @@ -243,7 +219,7 @@ class Operator { } #endif - propagate_outputs(); + propagate_outputs(debug); } // Runtime port access for control messages with type checking @@ -269,8 +245,13 @@ class Operator { // Update last timestamp control_ports_[port_index].last_timestamp = msg->time; + + if (control_ports_[port_index].queue.size() == max_size_per_port_) { + control_ports_[port_index].queue.pop_front(); + } + control_ports_[port_index].queue.push_back(std::move(msg)); - control_ports_with_new_data_.insert(port_index); + } std::shared_ptr connect(std::shared_ptr child, size_t output_port = 0, @@ -308,59 +289,122 @@ class Operator { } } - connections_.push_back({child, output_port, child_port_index, 0, child_port_kind}); + connections_.push_back({child, output_port, child_port_index, child_port_kind}); return child; } // Get port type std::type_index get_data_port_type(size_t port_index) const { if (port_index >= data_ports_.size()) { - throw std::runtime_error("Invalid data port index"); + throw std::runtime_error("Invalid data port index for data queue"); } return data_ports_[port_index].type; } std::type_index get_control_port_type(size_t port_index) const { if (port_index >= control_ports_.size()) { - throw std::runtime_error("Invalid control port index"); + throw std::runtime_error("Invalid control port index for control queue"); } return control_ports_[port_index].type; } std::type_index get_output_port_type(size_t port_index) const { if (port_index >= output_ports_.size()) { - throw std::runtime_error("Invalid output port index"); + throw std::runtime_error("Invalid output port index for output queue"); } return output_ports_[port_index].type; } const std::string& id() const { return id_; } + bool equals(const Operator& other) const { + // Compare IDs + if (id_ != other.id_) return false; + if (type_name() != other.type_name()) return false; + + // Compare number of ports + if (data_ports_.size() != other.data_ports_.size()) return false; + if (control_ports_.size() != other.control_ports_.size()) return false; + if (output_ports_.size() != other.output_ports_.size()) return false; + + // Compare data port types and last timestamps + for (size_t i = 0; i < data_ports_.size(); ++i) { + if (data_ports_[i].type != other.data_ports_[i].type) return false; + if (data_ports_[i].last_timestamp != other.data_ports_[i].last_timestamp) return false; + + // Optional: compare queues + if (data_ports_[i].queue.size() != other.data_ports_[i].queue.size()) return false; + for (size_t j = 0; j < data_ports_[i].queue.size(); ++j) { + if (data_ports_[i].queue[j]->hash() != other.data_ports_[i].queue[j]->hash()) return false; + if (data_ports_[i].queue[j]->time != other.data_ports_[i].queue[j]->time) return false; + } + } + + // Compare control port types and last timestamps + for (size_t i = 0; i < control_ports_.size(); ++i) { + if (control_ports_[i].type != other.control_ports_[i].type) return false; + if (control_ports_[i].last_timestamp != other.control_ports_[i].last_timestamp) return false; + + if (control_ports_[i].queue.size() != other.control_ports_[i].queue.size()) return false; + for (size_t j = 0; j < control_ports_[i].queue.size(); ++j) { + if (control_ports_[i].queue[j]->hash() != other.control_ports_[i].queue[j]->hash()) return false; + if (control_ports_[i].queue[j]->time != other.control_ports_[i].queue[j]->time) return false; + } + } + + // Compare output ports queues (types can be ignored if always match) + for (size_t i = 0; i < output_ports_.size(); ++i) { + if (output_ports_[i].queue.size() != other.output_ports_[i].queue.size()) return false; + for (size_t j = 0; j < output_ports_[i].queue.size(); ++j) { + if (output_ports_[i].queue[j]->hash() != other.output_ports_[i].queue[j]->hash()) return false; + if (output_ports_[i].queue[j]->time != other.output_ports_[i].queue[j]->time) return false; + } + } + + return true; + } + + bool operator==(const Operator& other) const { + return equals(other); + } + + bool operator!=(const Operator& other) const { + return !(*this == other); + } + // Access to port queues for derived classes MessageQueue& get_data_queue(size_t port_index) { if (port_index >= data_ports_.size()) { - throw std::runtime_error("Invalid data port index"); + throw std::runtime_error("Invalid data port index for data queue"); } return data_ports_[port_index].queue; } MessageQueue& get_control_queue(size_t port_index) { if (port_index >= control_ports_.size()) { - throw std::runtime_error("Invalid control port index"); + throw std::runtime_error("Invalid control port index for control queue"); } return control_ports_[port_index].queue; } MessageQueue& get_output_queue(size_t port_index) { if (port_index >= output_ports_.size()) { - throw std::runtime_error("Invalid output port index"); + throw std::runtime_error("Invalid output port index for output queue"); } return output_ports_[port_index].queue; } + MessageQueue& get_debug_output_queue(size_t port_index) { + static MessageQueue empty; + if (port_index >= debug_output_queues_.size()) { + return empty; + } + return debug_output_queues_[port_index]; + } + const MessageQueue& get_output_queue(size_t port_index) const { if (port_index >= output_ports_.size()) { - throw std::runtime_error("Invalid output port index"); + throw std::runtime_error("Invalid output port index for output queue"); } return output_ports_[port_index].queue; } @@ -369,17 +413,110 @@ class Operator { virtual void process_data() = 0; virtual void process_control() {} - void propagate_outputs() { - // First send the messages to the connected operators - for (auto& conn : connections_) { - auto& output_queue = output_ports_[conn.output_port].queue; - size_t last_propagated_index = conn.last_propagated_index; + bool sync_data_inputs() { + + if (data_ports_.empty()) return false; + + while (true) { + // If any queue is empty, sync not possible + for (auto& port : data_ports_) { + if (port.queue.empty()) + return false; + } + + // Find min and max front timestamps + timestamp_t min_time = data_ports_.front().queue.front()->time; + timestamp_t max_time = min_time; + + for (auto& port : data_ports_) { + timestamp_t t = port.queue.front()->time; + if (t < min_time) min_time = t; + if (t > max_time) max_time = t; + } + + // All equal → synchronized + if (min_time == max_time) + return true; + + // Pop all queues that have the oldest front timestamp + for (auto& port : data_ports_) { + if (!port.queue.empty() && port.queue.front()->time == min_time) + port.queue.pop_front(); + } + + // If any queue now empty → cannot sync + for (auto& port : data_ports_) { + if (port.queue.empty()) + return false; + } + } + return false; + } + + bool sync_control_inputs() { + + if (control_ports_.empty()) return false; + while (true) { + // If any queue is empty, sync not possible + for (auto& port : control_ports_) { + if (port.queue.empty()) + return false; + } + + // Find min and max front timestamps + timestamp_t min_time = control_ports_.front().queue.front()->time; + timestamp_t max_time = min_time; + + for (auto& port : control_ports_) { + timestamp_t t = port.queue.front()->time; + if (t < min_time) min_time = t; + if (t > max_time) max_time = t; + } + + // All equal → synchronized + if (min_time == max_time) + return true; + + // Pop all queues that have the oldest front timestamp + for (auto& port : control_ports_) { + if (!port.queue.empty() && port.queue.front()->time == min_time) + port.queue.pop_front(); + } + + // If any queue now empty → cannot sync + for (auto& port : control_ports_) { + if (port.queue.empty()) + return false; + } + } + return false; + } + + void propagate_outputs(bool debug=false) { + + std::unordered_set propagated_outputs; + if (debug) { + + if (debug_output_queues_.size() != num_output_ports()) { + debug_output_queues_.clear(); + for (int i = 0; i < num_output_ports(); i++) { + debug_output_queues_.push_back(MessageQueue()); + } + } + + } else if (!debug && debug_output_queues_.size() > 0) { + debug_output_queues_.clear(); + } + + // Send the messages to the connected operators + for (auto& conn : connections_) { + auto& output_queue = output_ports_[conn.output_port].queue; if (output_queue.empty()) { continue; } - for (size_t i = last_propagated_index; i < output_queue.size(); i++) { + for (size_t i = 0; i < output_queue.size(); i++) { auto msg_copy = output_queue[i]->clone(); #ifdef RTBOT_INSTRUMENTATION RTBOT_RECORD_MESSAGE_SENT(id_, type_name(), std::to_string(i), conn.child->id(), conn.child->type_name(), @@ -392,21 +529,37 @@ class Operator { } else { conn.child->receive_control(std::move(msg_copy), conn.child_input_port); } + propagated_outputs.insert(conn.output_port); } + } + + if (debug) { + for (size_t i = 0; i < num_output_ports(); i++) { + auto& queue = output_ports_[i].queue; + for (size_t j = 0; j < queue.size(); j++) { + auto msg_copy = queue[j]->clone(); + debug_output_queues_[i].push_back(std::move(msg_copy)); + } + } + } - conn.last_propagated_index = output_queue.size(); + for (const size_t& value : propagated_outputs) { + get_output_queue(value).clear(); } + + // Then execute connected operators - for (auto& conn : connections_) { - conn.child->execute(); + for (auto& conn : connections_) { + if (conn.child != nullptr && propagated_outputs.find(conn.output_port) != propagated_outputs.end()) + conn.child->execute(debug); } + } struct Connection { std::shared_ptr child; size_t output_port; - size_t child_input_port; - size_t last_propagated_index{0}; // Track last propagated message per connection + size_t child_input_port; PortKind child_port_kind{PortKind::DATA}; }; @@ -414,9 +567,9 @@ class Operator { std::vector data_ports_; std::vector control_ports_; std::vector output_ports_; - std::vector connections_; - std::set data_ports_with_new_data_; - std::set control_ports_with_new_data_; + std::vector debug_output_queues_; + std::vector connections_; + std::size_t max_size_per_port_; }; } // namespace rtbot diff --git a/libs/core/include/rtbot/Output.h b/libs/core/include/rtbot/Output.h index d2363b46..84119c4e 100644 --- a/libs/core/include/rtbot/Output.h +++ b/libs/core/include/rtbot/Output.h @@ -26,7 +26,7 @@ class Output : public Operator { port_type_names_.push_back(type); // Add input port and matching output port - PortType::add_port(*this, type, true, true); // input port + PortType::add_port(*this, type, true, false ,true); // input port } } @@ -35,6 +35,18 @@ class Output : public Operator { // Get port configuration const std::vector& get_port_types() const { return port_type_names_; } + bool equals(const Output& other) const { + return (port_type_names_ == other.port_type_names_ && Operator::equals(other)); + } + + bool operator==(const Output& other) const { + return equals(other); + } + + bool operator!=(const Output& other) const { + return !(*this == other); + } + protected: void process_data() override { // Forward all messages from inputs to corresponding outputs diff --git a/libs/core/include/rtbot/Pipeline.h b/libs/core/include/rtbot/Pipeline.h index fdc0d2c7..70bf9426 100644 --- a/libs/core/include/rtbot/Pipeline.h +++ b/libs/core/include/rtbot/Pipeline.h @@ -32,7 +32,7 @@ class Pipeline : public Operator { throw std::runtime_error("Unknown input port type: " + type); } // Add data input port - PortType::add_port(*this, type, true, false); + PortType::add_port(*this, type, true, false ,false); input_port_types_.push_back(type); } @@ -42,7 +42,7 @@ class Pipeline : public Operator { throw std::runtime_error("Unknown output port type: " + type); } // Add output port - PortType::add_port(*this, type, false, true); + PortType::add_port(*this, type, false, false ,true); output_port_types_.push_back(type); } } @@ -113,6 +113,33 @@ class Pipeline : public Operator { std::string type_name() const override { return "Pipeline"; } + bool equals(const Pipeline& other) const { + if (input_port_types_ != other.input_port_types_) return false; + if (output_port_types_!= other.output_port_types_) return false; + if (output_mappings_ != other.output_mappings_) return false; + if (entry_operator_ != other.entry_operator_) return false; + if (entry_port_ != other.entry_port_) return false; + if (operators_.size() != other.operators_.size()) return false; + + for (const auto& [key, op1] : operators_) { + auto it = other.operators_.find(key); + if (it == other.operators_.end()) return false; + const auto& op2 = it->second; + if (!op1 || !op2) return false; + else if (*op1 != *op2) return false; + } + + return Operator::equals(other); + } + + bool operator==(const Pipeline& other) const { + return equals(other); + } + + bool operator!=(const Pipeline& other) const { + return !(*this == other); + } + protected: void process_data() override { // Check if we have an entry point configured diff --git a/libs/core/include/rtbot/PortType.h b/libs/core/include/rtbot/PortType.h index 55adeafe..51b6bbc1 100644 --- a/libs/core/include/rtbot/PortType.h +++ b/libs/core/include/rtbot/PortType.h @@ -70,30 +70,22 @@ class PortType { // Helper method to add port of specified type to an operator template - static void add_port(OperatorType& op, const std::string& port_type, bool is_data = true, bool add_output = false) { + static void add_port(OperatorType& op, const std::string& port_type, bool add_data ,bool add_control ,bool add_output) { if (port_type == NUMBER) { - if (is_data) - op.template add_data_port(); - else - op.template add_control_port(); + if (add_data) op.template add_data_port(); + if (add_control) op.template add_control_port(); if (add_output) op.template add_output_port(); } else if (port_type == BOOLEAN) { - if (is_data) - op.template add_data_port(); - else - op.template add_control_port(); + if (add_data) op.template add_data_port(); + if (add_control) op.template add_control_port(); if (add_output) op.template add_output_port(); } else if (port_type == VECTOR_NUMBER) { - if (is_data) - op.template add_data_port(); - else - op.template add_control_port(); + if (add_data) op.template add_data_port(); + if (add_control) op.template add_control_port(); if (add_output) op.template add_output_port(); } else if (port_type == VECTOR_BOOLEAN) { - if (is_data) - op.template add_data_port(); - else - op.template add_control_port(); + if (add_data) op.template add_data_port(); + if (add_control) op.template add_control_port(); if (add_output) op.template add_output_port(); } else { throw std::runtime_error("Unknown port type: " + port_type); diff --git a/libs/core/include/rtbot/ReduceJoin.h b/libs/core/include/rtbot/ReduceJoin.h index ed7ef623..5a094519 100644 --- a/libs/core/include/rtbot/ReduceJoin.h +++ b/libs/core/include/rtbot/ReduceJoin.h @@ -36,25 +36,41 @@ class ReduceJoin : public Join { // Pure virtual function to define the reduction operation virtual std::optional combine(const T& accumulator, const T& next_value) const = 0; + bool equals(const ReduceJoin& other) const { + return (initial_value_ == other.initial_value_ && Operator::equals(other)); + } + protected: void process_data() override { - // First perform synchronization - sync(); + + while(true) { - // Get synchronized data - const auto& synced_data = get_synchronized_data(); + bool is_any_empty; + bool is_sync; + do { + is_any_empty = false; + is_sync = sync_data_inputs(); + for (int i=0; i < num_data_ports(); i++) { + if (get_data_queue(i).empty()) { + is_any_empty = true; + break; + } + } + } while (!is_sync && !is_any_empty ); - // Process each synchronized set of messages - for (const auto& [time, messages] : synced_data) { - std::vector*> typed_messages; - typed_messages.reserve(messages.size()); + if (!is_sync) return; - for (const auto& msg : messages) { - const auto* typed_msg = dynamic_cast*>(msg.get()); + std::vector*> typed_messages; + timestamp_t time = 0; + // Process each synchronized set of messages + for (int i=0; i < num_data_ports(); i++) { + typed_messages.reserve(num_data_ports()); + const auto* typed_msg = dynamic_cast*>(get_data_queue(i).front().get()); if (!typed_msg) { throw std::runtime_error("Invalid message type in ReduceJoin"); } typed_messages.push_back(typed_msg); + time = typed_msg->time; } std::optional result; @@ -72,13 +88,13 @@ class ReduceJoin : public Join { } } + for (int i = 0; i < num_data_ports(); i++) + get_data_queue(i).pop_front(); + if (result.has_value()) { get_output_queue(0).push_back(create_message(time, *result)); } } - - // Clear synchronized data after processing - clear_synchronized_data(); } private: diff --git a/libs/core/include/rtbot/StateSerializer.h b/libs/core/include/rtbot/StateSerializer.h index fc37eccc..a6211f2c 100644 --- a/libs/core/include/rtbot/StateSerializer.h +++ b/libs/core/include/rtbot/StateSerializer.h @@ -20,6 +20,13 @@ using MessageQueue = std::deque>; class StateSerializer { public: // Core serialization methods + static uint64_t fnv1a(const std::string& s); + + static uint64_t hash_double(double value); + + static void serialize_checksum(Bytes& bytes, std::uint64_t checksum); + static std::uint64_t deserialize_checksum(Bytes::const_iterator& it); + static void serialize_timestamp_set(Bytes& bytes, const std::set& times); static void deserialize_timestamp_set(Bytes::const_iterator& it, std::set& times); @@ -64,6 +71,34 @@ class StateSerializer { StateSerializer() = default; }; +inline uint64_t StateSerializer::fnv1a(const std::string& s) { + uint64_t hash = 1469598103934665603ULL; // FNV offset basis + for (unsigned char c : s) { + hash ^= c; + hash *= 1099511628211ULL; // FNV prime + } + return hash; +} + +inline uint64_t StateSerializer::hash_double(double value) { + uint64_t u; + uint64_t quantized = static_cast(value * 1e9); + std::memcpy(&u, &quantized, sizeof(uint64_t)); + return u; +} + +inline void StateSerializer::serialize_checksum(Bytes& bytes, std::uint64_t checksum) { + bytes.insert(bytes.end(), reinterpret_cast(&checksum), + reinterpret_cast(&checksum) + sizeof(checksum)); +} + +inline std::uint64_t StateSerializer::deserialize_checksum(Bytes::const_iterator& it) { + std::uint64_t checksum = 0; + std::memcpy(&checksum, &(*it), sizeof(uint64_t)); + it += sizeof(uint64_t); + return checksum; +} + inline void StateSerializer::serialize_timestamp_set(Bytes& bytes, const std::set& times) { size_t size = times.size(); bytes.insert(bytes.end(), reinterpret_cast(&size), @@ -237,16 +272,27 @@ inline void StateSerializer::serialize_message_queue(Bytes& bytes, const Message inline void StateSerializer::deserialize_message_queue(Bytes::const_iterator& it, MessageQueue& queue) { queue.clear(); - size_t queue_size = *reinterpret_cast(&(*it)); - it += sizeof(size_t); + // ---- read queue_size ---- + size_t queue_size; + std::memcpy(&queue_size, &(*it), sizeof(queue_size)); + it += sizeof(queue_size); + + // ---- read each message ---- for (size_t i = 0; i < queue_size; ++i) { - size_t msg_size = *reinterpret_cast(&(*it)); - it += sizeof(size_t); + // read the size of the message + size_t msg_size; + std::memcpy(&msg_size, &(*it), sizeof(msg_size)); + it += sizeof(msg_size); + + // extract message bytes + Bytes msg_bytes(it, it + msg_size); + + // deserialize the message + queue.push_back(BaseMessage::deserialize(msg_bytes)); - Bytes msg_bytes(it, it + msg_size); - queue.push_back(BaseMessage::deserialize(msg_bytes)); - it += msg_size; + // advance iterator + it += msg_size; } } diff --git a/libs/core/test/BUILD.bazel b/libs/core/test/BUILD.bazel index 3fb787f4..0678bdf0 100644 --- a/libs/core/test/BUILD.bazel +++ b/libs/core/test/BUILD.bazel @@ -18,4 +18,5 @@ cc_test( ], "//conditions:default": [], }), + timeout = "eternal", # or "long" ) diff --git a/libs/core/test/test_demultiplexer.cpp b/libs/core/test/test_demultiplexer.cpp index dade2c7b..177f1071 100644 --- a/libs/core/test/test_demultiplexer.cpp +++ b/libs/core/test/test_demultiplexer.cpp @@ -47,6 +47,32 @@ SCENARIO("Demultiplexer routes messages based on control signals", "[demultiplex REQUIRE(second_output.empty()); } } + + WHEN("Multiple control ports are active and messages exceeds the max_size_per_port()") { + + for (int i = 0; i < demux->max_size_per_port() + 2; i++) { + demux->receive_control(create_message(i, BooleanData{true}), 0); + demux->receive_control(create_message(i, BooleanData{true}), 1); + demux->receive_data(create_message(i, NumberData{i * 2.0}), 0); + } + demux->execute(); + + THEN("Message is routed to both ports") { + const auto& first_output = demux->get_output_queue(0); + const auto& second_output = demux->get_output_queue(1); + + REQUIRE(first_output.size() == demux->max_size_per_port()); + REQUIRE(second_output.size() == demux->max_size_per_port()); + + auto* msg1 = dynamic_cast*>(first_output[0].get()); + auto* msg2 = dynamic_cast*>(second_output[0].get()); + + REQUIRE(msg1->time == 2); + REQUIRE(msg1->data.value == 4.0); + REQUIRE(msg2->time == 2); + REQUIRE(msg2->data.value == 4.0); + } + } } } @@ -131,7 +157,38 @@ SCENARIO("Demultiplexer handles timing and cleanup", "[demultiplexer]") { } } -SCENARIO("Demultiplexer handles state serialization", "[demultiplexer]") { +SCENARIO("Demultiplexer fires exception when invalid data is sent to controls", "[demultiplexer]") { + GIVEN("A demultiplexer with two ports") { + auto demux = std::make_unique>("demux", 2); + + WHEN("Data arrives but controls arrives wit bad data") { + demux->receive_data(create_message(100, NumberData{10.0}), 0); + demux->execute(); + + + THEN("No output is produced and exception is fired") { + REQUIRE_THROWS_AS(demux->receive_control(create_message(100, NumberData{10.0}), 0),std::runtime_error); + const auto& output0 = demux->get_output_queue(0); + const auto& output1 = demux->get_output_queue(1); + + REQUIRE(output0.empty()); + REQUIRE(output1.empty()); + + AND_THEN("recieved proper control and data is produced") { + demux->receive_control(create_message(100, BooleanData{false}), 0); + demux->receive_control(create_message(100, BooleanData{true}), 1); + demux->execute(); + const auto& input = demux->get_data_queue(0); + REQUIRE(input.empty()); + REQUIRE(output0.empty()); + REQUIRE(!output1.empty()); + } + } + } + } +} + +SCENARIO("Demultiplexer handles state serialization", "[demultiplexer][State]") { GIVEN("A demultiplexer with active state") { auto demux = std::make_unique>("demux", 2); @@ -150,6 +207,8 @@ SCENARIO("Demultiplexer handles state serialization", "[demultiplexer]") { demux->execute(); restored->execute(); + REQUIRE(*demux == *restored); + const auto& orig_output0 = demux->get_output_queue(0); const auto& orig_output1 = demux->get_output_queue(1); const auto& rest_output0 = restored->get_output_queue(0); diff --git a/libs/core/test/test_input.cpp b/libs/core/test/test_input.cpp index 5ffcdf1d..7b69c557 100644 --- a/libs/core/test/test_input.cpp +++ b/libs/core/test/test_input.cpp @@ -27,13 +27,27 @@ SCENARIO("Input operator handles single number port", "[input]") { } } + WHEN("Receiving a max_size_per_port() + 1 messages, only max_size_per_port() are forwarded") { + for (int i = 0; i < input->max_size_per_port() + 1; i++) { + input->receive_data(create_message(i, NumberData{i * 2.0}), 0); + } + input->execute(); + + THEN("only 11000 are forwarded") { + const auto& output = input->get_output_queue(0); + REQUIRE(output.size() == input->max_size_per_port()); + const auto* msg = dynamic_cast*>(output.front().get()); + REQUIRE(msg->time == 1); + REQUIRE(msg->data.value == 2.0); + } + } + WHEN("Receiving messages with decreasing timestamps") { input->receive_data(create_message(2, NumberData{42.0}), 0); input->receive_data(create_message(1, NumberData{24.0}), 0); input->execute(); - THEN("Only the first message is forwarded") { - REQUIRE(input->get_last_sent_time(0) == 2); + THEN("Only the first message is forwarded") { const auto& output = input->get_output_queue(0); REQUIRE(output.size() == 1); const auto* msg = dynamic_cast*>(output.front().get()); @@ -129,7 +143,7 @@ SCENARIO("Input operator factory functions work correctly", "[input]") { } } -SCENARIO("Input operator handles state serialization", "[input]") { +SCENARIO("Input operator handles state serialization", "[input][State]") { GIVEN("An input operator with multiple types and processed messages") { auto input = std::make_unique("mixed_input", std::vector{PortType::NUMBER, PortType::BOOLEAN}); @@ -150,10 +164,7 @@ SCENARIO("Input operator handles state serialization", "[input]") { restored->restore(it); THEN("State is correctly preserved") { - REQUIRE(restored->get_last_sent_time(0) == 1); - REQUIRE(restored->get_last_sent_time(1) == 2); - REQUIRE(restored->has_sent(0)); - REQUIRE(restored->has_sent(1)); + REQUIRE(*restored == *input); } AND_WHEN("New messages are received") { @@ -176,11 +187,12 @@ SCENARIO("Input operator handles state serialization", "[input]") { // Create new operator with different configuration auto mismatched = std::make_unique( - "mixed_input", std::vector{PortType::BOOLEAN, PortType::NUMBER}); // Wrong order + "mixed_input", std::vector{PortType::BOOLEAN, PortType::NUMBER}); THEN("Type mismatch is detected") { auto it = state.cbegin(); - REQUIRE_THROWS_AS(mismatched->restore(it), std::runtime_error); + mismatched->restore(it); + REQUIRE(*input != *mismatched); } } } diff --git a/libs/core/test/test_join.cpp b/libs/core/test/test_join.cpp index 60b50cf2..fd1d1b10 100644 --- a/libs/core/test/test_join.cpp +++ b/libs/core/test/test_join.cpp @@ -64,6 +64,45 @@ SCENARIO("Join operator handles basic synchronization", "[join]") { REQUIRE(msg1->data.value == 24.0); } } + + WHEN("Receiving messages above the limit max_size_per_port()") { + + for (int i = 0; i < join->max_size_per_port() + 2; i++) { + join->receive_data(create_message(i, NumberData{2.0 * i}), 1); + join->receive_data(create_message(i, NumberData{3.0 * i}), 0); + } + + join->execute(); + + THEN("max_size_per_port() messages are synchronized correctly, first and second message are dropped") { + auto& output0 = join->get_output_queue(0); + auto& output1 = join->get_output_queue(1); + + const auto& input0 = join->get_data_queue(0); + const auto& input1 = join->get_data_queue(1); + + + REQUIRE(output0.size() == join->max_size_per_port()); + REQUIRE(output1.size() == join->max_size_per_port()); + + REQUIRE(input0.size() == 0); + REQUIRE(input1.size() == 0); + + const Message* msg0; + const Message* msg1; + + for (int i = 0; i < join->max_size_per_port(); i++) { + msg0 = dynamic_cast*>(output0.front().get()); + msg1 = dynamic_cast*>(output1.front().get()); + REQUIRE(msg0->time == i+2); + REQUIRE(msg0->data.value == (i+2.0) * 3); + REQUIRE(msg1->time == i+2); + REQUIRE(msg1->data.value == (i+2.0) * 2); + output0.pop_front(); + output1.pop_front(); + } + } + } } } @@ -108,15 +147,16 @@ SCENARIO("Join operator handles multiple types", "[join]") { } } - WHEN("Sending message to wrong port type") { + WHEN("Sending message to wrong port type and port indexes") { THEN("Type mismatch is detected") { REQUIRE_THROWS_AS(join->receive_data(create_message(1, BooleanData{true}), 0), std::runtime_error); + REQUIRE_THROWS_AS(join->receive_control(create_message(1, BooleanData{true}), 0), std::runtime_error); } } } } -SCENARIO("Join operator handles state serialization", "[join]") { +SCENARIO("Join operator handles state serialization", "[join][State]") { GIVEN("A join with processed messages") { auto join = std::make_unique("join1", std::vector{PortType::NUMBER, PortType::NUMBER}); @@ -136,6 +176,7 @@ SCENARIO("Join operator handles state serialization", "[join]") { restored->restore(it); AND_WHEN("New synchronized messages are received") { + REQUIRE(*restored == *join); restored->receive_data(create_message(3, NumberData{84.0}), 0); restored->receive_data(create_message(3, NumberData{48.0}), 1); restored->execute(); diff --git a/libs/core/test/test_multiplexer.cpp b/libs/core/test/test_multiplexer.cpp index 1be43dab..6e8b3e5c 100644 --- a/libs/core/test/test_multiplexer.cpp +++ b/libs/core/test/test_multiplexer.cpp @@ -76,6 +76,27 @@ SCENARIO("Multiplexer routes messages based on control signals", "[multiplexer]" } } + WHEN("Receiving multiple messages in sequence and exceeds max_size_per_port()") { + + for (int i = 0; i < mult->max_size_per_port() + 5; i ++) { + mult->receive_control(create_message(i, BooleanData{i % 2 == 0}), 0); + mult->receive_control(create_message(i, BooleanData{i % 2 == 1}), 1); + mult->receive_data(create_message(i, NumberData{i * 2.0}), 0); + mult->receive_data(create_message(i, NumberData{i * 3.0}), 1); + } + mult->execute(); + + THEN("It forwards data from the correct ports in sequence, it drops 5 messages") { + const auto& output = mult->get_output_queue(0); + REQUIRE(output.size() == mult->max_size_per_port()); + + auto* msg1 = dynamic_cast*>(output[0].get()); + REQUIRE(msg1 != nullptr); + REQUIRE(msg1->time == 5); + REQUIRE(msg1->data.value == 15.0); + } + } + WHEN("Receiving control signals with no active port") { mult->receive_control(create_message(1, BooleanData{false}), 0); mult->receive_control(create_message(1, BooleanData{false}), 1); @@ -104,7 +125,7 @@ SCENARIO("Multiplexer routes messages based on control signals", "[multiplexer]" } } -SCENARIO("Multiplexer state serialization", "[multiplexer]") { +SCENARIO("Multiplexer state serialization", "[multiplexer][State]") { GIVEN("A multiplexer with active state") { auto mult = std::make_unique>("mult", 2); @@ -133,6 +154,8 @@ SCENARIO("Multiplexer state serialization", "[multiplexer]") { send_data(mult); send_data(restored); + REQUIRE(*mult == *restored); + // Compare outputs const auto& orig_output = mult->get_output_queue(0); const auto& rest_output = restored->get_output_queue(0); diff --git a/libs/core/test/test_output.cpp b/libs/core/test/test_output.cpp index b086014b..f8e43bcf 100644 --- a/libs/core/test/test_output.cpp +++ b/libs/core/test/test_output.cpp @@ -159,4 +159,58 @@ SCENARIO("Output operator handles type mismatches", "[output]") { } } } +} + +SCENARIO("Output operator handles state serialization", "[output][State]") { + GIVEN("An output operator with multiple types and processed messages") { + auto output = std::make_unique("mixed_output", std::vector{PortType::NUMBER, PortType::BOOLEAN}); + + output->receive_data(create_message(1, NumberData{42.0}), 0); + output->receive_data(create_message(2, BooleanData{true}), 1); + output->execute(); + + WHEN("State is serialized and restored") { + // Serialize state + Bytes state = output->collect(); + + // Create new operator with same configuration + auto restored = std::make_unique("mixed_output", std::vector{PortType::NUMBER, PortType::BOOLEAN}); + + // Restore state + auto it = state.cbegin(); + restored->restore(it); + + THEN("State is correctly preserved") { + REQUIRE(*restored == *output); + } + + AND_WHEN("New messages are received") { + restored->clear_all_output_ports(); + restored->receive_data(create_message(3, NumberData{84.0}), 0); + restored->execute(); + + THEN("Messages are processed based on restored state") { + const auto& output_queue = restored->get_output_queue(0); + REQUIRE(output_queue.size() == 1); + const auto* msg = dynamic_cast*>(output_queue.front().get()); + REQUIRE(msg->time == 3); + REQUIRE(msg->data.value == 84.0); + } + } + } + + WHEN("Restoring with mismatched configuration") { + Bytes state = output->collect(); + + // Create new operator with different configuration + auto mismatched = std::make_unique( + "mixed_input", std::vector{PortType::BOOLEAN, PortType::NUMBER}); + + THEN("Type mismatch is detected") { + auto it = state.cbegin(); + mismatched->restore(it); + REQUIRE(*output != *mismatched); + } + } + } } \ No newline at end of file diff --git a/libs/std/include/rtbot/std/ArithmeticScalar.h b/libs/std/include/rtbot/std/ArithmeticScalar.h index 41869ce4..b24a1c2a 100644 --- a/libs/std/include/rtbot/std/ArithmeticScalar.h +++ b/libs/std/include/rtbot/std/ArithmeticScalar.h @@ -24,6 +24,10 @@ class ArithmeticScalar : public Operator { // Pure virtual method that derived classes must implement virtual double apply(double value) const = 0; + bool equals(const ArithmeticScalar& other) const { + return Operator::equals(other); + } + protected: void process_data() override { auto& input_queue = get_data_queue(0); @@ -50,6 +54,18 @@ class Add : public ArithmeticScalar { double apply(double x) const override { return x + value_; } double get_value() const { return value_; } + bool equals(const Add& other) const { + return (StateSerializer::hash_double(value_) == StateSerializer::hash_double(other.value_) && ArithmeticScalar::equals(other)); + } + + bool operator==(const Add& other) const { + return equals(other); + } + + bool operator!=(const Add& other) const { + return !(*this == other); + } + private: double value_; }; @@ -61,6 +77,18 @@ class Scale : public ArithmeticScalar { double apply(double x) const override { return x * value_; } double get_value() const { return value_; } + bool equals(const Scale& other) const { + return (StateSerializer::hash_double(value_) == StateSerializer::hash_double(other.value_) && ArithmeticScalar::equals(other)); + } + + bool operator==(const Scale& other) const { + return equals(other); + } + + bool operator!=(const Scale& other) const { + return !(*this == other); + } + private: double value_; }; @@ -72,6 +100,18 @@ class Power : public ArithmeticScalar { double apply(double x) const override { return std::pow(x, value_); } double get_value() const { return value_; } + bool equals(const Power& other) const { + return (StateSerializer::hash_double(value_) == StateSerializer::hash_double(other.value_) && ArithmeticScalar::equals(other)); + } + + bool operator==(const Power& other) const { + return equals(other); + } + + bool operator!=(const Power& other) const { + return !(*this == other); + } + private: double value_; }; @@ -81,6 +121,19 @@ class Sin : public ArithmeticScalar { public: Sin(std::string id) : ArithmeticScalar(std::move(id)) {} std::string type_name() const override { return "Sin"; } + + bool equals(const Sin& other) const { + return ArithmeticScalar::equals(other); + } + + bool operator==(const Sin& other) const { + return equals(other); + } + + bool operator!=(const Sin& other) const { + return !(*this == other); + } + double apply(double x) const override { return std::sin(x); } }; @@ -88,6 +141,19 @@ class Cos : public ArithmeticScalar { public: Cos(std::string id) : ArithmeticScalar(std::move(id)) {} std::string type_name() const override { return "Cos"; } + + bool equals(const Cos& other) const { + return ArithmeticScalar::equals(other); + } + + bool operator==(const Cos& other) const { + return equals(other); + } + + bool operator!=(const Cos& other) const { + return !(*this == other); + } + double apply(double x) const override { return std::cos(x); } }; @@ -95,6 +161,19 @@ class Tan : public ArithmeticScalar { public: Tan(std::string id) : ArithmeticScalar(std::move(id)) {} std::string type_name() const override { return "Tan"; } + + bool equals(const Tan& other) const { + return ArithmeticScalar::equals(other); + } + + bool operator==(const Tan& other) const { + return equals(other); + } + + bool operator!=(const Tan& other) const { + return !(*this == other); + } + double apply(double x) const override { return std::tan(x); } }; @@ -103,6 +182,19 @@ class Exp : public ArithmeticScalar { public: Exp(std::string id) : ArithmeticScalar(std::move(id)) {} std::string type_name() const override { return "Exp"; } + + bool equals(const Exp& other) const { + return ArithmeticScalar::equals(other); + } + + bool operator==(const Exp& other) const { + return equals(other); + } + + bool operator!=(const Exp& other) const { + return !(*this == other); + } + double apply(double x) const override { return std::exp(x); } }; @@ -110,6 +202,19 @@ class Log : public ArithmeticScalar { public: Log(std::string id) : ArithmeticScalar(std::move(id)) {} std::string type_name() const override { return "Log"; } + + bool equals(const Log& other) const { + return ArithmeticScalar::equals(other); + } + + bool operator==(const Log& other) const { + return equals(other); + } + + bool operator!=(const Log& other) const { + return !(*this == other); + } + double apply(double x) const override { return std::log(x); } }; @@ -117,6 +222,19 @@ class Log10 : public ArithmeticScalar { public: Log10(std::string id) : ArithmeticScalar(std::move(id)) {} std::string type_name() const override { return "Log10"; } + + bool equals(const Log10& other) const { + return ArithmeticScalar::equals(other); + } + + bool operator==(const Log10& other) const { + return equals(other); + } + + bool operator!=(const Log10& other) const { + return !(*this == other); + } + double apply(double x) const override { return std::log10(x); } }; @@ -125,6 +243,19 @@ class Abs : public ArithmeticScalar { public: Abs(std::string id) : ArithmeticScalar(std::move(id)) {} std::string type_name() const override { return "Abs"; } + + bool equals(const Abs& other) const { + return ArithmeticScalar::equals(other); + } + + bool operator==(const Abs& other) const { + return equals(other); + } + + bool operator!=(const Abs& other) const { + return !(*this == other); + } + double apply(double x) const override { return std::abs(x); } }; @@ -132,6 +263,19 @@ class Sign : public ArithmeticScalar { public: Sign(std::string id) : ArithmeticScalar(std::move(id)) {} std::string type_name() const override { return "Sign"; } + + bool equals(const Sign& other) const { + return ArithmeticScalar::equals(other); + } + + bool operator==(const Sign& other) const { + return equals(other); + } + + bool operator!=(const Sign& other) const { + return !(*this == other); + } + double apply(double x) const override { return x > 0 ? 1.0 : (x < 0 ? -1.0 : 0.0); } }; @@ -140,6 +284,19 @@ class Floor : public ArithmeticScalar { public: Floor(std::string id) : ArithmeticScalar(std::move(id)) {} std::string type_name() const override { return "Floor"; } + + bool equals(const Floor& other) const { + return ArithmeticScalar::equals(other); + } + + bool operator==(const Floor& other) const { + return equals(other); + } + + bool operator!=(const Floor& other) const { + return !(*this == other); + } + double apply(double x) const override { return std::floor(x); } }; @@ -147,6 +304,19 @@ class Ceil : public ArithmeticScalar { public: Ceil(std::string id) : ArithmeticScalar(std::move(id)) {} std::string type_name() const override { return "Ceil"; } + + bool equals(const Ceil& other) const { + return ArithmeticScalar::equals(other); + } + + bool operator==(const Ceil& other) const { + return equals(other); + } + + bool operator!=(const Ceil& other) const { + return !(*this == other); + } + double apply(double x) const override { return std::ceil(x); } }; @@ -154,6 +324,19 @@ class Round : public ArithmeticScalar { public: Round(std::string id) : ArithmeticScalar(std::move(id)) {} std::string type_name() const override { return "Round"; } + + bool equals(const Round& other) const { + return ArithmeticScalar::equals(other); + } + + bool operator==(const Round& other) const { + return equals(other); + } + + bool operator!=(const Round& other) const { + return !(*this == other); + } + double apply(double x) const override { return std::round(x); } }; diff --git a/libs/std/include/rtbot/std/ArithmeticSync.h b/libs/std/include/rtbot/std/ArithmeticSync.h index dd6c2dd2..8d00a720 100644 --- a/libs/std/include/rtbot/std/ArithmeticSync.h +++ b/libs/std/include/rtbot/std/ArithmeticSync.h @@ -17,6 +17,11 @@ class ArithmeticSync : public ReduceJoin { : ReduceJoin(std::move(id), num_ports, init_value) {} std::string type_name() const override = 0; + + bool equals(const ArithmeticSync& other) const { + return ReduceJoin::equals(other); + } + }; class Addition : public ArithmeticSync { @@ -26,6 +31,18 @@ class Addition : public ArithmeticSync { std::string type_name() const override { return "Addition"; } + bool equals(const Addition& other) const { + return ArithmeticSync::equals(other); + } + + bool operator==(const Addition& other) const { + return equals(other); + } + + bool operator!=(const Addition& other) const { + return !(*this == other); + } + protected: std::optional combine(const NumberData& acc, const NumberData& next) const override { return NumberData{acc.value + next.value}; @@ -38,6 +55,18 @@ class Subtraction : public ArithmeticSync { std::string type_name() const override { return "Subtraction"; } + bool equals(const Subtraction& other) const { + return ArithmeticSync::equals(other); + } + + bool operator==(const Subtraction& other) const { + return equals(other); + } + + bool operator!=(const Subtraction& other) const { + return !(*this == other); + } + protected: std::optional combine(const NumberData& acc, const NumberData& next) const override { // For 2 inputs, maintain original behavior @@ -56,6 +85,18 @@ class Multiplication : public ArithmeticSync { std::string type_name() const override { return "Multiplication"; } + bool equals(const Multiplication& other) const { + return ArithmeticSync::equals(other); + } + + bool operator==(const Multiplication& other) const { + return equals(other); + } + + bool operator!=(const Multiplication& other) const { + return !(*this == other); + } + protected: std::optional combine(const NumberData& acc, const NumberData& next) const override { return NumberData{acc.value * next.value}; @@ -68,6 +109,18 @@ class Division : public ArithmeticSync { std::string type_name() const override { return "Division"; } + bool equals(const Division& other) const { + return ArithmeticSync::equals(other); + } + + bool operator==(const Division& other) const { + return equals(other); + } + + bool operator!=(const Division& other) const { + return !(*this == other); + } + protected: std::optional combine(const NumberData& acc, const NumberData& next) const override { if (next.value == 0) { diff --git a/libs/std/include/rtbot/std/BooleanSync.h b/libs/std/include/rtbot/std/BooleanSync.h index 7dbd6fa2..103e3c38 100644 --- a/libs/std/include/rtbot/std/BooleanSync.h +++ b/libs/std/include/rtbot/std/BooleanSync.h @@ -20,6 +20,10 @@ class BooleanSync : public ReduceJoin { size_t get_num_ports() const { return num_ports_; } std::string type_name() const override = 0; + bool equals(const BooleanSync& other) const { + return ReduceJoin::equals(other); + } + protected: size_t num_ports_; }; @@ -30,6 +34,18 @@ class LogicalAnd : public BooleanSync { std::string type_name() const override { return "LogicalAnd"; } + bool equals(const LogicalAnd& other) const { + return BooleanSync::equals(other); + } + + bool operator==(const LogicalAnd& other) const { + return equals(other); + } + + bool operator!=(const LogicalAnd& other) const { + return !(*this == other); + } + protected: std::optional combine(const BooleanData& acc, const BooleanData& next) const override { return BooleanData{acc.value && next.value}; @@ -42,6 +58,18 @@ class LogicalOr : public BooleanSync { std::string type_name() const override { return "LogicalOr"; } + bool equals(const LogicalOr& other) const { + return BooleanSync::equals(other); + } + + bool operator==(const LogicalOr& other) const { + return equals(other); + } + + bool operator!=(const LogicalOr& other) const { + return !(*this == other); + } + protected: std::optional combine(const BooleanData& acc, const BooleanData& next) const override { return BooleanData{acc.value || next.value}; @@ -54,6 +82,18 @@ class LogicalXor : public BooleanSync { std::string type_name() const override { return "LogicalXor"; } + bool equals(const LogicalXor& other) const { + return BooleanSync::equals(other); + } + + bool operator==(const LogicalXor& other) const { + return equals(other); + } + + bool operator!=(const LogicalXor& other) const { + return !(*this == other); + } + protected: std::optional combine(const BooleanData& acc, const BooleanData& next) const override { return BooleanData{acc.value != next.value}; @@ -66,6 +106,18 @@ class LogicalNand : public BooleanSync { std::string type_name() const override { return "LogicalNand"; } + bool equals(const LogicalNand& other) const { + return BooleanSync::equals(other); + } + + bool operator==(const LogicalNand& other) const { + return equals(other); + } + + bool operator!=(const LogicalNand& other) const { + return !(*this == other); + } + protected: std::optional combine(const BooleanData& acc, const BooleanData& next) const override { auto temp = acc.value && next.value; @@ -79,6 +131,18 @@ class LogicalNor : public BooleanSync { std::string type_name() const override { return "LogicalNor"; } + bool equals(const LogicalNor& other) const { + return BooleanSync::equals(other); + } + + bool operator==(const LogicalNor& other) const { + return equals(other); + } + + bool operator!=(const LogicalNor& other) const { + return !(*this == other); + } + protected: std::optional combine(const BooleanData& acc, const BooleanData& next) const override { return BooleanData{!(acc.value || next.value)}; @@ -91,6 +155,18 @@ class LogicalXnor : public BooleanSync { std::string type_name() const override { return "LogicalXnor"; } + bool equals(const LogicalXnor& other) const { + return BooleanSync::equals(other); + } + + bool operator==(const LogicalXnor& other) const { + return equals(other); + } + + bool operator!=(const LogicalXnor& other) const { + return !(*this == other); + } + protected: std::optional combine(const BooleanData& acc, const BooleanData& next) const override { return BooleanData{acc.value == next.value}; @@ -103,6 +179,18 @@ class LogicalImplication : public BooleanSync { std::string type_name() const override { return "LogicalImplication"; } + bool equals(const LogicalImplication& other) const { + return BooleanSync::equals(other); + } + + bool operator==(const LogicalImplication& other) const { + return equals(other); + } + + bool operator!=(const LogicalImplication& other) const { + return !(*this == other); + } + protected: std::optional combine(const BooleanData& acc, const BooleanData& next) const override { return BooleanData{!acc.value || next.value}; diff --git a/libs/std/include/rtbot/std/Composite.h b/libs/std/include/rtbot/std/Composite.h deleted file mode 100644 index fa87596c..00000000 --- a/libs/std/include/rtbot/std/Composite.h +++ /dev/null @@ -1,365 +0,0 @@ -#ifndef COMPOSITE_H -#define COMPOSITE_H - -// TODO: migrate to new Operator -/* -#include -#include - -#include "rtbot/Demultiplexer.h" -#include "rtbot/Input.h" -#include "rtbot/Join.h" -#include "rtbot/Operator.h" -#include "rtbot/Output.h" -#include "rtbot/finance/RelativeStrengthIndex.h" -#include "rtbot/std/Add.h" -#include "rtbot/std/AutoRegressive.h" -#include "rtbot/std/Constant.h" -#include "rtbot/std/CosineResampler.h" -#include "rtbot/std/Count.h" -#include "rtbot/std/CumulativeSum.h" -#include "rtbot/std/Difference.h" -#include "rtbot/std/Division.h" -#include "rtbot/std/EqualTo.h" -#include "rtbot/std/FiniteImpulseResponse.h" -#include "rtbot/std/GreaterThan.h" -#include "rtbot/std/HermiteResampler.h" -#include "rtbot/std/Identity.h" -#include "rtbot/std/LessThan.h" -#include "rtbot/std/Linear.h" -#include "rtbot/std/Minus.h" -#include "rtbot/std/MovingAverage.h" -#include "rtbot/std/PeakDetector.h" -#include "rtbot/std/Power.h" -#include "rtbot/std/Scale.h" -#include "rtbot/std/StandardDeviation.h" -#include "rtbot/std/TimeShift.h" -#include "rtbot/std/Variable.h" - -namespace rtbot { - -using namespace std; - -template -struct Composite : public Operator // TODO: improve from chain to graph -{ - Composite() = default; - - Composite(string const &id) : Operator(id) { this->input == nullptr; } - - string typeName() const override { return "Composite"; } - - virtual ~Composite() = default; - - string createInput(string id, size_t numPorts = 1) { - if (this->ops.count(id) == 0 && this->input == nullptr) { - this->input = make_shared>(id, numPorts); - vector dataI = this->input->getDataInputs(); - vector controlI = this->input->getControlInputs(); - for (int i = 0; i < dataI.size(); i++) this->addDataInput(dataI.at(i), 1); - for (int i = 0; i < controlI.size(); i++) this->addControlInput(controlI.at(i), 1); - this->ops.emplace(id, this->input); - } else if (this->ops.count(id) > 0) - throw std::runtime_error(typeName() + ": unique id is required"); - else if (this->input != nullptr) - throw std::runtime_error(typeName() + ": input operator already setup"); - return id; - } - - string createAdd(string id, V value) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id, value)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createAutoregressive(string id, vector const &coeff) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id, coeff)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createConstant(string id, V value) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id, value)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createCosineResampler(string id, T dt) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id, dt)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createCount(string id) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createCumulativeSum(string id) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createDifference(string id, bool useOldestTime = true) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id, useOldestTime)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createJoin(string id, size_t numPorts) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id, numPorts)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createDivision(string id) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createMinus(string id) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createLinear(string id, vector const &coeff) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id, coeff)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createEqualTo(string id, V value) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id, value)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createGreaterThan(string id, V value) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id, value)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createLessThan(string id, V value) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id, value)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createHermiteResampler(string id, T dt) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id, dt)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createIdentity(string id) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createMovingAverage(string id, size_t n) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id, n)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createFiniteImpulseResponse(string id, vector coeff) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id, coeff)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createVariable(string id, V value = 0) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id, value)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createOutput(string id, size_t numPorts = 1) { - if (this->ops.count(id) == 0 && this->output == nullptr) { - this->output = make_shared>(id, numPorts); - vector outs = this->output->getOutputs(); - for (int i = 0; i < outs.size(); i++) this->addOutput(outs.at(i)); - this->ops.emplace(id, this->output); - } else if (this->ops.count(id) > 0) - throw std::runtime_error(typeName() + ": unique id is required"); - else if (this->output != nullptr) - throw std::runtime_error(typeName() + ": output operator already setup"); - - return id; - } - - string createPeakDetector(string id, size_t n) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id, n)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createStandardDeviation(string id, size_t n) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id, n)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createDemultiplexer(string id, size_t numOutputPorts = 2) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id, numOutputPorts)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createPower(string id, V value) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id, value)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createScale(string id, V value) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id, value)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - string createTimeShift(string id, T dt = 1, int times = 1) { - if (this->ops.count(id) == 0) { - this->ops.emplace(id, make_shared>(id, dt, times)); - return id; - } else - throw std::runtime_error(typeName() + ": unique id is required"); - } - - shared_ptr> getOperator(string id) { - if (this->ops.count(id) == 1) { - return this->ops.find(id)->second; - } else - throw std::runtime_error(typeName() + ": operator now found"); - } - - Operator *createInternalConnection(string operatorId, string childId, string outputPort = "", - string inputPort = "") { - if (this->ops.count(operatorId) == 0) - throw std::runtime_error(typeName() + ": operator " + operatorId + - " has not been added to the operator list, please use add[Operator] first"); - if (this->ops.count(childId) == 0) - throw std::runtime_error(typeName() + ": operator " + childId + - " has not been added to the operator list, please use add[Operator] first"); - auto op = this->ops.find(operatorId)->second; - auto child = this->ops.find(childId)->second; - auto childptr = op.get()->connect(child.get(), outputPort, inputPort); - if (childptr != nullptr) - return this; - else - throw std::runtime_error(typeName() + ": connection was not successful"); - } - - virtual void receiveData(Message msg, string inputPort = "") override { - if (this->input != nullptr) - this->input.get()->receiveData(msg, inputPort); - else - throw std::runtime_error(typeName() + ": the input operator have not been found"); - } - - virtual ProgramMessage executeData() override { - if (this->input != nullptr) - return this->input.get()->executeData(); - else - throw std::runtime_error(typeName() + ": the input operator have not been found"); - } - - virtual void receiveControl(Message msg, string inputPort = "") override { - if (this->input != nullptr) - this->input.get()->receiveControl(msg, inputPort); - else - throw std::runtime_error(typeName() + ": the input operator have not been found"); - } - - virtual ProgramMessage executeControl() override { - if (this->input != nullptr) - return this->input.get()->executeControl(); - else - throw std::runtime_error(typeName() + ": the input operator have not been found"); - } - - virtual OperatorMessage processData() override { return {}; } - - virtual OperatorMessage processControl() override { return {}; } - - virtual Operator *connect(Operator &child, string outputPort = "", string inputPort = "") override { - if (this->output != nullptr) - return this->output->connect(child, outputPort, inputPort); - else if (this->output == nullptr) - throw std::runtime_error(typeName() + ": output operator have not been found"); - return nullptr; - } - - virtual Operator *connect(Operator *child, string outputPort = "", string inputPort = "") override { - if (this->output != nullptr) - return this->output->connect(child, outputPort, inputPort); - else if (this->output == nullptr) - throw std::runtime_error(typeName() + ": output operator have not been found"); - return nullptr; - } - - private: - map>> ops; - shared_ptr> input; - shared_ptr> output; -}; - -} // namespace rtbot - -#endif // COMPOSITE_H - -*/ \ No newline at end of file diff --git a/libs/std/include/rtbot/std/Constant.h b/libs/std/include/rtbot/std/Constant.h index 655fe771..1097246f 100644 --- a/libs/std/include/rtbot/std/Constant.h +++ b/libs/std/include/rtbot/std/Constant.h @@ -21,6 +21,18 @@ class Constant : public Operator { // Accessor for the constant value const OutputT& get_value() const { return value_; } + bool equals(const Constant& other) const { + return (value_ == other.value_ && Operator::equals(other)); + } + + bool operator==(const Constant& other) const { + return equals(other); + } + + bool operator!=(const Constant& other) const { + return !(*this == other); + } + protected: void process_data() override { auto& input_queue = get_data_queue(0); diff --git a/libs/std/include/rtbot/std/Count.h b/libs/std/include/rtbot/std/Count.h index 64e5aee7..f907acb8 100644 --- a/libs/std/include/rtbot/std/Count.h +++ b/libs/std/include/rtbot/std/Count.h @@ -24,6 +24,18 @@ class Count : public Operator { std::string type_name() const override { return "Count"; } + bool equals(const Count& other) const { + return (count_ == other.count_ && Operator::equals(other)); + } + + bool operator==(const Count& other) const { + return equals(other); + } + + bool operator!=(const Count& other) const { + return !(*this == other); + } + // Serialize count_ since it's our only state Bytes collect() override { Bytes bytes = Operator::collect(); @@ -33,8 +45,12 @@ class Count : public Operator { } void restore(Bytes::const_iterator& it) override { + // Restore parent first Operator::restore(it); - count_ = *reinterpret_cast(&(*it)); + + // Safely read count_ + count_ = 0; + std::memcpy(&count_, &(*it), sizeof(count_)); it += sizeof(count_); } diff --git a/libs/std/include/rtbot/std/CumulativeSum.h b/libs/std/include/rtbot/std/CumulativeSum.h index 58afe2f6..e3a760f9 100644 --- a/libs/std/include/rtbot/std/CumulativeSum.h +++ b/libs/std/include/rtbot/std/CumulativeSum.h @@ -25,6 +25,18 @@ class CumulativeSum : public Operator { // Access current sum double get_sum() const { return sum_; } + + bool equals(const CumulativeSum& other) const { + return (StateSerializer::hash_double(sum_) == StateSerializer::hash_double(other.sum_) && Operator::equals(other)); + } + + bool operator==(const CumulativeSum& other) const { + return equals(other); // still check base class + } + + bool operator!=(const CumulativeSum& other) const { + return !(*this == other); + } // State serialization Bytes collect() override { @@ -35,9 +47,13 @@ class CumulativeSum : public Operator { } void restore(Bytes::const_iterator& it) override { + // Restore base state first Operator::restore(it); - sum_ = *reinterpret_cast(&(*it)); - it += sizeof(double); + + // Safely read a double value + std::memcpy(&sum_, &(*it), sizeof(sum_)); + it += sizeof(sum_); + } protected: diff --git a/libs/std/include/rtbot/std/Difference.h b/libs/std/include/rtbot/std/Difference.h index 0ef22fb6..d299c266 100644 --- a/libs/std/include/rtbot/std/Difference.h +++ b/libs/std/include/rtbot/std/Difference.h @@ -22,6 +22,18 @@ class Difference : public Buffer { bool get_use_oldest_time() const { return use_oldest_time_; } + bool equals(const Difference& other) const { + return (use_oldest_time_ == other.use_oldest_time_ && Buffer::equals(other)); + } + + bool operator==(const Difference& other) const { + return equals(other); + } + + bool operator!=(const Difference& other) const { + return !(*this == other); + } + protected: std::vector>> process_message(const Message* msg) override { std::vector>> output; diff --git a/libs/std/include/rtbot/std/FilterScalar.h b/libs/std/include/rtbot/std/FilterScalar.h index 45765d8c..4b660e98 100644 --- a/libs/std/include/rtbot/std/FilterScalar.h +++ b/libs/std/include/rtbot/std/FilterScalar.h @@ -23,6 +23,10 @@ class FilterScalar : public Operator { // Pure virtual method that derived classes must implement virtual bool evaluate(double value) const = 0; + bool equals(const FilterScalar& other) const { + return Operator::equals(other); + } + protected: void process_data() override { auto& input_queue = get_data_queue(0); @@ -52,6 +56,18 @@ class LessThan : public FilterScalar { bool evaluate(double x) const override { return x < threshold_; } double get_threshold() const { return threshold_; } + bool equals(const LessThan& other) const { + return StateSerializer::hash_double(threshold_) == StateSerializer::hash_double(other.threshold_) && FilterScalar::equals(other); + } + + bool operator==(const LessThan& other) const { + return equals(other); + } + + bool operator!=(const LessThan& other) const { + return !(*this == other); + } + private: double threshold_; }; @@ -63,6 +79,18 @@ class GreaterThan : public FilterScalar { bool evaluate(double x) const override { return x > threshold_; } double get_threshold() const { return threshold_; } + bool equals(const GreaterThan& other) const { + return StateSerializer::hash_double(threshold_) == StateSerializer::hash_double(other.threshold_) && FilterScalar::equals(other); + } + + bool operator==(const GreaterThan& other) const { + return equals(other); + } + + bool operator!=(const GreaterThan& other) const { + return !(*this == other); + } + private: double threshold_; }; @@ -77,6 +105,20 @@ class EqualTo : public FilterScalar { double get_value() const { return value_; } double get_epsilon() const { return epsilon_; } + bool equals(const EqualTo& other) const { + return StateSerializer::hash_double(value_) == StateSerializer::hash_double(other.value_) + && StateSerializer::hash_double(epsilon_) == StateSerializer::hash_double(other.epsilon_) + && FilterScalar::equals(other); + } + + bool operator==(const EqualTo& other) const { + return equals(other); + } + + bool operator!=(const EqualTo& other) const { + return !(*this == other); + } + private: double value_; double epsilon_; // Tolerance for floating-point comparison @@ -92,6 +134,20 @@ class NotEqualTo : public FilterScalar { double get_value() const { return value_; } double get_epsilon() const { return epsilon_; } + bool equals(const NotEqualTo& other) const { + return StateSerializer::hash_double(value_) == StateSerializer::hash_double(other.value_) + && StateSerializer::hash_double(epsilon_) == StateSerializer::hash_double(other.epsilon_) + && FilterScalar::equals(other); + } + + bool operator==(const NotEqualTo& other) const { + return equals(other); + } + + bool operator!=(const NotEqualTo& other) const { + return !(*this == other); + } + private: double value_; double epsilon_; // Tolerance for floating-point comparison diff --git a/libs/std/include/rtbot/std/FiniteImpulseResponse.h b/libs/std/include/rtbot/std/FiniteImpulseResponse.h index e9e9af34..7cedddfc 100644 --- a/libs/std/include/rtbot/std/FiniteImpulseResponse.h +++ b/libs/std/include/rtbot/std/FiniteImpulseResponse.h @@ -28,6 +28,18 @@ class FiniteImpulseResponse : public Buffer { const std::vector& get_coefficients() const { return coeffs_; } + bool equals(const FiniteImpulseResponse& other) const { + return coeffs_ == other.coeffs_ && Buffer::equals(other); + } + + bool operator==(const FiniteImpulseResponse& other) const { + return equals(other); + } + + bool operator!=(const FiniteImpulseResponse& other) const { + return !(*this == other); + } + protected: std::vector>> process_message(const Message* msg) override { std::vector>> output; diff --git a/libs/std/include/rtbot/std/Function.h b/libs/std/include/rtbot/std/Function.h index 5c5e2a59..b3bfade1 100644 --- a/libs/std/include/rtbot/std/Function.h +++ b/libs/std/include/rtbot/std/Function.h @@ -35,6 +35,18 @@ class Function : public Operator { const std::vector>& get_points() const { return points_; } InterpolationType get_interpolation_type() const { return type_; } + bool equals(const Function& other) const { + return (points_ == other.points_ && type_ == other.type_ && Operator::equals(other)); + } + + bool operator==(const Function& other) const { + return equals(other); + } + + bool operator!=(const Function& other) const { + return !(*this == other); + } + protected: void process_data() override { auto& input_queue = get_data_queue(0); diff --git a/libs/std/include/rtbot/std/Identity.h b/libs/std/include/rtbot/std/Identity.h index 716b6b4c..edd0b866 100644 --- a/libs/std/include/rtbot/std/Identity.h +++ b/libs/std/include/rtbot/std/Identity.h @@ -17,6 +17,18 @@ class Identity : public Operator { std::string type_name() const override { return "Identity"; } + bool equals(const Identity& other) const { + return Operator::equals(other); + } + + bool operator==(const Identity& other) const { + return equals(other); + } + + bool operator!=(const Identity& other) const { + return !(*this == other); + } + protected: void process_data() override { auto& input_queue = get_data_queue(0); diff --git a/libs/std/include/rtbot/std/InfiniteImpulseResponse.h b/libs/std/include/rtbot/std/InfiniteImpulseResponse.h index dab272fd..5435da7e 100644 --- a/libs/std/include/rtbot/std/InfiniteImpulseResponse.h +++ b/libs/std/include/rtbot/std/InfiniteImpulseResponse.h @@ -33,6 +33,28 @@ class InfiniteImpulseResponse : public Operator { std::vector get_a_coeffs() const { return a_; } std::vector get_b_coeffs() const { return b_; } + bool equals(const InfiniteImpulseResponse& other) const { + if (!Operator::equals(other)) return false; + + // Compare coefficients + if (b_ != other.b_) return false; + if (a_ != other.a_) return false; + + // Compare input/output buffers + if (x_ != other.x_) return false; + if (y_ != other.y_) return false; + + return true; + } + + bool operator==(const InfiniteImpulseResponse& other) const { + return equals(other); + } + + bool operator!=(const InfiniteImpulseResponse& other) const { + return !(*this == other); + } + Bytes collect() override { Bytes bytes = Operator::collect(); @@ -56,26 +78,35 @@ class InfiniteImpulseResponse : public Operator { } void restore(Bytes::const_iterator& it) override { + // Restore base state Operator::restore(it); - // Restore x buffer - size_t x_size = *reinterpret_cast(&(*it)); - it += sizeof(size_t); - x_.clear(); + // ---- Restore x buffer ---- + size_t x_size = 0; + std::memcpy(&x_size, &(*it), sizeof(x_size)); + it += sizeof(x_size); + + x_.clear(); + for (size_t i = 0; i < x_size; ++i) { - double value = *reinterpret_cast(&(*it)); - it += sizeof(double); - x_.push_back(value); + double value = 0.0; + std::memcpy(&value, &(*it), sizeof(value)); + it += sizeof(value); + x_.push_back(value); } - // Restore y buffer - size_t y_size = *reinterpret_cast(&(*it)); - it += sizeof(size_t); - y_.clear(); + // ---- Restore y buffer ---- + size_t y_size = 0; + std::memcpy(&y_size, &(*it), sizeof(y_size)); + it += sizeof(y_size); + + y_.clear(); + for (size_t i = 0; i < y_size; ++i) { - double value = *reinterpret_cast(&(*it)); - it += sizeof(double); - y_.push_back(value); + double value = 0.0; + std::memcpy(&value, &(*it), sizeof(value)); + it += sizeof(value); + y_.push_back(value); } } diff --git a/libs/std/include/rtbot/std/Linear.h b/libs/std/include/rtbot/std/Linear.h index fb55b551..cd4635d8 100644 --- a/libs/std/include/rtbot/std/Linear.h +++ b/libs/std/include/rtbot/std/Linear.h @@ -22,21 +22,49 @@ class Linear : public Join { std::string type_name() const override { return "Linear"; } const std::vector& get_coefficients() const { return coeffs_; } + bool equals(const Linear& other) const { + return (coeffs_ == other.coeffs_ && Join::equals(other)); + } + + bool operator==(const Linear& other) const { + return equals(other); + } + + bool operator!=(const Linear& other) const { + return !(*this == other); + } + protected: void process_data() override { - sync(); - const auto& synced_data = get_synchronized_data(); + + while(true) { - for (const auto& [time, messages] : synced_data) { - std::vector*> typed_messages; - typed_messages.reserve(messages.size()); + bool is_any_empty; + bool is_sync; + do { + is_any_empty = false; + is_sync = sync_data_inputs(); + for (int i=0; i < num_data_ports(); i++) { + if (get_data_queue(i).empty()) { + is_any_empty = true; + break; + } + } + } while (!is_sync && !is_any_empty ); + + if (!is_sync) return; - for (const auto& msg : messages) { - const auto* typed_msg = dynamic_cast*>(msg.get()); + std::vector*> typed_messages; + timestamp_t time = 0; + // Process each synchronized set of messages + for (int i=0; i < num_data_ports(); i++) { + typed_messages.reserve(num_data_ports()); + const auto* typed_msg = dynamic_cast*>(get_data_queue(i).front().get()); if (!typed_msg) { throw std::runtime_error("Invalid message type in Linear"); } typed_messages.push_back(typed_msg); + time = typed_msg->time; } double result = 0.0; @@ -44,11 +72,11 @@ class Linear : public Join { result += coeffs_[i] * typed_messages[i]->data.value; } + for (int i = 0; i < num_data_ports(); i++) + get_data_queue(i).pop_front(); + get_output_queue(0).push_back(create_message(time, NumberData{result})); } - - // Clear synchronized data after processing - clear_synchronized_data(); } private: diff --git a/libs/std/include/rtbot/std/MovingAverage.h b/libs/std/include/rtbot/std/MovingAverage.h index 201a92d5..466fc850 100644 --- a/libs/std/include/rtbot/std/MovingAverage.h +++ b/libs/std/include/rtbot/std/MovingAverage.h @@ -23,6 +23,18 @@ class MovingAverage : public Buffer { std::string type_name() const override { return "MovingAverage"; } + bool equals(const MovingAverage& other) const { + return (StateSerializer::hash_double(mean()) == StateSerializer::hash_double(other.mean()) && Buffer::equals(other)); + } + + bool operator==(const MovingAverage& other) const { + return equals(other); + } + + bool operator!=(const MovingAverage& other) const { + return !(*this == other); + } + protected: std::vector>> process_message(const Message *msg) override { // Only emit messages when the buffer is full to ensure diff --git a/libs/std/include/rtbot/std/MovingSum.h b/libs/std/include/rtbot/std/MovingSum.h index 271616db..727774ab 100644 --- a/libs/std/include/rtbot/std/MovingSum.h +++ b/libs/std/include/rtbot/std/MovingSum.h @@ -22,6 +22,18 @@ class MovingSum : public Buffer { std::string type_name() const override { return "MovingSum"; } + bool equals(const MovingSum& other) const { + return (StateSerializer::hash_double(sum()) == StateSerializer::hash_double(other.sum()) && Buffer::equals(other)); + } + + bool operator==(const MovingSum& other) const { + return equals(other); + } + + bool operator!=(const MovingSum& other) const { + return !(*this == other); + } + protected: std::vector>> process_message(const Message *msg) override { // Only emit messages when the buffer is full to ensure diff --git a/libs/std/include/rtbot/std/PeakDetector.h b/libs/std/include/rtbot/std/PeakDetector.h index 9ae1da81..5f6d828d 100644 --- a/libs/std/include/rtbot/std/PeakDetector.h +++ b/libs/std/include/rtbot/std/PeakDetector.h @@ -30,6 +30,18 @@ class PeakDetector : public Buffer { std::string type_name() const override { return "PeakDetector"; } + bool equals(const PeakDetector& other) const { + return Buffer::equals(other); + } + + bool operator==(const PeakDetector& other) const { + return equals(other); + } + + bool operator!=(const PeakDetector& other) const { + return !(*this == other); + } + protected: std::vector>> process_message(const Message* msg) override { std::vector>> output; diff --git a/libs/std/include/rtbot/std/Replace.h b/libs/std/include/rtbot/std/Replace.h index 4204d430..3f82697a 100644 --- a/libs/std/include/rtbot/std/Replace.h +++ b/libs/std/include/rtbot/std/Replace.h @@ -23,6 +23,10 @@ class Replace : public Operator { // Pure virtual method that derived classes must implement virtual double replace(double value) const = 0; + bool equals(const Replace& other) const { + return Operator::equals(other); + } + protected: void process_data() override { auto& input_queue = get_data_queue(0); @@ -56,6 +60,24 @@ class LessThanOrEqualToReplace : public Replace { double get_replace_by() const { return replaceBy_; } double replace(double x) const override { return (x <= threshold_) ? replaceBy_ : x; } + bool equals(const LessThanOrEqualToReplace& other) const { + + if (StateSerializer::hash_double(threshold_) != StateSerializer::hash_double(other.threshold_)) return false; + if (StateSerializer::hash_double(replaceBy_) != StateSerializer::hash_double(other.replaceBy_)) return false; + + if (!Replace::equals(other)) return false; + + return true; + } + + bool operator==(const LessThanOrEqualToReplace& other) const { + return equals(other); + } + + bool operator!=(const LessThanOrEqualToReplace& other) const { + return !(*this == other); + } + private: double threshold_; double replaceBy_; diff --git a/libs/std/include/rtbot/std/ResamplerConstant.h b/libs/std/include/rtbot/std/ResamplerConstant.h index 4812ce48..917fe054 100644 --- a/libs/std/include/rtbot/std/ResamplerConstant.h +++ b/libs/std/include/rtbot/std/ResamplerConstant.h @@ -18,6 +18,7 @@ class ResamplerConstant : public Operator { add_data_port(); add_output_port(); + last_value_ = T{}; } void reset() override { @@ -29,6 +30,27 @@ class ResamplerConstant : public Operator { std::string type_name() const override { return "ResamplerConstant"; } + bool equals(const ResamplerConstant& other) const { + + if (dt_ != other.dt_) return false; + if (t0_ != other.t0_) return false; + if (initialized_ != other.initialized_) return false; + if (next_emit_ != other.next_emit_) return false; + if (last_value_ != other.last_value_) return false; + + if (!Operator::equals(other)) return false; + + return true; + } + + bool operator==(const ResamplerConstant& other) const { + return equals(other); + } + + bool operator!=(const ResamplerConstant& other) const { + return !(*this == other); + } + Bytes collect() override { // First collect base state Bytes bytes = Operator::collect(); @@ -41,20 +63,28 @@ class ResamplerConstant : public Operator { bytes.insert(bytes.end(), reinterpret_cast(&initialized_), reinterpret_cast(&initialized_) + sizeof(initialized_)); + // Serialize last value + bytes.insert(bytes.end(), reinterpret_cast(&last_value_), + reinterpret_cast(&last_value_) + sizeof(last_value_)); + return bytes; } void restore(Bytes::const_iterator& it) override { - // First restore base state + // ---- Restore base state ---- Operator::restore(it); - // Restore next emission time - next_emit_ = *reinterpret_cast(&(*it)); - it += sizeof(timestamp_t); + // ---- Restore next_emit_ safely ---- + std::memcpy(&next_emit_, &(*it), sizeof(next_emit_)); + it += sizeof(next_emit_); + + // ---- Restore initialized_ safely ---- + std::memcpy(&initialized_, &(*it), sizeof(initialized_)); + it += sizeof(initialized_); - // Restore initialization state - initialized_ = *reinterpret_cast(&(*it)); - it += sizeof(bool); + // ---- Restore last_value_ safely ---- + std::memcpy(&last_value_, &(*it), sizeof(last_value_)); + it += sizeof(last_value_); } timestamp_t get_interval() const { return dt_; } diff --git a/libs/std/include/rtbot/std/ResamplerHermite.h b/libs/std/include/rtbot/std/ResamplerHermite.h index d1fe2a8c..4ccd06b6 100644 --- a/libs/std/include/rtbot/std/ResamplerHermite.h +++ b/libs/std/include/rtbot/std/ResamplerHermite.h @@ -33,6 +33,27 @@ class ResamplerHermite : public Buffer { } std::string type_name() const override { return "ResamplerHermite"; } + + bool equals(const ResamplerHermite& other) const { + + if (dt_ != other.dt_) return false; + if (t0_ != other.t0_) return false; + if (initialized_ != other.initialized_) return false; + if (next_emit_ != other.next_emit_) return false; + + if (!Buffer::equals(other)) return false; + + return true; + } + + bool operator==(const ResamplerHermite& other) const { + return equals(other); + } + + bool operator!=(const ResamplerHermite& other) const { + return !(*this == other); + } + Bytes collect() override { // First collect base state Bytes bytes = Buffer::collect(); @@ -49,16 +70,16 @@ class ResamplerHermite : public Buffer { } void restore(Bytes::const_iterator& it) override { - // First restore base state + // ---- Restore base state ---- Buffer::restore(it); - // Restore next emission time - next_emit_ = *reinterpret_cast(&(*it)); - it += sizeof(timestamp_t); + // ---- Restore next_emit_ safely ---- + std::memcpy(&next_emit_, &(*it), sizeof(next_emit_)); + it += sizeof(next_emit_); - // Restore initialization state - initialized_ = *reinterpret_cast(&(*it)); - it += sizeof(bool); + // ---- Restore initialized_ safely ---- + std::memcpy(&initialized_, &(*it), sizeof(initialized_)); + it += sizeof(initialized_); } timestamp_t get_interval() const { return dt_; } diff --git a/libs/std/include/rtbot/std/StandardDeviation.h b/libs/std/include/rtbot/std/StandardDeviation.h index f9fffafe..b4dee960 100644 --- a/libs/std/include/rtbot/std/StandardDeviation.h +++ b/libs/std/include/rtbot/std/StandardDeviation.h @@ -20,6 +20,18 @@ class StandardDeviation : public Buffer { std::string type_name() const override { return "StandardDeviation"; } + bool equals(const StandardDeviation& other) const { + return (StateSerializer::hash_double(standard_deviation()) == StateSerializer::hash_double(other.standard_deviation()) && Buffer::equals(other)); + } + + bool operator==(const StandardDeviation& other) const { + return equals(other); + } + + bool operator!=(const StandardDeviation& other) const { + return !(*this == other); + } + protected: std::vector>> process_message(const Message *msg) override { // Only emit messages when the buffer is full to ensure diff --git a/libs/std/include/rtbot/std/TimeShift.h b/libs/std/include/rtbot/std/TimeShift.h index 0daf2c7f..b36334fa 100644 --- a/libs/std/include/rtbot/std/TimeShift.h +++ b/libs/std/include/rtbot/std/TimeShift.h @@ -16,6 +16,18 @@ class TimeShift : public Operator { std::string type_name() const override { return "TimeShift"; } timestamp_t get_shift() const { return shift_; } + bool equals(const TimeShift& other) const { + return (shift_ == other.shift_ && Operator::equals(other)); + } + + bool operator==(const TimeShift& other) const { + return equals(other); + } + + bool operator!=(const TimeShift& other) const { + return !(*this == other); + } + protected: void process_data() override { auto& input_queue = get_data_queue(0); diff --git a/libs/std/include/rtbot/std/Variable.h b/libs/std/include/rtbot/std/Variable.h index 93482bcc..75ae1e9c 100644 --- a/libs/std/include/rtbot/std/Variable.h +++ b/libs/std/include/rtbot/std/Variable.h @@ -13,121 +13,54 @@ namespace rtbot { class Variable : public Operator { - private: - struct TimeValue { - timestamp_t time; - double value; - }; - + public: Variable(std::string id, double default_value = 0.0) : Operator(std::move(id)), default_value_(default_value) { add_data_port(); // For value updates add_control_port(); // For queries - add_output_port(); // For responses - - values_.push_back({0, default_value_}); - } - - void reset() override { - Operator::reset(); - values_.clear(); - values_.push_back({0, default_value_}); - pending_queries_.clear(); - } + add_output_port(); // For responses + } std::string type_name() const override { return "Variable"; } double get_default_value() const { return default_value_; } - Bytes collect() override { - Bytes bytes = Operator::collect(); - - // Serialize default value - bytes.insert(bytes.end(), reinterpret_cast(&default_value_), - reinterpret_cast(&default_value_) + sizeof(default_value_)); - - // Serialize time-value pairs - size_t values_count = values_.size(); - bytes.insert(bytes.end(), reinterpret_cast(&values_count), - reinterpret_cast(&values_count) + sizeof(values_count)); - - for (const auto& tv : values_) { - bytes.insert(bytes.end(), reinterpret_cast(&tv.time), - reinterpret_cast(&tv.time) + sizeof(tv.time)); - bytes.insert(bytes.end(), reinterpret_cast(&tv.value), - reinterpret_cast(&tv.value) + sizeof(tv.value)); - } - - // Serialize pending query timestamps - size_t queries_count = pending_queries_.size(); - bytes.insert(bytes.end(), reinterpret_cast(&queries_count), - reinterpret_cast(&queries_count) + sizeof(queries_count)); - - for (const auto& time : pending_queries_) { - bytes.insert(bytes.end(), reinterpret_cast(&time), - reinterpret_cast(&time) + sizeof(time)); - } - - return bytes; + bool equals(const Variable& other) const { + return (StateSerializer::hash_double(default_value_) == StateSerializer::hash_double(other.default_value_) && Operator::equals(other)); } - void restore(Bytes::const_iterator& it) override { - Operator::restore(it); - - // Restore default value - default_value_ = *reinterpret_cast(&(*it)); - it += sizeof(double); - - // Restore time-value pairs - size_t values_count = *reinterpret_cast(&(*it)); - it += sizeof(size_t); - - values_.clear(); - values_.reserve(values_count); - for (size_t i = 0; i < values_count; ++i) { - timestamp_t time = *reinterpret_cast(&(*it)); - it += sizeof(timestamp_t); - double value = *reinterpret_cast(&(*it)); - it += sizeof(double); - values_.push_back({time, value}); - } - - // Restore pending query timestamps - size_t queries_count = *reinterpret_cast(&(*it)); - it += sizeof(size_t); + bool operator==(const Variable& other) const { + return equals(other); + } - pending_queries_.clear(); - pending_queries_.reserve(queries_count); - for (size_t i = 0; i < queries_count; ++i) { - timestamp_t time = *reinterpret_cast(&(*it)); - it += sizeof(timestamp_t); - pending_queries_.push_back(time); - } + bool operator!=(const Variable& other) const { + return !(*this == other); } protected: - void process_data() override { - auto& data_queue = get_data_queue(0); - bool values_updated = false; - - while (!data_queue.empty()) { - const auto* msg = dynamic_cast*>(data_queue.front().get()); - if (!msg) { - throw std::runtime_error("Invalid message type in Variable"); - } - - values_.push_back({msg->time, msg->data.value}); - values_updated = true; - data_queue.pop_front(); - } + void process_data() override { - if (values_updated) { - process_pending_queries(); - } + if (!get_data_queue(0).empty() && !get_control_queue(0).empty()) + process_pending_queries(); + } void process_control() override { + + if (!get_data_queue(0).empty() && !get_control_queue(0).empty()) + process_pending_queries(); + } + + private: + double default_value_; + + void process_pending_queries() { + + auto& data_queue = get_data_queue(0); auto& control_queue = get_control_queue(0); + auto& output_queue = get_output_queue(0); + + if (data_queue.empty() || control_queue.empty()) return; while (!control_queue.empty()) { const auto* query = dynamic_cast*>(control_queue.front().get()); @@ -135,65 +68,80 @@ class Variable : public Operator { throw std::runtime_error("Invalid control message type in Variable"); } - pending_queries_.push_back(query->time); - control_queue.pop_front(); - } - - process_pending_queries(); - } + const timestamp_t query_time = query->time; + + + if (data_queue.empty()) break; + + bool found = false; + double result_value = 0.0; + timestamp_t match_on = 0; + timestamp_t result_time = 0; + + if (data_queue.size() == 1) { + const auto* msg = dynamic_cast*>(data_queue.front().get()); + if (msg->time == query_time) { + result_value = msg->data.value; + result_time = query_time; + match_on = query_time; + found = true; + } + } else { + auto prev = data_queue.begin(); + auto next = std::next(prev); + while (next != data_queue.end()) { + const auto* prev_msg = dynamic_cast*>((*prev).get()); + const auto* next_msg = dynamic_cast*>((*next).get()); + if (!prev_msg || !next_msg) + throw std::runtime_error("Invalid data message type in Variable"); + + if (query_time == next_msg->time) { + result_value = next_msg->data.value; + result_time = query_time; + match_on = query_time; + found = true; + break; + } else if (query_time >=prev_msg->time && query_time < next_msg->time) { + result_value = prev_msg->data.value; + result_time = query_time; + match_on = prev_msg->time; + found = true; + break; + } - private: - double default_value_; - std::vector values_; - std::vector pending_queries_; + ++prev; + ++next; + } + } - std::optional query_value(timestamp_t time) { - if (values_.size() == 1 && time == values_[0].time) { - return values_[0].value; - } + // Stop if query time is before the first or after the last + const auto* first_msg = dynamic_cast*>(data_queue.front().get()); + const auto* last_msg = dynamic_cast*>(data_queue.back().get()); - if (values_.size() > 1) { - // Find the last value that occurred at or before the query time - for (size_t i = 1; i < values_.size(); i++) { - if (time == values_[i].time) { - // Found exact match - clean up all ranges before this one - double value = values_[i].value; - for (size_t j = 0; j < i; j++) { - values_.erase(values_.begin()); + if (!found) { + if (query_time < first_msg->time) { + output_queue.push_back(create_message(query_time, NumberData{default_value_})); + control_queue.pop_front(); + } + else if (query_time > last_msg->time) { + while (data_queue.size() > 1) { + data_queue.pop_front(); } - return value; - } else if (values_[i - 1].time <= time && time < values_[i].time) { - // Found containing range - clean up all ranges before the starting range - double value = values_[i - 1].value; - for (size_t j = 0; j < i - 1; j++) { - values_.erase(values_.begin()); + break; + } + } else { + output_queue.push_back(create_message(result_time, NumberData{result_value})); + control_queue.pop_front(); + while (!data_queue.empty()) { + const auto* msg = dynamic_cast*>(data_queue.front().get()); + if (msg->time < match_on) { + data_queue.pop_front(); + } else { + break; } - return value; } } - } - - return std::nullopt; - } - - void process_pending_queries() { - auto& output_queue = get_output_queue(0); - size_t processed = 0; - - for (size_t i = 0; i < pending_queries_.size(); i++) { - timestamp_t query_time = pending_queries_[i]; - auto value = query_value(query_time); - if (!value) { - break; // Stop at first uncertain value - } - - output_queue.push_back(create_message(query_time, NumberData{*value})); - processed++; - } - // Remove processed queries - if (processed > 0) { - pending_queries_.erase(pending_queries_.begin(), pending_queries_.begin() + processed); } } }; diff --git a/libs/std/test/BUILD.bazel b/libs/std/test/BUILD.bazel index 3b089f8d..5588166b 100644 --- a/libs/std/test/BUILD.bazel +++ b/libs/std/test/BUILD.bazel @@ -15,6 +15,15 @@ cc_test( "//examples/data:ppg.csv", ], #defines = ["CATCH_CONFIG_MAIN"], + copts = [ + "-DCATCH_CONFIG_POSIX_SIGNALS", + "-g", + "-fsanitize=address", + "-Wno-macro-redefined" + ], + linkopts = [ + "-fsanitize=address", + ], deps = [ "//libs/std:rtbot-std", "@catch2", diff --git a/libs/std/test/integration_test_ppg.cpp b/libs/std/test/integration_test_ppg.cpp index 88c018c8..5a69046a 100644 --- a/libs/std/test/integration_test_ppg.cpp +++ b/libs/std/test/integration_test_ppg.cpp @@ -59,11 +59,11 @@ SCENARIO("PPG Pipeline propagates messages correctly", "[PPG][Integration]") { WHEN("Processing a single data point") { pipeline.input->receive_data(create_message(s.ti[0], NumberData{s.ppg[0]}), 0); - pipeline.input->execute(); + pipeline.input->execute(true); THEN("Messages propagate through moving averages") { - REQUIRE(pipeline.ma_short->get_output_queue(0).size() <= 1); - REQUIRE(pipeline.ma_long->get_output_queue(0).size() <= 1); + REQUIRE(pipeline.ma_short->get_debug_output_queue(0).size() <= 1); + REQUIRE(pipeline.ma_long->get_debug_output_queue(0).size() <= 1); } AND_THEN("Moving average buffers start filling") { @@ -75,12 +75,12 @@ SCENARIO("PPG Pipeline propagates messages correctly", "[PPG][Integration]") { WHEN("Processing enough points to fill short window") { for (int i = 1; i < short_window + 1; i++) { pipeline.input->receive_data(create_message(s.ti[i], NumberData{s.ppg[i]}), 0); - pipeline.input->execute(); + pipeline.input->execute(true); } THEN("Short moving average starts producing output") { REQUIRE(pipeline.ma_short->buffer_full()); - REQUIRE(pipeline.ma_short->get_output_queue(0).size() == 1); + REQUIRE(pipeline.ma_short->get_debug_output_queue(0).size() == 1); } AND_THEN("Long moving average still filling") { @@ -112,9 +112,9 @@ SCENARIO("PPG Pipeline propagates messages correctly", "[PPG][Integration]") { for (int i = 0; i < test_window; i++) { pipeline.input->receive_data(create_message(s.ti[i], NumberData{s.ppg[i]}), 0); - pipeline.input->execute(); + pipeline.input->execute(true); - const auto& peak_output = pipeline.peak->get_output_queue(0); + const auto& peak_output = pipeline.peak->get_debug_output_queue(0); for (const auto& msg : peak_output) { peak_times.push_back(msg->time); } diff --git a/libs/std/test/test_arithmetic_scalar.cpp b/libs/std/test/test_arithmetic_scalar.cpp index bc71be46..1e23d3d9 100644 --- a/libs/std/test/test_arithmetic_scalar.cpp +++ b/libs/std/test/test_arithmetic_scalar.cpp @@ -29,19 +29,38 @@ SCENARIO("ArithmeticScalar derived classes handle basic operations", "[math_scal std::vector> expected = {{2, 1.0}, {4, 2.0}, {5, 12.0}}; + add->clear_all_output_ports(); + for (const auto& input : inputs) { add->receive_data(create_message(input.first, NumberData{input.second}), 0); add->execute(); } - /*output = add->get_output_queue(0); - REQUIRE(output.size() == inputs.size()); + auto& output_queue = add->get_output_queue(0); + REQUIRE(output_queue.size() == inputs.size()); - for (size_t i = 0; i < output.size(); ++i) { - auto* msg = dynamic_cast*>(output[i].get()); + for (size_t i = 0; i < output_queue.size(); ++i) { + auto* msg = dynamic_cast*>(output_queue[i].get()); REQUIRE(msg->time == expected[i].first); REQUIRE(msg->data.value == expected[i].second); - }*/ + } + + WHEN("State is serialized and restored") { + // Serialize state + Bytes state = add->collect(); + + REQUIRE(add->get_output_queue(0).size() == 3); + + auto restored = make_add("add1", 2.0); + + // Restore state + auto it = state.cbegin(); + restored->restore(it); + + THEN("The operators match") { + REQUIRE(*restored == *add); + } + } } SECTION("Scale operator") { @@ -67,6 +86,23 @@ SCENARIO("ArithmeticScalar derived classes handle basic operations", "[math_scal REQUIRE(msg->time == expected[i].first); REQUIRE(msg->data.value == expected[i].second); } + + WHEN("State is serialized and restored") { + // Serialize state + Bytes state = scale->collect(); + + REQUIRE(scale->get_output_queue(0).size() == 3); + + auto restored = make_scale("scale1", 2.0); + + // Restore state + auto it = state.cbegin(); + restored->restore(it); + + THEN("The operators match") { + REQUIRE(*restored == *scale); + } + } } SECTION("Power operator") { @@ -92,6 +128,23 @@ SCENARIO("ArithmeticScalar derived classes handle basic operations", "[math_scal REQUIRE(msg->time == expected[i].first); REQUIRE(msg->data.value == Approx(expected[i].second)); } + + WHEN("State is serialized and restored") { + // Serialize state + Bytes state = power->collect(); + + REQUIRE(power->get_output_queue(0).size() == 3); + + auto restored = make_power("pow1", 2.0); + + // Restore state + auto it = state.cbegin(); + restored->restore(it); + + THEN("The operators match") { + REQUIRE(*restored == *power); + } + } } } @@ -116,8 +169,23 @@ SCENARIO("ArithmeticScalar handles trigonometric functions", "[math_scalar_op]") for (size_t i = 0; i < output.size(); ++i) { auto* msg = dynamic_cast*>(output[i].get()); REQUIRE(msg->time == expected[i].first); - // TODO: Fails with 0.0 == Approx( 0.0 ), which is correct - // REQUIRE(msg->data.value == Approx(expected[i].second).epsilon(1e-3)); + } + + WHEN("State is serialized and restored") { + // Serialize state + Bytes state = sin->collect(); + + REQUIRE(sin->get_output_queue(0).size() == 3); + + auto restored = make_sin("sin1"); + + // Restore state + auto it = state.cbegin(); + restored->restore(it); + + THEN("The operators match") { + REQUIRE(*restored == *sin); + } } } } @@ -145,6 +213,23 @@ SCENARIO("ArithmeticScalar handles exponential and logarithmic functions", "[mat REQUIRE(msg->time == expected[i].first); REQUIRE(msg->data.value == Approx(expected[i].second)); } + + WHEN("State is serialized and restored") { + // Serialize state + Bytes state = exp->collect(); + + REQUIRE(exp->get_output_queue(0).size() == 3); + + auto restored = make_exp("exp1"); + + // Restore state + auto it = state.cbegin(); + restored->restore(it); + + THEN("The operators match") { + REQUIRE(*restored == *exp); + } + } } // Similar tests for Log and Log10... @@ -173,12 +258,30 @@ SCENARIO("ArithmeticScalar handles rounding functions", "[math_scalar_op]") { REQUIRE(msg->time == expected[i].first); REQUIRE(msg->data.value == expected[i].second); } + + WHEN("State is serialized and restored") { + // Serialize state + Bytes state = round->collect(); + + REQUIRE(round->get_output_queue(0).size() == 4); + + auto restored = make_round("round1"); + + // Restore state + auto it = state.cbegin(); + restored->restore(it); + + THEN("The operators match") { + REQUIRE(*restored == *round); + } + } + } // Similar tests for Floor and Ceil... } -SCENARIO("ArithmeticScalar handles serialization", "[math_scalar_op]") { +SCENARIO("ArithmeticScalar handles serialization", "[math_scalar_op][State]") { SECTION("Add operator serialization") { auto add = make_add("add1", 2.0); @@ -200,6 +303,7 @@ SCENARIO("ArithmeticScalar handles serialization", "[math_scalar_op]") { // Verify restored state REQUIRE(restored->type_name() == add->type_name()); REQUIRE(dynamic_cast(restored.get())->get_value() == dynamic_cast(add.get())->get_value()); + REQUIRE(*restored == *add); // Process new data and verify behavior restored->clear_all_output_ports(); diff --git a/libs/std/test/test_arithmetic_sync.cpp b/libs/std/test/test_arithmetic_sync.cpp index 92e6e131..06763b2d 100644 --- a/libs/std/test/test_arithmetic_sync.cpp +++ b/libs/std/test/test_arithmetic_sync.cpp @@ -45,6 +45,91 @@ SCENARIO("ArithmeticSync operators handle basic synchronization", "[math_sync_bi check_output(div, 2.0); // 10 / 5 } } + + WHEN("Additon State is serialized and restored") { + // Serialize state + add->receive_data(create_message(1, NumberData{10.0}), 0); + add->receive_data(create_message(1, NumberData{5.0}), 1); + add->execute(); + Bytes state = add->collect(); + + REQUIRE(add->get_output_queue(0).size() == 1); + + auto restored = make_addition("add1"); + + // Restore state + auto it = state.cbegin(); + restored->restore(it); + + THEN("The operators match") { + REQUIRE(restored->get_output_queue(0).size() == 1); + REQUIRE(*restored == *add); + } + } + + WHEN("Subtraction State is serialized and restored") { + // Serialize state + sub->receive_data(create_message(1, NumberData{10.0}), 0); + sub->receive_data(create_message(1, NumberData{5.0}), 1); + sub->execute(); + Bytes state = sub->collect(); + + REQUIRE(sub->get_output_queue(0).size() == 1); + + auto restored = make_subtraction("sub1"); + + // Restore state + auto it = state.cbegin(); + restored->restore(it); + + THEN("The operators match") { + REQUIRE(restored->get_output_queue(0).size() == 1); + REQUIRE(*restored == *sub); + } + } + + WHEN("Multiplication State is serialized and restored") { + // Serialize state + mul->receive_data(create_message(1, NumberData{10.0}), 0); + mul->receive_data(create_message(1, NumberData{5.0}), 1); + mul->execute(); + Bytes state = mul->collect(); + + REQUIRE(mul->get_output_queue(0).size() == 1); + + auto restored = make_multiplication("mul1"); + + // Restore state + auto it = state.cbegin(); + restored->restore(it); + + THEN("The operators match") { + REQUIRE(restored->get_output_queue(0).size() == 1); + REQUIRE(*restored == *mul); + } + } + + WHEN("Division State is serialized and restored") { + // Serialize state + div->receive_data(create_message(1, NumberData{10.0}), 0); + div->receive_data(create_message(1, NumberData{5.0}), 1); + div->execute(); + Bytes state = div->collect(); + + REQUIRE(div->get_output_queue(0).size() == 1); + + auto restored = make_division("div1"); + + // Restore state + auto it = state.cbegin(); + restored->restore(it); + + THEN("The operators match") { + REQUIRE(restored->get_output_queue(0).size() == 1); + REQUIRE(*restored == *div); + } + } + } } @@ -77,7 +162,7 @@ SCENARIO("Division operator handles division by zero", "[math_sync_binary_op]") } }*/ -SCENARIO("ArithmeticSync operators handle state serialization", "[math_sync_binary_op]") { +SCENARIO("ArithmeticSync operators handle state serialization", "[math_sync_binary_op][State]") { GIVEN("An operator with buffered messages") { auto add = make_addition("add1"); @@ -103,6 +188,7 @@ SCENARIO("ArithmeticSync operators handle state serialization", "[math_sync_bina restored->execute(); THEN("Both operators produce identical results") { + REQUIRE(*add == *restored); const auto& orig_output = add->get_output_queue(0); const auto& rest_output = restored->get_output_queue(0); diff --git a/libs/std/test/test_boolean_sync.cpp b/libs/std/test/test_boolean_sync.cpp index 0f41ee00..7fa5af7d 100644 --- a/libs/std/test/test_boolean_sync.cpp +++ b/libs/std/test/test_boolean_sync.cpp @@ -96,9 +96,15 @@ SCENARIO("BooleanSync operators handle unsynchronized messages", "[boolean_sync_ } } -SCENARIO("BooleanSync operators handle state serialization", "[boolean_sync_binary_op]") { +SCENARIO("BooleanSync operators handle state serialization", "[boolean_sync_binary_op][State]") { GIVEN("An operator with buffered messages") { auto and_op = make_logical_and("and1"); + auto or_op = make_logical_or("or1"); + auto xor_op = make_logical_xor("xor1"); + auto nand_op = make_logical_nand("nand1"); + auto nor_op = make_logical_nor("nor1"); + auto xnor_op = make_logical_xnor("xnor1"); + auto impl_op = make_logical_implication("impl1"); // Add some messages to buffer and_op->receive_data(create_message(1, BooleanData{true}), 0); @@ -106,7 +112,37 @@ SCENARIO("BooleanSync operators handle state serialization", "[boolean_sync_bina and_op->receive_data(create_message(1, BooleanData{true}), 1); and_op->receive_data(create_message(2, BooleanData{true}), 1); - WHEN("State is serialized and restored") { + or_op->receive_data(create_message(1, BooleanData{true}), 0); + or_op->receive_data(create_message(2, BooleanData{false}), 0); + or_op->receive_data(create_message(1, BooleanData{true}), 1); + or_op->receive_data(create_message(2, BooleanData{true}), 1); + + xor_op->receive_data(create_message(1, BooleanData{true}), 0); + xor_op->receive_data(create_message(2, BooleanData{false}), 0); + xor_op->receive_data(create_message(1, BooleanData{true}), 1); + xor_op->receive_data(create_message(2, BooleanData{true}), 1); + + nand_op->receive_data(create_message(1, BooleanData{true}), 0); + nand_op->receive_data(create_message(2, BooleanData{false}), 0); + nand_op->receive_data(create_message(1, BooleanData{true}), 1); + nand_op->receive_data(create_message(2, BooleanData{true}), 1); + + nor_op->receive_data(create_message(1, BooleanData{true}), 0); + nor_op->receive_data(create_message(2, BooleanData{false}), 0); + nor_op->receive_data(create_message(1, BooleanData{true}), 1); + nor_op->receive_data(create_message(2, BooleanData{true}), 1); + + xnor_op->receive_data(create_message(1, BooleanData{true}), 0); + xnor_op->receive_data(create_message(2, BooleanData{false}), 0); + xnor_op->receive_data(create_message(1, BooleanData{true}), 1); + xnor_op->receive_data(create_message(2, BooleanData{true}), 1); + + impl_op->receive_data(create_message(1, BooleanData{true}), 0); + impl_op->receive_data(create_message(2, BooleanData{false}), 0); + impl_op->receive_data(create_message(1, BooleanData{true}), 1); + impl_op->receive_data(create_message(2, BooleanData{true}), 1); + + WHEN("and is serialized and restored") { // Serialize state Bytes state = and_op->collect(); @@ -122,6 +158,7 @@ SCENARIO("BooleanSync operators handle state serialization", "[boolean_sync_bina restored->execute(); THEN("Both operators produce identical results") { + REQUIRE(*restored == *and_op); const auto& orig_output = and_op->get_output_queue(0); const auto& rest_output = restored->get_output_queue(0); @@ -136,6 +173,198 @@ SCENARIO("BooleanSync operators handle state serialization", "[boolean_sync_bina } } } + + WHEN("or is serialized and restored") { + // Serialize state + Bytes state = or_op->collect(); + + // Create new operator + auto restored = make_logical_or("or1"); + + // Restore state + Bytes::const_iterator it = state.begin(); + restored->restore(it); + + // Execute both operators + or_op->execute(); + restored->execute(); + + THEN("Both operators produce identical results") { + REQUIRE(*restored == *or_op); + const auto& orig_output = or_op->get_output_queue(0); + const auto& rest_output = restored->get_output_queue(0); + + REQUIRE(orig_output.size() == rest_output.size()); + + for (size_t i = 0; i < orig_output.size(); i++) { + const auto* orig_msg = dynamic_cast*>(orig_output[i].get()); + const auto* rest_msg = dynamic_cast*>(rest_output[i].get()); + + REQUIRE(orig_msg->time == rest_msg->time); + REQUIRE(orig_msg->data.value == rest_msg->data.value); + } + } + } + + WHEN("xor is serialized and restored") { + // Serialize state + Bytes state = xor_op->collect(); + + // Create new operator + auto restored = make_logical_xor("xor1"); + + // Restore state + Bytes::const_iterator it = state.begin(); + restored->restore(it); + + // Execute both operators + xor_op->execute(); + restored->execute(); + + THEN("Both operators produce identical results") { + REQUIRE(*restored == *xor_op); + const auto& orig_output = xor_op->get_output_queue(0); + const auto& rest_output = restored->get_output_queue(0); + + REQUIRE(orig_output.size() == rest_output.size()); + + for (size_t i = 0; i < orig_output.size(); i++) { + const auto* orig_msg = dynamic_cast*>(orig_output[i].get()); + const auto* rest_msg = dynamic_cast*>(rest_output[i].get()); + + REQUIRE(orig_msg->time == rest_msg->time); + REQUIRE(orig_msg->data.value == rest_msg->data.value); + } + } + } + + WHEN("nand is serialized and restored") { + // Serialize state + Bytes state = nand_op->collect(); + + // Create new operator + auto restored = make_logical_nand("nand1"); + + // Restore state + Bytes::const_iterator it = state.begin(); + restored->restore(it); + + // Execute both operators + nand_op->execute(); + restored->execute(); + + THEN("Both operators produce identical results") { + REQUIRE(*restored == *nand_op); + const auto& orig_output = nand_op->get_output_queue(0); + const auto& rest_output = restored->get_output_queue(0); + + REQUIRE(orig_output.size() == rest_output.size()); + + for (size_t i = 0; i < orig_output.size(); i++) { + const auto* orig_msg = dynamic_cast*>(orig_output[i].get()); + const auto* rest_msg = dynamic_cast*>(rest_output[i].get()); + + REQUIRE(orig_msg->time == rest_msg->time); + REQUIRE(orig_msg->data.value == rest_msg->data.value); + } + } + } + + WHEN("nor is serialized and restored") { + // Serialize state + Bytes state = nor_op->collect(); + + // Create new operator + auto restored = make_logical_nor("nor1"); + + // Restore state + Bytes::const_iterator it = state.begin(); + restored->restore(it); + + // Execute both operators + nor_op->execute(); + restored->execute(); + + THEN("Both operators produce identical results") { + REQUIRE(*restored == *nor_op); + const auto& orig_output = nor_op->get_output_queue(0); + const auto& rest_output = restored->get_output_queue(0); + + REQUIRE(orig_output.size() == rest_output.size()); + + for (size_t i = 0; i < orig_output.size(); i++) { + const auto* orig_msg = dynamic_cast*>(orig_output[i].get()); + const auto* rest_msg = dynamic_cast*>(rest_output[i].get()); + + REQUIRE(orig_msg->time == rest_msg->time); + REQUIRE(orig_msg->data.value == rest_msg->data.value); + } + } + } + + WHEN("xnor is serialized and restored") { + // Serialize state + Bytes state = xnor_op->collect(); + + // Create new operator + auto restored = make_logical_xnor("xnor1"); + + // Restore state + Bytes::const_iterator it = state.begin(); + restored->restore(it); + + // Execute both operators + xnor_op->execute(); + restored->execute(); + + THEN("Both operators produce identical results") { + REQUIRE(*restored == *xnor_op); + const auto& orig_output = xnor_op->get_output_queue(0); + const auto& rest_output = restored->get_output_queue(0); + + REQUIRE(orig_output.size() == rest_output.size()); + + for (size_t i = 0; i < orig_output.size(); i++) { + const auto* orig_msg = dynamic_cast*>(orig_output[i].get()); + const auto* rest_msg = dynamic_cast*>(rest_output[i].get()); + + REQUIRE(orig_msg->time == rest_msg->time); + REQUIRE(orig_msg->data.value == rest_msg->data.value); + } + } + } + + WHEN("impl is serialized and restored") { + // Serialize state + Bytes state = impl_op->collect(); + + // Create new operator + auto restored = make_logical_implication("impl1"); + + // Restore state + Bytes::const_iterator it = state.begin(); + restored->restore(it); + + // Execute both operators + impl_op->execute(); + restored->execute(); + + THEN("Both operators produce identical results") { + REQUIRE(*restored == *impl_op); + const auto& orig_output = impl_op->get_output_queue(0); + const auto& rest_output = restored->get_output_queue(0); + + REQUIRE(orig_output.size() == rest_output.size()); + + for (size_t i = 0; i < orig_output.size(); i++) { + const auto* orig_msg = dynamic_cast*>(orig_output[i].get()); + const auto* rest_msg = dynamic_cast*>(rest_output[i].get()); + + REQUIRE(orig_msg->time == rest_msg->time); + REQUIRE(orig_msg->data.value == rest_msg->data.value); + } + } + } } } diff --git a/libs/std/test/test_constant.cpp b/libs/std/test/test_constant.cpp index 24c9ed5f..90883d41 100644 --- a/libs/std/test/test_constant.cpp +++ b/libs/std/test/test_constant.cpp @@ -81,4 +81,35 @@ SCENARIO("Constant operator handles error cases", "[constant]") { } } } +} + +SCENARIO("Constant operator handles state serialization", "[Constant][State]") { + GIVEN("A Constant operator with some history") { + auto constant = make_constant_boolean("const2", true); + + // Process initial messages + constant->receive_data(create_message(1, BooleanData{true}), 0); + constant->receive_data(create_message(2, BooleanData{false}), 0); + constant->execute(); + + WHEN("State is serialized and restored") { + Bytes state = constant->collect(); + auto restored = make_constant_boolean("const2", true); + auto it = state.cbegin(); + restored->restore(it); + + THEN("Continues counting from previous state") { + REQUIRE(*constant == *restored); + restored->clear_all_output_ports(); + restored->receive_data(create_message(3, BooleanData{false}), 0); + restored->execute(); + + const auto& output = restored->get_output_queue(0); + REQUIRE(output.size() == 1); + const auto* msg = dynamic_cast*>(output.front().get()); + REQUIRE(msg->time == 3); + REQUIRE(msg->data.value == true); + } + } + } } \ No newline at end of file diff --git a/libs/std/test/test_count.cpp b/libs/std/test/test_count.cpp index bdeb1474..3e4cfa46 100644 --- a/libs/std/test/test_count.cpp +++ b/libs/std/test/test_count.cpp @@ -44,6 +44,7 @@ SCENARIO("Count operator handles state serialization", "[Count][State]") { restored->restore(it); THEN("Continues counting from previous state") { + REQUIRE(*counter == *restored); restored->clear_all_output_ports(); restored->receive_data(create_message(3, BooleanData{true}), 0); restored->execute(); diff --git a/libs/std/test/test_cumulative_sum.cpp b/libs/std/test/test_cumulative_sum.cpp index ec362d9c..0c33301d 100644 --- a/libs/std/test/test_cumulative_sum.cpp +++ b/libs/std/test/test_cumulative_sum.cpp @@ -80,6 +80,8 @@ SCENARIO("CumulativeSum handles state serialization", "[CumulativeSum]") { THEN("State is preserved correctly") { REQUIRE(restored->get_sum() == 30.0); + REQUIRE(restored->get_sum() == sum->get_sum()); + REQUIRE(*restored == *sum); AND_WHEN("New messages are processed") { restored->receive_data(create_message(3, NumberData{40.0}), 0); diff --git a/libs/std/test/test_difference.cpp b/libs/std/test/test_difference.cpp index fb40b600..3d499f49 100644 --- a/libs/std/test/test_difference.cpp +++ b/libs/std/test/test_difference.cpp @@ -80,6 +80,7 @@ SCENARIO("Difference operator handles state serialization", "[Difference][State] restored->restore(it); THEN("Buffer content matches") { + REQUIRE(*diff == *restored); REQUIRE(restored->buffer_size() == diff->buffer_size()); REQUIRE(restored->get_use_oldest_time() == diff->get_use_oldest_time()); } diff --git a/libs/std/test/test_filter_scalar.cpp b/libs/std/test/test_filter_scalar.cpp index 0e74b5e6..b18ca855 100644 --- a/libs/std/test/test_filter_scalar.cpp +++ b/libs/std/test/test_filter_scalar.cpp @@ -178,37 +178,138 @@ SCENARIO("FilterScalarOp handles error cases", "[filter_scalar_op]") { } } -SCENARIO("FilterScalarOp handles serialization", "[filter_scalar_op]") { - SECTION("LessThan operator serialization") { +SCENARIO("FilterScalarOp handles serialization", "[filter_scalar_op][State]") { + SECTION("operators serialization") { auto lt = make_less_than("lt1", 3.0); + auto gt = make_greater_than("gt1", 3.0); + auto et = make_equal_to("et1", 3.0); + auto net = make_not_equal_to("net1", 3.0); // Fill with some data lt->receive_data(create_message(1, NumberData{1.0}), 0); lt->receive_data(create_message(2, NumberData{4.0}), 0); lt->execute(); - // Serialize state - Bytes state = lt->collect(); - - // Create new operator and restore state - auto restored = make_less_than("lt1", 3.0); - auto it = state.cbegin(); - restored->restore(it); - - // Verify restored state - REQUIRE(restored->type_name() == lt->type_name()); - REQUIRE(dynamic_cast(restored.get())->get_threshold() == - dynamic_cast(lt.get())->get_threshold()); - - // Process new data and verify behavior - restored->clear_all_output_ports(); - restored->receive_data(create_message(3, NumberData{2.0}), 0); - restored->execute(); - - auto& output = restored->get_output_queue(0); - REQUIRE(!output.empty()); - auto* msg = dynamic_cast*>(output[0].get()); - REQUIRE(msg->time == 3); - REQUIRE(msg->data.value == 2.0); + gt->receive_data(create_message(1, NumberData{1.0}), 0); + gt->receive_data(create_message(2, NumberData{4.0}), 0); + gt->execute(); + + et->receive_data(create_message(1, NumberData{1.0}), 0); + et->receive_data(create_message(2, NumberData{4.0}), 0); + et->execute(); + + net->receive_data(create_message(1, NumberData{1.0}), 0); + net->receive_data(create_message(2, NumberData{4.0}), 0); + net->execute(); + + WHEN("LessThan is serialized and restored") { + + // Serialize state + Bytes state = lt->collect(); + + // Create new operator and restore state + auto restored = make_less_than("lt1", 3.0); + auto it = state.cbegin(); + restored->restore(it); + + // Verify restored state + REQUIRE(*restored == *lt); + REQUIRE(restored->type_name() == lt->type_name()); + REQUIRE(dynamic_cast(restored.get())->get_threshold() == + dynamic_cast(lt.get())->get_threshold()); + + // Process new data and verify behavior + restored->clear_all_output_ports(); + restored->receive_data(create_message(3, NumberData{2.0}), 0); + restored->execute(); + + auto& output = restored->get_output_queue(0); + REQUIRE(!output.empty()); + auto* msg = dynamic_cast*>(output[0].get()); + REQUIRE(msg->time == 3); + REQUIRE(msg->data.value == 2.0); + } + + WHEN("GreaterThan is serialized and restored") { + + // Serialize state + Bytes state = gt->collect(); + + // Create new operator and restore state + auto restored = make_greater_than("gt1", 3.0); + auto it = state.cbegin(); + restored->restore(it); + + // Verify restored state + REQUIRE(*restored == *gt); + REQUIRE(restored->type_name() == gt->type_name()); + REQUIRE(dynamic_cast(restored.get())->get_threshold() == + dynamic_cast(gt.get())->get_threshold()); + + // Process new data and verify behavior + restored->clear_all_output_ports(); + restored->receive_data(create_message(3, NumberData{2.0}), 0); + restored->execute(); + + auto& output = restored->get_output_queue(0); + REQUIRE(output.empty()); + } + + WHEN("EqualTo is serialized and restored") { + + // Serialize state + Bytes state = et->collect(); + + // Create new operator and restore state + auto restored = make_equal_to("et1", 3.0); + auto it = state.cbegin(); + restored->restore(it); + + // Verify restored state + REQUIRE(*restored == *et); + REQUIRE(restored->type_name() == et->type_name()); + REQUIRE(dynamic_cast(restored.get())->get_value() == + dynamic_cast(et.get())->get_value()); + REQUIRE(dynamic_cast(restored.get())->get_epsilon() == + dynamic_cast(et.get())->get_epsilon()); + + // Process new data and verify behavior + restored->clear_all_output_ports(); + restored->receive_data(create_message(3, NumberData{3.0}), 0); + restored->execute(); + + auto& output = restored->get_output_queue(0); + REQUIRE(!output.empty()); + auto* msg = dynamic_cast*>(output[0].get()); + REQUIRE(msg->time == 3); + REQUIRE(msg->data.value == 3.0); + } + + WHEN("NotEqualTo is serialized and restored") { + + // Serialize state + Bytes state = net->collect(); + + // Create new operator and restore state + auto restored = make_not_equal_to("net1", 3.0); + auto it = state.cbegin(); + restored->restore(it); + + // Verify restored state + REQUIRE(*restored == *net); + REQUIRE(restored->type_name() == net->type_name()); + REQUIRE(dynamic_cast(restored.get())->get_value() == + dynamic_cast(net.get())->get_value()); + REQUIRE(dynamic_cast(restored.get())->get_epsilon() == + dynamic_cast(net.get())->get_epsilon()); + + // Process new data and verify behavior + restored->clear_all_output_ports(); + restored->receive_data(create_message(3, NumberData{3.0}), 0); + restored->execute(); + + auto& output = restored->get_output_queue(0); + REQUIRE(output.empty()); + } } } \ No newline at end of file diff --git a/libs/std/test/test_function.cpp b/libs/std/test/test_function.cpp index ac79a6ce..2b90700c 100644 --- a/libs/std/test/test_function.cpp +++ b/libs/std/test/test_function.cpp @@ -73,6 +73,25 @@ SCENARIO("Function operator handles edge cases", "[function]") { } } +SCENARIO("Function operator handles state serialization", "[function][State]") { + GIVEN("a valid function") { + std::vector> points = {{0.0, 0.0},{1.0, 1.0},{2.0, 2.0}}; + auto function = std::make_shared("func", points); + function->receive_data(create_message(1, NumberData{1.0}), 0); + function->execute(); + function->receive_data(create_message(2, NumberData{2.0}), 0); + + Bytes state = function->collect(); + auto restored = std::make_shared("func", points); + auto it = state.cbegin(); + restored->restore(it); + + SECTION("verifying deserialization") { + REQUIRE(*function == *restored); + } + } +} + SCENARIO("Function operator processes messages in sequence", "[function]") { GIVEN("A linear function with multiple points") { std::vector> points = {{0.0, 0.0}, {1.0, 2.0}, {2.0, 4.0}}; diff --git a/libs/std/test/test_identity.cpp b/libs/std/test/test_identity.cpp index d0ccabb6..510d2830 100644 --- a/libs/std/test/test_identity.cpp +++ b/libs/std/test/test_identity.cpp @@ -47,7 +47,7 @@ SCENARIO("Identity operator handles basic message forwarding", "[identity]") { } } -SCENARIO("Identity operator handles state serialization", "[identity]") { +SCENARIO("Identity operator handles state serialization", "[identity][State]") { GIVEN("An identity operator with processed messages") { auto identity = make_identity("id1"); @@ -68,6 +68,7 @@ SCENARIO("Identity operator handles state serialization", "[identity]") { restored->restore(it); THEN("Behavior is preserved") { + REQUIRE(*restored == *identity); restored->clear_all_output_ports(); restored->receive_data(create_message(5, NumberData{50.0}), 0); restored->execute(); diff --git a/libs/std/test/test_infinite_impulse_response.cpp b/libs/std/test/test_infinite_impulse_response.cpp index d95b5f04..cd67e730 100644 --- a/libs/std/test/test_infinite_impulse_response.cpp +++ b/libs/std/test/test_infinite_impulse_response.cpp @@ -156,6 +156,7 @@ SCENARIO("InfiniteImpulseResponse operator handles serialization", "[iir]") { const auto& rest_output = restored->get_output_queue(0); REQUIRE(orig_output.size() == rest_output.size()); + REQUIRE(*iir == *restored); auto* orig_msg = dynamic_cast*>(orig_output[0].get()); auto* rest_msg = dynamic_cast*>(rest_output[0].get()); diff --git a/libs/std/test/test_linear.cpp b/libs/std/test/test_linear.cpp index d1e87de6..046af528 100644 --- a/libs/std/test/test_linear.cpp +++ b/libs/std/test/test_linear.cpp @@ -87,6 +87,29 @@ SCENARIO("Linear operator validates configuration", "[linear]") { } } +SCENARIO("Linear operator handles serialization", "[linear][serialization]") { + GIVEN("A linear join" ) { + auto linear = std::make_shared("linear", std::vector{1.0, 2.0, 3.0}); + linear->receive_data(create_message(1, NumberData{1.0}), 0); + linear->receive_data(create_message(1, NumberData{1.0}), 1); + linear->receive_data(create_message(1, NumberData{1.0}), 2); + linear->execute(); + linear->receive_data(create_message(2, NumberData{1.0}), 0); + linear->receive_data(create_message(2, NumberData{1.0}), 1); + linear->receive_data(create_message(2, NumberData{1.0}), 2); + + Bytes state = linear->collect(); + auto restored = std::make_shared("linear", std::vector{1.0, 2.0, 3.0}); + auto it = state.cbegin(); + restored->restore(it); + + + SECTION("verifying deserialization") { + REQUIRE(*restored == *linear); + } + } +} + SCENARIO("Linear operator handles numerical stability", "[linear]") { GIVEN("A Linear operator with large coefficients") { std::vector coeffs = {1e6, -1e6}; // Large opposing coefficients diff --git a/libs/std/test/test_moving_average.cpp b/libs/std/test/test_moving_average.cpp index f69c4f03..5893ce7e 100644 --- a/libs/std/test/test_moving_average.cpp +++ b/libs/std/test/test_moving_average.cpp @@ -93,6 +93,7 @@ SCENARIO("MovingAverage operator handles state serialization", "[moving_average] THEN("State is correctly preserved") { REQUIRE(restored.mean() == ma.mean()); + REQUIRE(restored == ma); AND_WHEN("New data is added to both") { ma.receive_data(create_message(3, NumberData{6.0}), 0); diff --git a/libs/std/test/test_moving_sum.cpp b/libs/std/test/test_moving_sum.cpp index 04b30053..388ca479 100644 --- a/libs/std/test/test_moving_sum.cpp +++ b/libs/std/test/test_moving_sum.cpp @@ -69,7 +69,7 @@ SCENARIO("MovingSum operator handles basic calculations", "[moving_sum]") { } } -SCENARIO("MovingSum operator handles state serialization", "[moving_sum]") { +SCENARIO("MovingSum operator handles state serialization", "[moving_sum][State]") { GIVEN("A MovingSum operator with some data") { auto ms = MovingSum("test_ms", 3); @@ -77,6 +77,7 @@ SCENARIO("MovingSum operator handles state serialization", "[moving_sum]") { ms.receive_data(create_message(1, NumberData{2.0}), 0); ms.execute(); ms.receive_data(create_message(2, NumberData{4.0}), 0); + ms.receive_data(create_message(3, NumberData{6.0}), 0); ms.execute(); WHEN("State is serialized and restored") { @@ -92,26 +93,29 @@ SCENARIO("MovingSum operator handles state serialization", "[moving_sum]") { THEN("State is correctly preserved") { REQUIRE(restored.sum() == ms.sum()); + REQUIRE(restored == ms); AND_WHEN("New data is added to both") { - ms.receive_data(create_message(3, NumberData{6.0}), 0); - restored.receive_data(create_message(3, NumberData{6.0}), 0); + ms.receive_data(create_message(4, NumberData{8.0}), 0); + restored.receive_data(create_message(4, NumberData{8.0}), 0); ms.execute(); restored.execute(); THEN("Both produce identical output") { - const auto& orig_output = ms.get_output_queue(0); - const auto& rest_output = restored.get_output_queue(0); + auto& orig_output = ms.get_output_queue(0); + auto& rest_output = restored.get_output_queue(0); REQUIRE(orig_output.size() == rest_output.size()); - if (!orig_output.empty()) { + while (!orig_output.empty()) { const auto* orig_msg = dynamic_cast*>(orig_output.front().get()); const auto* rest_msg = dynamic_cast*>(rest_output.front().get()); REQUIRE(orig_msg->time == rest_msg->time); REQUIRE(orig_msg->data.value == rest_msg->data.value); + orig_output.pop_front(); + rest_output.pop_front(); } } } diff --git a/libs/std/test/test_peak_detector.cpp b/libs/std/test/test_peak_detector.cpp index 1933151d..9d5783d3 100644 --- a/libs/std/test/test_peak_detector.cpp +++ b/libs/std/test/test_peak_detector.cpp @@ -75,7 +75,7 @@ SCENARIO("PeakDetector handles edge cases", "[PeakDetector]") { } } -SCENARIO("PeakDetector handles state serialization", "[PeakDetector]") { +SCENARIO("PeakDetector handles state serialization", "[PeakDetector][State]") { GIVEN("A PeakDetector with processed data") { auto detector = std::make_unique("test", 3); @@ -99,6 +99,7 @@ SCENARIO("PeakDetector handles state serialization", "[PeakDetector]") { const auto& rest_buf = restored->buffer(); REQUIRE(orig_buf.size() == rest_buf.size()); + REQUIRE(*detector == *restored); for (size_t i = 0; i < orig_buf.size(); ++i) { REQUIRE(orig_buf[i]->time == rest_buf[i]->time); REQUIRE(orig_buf[i]->data.value == rest_buf[i]->data.value); @@ -164,8 +165,8 @@ SCENARIO("PeakDetector works in a PPG analysis pipeline", "[PeakDetector][Integr for (size_t i = 0; i < s.ti.size(); i++) { input->receive_data(create_message(s.ti[i], NumberData{s.ppg[i]}), 0); // std::cout << "Processing PPG data at " << s.ti[i] << std::endl; - input->execute(); - const auto& output_queue = join->get_output_queue(0); + input->execute(true); + const auto& output_queue = join->get_debug_output_queue(0); for (const auto& msg : output_queue) { const auto* data = dynamic_cast*>(msg.get()); diff --git a/libs/std/test/test_pipeline.cpp b/libs/std/test/test_pipeline.cpp index ff1a0db3..53ef974b 100644 --- a/libs/std/test/test_pipeline.cpp +++ b/libs/std/test/test_pipeline.cpp @@ -155,7 +155,7 @@ SCENARIO("Pipeline handles type checking", "[pipeline]") { } } -SCENARIO("Pipeline handles state serialization correctly", "[pipeline][serialization]") { +SCENARIO("Pipeline handles state serialization correctly", "[pipeline][State]") { GIVEN("A pipeline with base operator state") { // Create pipeline and verify base operator serialization auto pipeline = std::make_unique("serial_pipe", std::vector{PortType::NUMBER}, @@ -176,6 +176,7 @@ SCENARIO("Pipeline handles state serialization correctly", "[pipeline][serializa REQUIRE(restored->id() == pipeline->id()); REQUIRE(restored->num_data_ports() == pipeline->num_data_ports()); REQUIRE(restored->num_output_ports() == pipeline->num_output_ports()); + REQUIRE(*restored==*pipeline); } } } diff --git a/libs/std/test/test_replace.cpp b/libs/std/test/test_replace.cpp index 5b1c476c..13f856bd 100644 --- a/libs/std/test/test_replace.cpp +++ b/libs/std/test/test_replace.cpp @@ -76,7 +76,7 @@ SCENARIO("LessThanOrEqualToReplace handles error cases", "[replace_op]") { } } -SCENARIO("ReplaceOp handles serialization", "[replace_op]") { +SCENARIO("ReplaceOp handles serialization", "[replace_op][State]") { SECTION("LessThanOrEqualToReplace operator serialization") { auto ltR = make_less_than_or_equal_to_replace("ltR", 3.0, 2.0); @@ -94,6 +94,7 @@ SCENARIO("ReplaceOp handles serialization", "[replace_op]") { restored->restore(it); // Verify restored state + REQUIRE(*restored == *ltR); REQUIRE(restored->type_name() == ltR->type_name()); REQUIRE(dynamic_cast(restored.get())->get_threshold() == dynamic_cast(ltR.get())->get_threshold()); diff --git a/libs/std/test/test_resampler_constant.cpp b/libs/std/test/test_resampler_constant.cpp index 2ba29795..4c2fedf1 100644 --- a/libs/std/test/test_resampler_constant.cpp +++ b/libs/std/test/test_resampler_constant.cpp @@ -75,6 +75,30 @@ SCENARIO("ResamplerConstant upsampling without t0", "[ResamplerConstant]") { } } +SCENARIO("ResamplerConstant operator handles serialization", "[ResamplerConstant][State]") { + GIVEN("A linear join" ) { + auto rc = ResamplerConstant("resampler_constant", 1); + rc.receive_data(create_message(1, NumberData{1.0}), 0); + rc.receive_data(create_message(3, NumberData{3.0}), 0); + rc.receive_data(create_message(5, NumberData{5.0}), 0); + rc.receive_data(create_message(7, NumberData{7.0}), 0); + rc.receive_data(create_message(9, NumberData{9.0}), 0); + rc.execute(); + rc.receive_data(create_message(11, NumberData{11.0}), 0); + + + Bytes state = rc.collect(); + auto restored = ResamplerConstant("resampler_constant", 1); + auto it = state.cbegin(); + restored.restore(it); + + + SECTION("verifying deserialization") { + REQUIRE(restored == rc); + } + } +} + SCENARIO("ResamplerConstant with fixed t0", "[ResamplerConstant]") { auto resampler = ResamplerConstant("test", 10, 5); // Grid: 5,15,25,... diff --git a/libs/std/test/test_resampler_hermite.cpp b/libs/std/test/test_resampler_hermite.cpp index 0c8e3577..0f33eed8 100644 --- a/libs/std/test/test_resampler_hermite.cpp +++ b/libs/std/test/test_resampler_hermite.cpp @@ -131,7 +131,7 @@ SCENARIO("ResamplerHermite handles edge cases", "[resampler][hermite]") { } } -SCENARIO("ResamplerHermite maintains state correctly", "[resampler][hermite]") { +SCENARIO("ResamplerHermite maintains state correctly", "[resampler][hermite][State]") { GIVEN("A hermite resampler with interval 4") { auto resampler = make_resampler_hermite("test", 4); @@ -155,6 +155,7 @@ SCENARIO("ResamplerHermite maintains state correctly", "[resampler][hermite]") { REQUIRE(restored->get_interval() == resampler->get_interval()); REQUIRE(restored->get_next_emission_time() == resampler->get_next_emission_time()); REQUIRE(restored->buffer_size() == resampler->buffer_size()); + REQUIRE(*restored == *resampler); } AND_WHEN("New data is processed") { diff --git a/libs/std/test/test_standard_deviation.cpp b/libs/std/test/test_standard_deviation.cpp index 7a73be93..0b7d5100 100644 --- a/libs/std/test/test_standard_deviation.cpp +++ b/libs/std/test/test_standard_deviation.cpp @@ -97,7 +97,7 @@ SCENARIO("StandardDeviation operator handles edge cases", "[StandardDeviation]") } } -SCENARIO("StandardDeviation operator handles state serialization", "[StandardDeviation]") { +SCENARIO("StandardDeviation operator handles state serialization", "[StandardDeviation][State]") { GIVEN("A StandardDeviation operator with processed messages") { auto sd = StandardDeviation("sd1", 3); @@ -121,17 +121,20 @@ SCENARIO("StandardDeviation operator handles state serialization", "[StandardDev restored.restore(it); THEN("Statistical calculations match") { - const auto& orig_output = sd.get_output_queue(0); - const auto& rest_output = restored.get_output_queue(0); + auto& orig_output = sd.get_output_queue(0); + auto& rest_output = restored.get_output_queue(0); REQUIRE(orig_output.size() == rest_output.size()); + REQUIRE(sd == restored); - if (!orig_output.empty()) { - auto* orig_msg = dynamic_cast*>(orig_output[0].get()); - auto* rest_msg = dynamic_cast*>(rest_output[0].get()); + while (!orig_output.empty()) { + auto* orig_msg = dynamic_cast*>(orig_output.front().get()); + auto* rest_msg = dynamic_cast*>(rest_output.front().get()); REQUIRE(orig_msg->time == rest_msg->time); REQUIRE(orig_msg->data.value == rest_msg->data.value); + orig_output.pop_front(); + rest_output.pop_front(); } } diff --git a/libs/std/test/test_time_shift.cpp b/libs/std/test/test_time_shift.cpp index 8f22ece0..2fbfb69f 100644 --- a/libs/std/test/test_time_shift.cpp +++ b/libs/std/test/test_time_shift.cpp @@ -103,6 +103,7 @@ SCENARIO("TimeShift handles state serialization", "[TimeShift]") { const auto& rest_output = restored->get_output_queue(0); REQUIRE(orig_output.size() == rest_output.size()); + REQUIRE(*time_shift == *restored); for (size_t i = 0; i < orig_output.size(); ++i) { const auto* orig_msg = dynamic_cast*>(orig_output[i].get()); const auto* rest_msg = dynamic_cast*>(rest_output[i].get()); diff --git a/libs/std/test/test_variable.cpp b/libs/std/test/test_variable.cpp index 01854834..9dfc8ecb 100644 --- a/libs/std/test/test_variable.cpp +++ b/libs/std/test/test_variable.cpp @@ -57,7 +57,7 @@ SCENARIO("Variable operator handles basic operations", "[variable]") { } } -SCENARIO("Variable operator handles state serialization", "[variable]") { +SCENARIO("Variable operator handles state serialization", "[variable][State]") { GIVEN("A Variable with non-trivial state") { auto var = make_variable("var1", 42.0); @@ -84,6 +84,7 @@ SCENARIO("Variable operator handles state serialization", "[variable]") { for (auto t : {5, 10, 15, 20, 25}) { var->receive_control(create_message(t, NumberData{0.0}), 0); restored->receive_control(create_message(t, NumberData{0.0}), 0); + REQUIRE(*var == *restored); } var->execute(); @@ -94,6 +95,7 @@ SCENARIO("Variable operator handles state serialization", "[variable]") { const auto& rest_output = restored->get_output_queue(0); REQUIRE(orig_output.size() == rest_output.size()); + REQUIRE(*var == *restored); for (size_t i = 0; i < orig_output.size(); i++) { const auto* orig_msg = dynamic_cast*>(orig_output[i].get());