diff --git a/.cargo/config.toml b/.cargo/config.toml index 32bd6f25..1a40b6d7 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -5,4 +5,4 @@ linker = "x86_64-linux-musl-gcc" linker = "aarch64-linux-gnu-gcc" [target.aarch64-unknown-linux-musl] -linker = "aarch64-linux-musl-gcc" +linker = "aarch64-linux-musl-gcc" \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 6f45c824..dd64ee2c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -782,6 +782,26 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" +[[package]] +name = "bincode" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36eaf5d7b090263e8150820482d5d93cd964a81e4019913c972f4edcc6edb740" +dependencies = [ + "bincode_derive", + "serde", + "unty", +] + +[[package]] +name = "bincode_derive" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf95709a440f45e986983918d0e8a1f30a9b1df04918fc828670606804ac3c09" +dependencies = [ + "virtue", +] + [[package]] name = "bindgen" version = "0.72.1" @@ -3460,6 +3480,7 @@ dependencies = [ "axum", "axum-extra", "axum-server", + "bincode", "bytes", "chrono", "clap", @@ -3501,6 +3522,7 @@ dependencies = [ "url", "urlencoding", "uuid", + "xxhash-rust", ] [[package]] @@ -4204,6 +4226,12 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "unty" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae" + [[package]] name = "url" version = "2.5.7" @@ -4282,6 +4310,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "virtue" +version = "0.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "051eb1abcf10076295e815102942cc58f9d5e3b4560e46e53c21e8ff6f3af7b1" + [[package]] name = "vsimd" version = "0.8.0" @@ -4866,6 +4900,12 @@ version = "0.13.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" +[[package]] +name = "xxhash-rust" +version = "0.8.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" + [[package]] name = "yoke" version = "0.8.0" diff --git a/Cargo.toml b/Cargo.toml index f95ab68c..01754dda 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,6 +63,8 @@ mockall = "0.13.1" rustls = { version = "0.23.29", features = ["ring"] } tower = "0.5.2" mimalloc = { version = "0.1.48", features = ["v3"] } +bincode = "2.0.1" +xxhash-rust = { version = "0.8", features = ["xxh3"] } [dependencies.uuid] diff --git a/src/adapter/binary_protocol.rs b/src/adapter/binary_protocol.rs new file mode 100644 index 00000000..7ad080a9 --- /dev/null +++ b/src/adapter/binary_protocol.rs @@ -0,0 +1,429 @@ +use crate::adapter::horizontal_adapter::{ + BroadcastMessage, RequestBody, RequestType, ResponseBody, +}; +use crate::channel::PresenceMemberInfo; +use crate::error::{Error, Result}; +use bincode::{Decode, Encode}; +use std::collections::{HashMap, HashSet}; + +/// Protocol version for backward compatibility during rolling upgrades +pub const BINARY_PROTOCOL_VERSION: u8 = 1; + +/// Maximum message size (10MB) to prevent DoS attacks +pub const MAX_MESSAGE_SIZE: u64 = 10 * 1024 * 1024; + +/// Get the bincode configuration for consistent serialization +/// Uses standard config with variable-length encoding for optimal size +pub fn bincode_config() -> impl bincode::config::Config { + bincode::config::standard() + .with_little_endian() + .with_variable_int_encoding() + .with_limit::<{ MAX_MESSAGE_SIZE as usize }>() +} + +/// Binary envelope for broadcast messages +/// Wraps the original JSON client payload without re-parsing +#[derive(Debug, Clone, Encode, Decode)] +pub struct BinaryBroadcastMessage { + /// Protocol version for backward compatibility + pub version: u8, + + /// Hash of channel name for fast routing (xxh3_64) + pub channel_hash: u64, + + /// Channel name (kept for routing, but we also use hash) + pub channel: String, + + /// Node ID as fixed-size bytes for efficiency + pub node_id_bytes: [u8; 16], // UUID is 16 bytes + + /// App ID (keeping as string since it's user-defined) + pub app_id: String, + + /// Raw client JSON message bytes (not re-parsed) + /// This is the exact JSON that will be sent to clients + pub raw_client_json: Vec, + + /// Optional socket ID to exclude from broadcast + pub except_socket_id: Option, + + /// Timestamp in milliseconds since epoch (microsecond precision as f64) + pub timestamp_ms: Option, +} + +/// Binary envelope for request messages +#[derive(Debug, Clone, Encode, Decode)] +pub struct BinaryRequestBody { + /// Protocol version + pub version: u8, + + /// Request ID as fixed-size bytes + pub request_id_bytes: [u8; 16], + + /// Node ID as fixed-size bytes + pub node_id_bytes: [u8; 16], + + /// App ID + pub app_id: String, + + /// Request type (enum needs custom encoding - we'll use u8) + pub request_type_discriminant: u8, + + /// Channel name (optional) + pub channel: Option, + + /// Channel hash (optional, for fast routing) + pub channel_hash: Option, + + /// Socket ID (optional) + pub socket_id: Option, + + /// User ID (optional) + pub user_id: Option, + + /// Serialized user info (as JSON bytes) + pub user_info_bytes: Option>, + + /// Timestamp for heartbeat + pub timestamp: Option, + + /// Dead node ID (optional) + pub dead_node_id: Option, + + /// Target node ID (optional) + pub target_node_id: Option, +} + +/// Binary envelope for response messages +#[derive(Debug, Clone, Encode, Decode)] +pub struct BinaryResponseBody { + /// Protocol version + pub version: u8, + + /// Request ID as fixed-size bytes + pub request_id_bytes: [u8; 16], + + /// Node ID as fixed-size bytes + pub node_id_bytes: [u8; 16], + + /// App ID + pub app_id: String, + + /// Serialized members map (as JSON bytes, contains serde_json::Value) + pub members_bytes: Option>, + + /// Serialized channels with socket count (as bincode bytes) + pub channels_with_sockets_count: HashMap, + + /// Socket IDs + pub socket_ids: Vec, + + /// Sockets count + pub sockets_count: usize, + + /// Exists flag + pub exists: bool, + + /// Channels set + pub channels: HashSet, + + /// Members count + pub members_count: usize, +} + +/// Convert UUID string to fixed-size bytes +fn uuid_to_bytes(uuid_str: &str) -> Result<[u8; 16]> { + uuid::Uuid::parse_str(uuid_str) + .map(|u| *u.as_bytes()) + .map_err(|e| Error::Other(format!("Failed to parse UUID: {}", e))) +} + +/// Convert fixed-size bytes back to UUID string +fn bytes_to_uuid(bytes: &[u8; 16]) -> String { + uuid::Uuid::from_bytes(*bytes).to_string() +} + +/// Calculate xxh3 hash of a string for fast routing +pub fn calculate_channel_hash(channel: &str) -> u64 { + xxhash_rust::xxh3::xxh3_64(channel.as_bytes()) +} + +/// Convert RequestType to u8 discriminant for efficient binary encoding +fn request_type_to_u8(request_type: &RequestType) -> u8 { + match request_type { + RequestType::ChannelMembers => 0, + RequestType::ChannelSockets => 1, + RequestType::ChannelSocketsCount => 2, + RequestType::SocketExistsInChannel => 3, + RequestType::TerminateUserConnections => 4, + RequestType::ChannelsWithSocketsCount => 5, + RequestType::Sockets => 6, + RequestType::Channels => 7, + RequestType::SocketsCount => 8, + RequestType::ChannelMembersCount => 9, + RequestType::CountUserConnectionsInChannel => 10, + RequestType::PresenceMemberJoined => 11, + RequestType::PresenceMemberLeft => 12, + RequestType::Heartbeat => 13, + RequestType::NodeDead => 14, + RequestType::PresenceStateSync => 15, + } +} + +/// Convert u8 discriminant back to RequestType +fn u8_to_request_type(discriminant: u8) -> Result { + match discriminant { + 0 => Ok(RequestType::ChannelMembers), + 1 => Ok(RequestType::ChannelSockets), + 2 => Ok(RequestType::ChannelSocketsCount), + 3 => Ok(RequestType::SocketExistsInChannel), + 4 => Ok(RequestType::TerminateUserConnections), + 5 => Ok(RequestType::ChannelsWithSocketsCount), + 6 => Ok(RequestType::Sockets), + 7 => Ok(RequestType::Channels), + 8 => Ok(RequestType::SocketsCount), + 9 => Ok(RequestType::ChannelMembersCount), + 10 => Ok(RequestType::CountUserConnectionsInChannel), + 11 => Ok(RequestType::PresenceMemberJoined), + 12 => Ok(RequestType::PresenceMemberLeft), + 13 => Ok(RequestType::Heartbeat), + 14 => Ok(RequestType::NodeDead), + 15 => Ok(RequestType::PresenceStateSync), + _ => Err(Error::Other(format!( + "Unknown request type discriminant: {}", + discriminant + ))), + } +} + +impl From for BinaryBroadcastMessage { + fn from(msg: BroadcastMessage) -> Self { + let node_id_bytes = uuid_to_bytes(&msg.node_id).unwrap_or([0u8; 16]); + let channel_hash = calculate_channel_hash(&msg.channel); + + // The message field contains the JSON string that should be sent to clients + let raw_client_json = msg.message.into_bytes(); + + Self { + version: BINARY_PROTOCOL_VERSION, + channel_hash, + channel: msg.channel, + node_id_bytes, + app_id: msg.app_id, + raw_client_json, + except_socket_id: msg.except_socket_id, + timestamp_ms: msg.timestamp_ms, + } + } +} + +impl From for BroadcastMessage { + fn from(binary: BinaryBroadcastMessage) -> Self { + Self { + node_id: bytes_to_uuid(&binary.node_id_bytes), + app_id: binary.app_id, + channel: binary.channel, + message: String::from_utf8_lossy(&binary.raw_client_json).to_string(), + except_socket_id: binary.except_socket_id, + timestamp_ms: binary.timestamp_ms, + } + } +} + +impl TryFrom for BinaryRequestBody { + type Error = Error; + + fn try_from(req: RequestBody) -> Result { + let request_id_bytes = uuid_to_bytes(&req.request_id)?; + let node_id_bytes = uuid_to_bytes(&req.node_id)?; + let channel_hash = req.channel.as_ref().map(|c| calculate_channel_hash(c)); + let request_type_discriminant = request_type_to_u8(&req.request_type); + + // Serialize user_info to JSON bytes + let user_info_bytes = req + .user_info + .map(|v| serde_json::to_vec(&v)) + .transpose() + .map_err(|e| Error::Other(format!("Failed to serialize user_info: {}", e)))?; + + Ok(Self { + version: BINARY_PROTOCOL_VERSION, + request_id_bytes, + node_id_bytes, + app_id: req.app_id, + request_type_discriminant, + channel: req.channel, + channel_hash, + socket_id: req.socket_id, + user_id: req.user_id, + user_info_bytes, + timestamp: req.timestamp, + dead_node_id: req.dead_node_id, + target_node_id: req.target_node_id, + }) + } +} + +impl TryFrom for RequestBody { + type Error = Error; + + fn try_from(binary: BinaryRequestBody) -> Result { + let request_id = bytes_to_uuid(&binary.request_id_bytes); + let node_id = bytes_to_uuid(&binary.node_id_bytes); + let request_type = u8_to_request_type(binary.request_type_discriminant)?; + + // Deserialize user_info from JSON bytes if present + let user_info = binary + .user_info_bytes + .map(|bytes| serde_json::from_slice(&bytes)) + .transpose() + .map_err(|e| Error::Other(format!("Failed to deserialize user_info: {}", e)))?; + + Ok(Self { + request_id, + node_id, + app_id: binary.app_id, + request_type, + channel: binary.channel, + socket_id: binary.socket_id, + user_id: binary.user_id, + user_info, + timestamp: binary.timestamp, + dead_node_id: binary.dead_node_id, + target_node_id: binary.target_node_id, + }) + } +} + +impl TryFrom for BinaryResponseBody { + type Error = Error; + + fn try_from(resp: ResponseBody) -> Result { + let request_id_bytes = uuid_to_bytes(&resp.request_id)?; + let node_id_bytes = uuid_to_bytes(&resp.node_id)?; + + // Serialize members using JSON (because PresenceMemberInfo contains serde_json::Value) + let members_bytes = if !resp.members.is_empty() { + Some( + serde_json::to_vec(&resp.members) + .map_err(|e| Error::Other(format!("Failed to serialize members: {}", e)))?, + ) + } else { + None + }; + + Ok(Self { + version: BINARY_PROTOCOL_VERSION, + request_id_bytes, + node_id_bytes, + app_id: resp.app_id, + members_bytes, + channels_with_sockets_count: resp.channels_with_sockets_count, + socket_ids: resp.socket_ids, + sockets_count: resp.sockets_count, + exists: resp.exists, + channels: resp.channels, + members_count: resp.members_count, + }) + } +} + +impl TryFrom for ResponseBody { + type Error = Error; + + fn try_from(binary: BinaryResponseBody) -> Result { + let request_id = bytes_to_uuid(&binary.request_id_bytes); + let node_id = bytes_to_uuid(&binary.node_id_bytes); + + // Deserialize members from JSON bytes if present + let members = binary + .members_bytes + .map(|bytes| serde_json::from_slice::>(&bytes)) + .transpose() + .map_err(|e| Error::Other(format!("Failed to deserialize members: {}", e)))? + .unwrap_or_default(); + + Ok(Self { + request_id, + node_id, + app_id: binary.app_id, + members, + channels_with_sockets_count: binary.channels_with_sockets_count, + socket_ids: binary.socket_ids, + sockets_count: binary.sockets_count, + exists: binary.exists, + channels: binary.channels, + members_count: binary.members_count, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_uuid_conversion() { + let uuid_str = "550e8400-e29b-41d4-a716-446655440000"; + let bytes = uuid_to_bytes(uuid_str).unwrap(); + let recovered = bytes_to_uuid(&bytes); + assert_eq!(uuid_str, recovered); + } + + #[test] + fn test_channel_hash() { + let channel1 = "test-channel"; + let channel2 = "test-channel"; + let channel3 = "different-channel"; + + assert_eq!( + calculate_channel_hash(channel1), + calculate_channel_hash(channel2) + ); + assert_ne!( + calculate_channel_hash(channel1), + calculate_channel_hash(channel3) + ); + } + + #[test] + fn test_broadcast_message_conversion() { + let original = BroadcastMessage { + node_id: "550e8400-e29b-41d4-a716-446655440000".to_string(), + app_id: "test-app".to_string(), + channel: "test-channel".to_string(), + message: r#"{"event":"test","data":"payload"}"#.to_string(), + except_socket_id: Some("socket-123".to_string()), + timestamp_ms: Some(1234567890.123), + }; + + let binary: BinaryBroadcastMessage = original.clone().into(); + assert_eq!(binary.version, BINARY_PROTOCOL_VERSION); + assert_eq!(binary.app_id, original.app_id); + assert_eq!(binary.channel, original.channel); + + let recovered: BroadcastMessage = binary.into(); + assert_eq!(recovered.node_id, original.node_id); + assert_eq!(recovered.app_id, original.app_id); + assert_eq!(recovered.channel, original.channel); + assert_eq!(recovered.message, original.message); + } + + #[test] + fn test_bincode_serialization_size() { + let msg = BroadcastMessage { + node_id: "550e8400-e29b-41d4-a716-446655440000".to_string(), + app_id: "test-app".to_string(), + channel: "test-channel".to_string(), + message: r#"{"event":"test","data":"payload"}"#.to_string(), + except_socket_id: None, + timestamp_ms: Some(1234567890.123), + }; + + let _json_size = serde_json::to_vec(&msg).unwrap().len(); + let binary: BinaryBroadcastMessage = msg.into(); + let _binary_size = bincode::encode_to_vec(&binary, bincode_config()) + .unwrap() + .len(); + } +} diff --git a/src/adapter/mod.rs b/src/adapter/mod.rs index e28b1375..f8a515be 100644 --- a/src/adapter/mod.rs +++ b/src/adapter/mod.rs @@ -1,3 +1,4 @@ +pub mod binary_protocol; pub mod connection_manager; pub mod factory; pub mod handler; diff --git a/src/adapter/transports/nats_transport.rs b/src/adapter/transports/nats_transport.rs index 80dc605b..2394d1d0 100644 --- a/src/adapter/transports/nats_transport.rs +++ b/src/adapter/transports/nats_transport.rs @@ -1,3 +1,6 @@ +use crate::adapter::binary_protocol::{ + BinaryBroadcastMessage, BinaryRequestBody, BinaryResponseBody, bincode_config, +}; use crate::adapter::horizontal_adapter::{BroadcastMessage, RequestBody, ResponseBody}; use crate::adapter::horizontal_transport::{ HorizontalTransport, TransportConfig, TransportHandlers, @@ -8,7 +11,7 @@ use async_nats::{Client as NatsClient, ConnectOptions as NatsOptions, Subject}; use async_trait::async_trait; use futures::StreamExt; use std::time::Duration; -use tracing::{debug, info}; +use tracing::{debug, info, warn}; /// NATS transport implementation #[derive(Clone)] @@ -82,8 +85,10 @@ impl HorizontalTransport for NatsTransport { } async fn publish_broadcast(&self, message: &BroadcastMessage) -> Result<()> { - let message_data = serde_json::to_vec(message) - .map_err(|e| Error::Other(format!("Failed to serialize broadcast message: {e}")))?; + // Convert to binary format + let binary_msg: BinaryBroadcastMessage = message.clone().into(); + let message_data = bincode::encode_to_vec(&binary_msg, bincode_config()) + .map_err(|e| Error::Other(format!("Failed to serialize broadcast: {}", e)))?; self.client .publish( @@ -98,8 +103,10 @@ impl HorizontalTransport for NatsTransport { } async fn publish_request(&self, request: &RequestBody) -> Result<()> { - let request_data = serde_json::to_vec(request) - .map_err(|e| Error::Other(format!("Failed to serialize request: {e}")))?; + // Convert to binary format + let binary_req: BinaryRequestBody = request.clone().try_into()?; + let request_data = bincode::encode_to_vec(&binary_req, bincode_config()) + .map_err(|e| Error::Other(format!("Failed to serialize request: {}", e)))?; self.client .publish( @@ -114,8 +121,10 @@ impl HorizontalTransport for NatsTransport { } async fn publish_response(&self, response: &ResponseBody) -> Result<()> { - let response_data = serde_json::to_vec(response) - .map_err(|e| Error::Other(format!("Failed to serialize response: {e}")))?; + // Convert to binary format + let binary_resp: BinaryResponseBody = response.clone().try_into()?; + let response_data = bincode::encode_to_vec(&binary_resp, bincode_config()) + .map_err(|e| Error::Other(format!("Failed to serialize response: {}", e)))?; self.client .publish( @@ -167,7 +176,11 @@ impl HorizontalTransport for NatsTransport { let broadcast_handler = handlers.on_broadcast.clone(); tokio::spawn(async move { while let Some(msg) = broadcast_subscription.next().await { - if let Ok(broadcast) = serde_json::from_slice::(&msg.payload) { + if let Ok((binary_msg, _)) = bincode::decode_from_slice::( + &msg.payload, + bincode_config(), + ) { + let broadcast: BroadcastMessage = binary_msg.into(); broadcast_handler(broadcast).await; } } @@ -177,18 +190,26 @@ impl HorizontalTransport for NatsTransport { let request_handler = handlers.on_request.clone(); tokio::spawn(async move { while let Some(msg) = request_subscription.next().await { - if let Ok(request) = serde_json::from_slice::(&msg.payload) { + if let Ok((binary_req, _)) = bincode::decode_from_slice::( + &msg.payload, + bincode_config(), + ) && let Ok(request) = RequestBody::try_from(binary_req) + { let response_result = request_handler(request).await; - if let Ok(response) = response_result - && let Ok(response_data) = serde_json::to_vec(&response) - { - let _ = response_client - .publish( - Subject::from(response_subject.clone()), - response_data.into(), - ) - .await; + if let Ok(response) = response_result { + // Serialize response to binary + if let Ok(binary_resp) = BinaryResponseBody::try_from(response) + && let Ok(response_data) = + bincode::encode_to_vec(&binary_resp, bincode_config()) + { + let _ = response_client + .publish( + Subject::from(response_subject.clone()), + response_data.into(), + ) + .await; + } } } } @@ -198,8 +219,15 @@ impl HorizontalTransport for NatsTransport { let response_handler = handlers.on_response.clone(); tokio::spawn(async move { while let Some(msg) = response_subscription.next().await { - if let Ok(response) = serde_json::from_slice::(&msg.payload) { - response_handler(response).await; + if let Ok((binary_resp, _)) = bincode::decode_from_slice::( + &msg.payload, + bincode_config(), + ) { + if let Ok(response) = ResponseBody::try_from(binary_resp) { + response_handler(response).await; + } + } else { + warn!("Failed to parse binary response message"); } } }); diff --git a/src/adapter/transports/redis_cluster_transport.rs b/src/adapter/transports/redis_cluster_transport.rs index cb115c53..c406b899 100644 --- a/src/adapter/transports/redis_cluster_transport.rs +++ b/src/adapter/transports/redis_cluster_transport.rs @@ -1,3 +1,6 @@ +use crate::adapter::binary_protocol::{ + BinaryBroadcastMessage, BinaryRequestBody, BinaryResponseBody, bincode_config, +}; use crate::adapter::horizontal_adapter::{BroadcastMessage, RequestBody, ResponseBody}; use crate::adapter::horizontal_transport::{ HorizontalTransport, TransportConfig, TransportHandlers, @@ -7,7 +10,7 @@ use crate::options::RedisClusterAdapterConfig; use async_trait::async_trait; use redis::AsyncCommands; use redis::cluster::{ClusterClient, ClusterClientBuilder}; -use tracing::{debug, error}; +use tracing::{debug, error, warn}; /// Helper function to convert redis::Value to String fn value_to_string(v: &redis::Value) -> Option { @@ -19,6 +22,14 @@ fn value_to_string(v: &redis::Value) -> Option { } } +/// Helper function to convert redis::Value to bytes (for binary data) +fn value_to_bytes(v: &redis::Value) -> Option> { + match v { + redis::Value::BulkString(bytes) => Some(bytes.clone()), + _ => None, + } +} + impl TransportConfig for RedisClusterAdapterConfig { fn request_timeout_ms(&self) -> u64 { self.request_timeout_ms @@ -64,7 +75,10 @@ impl HorizontalTransport for RedisClusterTransport { } async fn publish_broadcast(&self, message: &BroadcastMessage) -> Result<()> { - let broadcast_json = serde_json::to_string(message)?; + // Convert to binary format + let binary_msg: BinaryBroadcastMessage = message.clone().into(); + let broadcast_bytes = bincode::encode_to_vec(&binary_msg, bincode_config()) + .map_err(|e| Error::Other(format!("Failed to serialize broadcast: {}", e)))?; // Use client's internal connection pooling - this is efficient let mut conn = self.client.get_async_connection().await.map_err(|e| { @@ -73,7 +87,7 @@ impl HorizontalTransport for RedisClusterTransport { )) })?; - conn.publish::<_, _, ()>(&self.broadcast_channel, broadcast_json) + conn.publish::<_, _, ()>(&self.broadcast_channel, broadcast_bytes.as_slice()) .await .map_err(|e| Error::Redis(format!("Failed to publish broadcast: {e}")))?; @@ -81,8 +95,10 @@ impl HorizontalTransport for RedisClusterTransport { } async fn publish_request(&self, request: &RequestBody) -> Result<()> { - let request_json = serde_json::to_string(request) - .map_err(|e| Error::Other(format!("Failed to serialize request: {e}")))?; + // Convert to binary format + let binary_req: BinaryRequestBody = request.clone().try_into()?; + let request_bytes = bincode::encode_to_vec(&binary_req, bincode_config()) + .map_err(|e| Error::Other(format!("Failed to serialize request: {}", e)))?; // Use client's internal connection pooling - this is efficient for cluster let mut conn = self.client.get_async_connection().await.map_err(|e| { @@ -90,7 +106,7 @@ impl HorizontalTransport for RedisClusterTransport { })?; let subscriber_count: i32 = conn - .publish(&self.request_channel, &request_json) + .publish(&self.request_channel, request_bytes.as_slice()) .await .map_err(|e| Error::Redis(format!("Failed to publish request: {e}")))?; @@ -103,8 +119,10 @@ impl HorizontalTransport for RedisClusterTransport { } async fn publish_response(&self, response: &ResponseBody) -> Result<()> { - let response_json = serde_json::to_string(response) - .map_err(|e| Error::Other(format!("Failed to serialize response: {e}")))?; + // Convert to binary format + let binary_resp: BinaryResponseBody = response.clone().try_into()?; + let response_bytes = bincode::encode_to_vec(&binary_resp, bincode_config()) + .map_err(|e| Error::Other(format!("Failed to serialize response: {}", e)))?; // Use client's internal connection pooling - this is efficient for cluster let mut conn = self.client.get_async_connection().await.map_err(|e| { @@ -113,7 +131,7 @@ impl HorizontalTransport for RedisClusterTransport { )) })?; - conn.publish::<_, _, ()>(&self.response_channel, response_json) + conn.publish::<_, _, ()>(&self.response_channel, response_bytes.as_slice()) .await .map_err(|e| Error::Redis(format!("Failed to publish response: {e}")))?; @@ -184,8 +202,8 @@ impl HorizontalTransport for RedisClusterTransport { } }; - let payload = match value_to_string(&push_info.data[1]) { - Some(s) => s, + let payload_bytes = match value_to_bytes(&push_info.data[1]) { + Some(bytes) => bytes, None => { error!("Failed to parse payload: {:?}", push_info.data[1]); continue; @@ -203,30 +221,59 @@ impl HorizontalTransport for RedisClusterTransport { tokio::spawn(async move { if channel == broadcast_channel_clone { - // Handle broadcast message - if let Ok(broadcast) = serde_json::from_str::(&payload) { + // Handle broadcast message - deserialize from binary + if let Ok((binary_msg, _)) = + bincode::decode_from_slice::( + &payload_bytes, + bincode_config(), + ) + { + let broadcast: BroadcastMessage = binary_msg.into(); broadcast_handler(broadcast).await; } } else if channel == request_channel_clone { - // Handle request message - if let Ok(request) = serde_json::from_str::(&payload) { + // Handle request message - deserialize from binary + if let Ok((binary_req, _)) = + bincode::decode_from_slice::( + &payload_bytes, + bincode_config(), + ) + && let Ok(request) = RequestBody::try_from(binary_req) + { let response_result = request_handler(request).await; - if let Ok(response) = response_result - && let Ok(response_json) = serde_json::to_string(&response) - { - // Use client's connection pooling for response publishing - if let Ok(mut conn) = client_clone.get_async_connection().await { - let _ = conn - .publish::<_, _, ()>(&response_channel_clone, response_json) - .await; + if let Ok(response) = response_result { + // Serialize response to binary + if let Ok(binary_resp) = BinaryResponseBody::try_from(response) + && let Ok(response_bytes) = + bincode::encode_to_vec(&binary_resp, bincode_config()) + { + // Use client's connection pooling for response publishing + if let Ok(mut conn) = client_clone.get_async_connection().await + { + let _ = conn + .publish::<_, _, ()>( + &response_channel_clone, + response_bytes.as_slice(), + ) + .await; + } } } } } else if channel == response_channel_clone { - // Handle response message - if let Ok(response) = serde_json::from_str::(&payload) { - response_handler(response).await; + // Handle response message - deserialize from binary + if let Ok((binary_resp, _)) = + bincode::decode_from_slice::( + &payload_bytes, + bincode_config(), + ) + { + if let Ok(response) = ResponseBody::try_from(binary_resp) { + response_handler(response).await; + } + } else { + warn!("Failed to parse binary response message"); } } }); diff --git a/src/adapter/transports/redis_transport.rs b/src/adapter/transports/redis_transport.rs index cd40dfcd..87d59967 100644 --- a/src/adapter/transports/redis_transport.rs +++ b/src/adapter/transports/redis_transport.rs @@ -1,3 +1,6 @@ +use crate::adapter::binary_protocol::{ + BinaryBroadcastMessage, BinaryRequestBody, BinaryResponseBody, bincode_config, +}; use crate::adapter::horizontal_adapter::{BroadcastMessage, RequestBody, ResponseBody}; use crate::adapter::horizontal_transport::{ HorizontalTransport, TransportConfig, TransportHandlers, @@ -94,7 +97,10 @@ impl HorizontalTransport for RedisTransport { } async fn publish_broadcast(&self, message: &BroadcastMessage) -> Result<()> { - let broadcast_json = serde_json::to_string(message)?; + // Convert to binary format + let binary_msg: BinaryBroadcastMessage = message.clone().into(); + let broadcast_bytes = bincode::encode_to_vec(&binary_msg, bincode_config()) + .map_err(|e| Error::Other(format!("Failed to serialize broadcast: {}", e)))?; // Retry broadcast with exponential backoff to handle connection recovery let mut retry_delay = 100u64; // Start with 100ms @@ -104,7 +110,7 @@ impl HorizontalTransport for RedisTransport { for attempt in 0..=MAX_RETRIES { let mut conn = self.events_connection.clone(); match conn - .publish::<_, _, i32>(&self.broadcast_channel, &broadcast_json) + .publish::<_, _, i32>(&self.broadcast_channel, broadcast_bytes.as_slice()) .await { Ok(_subscriber_count) => { @@ -141,12 +147,14 @@ impl HorizontalTransport for RedisTransport { } async fn publish_request(&self, request: &RequestBody) -> Result<()> { - let request_json = serde_json::to_string(request) - .map_err(|e| Error::Other(format!("Failed to serialize request: {e}")))?; + // Convert to binary format + let binary_req: BinaryRequestBody = request.clone().try_into()?; + let request_bytes = bincode::encode_to_vec(&binary_req, bincode_config()) + .map_err(|e| Error::Other(format!("Failed to serialize request: {}", e)))?; let mut conn = self.connection.clone(); let subscriber_count: i32 = conn - .publish(&self.request_channel, &request_json) + .publish(&self.request_channel, request_bytes.as_slice()) .await .map_err(|e| Error::Redis(format!("Failed to publish request: {e}")))?; @@ -158,12 +166,14 @@ impl HorizontalTransport for RedisTransport { } async fn publish_response(&self, response: &ResponseBody) -> Result<()> { - let response_json = serde_json::to_string(response) - .map_err(|e| Error::Other(format!("Failed to serialize response: {e}")))?; + // Convert to binary format + let binary_resp: BinaryResponseBody = response.clone().try_into()?; + let response_bytes = bincode::encode_to_vec(&binary_resp, bincode_config()) + .map_err(|e| Error::Other(format!("Failed to serialize response: {}", e)))?; let mut conn = self.connection.clone(); let _: () = conn - .publish(&self.response_channel, response_json) + .publish(&self.response_channel, response_bytes.as_slice()) .await .map_err(|e| Error::Redis(format!("Failed to publish response: {e}")))?; @@ -224,7 +234,7 @@ impl HorizontalTransport for RedisTransport { while let Some(msg) = message_stream.next().await { let channel: String = msg.get_channel_name().to_string(); - let payload_result: redis::RedisResult = msg.get_payload(); + let payload_result: redis::RedisResult> = msg.get_payload(); if let Ok(payload) = payload_result { let broadcast_handler = handlers.on_broadcast.clone(); @@ -237,36 +247,59 @@ impl HorizontalTransport for RedisTransport { tokio::spawn(async move { if channel == broadcast_channel_clone { - // Handle broadcast message - if let Ok(broadcast) = - serde_json::from_str::(&payload) + // Handle broadcast message - deserialize from binary + if let Ok((binary_msg, _)) = + bincode::decode_from_slice::( + &payload, + bincode_config(), + ) { + let broadcast: BroadcastMessage = binary_msg.into(); broadcast_handler(broadcast).await; } } else if channel == request_channel_clone { - // Handle request message - if let Ok(request) = serde_json::from_str::(&payload) { + // Handle request message - deserialize from binary + if let Ok((binary_req, _)) = + bincode::decode_from_slice::( + &payload, + bincode_config(), + ) + && let Ok(request) = RequestBody::try_from(binary_req) + { let response_result = request_handler(request).await; - if let Ok(response) = response_result - && let Ok(response_json) = serde_json::to_string(&response) - { - let mut conn = pub_connection_clone.clone(); - let _ = conn - .publish::<_, _, ()>( - &response_channel_clone, - response_json, + if let Ok(response) = response_result { + // Serialize response to binary + if let Ok(binary_resp) = + BinaryResponseBody::try_from(response) + && let Ok(response_bytes) = bincode::encode_to_vec( + &binary_resp, + bincode_config(), ) - .await; + { + let mut conn = pub_connection_clone.clone(); + let _ = conn + .publish::<_, _, ()>( + &response_channel_clone, + response_bytes.as_slice(), + ) + .await; + } } } } else if channel == response_channel_clone { - // Handle response message - if let Ok(response) = serde_json::from_str::(&payload) + // Handle response message - deserialize from binary + if let Ok((binary_resp, _)) = + bincode::decode_from_slice::( + &payload, + bincode_config(), + ) { - response_handler(response).await; + if let Ok(response) = ResponseBody::try_from(binary_resp) { + response_handler(response).await; + } } else { - warn!("Failed to parse response message: {}", payload); + warn!("Failed to parse binary response message"); } } }); diff --git a/tests/binary_protocol_test.rs b/tests/binary_protocol_test.rs new file mode 100644 index 00000000..9d17efe1 --- /dev/null +++ b/tests/binary_protocol_test.rs @@ -0,0 +1,230 @@ +use sockudo::adapter::binary_protocol::*; +use sockudo::adapter::horizontal_adapter::{ + BroadcastMessage, RequestBody, RequestType, ResponseBody, +}; +use std::collections::{HashMap, HashSet}; + +#[test] +fn test_broadcast_message_binary_conversion() { + let original = BroadcastMessage { + node_id: "550e8400-e29b-41d4-a716-446655440000".to_string(), + app_id: "test-app".to_string(), + channel: "test-channel".to_string(), + message: r#"{"event":"pusher:test","data":{"message":"hello"}}"#.to_string(), + except_socket_id: Some("socket-123".to_string()), + timestamp_ms: Some(1234567890.123456), + }; + + // Convert to binary + let binary: BinaryBroadcastMessage = original.clone().into(); + + // Verify binary structure + assert_eq!(binary.version, BINARY_PROTOCOL_VERSION); + assert_eq!(binary.app_id, original.app_id); + assert_eq!(binary.channel, original.channel); + assert_eq!(binary.except_socket_id, original.except_socket_id); + assert_eq!(binary.timestamp_ms, original.timestamp_ms); + + // Verify raw JSON is preserved + assert_eq!( + String::from_utf8_lossy(&binary.raw_client_json), + original.message + ); + + // Convert back + let recovered: BroadcastMessage = binary.into(); + assert_eq!(recovered.node_id, original.node_id); + assert_eq!(recovered.app_id, original.app_id); + assert_eq!(recovered.channel, original.channel); + assert_eq!(recovered.message, original.message); + assert_eq!(recovered.except_socket_id, original.except_socket_id); + assert_eq!(recovered.timestamp_ms, original.timestamp_ms); +} + +#[test] +fn test_request_body_binary_conversion() { + let original = RequestBody { + request_id: "123e4567-e89b-12d3-a456-426614174000".to_string(), + node_id: "550e8400-e29b-41d4-a716-446655440000".to_string(), + app_id: "test-app".to_string(), + request_type: RequestType::ChannelMembers, + channel: Some("presence-room".to_string()), + socket_id: Some("socket-456".to_string()), + user_id: Some("user-789".to_string()), + user_info: Some(serde_json::json!({"name": "Test User"})), + timestamp: Some(1234567890), + dead_node_id: None, + target_node_id: None, + }; + + // Convert to binary + let binary: BinaryRequestBody = original.clone().try_into().unwrap(); + + // Verify binary structure + assert_eq!(binary.version, BINARY_PROTOCOL_VERSION); + assert_eq!(binary.app_id, original.app_id); + // Note: request_type is now stored as request_type_discriminant (u8) + assert_eq!(binary.channel, original.channel); + assert!(binary.channel_hash.is_some()); + + // Convert back + let recovered: RequestBody = binary.try_into().unwrap(); + assert_eq!(recovered.request_id, original.request_id); + assert_eq!(recovered.node_id, original.node_id); + assert_eq!(recovered.app_id, original.app_id); + assert_eq!(recovered.channel, original.channel); + assert_eq!(recovered.user_id, original.user_id); +} + +#[test] +fn test_response_body_binary_conversion() { + let mut members = HashMap::new(); + members.insert( + "user-1".to_string(), + sockudo::channel::PresenceMemberInfo { + user_id: "user-1".to_string(), + user_info: Option::from(serde_json::json!({"name": "User 1"})), + }, + ); + + let mut channels = HashSet::new(); + channels.insert("channel-1".to_string()); + channels.insert("channel-2".to_string()); + + let original = ResponseBody { + request_id: "123e4567-e89b-12d3-a456-426614174000".to_string(), + node_id: "550e8400-e29b-41d4-a716-446655440000".to_string(), + app_id: "test-app".to_string(), + members: members.clone(), + channels_with_sockets_count: HashMap::new(), + socket_ids: vec!["socket-1".to_string(), "socket-2".to_string()], + sockets_count: 2, + exists: true, + channels: channels.clone(), + members_count: 1, + }; + + // Convert to binary + let binary: BinaryResponseBody = original.clone().try_into().unwrap(); + + // Verify binary structure + assert_eq!(binary.version, BINARY_PROTOCOL_VERSION); + assert_eq!(binary.app_id, original.app_id); + assert_eq!(binary.sockets_count, original.sockets_count); + assert_eq!(binary.exists, original.exists); + + // Convert back + let recovered: ResponseBody = binary.try_into().unwrap(); + assert_eq!(recovered.request_id, original.request_id); + assert_eq!(recovered.node_id, original.node_id); + assert_eq!(recovered.app_id, original.app_id); + assert_eq!(recovered.socket_ids, original.socket_ids); + assert_eq!(recovered.sockets_count, original.sockets_count); + assert_eq!(recovered.exists, original.exists); + assert_eq!(recovered.channels, original.channels); + assert_eq!(recovered.members.len(), original.members.len()); +} + +#[test] +fn test_bincode_size_reduction() { + let message = BroadcastMessage { + node_id: "550e8400-e29b-41d4-a716-446655440000".to_string(), + app_id: "test-app".to_string(), + channel: "test-channel".to_string(), + message: r#"{"event":"pusher:test","data":"small payload"}"#.to_string(), + except_socket_id: None, + timestamp_ms: Some(1234567890.123), + }; + + // Serialize with JSON + let json_bytes = serde_json::to_vec(&message).unwrap(); + + // Serialize with bincode + let binary_msg: BinaryBroadcastMessage = message.into(); + let binary_bytes = bincode::encode_to_vec(&binary_msg, bincode::config::standard()).unwrap(); + + // Print sizes for comparison + println!("JSON size: {} bytes", json_bytes.len()); + println!("Binary size: {} bytes", binary_bytes.len()); + + // Binary should be competitive or better + // Note: For very small messages, overhead might make binary slightly larger, + // but for typical messages with longer payloads, binary will be smaller +} + +#[test] +fn test_channel_hash_consistency() { + let channel1 = "private-user-123"; + let channel2 = "private-user-123"; + let channel3 = "private-user-456"; + + let hash1 = calculate_channel_hash(channel1); + let hash2 = calculate_channel_hash(channel2); + let hash3 = calculate_channel_hash(channel3); + + // Same channel should produce same hash + assert_eq!(hash1, hash2); + + // Different channels should produce different hashes + assert_ne!(hash1, hash3); +} + +#[test] +fn test_large_message_handling() { + // Create a large message with a big JSON payload + let large_payload = format!( + r#"{{"event":"pusher:test","data":"{}"}}"#, + "x".repeat(1024 * 100) // 100KB payload + ); + + let message = BroadcastMessage { + node_id: "550e8400-e29b-41d4-a716-446655440000".to_string(), + app_id: "test-app".to_string(), + channel: "test-channel".to_string(), + message: large_payload.clone(), + except_socket_id: None, + timestamp_ms: Some(1234567890.123), + }; + + // Convert to binary + let binary: BinaryBroadcastMessage = message.clone().into(); + let serialized = bincode::encode_to_vec(&binary, bincode::config::standard()).unwrap(); + + // Verify it's under the size limit + assert!(serialized.len() < MAX_MESSAGE_SIZE as usize); + + // Deserialize and verify + let (deserialized, _): (BinaryBroadcastMessage, usize) = + bincode::decode_from_slice(&serialized, bincode::config::standard()).unwrap(); + let recovered: BroadcastMessage = deserialized.into(); + + assert_eq!(recovered.message, message.message); +} + +#[test] +fn test_no_double_serialization() { + // This test verifies that we're not re-parsing the client JSON + let client_json = r#"{"event":"pusher:test","data":{"nested":{"deep":"value"}}}"#; + + let message = BroadcastMessage { + node_id: "550e8400-e29b-41d4-a716-446655440000".to_string(), + app_id: "test-app".to_string(), + channel: "test-channel".to_string(), + message: client_json.to_string(), + except_socket_id: None, + timestamp_ms: Some(1234567890.123), + }; + + // Convert to binary + let binary: BinaryBroadcastMessage = message.clone().into(); + + // The raw_client_json should be exactly the input bytes + assert_eq!( + String::from_utf8(binary.raw_client_json.clone()).unwrap(), + client_json + ); + + // When we recover it, we should get the exact same JSON + let recovered: BroadcastMessage = binary.into(); + assert_eq!(recovered.message, client_json); +} diff --git a/tests/message_data_parse_test.rs b/tests/message_data_parse_test.rs new file mode 100644 index 00000000..91767936 --- /dev/null +++ b/tests/message_data_parse_test.rs @@ -0,0 +1,70 @@ +use sockudo::protocol::messages::PusherMessage; + +#[test] +fn test_subscribe_message_parsing() { + // Test case 1: Simple subscription + let json1 = r#"{"event":"pusher:subscribe","data":{"channel":"test-channel"}}"#; + let result1 = serde_json::from_str::(json1); + assert!( + result1.is_ok(), + "Failed to parse simple subscription: {:?}", + result1.err() + ); + + // Test case 2: Subscription with auth + let json2 = r#"{"event":"pusher:subscribe","data":{"channel":"private-channel","auth":"app-key:signature"}}"#; + let result2 = serde_json::from_str::(json2); + assert!( + result2.is_ok(), + "Failed to parse subscription with auth: {:?}", + result2.err() + ); + + // Test case 3: Subscription with auth and channel_data + let json3 = r#"{"event":"pusher:subscribe","data":{"channel":"presence-channel","auth":"app-key:signature","channel_data":"{\"user_id\":\"123\"}"}}"#; + let result3 = serde_json::from_str::(json3); + assert!( + result3.is_ok(), + "Failed to parse subscription with channel_data: {:?}", + result3.err() + ); + + // Test case 4: Long auth string (similar to error message) + let json4 = r#"{"event":"pusher:subscribe","data":{"channel":"private-channel","auth":"app-key:very-long-hmac-signature-string-to-reach-approximately-column-146-position-with-secret"}}"#; + let result4 = serde_json::from_str::(json4); + assert!( + result4.is_ok(), + "Failed to parse subscription with long auth: {:?}", + result4.err() + ); +} + +#[test] +fn test_message_data_variants() { + // Test string data + let json1 = r#"{"event":"test","data":"simple string"}"#; + let result1 = serde_json::from_str::(json1); + assert!( + result1.is_ok(), + "Failed to parse string data: {:?}", + result1.err() + ); + + // Test JSON object data + let json2 = r#"{"event":"test","data":{"key":"value"}}"#; + let result2 = serde_json::from_str::(json2); + assert!( + result2.is_ok(), + "Failed to parse JSON object data: {:?}", + result2.err() + ); + + // Test nested JSON data + let json3 = r#"{"event":"test","data":{"nested":{"deep":"value"}}}"#; + let result3 = serde_json::from_str::(json3); + assert!( + result3.is_ok(), + "Failed to parse nested JSON data: {:?}", + result3.err() + ); +}