diff --git a/.gitignore b/.gitignore index e64a7689..60e2afb7 100644 --- a/.gitignore +++ b/.gitignore @@ -151,3 +151,7 @@ $RECYCLE.BIN/ _NCrunch* glide-logs/ + +# Test results and coverage reports +testresults/ +reports/ diff --git a/docs/configuration-architecture-analysis.md b/docs/configuration-architecture-analysis.md new file mode 100644 index 00000000..6431cfed --- /dev/null +++ b/docs/configuration-architecture-analysis.md @@ -0,0 +1,170 @@ +# Configuration Architecture Analysis + +## Overview + +This document analyzes the configuration architecture in the Valkey.Glide C# client, focusing on the relationship between `ConnectionConfiguration` and `ConfigurationOptions`, and how configuration changes can be made through the `ConnectionMultiplexer`. + +## Configuration Classes Relationship + +### ConfigurationOptions +- **Purpose**: External API configuration class that follows StackExchange.Redis compatibility patterns +- **Location**: `sources/Valkey.Glide/Abstract/ConfigurationOptions.cs` +- **Role**: User-facing configuration interface + +### ConnectionConfiguration +- **Purpose**: Internal configuration classes that map to the underlying FFI layer +- **Location**: `sources/Valkey.Glide/ConnectionConfiguration.cs` +- **Role**: Internal configuration representation and builder pattern implementation + +## Configuration Flow + +``` +ConfigurationOptions → ClientConfigurationBuilder → ConnectionConfig → FFI.ConnectionConfig +``` + +1. **User Input**: `ConfigurationOptions` (external API) +2. **Translation**: `ConnectionMultiplexer.CreateClientConfigBuilder()` method +3. **Building**: `ClientConfigurationBuilder` (internal) +4. **Internal Config**: `ConnectionConfig` record +5. **FFI Layer**: `FFI.ConnectionConfig` + +## Key Components Analysis + +### ConnectionMultiplexer Configuration Mapping + +The `ConnectionMultiplexer.CreateClientConfigBuilder()` method at line 174 performs the critical translation: + +```csharp +internal static T CreateClientConfigBuilder(ConfigurationOptions configuration) + where T : ClientConfigurationBuilder, new() +{ + T config = new(); + foreach (EndPoint ep in configuration.EndPoints) + { + config.Addresses += Utils.SplitEndpoint(ep); + } + config.UseTls = configuration.Ssl; + // ... other mappings + _ = configuration.ReadFrom.HasValue ? config.ReadFrom = configuration.ReadFrom.Value : new(); + return config; +} +``` + +### Configuration Builders + +The builder pattern is implemented through: +- `StandaloneClientConfigurationBuilder` (line 525) +- `ClusterClientConfigurationBuilder` (line 550) + +Both inherit from `ClientConfigurationBuilder` which provides: +- Fluent API methods (`WithXxx()`) +- Property setters +- Internal `ConnectionConfig Build()` method + +## Configuration Mutability Analysis + +### Current State: Immutable After Connection + +**Connection Creation**: Configuration is set once during `ConnectionMultiplexer.ConnectAsync()`: + +```csharp +public static async Task ConnectAsync(ConfigurationOptions configuration, TextWriter? log = null) +{ + // Configuration is translated and used to create the client + StandaloneClientConfiguration standaloneConfig = CreateClientConfigBuilder(configuration).Build(); + // ... connection establishment + return new(configuration, await Database.Create(config)); +} +``` + +**Storage**: The original `ConfigurationOptions` is stored in `RawConfig` property (line 156): + +```csharp +internal ConfigurationOptions RawConfig { private set; get; } +``` + +### Limitations for Runtime Configuration Changes + +1. **No Reconfiguration API**: `ConnectionMultiplexer` doesn't expose methods to change configuration after connection +2. **Immutable Builder Chain**: Once built, the configuration flows to FFI layer and cannot be modified +3. **Connection Recreation Required**: Any configuration change requires creating a new `ConnectionMultiplexer` instance + +## Potential Configuration Change Approaches + +### 1. Connection Recreation (Current Pattern) +```csharp +// Current approach - requires new connection +var newConfig = oldConfig.Clone(); +newConfig.ReadFrom = new ReadFrom(ReadFromStrategy.AzAffinity, "us-west-2"); +var newMultiplexer = await ConnectionMultiplexer.ConnectAsync(newConfig); +``` + +### 2. Potential Runtime Reconfiguration (Not Currently Implemented) + +To enable runtime configuration changes, the following would need to be implemented: + +```csharp +// Hypothetical API +public async Task ReconfigureAsync(Action configure) +{ + var newConfig = RawConfig.Clone(); + configure(newConfig); + + // Would need to: + // 1. Validate configuration changes + // 2. Update underlying client configuration + // 3. Potentially recreate connections + // 4. Update RawConfig +} +``` + +### 3. Builder Pattern Extension + +A potential approach could extend the builder pattern to support updates: + +```csharp +// Hypothetical API +public async Task TryUpdateConfigurationAsync(Action configure) + where T : ClientConfigurationBuilder, new() +{ + // Create new builder from current configuration + // Apply changes + // Validate and apply if possible +} +``` + +## ReadFrom Configuration Specifics + +### Current Implementation +- `ReadFrom` is a struct (line 74) with `ReadFromStrategy` enum and optional AZ string +- Mapped in `CreateClientConfigBuilder()` at line 199 +- Flows through to FFI layer via `ConnectionConfig.ToFfi()` method + +### ReadFrom Change Requirements +To change `ReadFrom` configuration at runtime would require: +1. **API Design**: Method to accept new `ReadFrom` configuration +2. **Validation**: Ensure new configuration is compatible with current connection type +3. **FFI Updates**: Update the underlying client configuration +4. **Connection Management**: Handle any required connection reestablishment + +## Recommendations + +### Short Term +1. **Document Current Limitations**: Clearly document that configuration changes require connection recreation +2. **Helper Methods**: Provide utility methods for common reconfiguration scenarios: + ```csharp + public static async Task RecreateWithReadFromAsync( + ConnectionMultiplexer current, + ReadFrom newReadFrom) + ``` + +### Long Term +1. **Runtime Reconfiguration API**: Implement selective runtime configuration updates for non-disruptive changes +2. **Configuration Validation**: Add validation to determine which changes require reconnection vs. runtime updates +3. **Connection Pool Management**: Consider connection pooling to minimize disruption during reconfiguration + +## Conclusion + +Currently, the `ConnectionMultiplexer` does not support runtime configuration changes. The architecture is designed around immutable configuration set at connection time. Any configuration changes, including `ReadFrom` strategy modifications, require creating a new `ConnectionMultiplexer` instance. + +The relationship between `ConfigurationOptions` and `ConnectionConfiguration` is a translation layer where the external API (`ConfigurationOptions`) is converted to internal configuration structures (`ConnectionConfiguration`) that interface with the FFI layer. diff --git a/monitor-valkey.sh b/monitor-valkey.sh new file mode 100755 index 00000000..09aba107 --- /dev/null +++ b/monitor-valkey.sh @@ -0,0 +1,19 @@ +#!/bin/zsh +# Streams all commands from Valkey and logs them to a file with timestamps in ./log directory. + +LOGDIR="./log" +LOGFILE="$LOGDIR/valkey-monitor.log" + +# Ensure directory exists +if [ ! -d "$LOGDIR" ]; then + mkdir -p "$LOGDIR" +fi +# Ensure log file exists +if [ ! -f "$LOGFILE" ]; then + touch "$LOGFILE" +fi + +# Run MONITOR and prepend timestamps using date +valkey-cli MONITOR | while read -r line; do + echo "$(date '+[%Y-%m-%d %H:%M:%S]') $line" +done >> "$LOGFILE" diff --git a/rust/src/ffi.rs b/rust/src/ffi.rs index 1ea81a44..1b208cd2 100644 --- a/rust/src/ffi.rs +++ b/rust/src/ffi.rs @@ -71,16 +71,95 @@ pub struct ConnectionConfig { pub protocol: redis::ProtocolVersion, /// zero pointer is valid, means no client name is given (`None`) pub client_name: *const c_char, + pub has_pubsub_config: bool, + pub pubsub_config: PubSubConfigInfo, /* TODO below pub periodic_checks: Option, - pub pubsub_subscriptions: Option, pub inflight_requests_limit: Option, pub otel_endpoint: Option, pub otel_flush_interval_ms: Option, */ } +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct PubSubConfigInfo { + pub channels_ptr: *const *const c_char, + pub channel_count: u32, + pub patterns_ptr: *const *const c_char, + pub pattern_count: u32, + pub sharded_channels_ptr: *const *const c_char, + pub sharded_channel_count: u32, +} + +/// Convert a C string array to a Vec of Vec +/// +/// # Safety +/// +/// * `ptr` must point to an array of `count` valid C string pointers +/// * Each C string pointer must be valid and null-terminated +unsafe fn convert_string_array(ptr: *const *const c_char, count: u32) -> Vec> { + if ptr.is_null() || count == 0 { + return Vec::new(); + } + + let slice = unsafe { std::slice::from_raw_parts(ptr, count as usize) }; + slice + .iter() + .map(|&str_ptr| { + let c_str = unsafe { CStr::from_ptr(str_ptr) }; + c_str.to_bytes().to_vec() + }) + .collect() +} + +/// Convert PubSubConfigInfo to the format expected by glide-core +/// +/// # Safety +/// +/// * All pointers in `config` must be valid or null +/// * String arrays must contain valid C strings +unsafe fn convert_pubsub_config( + config: &PubSubConfigInfo, +) -> std::collections::HashMap>> { + use redis::PubSubSubscriptionKind; + use std::collections::{HashMap, HashSet}; + + let mut subscriptions = HashMap::new(); + + // Convert exact channels + if config.channel_count > 0 { + let channels = unsafe { convert_string_array(config.channels_ptr, config.channel_count) }; + subscriptions.insert( + PubSubSubscriptionKind::Exact, + channels.into_iter().collect::>(), + ); + } + + // Convert patterns + if config.pattern_count > 0 { + let patterns = unsafe { convert_string_array(config.patterns_ptr, config.pattern_count) }; + subscriptions.insert( + PubSubSubscriptionKind::Pattern, + patterns.into_iter().collect::>(), + ); + } + + // Convert sharded channels + if config.sharded_channel_count > 0 { + let sharded = unsafe { + convert_string_array(config.sharded_channels_ptr, config.sharded_channel_count) + }; + subscriptions.insert( + PubSubSubscriptionKind::Sharded, + sharded.into_iter().collect::>(), + ); + } + + subscriptions +} + /// Convert connection configuration to a corresponding object. /// /// # Safety @@ -147,9 +226,18 @@ pub(crate) unsafe fn create_connection_request( } else { None }, + pubsub_subscriptions: if config.has_pubsub_config { + let subscriptions = unsafe { convert_pubsub_config(&config.pubsub_config) }; + if subscriptions.is_empty() { + None + } else { + Some(subscriptions) + } + } else { + None + }, // TODO below periodic_checks: None, - pubsub_subscriptions: None, inflight_requests_limit: None, lazy_connect: false, } @@ -593,3 +681,29 @@ pub(crate) unsafe fn get_pipeline_options( PipelineRetryStrategy::new(info.retry_server_error, info.retry_connection_error), ) } + +/// FFI callback function type for PubSub messages. +/// This callback is invoked by Rust when a PubSub message is received. +/// The callback signature matches the C# expectations for marshaling PubSub data. +/// +/// # Parameters +/// * `push_kind` - The type of push notification (message, pmessage, smessage, etc.) +/// * `message_ptr` - Pointer to the raw message bytes +/// * `message_len` - Length of the message data in bytes +/// * `channel_ptr` - Pointer to the raw channel name bytes +/// * `channel_len` - Length of the channel name in bytes +/// * `pattern_ptr` - Pointer to the raw pattern bytes (null if no pattern) +/// * `pattern_len` - Length of the pattern in bytes (0 if no pattern) +pub type PubSubCallback = unsafe extern "C" fn( + push_kind: u32, + message_ptr: *const u8, + message_len: i64, + channel_ptr: *const u8, + channel_len: i64, + pattern_ptr: *const u8, + pattern_len: i64, +); + +// PubSub callback functions removed - using instance-based callbacks instead. +// The pubsub_callback parameter in create_client will be used to configure glide-core's +// PubSub message handler when full integration is implemented. diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 9b857a75..12ca4616 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -2,8 +2,9 @@ mod ffi; use ffi::{ - BatchInfo, BatchOptionsInfo, CmdInfo, ConnectionConfig, ResponseValue, RouteInfo, create_cmd, - create_connection_request, create_pipeline, create_route, get_pipeline_options, + BatchInfo, BatchOptionsInfo, CmdInfo, ConnectionConfig, PubSubCallback, ResponseValue, + RouteInfo, create_cmd, create_connection_request, create_pipeline, create_route, + get_pipeline_options, }; use glide_core::{ client::Client as GlideClient, @@ -28,6 +29,8 @@ pub enum Level { pub struct Client { runtime: Runtime, core: Arc, + pubsub_shutdown: std::sync::Mutex>>, + pubsub_task: std::sync::Mutex>>, } /// Success callback that is called when a command succeeds. @@ -120,12 +123,15 @@ impl Drop for PanicGuard { /// * `config` must be a valid [`ConnectionConfig`] pointer. See the safety documentation of [`create_connection_request`]. /// * `success_callback` and `failure_callback` must be valid pointers to the corresponding FFI functions. /// See the safety documentation of [`SuccessCallback`] and [`FailureCallback`]. +/// * `pubsub_callback` is an optional callback. When provided, it must be a valid function pointer. +/// See the safety documentation in the FFI module for PubSubCallback. #[allow(rustdoc::private_intra_doc_links)] #[unsafe(no_mangle)] pub unsafe extern "C-unwind" fn create_client( config: *const ConnectionConfig, success_callback: SuccessCallback, failure_callback: FailureCallback, + #[allow(unused_variables)] pubsub_callback: Option, ) { let mut panic_guard = PanicGuard { panicked: true, @@ -134,6 +140,7 @@ pub unsafe extern "C-unwind" fn create_client( }; let request = unsafe { create_connection_request(config) }; + let runtime = Builder::new_multi_thread() .enable_all() .worker_threads(10) @@ -142,7 +149,15 @@ pub unsafe extern "C-unwind" fn create_client( .unwrap(); let _runtime_handle = runtime.enter(); - let res = runtime.block_on(GlideClient::new(request, None)); + + // Set up push notification channel if PubSub subscriptions are configured + // The callback is optional - users can use queue-based message retrieval instead + let is_subscriber = request.pubsub_subscriptions.is_some(); + + let (push_tx, mut push_rx) = tokio::sync::mpsc::unbounded_channel(); + let tx = if is_subscriber { Some(push_tx) } else { None }; + + let res = runtime.block_on(GlideClient::new(request, tx)); match res { Ok(client) => { let core = Arc::new(CommandExecutionCore { @@ -151,7 +166,56 @@ pub unsafe extern "C-unwind" fn create_client( client, }); - let client_ptr = Arc::into_raw(Arc::new(Client { runtime, core })); + // Set up graceful shutdown coordination for PubSub task + // Only spawn the callback task if a callback is provided + let (pubsub_shutdown, pubsub_task) = if is_subscriber && pubsub_callback.is_some() { + let callback = pubsub_callback.unwrap(); + let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel(); + + let task_handle = runtime.spawn(async move { + logger_core::log(logger_core::Level::Info, "pubsub", "PubSub task started"); + + loop { + tokio::select! { + Some(push_msg) = push_rx.recv() => { + unsafe { + process_push_notification(push_msg, callback); + } + } + _ = &mut shutdown_rx => { + logger_core::log( + logger_core::Level::Info, + "pubsub", + "PubSub task received shutdown signal", + ); + break; + } + } + } + + logger_core::log( + logger_core::Level::Info, + "pubsub", + "PubSub task completed gracefully", + ); + }); + + ( + std::sync::Mutex::new(Some(shutdown_tx)), + std::sync::Mutex::new(Some(task_handle)), + ) + } else { + (std::sync::Mutex::new(None), std::sync::Mutex::new(None)) + }; + + let client_adapter = Arc::new(Client { + runtime, + core, + pubsub_shutdown, + pubsub_task, + }); + let client_ptr = Arc::into_raw(client_adapter.clone()); + unsafe { success_callback(0, client_ptr as *const ResponseValue) }; } Err(err) => { @@ -170,10 +234,144 @@ pub unsafe extern "C-unwind" fn create_client( drop(panic_guard); } +/// Processes a push notification message and calls the provided callback function. +/// +/// This function extracts the message data from the PushInfo and invokes the C# callback +/// with the appropriate parameters using scoped lifetime management to prevent memory leaks. +/// +/// # Parameters +/// - `push_msg`: The push notification message to process. +/// - `pubsub_callback`: The callback function to invoke with the processed notification. +/// +/// # Safety +/// This function is unsafe because it: +/// - Calls an FFI function (`pubsub_callback`) that may have undefined behavior +/// - Assumes push_msg.data contains valid BulkString values +/// +/// The caller must ensure: +/// - `pubsub_callback` is a valid function pointer to a properly implemented callback +/// - The callback copies data synchronously before returning +/// +/// # Memory Safety +/// This implementation uses scoped lifetime management instead of `std::mem::forget()`. +/// Vec instances are kept alive during callback execution and automatically cleaned up +/// when the function exits, preventing memory leaks. +unsafe fn process_push_notification(push_msg: redis::PushInfo, pubsub_callback: PubSubCallback) { + use redis::Value; + + // Convert all values to Vec, handling both BulkString and Int types + let strings: Vec> = push_msg + .data + .into_iter() + .map(|value| match value { + Value::BulkString(bytes) => bytes, + Value::Int(num) => num.to_string().into_bytes(), + Value::SimpleString(s) => s.into_bytes(), + _ => { + logger_core::log( + logger_core::Level::Warn, + "pubsub", + &format!("Unexpected value type in PubSub message: {:?}", value), + ); + Vec::new() + } + }) + .collect(); + + // Store the kind to avoid move issues + let push_kind = push_msg.kind.clone(); + + // Validate message structure based on PushKind and convert to FFI kind + // These values MUST match the C# PushKind enum in FFI.structs.cs + let (pattern, channel, message, kind) = match (push_kind.clone(), strings.len()) { + (redis::PushKind::Message, 2) => { + // Regular message: [channel, message] -> PushMessage = 3 + (None, &strings[0], &strings[1], 3u32) + } + (redis::PushKind::PMessage, 3) => { + // Pattern message: [pattern, channel, message] -> PushPMessage = 4 + (Some(&strings[0]), &strings[1], &strings[2], 4u32) + } + (redis::PushKind::SMessage, 2) => { + // Sharded message: [channel, message] -> PushSMessage = 5 + (None, &strings[0], &strings[1], 5u32) + } + (redis::PushKind::Subscribe, 2) => { + // Subscribe confirmation: [channel, count] -> PushSubscribe = 9 + (None, &strings[0], &strings[1], 9u32) + } + (redis::PushKind::PSubscribe, 3) => { + // Pattern subscribe confirmation: [pattern, channel, count] -> PushPSubscribe = 10 + (Some(&strings[0]), &strings[1], &strings[2], 10u32) + } + (redis::PushKind::SSubscribe, 2) => { + // Sharded subscribe confirmation: [channel, count] -> PushSSubscribe = 11 + (None, &strings[0], &strings[1], 11u32) + } + (redis::PushKind::Unsubscribe, 2) => { + // Unsubscribe confirmation: [channel, count] -> PushUnsubscribe = 6 + (None, &strings[0], &strings[1], 6u32) + } + (redis::PushKind::PUnsubscribe, 3) => { + // Pattern unsubscribe confirmation: [pattern, channel, count] -> PushPUnsubscribe = 7 + (Some(&strings[0]), &strings[1], &strings[2], 7u32) + } + (redis::PushKind::SUnsubscribe, 2) => { + // Sharded unsubscribe confirmation: [channel, count] -> PushSUnsubscribe = 8 + (None, &strings[0], &strings[1], 8u32) + } + (redis::PushKind::Disconnection, _) => { + logger_core::log( + logger_core::Level::Info, + "pubsub", + "PubSub disconnection received", + ); + return; + } + (kind, len) => { + logger_core::log( + logger_core::Level::Error, + "pubsub", + &format!( + "Invalid PubSub message structure: kind={:?}, len={}", + kind, len + ), + ); + return; + } + }; + + // Prepare pointers while keeping strings alive + let pattern_ptr = pattern.map(|p| p.as_ptr()).unwrap_or(std::ptr::null()); + let pattern_len = pattern.map(|p| p.len() as i64).unwrap_or(0); + let channel_ptr = channel.as_ptr(); + let channel_len = channel.len() as i64; + let message_ptr = message.as_ptr(); + let message_len = message.len() as i64; + + // Call callback while strings are still alive + unsafe { + pubsub_callback( + kind, + message_ptr, + message_len, + channel_ptr, + channel_len, + pattern_ptr, + pattern_len, + ); + } + + // Vec instances are automatically cleaned up here + // No memory leak, no use-after-free +} + /// Closes the given client, deallocating it from the heap. /// This function should only be called once per pointer created by [`create_client`]. /// After calling this function the `client_ptr` is not in a valid state. /// +/// Implements graceful shutdown coordination for PubSub tasks with timeout. +/// /// # Safety /// /// * `client_ptr` must not be `null`. @@ -181,6 +379,71 @@ pub unsafe extern "C-unwind" fn create_client( #[unsafe(no_mangle)] pub extern "C" fn close_client(client_ptr: *const c_void) { assert!(!client_ptr.is_null()); + + // Get a reference to the client to access shutdown coordination + let client = unsafe { &*(client_ptr as *const Client) }; + + // Take ownership of shutdown sender and signal graceful shutdown + if let Ok(mut guard) = client.pubsub_shutdown.lock() { + if let Some(shutdown_tx) = guard.take() { + logger_core::log( + logger_core::Level::Debug, + "pubsub", + "Signaling PubSub task to shutdown", + ); + + // Send shutdown signal (ignore error if receiver already dropped) + let _ = shutdown_tx.send(()); + } + } + + // Take ownership of task handle and wait for completion with timeout + if let Ok(mut guard) = client.pubsub_task.lock() { + if let Some(task_handle) = guard.take() { + let timeout = std::time::Duration::from_secs(5); + + logger_core::log( + logger_core::Level::Debug, + "pubsub", + &format!( + "Waiting for PubSub task to complete (timeout: {:?})", + timeout + ), + ); + + let result = client + .runtime + .block_on(async { tokio::time::timeout(timeout, task_handle).await }); + + match result { + Ok(Ok(())) => { + logger_core::log( + logger_core::Level::Info, + "pubsub", + "PubSub task completed successfully", + ); + } + Ok(Err(e)) => { + logger_core::log( + logger_core::Level::Warn, + "pubsub", + &format!("PubSub task completed with error: {:?}", e), + ); + } + Err(_) => { + logger_core::log( + logger_core::Level::Warn, + "pubsub", + &format!( + "PubSub task did not complete within timeout ({:?})", + timeout + ), + ); + } + } + } + } + // This will bring the strong count down to 0 once all client requests are done. unsafe { Arc::decrement_strong_count(client_ptr as *const Client) }; } diff --git a/sources/Valkey.Glide/BaseClient.cs b/sources/Valkey.Glide/BaseClient.cs index 1bd68a45..8439b8ec 100644 --- a/sources/Valkey.Glide/BaseClient.cs +++ b/sources/Valkey.Glide/BaseClient.cs @@ -1,6 +1,7 @@ // Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 using System.Runtime.InteropServices; +using System.Threading.Channels; using Valkey.Glide.Internals; using Valkey.Glide.Pipeline; @@ -26,6 +27,10 @@ public void Dispose() { return; } + + // Clean up PubSub resources + CleanupPubSubResources(); + _messageContainer.DisposeWithError(null); CloseClientFfi(_clientPointer); _clientPointer = IntPtr.Zero; @@ -38,6 +43,28 @@ public void Dispose() public override int GetHashCode() => (int)_clientPointer; + /// + /// Get the PubSub message queue for manual message retrieval. + /// Returns null if no PubSub subscriptions are configured. + /// Uses thread-safe access to prevent race conditions. + /// + public PubSubMessageQueue? PubSubQueue + { + get + { + lock (_pubSubLock) + { + return _pubSubHandler?.GetQueue(); + } + } + } + + /// + /// Indicates whether this client has PubSub subscriptions configured. + /// Uses volatile read for thread-safe access without locking. + /// + public bool HasPubSubSubscriptions => _pubSubHandler != null; + #endregion public methods #region protected methods @@ -48,13 +75,23 @@ protected static async Task CreateClient(BaseClientConfiguration config, F nint successCallbackPointer = Marshal.GetFunctionPointerForDelegate(client._successCallbackDelegate); nint failureCallbackPointer = Marshal.GetFunctionPointerForDelegate(client._failureCallbackDelegate); + // Get PubSub callback pointer if PubSub subscriptions are configured + nint pubsubCallbackPointer = IntPtr.Zero; + if (config.Request.PubSubSubscriptions != null) + { + pubsubCallbackPointer = Marshal.GetFunctionPointerForDelegate(client._pubsubCallbackDelegate); + } + using FFI.ConnectionConfig request = config.Request.ToFfi(); Message message = client._messageContainer.GetMessageForCall(); - CreateClientFfi(request.ToPtr(), successCallbackPointer, failureCallbackPointer); + CreateClientFfi(request.ToPtr(), successCallbackPointer, failureCallbackPointer, pubsubCallbackPointer); client._clientPointer = await message; // This will throw an error thru failure callback if any if (client._clientPointer != IntPtr.Zero) { + // Initialize PubSub handler if subscriptions are configured + client.InitializePubSubHandler(config.Request.PubSubSubscriptions); + // Initialize server version after successful connection await client.InitializeServerVersionAsync(); return client; @@ -67,6 +104,7 @@ protected BaseClient() { _successCallbackDelegate = SuccessCallback; _failureCallbackDelegate = FailureCallback; + _pubsubCallbackDelegate = PubSubCallback; _messageContainer = new(this); } @@ -141,16 +179,297 @@ private void FailureCallback(ulong index, IntPtr strPtr, RequestErrorType errTyp _ = Task.Run(() => _messageContainer.GetMessage((int)index).SetException(Create(errType, str))); } + private void PubSubCallback( + uint pushKind, + IntPtr messagePtr, + long messageLen, + IntPtr channelPtr, + long channelLen, + IntPtr patternPtr, + long patternLen) + { + try + { + // Only process actual message notifications, ignore subscription confirmations + if (!IsMessageNotification((PushKind)pushKind)) + { + Logger.Log(Level.Debug, "PubSubCallback", $"PubSub notification received: {(PushKind)pushKind}"); + return; + } + + // Marshal the message from FFI callback parameters + PubSubMessage message = MarshalPubSubMessage( + (PushKind)pushKind, + messagePtr, + messageLen, + channelPtr, + channelLen, + patternPtr, + patternLen); + + // Write to channel (non-blocking with backpressure) + Channel? channel = _messageChannel; + if (channel != null) + { + if (!channel.Writer.TryWrite(message)) + { + Logger.Log(Level.Warn, "PubSubCallback", + $"PubSub message channel full, message dropped for channel {message.Channel}"); + } + } + } + catch (Exception ex) + { + Logger.Log(Level.Error, "PubSubCallback", + $"Error in PubSub callback: {ex.Message}", ex); + } + } + + private static bool IsMessageNotification(PushKind pushKind) => + pushKind switch + { + PushKind.PushMessage => true, // Regular channel message + PushKind.PushPMessage => true, // Pattern-based message + PushKind.PushSMessage => true, // Sharded channel message + PushKind.PushDisconnection => false, + PushKind.PushOther => false, + PushKind.PushInvalidate => false, + PushKind.PushUnsubscribe => false, + PushKind.PushPUnsubscribe => false, + PushKind.PushSUnsubscribe => false, + PushKind.PushSubscribe => false, + PushKind.PushPSubscribe => false, + PushKind.PushSSubscribe => false, + _ => false + }; + + private static PubSubMessage MarshalPubSubMessage( + PushKind pushKind, + IntPtr messagePtr, + long messageLen, + IntPtr channelPtr, + long channelLen, + IntPtr patternPtr, + long patternLen) + { + // Marshal the raw byte pointers to byte arrays + byte[] messageBytes = new byte[messageLen]; + Marshal.Copy(messagePtr, messageBytes, 0, (int)messageLen); + + byte[] channelBytes = new byte[channelLen]; + Marshal.Copy(channelPtr, channelBytes, 0, (int)channelLen); + + byte[]? patternBytes = null; + if (patternPtr != IntPtr.Zero && patternLen > 0) + { + patternBytes = new byte[patternLen]; + Marshal.Copy(patternPtr, patternBytes, 0, (int)patternLen); + } + + // Convert to strings (assuming UTF-8 encoding) + string message = System.Text.Encoding.UTF8.GetString(messageBytes); + string channel = System.Text.Encoding.UTF8.GetString(channelBytes); + string? pattern = patternBytes != null ? System.Text.Encoding.UTF8.GetString(patternBytes) : null; + + // Create the appropriate PubSubMessage based on whether pattern is present + return pattern != null + ? new PubSubMessage(message, channel, pattern) + : new PubSubMessage(message, channel); + } + ~BaseClient() => Dispose(); internal void SetInfo(string info) => _clientInfo = info; protected abstract Task InitializeServerVersionAsync(); + /// + /// Initializes PubSub message handling if PubSub subscriptions are configured. + /// Uses thread-safe initialization to ensure proper visibility across threads. + /// + /// The PubSub subscription configuration. + private void InitializePubSubHandler(BasePubSubSubscriptionConfig? config) + { + if (config == null) + { + return; + } + + lock (_pubSubLock) + { + // Get performance configuration or use defaults + PubSubPerformanceConfig perfConfig = config.PerformanceConfig ?? new(); + + // Store shutdown timeout for use during disposal + _shutdownTimeout = perfConfig.ShutdownTimeout; + + // Create bounded channel with configurable capacity and backpressure strategy + BoundedChannelOptions channelOptions = new(perfConfig.ChannelCapacity) + { + FullMode = perfConfig.FullMode, + SingleReader = true, // Optimization: only one processor task + SingleWriter = false // Multiple FFI callbacks may write + }; + + _messageChannel = Channel.CreateBounded(channelOptions); + _processingCancellation = new CancellationTokenSource(); + + // Create message handler + _pubSubHandler = new PubSubMessageHandler(config.Callback, config.Context); + + // Start dedicated processing task with graceful shutdown support + _messageProcessingTask = Task.Run(async () => + { + try + { + Logger.Log(Level.Debug, "BaseClient", "PubSub processing task started"); + + await foreach (PubSubMessage message in _messageChannel.Reader.ReadAllAsync(_processingCancellation.Token)) + { + try + { + // Thread-safe access to handler + PubSubMessageHandler? handler = _pubSubHandler; + if (handler != null && !_processingCancellation.Token.IsCancellationRequested) + { + handler.HandleMessage(message); + } + } + catch (Exception ex) + { + Logger.Log(Level.Error, "BaseClient", + $"Error processing PubSub message: {ex.Message}", ex); + } + } + + Logger.Log(Level.Debug, "BaseClient", "PubSub processing task completing normally"); + } + catch (OperationCanceledException) + { + Logger.Log(Level.Info, "BaseClient", "PubSub processing cancelled gracefully"); + } + catch (Exception ex) + { + Logger.Log(Level.Error, "BaseClient", + $"PubSub processing task failed: {ex.Message}", ex); + } + }, _processingCancellation.Token); + } + } + + /// + /// Handles incoming PubSub messages from the FFI layer. + /// This method is called directly by the FFI callback and uses thread-safe access to the handler. + /// + /// The PubSub message to handle. + internal virtual void HandlePubSubMessage(PubSubMessage message) + { + // Thread-safe access to handler - use local copy to avoid race conditions + PubSubMessageHandler? handler; + lock (_pubSubLock) + { + handler = _pubSubHandler; + } + + if (handler != null) + { + try + { + handler.HandleMessage(message); + } + catch (Exception ex) + { + // Log the error but don't let exceptions escape + Logger.Log(Level.Error, "BaseClient", $"Error handling PubSub message: {ex.Message}", ex); + } + } + } + + /// + /// Cleans up PubSub resources during client disposal with proper synchronization. + /// Uses locking to coordinate safe disposal and prevent conflicts with concurrent message processing. + /// Implements graceful shutdown with configurable timeout. + /// + private void CleanupPubSubResources() + { + PubSubMessageHandler? handler = null; + Channel? channel = null; + Task? processingTask = null; + CancellationTokenSource? cancellation = null; + TimeSpan shutdownTimeout = _shutdownTimeout; + + // Acquire lock and capture references, then set to null + lock (_pubSubLock) + { + handler = _pubSubHandler; + channel = _messageChannel; + processingTask = _messageProcessingTask; + cancellation = _processingCancellation; + + _pubSubHandler = null; + _messageChannel = null; + _messageProcessingTask = null; + _processingCancellation = null; + } + + // Cleanup outside of lock to prevent deadlocks + try + { + Logger.Log(Level.Debug, "BaseClient", "Initiating graceful PubSub shutdown"); + + // Signal shutdown to processing task + cancellation?.Cancel(); + + // Complete channel to stop message processing + // This will cause the ReadAllAsync to complete after processing remaining messages + channel?.Writer.Complete(); + + // Wait for processing task to complete (with timeout) + if (processingTask != null) + { + bool completed = processingTask.Wait(shutdownTimeout); + if (completed) + { + Logger.Log(Level.Info, "BaseClient", "PubSub processing task completed gracefully"); + } + else + { + Logger.Log(Level.Warn, "BaseClient", + $"PubSub processing task did not complete within timeout ({shutdownTimeout.TotalSeconds}s)"); + } + } + + // Dispose resources + handler?.Dispose(); + cancellation?.Dispose(); + + Logger.Log(Level.Debug, "BaseClient", "PubSub cleanup completed"); + } + catch (AggregateException ex) + { + Logger.Log(Level.Warn, "BaseClient", + $"Error during PubSub cleanup: {ex.InnerException?.Message ?? ex.Message}", ex); + } + catch (Exception ex) + { + Logger.Log(Level.Warn, "BaseClient", + $"Error during PubSub cleanup: {ex.Message}", ex); + } + } + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] private delegate void SuccessAction(ulong index, IntPtr ptr); [UnmanagedFunctionPointer(CallingConvention.Cdecl)] private delegate void FailureAction(ulong index, IntPtr strPtr, RequestErrorType err); + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + private delegate void PubSubAction( + uint pushKind, + IntPtr messagePtr, + long messageLen, + IntPtr channelPtr, + long channelLen, + IntPtr patternPtr, + long patternLen); #endregion private methods #region private fields @@ -163,6 +482,10 @@ private void FailureCallback(ulong index, IntPtr strPtr, RequestErrorType errTyp /// and held in order to prevent the cost of marshalling on each function call. private readonly SuccessAction _successCallbackDelegate; + /// Held as a measure to prevent the delegate being garbage collected. These are delegated once + /// and held in order to prevent the cost of marshalling on each function call. + private readonly PubSubAction _pubsubCallbackDelegate; + /// Raw pointer to the underlying native client. private IntPtr _clientPointer; private readonly MessageContainer _messageContainer; @@ -170,5 +493,24 @@ private void FailureCallback(ulong index, IntPtr strPtr, RequestErrorType errTyp private string _clientInfo = ""; // used to distinguish and identify clients during tests protected Version? _serverVersion; // cached server version + /// PubSub message handler for routing messages to callbacks or queues. + /// Uses volatile to ensure visibility across threads without locking on every read. + private volatile PubSubMessageHandler? _pubSubHandler; + + /// Lock object for coordinating PubSub handler access and disposal. + private readonly object _pubSubLock = new(); + + /// Channel for bounded message queuing with backpressure support. + private Channel? _messageChannel; + + /// Dedicated background task for processing PubSub messages. + private Task? _messageProcessingTask; + + /// Cancellation token source for graceful shutdown of message processing. + private CancellationTokenSource? _processingCancellation; + + /// Timeout for graceful shutdown of PubSub processing. + private TimeSpan _shutdownTimeout = TimeSpan.FromSeconds(PubSubPerformanceConfig.DefaultShutdownTimeoutSeconds); + #endregion private fields } diff --git a/sources/Valkey.Glide/ConnectionConfiguration.cs b/sources/Valkey.Glide/ConnectionConfiguration.cs index 0d28623f..9e0dce0d 100644 --- a/sources/Valkey.Glide/ConnectionConfiguration.cs +++ b/sources/Valkey.Glide/ConnectionConfiguration.cs @@ -24,9 +24,10 @@ internal record ConnectionConfig public uint DatabaseId; public Protocol? Protocol; public string? ClientName; + public BasePubSubSubscriptionConfig? PubSubSubscriptions; internal FFI.ConnectionConfig ToFfi() => - new(Addresses, TlsMode, ClusterMode, (uint?)RequestTimeout?.TotalMilliseconds, (uint?)ConnectionTimeout?.TotalMilliseconds, ReadFrom, RetryStrategy, AuthenticationInfo, DatabaseId, Protocol, ClientName); + new(Addresses, TlsMode, ClusterMode, (uint?)RequestTimeout?.TotalMilliseconds, (uint?)ConnectionTimeout?.TotalMilliseconds, ReadFrom, RetryStrategy, AuthenticationInfo, DatabaseId, Protocol, ClientName, PubSubSubscriptions); } /// @@ -549,6 +550,23 @@ public StandaloneClientConfigurationBuilder() : base(false) { } /// Complete the configuration with given settings. /// public new StandaloneClientConfiguration Build() => new() { Request = base.Build() }; + + #region PubSub Subscriptions + /// + /// Configure PubSub subscriptions for the standalone client. + /// + /// The PubSub subscription configuration. + /// This configuration builder instance for method chaining. + /// Thrown when config is null. + /// Thrown when config is invalid. + public StandaloneClientConfigurationBuilder WithPubSubSubscriptions(StandalonePubSubSubscriptionConfig config) + { + ArgumentNullException.ThrowIfNull(config); + config.Validate(); + Config.PubSubSubscriptions = config; + return this; + } + #endregion } /// @@ -563,5 +581,22 @@ public ClusterClientConfigurationBuilder() : base(true) { } /// Complete the configuration with given settings. /// public new ClusterClientConfiguration Build() => new() { Request = base.Build() }; + + #region PubSub Subscriptions + /// + /// Configure PubSub subscriptions for the cluster client. + /// + /// The PubSub subscription configuration. + /// This configuration builder instance for method chaining. + /// Thrown when config is null. + /// Thrown when config is invalid. + public ClusterClientConfigurationBuilder WithPubSubSubscriptions(ClusterPubSubSubscriptionConfig config) + { + ArgumentNullException.ThrowIfNull(config); + config.Validate(); + Config.PubSubSubscriptions = config; + return this; + } + #endregion } } diff --git a/sources/Valkey.Glide/Internals/FFI.methods.cs b/sources/Valkey.Glide/Internals/FFI.methods.cs index 0d8d7de9..c8c2f639 100644 --- a/sources/Valkey.Glide/Internals/FFI.methods.cs +++ b/sources/Valkey.Glide/Internals/FFI.methods.cs @@ -10,6 +10,25 @@ namespace Valkey.Glide.Internals; internal partial class FFI { + /// + /// FFI callback delegate for PubSub message reception matching the Rust FFI signature. + /// + /// The type of push notification received. + /// Pointer to the raw message bytes. + /// The length of the message data in bytes. + /// Pointer to the raw channel name bytes. + /// The length of the channel name in bytes. + /// Pointer to the raw pattern bytes (null if no pattern). + /// The length of the pattern in bytes (0 if no pattern). + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + internal delegate void PubSubMessageCallback( + uint pushKind, + IntPtr messagePtr, + long messageLen, + IntPtr channelPtr, + long channelLen, + IntPtr patternPtr, + long patternLen); #if NET8_0_OR_GREATER [LibraryImport("libglide_rs", EntryPoint = "command")] [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] @@ -25,11 +44,13 @@ internal partial class FFI [LibraryImport("libglide_rs", EntryPoint = "create_client")] [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] - public static partial void CreateClientFfi(IntPtr config, IntPtr successCallback, IntPtr failureCallback); + public static partial void CreateClientFfi(IntPtr config, IntPtr successCallback, IntPtr failureCallback, IntPtr pubsubCallback); [LibraryImport("libglide_rs", EntryPoint = "close_client")] [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] public static partial void CloseClientFfi(IntPtr client); + + #else [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "command")] public static extern void CommandFfi(IntPtr client, ulong index, IntPtr cmdInfo, IntPtr routeInfo); @@ -41,9 +62,11 @@ internal partial class FFI public static extern void FreeResponse(IntPtr response); [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "create_client")] - public static extern void CreateClientFfi(IntPtr config, IntPtr successCallback, IntPtr failureCallback); + public static extern void CreateClientFfi(IntPtr config, IntPtr successCallback, IntPtr failureCallback, IntPtr pubsubCallback); [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "close_client")] public static extern void CloseClientFfi(IntPtr client); + + #endif } diff --git a/sources/Valkey.Glide/Internals/FFI.structs.cs b/sources/Valkey.Glide/Internals/FFI.structs.cs index 4a8a4037..f7649d66 100644 --- a/sources/Valkey.Glide/Internals/FFI.structs.cs +++ b/sources/Valkey.Glide/Internals/FFI.structs.cs @@ -200,6 +200,10 @@ internal class ConnectionConfig : Marshallable { private ConnectionRequest _request; private readonly List _addresses; + private readonly BasePubSubSubscriptionConfig? _pubSubConfig; + private IntPtr _pubSubChannelsPtr = IntPtr.Zero; + private IntPtr _pubSubPatternsPtr = IntPtr.Zero; + private IntPtr _pubSubShardedChannelsPtr = IntPtr.Zero; public ConnectionConfig( List addresses, @@ -212,9 +216,11 @@ public ConnectionConfig( AuthenticationInfo? authenticationInfo, uint databaseId, ConnectionConfiguration.Protocol? protocol, - string? clientName) + string? clientName, + BasePubSubSubscriptionConfig? pubSubSubscriptions) { _addresses = addresses; + _pubSubConfig = pubSubSubscriptions; _request = new() { AddressCount = (nuint)addresses.Count, @@ -235,10 +241,47 @@ public ConnectionConfig( HasProtocol = protocol.HasValue, Protocol = protocol ?? default, ClientName = clientName, + HasPubSubConfig = pubSubSubscriptions != null, + PubSubConfig = new PubSubConfigInfo() }; } - protected override void FreeMemory() => Marshal.FreeHGlobal(_request.Addresses); + protected override void FreeMemory() + { + Marshal.FreeHGlobal(_request.Addresses); + + if (_pubSubConfig != null) + { + int channelCount = _pubSubConfig.Subscriptions.TryGetValue(0, out List? channels) ? channels.Count : 0; + int patternCount = _pubSubConfig.Subscriptions.TryGetValue(1, out List? patterns) ? patterns.Count : 0; + int shardedChannelCount = _pubSubConfig.Subscriptions.TryGetValue(2, out List? shardedChannels) ? shardedChannels.Count : 0; + + FreeStringArray(_pubSubChannelsPtr, channelCount); + FreeStringArray(_pubSubPatternsPtr, patternCount); + FreeStringArray(_pubSubShardedChannelsPtr, shardedChannelCount); + } + } + + private static void FreeStringArray(IntPtr arrayPtr, int count) + { + if (arrayPtr == IntPtr.Zero || count == 0) + { + return; + } + + // Free each string in the array + for (int i = 0; i < count; i++) + { + IntPtr stringPtr = Marshal.ReadIntPtr(arrayPtr, i * IntPtr.Size); + if (stringPtr != IntPtr.Zero) + { + Marshal.FreeHGlobal(stringPtr); + } + } + + // Free the array itself + Marshal.FreeHGlobal(arrayPtr); + } protected override IntPtr AllocateAndCopy() { @@ -248,8 +291,68 @@ protected override IntPtr AllocateAndCopy() { Marshal.StructureToPtr(_addresses[i], _request.Addresses + (i * addressSize), false); } + + // Marshal PubSub configuration if present + if (_pubSubConfig != null) + { + _request.PubSubConfig = MarshalPubSubConfig(_pubSubConfig); + } + return StructToPtr(_request); } + + + + private PubSubConfigInfo MarshalPubSubConfig(BasePubSubSubscriptionConfig config) + { + var pubSubInfo = new PubSubConfigInfo(); + + // Marshal exact channels (mode 0) + if (config.Subscriptions.TryGetValue(0, out List? channels) && channels.Count > 0) + { + _pubSubChannelsPtr = MarshalStringArray(channels); + pubSubInfo.ChannelsPtr = _pubSubChannelsPtr; + pubSubInfo.ChannelCount = (uint)channels.Count; + } + + // Marshal patterns (mode 1) + if (config.Subscriptions.TryGetValue(1, out List? patterns) && patterns.Count > 0) + { + _pubSubPatternsPtr = MarshalStringArray(patterns); + pubSubInfo.PatternsPtr = _pubSubPatternsPtr; + pubSubInfo.PatternCount = (uint)patterns.Count; + } + + // Marshal sharded channels (mode 2) - only for cluster clients + if (config.Subscriptions.TryGetValue(2, out List? shardedChannels) && shardedChannels.Count > 0) + { + _pubSubShardedChannelsPtr = MarshalStringArray(shardedChannels); + pubSubInfo.ShardedChannelsPtr = _pubSubShardedChannelsPtr; + pubSubInfo.ShardedChannelCount = (uint)shardedChannels.Count; + } + + return pubSubInfo; + } + + private static IntPtr MarshalStringArray(List strings) + { + if (strings.Count == 0) + { + return IntPtr.Zero; + } + + // Allocate array of string pointers + IntPtr arrayPtr = Marshal.AllocHGlobal(IntPtr.Size * strings.Count); + + for (int i = 0; i < strings.Count; i++) + { + // Allocate and copy each string + IntPtr stringPtr = Marshal.StringToHGlobalAnsi(strings[i]); + Marshal.WriteIntPtr(arrayPtr, i * IntPtr.Size, stringPtr); + } + + return arrayPtr; + } } private static IntPtr StructToPtr(T @struct) where T : struct @@ -265,6 +368,98 @@ private static IntPtr StructToPtr(T @struct) where T : struct private static void PoolReturn(T[] arr) => ArrayPool.Shared.Return(arr); + /// + /// Marshals raw byte arrays from FFI callback parameters to a managed PubSubMessage object. + /// + /// The type of push notification. + /// Pointer to the raw message bytes. + /// The length of the message data in bytes. + /// Pointer to the raw channel name bytes. + /// The length of the channel name in bytes. + /// Pointer to the raw pattern bytes (null if no pattern). + /// The length of the pattern in bytes (0 if no pattern). + /// A managed PubSubMessage object. + /// Thrown when the parameters are invalid or marshaling fails. + internal static PubSubMessage MarshalPubSubMessage( + PushKind pushKind, + IntPtr messagePtr, + long messageLen, + IntPtr channelPtr, + long channelLen, + IntPtr patternPtr, + long patternLen) + { + try + { + // Validate input parameters + if (messagePtr == IntPtr.Zero) + { + throw new ArgumentException("Invalid message data: pointer is null"); + } + + if (channelPtr == IntPtr.Zero) + { + throw new ArgumentException("Invalid channel data: pointer is null"); + } + + if (messageLen < 0) + { + throw new ArgumentException("Invalid message data: length cannot be negative"); + } + + if (channelLen <= 0) + { + throw new ArgumentException("Invalid channel data: pointer is null or length is zero"); + } + + // Marshal message bytes to string + byte[] messageBytes = new byte[messageLen]; + if (messageLen > 0) + { + Marshal.Copy(messagePtr, messageBytes, 0, (int)messageLen); + } + string message = System.Text.Encoding.UTF8.GetString(messageBytes); + + if (string.IsNullOrEmpty(message)) + { + throw new ArgumentException("PubSub message content cannot be null or empty after marshaling"); + } + + // Marshal channel bytes to string + byte[] channelBytes = new byte[channelLen]; + Marshal.Copy(channelPtr, channelBytes, 0, (int)channelLen); + string channel = System.Text.Encoding.UTF8.GetString(channelBytes); + + if (string.IsNullOrEmpty(channel)) + { + throw new ArgumentException("PubSub channel name cannot be null or empty after marshaling"); + } + + // Marshal pattern bytes to string if present + string? pattern = null; + if (patternPtr != IntPtr.Zero && patternLen > 0) + { + byte[] patternBytes = new byte[patternLen]; + Marshal.Copy(patternPtr, patternBytes, 0, (int)patternLen); + pattern = System.Text.Encoding.UTF8.GetString(patternBytes); + + if (string.IsNullOrEmpty(pattern)) + { + throw new ArgumentException("PubSub pattern cannot be empty when pattern pointer is provided"); + } + } + + // Create PubSubMessage based on whether pattern is present + return pattern == null + ? new PubSubMessage(message, channel) + : new PubSubMessage(message, channel, pattern); + } + catch (Exception ex) when (ex is not ArgumentException) + { + throw new ArgumentException($"Failed to marshal PubSub message from FFI callback parameters: {ex.Message}", ex); + } + } + [StructLayout(LayoutKind.Sequential)] private struct CmdInfo { @@ -770,9 +965,23 @@ private struct ConnectionRequest public ConnectionConfiguration.Protocol Protocol; [MarshalAs(UnmanagedType.LPStr)] public string? ClientName; + [MarshalAs(UnmanagedType.U1)] + public bool HasPubSubConfig; + public PubSubConfigInfo PubSubConfig; // TODO more config params, see ffi.rs } + [StructLayout(LayoutKind.Sequential)] + private struct PubSubConfigInfo + { + public IntPtr ChannelsPtr; + public uint ChannelCount; + public IntPtr PatternsPtr; + public uint PatternCount; + public IntPtr ShardedChannelsPtr; + public uint ShardedChannelCount; + } + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Ansi)] internal struct NodeAddress { @@ -790,9 +999,57 @@ internal struct AuthenticationInfo(string? username, string password) public string Password = password; } + /// + /// FFI structure for PubSub message data received from native code. + /// + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Ansi)] + public struct PubSubMessageInfo + { + [MarshalAs(UnmanagedType.LPStr)] + public string Message; + [MarshalAs(UnmanagedType.LPStr)] + public string Channel; + [MarshalAs(UnmanagedType.LPStr)] + public string? Pattern; + } + internal enum TlsMode : uint { NoTls = 0, SecureTls = 2, } + + /// + /// Enum representing the type of push notification received from the server. + /// This matches the PushKind enum in the Rust FFI layer. + /// + internal enum PushKind + { + /// Disconnection notification. + PushDisconnection = 0, + /// Other/unknown push notification type. + PushOther = 1, + /// Cache invalidation notification. + PushInvalidate = 2, + /// Regular channel message (SUBSCRIBE). + PushMessage = 3, + /// Pattern-based message (PSUBSCRIBE). + PushPMessage = 4, + /// Sharded channel message (SSUBSCRIBE). + PushSMessage = 5, + /// Unsubscribe confirmation. + PushUnsubscribe = 6, + /// Pattern unsubscribe confirmation. + PushPUnsubscribe = 7, + /// Sharded unsubscribe confirmation. + PushSUnsubscribe = 8, + /// Subscribe confirmation. + PushSubscribe = 9, + /// Pattern subscribe confirmation. + PushPSubscribe = 10, + /// Sharded subscribe confirmation. + PushSSubscribe = 11, + } + + } diff --git a/sources/Valkey.Glide/PubSubMessage.cs b/sources/Valkey.Glide/PubSubMessage.cs new file mode 100644 index 00000000..0198f02e --- /dev/null +++ b/sources/Valkey.Glide/PubSubMessage.cs @@ -0,0 +1,141 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +using System.Text.Json; + +namespace Valkey.Glide; + +/// +/// Represents a message received through PubSub subscription. +/// +public sealed class PubSubMessage +{ + /// + /// The message content. + /// + public string Message { get; } + + /// + /// The channel on which the message was received. + /// + public string Channel { get; } + + /// + /// The pattern that matched the channel (null for exact channel subscriptions). + /// + public string? Pattern { get; } + + /// + /// Initializes a new instance of the class for exact channel subscriptions. + /// + /// The message content. + /// The channel on which the message was received. + /// Thrown when message or channel is null. + /// Thrown when message or channel is empty. + public PubSubMessage(string message, string channel) + { + if (message == null) + { + throw new ArgumentNullException(nameof(message)); + } + + if (channel == null) + { + throw new ArgumentNullException(nameof(channel)); + } + + if (string.IsNullOrEmpty(message)) + { + throw new ArgumentException("Message cannot be empty", nameof(message)); + } + + if (string.IsNullOrEmpty(channel)) + { + throw new ArgumentException("Channel cannot be empty", nameof(channel)); + } + + Message = message; + Channel = channel; + Pattern = null; + } + + /// + /// Initializes a new instance of the class for pattern-based subscriptions. + /// + /// The message content. + /// The channel on which the message was received. + /// The pattern that matched the channel. + /// Thrown when message, channel, or pattern is null. + /// Thrown when message, channel, or pattern is empty. + public PubSubMessage(string message, string channel, string pattern) + { + if (message == null) + { + throw new ArgumentNullException(nameof(message)); + } + + if (channel == null) + { + throw new ArgumentNullException(nameof(channel)); + } + + if (pattern == null) + { + throw new ArgumentNullException(nameof(pattern)); + } + + if (string.IsNullOrEmpty(message)) + { + throw new ArgumentException("Message cannot be empty", nameof(message)); + } + + if (string.IsNullOrEmpty(channel)) + { + throw new ArgumentException("Channel cannot be empty", nameof(channel)); + } + + if (string.IsNullOrEmpty(pattern)) + { + throw new ArgumentException("Pattern cannot be empty", nameof(pattern)); + } + + Message = message; + Channel = channel; + Pattern = pattern; + } + + /// + /// Returns a JSON string representation of the PubSub message for debugging purposes. + /// + /// A JSON representation of the message. + public override string ToString() + { + var messageObject = new + { + Message, + Channel, + Pattern + }; + + return JsonSerializer.Serialize(messageObject, new JsonSerializerOptions + { + WriteIndented = false + }); + } + + /// + /// Determines whether the specified object is equal to the current PubSubMessage. + /// + /// The object to compare with the current PubSubMessage. + /// true if the specified object is equal to the current PubSubMessage; otherwise, false. + public override bool Equals(object? obj) => + obj is PubSubMessage other && + Message == other.Message && + Channel == other.Channel && + Pattern == other.Pattern; + + /// + /// Returns the hash code for this PubSubMessage. + /// + /// A hash code for the current PubSubMessage. + public override int GetHashCode() => HashCode.Combine(Message, Channel, Pattern); +} diff --git a/sources/Valkey.Glide/PubSubMessageHandler.cs b/sources/Valkey.Glide/PubSubMessageHandler.cs new file mode 100644 index 00000000..e581eac6 --- /dev/null +++ b/sources/Valkey.Glide/PubSubMessageHandler.cs @@ -0,0 +1,148 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +namespace Valkey.Glide; + +/// +/// Delegate for PubSub message callbacks. +/// +/// The received PubSub message. +/// User-provided context object. +public delegate void MessageCallback(PubSubMessage message, object? context); + +/// +/// Handles routing of PubSub messages to callbacks or queues. +/// Provides error handling and recovery for callback exceptions. +/// +internal sealed class PubSubMessageHandler : IDisposable +{ + private readonly MessageCallback? _callback; + private readonly object? _context; + private readonly PubSubMessageQueue _queue; + private readonly object _lock = new(); + private volatile bool _disposed; + + /// + /// Initializes a new instance of the class. + /// + /// Optional callback to invoke when messages are received. If null, messages will be queued. + /// Optional context object to pass to the callback. + internal PubSubMessageHandler(MessageCallback? callback, object? context) + { + _callback = callback; + _context = context; + _queue = new PubSubMessageQueue(); + } + + /// + /// Process an incoming PubSub message by routing it to callback or queue. + /// + /// The message to process. + /// Thrown when message is null. + /// Thrown when the handler has been disposed. + internal void HandleMessage(PubSubMessage message) + { + ArgumentNullException.ThrowIfNull(message); + ThrowIfDisposed(); + + if (_callback != null) + { + // Route to callback with error handling + InvokeCallbackSafely(message); + } + else + { + // Route to queue + try + { + _queue.EnqueueMessage(message); + } + catch (ObjectDisposedException) + { + // Queue was disposed, ignore the message + Logger.Log(Level.Warn, "PubSubMessageHandler", $"Attempted to enqueue message to disposed queue for channel {message.Channel}"); + } + } + } + + /// + /// Get the message queue for manual message retrieval. + /// + /// The message queue instance. + /// Thrown when the handler has been disposed. + /// Thrown when a callback is configured. + internal PubSubMessageQueue GetQueue() + { + ThrowIfDisposed(); + + return _callback != null + ? throw new InvalidOperationException("Cannot access message queue when callback is configured. Use callback mode or queue mode, not both.") + : _queue; + } + + /// + /// Safely invoke the callback with proper error handling and recovery. + /// + /// The message to pass to the callback. + private void InvokeCallbackSafely(PubSubMessage message) + { + try + { + // Check if disposed before invoking callback to avoid race conditions + if (_disposed) + { + return; + } + + _callback!(message, _context); + } + catch (Exception ex) + { + // Log the error and continue processing subsequent messages + // This ensures that callback exceptions don't break the message processing pipeline + Logger.Log(Level.Error, "PubSubMessageHandler", $"Error in PubSub message callback for channel {message.Channel}. Message processing will continue.", ex); + } + } + + /// + /// Releases all resources used by the . + /// + public void Dispose() + { + if (_disposed) + { + return; + } + + lock (_lock) + { + if (_disposed) + { + return; + } + + _disposed = true; + } + + // Dispose the message queue + try + { + _queue.Dispose(); + } + catch (Exception ex) + { + // Log disposal errors but don't throw to ensure cleanup completes + Logger.Log(Level.Warn, "PubSubMessageHandler", "Error during PubSub message queue disposal", ex); + } + } + + /// + /// Throws an ObjectDisposedException if the handler has been disposed. + /// + private void ThrowIfDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(PubSubMessageHandler)); + } + } +} diff --git a/sources/Valkey.Glide/PubSubMessageQueue.cs b/sources/Valkey.Glide/PubSubMessageQueue.cs new file mode 100644 index 00000000..0c0d8484 --- /dev/null +++ b/sources/Valkey.Glide/PubSubMessageQueue.cs @@ -0,0 +1,173 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +using System.Collections.Concurrent; +using System.Runtime.CompilerServices; + +namespace Valkey.Glide; + +/// +/// Thread-safe queue for PubSub messages with async support. +/// Provides both blocking and non-blocking message retrieval methods. +/// +public sealed class PubSubMessageQueue : IDisposable +{ + private readonly ConcurrentQueue _messages; + private readonly SemaphoreSlim _messageAvailable; + private readonly object _lock = new(); + private volatile bool _disposed; + + /// + /// Initializes a new instance of the class. + /// + public PubSubMessageQueue() + { + _messages = new ConcurrentQueue(); + _messageAvailable = new SemaphoreSlim(0); + } + + /// + /// Gets the current number of queued messages. + /// + public int Count => _messages.Count; + + /// + /// Try to get a message from the queue without blocking. + /// + /// The retrieved message, or null if no message is available. + /// true if a message was retrieved; otherwise, false. + /// Thrown when the queue has been disposed. + public bool TryGetMessage(out PubSubMessage? message) + { + ThrowIfDisposed(); + + if (_messages.TryDequeue(out message)) + { + // Consume one semaphore count since we dequeued a message + _ = _messageAvailable.Wait(0); + return true; + } + + message = null; + return false; + } + + /// + /// Asynchronously wait for and retrieve a message from the queue. + /// + /// Token to cancel the operation. + /// A task that represents the asynchronous operation. The task result contains the retrieved message. + /// Thrown when the queue has been disposed. + /// Thrown when the operation is cancelled. + public async Task GetMessageAsync(CancellationToken cancellationToken = default) + { + ThrowIfDisposed(); + + // Wait for a message to be available + await _messageAvailable.WaitAsync(cancellationToken).ConfigureAwait(false); + + // Check if disposed after waiting + ThrowIfDisposed(); + + // Try to dequeue the message + if (_messages.TryDequeue(out PubSubMessage? message)) + { + return message; + } + + // This should not happen under normal circumstances, but handle it gracefully + throw new InvalidOperationException("Message queue is in an inconsistent state"); + } + + /// + /// Get an async enumerable for continuous message processing. + /// + /// Token to cancel the enumeration. + /// An async enumerable that yields messages as they become available. + /// Thrown when the queue has been disposed. + public async IAsyncEnumerable GetMessagesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + while (!_disposed && !cancellationToken.IsCancellationRequested) + { + PubSubMessage message; + try + { + message = await GetMessageAsync(cancellationToken).ConfigureAwait(false); + } + catch (ObjectDisposedException) + { + // Queue was disposed, exit enumeration + yield break; + } + catch (OperationCanceledException) + { + // Operation was cancelled, exit enumeration + yield break; + } + + yield return message; + } + } + + /// + /// Enqueue a message to the queue. + /// This method is intended for internal use by the PubSub message handler. + /// + /// The message to enqueue. + /// Thrown when message is null. + /// Thrown when the queue has been disposed. + internal void EnqueueMessage(PubSubMessage message) + { + ArgumentNullException.ThrowIfNull(message); + ThrowIfDisposed(); + + _messages.Enqueue(message); + _ = _messageAvailable.Release(); + } + + /// + /// Releases all resources used by the . + /// + public void Dispose() + { + if (_disposed) + { + return; + } + + lock (_lock) + { + if (_disposed) + { + return; + } + + _disposed = true; + } + + // Release all waiting threads + try + { + // Release as many times as there are potentially waiting threads + // This ensures all waiting GetMessageAsync calls will complete + int releaseCount = Math.Max(1, _messageAvailable.CurrentCount + 10); + _ = _messageAvailable.Release(releaseCount); + } + catch (SemaphoreFullException) + { + // Ignore if semaphore is already at maximum + } + + _messageAvailable.Dispose(); + } + + /// + /// Throws an ObjectDisposedException if the queue has been disposed. + /// + private void ThrowIfDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(PubSubMessageQueue)); + } + } +} diff --git a/sources/Valkey.Glide/PubSubPerformanceConfig.cs b/sources/Valkey.Glide/PubSubPerformanceConfig.cs new file mode 100644 index 00000000..6bb904c6 --- /dev/null +++ b/sources/Valkey.Glide/PubSubPerformanceConfig.cs @@ -0,0 +1,61 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +using System.Threading.Channels; + +namespace Valkey.Glide; + +/// +/// Configuration options for PubSub performance tuning. +/// +public sealed class PubSubPerformanceConfig +{ + /// + /// Default channel capacity for message queuing. + /// + public const int DefaultChannelCapacity = 1000; + + /// + /// Default shutdown timeout in seconds. + /// + public const int DefaultShutdownTimeoutSeconds = 5; + + /// + /// Maximum number of messages to queue before applying backpressure. + /// Default: 1000 + /// + public int ChannelCapacity { get; set; } = DefaultChannelCapacity; + + /// + /// Strategy to use when the message channel is full. + /// Default: Wait (apply backpressure) + /// + public BoundedChannelFullMode FullMode { get; set; } = BoundedChannelFullMode.Wait; + + /// + /// Timeout for graceful shutdown of PubSub processing. + /// Default: 5 seconds + /// + public TimeSpan ShutdownTimeout { get; set; } = TimeSpan.FromSeconds(DefaultShutdownTimeoutSeconds); + + /// + /// Validates the configuration. + /// + /// Thrown when configuration values are invalid. + internal void Validate() + { + if (ChannelCapacity <= 0) + { + throw new ArgumentOutOfRangeException(nameof(ChannelCapacity), "Channel capacity must be greater than zero"); + } + + if (ShutdownTimeout <= TimeSpan.Zero) + { + throw new ArgumentOutOfRangeException(nameof(ShutdownTimeout), "Shutdown timeout must be greater than zero"); + } + + if (!Enum.IsDefined(typeof(BoundedChannelFullMode), FullMode)) + { + throw new ArgumentOutOfRangeException(nameof(FullMode), "Invalid BoundedChannelFullMode value"); + } + } +} diff --git a/sources/Valkey.Glide/PubSubSubscriptionConfig.cs b/sources/Valkey.Glide/PubSubSubscriptionConfig.cs new file mode 100644 index 00000000..5e1d8b10 --- /dev/null +++ b/sources/Valkey.Glide/PubSubSubscriptionConfig.cs @@ -0,0 +1,252 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +namespace Valkey.Glide; + +/// +/// PubSub subscription modes for standalone clients. +/// +public enum PubSubChannelMode +{ + /// Exact channel name subscription. + Exact = 0, + /// Pattern-based subscription. + Pattern = 1 +} + +/// +/// PubSub subscription modes for cluster clients. +/// +public enum PubSubClusterChannelMode +{ + /// Exact channel name subscription. + Exact = 0, + /// Pattern-based subscription. + Pattern = 1, + /// Sharded channel subscription (cluster-specific). + Sharded = 2 +} + +/// +/// Base configuration for PubSub subscriptions. +/// +public abstract class BasePubSubSubscriptionConfig +{ + internal MessageCallback? Callback { get; set; } + internal object? Context { get; set; } + internal Dictionary> Subscriptions { get; set; } = []; + internal PubSubPerformanceConfig? PerformanceConfig { get; set; } + + /// + /// Configure a message callback to be invoked when messages are received. + /// + /// The callback function to invoke for received messages. + /// Optional context object to pass to the callback. + /// This configuration instance for method chaining. + /// Thrown when callback is null. + public T WithCallback(MessageCallback callback, object? context = null) where T : BasePubSubSubscriptionConfig + { + Callback = callback ?? throw new ArgumentNullException(nameof(callback), "Callback cannot be null"); + Context = context; + return (T)this; + } + + /// + /// Validates the subscription configuration. + /// + /// Thrown when configuration is invalid. + internal virtual void Validate() + { + if (Subscriptions.Count == 0) + { + throw new ArgumentException("At least one subscription must be configured"); + } + + foreach (KeyValuePair> kvp in Subscriptions) + { + if (kvp.Value == null || kvp.Value.Count == 0) + { + throw new ArgumentException($"Subscription mode {kvp.Key} has no channels or patterns configured"); + } + + foreach (string channelOrPattern in kvp.Value) + { + if (string.IsNullOrWhiteSpace(channelOrPattern)) + { + throw new ArgumentException("Channel name or pattern cannot be null, empty, or whitespace"); + } + } + } + } +} + +/// +/// PubSub subscription configuration for standalone clients. +/// +public sealed class StandalonePubSubSubscriptionConfig : BasePubSubSubscriptionConfig +{ + /// + /// Initializes a new instance of the class. + /// + public StandalonePubSubSubscriptionConfig() + { + } + + /// + /// Add a channel or pattern subscription. + /// + /// The subscription mode (Exact or Pattern). + /// The channel name or pattern to subscribe to. + /// This configuration instance for method chaining. + /// Thrown when channelOrPattern is null, empty, or whitespace. + /// Thrown when mode is not valid for standalone clients. + public StandalonePubSubSubscriptionConfig WithSubscription(PubSubChannelMode mode, string channelOrPattern) + { + if (string.IsNullOrWhiteSpace(channelOrPattern)) + { + throw new ArgumentException("Channel name or pattern cannot be null, empty, or whitespace", nameof(channelOrPattern)); + } + + if (!Enum.IsDefined(typeof(PubSubChannelMode), mode)) + { + throw new ArgumentOutOfRangeException(nameof(mode), "Invalid PubSub channel mode for standalone client"); + } + + uint modeValue = (uint)mode; + if (!Subscriptions.ContainsKey(modeValue)) + { + Subscriptions[modeValue] = []; + } + + if (!Subscriptions[modeValue].Contains(channelOrPattern)) + { + Subscriptions[modeValue].Add(channelOrPattern); + } + + return this; + } + + /// + /// Add an exact channel subscription. + /// + /// The channel name to subscribe to. + /// /// Thistion instance for method chaining. + /// Thrown when channel is null, empty, or whitespace. + public StandalonePubSubSubscriptionConfig WithChannel(string channel) => WithSubscription(PubSubChannelMode.Exact, channel); + + /// + /// Add a pattern subscription. + /// + /// The pattern to subscribe to. + /// This configuration instance for method chaining. + /// Thrown when pattern is null, empty, or whitespace. + public StandalonePubSubSubscriptionConfig WithPattern(string pattern) => WithSubscription(PubSubChannelMode.Pattern, pattern); + + /// + /// Validates the standalone subscription configuration. + /// + /// Thrown when configuration is invalid. + internal override void Validate() + { + base.Validate(); + + // Ensure only valid modes for standalone clients are used + foreach (uint mode in Subscriptions.Keys) + { + if (mode is not ((uint)PubSubChannelMode.Exact) and not ((uint)PubSubChannelMode.Pattern)) + { + throw new ArgumentException($"Subscription mode {mode} is not valid for standalone clients"); + } + } + } +} + +/// +/// PubSub subscription configuration for cluster clients. +/// +public sealed class ClusterPubSubSubscriptionConfig : BasePubSubSubscriptionConfig +{ + /// + /// /// Initializes a ne of the class. + /// + public ClusterPubSubSubscriptionConfig() + { + } + + /// + /// Add a channel, pattern, or sharded subscription. + /// + /// The subscription mode (Exact, Pattern, or Sharded). + /// The channel name or pattern to subscribe to. + /// This configuration instance for method chaining. + /// Thrown when channelOrPattern is null, empty, or whitespace. + /// Thrown when mode is not valid for cluster clients. + public ClusterPubSubSubscriptionConfig WithSubscription(PubSubClusterChannelMode mode, string channelOrPattern) + { + if (string.IsNullOrWhiteSpace(channelOrPattern)) + { + throw new ArgumentException("Channel name or pattern cannot be null, empty, or whitespace", nameof(channelOrPattern)); + } + + if (!Enum.IsDefined(typeof(PubSubClusterChannelMode), mode)) + { + throw new ArgumentOutOfRangeException(nameof(mode), "Invalid PubSub channel mode for cluster client"); + } + + uint modeValue = (uint)mode; + if (!Subscriptions.ContainsKey(modeValue)) + { + Subscriptions[modeValue] = []; + } + + if (!Subscriptions[modeValue].Contains(channelOrPattern)) + { + Subscriptions[modeValue].Add(channelOrPattern); + } + + return this; + } + + /// + /// Add an exact channel subscription. + /// + /// The channel name to subscribe to. + /// This configuration instance for method chaining. + /// Thrown when channel is null, empty, or whitespace. + public ClusterPubSubSubscriptionConfig WithChannel(string channel) => WithSubscription(PubSubClusterChannelMode.Exact, channel); + + /// + /// Add a pattern subscription. + /// + /// The pattern to subscribe to. + /// This configuration instance for method chaining. + /// Thrown when pattern is null, empty, or whitespace. + public ClusterPubSubSubscriptionConfig WithPattern(string pattern) => WithSubscription(PubSubClusterChannelMode.Pattern, pattern); + + /// + /// Add a sharded channel subscription. + /// + /// The sharded channel name to subscribe to. + /// This configuration instance for method chaining. + /// Thrown when channel is null, empty, or whitespace. + public ClusterPubSubSubscriptionConfig WithShardedChannel(string channel) => WithSubscription(PubSubClusterChannelMode.Sharded, channel); + + /// + /// Validates the cluster subscription configuration. + /// + /// Thrown when configuration is invalid. + internal override void Validate() + { + base.Validate(); + + // Ensure only valid modes for cluster clients are used + foreach (uint mode in Subscriptions.Keys) + { + if (mode is not ((uint)PubSubClusterChannelMode.Exact) and + not ((uint)PubSubClusterChannelMode.Pattern) and + not ((uint)PubSubClusterChannelMode.Sharded)) + { + throw new ArgumentException($"Subscription mode {mode} is not valid for cluster clients"); + } + } + } +} diff --git a/tests/Valkey.Glide.IntegrationTests/PubSubCallbackIntegrationTests.cs b/tests/Valkey.Glide.IntegrationTests/PubSubCallbackIntegrationTests.cs new file mode 100644 index 00000000..ff7560e6 --- /dev/null +++ b/tests/Valkey.Glide.IntegrationTests/PubSubCallbackIntegrationTests.cs @@ -0,0 +1,671 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +using System.Diagnostics; + +namespace Valkey.Glide.IntegrationTests; + +/// +/// End-to-end integration tests for PubSub functionality. +/// These tests verify the complete message flow from PUBLISH commands through the server, +/// Rust core, FFI boundary, and into C# callbacks. +/// Uses CustomCommand for PUBLISH operations to test the full stack. +/// +public class PubSubCallbackIntegrationTests : IDisposable +{ + private readonly List _testClients = []; + private readonly ManualResetEventSlim _messageReceivedEvent = new(false); + private readonly object _lockObject = new(); + + public void Dispose() + { + // Clean up all test clients + foreach (BaseClient client in _testClients) + { + try + { + client.Dispose(); + } + catch + { + // Ignore disposal errors in tests + } + } + _testClients.Clear(); + _messageReceivedEvent.Dispose(); + } + + [Fact] + public async Task EndToEndMessageFlow_WithStandaloneClient_ProcessesMessagesCorrectly() + { + // Arrange + string testChannel = $"test-channel-{Guid.NewGuid()}"; + string testMessage = "Hello from integration test!"; + bool messageReceived = false; + PubSubMessage? receivedMessage = null; + + StandalonePubSubSubscriptionConfig pubsubConfig = new StandalonePubSubSubscriptionConfig() + .WithChannel(testChannel) + .WithCallback((message, context) => + { + receivedMessage = message; + messageReceived = true; + _messageReceivedEvent.Set(); + }); + + var subscriberConfig = TestConfiguration.DefaultClientConfig() + .WithPubSubSubscriptions(pubsubConfig) + .Build(); + + var publisherConfig = TestConfiguration.DefaultClientConfig().Build(); + + // Act + GlideClient subscriberClient = await GlideClient.CreateClient(subscriberConfig); + _testClients.Add(subscriberClient); + + GlideClient publisherClient = await GlideClient.CreateClient(publisherConfig); + _testClients.Add(publisherClient); + + // Wait for subscription to be established + await Task.Delay(1000); + + // Publish message through the server (true E2E) + object? publishResult = await publisherClient.CustomCommand(["PUBLISH", testChannel, testMessage]); + long numReceivers = Convert.ToInt64(publishResult); + Assert.Equal(1L, numReceivers); // Should have 1 subscriber + + // Wait for message to be received via callback + bool received = _messageReceivedEvent.Wait(TimeSpan.FromSeconds(5)); + + // Assert + Assert.True(received, "Message should have been received within timeout"); + Assert.True(messageReceived, "Callback should have been invoked"); + Assert.NotNull(receivedMessage); + Assert.Equal(testMessage, receivedMessage.Message); + Assert.Equal(testChannel, receivedMessage.Channel); + Assert.Null(receivedMessage.Pattern); + } + + [Fact] + public async Task EndToEndMessageFlow_WithClusterClient_ProcessesMessagesCorrectly() + { + // Skip if no cluster hosts available + if (TestConfiguration.CLUSTER_HOSTS.Count == 0) + { + return; + } + + // Arrange + string testChannel = $"test-cluster-channel-{Guid.NewGuid()}"; + string testMessage = "Hello from cluster integration test!"; + bool messageReceived = false; + PubSubMessage? receivedMessage = null; + + ClusterPubSubSubscriptionConfig pubsubConfig = new ClusterPubSubSubscriptionConfig() + .WithChannel(testChannel) + .WithCallback((message, context) => + { + receivedMessage = message; + messageReceived = true; + _messageReceivedEvent.Set(); + }); + + var subscriberConfig = TestConfiguration.DefaultClusterClientConfig() + .WithPubSubSubscriptions(pubsubConfig) + .Build(); + + var publisherConfig = TestConfiguration.DefaultClusterClientConfig().Build(); + + // Act + GlideClusterClient subscriberClient = await GlideClusterClient.CreateClient(subscriberConfig); + _testClients.Add(subscriberClient); + + GlideClusterClient publisherClient = await GlideClusterClient.CreateClient(publisherConfig); + _testClients.Add(publisherClient); + + // Wait for subscription to be established + await Task.Delay(1000); + + // Publish message through the server (true E2E) + ClusterValue publishResult = await publisherClient.CustomCommand(["PUBLISH", testChannel, testMessage]); + long numReceivers = Convert.ToInt64(publishResult.SingleValue); + Assert.Equal(1L, numReceivers); // Should have 1 subscriber + + // Wait for message to be received via callback + bool received = _messageReceivedEvent.Wait(TimeSpan.FromSeconds(5)); + + // Assert + Assert.True(received, "Message should have been received within timeout"); + Assert.True(messageReceived, "Callback should have been invoked"); + Assert.NotNull(receivedMessage); + Assert.Equal(testMessage, receivedMessage.Message); + Assert.Equal(testChannel, receivedMessage.Channel); + Assert.Null(receivedMessage.Pattern); + } + + [Fact] + public async Task PatternSubscription_WithServerPublish_ProcessesPatternMessagesCorrectly() + { + // Arrange + string testPattern = "news.*"; + string testChannel = $"news.sports.{Guid.NewGuid()}"; + string testMessage = "Breaking sports news!"; + bool messageReceived = false; + PubSubMessage? receivedMessage = null; + + StandalonePubSubSubscriptionConfig pubsubConfig = new StandalonePubSubSubscriptionConfig() + .WithPattern(testPattern) + .WithCallback((message, context) => + { + receivedMessage = message; + messageReceived = true; + _messageReceivedEvent.Set(); + }); + + var subscriberConfig = TestConfiguration.DefaultClientConfig() + .WithPubSubSubscriptions(pubsubConfig) + .Build(); + + var publisherConfig = TestConfiguration.DefaultClientConfig().Build(); + + // Act + GlideClient subscriberClient = await GlideClient.CreateClient(subscriberConfig); + _testClients.Add(subscriberClient); + + GlideClient publisherClient = await GlideClient.CreateClient(publisherConfig); + _testClients.Add(publisherClient); + + // Wait for subscription to be established + await Task.Delay(1000); + + // Publish message to channel matching pattern (true E2E) + object? publishResult = await publisherClient.CustomCommand(["PUBLISH", testChannel, testMessage]); + long numReceivers = Convert.ToInt64(publishResult); + Assert.Equal(1L, numReceivers); // Should have 1 pattern subscriber + + // Wait for message to be received via callback + bool received = _messageReceivedEvent.Wait(TimeSpan.FromSeconds(5)); + + // Assert + Assert.True(received, "Pattern message should have been received within timeout"); + Assert.True(messageReceived, "Callback should have been invoked for pattern message"); + Assert.NotNull(receivedMessage); + Assert.Equal(testMessage, receivedMessage.Message); + Assert.Equal(testChannel, receivedMessage.Channel); + Assert.Equal(testPattern, receivedMessage.Pattern); + } + + [Fact] + public async Task CallbackErrorHandling_WithExceptionInCallback_IsolatesErrorsAndContinuesProcessing() + { + // Arrange + string testChannel = $"error-test-{Guid.NewGuid()}"; + int callbackInvocations = 0; + int successfulMessages = 0; + + StandalonePubSubSubscriptionConfig pubsubConfig = new StandalonePubSubSubscriptionConfig() + .WithChannel(testChannel) + .WithCallback((message, context) => + { + int invocation = Interlocked.Increment(ref callbackInvocations); + + // Throw exception on first message, succeed on subsequent messages + if (invocation == 1) + { + throw new InvalidOperationException("Test exception in callback"); + } + + _ = Interlocked.Increment(ref successfulMessages); + if (successfulMessages >= 2) + { + _messageReceivedEvent.Set(); + } + }); + + var subscriberConfig = TestConfiguration.DefaultClientConfig() + .WithPubSubSubscriptions(pubsubConfig) + .Build(); + + var publisherConfig = TestConfiguration.DefaultClientConfig().Build(); + + // Act + GlideClient subscriberClient = await GlideClient.CreateClient(subscriberConfig); + _testClients.Add(subscriberClient); + + GlideClient publisherClient = await GlideClient.CreateClient(publisherConfig); + _testClients.Add(publisherClient); + + // Wait for subscription to be established + await Task.Delay(1000); + + // Publish multiple messages through server + await publisherClient.CustomCommand(["PUBLISH", testChannel, "Message 1 - should cause exception"]); + await Task.Delay(100); // Allow first message to be processed + + await publisherClient.CustomCommand(["PUBLISH", testChannel, "Message 2 - should succeed"]); + await publisherClient.CustomCommand(["PUBLISH", testChannel, "Message 3 - should succeed"]); + + // Wait for successful messages to be processed + bool received = _messageReceivedEvent.Wait(TimeSpan.FromSeconds(5)); + + // Assert + Assert.True(received, "Should have received successful messages despite callback exception"); + Assert.True(callbackInvocations >= 3, "Callback should have been invoked for all messages"); + Assert.True(successfulMessages >= 2, "Should have processed messages successfully after exception"); + } + + [Fact] + public async Task AsyncMessageProcessing_WithServerPublish_CompletesQuicklyWithoutBlockingFFI() + { + // Arrange + string testChannel = $"async-test-{Guid.NewGuid()}"; + List callbackDurations = []; + List processingDurations = []; + int messagesProcessed = 0; + + StandalonePubSubSubscriptionConfig pubsubConfig = new StandalonePubSubSubscriptionConfig() + .WithChannel(testChannel) + .WithCallback((message, context) => + { + Stopwatch sw = Stopwatch.StartNew(); + + // Simulate some processing work + Thread.Sleep(50); // 50ms processing time + + sw.Stop(); + lock (_lockObject) + { + processingDurations.Add(sw.Elapsed); + } + + int processed = Interlocked.Increment(ref messagesProcessed); + if (processed >= 5) + { + _messageReceivedEvent.Set(); + } + }); + + var subscriberConfig = TestConfiguration.DefaultClientConfig() + .WithPubSubSubscriptions(pubsubConfig) + .Build(); + + var publisherConfig = TestConfiguration.DefaultClientConfig().Build(); + + // Act + GlideClient subscriberClient = await GlideClient.CreateClient(subscriberConfig); + _testClients.Add(subscriberClient); + + GlideClient publisherClient = await GlideClient.CreateClient(publisherConfig); + _testClients.Add(publisherClient); + + // Wait for subscription to be established + await Task.Delay(1000); + + // Measure time to publish multiple messages rapidly + Stopwatch publishStopwatch = Stopwatch.StartNew(); + + for (int i = 0; i < 5; i++) + { + await publisherClient.CustomCommand(["PUBLISH", testChannel, $"Async test message {i}"]); + } + + publishStopwatch.Stop(); + + // Wait for all messages to be processed + bool allProcessed = _messageReceivedEvent.Wait(TimeSpan.FromSeconds(10)); + + // Assert + Assert.True(allProcessed, "All messages should have been processed"); + Assert.Equal(5, messagesProcessed); + + // Publishing should complete quickly (shouldn't block on callback processing) + Assert.True(publishStopwatch.ElapsedMilliseconds < 2000, + $"Publishing should complete quickly, took {publishStopwatch.ElapsedMilliseconds}ms"); + + // Processing durations should reflect the actual work done + Assert.True(processingDurations.Count >= 5, "Should have recorded processing durations"); + Assert.All(processingDurations, duration => + Assert.True(duration.TotalMilliseconds >= 40, + $"Processing should take at least 40ms, took {duration.TotalMilliseconds}ms")); + } + + [Fact] + public async Task MemoryManagement_WithMarshaledData_HandlesCleanupCorrectly() + { + // Arrange + string testChannel = $"memory-test-{Guid.NewGuid()}"; + List receivedMessages = []; + int messageCount = 0; + + StandalonePubSubSubscriptionConfig pubsubConfig = new StandalonePubSubSubscriptionConfig() + .WithChannel(testChannel) + .WithCallback((message, context) => + { + lock (_lockObject) + { + receivedMessages.Add(message.Message); + } + + int count = Interlocked.Increment(ref messageCount); + if (count >= 10) + { + _messageReceivedEvent.Set(); + } + }); + + var subscriberConfig = TestConfiguration.DefaultClientConfig() + .WithPubSubSubscriptions(pubsubConfig) + .Build(); + + var publisherConfig = TestConfiguration.DefaultClientConfig().Build(); + + // Act + GlideClient subscriberClient = await GlideClient.CreateClient(subscriberConfig); + _testClients.Add(subscriberClient); + + GlideClient publisherClient = await GlideClient.CreateClient(publisherConfig); + _testClients.Add(publisherClient); + + // Wait for subscription to be established + await Task.Delay(1000); + + // Publish messages with various content to test marshaling through FFI + string[] testMessages = [ + "Simple message", + "Message with special chars: !@#$%^&*()", + "Unicode message: 你好世界 🌍", + "Long message: " + new string('A', 1000), + "Empty content: ", + "Numbers: 1234567890", + "JSON: {\"key\": \"value\", \"number\": 42}", + "XML: value", + "Base64: SGVsbG8gV29ybGQ=", + "Final message" + ]; + + foreach (string message in testMessages) + { + await publisherClient.CustomCommand(["PUBLISH", testChannel, message]); + await Task.Delay(10); // Small delay between messages + } + + // Wait for all messages to be processed + bool allReceived = _messageReceivedEvent.Wait(TimeSpan.FromSeconds(10)); + + // Assert + Assert.True(allReceived, "All messages should have been received"); + Assert.Equal(10, messageCount); + + lock (_lockObject) + { + Assert.Equal(10, receivedMessages.Count); + + // Verify message content integrity (marshaling worked correctly through FFI) + for (int i = 0; i < testMessages.Length; i++) + { + Assert.Contains(testMessages[i], receivedMessages); + } + } + } + + [Fact] + public async Task ErrorIsolation_WithMessageHandlerExceptions_DoesNotCrashProcess() + { + // Arrange + string testChannel = $"isolation-test-{Guid.NewGuid()}"; + int callbackCount = 0; + bool systemStillWorking = false; + + StandalonePubSubSubscriptionConfig pubsubConfig = new StandalonePubSubSubscriptionConfig() + .WithChannel(testChannel) + .WithCallback((message, context) => + { + int count = Interlocked.Increment(ref callbackCount); + + if (count == 1) + { + // First message: throw a severe exception + throw new OutOfMemoryException("Simulated severe exception"); + } + else if (count == 2) + { + // Second message: throw a different exception + throw new InvalidOperationException("Another test exception"); + } + else + { + // Third message: succeed + systemStillWorking = true; + _messageReceivedEvent.Set(); + } + }); + + var subscriberConfig = TestConfiguration.DefaultClientConfig() + .WithPubSubSubscriptions(pubsubConfig) + .Build(); + + var publisherConfig = TestConfiguration.DefaultClientConfig().Build(); + + // Act + GlideClient subscriberClient = await GlideClient.CreateClient(subscriberConfig); + _testClients.Add(subscriberClient); + + GlideClient publisherClient = await GlideClient.CreateClient(publisherConfig); + _testClients.Add(publisherClient); + + // Wait for subscription to be established + await Task.Delay(1000); + + // Publish messages that will cause exceptions in callbacks + await publisherClient.CustomCommand(["PUBLISH", testChannel, "Message 1 - OutOfMemoryException"]); + await Task.Delay(100); + + await publisherClient.CustomCommand(["PUBLISH", testChannel, "Message 2 - InvalidOperationException"]); + await Task.Delay(100); + + await publisherClient.CustomCommand(["PUBLISH", testChannel, "Message 3 - Should succeed"]); + + // Wait for the successful message + bool received = _messageReceivedEvent.Wait(TimeSpan.FromSeconds(5)); + + // Assert + Assert.True(received, "System should continue working after callback exceptions"); + Assert.True(systemStillWorking, "System should process subsequent messages successfully"); + Assert.True(callbackCount >= 3, "All callbacks should have been invoked despite exceptions"); + + // Process should still be running (not crashed) + Assert.False(Environment.HasShutdownStarted, "Process should not have initiated shutdown"); + } + + [Fact] + public async Task MultipleSubscribers_ToSameChannel_AllReceiveMessages() + { + // Arrange + string testChannel = $"multi-sub-{Guid.NewGuid()}"; + string testMessage = "Broadcast message"; + int subscriber1Received = 0; + int subscriber2Received = 0; + ManualResetEventSlim allReceivedEvent = new ManualResetEventSlim(false); + + StandalonePubSubSubscriptionConfig pubsubConfig1 = new StandalonePubSubSubscriptionConfig() + .WithChannel(testChannel) + .WithCallback((message, context) => + { + int count = Interlocked.Increment(ref subscriber1Received); + if (subscriber1Received >= 1 && subscriber2Received >= 1) + { + allReceivedEvent.Set(); + } + }); + + StandalonePubSubSubscriptionConfig pubsubConfig2 = new StandalonePubSubSubscriptionConfig() + .WithChannel(testChannel) + .WithCallback((message, context) => + { + int count = Interlocked.Increment(ref subscriber2Received); + if (subscriber1Received >= 1 && subscriber2Received >= 1) + { + allReceivedEvent.Set(); + } + }); + + var subscriberConfig1 = TestConfiguration.DefaultClientConfig() + .WithPubSubSubscriptions(pubsubConfig1) + .Build(); + + var subscriberConfig2 = TestConfiguration.DefaultClientConfig() + .WithPubSubSubscriptions(pubsubConfig2) + .Build(); + + var publisherConfig = TestConfiguration.DefaultClientConfig().Build(); + + // Act + GlideClient subscriber1 = await GlideClient.CreateClient(subscriberConfig1); + _testClients.Add(subscriber1); + + GlideClient subscriber2 = await GlideClient.CreateClient(subscriberConfig2); + _testClients.Add(subscriber2); + + GlideClient publisher = await GlideClient.CreateClient(publisherConfig); + _testClients.Add(publisher); + + // Wait for subscriptions to be established + await Task.Delay(1000); + + // Publish message - should reach both subscribers + object? publishResult = await publisher.CustomCommand(["PUBLISH", testChannel, testMessage]); + long numReceivers = Convert.ToInt64(publishResult); + Assert.Equal(2L, numReceivers); // Should have 2 subscribers + + // Wait for both subscribers to receive + bool allReceived = allReceivedEvent.Wait(TimeSpan.FromSeconds(5)); + + // Assert + Assert.True(allReceived, "Both subscribers should receive the message"); + Assert.Equal(1, subscriber1Received); + Assert.Equal(1, subscriber2Received); + + allReceivedEvent.Dispose(); + } + + [Fact] + public async Task MessageOrdering_WithMultipleMessages_PreservesOrder() + { + // Arrange + string testChannel = $"order-test-{Guid.NewGuid()}"; + List receivedMessages = []; + int expectedCount = 20; + ManualResetEventSlim allReceivedEvent = new ManualResetEventSlim(false); + + StandalonePubSubSubscriptionConfig pubsubConfig = new StandalonePubSubSubscriptionConfig() + .WithChannel(testChannel) + .WithCallback((message, context) => + { + lock (receivedMessages) + { + receivedMessages.Add(message.Message); + if (receivedMessages.Count >= expectedCount) + { + allReceivedEvent.Set(); + } + } + }); + + var subscriberConfig = TestConfiguration.DefaultClientConfig() + .WithPubSubSubscriptions(pubsubConfig) + .Build(); + + var publisherConfig = TestConfiguration.DefaultClientConfig().Build(); + + // Act + GlideClient subscriber = await GlideClient.CreateClient(subscriberConfig); + _testClients.Add(subscriber); + + GlideClient publisher = await GlideClient.CreateClient(publisherConfig); + _testClients.Add(publisher); + + // Wait for subscription to be established + await Task.Delay(1000); + + // Publish messages in order + for (int i = 0; i < expectedCount; i++) + { + await publisher.CustomCommand(["PUBLISH", testChannel, $"Message-{i:D3}"]); + } + + // Wait for all messages + bool allReceived = allReceivedEvent.Wait(TimeSpan.FromSeconds(10)); + + // Assert + Assert.True(allReceived, "All messages should be received"); + Assert.Equal(expectedCount, receivedMessages.Count); + + // Verify order is preserved + for (int i = 0; i < expectedCount; i++) + { + Assert.Equal($"Message-{i:D3}", receivedMessages[i]); + } + + allReceivedEvent.Dispose(); + } + + [Fact] + public async Task ClusterPatternSubscription_WithServerPublish_ReceivesMatchingMessages() + { + // Skip if no cluster hosts available + if (TestConfiguration.CLUSTER_HOSTS.Count == 0) + { + return; + } + + // Arrange + string testPattern = "events.*"; + string testChannel = $"events.user.{Guid.NewGuid()}"; + string testMessage = "User event occurred"; + bool messageReceived = false; + PubSubMessage? receivedMessage = null; + + ClusterPubSubSubscriptionConfig pubsubConfig = new ClusterPubSubSubscriptionConfig() + .WithPattern(testPattern) + .WithCallback((message, context) => + { + receivedMessage = message; + messageReceived = true; + _messageReceivedEvent.Set(); + }); + + var subscriberConfig = TestConfiguration.DefaultClusterClientConfig() + .WithPubSubSubscriptions(pubsubConfig) + .Build(); + + var publisherConfig = TestConfiguration.DefaultClusterClientConfig().Build(); + + // Act + GlideClusterClient subscriber = await GlideClusterClient.CreateClient(subscriberConfig); + _testClients.Add(subscriber); + + GlideClusterClient publisher = await GlideClusterClient.CreateClient(publisherConfig); + _testClients.Add(publisher); + + // Wait for subscription to be established - pattern subscriptions in cluster mode may need more time + await Task.Delay(5000); + + // Publish to matching channel + ClusterValue publishResult = await publisher.CustomCommand(["PUBLISH", testChannel, testMessage]); + long numReceivers = Convert.ToInt64(publishResult.SingleValue); + + // Note: In cluster mode, PUBLISH returns the number of clients that received the message on the node + // where the channel is hashed. Pattern subscriptions may not always report correctly via PUBLISH return value + // in cluster mode because the subscription might be on a different node. We verify message delivery via callback. + + // Wait for message + bool received = _messageReceivedEvent.Wait(TimeSpan.FromSeconds(5)); + + // Assert + Assert.True(received, "Pattern message should be received"); + Assert.True(messageReceived); + Assert.NotNull(receivedMessage); + Assert.Equal(testMessage, receivedMessage.Message); + Assert.Equal(testChannel, receivedMessage.Channel); + Assert.Equal(testPattern, receivedMessage.Pattern); + } +} diff --git a/tests/Valkey.Glide.IntegrationTests/PubSubQueueIntegrationTests.cs b/tests/Valkey.Glide.IntegrationTests/PubSubQueueIntegrationTests.cs new file mode 100644 index 00000000..5643059b --- /dev/null +++ b/tests/Valkey.Glide.IntegrationTests/PubSubQueueIntegrationTests.cs @@ -0,0 +1,510 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +namespace Valkey.Glide.IntegrationTests; + +/// +/// End-to-end integration tests for PubSub queue-based message retrieval. +/// These tests verify the complete message flow from PUBLISH commands through the server, +/// Rust core, FFI boundary, and into C# message queues (without callbacks). +/// Tests the alternative PubSub usage pattern where users manually poll for messages. +/// +public class PubSubQueueIntegrationTests : IDisposable +{ + private readonly List _testClients = []; + + public void Dispose() + { + GC.SuppressFinalize(this); + foreach (BaseClient client in _testClients) + { + try + { + client.Dispose(); + } + catch + { + // Ignore disposal errors in tests + } + } + _testClients.Clear(); + } + + [Fact] + public async Task QueueBasedRetrieval_WithStandaloneClient_ReceivesMessages() + { + // Arrange + string testChannel = $"queue-test-{Guid.NewGuid()}"; + string testMessage = "Hello from queue test!"; + + // Create subscription config WITHOUT callback - messages go to queue + StandalonePubSubSubscriptionConfig pubsubConfig = new StandalonePubSubSubscriptionConfig() + .WithChannel(testChannel); + // Note: No .WithCallback() - this enables queue mode + + var subscriberConfig = TestConfiguration.DefaultClientConfig() + .WithPubSubSubscriptions(pubsubConfig) + .Build(); + + var publisherConfig = TestConfiguration.DefaultClientConfig().Build(); + + // Act + GlideClient subscriberClient = await GlideClient.CreateClient(subscriberConfig); + _testClients.Add(subscriberClient); + + GlideClient publisherClient = await GlideClient.CreateClient(publisherConfig); + _testClients.Add(publisherClient); + + // Wait for subscription to be established + await Task.Delay(1000); + + // Publish message through the server + object? publishResult = await publisherClient.CustomCommand(["PUBLISH", testChannel, testMessage]); + long numReceivers = Convert.ToInt64(publishResult); + Assert.Equal(1L, numReceivers); + + // Get the message queue and retrieve message + PubSubMessageQueue? queue = subscriberClient.PubSubQueue; + Assert.NotNull(queue); + + // Wait for message to arrive in queue + await Task.Delay(500); + + // Assert + Assert.True(queue.Count > 0, "Queue should contain at least one message"); + bool hasMessage = queue.TryGetMessage(out PubSubMessage? receivedMessage); + Assert.True(hasMessage, "Should successfully retrieve message from queue"); + Assert.NotNull(receivedMessage); + Assert.Equal(testMessage, receivedMessage.Message); + Assert.Equal(testChannel, receivedMessage.Channel); + Assert.Null(receivedMessage.Pattern); + } + + [Fact] + public async Task QueueBasedRetrieval_WithClusterClient_ReceivesMessages() + { + // Skip if no cluster hosts available + if (TestConfiguration.CLUSTER_HOSTS.Count == 0) + { + return; + } + + // Arrange + string testChannel = $"cluster-queue-{Guid.NewGuid()}"; + string testMessage = "Cluster queue message"; + + ClusterPubSubSubscriptionConfig pubsubConfig = new ClusterPubSubSubscriptionConfig() + .WithChannel(testChannel); + // No callback - queue mode + + var subscriberConfig = TestConfiguration.DefaultClusterClientConfig() + .WithPubSubSubscriptions(pubsubConfig) + .Build(); + + var publisherConfig = TestConfiguration.DefaultClusterClientConfig().Build(); + + // Act + GlideClusterClient subscriberClient = await GlideClusterClient.CreateClient(subscriberConfig); + _testClients.Add(subscriberClient); + + GlideClusterClient publisherClient = await GlideClusterClient.CreateClient(publisherConfig); + _testClients.Add(publisherClient); + + await Task.Delay(1000); + + ClusterValue publishResult = await publisherClient.CustomCommand(["PUBLISH", testChannel, testMessage]); + long numReceivers = Convert.ToInt64(publishResult.SingleValue); + Assert.Equal(1L, numReceivers); + + PubSubMessageQueue? queue = subscriberClient.PubSubQueue; + Assert.NotNull(queue); + await Task.Delay(500); + + // Assert + Assert.True(queue.Count > 0); + bool hasMessage = queue.TryGetMessage(out PubSubMessage? receivedMessage); + Assert.True(hasMessage); + Assert.NotNull(receivedMessage); + Assert.Equal(testMessage, receivedMessage.Message); + Assert.Equal(testChannel, receivedMessage.Channel); + } + + [Fact] + public async Task QueueBasedRetrieval_WithPatternSubscription_ReceivesMatchingMessages() + { + // Arrange + string testPattern = "queue.pattern.*"; + string testChannel = $"queue.pattern.{Guid.NewGuid()}"; + string testMessage = "Pattern queue message"; + + StandalonePubSubSubscriptionConfig pubsubConfig = new StandalonePubSubSubscriptionConfig() + .WithPattern(testPattern); + + var subscriberConfig = TestConfiguration.DefaultClientConfig() + .WithPubSubSubscriptions(pubsubConfig) + .Build(); + + var publisherConfig = TestConfiguration.DefaultClientConfig().Build(); + + // Act + GlideClient subscriberClient = await GlideClient.CreateClient(subscriberConfig); + _testClients.Add(subscriberClient); + + GlideClient publisherClient = await GlideClient.CreateClient(publisherConfig); + _testClients.Add(publisherClient); + + await Task.Delay(1000); + + await publisherClient.CustomCommand(["PUBLISH", testChannel, testMessage]); + await Task.Delay(500); + + PubSubMessageQueue? queue = subscriberClient.PubSubQueue; + Assert.NotNull(queue); + + // Assert + Assert.True(queue.Count > 0); + bool hasMessage = queue.TryGetMessage(out PubSubMessage? receivedMessage); + Assert.True(hasMessage); + Assert.NotNull(receivedMessage); + Assert.Equal(testMessage, receivedMessage.Message); + Assert.Equal(testChannel, receivedMessage.Channel); + Assert.Equal(testPattern, receivedMessage.Pattern); + } + + [Fact] + public async Task QueueBasedRetrieval_WithMultipleMessages_PreservesOrder() + { + // Arrange + string testChannel = $"queue-order-{Guid.NewGuid()}"; + int messageCount = 10; + + StandalonePubSubSubscriptionConfig pubsubConfig = new StandalonePubSubSubscriptionConfig() + .WithChannel(testChannel); + + var subscriberConfig = TestConfiguration.DefaultClientConfig() + .WithPubSubSubscriptions(pubsubConfig) + .Build(); + + var publisherConfig = TestConfiguration.DefaultClientConfig().Build(); + + // Act + GlideClient subscriberClient = await GlideClient.CreateClient(subscriberConfig); + _testClients.Add(subscriberClient); + + GlideClient publisherClient = await GlideClient.CreateClient(publisherConfig); + _testClients.Add(publisherClient); + + await Task.Delay(1000); + + // Publish multiple messages in order + for (int i = 0; i < messageCount; i++) + { + await publisherClient.CustomCommand(["PUBLISH", testChannel, $"Message-{i:D3}"]); + } + + // Wait for all messages to arrive + await Task.Delay(1000); + + PubSubMessageQueue? queue = subscriberClient.PubSubQueue; + Assert.NotNull(queue); + + // Assert + Assert.True(queue.Count >= messageCount, $"Queue should contain at least {messageCount} messages"); + + List receivedMessages = []; + for (int i = 0; i < messageCount; i++) + { + bool hasMessage = queue.TryGetMessage(out PubSubMessage? message); + Assert.True(hasMessage, $"Should retrieve message {i}"); + Assert.NotNull(message); + receivedMessages.Add(message.Message); + } + + // Verify order is preserved + for (int i = 0; i < messageCount; i++) + { + Assert.Equal($"Message-{i:D3}", receivedMessages[i]); + } + } + + [Fact] + public async Task QueueBasedRetrieval_GetMessageAsync_BlocksAndSupportsCancellation() + { + // Arrange + string testChannel = $"queue-async-{Guid.NewGuid()}"; + string testMessage = "Async message"; + + StandalonePubSubSubscriptionConfig pubsubConfig = new StandalonePubSubSubscriptionConfig() + .WithChannel(testChannel); + + var subscriberConfig = TestConfiguration.DefaultClientConfig() + .WithPubSubSubscriptions(pubsubConfig) + .Build(); + + var publisherConfig = TestConfiguration.DefaultClientConfig().Build(); + + // Act + GlideClient subscriberClient = await GlideClient.CreateClient(subscriberConfig); + _testClients.Add(subscriberClient); + + GlideClient publisherClient = await GlideClient.CreateClient(publisherConfig); + _testClients.Add(publisherClient); + + await Task.Delay(1000); + + PubSubMessageQueue? queue = subscriberClient.PubSubQueue; + Assert.NotNull(queue); + + // Test 1: Verify blocking behavior + Task getMessageTask = Task.Run(async () => + { + using CancellationTokenSource cts = new(TimeSpan.FromSeconds(10)); + return await queue.GetMessageAsync(cts.Token); + }); + + await Task.Delay(100); + Assert.False(getMessageTask.IsCompleted, "GetMessageAsync should be waiting for message"); + + await publisherClient.CustomCommand(["PUBLISH", testChannel, testMessage]); + PubSubMessage receivedMessage = await getMessageTask; + + Assert.NotNull(receivedMessage); + Assert.Equal(testMessage, receivedMessage.Message); + Assert.Equal(testChannel, receivedMessage.Channel); + + // Test 2: Verify cancellation support + using CancellationTokenSource cts2 = new(TimeSpan.FromMilliseconds(500)); + await Assert.ThrowsAsync(async () => + { + await queue.GetMessageAsync(cts2.Token); + }); + } + + [Fact] + public async Task QueueBasedRetrieval_TryGetMessage_ReturnsFalseWhenEmpty() + { + // Arrange + string testChannel = $"queue-empty-{Guid.NewGuid()}"; + + StandalonePubSubSubscriptionConfig pubsubConfig = new StandalonePubSubSubscriptionConfig() + .WithChannel(testChannel); + + var subscriberConfig = TestConfiguration.DefaultClientConfig() + .WithPubSubSubscriptions(pubsubConfig) + .Build(); + + // Act + GlideClient subscriberClient = await GlideClient.CreateClient(subscriberConfig); + _testClients.Add(subscriberClient); + + await Task.Delay(1000); + + PubSubMessageQueue? queue = subscriberClient.PubSubQueue; + Assert.NotNull(queue); + + // Assert + Assert.Equal(0, queue.Count); + bool hasMessage = queue.TryGetMessage(out PubSubMessage? message); + Assert.False(hasMessage); + Assert.Null(message); + } + + [Fact] + public async Task QueueBasedRetrieval_WithMultipleChannels_ReceivesAllMessages() + { + // Arrange + string channel1 = $"queue-multi-1-{Guid.NewGuid()}"; + string channel2 = $"queue-multi-2-{Guid.NewGuid()}"; + string message1 = "Message from channel 1"; + string message2 = "Message from channel 2"; + + StandalonePubSubSubscriptionConfig pubsubConfig = new StandalonePubSubSubscriptionConfig() + .WithChannel(channel1) + .WithChannel(channel2); + + var subscriberConfig = TestConfiguration.DefaultClientConfig() + .WithPubSubSubscriptions(pubsubConfig) + .Build(); + + var publisherConfig = TestConfiguration.DefaultClientConfig().Build(); + + // Act + GlideClient subscriberClient = await GlideClient.CreateClient(subscriberConfig); + _testClients.Add(subscriberClient); + + GlideClient publisherClient = await GlideClient.CreateClient(publisherConfig); + _testClients.Add(publisherClient); + + await Task.Delay(1000); + + // Publish to both channels + await publisherClient.CustomCommand(["PUBLISH", channel1, message1]); + await publisherClient.CustomCommand(["PUBLISH", channel2, message2]); + + await Task.Delay(500); + + PubSubMessageQueue? queue = subscriberClient.PubSubQueue; + Assert.NotNull(queue); + + // Assert + Assert.True(queue.Count >= 2, "Queue should contain messages from both channels"); + + HashSet receivedChannels = []; + HashSet receivedMessages = []; + + for (int i = 0; i < 2; i++) + { + bool hasMessage = queue.TryGetMessage(out PubSubMessage? message); + Assert.True(hasMessage); + Assert.NotNull(message); + receivedChannels.Add(message.Channel); + receivedMessages.Add(message.Message); + } + + Assert.Contains(channel1, receivedChannels); + Assert.Contains(channel2, receivedChannels); + Assert.Contains(message1, receivedMessages); + Assert.Contains(message2, receivedMessages); + } + + [Fact] + public async Task QueueBasedRetrieval_WithHighVolume_HandlesAllMessages() + { + // Arrange + string testChannel = $"queue-volume-{Guid.NewGuid()}"; + int messageCount = 100; + + StandalonePubSubSubscriptionConfig pubsubConfig = new StandalonePubSubSubscriptionConfig() + .WithChannel(testChannel); + + var subscriberConfig = TestConfiguration.DefaultClientConfig() + .WithPubSubSubscriptions(pubsubConfig) + .Build(); + + var publisherConfig = TestConfiguration.DefaultClientConfig().Build(); + + // Act + GlideClient subscriberClient = await GlideClient.CreateClient(subscriberConfig); + _testClients.Add(subscriberClient); + + GlideClient publisherClient = await GlideClient.CreateClient(publisherConfig); + _testClients.Add(publisherClient); + + await Task.Delay(1000); + + // Publish many messages rapidly + for (int i = 0; i < messageCount; i++) + { + await publisherClient.CustomCommand(["PUBLISH", testChannel, $"Volume-{i}"]); + } + + // Wait for messages to arrive + await Task.Delay(2000); + + PubSubMessageQueue? queue = subscriberClient.PubSubQueue; + Assert.NotNull(queue); + + // Assert + Assert.True(queue.Count >= messageCount, $"Queue should contain all {messageCount} messages"); + + HashSet receivedMessages = []; + for (int i = 0; i < messageCount; i++) + { + bool hasMessage = queue.TryGetMessage(out PubSubMessage? message); + Assert.True(hasMessage, $"Should retrieve message {i}"); + Assert.NotNull(message); + receivedMessages.Add(message.Message); + } + + Assert.Equal(messageCount, receivedMessages.Count); + } + + [Fact] + public async Task QueueBasedRetrieval_MixedCallbackAndQueue_ThrowsInvalidOperationException() + { + // Arrange + string testChannel = $"queue-mixed-{Guid.NewGuid()}"; + + // Create config WITH callback + StandalonePubSubSubscriptionConfig pubsubConfig = new StandalonePubSubSubscriptionConfig() + .WithChannel(testChannel) + .WithCallback((message, context) => + { + // Callback mode + }); + + var subscriberConfig = TestConfiguration.DefaultClientConfig() + .WithPubSubSubscriptions(pubsubConfig) + .Build(); + + // Act + GlideClient subscriberClient = await GlideClient.CreateClient(subscriberConfig); + _testClients.Add(subscriberClient); + + await Task.Delay(1000); + + // Assert - Should throw when trying to get queue in callback mode + Assert.Throws(() => + { + PubSubMessageQueue? queue = subscriberClient.PubSubQueue; + }); + } + + [Fact] + public async Task QueueBasedRetrieval_WithUnicodeAndSpecialCharacters_PreservesContent() + { + // Arrange + string testChannel = $"queue-unicode-{Guid.NewGuid()}"; + string[] testMessages = [ + "Simple ASCII", + "Unicode: 你好世界 🌍", + "Special chars: !@#$%^&*()", + "Emoji: 🎉🚀💻", + "Mixed: Hello世界!🌟" + ]; + + StandalonePubSubSubscriptionConfig pubsubConfig = new StandalonePubSubSubscriptionConfig() + .WithChannel(testChannel); + + var subscriberConfig = TestConfiguration.DefaultClientConfig() + .WithPubSubSubscriptions(pubsubConfig) + .Build(); + + var publisherConfig = TestConfiguration.DefaultClientConfig().Build(); + + // Act + GlideClient subscriberClient = await GlideClient.CreateClient(subscriberConfig); + _testClients.Add(subscriberClient); + + GlideClient publisherClient = await GlideClient.CreateClient(publisherConfig); + _testClients.Add(publisherClient); + + await Task.Delay(1000); + + foreach (string message in testMessages) + { + await publisherClient.CustomCommand(["PUBLISH", testChannel, message]); + } + + await Task.Delay(1000); + + PubSubMessageQueue? queue = subscriberClient.PubSubQueue; + Assert.NotNull(queue); + + // Assert + Assert.True(queue.Count >= testMessages.Length); + + List receivedMessages = []; + for (int i = 0; i < testMessages.Length; i++) + { + bool hasMessage = queue.TryGetMessage(out PubSubMessage? message); + Assert.True(hasMessage); + Assert.NotNull(message); + receivedMessages.Add(message.Message); + } + + foreach (string expectedMessage in testMessages) + { + Assert.Contains(expectedMessage, receivedMessages); + } + } +} diff --git a/tests/Valkey.Glide.UnitTests/GlobalSuppressions.cs b/tests/Valkey.Glide.UnitTests/GlobalSuppressions.cs new file mode 100644 index 00000000..021af516 --- /dev/null +++ b/tests/Valkey.Glide.UnitTests/GlobalSuppressions.cs @@ -0,0 +1,5 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +using System.Diagnostics.CodeAnalysis; + +[assembly: SuppressMessage("Style", "IDE0130:Namespace does not match folder structure", Justification = "", Scope = "namespace", Target = "~N:Valkey.Glide.UnitTests")] diff --git a/tests/Valkey.Glide.UnitTests/PubSubConfigurationTests.cs b/tests/Valkey.Glide.UnitTests/PubSubConfigurationTests.cs new file mode 100644 index 00000000..31b8aae0 --- /dev/null +++ b/tests/Valkey.Glide.UnitTests/PubSubConfigurationTests.cs @@ -0,0 +1,308 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +using System; +using System.Collections.Generic; + +using Xunit; + +using static Valkey.Glide.ConnectionConfiguration; + +namespace Valkey.Glide.UnitTests; + +/// +/// Unit tests for PubSub configuration extensions to ConnectionConfiguration builders. +/// These tests verify that PubSub subscription configuration flows correctly through +/// the configuration builders and validation works as expected. +/// +public class PubSubConfigurationTests +{ + #region StandaloneClientConfigurationBuilder Tests + + [Fact] + public void StandaloneClientConfigurationBuilder_WithPubSubSubscriptions_ValidConfig_SetsConfiguration() + { + // Arrange + var pubSubConfig = new StandalonePubSubSubscriptionConfig() + .WithChannel("test-channel") + .WithPattern("test-*"); + + // Act + var builder = new StandaloneClientConfigurationBuilder() + .WithAddress("localhost", 6379) + .WithPubSubSubscriptions(pubSubConfig); + + var config = builder.Build(); + + // Assert + Assert.NotNull(config.Request.PubSubSubscriptions); + Assert.Same(pubSubConfig, config.Request.PubSubSubscriptions); + } + + [Fact] + public void StandaloneClientConfigurationBuilder_WithPubSubSubscriptions_NullConfig_ThrowsArgumentNullException() + { + // Arrange + var builder = new StandaloneClientConfigurationBuilder() + .WithAddress("localhost", 6379); + + // Act & Assert + Assert.Throws(() => builder.WithPubSubSubscriptions(null!)); + } + + [Fact] + public void StandaloneClientConfigurationBuilder_WithPubSubSubscriptions_InvalidConfig_ThrowsArgumentException() + { + // Arrange + var pubSubConfig = new StandalonePubSubSubscriptionConfig(); // Empty config - invalid + + var builder = new StandaloneClientConfigurationBuilder() + .WithAddress("localhost", 6379); + + // Act & Assert + Assert.Throws(() => builder.WithPubSubSubscriptions(pubSubConfig)); + } + + [Fact] + public void StandaloneClientConfigurationBuilder_WithPubSubSubscriptions_WithCallback_SetsCallbackAndContext() + { + // Arrange + var context = new { TestData = "test" }; + MessageCallback callback = (message, ctx) => { /* test callback */ }; + + var pubSubConfig = new StandalonePubSubSubscriptionConfig() + .WithChannel("test-channel") + .WithCallback(callback, context); + + // Act + var builder = new StandaloneClientConfigurationBuilder() + .WithAddress("localhost", 6379) + .WithPubSubSubscriptions(pubSubConfig); + + var config = builder.Build(); + + // Assert + Assert.NotNull(config.Request.PubSubSubscriptions); + var storedConfig = config.Request.PubSubSubscriptions as StandalonePubSubSubscriptionConfig; + Assert.NotNull(storedConfig); + Assert.Same(callback, storedConfig.Callback); + Assert.Same(context, storedConfig.Context); + } + + [Fact] + public void StandaloneClientConfigurationBuilder_WithPubSubSubscriptions_MultipleChannelsAndPatterns_SetsAllSubscriptions() + { + // Arrange + var pubSubConfig = new StandalonePubSubSubscriptionConfig() + .WithChannel("channel1") + .WithChannel("channel2") + .WithPattern("pattern1*") + .WithPattern("pattern2*"); + + // Act + var builder = new StandaloneClientConfigurationBuilder() + .WithAddress("localhost", 6379) + .WithPubSubSubscriptions(pubSubConfig); + + var config = builder.Build(); + + // Assert + Assert.NotNull(config.Request.PubSubSubscriptions); + var storedConfig = config.Request.PubSubSubscriptions as StandalonePubSubSubscriptionConfig; + Assert.NotNull(storedConfig); + + // Check exact channels (mode 0) + Assert.True(storedConfig.Subscriptions.ContainsKey(0)); + Assert.Contains("channel1", storedConfig.Subscriptions[0]); + Assert.Contains("channel2", storedConfig.Subscriptions[0]); + + // Check patterns (mode 1) + Assert.True(storedConfig.Subscriptions.ContainsKey(1)); + Assert.Contains("pattern1*", storedConfig.Subscriptions[1]); + Assert.Contains("pattern2*", storedConfig.Subscriptions[1]); + } + + #endregion + + #region ClusterClientConfigurationBuilder Tests + + [Fact] + public void ClusterClientConfigurationBuilder_WithPubSubSubscriptions_ValidConfig_SetsConfiguration() + { + // Arrange + var pubSubConfig = new ClusterPubSubSubscriptionConfig() + .WithChannel("test-channel") + .WithPattern("test-*") + .WithShardedChannel("shard-channel"); + + // Act + var builder = new ClusterClientConfigurationBuilder() + .WithAddress("localhost", 6379) + .WithPubSubSubscriptions(pubSubConfig); + + var config = builder.Build(); + + // Assert + Assert.NotNull(config.Request.PubSubSubscriptions); + Assert.Same(pubSubConfig, config.Request.PubSubSubscriptions); + } + + [Fact] + public void ClusterClientConfigurationBuilder_WithPubSubSubscriptions_NullConfig_ThrowsArgumentNullException() + { + // Arrange + var builder = new ClusterClientConfigurationBuilder() + .WithAddress("localhost", 6379); + + // Act & Assert + Assert.Throws(() => builder.WithPubSubSubscriptions(null!)); + } + + [Fact] + public void ClusterClientConfigurationBuilder_WithPubSubSubscriptions_InvalidConfig_ThrowsArgumentException() + { + // Arrange + var pubSubConfig = new ClusterPubSubSubscriptionConfig(); // Empty config - invalid + + var builder = new ClusterClientConfigurationBuilder() + .WithAddress("localhost", 6379); + + // Act & Assert + Assert.Throws(() => builder.WithPubSubSubscriptions(pubSubConfig)); + } + + [Fact] + public void ClusterClientConfigurationBuilder_WithPubSubSubscriptions_WithCallback_SetsCallbackAndContext() + { + // Arrange + var context = new { TestData = "test" }; + MessageCallback callback = (message, ctx) => { /* test callback */ }; + + var pubSubConfig = new ClusterPubSubSubscriptionConfig() + .WithChannel("test-channel") + .WithCallback(callback, context); + + // Act + var builder = new ClusterClientConfigurationBuilder() + .WithAddress("localhost", 6379) + .WithPubSubSubscriptions(pubSubConfig); + + var config = builder.Build(); + + // Assert + Assert.NotNull(config.Request.PubSubSubscriptions); + var storedConfig = config.Request.PubSubSubscriptions as ClusterPubSubSubscriptionConfig; + Assert.NotNull(storedConfig); + Assert.Same(callback, storedConfig.Callback); + Assert.Same(context, storedConfig.Context); + } + + [Fact] + public void ClusterClientConfigurationBuilder_WithPubSubSubscriptions_AllSubscriptionTypes_SetsAllSubscriptions() + { + // Arrange + var pubSubConfig = new ClusterPubSubSubscriptionConfig() + .WithChannel("channel1") + .WithChannel("channel2") + .WithPattern("pattern1*") + .WithPattern("pattern2*") + .WithShardedChannel("shard1") + .WithShardedChannel("shard2"); + + // Act + var builder = new ClusterClientConfigurationBuilder() + .WithAddress("localhost", 6379) + .WithPubSubSubscriptions(pubSubConfig); + + var config = builder.Build(); + + // Assert + Assert.NotNull(config.Request.PubSubSubscriptions); + var storedConfig = config.Request.PubSubSubscriptions as ClusterPubSubSubscriptionConfig; + Assert.NotNull(storedConfig); + + // Check exact channels (mode 0) + Assert.True(storedConfig.Subscriptions.ContainsKey(0)); + Assert.Contains("channel1", storedConfig.Subscriptions[0]); + Assert.Contains("channel2", storedConfig.Subscriptions[0]); + + // Check patterns (mode 1) + Assert.True(storedConfig.Subscriptions.ContainsKey(1)); + Assert.Contains("pattern1*", storedConfig.Subscriptions[1]); + Assert.Contains("pattern2*", storedConfig.Subscriptions[1]); + + // Check sharded channels (mode 2) + Assert.True(storedConfig.Subscriptions.ContainsKey(2)); + Assert.Contains("shard1", storedConfig.Subscriptions[2]); + Assert.Contains("shard2", storedConfig.Subscriptions[2]); + } + + #endregion + + #region Validation Tests + + [Fact] + public void StandaloneClientConfigurationBuilder_WithPubSubSubscriptions_ValidatesConfigurationDuringBuild() + { + // Arrange + var pubSubConfig = new StandalonePubSubSubscriptionConfig() + .WithChannel("test-channel"); + + // Act & Assert - Should not throw + var builder = new StandaloneClientConfigurationBuilder() + .WithAddress("localhost", 6379) + .WithPubSubSubscriptions(pubSubConfig); + + var config = builder.Build(); // This should succeed + Assert.NotNull(config); + } + + [Fact] + public void ClusterClientConfigurationBuilder_WithPubSubSubscriptions_ValidatesConfigurationDuringBuild() + { + // Arrange + var pubSubConfig = new ClusterPubSubSubscriptionConfig() + .WithShardedChannel("shard-channel"); + + // Act & Assert - Should not throw + var builder = new ClusterClientConfigurationBuilder() + .WithAddress("localhost", 6379) + .WithPubSubSubscriptions(pubSubConfig); + + var config = builder.Build(); // This should succeed + Assert.NotNull(config); + } + + [Fact] + public void StandaloneClientConfigurationBuilder_WithPubSubSubscriptions_EmptyChannelName_ThrowsArgumentException() + { + // Arrange + var builder = new StandaloneClientConfigurationBuilder() + .WithAddress("localhost", 6379); + + // Act & Assert + Assert.Throws(() => + { + var pubSubConfig = new StandalonePubSubSubscriptionConfig() + .WithChannel(""); // Empty channel name should be invalid + builder.WithPubSubSubscriptions(pubSubConfig); + }); + } + + [Fact] + public void ClusterClientConfigurationBuilder_WithPubSubSubscriptions_EmptyShardedChannelName_ThrowsArgumentException() + { + // Arrange + var builder = new ClusterClientConfigurationBuilder() + .WithAddress("localhost", 6379); + + // Act & Assert + Assert.Throws(() => + { + var pubSubConfig = new ClusterPubSubSubscriptionConfig() + .WithShardedChannel(""); // Empty sharded channel name should be invalid + builder.WithPubSubSubscriptions(pubSubConfig); + }); + } + + #endregion +} diff --git a/tests/Valkey.Glide.UnitTests/PubSubFFIIntegrationTests.cs b/tests/Valkey.Glide.UnitTests/PubSubFFIIntegrationTests.cs new file mode 100644 index 00000000..3eeb3496 --- /dev/null +++ b/tests/Valkey.Glide.UnitTests/PubSubFFIIntegrationTests.cs @@ -0,0 +1,221 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +using System; +using System.Runtime.InteropServices; + +using Valkey.Glide.Internals; + +using Xunit; + +namespace Valkey.Glide.UnitTests; + +public class PubSubFFIIntegrationTests +{ + [Fact] + public void MarshalPubSubMessage_WithValidExactChannelMessage_ReturnsCorrectMessage() + { + // Arrange + string message = "test message"; + string channel = "test-channel"; + + IntPtr messagePtr = Marshal.StringToHGlobalAnsi(message); + IntPtr channelPtr = Marshal.StringToHGlobalAnsi(channel); + + try + { + // Act + PubSubMessage result = FFI.MarshalPubSubMessage( + FFI.PushKind.PushMessage, + messagePtr, + message.Length, + channelPtr, + channel.Length, + IntPtr.Zero, + 0); + + // Assert + Assert.Equal("test message", result.Message); + Assert.Equal("test-channel", result.Channel); + Assert.Null(result.Pattern); + } + finally + { + Marshal.FreeHGlobal(messagePtr); + Marshal.FreeHGlobal(channelPtr); + } + } + + [Fact] + public void MarshalPubSubMessage_WithValidPatternMessage_ReturnsCorrectMessage() + { + // Arrange + string message = "pattern message"; + string channel = "news.sports"; + string pattern = "news.*"; + + IntPtr messagePtr = Marshal.StringToHGlobalAnsi(message); + IntPtr channelPtr = Marshal.StringToHGlobalAnsi(channel); + IntPtr patternPtr = Marshal.StringToHGlobalAnsi(pattern); + + try + { + // Act + PubSubMessage result = FFI.MarshalPubSubMessage( + FFI.PushKind.PushPMessage, + messagePtr, + message.Length, + channelPtr, + channel.Length, + patternPtr, + pattern.Length); + + // Assert + Assert.Equal("pattern message", result.Message); + Assert.Equal("news.sports", result.Channel); + Assert.Equal("news.*", result.Pattern); + } + finally + { + Marshal.FreeHGlobal(messagePtr); + Marshal.FreeHGlobal(channelPtr); + Marshal.FreeHGlobal(patternPtr); + } + } + + [Fact] + public void MarshalPubSubMessage_WithNullMessagePointer_ThrowsArgumentException() + { + // Arrange + string channel = "test-channel"; + IntPtr channelPtr = Marshal.StringToHGlobalAnsi(channel); + + try + { + // Act & Assert + ArgumentException ex = Assert.Throws(() => + FFI.MarshalPubSubMessage( + FFI.PushKind.PushMessage, + IntPtr.Zero, + 0, + channelPtr, + channel.Length, + IntPtr.Zero, + 0)); + Assert.Contains("Invalid message data", ex.Message); + } + finally + { + Marshal.FreeHGlobal(channelPtr); + } + } + + [Fact] + public void MarshalPubSubMessage_WithEmptyMessage_ThrowsArgumentException() + { + // Arrange + string message = ""; + string channel = "test-channel"; + + IntPtr messagePtr = Marshal.StringToHGlobalAnsi(message); + IntPtr channelPtr = Marshal.StringToHGlobalAnsi(channel); + + try + { + // Act & Assert + ArgumentException ex = Assert.Throws(() => + FFI.MarshalPubSubMessage( + FFI.PushKind.PushMessage, + messagePtr, + message.Length, + channelPtr, + channel.Length, + IntPtr.Zero, + 0)); + Assert.Contains("PubSub message content cannot be null or empty after marshaling", ex.Message); + } + finally + { + Marshal.FreeHGlobal(messagePtr); + Marshal.FreeHGlobal(channelPtr); + } + } + + [Fact] + public void MarshalPubSubMessage_WithEmptyChannel_ThrowsArgumentException() + { + // Arrange + string message = "test message"; + string channel = ""; + + IntPtr messagePtr = Marshal.StringToHGlobalAnsi(message); + IntPtr channelPtr = Marshal.StringToHGlobalAnsi(channel); + + try + { + // Act & Assert + ArgumentException ex = Assert.Throws(() => + FFI.MarshalPubSubMessage( + FFI.PushKind.PushMessage, + messagePtr, + message.Length, + channelPtr, + channel.Length, + IntPtr.Zero, + 0)); + Assert.Contains("Invalid channel data: pointer is null or length is zero", ex.Message); + } + finally + { + Marshal.FreeHGlobal(messagePtr); + Marshal.FreeHGlobal(channelPtr); + } + } + + [Fact] + public void MarshalPubSubMessage_WithShardedMessage_ReturnsCorrectMessage() + { + // Arrange + string message = "sharded message"; + string channel = "shard-channel"; + + IntPtr messagePtr = Marshal.StringToHGlobalAnsi(message); + IntPtr channelPtr = Marshal.StringToHGlobalAnsi(channel); + + try + { + // Act + PubSubMessage result = FFI.MarshalPubSubMessage( + FFI.PushKind.PushSMessage, + messagePtr, + message.Length, + channelPtr, + channel.Length, + IntPtr.Zero, + 0); + + // Assert + Assert.Equal("sharded message", result.Message); + Assert.Equal("shard-channel", result.Channel); + Assert.Null(result.Pattern); + } + finally + { + Marshal.FreeHGlobal(messagePtr); + Marshal.FreeHGlobal(channelPtr); + } + } + + // Mock class for testing + private class MockBaseClient : BaseClient + { + protected override Task InitializeServerVersionAsync() + { + return Task.CompletedTask; + } + + internal override void HandlePubSubMessage(PubSubMessage message) + { + // Mock implementation + } + } +} diff --git a/tests/Valkey.Glide.UnitTests/PubSubFFIMemoryLeakTests.cs b/tests/Valkey.Glide.UnitTests/PubSubFFIMemoryLeakTests.cs new file mode 100644 index 00000000..c2e2cbf7 --- /dev/null +++ b/tests/Valkey.Glide.UnitTests/PubSubFFIMemoryLeakTests.cs @@ -0,0 +1,392 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +using System.Diagnostics; +using System.Runtime.InteropServices; + +namespace Valkey.Glide.UnitTests; + +/// +/// Tests for detecting memory leaks in the FFI layer during PubSub message processing. +/// These tests are designed to detect the critical memory leak issue where Rust-allocated +/// memory is not properly freed after C# marshaling. +/// +public class PubSubFFIMemoryLeakTests +{ + [Fact] + public void ProcessLargeVolumeMessages_NoMemoryLeak_MemoryUsageRemainsBounded() + { + // Arrange + const int messageCount = 100_000; + const long maxMemoryGrowthBytes = 50_000_000; // 50MB max growth allowed + + // Force initial GC to get baseline + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + long initialMemory = GC.GetTotalMemory(false); + Console.WriteLine($"Initial memory: {initialMemory:N0} bytes"); + + // Act: Process large volume of messages + for (int i = 0; i < messageCount; i++) + { + string message = $"test-message-{i}"; + string channel = $"test-channel-{i % 100}"; // Vary channels + string? pattern = i % 3 == 0 ? $"pattern-{i % 10}" : null; // Some with patterns + + ProcessSingleMessage(message, channel, pattern); + + // Periodic GC to detect leaks early + if (i % 10_000 == 0 && i > 0) + { + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + long currentMemory = GC.GetTotalMemory(false); + long memoryGrowth = currentMemory - initialMemory; + + Console.WriteLine($"Processed {i:N0} messages, memory growth: {memoryGrowth:N0} bytes"); + + // Early detection of memory leaks + if (memoryGrowth > maxMemoryGrowthBytes) + { + Assert.True(false, + $"Memory leak detected after {i:N0} messages. " + + $"Memory grew by {memoryGrowth:N0} bytes, exceeding limit of {maxMemoryGrowthBytes:N0} bytes."); + } + } + } + + // Final memory check + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + long finalMemory = GC.GetTotalMemory(false); + long totalMemoryGrowth = finalMemory - initialMemory; + + Console.WriteLine($"Final memory: {finalMemory:N0} bytes"); + Console.WriteLine($"Total memory growth: {totalMemoryGrowth:N0} bytes"); + Console.WriteLine($"Processed {messageCount:N0} messages successfully"); + + // Assert: Memory growth should be bounded + Assert.True(totalMemoryGrowth < maxMemoryGrowthBytes, + $"Memory leak detected. Total memory growth: {totalMemoryGrowth:N0} bytes, " + + $"limit: {maxMemoryGrowthBytes:N0} bytes"); + } + + [Fact] + public void ProcessVariousMessageSizes_NoMemoryLeak_ConsistentBehavior() + { + // Arrange + const int iterationsPerSize = 1_000; + int[] messageSizes = [10, 100, 1_000, 10_000, 100_000]; // Various sizes + + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + long initialMemory = GC.GetTotalMemory(false); + Console.WriteLine($"Initial memory: {initialMemory:N0} bytes"); + + // Act: Test different message sizes + foreach (int messageSize in messageSizes) + { + Console.WriteLine($"Testing message size: {messageSize} bytes"); + + string largeMessage = new('X', messageSize); + string channel = "test-channel"; + + long beforeSizeTest = GC.GetTotalMemory(true); + + for (int i = 0; i < iterationsPerSize; i++) + { + ProcessSingleMessage(largeMessage, channel, null); + } + + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + long afterSizeTest = GC.GetTotalMemory(false); + long sizeTestGrowth = afterSizeTest - beforeSizeTest; + + Console.WriteLine($"Memory growth for {messageSize}-byte messages: {sizeTestGrowth:N0} bytes"); + + // Memory growth should be reasonable for the message size + long expectedMaxGrowth = messageSize * iterationsPerSize * 2; // Allow 2x overhead + Assert.True(sizeTestGrowth < expectedMaxGrowth, + $"Excessive memory growth for {messageSize}-byte messages: {sizeTestGrowth:N0} bytes"); + } + + // Final check + long finalMemory = GC.GetTotalMemory(true); + long totalGrowth = finalMemory - initialMemory; + + Console.WriteLine($"Total memory growth across all sizes: {totalGrowth:N0} bytes"); + + // Should not have significant permanent growth + Assert.True(totalGrowth < 10_000_000, // 10MB max + $"Permanent memory growth detected: {totalGrowth:N0} bytes"); + } + + [Fact] + public void ProcessMessagesUnderGCPressure_NoMemoryLeak_StableUnderPressure() + { + // Arrange + const int messageCount = 50_000; + const int gcInterval = 1_000; // Force GC every 1000 messages + + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + long initialMemory = GC.GetTotalMemory(false); + Console.WriteLine($"Initial memory under GC pressure test: {initialMemory:N0} bytes"); + + // Act: Process messages with frequent GC pressure + for (int i = 0; i < messageCount; i++) + { + string message = $"gc-pressure-message-{i}"; + string channel = "gc-test-channel"; + + ProcessSingleMessage(message, channel, null); + + // Apply GC pressure frequently + if (i % gcInterval == 0) + { + // Create some temporary objects to increase GC pressure + object[] tempObjects = new object[1000]; + for (int j = 0; j < tempObjects.Length; j++) + { + tempObjects[j] = new byte[1024]; // 1KB objects + } + + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + // Clear temp objects + Array.Clear(tempObjects, 0, tempObjects.Length); + } + } + + // Final memory check under GC pressure + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + long finalMemory = GC.GetTotalMemory(false); + long memoryGrowth = finalMemory - initialMemory; + + Console.WriteLine($"Memory growth under GC pressure: {memoryGrowth:N0} bytes"); + + // Assert: Should remain stable even under GC pressure + Assert.True(memoryGrowth < 20_000_000, // 20MB max under pressure + $"Memory leak detected under GC pressure: {memoryGrowth:N0} bytes"); + } + + [Fact] + public void ProcessConcurrentMessages_NoMemoryLeak_ThreadSafeMemoryManagement() + { + // Arrange + const int threadsCount = 10; + const int messagesPerThread = 10_000; + + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + long initialMemory = GC.GetTotalMemory(false); + Console.WriteLine($"Initial memory for concurrent test: {initialMemory:N0} bytes"); + + // Act: Process messages concurrently from multiple threads + Task[] tasks = new Task[threadsCount]; + Exception?[] exceptions = new Exception?[threadsCount]; + + for (int threadIndex = 0; threadIndex < threadsCount; threadIndex++) + { + int capturedIndex = threadIndex; + tasks[threadIndex] = Task.Run(() => + { + try + { + for (int i = 0; i < messagesPerThread; i++) + { + string message = $"concurrent-message-{capturedIndex}-{i}"; + string channel = $"concurrent-channel-{capturedIndex}"; + string? pattern = i % 2 == 0 ? $"pattern-{capturedIndex}" : null; + + ProcessSingleMessage(message, channel, pattern); + } + } + catch (Exception ex) + { + exceptions[capturedIndex] = ex; + } + }); + } + + // Wait for all tasks to complete + Task.WaitAll(tasks); + + // Check for exceptions + for (int i = 0; i < exceptions.Length; i++) + { + if (exceptions[i] != null) + { + throw new AggregateException($"Thread {i} failed", exceptions[i]); + } + } + + // Final memory check + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + long finalMemory = GC.GetTotalMemory(false); + long memoryGrowth = finalMemory - initialMemory; + + Console.WriteLine($"Memory growth after concurrent processing: {memoryGrowth:N0} bytes"); + Console.WriteLine($"Processed {threadsCount * messagesPerThread:N0} messages concurrently"); + + // Assert: Memory should remain bounded even with concurrent access + Assert.True(memoryGrowth < 30_000_000, // 30MB max for concurrent test + $"Memory leak detected in concurrent processing: {memoryGrowth:N0} bytes"); + } + + [Fact] + public void ProcessExtendedDuration_NoMemoryLeak_StableOverTime() + { + // Arrange + const int durationSeconds = 30; // Run for 30 seconds + const int messagesPerSecond = 1000; + + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + long initialMemory = GC.GetTotalMemory(false); + Console.WriteLine($"Starting extended duration test for {durationSeconds} seconds"); + Console.WriteLine($"Initial memory: {initialMemory:N0} bytes"); + + Stopwatch stopwatch = Stopwatch.StartNew(); + int messageCount = 0; + List<(TimeSpan Time, long Memory)> memorySnapshots = []; + + // Act: Process messages for extended duration + while (stopwatch.Elapsed.TotalSeconds < durationSeconds) + { + for (int i = 0; i < messagesPerSecond && stopwatch.Elapsed.TotalSeconds < durationSeconds; i++) + { + string message = $"duration-test-{messageCount}"; + string channel = $"duration-channel-{messageCount % 10}"; + + ProcessSingleMessage(message, channel, null); + messageCount++; + } + + // Take memory snapshot every 5 seconds + if (stopwatch.Elapsed.TotalSeconds % 5 < 0.1) + { + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + long currentMemory = GC.GetTotalMemory(false); + memorySnapshots.Add((stopwatch.Elapsed, currentMemory)); + + Console.WriteLine($"Time: {stopwatch.Elapsed.TotalSeconds:F1}s, " + + $"Messages: {messageCount:N0}, " + + $"Memory: {currentMemory:N0} bytes"); + } + + Thread.Sleep(1); // Small delay to prevent tight loop + } + + stopwatch.Stop(); + + // Final memory check + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + long finalMemory = GC.GetTotalMemory(false); + long totalGrowth = finalMemory - initialMemory; + + Console.WriteLine($"Extended test completed:"); + Console.WriteLine($"Duration: {stopwatch.Elapsed.TotalSeconds:F1} seconds"); + Console.WriteLine($"Messages processed: {messageCount:N0}"); + Console.WriteLine($"Final memory: {finalMemory:N0} bytes"); + Console.WriteLine($"Total memory growth: {totalGrowth:N0} bytes"); + + // Check for memory growth trend + if (memorySnapshots.Count >= 2) + { + long firstSnapshot = memorySnapshots[0].Memory; + long lastSnapshot = memorySnapshots[^1].Memory; + long trendGrowth = lastSnapshot - firstSnapshot; + + Console.WriteLine($"Memory trend growth: {trendGrowth:N0} bytes"); + + // Memory should not continuously grow over time + Assert.True(trendGrowth < 25_000_000, // 25MB max trend growth + $"Continuous memory growth detected: {trendGrowth:N0} bytes over time"); + } + + // Assert: Total memory growth should be reasonable + Assert.True(totalGrowth < 40_000_000, // 40MB max for extended test + $"Excessive memory growth over extended duration: {totalGrowth:N0} bytes"); + } + + /// + /// Helper method to simulate processing a single PubSub message through the FFI marshaling layer. + /// This simulates the memory allocation and marshaling that occurs in the real FFI callback. + /// + private static void ProcessSingleMessage(string message, string channel, string? pattern) + { + // Simulate the FFI marshaling process that occurs in the real callback + IntPtr messagePtr = Marshal.StringToHGlobalAnsi(message); + IntPtr channelPtr = Marshal.StringToHGlobalAnsi(channel); + IntPtr patternPtr = pattern != null ? Marshal.StringToHGlobalAnsi(pattern) : IntPtr.Zero; + + try + { + // Simulate marshaling by creating byte arrays (as the real FFI callback does) + byte[] messageBytes = new byte[message.Length]; + Marshal.Copy(messagePtr, messageBytes, 0, message.Length); + + byte[] channelBytes = new byte[channel.Length]; + Marshal.Copy(channelPtr, channelBytes, 0, channel.Length); + + byte[]? patternBytes = null; + if (pattern != null && patternPtr != IntPtr.Zero) + { + patternBytes = new byte[pattern.Length]; + Marshal.Copy(patternPtr, patternBytes, 0, pattern.Length); + } + + // Create PubSubMessage (simulating what the real callback does) + PubSubMessage result = pattern != null + ? new PubSubMessage(message, channel, pattern) + : new PubSubMessage(message, channel); + + // Verify the message was created correctly + if (result.Message != message || result.Channel != channel || result.Pattern != pattern) + { + throw new InvalidOperationException("Message marshaling failed"); + } + } + finally + { + // Clean up allocated memory (this is what C# should do) + Marshal.FreeHGlobal(messagePtr); + Marshal.FreeHGlobal(channelPtr); + if (patternPtr != IntPtr.Zero) + { + Marshal.FreeHGlobal(patternPtr); + } + } + } +} diff --git a/tests/Valkey.Glide.UnitTests/PubSubGracefulShutdownTests.cs b/tests/Valkey.Glide.UnitTests/PubSubGracefulShutdownTests.cs new file mode 100644 index 00000000..409bc60a --- /dev/null +++ b/tests/Valkey.Glide.UnitTests/PubSubGracefulShutdownTests.cs @@ -0,0 +1,93 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +using System.Threading.Channels; + +namespace Valkey.Glide.UnitTests; + +/// +/// Tests for graceful shutdown coordination in PubSub processing. +/// Validates Requirements: 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 9.2 +/// +public class PubSubGracefulShutdownTests +{ + [Fact] + public void PubSubPerformanceConfig_DefaultShutdownTimeout_IsCorrect() + { + PubSubPerformanceConfig config = new(); + Assert.Equal(TimeSpan.FromSeconds(5), config.ShutdownTimeout); + } + + [Fact] + public void PubSubPerformanceConfig_CustomShutdownTimeout_CanBeSet() + { + TimeSpan customTimeout = TimeSpan.FromSeconds(10); + PubSubPerformanceConfig config = new() { ShutdownTimeout = customTimeout }; + Assert.Equal(customTimeout, config.ShutdownTimeout); + } + + [Fact] + public void PubSubPerformanceConfig_InvalidShutdownTimeout_ThrowsException() + { + PubSubPerformanceConfig config = new() { ShutdownTimeout = TimeSpan.FromSeconds(-1) }; + Assert.Throws(() => config.Validate()); + } + + [Fact] + public async Task ChannelBasedProcessing_CancellationToken_IsRespected() + { + Channel channel = Channel.CreateBounded(10); + CancellationTokenSource cts = new(); + int messagesProcessed = 0; + + Task processingTask = Task.Run(async () => + { + try + { + await foreach (int message in channel.Reader.ReadAllAsync(cts.Token)) + { + _ = Interlocked.Increment(ref messagesProcessed); + } + } + catch (OperationCanceledException) + { + // Expected when cancelled + } + }); + + await channel.Writer.WriteAsync(1); + await channel.Writer.WriteAsync(2); + await Task.Delay(50); + + cts.Cancel(); + channel.Writer.Complete(); + await processingTask; + + Assert.True(messagesProcessed >= 0); + } + + [Fact] + public async Task ChannelCompletion_StopsProcessing_Gracefully() + { + Channel channel = Channel.CreateBounded(10); + int messagesProcessed = 0; + bool processingCompleted = false; + + Task processingTask = Task.Run(async () => + { + await foreach (int message in channel.Reader.ReadAllAsync()) + { + _ = Interlocked.Increment(ref messagesProcessed); + } + processingCompleted = true; + }); + + await channel.Writer.WriteAsync(1); + await channel.Writer.WriteAsync(2); + await channel.Writer.WriteAsync(3); + channel.Writer.Complete(); + await processingTask; + + Assert.Equal(3, messagesProcessed); + Assert.True(processingCompleted); + } +} diff --git a/tests/Valkey.Glide.UnitTests/PubSubMemoryLeakFixValidationTests.cs b/tests/Valkey.Glide.UnitTests/PubSubMemoryLeakFixValidationTests.cs new file mode 100644 index 00000000..09e4972f --- /dev/null +++ b/tests/Valkey.Glide.UnitTests/PubSubMemoryLeakFixValidationTests.cs @@ -0,0 +1,162 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +namespace Valkey.Glide.UnitTests; + +/// +/// Simple validation tests to verify the memory leak fix in FFI message processing. +/// These tests validate that the MarshalPubSubMessage function works correctly +/// without causing memory leaks. +/// +public class PubSubMemoryLeakFixValidationTests +{ + [Fact] + public void MarshalPubSubMessage_ProcessMultipleMessages_NoMemoryLeak() + { + // Arrange + const int messageCount = 10_000; + + // Force initial GC to get baseline + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + long initialMemory = GC.GetTotalMemory(false); + + // Act: Process multiple messages + for (int i = 0; i < messageCount; i++) + { + string message = $"test-message-{i}"; + string channel = $"test-channel-{i % 10}"; + + ProcessSingleMessage(message, channel, null); + } + + // Force GC to clean up any leaked memory + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + long finalMemory = GC.GetTotalMemory(false); + long memoryGrowth = finalMemory - initialMemory; + + Console.WriteLine($"Processed {messageCount:N0} messages"); + Console.WriteLine($"Memory growth: {memoryGrowth:N0} bytes"); + + // Assert: Memory growth should be reasonable (less than 5MB for 10k messages) + Assert.True(memoryGrowth < 5_000_000, + $"Excessive memory growth detected: {memoryGrowth:N0} bytes for {messageCount} messages"); + } + + [Fact] + public void MarshalPubSubMessage_WithPatternMessages_HandlesCorrectly() + { + // Arrange + string message = "pattern test message"; + string channel = "news.sports"; + string pattern = "news.*"; + + IntPtr messagePtr = Marshal.StringToHGlobalAnsi(message); + IntPtr channelPtr = Marshal.StringToHGlobalAnsi(channel); + IntPtr patternPtr = Marshal.StringToHGlobalAnsi(pattern); + + try + { + // Act + PubSubMessage result = FFI.MarshalPubSubMessage( + FFI.PushKind.PushPMessage, + messagePtr, + message.Length, + channelPtr, + channel.Length, + patternPtr, + pattern.Length); + + // Assert + Assert.Equal(message, result.Message); + Assert.Equal(channel, result.Channel); + Assert.Equal(pattern, result.Pattern); + } + finally + { + // Clean up allocated memory + Marshal.FreeHGlobal(messagePtr); + Marshal.FreeHGlobal(channelPtr); + Marshal.FreeHGlobal(patternPtr); + } + } + + [Fact] + public void MarshalPubSubMessage_WithShardedMessages_HandlesCorrectly() + { + // Arrange + string message = "sharded test message"; + string channel = "shard-channel"; + + IntPtr messagePtr = Marshal.StringToHGlobalAnsi(message); + IntPtr channelPtr = Marshal.StringToHGlobalAnsi(channel); + + try + { + // Act + PubSubMessage result = FFI.MarshalPubSubMessage( + FFI.PushKind.PushSMessage, + messagePtr, + message.Length, + channelPtr, + channel.Length, + IntPtr.Zero, + 0); + + // Assert + Assert.Equal(message, result.Message); + Assert.Equal(channel, result.Channel); + Assert.Null(result.Pattern); + } + finally + { + // Clean up allocated memory + Marshal.FreeHGlobal(messagePtr); + Marshal.FreeHGlobal(channelPtr); + } + } + + /// + /// Helper method to simulate processing a single PubSub message through the FFI marshaling layer. + /// + private static void ProcessSingleMessage(string message, string channel, string? pattern) + { + IntPtr messagePtr = Marshal.StringToHGlobalAnsi(message); + IntPtr channelPtr = Marshal.StringToHGlobalAnsi(channel); + IntPtr patternPtr = pattern != null ? Marshal.StringToHGlobalAnsi(pattern) : IntPtr.Zero; + + try + { + FFI.PushKind pushKind = pattern != null ? FFI.PushKind.PushPMessage : FFI.PushKind.PushMessage; + + PubSubMessage result = FFI.MarshalPubSubMessage( + pushKind, + messagePtr, + message.Length, + channelPtr, + channel.Length, + patternPtr, + pattern?.Length ?? 0); + + // Verify the message was marshaled correctly + if (result.Message != message || result.Channel != channel || result.Pattern != pattern) + { + throw new InvalidOperationException("Message marshaling failed"); + } + } + finally + { + // Clean up allocated memory + Marshal.FreeHGlobal(messagePtr); + Marshal.FreeHGlobal(channelPtr); + if (patternPtr != IntPtr.Zero) + { + Marshal.FreeHGlobal(patternPtr); + } + } + } +} diff --git a/tests/Valkey.Glide.UnitTests/PubSubMessageHandlerTests.cs b/tests/Valkey.Glide.UnitTests/PubSubMessageHandlerTests.cs new file mode 100644 index 00000000..cc8d5521 --- /dev/null +++ b/tests/Valkey.Glide.UnitTests/PubSubMessageHandlerTests.cs @@ -0,0 +1,296 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +using System.Threading; +using System.Threading.Tasks; + +using Xunit; + +namespace Valkey.Glide.UnitTests; + +public class PubSubMessageHandlerTests +{ + [Fact] + public void Constructor_WithCallback_InitializesCorrectly() + { + // Arrange + MessageCallback callback = new MessageCallback((msg, ctx) => { }); + object context = new object(); + + // Act + using PubSubMessageHandler handler = new PubSubMessageHandler(callback, context); + + // Assert - GetQueue should throw when callback is configured + Assert.Throws(() => handler.GetQueue()); + } + + [Fact] + public void Constructor_WithoutCallback_InitializesCorrectly() + { + // Act + using PubSubMessageHandler handler = new PubSubMessageHandler(null, null); + + // Assert + Assert.NotNull(handler.GetQueue()); + } + + [Fact] + public void HandleMessage_WithCallback_InvokesCallback() + { + // Arrange + bool callbackInvoked = false; + PubSubMessage? receivedMessage = null; + object? receivedContext = null; + object context = new object(); + + MessageCallback callback = new MessageCallback((msg, ctx) => + { + callbackInvoked = true; + receivedMessage = msg; + receivedContext = ctx; + }); + + using PubSubMessageHandler handler = new PubSubMessageHandler(callback, context); + PubSubMessage message = new PubSubMessage("test-message", "test-channel"); + + // Act + handler.HandleMessage(message); + + // Assert + Assert.True(callbackInvoked); + Assert.Equal(message, receivedMessage); + Assert.Equal(context, receivedContext); + } + + [Fact] + public void HandleMessage_WithoutCallback_QueuesMessage() + { + // Arrange + using PubSubMessageHandler handler = new PubSubMessageHandler(null, null); + PubSubMessage message = new PubSubMessage("test-message", "test-channel"); + + // Act + handler.HandleMessage(message); + + // Assert + PubSubMessageQueue queue = handler.GetQueue(); + Assert.Equal(1, queue.Count); + Assert.True(queue.TryGetMessage(out PubSubMessage? queuedMessage)); + Assert.Equal(message, queuedMessage); + } + + [Fact] + public void HandleMessage_CallbackThrowsException_DoesNotPropagate() + { + // Arrange + bool exceptionThrown = false; + + MessageCallback callback = new MessageCallback((msg, ctx) => + { + exceptionThrown = true; + throw new InvalidOperationException("Test exception"); + }); + + using PubSubMessageHandler handler = new PubSubMessageHandler(callback, null); + PubSubMessage message = new PubSubMessage("test-message", "test-channel"); + + // Act & Assert - Exception should be caught and not propagate + handler.HandleMessage(message); + + Assert.True(exceptionThrown); + } + + [Fact] + public void HandleMessage_MultipleMessages_InvokesCallbackInOrder() + { + // Arrange + List receivedMessages = new List(); + MessageCallback callback = new MessageCallback((msg, ctx) => receivedMessages.Add(msg)); + + using PubSubMessageHandler handler = new PubSubMessageHandler(callback, null); + PubSubMessage message1 = new PubSubMessage("message1", "channel1"); + PubSubMessage message2 = new PubSubMessage("message2", "channel2"); + PubSubMessage message3 = new PubSubMessage("message3", "channel3"); + + // Act + handler.HandleMessage(message1); + handler.HandleMessage(message2); + handler.HandleMessage(message3); + + // Assert + Assert.Equal(3, receivedMessages.Count); + Assert.Equal(message1, receivedMessages[0]); + Assert.Equal(message2, receivedMessages[1]); + Assert.Equal(message3, receivedMessages[2]); + } + + [Fact] + public void HandleMessage_PatternMessage_InvokesCallbackCorrectly() + { + // Arrange + PubSubMessage? receivedMessage = null; + MessageCallback callback = new MessageCallback((msg, ctx) => receivedMessage = msg); + + using PubSubMessageHandler handler = new PubSubMessageHandler(callback, null); + PubSubMessage message = new PubSubMessage("test-message", "test-channel", "test-pattern"); + + // Act + handler.HandleMessage(message); + + // Assert + Assert.NotNull(receivedMessage); + Assert.Equal("test-message", receivedMessage.Message); + Assert.Equal("test-channel", receivedMessage.Channel); + Assert.Equal("test-pattern", receivedMessage.Pattern); + } + + [Fact] + public void HandleMessage_NullMessage_ThrowsArgumentNullException() + { + // Arrange + using PubSubMessageHandler handler = new PubSubMessageHandler(null, null); + + // Act & Assert + Assert.Throws(() => handler.HandleMessage(null!)); + } + + [Fact] + public void HandleMessage_DisposedHandler_ThrowsObjectDisposedException() + { + // Arrange + PubSubMessageHandler handler = new PubSubMessageHandler(null, null); + handler.Dispose(); + PubSubMessage message = new PubSubMessage("test-message", "test-channel"); + + // Act & Assert + Assert.Throws(() => handler.HandleMessage(message)); + } + + [Fact] + public void GetQueue_ReturnsValidQueue() + { + // Arrange + using PubSubMessageHandler handler = new PubSubMessageHandler(null, null); + + // Act + PubSubMessageQueue queue = handler.GetQueue(); + + // Assert + Assert.NotNull(queue); + Assert.Equal(0, queue.Count); + } + + [Fact] + public void GetQueue_DisposedHandler_ThrowsObjectDisposedException() + { + // Arrange + PubSubMessageHandler handler = new PubSubMessageHandler(null, null); + handler.Dispose(); + + // Act & Assert + Assert.Throws(() => handler.GetQueue()); + } + + [Fact] + public void Dispose_MultipleCalls_DoesNotThrow() + { + // Arrange + PubSubMessageHandler handler = new PubSubMessageHandler(null, null); + + // Act & Assert - Should not throw + handler.Dispose(); + handler.Dispose(); + handler.Dispose(); + } + + [Fact] + public void HandleMessage_CallbackWithNullContext_WorksCorrectly() + { + // Arrange + bool callbackInvoked = false; + object? receivedContext = new object(); // Initialize with non-null to verify it gets set to null + + MessageCallback callback = new MessageCallback((msg, ctx) => + { + callbackInvoked = true; + receivedContext = ctx; + }); + + using PubSubMessageHandler handler = new PubSubMessageHandler(callback, null); + PubSubMessage message = new PubSubMessage("test-message", "test-channel"); + + // Act + handler.HandleMessage(message); + + // Assert + Assert.True(callbackInvoked); + Assert.Null(receivedContext); + } + + [Fact] + public void HandleMessage_ConcurrentAccess_HandlesCorrectly() + { + // Arrange + List receivedMessages = new List(); + object lockObject = new object(); + MessageCallback callback = new MessageCallback((msg, ctx) => + { + lock (lockObject) + { + receivedMessages.Add(msg); + } + }); + + using PubSubMessageHandler handler = new PubSubMessageHandler(callback, null); + PubSubMessage[] messages = new[] + { + new PubSubMessage("message1", "channel1"), + new PubSubMessage("message2", "channel2"), + new PubSubMessage("message3", "channel3") + }; + + // Act + Task[] tasks = messages.Select(msg => Task.Run(() => handler.HandleMessage(msg))).ToArray(); + Task.WaitAll(tasks); + + // Assert + Assert.Equal(3, receivedMessages.Count); + Assert.Contains(messages[0], receivedMessages); + Assert.Contains(messages[1], receivedMessages); + Assert.Contains(messages[2], receivedMessages); + } + + [Fact] + public void HandleMessage_DisposedDuringCallback_HandlesGracefully() + { + // Arrange + ManualResetEventSlim callbackStarted = new ManualResetEventSlim(false); + ManualResetEventSlim disposeStarted = new ManualResetEventSlim(false); + bool callbackCompleted = false; + + MessageCallback callback = new MessageCallback((msg, ctx) => + { + callbackStarted.Set(); + disposeStarted.Wait(TimeSpan.FromSeconds(5)); // Wait for dispose to start + Thread.Sleep(100); // Simulate some work + callbackCompleted = true; + }); + + PubSubMessageHandler handler = new PubSubMessageHandler(callback, null); + PubSubMessage message = new PubSubMessage("test-message", "test-channel"); + + // Act + Task handleTask = Task.Run(() => handler.HandleMessage(message)); + callbackStarted.Wait(TimeSpan.FromSeconds(5)); + + Task disposeTask = Task.Run(() => + { + disposeStarted.Set(); + handler.Dispose(); + }); + + Task.WaitAll(handleTask, disposeTask); + + // Assert + Assert.True(callbackCompleted); + } +} diff --git a/tests/Valkey.Glide.UnitTests/PubSubMessageQueueTests.cs b/tests/Valkey.Glide.UnitTests/PubSubMessageQueueTests.cs new file mode 100644 index 00000000..c13a3878 --- /dev/null +++ b/tests/Valkey.Glide.UnitTests/PubSubMessageQueueTests.cs @@ -0,0 +1,406 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +using System.Collections.Concurrent; + +using Valkey.Glide; + +using Xunit; + +namespace Valkey.Glide.UnitTests; + +public class PubSubMessageQueueTests +{ + [Fact] + public void Constructor_InitializesEmptyQueue() + { + // Arrange & Act + using var queue = new PubSubMessageQueue(); + + // Assert + Assert.Equal(0, queue.Count); + } + + [Fact] + public void TryGetMessage_EmptyQueue_ReturnsFalse() + { + // Arrange + using var queue = new PubSubMessageQueue(); + + // Act + bool result = queue.TryGetMessage(out PubSubMessage? message); + + // Assert + Assert.False(result); + Assert.Null(message); + } + + [Fact] + public void TryGetMessage_WithMessage_ReturnsTrue() + { + // Arrange + using var queue = new PubSubMessageQueue(); + var testMessage = new PubSubMessage("test-message", "test-channel"); + queue.EnqueueMessage(testMessage); + + // Act + bool result = queue.TryGetMessage(out PubSubMessage? message); + + // Assert + Assert.True(result); + Assert.NotNull(message); + Assert.Equal("test-message", message.Message); + Assert.Equal("test-channel", message.Channel); + Assert.Equal(0, queue.Count); + } + + [Fact] + public void EnqueueMessage_NullMessage_ThrowsArgumentNullException() + { + // Arrange + using var queue = new PubSubMessageQueue(); + + // Act & Assert + Assert.Throws(() => queue.EnqueueMessage(null!)); + } + + [Fact] + public void EnqueueMessage_ValidMessage_IncreasesCount() + { + // Arrange + using var queue = new PubSubMessageQueue(); + var testMessage = new PubSubMessage("test-message", "test-channel"); + + // Act + queue.EnqueueMessage(testMessage); + + // Assert + Assert.Equal(1, queue.Count); + } + + [Fact] + public void EnqueueMessage_MultipleMessages_MaintainsOrder() + { + // Arrange + using var queue = new PubSubMessageQueue(); + var message1 = new PubSubMessage("message1", "channel1"); + var message2 = new PubSubMessage("message2", "channel2"); + var message3 = new PubSubMessage("message3", "channel3"); + + // Act + queue.EnqueueMessage(message1); + queue.EnqueueMessage(message2); + queue.EnqueueMessage(message3); + + // Assert + Assert.Equal(3, queue.Count); + + Assert.True(queue.TryGetMessage(out PubSubMessage? retrievedMessage1)); + Assert.Equal("message1", retrievedMessage1!.Message); + + Assert.True(queue.TryGetMessage(out PubSubMessage? retrievedMessage2)); + Assert.Equal("message2", retrievedMessage2!.Message); + + Assert.True(queue.TryGetMessage(out PubSubMessage? retrievedMessage3)); + Assert.Equal("message3", retrievedMessage3!.Message); + + Assert.Equal(0, queue.Count); + } + + [Fact] + public async Task GetMessageAsync_WithMessage_ReturnsImmediately() + { + // Arrange + using var queue = new PubSubMessageQueue(); + var testMessage = new PubSubMessage("test-message", "test-channel"); + queue.EnqueueMessage(testMessage); + + // Act + PubSubMessage result = await queue.GetMessageAsync(); + + // Assert + Assert.Equal("test-message", result.Message); + Assert.Equal("test-channel", result.Channel); + Assert.Equal(0, queue.Count); + } + + [Fact] + public async Task GetMessageAsync_EmptyQueue_WaitsForMessage() + { + // Arrange + using var queue = new PubSubMessageQueue(); + var testMessage = new PubSubMessage("test-message", "test-channel"); + + // Act + Task getMessageTask = queue.GetMessageAsync(); + + // Ensure the task is waiting + await Task.Delay(50); + Assert.False(getMessageTask.IsCompleted); + + // Enqueue a message + queue.EnqueueMessage(testMessage); + + // Wait for the task to complete + PubSubMessage result = await getMessageTask; + + // Assert + Assert.Equal("test-message", result.Message); + Assert.Equal("test-channel", result.Channel); + } + + [Fact] + public async Task GetMessageAsync_WithCancellation_ThrowsOperationCanceledException() + { + // Arrange + using var queue = new PubSubMessageQueue(); + using CancellationTokenSource cts = new(); + + // Act + Task getMessageTask = queue.GetMessageAsync(cts.Token); + + // Cancel after a short delay + cts.CancelAfter(50); + + // Assert + await Assert.ThrowsAsync(() => getMessageTask); + } + + [Fact] + public async Task GetMessagesAsync_YieldsMessages() + { + // Arrange + using var queue = new PubSubMessageQueue(); + var message1 = new PubSubMessage("message1", "channel1"); + var message2 = new PubSubMessage("message2", "channel2"); + var message3 = new PubSubMessage("message3", "channel3"); + + queue.EnqueueMessage(message1); + queue.EnqueueMessage(message2); + queue.EnqueueMessage(message3); + + // Act + List messages = []; + using CancellationTokenSource cts = new(); + + await foreach (PubSubMessage message in queue.GetMessagesAsync(cts.Token)) + { + messages.Add(message); + if (messages.Count == 3) + { + cts.Cancel(); // Stop enumeration after 3 messages + } + } + + // Assert + Assert.Equal(3, messages.Count); + Assert.Equal("message1", messages[0].Message); + Assert.Equal("message2", messages[1].Message); + Assert.Equal("message3", messages[2].Message); + } + + [Fact] + public async Task GetMessagesAsync_WithCancellation_StopsEnumeration() + { + // Arrange + using var queue = new PubSubMessageQueue(); + var message1 = new PubSubMessage("message1", "channel1"); + queue.EnqueueMessage(message1); + + using CancellationTokenSource cts = new(); + cts.CancelAfter(100); // Cancel after 100ms + + // Act + List messages = []; + await foreach (PubSubMessage message in queue.GetMessagesAsync(cts.Token)) + { + messages.Add(message); + // Don't add more messages, so enumeration will wait and then be cancelled + } + + // Assert + Assert.Single(messages); + Assert.Equal("message1", messages[0].Message); + } + + [Fact] + public void TryGetMessage_AfterDispose_ThrowsObjectDisposedException() + { + // Arrange + var queue = new PubSubMessageQueue(); + queue.Dispose(); + + // Act & Assert + Assert.Throws(() => queue.TryGetMessage(out _)); + } + + [Fact] + public async Task GetMessageAsync_AfterDispose_ThrowsObjectDisposedException() + { + // Arrange + var queue = new PubSubMessageQueue(); + queue.Dispose(); + + // Act & Assert + await Assert.ThrowsAsync(() => queue.GetMessageAsync()); + } + + [Fact] + public void EnqueueMessage_AfterDispose_ThrowsObjectDisposedException() + { + // Arrange + var queue = new PubSubMessageQueue(); + var testMessage = new PubSubMessage("test-message", "test-channel"); + queue.Dispose(); + + // Act & Assert + Assert.Throws(() => queue.EnqueueMessage(testMessage)); + } + + [Fact] + public async Task GetMessagesAsync_AfterDispose_StopsEnumeration() + { + // Arrange + using var queue = new PubSubMessageQueue(); + var message1 = new PubSubMessage("message1", "channel1"); + queue.EnqueueMessage(message1); + + // Act + List messages = []; + await foreach (PubSubMessage message in queue.GetMessagesAsync()) + { + messages.Add(message); + queue.Dispose(); // Dispose after getting first message + } + + // Assert + Assert.Single(messages); + Assert.Equal("message1", messages[0].Message); + } + + [Fact] + public async Task ConcurrentAccess_MultipleThreads_ThreadSafe() + { + // Arrange + using var queue = new PubSubMessageQueue(); + const int messageCount = 1000; + const int producerThreads = 5; + const int consumerThreads = 5; + + ConcurrentBag producedMessages = []; + ConcurrentBag consumedMessages = []; + List tasks = []; + + // Producer tasks + for (int i = 0; i < producerThreads; i++) + { + int threadId = i; + tasks.Add(Task.Run(() => + { + for (int j = 0; j < messageCount / producerThreads; j++) + { + string messageContent = $"thread-{threadId}-message-{j}"; + var message = new PubSubMessage(messageContent, $"channel-{threadId}"); + queue.EnqueueMessage(message); + producedMessages.Add(messageContent); + } + })); + } + + // Consumer tasks + for (int i = 0; i < consumerThreads; i++) + { + tasks.Add(Task.Run(async () => + { + int messagesConsumed = 0; + while (messagesConsumed < messageCount / consumerThreads) + { + try + { + PubSubMessage message = await queue.GetMessageAsync(); + consumedMessages.Add(message.Message); + messagesConsumed++; + } + catch (ObjectDisposedException) + { + // Queue was disposed, exit + break; + } + } + })); + } + + // Act + await Task.WhenAll(tasks); + + // Assert + Assert.Equal(messageCount, producedMessages.Count); + Assert.Equal(messageCount, consumedMessages.Count); + Assert.Equal(0, queue.Count); + + // Verify all produced messages were consumed + HashSet producedSet = [.. producedMessages]; + HashSet consumedSet = [.. consumedMessages]; + Assert.Equal(producedSet, consumedSet); + } + + [Fact] + public async Task ConcurrentTryGetMessage_MultipleThreads_ThreadSafe() + { + // Arrange + using var queue = new PubSubMessageQueue(); + const int messageCount = 100; + const int consumerThreads = 10; + + // Enqueue messages + for (int i = 0; i < messageCount; i++) + { + var message = new PubSubMessage($"message-{i}", $"channel-{i}"); + queue.EnqueueMessage(message); + } + + ConcurrentBag consumedMessages = []; + List tasks = []; + + // Consumer tasks using TryGetMessage + for (int i = 0; i < consumerThreads; i++) + { + tasks.Add(Task.Run(() => + { + while (queue.TryGetMessage(out PubSubMessage? message)) + { + if (message != null) + { + consumedMessages.Add(message.Message); + } + } + })); + } + + // Act + await Task.WhenAll(tasks); + + // Assert + Assert.Equal(messageCount, consumedMessages.Count); + Assert.Equal(0, queue.Count); + } + + [Fact] + public async Task DisposeDuringAsyncOperation_CancelsWaitingOperations() + { + // Arrange + using var queue = new PubSubMessageQueue(); + + // Start a task that will wait for a message + Task waitingTask = queue.GetMessageAsync(); + + // Ensure the task is waiting + await Task.Delay(50); + Assert.False(waitingTask.IsCompleted); + + // Act + queue.Dispose(); + + // Assert + await Assert.ThrowsAsync(() => waitingTask); + } +} diff --git a/tests/Valkey.Glide.UnitTests/PubSubMessageTests.cs b/tests/Valkey.Glide.UnitTests/PubSubMessageTests.cs new file mode 100644 index 00000000..b38f4845 --- /dev/null +++ b/tests/Valkey.Glide.UnitTests/PubSubMessageTests.cs @@ -0,0 +1,188 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +using System.Text.Json; + +namespace Valkey.Glide.UnitTests; + +public class PubSubMessageTests +{ + [Fact] + public void PubSubMessage_ExactChannelConstructor_SetsPropertiesCorrectly() + { + // Arrange + const string message = "test message"; + const string channel = "test-channel"; + + // Act + var pubSubMessage = new PubSubMessage(message, channel); + + // Assert + Assert.Equal(message, pubSubMessage.Message); + Assert.Equal(channel, pubSubMessage.Channel); + Assert.Null(pubSubMessage.Pattern); + } + + [Fact] + public void PubSubMessage_PatternConstructor_SetsPropertiesCorrectly() + { + // Arrange + const string message = "test message"; + const string channel = "test-channel"; + const string pattern = "test-*"; + + // Act + var pubSubMessage = new PubSubMessage(message, channel, pattern); + + // Assert + Assert.Equal(message, pubSubMessage.Message); + Assert.Equal(channel, pubSubMessage.Channel); + Assert.Equal(pattern, pubSubMessage.Pattern); + } + + [Theory] + [InlineData(null, "channel")] + [InlineData("", "channel")] + [InlineData("message", null)] + [InlineData("message", "")] + public void PubSubMessage_ExactChannelConstructor_ThrowsOnInvalidInput(string? message, string? channel) + { + // Act & Assert + if (message == null || channel == null) + { + Assert.Throws(() => new PubSubMessage(message!, channel!)); + } + else + { + Assert.Throws(() => new PubSubMessage(message, channel)); + } + } + + [Theory] + [InlineData(null, "channel", "pattern")] + [InlineData("", "channel", "pattern")] + [InlineData("message", null, "pattern")] + [InlineData("message", "", "pattern")] + [InlineData("message", "channel", null)] + [InlineData("message", "channel", "")] + public void PubSubMessage_PatternConstructor_ThrowsOnInvalidInput(string? message, string? channel, string? pattern) + { + // Act & Assert + if (message == null || channel == null || pattern == null) + { + Assert.Throws(() => new PubSubMessage(message!, channel!, pattern!)); + } + else + { + Assert.Throws(() => new PubSubMessage(message, channel, pattern)); + } + } + + [Fact] + public void PubSubMessage_ToString_ReturnsValidJson() + { + // Arrange + const string message = "test message"; + const string channel = "test-channel"; + const string pattern = "test-*"; + var pubSubMessage = new PubSubMessage(message, channel, pattern); + + // Act + var jsonString = pubSubMessage.ToString(); + + // Assert + Assert.NotNull(jsonString); + Assert.NotEmpty(jsonString); + + // Verify it's valid JSON by deserializing + JsonElement deserializedObject = JsonSerializer.Deserialize(jsonString); + Assert.Equal(message, deserializedObject.GetProperty("Message").GetString()); + Assert.Equal(channel, deserializedObject.GetProperty("Channel").GetString()); + Assert.Equal(pattern, deserializedObject.GetProperty("Pattern").GetString()); + } + + [Fact] + public void PubSubMessage_ToString_ExactChannel_ReturnsValidJsonWithNullPattern() + { + // Arrange + const string message = "test message"; + const string channel = "test-channel"; + var pubSubMessage = new PubSubMessage(message, channel); + + // Act + var jsonString = pubSubMessage.ToString(); + + // Assert + Assert.NotNull(jsonString); + Assert.NotEmpty(jsonString); + + // Verify it's valid JSON by deserializing + JsonElement deserializedObject = JsonSerializer.Deserialize(jsonString); + Assert.Equal(message, deserializedObject.GetProperty("Message").GetString()); + Assert.Equal(channel, deserializedObject.GetProperty("Channel").GetString()); + Assert.Equal(JsonValueKind.Null, deserializedObject.GetProperty("Pattern").ValueKind); + } + + [Fact] + public void PubSubMessage_Equals_ReturnsTrueForEqualMessages() + { + // Arrange + const string message = "test message"; + const string channel = "test-channel"; + const string pattern = "test-*"; + var message1 = new PubSubMessage(message, channel, pattern); + var message2 = new PubSubMessage(message, channel, pattern); + + // Act & Assert + Assert.Equal(message1, message2); + Assert.True(message1.Equals(message2)); + Assert.True(message2.Equals(message1)); + } + + [Fact] + public void PubSubMessage_Equals_ReturnsFalseForDifferentMessages() + { + // Arrange + var message1 = new PubSubMessage("message1", "channel1", "pattern1"); + var message2 = new PubSubMessage("message2", "channel2", "pattern2"); + + // Act & Assert + Assert.NotEqual(message1, message2); + Assert.False(message1.Equals(message2)); + Assert.False(message2.Equals(message1)); + } + + [Fact] + public void PubSubMessage_Equals_ReturnsFalseForNull() + { + // Arrange + var message = new PubSubMessage("test", "channel"); + + // Act & Assert + Assert.False(message.Equals(null)); + } + + [Fact] + public void PubSubMessage_GetHashCode_SameForEqualMessages() + { + // Arrange + const string message = "test message"; + const string channel = "test-channel"; + const string pattern = "test-*"; + var message1 = new PubSubMessage(message, channel, pattern); + var message2 = new PubSubMessage(message, channel, pattern); + + // Act & Assert + Assert.Equal(message1.GetHashCode(), message2.GetHashCode()); + } + + [Fact] + public void PubSubMessage_GetHashCode_DifferentForDifferentMessages() + { + // Arrange + var message1 = new PubSubMessage("message1", "channel1", "pattern1"); + var message2 = new PubSubMessage("message2", "channel2", "pattern2"); + + // Act & Assert + Assert.NotEqual(message1.GetHashCode(), message2.GetHashCode()); + } +} diff --git a/tests/Valkey.Glide.UnitTests/PubSubPerformanceTests.cs b/tests/Valkey.Glide.UnitTests/PubSubPerformanceTests.cs new file mode 100644 index 00000000..5aa7e3a6 --- /dev/null +++ b/tests/Valkey.Glide.UnitTests/PubSubPerformanceTests.cs @@ -0,0 +1,276 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +using System; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; + +using Xunit; + +namespace Valkey.Glide.UnitTests; + +/// +/// Performance validation tests for channel-based PubSub message processing. +/// These tests verify that the channel-based approach provides better performance +/// than the previous Task.Run per message approach. +/// +public class PubSubPerformanceTests +{ + + [Fact] + public void ChannelBasedProcessing_HighThroughput_HandlesMessagesEfficiently() + { + // Arrange + const int messageCount = 10_000; + var messagesReceived = 0; + + var config = new StandalonePubSubSubscriptionConfig() + .WithChannel("perf-test") + .WithCallback((msg, ctx) => + { + Interlocked.Increment(ref messagesReceived); + }, null); + + // Act - Simulate high-volume message processing + var stopwatch = Stopwatch.StartNew(); + + // Simulate messages being processed through the channel + for (int i = 0; i < messageCount; i++) + { + var message = new PubSubMessage($"message-{i}", "perf-test"); + config.Callback!(message, null); + } + + stopwatch.Stop(); + + // Assert + Assert.Equal(messageCount, messagesReceived); + + var throughput = messageCount / stopwatch.Elapsed.TotalSeconds; + + // Verify high throughput (should handle at least 10,000 msg/sec) + Assert.True(throughput >= 10_000, + $"Throughput {throughput:F0} msg/sec is below target of 10,000 msg/sec. Processed {messageCount} messages in {stopwatch.Elapsed.TotalMilliseconds:F2}ms"); + } + + [Fact] + public void ChannelBasedProcessing_ReducedAllocationPressure_MinimizesGCImpact() + { + // Arrange + const int messageCount = 50_000; + var messagesReceived = 0; + + var config = new StandalonePubSubSubscriptionConfig() + .WithChannel("gc-test") + .WithCallback((msg, ctx) => + { + Interlocked.Increment(ref messagesReceived); + }, null); + + // Force GC before test + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + var initialMemory = GC.GetTotalMemory(false); + var initialGen0 = GC.CollectionCount(0); + var initialGen1 = GC.CollectionCount(1); + var initialGen2 = GC.CollectionCount(2); + + // Act - Process many messages + for (int i = 0; i < messageCount; i++) + { + var message = new PubSubMessage($"message-{i}", "gc-test"); + config.Callback!(message, null); + } + + // Wait a bit for any pending operations + Thread.Sleep(100); + + // Force GC to measure impact + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + var finalMemory = GC.GetTotalMemory(false); + var finalGen0 = GC.CollectionCount(0); + var finalGen1 = GC.CollectionCount(1); + var finalGen2 = GC.CollectionCount(2); + + // Assert + Assert.Equal(messageCount, messagesReceived); + + var memoryGrowth = finalMemory - initialMemory; + var gen0Collections = finalGen0 - initialGen0; + var gen1Collections = finalGen1 - initialGen1; + var gen2Collections = finalGen2 - initialGen2; + + // Verify reasonable memory growth (should be less than 10MB for 50k messages) + Assert.True(memoryGrowth < 10_000_000, + $"Memory grew by {memoryGrowth:N0} bytes ({memoryGrowth / 1024.0 / 1024.0:F2} MB) - excessive allocation pressure. Gen0: {gen0Collections}, Gen1: {gen1Collections}, Gen2: {gen2Collections}"); + + // Verify minimal Gen2 collections (should be reasonable for 50k messages) + // Note: GC behavior can vary based on system load and other factors + Assert.True(gen2Collections <= 100, + $"Too many Gen2 collections ({gen2Collections}) - indicates allocation pressure. Gen0: {gen0Collections}, Gen1: {gen1Collections}"); + } + + [Fact] + public void ChannelBasedProcessing_ConcurrentMessages_MaintainsPerformance() + { + // Arrange + const int threadCount = 10; + const int messagesPerThread = 1_000; + var totalMessages = threadCount * messagesPerThread; + var messagesReceived = 0; + + var config = new StandalonePubSubSubscriptionConfig() + .WithChannel("concurrent-test") + .WithCallback((msg, ctx) => + { + Interlocked.Increment(ref messagesReceived); + }, null); + + // Act - Simulate concurrent message arrival from multiple threads + var stopwatch = Stopwatch.StartNew(); + var tasks = new Task[threadCount]; + + for (int t = 0; t < threadCount; t++) + { + var threadId = t; + tasks[t] = Task.Run(() => + { + for (int i = 0; i < messagesPerThread; i++) + { + var message = new PubSubMessage($"thread-{threadId}-msg-{i}", "concurrent-test"); + config.Callback!(message, null); + } + }); + } + + Task.WaitAll(tasks); + stopwatch.Stop(); + + // Assert + Assert.Equal(totalMessages, messagesReceived); + + var throughput = totalMessages / stopwatch.Elapsed.TotalSeconds; + + // Verify high throughput even with concurrent access + Assert.True(throughput >= 5_000, + $"Concurrent throughput {throughput:F0} msg/sec is below target of 5,000 msg/sec. Processed {totalMessages} messages from {threadCount} threads in {stopwatch.Elapsed.TotalMilliseconds:F2}ms"); + } + + [Fact] + public void ChannelBasedProcessing_BurstTraffic_HandlesSpikesEfficiently() + { + // Arrange + const int burstSize = 5_000; + const int burstCount = 5; + var messagesReceived = 0; + var burstTimes = new ConcurrentBag(); + + var config = new StandalonePubSubSubscriptionConfig() + .WithChannel("burst-test") + .WithCallback((msg, ctx) => + { + Interlocked.Increment(ref messagesReceived); + }, null); + + // Act - Simulate burst traffic patterns + var totalStopwatch = Stopwatch.StartNew(); + + for (int burst = 0; burst < burstCount; burst++) + { + var burstStopwatch = Stopwatch.StartNew(); + + // Send burst of messages + for (int i = 0; i < burstSize; i++) + { + var message = new PubSubMessage($"burst-{burst}-msg-{i}", "burst-test"); + config.Callback!(message, null); + } + + burstStopwatch.Stop(); + burstTimes.Add(burstStopwatch.Elapsed); + + // Small delay between bursts + Thread.Sleep(10); + } + + totalStopwatch.Stop(); + + // Assert + Assert.Equal(burstSize * burstCount, messagesReceived); + + var avgBurstTime = burstTimes.Select(t => t.TotalMilliseconds).Average(); + var maxBurstTime = burstTimes.Select(t => t.TotalMilliseconds).Max(); + + // Verify burst handling is efficient + Assert.True(avgBurstTime < 1000, + $"Average burst time {avgBurstTime:F2}ms exceeds 1 second threshold. Processed {burstCount} bursts of {burstSize} messages. Max burst time: {maxBurstTime:F2}ms, Total time: {totalStopwatch.Elapsed.TotalMilliseconds:F2}ms"); + } + + [Fact] + public void ChannelBasedProcessing_LongRunning_MaintainsStablePerformance() + { + // Arrange + const int duration = 5; // seconds + const int targetRate = 1_000; // messages per second + var messagesReceived = 0; + var throughputSamples = new ConcurrentBag(); + + var config = new StandalonePubSubSubscriptionConfig() + .WithChannel("long-running-test") + .WithCallback((msg, ctx) => + { + Interlocked.Increment(ref messagesReceived); + }, null); + + // Act - Sustained message processing + var stopwatch = Stopwatch.StartNew(); + var sampleInterval = TimeSpan.FromSeconds(1); + var nextSampleTime = sampleInterval; + var lastSampleCount = 0; + + while (stopwatch.Elapsed < TimeSpan.FromSeconds(duration)) + { + // Send messages at target rate + for (int i = 0; i < targetRate / 10; i++) + { + var message = new PubSubMessage($"msg-{messagesReceived}", "long-running-test"); + config.Callback!(message, null); + } + + Thread.Sleep(100); // 100ms intervals + + // Sample throughput every second + if (stopwatch.Elapsed >= nextSampleTime) + { + var currentCount = messagesReceived; + var sampleThroughput = (currentCount - lastSampleCount) / sampleInterval.TotalSeconds; + throughputSamples.Add(sampleThroughput); + lastSampleCount = currentCount; + nextSampleTime += sampleInterval; + } + } + + stopwatch.Stop(); + + // Assert + var avgThroughput = throughputSamples.Average(); + var minThroughput = throughputSamples.Min(); + var maxThroughput = throughputSamples.Max(); + var throughputStdDev = Math.Sqrt(throughputSamples.Select(t => Math.Pow(t - avgThroughput, 2)).Average()); + + // Verify stable performance over time + Assert.True(avgThroughput >= targetRate * 0.8, + $"Average throughput {avgThroughput:F0} is below 80% of target rate {targetRate}. Processed {messagesReceived} messages over {duration} seconds. Min: {minThroughput:F0}, Max: {maxThroughput:F0}, StdDev: {throughputStdDev:F0}"); + + // Verify throughput stability (std dev should be less than 20% of average) + Assert.True(throughputStdDev < avgThroughput * 0.2, + $"Throughput std dev {throughputStdDev:F0} indicates unstable performance. Average: {avgThroughput:F0}, Min: {minThroughput:F0}, Max: {maxThroughput:F0}"); + } +} diff --git a/tests/Valkey.Glide.UnitTests/PubSubSubscriptionConfigTests.cs b/tests/Valkey.Glide.UnitTests/PubSubSubscriptionConfigTests.cs new file mode 100644 index 00000000..209d6bc7 --- /dev/null +++ b/tests/Valkey.Glide.UnitTests/PubSubSubscriptionConfigTests.cs @@ -0,0 +1,429 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +namespace Valkey.Glide.UnitTests; + +public class PubSubSubscriptionConfigTests +{ + #region StandalonePubSubSubscriptionConfig Tests + + [Fact] + public void StandaloneConfig_WithChannel_AddsExactChannelSubscription() + { + // Arrange + var config = new StandalonePubSubSubscriptionConfig(); + + // Act + var result = config.WithChannel("test-channel"); + + // Assert + Assert.Same(config, result); // Should return same instance for chaining + Assert.True(config.Subscriptions.ContainsKey((uint)PubSubChannelMode.Exact)); + Assert.Contains("test-channel", config.Subscriptions[(uint)PubSubChannelMode.Exact]); + } + + [Fact] + public void StandaloneConfig_WithPattern_AddsPatternSubscription() + { + // Arrange + var config = new StandalonePubSubSubscriptionConfig(); + + // Act + var result = config.WithPattern("test-*"); + + // Assert + Assert.Same(config, result); // Should return same instance for chaining + Assert.True(config.Subscriptions.ContainsKey((uint)PubSubChannelMode.Pattern)); + Assert.Contains("test-*", config.Subscriptions[(uint)PubSubChannelMode.Pattern]); + } + + [Fact] + public void StandaloneConfig_WithSubscription_AddsCorrectSubscription() + { + // Arrange + var config = new StandalonePubSubSubscriptionConfig(); + + // Act + var result = config.WithSubscription(PubSubChannelMode.Exact, "exact-channel"); + + // Assert + Assert.Same(config, result); + Assert.True(config.Subscriptions.ContainsKey((uint)PubSubChannelMode.Exact)); + Assert.Contains("exact-channel", config.Subscriptions[(uint)PubSubChannelMode.Exact]); + } + + [Fact] + public void StandaloneConfig_WithSubscription_NullOrEmptyChannel_ThrowsArgumentException() + { + // Arrange + var config = new StandalonePubSubSubscriptionConfig(); + + // Act & Assert + Assert.Throws(() => config.WithSubscription(PubSubChannelMode.Exact, null!)); + Assert.Throws(() => config.WithSubscription(PubSubChannelMode.Exact, "")); + Assert.Throws(() => config.WithSubscription(PubSubChannelMode.Exact, " ")); + } + + [Fact] + public void StandaloneConfig_WithSubscription_InvalidMode_ThrowsArgumentOutOfRangeException() + { + // Arrange + var config = new StandalonePubSubSubscriptionConfig(); + + // Act & Assert + Assert.Throws(() => config.WithSubscription((PubSubChannelMode)999, "test")); + } + + [Fact] + public void StandaloneConfig_WithSubscription_DuplicateChannel_DoesNotAddDuplicate() + { + // Arrange + var config = new StandalonePubSubSubscriptionConfig(); + + // Act + config.WithChannel("test-channel"); + config.WithChannel("test-channel"); // Add same channel again + + // Assert + Assert.Single(config.Subscriptions[(uint)PubSubChannelMode.Exact]); + Assert.Contains("test-channel", config.Subscriptions[(uint)PubSubChannelMode.Exact]); + } + + [Fact] + public void StandaloneConfig_WithCallback_SetsCallbackAndContext() + { + // Arrange + var config = new StandalonePubSubSubscriptionConfig(); + var context = new { TestData = "test" }; + MessageCallback callback = (message, ctx) => { }; + + // Act + var result = config.WithCallback(callback, context); + + // Assert + Assert.Same(config, result); + Assert.Same(callback, config.Callback); + Assert.Same(context, config.Context); + } + + [Fact] + public void StandaloneConfig_WithCallback_NullCallback_ThrowsArgumentNullException() + { + // Arrange + var config = new StandalonePubSubSubscriptionConfig(); + + // Act & Assert + Assert.Throws(() => config.WithCallback(null!)); + } + + [Fact] + public void StandaloneConfig_Validate_NoSubscriptions_ThrowsArgumentException() + { + // Arrange + var config = new StandalonePubSubSubscriptionConfig(); + + // Act & Assert + var exception = Assert.Throws(() => config.Validate()); + Assert.Contains("At least one subscription must be configured", exception.Message); + } + + [Fact] + public void StandaloneConfig_Validate_ValidConfiguration_DoesNotThrow() + { + // Arrange + var config = new StandalonePubSubSubscriptionConfig() + .WithChannel("test-channel"); + + // Act & Assert + config.Validate(); // Should not throw + } + + [Fact] + public void StandaloneConfig_BuilderPattern_SupportsMethodChaining() + { + // Arrange & Act + var config = new StandalonePubSubSubscriptionConfig() + .WithChannel("channel1") + .WithPattern("pattern*") + .WithCallback((msg, ctx) => { }, "context"); + + // Assert + Assert.True(config.Subscriptions.ContainsKey((uint)PubSubChannelMode.Exact)); + Assert.True(config.Subscriptions.ContainsKey((uint)PubSubChannelMode.Pattern)); + Assert.NotNull(config.Callback); + Assert.Equal("context", config.Context); + } + + #endregion + + #region ClusterPubSubSubscriptionConfig Tests + + [Fact] + public void ClusterConfig_WithChannel_AddsExactChannelSubscription() + { + // Arrange + var config = new ClusterPubSubSubscriptionConfig(); + + // Act + var result = config.WithChannel("test-channel"); + + // Assert + Assert.Same(config, result); + Assert.True(config.Subscriptions.ContainsKey((uint)PubSubClusterChannelMode.Exact)); + Assert.Contains("test-channel", config.Subscriptions[(uint)PubSubClusterChannelMode.Exact]); + } + + [Fact] + public void ClusterConfig_WithPattern_AddsPatternSubscription() + { + // Arrange + var config = new ClusterPubSubSubscriptionConfig(); + + // Act + var result = config.WithPattern("test-*"); + + // Assert + Assert.Same(config, result); + Assert.True(config.Subscriptions.ContainsKey((uint)PubSubClusterChannelMode.Pattern)); + Assert.Contains("test-*", config.Subscriptions[(uint)PubSubClusterChannelMode.Pattern]); + } + + [Fact] + public void ClusterConfig_WithShardedChannel_AddsShardedSubscription() + { + // Arrange + var config = new ClusterPubSubSubscriptionConfig(); + + // Act + var result = config.WithShardedChannel("sharded-channel"); + + // Assert + Assert.Same(config, result); + Assert.True(config.Subscriptions.ContainsKey((uint)PubSubClusterChannelMode.Sharded)); + Assert.Contains("sharded-channel", config.Subscriptions[(uint)PubSubClusterChannelMode.Sharded]); + } + + [Fact] + public void ClusterConfig_WithSubscription_AddsCorrectSubscription() + { + // Arrange + var config = new ClusterPubSubSubscriptionConfig(); + + // Act + var result = config.WithSubscription(PubSubClusterChannelMode.Sharded, "sharded-channel"); + + // Assert + Assert.Same(config, result); + Assert.True(config.Subscriptions.ContainsKey((uint)PubSubClusterChannelMode.Sharded)); + Assert.Contains("sharded-channel", config.Subscriptions[(uint)PubSubClusterChannelMode.Sharded]); + } + + [Fact] + public void ClusterConfig_WithSubscription_NullOrEmptyChannel_ThrowsArgumentException() + { + // Arrange + var config = new ClusterPubSubSubscriptionConfig(); + + // Act & Assert + Assert.Throws(() => config.WithSubscription(PubSubClusterChannelMode.Exact, null!)); + Assert.Throws(() => config.WithSubscription(PubSubClusterChannelMode.Exact, "")); + Assert.Throws(() => config.WithSubscription(PubSubClusterChannelMode.Exact, " ")); + } + + [Fact] + public void ClusterConfig_WithSubscription_InvalidMode_ThrowsArgumentOutOfRangeException() + { + // Arrange + var config = new ClusterPubSubSubscriptionConfig(); + + // Act & Assert + Assert.Throws(() => config.WithSubscription((PubSubClusterChannelMode)999, "test")); + } + + [Fact] + public void ClusterConfig_WithSubscription_DuplicateChannel_DoesNotAddDuplicate() + { + // Arrange + var config = new ClusterPubSubSubscriptionConfig(); + + // Act + config.WithShardedChannel("sharded-channel"); + config.WithShardedChannel("sharded-channel"); // Add same channel again + + // Assert + Assert.Single(config.Subscriptions[(uint)PubSubClusterChannelMode.Sharded]); + Assert.Contains("sharded-channel", config.Subscriptions[(uint)PubSubClusterChannelMode.Sharded]); + } + + [Fact] + public void ClusterConfig_WithCallback_SetsCallbackAndContext() + { + // Arrange + var config = new ClusterPubSubSubscriptionConfig(); + var context = new { TestData = "test" }; + MessageCallback callback = (message, ctx) => { }; + + // Act + var result = config.WithCallback(callback, context); + + // Assert + Assert.Same(config, result); + Assert.Same(callback, config.Callback); + Assert.Same(context, config.Context); + } + + [Fact] + public void ClusterConfig_WithCallback_NullCallback_ThrowsArgumentNullException() + { + // Arrange + var config = new ClusterPubSubSubscriptionConfig(); + + // Act & Assert + Assert.Throws(() => config.WithCallback(null!)); + } + + [Fact] + public void ClusterConfig_Validate_NoSubscriptions_ThrowsArgumentException() + { + // Arrange + var config = new ClusterPubSubSubscriptionConfig(); + + // Act & Assert + var exception = Assert.Throws(() => config.Validate()); + Assert.Contains("At least one subscription must be configured", exception.Message); + } + + [Fact] + public void ClusterConfig_Validate_ValidConfiguration_DoesNotThrow() + { + // Arrange + var config = new ClusterPubSubSubscriptionConfig() + .WithShardedChannel("test-channel"); + + // Act & Assert + config.Validate(); // Should not throw + } + + [Fact] + public void ClusterConfig_BuilderPattern_SupportsMethodChaining() + { + // Arrange & Act + var config = new ClusterPubSubSubscriptionConfig() + .WithChannel("channel1") + .WithPattern("pattern*") + .WithShardedChannel("sharded1") + .WithCallback((msg, ctx) => { }, "context"); + + // Assert + Assert.True(config.Subscriptions.ContainsKey((uint)PubSubClusterChannelMode.Exact)); + Assert.True(config.Subscriptions.ContainsKey((uint)PubSubClusterChannelMode.Pattern)); + Assert.True(config.Subscriptions.ContainsKey((uint)PubSubClusterChannelMode.Sharded)); + Assert.NotNull(config.Callback); + Assert.Equal("context", config.Context); + } + + #endregion + + #region MessageCallback Tests + + [Fact] + public void MessageCallback_CanBeInvoked() + { + // Arrange + bool messageReceived = false; + PubSubMessage? receivedMessage = null; + object? receivedContext = null; + + MessageCallback callback = (message, context) => + { + messageReceived = true; + receivedMessage = message; + receivedContext = context; + }; + + var testMessage = new PubSubMessage("test-message", "test-channel"); + string testContext = "test-context"; + + // Act + callback(testMessage, testContext); + + // Assert + Assert.True(messageReceived); + Assert.Same(testMessage, receivedMessage); + Assert.Same(testContext, receivedContext); + } + + #endregion + + #region Validation Tests + + [Fact] + public void BasePubSubSubscriptionConfig_Validate_EmptyChannelList_ThrowsArgumentException() + { + // Arrange + var config = new StandalonePubSubSubscriptionConfig(); + config.Subscriptions[(uint)PubSubChannelMode.Exact] = new List(); + + // Act & Assert + var exception = Assert.Throws(() => config.Validate()); + Assert.Contains("has no channels or patterns configured", exception.Message); + } + + [Fact] + public void BasePubSubSubscriptionConfig_Validate_NullChannelInList_ThrowsArgumentException() + { + // Arrange + var config = new StandalonePubSubSubscriptionConfig(); + config.Subscriptions[(uint)PubSubChannelMode.Exact] = new List { null! }; + + // Act & Assert + var exception = Assert.Throws(() => config.Validate()); + Assert.Contains("Channel name or pattern cannot be null, empty, or whitespace", exception.Message); + } + + [Fact] + public void BasePubSubSubscriptionConfig_Validate_EmptyChannelInList_ThrowsArgumentException() + { + // Arrange + var config = new StandalonePubSubSubscriptionConfig(); + config.Subscriptions[(uint)PubSubChannelMode.Exact] = new List { "" }; + + // Act & Assert + var exception = Assert.Throws(() => config.Validate()); + Assert.Contains("Channel name or pattern cannot be null, empty, or whitespace", exception.Message); + } + + [Fact] + public void BasePubSubSubscriptionConfig_Validate_WhitespaceChannelInList_ThrowsArgumentException() + { + // Arrange + var config = new StandalonePubSubSubscriptionConfig(); + config.Subscriptions[(uint)PubSubChannelMode.Exact] = new List { " " }; + + // Act & Assert + var exception = Assert.Throws(() => config.Validate()); + Assert.Contains("Channel name or pattern cannot be null, empty, or whitespace", exception.Message); + } + + #endregion + + #region Enum Tests + + [Fact] + public void PubSubChannelMode_HasCorrectValues() + { + // Assert + Assert.Equal(0, (int)PubSubChannelMode.Exact); + Assert.Equal(1, (int)PubSubChannelMode.Pattern); + } + + [Fact] + public void PubSubClusterChannelMode_HasCorrectValues() + { + // Assert + Assert.Equal(0, (int)PubSubClusterChannelMode.Exact); + Assert.Equal(1, (int)PubSubClusterChannelMode.Pattern); + Assert.Equal(2, (int)PubSubClusterChannelMode.Sharded); + } + + #endregion +} diff --git a/tests/Valkey.Glide.UnitTests/PubSubThreadSafetyTests.cs b/tests/Valkey.Glide.UnitTests/PubSubThreadSafetyTests.cs new file mode 100644 index 00000000..42cf7638 --- /dev/null +++ b/tests/Valkey.Glide.UnitTests/PubSubThreadSafetyTests.cs @@ -0,0 +1,260 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; + +using Xunit; + +namespace Valkey.Glide.UnitTests; + +/// +/// Comprehensive thread safety tests for PubSub handler access in BaseClient. +/// Tests concurrent access, disposal during message processing, and race condition scenarios. +/// +public class PubSubThreadSafetyTests +{ + [Fact] + public async Task PubSubHandler_ConcurrentMessageProcessing_NoRaceConditions() + { + // Arrange + var messagesReceived = new ConcurrentBag(); + var config = new StandalonePubSubSubscriptionConfig() + .WithChannel("test-channel") + .WithCallback((msg, ctx) => + { + messagesReceived.Add(msg); + Thread.Sleep(1); // Simulate some processing + }, null); + + var client = CreateMockClientWithPubSub(config); + + // Act - Process 100 messages concurrently from multiple threads + var tasks = Enumerable.Range(0, 100) + .Select(i => Task.Run(() => + { + var message = new PubSubMessage($"message-{i}", "test-channel"); + client.HandlePubSubMessage(message); + })) + .ToArray(); + + await Task.WhenAll(tasks); + + // Wait for all messages to be processed + await Task.Delay(500); + + // Assert + Assert.Equal(100, messagesReceived.Count); + Assert.Equal(100, messagesReceived.Distinct().Count()); // All messages should be unique + } + + [Fact] + public async Task PubSubHandler_DisposalDuringMessageProcessing_NoNullReferenceException() + { + // Arrange + var processingStarted = new ManualResetEventSlim(false); + var continueProcessing = new ManualResetEventSlim(false); + var exceptions = new ConcurrentBag(); + + var config = new StandalonePubSubSubscriptionConfig() + .WithChannel("test-channel") + .WithCallback((msg, ctx) => + { + processingStarted.Set(); + continueProcessing.Wait(TimeSpan.FromSeconds(5)); + Thread.Sleep(50); // Simulate processing + }, null); + + var client = CreateMockClientWithPubSub(config); + + // Act - Start message processing + var messageTask = Task.Run(() => + { + try + { + var message = new PubSubMessage("test-message", "test-channel"); + client.HandlePubSubMessage(message); + } + catch (Exception ex) + { + exceptions.Add(ex); + } + }); + + // Wait for processing to start + processingStarted.Wait(TimeSpan.FromSeconds(5)); + + // Dispose client while message is being processed + var disposeTask = Task.Run(() => + { + try + { + client.Dispose(); + } + catch (Exception ex) + { + exceptions.Add(ex); + } + }); + + // Allow message processing to continue + continueProcessing.Set(); + + await Task.WhenAll(messageTask, disposeTask); + + // Assert - No exceptions should occur + Assert.Empty(exceptions); + } + + [Fact] + public async Task PubSubQueue_ConcurrentAccess_ThreadSafe() + { + // Arrange + var config = new StandalonePubSubSubscriptionConfig() + .WithChannel("test-channel"); + + var client = CreateMockClientWithPubSub(config); + + // Act - Access PubSubQueue from multiple threads concurrently + var tasks = Enumerable.Range(0, 50) + .Select(_ => Task.Run(() => + { + var queue = client.PubSubQueue; + Assert.NotNull(queue); + })) + .ToArray(); + + await Task.WhenAll(tasks); + + // Assert - No exceptions should occur + Assert.NotNull(client.PubSubQueue); + } + + [Fact] + public async Task HasPubSubSubscriptions_ConcurrentAccess_ThreadSafe() + { + // Arrange + var config = new StandalonePubSubSubscriptionConfig() + .WithChannel("test-channel"); + + var client = CreateMockClientWithPubSub(config); + + // Act - Access HasPubSubSubscriptions from multiple threads + var tasks = Enumerable.Range(0, 100) + .Select(_ => Task.Run(() => + { + var hasSubscriptions = client.HasPubSubSubscriptions; + Assert.True(hasSubscriptions); + })) + .ToArray(); + + await Task.WhenAll(tasks); + + // Assert + Assert.True(client.HasPubSubSubscriptions); + } + + [Fact] + public async Task PubSubHandler_RapidCreateAndDispose_NoMemoryLeaks() + { + // Arrange & Act - Create and dispose clients rapidly + for (int i = 0; i < 50; i++) + { + var config = new StandalonePubSubSubscriptionConfig() + .WithChannel("test-channel") + .WithCallback((msg, ctx) => { }, null); + + var client = CreateMockClientWithPubSub(config); + + // Send a few messages + for (int j = 0; j < 5; j++) + { + var message = new PubSubMessage($"message-{j}", "test-channel"); + client.HandlePubSubMessage(message); + } + + // Dispose immediately + client.Dispose(); + } + + // Force GC to detect any memory leaks + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + // Assert - Test passes if no exceptions occur + Assert.True(true); + await Task.CompletedTask; + } + + [Fact] + public async Task PubSubHandler_DisposalDuringCallback_CompletesWithoutHanging() + { + // Arrange - Create handler with slow callback + var disposeStarted = new ManualResetEventSlim(false); + var config = new StandalonePubSubSubscriptionConfig() + .WithChannel("test-channel") + .WithCallback((msg, ctx) => + { + // This callback will block during disposal + disposeStarted.Wait(TimeSpan.FromSeconds(10)); + }, null); + + var client = CreateMockClientWithPubSub(config); + + // Start a long-running message processing + var messageTask = Task.Run(() => + { + var message = new PubSubMessage("test-message", "test-channel"); + client.HandlePubSubMessage(message); + }); + + await Task.Delay(100); // Let message processing start + + // Act - Dispose should complete without hanging + var disposeTask = Task.Run(() => client.Dispose()); + + // Allow disposal to proceed after a short delay + await Task.Delay(100); + disposeStarted.Set(); + + await Task.WhenAll(messageTask, disposeTask); + + // Assert - Test passes if disposal completes + Assert.True(true); + } + + /// + /// Helper method to create a mock client with PubSub configuration for testing. + /// This simulates client creation without requiring actual server connection. + /// + private static TestableBaseClient CreateMockClientWithPubSub(BasePubSubSubscriptionConfig? config) + { + var client = new TestableBaseClient(); + client.InitializePubSubHandlerForTest(config); + return client; + } + + /// + /// Testable version of BaseClient that exposes internal methods for testing. + /// + private class TestableBaseClient : BaseClient + { + public void InitializePubSubHandlerForTest(BasePubSubSubscriptionConfig? config) + { + // Use reflection to call private InitializePubSubHandler method + var method = typeof(BaseClient).GetMethod("InitializePubSubHandler", + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + method?.Invoke(this, new object?[] { config }); + } + + protected override Task InitializeServerVersionAsync() + { + _serverVersion = new Version(7, 2, 0); + return Task.CompletedTask; + } + } +} diff --git a/tests/Valkey.Glide.UnitTests/Valkey.Glide.UnitTests.csproj b/tests/Valkey.Glide.UnitTests/Valkey.Glide.UnitTests.csproj index 6172a794..88dc60cc 100644 --- a/tests/Valkey.Glide.UnitTests/Valkey.Glide.UnitTests.csproj +++ b/tests/Valkey.Glide.UnitTests/Valkey.Glide.UnitTests.csproj @@ -34,7 +34,7 @@ true - $(NoWarn);CS1591;CS1573;CS1587 + $(NoWarn);CS1591;CS1573;CS1587;IDE0130 @@ -64,6 +64,8 @@ gs + +