diff --git a/.gitignore b/.gitignore index e64a7689..13e7108c 100644 --- a/.gitignore +++ b/.gitignore @@ -143,10 +143,11 @@ $RECYCLE.BIN/ # Mac desktop service store files .DS_Store -# IDE generaged files +# IDE generated files .vs .vscode .kiro/ +.amazonq/ _NCrunch* diff --git a/rust/src/ffi.rs b/rust/src/ffi.rs index 4a4569e2..d72f4514 100644 --- a/rust/src/ffi.rs +++ b/rust/src/ffi.rs @@ -317,7 +317,7 @@ pub(crate) unsafe fn create_route( /// * `data`, `data_len` and also each pointer stored in `data` must be able to be safely casted to a valid to a slice of the corresponding type via [`from_raw_parts`]. /// See the safety documentation of [`from_raw_parts`]. /// * The caller is responsible of freeing the allocated memory. -pub(crate) unsafe fn convert_double_pointer_to_vec<'a>( +pub(crate) unsafe fn convert_string_pointer_array_to_vector<'a>( data: *const *const u8, len: usize, data_len: *const usize, @@ -530,11 +530,11 @@ pub struct BatchOptionsInfo { /// * `cmd_ptr` must be able to be safely casted to a valid [`CmdInfo`] /// * `args` and `args_len` in a referred [`CmdInfo`] structure must not be `null`. /// * `data` in a referred [`CmdInfo`] structure must point to `arg_count` consecutive string pointers. -/// * `args_len` in a referred [`CmdInfo`] structure must point to `arg_count` consecutive string lengths. See the safety documentation of [`convert_double_pointer_to_vec`]. +/// * `args_len` in a referred [`CmdInfo`] structure must point to `arg_count` consecutive string lengths. See the safety documentation of [`convert_string_pointer_array_to_vector`]. pub(crate) unsafe fn create_cmd(ptr: *const CmdInfo) -> Result { let info = unsafe { *ptr }; let arg_vec = - unsafe { convert_double_pointer_to_vec(info.args, info.arg_count, info.args_len) }; + unsafe { convert_string_pointer_array_to_vector(info.args, info.arg_count, info.args_len) }; let Some(mut cmd) = info.request_type.get_command() else { return Err("Couldn't fetch command type".into()); diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 9b857a75..3515d526 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -11,6 +11,7 @@ use glide_core::{ }; use std::{ ffi::{CStr, CString, c_char, c_void}, + slice::from_raw_parts, sync::Arc, }; use tokio::runtime::{Builder, Runtime}; @@ -438,3 +439,305 @@ pub unsafe extern "C" fn init(level: Option, file_name: *const c_char) -> let logger_level = logger_core::init(level.map(|level| level.into()), file_name_as_str); logger_level.into() } + +/// Execute a cluster scan request. +/// +/// # Safety +/// * `client_ptr` must be a valid Client pointer from create_client +/// * `cursor` must be "0" for initial scan or a valid cursor ID from previous scan +/// * `args` and `arg_lengths` must be valid arrays of length `arg_count` +/// * `args` format: [b"MATCH", pattern_arg, b"COUNT", count, b"TYPE", type] (all optional) +#[unsafe(no_mangle)] +pub unsafe extern "C-unwind" fn request_cluster_scan( + client_ptr: *const c_void, + callback_index: usize, + cursor: *const c_char, + arg_count: u64, + args: *const usize, + arg_lengths: *const u64, +) { + // Build client and add panic guard. + let client = unsafe { + Arc::increment_strong_count(client_ptr); + Arc::from_raw(client_ptr as *mut Client) + }; + let core = client.core.clone(); + + let mut panic_guard = PanicGuard { + panicked: true, + failure_callback: core.failure_callback, + callback_index, + }; + + // Get the cluster scan state. + let cursor_id = unsafe { CStr::from_ptr(cursor) } + .to_str() + .unwrap_or("0") + .to_owned(); + + let scan_state_cursor = if cursor_id == "0" { + redis::ScanStateRC::new() + } else { + match glide_core::cluster_scan_container::get_cluster_scan_cursor(cursor_id.clone()) { + Ok(existing_cursor) => existing_cursor, + Err(_error) => { + unsafe { + (core.failure_callback)( + callback_index, + format!("Invalid cursor ID: {}", cursor_id).as_ptr() as *const c_char, + RequestErrorType::Unspecified, + ); + } + return; + } + } + }; + + // Build cluster scan arguments. + let cluster_scan_args = match unsafe { + build_cluster_scan_args( + arg_count, + args, + arg_lengths, + core.failure_callback, + callback_index, + ) + } { + Some(args) => args, + None => return, + }; + + // Run cluster scan. + client.runtime.spawn(async move { + let mut async_panic_guard = PanicGuard { + panicked: true, + failure_callback: core.failure_callback, + callback_index, + }; + + let result = core + .client + .clone() + .cluster_scan(&scan_state_cursor, cluster_scan_args) + .await; + match result { + Ok(value) => { + let ptr = Box::into_raw(Box::new(ResponseValue::from_value(value))); + unsafe { (core.success_callback)(callback_index, ptr) }; + } + Err(err) => unsafe { + report_error( + core.failure_callback, + callback_index, + glide_core::errors::error_message(&err), + glide_core::errors::error_type(&err), + ); + }, + }; + + async_panic_guard.panicked = false; + }); + + panic_guard.panicked = false; +} + +/// Remove a cluster scan cursor from the Rust core container. +/// +/// This should be called when the C# ClusterScanCursor is disposed or finalized +/// to clean up resources allocated by the Rust core for cluster scan operations. +/// +/// # Safety +/// * `cursor_id` must be a valid C string or null +#[unsafe(no_mangle)] +pub unsafe extern "C" fn remove_cluster_scan_cursor(cursor_id: *const c_char) { + if cursor_id.is_null() { + return; + } + + if let Ok(cursor_str) = unsafe { CStr::from_ptr(cursor_id).to_str() } { + glide_core::cluster_scan_container::remove_scan_state_cursor(cursor_str.to_string()); + } +} + +/// Build cluster scan arguments from C-style arrays. +/// +/// # Arguments +/// +/// * `arg_count` - The number of arguments in the arrays +/// * `args` - Pointer to an array of pointers to argument data +/// * `arg_lengths` - Pointer to an array of argument lengths +/// * `failure_callback` - Callback function to invoke on error +/// * `callback_index` - Index to pass to the callback function +/// +/// # Safety +/// * `args` and `arg_lengths` must be valid arrays of length `arg_count` +/// * Each pointer in `args` must point to valid memory of the corresponding length +unsafe fn build_cluster_scan_args( + arg_count: u64, + args: *const usize, + arg_lengths: *const u64, + failure_callback: FailureCallback, + callback_index: usize, +) -> Option { + if arg_count == 0 { + return Some(redis::ClusterScanArgs::builder().build()); + } + + let arg_vec = unsafe { convert_string_pointer_array_to_vector(args, arg_count, arg_lengths) }; + + // Parse arguments from vector. + let mut pattern_arg: &[u8] = &[]; + let mut type_arg: &[u8] = &[]; + let mut count_arg: &[u8] = &[]; + + let mut iter = arg_vec.iter().peekable(); + while let Some(arg) = iter.next() { + match *arg { + b"MATCH" => match iter.next() { + Some(p) => pattern_arg = p, + None => { + unsafe { + report_error( + failure_callback, + callback_index, + "No argument following MATCH.".into(), + RequestErrorType::Unspecified, + ); + } + return None; + } + }, + b"TYPE" => match iter.next() { + Some(t) => type_arg = t, + None => { + unsafe { + report_error( + failure_callback, + callback_index, + "No argument following TYPE.".into(), + RequestErrorType::Unspecified, + ); + } + return None; + } + }, + b"COUNT" => match iter.next() { + Some(c) => count_arg = c, + None => { + unsafe { + report_error( + failure_callback, + callback_index, + "No argument following COUNT.".into(), + RequestErrorType::Unspecified, + ); + } + return None; + } + }, + _ => { + unsafe { + report_error( + failure_callback, + callback_index, + "Unknown cluster scan argument".into(), + RequestErrorType::Unspecified, + ); + } + return None; + } + } + } + + // Build cluster scan arguments. + let mut cluster_scan_args_builder = redis::ClusterScanArgs::builder(); + + if !pattern_arg.is_empty() { + cluster_scan_args_builder = cluster_scan_args_builder.with_match_pattern(pattern_arg); + } + + if !type_arg.is_empty() { + let converted_type = match std::str::from_utf8(type_arg) { + Ok(t) => redis::ObjectType::from(t.to_string()), + Err(_) => { + unsafe { + report_error( + failure_callback, + callback_index, + "Invalid UTF-8 in TYPE argument".into(), + RequestErrorType::Unspecified, + ); + } + return None; + } + }; + + cluster_scan_args_builder = cluster_scan_args_builder.with_object_type(converted_type); + } + + if !count_arg.is_empty() { + let count_str = match std::str::from_utf8(count_arg) { + Ok(c) => c, + Err(_) => { + unsafe { + report_error( + failure_callback, + callback_index, + "Invalid UTF-8 in COUNT argument".into(), + RequestErrorType::Unspecified, + ); + } + return None; + } + }; + + let converted_count = match count_str.parse::() { + Ok(c) => c, + Err(_) => { + unsafe { + report_error( + failure_callback, + callback_index, + "Invalid COUNT value".into(), + RequestErrorType::Unspecified, + ); + } + return None; + } + }; + + cluster_scan_args_builder = cluster_scan_args_builder.with_count(converted_count); + } + + Some(cluster_scan_args_builder.build()) +} + +/// Converts an array of pointers to strings to a vector of strings. +/// +/// # Arguments +/// +/// * `data` - Pointer to an array of pointers to string data +/// * `len` - The number of strings in the array +/// * `data_len` - Pointer to an array of string lengths +/// +/// # Safety +/// +/// `convert_string_pointer_array_to_vector` returns a `Vec` of u8 slice which holds pointers of C +/// strings. The returned `Vec<&'a [u8]>` is meant to be copied into Rust code. Storing them +/// for later use will cause the program to crash as the pointers will be freed by the caller. +unsafe fn convert_string_pointer_array_to_vector<'a>( + data: *const usize, + len: u64, + data_len: *const u64, +) -> Vec<&'a [u8]> { + let string_ptrs = unsafe { from_raw_parts(data, len as usize) }; + let string_lengths = unsafe { from_raw_parts(data_len, len as usize) }; + + let mut result = Vec::<&[u8]>::with_capacity(string_ptrs.len()); + for (i, &str_ptr) in string_ptrs.iter().enumerate() { + let slice = unsafe { from_raw_parts(str_ptr as *const u8, string_lengths[i] as usize) }; + result.push(slice); + } + + result +} diff --git a/sources/Valkey.Glide/BaseClient.cs b/sources/Valkey.Glide/BaseClient.cs index 41ccbe90..4573d74e 100644 --- a/sources/Valkey.Glide/BaseClient.cs +++ b/sources/Valkey.Glide/BaseClient.cs @@ -22,21 +22,21 @@ public void Dispose() GC.SuppressFinalize(this); lock (_lock) { - if (_clientPointer == IntPtr.Zero) + if (ClientPointer == IntPtr.Zero) { return; } - _messageContainer.DisposeWithError(null); - CloseClientFfi(_clientPointer); - _clientPointer = IntPtr.Zero; + MessageContainer.DisposeWithError(null); + CloseClientFfi(ClientPointer); + ClientPointer = IntPtr.Zero; } } public async ValueTask DisposeAsync() => await Task.Run(Dispose); - public override string ToString() => $"{GetType().Name} {{ 0x{_clientPointer:X} {_clientInfo} }}"; + public override string ToString() => $"{GetType().Name} {{ 0x{ClientPointer:X} {_clientInfo} }}"; - public override int GetHashCode() => (int)_clientPointer; + public override int GetHashCode() => (int)ClientPointer; #endregion public methods @@ -49,11 +49,11 @@ protected static async Task CreateClient(BaseClientConfiguration config, F nint failureCallbackPointer = Marshal.GetFunctionPointerForDelegate(client._failureCallbackDelegate); using FFI.ConnectionConfig request = config.Request.ToFfi(); - Message message = client._messageContainer.GetMessageForCall(); + Message message = client.MessageContainer.GetMessageForCall(); CreateClientFfi(request.ToPtr(), successCallbackPointer, failureCallbackPointer); - client._clientPointer = await message; // This will throw an error thru failure callback if any + client.ClientPointer = await message; // This will throw an error thru failure callback if any - if (client._clientPointer == IntPtr.Zero) + if (client.ClientPointer == IntPtr.Zero) { throw new ConnectionException("Failed creating a client"); } @@ -65,7 +65,7 @@ protected BaseClient() { _successCallbackDelegate = SuccessCallback; _failureCallbackDelegate = FailureCallback; - _messageContainer = new(this); + MessageContainer = new(this); } protected internal delegate T ResponseHandler(IntPtr response); @@ -83,8 +83,8 @@ internal virtual async Task Command(Cmd command, Route? route = n using FFI.Route? ffiRoute = route?.ToFfi(); // 3. Sumbit request to the rust part - Message message = _messageContainer.GetMessageForCall(); - CommandFfi(_clientPointer, (ulong)message.Index, cmd.ToPtr(), ffiRoute?.ToPtr() ?? IntPtr.Zero); + Message message = MessageContainer.GetMessageForCall(); + CommandFfi(ClientPointer, (ulong)message.Index, cmd.ToPtr(), ffiRoute?.ToPtr() ?? IntPtr.Zero); // 4. Get a response and Handle it IntPtr response = await message; @@ -109,8 +109,8 @@ internal virtual async Task Command(Cmd command, Route? route = n using FFI.BatchOptions? ffiOptions = options?.ToFfi(); // 3. Sumbit request to the rust part - Message message = _messageContainer.GetMessageForCall(); - BatchFfi(_clientPointer, (ulong)message.Index, ffiBatch.ToPtr(), raiseOnError, ffiOptions?.ToPtr() ?? IntPtr.Zero); + Message message = MessageContainer.GetMessageForCall(); + BatchFfi(ClientPointer, (ulong)message.Index, ffiBatch.ToPtr(), raiseOnError, ffiOptions?.ToPtr() ?? IntPtr.Zero); // 4. Get a response and Handle it IntPtr response = await message; @@ -132,16 +132,27 @@ internal virtual async Task Command(Cmd command, Route? route = n } #endregion protected methods + #region protected fields + protected Version? _serverVersion; // cached server version + protected static readonly Version DefaultServerVersion = new(8, 0, 0); + #endregion protected fields + + #region internal fields + /// Raw pointer to the underlying native client. + internal IntPtr ClientPointer; + internal readonly MessageContainer MessageContainer; + #endregion internal fields + #region private methods private void SuccessCallback(ulong index, IntPtr ptr) => // Work needs to be offloaded from the calling thread, because otherwise we might starve the client's thread pool. - Task.Run(() => _messageContainer.GetMessage((int)index).SetResult(ptr)); + Task.Run(() => MessageContainer.GetMessage((int)index).SetResult(ptr)); private void FailureCallback(ulong index, IntPtr strPtr, RequestErrorType errType) { string str = Marshal.PtrToStringAnsi(strPtr)!; // Work needs to be offloaded from the calling thread, because otherwise we might starve the client's thread pool. - _ = Task.Run(() => _messageContainer.GetMessage((int)index).SetException(Create(errType, str))); + _ = Task.Run(() => MessageContainer.GetMessage((int)index).SetException(Create(errType, str))); } ~BaseClient() => Dispose(); @@ -166,13 +177,8 @@ private void FailureCallback(ulong index, IntPtr strPtr, RequestErrorType errTyp /// and held in order to prevent the cost of marshalling on each function call. private readonly SuccessAction _successCallbackDelegate; - /// Raw pointer to the underlying native client. - private IntPtr _clientPointer; - private readonly MessageContainer _messageContainer; private readonly object _lock = new(); private string _clientInfo = ""; // used to distinguish and identify clients during tests - protected Version? _serverVersion; // cached server version - protected static readonly Version DefaultServerVersion = new(8, 0, 0); #endregion private fields } diff --git a/sources/Valkey.Glide/Commands/Constants/Constants.cs b/sources/Valkey.Glide/Commands/Constants/Constants.cs index 35d4e2cc..6c413883 100644 --- a/sources/Valkey.Glide/Commands/Constants/Constants.cs +++ b/sources/Valkey.Glide/Commands/Constants/Constants.cs @@ -19,6 +19,7 @@ public static class Constants public const string ByScoreKeyword = "BYSCORE"; public const string MatchKeyword = "MATCH"; public const string CountKeyword = "COUNT"; + public const string TypeKeyword = "TYPE"; public const string LeftKeyword = "LEFT"; public const string RightKeyword = "RIGHT"; public const string BeforeKeyword = "BEFORE"; diff --git a/sources/Valkey.Glide/Commands/IGenericClusterCommands.cs b/sources/Valkey.Glide/Commands/IGenericClusterCommands.cs index c83e18e3..e48b0251 100644 --- a/sources/Valkey.Glide/Commands/IGenericClusterCommands.cs +++ b/sources/Valkey.Glide/Commands/IGenericClusterCommands.cs @@ -1,5 +1,6 @@ // Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 +using Valkey.Glide.Commands.Options; using Valkey.Glide.Pipeline; using static Valkey.Glide.Errors; @@ -251,4 +252,32 @@ public interface IGenericClusterCommands /// or if a transaction failed due to a WATCH command. /// Task Exec(ClusterBatch batch, bool raiseOnError, ClusterBatchOptions options); + + /// + /// Incrementally iterates over the matching keys in the cluster. + /// + /// The cluster SCAN command is a cursor-based iterator. An iteration starts when the cursor is set to + /// . At every call of the command, the server returns + /// an updated cursor that the user needs to use as the cursor argument in the next call. The iteration + /// terminates when returns true. + /// + /// + /// The cursor for iteration. + /// Optional scan options for filtering results. + /// The next cursor and an array of matching keys. + /// + /// + /// var allKeys = new List<ValkeyKey>(); + /// var cursor = ClusterScanCursor.InitialCursor(); + /// + /// while (!cursor.IsFinished) + /// { + /// (cursor, var keys) = await client.ScanAsync(cursor); + /// allKeys.AddRange(keys); + /// } + /// + /// + /// SCAN command + /// Cluster Scan + Task<(ClusterScanCursor cursor, ValkeyKey[] keys)> ScanAsync(ClusterScanCursor cursor, ScanOptions? options = null); } diff --git a/sources/Valkey.Glide/Commands/IGenericCommands.cs b/sources/Valkey.Glide/Commands/IGenericCommands.cs index b230c415..4dc1c8a2 100644 --- a/sources/Valkey.Glide/Commands/IGenericCommands.cs +++ b/sources/Valkey.Glide/Commands/IGenericCommands.cs @@ -1,5 +1,6 @@ // Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 +using Valkey.Glide.Commands.Options; using Valkey.Glide.Pipeline; using static Valkey.Glide.Errors; @@ -167,7 +168,30 @@ public interface IGenericCommands /// Task Exec(Batch batch, bool raiseOnError, BatchOptions options); - - - + /// + /// Incrementally iterates over the matching keys in the database. + /// + /// The SCAN command is a cursor-based iterator. An iteration starts when the cursor + /// is set to "0". At every call of the command, the + /// server returns an updated cursor that the user needs to use as the cursor argument in the next + /// call. The iteration terminates when the cursor is "0". + /// + /// + /// The cursor for iteration. + /// Optional scan options for filtering results. + /// The next cursor and an array of matching keys. + /// + /// + /// var allKeys = new List<ValkeyKey>(); + /// string cursor = "0"; + /// + /// do + /// { + /// (cursor, var keys) = await client.ScanAsync(cursor); + /// allKeys.AddRange(keys); + /// } while (cursor != "0"); + /// + /// + /// SCAN command + Task<(string cursor, ValkeyKey[] keys)> ScanAsync(string cursor, ScanOptions? options = null); } diff --git a/sources/Valkey.Glide/Commands/Options/ScanOptions.cs b/sources/Valkey.Glide/Commands/Options/ScanOptions.cs new file mode 100644 index 00000000..444d61b3 --- /dev/null +++ b/sources/Valkey.Glide/Commands/Options/ScanOptions.cs @@ -0,0 +1,66 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +using static Valkey.Glide.Commands.Constants.Constants; + +namespace Valkey.Glide.Commands.Options; + +/// +/// Options for the scan commands. +/// +public class ScanOptions +{ + /// + /// Pattern to filter keys against. + /// + public string? MatchPattern { get; set; } + + /// + /// Hint for the number of keys to return per iteration. + /// + public long? Count { get; set; } + + /// + /// Type to filter keys against. + /// + public ValkeyType? Type { get; set; } + + /// + /// Converts the options to an array of string arguments for scan commands. + /// + /// Array of string arguments. + internal string[] ToArgs() + { + List args = []; + + if (MatchPattern != null) + { + args.Add(MatchKeyword); + args.Add(MatchPattern); + } + + if (Count.HasValue) + { + args.Add(CountKeyword); + args.Add(Count.Value.ToString()); + } + + if (Type.HasValue) + { + args.Add(TypeKeyword); + args.Add(MapValkeyTypeToString(Type.Value)); + } + + return [.. args]; + } + + private static string MapValkeyTypeToString(ValkeyType type) => type switch + { + ValkeyType.String => "string", + ValkeyType.List => "list", + ValkeyType.Set => "set", + ValkeyType.SortedSet => "zset", + ValkeyType.Hash => "hash", + ValkeyType.Stream => "stream", + ValkeyType.Unknown or ValkeyType.None or _ => throw new ArgumentException($"Unsupported ValkeyType for SCAN: {type}") + }; +} diff --git a/sources/Valkey.Glide/GlideClient.cs b/sources/Valkey.Glide/GlideClient.cs index d15f7865..7b25adc7 100644 --- a/sources/Valkey.Glide/GlideClient.cs +++ b/sources/Valkey.Glide/GlideClient.cs @@ -91,10 +91,6 @@ public async Task PingAsync(ValkeyValue message, CommandFlags flags = return await Command(Request.Ping(message)); } - - - - public async Task[]> ConfigGetAsync(ValkeyValue pattern = default, CommandFlags flags = CommandFlags.None) { Utils.Requires(flags == CommandFlags.None, "Command flags are not supported by GLIDE"); @@ -177,25 +173,35 @@ public async IAsyncEnumerable KeysAsync(int database = -1, ValkeyValu { Utils.Requires(flags == CommandFlags.None, "Command flags are not supported by GLIDE"); - long currentCursor = cursor; + var options = new ScanOptions(); + if (!pattern.IsNull) options.MatchPattern = pattern.ToString(); + if (pageSize > 0) options.Count = pageSize; + + string currentCursor = cursor.ToString(); + ValkeyKey[] keys; int currentOffset = pageOffset; do { - (long nextCursor, ValkeyKey[] keys) = await Command(Request.ScanAsync(currentCursor, pattern, pageSize)); + (currentCursor, keys) = await ScanAsync(currentCursor, options); - IEnumerable keysToYield = currentOffset > 0 ? keys.Skip(currentOffset) : keys; + if (currentOffset > 0) + { + keys = [.. keys.Skip(currentOffset)]; + currentOffset = 0; + } - foreach (ValkeyKey key in keysToYield) + foreach (ValkeyKey key in keys) { yield return key; } - currentCursor = nextCursor; - currentOffset = 0; - } while (currentCursor != 0); + } while (currentCursor != "0"); } + public async Task<(string cursor, ValkeyKey[] keys)> ScanAsync(string cursor, ScanOptions? options = null) + => await Command(Request.ScanAsync(cursor, options)); + protected override async Task GetServerVersionAsync() { if (_serverVersion == null) diff --git a/sources/Valkey.Glide/GlideClusterClient.cs b/sources/Valkey.Glide/GlideClusterClient.cs index 2fb4ec03..e866a58c 100644 --- a/sources/Valkey.Glide/GlideClusterClient.cs +++ b/sources/Valkey.Glide/GlideClusterClient.cs @@ -1,5 +1,7 @@ // Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 +using System.Runtime.InteropServices; + using Valkey.Glide.Commands; using Valkey.Glide.Commands.Options; using Valkey.Glide.Internals; @@ -7,6 +9,8 @@ using static Valkey.Glide.ConnectionConfiguration; using static Valkey.Glide.Errors; +using static Valkey.Glide.Internals.FFI; +using static Valkey.Glide.Internals.ResponseHandler; using static Valkey.Glide.Pipeline.Options; using static Valkey.Glide.Route; @@ -318,4 +322,96 @@ protected override async Task GetServerVersionAsync() return _serverVersion; } + + /// + /// Iterates incrementally over keys in the cluster. + /// + /// The cursor to use for this iteration. + /// Optional scan options to filter results. + /// A tuple containing the next cursor and the keys found in this iteration. + /// + /// + public async Task<(ClusterScanCursor cursor, ValkeyKey[] keys)> ScanAsync(ClusterScanCursor cursor, ScanOptions? options = null) + { + string[] args = options?.ToArgs() ?? []; + var (nextCursorId, keys) = await ClusterScanCommand(cursor.CursorId, args); + return (new ClusterScanCursor(nextCursorId), keys); + } + + /// + /// Executes a cluster scan command with the given cursor and arguments. + /// + /// The cursor for the scan iteration. + /// Additional arguments for the scan command. + /// A tuple containing the next cursor and the keys found in this iteration. + private async Task<(string cursor, ValkeyKey[] keys)> ClusterScanCommand(string cursor, string[] args) + { + var message = MessageContainer.GetMessageForCall(); + IntPtr cursorPtr = Marshal.StringToHGlobalAnsi(cursor); + + IntPtr[]? argPtrs = null; + IntPtr argsPtr = IntPtr.Zero; + IntPtr argLengthsPtr = IntPtr.Zero; + + try + { + if (args.Length > 0) + { + // 1. Get a pointer to the array of argument string pointers. + // Example: if args = ["MATCH", "key*"], then argPtrs[0] points + // to "MATCH", argPtrs[1] points to "key*", and argsPtr points + // to the argsPtrs array. + argPtrs = [.. args.Select(Marshal.StringToHGlobalAnsi)]; + argsPtr = Marshal.AllocHGlobal(IntPtr.Size * args.Length); + Marshal.Copy(argPtrs, 0, argsPtr, args.Length); + + // 2. Get a pointer to an array of argument string lengths. + // Example: if args = ["MATCH", "key*"], then argLengths[0] = 5 + // (length of "MATCH"), argLengths[1] = 4 (length of "key*"), + // and argLengthsPtr points to the argLengths array. + var argLengths = args.Select(arg => (ulong)arg.Length).ToArray(); + argLengthsPtr = Marshal.AllocHGlobal(sizeof(ulong) * args.Length); + Marshal.Copy(argLengths.Select(l => (long)l).ToArray(), 0, argLengthsPtr, args.Length); + } + + // Submit request to Rust and wait for response. + RequestClusterScanFfi(ClientPointer, (ulong)message.Index, cursorPtr, (ulong)args.Length, argsPtr, argLengthsPtr); + IntPtr response = await message; + + try + { + var result = HandleResponse(response); + var array = (object[])result!; + var nextCursor = array[0]!.ToString()!; + var keys = ((object[])array[1]!).Select(k => new ValkeyKey(k!.ToString())).ToArray(); + return (nextCursor, keys); + } + finally + { + FreeResponse(response); + } + } + finally + { + // Clean up args memory + if (argLengthsPtr != IntPtr.Zero) + { + Marshal.FreeHGlobal(argLengthsPtr); + } + + if (argsPtr != IntPtr.Zero) + { + Marshal.FreeHGlobal(argsPtr); + } + + if (argPtrs != null) + { + Array.ForEach(argPtrs, Marshal.FreeHGlobal); + } + + // Clean up cursor in Rust + RemoveClusterScanCursorFfi(cursorPtr); + Marshal.FreeHGlobal(cursorPtr); + } + } } diff --git a/sources/Valkey.Glide/Internals/FFI.methods.cs b/sources/Valkey.Glide/Internals/FFI.methods.cs index 0d8d7de9..6bd49dbd 100644 --- a/sources/Valkey.Glide/Internals/FFI.methods.cs +++ b/sources/Valkey.Glide/Internals/FFI.methods.cs @@ -30,6 +30,14 @@ internal partial class FFI [LibraryImport("libglide_rs", EntryPoint = "close_client")] [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] public static partial void CloseClientFfi(IntPtr client); + + [LibraryImport("libglide_rs", EntryPoint = "request_cluster_scan")] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + public static partial void RequestClusterScanFfi(IntPtr client, ulong index, IntPtr cursor, ulong argCount, IntPtr args, IntPtr argLengths); + + [LibraryImport("libglide_rs", EntryPoint = "remove_cluster_scan_cursor")] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + public static partial void RemoveClusterScanCursorFfi(IntPtr cursorId); #else [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "command")] public static extern void CommandFfi(IntPtr client, ulong index, IntPtr cmdInfo, IntPtr routeInfo); @@ -45,5 +53,11 @@ internal partial class FFI [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "close_client")] public static extern void CloseClientFfi(IntPtr client); + + [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "request_cluster_scan")] + public static extern void RequestClusterScanFfi(IntPtr client, ulong index, IntPtr cursor, ulong argCount, IntPtr args, IntPtr argLengths); + + [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "remove_cluster_scan_cursor")] + public static extern void RemoveClusterScanCursorFfi(IntPtr cursorId); #endif } diff --git a/sources/Valkey.Glide/Internals/Request.GenericCommands.cs b/sources/Valkey.Glide/Internals/Request.GenericCommands.cs index 0ec98816..062f2986 100644 --- a/sources/Valkey.Glide/Internals/Request.GenericCommands.cs +++ b/sources/Valkey.Glide/Internals/Request.GenericCommands.cs @@ -318,25 +318,19 @@ public static Cmd SortAndStoreAsync(ValkeyKey destination, ValkeyKey public static Cmd KeyMoveAsync(ValkeyKey key, int database) => Simple(RequestType.Move, [key.ToGlideString(), database.ToGlideString()]); - public static Cmd ScanAsync(long cursor, ValkeyValue pattern = default, long pageSize = 0) + public static Cmd ScanAsync(string cursor, ScanOptions? options = null) { List args = [cursor.ToGlideString()]; - if (!pattern.IsNull) + if (options != null) { - args.AddRange([Constants.MatchKeyword.ToGlideString(), pattern.ToGlideString()]); - } - - if (pageSize > 0) - { - args.AddRange([Constants.CountKeyword.ToGlideString(), pageSize.ToGlideString()]); + args.AddRange(options.ToArgs().Select(arg => arg.ToGlideString())); } return new(RequestType.Scan, [.. args], false, arr => { - object[] scanArray = arr; - long nextCursor = scanArray[0] is long l ? l : long.Parse(scanArray[0].ToString() ?? "0"); - ValkeyKey[] keys = [.. ((object[])scanArray[1]).Cast().Select(gs => new ValkeyKey(gs))]; + string nextCursor = arr[0].ToString() ?? "0"; + ValkeyKey[] keys = [.. ((object[])arr[1]).Select(item => new ValkeyKey(item.ToString()))]; return (nextCursor, keys); }); } diff --git a/sources/Valkey.Glide/abstract_APITypes/ClusterScanCursor.cs b/sources/Valkey.Glide/abstract_APITypes/ClusterScanCursor.cs new file mode 100644 index 00000000..5743fd54 --- /dev/null +++ b/sources/Valkey.Glide/abstract_APITypes/ClusterScanCursor.cs @@ -0,0 +1,42 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +using System.Runtime.InteropServices; + +using Valkey.Glide.Internals; + +namespace Valkey.Glide; + +/// +/// A cursor used to iterate through data returned by cluster scan requests. +/// +/// +public class ClusterScanCursor +{ + private const string InitialCursorId = "0"; + private const string FinishedCursorId = "finished"; + + /// + /// The cursor ID for this scan iteration. + /// + public string CursorId { get; } + + /// + /// Indicates whether this cursor represents the end of the scan. + /// + public bool IsFinished => CursorId == FinishedCursorId; + + /// + /// Creates a cursor to start a new cluster scan. + /// + /// A cursor to start a new cluster scan. + public static ClusterScanCursor InitialCursor() => new(InitialCursorId); + + /// + /// Creates a new cursor with the specified cursor ID. + /// + /// The cursor ID. + internal ClusterScanCursor(string cursorId) + { + CursorId = cursorId; + } +} diff --git a/tests/Valkey.Glide.IntegrationTests/ScanTests.cs b/tests/Valkey.Glide.IntegrationTests/ScanTests.cs new file mode 100644 index 00000000..b69ba0e3 --- /dev/null +++ b/tests/Valkey.Glide.IntegrationTests/ScanTests.cs @@ -0,0 +1,145 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +using Valkey.Glide.Commands.Options; + +namespace Valkey.Glide.IntegrationTests; + +public class ScanTests(TestConfiguration config) +{ + public TestConfiguration Config { get; } = config; + + [Theory(DisableDiscoveryEnumeration = true)] + [MemberData(nameof(Config.TestClients), MemberType = typeof(TestConfiguration))] + public async Task TestScanAsync_PrefixFiltering(BaseClient client) + { + // Add keys. + string prefix = Guid.NewGuid().ToString(); + var key1 = new ValkeyKey($"{prefix}:key1"); + var key2 = new ValkeyKey($"{prefix}:key2"); + + await client.StringSetAsync(key1, "value1"); + await client.StringSetAsync(key2, "value2"); + + // Get all keys with matching prefix. + var options = new ScanOptions { MatchPattern = $"{prefix}:*" }; + var matchingKeys = await ExecuteScanAsync(client, options); + + Assert.Equivalent(new[] { key1, key2 }, matchingKeys); + + // Get all keys with non-existent prefix. + options = new ScanOptions { MatchPattern = $"nonexistent:*" }; + matchingKeys = await ExecuteScanAsync(client, options); + + Assert.Empty(matchingKeys); + + // Remove keys. + await client.KeyDeleteAsync([key1, key2]); + } + + [Theory(DisableDiscoveryEnumeration = true)] + [MemberData(nameof(Config.TestClients), MemberType = typeof(TestConfiguration))] + public async Task TestScanAsync_TypeFiltering(BaseClient client) + { + // Add keys with different types. + string prefix = Guid.NewGuid().ToString(); + var stringKey = new ValkeyKey($"{prefix}:string"); + var listKey = new ValkeyKey($"{prefix}:list"); + var setKey = new ValkeyKey($"{prefix}:set"); + + await client.StringSetAsync(stringKey, "value"); + await client.ListLeftPushAsync(listKey, "item"); + await client.SetAddAsync(setKey, "member"); + + // Get all keys with string type. + var options = new ScanOptions { MatchPattern = $"{prefix}:*", Type = ValkeyType.String }; + var matchingKeys = await ExecuteScanAsync(client, options); + + Assert.Equivalent(new[] { stringKey }, matchingKeys); + + // Get all keys with set type. + options = new ScanOptions { MatchPattern = $"{prefix}:*", Type = ValkeyType.Set }; + matchingKeys = await ExecuteScanAsync(client, options); + + Assert.Equivalent(new[] { setKey }, matchingKeys); + + // Get all keys with non-existent type. + options = new ScanOptions { MatchPattern = $"{prefix}:*", Type = ValkeyType.Hash }; + matchingKeys = await ExecuteScanAsync(client, options); + + Assert.Empty(matchingKeys); + + // Remove keys. + await client.KeyDeleteAsync([stringKey, listKey, setKey]); + } + + [Theory(DisableDiscoveryEnumeration = true)] + [MemberData(nameof(Config.TestClients), MemberType = typeof(TestConfiguration))] + public async Task TestScanAsync_CombinedOptions(BaseClient client) + { + // Add keys with different prefixes and types. + string prefix = Guid.NewGuid().ToString(); + var matchStringKey = new ValkeyKey($"{prefix}:match:string"); + var matchListKey = new ValkeyKey($"{prefix}:match:list"); + var otherStringKey = new ValkeyKey($"{prefix}:other:string"); + + await client.StringSetAsync(matchStringKey, "value"); + await client.ListLeftPushAsync(matchListKey, "item"); + await client.StringSetAsync(otherStringKey, "value"); + + // Get all keys with matching type and prefix. + var options = new ScanOptions + { + MatchPattern = $"{prefix}:match:*", + Count = 10, + Type = ValkeyType.String + }; + var matchingKeys = await ExecuteScanAsync(client, options); + + Assert.Equivalent(new[] { matchStringKey }, matchingKeys); + + // Remove keys. + await client.KeyDeleteAsync([matchStringKey, matchListKey, otherStringKey]); + } + + [Fact] + public async Task TestScanAsync_InvalidCursorId() + { + var standaloneClient = TestConfiguration.DefaultStandaloneClient(); + var exception = await Assert.ThrowsAsync(async () => + { + await standaloneClient.ScanAsync("invalid"); + }); + + var clusterClient = TestConfiguration.DefaultClusterClient(); + exception = await Assert.ThrowsAsync(async () => + { + await clusterClient.ScanAsync(new ClusterScanCursor("invalid")); + }); + } + + private static async Task ExecuteScanAsync(BaseClient client, ScanOptions? options = null) + { + var allKeys = new List(); + + if (client is GlideClient) + { + string cursor = "0"; + do + { + (cursor, var keys) = await ((GlideClient)client).ScanAsync(cursor, options); + allKeys.AddRange(keys); + } while (cursor != "0"); + } + else + { + var cursor = ClusterScanCursor.InitialCursor(); + while (!cursor.IsFinished) + { + (cursor, var keys) = await ((GlideClusterClient)client).ScanAsync(cursor, options); + allKeys.AddRange(keys); + } + } + + return [.. allKeys]; + } +} diff --git a/tests/Valkey.Glide.UnitTests/CommandTests.cs b/tests/Valkey.Glide.UnitTests/CommandTests.cs index 99a7815b..b0c3c3b4 100644 --- a/tests/Valkey.Glide.UnitTests/CommandTests.cs +++ b/tests/Valkey.Glide.UnitTests/CommandTests.cs @@ -188,14 +188,19 @@ public void ValidateCommandArgs() () => Assert.Equal(["MOVE", "key", "1"], Request.KeyMoveAsync("key", 1).GetArgs()), // SCAN Commands - () => Assert.Equal(["SCAN", "0"], Request.ScanAsync(0).GetArgs()), - () => Assert.Equal(["SCAN", "10"], Request.ScanAsync(10).GetArgs()), - () => Assert.Equal(["SCAN", "0", "MATCH", "pattern*"], Request.ScanAsync(0, "pattern*").GetArgs()), - () => Assert.Equal(["SCAN", "5", "MATCH", "test*"], Request.ScanAsync(5, "test*").GetArgs()), - () => Assert.Equal(["SCAN", "0", "COUNT", "10"], Request.ScanAsync(0, pageSize: 10).GetArgs()), - () => Assert.Equal(["SCAN", "5", "COUNT", "20"], Request.ScanAsync(5, pageSize: 20).GetArgs()), - () => Assert.Equal(["SCAN", "0", "MATCH", "pattern*", "COUNT", "10"], Request.ScanAsync(0, "pattern*", 10).GetArgs()), - () => Assert.Equal(["SCAN", "10", "MATCH", "*suffix", "COUNT", "5"], Request.ScanAsync(10, "*suffix", 5).GetArgs()), + () => Assert.Equal(["SCAN", "0"], Request.ScanAsync("0").GetArgs()), + () => Assert.Equal(["SCAN", "10"], Request.ScanAsync("10").GetArgs()), + () => Assert.Equal(["SCAN", "0", "MATCH", "pattern*"], Request.ScanAsync("0", new ScanOptions { MatchPattern = "pattern*" }).GetArgs()), + () => Assert.Equal(["SCAN", "5", "MATCH", "test*"], Request.ScanAsync("5", new ScanOptions { MatchPattern = "test*" }).GetArgs()), + () => Assert.Equal(["SCAN", "0", "COUNT", "10"], Request.ScanAsync("0", new ScanOptions { Count = 10 }).GetArgs()), + () => Assert.Equal(["SCAN", "5", "COUNT", "20"], Request.ScanAsync("5", new ScanOptions { Count = 20 }).GetArgs()), + () => Assert.Equal(["SCAN", "0", "TYPE", "string"], Request.ScanAsync("0", new ScanOptions { Type = ValkeyType.String }).GetArgs()), + () => Assert.Equal(["SCAN", "5", "TYPE", "list"], Request.ScanAsync("5", new ScanOptions { Type = ValkeyType.List }).GetArgs()), + () => Assert.Equal(["SCAN", "0", "TYPE", "set"], Request.ScanAsync("0", new ScanOptions { Type = ValkeyType.Set }).GetArgs()), + () => Assert.Equal(["SCAN", "0", "TYPE", "zset"], Request.ScanAsync("0", new ScanOptions { Type = ValkeyType.SortedSet }).GetArgs()), + () => Assert.Equal(["SCAN", "0", "TYPE", "hash"], Request.ScanAsync("0", new ScanOptions { Type = ValkeyType.Hash }).GetArgs()), + () => Assert.Equal(["SCAN", "0", "TYPE", "stream"], Request.ScanAsync("0", new ScanOptions { Type = ValkeyType.Stream }).GetArgs()), + () => Assert.Equal(["SCAN", "10", "MATCH", "key*", "COUNT", "20", "TYPE", "string"], Request.ScanAsync("10", new ScanOptions { MatchPattern = "key*", Count = 20, Type = ValkeyType.String }).GetArgs()), // WAIT Commands () => Assert.Equal(["WAIT", "1", "1000"], Request.WaitAsync(1, 1000).GetArgs()), @@ -483,20 +488,20 @@ public void ValidateCommandConverters() // SCAN Commands Converters () => { - var result = Request.ScanAsync(0).Converter([0L, new object[] { (gs)"key1", (gs)"key2" }]); - Assert.Equal(0L, result.Item1); - Assert.Equal(["key1", "key2"], result.Item2.Select(k => k.ToString())); + var result = Request.ScanAsync("0").Converter(["0", new object[] { (gs)"key1", (gs)"key2" }]); + Assert.Equal("0", result.Item1); + Assert.Equal([new ValkeyKey("key1"), new ValkeyKey("key2")], result.Item2); }, () => { - var result = Request.ScanAsync(10).Converter([5L, new object[] { (gs)"test" }]); - Assert.Equal(5L, result.Item1); - Assert.Equal(["test"], result.Item2.Select(k => k.ToString())); + var result = Request.ScanAsync("10").Converter(["5", new object[] { (gs)"test" }]); + Assert.Equal("5", result.Item1); + Assert.Equal([new ValkeyKey("test")], result.Item2); }, () => { - var result = Request.ScanAsync(0).Converter([0L, Array.Empty()]); - Assert.Equal(0L, result.Item1); + var result = Request.ScanAsync("0").Converter(["0", Array.Empty()]); + Assert.Equal("0", result.Item1); Assert.Empty(result.Item2); },