diff --git a/.ai/mcp/mcp.json b/.ai/mcp/mcp.json new file mode 100644 index 00000000..e69de29b diff --git a/.gitignore b/.gitignore index cd904db3..272a2eed 100644 --- a/.gitignore +++ b/.gitignore @@ -23,8 +23,12 @@ dist/ .reviews/ .review/ -# Agent memory +# AI agent artifacts +.ai/ .claude/agent-memory/ +# Private keys — never commit signing keys or credentials +*.key + # Svelte/Node frontend **/node_modules/ \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md index 7032e9b8..d71c1b3a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -82,12 +82,11 @@ Business logic uses vertical slice architecture with `Immediate.Handlers` (sourc ``` Features/ - Chat/ (4 handlers — SendMessage, BuildChatRequest, ProcessToolCalls, CompactSession) - Session/ (5 handlers — Load, Save, Clear, List, Prune) - Cost/ (3 handlers — CheckBudget, RecordUsage, GetSummary) - Memory/ (5 handlers — Store, Search, Recall, ExtractFacts, Decay) - Tools/ (1 handler — ExecuteTool) - Behaviors/ (pipeline behaviors — validation, logging) + Chat/ (4 handlers — ApplySecurityGuards, SanitizeReply, BuildChatRequest, RouteModel) + Session/ (4 handlers — Load, Save, Clear, Prune) + Cost/ (3 handlers — CheckBudget, RecordUsage, GetCostSummary) + Memory/ (5 handlers — WriteMemory, ClearMemory, ExtractFacts, SearchMemory, GetMemoryContext) + Behaviors/ (pipeline behaviors — authorization, logging) ``` Generated registration methods: `AddclawsharpHandlers()` / `AddclawsharpBehaviors()` (lowercase 'c' — uses raw assembly name). Handler lifetime is `ServiceLifetime.Singleton`. diff --git a/compose.yaml b/compose.yaml index 03d21507..4ec9300d 100644 --- a/compose.yaml +++ b/compose.yaml @@ -56,7 +56,16 @@ services: tmpfs: - /tmp # .NET runtime temp files - /var/tmp # some system libs expect this - network_mode: host + ports: + - "127.0.0.1:3001:3001" + extra_hosts: + - "host.docker.internal:host-gateway" + healthcheck: + test: ["CMD-SHELL", "dotnet /app/clawsharp.dll doctor 2>/dev/null || exit 1"] + interval: 30s + timeout: 5s + start_period: 15s + retries: 3 volumes: # Persists config.json, sessions, memory, skills, and the .secret_key file. # The .secret_key file is only used when neither CLAWSHARP_SECRET_KEY env var @@ -80,18 +89,18 @@ services: CLAWSHARP__channels__web__enabled: "true" CLAWSHARP__channels__web__webHost: "0.0.0.0" CLAWSHARP__channels__web__webPort: "3001" - CLAWSHARP__channels__web__pairingToken: "test-token-123" + CLAWSHARP__channels__web__pairingToken: "${CLAWSHARP_WEB_PAIRING_TOKEN:?Set CLAWSHARP_WEB_PAIRING_TOKEN in .env}" # ── PostgreSQL memory backend ────────────────────────────────────────── CLAWSHARP__memory__backend: postgres - CLAWSHARP__memory__connectionString: "Host=127.0.0.1;Database=clawsharp;Username=clawsharp;Password=${POSTGRES_PASSWORD}" + CLAWSHARP__memory__connectionString: "Host=postgres;Database=clawsharp;Username=clawsharp;Password=${POSTGRES_PASSWORD}" # ── Analytics (interactions) → PostgreSQL ────────────────────────────── CLAWSHARP__analytics__enabled: "true" CLAWSHARP__analytics__backend: postgres - # ── LM Studio (host network — use localhost) ─────────────────────────── - CLAWSHARP__providers__lmstudio__baseUrl: "http://127.0.0.1:1234" + # ── LM Studio (via host.docker.internal) ──────────────────────────────── + CLAWSHARP__providers__lmstudio__baseUrl: "http://host.docker.internal:1234" CLAWSHARP__agents__defaults__model: "qwen/qwen3.5-9b" # ── Optional: inline config overrides (no file needed) ────────────────── diff --git a/nuget.config b/nuget.config new file mode 100644 index 00000000..765346e5 --- /dev/null +++ b/nuget.config @@ -0,0 +1,7 @@ + + + + + + + diff --git a/src/clawsharp-sign/Program.cs b/src/clawsharp-sign/Program.cs index 6a9f2aa9..c04703f4 100644 --- a/src/clawsharp-sign/Program.cs +++ b/src/clawsharp-sign/Program.cs @@ -118,14 +118,15 @@ private static int Sign(ReadOnlySpan args) // Derive package name from directory name or primary plugin DLL var package = DerivePackageName(pluginDir, dllFiles); - // Build manifest without signature for signing + var timestamp = DateTimeOffset.UtcNow.ToString("O"); + + // Build manifest without signature for signing (canonical payload — no timestamp) var manifestData = new ManifestData { + Files = new SortedDictionary(files, StringComparer.Ordinal), + KeyId = keyId, Package = package, Version = version, - KeyId = keyId, - Timestamp = DateTimeOffset.UtcNow.ToString("O"), - Files = new SortedDictionary(files, StringComparer.Ordinal), }; // Canonical JSON: sorted keys, no whitespace @@ -141,7 +142,7 @@ private static int Sign(ReadOnlySpan args) Package = manifestData.Package, Version = manifestData.Version, KeyId = manifestData.KeyId, - Timestamp = manifestData.Timestamp, + Timestamp = timestamp, Files = manifestData.Files, Signature = signatureBase64, }; @@ -194,13 +195,13 @@ public static int Verify(ReadOnlySpan args) } // Step 1: Verify signature over canonical manifest payload (D-30: signature first) + // Canonical payload excludes timestamp — must match signer's ManifestData shape var manifestData = new ManifestData { + Files = signedManifest.Files, + KeyId = signedManifest.KeyId, Package = signedManifest.Package, Version = signedManifest.Version, - KeyId = signedManifest.KeyId, - Timestamp = signedManifest.Timestamp, - Files = signedManifest.Files, }; var canonicalBytes = JsonSerializer.SerializeToUtf8Bytes(manifestData, ManifestJsonContext.Default.ManifestData); @@ -307,23 +308,26 @@ Verify a signed plugin directory against a public key. // ── JSON DTOs ─────────────────────────────────────────────────────────── -/// Manifest data without signature — the canonical payload that gets signed. +/// +/// Manifest data without signature — the canonical payload that gets signed. +/// Properties are in alphabetical order by JSON key to match the verifier's +/// SortedDictionary-based canonical payload (STJ source-gen serializes +/// in declaration order). Timestamp is excluded from the signed payload — +/// it is metadata in the full only. +/// internal sealed class ManifestData { - [JsonPropertyName("package")] - public string Package { get; init; } = ""; - - [JsonPropertyName("version")] - public string Version { get; init; } = ""; + [JsonPropertyName("files")] + public SortedDictionary Files { get; init; } = new(StringComparer.Ordinal); [JsonPropertyName("keyId")] public string KeyId { get; init; } = ""; - [JsonPropertyName("timestamp")] - public string Timestamp { get; init; } = ""; + [JsonPropertyName("package")] + public string Package { get; init; } = ""; - [JsonPropertyName("files")] - public SortedDictionary Files { get; init; } = new(StringComparer.Ordinal); + [JsonPropertyName("version")] + public string Version { get; init; } = ""; } /// Full manifest with signature — written to plugin.manifest.json. diff --git a/src/clawsharp-web/src/lib/markdown.ts b/src/clawsharp-web/src/lib/markdown.ts index 64df431c..3ea4804a 100644 --- a/src/clawsharp-web/src/lib/markdown.ts +++ b/src/clawsharp-web/src/lib/markdown.ts @@ -12,6 +12,15 @@ function escapeHtml(s: string): string { .replace(/>/g, '>'); } +function isSafeUrl(url: string): boolean { + try { + const parsed = new URL(url, window.location.origin); + return ['http:', 'https:', 'mailto:'].includes(parsed.protocol); + } catch { + return false; + } +} + function inlineMarkdown(s: string): string { // Inline code — must come before bold/italic s = s.replace(/`([^`]+)`/g, '$1'); @@ -19,10 +28,13 @@ function inlineMarkdown(s: string): string { s = s.replace(/\*\*(.+?)\*\*/g, '$1'); // Italic s = s.replace(/\*(.+?)\*/g, '$1'); - // Links + // Links — only allow safe protocols (http, https, mailto) s = s.replace( /\[([^\]]+)\]\(([^)]+)\)/g, - '$1', + (_, text, url) => + isSafeUrl(url) + ? `${text}` + : text, ); return s; } diff --git a/src/clawsharp.Plugin.Confluence/ConfluenceApiClient.cs b/src/clawsharp.Plugin.Confluence/ConfluenceApiClient.cs index e81d1092..427d3ab0 100644 --- a/src/clawsharp.Plugin.Confluence/ConfluenceApiClient.cs +++ b/src/clawsharp.Plugin.Confluence/ConfluenceApiClient.cs @@ -7,7 +7,7 @@ namespace Clawsharp.Plugin.Confluence; /// /// HTTP client for the Confluence REST API v2 with cursor-based pagination per D-10. /// Uses a named injected via DI with SsrfGuard-protected -/// per D-26. +/// per D-26. /// internal sealed class ConfluenceApiClient { diff --git a/src/clawsharp.Plugin.Gcs/GcsPlugin.cs b/src/clawsharp.Plugin.Gcs/GcsPlugin.cs index 74351ee0..5168c8da 100644 --- a/src/clawsharp.Plugin.Gcs/GcsPlugin.cs +++ b/src/clawsharp.Plugin.Gcs/GcsPlugin.cs @@ -68,6 +68,6 @@ private static bool IsPrivateIpLikeName(string name) if (!System.Net.IPAddress.TryParse(name, out var ip)) return false; - return Clawsharp.Security.SsrfGuard.IsPrivateOrReservedAddress(ip); + return Security.SsrfGuard.IsPrivateOrReservedAddress(ip); } } diff --git a/src/clawsharp/A2a/A2aAgentCardBuilder.cs b/src/clawsharp/A2a/A2aAgentCardBuilder.cs index c2c5bd29..b9fd4570 100644 --- a/src/clawsharp/A2a/A2aAgentCardBuilder.cs +++ b/src/clawsharp/A2a/A2aAgentCardBuilder.cs @@ -8,8 +8,8 @@ namespace Clawsharp.A2a; /// /// Builds the Agent Card for A2A discovery (/.well-known/agent-card.json). /// Skills are derived 1:1 from the tool registry, filtered to Low/Medium sensitivity (D-10). -/// Capabilities reflect runtime config (D-12). Metadata follows config override chains (D-13). -/// Card is built once at startup and cached — tool registry is immutable at runtime (D-11). +/// Capabilities reflect runtime config (D-12). Name falls back to "ClawSharp Agent" when +/// a2a.agentCard.name is null. Card is built once at startup and cached (D-11). /// public sealed class A2aAgentCardBuilder( IToolRegistry toolRegistry, diff --git a/src/clawsharp/A2a/A2aAttributes.cs b/src/clawsharp/A2a/A2aAttributes.cs index e4945abd..86578de5 100644 --- a/src/clawsharp/A2a/A2aAttributes.cs +++ b/src/clawsharp/A2a/A2aAttributes.cs @@ -43,4 +43,18 @@ internal static class A2aAttributes /// Unique chain identifier correlating delegation hops across instances. internal const string DelegationChainId = "a2a.delegation.chain_id"; + + // ── Cooperative delegation metadata keys (propagated in A2A task metadata) ── + + /// Metadata key: current delegation depth (incremented per hop). + internal const string MetaDepth = "clawsharp.delegation.depth"; + + /// Metadata key: maximum allowed delegation depth. + internal const string MetaMaxDepth = "clawsharp.delegation.maxDepth"; + + /// Metadata key: machine name of the originating instance. + internal const string MetaOriginInstance = "clawsharp.delegation.originInstance"; + + /// Metadata key: unique chain identifier for correlating delegation hops. + internal const string MetaChainId = "clawsharp.delegation.chainId"; } diff --git a/src/clawsharp/A2a/A2aClientConfig.cs b/src/clawsharp/A2a/A2aClientConfig.cs index 9ce22fea..37215a65 100644 --- a/src/clawsharp/A2a/A2aClientConfig.cs +++ b/src/clawsharp/A2a/A2aClientConfig.cs @@ -4,7 +4,7 @@ namespace Clawsharp.A2a; /// Client-side A2A delegation config. Added to as nullable Client property. /// Null = no delegation capability (zero tools registered). /// -public sealed record A2aClientConfig +public sealed class A2aClientConfig { /// Max delegation chain depth. Default: 3. /// Uses set (not init) so STJ source-gen preserves defaults on deserialization. @@ -19,7 +19,7 @@ public sealed record A2aClientConfig } /// Configuration for a single trusted external A2A agent. -public sealed record TrustedAgentConfig +public sealed class TrustedAgentConfig { /// Base URL of the external agent's A2A endpoint. public required string Url { get; init; } @@ -32,7 +32,7 @@ public sealed record TrustedAgentConfig } /// Authentication credentials for a trusted agent. Supports bearer token and API key. -public sealed record AgentAuthConfig +public sealed class AgentAuthConfig { /// Auth type: "bearer" or "apiKey". public required string Type { get; init; } diff --git a/src/clawsharp/A2a/A2aClientService.cs b/src/clawsharp/A2a/A2aClientService.cs index 81fc8ac3..3975655e 100644 --- a/src/clawsharp/A2a/A2aClientService.cs +++ b/src/clawsharp/A2a/A2aClientService.cs @@ -1,3 +1,4 @@ +using System.Collections.Concurrent; using System.Collections.Frozen; using System.Net.Http.Headers; using System.Text; @@ -62,26 +63,27 @@ public async Task InitializeAsync(CancellationToken ct = default) return; } - var clients = new Dictionary(AgentRegistry.Count, StringComparer.Ordinal); - var cards = new Dictionary(AgentRegistry.Count, StringComparer.Ordinal); + var clients = new ConcurrentDictionary(StringComparer.Ordinal); + var cards = new ConcurrentDictionary(StringComparer.Ordinal); - foreach (var (name, agentConfig) in AgentRegistry) + await Parallel.ForEachAsync(AgentRegistry, ct, async (kvp, token) => { + var (name, agentConfig) = kvp; try { var uri = new Uri(agentConfig.Url); // D-03: Validate URL via SsrfGuard at startup - var ssrfResult = await SsrfGuard.CheckAsync(uri, ct).ConfigureAwait(false); + var ssrfResult = await SsrfGuard.CheckAsync(uri, token).ConfigureAwait(false); if (ssrfResult is not null) { LogAgentUrlBlocked(_logger, name, agentConfig.Url, ssrfResult); - continue; + return; } // Create HttpClient with auth headers pre-configured var httpClient = _httpFactory.CreateClient("a2a-client"); - ConfigureAuth(httpClient, agentConfig.Auth); + ConfigureAuth(httpClient, agentConfig.Auth, name, _logger); // Create A2AClient per agent (D-16) var client = new A2AClient(uri, httpClient); @@ -92,7 +94,7 @@ public async Task InitializeAsync(CancellationToken ct = default) try { var resolver = new A2ACardResolver(uri, httpClient, "/.well-known/agent-card.json", null!); - card = await resolver.GetAgentCardAsync(ct).ConfigureAwait(false); + card = await resolver.GetAgentCardAsync(token).ConfigureAwait(false); LogAgentCardFetched(_logger, name, card.Name ?? name); } catch (Exception ex) @@ -106,19 +108,20 @@ public async Task InitializeAsync(CancellationToken ct = default) { LogAgentInitFailed(_logger, name, ex); } - } + }).ConfigureAwait(false); _clients = clients.ToFrozenDictionary(StringComparer.Ordinal); _agentCards = cards.ToFrozenDictionary(StringComparer.Ordinal); } /// - /// Delegates a task to an external A2A agent. Returns the text result as a string. - /// Never throws — errors are returned as descriptive strings (D-19). + /// Delegates a task to an external A2A agent. Returns (Text, IsError) so callers + /// can reliably classify outcomes. Never throws — errors are returned as descriptive + /// tuples with IsError = true (D-19). /// Uses streaming by default (D-16), falls back to sync+poll when agent card /// capabilities.streaming is false (D-17). /// - public async Task DelegateAsync( + public async Task<(string Text, bool IsError)> DelegateAsync( string agentName, string taskText, int? timeoutSeconds = null, @@ -130,7 +133,7 @@ public async Task DelegateAsync( var available = _clients.Count > 0 ? string.Join(", ", _clients.Keys) : "(none)"; - return $"Unknown agent '{agentName}'. Available: {available}"; + return ($"Unknown agent '{agentName}'. Available: {available}", true); } try @@ -154,21 +157,25 @@ public async Task DelegateAsync( var supportsStreaming = _agentCards.TryGetValue(agentName, out var card) && card?.Capabilities?.Streaming == true; - return supportsStreaming + var text = supportsStreaming ? await DelegateStreamingAsync(client, agentName, request, timeoutCts.Token).ConfigureAwait(false) : await DelegateSyncAsync(client, agentName, request, timeoutCts.Token).ConfigureAwait(false); + + return (text, false); } catch (OperationCanceledException) { - return $"Delegation to '{agentName}' failed: operation timed out or was cancelled."; + return ($"Delegation to '{agentName}' failed: operation timed out or was cancelled.", true); } catch (HttpRequestException ex) { - return $"Delegation to '{agentName}' failed: {ex.Message}"; + _logger.LogWarning(ex, "A2A delegation to '{AgentName}' failed", agentName); + return ($"Delegation to '{agentName}' failed: the remote agent is unavailable.", true); } catch (Exception ex) { - return $"Delegation to '{agentName}' failed: {ex.Message}"; + _logger.LogWarning(ex, "A2A delegation to '{AgentName}' failed unexpectedly", agentName); + return ($"Delegation to '{agentName}' failed: an unexpected error occurred.", true); } } @@ -303,8 +310,9 @@ public static string ExtractTextFromTask(AgentTask task) /// /// Configures authentication headers on an HttpClient based on . + /// Logs a warning for unrecognized auth types. /// - private static void ConfigureAuth(HttpClient httpClient, AgentAuthConfig auth) + private static void ConfigureAuth(HttpClient httpClient, AgentAuthConfig auth, string agentName, ILogger logger) { switch (auth.Type.ToUpperInvariant()) { @@ -321,6 +329,9 @@ private static void ConfigureAuth(HttpClient httpClient, AgentAuthConfig auth) httpClient.DefaultRequestHeaders.Add("X-API-Key", auth.Key); } break; + default: + LogUnrecognizedAuthType(logger, auth.Type, agentName); + break; } } @@ -340,4 +351,7 @@ private static void ConfigureAuth(HttpClient httpClient, AgentAuthConfig auth) [LoggerMessage(Level = LogLevel.Debug, Message = "A2A delegation to '{AgentName}' reached state: {State}")] private static partial void LogDelegationStateUpdate(ILogger logger, string agentName, string state); + + [LoggerMessage(Level = LogLevel.Warning, Message = "Unrecognized auth type '{AuthType}' for agent '{AgentName}'")] + private static partial void LogUnrecognizedAuthType(ILogger logger, string authType, string agentName); } diff --git a/src/clawsharp/A2a/A2aConfig.cs b/src/clawsharp/A2a/A2aConfig.cs index 8bb4d178..d5d07a2f 100644 --- a/src/clawsharp/A2a/A2aConfig.cs +++ b/src/clawsharp/A2a/A2aConfig.cs @@ -4,7 +4,7 @@ namespace Clawsharp.A2a; /// A2A Protocol configuration. Null on AppConfig = disabled (zero overhead). /// Minimum config: { "a2a": { "enabled": true } } /// -public sealed record A2aConfig +public sealed class A2aConfig { /// Whether A2A protocol endpoints are active. public bool Enabled { get; init; } @@ -20,7 +20,7 @@ public sealed record A2aConfig } /// Server-side A2A task processing configuration. -public sealed record A2aServerConfig +public sealed class A2aServerConfig { /// Minutes before completed/failed tasks are evicted. Default: 60. /// Uses set (not init) so STJ source-gen preserves defaults on deserialization. @@ -36,7 +36,7 @@ public sealed record A2aServerConfig } /// Agent Card metadata overrides for discovery. -public sealed record AgentCardConfig +public sealed class AgentCardConfig { /// Override agent name. Null = BotName from agent config, then "ClawSharp Agent". public string? Name { get; init; } @@ -49,7 +49,7 @@ public sealed record AgentCardConfig } /// Agent Card provider metadata overrides. -public sealed record AgentProviderConfig +public sealed class AgentProviderConfig { /// Organization name. Null = Organization.Name from config, then "ClawSharp". public string? Organization { get; init; } diff --git a/src/clawsharp/A2a/A2aDelegateTool.cs b/src/clawsharp/A2a/A2aDelegateTool.cs index ac60755f..cd8035e9 100644 --- a/src/clawsharp/A2a/A2aDelegateTool.cs +++ b/src/clawsharp/A2a/A2aDelegateTool.cs @@ -83,25 +83,24 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat activity?.SetTag(A2aAttributes.Direction, "outbound"); activity?.SetTag(A2aAttributes.TargetAgent, agentName); activity?.SetTag(A2aAttributes.DelegationDepth, currentDepth); - if (metadata.TryGetValue("clawsharp.delegation.chainId", out var chainElement)) + if (metadata.TryGetValue(A2aAttributes.MetaChainId, out var chainElement)) activity?.SetTag(A2aAttributes.DelegationChainId, chainElement.GetString()); string result; var outcome = "failed"; try { - result = await _clientService.DelegateAsync(agentName, taskText, timeout, metadata, ct) + // DelegateAsync never throws — errors are returned via IsError (D-19). + var (text, isError) = await _clientService.DelegateAsync(agentName, taskText, timeout, metadata, ct) .ConfigureAwait(false); - outcome = result.StartsWith("Error", StringComparison.Ordinal) ? "failed" : "completed"; - } - catch - { - outcome = "failed"; - throw; + outcome = isError ? "failed" : "completed"; + result = text; } finally { activity?.SetTag(A2aAttributes.Outcome, outcome); + if (outcome == "failed") + activity?.SetStatus(ActivityStatusCode.Error, "A2A delegation failed"); var elapsed = Stopwatch.GetElapsedTime(startTimestamp); _metrics.RecordTaskDuration(elapsed.TotalSeconds, "outbound"); if (outcome == "completed") @@ -138,10 +137,10 @@ internal static Dictionary BuildDelegationMetadata(int curr { return new Dictionary { - ["clawsharp.delegation.depth"] = JsonSerializer.SerializeToElement(currentDepth + 1), - ["clawsharp.delegation.maxDepth"] = JsonSerializer.SerializeToElement(depthLimit), - ["clawsharp.delegation.originInstance"] = JsonSerializer.SerializeToElement(Environment.MachineName), - ["clawsharp.delegation.chainId"] = JsonSerializer.SerializeToElement( + [A2aAttributes.MetaDepth] = JsonSerializer.SerializeToElement(currentDepth + 1), + [A2aAttributes.MetaMaxDepth] = JsonSerializer.SerializeToElement(depthLimit), + [A2aAttributes.MetaOriginInstance] = JsonSerializer.SerializeToElement(Environment.MachineName), + [A2aAttributes.MetaChainId] = JsonSerializer.SerializeToElement( Guid.CreateVersion7().ToString("N")), }; } diff --git a/src/clawsharp/A2a/A2aRouteRegistrar.cs b/src/clawsharp/A2a/A2aRouteRegistrar.cs index d6c7819e..f391c347 100644 --- a/src/clawsharp/A2a/A2aRouteRegistrar.cs +++ b/src/clawsharp/A2a/A2aRouteRegistrar.cs @@ -43,7 +43,8 @@ public void ConfigureServices(WebApplicationBuilder builder) sp.GetRequiredService>(), sp.GetRequiredService(), sp.GetRequiredService(), - sp.GetRequiredService())); + sp.GetRequiredService(), + sp.GetService())); // SDK registration -- ITaskStore + IA2ARequestHandler already registered, TryAddSingleton is a no-op builder.Services.AddA2AAgent(_agentCard); diff --git a/src/clawsharp/A2a/A2aServerWithPush.cs b/src/clawsharp/A2a/A2aServerWithPush.cs index 4d38f666..39f634f8 100644 --- a/src/clawsharp/A2a/A2aServerWithPush.cs +++ b/src/clawsharp/A2a/A2aServerWithPush.cs @@ -2,8 +2,11 @@ using System.Text.Json; using A2A; using Clawsharp.Config.Features; +using Clawsharp.Core.Security; +using Clawsharp.McpServer; using Clawsharp.Security; using Clawsharp.Webhooks; +using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; namespace Clawsharp.A2a; @@ -28,8 +31,10 @@ public sealed partial class A2aServerWithPush : A2AServer /// private readonly ConcurrentDictionary> _pushConfigs = new(StringComparer.Ordinal); + private readonly A2aTaskStore _taskStore; private readonly WebhookQueueRegistry _queueRegistry; private readonly DeliveryStorage _deliveryStorage; + private readonly IHttpContextAccessor? _httpContextAccessor; private readonly ILogger _logger; public A2aServerWithPush( @@ -39,11 +44,14 @@ public A2aServerWithPush( ILogger logger, A2AServerOptions options, WebhookQueueRegistry queueRegistry, - DeliveryStorage deliveryStorage) + DeliveryStorage deliveryStorage, + IHttpContextAccessor? httpContextAccessor = null) : base(handler, taskStore, notifier, logger, options) { + _taskStore = taskStore; _queueRegistry = queueRegistry; _deliveryStorage = deliveryStorage; + _httpContextAccessor = httpContextAccessor; _logger = logger; // Wire up the push delivery trigger via task store callback @@ -55,10 +63,13 @@ public A2aServerWithPush( /// /// Creates a push notification config for a task. Validates the callback URL /// against before storing (PUSH-04). + /// M-02: Verifies task ownership before allowing push config creation. /// public override async Task CreateTaskPushNotificationConfigAsync( CreateTaskPushNotificationConfigRequest request, CancellationToken cancellationToken) { + VerifyTaskOwnership(request.TaskId); + var url = request.Config?.Url; if (string.IsNullOrEmpty(url)) throw new A2AException("Push notification config must include a URL.", A2AErrorCode.InvalidParams); @@ -84,31 +95,28 @@ public override async Task CreateTaskPushNotificatio PushNotificationConfig = request.Config, }; - _pushConfigs.AddOrUpdate( - request.TaskId, - _ => [config], - (_, existing) => - { - lock (existing) - { - existing.Add(config); - } - return existing; - }); + var list = _pushConfigs.GetOrAdd(request.TaskId, _ => []); + lock (list) + { + list.Add(config); + } // Ensure a dynamic queue exists for this task's push notifications _queueRegistry.TryCreateQueue($"a2a-push:{request.TaskId}"); - LogPushConfigCreated(_logger, request.TaskId, configId, url); + LogPushConfigCreated(_logger, request.TaskId, configId, RedactUrl(url)); return config; } /// /// Retrieves a specific push notification config by task ID and config ID. + /// M-02: Verifies task ownership before returning push config. /// public override Task GetTaskPushNotificationConfigAsync( GetTaskPushNotificationConfigRequest request, CancellationToken cancellationToken) { + VerifyTaskOwnership(request.TaskId); + if (!_pushConfigs.TryGetValue(request.TaskId, out var configs)) throw new A2AException($"No push configs found for task '{request.TaskId}'.", A2AErrorCode.TaskNotFound); @@ -128,10 +136,13 @@ public override Task GetTaskPushNotificationConfigAs /// /// Lists all push notification configs for a task. + /// M-02: Verifies task ownership before listing push configs. /// public override Task ListTaskPushNotificationConfigAsync( ListTaskPushNotificationConfigRequest request, CancellationToken cancellationToken) { + VerifyTaskOwnership(request.TaskId); + List snapshot; if (_pushConfigs.TryGetValue(request.TaskId, out var configs)) @@ -154,26 +165,22 @@ public override Task ListTaskPushNotific } /// - /// Deletes a push notification config. If no configs remain for the task, - /// removes the dynamic queue to free resources. + /// Deletes a push notification config. Empty lists are left in the dictionary + /// rather than eagerly removed — CleanupTask handles full eviction when the task + /// is evicted, avoiding a TOCTOU race with concurrent Create calls. + /// M-02: Verifies task ownership before allowing deletion. /// public override Task DeleteTaskPushNotificationConfigAsync( DeleteTaskPushNotificationConfigRequest request, CancellationToken cancellationToken) { + VerifyTaskOwnership(request.TaskId); + if (!_pushConfigs.TryGetValue(request.TaskId, out var configs)) throw new A2AException($"No push configs found for task '{request.TaskId}'.", A2AErrorCode.TaskNotFound); - bool removedLast; lock (configs) { configs.RemoveAll(c => string.Equals(c.Id, request.Id, StringComparison.Ordinal)); - removedLast = configs.Count == 0; - } - - if (removedLast) - { - _pushConfigs.TryRemove(request.TaskId, out _); - _queueRegistry.RemoveQueue($"a2a-push:{request.TaskId}"); } LogPushConfigDeleted(_logger, request.TaskId, request.Id); @@ -214,6 +221,18 @@ internal async Task OnTaskStateChangedAsync(string taskId, AgentTask task, Cance if (string.IsNullOrEmpty(pushUrl)) continue; + // Re-validate SSRF at delivery time to close the TOCTOU window between + // registration and delivery (MED-04: DNS could change between these events). + if (Uri.TryCreate(pushUrl, UriKind.Absolute, out var pushUri)) + { + var ssrfError = await SsrfGuard.CheckAsync(pushUri, cancellationToken).ConfigureAwait(false); + if (ssrfError is not null) + { + LogPushUrlRejected(_logger, taskId, pushUrl, ssrfError); + continue; + } + } + var record = new WebhookDeliveryRecord { Id = WebhookSigner.NewEventId(), @@ -247,6 +266,38 @@ internal async Task OnTaskStateChangedAsync(string taskId, AgentTask task, Cance } } + // ── Ownership verification (M-02) ────────────────────────────────────────── + + /// + /// Extracts the caller identity from the current HTTP context and verifies + /// that the specified task belongs to the caller. Throws + /// with if ownership check fails. + /// Uses TaskNotFound (not Unauthorized) to avoid leaking task existence to non-owners. + /// + private void VerifyTaskOwnership(string taskId) + { + var callerId = GetCallerOwnerId(); + if (!_taskStore.IsTaskOwnedBy(taskId, callerId)) + { + LogPushOwnershipDenied(_logger, taskId, callerId ?? "(unknown)"); + throw new A2AException( + $"Task '{taskId}' not found.", + A2AErrorCode.TaskNotFound); + } + } + + /// + /// Extracts the authenticated caller's owner ID from the current HTTP context. + /// Returns the KeyId or User.Name from , or null + /// when no HTTP context is available. + /// + private string? GetCallerOwnerId() + { + var authResult = _httpContextAccessor?.HttpContext?.Items[BearerTokenAuthFilter.AuthResultKey] + as McpServerAuthResult; + return authResult?.KeyId ?? authResult?.User?.Name; + } + // ── Cleanup ─────────────────────────────────────────────────────────────── /// @@ -260,6 +311,17 @@ public void CleanupTask(string taskId) _queueRegistry.RemoveQueue($"a2a-push:{taskId}"); } + /// + /// Strips query string and fragment from a URL to avoid logging auth tokens + /// that may be embedded in push notification callback URLs. + /// + private static string RedactUrl(string url) + { + if (Uri.TryCreate(url, UriKind.Absolute, out var uri)) + return uri.GetLeftPart(UriPartial.Path); + return "(invalid url)"; + } + // ── Source-generated log methods ────────────────────────────────────────── [LoggerMessage(EventId = 1, Level = LogLevel.Information, @@ -277,4 +339,8 @@ public void CleanupTask(string taskId) [LoggerMessage(EventId = 4, Level = LogLevel.Warning, Message = "Push URL rejected for task '{TaskId}': url={Url}, reason={Reason}")] private static partial void LogPushUrlRejected(ILogger logger, string taskId, string url, string reason); + + [LoggerMessage(EventId = 5, Level = LogLevel.Warning, + Message = "Push config ownership denied for task '{TaskId}': caller={CallerId}")] + private static partial void LogPushOwnershipDenied(ILogger logger, string taskId, string callerId); } diff --git a/src/clawsharp/A2a/A2aTaskEvictionService.cs b/src/clawsharp/A2a/A2aTaskEvictionService.cs index 7edc0f52..54cabf0e 100644 --- a/src/clawsharp/A2a/A2aTaskEvictionService.cs +++ b/src/clawsharp/A2a/A2aTaskEvictionService.cs @@ -14,6 +14,7 @@ namespace Clawsharp.A2a; public sealed partial class A2aTaskEvictionService : BackgroundService { private readonly A2aTaskStore _store; + private readonly A2aServerWithPush? _pushServer; private readonly TimeSpan _ttl; private readonly int _maxHistory; private readonly ILogger _logger; @@ -21,9 +22,11 @@ public sealed partial class A2aTaskEvictionService : BackgroundService public A2aTaskEvictionService( A2aTaskStore store, A2aServerConfig? serverConfig, - ILogger logger) + ILogger logger, + A2aServerWithPush? pushServer = null) { _store = store; + _pushServer = pushServer; _ttl = TimeSpan.FromMinutes(serverConfig?.TaskTtlMinutes ?? 60); _maxHistory = serverConfig?.MaxTaskHistory ?? 1000; _logger = logger; @@ -69,6 +72,7 @@ internal async Task EvictAsync(CancellationToken ct = default) if (now - taskTimestamp >= _ttl) { await _store.DeleteTaskAsync(taskId, ct).ConfigureAwait(false); + _pushServer?.CleanupTask(taskId); evictedCount++; } } @@ -90,6 +94,7 @@ internal async Task EvictAsync(CancellationToken ct = default) foreach (var (taskId, _) in evictionCandidates) { await _store.DeleteTaskAsync(taskId, ct).ConfigureAwait(false); + _pushServer?.CleanupTask(taskId); evictedCount++; } } diff --git a/src/clawsharp/A2a/A2aTaskProcessor.cs b/src/clawsharp/A2a/A2aTaskProcessor.cs index 32327ed6..5ae744fc 100644 --- a/src/clawsharp/A2a/A2aTaskProcessor.cs +++ b/src/clawsharp/A2a/A2aTaskProcessor.cs @@ -6,6 +6,7 @@ using Clawsharp.Core; using Clawsharp.Core.Security; using Clawsharp.Core.Sessions; +using Clawsharp.Security; using Clawsharp.Core.Utilities; using Clawsharp.Cost; using Clawsharp.McpServer; @@ -149,7 +150,7 @@ await updater.StartWorkAsync( // ── D-14: Cooperative delegation depth from upstream ClawSharp ── var inboundDepth = 0; - if (context.Metadata?.TryGetValue("clawsharp.delegation.depth", out var depthElement) == true + if (context.Metadata?.TryGetValue(A2aAttributes.MetaDepth, out var depthElement) == true && depthElement.ValueKind == JsonValueKind.Number) { inboundDepth = depthElement.GetInt32(); @@ -245,11 +246,28 @@ await updater.RequireInputAsync( } } - // ── Final artifact for sync callers (D-01) ────────────────── - if (!context.StreamingResponse) + // ── H-02: Scan accumulated text for credential leaks ─────── + var scanResult = LeakDetector.Scan(fullText.ToString()); + var safeText = scanResult.Redacted; + if (!scanResult.IsClean) + { + LogLeakDetected(logger, context.TaskId, scanResult.Patterns.Count); + } + + // ── Final artifact ────────────────────────────────────────── + if (context.StreamingResponse) + { + // Close the artifact stream with lastChunk=true per SDK contract. + await updater.AddArtifactAsync( + [Part.FromText("")], + append: true, + lastChunk: true, + cancellationToken: linked.Token).ConfigureAwait(false); + } + else { await updater.AddArtifactAsync( - [Part.FromText(fullText.ToString())], + [Part.FromText(safeText)], cancellationToken: linked.Token).ConfigureAwait(false); } @@ -258,7 +276,7 @@ await updater.CompleteAsync( new Message { Role = Role.Agent, - Parts = [Part.FromText(fullText.ToString())], + Parts = [Part.FromText(safeText)], }, linked.Token).ConfigureAwait(false); @@ -267,7 +285,7 @@ await updater.CompleteAsync( // reverting a completed task if cancellation fires during bookkeeping if (!context.IsContinuation) session.Messages.Add(new ChatMessage(MessageRole.User, userPrompt)); - session.Messages.Add(new ChatMessage(MessageRole.Assistant, fullText.ToString())); + session.Messages.Add(new ChatMessage(MessageRole.Assistant, safeText)); await sessionStore.SaveAsync(session, CancellationToken.None).ConfigureAwait(false); // ── Record cost (D-11) ────────────────────────────────────── @@ -313,6 +331,8 @@ await updater.FailAsync( { // ── OTel: finalize span + record metrics ──────────────────── activity?.SetTag(A2aAttributes.Outcome, outcome); + if (outcome is "failed" or "canceled") + activity?.SetStatus(ActivityStatusCode.Error, $"A2A task {outcome}"); var elapsed = Stopwatch.GetElapsedTime(startTimestamp); metrics.RecordTaskDuration(elapsed.TotalSeconds, "inbound"); if (outcome == "completed") @@ -440,4 +460,8 @@ private static string MapPipelineError(Exception ex) [LoggerMessage(EventId = 7, Level = LogLevel.Information, Message = "A2A task {TaskId} requires input")] private static partial void LogTaskInputRequired(ILogger logger, string taskId); + + [LoggerMessage(EventId = 8, Level = LogLevel.Warning, + Message = "A2A task {TaskId} output redacted: {PatternCount} leak pattern(s) detected")] + private static partial void LogLeakDetected(ILogger logger, string taskId, int patternCount); } diff --git a/src/clawsharp/A2a/A2aTaskRecord.cs b/src/clawsharp/A2a/A2aTaskRecord.cs index 991af190..a7a0ee97 100644 --- a/src/clawsharp/A2a/A2aTaskRecord.cs +++ b/src/clawsharp/A2a/A2aTaskRecord.cs @@ -14,6 +14,13 @@ public sealed record A2aTaskRecord public required DateTimeOffset UpdatedAt { get; init; } public string? OrgUserId { get; init; } + /// + /// Identity of the authenticated client that created/owns this task. + /// Used for IDOR protection — GetTaskAsync/ListTasksAsync filter by this field. + /// Null for tasks created before ownership tracking was added (backward compat). + /// + public string? OwnerId { get; init; } + /// /// Opaque SDK-serialized AgentTask JSON. Deserialized via A2AJsonUtilities.DefaultOptions, /// NOT via A2aJsonContext. The SDK owns its own serialization. diff --git a/src/clawsharp/A2a/A2aTaskStore.cs b/src/clawsharp/A2a/A2aTaskStore.cs index d566b045..070a783c 100644 --- a/src/clawsharp/A2a/A2aTaskStore.cs +++ b/src/clawsharp/A2a/A2aTaskStore.cs @@ -2,6 +2,9 @@ using System.Text.Json; using A2A; using Clawsharp.Config; +using Clawsharp.Core.Security; +using Clawsharp.McpServer; +using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; namespace Clawsharp.A2a; @@ -13,12 +16,21 @@ namespace Clawsharp.A2a; /// On construction, the file is loaded with last-write-wins deduplication. /// Pattern mirrors DeliveryStorage from the webhook subsystem. /// +/// +/// Owner-based access control (M-01): each task records an OwnerId at creation time. +/// and extract the current caller's +/// identity from and filter results to owned tasks only. +/// Tasks created before ownership tracking (OwnerId is null) are visible to all callers +/// for backward compatibility. +/// public sealed partial class A2aTaskStore : ITaskStore { private readonly ConcurrentDictionary _tasks = new(StringComparer.Ordinal); + private readonly ConcurrentDictionary _taskOwners = new(StringComparer.Ordinal); private readonly SemaphoreSlim _writeLock = new(1, 1); private readonly string _filePath; private readonly ILogger _logger; + private readonly IHttpContextAccessor? _httpContextAccessor; /// /// Optional callback invoked after a task is saved. Used by @@ -40,17 +52,18 @@ internal Func? OnTaskSaved /// /// Production constructor. Stores tasks at ~/.clawsharp/a2a/tasks.jsonl. /// - public A2aTaskStore(ILogger logger, A2aServerConfig? serverConfig = null) - : this(ConfigLoader.ExpandHome("~/.clawsharp/a2a"), logger) + public A2aTaskStore(ILogger logger, IHttpContextAccessor? httpContextAccessor = null) + : this(ConfigLoader.ExpandHome("~/.clawsharp/a2a"), logger, httpContextAccessor) { } /// /// Internal constructor for tests. Accepts a custom directory path. /// - internal A2aTaskStore(string directory, ILogger logger) + internal A2aTaskStore(string directory, ILogger logger, IHttpContextAccessor? httpContextAccessor = null) { _logger = logger; + _httpContextAccessor = httpContextAccessor; Directory.CreateDirectory(directory); _filePath = Path.Combine(directory, "tasks.jsonl"); LoadFromDisk(); @@ -63,15 +76,43 @@ internal IReadOnlyCollection> GetAllTasks() => _tasks.ToArray(); /// + /// + /// M-01 IDOR protection: extracts the caller's identity from + /// and returns null if the task exists but belongs to a different owner. + /// Tasks with null OwnerId (pre-ownership) are visible to all authenticated callers. + /// public Task GetTaskAsync(string taskId, CancellationToken cancellationToken = default) - => Task.FromResult(_tasks.TryGetValue(taskId, out var task) ? task : null); + { + if (!_tasks.TryGetValue(taskId, out var task)) + return Task.FromResult(null); + + var callerId = GetCallerOwnerId(); + if (callerId is not null + && _taskOwners.TryGetValue(taskId, out var ownerId) + && ownerId is not null + && !string.Equals(ownerId, callerId, StringComparison.Ordinal)) + { + return Task.FromResult(null); + } + + return Task.FromResult(task); + } /// public async Task SaveTaskAsync(string taskId, AgentTask task, CancellationToken cancellationToken = default) { - ValidateTransition(taskId, task); + if (!ValidateTransition(taskId, task)) + throw new InvalidOperationException($"Invalid A2A task state transition for task '{taskId}'."); + _tasks[taskId] = task; + // Record owner from current HTTP context on first save (task creation). + // Subsequent saves (state transitions) preserve the original owner. + if (!_taskOwners.ContainsKey(taskId)) + { + _taskOwners[taskId] = GetCallerOwnerId(); + } + var rawJson = JsonSerializer.Serialize(task, A2AJsonUtilities.DefaultOptions); var record = new A2aTaskRecord { @@ -80,6 +121,7 @@ public async Task SaveTaskAsync(string taskId, AgentTask task, CancellationToken State = task.Status?.State.ToString() ?? "Unknown", CreatedAt = DateTimeOffset.UtcNow, UpdatedAt = DateTimeOffset.UtcNow, + OwnerId = _taskOwners.TryGetValue(taskId, out var owner) ? owner : null, RawTaskJson = rawJson, }; @@ -87,7 +129,7 @@ public async Task SaveTaskAsync(string taskId, AgentTask task, CancellationToken await _writeLock.WaitAsync(cancellationToken).ConfigureAwait(false); try { - await File.AppendAllTextAsync(_filePath, line + "\n", cancellationToken).ConfigureAwait(false); + await File.AppendAllLinesAsync(_filePath, [line], cancellationToken).ConfigureAwait(false); } finally { @@ -101,36 +143,80 @@ public async Task SaveTaskAsync(string taskId, AgentTask task, CancellationToken } } + /// Tombstone state value written to JSONL when a task is evicted. + internal const string DeletedState = "Deleted"; + /// - public Task DeleteTaskAsync(string taskId, CancellationToken cancellationToken = default) + public async Task DeleteTaskAsync(string taskId, CancellationToken cancellationToken = default) { _tasks.TryRemove(taskId, out _); - return Task.CompletedTask; + _taskOwners.TryRemove(taskId, out _); + + // Append a tombstone record so LoadFromDisk skips this task after restart. + var tombstone = new A2aTaskRecord + { + TaskId = taskId, + ContextId = "", + State = DeletedState, + CreatedAt = DateTimeOffset.UtcNow, + UpdatedAt = DateTimeOffset.UtcNow, + RawTaskJson = "{}", + }; + var line = JsonSerializer.Serialize(tombstone, A2aJsonlContext.Default.A2aTaskRecord); + await _writeLock.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + await File.AppendAllLinesAsync(_filePath, [line], cancellationToken).ConfigureAwait(false); + } + finally + { + _writeLock.Release(); + } } /// + /// + /// M-01 IDOR protection: filters results to tasks owned by the current caller. + /// Tasks with null OwnerId (pre-ownership) are visible to all callers for backward compat. + /// L-16: Page size is clamped to [1, 100] to prevent excessive memory use. + /// public Task ListTasksAsync(ListTasksRequest request, CancellationToken cancellationToken = default) { - var filtered = _tasks.Values.AsEnumerable(); + var callerId = GetCallerOwnerId(); + var filtered = _tasks.AsEnumerable(); + + // M-01: Filter by owner — only return tasks belonging to the caller + // or tasks with no owner (backward compat for pre-ownership tasks) + if (callerId is not null) + { + filtered = filtered.Where(kvp => + !_taskOwners.TryGetValue(kvp.Key, out var ownerId) + || ownerId is null + || string.Equals(ownerId, callerId, StringComparison.Ordinal)); + } + + var filteredValues = filtered.Select(kvp => kvp.Value); if (request.ContextId is not null) { - filtered = filtered.Where(t => + filteredValues = filteredValues.Where(t => string.Equals(t.ContextId, request.ContextId, StringComparison.Ordinal)); } if (request.Status is not null) { - filtered = filtered.Where(t => t.Status?.State == request.Status); + filteredValues = filteredValues.Where(t => t.Status?.State == request.Status); } // Order by task ID descending (ULID gives chronological + lexicographic ordering) - var ordered = filtered + var ordered = filteredValues .OrderByDescending(t => t.Id, StringComparer.Ordinal) .ToList(); var totalFiltered = ordered.Count; - var pageSize = request.PageSize ?? 20; + + // L-16: Clamp page size to prevent excessive memory use + var pageSize = Math.Clamp(request.PageSize ?? 20, 1, 100); // Apply cursor: skip to after the cursor ID if (!string.IsNullOrEmpty(request.PageToken)) @@ -178,6 +264,7 @@ internal async Task CompactAsync(CancellationToken cancellationToken = default) State = task.Status?.State.ToString() ?? "Unknown", CreatedAt = DateTimeOffset.UtcNow, UpdatedAt = DateTimeOffset.UtcNow, + OwnerId = _taskOwners.TryGetValue(taskId, out var owner) ? owner : null, RawTaskJson = rawJson, }; lines.Add(JsonSerializer.Serialize(record, A2aJsonlContext.Default.A2aTaskRecord)); @@ -211,11 +298,20 @@ private void LoadFromDisk() if (record is null) continue; + // Tombstone records mark evicted tasks — remove from memory (last-write-wins). + if (string.Equals(record.State, DeletedState, StringComparison.Ordinal)) + { + _tasks.TryRemove(record.TaskId, out _); + _taskOwners.TryRemove(record.TaskId, out _); + continue; + } + var agentTask = JsonSerializer.Deserialize( record.RawTaskJson, A2AJsonUtilities.DefaultOptions); if (agentTask is not null) { _tasks[record.TaskId] = agentTask; // last-write-wins dedup + _taskOwners[record.TaskId] = record.OwnerId; // restore owner mapping } } catch (JsonException ex) @@ -227,20 +323,25 @@ private void LoadFromDisk() LogLoadedTasks(_logger, _tasks.Count, _filePath); } - private void ValidateTransition(string taskId, AgentTask newTask) + /// + /// Validates A2A task state transitions. Returns true if the transition is valid + /// (or no prior state exists), false if the transition violates the state machine. + /// L-10: invalid transitions are now rejected, not just logged. + /// + private bool ValidateTransition(string taskId, AgentTask newTask) { if (!_tasks.TryGetValue(taskId, out var existing)) - return; + return true; // New task, no prior state var oldState = existing.Status?.State; var newState = newTask.Status?.State; if (oldState is null || newState is null) - return; + return true; // Same state is always allowed (idempotent save) if (oldState == newState) - return; + return true; var isValid = oldState switch { @@ -255,6 +356,36 @@ private void ValidateTransition(string taskId, AgentTask newTask) { LogInvalidTransition(_logger, oldState.Value.ToString(), newState.Value.ToString(), taskId); } + + return isValid; + } + + /// + /// Extracts the authenticated caller's owner ID from the current HTTP context. + /// Returns the KeyId or User.Name from , or null + /// when no HTTP context is available (e.g., eviction service, tests). + /// + private string? GetCallerOwnerId() + { + var authResult = _httpContextAccessor?.HttpContext?.Items[BearerTokenAuthFilter.AuthResultKey] + as McpServerAuthResult; + return authResult?.KeyId ?? authResult?.User?.Name; + } + + /// + /// Checks whether a task is owned by the specified owner. Used by + /// for push notification IDOR protection (M-02). Returns true if the task has no owner + /// (backward compat) or if the owner matches. + /// + internal bool IsTaskOwnedBy(string taskId, string? callerId) + { + if (callerId is null) + return true; // No caller identity available — allow (e.g., localhost bypass) + + if (!_taskOwners.TryGetValue(taskId, out var ownerId) || ownerId is null) + return true; // Pre-ownership task — allow for backward compat + + return string.Equals(ownerId, callerId, StringComparison.Ordinal); } // ── Source-generated log methods ───────────────────────────────────────── diff --git a/src/clawsharp/Analytics/EfInteractionStore.cs b/src/clawsharp/Analytics/EfInteractionStore.cs index e2774b9e..c4429936 100644 --- a/src/clawsharp/Analytics/EfInteractionStore.cs +++ b/src/clawsharp/Analytics/EfInteractionStore.cs @@ -22,12 +22,12 @@ public sealed partial class EfInteractionStore( public async Task AppendAsync(InteractionRecord record, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var db = await contextFactory.CreateDbContextAsync(ct); + await EnsureInitializedAsync(ct).ConfigureAwait(false); + await using var db = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); // Get or create conversation thread for this session var thread = await db.Set() - .FirstOrDefaultAsync(t => t.SessionId == record.SessionId, ct); + .FirstOrDefaultAsync(t => t.SessionId == record.SessionId, ct).ConfigureAwait(false); if (thread is null) { @@ -40,23 +40,23 @@ public async Task AppendAsync(InteractionRecord record, CancellationToken ct = d try { - await db.SaveChangesAsync(ct); + await db.SaveChangesAsync(ct).ConfigureAwait(false); } catch (DbUpdateException) { // Concurrent insert won the race — reload the existing thread db.ChangeTracker.Clear(); thread = await db.Set() - .FirstAsync(t => t.SessionId == record.SessionId, ct); + .FirstAsync(t => t.SessionId == record.SessionId, ct).ConfigureAwait(false); } } - await using var transaction = await db.Database.BeginTransactionAsync(ct); + await using var transaction = await db.Database.BeginTransactionAsync(ct).ConfigureAwait(false); var entity = ToEntity(record); entity.ConversationThreadId = thread.Id; db.Set().Add(entity); - await db.SaveChangesAsync(ct); + await db.SaveChangesAsync(ct).ConfigureAwait(false); // Insert per-message rows var now = record.Timestamp; @@ -91,19 +91,19 @@ public async Task AppendAsync(InteractionRecord record, CancellationToken ct = d Timestamp = now, }); - await db.SaveChangesAsync(ct); - await transaction.CommitAsync(ct); + await db.SaveChangesAsync(ct).ConfigureAwait(false); + await transaction.CommitAsync(ct).ConfigureAwait(false); } public async Task> ReadAllAsync(CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var db = await contextFactory.CreateDbContextAsync(ct); + await EnsureInitializedAsync(ct).ConfigureAwait(false); + await using var db = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); var entities = await db.Set() .AsNoTracking() .OrderBy(e => e.Id) - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); return entities.Select(ToRecord).ToList(); } @@ -115,7 +115,7 @@ private async Task EnsureInitializedAsync(CancellationToken ct) return; } - await _initLock.WaitAsync(ct); + await _initLock.WaitAsync(ct).ConfigureAwait(false); try { if (_initialized) @@ -123,8 +123,8 @@ private async Task EnsureInitializedAsync(CancellationToken ct) return; } - await using var db = await contextFactory.CreateDbContextAsync(ct); - await db.Database.MigrateAsync(ct); + await using var db = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); + await db.Database.MigrateAsync(ct).ConfigureAwait(false); _initialized = true; LogDatabaseInitialized(typeof(TContext).Name); } diff --git a/src/clawsharp/Analytics/InteractionStorage.cs b/src/clawsharp/Analytics/InteractionStorage.cs index a335bba7..178776e2 100644 --- a/src/clawsharp/Analytics/InteractionStorage.cs +++ b/src/clawsharp/Analytics/InteractionStorage.cs @@ -1,5 +1,6 @@ using System.Text.Json; using Clawsharp.Config; +using Clawsharp.Core.Utilities; namespace Clawsharp.Analytics; @@ -25,7 +26,7 @@ public sealed class InteractionStorage : IInteractionStore public InteractionStorage() { var dir = ConfigLoader.ExpandHome("~/.clawsharp"); - Directory.CreateDirectory(dir); + FilePermissions.EnsureRestrictedDirectory(dir); _filePath = Path.Combine(dir, "interactions.jsonl"); } @@ -48,7 +49,7 @@ public async Task AppendAsync(InteractionRecord record, CancellationToken ct = d await _writeLock.WaitAsync(ct).ConfigureAwait(false); try { - await File.AppendAllTextAsync(_filePath, json + "\n", ct).ConfigureAwait(false); + await File.AppendAllLinesAsync(_filePath, [json], ct).ConfigureAwait(false); // Invalidate the cache — next ReadAllAsync will re-read the file lock (_cacheLock) diff --git a/src/clawsharp/Analytics/InteractionTracker.cs b/src/clawsharp/Analytics/InteractionTracker.cs index 419928a7..2adf8a72 100644 --- a/src/clawsharp/Analytics/InteractionTracker.cs +++ b/src/clawsharp/Analytics/InteractionTracker.cs @@ -59,7 +59,7 @@ public async Task RecordAsync( try { - await store.AppendAsync(record, ct); + await store.AppendAsync(record, ct).ConfigureAwait(false); LogInteractionRecorded(sessionId, model, cost, savings); } catch (Exception ex) @@ -70,7 +70,7 @@ public async Task RecordAsync( if (storeInMemory) { - await StoreMemoryFactAsync(record, ct); + await StoreMemoryFactAsync(record, ct).ConfigureAwait(false); } } @@ -112,7 +112,7 @@ private async Task StoreMemoryFactAsync(InteractionRecord record, CancellationTo $"{record.InputTokens:N0} in / {record.OutputTokens:N0} out tokens, " + $"${record.CostUsd:F4} cost, ${record.CacheSavingsUsd:F4} cache savings ({cacheRate:F0}% cache hit).{toolInfo}{thinkingInfo}"; - await memory.AppendFactAsync(fact, ct); + await memory.AppendFactAsync(fact, ct).ConfigureAwait(false); } catch (Exception ex) { diff --git a/src/clawsharp/Auth/AuthStore.cs b/src/clawsharp/Auth/AuthStore.cs index 1900f226..9fcb7046 100644 --- a/src/clawsharp/Auth/AuthStore.cs +++ b/src/clawsharp/Auth/AuthStore.cs @@ -20,7 +20,7 @@ public static async Task SaveAsync(string provider, OAuthToken token, Cancellati var path = GetTokenPath(provider); var tmpPath = path + ".tmp"; var json = JsonSerializer.Serialize(token, AuthJsonContext.Default.OAuthToken); - await File.WriteAllTextAsync(tmpPath, json, ct); + await File.WriteAllTextAsync(tmpPath, json, ct).ConfigureAwait(false); // Restrict file permissions on Unix (owner read/write only) if (!OperatingSystem.IsWindows()) @@ -39,7 +39,7 @@ public static async Task SaveAsync(string provider, OAuthToken token, Cancellati return null; } - var json = await File.ReadAllTextAsync(path, ct); + var json = await File.ReadAllTextAsync(path, ct).ConfigureAwait(false); return JsonSerializer.Deserialize(json, AuthJsonContext.Default.OAuthToken); } diff --git a/src/clawsharp/Auth/GitHubDeviceFlow.cs b/src/clawsharp/Auth/GitHubDeviceFlow.cs index ce8962e6..bd47d2e0 100644 --- a/src/clawsharp/Auth/GitHubDeviceFlow.cs +++ b/src/clawsharp/Auth/GitHubDeviceFlow.cs @@ -26,7 +26,7 @@ public sealed class GitHubDeviceFlow(IHttpClientFactory httpFactory) http.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); // Step 1: Request device code - var deviceCode = await RequestDeviceCodeAsync(http, ct); + var deviceCode = await RequestDeviceCodeAsync(http, ct).ConfigureAwait(false); if (deviceCode is null) { AnsiConsole.MarkupLine("[red][[auth]][/] Failed to request device code from GitHub."); @@ -40,7 +40,7 @@ public sealed class GitHubDeviceFlow(IHttpClientFactory httpFactory) AnsiConsole.MarkupLine(" Waiting for authorization..."); // Step 2: Poll for GitHub access token - var githubToken = await PollForAccessTokenAsync(http, deviceCode, ct); + var githubToken = await PollForAccessTokenAsync(http, deviceCode, ct).ConfigureAwait(false); if (githubToken is null) { AnsiConsole.MarkupLine("[red][[auth]][/] Device flow authorization timed out or was denied."); @@ -50,7 +50,7 @@ public sealed class GitHubDeviceFlow(IHttpClientFactory httpFactory) AnsiConsole.MarkupLine(" GitHub authorization successful. Fetching Copilot token..."); // Step 3: Exchange GitHub token for Copilot token - var copilotToken = await ExchangeForCopilotTokenAsync(http, githubToken, ct); + var copilotToken = await ExchangeForCopilotTokenAsync(http, githubToken, ct).ConfigureAwait(false); if (copilotToken is null) { AnsiConsole.MarkupLine("[red][[auth]][/] Failed to obtain Copilot token. Ensure your GitHub account has Copilot access."); @@ -68,7 +68,7 @@ public sealed class GitHubDeviceFlow(IHttpClientFactory httpFactory) { using var http = httpFactory.CreateClient("llm"); http.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); - return await ExchangeForCopilotTokenAsync(http, githubToken, ct); + return await ExchangeForCopilotTokenAsync(http, githubToken, ct).ConfigureAwait(false); } private static async Task RequestDeviceCodeAsync(HttpClient http, CancellationToken ct) @@ -81,15 +81,15 @@ public sealed class GitHubDeviceFlow(IHttpClientFactory httpFactory) try { - var resp = await http.PostAsync("https://github.com/login/device/code", body, ct); + var resp = await http.PostAsync("https://github.com/login/device/code", body, ct).ConfigureAwait(false); if (!resp.IsSuccessStatusCode) { - var err = await resp.Content.ReadAsStringAsync(ct); + var err = await resp.Content.ReadAsStringAsync(ct).ConfigureAwait(false); AnsiConsole.MarkupLine($"[red][[auth]][/] Device code request failed ({resp.StatusCode}): {Markup.Escape(err)}"); return null; } - var json = await resp.Content.ReadAsStringAsync(ct); + var json = await resp.Content.ReadAsStringAsync(ct).ConfigureAwait(false); return JsonSerializer.Deserialize(json, AuthJsonContext.Default.GitHubDeviceCodeResponse); } catch (Exception ex) @@ -108,7 +108,7 @@ public sealed class GitHubDeviceFlow(IHttpClientFactory httpFactory) while (DateTimeOffset.UtcNow < deadline) { ct.ThrowIfCancellationRequested(); - await Task.Delay(TimeSpan.FromSeconds(interval), ct); + await Task.Delay(TimeSpan.FromSeconds(interval), ct).ConfigureAwait(false); var body = new FormUrlEncodedContent(new Dictionary { @@ -119,8 +119,8 @@ public sealed class GitHubDeviceFlow(IHttpClientFactory httpFactory) try { - var resp = await http.PostAsync("https://github.com/login/oauth/access_token", body, ct); - var json = await resp.Content.ReadAsStringAsync(ct); + var resp = await http.PostAsync("https://github.com/login/oauth/access_token", body, ct).ConfigureAwait(false); + var json = await resp.Content.ReadAsStringAsync(ct).ConfigureAwait(false); var tokenResp = JsonSerializer.Deserialize(json, AuthJsonContext.Default.GitHubAccessTokenResponse); if (tokenResp is null) @@ -177,15 +177,15 @@ public sealed class GitHubDeviceFlow(IHttpClientFactory httpFactory) req.Headers.Authorization = new AuthenticationHeaderValue("token", githubToken); req.Headers.UserAgent.Add(new ProductInfoHeaderValue("clawsharp", "1.0")); - var resp = await http.SendAsync(req, ct); + var resp = await http.SendAsync(req, ct).ConfigureAwait(false); if (!resp.IsSuccessStatusCode) { - var err = await resp.Content.ReadAsStringAsync(ct); + var err = await resp.Content.ReadAsStringAsync(ct).ConfigureAwait(false); AnsiConsole.MarkupLine($"[red][[auth]][/] Copilot token exchange failed ({resp.StatusCode}): {Markup.Escape(err)}"); return null; } - var json = await resp.Content.ReadAsStringAsync(ct); + var json = await resp.Content.ReadAsStringAsync(ct).ConfigureAwait(false); var copilotResp = JsonSerializer.Deserialize(json, AuthJsonContext.Default.CopilotTokenResponse); if (copilotResp is null || string.IsNullOrEmpty(copilotResp.Token)) { diff --git a/src/clawsharp/Channels/BridgePollingChannelBase.cs b/src/clawsharp/Channels/BridgePollingChannelBase.cs index 566b543e..6a3e3249 100644 --- a/src/clawsharp/Channels/BridgePollingChannelBase.cs +++ b/src/clawsharp/Channels/BridgePollingChannelBase.cs @@ -1,5 +1,3 @@ -using System.Text; -using System.Text.Json; using System.Text.Json.Serialization.Metadata; using Clawsharp.Config; using Clawsharp.Core; @@ -229,7 +227,7 @@ private async Task PollOnceAsync(CancellationToken ct) // Static AllowFrom + dynamic approved senders if (!_allowPolicy.IsAllowed(senderId) && - !await _approvedSenders.IsApprovedAsync(Name.Value, senderId).ConfigureAwait(false)) + !await _approvedSenders.IsApprovedAsync(Name.Value, senderId, ct).ConfigureAwait(false)) { LogBlockedSender(Logger, Name.Value, senderId); continue; @@ -258,8 +256,7 @@ public virtual async Task SendAsync(OutboundMessage message, CancellationToken c } var req = MapToSendRequest(message); - var json = JsonSerializer.Serialize(req, SendRequestTypeInfo); - using var content = new StringContent(json, Encoding.UTF8, "application/json"); + using var content = Utf8JsonContent.Create(req, SendRequestTypeInfo); try { diff --git a/src/clawsharp/Channels/Cli/CliChannel.cs b/src/clawsharp/Channels/Cli/CliChannel.cs index 192ba2a7..41bb8a43 100644 --- a/src/clawsharp/Channels/Cli/CliChannel.cs +++ b/src/clawsharp/Channels/Cli/CliChannel.cs @@ -31,7 +31,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) AnsiConsole.MarkupLine("[cyan]clawsharp[/] — type your message, Ctrl+C to exit\n"); AnsiConsole.Markup("[green]> [/]"); - await RunMessageLoopAsync(stoppingToken); + await RunMessageLoopAsync(stoppingToken).ConfigureAwait(false); } private async Task RunMessageLoopAsync(CancellationToken stoppingToken) @@ -43,7 +43,7 @@ private async Task RunMessageLoopAsync(CancellationToken stoppingToken) // TaskCompletionSource so that cancellation returns immediately // even if Console.ReadLine() stays blocked (the background thread // is IsBackground=true so it won't prevent process exit). - var line = await ReadLineAsync(stoppingToken); + var line = await ReadLineAsync(stoppingToken).ConfigureAwait(false); if (line is null || stoppingToken.IsCancellationRequested) { break; @@ -62,7 +62,7 @@ await bus.PublishAsync(new InboundMessage( SenderId: "cli-user", SenderName: "User", Text: line - ), stoppingToken); + ), stoppingToken).ConfigureAwait(false); // The next "> " prompt is printed by SendAsync/StreamAsync after the response. } catch (OperationCanceledException) @@ -93,7 +93,7 @@ await bus.PublishAsync(new InboundMessage( Name = "CLI-ReadLine" }; thread.Start(tcs); - return await tcs.Task; + return await tcs.Task.ConfigureAwait(false); } [LoggerMessage(EventId = 1, Level = LogLevel.Error, Message = "Error processing CLI input")] @@ -111,7 +111,7 @@ public async Task StreamAsync(OutboundMessage message, IAsyncEnumerable { AnsiConsole.Markup("[blue]Assistant:[/] "); var first = true; - await foreach (var token in tokens.WithCancellation(ct)) + await foreach (var token in tokens.WithCancellation(ct).ConfigureAwait(false)) { if (first) { diff --git a/src/clawsharp/Channels/Discord/DiscordChannel.cs b/src/clawsharp/Channels/Discord/DiscordChannel.cs index 334c562e..84542a73 100644 --- a/src/clawsharp/Channels/Discord/DiscordChannel.cs +++ b/src/clawsharp/Channels/Discord/DiscordChannel.cs @@ -43,6 +43,7 @@ public async Task SendAsync(OutboundMessage message, CancellationToken ct = defa if (!result.IsSuccess) { LogSendError(logger, result.Error); + break; } } } diff --git a/src/clawsharp/Channels/Discord/DiscordMessageResponder.cs b/src/clawsharp/Channels/Discord/DiscordMessageResponder.cs index 22df653d..9ed97a93 100644 --- a/src/clawsharp/Channels/Discord/DiscordMessageResponder.cs +++ b/src/clawsharp/Channels/Discord/DiscordMessageResponder.cs @@ -151,7 +151,7 @@ private async Task CheckUserAllowedAsync( string authorId, bool isDm, Snowflake channelId, string username, CancellationToken ct) { var isAllowed = _allowPolicy.IsAllowed(authorId) - || await approvedSenders.IsApprovedAsync("discord", authorId, ct); + || await approvedSenders.IsApprovedAsync(ChannelName.Discord.Value, authorId, ct); if (isAllowed) { return true; @@ -161,7 +161,7 @@ private async Task CheckUserAllowedAsync( { try { - var code = await pairingStore.GetOrCreateCodeAsync("discord", authorId, username, ct); + var code = await pairingStore.GetOrCreateCodeAsync(ChannelName.Discord.Value, authorId, username, ct); var msg = $"Hi! To use this bot, send your operator the pairing code: **{code}**\n" + "This code expires in 24 hours."; await restChannel.CreateMessageAsync(channelId, msg, ct: ct); diff --git a/src/clawsharp/Channels/Irc/IrcChannel.cs b/src/clawsharp/Channels/Irc/IrcChannel.cs index aaddcfa1..3b9a4310 100644 --- a/src/clawsharp/Channels/Irc/IrcChannel.cs +++ b/src/clawsharp/Channels/Irc/IrcChannel.cs @@ -89,8 +89,11 @@ public async Task SendAsync(OutboundMessage message, CancellationToken ct = defa } var target = message.RecipientId; - // Split long messages (IRC line limit ~512 bytes) - var text = message.Text; + // Strip CR/LF to prevent protocol injection — an LLM response containing \r\n + // would otherwise be interpreted as multiple IRC commands by the server. + var text = message.Text + .Replace("\r", "", StringComparison.Ordinal) + .Replace("\n", " ", StringComparison.Ordinal); while (text.Length > 0) { var chunk = text; diff --git a/src/clawsharp/Channels/Lark/LarkChannel.cs b/src/clawsharp/Channels/Lark/LarkChannel.cs index af5dcfae..23818407 100644 --- a/src/clawsharp/Channels/Lark/LarkChannel.cs +++ b/src/clawsharp/Channels/Lark/LarkChannel.cs @@ -149,20 +149,24 @@ protected override async Task HandleRequestAsync(HttpListenerContext ctx, Cancel return; } - // MED-45: Signature verification — REQUIRE valid signature when token is configured. - // If no token is configured, we already logged a warning at startup (LogNoVerificationToken). - if (_verificationToken.Length > 0) + // Signature verification — REQUIRE valid signature when token is configured. + // When no token is configured, reject message events entirely to prevent forged webhooks. + if (_verificationToken.Length == 0) { - var timestamp = req.Headers["X-Lark-Request-Timestamp"] ?? ""; - var nonce = req.Headers["X-Lark-Request-Nonce"] ?? ""; - var signature = req.Headers["X-Lark-Signature"]; - if (signature is null || !VerifySignature(timestamp, nonce, bodyBytes, signature)) - { - LogInvalidSignature(); - resp.StatusCode = 403; - resp.Close(); - return; - } + resp.StatusCode = 403; + resp.Close(); + return; + } + + var timestamp = req.Headers["X-Lark-Request-Timestamp"] ?? ""; + var nonce = req.Headers["X-Lark-Request-Nonce"] ?? ""; + var signature = req.Headers["X-Lark-Signature"]; + if (signature is null || !VerifySignature(timestamp, nonce, bodyBytes, signature)) + { + LogInvalidSignature(); + resp.StatusCode = 403; + resp.Close(); + return; } // Handle im.message.receive_v1 @@ -303,8 +307,7 @@ public async Task SendAsync(OutboundMessage message, CancellationToken ct = defa MsgType = LarkMessageType.Text }; - var json = JsonSerializer.Serialize(sendReq, LarkJsonContext.Default.LarkSendMessageRequest); - using var content = new StringContent(json, Encoding.UTF8, "application/json"); + using var content = Utf8JsonContent.Create(sendReq, LarkJsonContext.Default.LarkSendMessageRequest); try { @@ -317,7 +320,7 @@ public async Task SendAsync(OutboundMessage message, CancellationToken ct = defa if (!resp.IsSuccessStatusCode) { var responseBody = await resp.Content.ReadAsStringAsync(ct).ConfigureAwait(false); - LogSendError(responseBody); + LogSendError(TruncateResponseBody(responseBody)); } } catch (Exception ex) @@ -346,14 +349,14 @@ public async Task SendAsync(OutboundMessage message, CancellationToken ct = defa AppSecret = _appSecret }; - var json = JsonSerializer.Serialize(tokenReq, LarkJsonContext.Default.LarkTokenRequest); - using var content = new StringContent(json, Encoding.UTF8, "application/json"); + using var content = Utf8JsonContent.Create(tokenReq, LarkJsonContext.Default.LarkTokenRequest); using var resp = await _http.PostAsync( "open-apis/auth/v3/tenant_access_token/internal/", content, ct).ConfigureAwait(false); if (!resp.IsSuccessStatusCode) { - LogTokenHttpError(await resp.Content.ReadAsStringAsync(ct).ConfigureAwait(false)); + var tokenBody = await resp.Content.ReadAsStringAsync(ct).ConfigureAwait(false); + LogTokenHttpError(TruncateResponseBody(tokenBody)); return null; } @@ -378,6 +381,10 @@ public async Task SendAsync(OutboundMessage message, CancellationToken ct = defa } } + /// Truncates response bodies to avoid logging sensitive data (tokens, session info). + private static string TruncateResponseBody(string body, int maxLength = 500) => + body.Length > maxLength ? string.Concat(body.AsSpan(0, maxLength), "...(truncated)") : body; + // ── LoggerMessage methods ──────────────────────────────────────── [LoggerMessage(EventId = 1, Level = LogLevel.Information, Message = "Starting Lark webhook listener on port {Port}")] diff --git a/src/clawsharp/Channels/Line/LineChannel.cs b/src/clawsharp/Channels/Line/LineChannel.cs index 3af4f994..14998dad 100644 --- a/src/clawsharp/Channels/Line/LineChannel.cs +++ b/src/clawsharp/Channels/Line/LineChannel.cs @@ -209,8 +209,7 @@ public async Task SendAsync(OutboundMessage message, CancellationToken ct = defa Messages = [new LineTextMessage { Text = message.Text }] }; - var json = JsonSerializer.Serialize(req, LineJsonContext.Default.LinePushRequest); - using var content = new StringContent(json, Encoding.UTF8, "application/json"); + using var content = Utf8JsonContent.Create(req, LineJsonContext.Default.LinePushRequest); try { @@ -222,7 +221,7 @@ public async Task SendAsync(OutboundMessage message, CancellationToken ct = defa if (!resp.IsSuccessStatusCode) { var body = await resp.Content.ReadAsStringAsync(ct).ConfigureAwait(false); - LogSendError(_logger, body); + LogSendError(_logger, TruncateResponseBody(body)); } } catch (Exception ex) @@ -231,6 +230,10 @@ public async Task SendAsync(OutboundMessage message, CancellationToken ct = defa } } + /// Truncates response bodies to avoid logging sensitive data (tokens, session info). + private static string TruncateResponseBody(string body, int maxLength = 500) => + body.Length > maxLength ? string.Concat(body.AsSpan(0, maxLength), "...(truncated)") : body; + [LoggerMessage(EventId = 1, Level = LogLevel.Information, Message = "Starting LINE webhook listener on port {Port}")] private static partial void LogStartingWebhook(ILogger logger, int port); diff --git a/src/clawsharp/Channels/Matrix/MatrixChannel.cs b/src/clawsharp/Channels/Matrix/MatrixChannel.cs index eb6c3729..7d646bb2 100644 --- a/src/clawsharp/Channels/Matrix/MatrixChannel.cs +++ b/src/clawsharp/Channels/Matrix/MatrixChannel.cs @@ -1,5 +1,4 @@ using System.Net.Http.Headers; -using System.Text; using System.Text.Json; using Clawsharp.Config; using Clawsharp.Core; @@ -150,10 +149,10 @@ public async Task SendAsync(OutboundMessage message, CancellationToken ct = defa HttpMethod? method = null) { var httpMethod = method ?? HttpMethod.Post; - var json = JsonSerializer.Serialize(request, request.RequestTypeInfo); + var jsonBytes = JsonSerializer.SerializeToUtf8Bytes(request, request.RequestTypeInfo); // First attempt. - using (var content = new StringContent(json, Encoding.UTF8, "application/json")) + using (var content = Utf8JsonContent.FromUtf8Bytes(jsonBytes)) using (var req = CreateRequest(httpMethod, request.Url, content)) { using var resp = await _http.SendAsync(req, ct); @@ -171,20 +170,22 @@ public async Task SendAsync(OutboundMessage message, CancellationToken ct = defa } else { - LogSendFailed(_logger, await resp.Content.ReadAsStringAsync(ct)); + var body = await resp.Content.ReadAsStringAsync(ct); + LogSendFailed(_logger, TruncateResponseBody(body)); return default; } } // Retry after successful re-login. - using (var retryContent = new StringContent(json, Encoding.UTF8, "application/json")) + using (var retryContent = Utf8JsonContent.FromUtf8Bytes(jsonBytes)) using (var retryReq = CreateRequest(httpMethod, request.Url, retryContent)) { using var retryResp = await _http.SendAsync(retryReq, ct); if (!retryResp.IsSuccessStatusCode) { - LogSendFailed(_logger, await retryResp.Content.ReadAsStringAsync(ct)); + var retryBody = await retryResp.Content.ReadAsStringAsync(ct); + LogSendFailed(_logger, TruncateResponseBody(retryBody)); return default; } @@ -264,8 +265,7 @@ private async Task TryReloginAsync(CancellationToken ct) Url = "_matrix/client/v3/login" }; - var json = JsonSerializer.Serialize(loginRequest, MatrixJsonContext.Default.MatrixLoginRequest); - using var content = new StringContent(json, Encoding.UTF8, "application/json"); + using var content = Utf8JsonContent.Create(loginRequest, MatrixJsonContext.Default.MatrixLoginRequest); // Do NOT use CreateRequest here -- we may have an expired/invalid token, // and the login endpoint does not require Authorization. using var req = new HttpRequestMessage(HttpMethod.Post, loginRequest.Url) { Content = content }; @@ -274,7 +274,7 @@ private async Task TryReloginAsync(CancellationToken ct) if (!resp.IsSuccessStatusCode) { var body = await resp.Content.ReadAsStringAsync(ct); - LogReloginFailed(_logger, $"HTTP {(int)resp.StatusCode}: {body}"); + LogReloginFailed(_logger, $"HTTP {(int)resp.StatusCode}: {TruncateResponseBody(body)}"); return null; } @@ -372,7 +372,11 @@ private void SaveSyncToken(string token) Directory.CreateDirectory(dir); } - File.WriteAllText(SyncTokenPath, token); + // Atomic write via temp+rename — prevents token corruption on crash + // (consistent with SessionManager's File.Move pattern). + var tmp = SyncTokenPath + ".tmp"; + File.WriteAllText(tmp, token); + File.Move(tmp, SyncTokenPath, overwrite: true); } catch (Exception ex) { @@ -508,7 +512,7 @@ private async Task ProcessSyncRoomsAsync(MatrixSyncResponse sync, CancellationTo // Per-user allowlist check (static AllowFrom + dynamic approved senders) if (!_allowPolicy.IsAllowed(ev.Sender) && - !await _approvedSenders.IsApprovedAsync(ChannelName.Matrix.Value, ev.Sender)) + !await _approvedSenders.IsApprovedAsync(ChannelName.Matrix.Value, ev.Sender, ct)) { LogBlockedUser(_logger, ev.Sender); continue; @@ -549,6 +553,10 @@ await _bus.PublishAsync(new InboundMessage( } } + /// Truncates response bodies to avoid logging sensitive data (tokens, session info). + private static string TruncateResponseBody(string body, int maxLength = 500) => + body.Length > maxLength ? string.Concat(body.AsSpan(0, maxLength), "...(truncated)") : body; + [LoggerMessage(EventId = 1, Level = LogLevel.Information, Message = "Starting sync loop")] private static partial void LogStartingSyncLoop(ILogger logger); diff --git a/src/clawsharp/Channels/Mattermost/MattermostChannel.cs b/src/clawsharp/Channels/Mattermost/MattermostChannel.cs index 00820ee4..589b6622 100644 --- a/src/clawsharp/Channels/Mattermost/MattermostChannel.cs +++ b/src/clawsharp/Channels/Mattermost/MattermostChannel.cs @@ -1,6 +1,5 @@ using System.Net.Http.Headers; using System.Net.WebSockets; -using System.Text; using System.Text.Json; using Clawsharp.Config; using Clawsharp.Core; @@ -308,17 +307,16 @@ public async Task SendAsync(OutboundMessage message, CancellationToken ct = defa Message = message.Text }; - var json = JsonSerializer.Serialize(postReq, MattermostJsonContext.Default.MattermostCreatePostRequest); - try { using var httpReq = new HttpRequestMessage(HttpMethod.Post, "api/v4/posts"); httpReq.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _botToken); - httpReq.Content = new StringContent(json, Encoding.UTF8, "application/json"); + httpReq.Content = Utf8JsonContent.Create(postReq, MattermostJsonContext.Default.MattermostCreatePostRequest); using var resp = await _http.SendAsync(httpReq, ct); if (!resp.IsSuccessStatusCode) { - LogSendFailed(await resp.Content.ReadAsStringAsync(ct)); + var body = await resp.Content.ReadAsStringAsync(ct); + LogSendFailed(TruncateResponseBody(body)); } } catch (Exception ex) @@ -438,10 +436,9 @@ public async Task StreamAsync(OutboundMessage message, IAsyncEnumerable ChannelId = channelId, Message = text }; - var json = JsonSerializer.Serialize(postReq, MattermostJsonContext.Default.MattermostCreatePostRequest); using var httpReq = new HttpRequestMessage(HttpMethod.Post, "api/v4/posts"); httpReq.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _botToken); - httpReq.Content = new StringContent(json, Encoding.UTF8, "application/json"); + httpReq.Content = Utf8JsonContent.Create(postReq, MattermostJsonContext.Default.MattermostCreatePostRequest); using var resp = await _http.SendAsync(httpReq, ct); if (!resp.IsSuccessStatusCode) @@ -462,15 +459,15 @@ private async Task UpdatePostAsync(string postId, string text, CancellationToken Id = postId, Message = text }; - var json = JsonSerializer.Serialize(updateReq, MattermostJsonContext.Default.MattermostUpdatePostRequest); using var httpReq = new HttpRequestMessage(HttpMethod.Put, $"api/v4/posts/{postId}"); httpReq.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _botToken); - httpReq.Content = new StringContent(json, Encoding.UTF8, "application/json"); + httpReq.Content = Utf8JsonContent.Create(updateReq, MattermostJsonContext.Default.MattermostUpdatePostRequest); using var resp = await _http.SendAsync(httpReq, ct); if (!resp.IsSuccessStatusCode) { - LogUpdatePostFailed(postId, await resp.Content.ReadAsStringAsync(ct)); + var body = await resp.Content.ReadAsStringAsync(ct); + LogUpdatePostFailed(postId, TruncateResponseBody(body)); } } @@ -495,6 +492,10 @@ private async Task FetchSelfIdAsync(CancellationToken ct) private const int MaxWebSocketMessageBytes = 1 * 1024 * 1024; // 1 MB + /// Truncates response bodies to avoid logging sensitive data (tokens, session info). + private static string TruncateResponseBody(string body, int maxLength = 500) => + body.Length > maxLength ? string.Concat(body.AsSpan(0, maxLength), "...(truncated)") : body; + // ── LoggerMessage methods ──────────────────────────────────────── [LoggerMessage(EventId = 1, Level = LogLevel.Information, Message = "Starting Mattermost channel")] diff --git a/src/clawsharp/Channels/Qq/QqChannel.cs b/src/clawsharp/Channels/Qq/QqChannel.cs index d0b735ed..d08eecb3 100644 --- a/src/clawsharp/Channels/Qq/QqChannel.cs +++ b/src/clawsharp/Channels/Qq/QqChannel.cs @@ -141,7 +141,8 @@ public async Task SendAsync(OutboundMessage message, CancellationToken ct = defa if (!resp.IsSuccessStatusCode) { - LogSendFailed(await resp.Content.ReadAsStringAsync(ct)); + var responseBody = await resp.Content.ReadAsStringAsync(ct); + LogSendFailed(TruncateResponseBody(responseBody)); } } else @@ -150,7 +151,8 @@ public async Task SendAsync(OutboundMessage message, CancellationToken ct = defa if (!resp.IsSuccessStatusCode) { - LogSendFailed(await resp.Content.ReadAsStringAsync(ct)); + var responseBody = await resp.Content.ReadAsStringAsync(ct); + LogSendFailed(TruncateResponseBody(responseBody)); } } } @@ -421,6 +423,10 @@ private Uri BuildWebSocketUri() return new Uri($"{_wsUrl}{separator}access_token={Uri.EscapeDataString(_token)}"); } + /// Truncates response bodies to avoid logging sensitive data (tokens, session info). + private static string TruncateResponseBody(string body, int maxLength = 500) => + body.Length > maxLength ? string.Concat(body.AsSpan(0, maxLength), "...(truncated)") : body; + // ── LoggerMessage methods ──────────────────────────────────────── [LoggerMessage(EventId = 1, Level = LogLevel.Information, Message = "Starting QQ/OneBot channel")] diff --git a/src/clawsharp/Channels/Signal/SignalChannel.cs b/src/clawsharp/Channels/Signal/SignalChannel.cs index f8e7559e..6f2f6f95 100644 --- a/src/clawsharp/Channels/Signal/SignalChannel.cs +++ b/src/clawsharp/Channels/Signal/SignalChannel.cs @@ -346,9 +346,7 @@ await _bus.PublishAsync(new InboundMessage( } }; - var json = JsonSerializer.Serialize( - rpcRequest, SignalJsonContext.Default.SignalGetAttachmentRpcRequest); - using var content = new StringContent(json, Encoding.UTF8, "application/json"); + using var content = Utf8JsonContent.Create(rpcRequest, SignalJsonContext.Default.SignalGetAttachmentRpcRequest); using var resp = await _http.PostAsync("api/v1/rpc", content, ct); if (!resp.IsSuccessStatusCode) @@ -403,8 +401,7 @@ public async Task SendAsync(OutboundMessage message, CancellationToken ct = defa } }; - var json = JsonSerializer.Serialize(rpcRequest, SignalJsonContext.Default.SignalSendRpcRequest); - using var content = new StringContent(json, Encoding.UTF8, "application/json"); + using var content = Utf8JsonContent.Create(rpcRequest, SignalJsonContext.Default.SignalSendRpcRequest); try { diff --git a/src/clawsharp/Channels/Slack/SlackChannel.cs b/src/clawsharp/Channels/Slack/SlackChannel.cs index cc0b64d7..c3ea80f8 100644 --- a/src/clawsharp/Channels/Slack/SlackChannel.cs +++ b/src/clawsharp/Channels/Slack/SlackChannel.cs @@ -163,14 +163,14 @@ public Task StopThinkingAsync(string recipientId, CancellationToken ct = default /// private async Task ExecuteAsync(IRequest request, CancellationToken ct) { - var json = JsonSerializer.Serialize(request, request.RequestTypeInfo); using var httpReq = new HttpRequestMessage(HttpMethod.Post, request.Url); httpReq.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _botToken); - httpReq.Content = new StringContent(json, Encoding.UTF8, "application/json"); + httpReq.Content = Utf8JsonContent.Create(request, request.RequestTypeInfo); using var resp = await _http.SendAsync(httpReq, ct); if (!resp.IsSuccessStatusCode) { - LogSendFailed(_logger, await resp.Content.ReadAsStringAsync(ct)); + var body = await resp.Content.ReadAsStringAsync(ct); + LogSendFailed(_logger, TruncateResponseBody(body)); return default; } @@ -220,6 +220,7 @@ await ExecuteAsync(new SlackUpdateMessageRequest ct: ct).ConfigureAwait(false); // If the placeholder failed, send as a new message. + // result.Text is always raw LLM text (no mrkdwn); SendAsync applies ConvertToMrkdwn. if (!result.PlaceholderCreated) { await SendAsync(message with { Text = result.Text }, ct).ConfigureAwait(false); @@ -391,7 +392,7 @@ private static (string Text, string UserId, string ChannelId, string? Ts, string private async Task CheckUserAllowedAsync(string userId, string channelId, JsonElement ev, CancellationToken ct) { var isAllowed = _allowPolicy.IsAllowed(userId) - || await _approvedSenders.IsApprovedAsync("slack", userId); + || await _approvedSenders.IsApprovedAsync(ChannelName.Slack.Value, userId, ct); if (isAllowed) { return true; @@ -408,7 +409,7 @@ private async Task CheckUserAllowedAsync(string userId, string channelId, { userName = dn.GetString() ?? userId; } - var code = await _pairingStore.GetOrCreateCodeAsync("slack", userId, userName, ct); + var code = await _pairingStore.GetOrCreateCodeAsync(ChannelName.Slack.Value, userId, userName, ct); await PostPairingMessageAsync(userId, code, ct); LogPairingSent(_logger, userId, code); } @@ -481,6 +482,11 @@ internal static string ConvertToMrkdwn(string markdown) return $"\x00IC{inlineCode.Count - 1}\x00"; }); + // PERF: The following 7 regex replacements each allocate an intermediate string from the full + // response. For a 10KB message this is ~70KB transient allocation. Regex.Replace returns string + // (no StringBuilder overload exists) so there is no simple way to chain these without allocation. + // Acceptable for per-message Slack formatting — this runs once per outbound message, not per token. + // 1. Bold: **text** → *text* (must run before italic to avoid conflict) result = BoldRegex().Replace(result, "*$1*"); @@ -550,6 +556,10 @@ internal static string ConvertToMrkdwn(string markdown) [GeneratedRegex(@"\x00IC(\d+)\x00", RegexOptions.None, 200)] private static partial Regex InlineCodeSentinelRegex(); + /// Truncates response bodies to avoid logging sensitive data (tokens, session info). + private static string TruncateResponseBody(string body, int maxLength = 500) => + body.Length > maxLength ? string.Concat(body.AsSpan(0, maxLength), "...(truncated)") : body; + [LoggerMessage(EventId = 1, Level = LogLevel.Information, Message = "Starting Socket Mode")] private static partial void LogStartingSocketMode(ILogger logger); diff --git a/src/clawsharp/Channels/Telegram/TelegramChannel.cs b/src/clawsharp/Channels/Telegram/TelegramChannel.cs index a2a6ac2d..cfd324e1 100644 --- a/src/clawsharp/Channels/Telegram/TelegramChannel.cs +++ b/src/clawsharp/Channels/Telegram/TelegramChannel.cs @@ -115,6 +115,12 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) while (!stoppingToken.IsCancellationRequested) { + // Retry bot info fetch if it failed at startup + if (_botUsername is null) + { + await FetchBotInfoAsync(stoppingToken).ConfigureAwait(false); + } + try { await _retryPipeline.ExecuteAsync( @@ -281,7 +287,7 @@ private async Task ProcessUpdateAsync(TelegramUpdate update, CancellationToken c } } - if (!await IsUserAllowedAsync(msg.From)) + if (!await IsUserAllowedAsync(msg.From, ct)) { if (_dmPolicy == DmPolicy.Pairing) { @@ -571,12 +577,12 @@ public Task StopThinkingAsync(string recipientId, CancellationToken ct = default /// private async Task ExecuteAsync(IRequest request, CancellationToken ct) { - var json = JsonSerializer.Serialize(request, request.RequestTypeInfo); - using var content = new StringContent(json, Encoding.UTF8, "application/json"); + using var content = Utf8JsonContent.Create(request, request.RequestTypeInfo); using var resp = await _http.PostAsync(request.Url, content, ct); if (!resp.IsSuccessStatusCode) { - LogSendFailed(_logger, $"HTTP {(int)resp.StatusCode}: {await resp.Content.ReadAsStringAsync(ct)}"); + var body = await resp.Content.ReadAsStringAsync(ct); + LogSendFailed(_logger, $"HTTP {(int)resp.StatusCode}: {TruncateResponseBody(body)}"); return default; } @@ -799,7 +805,7 @@ private static string Normalize(string entry) return entry.TrimStart('@').Trim(); } - private async ValueTask IsUserAllowedAsync(TelegramUser user) + private async ValueTask IsUserAllowedAsync(TelegramUser user, CancellationToken ct = default) { if (_allowPolicy.IsAllowAll) { @@ -807,7 +813,7 @@ private async ValueTask IsUserAllowedAsync(TelegramUser user) } // Check dynamic approved senders store - if (await _approvedSenders.IsApprovedAsync(ChannelName.Telegram.Value, user.Id.ToString())) + if (await _approvedSenders.IsApprovedAsync(ChannelName.Telegram.Value, user.Id.ToString(), ct)) { return true; } @@ -1036,6 +1042,10 @@ await ExecuteAsync(new TelegramSendMessageRequest } } + /// Truncates response bodies to avoid logging sensitive data (tokens, session info). + private static string TruncateResponseBody(string body, int maxLength = 500) => + body.Length > maxLength ? string.Concat(body.AsSpan(0, maxLength), "...(truncated)") : body; + [LoggerMessage(EventId = 1, Level = LogLevel.Information, Message = "Starting long-poll loop")] private static partial void LogStartingLongPollLoop(ILogger logger); diff --git a/src/clawsharp/Channels/WeChat/WeChatChannel.cs b/src/clawsharp/Channels/WeChat/WeChatChannel.cs index e41927d5..bfa2f973 100644 --- a/src/clawsharp/Channels/WeChat/WeChatChannel.cs +++ b/src/clawsharp/Channels/WeChat/WeChatChannel.cs @@ -1,4 +1,3 @@ -using System.Text; using System.Text.Json; using System.Text.Json.Serialization.Metadata; using Clawsharp.Config; @@ -156,8 +155,7 @@ private async Task SendViaWebhookAsync(OutboundMessage message, CancellationToke Text = new WeChatWebhookText { Content = message.Text } }; - var json = JsonSerializer.Serialize(req, WeChatJsonContext.Default.WeChatWebhookRequest); - using var content = new StringContent(json, Encoding.UTF8, "application/json"); + using var content = Utf8JsonContent.Create(req, WeChatJsonContext.Default.WeChatWebhookRequest); try { diff --git a/src/clawsharp/Channels/WeCom/WeComChannel.cs b/src/clawsharp/Channels/WeCom/WeComChannel.cs index 0fc7df25..ca9283dc 100644 --- a/src/clawsharp/Channels/WeCom/WeComChannel.cs +++ b/src/clawsharp/Channels/WeCom/WeComChannel.cs @@ -2,10 +2,12 @@ using System.Net; using System.Text; using System.Text.Json; +using System.Xml; using System.Xml.Linq; using Clawsharp.Config; using Clawsharp.Core; using Clawsharp.Core.Services; +using Clawsharp.Security; using Clawsharp.Core.Sessions; using Clawsharp.Core.Utilities; using Microsoft.Extensions.Logging; @@ -245,7 +247,9 @@ private async Task HandleMessageAsync( string? encryptContent; try { - var doc = await XDocument.LoadAsync(ms, LoadOptions.None, ct).ConfigureAwait(false); + var xmlSettings = new XmlReaderSettings { DtdProcessing = DtdProcessing.Prohibit, XmlResolver = null, Async = true }; + using var xmlReader = XmlReader.Create(ms, xmlSettings); + var doc = await XDocument.LoadAsync(xmlReader, LoadOptions.None, ct).ConfigureAwait(false); toUserName = doc.Root?.Element("ToUserName")?.Value; encryptContent = doc.Root?.Element("Encrypt")?.Value; } @@ -322,8 +326,10 @@ private async Task ProcessBotMessageAsync(WeComBotMessage msg, CancellationToken return; } - // Store response_url for SendAsync - if (msg.ResponseUrl is not null) + // Store response_url for SendAsync — validate via SsrfGuard first. + if (msg.ResponseUrl is not null + && Uri.TryCreate(msg.ResponseUrl, UriKind.Absolute, out var responseUri) + && await SsrfGuard.CheckAsync(responseUri, ct).ConfigureAwait(false) is null) { _responseUrls[senderId] = msg.ResponseUrl; } @@ -440,8 +446,7 @@ public async Task SendAsync(OutboundMessage message, CancellationToken ct = defa Text = new WeComReplyText { Content = message.Text } }; - var json = JsonSerializer.Serialize(reply, WeComBotJsonContext.Default.WeComReplyMessage); - using var content = new StringContent(json, Encoding.UTF8, "application/json"); + using var content = Utf8JsonContent.Create(reply, WeComBotJsonContext.Default.WeComReplyMessage); try { @@ -449,7 +454,7 @@ public async Task SendAsync(OutboundMessage message, CancellationToken ct = defa if (!resp.IsSuccessStatusCode) { var body = await resp.Content.ReadAsStringAsync(ct).ConfigureAwait(false); - LogSendError(body); + LogSendError(TruncateResponseBody(body)); } } catch (Exception ex) @@ -458,6 +463,10 @@ public async Task SendAsync(OutboundMessage message, CancellationToken ct = defa } } + /// Truncates response bodies to avoid logging sensitive data (tokens, session info). + private static string TruncateResponseBody(string body, int maxLength = 500) => + body.Length > maxLength ? string.Concat(body.AsSpan(0, maxLength), "...(truncated)") : body; + // ── LoggerMessage methods ──────────────────────────────────────────── [LoggerMessage(EventId = 1, Level = LogLevel.Information, Message = "Starting WeCom AI Bot webhook listener on port {Port}")] diff --git a/src/clawsharp/Channels/Web/WebChannel.Oidc.cs b/src/clawsharp/Channels/Web/WebChannel.Oidc.cs index 195e71ce..5887af13 100644 --- a/src/clawsharp/Channels/Web/WebChannel.Oidc.cs +++ b/src/clawsharp/Channels/Web/WebChannel.Oidc.cs @@ -165,7 +165,7 @@ private async Task HandleOidcCallbackAsync(HttpContext context, CancellationToke var message = resolveResult.Message ?? "Identity resolution failed."; LogOidcIdentityDenied(_logger, message); context.Response.StatusCode = StatusCodes.Status403Forbidden; - await context.Response.WriteAsync(message, ct); + await context.Response.WriteAsync("Access denied. Contact your administrator.", ct); DeleteStateCookie(context); return; } @@ -202,14 +202,14 @@ private async Task HandleLinkCallbackAsync(HttpContext context, CancellationToke return; } - // Validate link token (but don't consume it yet — that happens after OIDC callback) - // We peek at the token to verify it exists and is valid before redirecting to IdP. - // The actual consumption happens in CompleteLinkFlowAsync. - // NOTE: LinkTokenStore.Validate is destructive (TryRemove). We need to re-store - // the token temporarily or validate non-destructively. Since LinkTokenStore uses - // TryRemove for single-use, we validate the signature manually here and consume in callback. - // For now, we trust the token format and signature will be validated at callback time. - // The link token + sig are passed through the state cookie to the callback. + // Validate link token non-destructively before redirecting to IdP. + // Consumption happens in CompleteLinkFlowAsync after the OIDC round-trip. + if (!_linkTokenStore.Peek(linkToken, linkSig)) + { + context.Response.StatusCode = StatusCodes.Status400BadRequest; + await context.Response.WriteAsync("Invalid or expired link token.", ct).ConfigureAwait(false); + return; + } var (state, nonce) = OidcService.GenerateStateAndNonce(); var (codeVerifier, codeChallenge) = OidcService.GeneratePkce(); @@ -266,7 +266,7 @@ private async Task CompleteLinkFlowAsync( var message = resolveResult.Message ?? "Identity resolution failed."; LogOidcLinkDenied(_logger, message); context.Response.StatusCode = StatusCodes.Status403Forbidden; - await context.Response.WriteAsync(message, ct); + await context.Response.WriteAsync("Access denied. Contact your administrator.", ct); return; } diff --git a/src/clawsharp/Channels/Web/WebChannel.cs b/src/clawsharp/Channels/Web/WebChannel.cs index cbac2743..27f9a925 100644 --- a/src/clawsharp/Channels/Web/WebChannel.cs +++ b/src/clawsharp/Channels/Web/WebChannel.cs @@ -571,7 +571,15 @@ private async Task HandleHttpChatAsync(HttpContext context, CancellationToken ct // Per Pitfall #3: cookie-authenticated users derive ID from OIDC sub claim. var sessionId = DeriveSessionIdFromContext(context); var tcs = new TaskCompletionSource(); - _pending[sessionId] = tcs; + + // Reject concurrent requests for the same session — indexer overwrite would + // silently abandon the first TCS, causing an HTTP 500 timeout. + if (!_pending.TryAdd(sessionId, tcs)) + { + context.Response.StatusCode = StatusCodes.Status409Conflict; + await context.Response.WriteAsync("A request is already in progress for this session.", ct).ConfigureAwait(false); + return; + } try { @@ -683,6 +691,22 @@ private async Task HandleWebSocketAsync(WebSocket ws, IPAddress? remoteIp, Cance /// private async Task RunWebSocketMessageLoopAsync(WebSocket ws, string sessionId, IPAddress? remoteIp, CancellationToken ct) { + // Close the previous connection if a new one authenticates with the same session, + // preventing delivery hijack where replies go to the new connection while the old + // connection's loop still publishes inbound messages. + if (_wsClients.TryGetValue(sessionId, out var existing) && existing.State == WebSocketState.Open) + { + try + { + await existing.CloseAsync(WebSocketCloseStatus.NormalClosure, "Replaced by new connection", ct) + .ConfigureAwait(false); + } + catch + { + // Best-effort close — the old connection may already be broken. + } + } + _wsClients[sessionId] = ws; var buffer = ArrayPool.Shared.Rent(WebSocketReceiveBufferSize); try diff --git a/src/clawsharp/Channels/Web/index.html b/src/clawsharp/Channels/Web/index.html index e82e9379..c1df0571 100644 --- a/src/clawsharp/Channels/Web/index.html +++ b/src/clawsharp/Channels/Web/index.html @@ -10,8 +10,8 @@ href="https://fonts.googleapis.com/css2?family=JetBrains+Mono:ital,wght@0,400;0,500;0,600;1,400&display=swap" rel="stylesheet" /> - + diff --git a/src/clawsharp/Cli/AgentCommand.cs b/src/clawsharp/Cli/AgentCommand.cs index 8e44f8ab..f11736b3 100644 --- a/src/clawsharp/Cli/AgentCommand.cs +++ b/src/clawsharp/Cli/AgentCommand.cs @@ -28,11 +28,11 @@ public override async Task ExecuteAsync(CommandContext context, Settings se { if (settings.Message is not null) { - await SingleShotCommand.RunAsync(settings.Message, cancellationToken); + await SingleShotCommand.RunAsync(settings.Message, cancellationToken).ConfigureAwait(false); return 0; } - await GatewayHost.RunAsync(cancellationToken); + await GatewayHost.RunAsync(cancellationToken).ConfigureAwait(false); return 0; } } \ No newline at end of file diff --git a/src/clawsharp/Cli/Auth/AuthLoginCopilotCommand.cs b/src/clawsharp/Cli/Auth/AuthLoginCopilotCommand.cs index ad319449..67027a90 100644 --- a/src/clawsharp/Cli/Auth/AuthLoginCopilotCommand.cs +++ b/src/clawsharp/Cli/Auth/AuthLoginCopilotCommand.cs @@ -16,14 +16,14 @@ public override async Task ExecuteAsync(CommandContext context, Cancellatio AnsiConsole.MarkupLine("Your GitHub account must have an active Copilot subscription."); AnsiConsole.WriteLine(); - var token = await deviceFlow.LoginAsync(cancellationToken); + var token = await deviceFlow.LoginAsync(cancellationToken).ConfigureAwait(false); if (token is null) { AnsiConsole.MarkupLine("[red]Login failed.[/]"); return 1; } - await AuthStore.SaveAsync("copilot", token, cancellationToken); + await AuthStore.SaveAsync("copilot", token, cancellationToken).ConfigureAwait(false); AnsiConsole.MarkupLine("[green]Logged in to GitHub Copilot successfully.[/]"); if (token.ExpiresAt.HasValue) { diff --git a/src/clawsharp/Cli/Auth/AuthStatusCommand.cs b/src/clawsharp/Cli/Auth/AuthStatusCommand.cs index f3f95f65..81d82cf8 100644 --- a/src/clawsharp/Cli/Auth/AuthStatusCommand.cs +++ b/src/clawsharp/Cli/Auth/AuthStatusCommand.cs @@ -22,7 +22,7 @@ public override async Task ExecuteAsync(CommandContext context, Cancellatio foreach (var provider in KnownProviders) { - var token = await AuthStore.LoadAsync(provider, cancellationToken); + var token = await AuthStore.LoadAsync(provider, cancellationToken).ConfigureAwait(false); if (token is null) { table.AddRow(provider, "[grey]Not logged in[/]", "-"); diff --git a/src/clawsharp/Cli/Channel/ChannelPairWebCommand.cs b/src/clawsharp/Cli/Channel/ChannelPairWebCommand.cs index 39c4e943..a67d5c8e 100644 --- a/src/clawsharp/Cli/Channel/ChannelPairWebCommand.cs +++ b/src/clawsharp/Cli/Channel/ChannelPairWebCommand.cs @@ -31,7 +31,7 @@ public override async Task ExecuteAsync(CommandContext context, Cancellatio { using var connectCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); connectCts.CancelAfter(TimeSpan.FromSeconds(3)); - await pipe.ConnectAsync(connectCts.Token); + await pipe.ConnectAsync(connectCts.Token).ConfigureAwait(false); } catch (OperationCanceledException) { @@ -50,11 +50,11 @@ public override async Task ExecuteAsync(CommandContext context, Cancellatio await using var writer = new StreamWriter(pipe, Encoding.UTF8, leaveOpen: true) { AutoFlush = true }; var reqJson = JsonSerializer.Serialize(new IpcRequest(command, token), IpcJsonContext.Default.IpcRequest); - await writer.WriteLineAsync(reqJson.AsMemory(), cancellationToken); + await writer.WriteLineAsync(reqJson.AsMemory(), cancellationToken).ConfigureAwait(false); using var readCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); readCts.CancelAfter(TimeSpan.FromSeconds(5)); - var line = await reader.ReadLineAsync(readCts.Token); + var line = await reader.ReadLineAsync(readCts.Token).ConfigureAwait(false); if (line is null) { diff --git a/src/clawsharp/Cli/Config/ConfigSetCommand.cs b/src/clawsharp/Cli/Config/ConfigSetCommand.cs index 16b71646..b9088169 100644 --- a/src/clawsharp/Cli/Config/ConfigSetCommand.cs +++ b/src/clawsharp/Cli/Config/ConfigSetCommand.cs @@ -61,7 +61,7 @@ public override async Task ExecuteAsync(CommandContext context, Settings se JsonNode? root = null; if (File.Exists(configPath)) { - var json = await File.ReadAllTextAsync(configPath, cancellationToken); + var json = await File.ReadAllTextAsync(configPath, cancellationToken).ConfigureAwait(false); root = JsonNode.Parse(json); } @@ -102,6 +102,12 @@ public override async Task ExecuteAsync(CommandContext context, Settings se value = store.Encrypt(value); } + if (settings.Type is not null && settings.Type.ToLowerInvariant() is not ("string" or "int" or "bool")) + { + AnsiConsole.MarkupLine($"[red]Error:[/] Unsupported type '{Markup.Escape(settings.Type)}'. Supported: string, int, bool."); + return 1; + } + var typed = DetectTypedValue(value, settings.Type); if (typed is null) { @@ -113,7 +119,7 @@ public override async Task ExecuteAsync(CommandContext context, Settings se var output = root.ToJsonString(new JsonSerializerOptions { WriteIndented = true }); var tempPath = configPath + ".tmp"; - await File.WriteAllTextAsync(tempPath, output, cancellationToken); + await File.WriteAllTextAsync(tempPath, output, cancellationToken).ConfigureAwait(false); File.Move(tempPath, configPath, overwrite: true); AnsiConsole.MarkupLine($"[green]Set[/] [cyan]{Markup.Escape(key)}[/] in [grey]~/.clawsharp/config.json[/]"); diff --git a/src/clawsharp/Cli/Config/EncryptSecretsCommand.cs b/src/clawsharp/Cli/Config/EncryptSecretsCommand.cs index 30156df3..70424c98 100644 --- a/src/clawsharp/Cli/Config/EncryptSecretsCommand.cs +++ b/src/clawsharp/Cli/Config/EncryptSecretsCommand.cs @@ -15,7 +15,7 @@ public sealed class EncryptSecretsCommand : AsyncCommand // Fields that hold secrets in config.json private static readonly IReadOnlySet SecretFields = KnownSecretFields.All; - public override Task ExecuteAsync(CommandContext context, CancellationToken cancellationToken) + public override async Task ExecuteAsync(CommandContext context, CancellationToken cancellationToken) { var config = ClawsharpConfiguration.GetAppConfig(); var store = new SecretStore(Microsoft.Extensions.Options.Options.Create(config)); @@ -24,24 +24,25 @@ public override Task ExecuteAsync(CommandContext context, CancellationToken if (!File.Exists(configPath)) { AnsiConsole.MarkupLine("[yellow]No config file found at {0}.[/]", Markup.Escape(configPath)); - return Task.FromResult(1); + return 1; } - var json = File.ReadAllText(configPath); + var json = await File.ReadAllTextAsync(configPath, cancellationToken).ConfigureAwait(false); var root = JsonNode.Parse(json) as JsonObject; if (root is null) { AnsiConsole.MarkupLine("[red]Config file is not valid JSON.[/]"); - return Task.FromResult(1); + return 1; } var count = EncryptNode(root, store); var tempPath = configPath + ".tmp"; - File.WriteAllText(tempPath, root.ToJsonString(new JsonSerializerOptions { WriteIndented = true })); + await File.WriteAllTextAsync(tempPath, + root.ToJsonString(new JsonSerializerOptions { WriteIndented = true }), cancellationToken).ConfigureAwait(false); File.Move(tempPath, configPath, overwrite: true); AnsiConsole.MarkupLine($"[green]Encrypted {count} secret field(s) in {Markup.Escape(configPath)}.[/]"); - return Task.FromResult(0); + return 0; } private static int EncryptNode(JsonNode node, SecretStore store) diff --git a/src/clawsharp/Cli/Cron/CronAddCommand.cs b/src/clawsharp/Cli/Cron/CronAddCommand.cs index 92265382..cae5ddc4 100644 --- a/src/clawsharp/Cli/Cron/CronAddCommand.cs +++ b/src/clawsharp/Cli/Cron/CronAddCommand.cs @@ -75,8 +75,8 @@ public override async Task ExecuteAsync(CommandContext context, Settings se var config = ClawsharpConfiguration.GetAppConfig(); var store = CronStoreFactory.Create(config); - await store.InitAsync(cancellationToken); - await store.UpsertAsync(job, cancellationToken); + await store.InitAsync(cancellationToken).ConfigureAwait(false); + await store.UpsertAsync(job, cancellationToken).ConfigureAwait(false); AnsiConsole.MarkupLine($"[green]Created[/] cron job [cyan]{job.Id[..8]}[/] " + $"(kind=[bold]{Markup.Escape(kind.Value)}[/], expr=[bold]{Markup.Escape(expr)}[/], " + diff --git a/src/clawsharp/Cli/Cron/CronListCommand.cs b/src/clawsharp/Cli/Cron/CronListCommand.cs index 1542c740..eea285de 100644 --- a/src/clawsharp/Cli/Cron/CronListCommand.cs +++ b/src/clawsharp/Cli/Cron/CronListCommand.cs @@ -14,8 +14,8 @@ public override async Task ExecuteAsync(CommandContext context, Cancellatio { var config = ClawsharpConfiguration.GetAppConfig(); var store = CronStoreFactory.Create(config); - await store.InitAsync(cancellationToken); - var jobs = await store.LoadAllAsync(cancellationToken); + await store.InitAsync(cancellationToken).ConfigureAwait(false); + var jobs = await store.LoadAllAsync(cancellationToken).ConfigureAwait(false); if (jobs.Count == 0) { diff --git a/src/clawsharp/Cli/Cron/CronRemoveCommand.cs b/src/clawsharp/Cli/Cron/CronRemoveCommand.cs index 6f8d222b..7129c04b 100644 --- a/src/clawsharp/Cli/Cron/CronRemoveCommand.cs +++ b/src/clawsharp/Cli/Cron/CronRemoveCommand.cs @@ -23,8 +23,8 @@ public override async Task ExecuteAsync(CommandContext context, Settings se { var config = ClawsharpConfiguration.GetAppConfig(); var store = CronStoreFactory.Create(config); - await store.InitAsync(cancellationToken); - var jobs = await store.LoadAllAsync(cancellationToken); + await store.InitAsync(cancellationToken).ConfigureAwait(false); + var jobs = await store.LoadAllAsync(cancellationToken).ConfigureAwait(false); var match = jobs.FirstOrDefault(j => j.Id.StartsWith(settings.Id, StringComparison.OrdinalIgnoreCase)); if (match is null) @@ -33,7 +33,7 @@ public override async Task ExecuteAsync(CommandContext context, Settings se return 1; } - await store.DeleteAsync(match.Id, cancellationToken); + await store.DeleteAsync(match.Id, cancellationToken).ConfigureAwait(false); var shortId = match.Id[..Math.Min(8, match.Id.Length)]; AnsiConsole.MarkupLine($"[green]Removed[/] cron job [cyan]{shortId}[/] ([grey]{Markup.Escape(match.Name ?? "(unnamed)")}[/])"); return 0; diff --git a/src/clawsharp/Cli/Cron/CronRunCommand.cs b/src/clawsharp/Cli/Cron/CronRunCommand.cs index a355f858..759d8986 100644 --- a/src/clawsharp/Cli/Cron/CronRunCommand.cs +++ b/src/clawsharp/Cli/Cron/CronRunCommand.cs @@ -23,8 +23,8 @@ public override async Task ExecuteAsync(CommandContext context, Settings se { var config = ClawsharpConfiguration.GetAppConfig(); var store = CronStoreFactory.Create(config); - await store.InitAsync(cancellationToken); - var jobs = await store.LoadAllAsync(cancellationToken); + await store.InitAsync(cancellationToken).ConfigureAwait(false); + var jobs = await store.LoadAllAsync(cancellationToken).ConfigureAwait(false); var match = jobs.FirstOrDefault(j => j.Id.StartsWith(settings.Id, StringComparison.OrdinalIgnoreCase)); if (match is null) diff --git a/src/clawsharp/Cli/DoctorCommand.cs b/src/clawsharp/Cli/DoctorCommand.cs index 92983bc5..ac93a69d 100644 --- a/src/clawsharp/Cli/DoctorCommand.cs +++ b/src/clawsharp/Cli/DoctorCommand.cs @@ -48,9 +48,9 @@ public override async Task ExecuteAsync(CommandContext context, Settings se AnsiConsole.WriteLine(); AnsiConsole.MarkupLine("[bold]Deep checks:[/]"); - failures += await CheckProviderConnectivity(config, cancellationToken); - failures += await CheckDatabaseConnectivity(config, cancellationToken); - warnings += await CheckWorkspaceWritability(workspace, cancellationToken); + failures += await CheckProviderConnectivity(config, cancellationToken).ConfigureAwait(false); + failures += await CheckDatabaseConnectivity(config, cancellationToken).ConfigureAwait(false); + warnings += await CheckWorkspaceWritability(workspace, cancellationToken).ConfigureAwait(false); CheckSystemMd(workspace, ref warnings); CheckBraveSearch(config, ref warnings); Ok($".NET {Environment.Version}"); diff --git a/src/clawsharp/Cli/GatewayCommand.cs b/src/clawsharp/Cli/GatewayCommand.cs index 5dfc8024..cb2725f4 100644 --- a/src/clawsharp/Cli/GatewayCommand.cs +++ b/src/clawsharp/Cli/GatewayCommand.cs @@ -17,7 +17,7 @@ public sealed class GatewayCommand : AsyncCommand Justification = "Spectre.Console.Cli already requires dynamic code. EF Core types are statically rooted in this project.")] public override async Task ExecuteAsync(CommandContext context, CancellationToken cancellationToken) { - await GatewayHost.RunAsync(cancellationToken); + await GatewayHost.RunAsync(cancellationToken).ConfigureAwait(false); return 0; } } \ No newline at end of file diff --git a/src/clawsharp/Cli/GatewayHost.cs b/src/clawsharp/Cli/GatewayHost.cs index ae439bfb..947575b7 100644 --- a/src/clawsharp/Cli/GatewayHost.cs +++ b/src/clawsharp/Cli/GatewayHost.cs @@ -120,9 +120,14 @@ public static async Task RunAsync(CancellationToken ct = default) ApplyLandlockSandbox(appConfig); - appConfig.Channels.TryGetValue("discord", out var discordCfg); + appConfig.Channels.TryGetValue(ChannelName.Discord.Value, out var discordCfg); var discordEnabled = discordCfg is { Enabled: true, Token: not null }; + // Pre-load knowledge plugins before host construction so the async + // verification path is properly awaited instead of blocked via + // GetAwaiter().GetResult() inside the synchronous ConfigureServices callback. + var knowledgePlugins = await LoadKnowledgePluginsAsync(appConfig).ConfigureAwait(false); + var hostBuilder = Host.CreateDefaultBuilder(Array.Empty()) .ConfigureLogging(logging => ConfigureLogging(logging, appConfig.Telemetry)) .AddClawsharpTelemetry(appConfig.Telemetry) @@ -132,7 +137,7 @@ public static async Task RunAsync(CancellationToken ct = default) var webProxy = CreateProxy(appConfig); ConfigureHostOptions(services); - AddLlmHttpClient(services, appConfig, webProxy); + AddLlmHttpClient(services, appConfig, ssrfConnectCallback, webProxy); AddToolAndTranscriptionHttpClients(services, ssrfConnectCallback, webProxy); AddChannelHttpClients(services, appConfig, ssrfConnectCallback, webProxy); services.AddChannelResiliencePipelines(appConfig.Channels); @@ -140,9 +145,10 @@ public static async Task RunAsync(CancellationToken ct = default) RegisterEmbeddingProvider(services, appConfig); RegisterMemoryBackend(services, appConfig); RegisterKnowledgeStore(services, appConfig); - RegisterDocumentLoaders(services, appConfig, configuration); + RegisterDocumentLoaders(services, appConfig, configuration, knowledgePlugins); RegisterIngestionPipeline(services, appConfig); - RegisterReranker(services, appConfig); + var rerankerHandler = CreateHandlerFactory(ssrfConnectCallback, webProxy, useProxy: true); + RegisterReranker(services, appConfig, rerankerHandler); RegisterProviderFactory(services, appConfig); RegisterConditionalHostedServices(services, appConfig); RegisterSharedAuthServices(services, appConfig); @@ -164,7 +170,7 @@ public static async Task RunAsync(CancellationToken ct = default) ConfigureDiscord(hostBuilder, discordCfg!); } - await hostBuilder.RunConsoleAsync(ct); + await hostBuilder.RunConsoleAsync(ct).ConfigureAwait(false); } /// @@ -174,7 +180,7 @@ public static async Task RunAsync(CancellationToken ct = default) /// [RequiresUnreferencedCode("Creates EF Core DbContext instances which use reflection for model building.")] [RequiresDynamicCode("Creates EF Core DbContext instances which require dynamic code for query compilation.")] - internal static ServiceProvider BuildKnowledgeServiceProvider(AppConfig appConfig) + internal static async Task BuildKnowledgeServiceProviderAsync(AppConfig appConfig) { var configuration = ClawsharpConfiguration.Build(); var services = new ServiceCollection(); @@ -187,11 +193,13 @@ internal static ServiceProvider BuildKnowledgeServiceProvider(AppConfig appConfi // Options services.AddSingleton>(new OptionsWrapper(appConfig)); + var plugins = await LoadKnowledgePluginsAsync(appConfig).ConfigureAwait(false); + // Embedding, memory, knowledge store, document loaders, ingestion pipeline RegisterEmbeddingProvider(services, appConfig); RegisterMemoryBackend(services, appConfig); RegisterKnowledgeStore(services, appConfig); - RegisterDocumentLoaders(services, appConfig, configuration); + RegisterDocumentLoaders(services, appConfig, configuration, plugins); RegisterIngestionPipeline(services, appConfig); RegisterReranker(services, appConfig); @@ -224,6 +232,13 @@ private static void ApplyLandlockSandbox(AppConfig appConfig) using var landlockLoggerFactory = LoggerFactory.Create(b => b.AddConsole().SetMinimumLevel(LogLevel.Information)); var landlockLogger = landlockLoggerFactory.CreateLogger("Landlock"); + + if (!(appConfig.Security?.Landlock?.Enabled ?? false)) + { + landlockLogger.LogInformation( + "Landlock filesystem sandbox is not enabled — consider enabling security.landlock.enabled for defense-in-depth"); + } + var shellEnabled = appConfig.Tools.ShellEnabled; LandlockSandbox.Apply(appConfig.Security?.Landlock ?? new LandlockConfig(), landlockLogger, shellEnabled); } @@ -336,6 +351,7 @@ private static Func CreateHandlerFactory private static void AddLlmHttpClient( IServiceCollection services, AppConfig appConfig, + Func> ssrfConnectCallback, System.Net.WebProxy? webProxy) { var resilience = appConfig.Agents.Defaults.Resilience; @@ -346,6 +362,7 @@ private static void AddLlmHttpClient( .ConfigurePrimaryHttpMessageHandler(() => { var h = new SocketsHttpHandler(); + h.ConnectCallback = ssrfConnectCallback; if (webProxy is not null) { h.Proxy = webProxy; @@ -428,6 +445,9 @@ private static void AddChannelHttpClients( { var noProxyHandler = CreateHandlerFactory(ssrfConnectCallback, webProxy, useProxy: false); + // OIDC token exchange — 30s timeout, SSRF-protected. + AddSsrfSafeHttpClient(services, noProxyHandler, "oidc", timeoutSeconds: 30); + // Telegram — 35 s timeout (> 30 s long-poll). AddSsrfSafeHttpClient(services, noProxyHandler, "telegram", timeoutSeconds: 35, configure: client => client.BaseAddress = new Uri(ClawsharpConstants.TelegramBaseUrl)); @@ -494,7 +514,7 @@ private static void AddChannelHttpClients( // Lark/Feishu — domain determined by feishuDomain config. AddSsrfSafeHttpClient(services, noProxyHandler, "lark", configure: client => { - if (!appConfig.Channels.TryGetValue("lark", out var larkCfg)) + if (!appConfig.Channels.TryGetValue(ChannelName.Lark.Value, out var larkCfg)) { return; } @@ -752,7 +772,46 @@ internal static void RegisterKnowledgeStore(IServiceCollection services, AppConf /// Registers the five built-in document loaders and the DocumentLoaderRegistry /// for the knowledge ingestion pipeline per D-31. Only registers when knowledge is enabled. /// - internal static void RegisterDocumentLoaders(IServiceCollection services, AppConfig appConfig, IConfiguration configuration) + /// + /// Loads knowledge plugins asynchronously, including integrity verification when configured. + /// Returns empty list if knowledge is not enabled or no plugins directory exists. + /// Called before host construction so the async load is properly awaited instead of blocked. + /// + internal static async Task> LoadKnowledgePluginsAsync(AppConfig appConfig) + { + if (appConfig.Knowledge is not { Enabled: true }) + return Array.Empty(); + + var pluginsPath = appConfig.Knowledge.PluginsPath + ?? Path.Combine(AppContext.BaseDirectory, "plugins"); + + using var pluginLoggerFactory = LoggerFactory.Create( + b => b.AddConsole().SetMinimumLevel(LogLevel.Information)); + var pluginLogger = pluginLoggerFactory.CreateLogger("PluginLoader"); + + if (appConfig.Knowledge.RequireSignedPlugins) + { + // D-35: Integrity verification BEFORE assembly loading. + var auditLogger = new AuditLogger( + Options.Create(appConfig), + pluginLoggerFactory.CreateLogger()); + var verifier = new PluginIntegrityVerifier( + auditLogger, + appConfig.Knowledge, + pluginLoggerFactory.CreateLogger()); + return await PluginLoader.LoadPluginsAsync( + pluginsPath, verifier, requireSigned: true, + pluginLogger).ConfigureAwait(false); + } + + pluginLogger.LogWarning( + "Plugin signature verification is disabled — loading unsigned plugins from {PluginsPath}", + pluginsPath); + return PluginLoader.LoadPlugins(pluginsPath, pluginLogger); + } + + internal static void RegisterDocumentLoaders( + IServiceCollection services, AppConfig appConfig, IConfiguration configuration, IReadOnlyList plugins) { if (appConfig.Knowledge is not { Enabled: true }) { @@ -766,20 +825,12 @@ internal static void RegisterDocumentLoaders(IServiceCollection services, AppCon services.AddSingleton(); services.AddSingleton(); - // Plugin system: discover and load plugin DLLs from plugins/ directory (PLUG-01 through PLUG-04) - var pluginsPath = appConfig.Knowledge.PluginsPath - ?? Path.Combine(AppContext.BaseDirectory, "plugins"); - - var plugins = PluginLoader.LoadPluginsAsync( - pluginsPath, verifier: null, requireSigned: false, - NullLogger.Instance).GetAwaiter().GetResult(); - - // Each plugin registers its IDocumentLoader implementations + supporting services (D-08) - foreach (var plugin in plugins) - { - var section = configuration.GetSection($"knowledge:plugins:{plugin.Name}"); - plugin.ConfigureServices(services, section); - } + // Each plugin registers its IDocumentLoader implementations + supporting services (D-08). + // Fault-tolerant: failures are logged and skipped (D-04/D-05). + using var pluginLoggerFactory = LoggerFactory.Create( + b => b.AddConsole().SetMinimumLevel(LogLevel.Information)); + PluginLoader.RegisterPluginServices(plugins, services, configuration, + pluginLoggerFactory.CreateLogger("PluginLoader")); // D-31: Registry collects all IDocumentLoader from DI and indexes by extension services.AddSingleton(); @@ -810,7 +861,7 @@ internal static void RegisterIngestionPipeline(IServiceCollection services, AppC if (embeddingProvider is IBatchEmbeddingProvider nativeBatch) return nativeBatch; - var batchConfig = appConfig.Knowledge.Embedding ?? new Clawsharp.Knowledge.Config.EmbeddingBatchConfig(); + var batchConfig = appConfig.Knowledge.Embedding ?? new EmbeddingBatchConfig(); var logger = sp.GetRequiredService>(); return new BatchEmbeddingProvider(embeddingProvider, batchConfig, logger); }); @@ -823,11 +874,11 @@ internal static void RegisterIngestionPipeline(IServiceCollection services, AppC Func>? factory = appConfig.Memory.Backend switch { var b when b == MemoryBackend.Sqlite.Value => - async ct => await sp.GetRequiredService>().CreateDbContextAsync(ct), + async ct => await sp.GetRequiredService>().CreateDbContextAsync(ct).ConfigureAwait(false), var b when b == MemoryBackend.Postgres.Value => - async ct => await sp.GetRequiredService>().CreateDbContextAsync(ct), + async ct => await sp.GetRequiredService>().CreateDbContextAsync(ct).ConfigureAwait(false), var b when b == MemoryBackend.MsSql.Value => - async ct => await sp.GetRequiredService>().CreateDbContextAsync(ct), + async ct => await sp.GetRequiredService>().CreateDbContextAsync(ct).ConfigureAwait(false), _ => null, // Redis, Markdown — no EF-based CAS }; return new SyncStateTracker(factory, logger); @@ -851,7 +902,10 @@ internal static void RegisterIngestionPipeline(IServiceCollection services, AppC /// Only registers when knowledge.enabled is true; otherwise no IReranker in DI /// (ToolRegistry handles null IReranker gracefully). /// - internal static void RegisterReranker(IServiceCollection services, AppConfig appConfig) + internal static void RegisterReranker( + IServiceCollection services, + AppConfig appConfig, + Func? ssrfHandlerFactory = null) { if (appConfig.Knowledge is not { Enabled: true }) { @@ -869,11 +923,19 @@ internal static void RegisterReranker(IServiceCollection services, AppConfig app if (string.Equals(rerankerConfig.Provider, "cohere", StringComparison.OrdinalIgnoreCase)) { - // D-25: Named HTTP client with 10s timeout - services.AddHttpClient("cohere-reranker", client => + // M-11: Use SSRF-safe HTTP client when handler factory is available (host path). + // CLI ingestion path passes null — falls back to plain AddHttpClient. + if (ssrfHandlerFactory is not null) { - client.Timeout = TimeSpan.FromSeconds(10); - }); + AddSsrfSafeHttpClient(services, ssrfHandlerFactory, "cohere-reranker", timeoutSeconds: 10); + } + else + { + services.AddHttpClient("cohere-reranker", client => + { + client.Timeout = TimeSpan.FromSeconds(10); + }); + } services.AddSingleton(sp => { @@ -911,9 +973,12 @@ private static void RegisterProviderFactory(IServiceCollection services, AppConf catch (Exception ex) { LogProviderFallback(initLogger, ex); - opts.Providers["ollama"] = new ProviderConfig - { Type = "ollama", BaseUrl = ClawsharpConstants.OllamaDefaultBaseUrl }; - return ProviderFactory.Create("ollama", opts.Providers, httpFactory); + var fallbackProviders = new Dictionary(opts.Providers) + { + ["ollama"] = new ProviderConfig + { Type = "ollama", BaseUrl = ClawsharpConstants.OllamaDefaultBaseUrl } + }; + return ProviderFactory.Create("ollama", fallbackProviders, httpFactory); } }); @@ -972,8 +1037,7 @@ internal static void RegisterMcpServerMode(IServiceCollection services, AppConfi services.AddSingleton(sp => new McpServerAuthenticator( appConfig.McpServer, - sp.GetRequiredService(), - sp.GetRequiredService>())); + sp.GetRequiredService())); services.AddSingleton(); services.AddSingleton(); services.AddSingleton( @@ -1280,83 +1344,83 @@ bool IsChannelEnabled(string key) => AddChannel(services); - if (IsChannelEnabled("web")) + if (IsChannelEnabled(ChannelName.Web.Value)) { AddChannel(services); services.AddSingleton(sp => sp.GetRequiredService()); } - if (IsChannelEnabled("telegram")) + if (IsChannelEnabled(ChannelName.Telegram.Value)) { AddChannel(services); } - if (IsChannelEnabled("slack")) + if (IsChannelEnabled(ChannelName.Slack.Value)) { AddChannel(services); } - if (IsChannelEnabled("matrix")) + if (IsChannelEnabled(ChannelName.Matrix.Value)) { AddChannel(services); } - if (IsChannelEnabled("email")) + if (IsChannelEnabled(ChannelName.Email.Value)) { AddChannel(services); } - if (IsChannelEnabled("irc")) + if (IsChannelEnabled(ChannelName.Irc.Value)) { AddChannel(services); } - if (IsChannelEnabled("mattermost")) + if (IsChannelEnabled(ChannelName.Mattermost.Value)) { AddChannel(services); } - if (IsChannelEnabled("nostr")) + if (IsChannelEnabled(ChannelName.Nostr.Value)) { AddChannel(services); } - if (IsChannelEnabled("qq")) + if (IsChannelEnabled(ChannelName.Qq.Value)) { AddChannel(services); } - if (IsChannelEnabled("signal")) + if (IsChannelEnabled(ChannelName.Signal.Value)) { AddChannel(services); } - if (IsChannelEnabled("whatsapp")) + if (IsChannelEnabled(ChannelName.WhatsApp.Value)) { AddChannel(services); } - if (IsChannelEnabled("wechat")) + if (IsChannelEnabled(ChannelName.WeChat.Value)) { AddChannel(services); } - if (IsChannelEnabled("bluebubbles")) + if (IsChannelEnabled(ChannelName.BlueBubbles.Value)) { AddChannel(services); } - if (IsChannelEnabled("line")) + if (IsChannelEnabled(ChannelName.Line.Value)) { AddChannel(services); } - if (IsChannelEnabled("lark")) + if (IsChannelEnabled(ChannelName.Lark.Value)) { AddChannel(services); } - if (IsChannelEnabled("wecom")) + if (IsChannelEnabled(ChannelName.WeCom.Value)) { AddChannel(services); } diff --git a/src/clawsharp/Cli/Knowledge/KnowledgeIngestCommand.cs b/src/clawsharp/Cli/Knowledge/KnowledgeIngestCommand.cs index 82bcfb18..dfb208fb 100644 --- a/src/clawsharp/Cli/Knowledge/KnowledgeIngestCommand.cs +++ b/src/clawsharp/Cli/Knowledge/KnowledgeIngestCommand.cs @@ -47,14 +47,14 @@ public override async Task ExecuteAsync( AnsiConsole.MarkupLine($"[grey]Path:[/] {Markup.Escape(sourceConfig.Path ?? sourceConfig.Url ?? "(none)")}"); AnsiConsole.WriteLine(); - await using var sp = GatewayHost.BuildKnowledgeServiceProvider(config); + await using var sp = await GatewayHost.BuildKnowledgeServiceProviderAsync(config).ConfigureAwait(false); var pipeline = sp.GetRequiredService(); var store = sp.GetRequiredService(); // Find or create the source entity var normalizedUri = sourceConfig.Path ?? sourceConfig.Url ?? sourceConfig.Name; - var sources = await store.ListSourcesAsync(cancellationToken); + var sources = await store.ListSourcesAsync(cancellationToken).ConfigureAwait(false); var existingSource = sources.FirstOrDefault( s => string.Equals(s.SourceUri, normalizedUri, StringComparison.Ordinal)); @@ -74,7 +74,7 @@ public override async Task ExecuteAsync( try { - await pipeline.IngestSourceAsync(sourceConfig, sourceId, progress, cancellationToken, trigger: "cli"); + await pipeline.IngestSourceAsync(sourceConfig, sourceId, progress, cancellationToken, trigger: "cli").ConfigureAwait(false); return 0; } catch (Exception ex) when (ex is not OperationCanceledException) @@ -113,11 +113,17 @@ internal static KnowledgeSourceConfig ResolveSourceConfig(AppConfig config, stri }; } + var fullPath = Path.GetFullPath(source); + if (!File.Exists(fullPath) && !Directory.Exists(fullPath)) + { + throw new FileNotFoundException($"Path not found: {fullPath}"); + } + return new KnowledgeSourceConfig { Name = Path.GetFileName(source.TrimEnd(Path.DirectorySeparatorChar, Path.AltDirectorySeparatorChar)), Type = "local", - Path = Path.GetFullPath(source), + Path = fullPath, }; } diff --git a/src/clawsharp/Cli/Knowledge/KnowledgeStatusCommand.cs b/src/clawsharp/Cli/Knowledge/KnowledgeStatusCommand.cs index 1f530ac4..59ebabad 100644 --- a/src/clawsharp/Cli/Knowledge/KnowledgeStatusCommand.cs +++ b/src/clawsharp/Cli/Knowledge/KnowledgeStatusCommand.cs @@ -28,10 +28,10 @@ public override async Task ExecuteAsync(CommandContext context, Cancellatio return 1; } - await using var sp = GatewayHost.BuildKnowledgeServiceProvider(config); + await using var sp = await GatewayHost.BuildKnowledgeServiceProviderAsync(config).ConfigureAwait(false); var store = sp.GetRequiredService(); - var sources = await store.ListSourcesAsync(cancellationToken); + var sources = await store.ListSourcesAsync(cancellationToken).ConfigureAwait(false); if (sources.Count == 0) { diff --git a/src/clawsharp/Cli/Memory/MemoryClearCommand.cs b/src/clawsharp/Cli/Memory/MemoryClearCommand.cs index 3a7c1c79..9775f2cb 100644 --- a/src/clawsharp/Cli/Memory/MemoryClearCommand.cs +++ b/src/clawsharp/Cli/Memory/MemoryClearCommand.cs @@ -27,7 +27,7 @@ public override async Task ExecuteAsync(CommandContext context, Cancellatio } var memory = MemoryFactory.Create(config); - await memory.ClearAsync(cancellationToken); + await memory.ClearAsync(cancellationToken).ConfigureAwait(false); AnsiConsole.MarkupLine("[green]Memory cleared successfully.[/]"); return 0; diff --git a/src/clawsharp/Cli/Memory/MemoryExportCommand.cs b/src/clawsharp/Cli/Memory/MemoryExportCommand.cs index 8895399e..f1d5ae29 100644 --- a/src/clawsharp/Cli/Memory/MemoryExportCommand.cs +++ b/src/clawsharp/Cli/Memory/MemoryExportCommand.cs @@ -30,7 +30,7 @@ public override async Task ExecuteAsync(CommandContext context, Settings se { var config = ClawsharpConfiguration.GetAppConfig(); var memory = MemoryFactory.Create(config); - var facts = await memory.ListFactsAsync(cancellationToken); + var facts = await memory.ListFactsAsync(cancellationToken).ConfigureAwait(false); if (facts.Count == 0) { @@ -47,7 +47,7 @@ public override async Task ExecuteAsync(CommandContext context, Settings se var list = new List(facts); var json = JsonSerializer.Serialize(list, ConfigJsonContext.Default.ListFact); - await File.WriteAllTextAsync(outputPath, json, cancellationToken); + await File.WriteAllTextAsync(outputPath, json, cancellationToken).ConfigureAwait(false); AnsiConsole.MarkupLine($"[green]Exported {facts.Count} fact(s) to[/] {Markup.Escape(outputPath)}"); return 0; diff --git a/src/clawsharp/Cli/Memory/MemoryListCommand.cs b/src/clawsharp/Cli/Memory/MemoryListCommand.cs index 690fb8d8..573aa93d 100644 --- a/src/clawsharp/Cli/Memory/MemoryListCommand.cs +++ b/src/clawsharp/Cli/Memory/MemoryListCommand.cs @@ -18,7 +18,7 @@ public override async Task ExecuteAsync(CommandContext context, Cancellatio { var config = ClawsharpConfiguration.GetAppConfig(); var memory = MemoryFactory.Create(config); - var facts = await memory.ListFactsAsync(cancellationToken); + var facts = await memory.ListFactsAsync(cancellationToken).ConfigureAwait(false); if (facts.Count == 0) { diff --git a/src/clawsharp/Cli/Memory/MemorySearchCommand.cs b/src/clawsharp/Cli/Memory/MemorySearchCommand.cs index bce6cec4..7ccf7a64 100644 --- a/src/clawsharp/Cli/Memory/MemorySearchCommand.cs +++ b/src/clawsharp/Cli/Memory/MemorySearchCommand.cs @@ -33,7 +33,7 @@ public override async Task ExecuteAsync(CommandContext context, Settings se { var config = ClawsharpConfiguration.GetAppConfig(); var memory = MemoryFactory.Create(config); - var results = await memory.SearchAsync(settings.Query, settings.Limit, cancellationToken); + var results = await memory.SearchAsync(settings.Query, settings.Limit, cancellationToken).ConfigureAwait(false); if (results.Count == 0) { diff --git a/src/clawsharp/Cli/Migrate/MigrateCommand.cs b/src/clawsharp/Cli/Migrate/MigrateCommand.cs index 3f6be5ea..fa290c46 100644 --- a/src/clawsharp/Cli/Migrate/MigrateCommand.cs +++ b/src/clawsharp/Cli/Migrate/MigrateCommand.cs @@ -48,9 +48,9 @@ public override async Task ExecuteAsync(CommandContext ctx, Settings settin return settings.Source.ToLowerInvariant() switch { - "picoclaw" => await MigratePicoClawAsync(settings, sourceConfig, destConfig, cancellation), - "zeroclaw" => await MigrateZeroClawAsync(settings, sourceConfig, destConfig, cancellation), - _ => await MigrateOpenClawAsync(settings, sourceConfig, destConfig, cancellation), + "picoclaw" => await MigratePicoClawAsync(settings, sourceConfig, destConfig, cancellation).ConfigureAwait(false), + "zeroclaw" => await MigrateZeroClawAsync(settings, sourceConfig, destConfig, cancellation).ConfigureAwait(false), + _ => await MigrateOpenClawAsync(settings, sourceConfig, destConfig, cancellation).ConfigureAwait(false), }; } @@ -70,7 +70,7 @@ private static async Task MigrateOpenClawAsync( AnsiConsole.MarkupLine($"[bold]Writing to:[/] {destConfig}"); AnsiConsole.WriteLine(); - var sourceText = await File.ReadAllTextAsync(sourceConfig, ct); + var sourceText = await File.ReadAllTextAsync(sourceConfig, ct).ConfigureAwait(false); var source = JsonNode.Parse(sourceText); if (source is null) { @@ -81,7 +81,7 @@ private static async Task MigrateOpenClawAsync( JsonNode dest; if (File.Exists(destConfig)) { - var existing = await File.ReadAllTextAsync(destConfig, ct); + var existing = await File.ReadAllTextAsync(destConfig, ct).ConfigureAwait(false); dest = JsonNode.Parse(existing) ?? new JsonObject(); } else @@ -202,7 +202,7 @@ private static async Task MigrateOpenClawAsync( warnings.Add("'hooks': Plugin hooks not available in clawsharp"); } - return await WriteDestConfig(settings, dest, destConfig, migrated, warnings, ct); + return await WriteDestConfig(settings, dest, destConfig, migrated, warnings, ct).ConfigureAwait(false); } // ── picoclaw ────────────────────────────────────────────────────────────── @@ -221,7 +221,7 @@ private static async Task MigratePicoClawAsync( AnsiConsole.MarkupLine($"[bold]Writing to:[/] {destConfig}"); AnsiConsole.WriteLine(); - var sourceText = await File.ReadAllTextAsync(sourceConfig, ct); + var sourceText = await File.ReadAllTextAsync(sourceConfig, ct).ConfigureAwait(false); var source = JsonNode.Parse(sourceText); if (source is null) { @@ -232,7 +232,7 @@ private static async Task MigratePicoClawAsync( JsonNode dest; if (File.Exists(destConfig)) { - var existing = await File.ReadAllTextAsync(destConfig, ct); + var existing = await File.ReadAllTextAsync(destConfig, ct).ConfigureAwait(false); dest = JsonNode.Parse(existing) ?? new JsonObject(); } else @@ -366,7 +366,7 @@ private static async Task MigratePicoClawAsync( warnings.Add("'session': picoclaw session config not applicable; clawsharp manages sessions automatically"); } - return await WriteDestConfig(settings, dest, destConfig, migrated, warnings, ct); + return await WriteDestConfig(settings, dest, destConfig, migrated, warnings, ct).ConfigureAwait(false); } // ── zeroclaw ────────────────────────────────────────────────────────────── @@ -385,13 +385,13 @@ private static async Task MigrateZeroClawAsync( AnsiConsole.MarkupLine($"[bold]Writing to:[/] {destConfig}"); AnsiConsole.WriteLine(); - var tomlText = await File.ReadAllTextAsync(sourceConfig, ct); + var tomlText = await File.ReadAllTextAsync(sourceConfig, ct).ConfigureAwait(false); var toml = ParseToml(tomlText); JsonNode dest; if (File.Exists(destConfig)) { - var existing = await File.ReadAllTextAsync(destConfig, ct); + var existing = await File.ReadAllTextAsync(destConfig, ct).ConfigureAwait(false); dest = JsonNode.Parse(existing) ?? new JsonObject(); } else @@ -491,7 +491,7 @@ private static async Task MigrateZeroClawAsync( migrated.Add("tools.brave.apiKey"); } - return await WriteDestConfig(settings, dest, destConfig, migrated, warnings, ct); + return await WriteDestConfig(settings, dest, destConfig, migrated, warnings, ct).ConfigureAwait(false); } // ── TOML parser ─────────────────────────────────────────────────────────── @@ -633,7 +633,7 @@ private static async Task WriteDestConfig( Directory.CreateDirectory(Path.GetDirectoryName(destConfig)!); var destText = dest.ToJsonString(new JsonSerializerOptions { WriteIndented = true }); - await File.WriteAllTextAsync(destConfig, destText, ct); + await File.WriteAllTextAsync(destConfig, destText, ct).ConfigureAwait(false); AnsiConsole.WriteLine(); AnsiConsole.MarkupLine($"[bold green]Config written to {destConfig}[/]"); AnsiConsole.MarkupLine("Run [bold]clawsharp config validate[/] to check for issues."); diff --git a/src/clawsharp/Cli/Models/ModelsJsonContext.cs b/src/clawsharp/Cli/Models/ModelsJsonContext.cs index 0d20b861..a42466bd 100644 --- a/src/clawsharp/Cli/Models/ModelsJsonContext.cs +++ b/src/clawsharp/Cli/Models/ModelsJsonContext.cs @@ -5,4 +5,5 @@ namespace Clawsharp.Cli.Models; /// Source-generated JSON context for model list API responses. [JsonSerializable(typeof(OpenAiModelsResponse))] [JsonSerializable(typeof(GeminiModelsResponse))] +[JsonSourceGenerationOptions(DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull)] internal sealed partial class ModelsJsonContext : JsonSerializerContext; \ No newline at end of file diff --git a/src/clawsharp/Cli/Models/ModelsListCommand.cs b/src/clawsharp/Cli/Models/ModelsListCommand.cs index d31ce6ff..97d0c0d7 100644 --- a/src/clawsharp/Cli/Models/ModelsListCommand.cs +++ b/src/clawsharp/Cli/Models/ModelsListCommand.cs @@ -61,12 +61,12 @@ public override async Task ExecuteAsync(CommandContext context, Cancellatio if (providerType == LlmProviderType.Gemini) { - await FetchGeminiModelsAsync(http, name, providerCfg, cancellationToken); + await FetchGeminiModelsAsync(http, name, providerCfg, cancellationToken).ConfigureAwait(false); continue; } // All remaining types are OpenAI-compatible - await FetchOpenAiModelsAsync(http, name, providerCfg, providerType, cancellationToken); + await FetchOpenAiModelsAsync(http, name, providerCfg, providerType, cancellationToken).ConfigureAwait(false); } return 0; @@ -103,12 +103,12 @@ private static async Task FetchGeminiModelsAsync( try { - using var response = await http.GetAsync(url, ct); + using var response = await http.GetAsync(url, ct).ConfigureAwait(false); response.EnsureSuccessStatusCode(); - await using var stream = await response.Content.ReadAsStreamAsync(ct); + await using var stream = await response.Content.ReadAsStreamAsync(ct).ConfigureAwait(false); var result = await JsonSerializer.DeserializeAsync( - stream, ModelsJsonContext.Default.GeminiModelsResponse, ct); + stream, ModelsJsonContext.Default.GeminiModelsResponse, ct).ConfigureAwait(false); var models = result?.Models; if (models is null || models.Count == 0) @@ -173,12 +173,12 @@ private static async Task FetchOpenAiModelsAsync( request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", providerCfg.ApiKey); } - using var response = await http.SendAsync(request, ct); + using var response = await http.SendAsync(request, ct).ConfigureAwait(false); response.EnsureSuccessStatusCode(); - await using var stream = await response.Content.ReadAsStreamAsync(ct); + await using var stream = await response.Content.ReadAsStreamAsync(ct).ConfigureAwait(false); var result = await JsonSerializer.DeserializeAsync( - stream, ModelsJsonContext.Default.OpenAiModelsResponse, ct); + stream, ModelsJsonContext.Default.OpenAiModelsResponse, ct).ConfigureAwait(false); var models = result?.Data; if (models is null || models.Count == 0) diff --git a/src/clawsharp/Cli/OnboardCommand.cs b/src/clawsharp/Cli/OnboardCommand.cs index 7d8af660..a7f995a7 100644 --- a/src/clawsharp/Cli/OnboardCommand.cs +++ b/src/clawsharp/Cli/OnboardCommand.cs @@ -83,10 +83,10 @@ public override async Task ExecuteAsync(CommandContext context, Settings se } AnsiConsole.WriteLine(); - await SkillRegistry.InstallSkillsAsync(skillsToInstall, cancellationToken); + await SkillRegistry.InstallSkillsAsync(skillsToInstall, cancellationToken).ConfigureAwait(false); await WriteConfigAndPrintSummary( - providerType, model, apiKey, selectedChannels, channelCreds, skillsToInstall, cancellationToken); + providerType, model, apiKey, selectedChannels, channelCreds, skillsToInstall, cancellationToken).ConfigureAwait(false); PrintOpenAccessWarnings(selectedChannels, channelCreds); PrintChannelSecurityAdvisories(selectedChannels); diff --git a/src/clawsharp/Cli/Pairing/PairingApproveCommand.cs b/src/clawsharp/Cli/Pairing/PairingApproveCommand.cs index eeeb7b65..b633a278 100644 --- a/src/clawsharp/Cli/Pairing/PairingApproveCommand.cs +++ b/src/clawsharp/Cli/Pairing/PairingApproveCommand.cs @@ -27,7 +27,7 @@ public override async Task ExecuteAsync(CommandContext context, Settings se var store = new PairingStore( NullLogger.Instance); - var approved = await store.ApproveAsync(settings.Code, cancellationToken); + var approved = await store.ApproveAsync(settings.Code, cancellationToken).ConfigureAwait(false); if (approved is null) { AnsiConsole.MarkupLine("[red]Pairing code not found or expired.[/]"); @@ -37,7 +37,7 @@ public override async Task ExecuteAsync(CommandContext context, Settings se // Add the sender to the dynamic approved-senders store (takes effect immediately on running gateways). var approvedSenders = new ApprovedSendersStore( NullLogger.Instance); - await approvedSenders.AddAsync(approved.Channel, approved.SenderId, cancellationToken); + await approvedSenders.AddAsync(approved.Channel, approved.SenderId, cancellationToken).ConfigureAwait(false); // Add the sender ID to channels.{channel}.allowFrom in ~/.clawsharp/config.json var configPath = Path.Combine(ConfigLoader.ExpandHome("~/.clawsharp"), "config.json"); @@ -46,7 +46,7 @@ public override async Task ExecuteAsync(CommandContext context, Settings se JsonNode? root = null; if (File.Exists(configPath)) { - var json = await File.ReadAllTextAsync(configPath, cancellationToken); + var json = await File.ReadAllTextAsync(configPath, cancellationToken).ConfigureAwait(false); root = JsonNode.Parse(json); } @@ -94,7 +94,7 @@ public override async Task ExecuteAsync(CommandContext context, Settings se var output = root.ToJsonString(new JsonSerializerOptions { WriteIndented = true }); var tempPath = configPath + ".tmp"; - await File.WriteAllTextAsync(tempPath, output, cancellationToken); + await File.WriteAllTextAsync(tempPath, output, cancellationToken).ConfigureAwait(false); File.Move(tempPath, configPath, overwrite: true); AnsiConsole.MarkupLine( diff --git a/src/clawsharp/Cli/Pairing/PairingListCommand.cs b/src/clawsharp/Cli/Pairing/PairingListCommand.cs index 0b80dd16..5a12efc1 100644 --- a/src/clawsharp/Cli/Pairing/PairingListCommand.cs +++ b/src/clawsharp/Cli/Pairing/PairingListCommand.cs @@ -15,7 +15,7 @@ public override async Task ExecuteAsync(CommandContext context, Cancellatio var store = new PairingStore( Microsoft.Extensions.Logging.Abstractions.NullLogger.Instance); - var pending = await store.GetPendingAsync(cancellationToken); + var pending = await store.GetPendingAsync(cancellationToken).ConfigureAwait(false); if (pending.Count == 0) { diff --git a/src/clawsharp/Cli/Policy/PolicyExplainCommand.cs b/src/clawsharp/Cli/Policy/PolicyExplainCommand.cs index c0fbb991..64bde6e6 100644 --- a/src/clawsharp/Cli/Policy/PolicyExplainCommand.cs +++ b/src/clawsharp/Cli/Policy/PolicyExplainCommand.cs @@ -55,7 +55,7 @@ public override Task ExecuteAsync(CommandContext context, Settings settings if (abacRules is { Count: > 0 }) { // Use current time as frozen timestamp for CLI explain - var ctx = new AbacContext(orgUser, Clawsharp.Core.Utilities.ChannelName.Cli, DateTimeOffset.UtcNow); + var ctx = new AbacContext(orgUser, Core.Utilities.ChannelName.Cli, DateTimeOffset.UtcNow); abacDecision = evaluator.ApplyAbacRules(rbacDecision, abacRules, ctx); } diff --git a/src/clawsharp/Cli/Policy/PolicySimulateCommand.cs b/src/clawsharp/Cli/Policy/PolicySimulateCommand.cs index 8a737713..7e3c2cca 100644 --- a/src/clawsharp/Cli/Policy/PolicySimulateCommand.cs +++ b/src/clawsharp/Cli/Policy/PolicySimulateCommand.cs @@ -64,7 +64,7 @@ public override Task ExecuteAsync(CommandContext context, Settings settings var abacRules = config.Organization.Policies?.Rules; if (abacRules is { Count: > 0 }) { - var ctx = new AbacContext(orgUser, Clawsharp.Core.Utilities.ChannelName.Cli, DateTimeOffset.UtcNow); + var ctx = new AbacContext(orgUser, Core.Utilities.ChannelName.Cli, DateTimeOffset.UtcNow); decision = evaluator.ApplyAbacRules(rbacDecision, abacRules, ctx); } else diff --git a/src/clawsharp/Cli/Service/ServiceCommand.cs b/src/clawsharp/Cli/Service/ServiceCommand.cs index 2fade910..aac328c7 100644 --- a/src/clawsharp/Cli/Service/ServiceCommand.cs +++ b/src/clawsharp/Cli/Service/ServiceCommand.cs @@ -21,7 +21,7 @@ public sealed class Settings : CommandSettings } public override async Task ExecuteAsync(CommandContext context, Settings settings, CancellationToken cancellationToken) - => await ServiceCommand.InstallAsync(settings.System, cancellationToken); + => await ServiceCommand.InstallAsync(settings.System, cancellationToken).ConfigureAwait(false); } /// Spectre command: clawsharp service uninstall [--system] @@ -38,7 +38,7 @@ public sealed class Settings : CommandSettings } public override async Task ExecuteAsync(CommandContext context, Settings settings, CancellationToken cancellationToken) - => await ServiceCommand.UninstallAsync(settings.System, cancellationToken); + => await ServiceCommand.UninstallAsync(settings.System, cancellationToken).ConfigureAwait(false); } /// Spectre command: clawsharp service status @@ -46,7 +46,7 @@ public override async Task ExecuteAsync(CommandContext context, Settings se public sealed class ServiceStatusCommand : AsyncCommand { public override async Task ExecuteAsync(CommandContext context, CancellationToken cancellationToken) - => await ServiceCommand.StatusAsync(cancellationToken); + => await ServiceCommand.StatusAsync(cancellationToken).ConfigureAwait(false); } /// @@ -75,17 +75,17 @@ public static async Task InstallAsync(bool system, CancellationToken ct = d if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) { - return await InstallSystemdAsync(binaryPath, system, ct); + return await InstallSystemdAsync(binaryPath, system, ct).ConfigureAwait(false); } if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { - return await InstallLaunchdAsync(binaryPath, ct); + return await InstallLaunchdAsync(binaryPath, ct).ConfigureAwait(false); } if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { - return await InstallWindowsServiceAsync(binaryPath, ct); + return await InstallWindowsServiceAsync(binaryPath, ct).ConfigureAwait(false); } AnsiConsole.MarkupLine("[red][[service]][/] Unsupported platform."); @@ -98,17 +98,17 @@ public static async Task UninstallAsync(bool system, CancellationToken ct = { if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) { - return await UninstallSystemdAsync(system, ct); + return await UninstallSystemdAsync(system, ct).ConfigureAwait(false); } if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { - return await UninstallLaunchdAsync(ct); + return await UninstallLaunchdAsync(ct).ConfigureAwait(false); } if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { - return await UninstallWindowsServiceAsync(ct); + return await UninstallWindowsServiceAsync(ct).ConfigureAwait(false); } AnsiConsole.MarkupLine("[red][[service]][/] Unsupported platform."); @@ -121,7 +121,7 @@ public static async Task StatusAsync(CancellationToken ct = default) { if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) { - return await RunAsync("systemctl", $"--user status {ServiceName}", ct); + return await RunAsync("systemctl", $"--user status {ServiceName}", ct).ConfigureAwait(false); } if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) @@ -140,7 +140,7 @@ public static async Task StatusAsync(CancellationToken ct = default) if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { - return await RunAsync("sc.exe", $"query {ServiceName}", ct); + return await RunAsync("sc.exe", $"query {ServiceName}", ct).ConfigureAwait(false); } AnsiConsole.MarkupLine("[red][[service]][/] Unsupported platform."); @@ -180,7 +180,7 @@ private static async Task InstallSystemdAsync(string binaryPath, bool syste var configPath = GetConfigPath(); var unit = SystemdUnit(binaryPath, configPath, system); - await File.WriteAllTextAsync(unitPath, unit, ct); + await File.WriteAllTextAsync(unitPath, unit, ct).ConfigureAwait(false); AnsiConsole.MarkupLine($"[[service]] Unit file written: {Markup.Escape(unitPath)}"); if (await RunAsync("systemctl", $"{systemctlArgs} daemon-reload", ct) != 0) @@ -213,8 +213,8 @@ private static async Task UninstallSystemdAsync(bool system, CancellationTo var systemctlArgs = system ? "--system" : "--user"; - await RunAsync("systemctl", $"{systemctlArgs} stop {ServiceName}", ct); - await RunAsync("systemctl", $"{systemctlArgs} disable {ServiceName}", ct); + await RunAsync("systemctl", $"{systemctlArgs} stop {ServiceName}", ct).ConfigureAwait(false); + await RunAsync("systemctl", $"{systemctlArgs} disable {ServiceName}", ct).ConfigureAwait(false); string unitDir; if (system) @@ -234,7 +234,7 @@ private static async Task UninstallSystemdAsync(bool system, CancellationTo AnsiConsole.MarkupLine($"[[service]] Deleted: {Markup.Escape(unitPath)}"); } - await RunAsync("systemctl", $"{systemctlArgs} daemon-reload", ct); + await RunAsync("systemctl", $"{systemctlArgs} daemon-reload", ct).ConfigureAwait(false); AnsiConsole.MarkupLine($"[green][[service]][/] {ServiceName} uninstalled."); return 0; } @@ -286,7 +286,7 @@ private static async Task InstallLaunchdAsync(string binaryPath, Cancellati var configPath = GetConfigPath(); var plist = LaunchdPlist(binaryPath, configPath); - await File.WriteAllTextAsync(plistPath, plist, ct); + await File.WriteAllTextAsync(plistPath, plist, ct).ConfigureAwait(false); AnsiConsole.MarkupLine($"[[service]] Plist written: {Markup.Escape(plistPath)}"); var label = LaunchdLabel(); @@ -309,7 +309,7 @@ private static async Task UninstallLaunchdAsync(CancellationToken ct) return 0; } - await RunAsync("launchctl", $"unload -w {plistPath}", ct); + await RunAsync("launchctl", $"unload -w {plistPath}", ct).ConfigureAwait(false); File.Delete(plistPath); AnsiConsole.MarkupLine($"[[service]] Deleted: {Markup.Escape(plistPath)}"); return 0; @@ -375,15 +375,15 @@ private static string LaunchdPlist(string binaryPath, string? configPath) private static async Task InstallWindowsServiceAsync(string binaryPath, CancellationToken ct) { var code = await RunAsync("sc.exe", - $"create {ServiceName} binPath= \"{binaryPath} gateway\" start= auto DisplayName= \"{ServiceDesc}\"", ct); + $"create {ServiceName} binPath= \"{binaryPath} gateway\" start= auto DisplayName= \"{ServiceDesc}\"", ct).ConfigureAwait(false); if (code != 0) { return code; } - await RunAsync("sc.exe", $"description {ServiceName} \"{ServiceDesc}\"", ct); + await RunAsync("sc.exe", $"description {ServiceName} \"{ServiceDesc}\"", ct).ConfigureAwait(false); - code = await RunAsync("sc.exe", $"start {ServiceName}", ct); + code = await RunAsync("sc.exe", $"start {ServiceName}", ct).ConfigureAwait(false); if (code != 0) { return code; @@ -396,8 +396,8 @@ private static async Task InstallWindowsServiceAsync(string binaryPath, Can private static async Task UninstallWindowsServiceAsync(CancellationToken ct) { - await RunAsync("sc.exe", $"stop {ServiceName}", ct); - var code = await RunAsync("sc.exe", $"delete {ServiceName}", ct); + await RunAsync("sc.exe", $"stop {ServiceName}", ct).ConfigureAwait(false); + var code = await RunAsync("sc.exe", $"delete {ServiceName}", ct).ConfigureAwait(false); if (code == 0) { AnsiConsole.MarkupLine($"[green][[service]][/] {ServiceName} uninstalled."); @@ -472,7 +472,7 @@ private static async Task RunAsync(string exe, string args, CancellationTok return 1; } - await proc.WaitForExitAsync(ct); + await proc.WaitForExitAsync(ct).ConfigureAwait(false); return proc.ExitCode; } catch (Exception ex) diff --git a/src/clawsharp/Cli/Session/SessionCommand.cs b/src/clawsharp/Cli/Session/SessionCommand.cs index f2f01682..60fefb12 100644 --- a/src/clawsharp/Cli/Session/SessionCommand.cs +++ b/src/clawsharp/Cli/Session/SessionCommand.cs @@ -36,7 +36,7 @@ public override async Task ExecuteAsync(CommandContext context, Cancellatio try { await using var stream = File.OpenRead(fi.FullName); - var session = await JsonSerializer.DeserializeAsync(stream, SessionJsonContext.Default.Session, cancellationToken); + var session = await JsonSerializer.DeserializeAsync(stream, SessionJsonContext.Default.Session, cancellationToken).ConfigureAwait(false); if (session is null) { return (Name: Path.GetFileNameWithoutExtension(fi.Name), Messages: 0, In: 0L, Out: 0L, Ok: false); diff --git a/src/clawsharp/Cli/SingleShotCommand.cs b/src/clawsharp/Cli/SingleShotCommand.cs index de19f756..6f4e3909 100644 --- a/src/clawsharp/Cli/SingleShotCommand.cs +++ b/src/clawsharp/Cli/SingleShotCommand.cs @@ -42,7 +42,7 @@ public static async Task RunAsync(string message, CancellationToken ct = default { if (provider is IStreamingProvider streamingProvider) { - await foreach (var chunk in streamingProvider.StreamAsync(request, ct)) + await foreach (var chunk in streamingProvider.StreamAsync(request, ct).ConfigureAwait(false)) { if (chunk is TextDeltaChunk td) { @@ -54,7 +54,7 @@ public static async Task RunAsync(string message, CancellationToken ct = default } else { - var response = await provider.ChatAsync(request, ct); + var response = await provider.ChatAsync(request, ct).ConfigureAwait(false); AnsiConsole.MarkupLine(Markup.Escape(response.Content ?? "(no response)")); } } diff --git a/src/clawsharp/Cli/Skills/SkillRegistry.cs b/src/clawsharp/Cli/Skills/SkillRegistry.cs index e7c05dd6..3cd3b2ae 100644 --- a/src/clawsharp/Cli/Skills/SkillRegistry.cs +++ b/src/clawsharp/Cli/Skills/SkillRegistry.cs @@ -127,13 +127,13 @@ public static async Task InstallSkillAsync(string skill, CancellationToken ct) switch (entry.Source) { case SkillSource.BuiltIn: - await WriteBuiltInSkillAsync(skill, destDir, ct); + await WriteBuiltInSkillAsync(skill, destDir, ct).ConfigureAwait(false); break; case SkillSource.GitClone: - await GitCloneSkillAsync(skill, entry.CloneUrl!, destDir); + await GitCloneSkillAsync(skill, entry.CloneUrl!, destDir).ConfigureAwait(false); break; case SkillSource.GitHubApi: - await GitHubApiDownloadAsync(skill, entry.GitHubRepo!, entry.GitHubPath!, destDir, ct); + await GitHubApiDownloadAsync(skill, entry.GitHubRepo!, entry.GitHubPath!, destDir, ct).ConfigureAwait(false); break; } } @@ -142,7 +142,7 @@ public static async Task InstallSkillsAsync(IReadOnlyList skills, Cancel { foreach (var skill in skills) { - await InstallSkillAsync(skill, ct); + await InstallSkillAsync(skill, ct).ConfigureAwait(false); } } @@ -161,7 +161,7 @@ private static async Task WriteBuiltInSkillAsync(string skill, string destDir, C return; } - await File.WriteAllTextAsync(Path.Combine(destDir, "SKILL.md"), content, ct); + await File.WriteAllTextAsync(Path.Combine(destDir, "SKILL.md"), content, ct).ConfigureAwait(false); AnsiConsole.MarkupLine($" Installed {Markup.Escape(skill)} (built-in)"); } @@ -187,7 +187,7 @@ private static async Task GitCloneSkillAsync(string skill, string repoUrl, strin throw new InvalidOperationException("Failed to start git"); } - await proc.WaitForExitAsync(); + await proc.WaitForExitAsync().ConfigureAwait(false); if (proc.ExitCode == 0) { AnsiConsole.MarkupLine("[green]done[/]"); @@ -211,7 +211,7 @@ private static async Task GitHubApiDownloadAsync(string skill, string repo, stri try { Directory.CreateDirectory(destDir); - await DownloadGitHubDirAsync(SharedHttpClient, repo, repoPath, destDir, ct); + await DownloadGitHubDirAsync(SharedHttpClient, repo, repoPath, destDir, ct).ConfigureAwait(false); AnsiConsole.MarkupLine("[green]done[/]"); } catch (Exception ex) @@ -225,7 +225,7 @@ private static async Task GitHubApiDownloadAsync(string skill, string repo, stri private static async Task DownloadGitHubDirAsync(HttpClient http, string repo, string path, string localDir, CancellationToken ct) { var url = $"https://api.github.com/repos/{repo}/contents/{path}"; - var response = await http.GetStringAsync(url, ct); + var response = await http.GetStringAsync(url, ct).ConfigureAwait(false); using var doc = JsonDocument.Parse(response); foreach (var entry in doc.RootElement.EnumerateArray()) @@ -243,7 +243,7 @@ private static async Task DownloadGitHubDirAsync(HttpClient http, string repo, s } var dlUrl = dlProp.GetString()!; - var bytes = await http.GetByteArrayAsync(dlUrl, ct); + var bytes = await http.GetByteArrayAsync(dlUrl, ct).ConfigureAwait(false); // Redact hardcoded demo API key that ships in supermemory SKILL.md if (name.Equals("SKILL.md", StringComparison.OrdinalIgnoreCase)) { @@ -252,12 +252,12 @@ private static async Task DownloadGitHubDirAsync(HttpClient http, string repo, s bytes = Encoding.UTF8.GetBytes(text); } - await File.WriteAllBytesAsync(local, bytes, ct); + await File.WriteAllBytesAsync(local, bytes, ct).ConfigureAwait(false); } else if (type == "dir") { Directory.CreateDirectory(local); - await DownloadGitHubDirAsync(http, repo, $"{path}/{name}", local, ct); + await DownloadGitHubDirAsync(http, repo, $"{path}/{name}", local, ct).ConfigureAwait(false); } } } diff --git a/src/clawsharp/Cli/Skills/SkillsInstallCommand.cs b/src/clawsharp/Cli/Skills/SkillsInstallCommand.cs index 3026fbdd..e907ade7 100644 --- a/src/clawsharp/Cli/Skills/SkillsInstallCommand.cs +++ b/src/clawsharp/Cli/Skills/SkillsInstallCommand.cs @@ -23,7 +23,7 @@ public override async Task ExecuteAsync(CommandContext context, Settings se return 1; } - await SkillRegistry.InstallSkillAsync(settings.Name, cancellationToken); + await SkillRegistry.InstallSkillAsync(settings.Name, cancellationToken).ConfigureAwait(false); return 0; } } \ No newline at end of file diff --git a/src/clawsharp/Cli/StatusCommand.cs b/src/clawsharp/Cli/StatusCommand.cs index 5f4790c9..42ba3cd3 100644 --- a/src/clawsharp/Cli/StatusCommand.cs +++ b/src/clawsharp/Cli/StatusCommand.cs @@ -82,7 +82,7 @@ public override async Task ExecuteAsync(CommandContext context, Cancellatio AnsiConsole.WriteLine(); // Token totals (sum across all sessions) - var (totalIn, totalOut, sessionCount) = await ScanSessionTokensAsync(cancellationToken); + var (totalIn, totalOut, sessionCount) = await ScanSessionTokensAsync(cancellationToken).ConfigureAwait(false); AnsiConsole.MarkupLine($"[cyan]Sessions[/] : {sessionCount}"); if (totalIn > 0 || totalOut > 0) { @@ -111,7 +111,7 @@ public override async Task ExecuteAsync(CommandContext context, Cancellatio try { await using var stream = File.OpenRead(file); - var session = await JsonSerializer.DeserializeAsync(stream, SessionJsonContext.Default.Session, ct); + var session = await JsonSerializer.DeserializeAsync(stream, SessionJsonContext.Default.Session, ct).ConfigureAwait(false); return session is null ? (0L, 0L, 0) : (session.TotalInputTokens, session.TotalOutputTokens, 1); } catch diff --git a/src/clawsharp/Config/AppConfig.cs b/src/clawsharp/Config/AppConfig.cs index f2e7284a..81dcd420 100644 --- a/src/clawsharp/Config/AppConfig.cs +++ b/src/clawsharp/Config/AppConfig.cs @@ -3,6 +3,7 @@ using Clawsharp.Config.Channels; using Clawsharp.Config.Features; using Clawsharp.Config.Memory; +using Clawsharp.Config.Organization; using Clawsharp.Config.Security; using Clawsharp.A2a; using Clawsharp.Knowledge.Config; @@ -41,15 +42,39 @@ public sealed class AppConfig /// MCP server configurations keyed by server name. public Dictionary? McpServers { get; init; } + /// + /// MCP server mode configuration (exposing clawsharp tools to external MCP clients). + /// Null = disabled (zero overhead). + /// + public McpServerModeConfig? McpServer { get; init; } + /// Security configuration. public SecurityConfig? Security { get; init; } + /// + /// Organization and multi-user identity/policy configuration. + /// Null = single-operator mode (v1.5.0 behavior). + /// + public OrganizationConfig? Organization { get; init; } + /// At-rest secrets encryption configuration. public SecretsConfig? Secrets { get; init; } /// Voice transcription configuration (Groq Whisper / OpenAI Whisper), shared by all channels. public TranscriptionConfig? Transcription { get; init; } + /// + /// OpenTelemetry observability configuration (traces, metrics, logs). + /// Null = disabled (zero overhead). + /// + public TelemetryConfig? Telemetry { get; init; } + + /// + /// Webhook / event subscription system configuration. + /// Null = disabled (zero overhead). + /// + public WebhookConfig? Webhooks { get; init; } + /// HTTP request settings (proxy) for outbound LLM provider calls. public HttpRequestConfig? HttpRequest { get; init; } diff --git a/src/clawsharp/Config/ClawsharpConfiguration.cs b/src/clawsharp/Config/ClawsharpConfiguration.cs index 2270c36d..15ab63c3 100644 --- a/src/clawsharp/Config/ClawsharpConfiguration.cs +++ b/src/clawsharp/Config/ClawsharpConfiguration.cs @@ -167,6 +167,14 @@ string Resolve(string? value) { provider.ApiKey = Resolve(provider.ApiKey); provider.AwsSecretAccessKey = Resolve(provider.AwsSecretAccessKey); + + if (provider.ApiKeys is { } keys) + { + for (var i = 0; i < keys.Count; i++) + { + keys[i] = Resolve(keys[i]); + } + } } if (config.Transcription is { } t) diff --git a/src/clawsharp/Config/ConfigKeyValidator.cs b/src/clawsharp/Config/ConfigKeyValidator.cs index 00a7f5ae..5a2abb00 100644 --- a/src/clawsharp/Config/ConfigKeyValidator.cs +++ b/src/clawsharp/Config/ConfigKeyValidator.cs @@ -65,6 +65,9 @@ internal static class ConfigKeyValidator "agents.defaults.thinking.reasoningEffort", "agents.defaults.thinking.geminiBudgetTokens", + // agents.defaults.spawn + "agents.defaults.spawnTimeout", + // agents.defaults.modelRouting "agents.defaults.modelRouting.enabled", "agents.defaults.modelRouting.simpleModel", @@ -102,6 +105,7 @@ internal static class ConfigKeyValidator "memory.factExtraction.minChars", // tools — core + "tools.shellEnabled", "tools.workspace", "tools.requireShellApproval", "tools.enableShellDenyPatterns", diff --git a/src/clawsharp/Config/ConfigValidator.cs b/src/clawsharp/Config/ConfigValidator.cs index 849b7757..469c3b77 100644 --- a/src/clawsharp/Config/ConfigValidator.cs +++ b/src/clawsharp/Config/ConfigValidator.cs @@ -4,6 +4,7 @@ using Clawsharp.Config.Memory; using Clawsharp.Config.Organization; using Clawsharp.Config.Security; +using Clawsharp.Knowledge.Config; using Clawsharp.Security; namespace Clawsharp.Config; @@ -247,6 +248,31 @@ public static List Validate(AppConfig config) } } + // ── Knowledge ──────────────────────────────────────────────────────── + if (config.Knowledge is { Enabled: true }) + { + if (config.Memory.Embedding is null + || string.IsNullOrWhiteSpace(config.Memory.Embedding.Provider)) + { + errors.Add("knowledge is enabled but memory.embedding is not configured. " + + "Set memory.embedding.provider to 'openai' or 'ollama'."); + } + + ValidateChunkingConfig(errors, config.Knowledge.Chunking, "knowledge.chunking"); + + if (config.Knowledge.Sources is { Count: > 0 }) + { + for (var i = 0; i < config.Knowledge.Sources.Count; i++) + { + var source = config.Knowledge.Sources[i]; + if (source.Chunking is not null) + { + ValidateChunkingConfig(errors, source.Chunking, $"knowledge.sources[{i}].chunking"); + } + } + } + } + // ── Egress policy ──────────────────────────────────────────────────── if (config.Security?.Egress is { } egress) { @@ -434,6 +460,14 @@ private static void ValidateAbacRules(List errors, List rules) errors.Add($"{prefix}: duplicate ruleId '{effectiveId}'."); } + // Deny rules must specify when.tool (otherwise they silently match nothing) + if (rule.When is not null + && string.Equals(rule.Effect, AbacRule.Effects.Deny, StringComparison.Ordinal) + && rule.When.Tool is null) + { + errors.Add($"{prefix}: deny rules must specify when.tool (use '*' to deny all tools)."); + } + // Validate timeWindow entries if (rule.When?.TimeWindow is { ValueKind: System.Text.Json.JsonValueKind.Array } tw) { @@ -468,6 +502,24 @@ private static bool IsValidTimeWindow(string window) TimeOnly.TryParse(endPart, System.Globalization.CultureInfo.InvariantCulture, out _); } + /// + /// Validates chunking configuration: chunk size and overlap bounds. + /// + private static void ValidateChunkingConfig(List errors, ChunkingConfig? config, string prefix) + { + if (config is null) return; + + if (config.ChunkSize < 64) + { + errors.Add($"{prefix}.chunkSize must be at least 64 (got {config.ChunkSize})."); + } + + if (config.Overlap < 0.0 || config.Overlap >= 1.0) + { + errors.Add($"{prefix}.overlap must be in [0.0, 1.0) (got {config.Overlap})."); + } + } + /// /// Validates the telemetry configuration block: endpoint URI, protocol, sampling range, and log level. /// @@ -522,6 +574,13 @@ private static void ValidateMcpServerMode( if (string.IsNullOrWhiteSpace(keyEntry.User)) errors.Add($"mcpServer.apiKeys.{keyId}: 'user' must not be empty."); + // Validate secret minimum length when explicitly set + if (keyEntry.Secret is not null && keyEntry.Secret.Length < 32) + { + errors.Add($"mcpServer.apiKeys.{keyId}: 'secret' must be at least 32 characters " + + $"(got {keyEntry.Secret.Length}). Use: openssl rand -hex 32"); + } + // Validate that referenced user exists in org config (if org is configured) if (config.Organization is not null && !config.Organization.Users.ContainsKey(keyEntry.User)) diff --git a/src/clawsharp/Config/DotEnvConfigurationSource.cs b/src/clawsharp/Config/DotEnvConfigurationSource.cs index d45ebb8f..5be66680 100644 --- a/src/clawsharp/Config/DotEnvConfigurationSource.cs +++ b/src/clawsharp/Config/DotEnvConfigurationSource.cs @@ -38,7 +38,10 @@ public override void Load() var key = trimmed[..eq].Trim(); var value = trimmed[(eq + 1)..].Trim(); - // Strip optional surrounding quotes (double or single) + // Strip optional surrounding quotes (double or single). + // Escape sequences (\n, \", etc.) within quoted values are NOT unescaped; + // this is intentionally simpler than dotenv/godotenv. Use single-quoted + // values or avoid escape sequences if this is a concern. if (value.Length >= 2 && value[0] == '"' && value[^1] == '"') { value = value[1..^1]; diff --git a/src/clawsharp/Config/Features/McpServerModeConfig.cs b/src/clawsharp/Config/Features/McpServerModeConfig.cs index 009b671f..de423319 100644 --- a/src/clawsharp/Config/Features/McpServerModeConfig.cs +++ b/src/clawsharp/Config/Features/McpServerModeConfig.cs @@ -17,10 +17,13 @@ public sealed class McpServerModeConfig public string[]? AllowedOrigins { get; init; } /// - /// API keys for Bearer token authentication. Key = key identifier, Value = key config. + /// API keys for Bearer token authentication. Key = the bearer token itself, Value = key config. /// When null or empty in single-operator mode, auth is not required. + /// Dictionary keys cannot use enc2: encryption or op:// references because DecryptSecrets + /// can only mutate property values, not dictionary keys. Protect config.json with chmod 600 + /// and CLAWSHARP_SECRET_KEY for at-rest protection of the entire file. /// - public Dictionary? ApiKeys { get; init; } + public IReadOnlyDictionary? ApiKeys { get; init; } } /// @@ -33,4 +36,13 @@ public sealed class McpApiKeyEntry /// Optional description for operator reference. public string? Description { get; init; } + + /// + /// The bearer token secret for this key entry. When set, the bearer token is this value + /// rather than the dictionary key (keyId). This separates the human-readable identifier + /// from the credential, preventing keyId from leaking via logs, OTel spans, and cost records. + /// When null, the dictionary key is used as the bearer secret for backward compatibility + /// (deprecated — a warning is logged at startup). + /// + public string? Secret { get; init; } } diff --git a/src/clawsharp/Config/JsonContext.cs b/src/clawsharp/Config/JsonContext.cs index 34113d51..35f79c8c 100644 --- a/src/clawsharp/Config/JsonContext.cs +++ b/src/clawsharp/Config/JsonContext.cs @@ -5,6 +5,7 @@ using Clawsharp.Config.Channels; using Clawsharp.Config.Features; using Clawsharp.Config.Memory; +using Clawsharp.Config.Organization; using Clawsharp.Config.Search; using Clawsharp.Config.Security; using Clawsharp.A2a; @@ -59,6 +60,24 @@ namespace Clawsharp.Config; JsonSerializable(typeof(LandlockConfig)), JsonSerializable(typeof(EgressConfig)), JsonSerializable(typeof(EgressRule)), JsonSerializable(typeof(List)), JsonSerializable(typeof(EgressMode)), + // Organization + JsonSerializable(typeof(OrganizationConfig)), JsonSerializable(typeof(OrgUserConfig)), + JsonSerializable(typeof(PoliciesConfig)), JsonSerializable(typeof(RolePolicy)), + JsonSerializable(typeof(DepartmentConfig)), JsonSerializable(typeof(PolicyDefaults)), + JsonSerializable(typeof(AdminNotifyConfig)), JsonSerializable(typeof(BudgetLimits)), + JsonSerializable(typeof(AbacRule)), JsonSerializable(typeof(AbacCondition)), + JsonSerializable(typeof(IdpConfig)), JsonSerializable(typeof(ClaimsConfig)), + JsonSerializable(typeof(Dictionary)), + JsonSerializable(typeof(Dictionary)), + JsonSerializable(typeof(Dictionary)), + JsonSerializable(typeof(List)), + // Telemetry + JsonSerializable(typeof(TelemetryConfig)), + // MCP server mode + JsonSerializable(typeof(McpServerModeConfig)), JsonSerializable(typeof(McpApiKeyEntry)), + // Webhooks + JsonSerializable(typeof(WebhookConfig)), JsonSerializable(typeof(WebhookEndpointConfig)), + JsonSerializable(typeof(Dictionary)), // Intellenum config types JsonSerializable(typeof(DmPolicy)), JsonSerializable(typeof(GroupPolicy)), JsonSerializable(typeof(ReasoningEffort)), JsonSerializable(typeof(PromptGuardMode)), diff --git a/src/clawsharp/Config/Organization/ConfigMutator.cs b/src/clawsharp/Config/Organization/ConfigMutator.cs index d739cc7f..01da169d 100644 --- a/src/clawsharp/Config/Organization/ConfigMutator.cs +++ b/src/clawsharp/Config/Organization/ConfigMutator.cs @@ -1,5 +1,6 @@ using System.Text.Json; using System.Text.Json.Nodes; +using Microsoft.Extensions.Logging; namespace Clawsharp.Config.Organization; @@ -7,12 +8,19 @@ namespace Clawsharp.Config.Organization; /// Provides atomic read-modify-write operations on ~/.clawsharp/config.json. /// Serializes concurrent mutations with a per Pitfall #1. /// -public static class ConfigMutator +public static partial class ConfigMutator { private static readonly SemaphoreSlim Lock = new(1, 1); private static readonly JsonSerializerOptions WriteOptions = new() { WriteIndented = true }; + private static ILogger? _logger; + + /// + /// Sets the logger for ConfigMutator. Called once during DI setup. + /// + internal static void SetLogger(ILogger logger) => _logger = logger; + /// /// Reads config.json, applies to the parsed , /// and writes the result atomically via temp file + . @@ -44,7 +52,14 @@ internal static async Task MutateConfigAsync(string configPath, Action if (File.Exists(configPath)) { var json = await File.ReadAllTextAsync(configPath, ct).ConfigureAwait(false); - root = JsonNode.Parse(json); + if (!string.IsNullOrWhiteSpace(json)) + { + root = JsonNode.Parse(json); + } + else + { + LogEmptyConfigFile(_logger, configPath); + } } root ??= new JsonObject(); @@ -61,4 +76,11 @@ internal static async Task MutateConfigAsync(string configPath, Action Lock.Release(); } } + + [LoggerMessage(EventId = 1, Level = LogLevel.Warning, + Message = "Config file '{ConfigPath}' exists but is empty; treating as missing")] + private static partial void LogEmptyConfigFile(ILogger? logger, string configPath); } + +/// Marker type for in . +public sealed class ConfigMutatorLogger; diff --git a/src/clawsharp/Config/Organization/PolicyDefaults.cs b/src/clawsharp/Config/Organization/PolicyDefaults.cs index 4152f3f2..4a5ab458 100644 --- a/src/clawsharp/Config/Organization/PolicyDefaults.cs +++ b/src/clawsharp/Config/Organization/PolicyDefaults.cs @@ -6,11 +6,16 @@ namespace Clawsharp.Config.Organization; /// public sealed class PolicyDefaults { + /// + /// The fallback role name used when no explicit default is configured. + /// + public const string DefaultRoleName = "user"; + /// /// The role assigned to unknown senders when is false. /// Must reference a key in . /// - public string DefaultRole { get; init; } = "user"; + public string DefaultRole { get; init; } = DefaultRoleName; /// /// When true, unknown senders are denied with an explanatory message. diff --git a/src/clawsharp/Config/Security/SecurityConfig.cs b/src/clawsharp/Config/Security/SecurityConfig.cs index 70be53b2..75da82c5 100644 --- a/src/clawsharp/Config/Security/SecurityConfig.cs +++ b/src/clawsharp/Config/Security/SecurityConfig.cs @@ -88,9 +88,10 @@ public sealed class LeakDetectorConfig { /// /// Detection sensitivity (0.0–1.0, default 0.7). - /// At 0.0: only structural patterns (API keys, AWS, JWTs, private keys, DB URLs). + /// At 0.0: structural patterns only (API keys, AWS credentials, JWTs, private keys, DB URLs). /// Above 0.5: also generic secrets (password=, token=) and high-entropy tokens. - /// Set to 0 to disable leak detection entirely. + /// Structural-pattern detection cannot be disabled — this is intentional. + /// To minimize scan impact, set to 0.0. /// [System.ComponentModel.DataAnnotations.Range(0.0, 1.0)] public double Sensitivity { get; init; } = 0.7; diff --git a/src/clawsharp/Core/AgentStepExecutor.cs b/src/clawsharp/Core/AgentStepExecutor.cs index 1255bf71..fbda8377 100644 --- a/src/clawsharp/Core/AgentStepExecutor.cs +++ b/src/clawsharp/Core/AgentStepExecutor.cs @@ -80,7 +80,7 @@ public async Task ExecuteAsync( ChatResponse response; try { - response = await provider.ChatAsync(chatRequest, ct); + response = await provider.ChatAsync(chatRequest, ct).ConfigureAwait(false); } catch (Exception ex) { @@ -101,16 +101,29 @@ public async Task ExecuteAsync( messages.Add(new ChatMessage(MessageRole.Assistant, response.Content, ToolCalls: response.ToolCalls)); - foreach (var tc in response.ToolCalls) + if (response.ToolCalls.Count == 1) { + var tc = response.ToolCalls[0]; toolCallCount++; - - // Invoke the pre-execution callback if provided (e.g. to set RBAC context) request.BeforeToolExecution?.Invoke(tc); - - var result = await tools.ExecuteAsync(tc.Name, tc.ArgumentsJson, ct); + var result = await tools.ExecuteAsync(tc.Name, tc.ArgumentsJson, ct).ConfigureAwait(false); messages.Add(new ChatMessage(MessageRole.Tool, result, ToolCallId: tc.Id, Name: tc.Name)); } + else + { + var toolCalls = response.ToolCalls; + toolCallCount += toolCalls.Count; + foreach (var tc in toolCalls) + request.BeforeToolExecution?.Invoke(tc); + + var tasks = new Task[toolCalls.Count]; + for (var i = 0; i < toolCalls.Count; i++) + tasks[i] = tools.ExecuteAsync(toolCalls[i].Name, toolCalls[i].ArgumentsJson, ct); + + var results = await Task.WhenAll(tasks).ConfigureAwait(false); + for (var i = 0; i < toolCalls.Count; i++) + messages.Add(new ChatMessage(MessageRole.Tool, results[i], ToolCallId: toolCalls[i].Id, Name: toolCalls[i].Name)); + } chatRequest = chatRequest with { Messages = messages }; continue; @@ -151,6 +164,18 @@ public async IAsyncEnumerable StreamAsync( IToolRegistry tools, [EnumeratorCancellation] CancellationToken ct = default) { + // Mirror ExecuteAsync: capture parent context for ActivityLink, then create a new trace root. + var parentSpawnContext = Activity.Current?.Context; + Activity.Current = null; + var links = parentSpawnContext.HasValue + ? new[] { new ActivityLink(parentSpawnContext.Value) } + : null; + using var activity = ClawsharpActivitySources.Pipeline.StartActivity( + "agent.step", + ActivityKind.Internal, + parentContext: default(ActivityContext), + links: links); + var messages = new List { new(MessageRole.System, request.SystemPrompt), @@ -185,7 +210,7 @@ public async IAsyncEnumerable StreamAsync( // ── Streaming path ────────────────────────────────────────── // Consume the stream into collected events + tool builders via a non-yielding helper. // C# disallows yield inside try-catch, so we separate consumption from yielding. - var consumeResult = await ConsumeStreamAsync(sp, chatRequest, ct); + var consumeResult = await ConsumeStreamAsync(sp, chatRequest, ct).ConfigureAwait(false); if (consumeResult.Failed) { @@ -214,15 +239,36 @@ public async IAsyncEnumerable StreamAsync( messages.Add(new ChatMessage(MessageRole.Assistant, assistantText, ToolCalls: toolCalls)); - foreach (var tc in toolCalls) + if (toolCalls.Count == 1) { + var tc = toolCalls[0]; yield return new StreamEvent.ToolStart(tc.Name); request.BeforeToolExecution?.Invoke(tc); - var result = await tools.ExecuteAsync(tc.Name, tc.ArgumentsJson, ct); + var result = await tools.ExecuteAsync(tc.Name, tc.ArgumentsJson, ct).ConfigureAwait(false); yield return new StreamEvent.ToolResult(tc.Name, result); messages.Add(new ChatMessage(MessageRole.Tool, result, ToolCallId: tc.Id, Name: tc.Name)); } + else + { + foreach (var tc in toolCalls) + { + yield return new StreamEvent.ToolStart(tc.Name); + request.BeforeToolExecution?.Invoke(tc); + } + + var tasks = new Task[toolCalls.Count]; + for (var i = 0; i < toolCalls.Count; i++) + tasks[i] = tools.ExecuteAsync(toolCalls[i].Name, toolCalls[i].ArgumentsJson, ct); + + var results = await Task.WhenAll(tasks).ConfigureAwait(false); + for (var i = 0; i < toolCalls.Count; i++) + { + yield return new StreamEvent.ToolResult(toolCalls[i].Name, results[i]); + messages.Add(new ChatMessage(MessageRole.Tool, results[i], ToolCallId: toolCalls[i].Id, + Name: toolCalls[i].Name)); + } + } chatRequest = chatRequest with { Messages = messages }; continue; // next iteration @@ -249,7 +295,7 @@ public async IAsyncEnumerable StreamAsync( else { // ── Fallback path: non-streaming provider ─────────────────── - var fallbackResult = await CallChatAsync(provider, chatRequest, ct); + var fallbackResult = await CallChatAsync(provider, chatRequest, ct).ConfigureAwait(false); if (fallbackResult.Failed) { @@ -272,15 +318,36 @@ public async IAsyncEnumerable StreamAsync( messages.Add(new ChatMessage(MessageRole.Assistant, response.Content, ToolCalls: response.ToolCalls)); - foreach (var tc in response.ToolCalls) + if (response.ToolCalls.Count == 1) { + var tc = response.ToolCalls[0]; yield return new StreamEvent.ToolStart(tc.Name); request.BeforeToolExecution?.Invoke(tc); - var result = await tools.ExecuteAsync(tc.Name, tc.ArgumentsJson, ct); + var result = await tools.ExecuteAsync(tc.Name, tc.ArgumentsJson, ct).ConfigureAwait(false); yield return new StreamEvent.ToolResult(tc.Name, result); messages.Add(new ChatMessage(MessageRole.Tool, result, ToolCallId: tc.Id, Name: tc.Name)); } + else + { + foreach (var tc in response.ToolCalls) + { + yield return new StreamEvent.ToolStart(tc.Name); + request.BeforeToolExecution?.Invoke(tc); + } + + var tasks = new Task[response.ToolCalls.Count]; + for (var i = 0; i < response.ToolCalls.Count; i++) + tasks[i] = tools.ExecuteAsync(response.ToolCalls[i].Name, response.ToolCalls[i].ArgumentsJson, ct); + + var results = await Task.WhenAll(tasks).ConfigureAwait(false); + for (var i = 0; i < response.ToolCalls.Count; i++) + { + yield return new StreamEvent.ToolResult(response.ToolCalls[i].Name, results[i]); + messages.Add(new ChatMessage(MessageRole.Tool, results[i], ToolCallId: response.ToolCalls[i].Id, + Name: response.ToolCalls[i].Name)); + } + } chatRequest = chatRequest with { Messages = messages }; continue; // next iteration @@ -324,7 +391,7 @@ private async Task ConsumeStreamAsync( try { - await foreach (var chunk in sp.StreamAsync(chatRequest, ct)) + await foreach (var chunk in sp.StreamAsync(chatRequest, ct).ConfigureAwait(false)) { switch (chunk) { @@ -381,7 +448,7 @@ private async Task CallChatAsync( { try { - var response = await provider.ChatAsync(chatRequest, ct); + var response = await provider.ChatAsync(chatRequest, ct).ConfigureAwait(false); return new FallbackCallResult(response, Failed: false); } catch (Exception ex) when (ex is not OperationCanceledException) diff --git a/src/clawsharp/Core/Events/EventBus.cs b/src/clawsharp/Core/Events/EventBus.cs index a94d1279..7a5d3895 100644 --- a/src/clawsharp/Core/Events/EventBus.cs +++ b/src/clawsharp/Core/Events/EventBus.cs @@ -1,5 +1,6 @@ using System.Collections.Concurrent; using Microsoft.Extensions.Logging; +using Remora.Discord.API.Objects; namespace Clawsharp.Core.Events; diff --git a/src/clawsharp/Core/Hosting/HttpHostService.cs b/src/clawsharp/Core/Hosting/HttpHostService.cs index 3a17a3fb..449111e5 100644 --- a/src/clawsharp/Core/Hosting/HttpHostService.cs +++ b/src/clawsharp/Core/Hosting/HttpHostService.cs @@ -2,6 +2,7 @@ using Clawsharp.Core.Utilities; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; @@ -73,6 +74,7 @@ public async Task StartAsync(CancellationToken cancellationToken) // Kestrel configuration — max request body size for WebSocket support. builder.WebHost.ConfigureKestrel(options => { + options.AddServerHeader = false; // suppress "Server: Kestrel" disclosure options.Limits.MaxRequestBodySize = 1 * 1024 * 1024; // 1 MB }); @@ -84,6 +86,24 @@ public async Task StartAsync(CancellationToken cancellationToken) _app = builder.Build(); + // Global exception handler — prevents stack trace leakage regardless of environment. + _app.UseExceptionHandler(errApp => errApp.Run(async ctx => + { + ctx.Response.StatusCode = 500; + ctx.Response.ContentType = "text/plain"; + await ctx.Response.WriteAsync("Internal server error", ctx.RequestAborted).ConfigureAwait(false); + })); + + // Global security headers — runs before all registrar middleware/routes so + // A2A, webhook, and MCP endpoints get headers even when WebChannel is disabled. + _app.Use(async (context, next) => + { + ApplySecurityHeaders(context.Response); + if (_tls) + context.Response.Headers.StrictTransportSecurity = "max-age=31536000; includeSubDomains"; + await next(context).ConfigureAwait(false); + }); + // Let each registrar map middleware and routes. // Order matters: registrars are resolved in DI registration order. foreach (var registrar in registrarList) @@ -96,12 +116,12 @@ public async Task StartAsync(CancellationToken cancellationToken) try { - await _app.StartAsync(cancellationToken); + await _app.StartAsync(cancellationToken).ConfigureAwait(false); } catch (Exception ex) when (ex is not OperationCanceledException) { LogStartFailed(_logger, _port, ex); - await _app.DisposeAsync(); + await _app.DisposeAsync().ConfigureAwait(false); _app = null; } } @@ -110,7 +130,9 @@ public async Task StopAsync(CancellationToken cancellationToken) { if (_app is not null) { - await _app.StopAsync(cancellationToken); + await _app.StopAsync(cancellationToken).ConfigureAwait(false); + await _app.DisposeAsync().ConfigureAwait(false); + _app = null; } } @@ -118,7 +140,8 @@ public async ValueTask DisposeAsync() { if (_app is not null) { - await _app.DisposeAsync(); + await _app.DisposeAsync().ConfigureAwait(false); + _app = null; } } @@ -137,4 +160,17 @@ public async ValueTask DisposeAsync() Message = "TLS is enabled in config but Kestrel is not configured for TLS directly. " + "Configure a reverse proxy (nginx, Caddy, Traefik) to handle TLS termination on port {Port}.")] private static partial void LogTlsAdvisory(ILogger logger, int port); + + /// + /// Applies baseline security headers to all HTTP responses regardless of + /// which implementations are active. + /// + private static void ApplySecurityHeaders(HttpResponse response) + { + response.Headers.XContentTypeOptions = "nosniff"; + response.Headers["Referrer-Policy"] = "no-referrer"; + response.Headers.XFrameOptions = "DENY"; + response.Headers["Permissions-Policy"] = "camera=(), microphone=(), geolocation=(), usb=(), payment=()"; + response.Headers.XXSSProtection = "1; mode=block"; + } } diff --git a/src/clawsharp/Core/Pipeline/AgentLoop.OrgCommands.cs b/src/clawsharp/Core/Pipeline/AgentLoop.OrgCommands.cs index 6c2719c8..a4c0c1aa 100644 --- a/src/clawsharp/Core/Pipeline/AgentLoop.OrgCommands.cs +++ b/src/clawsharp/Core/Pipeline/AgentLoop.OrgCommands.cs @@ -4,6 +4,7 @@ using Clawsharp.Cost; using Clawsharp.Core.Sessions; using Clawsharp.Organization; +using Clawsharp.Tools; namespace Clawsharp.Core.Pipeline; @@ -50,7 +51,7 @@ private string HandleOrgExplain(Session session, string? argument) var abacRules = _appConfig.Organization.Policies?.Rules; if (abacRules is { Count: > 0 }) { - var ctx = new AbacContext(orgUser, Clawsharp.Core.Utilities.ChannelName.Cli, DateTimeOffset.UtcNow); + var ctx = new AbacContext(orgUser, Utilities.ChannelName.Cli, DateTimeOffset.UtcNow); abacDecision = evaluator.ApplyAbacRules(rbacDecision, abacRules, ctx); } @@ -94,7 +95,7 @@ private string HandleOrgSimulate(Session session, string? argument) var abacRules = _appConfig.Organization.Policies?.Rules; if (abacRules is { Count: > 0 }) { - var ctx = new AbacContext(orgUser, Clawsharp.Core.Utilities.ChannelName.Cli, DateTimeOffset.UtcNow); + var ctx = new AbacContext(orgUser, Utilities.ChannelName.Cli, DateTimeOffset.UtcNow); decision = evaluator.ApplyAbacRules(rbacDecision, abacRules, ctx); } else @@ -119,7 +120,7 @@ private string HandleOrgSimulate(Session session, string? argument) : (0m, 0m); decimal? deptMonthlyUsed = null; - Config.Organization.BudgetLimits? deptBudget = null; + BudgetLimits? deptBudget = null; if (orgUser.Department is not null) { var (_, deptMonthly) = _costTracker.GetScopeTotals($"dept:{orgUser.Department}"); @@ -311,7 +312,7 @@ internal static string HandleOrgUsage(Session session, string? argument, AppConf /// private async Task HandleOrgApproveAsync(Session session, string? argument, CancellationToken ct) { - var (success, message) = HandleOrgApprove(session, argument, _appConfig, _orgServices.ApprovalQueue); + var (success, message) = HandleOrgApprove(session, argument, _appConfig, _orgServices.ApprovalQueue, _tools.GetToolSensitivity); // Fire proactive notification on successful approval (D-04) if (success && argument is not null) @@ -335,7 +336,9 @@ private async Task HandleOrgApproveAsync(Session session, string? argume /// Handles /org approve <id> [--ttl <duration>] — admin-only (per D-18). /// Internal static for testability via InternalsVisibleTo. /// - internal static (bool Success, string Message) HandleOrgApprove(Session session, string? argument, AppConfig appConfig, ApprovalQueue approvalQueue) + internal static (bool Success, string Message) HandleOrgApprove( + Session session, string? argument, AppConfig appConfig, ApprovalQueue approvalQueue, + Func? getToolSensitivity = null) { if (appConfig.Organization is null) return (false, "Organization mode is not enabled."); @@ -375,6 +378,21 @@ internal static (bool Success, string Message) HandleOrgApprove(Session session, if (request.State != ApprovalState.Pending) return (false, "Request is no longer pending."); + // CVE-2026-33579 mitigation: validate that the approver's own policy allows this tool. + // An admin whose policy restricts them to low-sensitivity tools should not be able to + // approve requests for tools they cannot use themselves. + if (getToolSensitivity is not null + && session.CurrentPolicy is { } callerPolicy + && callerPolicy != PolicyDecision.Unrestricted) + { + var toolSensitivity = getToolSensitivity(request.ToolName); + var effect = callerPolicy.EvaluateToolAccess(request.ToolName, toolSensitivity); + // Allow if the tool would be Allowed or ApprovalRequired for the admin's own policy. + // Deny if DeniedBySensitivity, DeniedByGlob, or DeniedByAbac. + if (effect is not (PolicyEffect.Allowed or PolicyEffect.ApprovalRequired)) + return (false, $"Cannot approve '{request.ToolName}' — your own policy does not allow this tool."); + } + var grant = approvalQueue.Approve(requestId, session.CurrentUser.Name, ttl); if (grant is null) return (false, "Request is no longer pending."); @@ -477,13 +495,13 @@ await ConfigMutator.MutateConfigAsync(root => if (userNode is null) return; userNode["roles"] = new System.Text.Json.Nodes.JsonArray(newRole); - }, ct); + }, ct).ConfigureAwait(false); // Replace OrgUserConfig with a new instance carrying the updated role list. // Never mutate the shared List — concurrent readers may be iterating it. if (_appConfig.Organization is { } org && org.Users.TryGetValue(username, out var userConfig)) { - var updatedConfig = new Config.Organization.OrgUserConfig + var updatedConfig = new OrgUserConfig { Ids = new List(userConfig.Ids), Roles = [newRole], @@ -538,6 +556,23 @@ internal static (bool Success, string Message) HandleOrgSetRole(Session session, return (false, $"Role not found: {newRole}. Available roles: {available}"); } + // CVE-2026-33579 mitigation: validate that the caller's own policy is at least as + // permissive as the target role. An admin with restricted scope should not be able + // to assign a role that grants broader privileges than they themselves hold. + var targetRole = roles[newRole]; + if (session.CurrentPolicy is { } callerPolicy && callerPolicy != PolicyDecision.Unrestricted) + { + if (targetRole.IsUnrestrictedToolAccess && !callerPolicy.IsUnrestrictedToolAccess) + return (false, $"Cannot assign role '{newRole}' — it grants unrestricted tool access that exceeds your own policy."); + + var targetSensitivity = ToolSensitivityParser.Parse(targetRole.MaxToolSensitivity); + if (targetSensitivity > callerPolicy.MaxSensitivity) + return (false, $"Cannot assign role '{newRole}' — its tool sensitivity ceiling exceeds your own."); + + if (targetRole.IsUnrestrictedModels && !callerPolicy.IsUnrestrictedModels) + return (false, $"Cannot assign role '{newRole}' — it grants unrestricted model access that exceeds your own policy."); + } + return (true, $"Role updated: @{username} is now [{newRole}]. Change is effective immediately."); } @@ -693,13 +728,13 @@ await ConfigMutator.MutateConfigAsync(root => if (userNode is null) return; userNode["ids"] = new System.Text.Json.Nodes.JsonArray(); - }, ct); + }, ct).ConfigureAwait(false); // Replace OrgUserConfig with a new instance carrying an empty Ids list. // Never mutate the shared List — concurrent readers may be iterating it. if (_appConfig.Organization is { } org && org.Users.TryGetValue(username, out var userConfig)) { - var updatedConfig = new Config.Organization.OrgUserConfig + var updatedConfig = new OrgUserConfig { Ids = [], Roles = new List(userConfig.Roles), diff --git a/src/clawsharp/Core/Pipeline/AgentLoop.Pipeline.cs b/src/clawsharp/Core/Pipeline/AgentLoop.Pipeline.cs index a3a7bf19..88d0992f 100644 --- a/src/clawsharp/Core/Pipeline/AgentLoop.Pipeline.cs +++ b/src/clawsharp/Core/Pipeline/AgentLoop.Pipeline.cs @@ -67,6 +67,9 @@ private async Task> ApplyContextWindowGuardAsync( { LogContextWindowCompacting(estimated, contextWindow); + // Compaction bypasses the mediator pipeline intentionally: it requires + // direct provider/model parameters and post-compaction session mutation + // that don't fit the handler's command/result abstraction cleanly. if (compConfig.PreCompactionMemoryFlush) { var recentStart = Math.Max(1, messages.Count - compConfig.KeepRecent); @@ -276,12 +279,12 @@ await channel.SendAsync( if (loopResult.CacheRead > 0) { ClawsharpMetrics.TokenUsage.Record(loopResult.CacheRead, - new GenAiMetricTags { OperationName = "chat", Model = normalizedModel, TokenType = "cache_read" }); + new GenAiMetricTags { OperationName = "chat", Model = normalizedModel, TokenType = "input_cached" }); } // MET-02: LLM operation duration histogram ClawsharpMetrics.OperationDuration.Record(sw.Elapsed.TotalSeconds, - new GenAiMetricTags { OperationName = "chat", Model = normalizedModel, TokenType = "" }); + new DurationMetricTags { OperationName = "chat", Model = normalizedModel }); await _handlers.RecordUsage.HandleAsync(new RecordUsage.Command( sessionId, actualModel, inputDelta, outputDelta, @@ -293,17 +296,21 @@ await _handlers.RecordUsage.HandleAsync(new RecordUsage.Command( // Record interaction analytics (fire-and-forget — must not block the response pipeline). if (_analytics.Enabled && loopResult.Reply is not null) { + var sanitizedResponse = LeakDetector.Scan(loopResult.Reply).Redacted; + var sanitizedThinking = loopResult.Thinking is not null + ? LeakDetector.Scan(loopResult.Thinking).Redacted + : null; var interactionInput = new InteractionInput( SessionId: sessionId, Channel: inbound.Channel.Value, Model: actualModel, UserPrompt: messages.LastOrDefault(m => m.Role == MessageRole.User)?.Content ?? "", - Thinking: loopResult.Thinking, Response: loopResult.Reply, + Thinking: sanitizedThinking, Response: sanitizedResponse, ToolCalls: loopResult.ToolCallSummaries, ToolIterations: loopResult.ToolIterations, InputTokens: inputDelta, OutputTokens: outputDelta, CacheReadTokens: loopResult.CacheRead, CacheWriteTokens: loopResult.CacheWrite, DurationMs: sw.ElapsedMilliseconds); SpanIsolation.RunFireAndForget("analytics.record", ClawsharpActivitySources.Pipeline, async () => { - await _analytics.InteractionTracker.RecordAsync(interactionInput, CancellationToken.None); + await _analytics.InteractionTracker.RecordAsync(interactionInput, CancellationToken.None).ConfigureAwait(false); }); } @@ -505,14 +512,13 @@ private async Task PostProcessReplyAsync( { try { - var sb = new StringBuilder(); - for (var i = 0; i < audioChunks.Count; i++) + using var ms = new MemoryStream(); + foreach (var chunk in audioChunks) { - sb.Append(i < audioChunks.Count - 1 - ? audioChunks[i].TrimEnd('=') - : audioChunks[i]); + var bytes = Convert.FromBase64String(chunk); + ms.Write(bytes); } - var audioBytes = Convert.FromBase64String(sb.ToString()); + var audioBytes = ms.ToArray(); var audioExt = AudioAttachment.FormatToExtension(loopResult.AudioFormat ?? "wav"); PendingFileStore.Enqueue(new PendingFile($"generated-audio{audioExt}", audioBytes, loopResult.AudioTranscript)); } @@ -536,7 +542,7 @@ private async Task PostProcessReplyAsync( var messagesSnapshot = session.Messages.ToList(); SpanIsolation.RunFireAndForget("memory.consolidate", ClawsharpActivitySources.Memory, async () => { - await ConsolidateMemoryAsync(messagesSnapshot, CancellationToken.None); + await ConsolidateMemoryAsync(messagesSnapshot, CancellationToken.None).ConfigureAwait(false); }); } @@ -572,7 +578,7 @@ private void TriggerFactExtraction(string sessionId, string userText, string rep SpanIsolation.RunFireAndForget("memory.extract_facts", ClawsharpActivitySources.Memory, async () => { await _handlers.ExtractFacts.HandleAsync( - new ExtractFacts.Command(conversationText), CancellationToken.None); + new ExtractFacts.Command(conversationText), CancellationToken.None).ConfigureAwait(false); }); } @@ -615,7 +621,7 @@ private async Task FlushMemoryBeforeCompactionAsync( MaxTokens: 800 ); - var resp = await _provider.ChatAsync(flushReq, ct); + var resp = await _provider.ChatAsync(flushReq, ct).ConfigureAwait(false); if (resp.Content is { Length: > 0 } facts && !facts.Contains("(nothing to save)", StringComparison.OrdinalIgnoreCase)) { @@ -627,7 +633,7 @@ private async Task FlushMemoryBeforeCompactionAsync( facts = scrubResult.Redacted; } - await _memory.AppendHistoryAsync(facts, ct); + await _memory.AppendHistoryAsync(facts, ct).ConfigureAwait(false); LogPreCompactionFlushComplete(messagesToDiscard.Count, facts.Length); } } @@ -662,7 +668,7 @@ private async Task ConsolidateMemoryAsync(List messages, Cancellati MaxTokens: 500 ); - var summaryResp = await _provider.ChatAsync(summaryRequest, ct); + var summaryResp = await _provider.ChatAsync(summaryRequest, ct).ConfigureAwait(false); if (summaryResp.Content is { Length: > 0 } summary) { // Scrub secrets from LLM summary before persisting to memory @@ -673,7 +679,7 @@ private async Task ConsolidateMemoryAsync(List messages, Cancellati summary = scrubResult.Redacted; } - await _memory.AppendHistoryAsync(summary, ct); + await _memory.AppendHistoryAsync(summary, ct).ConfigureAwait(false); } } catch (Exception ex) diff --git a/src/clawsharp/Core/Pipeline/AgentLoop.SlashCommands.cs b/src/clawsharp/Core/Pipeline/AgentLoop.SlashCommands.cs index cee8de74..df328caf 100644 --- a/src/clawsharp/Core/Pipeline/AgentLoop.SlashCommands.cs +++ b/src/clawsharp/Core/Pipeline/AgentLoop.SlashCommands.cs @@ -23,14 +23,14 @@ public sealed partial class AgentLoop switch (cmd) { case SlashCommandResult.ClearSession: - await _handlers.ClearSession.HandleAsync(new ClearSession.Command(session), ct); + await _handlers.ClearSession.HandleAsync(new ClearSession.Command(session), ct).ConfigureAwait(false); return "Session cleared."; case SlashCommandResult.SendStatus: var factCount = 0; try { - var ctx = await _memory.GetContextAsync(ct); + var ctx = await _memory.GetContextAsync(ct).ConfigureAwait(false); factCount = ctx?.Split('\n').Length ?? 0; } catch @@ -58,10 +58,10 @@ public sealed partial class AgentLoop var msgs = new List(session.Messages); var compacted = await _compactionService.CompactAsync( msgs, _provider, _defaults.Model, - compConfig.KeepRecent, compConfig.MaxSummaryChars, compConfig.MaxSourceChars, ct); + compConfig.KeepRecent, compConfig.MaxSummaryChars, compConfig.MaxSourceChars, ct).ConfigureAwait(false); session.Messages.Clear(); session.Messages.AddRange(compacted.Where(m => m.Role != MessageRole.System)); - await _handlers.SaveSession.HandleAsync(new SaveSession.Command(session), ct); + await _handlers.SaveSession.HandleAsync(new SaveSession.Command(session), ct).ConfigureAwait(false); return $"Compacted: {msgs.Count} -> {session.Messages.Count} messages."; case SlashCommandResult.ShowUsage: @@ -70,7 +70,7 @@ public sealed partial class AgentLoop return "Cost tracking is not enabled.\nSet cost.enabled: true in config to enable it."; } - var summary = await _handlers.GetCostSummary.HandleAsync(new GetCostSummary.Query(session.Id), ct); + var summary = await _handlers.GetCostSummary.HandleAsync(new GetCostSummary.Query(session.Id), ct).ConfigureAwait(false); var usageSb = new StringBuilder(); usageSb.AppendLine($"Usage (today): ${summary.Daily:F4}"); usageSb.AppendLine($"Usage (this month): ${summary.Monthly:F4}"); @@ -101,7 +101,7 @@ public sealed partial class AgentLoop if (_provider is OpenRouterProvider orProvider) { - var keyInfo = await orProvider.GetKeyInfoAsync(ct); + var keyInfo = await orProvider.GetKeyInfoAsync(ct).ConfigureAwait(false); if (keyInfo is not null) { usageSb.AppendLine(); @@ -135,30 +135,30 @@ public sealed partial class AgentLoop case SlashCommandResult.ThinkOn: session.ShowThinking = true; - await _handlers.SaveSession.HandleAsync(new SaveSession.Command(session), ct); + await _handlers.SaveSession.HandleAsync(new SaveSession.Command(session), ct).ConfigureAwait(false); return "Thinking mode on. Reasoning blocks will be shown in replies."; case SlashCommandResult.ThinkOff: session.ShowThinking = false; - await _handlers.SaveSession.HandleAsync(new SaveSession.Command(session), ct); + await _handlers.SaveSession.HandleAsync(new SaveSession.Command(session), ct).ConfigureAwait(false); return "Thinking mode off."; case SlashCommandResult.ThinkToggle: session.ShowThinking = !session.ShowThinking; - await _handlers.SaveSession.HandleAsync(new SaveSession.Command(session), ct); + await _handlers.SaveSession.HandleAsync(new SaveSession.Command(session), ct).ConfigureAwait(false); return $"Thinking mode {(session.ShowThinking ? "on" : "off")}."; case SlashCommandResult.ShowGoals: - return await HandleGoalsCommandAsync(null, ct); + return await HandleGoalsCommandAsync(null, ct).ConfigureAwait(false); case SlashCommandResult.ClearGoals: - return await HandleGoalsCommandAsync("clear", ct); + return await HandleGoalsCommandAsync("clear", ct).ConfigureAwait(false); case SlashCommandResult.SetModel: - return await HandleModelCommandAsync(session, argument, ct); + return await HandleModelCommandAsync(session, argument, ct).ConfigureAwait(false); case SlashCommandResult.ListModels: - return await HandleListModelsCommandAsync(argument, ct); + return await HandleListModelsCommandAsync(argument, ct).ConfigureAwait(false); case SlashCommandResult.OrgExplain: return HandleOrgExplain(session, argument); @@ -176,7 +176,7 @@ public sealed partial class AgentLoop return HandleOrgUsage(session, argument, _appConfig, _costTracker); case SlashCommandResult.OrgApprove: - return await HandleOrgApproveAsync(session, argument, ct); + return await HandleOrgApproveAsync(session, argument, ct).ConfigureAwait(false); case SlashCommandResult.OrgDeny: return HandleOrgDeny(session, argument, _appConfig, _orgServices.ApprovalQueue); @@ -185,7 +185,7 @@ public sealed partial class AgentLoop return HandleOrgCancel(session, _appConfig, _orgServices.ApprovalQueue); case SlashCommandResult.OrgSetRole: - return await HandleOrgSetRoleAsync(session, argument, ct); + return await HandleOrgSetRoleAsync(session, argument, ct).ConfigureAwait(false); case SlashCommandResult.Link: return HandleLink(session, _appConfig, _orgServices.LinkTokenStore); @@ -194,7 +194,7 @@ public sealed partial class AgentLoop return HandleWhoami(session, _appConfig, _costTracker); case SlashCommandResult.OrgUnlink: - return await HandleOrgUnlinkAsync(session, argument, ct); + return await HandleOrgUnlinkAsync(session, argument, ct).ConfigureAwait(false); case SlashCommandResult.OrgUnknown: return "Unknown /org subcommand. Available: explain, simulate, status, usage, quota, approve, deny, cancel, set-role, unlink"; @@ -239,7 +239,7 @@ private async Task HandleModelCommandAsync(Session session, string? argu if (string.Equals(argument, "reset", StringComparison.OrdinalIgnoreCase)) { session.ModelOverride = null; - await _handlers.SaveSession.HandleAsync(new SaveSession.Command(session), ct); + await _handlers.SaveSession.HandleAsync(new SaveSession.Command(session), ct).ConfigureAwait(false); return $"Model reset to config default: {_defaults.Model}"; } @@ -252,7 +252,7 @@ private async Task HandleModelCommandAsync(Session session, string? argu return denial; session.ModelOverride = trimmedArg; - await _handlers.SaveSession.HandleAsync(new SaveSession.Command(session), ct); + await _handlers.SaveSession.HandleAsync(new SaveSession.Command(session), ct).ConfigureAwait(false); return $"Model set to: {session.ModelOverride} (for this session)"; } @@ -282,7 +282,7 @@ private async Task HandleGoalsCommandAsync(string? subcommand, Cancellat { try { - var goals = await _analytics.GoalStorage.LoadAsync(ct); + var goals = await _analytics.GoalStorage.LoadAsync(ct).ConfigureAwait(false); var cleared = 0; foreach (var g in goals.Where(g => g.Status == GoalStatus.Active || g.Status == GoalStatus.Paused)) { @@ -291,7 +291,7 @@ private async Task HandleGoalsCommandAsync(string? subcommand, Cancellat cleared++; } - await _analytics.GoalStorage.SaveAsync(goals, ct); + await _analytics.GoalStorage.SaveAsync(goals, ct).ConfigureAwait(false); return cleared > 0 ? $"Cleared {cleared} goal(s)." : "No active or paused goals to clear."; } catch (Exception) @@ -303,7 +303,7 @@ private async Task HandleGoalsCommandAsync(string? subcommand, Cancellat // Default: list active goals try { - var goals = await _analytics.GoalStorage.LoadAsync(ct); + var goals = await _analytics.GoalStorage.LoadAsync(ct).ConfigureAwait(false); var active = goals.Where(g => g.Status == GoalStatus.Active || g.Status == GoalStatus.Paused).ToList(); if (active.Count == 0) { @@ -334,7 +334,7 @@ private async Task HandleListModelsCommandAsync(string? argument, Cancel return "Model listing is currently only available for the OpenRouter provider."; } - var allModels = await modelsProvider.ListModelsAsync(ct); + var allModels = await modelsProvider.ListModelsAsync(ct).ConfigureAwait(false); if (allModels.Count == 0) { return "Unable to fetch models from OpenRouter."; @@ -433,6 +433,9 @@ private Task HandleKnowledgeStatusAsync(CancellationToken ct) /// /// Handles /knowledge ingest — delegates to /// when the knowledge system is enabled; returns an informative message otherwise. + /// No admin gate: knowledge commands are available to all authenticated users because + /// operators control which sources exist via config. Users can only re-trigger ingestion + /// of operator-configured sources, not specify arbitrary paths. /// private Task HandleKnowledgeIngestAsync(string? argument, CancellationToken ct) { diff --git a/src/clawsharp/Core/Pipeline/AgentLoop.Streaming.cs b/src/clawsharp/Core/Pipeline/AgentLoop.Streaming.cs index cec026ec..bc07b40e 100644 --- a/src/clawsharp/Core/Pipeline/AgentLoop.Streaming.cs +++ b/src/clawsharp/Core/Pipeline/AgentLoop.Streaming.cs @@ -5,6 +5,7 @@ using Clawsharp.Channels; using Clawsharp.Cost; using Clawsharp.Providers; +using Clawsharp.Security; using Clawsharp.Telemetry; using Clawsharp.Core.Services; using Clawsharp.Core.Sessions; @@ -72,15 +73,15 @@ private async Task RunStreamingLoopAsync( // Forward text deltas to the channel while consuming. try { - await streamingChannel.StreamAsync(outbound, pipe.Reader.ReadAllAsync(ct), ct); + await streamingChannel.StreamAsync(outbound, pipe.Reader.ReadAllAsync(ct), ct).ConfigureAwait(false); } - catch (Exception ex) + catch (Exception ex) when (ex is not OperationCanceledException) { LogStreamingChannelError(_logger, ex); } // Wait for the producer to finish accumulating tool calls. - var result = await consumeTask; + var result = await consumeTask.ConfigureAwait(false); // ── Telemetry: post-call LLM span enrichment ───────────────── llmActivity?.SetTag(GenAiAttributes.UsageInputTokens, result.InputTokens); @@ -126,23 +127,12 @@ private async Task RunStreamingLoopAsync( // MET-07 / LLM-04: TPOT histogram (average inter-token latency) var tpot = StreamingMetricsHelper.ComputeTpot(result.StreamDuration, result.Ttft ?? TimeSpan.Zero, result.OutputTokens); - if (result.Ttft is not null && tpot is { } tpotValue) + if (result.Ttft is not null && tpot is { } tpotValue && tpotValue >= 0) { ClawsharpMetrics.Tpot.Record(tpotValue, new StreamingMetricTags { Model = normalizedModel, Channel = channelName }); } - // D-12: Token usage + duration (same as non-streaming path) - ClawsharpMetrics.TokenUsage.Record(result.InputTokens, - new GenAiMetricTags { OperationName = "chat", Model = normalizedModel, TokenType = "input" }); - ClawsharpMetrics.TokenUsage.Record(result.OutputTokens, - new GenAiMetricTags { OperationName = "chat", Model = normalizedModel, TokenType = "output" }); - if (result.CacheReadTokens > 0) - ClawsharpMetrics.TokenUsage.Record(result.CacheReadTokens, - new GenAiMetricTags { OperationName = "chat", Model = normalizedModel, TokenType = "cache_read" }); - ClawsharpMetrics.OperationDuration.Record(result.StreamDuration.TotalSeconds, - new GenAiMetricTags { OperationName = "chat", Model = normalizedModel, TokenType = "" }); - // Update session token counts from streaming usage data. session.TotalInputTokens += result.InputTokens; session.TotalOutputTokens += result.OutputTokens; @@ -203,15 +193,21 @@ private async Task RunStreamingLoopAsync( if (toolCalls?.Count > 0) { completedIterations++; - toolCallSummaries ??= []; - foreach (var tc in toolCalls) - { - toolCallSummaries.Add(new ToolCallSummary { Name = tc.Name, ResultLength = tc.ArgumentsJson.Length }); - } // Add the assistant's turn (which may include streaming text + tool calls) to history. messages.Add(new ChatMessage(MessageRole.Assistant, assistantText, ToolCalls: toolCalls)); - await ExecuteToolCallsAsync(toolCalls, messages, ct); + await ExecuteToolCallsAsync(toolCalls, messages, ct).ConfigureAwait(false); + + // Build summaries from actual tool results (last N messages are tool results). + toolCallSummaries ??= []; + for (var i = messages.Count - toolCalls.Count; i < messages.Count; i++) + { + toolCallSummaries.Add(new ToolCallSummary + { + Name = messages[i].Name ?? "unknown", + ResultLength = messages[i].Content?.Length ?? 0 + }); + } request = request with { Messages = messages }; continue; // next streaming iteration @@ -255,8 +251,10 @@ private async Task ConsumeProviderStreamAsync( bool showThinking, CancellationToken ct) { + const int leakScanBufferThreshold = 512; var textSb = new StringBuilder(); var thinkingSb = new StringBuilder(); + var streamLeakBuffer = new StringBuilder(leakScanBufferThreshold); var emittedThinkingOpen = false; var toolBuilders = new Dictionary(); var inputTokens = 0; @@ -275,7 +273,7 @@ private async Task ConsumeProviderStreamAsync( try { - await foreach (var chunk in _fallbackChain.ExecuteStreamAsync(candidates, request, ct, ApplyModelOverride)) + await foreach (var chunk in _fallbackChain.ExecuteStreamAsync(candidates, request, ct, ApplyModelOverride).ConfigureAwait(false)) { switch (chunk) { @@ -290,11 +288,17 @@ private async Task ConsumeProviderStreamAsync( if (emittedThinkingOpen) { emittedThinkingOpen = false; - await pipeWriter.WriteAsync("\n\n\n", ct); + await pipeWriter.WriteAsync("\n\n\n", ct).ConfigureAwait(false); } textSb.Append(td.Delta); - await pipeWriter.WriteAsync(td.Delta, ct); + streamLeakBuffer.Append(td.Delta); + if (streamLeakBuffer.Length >= leakScanBufferThreshold) + { + var scanned = LeakDetector.Scan(streamLeakBuffer.ToString()); + await pipeWriter.WriteAsync(scanned.Redacted, ct).ConfigureAwait(false); + streamLeakBuffer.Clear(); + } break; case ThinkingDeltaChunk tk: @@ -307,10 +311,10 @@ private async Task ConsumeProviderStreamAsync( if (!emittedThinkingOpen) { emittedThinkingOpen = true; - await pipeWriter.WriteAsync("\n", ct); + await pipeWriter.WriteAsync("\n", ct).ConfigureAwait(false); } - await pipeWriter.WriteAsync(tk.Delta, ct); + await pipeWriter.WriteAsync(tk.Delta, ct).ConfigureAwait(false); } break; @@ -367,7 +371,7 @@ private async Task ConsumeProviderStreamAsync( if (emittedThinkingOpen) { emittedThinkingOpen = false; - await pipeWriter.WriteAsync("\n\n\n", ct); + await pipeWriter.WriteAsync("\n\n\n", ct).ConfigureAwait(false); } break; @@ -387,6 +391,11 @@ private async Task ConsumeProviderStreamAsync( } finally { + if (streamLeakBuffer.Length > 0) + { + var scanned = LeakDetector.Scan(streamLeakBuffer.ToString()); + await pipeWriter.WriteAsync(scanned.Redacted, ct).ConfigureAwait(false); + } pipeWriter.Complete(); } @@ -411,19 +420,16 @@ private async Task ConsumeProviderStreamAsync( return null; } - return toolBuilders - .OrderBy(kv => kv.Key) - .Select(kv => - { - var args = "{}"; - if (kv.Value.Args.Length > 0) - { - args = kv.Value.Args.ToString(); - } - - return new ToolCall(kv.Value.Id, kv.Value.Name, args); - }) - .ToList(); + var sortedKeys = toolBuilders.Keys.ToArray(); + Array.Sort(sortedKeys); + var result = new List(sortedKeys.Length); + foreach (var idx in sortedKeys) + { + var (id, name, args) = toolBuilders[idx]; + result.Add(new ToolCall(id, name, args.Length > 0 ? args.ToString() : "{}")); + } + + return result; } /// diff --git a/src/clawsharp/Core/Pipeline/AgentLoop.ToolExecution.cs b/src/clawsharp/Core/Pipeline/AgentLoop.ToolExecution.cs index 7656c92b..d88c1fec 100644 --- a/src/clawsharp/Core/Pipeline/AgentLoop.ToolExecution.cs +++ b/src/clawsharp/Core/Pipeline/AgentLoop.ToolExecution.cs @@ -17,6 +17,11 @@ public sealed partial class AgentLoop /// streaming and non-streaming loops. When multiple tool calls are present, they /// are executed concurrently via and results are /// appended in the original order for deterministic behavior. + /// + /// Tool execution bypasses the mediator pipeline intentionally — authorization + /// and RBAC filtering are enforced directly by at + /// definition-time (GetFilteredDefinitions) and execution-time (ExecuteAsync). + /// /// private async Task ExecuteToolCallsAsync( IReadOnlyList toolCalls, @@ -31,7 +36,7 @@ private async Task ExecuteToolCallsAsync( { var tc = toolCalls[0]; LogToolExecution(_logger, tc.Name, tc.ArgumentsJson[..Math.Min(ToolArgsLogPreviewLength, tc.ArgumentsJson.Length)]); - var result = await _tools.ExecuteAsync(tc.Name, tc.ArgumentsJson, ct); + var result = await _tools.ExecuteAsync(tc.Name, tc.ArgumentsJson, ct).ConfigureAwait(false); result = ApplyToolResultGuard(tc, result, ct); messages.Add(new ChatMessage(MessageRole.Tool, result, ToolCallId: tc.Id, Name: tc.Name)); } @@ -45,7 +50,7 @@ private async Task ExecuteToolCallsAsync( tasks[i] = _tools.ExecuteAsync(tc.Name, tc.ArgumentsJson, ct); } - var results = await Task.WhenAll(tasks); + var results = await Task.WhenAll(tasks).ConfigureAwait(false); for (var i = 0; i < toolCalls.Count; i++) { diff --git a/src/clawsharp/Core/Pipeline/AgentLoop.cs b/src/clawsharp/Core/Pipeline/AgentLoop.cs index 167c5a1f..80afb6bb 100644 --- a/src/clawsharp/Core/Pipeline/AgentLoop.cs +++ b/src/clawsharp/Core/Pipeline/AgentLoop.cs @@ -94,14 +94,14 @@ public sealed partial class AgentLoop // ── Analytics, goals, and fact extraction — grouped behind AnalyticsServices ── private readonly AnalyticsServices _analytics; - /// Lazily-built candidate list for provider fallback. Built once on first use. - private IReadOnlyList<(string Name, IProvider Provider)>? _fallbackCandidates; + /// Pre-built candidate list for provider fallback. Built once in the constructor. + private readonly IReadOnlyList<(string Name, IProvider Provider)> _fallbackCandidates; - /// Lazily-built candidate list filtered to streaming providers only. Built alongside . - private IReadOnlyList<(string Name, IStreamingProvider Provider)>? _streamingFallbackCandidates; + /// Pre-built candidate list filtered to streaming providers only. Built alongside . + private readonly IReadOnlyList<(string Name, IStreamingProvider Provider)> _streamingFallbackCandidates; /// Per-fallback model overrides keyed by provider name. Built alongside . - private Dictionary? _fallbackModelOverrides; + private readonly Dictionary _fallbackModelOverrides; /// Result of a single tool-loop execution (streaming or non-streaming). private sealed record LoopResult( @@ -169,6 +169,8 @@ public AgentLoop( _webhookSlashCommandHandler = webhookSlashCommandHandler; _knowledgeSlashCommandHandler = knowledgeSlashCommandHandler; + (_fallbackCandidates, _streamingFallbackCandidates, _fallbackModelOverrides) = BuildFallbackCandidates(); + // MET-05: active session gauge — reports _sessionPipelines.Count on each scrape ClawsharpMetrics.InitializeSessionGauge(() => _sessionPipelines.Count); } @@ -178,23 +180,28 @@ public async Task RunAsync(IMessageBus bus, CancellationToken ct = default) // Dispatch each inbound message to the owning session's pipeline. // Lazy guarantees StartSessionPipeline runs exactly once per key, // even if multiple threads race on GetOrAdd for the same session. - await foreach (var inbound in bus.ReadAllAsync(ct)) + await foreach (var inbound in bus.ReadAllAsync(ct).ConfigureAwait(false)) { var sessionId = $"{inbound.Channel.Value}:{inbound.SenderId}"; var lazy = _sessionPipelines.GetOrAdd(sessionId, k => new Lazy<(Channel, Task)>(() => StartSessionPipeline(k, ct))); - await lazy.Value.Ch.Writer.WriteAsync(inbound, ct); + await lazy.Value.Ch.Writer.WriteAsync(inbound, ct).ConfigureAwait(false); } // Await all drain tasks so exceptions are observed on shutdown. - // Use a 5-second timeout so in-flight LLM calls don't block exit. + // Use a dedicated 5-second timeout since `ct` is already cancelled at this point. + using var drainCts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); foreach (var kvp in _sessionPipelines) { if (kvp.Value.IsValueCreated) { try { - await kvp.Value.Value.DrainTask.WaitAsync(TimeSpan.FromSeconds(5), ct); + await kvp.Value.Value.DrainTask.WaitAsync(drainCts.Token).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + // 5-second drain window elapsed — abandon remaining in-flight work. } catch (TimeoutException) { @@ -218,13 +225,30 @@ public async Task RunAsync(IMessageBus bus, CancellationToken ct = default) /// Processes all messages for one session in arrival order. /// Runs until the channel is completed or is cancelled. /// + /// How long a session pipeline waits for the next message before self-evicting. + private static readonly TimeSpan SessionIdleTimeout = TimeSpan.FromMinutes(30); + private async Task DrainSessionAsync(string sessionId, ChannelReader reader, CancellationToken ct) { try { - await foreach (var inbound in reader.ReadAllAsync(ct)) + while (!ct.IsCancellationRequested) { - await ProcessMessageAsync(inbound, ct); + using var idleCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + idleCts.CancelAfter(SessionIdleTimeout); + + InboundMessage inbound; + try + { + inbound = await reader.ReadAsync(idleCts.Token).ConfigureAwait(false); + } + catch (OperationCanceledException) when (!ct.IsCancellationRequested) + { + // Idle timeout — evict this session pipeline. + break; + } + + await ProcessMessageAsync(inbound, ct).ConfigureAwait(false); } } finally @@ -292,6 +316,7 @@ internal async Task ProcessMessageAsync(InboundMessage inbound, CancellationToke // Start thinking indicator (best-effort, fire-and-forget style). var thinkingIndicator = channel as IThinkingIndicator; + var messageSw = Stopwatch.StartNew(); try { if (thinkingIndicator is not null) @@ -617,6 +642,10 @@ await channel.SendAsync( } finally { + messageSw.Stop(); + ClawsharpMetrics.MessageDuration.Record(messageSw.Elapsed.TotalSeconds, + new PipelineMetricTags { Channel = inbound.Channel.Value }); + // Stop thinking indicator (best-effort). try { @@ -677,7 +706,7 @@ private async Task RunNonStreamingLoopAsync( response = await _fallbackChain.ExecuteAsync( candidates, (name, provider, token) => provider.ChatAsync(ApplyModelOverride(name, request), token), - ct); + ct).ConfigureAwait(false); } catch (FallbackExhaustedException ex) { @@ -738,15 +767,20 @@ private async Task RunNonStreamingLoopAsync( if (response.ToolCalls?.Count > 0) { completedIterations++; + messages.Add(new ChatMessage(MessageRole.Assistant, response.Content, ToolCalls: response.ToolCalls)); + await ExecuteToolCallsAsync(response.ToolCalls, messages, ct).ConfigureAwait(false); + + // Build summaries from actual tool results (last N messages are tool results). toolCallSummaries ??= []; - foreach (var tc in response.ToolCalls) + for (var i = messages.Count - response.ToolCalls.Count; i < messages.Count; i++) { - toolCallSummaries.Add(new ToolCallSummary { Name = tc.Name, ResultLength = tc.ArgumentsJson.Length }); + toolCallSummaries.Add(new ToolCallSummary + { + Name = messages[i].Name ?? "unknown", + ResultLength = messages[i].Content?.Length ?? 0 + }); } - messages.Add(new ChatMessage(MessageRole.Assistant, response.Content, ToolCalls: response.ToolCalls)); - await ExecuteToolCallsAsync(response.ToolCalls, messages, ct); - request = request with { Messages = messages }; continue; } @@ -768,17 +802,24 @@ private async Task RunNonStreamingLoopAsync( // Fallback candidate management // ────────────────────────────────────────────────────────────────────── + /// + /// Returns the pre-built ordered candidate list for the fallback chain. + /// + private IReadOnlyList<(string Name, IProvider Provider)> GetFallbackCandidates() => _fallbackCandidates; + + /// + /// Returns the pre-built ordered candidate list filtered to streaming providers only. + /// + private IReadOnlyList<(string Name, IStreamingProvider Provider)> GetStreamingFallbackCandidates() => _streamingFallbackCandidates; + /// /// Builds the ordered candidate list for the fallback chain: primary provider first, - /// then each configured fallback provider. Built once and cached. + /// then each configured fallback provider. Called once from the constructor. /// - private IReadOnlyList<(string Name, IProvider Provider)> GetFallbackCandidates() + private (IReadOnlyList<(string Name, IProvider Provider)>, + IReadOnlyList<(string Name, IStreamingProvider Provider)>, + Dictionary) BuildFallbackCandidates() { - if (_fallbackCandidates is not null) - { - return _fallbackCandidates; - } - var candidates = new List<(string Name, IProvider Provider)> { (_defaults.Provider, _provider) @@ -824,27 +865,12 @@ private async Task RunNonStreamingLoopAsync( } } - _fallbackModelOverrides = modelOverrides; - _fallbackCandidates = candidates; - _streamingFallbackCandidates = candidates - .Where(c => c.Provider is IStreamingProvider) - .Select(c => (c.Name, (IStreamingProvider)c.Provider)) - .ToList(); - return _fallbackCandidates; - } - - /// - /// Builds the ordered candidate list filtered to streaming providers only. - /// - private IReadOnlyList<(string Name, IStreamingProvider Provider)> GetStreamingFallbackCandidates() - { - if (_streamingFallbackCandidates is not null) - { - return _streamingFallbackCandidates; - } + var streamingCandidates = candidates + .Where(c => c.Provider is IStreamingProvider) + .Select(c => (c.Name, (IStreamingProvider)c.Provider)) + .ToList(); - GetFallbackCandidates(); - return _streamingFallbackCandidates!; + return (candidates, streamingCandidates, modelOverrides); } /// @@ -853,8 +879,7 @@ private async Task RunNonStreamingLoopAsync( /// private ChatRequest ApplyModelOverride(string candidateName, ChatRequest request) { - if (_fallbackModelOverrides is not null - && _fallbackModelOverrides.TryGetValue(candidateName, out var modelOverride)) + if (_fallbackModelOverrides.TryGetValue(candidateName, out var modelOverride)) { return request with { Model = modelOverride }; } @@ -876,6 +901,12 @@ internal static List MergeConsecutiveRoles(List messag return messages; } + // Fast path: skip allocation when no adjacent same-role messages need merging. + if (!NeedsMerge(messages)) + { + return messages; + } + var result = new List(messages.Count); result.Add(messages[0]); @@ -884,12 +915,16 @@ internal static List MergeConsecutiveRoles(List messag var current = messages[i]; var previous = result[^1]; - // Only merge user<->user or assistant<->assistant (not system, not tool) + // Only merge user<->user or assistant<->assistant (not system, not tool). + // Never merge messages that carry multimodal attachments — the `with` expression + // would silently drop the current message's images/files/videos/audio. if (current.Role == previous.Role && current.Role != MessageRole.System && current.Role != MessageRole.Tool && current.ToolCalls is null // don't merge assistant messages that have tool calls - && previous.ToolCalls is null) + && previous.ToolCalls is null + && !HasAttachments(current) + && !HasAttachments(previous)) { var merged = (previous.Content ?? "") + "\n\n" + (current.Content ?? ""); result[^1] = previous with { Content = merged.Trim() }; @@ -903,6 +938,31 @@ internal static List MergeConsecutiveRoles(List messag return result; } + private static bool NeedsMerge(List messages) + { + for (var i = 1; i < messages.Count; i++) + { + var current = messages[i]; + var previous = messages[i - 1]; + if (current.Role == previous.Role + && current.Role != MessageRole.System + && current.Role != MessageRole.Tool + && current.ToolCalls is null + && previous.ToolCalls is null + && !HasAttachments(current) + && !HasAttachments(previous)) + { + return true; + } + } + + return false; + } + + private static bool HasAttachments(ChatMessage m) => + m.Images is { Count: > 0 } || m.Files is { Count: > 0 } || + m.Videos is { Count: > 0 } || m.Audio is not null; + // ────────────────────────────────────────────────────────────────────── // LoggerMessage declarations // ────────────────────────────────────────────────────────────────────── diff --git a/src/clawsharp/Core/Pipeline/AgentLoopService.cs b/src/clawsharp/Core/Pipeline/AgentLoopService.cs index 429e3616..ffd4f974 100644 --- a/src/clawsharp/Core/Pipeline/AgentLoopService.cs +++ b/src/clawsharp/Core/Pipeline/AgentLoopService.cs @@ -25,7 +25,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) try { - await agentLoop.RunAsync(bus, stoppingToken); + await agentLoop.RunAsync(bus, stoppingToken).ConfigureAwait(false); } catch (OperationCanceledException) { diff --git a/src/clawsharp/Core/Pipeline/SystemPrompt.cs b/src/clawsharp/Core/Pipeline/SystemPrompt.cs index 9a735518..57bbf56f 100644 --- a/src/clawsharp/Core/Pipeline/SystemPrompt.cs +++ b/src/clawsharp/Core/Pipeline/SystemPrompt.cs @@ -12,7 +12,7 @@ public static string Build( string? memoryContext = null, string? workspaceContext = null, string? channelName = null, - IReadOnlyList? enabledTools = null, + IEnumerable? enabledTools = null, string? activeGoalsContext = null) { var (staticPart, dynamicPart) = BuildSplit(memoryContext, workspaceContext, channelName, enabledTools, activeGoalsContext); @@ -42,7 +42,7 @@ public static (string StaticPart, string DynamicPart) BuildSplit( string? memoryContext = null, string? workspaceContext = null, string? channelName = null, - IReadOnlyList? enabledTools = null, + IEnumerable? enabledTools = null, string? activeGoalsContext = null) { var sb = new StringBuilder(); @@ -56,10 +56,14 @@ public static (string StaticPart, string DynamicPart) BuildSplit( sb.AppendLine("You are clawsharp, a helpful AI assistant running on the user's own hardware."); sb.AppendLine("Be concise, accurate, and helpful. When using tools, prefer the minimum necessary."); - if (enabledTools is { Count: > 0 }) + if (enabledTools is not null) { - sb.AppendLine(); - sb.AppendLine($"Available tools: {string.Join(", ", enabledTools)}"); + var toolList = string.Join(", ", enabledTools); + if (toolList.Length > 0) + { + sb.AppendLine(); + sb.AppendLine($"Available tools: {toolList}"); + } } if (!string.IsNullOrWhiteSpace(memoryContext)) diff --git a/src/clawsharp/Core/Resilience/ChannelResilienceExtensions.cs b/src/clawsharp/Core/Resilience/ChannelResilienceExtensions.cs index 9aa48d90..a02d5f6c 100644 --- a/src/clawsharp/Core/Resilience/ChannelResilienceExtensions.cs +++ b/src/clawsharp/Core/Resilience/ChannelResilienceExtensions.cs @@ -2,6 +2,7 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Polly; +using Polly.Registry; using Polly.Retry; namespace Clawsharp.Core.Resilience; diff --git a/src/clawsharp/Core/Security/AdminRoleFilter.cs b/src/clawsharp/Core/Security/AdminRoleFilter.cs index c1eeae54..e4338920 100644 --- a/src/clawsharp/Core/Security/AdminRoleFilter.cs +++ b/src/clawsharp/Core/Security/AdminRoleFilter.cs @@ -1,3 +1,4 @@ +using Clawsharp.Config.Organization; using Clawsharp.McpServer; using Clawsharp.Organization; using Microsoft.AspNetCore.Http; @@ -10,7 +11,6 @@ namespace Clawsharp.Core.Security; /// stored by that filter from . /// Passes through when: /// - The policy is (single-operator implicit admin), or -/// - The policy has , or /// - The resolved user has at least one role with . /// Returns HTTP 403 (not 401) for authenticated but non-admin users per Pitfall 4. /// Per D-24, D-26 of the v2.3 webhook design. @@ -26,15 +26,13 @@ public sealed class AdminRoleFilter : IEndpointFilter // D-26: Unrestricted policy = single-operator implicit admin if (authResult.PolicyDecision == PolicyDecision.Unrestricted) - return await next(ctx); + return await next(ctx).ConfigureAwait(false); - // IsUnrestrictedToolAccess: granted when any role gives full tool access - if (authResult.PolicyDecision.IsUnrestrictedToolAccess) - return await next(ctx); - - // Check if user has any admin role in resolved policies + // Check if user has any admin role in resolved policies. + // Note: IsUnrestrictedToolAccess alone does NOT grant admin access — "can use all tools" + // is a separate concern from "can administer the system" (CWE-863 mitigation). if (authResult.User?.ResolvedPolicies.Any(p => p.IsAdmin) == true) - return await next(ctx); + return await next(ctx).ConfigureAwait(false); // Authenticated but not admin — return 403 (not 401, not Results.Forbid() which triggers // challenge middleware per research Pitfall 4) diff --git a/src/clawsharp/Core/Security/ApiKeyAuthenticator.cs b/src/clawsharp/Core/Security/ApiKeyAuthenticator.cs index 8d0a0c20..4125df48 100644 --- a/src/clawsharp/Core/Security/ApiKeyAuthenticator.cs +++ b/src/clawsharp/Core/Security/ApiKeyAuthenticator.cs @@ -49,12 +49,20 @@ public ApiKeyAuthenticator( _requireAuth = config?.ApiKeys is not null || oidcService is not null; // Pre-compute UTF-8 bytes for constant-time comparison (Pitfall 3). + // When entry.Secret is set, use it as the bearer token value (separates keyId from credential). + // When entry.Secret is null, fall back to keyId as the bearer token (backward compat, deprecated). _apiKeyBytes = []; if (config?.ApiKeys is not null) { foreach (var (keyId, entry) in config.ApiKeys) { - _apiKeyBytes.Add((Encoding.UTF8.GetBytes(keyId), keyId, entry)); + var secret = entry.Secret ?? keyId; + _apiKeyBytes.Add((Encoding.UTF8.GetBytes(secret), keyId, entry)); + + if (entry.Secret is null) + { + LogApiKeyMissingSecret(_logger, keyId); + } } } } @@ -179,4 +187,9 @@ public bool IsLocalhostBypass(IPAddress? remoteAddress) [LoggerMessage(EventId = 4, Level = LogLevel.Warning, Message = "JWT Bearer validation error: {Error}")] private static partial void LogJwtValidationError(ILogger logger, string error); + + [LoggerMessage(EventId = 5, Level = LogLevel.Warning, + Message = "API key '{KeyId}' uses the dictionary key as the bearer secret (deprecated). " + + "Add a 'secret' field to separate the identifier from the credential.")] + private static partial void LogApiKeyMissingSecret(ILogger logger, string keyId); } diff --git a/src/clawsharp/Core/Security/BearerTokenAuthFilter.cs b/src/clawsharp/Core/Security/BearerTokenAuthFilter.cs index bfc768a8..88401976 100644 --- a/src/clawsharp/Core/Security/BearerTokenAuthFilter.cs +++ b/src/clawsharp/Core/Security/BearerTokenAuthFilter.cs @@ -25,7 +25,7 @@ public sealed class BearerTokenAuthFilter(ApiKeyAuthenticator authenticator) : I if (authenticator.IsLocalhostBypass(httpCtx.Connection.RemoteIpAddress)) { httpCtx.Items[AuthResultKey] = McpServerAuthResult.Success(null, PolicyDecision.Unrestricted, null); - return await next(ctx); + return await next(ctx).ConfigureAwait(false); } var authHeader = httpCtx.Request.Headers.Authorization.ToString(); @@ -40,6 +40,6 @@ public sealed class BearerTokenAuthFilter(ApiKeyAuthenticator authenticator) : I return Results.Unauthorized(); httpCtx.Items[AuthResultKey] = result; - return await next(ctx); + return await next(ctx).ConfigureAwait(false); } } diff --git a/src/clawsharp/Core/Services/CooldownTracker.cs b/src/clawsharp/Core/Services/CooldownTracker.cs index 84331cf6..c6cc9e75 100644 --- a/src/clawsharp/Core/Services/CooldownTracker.cs +++ b/src/clawsharp/Core/Services/CooldownTracker.cs @@ -79,18 +79,33 @@ public void RecordSuccess(string providerName) /// private static TimeSpan ComputeCooldown(FailoverReason reason, int failureCount) { - if (reason == FailoverReason.Billing) + switch (reason) { - var exponent = Math.Min(failureCount - 1, 10); - var hours = 5.0 * Math.Pow(2, exponent); - return TimeSpan.FromHours(Math.Min(hours, 24)); - } + case FailoverReason.Billing: + { + var exponent = Math.Min(failureCount - 1, 10); + var hours = 5.0 * Math.Pow(2, exponent); + return TimeSpan.FromHours(Math.Min(hours, 24)); + } - // Standard backoff: 1 min * 5^min(n-1, 3) - // n=1 → 1m, n=2 → 5m, n=3 → 25m, n=4+ → capped at 60m (1h) - var exp = Math.Min(failureCount - 1, 3); - var minutes = 1.0 * Math.Pow(5, exp); - return TimeSpan.FromMinutes(Math.Min(minutes, 60)); + // Overloaded uses the same standard backoff as RateLimit, Timeout, etc. + // Explicit case so intent is clear — the FailoverReason enum doc says + // "mapped to RateLimit behavior" and this keeps the two aligned. + case FailoverReason.Overloaded: + case FailoverReason.RateLimit: + case FailoverReason.Timeout: + case FailoverReason.Auth: + case FailoverReason.Format: + case FailoverReason.Unknown: + default: + { + // Standard backoff: 1 min * 5^min(n-1, 3) + // n=1 → 1m, n=2 → 5m, n=3 → 25m, n=4+ → capped at 60m (1h) + var exp = Math.Min(failureCount - 1, 3); + var minutes = 1.0 * Math.Pow(5, exp); + return TimeSpan.FromMinutes(Math.Min(minutes, 60)); + } + } } /// Mutable state for a single provider's cooldown tracking. diff --git a/src/clawsharp/Core/Services/CronService.cs b/src/clawsharp/Core/Services/CronService.cs index f023b173..3f2e7f46 100644 --- a/src/clawsharp/Core/Services/CronService.cs +++ b/src/clawsharp/Core/Services/CronService.cs @@ -40,11 +40,11 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) { try { - await store.InitAsync(stoppingToken); + await store.InitAsync(stoppingToken).ConfigureAwait(false); - var loaded = await store.LoadAllAsync(stoppingToken); + var loaded = await store.LoadAllAsync(stoppingToken).ConfigureAwait(false); - await _jobsLock.WaitAsync(stoppingToken); + await _jobsLock.WaitAsync(stoppingToken).ConfigureAwait(false); try { foreach (var job in loaded) @@ -87,7 +87,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) Provider = entry.Provider }; _jobs[id] = job; - await store.UpsertAsync(job, stoppingToken); + await store.UpsertAsync(job, stoppingToken).ConfigureAwait(false); } } @@ -113,7 +113,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) // Wake up immediately if we already have enabled jobs bool hasEnabledOnStart; - await _jobsLock.WaitAsync(stoppingToken); + await _jobsLock.WaitAsync(stoppingToken).ConfigureAwait(false); try { hasEnabledOnStart = _jobs.Values.Any(j => j.Enabled); @@ -133,7 +133,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) { try { - await _wakeSignal.WaitAsync(stoppingToken); + await _wakeSignal.WaitAsync(stoppingToken).ConfigureAwait(false); } catch (OperationCanceledException) { @@ -149,12 +149,12 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) { // Reload from the backing store so that jobs added externally // (e.g. via CLI `clawsharp cron add`) are picked up without a restart. - await ReloadFromStoreAsync(stoppingToken); + await ReloadFromStoreAsync(stoppingToken).ConfigureAwait(false); - await FireDueJobsAsync(stoppingToken); + await FireDueJobsAsync(stoppingToken).ConfigureAwait(false); bool hasEnabled; - await _jobsLock.WaitAsync(stoppingToken); + await _jobsLock.WaitAsync(stoppingToken).ConfigureAwait(false); try { hasEnabled = _jobs.Values.Any(j => j.Enabled); @@ -170,7 +170,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) break; } - await Task.Delay(PollIntervalMs, stoppingToken); + await Task.Delay(PollIntervalMs, stoppingToken).ConfigureAwait(false); } catch (OperationCanceledException) { @@ -184,8 +184,8 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) public async Task AddJobAsync(CronJob job, CancellationToken ct = default) { - await _initialized.Task.WaitAsync(ct); - await _jobsLock.WaitAsync(ct); + await _initialized.Task.WaitAsync(ct).ConfigureAwait(false); + await _jobsLock.WaitAsync(ct).ConfigureAwait(false); try { _jobs[job.Id] = job; @@ -195,7 +195,7 @@ public async Task AddJobAsync(CronJob job, CancellationToken ct = defau _jobsLock.Release(); } - await store.UpsertAsync(job, ct); + await store.UpsertAsync(job, ct).ConfigureAwait(false); if (job.Enabled) { @@ -208,8 +208,8 @@ public async Task AddJobAsync(CronJob job, CancellationToken ct = defau public async Task> ListJobsAsync(CancellationToken ct = default) { - await _initialized.Task.WaitAsync(ct); - await _jobsLock.WaitAsync(ct); + await _initialized.Task.WaitAsync(ct).ConfigureAwait(false); + await _jobsLock.WaitAsync(ct).ConfigureAwait(false); try { return _jobs.Values.ToList(); @@ -222,9 +222,9 @@ public async Task> ListJobsAsync(CancellationToken ct = d public async Task RemoveJobAsync(string id, CancellationToken ct = default) { - await _initialized.Task.WaitAsync(ct); + await _initialized.Task.WaitAsync(ct).ConfigureAwait(false); bool removed; - await _jobsLock.WaitAsync(ct); + await _jobsLock.WaitAsync(ct).ConfigureAwait(false); try { removed = _jobs.Remove(id); @@ -236,7 +236,7 @@ public async Task RemoveJobAsync(string id, CancellationToken ct = default if (removed) { - await store.DeleteAsync(id, ct); + await store.DeleteAsync(id, ct).ConfigureAwait(false); LogJobRemoved(logger, id); } @@ -245,8 +245,8 @@ public async Task RemoveJobAsync(string id, CancellationToken ct = default public async Task UpdateJobAsync(CronJob job, CancellationToken ct = default) { - await _initialized.Task.WaitAsync(ct); - await _jobsLock.WaitAsync(ct); + await _initialized.Task.WaitAsync(ct).ConfigureAwait(false); + await _jobsLock.WaitAsync(ct).ConfigureAwait(false); try { if (!_jobs.ContainsKey(job.Id)) @@ -261,7 +261,7 @@ public async Task RemoveJobAsync(string id, CancellationToken ct = default _jobsLock.Release(); } - await store.UpsertAsync(job, ct); + await store.UpsertAsync(job, ct).ConfigureAwait(false); if (job.Enabled) { @@ -273,9 +273,9 @@ public async Task RemoveJobAsync(string id, CancellationToken ct = default public async Task RunJobNowAsync(string id, CancellationToken ct = default) { - await _initialized.Task.WaitAsync(ct); + await _initialized.Task.WaitAsync(ct).ConfigureAwait(false); CronJob? job; - await _jobsLock.WaitAsync(ct); + await _jobsLock.WaitAsync(ct).ConfigureAwait(false); try { _jobs.TryGetValue(id, out job); @@ -290,7 +290,7 @@ public async Task RunJobNowAsync(string id, CancellationToken ct = defau return $"No job with id '{id}'."; } - await FireJobAsync(job, ct); + await FireJobAsync(job, ct).ConfigureAwait(false); return $"Fired job '{job.Id}' ({job.Name ?? job.ScheduleExpr})."; } @@ -308,7 +308,7 @@ private async Task ReloadFromStoreAsync(CancellationToken ct) IReadOnlyList stored; try { - stored = await store.LoadAllAsync(ct); + stored = await store.LoadAllAsync(ct).ConfigureAwait(false); } catch (Exception ex) when (!ct.IsCancellationRequested) { @@ -316,7 +316,7 @@ private async Task ReloadFromStoreAsync(CancellationToken ct) return; // proceed with stale in-memory data } - await _jobsLock.WaitAsync(ct); + await _jobsLock.WaitAsync(ct).ConfigureAwait(false); try { var storeIds = new HashSet(stored.Count, StringComparer.Ordinal); @@ -359,7 +359,7 @@ private async Task ReloadFromStoreAsync(CancellationToken ct) private async Task FireDueJobsAsync(CancellationToken ct) { List snapshot; - await _jobsLock.WaitAsync(ct); + await _jobsLock.WaitAsync(ct).ConfigureAwait(false); try { snapshot = _jobs.Values.Where(j => j.Enabled).ToList(); @@ -380,7 +380,7 @@ private async Task FireDueJobsAsync(CancellationToken ct) try { - await FireJobAsync(job, ct); + await FireJobAsync(job, ct).ConfigureAwait(false); } catch (Exception ex) when (!ct.IsCancellationRequested) { @@ -413,13 +413,13 @@ await bus.PublishAsync(new InboundMessage( ArrivedAt: DateTimeOffset.UtcNow, ModelOverride: job.Model, ProviderOverride: job.Provider - ), ct); + ), ct).ConfigureAwait(false); var now = DateTimeOffset.UtcNow; var newCount = job.RunCount + 1; // Update in-memory state under lock - await _jobsLock.WaitAsync(ct); + await _jobsLock.WaitAsync(ct).ConfigureAwait(false); try { if (_jobs.TryGetValue(job.Id, out var current)) @@ -441,7 +441,7 @@ await bus.PublishAsync(new InboundMessage( using var statsCts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); try { - await store.UpdateRunStatsAsync(job.Id, now, newCount, statsCts.Token); + await store.UpdateRunStatsAsync(job.Id, now, newCount, statsCts.Token).ConfigureAwait(false); } catch (Exception ex) { @@ -454,7 +454,7 @@ await bus.PublishAsync(new InboundMessage( try { using var atCts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); - await _jobsLock.WaitAsync(atCts.Token); + await _jobsLock.WaitAsync(atCts.Token).ConfigureAwait(false); CronJob? disabled; try { @@ -467,7 +467,7 @@ await bus.PublishAsync(new InboundMessage( if (disabled is not null) { - await store.UpsertAsync(disabled, atCts.Token); + await store.UpsertAsync(disabled, atCts.Token).ConfigureAwait(false); } } catch (Exception ex) diff --git a/src/clawsharp/Core/Services/FallbackChain.cs b/src/clawsharp/Core/Services/FallbackChain.cs index d6b856b7..32a3addb 100644 --- a/src/clawsharp/Core/Services/FallbackChain.cs +++ b/src/clawsharp/Core/Services/FallbackChain.cs @@ -37,7 +37,7 @@ public async Task ExecuteAsync( try { - var result = await action(name, provider, ct); + var result = await action(name, provider, ct).ConfigureAwait(false); cooldowns.RecordSuccess(name); return result; } @@ -110,7 +110,7 @@ public async IAsyncEnumerable ExecuteStreamAsync( try { enumerator = provider.StreamAsync(effectiveRequest, ct).GetAsyncEnumerator(ct); - hasFirst = await enumerator.MoveNextAsync(); + hasFirst = await enumerator.MoveNextAsync().ConfigureAwait(false); firstChunk = hasFirst ? enumerator.Current : null; } catch (Exception ex) when (ex is not OperationCanceledException) @@ -118,7 +118,7 @@ public async IAsyncEnumerable ExecuteStreamAsync( // Dispose the enumerator on error — it may hold an HTTP connection. if (enumerator is not null) { - await enumerator.DisposeAsync(); + await enumerator.DisposeAsync().ConfigureAwait(false); } var reason = ErrorClassifier.Classify(ex); @@ -150,7 +150,7 @@ public async IAsyncEnumerable ExecuteStreamAsync( // rather than propagated (which would mask the actual stream result). try { - while (hasFirst && await enumerator.MoveNextAsync()) + while (hasFirst && await enumerator.MoveNextAsync().ConfigureAwait(false)) { yield return enumerator.Current; } @@ -159,7 +159,7 @@ public async IAsyncEnumerable ExecuteStreamAsync( { try { - await enumerator.DisposeAsync(); + await enumerator.DisposeAsync().ConfigureAwait(false); } catch (Exception disposeEx) { diff --git a/src/clawsharp/Core/Services/HeartbeatService.cs b/src/clawsharp/Core/Services/HeartbeatService.cs index f077a143..8387bd0d 100644 --- a/src/clawsharp/Core/Services/HeartbeatService.cs +++ b/src/clawsharp/Core/Services/HeartbeatService.cs @@ -56,7 +56,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) { // Sleep 10 seconds then check if the cron expression matches the current minute. // This mirrors the polling approach used by CronService. - await Task.Delay(PollIntervalMs, stoppingToken); + await Task.Delay(PollIntervalMs, stoppingToken).ConfigureAwait(false); // Heartbeat cron schedule is evaluated against the machine's local time // (DateTimeOffset.Now), NOT UTC. This matches user expectations for schedules @@ -77,7 +77,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) Volatile.Write(ref _lastFiredMinuteTicks, truncatedTicks); - var prompt = await ReadPromptFileAsync(stoppingToken); + var prompt = await ReadPromptFileAsync(stoppingToken).ConfigureAwait(false); LogHeartbeatFiring(_logger, _heartbeatConfig.Channel, prompt.Length); await _bus.PublishAsync(new InboundMessage( @@ -87,7 +87,7 @@ await _bus.PublishAsync(new InboundMessage( Text: prompt, ArrivedAt: DateTimeOffset.UtcNow, IsHeartbeat: true - ), stoppingToken); + ), stoppingToken).ConfigureAwait(false); } catch (OperationCanceledException) { @@ -132,7 +132,7 @@ private async Task ReadPromptFileAsync(CancellationToken ct) try { - var content = await File.ReadAllTextAsync(resolved, ct); + var content = await File.ReadAllTextAsync(resolved, ct).ConfigureAwait(false); if (!string.IsNullOrWhiteSpace(content)) { return content.Trim(); diff --git a/src/clawsharp/Core/Services/LifecycleBackgroundService.cs b/src/clawsharp/Core/Services/LifecycleBackgroundService.cs index 083c3036..d2b858bc 100644 --- a/src/clawsharp/Core/Services/LifecycleBackgroundService.cs +++ b/src/clawsharp/Core/Services/LifecycleBackgroundService.cs @@ -44,7 +44,7 @@ public virtual async Task StopAsync(CancellationToken cancellationToken) try { - await _cts!.CancelAsync(); + await _cts.CancelAsync().ConfigureAwait(false); } finally { diff --git a/src/clawsharp/Core/Sessions/SessionStore.cs b/src/clawsharp/Core/Sessions/SessionStore.cs index e92d2488..f1e3d8d7 100644 --- a/src/clawsharp/Core/Sessions/SessionStore.cs +++ b/src/clawsharp/Core/Sessions/SessionStore.cs @@ -1,6 +1,7 @@ using System.Security.Cryptography; using System.Text; using System.Text.Json; +using Clawsharp.Core.Utilities; using Microsoft.Extensions.Logging; namespace Clawsharp.Core.Sessions; @@ -19,9 +20,9 @@ public SessionStore(ILogger logger) { _logger = logger; var root = Config.ConfigLoader.ExpandHome("~/.clawsharp"); - Directory.CreateDirectory(root); + FilePermissions.EnsureRestrictedDirectory(root); _dir = Path.Combine(root, "sessions"); - Directory.CreateDirectory(_dir); + FilePermissions.EnsureRestrictedDirectory(_dir); } /// Test-only constructor with custom sessions directory. @@ -43,7 +44,7 @@ public async Task LoadOrCreateAsync(string sessionId, CancellationToken try { await using var stream = File.OpenRead(path); - var session = await JsonSerializer.DeserializeAsync(stream, SessionJsonContext.Default.Session, ct); + var session = await JsonSerializer.DeserializeAsync(stream, SessionJsonContext.Default.Session, ct).ConfigureAwait(false); return session ?? new Session { Id = sessionId }; } catch (Exception ex) when (ex is JsonException or IOException) @@ -61,11 +62,12 @@ public async Task SaveAsync(Session session, CancellationToken ct = default) { await using (var stream = File.Create(tmp)) { - await JsonSerializer.SerializeAsync(stream, session, SessionJsonContext.Default.Session, ct); - await stream.FlushAsync(ct); + await JsonSerializer.SerializeAsync(stream, session, SessionJsonContext.Default.Session, ct).ConfigureAwait(false); + await stream.FlushAsync(ct).ConfigureAwait(false); } File.Move(tmp, path, true); + FilePermissions.SetRestrictedFilePermissions(path); } catch { @@ -86,6 +88,8 @@ public async Task SaveAsync(Session session, CancellationToken ct = default) /// Builds a safe filesystem path for the given session ID. /// Uses for reversible, collision-free encoding. /// Falls back to a truncated SHA-256 hash if the encoded name exceeds 200 characters. + /// The 16-character (8-byte) hash prefix has a ~2^32 collision threshold — acceptable + /// for a personal assistant but would need a longer prefix for multi-tenant deployments. /// internal string SessionPath(string sessionId) { diff --git a/src/clawsharp/Core/Transcription/VoiceTranscriptionService.cs b/src/clawsharp/Core/Transcription/VoiceTranscriptionService.cs index e70b21f4..5f5fcf2f 100644 --- a/src/clawsharp/Core/Transcription/VoiceTranscriptionService.cs +++ b/src/clawsharp/Core/Transcription/VoiceTranscriptionService.cs @@ -153,7 +153,7 @@ public VoiceTranscriptionService( req.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _apiKey); req.Content = form; - using var resp = await _http!.SendAsync(req, ct).ConfigureAwait(false); + using var resp = await _http.SendAsync(req, ct).ConfigureAwait(false); if (!resp.IsSuccessStatusCode) { var errorBody = await ReadErrorBodyAsync(resp, ct).ConfigureAwait(false); @@ -200,7 +200,7 @@ public VoiceTranscriptionService( req.Headers.Add("Ocp-Apim-Subscription-Key", _apiKey); req.Content = form; - using var resp = await _http!.SendAsync(req, ct).ConfigureAwait(false); + using var resp = await _http.SendAsync(req, ct).ConfigureAwait(false); if (!resp.IsSuccessStatusCode) { var errorBody = await ReadErrorBodyAsync(resp, ct).ConfigureAwait(false); @@ -302,12 +302,10 @@ public VoiceTranscriptionService( }, }; - var bodyJson = JsonSerializer.Serialize( - reqBody, VoiceTranscriptJsonContext.Default.GcpSpeechRequest); - using var content = new StringContent(bodyJson, Encoding.UTF8, "application/json"); + using var content = Utf8JsonContent.Create(reqBody, VoiceTranscriptJsonContext.Default.GcpSpeechRequest); var url = $"{_gcpUrl}?key={Uri.EscapeDataString(_apiKey!)}"; - using var resp = await _http!.PostAsync(url, content, ct).ConfigureAwait(false); + using var resp = await _http.PostAsync(url, content, ct).ConfigureAwait(false); if (!resp.IsSuccessStatusCode) { var errorBody = await ReadErrorBodyAsync(resp, ct).ConfigureAwait(false); diff --git a/src/clawsharp/Core/Utilities/FilePermissions.cs b/src/clawsharp/Core/Utilities/FilePermissions.cs new file mode 100644 index 00000000..74a90f24 --- /dev/null +++ b/src/clawsharp/Core/Utilities/FilePermissions.cs @@ -0,0 +1,34 @@ +namespace Clawsharp.Core.Utilities; + +/// +/// Enforces restrictive Unix file permissions (owner-only) on data directories and files. +/// No-op on Windows where Unix file modes are not supported. +/// +internal static class FilePermissions +{ + /// + /// Creates the directory (if needed) and restricts it to owner rwx (0700) on Unix. + /// + internal static void EnsureRestrictedDirectory(string path) + { + Directory.CreateDirectory(path); + if (!OperatingSystem.IsWindows()) + { + File.SetUnixFileMode(path, + UnixFileMode.UserRead | UnixFileMode.UserWrite | UnixFileMode.UserExecute); + } + } + + /// + /// Restricts an existing file to owner rw (0600) on Unix. + /// No-op if the file does not exist or the OS is Windows. + /// + internal static void SetRestrictedFilePermissions(string path) + { + if (!OperatingSystem.IsWindows() && File.Exists(path)) + { + File.SetUnixFileMode(path, + UnixFileMode.UserRead | UnixFileMode.UserWrite); + } + } +} diff --git a/src/clawsharp/Core/Utilities/JsonContent.cs b/src/clawsharp/Core/Utilities/JsonContent.cs new file mode 100644 index 00000000..864ad1e8 --- /dev/null +++ b/src/clawsharp/Core/Utilities/JsonContent.cs @@ -0,0 +1,59 @@ +using System.Net.Http.Headers; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; + +namespace Clawsharp.Core.Utilities; + +/// +/// Creates from JSON data using UTF-8 bytes directly, +/// avoiding the double-encoding overhead of +/// (which accepts a UTF-16 string and re-encodes it to UTF-8). +/// Named Utf8JsonContent to avoid collision with . +/// +internal static class Utf8JsonContent +{ + /// + /// Serializes directly to UTF-8 bytes using the provided + /// source-generated , then wraps in . + /// + public static HttpContent Create(T value, JsonTypeInfo typeInfo) + { + var bytes = JsonSerializer.SerializeToUtf8Bytes(value, typeInfo); + return Wrap(bytes); + } + + /// + /// Serializes directly to UTF-8 bytes using the provided + /// non-generic , then wraps in . + /// Useful for patterns where the type info is dynamically typed. + /// + public static HttpContent Create(object value, JsonTypeInfo typeInfo) + { + var bytes = JsonSerializer.SerializeToUtf8Bytes(value, typeInfo); + return Wrap(bytes); + } + + /// + /// Wraps a pre-serialized JSON string as UTF-8 , + /// avoiding the intermediate UTF-16 re-encoding that performs. + /// + public static HttpContent FromString(string json) + { + var bytes = Encoding.UTF8.GetBytes(json); + return Wrap(bytes); + } + + /// + /// Wraps pre-serialized UTF-8 JSON bytes as . + /// Useful when the same bytes must be sent multiple times (e.g., retry after re-login). + /// + public static HttpContent FromUtf8Bytes(byte[] jsonBytes) => Wrap(jsonBytes); + + private static ReadOnlyMemoryContent Wrap(byte[] bytes) + { + var content = new ReadOnlyMemoryContent(bytes); + content.Headers.ContentType = new MediaTypeHeaderValue("application/json") { CharSet = "utf-8" }; + return content; + } +} diff --git a/src/clawsharp/Core/Utilities/JsonFileStore.cs b/src/clawsharp/Core/Utilities/JsonFileStore.cs index 67e97bc4..974c4f47 100644 --- a/src/clawsharp/Core/Utilities/JsonFileStore.cs +++ b/src/clawsharp/Core/Utilities/JsonFileStore.cs @@ -8,7 +8,7 @@ namespace Clawsharp.Core.Utilities; /// Designed for small configuration/state files (pairing codes, approved senders, etc.). /// /// Thread safety: all operations acquire a before touching the file. -/// Atomic writes: data is written to a .tmp file first, then moved into place via . +/// Atomic writes: data is written to a .tmp file first, then moved into place via . /// /// The type to serialize/deserialize. Must have a source-generated . public sealed class JsonFileStore : IDisposable where T : class, new() diff --git a/src/clawsharp/Cost/CostStorage.cs b/src/clawsharp/Cost/CostStorage.cs index 8612f499..b7c95199 100644 --- a/src/clawsharp/Cost/CostStorage.cs +++ b/src/clawsharp/Cost/CostStorage.cs @@ -1,5 +1,6 @@ using System.Text.Json; using Clawsharp.Config; +using Clawsharp.Core.Utilities; namespace Clawsharp.Cost; @@ -25,7 +26,7 @@ public sealed class CostStorage public CostStorage() { var dir = ConfigLoader.ExpandHome("~/.clawsharp"); - Directory.CreateDirectory(dir); + FilePermissions.EnsureRestrictedDirectory(dir); _filePath = Path.Combine(dir, "costs.jsonl"); } @@ -48,7 +49,7 @@ public async Task AppendAsync(CostRecord record, CancellationToken ct = default) await _writeLock.WaitAsync(ct).ConfigureAwait(false); try { - await File.AppendAllTextAsync(_filePath, json + "\n", ct).ConfigureAwait(false); + await File.AppendAllLinesAsync(_filePath, [json], ct).ConfigureAwait(false); // Invalidate the cache — next ReadAllAsync will re-read the file lock (_cacheLock) @@ -67,6 +68,13 @@ public async Task AppendAsync(CostRecord record, CancellationToken ct = default) /// Uses a simple in-memory cache that is invalidated when a new record is written /// or when the file's last-write time changes (e.g., external edits). /// + /// + /// No lock is held across File.Exists, GetLastWriteTimeUtc, and ReadLinesAsync. + /// External file manipulation between these calls could yield stale or empty results. + /// This is acceptable because the file is in ~/.clawsharp/ under user control and no + /// external process is expected to modify it during operation. Write-side serialization + /// via _writeLock ensures internal consistency. + /// public async Task> ReadAllAsync(CancellationToken ct = default) { if (!File.Exists(_filePath)) diff --git a/src/clawsharp/Cost/CostTracker.cs b/src/clawsharp/Cost/CostTracker.cs index 83b4c52e..bf7aa6e5 100644 --- a/src/clawsharp/Cost/CostTracker.cs +++ b/src/clawsharp/Cost/CostTracker.cs @@ -21,6 +21,8 @@ public sealed partial class CostTracker( // Global in-memory aggregation (backward compat) private decimal _dailyTotal; private decimal _monthlyTotal; + private decimal _dailySavings; + private decimal _monthlySavings; // Per-scope aggregation via ConcurrentDictionary // Scope key format: "global", "user:{name}", "dept:{name}" @@ -41,6 +43,12 @@ public sealed partial class CostTracker( /// after all exceed checks pass. Returns extended /// with per-scope status via . /// + /// + /// The lock is released before the LLM call runs, creating a check-then-act window. + /// Concurrent requests can exceed limits by up to N * estimatedCost where N is concurrency depth. + /// This is an intentional trade-off: strict enforcement would serialize all concurrent requests. + /// Real-world overspend is bounded to fractions of a cent for typical usage patterns. + /// public async Task CheckBudgetAsync( decimal estimatedCost, string? userId = null, @@ -55,13 +63,28 @@ public async Task CheckBudgetAsync( } decimal dailySnapshot, monthlySnapshot; - await _lock.WaitAsync(ct); + ScopeBudgetStatus? userStatus = null; + ScopeBudgetStatus? deptStatus = null; + + await _lock.WaitAsync(ct).ConfigureAwait(false); try { - await EnsureInitializedAsync(ct); + await EnsureInitializedAsync(ct).ConfigureAwait(false); CheckDayMonthBoundary(); dailySnapshot = _dailyTotal; monthlySnapshot = _monthlyTotal; + + // Snapshot per-scope totals inside the lock to avoid TOCTOU between + // the global and per-scope reads (M-12). + if (userId is not null && userBudget is not null) + { + userStatus = EvaluateScope($"user:{userId}", userBudget, estimatedCost); + } + + if (departmentId is not null && deptBudget is not null) + { + deptStatus = EvaluateScope($"dept:{departmentId}", deptBudget, estimatedCost); + } } finally { @@ -94,36 +117,26 @@ public async Task CheckBudgetAsync( } // --- Check per-user scope --- - ScopeBudgetStatus? userStatus = null; - if (userId is not null && userBudget is not null) + if (userStatus is { Status: BudgetStatus.Exceeded }) { - userStatus = EvaluateScope($"user:{userId}", userBudget, estimatedCost); - if (userStatus.Status == BudgetStatus.Exceeded) - { - return new BudgetCheckResult( - BudgetStatus.Exceeded, - $"User daily budget exceeded: ${userStatus.DailyUsed:F4} / ${userStatus.DailyLimit:F2}", - dailySnapshot, - monthlySnapshot, - UserBudget: userStatus); - } + return new BudgetCheckResult( + BudgetStatus.Exceeded, + $"User daily budget exceeded: ${userStatus.DailyUsed:F4} / ${userStatus.DailyLimit:F2}", + dailySnapshot, + monthlySnapshot, + UserBudget: userStatus); } // --- Check per-department scope --- - ScopeBudgetStatus? deptStatus = null; - if (departmentId is not null && deptBudget is not null) + if (deptStatus is { Status: BudgetStatus.Exceeded }) { - deptStatus = EvaluateScope($"dept:{departmentId}", deptBudget, estimatedCost); - if (deptStatus.Status == BudgetStatus.Exceeded) - { - return new BudgetCheckResult( - BudgetStatus.Exceeded, - $"{departmentId} department monthly budget exhausted (${deptStatus.MonthlyLimit:F2}). Contact your admin.", - dailySnapshot, - monthlySnapshot, - UserBudget: userStatus, - DepartmentBudget: deptStatus); - } + return new BudgetCheckResult( + BudgetStatus.Exceeded, + $"{departmentId} department monthly budget exhausted (${deptStatus.MonthlyLimit:F2}). Contact your admin.", + dailySnapshot, + monthlySnapshot, + UserBudget: userStatus, + DepartmentBudget: deptStatus); } // --- Collect warnings from all scopes --- @@ -277,11 +290,8 @@ public async Task RecordUsageAsync( cost = reportedCost; } - var cacheSavings = 0.0m; - if (savings > 0) - { - cacheSavings = savings; - } + // Cache savings can be negative when Anthropic cache write premiums exceed read discounts. + var cacheSavings = savings; var record = new CostRecord { @@ -299,14 +309,16 @@ public async Task RecordUsageAsync( DepartmentId = departmentId, }; - await storage.AppendAsync(record, ct); + await storage.AppendAsync(record, ct).ConfigureAwait(false); - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { CheckDayMonthBoundary(); _dailyTotal += cost; _monthlyTotal += cost; + _dailySavings += cacheSavings; + _monthlySavings += cacheSavings; } finally { @@ -343,59 +355,54 @@ public async Task RecordUsageAsync( } /// Get current cost and cache-savings summary. Optionally filter by session. + /// + /// Daily, Monthly, DailySavings, and MonthlySavings are from the in-memory snapshot taken under lock. + /// When is provided, Session and SessionSavings are computed from a + /// disk scan that may include records written after the snapshot. Session totals may therefore + /// diverge from daily/monthly totals by the cost of one in-flight request. Budget enforcement + /// uses its own locking path and is unaffected. + /// public async Task GetSummaryAsync( string? sessionId = null, CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); decimal daily; decimal monthly; + decimal dailySavingsSnapshot; + decimal monthlySavingsSnapshot; try { - await EnsureInitializedAsync(ct); + await EnsureInitializedAsync(ct).ConfigureAwait(false); CheckDayMonthBoundary(); daily = _dailyTotal; monthly = _monthlyTotal; + dailySavingsSnapshot = _dailySavings; + monthlySavingsSnapshot = _monthlySavings; } finally { _lock.Release(); } - // Savings and session totals are not tracked in memory -- scan disk for those. + // Session totals require a sessionId filter — scan disk only when requested. var session = 0.0m; - var dailySavings = 0.0m; - var monthlySavings = 0.0m; var sessionSavings = 0.0m; - var records = await storage.ReadAllAsync(ct); - var now = DateTimeOffset.UtcNow; - var todayUtc = DateOnly.FromDateTime(now.UtcDateTime); - - foreach (var r in records) + if (sessionId is not null) { - var recordDate = DateOnly.FromDateTime(r.Timestamp.UtcDateTime); - - if (recordDate == todayUtc) + var records = await storage.ReadAllAsync(ct).ConfigureAwait(false); + foreach (var r in records) { - dailySavings += r.CacheSavingsUsd; - } - - if (r.Timestamp.UtcDateTime.Year == now.UtcDateTime.Year && - r.Timestamp.UtcDateTime.Month == now.UtcDateTime.Month) - { - monthlySavings += r.CacheSavingsUsd; - } - - if (sessionId is not null && - string.Equals(r.SessionId, sessionId, StringComparison.Ordinal)) - { - session += r.CostUsd; - sessionSavings += r.CacheSavingsUsd; + if (string.Equals(r.SessionId, sessionId, StringComparison.Ordinal)) + { + session += r.CostUsd; + sessionSavings += r.CacheSavingsUsd; + } } } - return new CostSummary(daily, monthly, session, dailySavings, monthlySavings, sessionSavings); + return new CostSummary(daily, monthly, session, dailySavingsSnapshot, monthlySavingsSnapshot, sessionSavings); } /// Detect day/month boundary crossings and reset in-memory aggregates. @@ -407,12 +414,14 @@ private void CheckDayMonthBoundary() if (todayUtc != _currentDay) { _dailyTotal = 0; + _dailySavings = 0; _dailyTotals.Clear(); _currentDay = todayUtc; if (now.UtcDateTime.Year != _currentYear || now.UtcDateTime.Month != _currentMonth) { _monthlyTotal = 0; + _monthlySavings = 0; _monthlyTotals.Clear(); _currentMonth = now.UtcDateTime.Month; _currentYear = now.UtcDateTime.Year; @@ -428,7 +437,7 @@ private async Task EnsureInitializedAsync(CancellationToken ct) return; } - var records = await storage.ReadAllAsync(ct); + var records = await storage.ReadAllAsync(ct).ConfigureAwait(false); var now = DateTimeOffset.UtcNow; var todayUtc = DateOnly.FromDateTime(now.UtcDateTime); _currentDay = todayUtc; @@ -445,6 +454,7 @@ private async Task EnsureInitializedAsync(CancellationToken ct) if (isToday) { _dailyTotal += r.CostUsd; + _dailySavings += r.CacheSavingsUsd; // Per-scope daily aggregation from JSONL if (r.UserId is not null) @@ -463,6 +473,7 @@ private async Task EnsureInitializedAsync(CancellationToken ct) if (isThisMonth) { _monthlyTotal += r.CostUsd; + _monthlySavings += r.CacheSavingsUsd; // Per-scope monthly aggregation from JSONL if (r.UserId is not null) diff --git a/src/clawsharp/Cost/DefaultPricing.cs b/src/clawsharp/Cost/DefaultPricing.cs index 299d477d..8e4ab025 100644 --- a/src/clawsharp/Cost/DefaultPricing.cs +++ b/src/clawsharp/Cost/DefaultPricing.cs @@ -25,18 +25,53 @@ public static class DefaultPricing ["claude-3-haiku"] = (0.25m, 1.25m), // OpenAI (prefixed) - ["openai/gpt-4o"] = (5.00m, 15.00m), + ["openai/gpt-4o"] = (2.50m, 10.00m), ["openai/gpt-4o-mini"] = (0.15m, 0.60m), - ["openai/o1-preview"] = (15.00m, 60.00m), + ["openai/gpt-4.1"] = (2.00m, 8.00m), + ["openai/gpt-4.1-mini"] = (0.40m, 1.60m), + ["openai/gpt-4.1-nano"] = (0.10m, 0.40m), + ["openai/gpt-5"] = (1.25m, 10.00m), + ["openai/gpt-5-mini"] = (0.25m, 2.00m), + ["openai/gpt-5-nano"] = (0.05m, 0.40m), + ["openai/gpt-5-pro"] = (15.00m, 120.00m), + ["openai/gpt-5.1"] = (1.25m, 10.00m), + ["openai/gpt-5.2"] = (1.75m, 14.00m), + ["openai/gpt-5.4"] = (2.50m, 15.00m), + ["openai/gpt-5.4-mini"] = (0.75m, 4.50m), + ["openai/gpt-5.4-nano"] = (0.20m, 1.25m), + ["openai/gpt-5.4-pro"] = (30.00m, 180.00m), + ["openai/o1"] = (15.00m, 60.00m), + ["openai/o1-mini"] = (1.10m, 4.40m), + ["openai/o1-pro"] = (150.00m, 600.00m), + ["openai/o3"] = (2.00m, 8.00m), + ["openai/o3-mini"] = (1.10m, 4.40m), + ["openai/o3-pro"] = (20.00m, 80.00m), + ["openai/o4-mini"] = (1.10m, 4.40m), // OpenAI (bare) - ["gpt-4o"] = (5.00m, 15.00m), + ["gpt-4o"] = (2.50m, 10.00m), ["gpt-4o-mini"] = (0.15m, 0.60m), ["gpt-4.1"] = (2.00m, 8.00m), ["gpt-4.1-mini"] = (0.40m, 1.60m), - ["gpt-5.2"] = (5.00m, 15.00m), - ["o1-preview"] = (15.00m, 60.00m), + ["gpt-4.1-nano"] = (0.10m, 0.40m), + ["gpt-5"] = (1.25m, 10.00m), + ["gpt-5-mini"] = (0.25m, 2.00m), + ["gpt-5-nano"] = (0.05m, 0.40m), + ["gpt-5-pro"] = (15.00m, 120.00m), + ["gpt-5.1"] = (1.25m, 10.00m), + ["gpt-5.2"] = (1.75m, 14.00m), + ["gpt-5.2-pro"] = (21.00m, 168.00m), + ["gpt-5.4"] = (2.50m, 15.00m), + ["gpt-5.4-mini"] = (0.75m, 4.50m), + ["gpt-5.4-nano"] = (0.20m, 1.25m), + ["gpt-5.4-pro"] = (30.00m, 180.00m), + ["o1"] = (15.00m, 60.00m), + ["o1-mini"] = (1.10m, 4.40m), + ["o1-pro"] = (150.00m, 600.00m), + ["o3"] = (2.00m, 8.00m), ["o3-mini"] = (1.10m, 4.40m), + ["o3-pro"] = (20.00m, 80.00m), + ["o4-mini"] = (1.10m, 4.40m), // Google (prefixed) ["google/gemini-2.0-flash"] = (0.10m, 0.40m), @@ -111,10 +146,10 @@ public static class DefaultPricing ["kimi-k2-thinking"] = (0.60m, 2.50m), // MiniMax - ["MiniMax-Text-01"] = (0.20m, 1.10m), - ["MiniMax-M2"] = (0.255m, 1.00m), - ["MiniMax-M2.1"] = (0.27m, 0.95m), - ["MiniMax-M2.5"] = (0.295m, 1.20m), + ["minimax-text-01"] = (0.20m, 1.10m), + ["minimax-m2"] = (0.255m, 1.00m), + ["minimax-m2.1"] = (0.27m, 0.95m), + ["minimax-m2.5"] = (0.295m, 1.20m), // VolcEngine / ByteDance Doubao ["doubao-1-5-pro-32k-250115"] = (0.11m, 0.28m), @@ -130,7 +165,7 @@ public static class DefaultPricing ["Qwen/Qwen2.5-72B-Instruct"] = (0.40m, 1.20m), ["Qwen/Qwen2.5-7B-Instruct"] = (0.05m, 0.20m), ["moonshotai/Kimi-K2-Instruct"] = (0.58m, 2.29m), - ["MiniMaxAI/MiniMax-M2.5"] = (0.30m, 1.20m), + ["minimaxai/minimax-m2.5"] = (0.30m, 1.20m), ["meta-llama/Meta-Llama-3.1-8B-Instruct"] = (0.06m, 0.06m), }.ToFrozenDictionary(StringComparer.OrdinalIgnoreCase); @@ -153,7 +188,7 @@ public static (decimal InputPer1M, decimal OutputPer1M) GetPrice(string model) /// Calculate the USD cost for a given model and token counts. /// Returns 0.0 for unknown models. /// - public static decimal CalculateCost(string model, int inputTokens, int outputTokens) + public static decimal CalculateCost(string model, long inputTokens, long outputTokens) { var (inputPer1M, outputPer1M) = GetPrice(model); if (inputPer1M == 0 && outputPer1M == 0) @@ -171,8 +206,8 @@ public static decimal CalculateCost(string model, int inputTokens, int outputTok /// public static decimal CalculateCost( string model, - int inputTokens, - int outputTokens, + long inputTokens, + long outputTokens, IReadOnlyDictionary? overrides) { if (overrides is not null && diff --git a/src/clawsharp/Cron/JsonCronStore.cs b/src/clawsharp/Cron/JsonCronStore.cs index b8016e0b..ff37aa35 100644 --- a/src/clawsharp/Cron/JsonCronStore.cs +++ b/src/clawsharp/Cron/JsonCronStore.cs @@ -19,10 +19,10 @@ public Task InitAsync(CancellationToken ct = default) public async Task> LoadAllAsync(CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { - return await ReadAsync(ct); + return await ReadAsync(ct).ConfigureAwait(false); } finally { @@ -32,10 +32,10 @@ public async Task> LoadAllAsync(CancellationToken ct = de public async Task UpsertAsync(CronJob job, CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { - var jobs = await ReadAsync(ct); + var jobs = await ReadAsync(ct).ConfigureAwait(false); var idx = jobs.FindIndex(j => j.Id == job.Id); if (idx >= 0) { @@ -46,7 +46,7 @@ public async Task UpsertAsync(CronJob job, CancellationToken ct = default) jobs.Add(job); } - await WriteAsync(jobs, ct); + await WriteAsync(jobs, ct).ConfigureAwait(false); } finally { @@ -56,12 +56,12 @@ public async Task UpsertAsync(CronJob job, CancellationToken ct = default) public async Task DeleteAsync(string id, CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { - var jobs = await ReadAsync(ct); + var jobs = await ReadAsync(ct).ConfigureAwait(false); jobs.RemoveAll(j => j.Id == id); - await WriteAsync(jobs, ct); + await WriteAsync(jobs, ct).ConfigureAwait(false); } finally { @@ -71,16 +71,16 @@ public async Task DeleteAsync(string id, CancellationToken ct = default) public async Task UpdateRunStatsAsync(string id, DateTimeOffset lastRunAt, int runCount, CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { - var jobs = await ReadAsync(ct); + var jobs = await ReadAsync(ct).ConfigureAwait(false); var job = jobs.Find(j => j.Id == id); if (job is not null) { job.LastRunAt = lastRunAt; job.RunCount = runCount; - await WriteAsync(jobs, ct); + await WriteAsync(jobs, ct).ConfigureAwait(false); } } finally @@ -98,7 +98,7 @@ private async Task> ReadAsync(CancellationToken ct) try { - var json = await File.ReadAllTextAsync(_filePath, ct); + var json = await File.ReadAllTextAsync(_filePath, ct).ConfigureAwait(false); return JsonSerializer.Deserialize(json, CronJsonContext.WithConverters.ListCronJob) ?? []; } catch (Exception ex) diff --git a/src/clawsharp/Cron/MssqlCronStore.cs b/src/clawsharp/Cron/MssqlCronStore.cs index 91b8cbda..a0bdd4b7 100644 --- a/src/clawsharp/Cron/MssqlCronStore.cs +++ b/src/clawsharp/Cron/MssqlCronStore.cs @@ -7,7 +7,7 @@ public sealed class MssqlCronStore(string connectionString) : ICronStore public async Task InitAsync(CancellationToken ct = default) { await using var conn = new SqlConnection(connectionString); - await conn.OpenAsync(ct); + await conn.OpenAsync(ct).ConfigureAwait(false); await using var cmd = conn.CreateCommand(); cmd.CommandText = """ IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'cron_jobs') @@ -29,7 +29,7 @@ model NVARCHAR(255), provider NVARCHAR(255) ); """; - await cmd.ExecuteNonQueryAsync(ct); + await cmd.ExecuteNonQueryAsync(ct).ConfigureAwait(false); // Migrate existing tables that lack model/provider columns. await using var alter = conn.CreateCommand(); @@ -39,19 +39,19 @@ IF NOT EXISTS (SELECT 1 FROM sys.columns WHERE object_id = OBJECT_ID('cron_jobs' IF NOT EXISTS (SELECT 1 FROM sys.columns WHERE object_id = OBJECT_ID('cron_jobs') AND name = 'provider') ALTER TABLE cron_jobs ADD provider NVARCHAR(255); """; - await alter.ExecuteNonQueryAsync(ct); + await alter.ExecuteNonQueryAsync(ct).ConfigureAwait(false); } public async Task> LoadAllAsync(CancellationToken ct = default) { await using var conn = new SqlConnection(connectionString); - await conn.OpenAsync(ct); + await conn.OpenAsync(ct).ConfigureAwait(false); await using var cmd = conn.CreateCommand(); cmd.CommandText = "SELECT id,name,schedule_kind,schedule_expr,tz,channel,message,sender_id,enabled,created_at,last_run_at,run_count,source,model,provider FROM cron_jobs"; var jobs = new List(); - await using var reader = await cmd.ExecuteReaderAsync(ct); - while (await reader.ReadAsync(ct)) + await using var reader = await cmd.ExecuteReaderAsync(ct).ConfigureAwait(false); + while (await reader.ReadAsync(ct).ConfigureAwait(false)) { jobs.Add(new CronJob { @@ -79,7 +79,7 @@ public async Task> LoadAllAsync(CancellationToken ct = de public async Task UpsertAsync(CronJob job, CancellationToken ct = default) { await using var conn = new SqlConnection(connectionString); - await conn.OpenAsync(ct); + await conn.OpenAsync(ct).ConfigureAwait(false); await using var cmd = conn.CreateCommand(); cmd.CommandText = """ MERGE cron_jobs AS target @@ -111,28 +111,28 @@ WHEN NOT MATCHED THEN INSERT (id,name,schedule_kind,schedule_expr,tz,channel,mes cmd.Parameters.AddWithValue("@src", job.Source.Value); cmd.Parameters.AddWithValue("@model", (object?)job.Model ?? DBNull.Value); cmd.Parameters.AddWithValue("@provider", (object?)job.Provider ?? DBNull.Value); - await cmd.ExecuteNonQueryAsync(ct); + await cmd.ExecuteNonQueryAsync(ct).ConfigureAwait(false); } public async Task DeleteAsync(string id, CancellationToken ct = default) { await using var conn = new SqlConnection(connectionString); - await conn.OpenAsync(ct); + await conn.OpenAsync(ct).ConfigureAwait(false); await using var cmd = conn.CreateCommand(); cmd.CommandText = "DELETE FROM cron_jobs WHERE id = @id"; cmd.Parameters.AddWithValue("@id", id); - await cmd.ExecuteNonQueryAsync(ct); + await cmd.ExecuteNonQueryAsync(ct).ConfigureAwait(false); } public async Task UpdateRunStatsAsync(string id, DateTimeOffset lastRunAt, int runCount, CancellationToken ct = default) { await using var conn = new SqlConnection(connectionString); - await conn.OpenAsync(ct); + await conn.OpenAsync(ct).ConfigureAwait(false); await using var cmd = conn.CreateCommand(); cmd.CommandText = "UPDATE cron_jobs SET last_run_at=@lra, run_count=@rc WHERE id=@id"; cmd.Parameters.AddWithValue("@lra", lastRunAt.ToString("O")); cmd.Parameters.AddWithValue("@rc", runCount); cmd.Parameters.AddWithValue("@id", id); - await cmd.ExecuteNonQueryAsync(ct); + await cmd.ExecuteNonQueryAsync(ct).ConfigureAwait(false); } } \ No newline at end of file diff --git a/src/clawsharp/Cron/PostgresCronStore.cs b/src/clawsharp/Cron/PostgresCronStore.cs index a1d08777..4e10f356 100644 --- a/src/clawsharp/Cron/PostgresCronStore.cs +++ b/src/clawsharp/Cron/PostgresCronStore.cs @@ -7,7 +7,7 @@ public sealed class PostgresCronStore(string connectionString) : ICronStore public async Task InitAsync(CancellationToken ct = default) { await using var conn = new NpgsqlConnection(connectionString); - await conn.OpenAsync(ct); + await conn.OpenAsync(ct).ConfigureAwait(false); await using var cmd = conn.CreateCommand(); cmd.CommandText = """ CREATE TABLE IF NOT EXISTS cron_jobs ( @@ -28,7 +28,7 @@ CREATE TABLE IF NOT EXISTS cron_jobs ( provider TEXT ); """; - await cmd.ExecuteNonQueryAsync(ct); + await cmd.ExecuteNonQueryAsync(ct).ConfigureAwait(false); // Migrate existing tables that lack model/provider columns. await using var alter = conn.CreateCommand(); @@ -36,19 +36,19 @@ provider TEXT ALTER TABLE cron_jobs ADD COLUMN IF NOT EXISTS model TEXT; ALTER TABLE cron_jobs ADD COLUMN IF NOT EXISTS provider TEXT; """; - await alter.ExecuteNonQueryAsync(ct); + await alter.ExecuteNonQueryAsync(ct).ConfigureAwait(false); } public async Task> LoadAllAsync(CancellationToken ct = default) { await using var conn = new NpgsqlConnection(connectionString); - await conn.OpenAsync(ct); + await conn.OpenAsync(ct).ConfigureAwait(false); await using var cmd = conn.CreateCommand(); cmd.CommandText = "SELECT id,name,schedule_kind,schedule_expr,tz,channel,message,sender_id,enabled,created_at,last_run_at,run_count,source,model,provider FROM cron_jobs"; var jobs = new List(); - await using var reader = await cmd.ExecuteReaderAsync(ct); - while (await reader.ReadAsync(ct)) + await using var reader = await cmd.ExecuteReaderAsync(ct).ConfigureAwait(false); + while (await reader.ReadAsync(ct).ConfigureAwait(false)) { jobs.Add(new CronJob { @@ -76,7 +76,7 @@ public async Task> LoadAllAsync(CancellationToken ct = de public async Task UpsertAsync(CronJob job, CancellationToken ct = default) { await using var conn = new NpgsqlConnection(connectionString); - await conn.OpenAsync(ct); + await conn.OpenAsync(ct).ConfigureAwait(false); await using var cmd = conn.CreateCommand(); cmd.CommandText = """ INSERT INTO cron_jobs (id,name,schedule_kind,schedule_expr,tz,channel,message,sender_id,enabled,created_at,last_run_at,run_count,source,model,provider) @@ -104,28 +104,28 @@ ON CONFLICT(id) DO UPDATE SET cmd.Parameters.AddWithValue("@src", job.Source.Value); cmd.Parameters.AddWithValue("@model", (object?)job.Model ?? DBNull.Value); cmd.Parameters.AddWithValue("@provider", (object?)job.Provider ?? DBNull.Value); - await cmd.ExecuteNonQueryAsync(ct); + await cmd.ExecuteNonQueryAsync(ct).ConfigureAwait(false); } public async Task DeleteAsync(string id, CancellationToken ct = default) { await using var conn = new NpgsqlConnection(connectionString); - await conn.OpenAsync(ct); + await conn.OpenAsync(ct).ConfigureAwait(false); await using var cmd = conn.CreateCommand(); cmd.CommandText = "DELETE FROM cron_jobs WHERE id = @id"; cmd.Parameters.AddWithValue("@id", id); - await cmd.ExecuteNonQueryAsync(ct); + await cmd.ExecuteNonQueryAsync(ct).ConfigureAwait(false); } public async Task UpdateRunStatsAsync(string id, DateTimeOffset lastRunAt, int runCount, CancellationToken ct = default) { await using var conn = new NpgsqlConnection(connectionString); - await conn.OpenAsync(ct); + await conn.OpenAsync(ct).ConfigureAwait(false); await using var cmd = conn.CreateCommand(); cmd.CommandText = "UPDATE cron_jobs SET last_run_at=@lra, run_count=@rc WHERE id=@id"; cmd.Parameters.AddWithValue("@lra", lastRunAt.ToString("O")); cmd.Parameters.AddWithValue("@rc", runCount); cmd.Parameters.AddWithValue("@id", id); - await cmd.ExecuteNonQueryAsync(ct); + await cmd.ExecuteNonQueryAsync(ct).ConfigureAwait(false); } } \ No newline at end of file diff --git a/src/clawsharp/Cron/SqliteCronStore.cs b/src/clawsharp/Cron/SqliteCronStore.cs index c2a50cad..6e57cc81 100644 --- a/src/clawsharp/Cron/SqliteCronStore.cs +++ b/src/clawsharp/Cron/SqliteCronStore.cs @@ -17,11 +17,11 @@ public sealed class SqliteCronStore(string dbPath) : ICronStore public async Task InitAsync(CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { await using var conn = new SqliteConnection(_connectionString); - await conn.OpenAsync(ct); + await conn.OpenAsync(ct).ConfigureAwait(false); await using var cmd = conn.CreateCommand(); cmd.CommandText = """ CREATE TABLE IF NOT EXISTS cron_jobs ( @@ -42,7 +42,7 @@ CREATE TABLE IF NOT EXISTS cron_jobs ( provider TEXT ); """; - await cmd.ExecuteNonQueryAsync(ct); + await cmd.ExecuteNonQueryAsync(ct).ConfigureAwait(false); // Migrate existing tables that lack model/provider columns. await using var alter = conn.CreateCommand(); @@ -51,7 +51,7 @@ provider TEXT """; try { - await alter.ExecuteNonQueryAsync(ct); + await alter.ExecuteNonQueryAsync(ct).ConfigureAwait(false); } catch (SqliteException) { @@ -64,7 +64,7 @@ provider TEXT """; try { - await alter2.ExecuteNonQueryAsync(ct); + await alter2.ExecuteNonQueryAsync(ct).ConfigureAwait(false); } catch (SqliteException) { @@ -79,17 +79,17 @@ provider TEXT public async Task> LoadAllAsync(CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { await using var conn = new SqliteConnection(_connectionString); - await conn.OpenAsync(ct); + await conn.OpenAsync(ct).ConfigureAwait(false); await using var cmd = conn.CreateCommand(); cmd.CommandText = "SELECT id,name,schedule_kind,schedule_expr,tz,channel,message,sender_id,enabled,created_at,last_run_at,run_count,source,model,provider FROM cron_jobs"; var jobs = new List(); - await using var reader = await cmd.ExecuteReaderAsync(ct); - while (await reader.ReadAsync(ct)) + await using var reader = await cmd.ExecuteReaderAsync(ct).ConfigureAwait(false); + while (await reader.ReadAsync(ct).ConfigureAwait(false)) { jobs.Add(MapRow(reader)); } @@ -104,11 +104,11 @@ public async Task> LoadAllAsync(CancellationToken ct = de public async Task UpsertAsync(CronJob job, CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { await using var conn = new SqliteConnection(_connectionString); - await conn.OpenAsync(ct); + await conn.OpenAsync(ct).ConfigureAwait(false); await using var cmd = conn.CreateCommand(); cmd.CommandText = """ INSERT INTO cron_jobs (id,name,schedule_kind,schedule_expr,tz,channel,message,sender_id,enabled,created_at,last_run_at,run_count,source,model,provider) @@ -122,7 +122,7 @@ ON CONFLICT(id) DO UPDATE SET source=excluded.source, model=excluded.model, provider=excluded.provider; """; BindParams(cmd, job); - await cmd.ExecuteNonQueryAsync(ct); + await cmd.ExecuteNonQueryAsync(ct).ConfigureAwait(false); } finally { @@ -132,15 +132,15 @@ ON CONFLICT(id) DO UPDATE SET public async Task DeleteAsync(string id, CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { await using var conn = new SqliteConnection(_connectionString); - await conn.OpenAsync(ct); + await conn.OpenAsync(ct).ConfigureAwait(false); await using var cmd = conn.CreateCommand(); cmd.CommandText = "DELETE FROM cron_jobs WHERE id = @id"; cmd.Parameters.AddWithValue("@id", id); - await cmd.ExecuteNonQueryAsync(ct); + await cmd.ExecuteNonQueryAsync(ct).ConfigureAwait(false); } finally { @@ -150,17 +150,17 @@ public async Task DeleteAsync(string id, CancellationToken ct = default) public async Task UpdateRunStatsAsync(string id, DateTimeOffset lastRunAt, int runCount, CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { await using var conn = new SqliteConnection(_connectionString); - await conn.OpenAsync(ct); + await conn.OpenAsync(ct).ConfigureAwait(false); await using var cmd = conn.CreateCommand(); cmd.CommandText = "UPDATE cron_jobs SET last_run_at=@lra, run_count=@rc WHERE id=@id"; cmd.Parameters.AddWithValue("@lra", lastRunAt.ToString("O")); cmd.Parameters.AddWithValue("@rc", runCount); cmd.Parameters.AddWithValue("@id", id); - await cmd.ExecuteNonQueryAsync(ct); + await cmd.ExecuteNonQueryAsync(ct).ConfigureAwait(false); } finally { diff --git a/src/clawsharp/Features/Behaviors/AuthorizationBehavior.cs b/src/clawsharp/Features/Behaviors/AuthorizationBehavior.cs index dd184247..d335a6c0 100644 --- a/src/clawsharp/Features/Behaviors/AuthorizationBehavior.cs +++ b/src/clawsharp/Features/Behaviors/AuthorizationBehavior.cs @@ -1,6 +1,5 @@ using Clawsharp.Config; using Immediate.Handlers.Shared; -using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; namespace Clawsharp.Features.Behaviors; @@ -10,9 +9,8 @@ namespace Clawsharp.Features.Behaviors; /// Provides fast-path bypass for handlers that do not require authorization (D-18), /// and serves as the gate + context + audit hook for the policy engine (D-17). /// -public sealed partial class AuthorizationBehavior( - IOptions appConfig, - ILogger> logger +public sealed class AuthorizationBehavior( + IOptions appConfig ) : Behavior { public override async ValueTask HandleAsync( @@ -20,11 +18,11 @@ public override async ValueTask HandleAsync( { // Fast-path: no org config = no authorization needed (backward compat) if (appConfig.Value.Organization is null) - return await Next(request, cancellationToken); + return await Next(request, cancellationToken).ConfigureAwait(false); // Fast-path: skip internal handlers that don't need auth (D-18) if (!RequiresAuthorization(request)) - return await Next(request, cancellationToken); + return await Next(request, cancellationToken).ConfigureAwait(false); // D-19: Context propagation + gates happen here. // Phase 3 establishes the behavior in the pipeline. @@ -32,7 +30,7 @@ public override async ValueTask HandleAsync( // Tool gates are handled at ToolRegistry (Phase 2). // Future phases add: admin command gating, budget gates, audit emission. - return await Next(request, cancellationToken); + return await Next(request, cancellationToken).ConfigureAwait(false); } /// diff --git a/src/clawsharp/Features/Behaviors/LoggingBehavior.cs b/src/clawsharp/Features/Behaviors/LoggingBehavior.cs index 84eaef70..52cbbb01 100644 --- a/src/clawsharp/Features/Behaviors/LoggingBehavior.cs +++ b/src/clawsharp/Features/Behaviors/LoggingBehavior.cs @@ -22,7 +22,7 @@ public override async ValueTask HandleAsync( LogHandlingHandlerWithRequest(handlerName, requestName); var timestamp = Stopwatch.GetTimestamp(); - var response = await Next(request, cancellationToken); + var response = await Next(request, cancellationToken).ConfigureAwait(false); var elapsedTime = Stopwatch.GetElapsedTime(timestamp); LogHandledHandlerInElapsedmsMs(handlerName, elapsedTime.TotalMilliseconds); diff --git a/src/clawsharp/Features/Chat/Commands/ApplySecurityGuards.cs b/src/clawsharp/Features/Chat/Commands/ApplySecurityGuards.cs index 8f0224ae..e4c81888 100644 --- a/src/clawsharp/Features/Chat/Commands/ApplySecurityGuards.cs +++ b/src/clawsharp/Features/Chat/Commands/ApplySecurityGuards.cs @@ -70,8 +70,7 @@ private static ValueTask HandleAsync( if (injAction != InjectionAction.None) { - logger.LogWarning( - "Potential prompt injection in user message: {Preview}", + LogPotentialInjection(logger, userText[..Math.Min(InjectionLogPreviewLength, userText.Length)]); } @@ -82,9 +81,7 @@ private static ValueTask HandleAsync( IReadOnlyList? imagesToInclude = inbound.Images; if (inbound.Images is { Count: > 0 } && !command.SupportsVision) { - logger.LogWarning( - "Dropping {ImageCount} image(s) from {Channel} — provider {Provider} does not support vision", - inbound.Images.Count, inbound.Channel.Value, command.ProviderName); + LogVisionDropped(logger, inbound.Images.Count, inbound.Channel.Value, command.ProviderName); imagesToInclude = null; userText += "\n\n[System note: The user attached image(s) but the current provider does not support vision. Inform the user that their images could not be processed and suggest switching to a vision-capable provider like gpt-4o, claude-3, or gemini.]"; @@ -92,4 +89,12 @@ private static ValueTask HandleAsync( return ValueTask.FromResult(new Result(userText, imagesToInclude, inbound.Files, inbound.Videos, inbound.Audio, Blocked: false)); } + + [LoggerMessage(Level = LogLevel.Warning, + Message = "Potential prompt injection in user message: {Preview}")] + private static partial void LogPotentialInjection(ILogger logger, string preview); + + [LoggerMessage(Level = LogLevel.Warning, + Message = "Dropping {ImageCount} image(s) from {Channel} — provider {Provider} does not support vision")] + private static partial void LogVisionDropped(ILogger logger, int imageCount, string channel, string provider); } \ No newline at end of file diff --git a/src/clawsharp/Features/Chat/Commands/SanitizeReply.cs b/src/clawsharp/Features/Chat/Commands/SanitizeReply.cs index d2bb5d10..1210f32b 100644 --- a/src/clawsharp/Features/Chat/Commands/SanitizeReply.cs +++ b/src/clawsharp/Features/Chat/Commands/SanitizeReply.cs @@ -47,30 +47,32 @@ private static async ValueTask HandleAsync( reply = command.CanaryGuard.Redact(reply); await auditLogger.LogSecurityEventAsync( "Canary token exfiltration detected \u2014 LLM leaked system prompt content", - command.ChannelName, command.SenderId, ct); - logger.LogWarning( - "Canary exfiltration detected on channel {Channel} for sender {Sender}", - command.ChannelName, command.SenderId); + command.ChannelName, command.SenderId, ct).ConfigureAwait(false); + LogCanaryExfiltration(logger, command.ChannelName, command.SenderId); } // Scan outbound reply for leaked credentials/secrets before delivery. + // Sensitivity is [Range(0,1)] — structural patterns always run; higher values add generic/entropy checks. var leakSensitivity = appConfig.Value.Security?.LeakDetector?.Sensitivity ?? 0.7; - if (leakSensitivity >= 0) + var leakResult = LeakDetector.Scan(reply, leakSensitivity); + if (!leakResult.IsClean) { - var leakResult = LeakDetector.Scan(reply, leakSensitivity); - if (!leakResult.IsClean) - { - await auditLogger.LogSecurityEventAsync( - $"LLM output leak detected: {string.Join(", ", leakResult.Patterns)}", - command.ChannelName, command.SenderId, ct); - logger.LogWarning( - "Leak detected in reply for {Channel}:{Sender}: {Patterns}", - command.ChannelName, command.SenderId, - string.Join(", ", leakResult.Patterns)); - reply = leakResult.Redacted; - } + await auditLogger.LogSecurityEventAsync( + $"LLM output leak detected: {string.Join(", ", leakResult.Patterns)}", + command.ChannelName, command.SenderId, ct).ConfigureAwait(false); + LogLeakDetected(logger, command.ChannelName, command.SenderId, + string.Join(", ", leakResult.Patterns)); + reply = leakResult.Redacted; } return new Result(reply); } + + [LoggerMessage(Level = LogLevel.Warning, + Message = "Canary exfiltration detected on channel {Channel} for sender {Sender}")] + private static partial void LogCanaryExfiltration(ILogger logger, string channel, string sender); + + [LoggerMessage(Level = LogLevel.Warning, + Message = "Leak detected in reply for {Channel}:{Sender}: {Patterns}")] + private static partial void LogLeakDetected(ILogger logger, string channel, string sender, string patterns); } \ No newline at end of file diff --git a/src/clawsharp/Features/Chat/Queries/BuildChatRequest.cs b/src/clawsharp/Features/Chat/Queries/BuildChatRequest.cs index f19b9eab..be235b00 100644 --- a/src/clawsharp/Features/Chat/Queries/BuildChatRequest.cs +++ b/src/clawsharp/Features/Chat/Queries/BuildChatRequest.cs @@ -25,6 +25,11 @@ public static partial class BuildChatRequest { private const int MaxGoalsContextChars = 500; + // Cached SYSTEM.md content — reloaded only when file's LastWriteTimeUtc changes. + private static string? _cachedSystemMd; + private static string? _cachedSystemMdPath; + private static DateTime _cachedSystemMdLastWrite; + /// Logger category for DI resolution (static types cannot be used as type arguments). public sealed class Log; @@ -60,10 +65,10 @@ private static async ValueTask HandleAsync( var workspacePath = ConfigLoader.ExpandHome(appConfig.Value.Tools.Workspace); // Load workspace context (SYSTEM.md) — best-effort, never breaks the pipeline. - var workspaceContext = await LoadWorkspaceContextAsync(workspacePath, logger, ct); + var workspaceContext = await LoadWorkspaceContextAsync(workspacePath, logger, ct).ConfigureAwait(false); // Build goals context — best-effort, never breaks the pipeline. - var goalsContext = await BuildGoalsContextAsync(goalStorage, ct); + var goalsContext = await BuildGoalsContextAsync(goalStorage, ct).ConfigureAwait(false); // Set tool context and get filtered definitions for this message. toolRegistry.SetChannelContext(inbound.Channel, inbound.SpawnDepth, @@ -81,7 +86,7 @@ private static async ValueTask HandleAsync( query.MemoryContext, workspaceContext, channelName: inbound.Channel.Value, - enabledTools: toolDefs.Select(t => t.Name).ToList(), + enabledTools: toolDefs.Select(t => t.Name), activeGoalsContext: goalsContext); string systemPrompt; @@ -108,6 +113,7 @@ private static async ValueTask HandleAsync( /// /// Reads the SYSTEM.md file from the workspace directory if it exists. + /// Caches the content and only re-reads when the file's LastWriteTimeUtc changes. /// Returns null on any I/O failure — workspace context is strictly best-effort. /// private static async Task LoadWorkspaceContextAsync( @@ -118,16 +124,29 @@ private static async ValueTask HandleAsync( var systemMdPath = Path.Combine(workspacePath, "SYSTEM.md"); if (!File.Exists(systemMdPath)) { + _cachedSystemMd = null; + _cachedSystemMdPath = null; return null; } try { - return await File.ReadAllTextAsync(systemMdPath, ct); + var lastWrite = File.GetLastWriteTimeUtc(systemMdPath); + if (string.Equals(_cachedSystemMdPath, systemMdPath, StringComparison.Ordinal) + && lastWrite == _cachedSystemMdLastWrite) + { + return _cachedSystemMd; + } + + var content = await File.ReadAllTextAsync(systemMdPath, ct).ConfigureAwait(false); + _cachedSystemMd = content; + _cachedSystemMdPath = systemMdPath; + _cachedSystemMdLastWrite = lastWrite; + return content; } catch (Exception ex) { - logger.LogWarning(ex, "Could not read workspace SYSTEM.md at {Path}", systemMdPath); + LogWorkspaceReadFailed(logger, systemMdPath, ex); return null; } } @@ -143,7 +162,7 @@ private static async ValueTask HandleAsync( { try { - var goals = await goalStorage.LoadAsync(ct); + var goals = await goalStorage.LoadAsync(ct).ConfigureAwait(false); var active = goals.Where(g => g.Status == GoalStatus.Active).ToList(); if (active.Count == 0) { @@ -151,8 +170,9 @@ private static async ValueTask HandleAsync( } var sb = new StringBuilder(); - foreach (var g in active) + for (var i = 0; i < active.Count; i++) { + var g = active[i]; var doneCount = g.Steps.Count(s => s.Done); var stepInfo = g.Steps.Count > 0 ? $" ({doneCount}/{g.Steps.Count} steps done)" : ""; sb.AppendLine($"- [{g.Id}] {PromptGuard.EscapeDelimiterContent(g.Title)}{stepInfo}"); @@ -160,7 +180,8 @@ private static async ValueTask HandleAsync( // Cap at ~500 chars to avoid bloating every prompt. if (sb.Length > MaxGoalsContextChars) { - sb.AppendLine(" ...(more goals truncated)"); + var remaining = active.Count - i - 1; + sb.AppendLine($" ...({remaining} more goal{(remaining != 1 ? "s" : "")} not shown \u2014 use /goals to view all)"); break; } } @@ -173,4 +194,8 @@ private static async ValueTask HandleAsync( return null; } } + + [LoggerMessage(Level = LogLevel.Warning, + Message = "Could not read workspace SYSTEM.md at {Path}")] + private static partial void LogWorkspaceReadFailed(ILogger logger, string path, Exception exception); } \ No newline at end of file diff --git a/src/clawsharp/Features/Chat/Queries/RouteModel.cs b/src/clawsharp/Features/Chat/Queries/RouteModel.cs index aa89c7b8..1450b735 100644 --- a/src/clawsharp/Features/Chat/Queries/RouteModel.cs +++ b/src/clawsharp/Features/Chat/Queries/RouteModel.cs @@ -73,21 +73,15 @@ private static ValueTask HandleAsync( // Check if simple model is allowed by user's policy (per D-13/MODEL-04) if (!ShouldRouteToSimple(query.CurrentPolicy, routing.SimpleModel)) { - logger.LogDebug( - "Model routing: simple model {SimpleModel} denied by policy, using primary {PrimaryModel}", - routing.SimpleModel, cfg.Model); + LogSimpleModelDenied(logger, routing.SimpleModel, cfg.Model); return ValueTask.FromResult(new Result(cfg.Model, WasRouted: false, ComplexityScore: score)); } - logger.LogDebug( - "Model routing: score {Score} < threshold {Threshold}, using simple model {SimpleModel}", - score, routing.Threshold, routing.SimpleModel); + LogRoutedToSimple(logger, score, routing.Threshold, routing.SimpleModel); return ValueTask.FromResult(new Result(routing.SimpleModel, WasRouted: true, ComplexityScore: score)); } - logger.LogDebug( - "Model routing: score {Score} >= threshold {Threshold}, using primary model {PrimaryModel}", - score, routing.Threshold, cfg.Model); + LogRoutedToPrimary(logger, score, routing.Threshold, cfg.Model); return ValueTask.FromResult(new Result(cfg.Model, WasRouted: false, ComplexityScore: score)); } @@ -110,4 +104,16 @@ internal static bool ShouldRouteToSimple(PolicyDecision? policy, string simpleMo return policy.IsModelAllowed(simpleModel); } + + [LoggerMessage(Level = LogLevel.Debug, + Message = "Model routing: simple model {SimpleModel} denied by policy, using primary {PrimaryModel}")] + private static partial void LogSimpleModelDenied(ILogger logger, string simpleModel, string primaryModel); + + [LoggerMessage(Level = LogLevel.Debug, + Message = "Model routing: score {Score} < threshold {Threshold}, using simple model {SimpleModel}")] + private static partial void LogRoutedToSimple(ILogger logger, int score, int threshold, string simpleModel); + + [LoggerMessage(Level = LogLevel.Debug, + Message = "Model routing: score {Score} >= threshold {Threshold}, using primary model {PrimaryModel}")] + private static partial void LogRoutedToPrimary(ILogger logger, int score, int threshold, string primaryModel); } \ No newline at end of file diff --git a/src/clawsharp/Features/Cost/Commands/RecordUsage.cs b/src/clawsharp/Features/Cost/Commands/RecordUsage.cs index 5081f411..bf9b18db 100644 --- a/src/clawsharp/Features/Cost/Commands/RecordUsage.cs +++ b/src/clawsharp/Features/Cost/Commands/RecordUsage.cs @@ -37,6 +37,6 @@ await costTracker.RecordUsageAsync( command.ProviderReportedCost, command.UserId, command.DepartmentId, - ct); + ct).ConfigureAwait(false); } } \ No newline at end of file diff --git a/src/clawsharp/Features/Cost/Queries/CheckBudget.cs b/src/clawsharp/Features/Cost/Queries/CheckBudget.cs index f82b4bef..e060d906 100644 --- a/src/clawsharp/Features/Cost/Queries/CheckBudget.cs +++ b/src/clawsharp/Features/Cost/Queries/CheckBudget.cs @@ -36,6 +36,6 @@ private static async ValueTask HandleAsync( query.DepartmentId, query.UserBudget, query.DepartmentBudget, - ct); + ct).ConfigureAwait(false); } } diff --git a/src/clawsharp/Features/Cost/Queries/GetCostSummary.cs b/src/clawsharp/Features/Cost/Queries/GetCostSummary.cs index db5003c9..c7a37ce2 100644 --- a/src/clawsharp/Features/Cost/Queries/GetCostSummary.cs +++ b/src/clawsharp/Features/Cost/Queries/GetCostSummary.cs @@ -18,6 +18,6 @@ private static async ValueTask HandleAsync( CostTracker costTracker, CancellationToken ct) { - return await costTracker.GetSummaryAsync(query.SessionId, ct); + return await costTracker.GetSummaryAsync(query.SessionId, ct).ConfigureAwait(false); } } \ No newline at end of file diff --git a/src/clawsharp/Features/Memory/Commands/ClearMemory.cs b/src/clawsharp/Features/Memory/Commands/ClearMemory.cs index 42cb1506..87d243c2 100644 --- a/src/clawsharp/Features/Memory/Commands/ClearMemory.cs +++ b/src/clawsharp/Features/Memory/Commands/ClearMemory.cs @@ -19,6 +19,6 @@ private static async ValueTask HandleAsync( IMemory memory, CancellationToken ct) { - await memory.ClearAsync(ct); + await memory.ClearAsync(ct).ConfigureAwait(false); } } \ No newline at end of file diff --git a/src/clawsharp/Features/Memory/Commands/ExtractFacts.cs b/src/clawsharp/Features/Memory/Commands/ExtractFacts.cs index cb99bc5d..ae54e200 100644 --- a/src/clawsharp/Features/Memory/Commands/ExtractFacts.cs +++ b/src/clawsharp/Features/Memory/Commands/ExtractFacts.cs @@ -90,7 +90,7 @@ private static async ValueTask HandleAsync( var scrubResult = LeakDetector.Scan(fact, 0.5); if (!scrubResult.IsClean) { - logger.LogWarning("Scrubbed {Count} secret pattern(s) from extracted fact", scrubResult.Patterns.Count); + LogSecretsScrubbed(logger, scrubResult.Patterns.Count); fact = scrubResult.Redacted; } @@ -100,9 +100,17 @@ private static async ValueTask HandleAsync( if (factsStored > 0) { - logger.LogInformation("Extracted and stored {Count} fact(s) from conversation", factsStored); + LogFactsExtracted(logger, factsStored); } return new Result(factsStored); } + + [LoggerMessage(Level = LogLevel.Warning, + Message = "Scrubbed {Count} secret pattern(s) from extracted fact")] + private static partial void LogSecretsScrubbed(ILogger logger, int count); + + [LoggerMessage(Level = LogLevel.Information, + Message = "Extracted and stored {Count} fact(s) from conversation")] + private static partial void LogFactsExtracted(ILogger logger, int count); } \ No newline at end of file diff --git a/src/clawsharp/Features/Memory/Commands/WriteMemory.cs b/src/clawsharp/Features/Memory/Commands/WriteMemory.cs index 0dd34613..acd29cfb 100644 --- a/src/clawsharp/Features/Memory/Commands/WriteMemory.cs +++ b/src/clawsharp/Features/Memory/Commands/WriteMemory.cs @@ -18,6 +18,6 @@ private static async ValueTask HandleAsync( IMemory memory, CancellationToken ct) { - await memory.AppendFactAsync(command.Fact, ct); + await memory.AppendFactAsync(command.Fact, ct).ConfigureAwait(false); } } \ No newline at end of file diff --git a/src/clawsharp/Features/Memory/Queries/GetMemoryContext.cs b/src/clawsharp/Features/Memory/Queries/GetMemoryContext.cs index c32b02a4..42749e75 100644 --- a/src/clawsharp/Features/Memory/Queries/GetMemoryContext.cs +++ b/src/clawsharp/Features/Memory/Queries/GetMemoryContext.cs @@ -28,7 +28,7 @@ public sealed record Query(string UserText) : IInternalOperation; ILogger logger, CancellationToken ct) { - var primaryContext = await memory.GetContextAsync(ct); + var primaryContext = await memory.GetContextAsync(ct).ConfigureAwait(false); var recallConfig = memoryConfigOptions.Value.EnhancedRecall; if (recallConfig is not { Enabled: true }) @@ -89,9 +89,7 @@ public sealed record Query(string UserText) : IInternalOperation; return primaryContext; } - logger.LogDebug( - "Enhanced recall found {Count} additional facts from {KeywordCount} keywords", - additionalFacts.Count, keywordsToSearch.Count); + LogEnhancedRecallResults(logger, additionalFacts.Count, keywordsToSearch.Count); // Append additional facts to the primary context. var sb = new StringBuilder(); @@ -113,8 +111,16 @@ public sealed record Query(string UserText) : IInternalOperation; } catch (Exception ex) { - logger.LogWarning(ex, "Enhanced recall failed, returning primary context"); + LogEnhancedRecallFailed(logger, ex); return primaryContext; } } + + [LoggerMessage(Level = LogLevel.Debug, + Message = "Enhanced recall found {Count} additional facts from {KeywordCount} keywords")] + private static partial void LogEnhancedRecallResults(ILogger logger, int count, int keywordCount); + + [LoggerMessage(Level = LogLevel.Warning, + Message = "Enhanced recall failed, returning primary context")] + private static partial void LogEnhancedRecallFailed(ILogger logger, Exception exception); } \ No newline at end of file diff --git a/src/clawsharp/Features/Memory/Queries/SearchMemory.cs b/src/clawsharp/Features/Memory/Queries/SearchMemory.cs index b3849b42..30690d20 100644 --- a/src/clawsharp/Features/Memory/Queries/SearchMemory.cs +++ b/src/clawsharp/Features/Memory/Queries/SearchMemory.cs @@ -18,6 +18,6 @@ private static async ValueTask> HandleAsync( IMemory memory, CancellationToken ct) { - return await memory.SearchAsync(query.SearchText, query.Limit, ct); + return await memory.SearchAsync(query.SearchText, query.Limit, ct).ConfigureAwait(false); } } \ No newline at end of file diff --git a/src/clawsharp/Features/Session/Commands/ClearSession.cs b/src/clawsharp/Features/Session/Commands/ClearSession.cs index c64473f4..b3005a45 100644 --- a/src/clawsharp/Features/Session/Commands/ClearSession.cs +++ b/src/clawsharp/Features/Session/Commands/ClearSession.cs @@ -17,6 +17,6 @@ private static async ValueTask HandleAsync( CancellationToken ct) { command.Session.Messages.Clear(); - await sessionManager.SaveAsync(command.Session, ct); + await sessionManager.SaveAsync(command.Session, ct).ConfigureAwait(false); } } \ No newline at end of file diff --git a/src/clawsharp/Features/Session/Commands/CompactSession.cs b/src/clawsharp/Features/Session/Commands/CompactSession.cs deleted file mode 100644 index 11c9a4f3..00000000 --- a/src/clawsharp/Features/Session/Commands/CompactSession.cs +++ /dev/null @@ -1,46 +0,0 @@ -using Clawsharp.Core; -using Clawsharp.Core.Services; -using Clawsharp.Core.Sessions; -using Clawsharp.Providers; -using Immediate.Handlers.Shared; - -namespace Clawsharp.Features.Session.Commands; - -/// -/// Compacts session history by summarizing older messages via the LLM, -/// keeping recent messages verbatim. Replaces the session's message list -/// with the compacted result and persists to disk. -/// -[Handler] -public static partial class CompactSession -{ - public sealed record Command( - Core.Sessions.Session Session, - IProvider Provider, - string Model, - int KeepRecent = 20, - int MaxSummaryChars = 2000, - int MaxSourceChars = 12000) : IInternalOperation; - - private static async ValueTask> HandleAsync( - Command command, - CompactionService compactionService, - SessionStore sessionManager, - CancellationToken ct) - { - var compacted = await compactionService.CompactAsync( - command.Session.Messages, - command.Provider, - command.Model, - command.KeepRecent, - command.MaxSummaryChars, - command.MaxSourceChars, - ct); - - command.Session.Messages.Clear(); - command.Session.Messages.AddRange(compacted); - await sessionManager.SaveAsync(command.Session, ct); - - return compacted; - } -} \ No newline at end of file diff --git a/src/clawsharp/Features/Session/Commands/PruneSession.cs b/src/clawsharp/Features/Session/Commands/PruneSession.cs index ce7fe3a0..2cfc81f6 100644 --- a/src/clawsharp/Features/Session/Commands/PruneSession.cs +++ b/src/clawsharp/Features/Session/Commands/PruneSession.cs @@ -22,7 +22,7 @@ private static async ValueTask HandleAsync( return false; } - await sessionManager.SaveAsync(command.Session, ct); + await sessionManager.SaveAsync(command.Session, ct).ConfigureAwait(false); return true; } } \ No newline at end of file diff --git a/src/clawsharp/Features/Session/Commands/SaveSession.cs b/src/clawsharp/Features/Session/Commands/SaveSession.cs index a767b297..461e22e6 100644 --- a/src/clawsharp/Features/Session/Commands/SaveSession.cs +++ b/src/clawsharp/Features/Session/Commands/SaveSession.cs @@ -16,6 +16,6 @@ private static async ValueTask HandleAsync( SessionStore sessionManager, CancellationToken ct) { - await sessionManager.SaveAsync(command.Session, ct); + await sessionManager.SaveAsync(command.Session, ct).ConfigureAwait(false); } } \ No newline at end of file diff --git a/src/clawsharp/Features/Session/Queries/LoadSession.cs b/src/clawsharp/Features/Session/Queries/LoadSession.cs index b6673fb7..30550e86 100644 --- a/src/clawsharp/Features/Session/Queries/LoadSession.cs +++ b/src/clawsharp/Features/Session/Queries/LoadSession.cs @@ -16,6 +16,6 @@ public sealed record Query(string SessionId) : IInternalOperation; SessionStore sessionManager, CancellationToken ct) { - return await sessionManager.LoadOrCreateAsync(query.SessionId, ct); + return await sessionManager.LoadOrCreateAsync(query.SessionId, ct).ConfigureAwait(false); } } \ No newline at end of file diff --git a/src/clawsharp/Features/Tools/Commands/ExecuteToolCall.cs b/src/clawsharp/Features/Tools/Commands/ExecuteToolCall.cs deleted file mode 100644 index 7f6f86ed..00000000 --- a/src/clawsharp/Features/Tools/Commands/ExecuteToolCall.cs +++ /dev/null @@ -1,22 +0,0 @@ -using Clawsharp.Tools; -using Immediate.Handlers.Shared; - -namespace Clawsharp.Features.Tools.Commands; - -/// -/// Executes a tool by name with the given JSON arguments. -/// Thin wrapper around . -/// -[Handler] -public static partial class ExecuteToolCall -{ - public sealed record Command(string ToolName, string ArgumentsJson); - - private static async ValueTask HandleAsync( - Command command, - IToolRegistry toolRegistry, - CancellationToken ct) - { - return await toolRegistry.ExecuteAsync(command.ToolName, command.ArgumentsJson, ct); - } -} \ No newline at end of file diff --git a/src/clawsharp/Ipc/IpcMessages.cs b/src/clawsharp/Ipc/IpcMessages.cs index b068b256..77dda39e 100644 --- a/src/clawsharp/Ipc/IpcMessages.cs +++ b/src/clawsharp/Ipc/IpcMessages.cs @@ -2,10 +2,16 @@ namespace Clawsharp.Ipc; -internal sealed record IpcRequest(string Command, string? Token = null); +internal sealed record IpcRequest( + [property: JsonPropertyName("command")] string Command, + [property: JsonPropertyName("token")] string? Token = null); -internal sealed record IpcResponse(string? Code, string? Error, bool Cleared); +internal sealed record IpcResponse( + [property: JsonPropertyName("code")] string? Code, + [property: JsonPropertyName("error")] string? Error, + [property: JsonPropertyName("cleared")] bool Cleared); [JsonSerializable(typeof(IpcRequest))] [JsonSerializable(typeof(IpcResponse))] +[JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] internal sealed partial class IpcJsonContext : JsonSerializerContext; \ No newline at end of file diff --git a/src/clawsharp/Knowledge/Chunking/ChunkingHelpers.cs b/src/clawsharp/Knowledge/Chunking/ChunkingHelpers.cs new file mode 100644 index 00000000..89fe086b --- /dev/null +++ b/src/clawsharp/Knowledge/Chunking/ChunkingHelpers.cs @@ -0,0 +1,93 @@ +using System.Runtime.CompilerServices; +using Clawsharp.Knowledge.Loading; + +namespace Clawsharp.Knowledge.Chunking; + +/// +/// Shared helper methods used by both and +/// : page concatenation, source page lookup, +/// overlap extraction, and sync-to-async enumerable bridging. +/// +internal static class ChunkingHelpers +{ + /// + /// Concatenates all contents with "\n\n" separators, + /// tracking page boundaries for source attribution per D-18. + /// + internal static async Task<(string CombinedText, List Boundaries)> ConcatenatePagesAsync( + IAsyncEnumerable pages, CancellationToken ct) + { + var boundaries = new List(); + var parts = new List(); + var currentPos = 0; + + await foreach (var page in pages.WithCancellation(ct).ConfigureAwait(false)) + { + if (string.IsNullOrEmpty(page.Content)) + continue; + + if (parts.Count > 0) + currentPos += 2; // "\n\n" separator + + boundaries.Add(new PageBoundary(page.PageNumber, currentPos, currentPos + page.Content.Length)); + parts.Add(page.Content); + currentPos += page.Content.Length; + } + + return (string.Join("\n\n", parts), boundaries); + } + + /// + /// Returns the page numbers whose boundaries overlap the character range + /// [, ). + /// + internal static IReadOnlyList GetSourcePages(int startPos, int endPos, List boundaries) + { + var pages = new List(); + foreach (var boundary in boundaries) + { + if (boundary.End > startPos && boundary.Start < endPos) + pages.Add(boundary.PageNumber); + } + + return pages.Count > 0 ? pages : [1]; + } + + /// + /// Extracts the last tokens from the end of + /// for inter-chunk overlap per D-27/D-28. + /// + internal static string ExtractOverlapFromEnd(string text, int overlapTokens) + { + if (overlapTokens <= 0 || string.IsNullOrEmpty(text)) + return ""; + + var totalTokens = TokenCounter.CountTokens(text); + if (totalTokens <= overlapTokens) + return text; + + var skipTokens = totalTokens - overlapTokens; + var startCharIndex = TokenCounter.GetIndexByTokenCount(text, skipTokens); + return text[startCharIndex..]; + } + + /// + /// Wraps a materialized as an + /// for consumption by . + /// +#pragma warning disable CS1998 // Async method lacks 'await' — required for yield return in IAsyncEnumerable + internal static async IAsyncEnumerable ToAsyncEnumerable( + List items, + [EnumeratorCancellation] CancellationToken ct = default) + { + foreach (var item in items) + { + ct.ThrowIfCancellationRequested(); + yield return item; + } + } +#pragma warning restore CS1998 + + /// Tracks a page's character range within the concatenated document text. + internal sealed record PageBoundary(int PageNumber, int Start, int End); +} diff --git a/src/clawsharp/Knowledge/Chunking/HeadingAwareChunker.cs b/src/clawsharp/Knowledge/Chunking/HeadingAwareChunker.cs index a4d7295f..e22ae855 100644 --- a/src/clawsharp/Knowledge/Chunking/HeadingAwareChunker.cs +++ b/src/clawsharp/Knowledge/Chunking/HeadingAwareChunker.cs @@ -16,7 +16,7 @@ internal sealed partial class HeadingAwareChunker : IChunkingStrategy private static partial Regex HeadingLineRegex(); /// - public string Name => "paragraph"; + public string Name => "heading"; /// public async IAsyncEnumerable ChunkAsync( @@ -25,7 +25,7 @@ public async IAsyncEnumerable ChunkAsync( [EnumeratorCancellation] CancellationToken ct = default) { // Step 1: Materialize and concatenate all pages (D-18) - var (combinedText, pageBoundaries) = await ConcatenatePagesAsync(pages, ct); + var (combinedText, pageBoundaries) = await ChunkingHelpers.ConcatenatePagesAsync(pages, ct).ConfigureAwait(false); if (string.IsNullOrWhiteSpace(combinedText)) yield break; @@ -51,12 +51,12 @@ public async IAsyncEnumerable ChunkAsync( var content = seg.Text; if (idx > 0 && overlapTokens > 0 && previousRaw is not null) { - var overlapText = ExtractOverlapFromEnd(previousRaw, overlapTokens); + var overlapText = ChunkingHelpers.ExtractOverlapFromEnd(previousRaw, overlapTokens); if (!string.IsNullOrEmpty(overlapText)) content = overlapText + content; } - var sourcePages = GetSourcePages(seg.StartPos, seg.StartPos + seg.Text.Length, pageBoundaries); + var sourcePages = ChunkingHelpers.GetSourcePages(seg.StartPos, seg.StartPos + seg.Text.Length, pageBoundaries); yield return new DocumentChunk( Content: content, @@ -101,7 +101,7 @@ public async IAsyncEnumerable ChunkAsync( // Apply overlap from previous chunk if (chunkIndex > 0 && overlapTokenCount > 0 && previousRawContent is not null) { - var overlapText = ExtractOverlapFromEnd(previousRawContent, overlapTokenCount); + var overlapText = ChunkingHelpers.ExtractOverlapFromEnd(previousRawContent, overlapTokenCount); if (!string.IsNullOrEmpty(overlapText)) content = overlapText + content; } @@ -111,7 +111,7 @@ public async IAsyncEnumerable ChunkAsync( ? $"[Section: {section.Heading}]\n{content}" : content; - var sourcePages = GetSourcePages( + var sourcePages = ChunkingHelpers.GetSourcePages( section.StartPos, section.StartPos + rawSectionText.Length, pageBoundaries); yield return new DocumentChunk( @@ -141,7 +141,7 @@ public async IAsyncEnumerable ChunkAsync( // Apply overlap from previous chunk if (chunkIndex > 0 && overlapTokenCount > 0 && previousRawContent is not null) { - var overlapText = ExtractOverlapFromEnd(previousRawContent, overlapTokenCount); + var overlapText = ChunkingHelpers.ExtractOverlapFromEnd(previousRawContent, overlapTokenCount); if (!string.IsNullOrEmpty(overlapText)) content = overlapText + content; } @@ -151,7 +151,7 @@ public async IAsyncEnumerable ChunkAsync( ? $"[Section: {section.Heading}]\n{content}" : content; - var sourcePages = GetSourcePages( + var sourcePages = ChunkingHelpers.GetSourcePages( seg.StartPos, seg.StartPos + seg.Text.Length, pageBoundaries); yield return new DocumentChunk( @@ -208,56 +208,5 @@ private static List ParseSections(string text) return sections; } - private static string ExtractOverlapFromEnd(string text, int overlapTokens) - { - if (overlapTokens <= 0 || string.IsNullOrEmpty(text)) - return ""; - - var totalTokens = TokenCounter.CountTokens(text); - if (totalTokens <= overlapTokens) - return text; - - var skipTokens = totalTokens - overlapTokens; - var startCharIndex = TokenCounter.GetIndexByTokenCount(text, skipTokens); - return text[startCharIndex..]; - } - - private static async Task<(string CombinedText, List Boundaries)> ConcatenatePagesAsync( - IAsyncEnumerable pages, CancellationToken ct) - { - var boundaries = new List(); - var parts = new List(); - var currentPos = 0; - - await foreach (var page in pages.WithCancellation(ct)) - { - if (string.IsNullOrEmpty(page.Content)) - continue; - - if (parts.Count > 0) - currentPos += 2; // "\n\n" separator - - boundaries.Add(new PageBoundary(page.PageNumber, currentPos, currentPos + page.Content.Length)); - parts.Add(page.Content); - currentPos += page.Content.Length; - } - - return (string.Join("\n\n", parts), boundaries); - } - - private static IReadOnlyList GetSourcePages(int startPos, int endPos, List boundaries) - { - var pages = new List(); - foreach (var boundary in boundaries) - { - if (boundary.End > startPos && boundary.Start < endPos) - pages.Add(boundary.PageNumber); - } - - return pages.Count > 0 ? pages : [1]; - } - private sealed record HeadingSection(string? Heading, string Content, int StartPos); - - private sealed record PageBoundary(int PageNumber, int Start, int End); } diff --git a/src/clawsharp/Knowledge/Chunking/IChunkingStrategy.cs b/src/clawsharp/Knowledge/Chunking/IChunkingStrategy.cs index 7b98215f..a355e556 100644 --- a/src/clawsharp/Knowledge/Chunking/IChunkingStrategy.cs +++ b/src/clawsharp/Knowledge/Chunking/IChunkingStrategy.cs @@ -1,7 +1,7 @@ namespace Clawsharp.Knowledge.Chunking; -using Clawsharp.Knowledge.Config; -using Clawsharp.Knowledge.Loading; +using Config; +using Loading; /// /// Chunking strategy that consumes document pages and produces sized chunks @@ -9,7 +9,7 @@ namespace Clawsharp.Knowledge.Chunking; /// public interface IChunkingStrategy { - /// Strategy name for config matching ("recursive", "paragraph"). + /// Strategy name for config matching ("recursive", "heading"). string Name { get; } IAsyncEnumerable ChunkAsync( diff --git a/src/clawsharp/Knowledge/Chunking/RecursiveCharacterChunker.cs b/src/clawsharp/Knowledge/Chunking/RecursiveCharacterChunker.cs index 19b381c1..d18d5b96 100644 --- a/src/clawsharp/Knowledge/Chunking/RecursiveCharacterChunker.cs +++ b/src/clawsharp/Knowledge/Chunking/RecursiveCharacterChunker.cs @@ -30,7 +30,7 @@ public async IAsyncEnumerable ChunkAsync( [EnumeratorCancellation] CancellationToken ct = default) { // Step 1: Materialize and concatenate all pages (D-18) - var (combinedText, pageBoundaries) = await ConcatenatePagesAsync(pages, ct); + var (combinedText, pageBoundaries) = await ChunkingHelpers.ConcatenatePagesAsync(pages, ct).ConfigureAwait(false); if (string.IsNullOrWhiteSpace(combinedText)) yield break; @@ -63,7 +63,7 @@ public async IAsyncEnumerable ChunkAsync( // Apply overlap from previous chunk (D-27), excluding heading prefix (D-28) if (chunkIndex > 0 && overlapTokens > 0 && previousRawContent is not null) { - var overlapText = ExtractOverlapFromEnd(previousRawContent, overlapTokens); + var overlapText = ChunkingHelpers.ExtractOverlapFromEnd(previousRawContent, overlapTokens); if (!string.IsNullOrEmpty(overlapText)) contentBuilder.Add(overlapText); } @@ -77,7 +77,7 @@ public async IAsyncEnumerable ChunkAsync( ? $"[Section: {heading}]\n{rawWithOverlap}" : rawWithOverlap; - var sourcePages = GetSourcePages(segment.StartPos, segment.StartPos + segment.Text.Length, pageBoundaries); + var sourcePages = ChunkingHelpers.GetSourcePages(segment.StartPos, segment.StartPos + segment.Text.Length, pageBoundaries); yield return new DocumentChunk( Content: finalContent, @@ -252,59 +252,5 @@ private static void MergeHeadingOnlySegments(List segments, int max return lastHeading; } - private static string ExtractOverlapFromEnd(string text, int overlapTokens) - { - if (overlapTokens <= 0 || string.IsNullOrEmpty(text)) - return ""; - - var totalTokens = TokenCounter.CountTokens(text); - if (totalTokens <= overlapTokens) - return text; - - // Get the character index for (totalTokens - overlapTokens) to find where the last N tokens start - var skipTokens = totalTokens - overlapTokens; - var startCharIndex = TokenCounter.GetIndexByTokenCount(text, skipTokens); - return text[startCharIndex..]; - } - - private static async Task<(string CombinedText, List Boundaries)> ConcatenatePagesAsync( - IAsyncEnumerable pages, CancellationToken ct) - { - var boundaries = new List(); - var parts = new List(); - var currentPos = 0; - - await foreach (var page in pages.WithCancellation(ct)) - { - if (string.IsNullOrEmpty(page.Content)) - continue; - - if (parts.Count > 0) - { - currentPos += 2; // "\n\n" separator - } - - boundaries.Add(new PageBoundary(page.PageNumber, currentPos, currentPos + page.Content.Length)); - parts.Add(page.Content); - currentPos += page.Content.Length; - } - - return (string.Join("\n\n", parts), boundaries); - } - - private static IReadOnlyList GetSourcePages(int startPos, int endPos, List boundaries) - { - var pages = new List(); - foreach (var boundary in boundaries) - { - if (boundary.End > startPos && boundary.Start < endPos) - pages.Add(boundary.PageNumber); - } - - return pages.Count > 0 ? pages : [1]; - } - internal sealed record TextSegment(string Text, int StartPos); - - private sealed record PageBoundary(int PageNumber, int Start, int End); } diff --git a/src/clawsharp/Knowledge/Config/ChunkingConfig.cs b/src/clawsharp/Knowledge/Config/ChunkingConfig.cs index 4899860e..6d470c75 100644 --- a/src/clawsharp/Knowledge/Config/ChunkingConfig.cs +++ b/src/clawsharp/Knowledge/Config/ChunkingConfig.cs @@ -7,14 +7,18 @@ namespace Clawsharp.Knowledge.Config; public sealed class ChunkingConfig { /// Target chunk size in tokens. Default 512 per NAACL 2025 research. - public int ChunkSize { get; init; } = 512; + /// Uses set (not init) so STJ source-gen preserves defaults on deserialization. + public int ChunkSize { get; set; } = 512; /// Overlap ratio between consecutive chunks (0.0 - 1.0). Default 10%. - public double Overlap { get; init; } = 0.1; + /// Uses set (not init) so STJ source-gen preserves defaults on deserialization. + public double Overlap { get; set; } = 0.1; /// - /// Chunking strategy selection per D-22. Values: "recursive", "paragraph", "auto". - /// "auto" detects heading markers in content to choose. Default "auto". + /// Chunking strategy selection per D-22. Values: "recursive", "heading". + /// Default "recursive" (recursive character splitting with separator hierarchy). + /// "heading" splits at markdown heading boundaries first, then falls back to recursive splitting. /// - public string Strategy { get; init; } = "auto"; + /// Uses set (not init) so STJ source-gen preserves defaults on deserialization. + public string Strategy { get; set; } = "recursive"; } diff --git a/src/clawsharp/Knowledge/Config/EmbeddingBatchConfig.cs b/src/clawsharp/Knowledge/Config/EmbeddingBatchConfig.cs index ab8f36f8..3100c52a 100644 --- a/src/clawsharp/Knowledge/Config/EmbeddingBatchConfig.cs +++ b/src/clawsharp/Knowledge/Config/EmbeddingBatchConfig.cs @@ -7,8 +7,10 @@ namespace Clawsharp.Knowledge.Config; public sealed class EmbeddingBatchConfig { /// Maximum number of texts per embedding API call. Default 100. - public int MaxBatchSize { get; init; } = 100; + /// Uses set (not init) so STJ source-gen preserves defaults on deserialization. + public int MaxBatchSize { get; set; } = 100; /// Maximum number of concurrent embedding batches. Default 3. - public int MaxParallelBatches { get; init; } = 3; + /// Uses set (not init) so STJ source-gen preserves defaults on deserialization. + public int MaxParallelBatches { get; set; } = 3; } diff --git a/src/clawsharp/Knowledge/Config/KnowledgeConfig.cs b/src/clawsharp/Knowledge/Config/KnowledgeConfig.cs index 29691d6d..af720386 100644 --- a/src/clawsharp/Knowledge/Config/KnowledgeConfig.cs +++ b/src/clawsharp/Knowledge/Config/KnowledgeConfig.cs @@ -31,6 +31,13 @@ public sealed class KnowledgeConfig /// Retrieval and hybrid search configuration per D-47. Null = all defaults. public RetrievalConfig? Retrieval { get; init; } + /// + /// Whether plugins must pass Ed25519 signature verification before loading (D-35). + /// Default is true — unsigned or tampered plugins are rejected at startup. + /// Set to false only for local development with unsigned plugin DLLs. + /// + public bool RequireSignedPlugins { get; init; } = true; + /// /// Per-plugin configuration keyed by plugin name per D-44. /// Each plugin receives its own IConfiguration section scoped to diff --git a/src/clawsharp/Knowledge/Config/KnowledgeSourceType.cs b/src/clawsharp/Knowledge/Config/KnowledgeSourceType.cs new file mode 100644 index 00000000..2751437a --- /dev/null +++ b/src/clawsharp/Knowledge/Config/KnowledgeSourceType.cs @@ -0,0 +1,14 @@ +namespace Clawsharp.Knowledge.Config; + +/// +/// Well-known source type discriminator constants for . +/// +internal static class KnowledgeSourceType +{ + public const string Local = "local"; + public const string Confluence = "confluence"; + public const string Git = "git"; + public const string S3 = "s3"; + public const string Azure = "azure"; + public const string Gcs = "gcs"; +} diff --git a/src/clawsharp/Knowledge/Config/RetrievalConfig.cs b/src/clawsharp/Knowledge/Config/RetrievalConfig.cs index 8644286f..515c9e13 100644 --- a/src/clawsharp/Knowledge/Config/RetrievalConfig.cs +++ b/src/clawsharp/Knowledge/Config/RetrievalConfig.cs @@ -7,10 +7,12 @@ namespace Clawsharp.Knowledge.Config; public sealed class RetrievalConfig { /// Default number of results to return from knowledge search. Default 5. - public int DefaultTopK { get; init; } = 5; + /// Uses set (not init) so STJ source-gen preserves defaults on deserialization. + public int DefaultTopK { get; set; } = 5; /// RRF constant (k parameter in 1/(k+rank)). Default 60 per standard RRF literature. - public int RrfK { get; init; } = 60; + /// Uses set (not init) so STJ source-gen preserves defaults on deserialization. + public int RrfK { get; set; } = 60; /// /// Optional reranker configuration. Null = no reranking (PassThroughReranker used). @@ -23,5 +25,6 @@ public sealed class RetrievalConfig /// CandidateMultiplier * topK candidates, then the reranker narrows to topK. /// Default 6 per D-30. /// - public int CandidateMultiplier { get; init; } = 6; + /// Uses set (not init) so STJ source-gen preserves defaults on deserialization. + public int CandidateMultiplier { get; set; } = 6; } diff --git a/src/clawsharp/Knowledge/Embedding/BatchEmbeddingProvider.cs b/src/clawsharp/Knowledge/Embedding/BatchEmbeddingProvider.cs index 0c5164e3..0d08dc39 100644 --- a/src/clawsharp/Knowledge/Embedding/BatchEmbeddingProvider.cs +++ b/src/clawsharp/Knowledge/Embedding/BatchEmbeddingProvider.cs @@ -13,7 +13,7 @@ namespace Clawsharp.Knowledge.Embedding; /// . /// Per D-01, D-05 through D-10 of the Phase 22 embedding design. /// -public sealed class BatchEmbeddingProvider : IBatchEmbeddingProvider +public sealed partial class BatchEmbeddingProvider : IBatchEmbeddingProvider { private static readonly TimeSpan MaxRetryDelay = TimeSpan.FromSeconds(60); @@ -88,22 +88,34 @@ public async Task> EmbedBatchAsync( { await Parallel.ForEachAsync(batches, parallelOptions, async (batch, token) => { - foreach (var (text, globalIndex) in batch) + // Each text in the batch writes to its own index in the results array + // (no contention), so concurrent execution within a batch is safe. + var tasks = batch.Select(item => { - var embedding = await _pipeline.ExecuteAsync( - async t => await _inner.EmbedAsync(text, t), - token); - - results[globalIndex] = embedding; - } - }); + var (text, globalIndex) = item; + return EmbedSingleAsync(text, globalIndex, results, token); + }).ToArray(); + await Task.WhenAll(tasks).ConfigureAwait(false); + }).ConfigureAwait(false); } catch (Exception ex) when (ex is not OperationCanceledException) { - _logger.LogError(ex, "Batch embedding failed after Polly retries exhausted for {TextCount} texts", texts.Count); + LogBatchEmbeddingFailed(_logger, texts.Count, ex); throw; } return results; } + + private async Task EmbedSingleAsync(string text, int globalIndex, float[][] results, CancellationToken token) + { + var embedding = await _pipeline.ExecuteAsync( + async t => await _inner.EmbedAsync(text, t).ConfigureAwait(false), + token).ConfigureAwait(false); + results[globalIndex] = embedding; + } + + [LoggerMessage(Level = LogLevel.Error, + Message = "Batch embedding failed after Polly retries exhausted for {TextCount} texts")] + private static partial void LogBatchEmbeddingFailed(ILogger logger, int textCount, Exception exception); } diff --git a/src/clawsharp/Knowledge/Ingestion/KnowledgeIngestionPipeline.cs b/src/clawsharp/Knowledge/Ingestion/KnowledgeIngestionPipeline.cs index 1d3c9ffc..96bbb466 100644 --- a/src/clawsharp/Knowledge/Ingestion/KnowledgeIngestionPipeline.cs +++ b/src/clawsharp/Knowledge/Ingestion/KnowledgeIngestionPipeline.cs @@ -74,13 +74,14 @@ public virtual async Task IngestSourceAsync( try { - await IngestCoreAsync(sourceConfig, sourceId, progress, ct); + await IngestCoreAsync(sourceConfig, sourceId, progress, ct).ConfigureAwait(false); } catch (Exception ex) when (ex is not OperationCanceledException) { + rootSpan?.SetStatus(ActivityStatusCode.Error, ex.Message); _metrics?.RecordDocumentFailed(sourceConfig.Name, sourceConfig.Type); LogIngestionFailed(sourceConfig.Name, ex); - await _stateTracker.MarkFailedAsync(sourceId, ex.Message, ct); + await _stateTracker.MarkFailedAsync(sourceId, ex.Message, ct).ConfigureAwait(false); throw; } } @@ -91,10 +92,10 @@ private async Task IngestCoreAsync( IProgress? progress, CancellationToken ct) { - // Resolve chunking strategy: per-source override > global default > "auto" + // Resolve chunking strategy: per-source override > global default > "recursive" var strategyName = sourceConfig.Chunking?.Strategy ?? _config.Knowledge?.Chunking?.Strategy - ?? "auto"; + ?? "recursive"; if (!_strategies.TryGetValue(strategyName, out var chunkingStrategy)) { @@ -105,27 +106,27 @@ private async Task IngestCoreAsync( var chunkingConfig = sourceConfig.Chunking ?? _config.Knowledge?.Chunking ?? new ChunkingConfig(); // Get existing document hashes for delta detection - var existingHashes = await _store.GetDocumentHashesBySourceAsync(sourceId, ct); + var existingHashes = await _store.GetDocumentHashesBySourceAsync(sourceId, ct).ConfigureAwait(false); // Determine ingestion path: local file enumeration or remote loader dispatch var files = EnumerateSourceFiles(sourceConfig); - if (files.Count > 0 || string.Equals(sourceConfig.Type, "local", StringComparison.OrdinalIgnoreCase)) + if (files.Count > 0 || string.Equals(sourceConfig.Type, KnowledgeSourceType.Local, StringComparison.OrdinalIgnoreCase)) { // Local source path await IngestLocalSourceAsync(files, sourceConfig, sourceId, chunkingStrategy, chunkingConfig, - existingHashes, progress, ct); + existingHashes, progress, ct).ConfigureAwait(false); } else if (_remoteLoaders.TryGetValue(sourceConfig.Type, out var remoteLoader)) { // Remote source path: dispatch to the appropriate remote loader await IngestRemoteSourceAsync(remoteLoader, sourceConfig, sourceId, chunkingStrategy, chunkingConfig, - existingHashes, progress, ct); + existingHashes, progress, ct).ConfigureAwait(false); } else { LogUnsupportedSourceType(sourceConfig.Type, sourceConfig.Name); - await _stateTracker.MarkCompletedAsync(sourceId, "", 0, ct); + await _stateTracker.MarkCompletedAsync(sourceId, "", 0, ct).ConfigureAwait(false); progress?.Report(new IngestionProgress( IngestionProgressKind.Summary, $"Source '{sourceConfig.Name}' has unsupported type '{sourceConfig.Type}' — no remote loader registered")); @@ -155,9 +156,12 @@ private async Task IngestLocalSourceAsync( { var filePath = files[i]; - // Load pages and buffer content for hash computation + // Load pages via the format loader directly — bypasses PathGuard workspace + // check since knowledge source paths are admin-configured, not user/LLM input. + var ext = Path.GetExtension(filePath); + var loader = _loaderRegistry.GetLoader(ext); var pages = new List(); - await foreach (var page in _loaderRegistry.LoadAsync(filePath, ct)) + await foreach (var page in loader.LoadAsync(filePath, ct).ConfigureAwait(false)) { pages.Add(page); } @@ -182,9 +186,9 @@ private async Task IngestLocalSourceAsync( IngestionProgressKind.DocumentLoading, $"[{i + 1}/{totalFiles}] {filePath}...")); - var docPages = ToAsyncEnumerable(pages); + var docPages = ChunkingHelpers.ToAsyncEnumerable(pages); var chunks = new List(); - await foreach (var chunk in chunkingStrategy.ChunkAsync(docPages, chunkingConfig, ct)) + await foreach (var chunk in chunkingStrategy.ChunkAsync(docPages, chunkingConfig, ct).ConfigureAwait(false)) { chunks.Add(chunk); } @@ -207,7 +211,7 @@ private async Task IngestLocalSourceAsync( // Complete Phase B: embed and store await EmbedAndStoreAsync(changedDocuments, allDocHashes, sourceConfig, sourceId, - existingHashes, totalFiles, skipCount, progress, ct); + existingHashes, totalFiles, skipCount, progress, ct).ConfigureAwait(false); } /// Ingest documents from a remote source via an IRemoteSourceLoader plugin. @@ -229,13 +233,13 @@ private async Task IngestRemoteSourceAsync( // Phase A: Load + Chunk (interleaved per-document) using (var loadSpan = ClawsharpActivitySources.Knowledge.StartActivity("knowledge.load")) { - await foreach (var remoteDoc in remoteLoader.LoadDocumentsAsync(sourceConfig, ct)) + await foreach (var remoteDoc in remoteLoader.LoadDocumentsAsync(sourceConfig, ct).ConfigureAwait(false)) { docIndex++; // Buffer pages and compute content for hash var pages = new List(); - await foreach (var page in remoteDoc.Pages) + await foreach (var page in remoteDoc.Pages.ConfigureAwait(false)) { pages.Add(page); } @@ -260,9 +264,9 @@ private async Task IngestRemoteSourceAsync( IngestionProgressKind.DocumentLoading, $"[{docIndex}] {remoteDoc.SourceUri}...")); - var docPages = ToAsyncEnumerable(pages); + var docPages = ChunkingHelpers.ToAsyncEnumerable(pages); var chunks = new List(); - await foreach (var chunk in chunkingStrategy.ChunkAsync(docPages, chunkingConfig, ct)) + await foreach (var chunk in chunkingStrategy.ChunkAsync(docPages, chunkingConfig, ct).ConfigureAwait(false)) { chunks.Add(chunk); } @@ -285,7 +289,7 @@ private async Task IngestRemoteSourceAsync( // Complete Phase B: embed and store await EmbedAndStoreAsync(changedDocuments, allDocHashes, sourceConfig, sourceId, - existingHashes, docIndex, skipCount, progress, ct); + existingHashes, docIndex, skipCount, progress, ct).ConfigureAwait(false); } /// @@ -306,13 +310,13 @@ private async Task EmbedAndStoreAsync( // Source-level Merkle check (D-20): if all document hashes produce the same // Merkle rollup as the stored source hash, no work needed at all. var newMerkleHash = ContentHasher.ComputeSourceHash(allDocHashes); - var existingSource = await _store.GetSourceAsync(sourceId, ct); + var existingSource = await _store.GetSourceAsync(sourceId, ct).ConfigureAwait(false); var totalChunkCount = existingSource?.ChunkCount ?? 0; if (changedDocuments.Count == 0) { // Nothing changed -- mark completed with existing Merkle hash - await _stateTracker.MarkCompletedAsync(sourceId, newMerkleHash, totalChunkCount, ct); + await _stateTracker.MarkCompletedAsync(sourceId, newMerkleHash, totalChunkCount, ct).ConfigureAwait(false); progress?.Report(new IngestionProgress( IngestionProgressKind.Summary, $"Ingested {totalDocuments} documents -> 0 new chunks (skipped {skipCount} unchanged)")); @@ -331,7 +335,7 @@ private async Task EmbedAndStoreAsync( using (var embedSpan = ClawsharpActivitySources.Knowledge.StartActivity("knowledge.embed")) { var embedStart = Stopwatch.GetTimestamp(); - embeddings = await _embeddingProvider.EmbedBatchAsync(texts, ct); + embeddings = await _embeddingProvider.EmbedBatchAsync(texts, ct).ConfigureAwait(false); var elapsed = Stopwatch.GetElapsedTime(embedStart).TotalSeconds; _metrics?.RecordEmbeddingLatency(elapsed, sourceConfig.Name, sourceConfig.Type); } @@ -350,7 +354,7 @@ private async Task EmbedAndStoreAsync( foreach (var doc in changedDocuments) { // Delete old chunks for this changed document (per-document granularity, not source-level) - await _store.DeleteByDocumentAsync(sourceId, doc.FilePath, ct); + await _store.DeleteByDocumentAsync(sourceId, doc.FilePath, ct).ConfigureAwait(false); foreach (var chunk in doc.Chunks) { @@ -372,16 +376,16 @@ private async Task EmbedAndStoreAsync( } } - // Upsert new chunks - await _store.UpsertChunksAsync(sourceId, knowledgeChunks, ct); + // UpsertChunksAsync replaces chunks for changed documents only; + // the store computes total chunk count internally. + await _store.UpsertChunksAsync(sourceId, knowledgeChunks, ct).ConfigureAwait(false); - // Compute total chunk count: unchanged chunks + new chunks - var unchangedChunkCount = existingSource?.ChunkCount ?? 0; - var newTotalChunkCount = knowledgeChunks.Count + (totalChunkCount - changedDocuments.Count); - if (newTotalChunkCount < knowledgeChunks.Count) newTotalChunkCount = knowledgeChunks.Count; + // Query the actual chunk count from the store after upsert to avoid + // arithmetic that mixes document counts and chunk counts. + var finalChunkCount = await _store.CountChunksAsync(sourceId, ct).ConfigureAwait(false); // Mark completed with Merkle hash (D-20) - await _stateTracker.MarkCompletedAsync(sourceId, newMerkleHash, newTotalChunkCount, ct); + await _stateTracker.MarkCompletedAsync(sourceId, newMerkleHash, finalChunkCount, ct).ConfigureAwait(false); storeSpan?.SetTag(KnowledgeAttributes.SkippedCount, skipCount); } @@ -395,7 +399,7 @@ private async Task EmbedAndStoreAsync( private List EnumerateSourceFiles(KnowledgeSourceConfig sourceConfig) { - if (!string.Equals(sourceConfig.Type, "local", StringComparison.OrdinalIgnoreCase) + if (!string.Equals(sourceConfig.Type, KnowledgeSourceType.Local, StringComparison.OrdinalIgnoreCase) || string.IsNullOrEmpty(sourceConfig.Path)) { return []; @@ -407,7 +411,7 @@ private List EnumerateSourceFiles(KnowledgeSourceConfig sourceConfig) if (!Directory.Exists(sourceConfig.Path)) { - _logger.LogWarning("Source path does not exist: {Path}", sourceConfig.Path); + LogSourcePathMissing(sourceConfig.Path); return []; } @@ -418,15 +422,6 @@ private List EnumerateSourceFiles(KnowledgeSourceConfig sourceConfig) .ToList(); } - private static async IAsyncEnumerable ToAsyncEnumerable(List pages) - { - foreach (var page in pages) - { - yield return page; - await Task.CompletedTask; - } - } - private sealed record ChangedDocument(string FilePath, string Hash, List Chunks); [LoggerMessage(Level = LogLevel.Error, Message = "Ingestion failed for source {SourceName}")] @@ -434,4 +429,7 @@ private sealed record ChangedDocument(string FilePath, string Hash, List public async ValueTask EnqueueAsync(IngestionJob job, CancellationToken ct = default) { - await _channel.Writer.WriteAsync(job, ct); + await _channel.Writer.WriteAsync(job, ct).ConfigureAwait(false); } /// @@ -56,10 +57,17 @@ public async ValueTask EnqueueAsync(IngestionJob job, CancellationToken ct = def /// public override async Task StartAsync(CancellationToken ct) { - var recovered = await _stateTracker.RecoverStuckSourcesAsync(ct); - if (recovered > 0) + try { - LogCrashRecovery(recovered); + var recovered = await _stateTracker.RecoverStuckSourcesAsync(ct).ConfigureAwait(false); + if (recovered > 0) + { + LogCrashRecovery(recovered); + } + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + LogRecoveryFailed(ex); } // D-32/D-33: Register cron jobs for sources with syncCron @@ -75,18 +83,18 @@ await _cronService.AddJobAsync(new CronJob Name = $"Knowledge sync: {source.Name}", ScheduleKind = CronScheduleKind.Cron, ScheduleExpr = source.SyncCron, - Channel = "cli", + Channel = ChannelName.Cli.Value, SenderId = "knowledge-sync", Message = $"/knowledge ingest {source.Name}", Enabled = true, Source = CronSource.Config, - }, ct); + }, ct).ConfigureAwait(false); LogCronJobRegistered(source.Name, source.SyncCron); } } } - await base.StartAsync(ct); + await base.StartAsync(ct).ConfigureAwait(false); } /// @@ -94,7 +102,7 @@ await _cronService.AddJobAsync(new CronJob /// protected override async Task ExecuteAsync(CancellationToken stoppingToken) { - await foreach (var job in _channel.Reader.ReadAllAsync(stoppingToken)) + await foreach (var job in _channel.Reader.ReadAllAsync(stoppingToken).ConfigureAwait(false)) { try { @@ -109,7 +117,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) job.SourceId, Memory.Entities.KnowledgeSource.Statuses.Pending, Memory.Entities.KnowledgeSource.Statuses.Processing, - stoppingToken)) + stoppingToken).ConfigureAwait(false)) { LogSourceAlreadyProcessing(job.SourceName); continue; @@ -126,12 +134,12 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) _ => "unknown", }; - await _pipeline.IngestSourceAsync(sourceConfig, job.SourceId, progress: null, stoppingToken, triggerStr); + await _pipeline.IngestSourceAsync(sourceConfig, job.SourceId, progress: null, stoppingToken, triggerStr).ConfigureAwait(false); } catch (Exception ex) when (ex is not OperationCanceledException) { LogIngestionFailed(job.SourceName, ex); - await _stateTracker.MarkFailedAsync(job.SourceId, ex.Message, stoppingToken); + await _stateTracker.MarkFailedAsync(job.SourceId, ex.Message, stoppingToken).ConfigureAwait(false); } } } @@ -144,6 +152,10 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) Message = "Crash recovery: reset {Count} stuck sources back to Pending")] private partial void LogCrashRecovery(int count); + [LoggerMessage(Level = LogLevel.Warning, + Message = "Failed to recover stuck sources at startup — will retry on next ingestion cycle")] + private partial void LogRecoveryFailed(Exception exception); + [LoggerMessage(Level = LogLevel.Warning, Message = "Source config not found for '{SourceName}', skipping ingestion job")] private partial void LogSourceNotFound(string sourceName); diff --git a/src/clawsharp/Knowledge/Ingestion/SyncStateTracker.cs b/src/clawsharp/Knowledge/Ingestion/SyncStateTracker.cs index 95e91090..abc317b7 100644 --- a/src/clawsharp/Knowledge/Ingestion/SyncStateTracker.cs +++ b/src/clawsharp/Knowledge/Ingestion/SyncStateTracker.cs @@ -16,7 +16,7 @@ namespace Clawsharp.Knowledge.Ingestion; /// For non-EF backends (Redis, Markdown), pass null as the factory — all transitions /// return true unconditionally and the pipeline handles idempotency at the application layer. /// -public class SyncStateTracker( +public partial class SyncStateTracker( Func>? contextFactory, ILogger logger) { @@ -34,26 +34,29 @@ public virtual async Task TryTransitionAsync( { if (contextFactory is null) return true; - await using var ctx = await contextFactory(ct); - var source = await ctx.Set().FindAsync([sourceId], ct); - if (source is null || !string.Equals(source.Status, expectedStatus, StringComparison.Ordinal)) - return false; - - source.Status = newStatus; - source.UpdatedAt = DateTimeOffset.UtcNow; + var ctx = await contextFactory(ct).ConfigureAwait(false); + await using (ctx.ConfigureAwait(false)) + { + var source = await ctx.Set().FindAsync([sourceId], ct).ConfigureAwait(false); + if (source is null || !string.Equals(source.Status, expectedStatus, StringComparison.Ordinal)) + return false; - if (string.Equals(newStatus, KnowledgeSource.Statuses.Processing, StringComparison.Ordinal)) - source.ProcessingStartedAt = DateTimeOffset.UtcNow; + source.Status = newStatus; + source.UpdatedAt = DateTimeOffset.UtcNow; - try - { - await ctx.SaveChangesAsync(ct); - return true; - } - catch (DbUpdateConcurrencyException) - { - logger.LogDebug("CAS transition failed for source {SourceId}: concurrent modification detected", sourceId); - return false; + if (string.Equals(newStatus, KnowledgeSource.Statuses.Processing, StringComparison.Ordinal)) + source.ProcessingStartedAt = DateTimeOffset.UtcNow; + + try + { + await ctx.SaveChangesAsync(ct).ConfigureAwait(false); + return true; + } + catch (DbUpdateConcurrencyException) + { + LogCasTransitionFailed(logger, sourceId); + return false; + } } } @@ -66,28 +69,31 @@ public virtual async Task RecoverStuckSourcesAsync(CancellationToken ct = d { if (contextFactory is null) return 0; - await using var ctx = await contextFactory(ct); - var cutoff = DateTimeOffset.UtcNow - StuckTimeout; - - var stuckSources = await ctx.Set() - .Where(s => s.Status == KnowledgeSource.Statuses.Processing - && s.ProcessingStartedAt != null - && s.ProcessingStartedAt < cutoff) - .ToListAsync(ct); - - foreach (var source in stuckSources) + var ctx = await contextFactory(ct).ConfigureAwait(false); + await using (ctx.ConfigureAwait(false)) { - source.Status = KnowledgeSource.Statuses.Pending; - source.ProcessingStartedAt = null; - source.UpdatedAt = DateTimeOffset.UtcNow; - logger.LogWarning("Recovered stuck source {SourceId} ({SourceUri}) — was Processing since {StartedAt}", - source.Id, source.SourceUri, source.ProcessingStartedAt); + var cutoff = DateTimeOffset.UtcNow - StuckTimeout; + + var stuckSources = await ctx.Set() + .Where(s => s.Status == KnowledgeSource.Statuses.Processing + && s.ProcessingStartedAt != null + && s.ProcessingStartedAt < cutoff) + .ToListAsync(ct).ConfigureAwait(false); + + foreach (var source in stuckSources) + { + var startedAt = source.ProcessingStartedAt; + source.Status = KnowledgeSource.Statuses.Pending; + source.ProcessingStartedAt = null; + source.UpdatedAt = DateTimeOffset.UtcNow; + LogStuckSourceRecovered(logger, source.Id, source.SourceUri, startedAt); + } + + if (stuckSources.Count > 0) + await ctx.SaveChangesAsync(ct).ConfigureAwait(false); + + return stuckSources.Count; } - - if (stuckSources.Count > 0) - await ctx.SaveChangesAsync(ct); - - return stuckSources.Count; } /// @@ -98,17 +104,20 @@ public virtual async Task MarkCompletedAsync(Guid sourceId, string contentHash, { if (contextFactory is null) return; - await using var ctx = await contextFactory(ct); - var source = await ctx.Set().FindAsync([sourceId], ct); - if (source is null) return; + var ctx = await contextFactory(ct).ConfigureAwait(false); + await using (ctx.ConfigureAwait(false)) + { + var source = await ctx.Set().FindAsync([sourceId], ct).ConfigureAwait(false); + if (source is null) return; - source.Status = KnowledgeSource.Statuses.Completed; - source.ContentHash = contentHash; - source.ChunkCount = chunkCount; - source.ProcessingStartedAt = null; - source.UpdatedAt = DateTimeOffset.UtcNow; + source.Status = KnowledgeSource.Statuses.Completed; + source.ContentHash = contentHash; + source.ChunkCount = chunkCount; + source.ProcessingStartedAt = null; + source.UpdatedAt = DateTimeOffset.UtcNow; - await ctx.SaveChangesAsync(ct); + await ctx.SaveChangesAsync(ct).ConfigureAwait(false); + } } /// @@ -119,15 +128,26 @@ public virtual async Task MarkFailedAsync(Guid sourceId, string error, Cancellat { if (contextFactory is null) return; - await using var ctx = await contextFactory(ct); - var source = await ctx.Set().FindAsync([sourceId], ct); - if (source is null) return; + var ctx = await contextFactory(ct).ConfigureAwait(false); + await using (ctx.ConfigureAwait(false)) + { + var source = await ctx.Set().FindAsync([sourceId], ct).ConfigureAwait(false); + if (source is null) return; - source.Status = KnowledgeSource.Statuses.Failed; - source.ErrorMessage = error; - source.ProcessingStartedAt = null; - source.UpdatedAt = DateTimeOffset.UtcNow; + source.Status = KnowledgeSource.Statuses.Failed; + source.ErrorMessage = error; + source.ProcessingStartedAt = null; + source.UpdatedAt = DateTimeOffset.UtcNow; - await ctx.SaveChangesAsync(ct); + await ctx.SaveChangesAsync(ct).ConfigureAwait(false); + } } + + [LoggerMessage(Level = LogLevel.Debug, + Message = "CAS transition failed for source {SourceId}: concurrent modification detected")] + private static partial void LogCasTransitionFailed(ILogger logger, Guid sourceId); + + [LoggerMessage(Level = LogLevel.Warning, + Message = "Recovered stuck source {SourceId} ({SourceUri}) — was Processing since {StartedAt}")] + private static partial void LogStuckSourceRecovered(ILogger logger, Guid sourceId, string sourceUri, DateTimeOffset? startedAt); } diff --git a/src/clawsharp/Knowledge/Loading/CloudStorageLoaderBase.cs b/src/clawsharp/Knowledge/Loading/CloudStorageLoaderBase.cs index 3f2def0a..42372b04 100644 --- a/src/clawsharp/Knowledge/Loading/CloudStorageLoaderBase.cs +++ b/src/clawsharp/Knowledge/Loading/CloudStorageLoaderBase.cs @@ -1,4 +1,5 @@ using System.Runtime.CompilerServices; +using Clawsharp.Knowledge.Chunking; using Clawsharp.Knowledge.Config; using Microsoft.Extensions.Logging; @@ -10,7 +11,7 @@ namespace Clawsharp.Knowledge.Loading; /// via (D-25: source loader, not format loader), /// and construction with correct source URIs. /// -public abstract class CloudStorageLoaderBase : IRemoteSourceLoader +public abstract partial class CloudStorageLoaderBase : IRemoteSourceLoader { private readonly IDocumentLoaderRegistry _loaderRegistry; private readonly ILogger _logger; @@ -55,8 +56,7 @@ public async IAsyncEnumerable LoadDocumentsAsync( // D-24: Filter BEFORE download -- skip unsupported extensions without downloading. if (string.IsNullOrEmpty(extension) || !allowedExtensions.Contains(extension)) { - _logger.LogDebug("Skipping object {ObjectKey}: extension '{Extension}' not in allowed set", - obj.Key, extension); + LogSkippedExtension(_logger, obj.Key, extension); continue; } @@ -68,8 +68,7 @@ public async IAsyncEnumerable LoadDocumentsAsync( } catch (InvalidOperationException) { - _logger.LogDebug("Skipping object {ObjectKey}: no format loader for extension '{Extension}'", - obj.Key, extension); + LogNoFormatLoader(_logger, obj.Key, extension); continue; } @@ -96,7 +95,7 @@ public async IAsyncEnumerable LoadDocumentsAsync( materializedPages.Add(page); } - yield return new RemoteDocument(sourceUri, ToAsyncEnumerable(materializedPages)); + yield return new RemoteDocument(sourceUri, ChunkingHelpers.ToAsyncEnumerable(materializedPages)); } finally { @@ -108,21 +107,22 @@ public async IAsyncEnumerable LoadDocumentsAsync( } catch (IOException ex) { - _logger.LogWarning(ex, "Failed to delete temp file {TempFile}", tempFile); + LogTempFileDeleteFailed(_logger, tempFile, ex); } } } } } - private static async IAsyncEnumerable ToAsyncEnumerable( - List pages) - { - foreach (var page in pages) - { - yield return page; - } + [LoggerMessage(Level = LogLevel.Debug, + Message = "Skipping object {ObjectKey}: extension '{Extension}' not in allowed set")] + private static partial void LogSkippedExtension(ILogger logger, string objectKey, string extension); - await Task.CompletedTask; - } + [LoggerMessage(Level = LogLevel.Debug, + Message = "Skipping object {ObjectKey}: no format loader for extension '{Extension}'")] + private static partial void LogNoFormatLoader(ILogger logger, string objectKey, string extension); + + [LoggerMessage(Level = LogLevel.Warning, + Message = "Failed to delete temp file {TempFile}")] + private static partial void LogTempFileDeleteFailed(ILogger logger, string tempFile, Exception exception); } diff --git a/src/clawsharp/Knowledge/Plugins/PluginIntegrityVerifier.cs b/src/clawsharp/Knowledge/Plugins/PluginIntegrityVerifier.cs index a2b83a6b..2c8c1aef 100644 --- a/src/clawsharp/Knowledge/Plugins/PluginIntegrityVerifier.cs +++ b/src/clawsharp/Knowledge/Plugins/PluginIntegrityVerifier.cs @@ -128,8 +128,8 @@ public async Task VerifyAsync(string pluginDirectory, return result; } - var algorithm = NSec.Cryptography.SignatureAlgorithm.Ed25519; - var publicKey = NSec.Cryptography.PublicKey.Import(algorithm, publicKeyBytes, KeyBlobFormat.RawPublicKey); + var algorithm = SignatureAlgorithm.Ed25519; + var publicKey = PublicKey.Import(algorithm, publicKeyBytes, KeyBlobFormat.RawPublicKey); if (!algorithm.Verify(publicKey, canonicalBytes, signatureBytes)) { @@ -212,18 +212,20 @@ public async Task VerifyAsync(string pluginDirectory, internal static byte[] BuildCanonicalPayload(PluginManifest manifest) { // Build a dictionary representation without the signature field, with sorted keys - var sortedFiles = new SortedDictionary(manifest.Files, StringComparer.Ordinal); + var sortedFiles = new SortedDictionary(StringComparer.Ordinal); + foreach (var (key, value) in manifest.Files) + sortedFiles[key] = value; - var canonical = new SortedDictionary(StringComparer.Ordinal) + var canonical = new CanonicalManifest { - ["files"] = sortedFiles, - ["keyId"] = manifest.KeyId, - ["package"] = manifest.Package, - ["version"] = manifest.Version + Files = sortedFiles, + KeyId = manifest.KeyId, + Package = manifest.Package, + Version = manifest.Version }; // Serialize with no whitespace for deterministic output - return JsonSerializer.SerializeToUtf8Bytes(canonical, CanonicalJsonContext.Default.SortedDictionaryStringObject); + return JsonSerializer.SerializeToUtf8Bytes(canonical, CanonicalJsonContext.Default.CanonicalManifest); } /// @@ -274,9 +276,24 @@ await _auditLogger.LogAsync(new AuditEvent /// Source-generated JSON context for canonical manifest serialization (signature verification). /// Produces deterministic JSON with sorted keys and no whitespace. /// -[System.Text.Json.Serialization.JsonSerializable(typeof(SortedDictionary))] +internal sealed class CanonicalManifest +{ + [System.Text.Json.Serialization.JsonPropertyName("files")] + public SortedDictionary Files { get; init; } = new(); + + [System.Text.Json.Serialization.JsonPropertyName("keyId")] + public string KeyId { get; init; } = ""; + + [System.Text.Json.Serialization.JsonPropertyName("package")] + public string Package { get; init; } = ""; + + [System.Text.Json.Serialization.JsonPropertyName("version")] + public string Version { get; init; } = ""; +} + +[System.Text.Json.Serialization.JsonSerializable(typeof(CanonicalManifest))] [System.Text.Json.Serialization.JsonSerializable(typeof(SortedDictionary))] [System.Text.Json.Serialization.JsonSourceGenerationOptions( PropertyNamingPolicy = System.Text.Json.Serialization.JsonKnownNamingPolicy.CamelCase, WriteIndented = false)] -internal partial class CanonicalJsonContext : System.Text.Json.Serialization.JsonSerializerContext; +internal sealed partial class CanonicalJsonContext : System.Text.Json.Serialization.JsonSerializerContext; diff --git a/src/clawsharp/Knowledge/Plugins/PluginLoader.cs b/src/clawsharp/Knowledge/Plugins/PluginLoader.cs index 2dafd11b..bc9872a7 100644 --- a/src/clawsharp/Knowledge/Plugins/PluginLoader.cs +++ b/src/clawsharp/Knowledge/Plugins/PluginLoader.cs @@ -12,6 +12,54 @@ namespace Clawsharp.Knowledge.Plugins; /// internal static partial class PluginLoader { + /// + /// Scans subdirectories of for plugin assemblies matching + /// clawsharp.Plugin.*.dll. Each subdirectory is treated as an isolated plugin. + /// No integrity verification is performed — use when + /// requireSigned is needed. + /// + /// Absolute path to the plugins directory. + /// Logger for discovery diagnostics. + /// List of successfully loaded plugins. Empty if directory is missing or has no plugins. + internal static IReadOnlyList LoadPlugins( + string pluginsPath, + ILogger logger) + { + if (!Directory.Exists(pluginsPath)) + { + LogNoPluginsDirectory(logger, pluginsPath); + return []; + } + + var subDirs = Directory.GetDirectories(pluginsPath); + if (subDirs.Length == 0) + { + LogEmptyPluginsDirectory(logger, pluginsPath); + return []; + } + + var plugins = new List(); + var total = 0; + + foreach (var subDir in subDirs) + { + if (TryLoadPlugin(subDir, logger) is { } plugin) + { + plugins.Add(plugin); + total++; + } + else if (Directory.GetFiles(subDir, "clawsharp.Plugin.*.dll").Length > 0) + { + total++; // Had a DLL but failed to load + } + } + + LogPluginSummary(logger, plugins.Count, total, + plugins.Count > 0 ? string.Join(", ", plugins.Select(p => p.Name)) : "(none)"); + + return plugins; + } + /// /// Scans subdirectories of for plugin assemblies matching /// clawsharp.Plugin.*.dll. Each subdirectory is treated as an isolated plugin. @@ -50,62 +98,31 @@ internal static async Task> LoadPluginsAsync( foreach (var subDir in subDirs) { var primaryDlls = Directory.GetFiles(subDir, "clawsharp.Plugin.*.dll"); - var primaryDll = primaryDlls.FirstOrDefault(); - if (primaryDll is null) - { - continue; // Subdirectory without a plugin DLL, silently skip - } + if (primaryDlls.Length == 0) + continue; total++; - try + // ── D-35: Integrity verification BEFORE assembly loading ─ + if (requireSigned) { - // ── D-35: Integrity verification BEFORE assembly loading ─ - if (requireSigned) - { - if (verifier is null) - { - LogVerifierNotAvailable(logger, Path.GetFileName(subDir)); - continue; - } - - var verification = await verifier.VerifyAsync(subDir, ct).ConfigureAwait(false); - if (!verification.IsValid) - { - LogIntegrityCheckFailed(logger, Path.GetFileName(subDir), verification.Outcome, - verification.ErrorDetail ?? "unknown"); - continue; - } - } - - // ── Assembly loading (only reached after integrity check passes) ─ - var loadContext = new PluginLoadContext(primaryDll); - var assemblyName = new AssemblyName(Path.GetFileNameWithoutExtension(primaryDll)); - var assembly = loadContext.LoadFromAssemblyName(assemblyName); - - var pluginType = assembly.GetTypes() - .FirstOrDefault(t => typeof(IPlugin).IsAssignableFrom(t) && !t.IsAbstract && !t.IsInterface); - - if (pluginType is null) + if (verifier is null) { - continue; // DLL exists but doesn't implement IPlugin, silent skip + LogVerifierNotAvailable(logger, Path.GetFileName(subDir)); + continue; } - if (Activator.CreateInstance(pluginType) is not IPlugin plugin) + var verification = await verifier.VerifyAsync(subDir, ct).ConfigureAwait(false); + if (!verification.IsValid) { - LogPluginInstantiationFailed(logger, pluginType.FullName ?? pluginType.Name, Path.GetFileName(primaryDll)); + LogIntegrityCheckFailed(logger, Path.GetFileName(subDir), verification.Outcome, + verification.ErrorDetail ?? "unknown"); continue; } + } + if (TryLoadPlugin(subDir, logger) is { } plugin) plugins.Add(plugin); - LogPluginDiscovered(logger, plugin.Name); - } - catch (Exception ex) - { - var dirName = Path.GetFileName(subDir); - LogPluginLoadFailed(logger, dirName, ex.Message, ex); - LogPluginUnavailable(logger, dirName); - } } LogPluginSummary(logger, plugins.Count, total, @@ -115,13 +132,44 @@ internal static async Task> LoadPluginsAsync( } /// - /// Synchronous wrapper that calls the flat-directory scan for backward compatibility. - /// Retained for callers that cannot use the async path or do not need integrity verification. + /// Attempts to load a single plugin from a subdirectory. Returns null if + /// the directory contains no plugin DLL or loading fails. /// - internal static IReadOnlyList LoadPlugins(string pluginsPath, ILogger logger) + private static IPlugin? TryLoadPlugin(string subDir, ILogger logger) { - // Delegate to async method without integrity verification for backward compatibility - return LoadPluginsAsync(pluginsPath, verifier: null, requireSigned: false, logger).GetAwaiter().GetResult(); + var primaryDlls = Directory.GetFiles(subDir, "clawsharp.Plugin.*.dll"); + var primaryDll = primaryDlls.FirstOrDefault(); + if (primaryDll is null) + return null; + + try + { + var loadContext = new PluginLoadContext(primaryDll); + var assemblyName = new AssemblyName(Path.GetFileNameWithoutExtension(primaryDll)); + var assembly = loadContext.LoadFromAssemblyName(assemblyName); + + var pluginType = assembly.GetTypes() + .FirstOrDefault(t => typeof(IPlugin).IsAssignableFrom(t) && !t.IsAbstract && !t.IsInterface); + + if (pluginType is null) + return null; + + if (Activator.CreateInstance(pluginType) is not IPlugin plugin) + { + LogPluginInstantiationFailed(logger, pluginType.FullName ?? pluginType.Name, Path.GetFileName(primaryDll)); + return null; + } + + LogPluginDiscovered(logger, plugin.Name); + return plugin; + } + catch (Exception ex) + { + var dirName = Path.GetFileName(subDir); + LogPluginLoadFailed(logger, dirName, ex.Message, ex); + LogPluginUnavailable(logger, dirName); + return null; + } } /// diff --git a/src/clawsharp/Knowledge/Plugins/PluginManifest.cs b/src/clawsharp/Knowledge/Plugins/PluginManifest.cs index feead284..001b0ab4 100644 --- a/src/clawsharp/Knowledge/Plugins/PluginManifest.cs +++ b/src/clawsharp/Knowledge/Plugins/PluginManifest.cs @@ -23,6 +23,11 @@ public sealed record PluginManifest /// Map of filename to SHA-256 hex hash. Keys are simple filenames (no directory separators). /// All files in the plugin directory must be listed here (strict file-list enforcement per D-44). /// + /// + /// Concrete is required for JSON source-gen deserialization + /// and SortedDictionary construction in . + /// Treat as read-only after deserialization. + /// public required Dictionary Files { get; init; } /// Base64-encoded Ed25519 signature over the canonical manifest payload. diff --git a/src/clawsharp/Knowledge/Plugins/PluginManifestJsonContext.cs b/src/clawsharp/Knowledge/Plugins/PluginManifestJsonContext.cs index 4b7b46e4..400f3a67 100644 --- a/src/clawsharp/Knowledge/Plugins/PluginManifestJsonContext.cs +++ b/src/clawsharp/Knowledge/Plugins/PluginManifestJsonContext.cs @@ -1,13 +1,14 @@ using System.Text.Json.Serialization; +using Clawsharp.Config; namespace Clawsharp.Knowledge.Plugins; /// /// Source-generated JSON serializer context for . -/// Intentionally separate from -- plugin manifest +/// Intentionally separate from -- plugin manifest /// deserialization is plugin subsystem work, not config-pipeline work. The manifest is loaded /// during plugin verification (), not during config loading. /// [JsonSerializable(typeof(PluginManifest))] [JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] -internal partial class PluginManifestJsonContext : JsonSerializerContext; +internal sealed partial class PluginManifestJsonContext : JsonSerializerContext; diff --git a/src/clawsharp/Knowledge/Plugins/WellKnownKeys.cs b/src/clawsharp/Knowledge/Plugins/WellKnownKeys.cs index c38a6877..6ff81d62 100644 --- a/src/clawsharp/Knowledge/Plugins/WellKnownKeys.cs +++ b/src/clawsharp/Knowledge/Plugins/WellKnownKeys.cs @@ -23,14 +23,14 @@ internal static class WellKnownKeys /// /// Embedded official Ed25519 public key (32 bytes). This is the root of trust for first-party /// plugins -- compiled into the binary and not removable by operators. - /// + /// DEV KEY -- replace before release. /// public static ReadOnlySpan OfficialPublicKey => [ - 0xAB, 0x48, 0x93, 0x37, 0xEF, 0xBC, 0xC3, 0x78, - 0xE3, 0x8A, 0x9B, 0xA6, 0x2D, 0xED, 0x6C, 0x12, - 0xD5, 0x75, 0x6E, 0x46, 0x73, 0xCF, 0x26, 0xEB, - 0xA9, 0xAC, 0x5A, 0x54, 0xA5, 0x25, 0xA7, 0x61 + 0x10, 0x11, 0x59, 0xE9, 0xBF, 0x8F, 0xEC, 0xB6, + 0x86, 0x06, 0x7A, 0x60, 0x7C, 0x7E, 0x8D, 0x51, + 0xF9, 0x2C, 0xDC, 0x58, 0x36, 0x27, 0x66, 0xA0, + 0x59, 0x0A, 0xB5, 0x2B, 0x44, 0x60, 0x8F, 0xDD ]; /// diff --git a/src/clawsharp/Knowledge/Retrieval/CohereReranker.cs b/src/clawsharp/Knowledge/Retrieval/CohereReranker.cs index defb0d9e..842ad595 100644 --- a/src/clawsharp/Knowledge/Retrieval/CohereReranker.cs +++ b/src/clawsharp/Knowledge/Retrieval/CohereReranker.cs @@ -71,14 +71,14 @@ public async Task> RerankAsync( }; var jsonContent = JsonContent.Create(request, CohereJsonContext.Default.CohereRerankRequest); - var httpResponse = await _httpClient.PostAsync(RerankEndpoint, jsonContent, token); + var httpResponse = await _httpClient.PostAsync(RerankEndpoint, jsonContent, token).ConfigureAwait(false); httpResponse.EnsureSuccessStatusCode(); var result = await httpResponse.Content.ReadFromJsonAsync( - CohereJsonContext.Default.CohereRerankResponse, token); + CohereJsonContext.Default.CohereRerankResponse, token).ConfigureAwait(false); return result; - }, ct); + }, ct).ConfigureAwait(false); if (response?.Results is null || response.Results.Count == 0) { diff --git a/src/clawsharp/Knowledge/Slash/KnowledgeSlashCommandHandler.cs b/src/clawsharp/Knowledge/Slash/KnowledgeSlashCommandHandler.cs index 71383794..ba403124 100644 --- a/src/clawsharp/Knowledge/Slash/KnowledgeSlashCommandHandler.cs +++ b/src/clawsharp/Knowledge/Slash/KnowledgeSlashCommandHandler.cs @@ -44,7 +44,7 @@ public static string GetUnknownCommandMessage() => /// public async Task HandleStatusAsync(CancellationToken ct) { - var sources = await _store.ListSourcesAsync(ct); + var sources = await _store.ListSourcesAsync(ct).ConfigureAwait(false); if (sources.Count == 0) { @@ -112,7 +112,7 @@ public async Task HandleIngestAsync(string? argument, CancellationToken foreach (var source in sources) { var job = new IngestionJob(Guid.CreateVersion7(), source.Name, IngestionTrigger.Manual); - await _worker.EnqueueAsync(job, ct); + await _worker.EnqueueAsync(job, ct).ConfigureAwait(false); } return $"Queued ingestion for {sources.Count} sources."; @@ -128,7 +128,7 @@ public async Task HandleIngestAsync(string? argument, CancellationToken } var ingestJob = new IngestionJob(Guid.CreateVersion7(), targetSource.Name, IngestionTrigger.Manual); - await _worker.EnqueueAsync(ingestJob, ct); + await _worker.EnqueueAsync(ingestJob, ct).ConfigureAwait(false); return $"Queued ingestion for {targetSource.Name}."; } @@ -175,7 +175,4 @@ private static string FormatRelativeTime(DateTimeOffset timestamp) return $"{(int)elapsed.TotalDays}d ago"; } - [LoggerMessage(Level = LogLevel.Information, - Message = "Knowledge slash command: enqueued ingestion for {SourceName}")] - private partial void LogEnqueuedIngestion(string sourceName); } diff --git a/src/clawsharp/McpServer/McpExecutionContext.cs b/src/clawsharp/McpServer/McpExecutionContext.cs index 6526577c..d93c8389 100644 --- a/src/clawsharp/McpServer/McpExecutionContext.cs +++ b/src/clawsharp/McpServer/McpExecutionContext.cs @@ -2,7 +2,9 @@ namespace Clawsharp.McpServer; /// /// Per-session MCP context stored in AsyncLocal for propagation to tool.execute spans. -/// Mutable: ClientName/ClientVersion are filled post-handshake via InitializeHandler. +/// Immutable: all properties are set at construction time during ConfigureSessionAsync. +/// ClientName/ClientVersion remain null because the SDK handles the initialize handshake +/// internally; they are enriched on tool.execute spans via the SDK's own metadata instead. /// SessionId is a local UUID (not the transport Mcp-Session-Id header per D-09). /// public sealed class McpExecutionContext @@ -16,9 +18,9 @@ public sealed class McpExecutionContext /// OrgUser.Name from auth result. public string? AuthUser { get; init; } - /// MCP client name from initialize handshake (null until handshake completes). - public string? ClientName { get; set; } + /// MCP client name from initialize handshake. Null when not available at context creation time. + public string? ClientName { get; init; } - /// MCP client version from initialize handshake (null until handshake completes). - public string? ClientVersion { get; set; } + /// MCP client version from initialize handshake. Null when not available at context creation time. + public string? ClientVersion { get; init; } } diff --git a/src/clawsharp/McpServer/McpServerAuthResult.cs b/src/clawsharp/McpServer/McpServerAuthResult.cs index e6266ac9..303102a1 100644 --- a/src/clawsharp/McpServer/McpServerAuthResult.cs +++ b/src/clawsharp/McpServer/McpServerAuthResult.cs @@ -4,8 +4,9 @@ namespace Clawsharp.McpServer; /// /// Result of MCP server authentication. Used by the transport layer (Phase 13) -/// to determine HTTP response (401/403) and to pass OrgUser + PolicyDecision -/// to the dispatcher for RBAC-filtered tool listing and execution. +/// to pass OrgUser + PolicyDecision to the dispatcher for RBAC-filtered tool +/// listing and execution. Origin validation is handled separately by +/// before authentication. /// public sealed record McpServerAuthResult { @@ -15,23 +16,20 @@ public sealed record McpServerAuthResult /// The resolved org user, if any. Null for single-operator mode or auth failure. public OrgUser? User { get; init; } - /// The merged policy decision for this connection. Defaults to Unrestricted. + /// + /// The merged policy decision for this connection. Defaults to . + /// Only meaningful when is true; when false, this value + /// carries no authorization semantics and should not be used for access decisions. + /// public PolicyDecision PolicyDecision { get; init; } = PolicyDecision.Unrestricted; /// The matched API key identifier, if authenticated via static key. public string? KeyId { get; init; } - /// Whether the request was denied due to Origin header validation (HTTP 403 vs 401). - public bool IsOriginDenied { get; init; } - /// Creates an unauthenticated result with no details (per D-16). public static McpServerAuthResult Unauthenticated() => new(); /// Creates a successful auth result with resolved identity and policy. public static McpServerAuthResult Success(OrgUser? user, PolicyDecision policy, string? keyId) => new() { IsAuthenticated = true, User = user, PolicyDecision = policy, KeyId = keyId }; - - /// Creates an origin-denied result (HTTP 403, distinct from 401). - public static McpServerAuthResult OriginDenied() => - new() { IsOriginDenied = true }; } diff --git a/src/clawsharp/McpServer/McpServerAuthenticator.cs b/src/clawsharp/McpServer/McpServerAuthenticator.cs index 48ac333f..1a872c9b 100644 --- a/src/clawsharp/McpServer/McpServerAuthenticator.cs +++ b/src/clawsharp/McpServer/McpServerAuthenticator.cs @@ -1,7 +1,6 @@ using System.IO.Enumeration; using Clawsharp.Config.Features; using Clawsharp.Core.Security; -using Microsoft.Extensions.Logging; namespace Clawsharp.McpServer; @@ -11,20 +10,17 @@ namespace Clawsharp.McpServer; /// Delegates all API key validation, JWT verification, and localhost bypass to /// which is also consumed by the webhook dashboard. /// -public sealed partial class McpServerAuthenticator +public sealed class McpServerAuthenticator { private readonly ApiKeyAuthenticator _apiKeyAuthenticator; private readonly string[]? _allowedOrigins; - private readonly ILogger _logger; public McpServerAuthenticator( McpServerModeConfig? config, - ApiKeyAuthenticator apiKeyAuthenticator, - ILogger logger) + ApiKeyAuthenticator apiKeyAuthenticator) { _apiKeyAuthenticator = apiKeyAuthenticator; _allowedOrigins = config?.AllowedOrigins; - _logger = logger; } /// @@ -85,8 +81,4 @@ internal static bool IsOriginAllowed(string? origin, string[]? allowedOrigins) public Task AuthenticateAsync( string? bearerToken, CancellationToken ct = default) => _apiKeyAuthenticator.AuthenticateAsync(bearerToken, ct); - - [LoggerMessage(EventId = 1, Level = LogLevel.Warning, - Message = "MCP session rejected: origin={Origin}")] - private static partial void LogOriginRejected(ILogger logger, string origin); } diff --git a/src/clawsharp/McpServer/McpServerRouteRegistrar.cs b/src/clawsharp/McpServer/McpServerRouteRegistrar.cs index dda61020..e95a0122 100644 --- a/src/clawsharp/McpServer/McpServerRouteRegistrar.cs +++ b/src/clawsharp/McpServer/McpServerRouteRegistrar.cs @@ -52,7 +52,9 @@ internal async Task ConfigureSessionAsync( if (!authenticator.IsOriginAllowed(originToCheck)) { LogOriginRejected(logger, originToCheck ?? "(null)"); - throw new UnauthorizedAccessException("Forbidden: origin not allowed"); + httpContext.Response.StatusCode = StatusCodes.Status403Forbidden; + await httpContext.Response.CompleteAsync().ConfigureAwait(false); + throw new OperationCanceledException("Origin not allowed"); } // Step 2: Bearer token authentication (per D-02, RBAC-01) @@ -63,17 +65,13 @@ internal async Task ConfigureSessionAsync( bearerToken = authHeader["Bearer ".Length..]; } - var authResult = await authenticator.AuthenticateAsync(bearerToken, ct); - if (authResult.IsOriginDenied) - { - LogOriginRejected(logger, originToCheck ?? "(null)"); - throw new UnauthorizedAccessException("Forbidden: origin denied"); - } - + var authResult = await authenticator.AuthenticateAsync(bearerToken, ct).ConfigureAwait(false); if (!authResult.IsAuthenticated) { LogAuthFailed(logger); - throw new UnauthorizedAccessException("Unauthorized"); + httpContext.Response.StatusCode = StatusCodes.Status401Unauthorized; + await httpContext.Response.CompleteAsync().ConfigureAwait(false); + throw new OperationCanceledException("Unauthorized"); } // Phase 14 CHAN-03: mcp.session.init span (D-01, D-02) diff --git a/src/clawsharp/McpServer/McpServerToolBridge.cs b/src/clawsharp/McpServer/McpServerToolBridge.cs index 7dfd76f6..af882870 100644 --- a/src/clawsharp/McpServer/McpServerToolBridge.cs +++ b/src/clawsharp/McpServer/McpServerToolBridge.cs @@ -4,6 +4,7 @@ using Clawsharp.Cost; using Clawsharp.Organization; using Clawsharp.Tools; +using Microsoft.Extensions.AI; using ModelContextProtocol.Server; namespace Clawsharp.McpServer; @@ -33,6 +34,8 @@ internal static (bool? ReadOnly, bool? Destructive, bool? Idempotent, bool? Open /// /// Creates an from a clawsharp , /// with RBAC context captured in the tool delegate for defense-in-depth per D-05. + /// Uses a custom subclass to forward the tool's own JSON Schema + /// rather than letting the SDK infer a wrong schema from a delegate signature. /// public McpServerTool CreateMcpServerTool( ToolDefinition def, OrgUser? orgUser, PolicyDecision policyDecision, @@ -41,35 +44,15 @@ public McpServerTool CreateMcpServerTool( var sensitivity = toolRegistry.GetToolSensitivity(def.Name); var annotations = MapAnnotations(sensitivity); - // Delegate wraps ToolRegistry.ExecuteAsync with per-call RBAC context. - // The SDK invokes this when tools/call is dispatched for this tool. - var toolDelegate = async (JsonElement arguments, CancellationToken ct) => - { - // Defense-in-depth: re-set AsyncLocal context per call (D-05) - toolRegistry.SetChannelContext( - ChannelName.Mcp, spawnDepth: 0, - orgUser: orgUser, policyDecision: policyDecision); - - if (mcpCtx is not null) - toolRegistry.SetMcpExecutionContext(mcpCtx); - - var result = await toolRegistry.ExecuteAsync(def.Name, arguments.GetRawText(), ct); - - // CHAN-02: zero-cost record for MCP tool activity visibility (D-07) - await costTracker.RecordUsageAsync( - sessionId: $"mcp:{keyId ?? "jwt"}", - model: "mcp-tool", - inputTokens: 0, - outputTokens: 0, - userId: orgUser?.Name, - departmentId: orgUser?.Department, - ct: ct); + // Parse the tool's own JSON Schema so we can forward it verbatim to the SDK. + using var schemaDoc = JsonDocument.Parse(def.ParametersSchemaJson); + var schemaElement = schemaDoc.RootElement.Clone(); - return result; - }; + var ctx = new McpToolContext(orgUser, policyDecision, keyId, mcpCtx); + var aiFunc = new ToolAIFunction(def, schemaElement, toolRegistry, costTracker, ctx); return McpServerTool.Create( - (Delegate)toolDelegate, + aiFunc, new McpServerToolCreateOptions { Name = def.Name, @@ -89,4 +72,68 @@ public IReadOnlyList GetNativeFilteredTools(IReadOnlyList toolRegistry.IsNativeTool(d.Name)).ToList(); } + + // ── AIFunction subclass for correct schema forwarding ──────────────── + + /// Captured RBAC context for per-call defense-in-depth (D-05). + private sealed record McpToolContext( + OrgUser? OrgUser, PolicyDecision PolicyDecision, + string? KeyId, McpExecutionContext? McpCtx); + + /// + /// Custom that forwards the tool's own JSON Schema verbatim + /// and delegates execution to the . This avoids the SDK + /// inferring a wrong schema from a delegate's parameter types. + /// + private sealed class ToolAIFunction( + ToolDefinition def, JsonElement schemaElement, + IToolRegistry registry, CostTracker tracker, + McpToolContext ctx) : AIFunction + { + public override string Name => def.Name; + public override string Description => def.Description; + public override JsonElement JsonSchema => schemaElement; + + protected override async ValueTask InvokeCoreAsync( + AIFunctionArguments arguments, CancellationToken ct) + { + // Defense-in-depth: re-set AsyncLocal context per call (D-05) + registry.SetChannelContext(ChannelName.Mcp, spawnDepth: 0, + orgUser: ctx.OrgUser, policyDecision: ctx.PolicyDecision); + + if (ctx.McpCtx is not null) + registry.SetMcpExecutionContext(ctx.McpCtx); + + // Reconstruct the JSON arguments from the SDK's parsed key-value pairs + using var buffer = new MemoryStream(); + using (var writer = new Utf8JsonWriter(buffer)) + { + writer.WriteStartObject(); + foreach (var kvp in arguments) + { + writer.WritePropertyName(kvp.Key); + if (kvp.Value is JsonElement je) + je.WriteTo(writer); + else + writer.WriteNullValue(); + } + writer.WriteEndObject(); + } + + var argsJson = System.Text.Encoding.UTF8.GetString(buffer.ToArray()); + var result = await registry.ExecuteAsync(def.Name, argsJson, ct).ConfigureAwait(false); + + // CHAN-02: zero-cost record for MCP tool activity visibility (D-07) + await tracker.RecordUsageAsync( + sessionId: $"mcp:{ctx.KeyId ?? "jwt"}", + model: "mcp-tool", + inputTokens: 0, + outputTokens: 0, + userId: ctx.OrgUser?.Name, + departmentId: ctx.OrgUser?.Department, + ct: ct).ConfigureAwait(false); + + return result; + } + } } diff --git a/src/clawsharp/Memory/IKnowledgeStore.cs b/src/clawsharp/Memory/IKnowledgeStore.cs index e29de9ec..d5668a96 100644 --- a/src/clawsharp/Memory/IKnowledgeStore.cs +++ b/src/clawsharp/Memory/IKnowledgeStore.cs @@ -45,4 +45,11 @@ Task> SearchAsync( /// Used by the ingestion pipeline for per-document delta detection. /// Task> GetDocumentHashesBySourceAsync(Guid sourceId, CancellationToken ct = default); + + /// + /// Returns the total number of chunks stored for a given source. + /// Used by the ingestion pipeline to record an accurate chunk count after upsert, + /// avoiding arithmetic that mixes document counts and chunk counts. + /// + Task CountChunksAsync(Guid sourceId, CancellationToken ct = default); } diff --git a/src/clawsharp/Memory/LazyAsyncInit.cs b/src/clawsharp/Memory/LazyAsyncInit.cs new file mode 100644 index 00000000..a3a0b59d --- /dev/null +++ b/src/clawsharp/Memory/LazyAsyncInit.cs @@ -0,0 +1,42 @@ +namespace Clawsharp.Memory; + +/// +/// Thread-safe lazy async initialization with retry-on-failure semantics. +/// Replaces the copy-pasted volatile Task? + SemaphoreSlim pattern across all +/// and implementations. +/// Uses on the fast path to guarantee correct +/// visibility on ARM64 (non-TSO) architectures. +/// +internal sealed class LazyAsyncInit : IDisposable +{ + private Task? _task; + private readonly SemaphoreSlim _lock = new(1, 1); + + public async Task EnsureCompletedAsync(Func factory, CancellationToken ct) + { + var observed = Volatile.Read(ref _task); + if (observed is { IsCompletedSuccessfully: true }) return; + + await _lock.WaitAsync(ct).ConfigureAwait(false); + try + { + observed = Volatile.Read(ref _task); + if (observed is { IsCompletedSuccessfully: true }) return; + + var task = factory(ct); + Volatile.Write(ref _task, task); + await task.ConfigureAwait(false); + } + catch + { + Volatile.Write(ref _task, null); + throw; + } + finally + { + _lock.Release(); + } + } + + public void Dispose() => _lock.Dispose(); +} diff --git a/src/clawsharp/Memory/Markdown/MarkdownKnowledgeStore.cs b/src/clawsharp/Memory/Markdown/MarkdownKnowledgeStore.cs index 1657fb4d..606fa595 100644 --- a/src/clawsharp/Memory/Markdown/MarkdownKnowledgeStore.cs +++ b/src/clawsharp/Memory/Markdown/MarkdownKnowledgeStore.cs @@ -28,11 +28,11 @@ public MarkdownKnowledgeStore(string dir, ILogger logger public async Task UpsertChunksAsync(Guid sourceId, IReadOnlyList chunks, CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { // Remove existing chunks for source - await RewriteWithoutSourceAsync(_chunksPath, sourceId, ct); + await RewriteWithoutSourceAsync(_chunksPath, sourceId, ct).ConfigureAwait(false); // Append new chunks var lines = new List(); @@ -56,11 +56,11 @@ public async Task UpsertChunksAsync(Guid sourceId, IReadOnlyList if (lines.Count > 0) { - await File.AppendAllLinesAsync(_chunksPath, lines, ct); + await File.AppendAllLinesAsync(_chunksPath, lines, ct).ConfigureAwait(false); } // Update source - await UpsertSourceChunkCountAsync(sourceId, chunks.Count, ct); + await UpsertSourceChunkCountAsync(sourceId, chunks.Count, ct).ConfigureAwait(false); } finally { @@ -70,10 +70,10 @@ public async Task UpsertChunksAsync(Guid sourceId, IReadOnlyList public async Task DeleteByDocumentAsync(Guid sourceId, string sourceUri, CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { - await RewriteWithoutDocumentAsync(_chunksPath, sourceId, sourceUri, ct); + await RewriteWithoutDocumentAsync(_chunksPath, sourceId, sourceUri, ct).ConfigureAwait(false); } finally { @@ -83,11 +83,11 @@ public async Task DeleteByDocumentAsync(Guid sourceId, string sourceUri, Cancell public async Task DeleteBySourceAsync(Guid sourceId, CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { - await RewriteWithoutSourceAsync(_chunksPath, sourceId, ct); - await RewriteWithoutSourceIdAsync(_sourcesPath, sourceId, ct); + await RewriteWithoutSourceAsync(_chunksPath, sourceId, ct).ConfigureAwait(false); + await RewriteWithoutSourceIdAsync(_sourcesPath, sourceId, ct).ConfigureAwait(false); } finally { @@ -98,12 +98,15 @@ public async Task DeleteBySourceAsync(Guid sourceId, CancellationToken ct = defa public async Task> SearchAsync( float[]? queryEmbedding, string queryText, AclFilter acl, int topK = 5, CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { - var chunks = await LoadChunksAsync(ct); + var chunks = await LoadChunksAsync(ct).ConfigureAwait(false); - // No ACL filtering per D-39 (startup warning emitted, not per-query) + if (acl.HasRestrictions) + { + LogAclIgnored(_logger); + } // Path 1: substring matching for "FTS" var ftsResults = new List<(Guid ChunkId, int Rank)>(); @@ -160,10 +163,10 @@ public async Task> SearchAsync( public async Task> ListSourcesAsync(CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { - return await LoadSourcesAsync(ct); + return await LoadSourcesAsync(ct).ConfigureAwait(false); } finally { @@ -173,10 +176,10 @@ public async Task> ListSourcesAsync(CancellationT public async Task GetSourceAsync(Guid id, CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { - var sources = await LoadSourcesAsync(ct); + var sources = await LoadSourcesAsync(ct).ConfigureAwait(false); return sources.FirstOrDefault(s => s.Id == id); } finally @@ -188,10 +191,10 @@ public async Task> ListSourcesAsync(CancellationT /// public async Task> GetDocumentHashesBySourceAsync(Guid sourceId, CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { - var chunks = await LoadChunksAsync(ct); + var chunks = await LoadChunksAsync(ct).ConfigureAwait(false); var result = new Dictionary(StringComparer.Ordinal); foreach (var chunk in chunks.Where(c => c.KnowledgeSourceId == sourceId)) { @@ -205,6 +208,21 @@ public async Task> GetDocumentHashesBySource } } + /// + public async Task CountChunksAsync(Guid sourceId, CancellationToken ct = default) + { + await _lock.WaitAsync(ct).ConfigureAwait(false); + try + { + var chunks = await LoadChunksAsync(ct).ConfigureAwait(false); + return chunks.Count(c => c.KnowledgeSourceId == sourceId); + } + finally + { + _lock.Release(); + } + } + public void Dispose() => _lock.Dispose(); // ── Storage helpers ────────────────────────────────────────── @@ -214,7 +232,7 @@ private async Task> LoadChunksAsync(CancellationToken ct) var chunks = new List(); if (!File.Exists(_chunksPath)) return chunks; - var lines = await File.ReadAllLinesAsync(_chunksPath, ct); + var lines = await File.ReadAllLinesAsync(_chunksPath, ct).ConfigureAwait(false); foreach (var line in lines) { if (string.IsNullOrWhiteSpace(line)) continue; @@ -237,7 +255,7 @@ private async Task> LoadSourcesAsync(CancellationToken ct) var sources = new List(); if (!File.Exists(_sourcesPath)) return sources; - var lines = await File.ReadAllLinesAsync(_sourcesPath, ct); + var lines = await File.ReadAllLinesAsync(_sourcesPath, ct).ConfigureAwait(false); foreach (var line in lines) { if (string.IsNullOrWhiteSpace(line)) continue; @@ -259,7 +277,7 @@ private async Task RewriteWithoutDocumentAsync(string path, Guid sourceId, strin { if (!File.Exists(path)) return; - var lines = await File.ReadAllLinesAsync(path, ct); + var lines = await File.ReadAllLinesAsync(path, ct).ConfigureAwait(false); var kept = new List(); foreach (var line in lines) { @@ -281,14 +299,14 @@ private async Task RewriteWithoutDocumentAsync(string path, Guid sourceId, strin } } - await File.WriteAllLinesAsync(path, kept, ct); + await File.WriteAllLinesAsync(path, kept, ct).ConfigureAwait(false); } private async Task RewriteWithoutSourceAsync(string path, Guid sourceId, CancellationToken ct) { if (!File.Exists(path)) return; - var lines = await File.ReadAllLinesAsync(path, ct); + var lines = await File.ReadAllLinesAsync(path, ct).ConfigureAwait(false); var kept = new List(); foreach (var line in lines) { @@ -307,14 +325,14 @@ private async Task RewriteWithoutSourceAsync(string path, Guid sourceId, Cancell } } - await File.WriteAllLinesAsync(path, kept, ct); + await File.WriteAllLinesAsync(path, kept, ct).ConfigureAwait(false); } private async Task RewriteWithoutSourceIdAsync(string path, Guid sourceId, CancellationToken ct) { if (!File.Exists(path)) return; - var lines = await File.ReadAllLinesAsync(path, ct); + var lines = await File.ReadAllLinesAsync(path, ct).ConfigureAwait(false); var kept = new List(); foreach (var line in lines) { @@ -333,7 +351,7 @@ private async Task RewriteWithoutSourceIdAsync(string path, Guid sourceId, Cance } } - await File.WriteAllLinesAsync(path, kept, ct); + await File.WriteAllLinesAsync(path, kept, ct).ConfigureAwait(false); } private async Task UpsertSourceChunkCountAsync(Guid sourceId, int chunkCount, CancellationToken ct) @@ -341,7 +359,7 @@ private async Task UpsertSourceChunkCountAsync(Guid sourceId, int chunkCount, Ca // Rewrite the source record with updated chunk count if (File.Exists(_sourcesPath)) { - var lines = await File.ReadAllLinesAsync(_sourcesPath, ct); + var lines = await File.ReadAllLinesAsync(_sourcesPath, ct).ConfigureAwait(false); var kept = new List(); var found = false; foreach (var line in lines) @@ -370,7 +388,7 @@ private async Task UpsertSourceChunkCountAsync(Guid sourceId, int chunkCount, Ca if (found) { - await File.WriteAllLinesAsync(_sourcesPath, kept, ct); + await File.WriteAllLinesAsync(_sourcesPath, kept, ct).ConfigureAwait(false); } } } @@ -435,4 +453,8 @@ internal sealed class SourceDto [LoggerMessage(EventId = 1, Level = LogLevel.Warning, Message = "Markdown knowledge store does not support department-scoped access control. All knowledge is accessible to all users. Use SQLite, PostgreSQL, MsSql, or Redis for ACL support.")] internal static partial void LogNoAclSupport(ILogger logger); + + [LoggerMessage(EventId = 2, Level = LogLevel.Debug, + Message = "ACL filter ignored: markdown knowledge store does not support department-scoped access control")] + private static partial void LogAclIgnored(ILogger logger); } diff --git a/src/clawsharp/Memory/Markdown/MarkdownMemory.cs b/src/clawsharp/Memory/Markdown/MarkdownMemory.cs index 4c547b52..4bb410e2 100644 --- a/src/clawsharp/Memory/Markdown/MarkdownMemory.cs +++ b/src/clawsharp/Memory/Markdown/MarkdownMemory.cs @@ -12,7 +12,7 @@ public sealed class MarkdownMemory(string dir) : IMemory, IDisposable public async Task GetContextAsync(CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { if (!File.Exists(_memoryPath)) @@ -20,7 +20,7 @@ public sealed class MarkdownMemory(string dir) : IMemory, IDisposable return null; } - var content = await File.ReadAllTextAsync(_memoryPath, ct); + var content = await File.ReadAllTextAsync(_memoryPath, ct).ConfigureAwait(false); return string.IsNullOrWhiteSpace(content) ? null : content; } finally @@ -31,11 +31,11 @@ public sealed class MarkdownMemory(string dir) : IMemory, IDisposable public async Task AppendFactAsync(string fact, CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { var line = $"- {fact.ReplaceLineEndings(" ")}\n"; - await File.AppendAllTextAsync(_memoryPath, line, ct); + await File.AppendAllTextAsync(_memoryPath, line, ct).ConfigureAwait(false); } finally { @@ -45,12 +45,12 @@ public async Task AppendFactAsync(string fact, CancellationToken ct = default) public async Task AppendHistoryAsync(string summary, CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { var now = DateTimeOffset.UtcNow.ToString("yyyy-MM-ddTHH:mm:ssZ"); var entry = $"\n## {now}\n{summary}\n"; - await File.AppendAllTextAsync(_historyPath, entry, ct); + await File.AppendAllTextAsync(_historyPath, entry, ct).ConfigureAwait(false); } finally { @@ -60,7 +60,7 @@ public async Task AppendHistoryAsync(string summary, CancellationToken ct = defa public async Task> SearchAsync(string query, int n = 5, CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { var results = new List(); @@ -69,7 +69,7 @@ public async Task> SearchAsync(string query, int n = 5, Ca return results; } - var lines = await File.ReadAllLinesAsync(_memoryPath, ct); + var lines = await File.ReadAllLinesAsync(_memoryPath, ct).ConfigureAwait(false); foreach (var line in lines) { @@ -96,7 +96,7 @@ public async Task> SearchHybridAsync(string query, float[]? CancellationToken ct = default) { // Markdown backend does not support embeddings — fall back to string contains - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { if (!File.Exists(_memoryPath)) @@ -104,7 +104,7 @@ public async Task> SearchHybridAsync(string query, float[]? return []; } - var lines = await File.ReadAllLinesAsync(_memoryPath, ct); + var lines = await File.ReadAllLinesAsync(_memoryPath, ct).ConfigureAwait(false); var facts = new List(); long id = 1; foreach (var line in lines) @@ -132,7 +132,7 @@ public async Task> SearchHybridAsync(string query, float[]? public async Task> ListFactsAsync(CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { if (!File.Exists(_memoryPath)) @@ -140,7 +140,7 @@ public async Task> ListFactsAsync(CancellationToken ct = def return []; } - var lines = await File.ReadAllLinesAsync(_memoryPath, ct); + var lines = await File.ReadAllLinesAsync(_memoryPath, ct).ConfigureAwait(false); var facts = new List(); long id = 1; foreach (var line in lines) @@ -162,7 +162,7 @@ public async Task> ListFactsAsync(CancellationToken ct = def public async Task ClearAsync(CancellationToken ct = default) { - await _lock.WaitAsync(ct); + await _lock.WaitAsync(ct).ConfigureAwait(false); try { if (File.Exists(_memoryPath)) diff --git a/src/clawsharp/Memory/MemoryDecayService.cs b/src/clawsharp/Memory/MemoryDecayService.cs index 1d548f2e..132ab639 100644 --- a/src/clawsharp/Memory/MemoryDecayService.cs +++ b/src/clawsharp/Memory/MemoryDecayService.cs @@ -30,9 +30,9 @@ protected override async Task ExecuteAsync(CancellationToken ct) { try { - await Task.Delay(TimeSpan.FromHours(decay.PruneIntervalHours), ct); + await Task.Delay(TimeSpan.FromHours(decay.PruneIntervalHours), ct).ConfigureAwait(false); - var pruned = await memory.PruneExpiredFactsAsync(TimeSpan.FromDays(decay.TtlDays), ct); + var pruned = await memory.PruneExpiredFactsAsync(TimeSpan.FromDays(decay.TtlDays), ct).ConfigureAwait(false); if (pruned > 0) { LogPruned(pruned, decay.TtlDays); diff --git a/src/clawsharp/Memory/MsSql/MsSqlKnowledgeStore.cs b/src/clawsharp/Memory/MsSql/MsSqlKnowledgeStore.cs index 222b975f..4949e976 100644 --- a/src/clawsharp/Memory/MsSql/MsSqlKnowledgeStore.cs +++ b/src/clawsharp/Memory/MsSql/MsSqlKnowledgeStore.cs @@ -1,4 +1,5 @@ using System.Diagnostics.CodeAnalysis; +using System.Text; using Clawsharp.Knowledge; using Clawsharp.Memory.Entities; using Microsoft.EntityFrameworkCore; @@ -19,16 +20,15 @@ public sealed partial class MsSqlKnowledgeStore( { private const int CandidateCount = 30; - private volatile Task? _initTask; - private readonly SemaphoreSlim _initLock = new(1, 1); + private readonly LazyAsyncInit _init = new(); public async Task UpsertChunksAsync(Guid sourceId, IReadOnlyList chunks, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); // Delete existing chunks for source - await context.KnowledgeChunks.Where(c => c.KnowledgeSourceId == sourceId).ExecuteDeleteAsync(ct); + await context.KnowledgeChunks.Where(c => c.KnowledgeSourceId == sourceId).ExecuteDeleteAsync(ct).ConfigureAwait(false); // Insert new chunks with embedding as JSON TEXT foreach (var chunk in chunks) @@ -36,18 +36,31 @@ public async Task UpsertChunksAsync(Guid sourceId, IReadOnlyList context.KnowledgeChunks.Add(chunk); } - await context.SaveChangesAsync(ct); + await context.SaveChangesAsync(ct).ConfigureAwait(false); - // Store embeddings as JSON in a TEXT column - foreach (var chunk in chunks) + // Store embeddings as JSON in a TEXT column (batched to avoid N round-trips) + var embeddingChunks = chunks.Where(c => c.Embedding is not null).ToList(); + if (embeddingChunks.Count > 0) { - if (chunk.Embedding is not null) + // Build a single UPDATE...CASE statement for all embeddings + var caseClauses = new StringBuilder(); + var ids = new StringBuilder(); + var parameters = new List(); + for (var i = 0; i < embeddingChunks.Count; i++) { - var json = EmbeddingMath.Serialize(chunk.Embedding.ToArray()); - await context.Database.ExecuteSqlRawAsync( - $"UPDATE {KnowledgeChunk.TableName} SET embedding_json = {{0}} WHERE Id = {{1}}", - [json, chunk.Id], ct); + var chunk = embeddingChunks[i]; + var json = EmbeddingMath.Serialize(chunk.Embedding!.ToArray()); + var jsonParam = i * 2; + var idParam = jsonParam + 1; + caseClauses.Append($"WHEN Id = {{{idParam}}} THEN {{{jsonParam}}} "); + if (i > 0) ids.Append(','); + ids.Append($"{{{idParam}}}"); + parameters.Add(json); + parameters.Add(chunk.Id); } + + var sql = $"UPDATE [{KnowledgeChunk.TableName}] SET embedding_json = CASE {caseClauses}END WHERE Id IN ({ids})"; + await context.Database.ExecuteSqlRawAsync(sql, parameters, ct).ConfigureAwait(false); } // Update source chunk count @@ -55,40 +68,40 @@ await context.KnowledgeSources .Where(s => s.Id == sourceId) .ExecuteUpdateAsync(s => s .SetProperty(x => x.ChunkCount, chunks.Count) - .SetProperty(x => x.UpdatedAt, DateTimeOffset.UtcNow), ct); + .SetProperty(x => x.UpdatedAt, DateTimeOffset.UtcNow), ct).ConfigureAwait(false); } public async Task DeleteByDocumentAsync(Guid sourceId, string sourceUri, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); await context.KnowledgeChunks .Where(c => c.KnowledgeSourceId == sourceId && c.SourceUri == sourceUri) - .ExecuteDeleteAsync(ct); + .ExecuteDeleteAsync(ct).ConfigureAwait(false); } public async Task DeleteBySourceAsync(Guid sourceId, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); // Cascade delete via FK - await context.KnowledgeSources.Where(s => s.Id == sourceId).ExecuteDeleteAsync(ct); + await context.KnowledgeSources.Where(s => s.Id == sourceId).ExecuteDeleteAsync(ct).ConfigureAwait(false); } public async Task> SearchAsync( float[]? queryEmbedding, string queryText, AclFilter acl, int topK = 5, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); // Path 1: Keyword search (LIKE fallback -- full-text catalog may not be configured) - var ftsResults = await KeywordSearchAsync(context, queryText, acl, ct); + var ftsResults = await KeywordSearchAsync(context, queryText, acl, ct).ConfigureAwait(false); // Path 2: In-process cosine vector search (skipped when embedding is null per D-13) var vectorResults = queryEmbedding is not null - ? await VectorSearchAsync(context, queryEmbedding, acl, ct) + ? await VectorSearchAsync(context, queryEmbedding, acl, ct).ConfigureAwait(false) : []; // Build chunk lookup and RRF merge @@ -105,39 +118,48 @@ public async Task> SearchAsync( var chunkLookup = await context.KnowledgeChunks .AsNoTracking() .Where(c => allIds.Contains(c.Id)) - .ToDictionaryAsync(c => c.Id, ct); + .ToDictionaryAsync(c => c.Id, ct).ConfigureAwait(false); return RrfMerger.Merge(ftsResults, vectorResults, chunkLookup, topK: topK); } public async Task> ListSourcesAsync(CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); - return await context.KnowledgeSources.AsNoTracking().OrderByDescending(s => s.CreatedAt).ToListAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); + return await context.KnowledgeSources.AsNoTracking().OrderByDescending(s => s.CreatedAt).ToListAsync(ct).ConfigureAwait(false); } public async Task GetSourceAsync(Guid id, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); - return await context.KnowledgeSources.AsNoTracking().FirstOrDefaultAsync(s => s.Id == id, ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); + return await context.KnowledgeSources.AsNoTracking().FirstOrDefaultAsync(s => s.Id == id, ct).ConfigureAwait(false); } /// public async Task> GetDocumentHashesBySourceAsync(Guid sourceId, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); var pairs = await context.KnowledgeChunks .AsNoTracking() .Where(c => c.KnowledgeSourceId == sourceId) .Select(c => new { c.SourceUri, c.DocumentHash }) .Distinct() - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); return pairs.ToDictionary(p => p.SourceUri, p => p.DocumentHash, StringComparer.Ordinal); } + /// + public async Task CountChunksAsync(Guid sourceId, CancellationToken ct = default) + { + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); + return await context.KnowledgeChunks + .CountAsync(c => c.KnowledgeSourceId == sourceId, ct).ConfigureAwait(false); + } + // ── Keyword search ─────────────────────────────────────────── private async Task> KeywordSearchAsync( @@ -164,7 +186,7 @@ public async Task> GetDocumentHashesBySource query = query.Where(c => depts.Contains(c.DepartmentId)); } - ids = await query.Take(CandidateCount).Select(c => c.Id).ToListAsync(ct); + ids = await query.Take(CandidateCount).Select(c => c.Id).ToListAsync(ct).ConfigureAwait(false); } catch { @@ -180,7 +202,7 @@ public async Task> GetDocumentHashesBySource likeQuery = likeQuery.Where(c => depts.Contains(c.DepartmentId)); } - ids = await likeQuery.Take(CandidateCount).Select(c => c.Id).ToListAsync(ct); + ids = await likeQuery.Take(CandidateCount).Select(c => c.Id).ToListAsync(ct).ConfigureAwait(false); } var rank = 1; @@ -212,33 +234,41 @@ private sealed class ChunkEmbeddingRow try { - string sql; + // Always load all embeddings (no ACL in SQL); post-filter by department via LINQ + var sql = $""" + SELECT Id AS ChunkId, embedding_json AS EmbeddingJson + FROM {KnowledgeChunk.TableName} + WHERE embedding_json IS NOT NULL + """; + var rows = await context.Database.SqlQueryRaw(sql).ToListAsync(ct).ConfigureAwait(false); + + // Build department allowlist for post-filtering + HashSet? allowedDepts = null; + Dictionary? deptLookup = null; if (acl.HasRestrictions) { - var deptList = string.Join(",", acl.DepartmentIds.Select(d => $"'{d.Replace("'", "''")}'")); - sql = $""" - SELECT Id AS ChunkId, embedding_json AS EmbeddingJson - FROM {KnowledgeChunk.TableName} - WHERE embedding_json IS NOT NULL - AND DepartmentId IN ({deptList}) - """; - } - else - { - sql = $""" - SELECT Id AS ChunkId, embedding_json AS EmbeddingJson - FROM {KnowledgeChunk.TableName} - WHERE embedding_json IS NOT NULL - """; - } + allowedDepts = acl.DepartmentIds.ToHashSet(StringComparer.Ordinal); - var rows = await context.Database.SqlQueryRaw(sql).ToListAsync(ct); + var candidateIds = rows.Select(r => r.ChunkId).ToList(); + deptLookup = await context.KnowledgeChunks + .AsNoTracking() + .Where(c => candidateIds.Contains(c.Id)) + .Select(c => new { c.Id, c.DepartmentId }) + .ToDictionaryAsync(c => c.Id, c => c.DepartmentId, ct).ConfigureAwait(false); + } var scored = new List<(Guid id, float score)>(); foreach (var row in rows) { if (row.EmbeddingJson is null) continue; + // Post-filter: skip chunks not in allowed departments + if (allowedDepts is not null && deptLookup is not null) + { + if (!deptLookup.TryGetValue(row.ChunkId, out var dept) || !allowedDepts.Contains(dept)) + continue; + } + var vec = EmbeddingMath.Deserialize(row.EmbeddingJson); if (vec.Length == 0 || vec.Length != queryEmbedding.Length) continue; @@ -271,34 +301,10 @@ private static string EscapeLikePattern(string query) => // ── Init ───────────────────────────────────────────────────── - private async Task EnsureInitializedAsync(CancellationToken ct) - { - if (_initTask is { IsCompletedSuccessfully: true }) return; - - await _initLock.WaitAsync(ct); - try - { - if (_initTask is { IsCompletedSuccessfully: true }) return; - - var task = InitSchemaAsync(ct); - _initTask = task; - await task; - } - catch - { - _initTask = null; - throw; - } - finally - { - _initLock.Release(); - } - } - [RequiresDynamicCode("EF Core MigrateAsync requires dynamic code generation.")] private async Task InitSchemaAsync(CancellationToken ct) { - await using var context = await contextFactory.CreateDbContextAsync(ct); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); // Add embedding_json NVARCHAR(MAX) column if not present const string sql = @@ -309,7 +315,7 @@ IF NOT EXISTS (SELECT 1 FROM sys.columns WHERE object_id = OBJECT_ID('{Knowledge END """; - await context.Database.ExecuteSqlRawAsync(sql, ct); + await context.Database.ExecuteSqlRawAsync(sql, ct).ConfigureAwait(false); LogSchemaInitialized(logger); } diff --git a/src/clawsharp/Memory/MsSql/MsSqlMemory.cs b/src/clawsharp/Memory/MsSql/MsSqlMemory.cs index fa1dfb20..42cd8a5f 100644 --- a/src/clawsharp/Memory/MsSql/MsSqlMemory.cs +++ b/src/clawsharp/Memory/MsSql/MsSqlMemory.cs @@ -14,9 +14,7 @@ public sealed partial class MsSqlMemory(IDbContextFactory co { private const int RecentContentLimit = 50; - private volatile Task? _initTask; - - private readonly SemaphoreSlim _initLock = new(1, 1); + private readonly LazyAsyncInit _init = new(); private static readonly Func> GetRecentContentQuery = EF.CompileAsyncQuery((MsSqlMemoryContext db) => @@ -52,10 +50,10 @@ private static readonly Func> public async Task GetContextAsync(CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); var facts = new List(); - await foreach (var content in GetRecentContentQuery(context).WithCancellation(ct)) + await foreach (var content in GetRecentContentQuery(context).WithCancellation(ct).ConfigureAwait(false)) { facts.Add($"- {content}"); } @@ -65,24 +63,24 @@ private static readonly Func> public async Task AppendFactAsync(string fact, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); context.Facts.Add(new Fact { Content = fact, CreatedAt = DateTimeOffset.UtcNow }); - await context.SaveChangesAsync(ct); + await context.SaveChangesAsync(ct).ConfigureAwait(false); } public async Task AppendHistoryAsync(string summary, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); context.History.Add(new HistoryEntry(summary, DateTimeOffset.UtcNow)); - await context.SaveChangesAsync(ct); + await context.SaveChangesAsync(ct).ConfigureAwait(false); } public async Task> SearchAsync(string query, int n = 5, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); List results = []; try @@ -98,7 +96,7 @@ public async Task> SearchAsync(string query, int n = 5, Ca .OrderByDescending(f => f.Id) .Take(n) .Select(f => f.Content) - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); } catch (Exception ex) { @@ -109,7 +107,7 @@ public async Task> SearchAsync(string query, int n = 5, Ca if (results.Count == 0) { var pattern = $"%{EscapeLikePattern(query)}%"; - await foreach (var content in SearchLikeFallbackQuery(context, pattern, n).WithCancellation(ct)) + await foreach (var content in SearchLikeFallbackQuery(context, pattern, n).WithCancellation(ct).ConfigureAwait(false)) { results.Add(content); } @@ -121,12 +119,12 @@ public async Task> SearchAsync(string query, int n = 5, Ca public async Task> SearchHybridAsync(string query, float[]? queryEmbedding = null, int topK = 5, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); // MsSql backend does not support embeddings — fall back to LIKE search - await using var context = await contextFactory.CreateDbContextAsync(ct); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); var pattern = $"%{EscapeLikePattern(query)}%"; var results = new List(); - await foreach (var fact in SearchHybridLikeQuery(context, pattern, topK).WithCancellation(ct)) + await foreach (var fact in SearchHybridLikeQuery(context, pattern, topK).WithCancellation(ct).ConfigureAwait(false)) { results.Add(fact); } @@ -134,7 +132,7 @@ public async Task> SearchHybridAsync(string query, float[]? var ids = results.Select(f => f.Id).ToList(); if (ids.Count > 0) { - await UpdateAccessCountsAsync(ids, ct); + await UpdateAccessCountsAsync(ids, ct).ConfigureAwait(false); } return results; @@ -142,10 +140,10 @@ public async Task> SearchHybridAsync(string query, float[]? public async Task> ListFactsAsync(CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); var facts = new List(); - await foreach (var fact in ListAllFactsQuery(context).WithCancellation(ct)) + await foreach (var fact in ListAllFactsQuery(context).WithCancellation(ct).ConfigureAwait(false)) { facts.Add(fact); } @@ -155,33 +153,33 @@ public async Task> ListFactsAsync(CancellationToken ct = def public async Task ClearAsync(CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); - await context.Database.ExecuteSqlRawAsync($"TRUNCATE TABLE {Fact.TableName}", ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); + await context.Database.ExecuteSqlRawAsync($"DELETE FROM {Fact.TableName}", ct).ConfigureAwait(false); } public async Task PruneExpiredFactsAsync(TimeSpan maxAge, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); var cutoff = DateTimeOffset.UtcNow - maxAge; return await context.Facts .Where(f => f.CreatedAt < cutoff) - .ExecuteDeleteAsync(ct); + .ExecuteDeleteAsync(ct).ConfigureAwait(false); } private async Task UpdateAccessCountsAsync(List ids, CancellationToken ct = default) { try { - await using var ctx = await contextFactory.CreateDbContextAsync(ct); + await using var ctx = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); var now = DateTimeOffset.UtcNow; await ctx.Facts .Where(f => ids.Contains(f.Id)) .ExecuteUpdateAsync(s => s .SetProperty(f => f.AccessCount, f => f.AccessCount + 1) - .SetProperty(f => f.LastAccessedAt, now), ct); + .SetProperty(f => f.LastAccessedAt, now), ct).ConfigureAwait(false); } catch (Exception ex) { @@ -193,43 +191,13 @@ await ctx.Facts private static string EscapeLikePattern(string query) => query.Replace("[", "[[]").Replace("%", "[%]").Replace("_", "[_]"); - private async Task EnsureInitializedAsync(CancellationToken ct) - { - if (_initTask is { IsCompletedSuccessfully: true }) - { - return; - } - - await _initLock.WaitAsync(ct); - try - { - if (_initTask is { IsCompletedSuccessfully: true }) - { - return; - } - - var task = InitSchemaAsync(ct); - _initTask = task; - await task; - } - catch - { - _initTask = null; // allow retry on next call - throw; - } - finally - { - _initLock.Release(); - } - } - [RequiresDynamicCode("EF Core MigrateAsync builds the design-time model at runtime. Not compatible with NativeAOT; use migration bundles for AOT deployment.")] private async Task InitSchemaAsync(CancellationToken ct) { - await using var context = await contextFactory.CreateDbContextAsync(ct); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); using var migrationCts = CancellationTokenSource.CreateLinkedTokenSource(ct); migrationCts.CancelAfter(TimeSpan.FromSeconds(30)); - await context.Database.MigrateAsync(migrationCts.Token); + await context.Database.MigrateAsync(migrationCts.Token).ConfigureAwait(false); const string sql = $""" @@ -240,7 +208,7 @@ IF NOT EXISTS (SELECT 1 FROM sys.columns WHERE object_id = OBJECT_ID('{Fact.Tabl END """; - await context.Database.ExecuteSqlRawAsync(sql, migrationCts.Token); + await context.Database.ExecuteSqlRawAsync(sql, migrationCts.Token).ConfigureAwait(false); } [LoggerMessage(EventId = 1, Level = LogLevel.Warning, Message = "Memory operation failed: {Message}")] diff --git a/src/clawsharp/Memory/OllamaEmbeddingProvider.cs b/src/clawsharp/Memory/OllamaEmbeddingProvider.cs index ae294d54..5606c173 100644 --- a/src/clawsharp/Memory/OllamaEmbeddingProvider.cs +++ b/src/clawsharp/Memory/OllamaEmbeddingProvider.cs @@ -25,10 +25,10 @@ public async Task EmbedAsync(string text, CancellationToken ct = defaul using var httpRequest = new HttpRequestMessage(HttpMethod.Post, $"{_baseUrl}/api/embeddings"); httpRequest.Content = JsonContent.Create(request, EmbeddingJsonContext.Default.OllamaEmbeddingRequest); - using var response = await client.SendAsync(httpRequest, ct); + using var response = await client.SendAsync(httpRequest, ct).ConfigureAwait(false); response.EnsureSuccessStatusCode(); - var result = await response.Content.ReadFromJsonAsync(EmbeddingJsonContext.Default.OllamaEmbeddingResponse, ct); + var result = await response.Content.ReadFromJsonAsync(EmbeddingJsonContext.Default.OllamaEmbeddingResponse, ct).ConfigureAwait(false); var embedding = result?.Embedding; if (embedding is null || embedding.Length == 0) diff --git a/src/clawsharp/Memory/OpenAiEmbeddingProvider.cs b/src/clawsharp/Memory/OpenAiEmbeddingProvider.cs index fdde1efc..4e19bebf 100644 --- a/src/clawsharp/Memory/OpenAiEmbeddingProvider.cs +++ b/src/clawsharp/Memory/OpenAiEmbeddingProvider.cs @@ -36,10 +36,10 @@ public async Task EmbedAsync(string text, CancellationToken ct = defaul httpRequest.Headers.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", apiKey); httpRequest.Content = JsonContent.Create(request, EmbeddingJsonContext.Default.EmbeddingRequest); - using var response = await client.SendAsync(httpRequest, ct); + using var response = await client.SendAsync(httpRequest, ct).ConfigureAwait(false); response.EnsureSuccessStatusCode(); - var result = await response.Content.ReadFromJsonAsync(EmbeddingJsonContext.Default.EmbeddingResponse, ct); + var result = await response.Content.ReadFromJsonAsync(EmbeddingJsonContext.Default.EmbeddingResponse, ct).ConfigureAwait(false); var embedding = result?.Data is { Length: > 0 } ? result.Data[0].Embedding : null; if (embedding is null || embedding.Length == 0) diff --git a/src/clawsharp/Memory/Postgres/PostgresKnowledgeStore.cs b/src/clawsharp/Memory/Postgres/PostgresKnowledgeStore.cs index 30a26755..342136d6 100644 --- a/src/clawsharp/Memory/Postgres/PostgresKnowledgeStore.cs +++ b/src/clawsharp/Memory/Postgres/PostgresKnowledgeStore.cs @@ -22,60 +22,59 @@ public sealed partial class PostgresKnowledgeStore( { private const int CandidateCount = 30; - private volatile Task? _initTask; - private readonly SemaphoreSlim _initLock = new(1, 1); + private readonly LazyAsyncInit _init = new(); public async Task UpsertChunksAsync(Guid sourceId, IReadOnlyList chunks, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); // Delete existing chunks for source (re-ingestion replaces all) - await context.KnowledgeChunks.Where(c => c.KnowledgeSourceId == sourceId).ExecuteDeleteAsync(ct); + await context.KnowledgeChunks.Where(c => c.KnowledgeSourceId == sourceId).ExecuteDeleteAsync(ct).ConfigureAwait(false); // Insert new chunks context.KnowledgeChunks.AddRange(chunks); - await context.SaveChangesAsync(ct); + await context.SaveChangesAsync(ct).ConfigureAwait(false); // Update source chunk count await context.KnowledgeSources .Where(s => s.Id == sourceId) .ExecuteUpdateAsync(s => s .SetProperty(x => x.ChunkCount, chunks.Count) - .SetProperty(x => x.UpdatedAt, DateTimeOffset.UtcNow), ct); + .SetProperty(x => x.UpdatedAt, DateTimeOffset.UtcNow), ct).ConfigureAwait(false); } public async Task DeleteByDocumentAsync(Guid sourceId, string sourceUri, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); await context.KnowledgeChunks .Where(c => c.KnowledgeSourceId == sourceId && c.SourceUri == sourceUri) - .ExecuteDeleteAsync(ct); + .ExecuteDeleteAsync(ct).ConfigureAwait(false); } public async Task DeleteBySourceAsync(Guid sourceId, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); // Cascade delete: deleting the source removes all chunks via FK cascade - await context.KnowledgeSources.Where(s => s.Id == sourceId).ExecuteDeleteAsync(ct); + await context.KnowledgeSources.Where(s => s.Id == sourceId).ExecuteDeleteAsync(ct).ConfigureAwait(false); } public async Task> SearchAsync( float[]? queryEmbedding, string queryText, AclFilter acl, int topK = 5, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); // Path 1: tsvector FTS with ts_rank - var ftsResults = await FtsSearchAsync(context, queryText, acl, ct); + var ftsResults = await FtsSearchAsync(context, queryText, acl, ct).ConfigureAwait(false); // Path 2: pgvector KNN cosine distance (skipped when embedding is null per D-13) var vectorResults = queryEmbedding is not null - ? await VectorSearchAsync(context, queryEmbedding, acl, ct) + ? await VectorSearchAsync(context, queryEmbedding, acl, ct).ConfigureAwait(false) : []; // Build chunk lookup and RRF merge @@ -92,39 +91,48 @@ public async Task> SearchAsync( var chunkLookup = await context.KnowledgeChunks .AsNoTracking() .Where(c => allIds.Contains(c.Id)) - .ToDictionaryAsync(c => c.Id, ct); + .ToDictionaryAsync(c => c.Id, ct).ConfigureAwait(false); return RrfMerger.Merge(ftsResults, vectorResults, chunkLookup, topK: topK); } public async Task> ListSourcesAsync(CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); - return await context.KnowledgeSources.AsNoTracking().OrderByDescending(s => s.CreatedAt).ToListAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); + return await context.KnowledgeSources.AsNoTracking().OrderByDescending(s => s.CreatedAt).ToListAsync(ct).ConfigureAwait(false); } public async Task GetSourceAsync(Guid id, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); - return await context.KnowledgeSources.AsNoTracking().FirstOrDefaultAsync(s => s.Id == id, ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); + return await context.KnowledgeSources.AsNoTracking().FirstOrDefaultAsync(s => s.Id == id, ct).ConfigureAwait(false); } /// public async Task> GetDocumentHashesBySourceAsync(Guid sourceId, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); var pairs = await context.KnowledgeChunks .AsNoTracking() .Where(c => c.KnowledgeSourceId == sourceId) .Select(c => new { c.SourceUri, c.DocumentHash }) .Distinct() - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); return pairs.ToDictionary(p => p.SourceUri, p => p.DocumentHash, StringComparer.Ordinal); } + /// + public async Task CountChunksAsync(Guid sourceId, CancellationToken ct = default) + { + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); + return await context.KnowledgeChunks + .CountAsync(c => c.KnowledgeSourceId == sourceId, ct).ConfigureAwait(false); + } + // ── FTS search ─────────────────────────────────────────────── private async Task> FtsSearchAsync( @@ -151,7 +159,7 @@ ORDER BY ts_rank(knowledge_content_tsv, websearch_to_tsquery('simple', {2})) DES queryText, depts, queryText) .AsNoTracking() .Select(c => c.Id) - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); } else { @@ -166,7 +174,7 @@ ORDER BY ts_rank(knowledge_content_tsv, websearch_to_tsquery('simple', {1})) DES queryText, queryText) .AsNoTracking() .Select(c => c.Id) - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); } var rank = 1; @@ -208,7 +216,7 @@ ORDER BY ts_rank(knowledge_content_tsv, websearch_to_tsquery('simple', {1})) DES .OrderBy(c => c.Embedding!.CosineDistance(queryVector)) .Take(CandidateCount) .Select(c => c.Id) - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); return candidates.Select((id, i) => (id, i + 1)).ToList(); } @@ -221,34 +229,10 @@ ORDER BY ts_rank(knowledge_content_tsv, websearch_to_tsquery('simple', {1})) DES // ── Init ───────────────────────────────────────────────────── - private async Task EnsureInitializedAsync(CancellationToken ct) - { - if (_initTask is { IsCompletedSuccessfully: true }) return; - - await _initLock.WaitAsync(ct); - try - { - if (_initTask is { IsCompletedSuccessfully: true }) return; - - var task = InitSchemaAsync(ct); - _initTask = task; - await task; - } - catch - { - _initTask = null; - throw; - } - finally - { - _initLock.Release(); - } - } - [RequiresDynamicCode("EF Core MigrateAsync requires dynamic code generation.")] private async Task InitSchemaAsync(CancellationToken ct) { - await using var context = await contextFactory.CreateDbContextAsync(ct); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); // Add tsvector generated column + GIN index for KnowledgeChunks content const string contentTsvSql = @@ -268,7 +252,7 @@ CREATE INDEX IF NOT EXISTS knowledge_chunks_tsv_idx END $$; """; - await context.Database.ExecuteSqlRawAsync(contentTsvSql, ct); + await context.Database.ExecuteSqlRawAsync(contentTsvSql, ct).ConfigureAwait(false); LogSchemaInitialized(logger); } diff --git a/src/clawsharp/Memory/Postgres/PostgresMemory.cs b/src/clawsharp/Memory/Postgres/PostgresMemory.cs index 113fd50c..aa80dadb 100644 --- a/src/clawsharp/Memory/Postgres/PostgresMemory.cs +++ b/src/clawsharp/Memory/Postgres/PostgresMemory.cs @@ -37,9 +37,7 @@ public sealed partial class PostgresMemory( private readonly int _embeddingDimension = memoryConfig?.Value.EmbeddingDimension ?? 1536; - private volatile Task? _initTask; - - private readonly SemaphoreSlim _initLock = new(1, 1); + private readonly LazyAsyncInit _init = new(); /// Whether the pgvector extension is available. Set during schema init. private bool _pgvectorAvailable; @@ -85,10 +83,10 @@ private static readonly Func> public async Task GetContextAsync(CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); var facts = new List(); - await foreach (var content in GetRecentContentQuery(context).WithCancellation(ct)) + await foreach (var content in GetRecentContentQuery(context).WithCancellation(ct).ConfigureAwait(false)) { facts.Add($"- {content}"); } @@ -98,14 +96,14 @@ private static readonly Func> public async Task AppendFactAsync(string fact, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); float[]? embedding = null; if (embeddingProvider is not null) { try { - embedding = await embeddingProvider.EmbedAsync(fact, ct); + embedding = await embeddingProvider.EmbedAsync(fact, ct).ConfigureAwait(false); } catch (Exception ex) { @@ -127,7 +125,7 @@ public async Task AppendFactAsync(string fact, CancellationToken ct = default) }; context.Facts.Add(entity); - await context.SaveChangesAsync(ct); + await context.SaveChangesAsync(ct).ConfigureAwait(false); // Legacy fallback: write JSON TEXT column when pgvector is not available if (embedding is not null && !_pgvectorAvailable) @@ -137,7 +135,7 @@ public async Task AppendFactAsync(string fact, CancellationToken ct = default) var json = EmbeddingMath.Serialize(embedding); await context.Database.ExecuteSqlRawAsync( $"UPDATE \"{Fact.TableName}\" SET embedding = {{0}} WHERE \"Id\" = {{1}}", - new object[] { json, entity.Id }, ct); + new object[] { json, entity.Id }, ct).ConfigureAwait(false); } catch (Exception ex) { @@ -148,16 +146,16 @@ await context.Database.ExecuteSqlRawAsync( public async Task AppendHistoryAsync(string summary, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); context.History.Add(new HistoryEntry(summary, DateTimeOffset.UtcNow)); - await context.SaveChangesAsync(ct); + await context.SaveChangesAsync(ct).ConfigureAwait(false); } public async Task> SearchAsync(string query, int n = 5, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); const string sql = $$""" @@ -169,12 +167,12 @@ ORDER BY ts_rank(content_tsv, websearch_to_tsquery('simple', {1})) DESC var results = await context.Database .SqlQueryRaw(sql, query, query, n) - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); if (results.Count == 0) { var pattern = $"%{EscapeLikePattern(query)}%"; - await foreach (var content in SearchILikeFallbackQuery(context, pattern, n).WithCancellation(ct)) + await foreach (var content in SearchILikeFallbackQuery(context, pattern, n).WithCancellation(ct).ConfigureAwait(false)) { results.Add(content); } @@ -186,15 +184,15 @@ ORDER BY ts_rank(content_tsv, websearch_to_tsquery('simple', {1})) DESC public async Task> SearchHybridAsync(string query, float[]? queryEmbedding = null, int topK = 5, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); // If no query embedding, fall back to ILIKE search returning Fact objects if (queryEmbedding is null || queryEmbedding.Length == 0) { var pattern = $"%{EscapeLikePattern(query)}%"; var candidates = new List(); - await foreach (var fact in SearchHybridILikeQuery(context, pattern, topK).WithCancellation(ct)) + await foreach (var fact in SearchHybridILikeQuery(context, pattern, topK).WithCancellation(ct).ConfigureAwait(false)) { candidates.Add(fact); } @@ -202,7 +200,7 @@ public async Task> SearchHybridAsync(string query, float[]? var ids = candidates.Select(f => f.Id).ToList(); if (ids.Count > 0) { - await UpdateAccessCountsAsync(ids, ct); + await UpdateAccessCountsAsync(ids, ct).ConfigureAwait(false); } return candidates; @@ -211,10 +209,10 @@ public async Task> SearchHybridAsync(string query, float[]? // Use native pgvector ANN if available, otherwise fall back to in-process cosine if (_pgvectorAvailable) { - return await SearchHybridPgvectorAsync(query, queryEmbedding, topK, context, ct); + return await SearchHybridPgvectorAsync(query, queryEmbedding, topK, context, ct).ConfigureAwait(false); } - return await SearchHybridFallbackAsync(query, queryEmbedding, topK, context, ct); + return await SearchHybridFallbackAsync(query, queryEmbedding, topK, context, ct).ConfigureAwait(false); } /// @@ -239,7 +237,7 @@ ORDER BY ts_rank(content_tsv, websearch_to_tsquery('simple', {1})) DESC candidateIds = await context.Database .SqlQueryRaw(sql, query, query) - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); } catch (Exception ex) { @@ -262,13 +260,13 @@ ORDER BY ts_rank(content_tsv, websearch_to_tsquery('simple', {1})) DESC .OrderBy(f => f.Embedding!.CosineDistance(queryVector)) .Take(topK * OversampleFactor) .Select(f => new { Fact = f, Distance = f.Embedding!.CosineDistance(queryVector) }) - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); if (candidatesWithDistance.Count == 0) { // No embeddings at all — fall back to most recent facts var recentFacts = new List(); - await foreach (var fact in GetRecentFactsQuery(context, topK).WithCancellation(ct)) + await foreach (var fact in GetRecentFactsQuery(context, topK).WithCancellation(ct).ConfigureAwait(false)) { recentFacts.Add(fact); } @@ -276,7 +274,7 @@ ORDER BY ts_rank(content_tsv, websearch_to_tsquery('simple', {1})) DESC var fallbackIds = recentFacts.Select(f => f.Id).ToList(); if (fallbackIds.Count > 0) { - await UpdateAccessCountsAsync(fallbackIds, ct); + await UpdateAccessCountsAsync(fallbackIds, ct).ConfigureAwait(false); } return recentFacts; @@ -312,7 +310,7 @@ ORDER BY ts_rank(content_tsv, websearch_to_tsquery('simple', {1})) DESC var returnedIds = scored.Select(f => f.Id).ToList(); if (returnedIds.Count > 0) { - await UpdateAccessCountsAsync(returnedIds, ct); + await UpdateAccessCountsAsync(returnedIds, ct).ConfigureAwait(false); } return scored; @@ -340,7 +338,7 @@ ORDER BY ts_rank(content_tsv, websearch_to_tsquery('simple', {1})) DESC rows = await context.Database .SqlQueryRaw(tsquerySql, query, query) - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); if (rows.Count == 0) { @@ -354,7 +352,7 @@ embedding AS "EmbeddingJson" rows = await context.Database .SqlQuery(fallbackSql) - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); } } catch (Exception ex) @@ -370,7 +368,7 @@ embedding AS "EmbeddingJson" rows = await context.Database .SqlQueryRaw(fallbackSql) - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); } var scored = rows.Select(row => @@ -408,7 +406,7 @@ embedding AS "EmbeddingJson" var returnedIds = scored.Select(f => f.Id).ToList(); if (returnedIds.Count > 0) { - await UpdateAccessCountsAsync(returnedIds, ct); + await UpdateAccessCountsAsync(returnedIds, ct).ConfigureAwait(false); } return scored; @@ -416,10 +414,10 @@ embedding AS "EmbeddingJson" public async Task> ListFactsAsync(CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); var facts = new List(); - await foreach (var fact in ListAllFactsQuery(context).WithCancellation(ct)) + await foreach (var fact in ListAllFactsQuery(context).WithCancellation(ct).ConfigureAwait(false)) { facts.Add(fact); } @@ -429,33 +427,33 @@ public async Task> ListFactsAsync(CancellationToken ct = def public async Task ClearAsync(CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); - await context.Database.ExecuteSqlRawAsync($"TRUNCATE TABLE \"{Fact.TableName}\"", ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); + await context.Database.ExecuteSqlRawAsync($"TRUNCATE TABLE \"{Fact.TableName}\"", ct).ConfigureAwait(false); } public async Task PruneExpiredFactsAsync(TimeSpan maxAge, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); var cutoff = DateTimeOffset.UtcNow - maxAge; return await context.Facts .Where(f => f.CreatedAt < cutoff) - .ExecuteDeleteAsync(ct); + .ExecuteDeleteAsync(ct).ConfigureAwait(false); } private async Task UpdateAccessCountsAsync(List ids, CancellationToken ct = default) { try { - await using var ctx = await contextFactory.CreateDbContextAsync(ct); + await using var ctx = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); var now = DateTimeOffset.UtcNow; await ctx.Facts .Where(f => ids.Contains(f.Id)) .ExecuteUpdateAsync(s => s .SetProperty(f => f.AccessCount, f => f.AccessCount + 1) - .SetProperty(f => f.LastAccessedAt, now), ct); + .SetProperty(f => f.LastAccessedAt, now), ct).ConfigureAwait(false); } catch (Exception ex) { @@ -468,47 +466,17 @@ await ctx.Facts private static string EscapeLikePattern(string query) => query.Replace(@"\", @"\\").Replace("%", @"\%").Replace("_", @"\_"); - private async Task EnsureInitializedAsync(CancellationToken ct) - { - if (_initTask is { IsCompletedSuccessfully: true }) - { - return; - } - - await _initLock.WaitAsync(ct); - try - { - if (_initTask is { IsCompletedSuccessfully: true }) - { - return; - } - - var task = InitSchemaAsync(ct); - _initTask = task; - await task; - } - catch - { - _initTask = null; // allow retry on next call - throw; - } - finally - { - _initLock.Release(); - } - } - [RequiresDynamicCode( "EF Core MigrateAsync builds the design-time model at runtime. Not compatible with NativeAOT; use migration bundles for AOT deployment.")] private async Task InitSchemaAsync(CancellationToken ct) { - await using var context = await contextFactory.CreateDbContextAsync(ct); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); // MigrateAsync applies pending migrations. For existing databases originally created via // EnsureCreated, the __EFMigrationsHistory table will be created and the initial migration // marked as applied if the schema already matches. using var migrationCts = CancellationTokenSource.CreateLinkedTokenSource(ct); migrationCts.CancelAfter(TimeSpan.FromSeconds(30)); - await context.Database.MigrateAsync(migrationCts.Token); + await context.Database.MigrateAsync(migrationCts.Token).ConfigureAwait(false); // Add content_tsv generated column + GIN index if not already present const string contentTsvSql = @@ -527,7 +495,7 @@ ADD COLUMN content_tsv tsvector END $$; """; - await context.Database.ExecuteSqlRawAsync(contentTsvSql); + await context.Database.ExecuteSqlRawAsync(contentTsvSql).ConfigureAwait(false); // Add legacy TEXT embedding column if not already present (backward compatibility) const string embeddingColumnSql = @@ -543,7 +511,7 @@ SELECT 1 FROM information_schema.columns END $$; """; - await context.Database.ExecuteSqlRawAsync(embeddingColumnSql); + await context.Database.ExecuteSqlRawAsync(embeddingColumnSql).ConfigureAwait(false); // Add access tracking columns if not already present const string accessTrackingSql = @@ -560,13 +528,13 @@ SELECT 1 FROM information_schema.columns END $$; """; - await context.Database.ExecuteSqlRawAsync(accessTrackingSql, cancellationToken: migrationCts.Token); + await context.Database.ExecuteSqlRawAsync(accessTrackingSql, cancellationToken: migrationCts.Token).ConfigureAwait(false); // pgvector: install extension + add vector column + HNSW index var dim = _embeddingDimension; try { - await context.Database.ExecuteSqlRawAsync("CREATE EXTENSION IF NOT EXISTS vector", cancellationToken: migrationCts.Token); + await context.Database.ExecuteSqlRawAsync("CREATE EXTENSION IF NOT EXISTS vector", cancellationToken: migrationCts.Token).ConfigureAwait(false); var pgvecDdl = string.Create(null, stackalloc char[512], $""" @@ -584,7 +552,7 @@ CREATE INDEX IF NOT EXISTS facts_embedding_hnsw_idx END $$; """); - await context.Database.ExecuteSqlRawAsync(pgvecDdl, cancellationToken: migrationCts.Token); + await context.Database.ExecuteSqlRawAsync(pgvecDdl, cancellationToken: migrationCts.Token).ConfigureAwait(false); _pgvectorAvailable = true; LogPgvectorLoaded(logger, dim); } diff --git a/src/clawsharp/Memory/Redis/RedisKnowledgeStore.cs b/src/clawsharp/Memory/Redis/RedisKnowledgeStore.cs index d551f07a..aa4ac3a2 100644 --- a/src/clawsharp/Memory/Redis/RedisKnowledgeStore.cs +++ b/src/clawsharp/Memory/Redis/RedisKnowledgeStore.cs @@ -6,6 +6,7 @@ using NRedisStack; using NRedisStack.RedisStackCommands; using NRedisStack.Search; +using NRedisStack.Search.DataTypes; using NRedisStack.Search.Literals.Enums; using StackExchange.Redis; @@ -53,19 +54,20 @@ public sealed partial class RedisKnowledgeStore( private readonly int _embeddingDimension = memoryConfig?.Value.EmbeddingDimension ?? 1536; - private volatile Task? _initTask; - private readonly SemaphoreSlim _initLock = new(1, 1); + private readonly LazyAsyncInit _init = new(); private volatile bool _vectorSearchEnabled; public async Task UpsertChunksAsync(Guid sourceId, IReadOnlyList chunks, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); + await _init.EnsureCompletedAsync(InitIndexAsync, ct).ConfigureAwait(false); var db = redis.GetDatabase(); // Delete existing chunks for source - await DeleteChunksBySourceIdAsync(db, sourceId); + await DeleteChunksBySourceIdAsync(db, sourceId).ConfigureAwait(false); - // Insert new chunks + // Insert new chunks — pipelined to avoid per-chunk round-trips (M-2) + var batch = db.CreateBatch(); + var upsertTasks = new List(chunks.Count); foreach (var chunk in chunks) { var key = $"{ChunkPrefix}{chunk.Id}"; @@ -86,9 +88,12 @@ public async Task UpsertChunksAsync(Guid sourceId, IReadOnlyList entries.Add(new HashEntry(EmbeddingField, EmbeddingToBlob(chunk.Embedding.ToArray()))); } - await db.HashSetAsync(key, entries.ToArray()); + upsertTasks.Add(batch.HashSetAsync(key, entries.ToArray())); } + batch.Execute(); + await Task.WhenAll(upsertTasks).ConfigureAwait(false); + // Update source metadata var sourceKey = $"{SourcePrefix}{sourceId}"; if (await db.KeyExistsAsync(sourceKey)) @@ -97,43 +102,43 @@ await db.HashSetAsync(sourceKey, [ new HashEntry(ChunkCountField, chunks.Count), new HashEntry(UpdatedAtField, DateTimeOffset.UtcNow.ToString("O")), - ]); + ]).ConfigureAwait(false); } } public async Task DeleteByDocumentAsync(Guid sourceId, string sourceUri, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); + await _init.EnsureCompletedAsync(InitIndexAsync, ct).ConfigureAwait(false); var db = redis.GetDatabase(); - await DeleteChunksBySourceIdAndUriAsync(db, sourceId, sourceUri); + await DeleteChunksBySourceIdAndUriAsync(db, sourceId, sourceUri).ConfigureAwait(false); } public async Task DeleteBySourceAsync(Guid sourceId, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); + await _init.EnsureCompletedAsync(InitIndexAsync, ct).ConfigureAwait(false); var db = redis.GetDatabase(); - await DeleteChunksBySourceIdAsync(db, sourceId); + await DeleteChunksBySourceIdAsync(db, sourceId).ConfigureAwait(false); // Delete source record var sourceKey = $"{SourcePrefix}{sourceId}"; - await db.KeyDeleteAsync(sourceKey); + await db.KeyDeleteAsync(sourceKey).ConfigureAwait(false); } public async Task> SearchAsync( float[]? queryEmbedding, string queryText, AclFilter acl, int topK = 5, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); + await _init.EnsureCompletedAsync(InitIndexAsync, ct).ConfigureAwait(false); var db = redis.GetDatabase(); var ft = db.FT(); // Path 1: Text BM25 search - var ftsResults = await TextSearchAsync(ft, queryText, acl); + var ftsResults = await TextSearchAsync(ft, queryText, acl).ConfigureAwait(false); // Path 2: Vector KNN search (skipped when embedding is null per D-13) var vectorResults = queryEmbedding is not null - ? await VectorSearchAsync(ft, queryEmbedding, acl) + ? await VectorSearchAsync(ft, queryEmbedding, acl).ConfigureAwait(false) : []; // Build chunk lookup from both result sets @@ -147,11 +152,18 @@ public async Task> SearchAsync( return []; } - var chunkLookup = new Dictionary(); - foreach (var id in allIds) + // Pipeline all hash lookups to avoid per-chunk round-trips (M-1) + var hydrateBatch = db.CreateBatch(); + var hydrateTasks = allIds + .Select(id => (Id: id, Task: hydrateBatch.HashGetAllAsync($"{ChunkPrefix}{id}"))) + .ToList(); + hydrateBatch.Execute(); + await Task.WhenAll(hydrateTasks.Select(t => t.Task)).ConfigureAwait(false); + + var chunkLookup = new Dictionary(hydrateTasks.Count); + foreach (var (id, task) in hydrateTasks) { - var key = $"{ChunkPrefix}{id}"; - var hash = await db.HashGetAllAsync(key); + var hash = task.Result; if (hash.Length > 0) { chunkLookup[id] = ChunkFromHash(id, hash); @@ -163,14 +175,14 @@ public async Task> SearchAsync( public async Task> ListSourcesAsync(CancellationToken ct = default) { - await EnsureInitializedAsync(ct); + await _init.EnsureCompletedAsync(InitIndexAsync, ct).ConfigureAwait(false); var db = redis.GetDatabase(); var server = redis.GetServer(redis.GetEndPoints()[0]); var sources = new List(); - await foreach (var key in server.KeysAsync(pattern: $"{SourcePrefix}*")) + await foreach (var key in server.KeysAsync(pattern: $"{SourcePrefix}*").ConfigureAwait(false)) { - var hash = await db.HashGetAllAsync(key); + var hash = await db.HashGetAllAsync(key).ConfigureAwait(false); if (hash.Length > 0) { var idStr = key.ToString()[(SourcePrefix.Length)..]; @@ -186,10 +198,10 @@ public async Task> ListSourcesAsync(CancellationT public async Task GetSourceAsync(Guid id, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); + await _init.EnsureCompletedAsync(InitIndexAsync, ct).ConfigureAwait(false); var db = redis.GetDatabase(); var key = $"{SourcePrefix}{id}"; - var hash = await db.HashGetAllAsync(key); + var hash = await db.HashGetAllAsync(key).ConfigureAwait(false); if (hash.Length == 0) return null; return SourceFromHash(id, hash); } @@ -197,23 +209,131 @@ public async Task> ListSourcesAsync(CancellationT /// public async Task> GetDocumentHashesBySourceAsync(Guid sourceId, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); + await _init.EnsureCompletedAsync(InitIndexAsync, ct).ConfigureAwait(false); var db = redis.GetDatabase(); - var server = redis.GetServer(redis.GetEndPoints()[0]); var result = new Dictionary(StringComparer.Ordinal); + var escapedId = sourceId.ToString().Replace("-", "\\-"); - await foreach (var key in server.KeysAsync(pattern: $"{ChunkPrefix}*")) + // Use FT.SEARCH with @sourceId TAG filter instead of KEYS scan (M-3) + var ft = db.FT(); + try + { + var offset = 0; + const int pageSize = 100; + while (true) + { + var query = new Query($"@{SourceIdField}:{{{escapedId}}}") + .Limit(offset, pageSize) + .ReturnFields(SourceUriField, DocumentHashField) + .Dialect(2); + + var searchResult = await ft.SearchAsync(IndexName, query).ConfigureAwait(false); + foreach (var doc in searchResult.Documents) + { + var sourceUri = (string?)doc[SourceUriField]; + var docHash = (string?)doc[DocumentHashField]; + if (sourceUri is not null && docHash is not null) + { + result.TryAdd(sourceUri, docHash); + } + } + + if (searchResult.Documents.Count < pageSize) break; + offset += pageSize; + } + } + catch (RedisServerException) { - var fields = await db.HashGetAsync(key, [SourceIdField, SourceUriField, DocumentHashField]); - if (fields[0].IsNullOrEmpty || fields[0].ToString() != sourceId.ToString()) continue; - var sourceUri = fields[1].ToString(); - var docHash = fields[2].ToString(); - result.TryAdd(sourceUri, docHash); + // Index not ready — fall back to KEYS scan with batched reads + var server = redis.GetServer(redis.GetEndPoints()[0]); + var keys = new List(); + await foreach (var key in server.KeysAsync(pattern: $"{ChunkPrefix}*").ConfigureAwait(false)) + { + keys.Add(key); + } + + if (keys.Count > 0) + { + var batch = db.CreateBatch(); + var tasks = keys + .Select(k => (Key: k, Task: batch.HashGetAsync(k, [SourceIdField, SourceUriField, DocumentHashField]))) + .ToList(); + batch.Execute(); + await Task.WhenAll(tasks.Select(t => t.Task)).ConfigureAwait(false); + + foreach (var (_, task) in tasks) + { + var fields = task.Result; + if (fields[0].IsNullOrEmpty || fields[0].ToString() != sourceId.ToString()) continue; + var sourceUri = fields[1].ToString(); + var docHash = fields[2].ToString(); + result.TryAdd(sourceUri, docHash); + } + } } return result; } + /// + public async Task CountChunksAsync(Guid sourceId, CancellationToken ct = default) + { + await _init.EnsureCompletedAsync(InitIndexAsync, ct).ConfigureAwait(false); + var db = redis.GetDatabase(); + var ft = db.FT(); + var escapedId = sourceId.ToString().Replace("-", "\\-"); + + try + { + var count = 0; + var offset = 0; + const int pageSize = 100; + while (true) + { + var query = new Query($"@{SourceIdField}:{{{escapedId}}}") + .Limit(offset, pageSize) + .SetNoContent() + .Dialect(2); + + var result = await ft.SearchAsync(IndexName, query).ConfigureAwait(false); + count += result.Documents.Count; + if (result.Documents.Count < pageSize) break; + offset += pageSize; + } + + return count; + } + catch (RedisServerException) + { + // Index not ready -- fall back to KEYS scan + var server = redis.GetServer(redis.GetEndPoints()[0]); + var count = 0; + + var keys = new List(); + await foreach (var key in server.KeysAsync(pattern: $"{ChunkPrefix}*").ConfigureAwait(false)) + { + keys.Add(key); + } + + if (keys.Count > 0) + { + var batch = db.CreateBatch(); + var tasks = keys.Select(k => (Key: k, Task: batch.HashGetAsync(k, SourceIdField))).ToList(); + batch.Execute(); + await Task.WhenAll(tasks.Select(t => t.Task)).ConfigureAwait(false); + + foreach (var (_, task) in tasks) + { + var sid = task.Result; + if (!sid.IsNullOrEmpty && sid.ToString() == sourceId.ToString()) + count++; + } + } + + return count; + } + } + // ── Text search ────────────────────────────────────────────── private async Task> TextSearchAsync( @@ -226,7 +346,7 @@ public async Task> GetDocumentHashesBySource { var escaped = EscapeRediSearchQuery(queryText); var filter = acl.HasRestrictions - ? $"(@{DepartmentField}:{{{string.Join("|", acl.DepartmentIds)}}} {escaped})" + ? $"(@{DepartmentField}:{{{string.Join("|", acl.DepartmentIds.Select(EscapeTagValue))}}} {escaped})" : escaped; var query = new Query(filter) @@ -234,7 +354,7 @@ public async Task> GetDocumentHashesBySource .ReturnFields(ContentField) .Dialect(2); - var result = await ft.SearchAsync(IndexName, query); + var result = await ft.SearchAsync(IndexName, query).ConfigureAwait(false); var rank = 1; foreach (var doc in result.Documents) { @@ -265,7 +385,7 @@ public async Task> GetDocumentHashesBySource var blob = EmbeddingToBlob(queryEmbedding); var preFilter = acl.HasRestrictions - ? $"@{DepartmentField}:{{{string.Join("|", acl.DepartmentIds)}}}" + ? $"@{DepartmentField}:{{{string.Join("|", acl.DepartmentIds.Select(EscapeTagValue))}}}" : "*"; var query = new Query($"({preFilter})=>[KNN {CandidateCount} @{EmbeddingField} $vec AS __score]") @@ -275,7 +395,7 @@ public async Task> GetDocumentHashesBySource .Limit(0, CandidateCount) .Dialect(2); - var result = await ft.SearchAsync(IndexName, query); + var result = await ft.SearchAsync(IndexName, query).ConfigureAwait(false); var results = new List<(Guid ChunkId, int Rank)>(); var rank = 1; foreach (var doc in result.Documents) @@ -300,42 +420,129 @@ public async Task> GetDocumentHashesBySource private async Task DeleteChunksBySourceIdAndUriAsync(IDatabase db, Guid sourceId, string sourceUri) { - var server = redis.GetServer(redis.GetEndPoints()[0]); + // Use FT.SEARCH with @sourceId + @sourceUri TAG filter instead of KEYS scan (M-3) + var ft = db.FT(); var keysToDelete = new List(); + var escapedId = sourceId.ToString().Replace("-", "\\-"); + var escapedUri = EscapeTagValue(sourceUri); - await foreach (var key in server.KeysAsync(pattern: $"{ChunkPrefix}*")) + try { - var fields = await db.HashGetAsync(key, [new RedisValue(SourceIdField), new RedisValue(SourceUriField)]); - if (!fields[0].IsNullOrEmpty && fields[0].ToString() == sourceId.ToString() - && !fields[1].IsNullOrEmpty && fields[1].ToString() == sourceUri) + var offset = 0; + const int pageSize = 100; + while (true) { - keysToDelete.Add(key); + var query = new Query($"@{SourceIdField}:{{{escapedId}}} @{SourceUriField}:{{{escapedUri}}}") + .Limit(offset, pageSize) + .SetNoContent() + .Dialect(2); + + var result = await ft.SearchAsync(IndexName, query).ConfigureAwait(false); + foreach (var doc in result.Documents) + { + keysToDelete.Add(doc.Id); + } + + if (result.Documents.Count < pageSize) break; + offset += pageSize; + } + } + catch (RedisServerException) + { + // Index not ready — fall back to KEYS scan with batched reads + var server = redis.GetServer(redis.GetEndPoints()[0]); + var keys = new List(); + await foreach (var key in server.KeysAsync(pattern: $"{ChunkPrefix}*").ConfigureAwait(false)) + { + keys.Add(key); + } + + if (keys.Count > 0) + { + var batch = db.CreateBatch(); + var tasks = keys + .Select(k => (Key: k, Task: batch.HashGetAsync(k, [new RedisValue(SourceIdField), new RedisValue(SourceUriField)]))) + .ToList(); + batch.Execute(); + await Task.WhenAll(tasks.Select(t => t.Task)).ConfigureAwait(false); + + foreach (var (key, task) in tasks) + { + var fields = task.Result; + if (!fields[0].IsNullOrEmpty && fields[0].ToString() == sourceId.ToString() + && !fields[1].IsNullOrEmpty && fields[1].ToString() == sourceUri) + { + keysToDelete.Add(key); + } + } } } if (keysToDelete.Count > 0) { - await db.KeyDeleteAsync(keysToDelete.ToArray()); + await db.KeyDeleteAsync(keysToDelete.ToArray()).ConfigureAwait(false); } } private async Task DeleteChunksBySourceIdAsync(IDatabase db, Guid sourceId) { - var server = redis.GetServer(redis.GetEndPoints()[0]); + // Use FT.SEARCH with @sourceId TAG filter instead of KEYS scan (M-3) + var ft = db.FT(); var keysToDelete = new List(); + var escapedId = sourceId.ToString().Replace("-", "\\-"); - await foreach (var key in server.KeysAsync(pattern: $"{ChunkPrefix}*")) + try + { + var offset = 0; + const int pageSize = 100; + while (true) + { + var query = new Query($"@{SourceIdField}:{{{escapedId}}}") + .Limit(offset, pageSize) + .SetNoContent() + .Dialect(2); + + var result = await ft.SearchAsync(IndexName, query).ConfigureAwait(false); + foreach (var doc in result.Documents) + { + keysToDelete.Add(doc.Id); + } + + if (result.Documents.Count < pageSize) break; + offset += pageSize; + } + } + catch (RedisServerException) { - var sid = await db.HashGetAsync(key, SourceIdField); - if (!sid.IsNullOrEmpty && sid.ToString() == sourceId.ToString()) + // Index not ready — fall back to KEYS scan with batched reads + var server = redis.GetServer(redis.GetEndPoints()[0]); + var keys = new List(); + await foreach (var key in server.KeysAsync(pattern: $"{ChunkPrefix}*").ConfigureAwait(false)) + { + keys.Add(key); + } + + if (keys.Count > 0) { - keysToDelete.Add(key); + var batch = db.CreateBatch(); + var tasks = keys.Select(k => (Key: k, Task: batch.HashGetAsync(k, SourceIdField))).ToList(); + batch.Execute(); + await Task.WhenAll(tasks.Select(t => t.Task)).ConfigureAwait(false); + + foreach (var (key, task) in tasks) + { + var sid = task.Result; + if (!sid.IsNullOrEmpty && sid.ToString() == sourceId.ToString()) + { + keysToDelete.Add(key); + } + } } } if (keysToDelete.Count > 0) { - await db.KeyDeleteAsync(keysToDelete.ToArray()); + await db.KeyDeleteAsync(keysToDelete.ToArray()).ConfigureAwait(false); } } @@ -435,6 +642,42 @@ private static KnowledgeSource SourceFromHash(Guid id, HashEntry[] entries) }; } + /// Escape special characters in a RediSearch TAG value. + private static string EscapeTagValue(string value) + { + // TAG values need escaping of: , . < > { } [ ] " ' : ; ! @ # $ % ^ & * ( ) - + = ~ + return value + .Replace("\\", "\\\\") + .Replace("-", "\\-") + .Replace(":", "\\:") + .Replace("/", "\\/") + .Replace(".", "\\.") + .Replace("@", "\\@") + .Replace("#", "\\#") + .Replace("$", "\\$") + .Replace("%", "\\%") + .Replace("^", "\\^") + .Replace("&", "\\&") + .Replace("*", "\\*") + .Replace("(", "\\(") + .Replace(")", "\\)") + .Replace("+", "\\+") + .Replace("=", "\\=") + .Replace("~", "\\~") + .Replace("'", "\\'") + .Replace("\"", "\\\"") + .Replace("!", "\\!") + .Replace("{", "\\{") + .Replace("}", "\\}") + .Replace("[", "\\[") + .Replace("]", "\\]") + .Replace("<", "\\<") + .Replace(">", "\\>") + .Replace(";", "\\;") + .Replace(",", "\\,") + .Replace(" ", "\\ "); + } + private static string EscapeRediSearchQuery(string query) { var words = query.Split(' ', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); @@ -451,39 +694,15 @@ private static byte[] EmbeddingToBlob(float[] embedding) // ── Init ───────────────────────────────────────────────────── - private async Task EnsureInitializedAsync(CancellationToken ct) - { - if (_initTask is { IsCompletedSuccessfully: true }) return; - - await _initLock.WaitAsync(ct); - try - { - if (_initTask is { IsCompletedSuccessfully: true }) return; - - var task = InitIndexAsync(); - _initTask = task; - await task; - } - catch - { - _initTask = null; - throw; - } - finally - { - _initLock.Release(); - } - } - - private async Task InitIndexAsync() + private async Task InitIndexAsync(CancellationToken _ = default) { var db = redis.GetDatabase(); var ft = db.FT(); try { - await ft.InfoAsync(IndexName); - _vectorSearchEnabled = true; + var info = await ft.InfoAsync(IndexName).ConfigureAwait(false); + _vectorSearchEnabled = IndexHasVectorField(info); LogInitialized(logger, _vectorSearchEnabled); return; } @@ -526,7 +745,7 @@ private async Task InitIndexAsync() try { - await ft.CreateAsync(IndexName, createParams, schema); + await ft.CreateAsync(IndexName, createParams, schema).ConfigureAwait(false); } catch (RedisServerException ex) when (ex.Message.Contains("Index already exists", StringComparison.OrdinalIgnoreCase)) { @@ -536,6 +755,22 @@ private async Task InitIndexAsync() LogInitialized(logger, _vectorSearchEnabled); } + private static bool IndexHasVectorField(InfoResult info) + { + if (info.Attributes is null) return false; + foreach (var attr in info.Attributes) + { + if (attr.TryGetValue("identifier", out var id) + && string.Equals(id.ToString(), EmbeddingField, StringComparison.Ordinal)) + return true; + if (attr.TryGetValue("type", out var type) + && string.Equals(type.ToString(), "VECTOR", StringComparison.OrdinalIgnoreCase)) + return true; + } + + return false; + } + [LoggerMessage(EventId = 1, Level = LogLevel.Warning, Message = "Redis knowledge store operation failed: {Message}")] private static partial void LogOperationFailed(ILogger logger, Exception exception, string message); diff --git a/src/clawsharp/Memory/Redis/RedisMemory.cs b/src/clawsharp/Memory/Redis/RedisMemory.cs index 0ea8c9c4..7ceacc63 100644 --- a/src/clawsharp/Memory/Redis/RedisMemory.cs +++ b/src/clawsharp/Memory/Redis/RedisMemory.cs @@ -53,16 +53,14 @@ public sealed partial class RedisMemory( private readonly int _embeddingDimension = memoryConfig?.Value.EmbeddingDimension ?? 1536; - private volatile Task? _initTask; - - private readonly SemaphoreSlim _initLock = new(1, 1); + private readonly LazyAsyncInit _init = new(); /// Whether vector search is available in the RediSearch index. private volatile bool _vectorSearchEnabled; public async Task GetContextAsync(CancellationToken ct = default) { - await EnsureInitializedAsync(ct); + await _init.EnsureCompletedAsync(InitIndexAsync, ct).ConfigureAwait(false); var db = redis.GetDatabase(); var ft = db.FT(); @@ -75,7 +73,7 @@ public sealed partial class RedisMemory( .ReturnFields(ContentField) .Dialect(2); - var result = await ft.SearchAsync(IndexName, query); + var result = await ft.SearchAsync(IndexName, query).ConfigureAwait(false); foreach (var doc in result.Documents) { var content = (string?)doc[ContentField]; @@ -89,7 +87,7 @@ public sealed partial class RedisMemory( { // Index not ready — fall back to SCAN LogMemoryOperationFailed(logger, ex, "FT.SEARCH failed, falling back to SCAN"); - await FallbackScanFacts(db, facts, RecentContentLimit); + await FallbackScanFacts(db, facts, RecentContentLimit).ConfigureAwait(false); } return facts.Count > 0 ? "## Memory\n" + string.Join("\n", facts) : null; @@ -97,10 +95,10 @@ public sealed partial class RedisMemory( public async Task AppendFactAsync(string fact, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); + await _init.EnsureCompletedAsync(InitIndexAsync, ct).ConfigureAwait(false); var db = redis.GetDatabase(); - var id = await db.StringIncrementAsync(FactSeqKey); + var id = await db.StringIncrementAsync(FactSeqKey).ConfigureAwait(false); var key = $"{FactPrefix}{id}"; var entries = new List @@ -114,7 +112,7 @@ public async Task AppendFactAsync(string fact, CancellationToken ct = default) { try { - var embedding = await embeddingProvider.EmbedAsync(fact, ct); + var embedding = await embeddingProvider.EmbedAsync(fact, ct).ConfigureAwait(false); var blob = EmbeddingToBlob(embedding); entries.Add(new HashEntry(EmbeddingField, blob)); } @@ -125,26 +123,26 @@ public async Task AppendFactAsync(string fact, CancellationToken ct = default) } } - await db.HashSetAsync(key, entries.ToArray()); + await db.HashSetAsync(key, entries.ToArray()).ConfigureAwait(false); } public async Task AppendHistoryAsync(string summary, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); + await _init.EnsureCompletedAsync(InitIndexAsync, ct).ConfigureAwait(false); var db = redis.GetDatabase(); - var id = await db.StringIncrementAsync(HistorySeqKey); + var id = await db.StringIncrementAsync(HistorySeqKey).ConfigureAwait(false); var key = $"{HistoryPrefix}{id}"; await db.HashSetAsync(key, [ new HashEntry(SummaryField, summary), new HashEntry(TsField, DateTimeOffset.UtcNow.ToString("O")), - ]); + ]).ConfigureAwait(false); } public async Task> SearchAsync(string query, int n = 5, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); + await _init.EnsureCompletedAsync(InitIndexAsync, ct).ConfigureAwait(false); var db = redis.GetDatabase(); var ft = db.FT(); @@ -156,7 +154,7 @@ public async Task> SearchAsync(string query, int n = 5, Ca .ReturnFields(ContentField) .Dialect(2); - var result = await ft.SearchAsync(IndexName, ftQuery); + var result = await ft.SearchAsync(IndexName, ftQuery).ConfigureAwait(false); var results = new List(); foreach (var doc in result.Documents) { @@ -173,21 +171,21 @@ public async Task> SearchAsync(string query, int n = 5, Ca { // FT.SEARCH syntax error or index not ready — fall back to SCAN + Contains LogMemoryOperationFailed(logger, ex, ex.Message); - return await ScanContainsSearch(db, query, n); + return await ScanContainsSearch(db, query, n).ConfigureAwait(false); } } public async Task> SearchHybridAsync(string query, float[]? queryEmbedding = null, int topK = 5, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); + await _init.EnsureCompletedAsync(InitIndexAsync, ct).ConfigureAwait(false); var db = redis.GetDatabase(); var ft = db.FT(); // If no query embedding, fall back to text search returning Fact objects if (queryEmbedding is null || queryEmbedding.Length == 0) { - return await SearchTextOnly(db, ft, query, topK); + return await SearchTextOnly(db, ft, query, topK).ConfigureAwait(false); } // Hybrid search: combine text pre-filter with vector search @@ -195,7 +193,7 @@ public async Task> SearchHybridAsync(string query, float[]? { try { - return await SearchHybridVectorAsync(db, ft, query, queryEmbedding, topK); + return await SearchHybridVectorAsync(db, ft, query, queryEmbedding, topK).ConfigureAwait(false); } catch (Exception ex) { @@ -204,33 +202,42 @@ public async Task> SearchHybridAsync(string query, float[]? } // Fallback: text pre-filter + in-process cosine scoring - return await SearchHybridFallbackAsync(db, ft, query, queryEmbedding, topK); + return await SearchHybridFallbackAsync(db, ft, query, queryEmbedding, topK).ConfigureAwait(false); } public async Task> ListFactsAsync(CancellationToken ct = default) { - await EnsureInitializedAsync(ct); + await _init.EnsureCompletedAsync(InitIndexAsync, ct).ConfigureAwait(false); var db = redis.GetDatabase(); - - var facts = new List(); var server = redis.GetServer(redis.GetEndPoints()[0]); - await foreach (var key in server.KeysAsync(pattern: $"{FactPrefix}*")) + // Collect all fact keys first, then pipeline HashGetAllAsync (M-4) + var keyEntries = new List<(long Id, RedisKey Key)>(); + await foreach (var key in server.KeysAsync(pattern: $"{FactPrefix}*").ConfigureAwait(false)) { var keyStr = key.ToString(); - // Skip the sequence key - if (keyStr == FactSeqKey) - { - continue; - } + if (keyStr == FactSeqKey) continue; var id = ParseIdFromKey(keyStr, FactPrefix); - if (id < 0) + if (id >= 0) { - continue; + keyEntries.Add((id, key)); } + } + + if (keyEntries.Count == 0) return []; - var hash = await db.HashGetAllAsync(key); + var batch = db.CreateBatch(); + var tasks = keyEntries + .Select(e => (e.Id, Task: batch.HashGetAllAsync(e.Key))) + .ToList(); + batch.Execute(); + await Task.WhenAll(tasks.Select(t => t.Task)).ConfigureAwait(false); + + var facts = new List(tasks.Count); + foreach (var (id, task) in tasks) + { + var hash = task.Result; if (hash.Length > 0) { facts.Add(FactFromHash(id, hash)); @@ -244,24 +251,24 @@ public async Task> ListFactsAsync(CancellationToken ct = def public async Task ClearAsync(CancellationToken ct = default) { - await EnsureInitializedAsync(ct); + await _init.EnsureCompletedAsync(InitIndexAsync, ct).ConfigureAwait(false); var db = redis.GetDatabase(); var server = redis.GetServer(redis.GetEndPoints()[0]); // Delete all fact keys var keysToDelete = new List(); - await foreach (var key in server.KeysAsync(pattern: $"{FactPrefix}*")) + await foreach (var key in server.KeysAsync(pattern: $"{FactPrefix}*").ConfigureAwait(false)) { keysToDelete.Add(key); } if (keysToDelete.Count > 0) { - await db.KeyDeleteAsync(keysToDelete.ToArray()); + await db.KeyDeleteAsync(keysToDelete.ToArray()).ConfigureAwait(false); } // Reset the sequence counter - await db.KeyDeleteAsync(FactSeqKey); + await db.KeyDeleteAsync(FactSeqKey).ConfigureAwait(false); // History entries are WORM (write-once read-many) — never deleted. // They represent immutable compaction snapshots and are preserved across clears. @@ -269,13 +276,13 @@ public async Task ClearAsync(CancellationToken ct = default) public async Task PruneExpiredFactsAsync(TimeSpan maxAge, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); + await _init.EnsureCompletedAsync(InitIndexAsync, ct).ConfigureAwait(false); var db = redis.GetDatabase(); var server = redis.GetServer(redis.GetEndPoints()[0]); var cutoff = DateTimeOffset.UtcNow - maxAge; var keysToDelete = new List(); - await foreach (var key in server.KeysAsync(pattern: $"{FactPrefix}*")) + await foreach (var key in server.KeysAsync(pattern: $"{FactPrefix}*").ConfigureAwait(false)) { var keyStr = key.ToString(); if (keyStr == FactSeqKey) @@ -283,7 +290,7 @@ public async Task PruneExpiredFactsAsync(TimeSpan maxAge, CancellationToken continue; } - var createdAtVal = await db.HashGetAsync(key, CreatedAtField); + var createdAtVal = await db.HashGetAsync(key, CreatedAtField).ConfigureAwait(false); if (createdAtVal.IsNullOrEmpty) { continue; @@ -297,7 +304,7 @@ public async Task PruneExpiredFactsAsync(TimeSpan maxAge, CancellationToken if (keysToDelete.Count > 0) { - await db.KeyDeleteAsync(keysToDelete.ToArray()); + await db.KeyDeleteAsync(keysToDelete.ToArray()).ConfigureAwait(false); } return keysToDelete.Count; @@ -319,7 +326,7 @@ private async Task> SearchHybridVectorAsync( .Limit(0, knn) .Dialect(2); - var result = await ft.SearchAsync(IndexName, vecQuery); + var result = await ft.SearchAsync(IndexName, vecQuery).ConfigureAwait(false); if (result.TotalResults == 0) { @@ -349,7 +356,7 @@ private async Task> SearchHybridVectorAsync( var returnedIds = scored.Select(f => f.Id).ToList(); if (returnedIds.Count > 0) { - await UpdateAccessCountsAsync(db, returnedIds); + await UpdateAccessCountsAsync(db, returnedIds).ConfigureAwait(false); } return scored; @@ -372,7 +379,7 @@ private async Task> SearchHybridFallbackAsync( .ReturnFields(ContentField, CreatedAtField, AccessCountField, LastAccessedAtField, EmbeddingField) .Dialect(2); - var result = await ft.SearchAsync(IndexName, ftQuery); + var result = await ft.SearchAsync(IndexName, ftQuery).ConfigureAwait(false); candidates = result.Documents .Select(doc => (FactFromDocument(doc), EmbeddingBlobFromDocument(doc))) .ToList(); @@ -380,13 +387,13 @@ private async Task> SearchHybridFallbackAsync( // If text search returned nothing, fall back to most recent facts if (candidates.Count == 0) { - candidates = await LoadRecentFactsWithEmbeddings(db, CandidateLimit); + candidates = await LoadRecentFactsWithEmbeddings(db, CandidateLimit).ConfigureAwait(false); } } catch (Exception ex) { LogMemoryOperationFailed(logger, ex, ex.Message); - candidates = await LoadRecentFactsWithEmbeddings(db, CandidateLimit); + candidates = await LoadRecentFactsWithEmbeddings(db, CandidateLimit).ConfigureAwait(false); } // Step 2: in-process cosine scoring @@ -412,7 +419,7 @@ private async Task> SearchHybridFallbackAsync( var returnedIds = scored.Select(f => f.Id).ToList(); if (returnedIds.Count > 0) { - await UpdateAccessCountsAsync(db, returnedIds); + await UpdateAccessCountsAsync(db, returnedIds).ConfigureAwait(false); } return scored; @@ -430,20 +437,20 @@ private async Task> SearchTextOnly(IDatabase db, SearchComma .ReturnFields(ContentField, CreatedAtField, AccessCountField, LastAccessedAtField) .Dialect(2); - var result = await ft.SearchAsync(IndexName, ftQuery); + var result = await ft.SearchAsync(IndexName, ftQuery).ConfigureAwait(false); candidates = result.Documents.Select(FactFromDocument).ToList(); } catch (Exception ex) { LogMemoryOperationFailed(logger, ex, ex.Message); // Fall back to SCAN + Contains - candidates = await ScanContainsSearchFacts(db, query, topK); + candidates = await ScanContainsSearchFacts(db, query, topK).ConfigureAwait(false); } var ids = candidates.Select(f => f.Id).ToList(); if (ids.Count > 0) { - await UpdateAccessCountsAsync(db, ids); + await UpdateAccessCountsAsync(db, ids).ConfigureAwait(false); } return candidates; @@ -470,23 +477,34 @@ private float ComputeHybridScore(Fact fact, float vectorScore, string query) private async Task FallbackScanFacts(IDatabase db, List results, int limit) { var server = redis.GetServer(redis.GetEndPoints()[0]); - var facts = new List<(long id, string content)>(); - await foreach (var key in server.KeysAsync(pattern: $"{FactPrefix}*")) + // Collect keys first, then pipeline HashGetAsync (M-4) + var keyEntries = new List<(long Id, RedisKey Key)>(); + await foreach (var key in server.KeysAsync(pattern: $"{FactPrefix}*").ConfigureAwait(false)) { var keyStr = key.ToString(); - if (keyStr == FactSeqKey) - { - continue; - } + if (keyStr == FactSeqKey) continue; var id = ParseIdFromKey(keyStr, FactPrefix); - if (id < 0) + if (id >= 0) { - continue; + keyEntries.Add((id, key)); } + } + + if (keyEntries.Count == 0) return; - var content = await db.HashGetAsync(key, ContentField); + var batch = db.CreateBatch(); + var tasks = keyEntries + .Select(e => (e.Id, Task: batch.HashGetAsync(e.Key, ContentField))) + .ToList(); + batch.Execute(); + await Task.WhenAll(tasks.Select(t => t.Task)).ConfigureAwait(false); + + var facts = new List<(long id, string content)>(); + foreach (var (id, task) in tasks) + { + var content = task.Result; if (!content.IsNullOrEmpty) { facts.Add((id, content.ToString())); @@ -502,23 +520,32 @@ private async Task FallbackScanFacts(IDatabase db, List results, int lim private async Task> ScanContainsSearch(IDatabase db, string query, int n) { var server = redis.GetServer(redis.GetEndPoints()[0]); - var results = new List(); - await foreach (var key in server.KeysAsync(pattern: $"{FactPrefix}*")) + // Collect keys first, then pipeline HashGetAsync (M-4) + var keys = new List(); + await foreach (var key in server.KeysAsync(pattern: $"{FactPrefix}*").ConfigureAwait(false)) { - if (key.ToString() == FactSeqKey) + if (key.ToString() != FactSeqKey) { - continue; + keys.Add(key); } + } + + if (keys.Count == 0) return []; - var content = await db.HashGetAsync(key, ContentField); + var batch = db.CreateBatch(); + var tasks = keys.Select(k => (Key: k, Task: batch.HashGetAsync(k, ContentField))).ToList(); + batch.Execute(); + await Task.WhenAll(tasks.Select(t => t.Task)).ConfigureAwait(false); + + var results = new List(); + foreach (var (_, task) in tasks) + { + var content = task.Result; if (!content.IsNullOrEmpty && content.ToString().Contains(query, StringComparison.OrdinalIgnoreCase)) { results.Add(content.ToString()); - if (results.Count >= n) - { - break; - } + if (results.Count >= n) break; } } @@ -528,36 +555,41 @@ private async Task> ScanContainsSearch(IDatabase db, string query, private async Task> ScanContainsSearchFacts(IDatabase db, string query, int n) { var server = redis.GetServer(redis.GetEndPoints()[0]); - var results = new List(); - await foreach (var key in server.KeysAsync(pattern: $"{FactPrefix}*")) + // Collect keys first, then pipeline HashGetAllAsync (M-4) + var keyEntries = new List<(long Id, RedisKey Key)>(); + await foreach (var key in server.KeysAsync(pattern: $"{FactPrefix}*").ConfigureAwait(false)) { var keyStr = key.ToString(); - if (keyStr == FactSeqKey) - { - continue; - } + if (keyStr == FactSeqKey) continue; var id = ParseIdFromKey(keyStr, FactPrefix); - if (id < 0) + if (id >= 0) { - continue; + keyEntries.Add((id, key)); } + } - var hash = await db.HashGetAllAsync(key); - if (hash.Length == 0) - { - continue; - } + if (keyEntries.Count == 0) return []; + + var batch = db.CreateBatch(); + var tasks = keyEntries + .Select(e => (e.Id, Task: batch.HashGetAllAsync(e.Key))) + .ToList(); + batch.Execute(); + await Task.WhenAll(tasks.Select(t => t.Task)).ConfigureAwait(false); + + var results = new List(); + foreach (var (id, task) in tasks) + { + var hash = task.Result; + if (hash.Length == 0) continue; var fact = FactFromHash(id, hash); if (fact.Content.Contains(query, StringComparison.OrdinalIgnoreCase)) { results.Add(fact); - if (results.Count >= n) - { - break; - } + if (results.Count >= n) break; } } @@ -567,27 +599,35 @@ private async Task> ScanContainsSearchFacts(IDatabase db, string quer private async Task> LoadRecentFactsWithEmbeddings(IDatabase db, int limit) { var server = redis.GetServer(redis.GetEndPoints()[0]); - var all = new List<(long id, Fact fact, byte[]? blob)>(); - await foreach (var key in server.KeysAsync(pattern: $"{FactPrefix}*")) + // Collect keys first, then pipeline HashGetAllAsync (M-4) + var keyEntries = new List<(long Id, RedisKey Key)>(); + await foreach (var key in server.KeysAsync(pattern: $"{FactPrefix}*").ConfigureAwait(false)) { var keyStr = key.ToString(); - if (keyStr == FactSeqKey) - { - continue; - } + if (keyStr == FactSeqKey) continue; var id = ParseIdFromKey(keyStr, FactPrefix); - if (id < 0) + if (id >= 0) { - continue; + keyEntries.Add((id, key)); } + } - var hash = await db.HashGetAllAsync(key); - if (hash.Length == 0) - { - continue; - } + if (keyEntries.Count == 0) return []; + + var batch = db.CreateBatch(); + var tasks = keyEntries + .Select(e => (e.Id, Task: batch.HashGetAllAsync(e.Key))) + .ToList(); + batch.Execute(); + await Task.WhenAll(tasks.Select(t => t.Task)).ConfigureAwait(false); + + var all = new List<(long id, Fact fact, byte[]? blob)>(); + foreach (var (id, task) in tasks) + { + var hash = task.Result; + if (hash.Length == 0) continue; var fact = FactFromHash(id, hash); byte[]? blob = null; @@ -613,12 +653,17 @@ private async Task UpdateAccessCountsAsync(IDatabase db, List ids) try { var now = DateTimeOffset.UtcNow.ToString("O"); + var batch = db.CreateBatch(); + var tasks = new List(ids.Count * 2); foreach (var id in ids) { var key = $"{FactPrefix}{id}"; - await db.HashIncrementAsync(key, AccessCountField); - await db.HashSetAsync(key, [new HashEntry(LastAccessedAtField, now)]); + tasks.Add(batch.HashIncrementAsync(key, AccessCountField)); + tasks.Add(batch.HashSetAsync(key, [new HashEntry(LastAccessedAtField, now)])); } + + batch.Execute(); + await Task.WhenAll(tasks).ConfigureAwait(false); } catch (Exception ex) { @@ -629,37 +674,7 @@ private async Task UpdateAccessCountsAsync(IDatabase db, List ids) // ── Initialization ────────────────────────────────────────────────────── - private async Task EnsureInitializedAsync(CancellationToken ct) - { - if (_initTask is { IsCompletedSuccessfully: true }) - { - return; - } - - await _initLock.WaitAsync(ct); - try - { - if (_initTask is { IsCompletedSuccessfully: true }) - { - return; - } - - var task = InitIndexAsync(); - _initTask = task; - await task; - } - catch - { - _initTask = null; // allow retry on next call - throw; - } - finally - { - _initLock.Release(); - } - } - - private async Task InitIndexAsync() + private async Task InitIndexAsync(CancellationToken _ = default) { var db = redis.GetDatabase(); var ft = db.FT(); @@ -667,7 +682,7 @@ private async Task InitIndexAsync() try { // Check if index already exists - await ft.InfoAsync(IndexName); + await ft.InfoAsync(IndexName).ConfigureAwait(false); // Index exists — check if vector field is present _vectorSearchEnabled = embeddingProvider is not null; LogInitialized(logger, _vectorSearchEnabled); @@ -713,7 +728,7 @@ private async Task InitIndexAsync() try { - await ft.CreateAsync(IndexName, createParams, schema); + await ft.CreateAsync(IndexName, createParams, schema).ConfigureAwait(false); } catch (RedisServerException ex) when (ex.Message.Contains("Index already exists", StringComparison.OrdinalIgnoreCase)) { diff --git a/src/clawsharp/Memory/Sqlite/SqliteKnowledgeStore.cs b/src/clawsharp/Memory/Sqlite/SqliteKnowledgeStore.cs index de793776..5c53b86b 100644 --- a/src/clawsharp/Memory/Sqlite/SqliteKnowledgeStore.cs +++ b/src/clawsharp/Memory/Sqlite/SqliteKnowledgeStore.cs @@ -28,32 +28,32 @@ public sealed partial class SqliteKnowledgeStore( private const string FtsTable = "KnowledgeChunks_fts"; private const string EmbeddingColumn = "embedding_json"; - private volatile Task? _initTask; - private readonly SemaphoreSlim _initLock = new(1, 1); + private readonly LazyAsyncInit _init = new(); public async Task UpsertChunksAsync(Guid sourceId, IReadOnlyList chunks, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); - await using var transaction = await context.Database.BeginTransactionAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); + await using var transaction = await context.Database.BeginTransactionAsync(ct).ConfigureAwait(false); try { - // Delete existing FTS entries for this source's chunks + // Delete existing FTS entries for this source's chunks in a single batch var existingIds = await context.KnowledgeChunks .Where(c => c.KnowledgeSourceId == sourceId) .Select(c => c.Id) - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); - foreach (var id in existingIds) + if (existingIds.Count > 0) { - var idStr = id.ToString(); - await context.Database.ExecuteSqlAsync( - $"DELETE FROM KnowledgeChunks_fts WHERE ChunkId = {idStr}", ct); + // Batch FTS delete: IDs are Guids from our own query (not user input), safe for IN clause + var idCsv = string.Join(",", existingIds.Select(id => $"'{id}'")); + await context.Database.ExecuteSqlRawAsync( + $"DELETE FROM {FtsTable} WHERE ChunkId IN ({idCsv})", ct).ConfigureAwait(false); } // Delete existing chunks via EF - await context.KnowledgeChunks.Where(c => c.KnowledgeSourceId == sourceId).ExecuteDeleteAsync(ct); + await context.KnowledgeChunks.Where(c => c.KnowledgeSourceId == sourceId).ExecuteDeleteAsync(ct).ConfigureAwait(false); // Insert new chunks foreach (var chunk in chunks) @@ -61,24 +61,53 @@ await context.Database.ExecuteSqlAsync( context.KnowledgeChunks.Add(chunk); } - await context.SaveChangesAsync(ct); + await context.SaveChangesAsync(ct).ConfigureAwait(false); - // Store embedding as JSON TEXT and sync FTS5 - foreach (var chunk in chunks) + // Store embeddings as JSON TEXT in a batch using CASE expression + var chunksWithEmbeddings = chunks.Where(c => c.Embedding is not null).ToList(); + if (chunksWithEmbeddings.Count > 0) { - // Store embedding as JSON TEXT column - if (chunk.Embedding is not null) + // Build parameterized batch update: UPDATE ... SET embedding_json = CASE Id WHEN ... END WHERE Id IN (...) + var parameters = new List(); + var caseParts = new List(); + var idParts = new List(); + var paramIdx = 0; + + foreach (var chunk in chunksWithEmbeddings) { - var json = EmbeddingMath.Serialize(chunk.Embedding.ToArray()); + var json = EmbeddingMath.Serialize(chunk.Embedding!.ToArray()); var idStr = chunk.Id.ToString(); - await context.Database.ExecuteSqlAsync( - $"UPDATE KnowledgeChunks SET embedding_json = {json} WHERE Id = {idStr}", ct); + caseParts.Add($"WHEN {{{paramIdx}}} THEN {{{paramIdx + 1}}}"); + idParts.Add($"{{{paramIdx}}}"); + parameters.Add(idStr); + parameters.Add(json); + paramIdx += 2; } - // Sync FTS5 - var chunkIdStr = chunk.Id.ToString(); - await context.Database.ExecuteSqlAsync( - $"INSERT INTO KnowledgeChunks_fts(ChunkId, Content) VALUES ({chunkIdStr}, {chunk.Content})", ct); + var sql = $"UPDATE KnowledgeChunks SET {EmbeddingColumn} = CASE CAST(Id AS TEXT) " + + string.Join(" ", caseParts) + + " END WHERE CAST(Id AS TEXT) IN (" + string.Join(",", idParts) + ")"; + await context.Database.ExecuteSqlRawAsync(sql, parameters, ct).ConfigureAwait(false); + } + + // Batch FTS5 insert: build a single INSERT with multiple VALUES rows + if (chunks.Count > 0) + { + var ftsParams = new List(); + var ftsValueParts = new List(); + var ftsParamIdx = 0; + + foreach (var chunk in chunks) + { + ftsValueParts.Add($"({{{ftsParamIdx}}}, {{{ftsParamIdx + 1}}})"); + ftsParams.Add(chunk.Id.ToString()); + ftsParams.Add(chunk.Content); + ftsParamIdx += 2; + } + + var ftsSql = $"INSERT INTO {FtsTable}(ChunkId, Content) VALUES " + + string.Join(", ", ftsValueParts); + await context.Database.ExecuteSqlRawAsync(ftsSql, ftsParams, ct).ConfigureAwait(false); } // Update source chunk count @@ -86,22 +115,22 @@ await context.KnowledgeSources .Where(s => s.Id == sourceId) .ExecuteUpdateAsync(s => s .SetProperty(x => x.ChunkCount, chunks.Count) - .SetProperty(x => x.UpdatedAt, DateTimeOffset.UtcNow), ct); + .SetProperty(x => x.UpdatedAt, DateTimeOffset.UtcNow), ct).ConfigureAwait(false); - await transaction.CommitAsync(ct); + await transaction.CommitAsync(ct).ConfigureAwait(false); } catch { - await transaction.RollbackAsync(ct); + await transaction.RollbackAsync(ct).ConfigureAwait(false); throw; } } public async Task DeleteByDocumentAsync(Guid sourceId, string sourceUri, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); - await using var transaction = await context.Database.BeginTransactionAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); + await using var transaction = await context.Database.BeginTransactionAsync(ct).ConfigureAwait(false); try { @@ -109,34 +138,34 @@ public async Task DeleteByDocumentAsync(Guid sourceId, string sourceUri, Cancell var chunkIds = await context.KnowledgeChunks .Where(c => c.KnowledgeSourceId == sourceId && c.SourceUri == sourceUri) .Select(c => c.Id) - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); foreach (var id in chunkIds) { var idStr = id.ToString(); await context.Database.ExecuteSqlAsync( - $"DELETE FROM KnowledgeChunks_fts WHERE ChunkId = {idStr}", ct); + $"DELETE FROM KnowledgeChunks_fts WHERE ChunkId = {idStr}", ct).ConfigureAwait(false); } // Delete only chunks matching both sourceId and sourceUri await context.KnowledgeChunks .Where(c => c.KnowledgeSourceId == sourceId && c.SourceUri == sourceUri) - .ExecuteDeleteAsync(ct); + .ExecuteDeleteAsync(ct).ConfigureAwait(false); - await transaction.CommitAsync(ct); + await transaction.CommitAsync(ct).ConfigureAwait(false); } catch { - await transaction.RollbackAsync(ct); + await transaction.RollbackAsync(ct).ConfigureAwait(false); throw; } } public async Task DeleteBySourceAsync(Guid sourceId, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); - await using var transaction = await context.Database.BeginTransactionAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); + await using var transaction = await context.Database.BeginTransactionAsync(ct).ConfigureAwait(false); try { @@ -144,23 +173,23 @@ public async Task DeleteBySourceAsync(Guid sourceId, CancellationToken ct = defa var chunkIds = await context.KnowledgeChunks .Where(c => c.KnowledgeSourceId == sourceId) .Select(c => c.Id) - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); foreach (var id in chunkIds) { var idStr = id.ToString(); await context.Database.ExecuteSqlAsync( - $"DELETE FROM KnowledgeChunks_fts WHERE ChunkId = {idStr}", ct); + $"DELETE FROM KnowledgeChunks_fts WHERE ChunkId = {idStr}", ct).ConfigureAwait(false); } // Cascade delete: deleting the source removes all chunks via FK cascade - await context.KnowledgeSources.Where(s => s.Id == sourceId).ExecuteDeleteAsync(ct); + await context.KnowledgeSources.Where(s => s.Id == sourceId).ExecuteDeleteAsync(ct).ConfigureAwait(false); - await transaction.CommitAsync(ct); + await transaction.CommitAsync(ct).ConfigureAwait(false); } catch { - await transaction.RollbackAsync(ct); + await transaction.RollbackAsync(ct).ConfigureAwait(false); throw; } } @@ -168,15 +197,15 @@ await context.Database.ExecuteSqlAsync( public async Task> SearchAsync( float[]? queryEmbedding, string queryText, AclFilter acl, int topK = 5, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); // Path 1: FTS5 keyword search - var ftsResults = await FtsSearchAsync(context, queryText, acl, ct); + var ftsResults = await FtsSearchAsync(context, queryText, acl, ct).ConfigureAwait(false); // Path 2: In-process cosine vector search (skipped when embedding is null per D-13) var vectorResults = queryEmbedding is not null - ? await VectorSearchAsync(context, queryEmbedding, acl, ct) + ? await VectorSearchAsync(context, queryEmbedding, acl, ct).ConfigureAwait(false) : []; // Build chunk lookup and RRF merge @@ -193,39 +222,48 @@ public async Task> SearchAsync( var chunkLookup = await context.KnowledgeChunks .AsNoTracking() .Where(c => allIds.Contains(c.Id)) - .ToDictionaryAsync(c => c.Id, ct); + .ToDictionaryAsync(c => c.Id, ct).ConfigureAwait(false); return RrfMerger.Merge(ftsResults, vectorResults, chunkLookup, topK: topK); } public async Task> ListSourcesAsync(CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); - return await context.KnowledgeSources.AsNoTracking().OrderByDescending(s => s.CreatedAt).ToListAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); + return await context.KnowledgeSources.AsNoTracking().OrderByDescending(s => s.CreatedAt).ToListAsync(ct).ConfigureAwait(false); } public async Task GetSourceAsync(Guid id, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); - return await context.KnowledgeSources.AsNoTracking().FirstOrDefaultAsync(s => s.Id == id, ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); + return await context.KnowledgeSources.AsNoTracking().FirstOrDefaultAsync(s => s.Id == id, ct).ConfigureAwait(false); } /// public async Task> GetDocumentHashesBySourceAsync(Guid sourceId, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); var pairs = await context.KnowledgeChunks .AsNoTracking() .Where(c => c.KnowledgeSourceId == sourceId) .Select(c => new { c.SourceUri, c.DocumentHash }) .Distinct() - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); return pairs.ToDictionary(p => p.SourceUri, p => p.DocumentHash, StringComparer.Ordinal); } + /// + public async Task CountChunksAsync(Guid sourceId, CancellationToken ct = default) + { + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); + return await context.KnowledgeChunks + .CountAsync(c => c.KnowledgeSourceId == sourceId, ct).ConfigureAwait(false); + } + // ── FTS5 search ────────────────────────────────────────────── private sealed class FtsRow @@ -243,40 +281,45 @@ private sealed class FtsRow { var ftsQuery = SanitizeFtsQuery(queryText); - List rows; + // Always run FTS without ACL filter (CandidateCount is a compile-time const, safe to interpolate) + var sql = $$""" + SELECT f.ChunkId AS "ChunkId" + FROM KnowledgeChunks_fts f + WHERE KnowledgeChunks_fts MATCH {0} + ORDER BY rank + LIMIT {{CandidateCount}} + """; + var rows = await context.Database.SqlQueryRaw(sql, ftsQuery).ToListAsync(ct).ConfigureAwait(false); + + // Parse FTS results into Guid IDs + var parsedIds = new List(); + foreach (var row in rows) + { + if (Guid.TryParse(row.ChunkId, out var id)) + parsedIds.Add(id); + } + + // Post-filter by department ACL via LINQ (no string interpolation of user data) + IEnumerable allowedIds; if (acl.HasRestrictions) { - var deptList = string.Join(",", acl.DepartmentIds.Select(d => $"'{d.Replace("'", "''")}'")); - var sql = $$""" - SELECT f.ChunkId AS "ChunkId" - FROM KnowledgeChunks_fts f - JOIN KnowledgeChunks c ON f.ChunkId = CAST(c.Id AS TEXT) - WHERE KnowledgeChunks_fts MATCH {0} - AND c.DepartmentId IN ({{deptList}}) - ORDER BY rank - LIMIT {{CandidateCount}} - """; - rows = await context.Database.SqlQueryRaw(sql, ftsQuery).ToListAsync(ct); + var depts = acl.DepartmentIds.ToHashSet(StringComparer.Ordinal); + var allowedSet = await context.KnowledgeChunks + .AsNoTracking() + .Where(c => parsedIds.Contains(c.Id) && depts.Contains(c.DepartmentId)) + .Select(c => c.Id) + .ToHashSetAsync(ct).ConfigureAwait(false); + allowedIds = parsedIds.Where(id => allowedSet.Contains(id)); } else { - var sql = $$""" - SELECT f.ChunkId AS "ChunkId" - FROM KnowledgeChunks_fts f - WHERE KnowledgeChunks_fts MATCH {0} - ORDER BY rank - LIMIT {{CandidateCount}} - """; - rows = await context.Database.SqlQueryRaw(sql, ftsQuery).ToListAsync(ct); + allowedIds = parsedIds; } var rank = 1; - foreach (var row in rows) + foreach (var id in allowedIds) { - if (Guid.TryParse(row.ChunkId, out var id)) - { - results.Add((id, rank++)); - } + results.Add((id, rank++)); } } catch (Exception ex) @@ -302,27 +345,33 @@ private sealed class ChunkEmbeddingRow try { - List rows; + // Always load all embeddings (no ACL in SQL); post-filter by department via LINQ + FormattableString sql = + $""" + SELECT CAST(Id AS TEXT) AS "ChunkId", embedding_json AS "EmbeddingJson" + FROM KnowledgeChunks + WHERE embedding_json IS NOT NULL + """; + var rows = await context.Database.SqlQuery(sql).ToListAsync(ct).ConfigureAwait(false); + + // Build department allowlist for post-filtering + HashSet? allowedDepts = null; + Dictionary? deptLookup = null; if (acl.HasRestrictions) { - var deptList = string.Join(",", acl.DepartmentIds.Select(d => $"'{d.Replace("'", "''")}'")); - var sql = $""" - SELECT CAST(Id AS TEXT) AS "ChunkId", embedding_json AS "EmbeddingJson" - FROM KnowledgeChunks - WHERE embedding_json IS NOT NULL - AND DepartmentId IN ({deptList}) - """; - rows = await context.Database.SqlQueryRaw(sql).ToListAsync(ct); - } - else - { - FormattableString sql = - $""" - SELECT CAST(Id AS TEXT) AS "ChunkId", embedding_json AS "EmbeddingJson" - FROM KnowledgeChunks - WHERE embedding_json IS NOT NULL - """; - rows = await context.Database.SqlQuery(sql).ToListAsync(ct); + allowedDepts = acl.DepartmentIds.ToHashSet(StringComparer.Ordinal); + + // Fetch DepartmentId for all candidate chunks in a single query + var candidateIds = rows + .Where(r => Guid.TryParse(r.ChunkId, out _)) + .Select(r => Guid.Parse(r.ChunkId)) + .ToList(); + + deptLookup = await context.KnowledgeChunks + .AsNoTracking() + .Where(c => candidateIds.Contains(c.Id)) + .Select(c => new { c.Id, c.DepartmentId }) + .ToDictionaryAsync(c => c.Id, c => c.DepartmentId, ct).ConfigureAwait(false); } var scored = new List<(Guid id, float score)>(); @@ -330,6 +379,13 @@ WHERE embedding_json IS NOT NULL { if (row.EmbeddingJson is null || !Guid.TryParse(row.ChunkId, out var id)) continue; + // Post-filter: skip chunks not in allowed departments + if (allowedDepts is not null && deptLookup is not null) + { + if (!deptLookup.TryGetValue(id, out var dept) || !allowedDepts.Contains(dept)) + continue; + } + var vec = EmbeddingMath.Deserialize(row.EmbeddingJson); if (vec.Length == 0 || vec.Length != queryEmbedding.Length) continue; @@ -358,40 +414,16 @@ private static string SanitizeFtsQuery(string query) // ── Init ───────────────────────────────────────────────────── - private async Task EnsureInitializedAsync(CancellationToken ct) - { - if (_initTask is { IsCompletedSuccessfully: true }) return; - - await _initLock.WaitAsync(ct); - try - { - if (_initTask is { IsCompletedSuccessfully: true }) return; - - var task = InitSchemaAsync(ct); - _initTask = task; - await task; - } - catch - { - _initTask = null; - throw; - } - finally - { - _initLock.Release(); - } - } - [RequiresDynamicCode("EF Core runtime model building requires dynamic code generation.")] private async Task InitSchemaAsync(CancellationToken ct) { - await using var context = await contextFactory.CreateDbContextAsync(ct); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); // Add embedding JSON TEXT column if not present (graceful migration) try { await context.Database.ExecuteSqlRawAsync( - $"ALTER TABLE {KnowledgeChunk.TableName} ADD COLUMN {EmbeddingColumn} TEXT", ct); + $"ALTER TABLE {KnowledgeChunk.TableName} ADD COLUMN {EmbeddingColumn} TEXT", ct).ConfigureAwait(false); } catch { @@ -400,7 +432,7 @@ await context.Database.ExecuteSqlRawAsync( // Standalone FTS5 table (not content-synced — Guid PKs are not integer rowids) await context.Database.ExecuteSqlRawAsync( - $"CREATE VIRTUAL TABLE IF NOT EXISTS {FtsTable} USING fts5(ChunkId, Content)", ct); + $"CREATE VIRTUAL TABLE IF NOT EXISTS {FtsTable} USING fts5(ChunkId, Content)", ct).ConfigureAwait(false); LogSchemaInitialized(logger); } diff --git a/src/clawsharp/Memory/Sqlite/SqliteMemory.cs b/src/clawsharp/Memory/Sqlite/SqliteMemory.cs index a50e60b9..f096f346 100644 --- a/src/clawsharp/Memory/Sqlite/SqliteMemory.cs +++ b/src/clawsharp/Memory/Sqlite/SqliteMemory.cs @@ -42,9 +42,7 @@ public sealed partial class SqliteMemory( private readonly int _embeddingDimension = memoryConfig?.Value.EmbeddingDimension ?? 1536; - private volatile Task? _initTask; - - private readonly SemaphoreSlim _initLock = new(1, 1); + private readonly LazyAsyncInit _init = new(); /// Whether the vec0 virtual table was successfully created during init. private bool _vecTableReady; @@ -86,10 +84,10 @@ private static readonly Func> public async Task GetContextAsync(CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); var facts = new List(); - await foreach (var content in GetRecentContentQuery(context).WithCancellation(ct)) + await foreach (var content in GetRecentContentQuery(context).WithCancellation(ct).ConfigureAwait(false)) { facts.Add($"- {content}"); } @@ -99,34 +97,34 @@ private static readonly Func> public async Task AppendFactAsync(string fact, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); // Wrap fact + FTS insert in a transaction to prevent orphaned data on crash - await using var transaction = await context.Database.BeginTransactionAsync(ct); + await using var transaction = await context.Database.BeginTransactionAsync(ct).ConfigureAwait(false); try { var entity = new Fact { Content = fact, CreatedAt = DateTimeOffset.UtcNow }; context.Facts.Add(entity); - await context.SaveChangesAsync(ct); + await context.SaveChangesAsync(ct).ConfigureAwait(false); // Keep the FTS5 shadow table in sync await context.Database.ExecuteSqlAsync( - $"INSERT INTO Facts_fts(rowid, Content) VALUES ({entity.Id}, {fact})", ct); + $"INSERT INTO Facts_fts(rowid, Content) VALUES ({entity.Id}, {fact})", ct).ConfigureAwait(false); - await transaction.CommitAsync(ct); + await transaction.CommitAsync(ct).ConfigureAwait(false); // Embed and store vector if provider is configured (outside transaction — non-critical) if (embeddingProvider is not null) { try { - var embedding = await embeddingProvider.EmbedAsync(fact, ct); + var embedding = await embeddingProvider.EmbedAsync(fact, ct).ConfigureAwait(false); var json = EmbeddingMath.Serialize(embedding); // Store in TEXT column (legacy, used by fallback path) await context.Database.ExecuteSqlAsync( - $"""UPDATE Facts SET embedding = {json} WHERE "Id" = {entity.Id}""", ct); + $"""UPDATE Facts SET embedding = {json} WHERE "Id" = {entity.Id}""", ct).ConfigureAwait(false); // Also insert into vec0 virtual table if available if (_vecTableReady && SqliteVecConnectionInterceptor.VecExtensionLoaded) @@ -134,7 +132,7 @@ await context.Database.ExecuteSqlAsync( try { await context.Database.ExecuteSqlAsync( - $"INSERT INTO Facts_vec(rowid, embedding) VALUES ({entity.Id}, {json})", ct); + $"INSERT INTO Facts_vec(rowid, embedding) VALUES ({entity.Id}, {json})", ct).ConfigureAwait(false); } catch (Exception ex) { @@ -151,23 +149,23 @@ await context.Database.ExecuteSqlAsync( } catch { - await transaction.RollbackAsync(ct); + await transaction.RollbackAsync(ct).ConfigureAwait(false); throw; } } public async Task AppendHistoryAsync(string summary, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); context.History.Add(new HistoryEntry(summary, DateTimeOffset.UtcNow)); - await context.SaveChangesAsync(ct); + await context.SaveChangesAsync(ct).ConfigureAwait(false); } public async Task> SearchAsync(string query, int n = 5, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); try { var ftsQuery = SanitizeFtsQuery(query); @@ -182,7 +180,7 @@ ORDER BY rank return await context.Database .SqlQuery(sql) - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); } catch (Exception ex) { @@ -190,7 +188,7 @@ ORDER BY rank LogMemoryOperationFailed(logger, ex, ex.Message); var pattern = $"%{EscapeLikePattern(query)}%"; var results = new List(); - await foreach (var content in SearchLikeFallbackQuery(context, pattern, n).WithCancellation(ct)) + await foreach (var content in SearchLikeFallbackQuery(context, pattern, n).WithCancellation(ct).ConfigureAwait(false)) { results.Add(content); } @@ -202,15 +200,15 @@ ORDER BY rank public async Task> SearchHybridAsync(string query, float[]? queryEmbedding = null, int topK = 5, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); // If no query embedding, fall back to LIKE search returning Fact objects if (queryEmbedding is null || queryEmbedding.Length == 0) { var pattern = $"%{EscapeLikePattern(query)}%"; var candidates = new List(); - await foreach (var fact in SearchHybridLikeQuery(context, pattern, topK).WithCancellation(ct)) + await foreach (var fact in SearchHybridLikeQuery(context, pattern, topK).WithCancellation(ct).ConfigureAwait(false)) { candidates.Add(fact); } @@ -218,7 +216,7 @@ public async Task> SearchHybridAsync(string query, float[]? var ids = candidates.Select(f => f.Id).ToList(); if (ids.Count > 0) { - await UpdateAccessCountsAsync(ids, ct); + await UpdateAccessCountsAsync(ids, ct).ConfigureAwait(false); } return candidates; @@ -229,7 +227,7 @@ public async Task> SearchHybridAsync(string query, float[]? { try { - return await SearchHybridVecAsync(query, queryEmbedding, topK, context, ct); + return await SearchHybridVecAsync(query, queryEmbedding, topK, context, ct).ConfigureAwait(false); } catch (Exception ex) { @@ -238,7 +236,7 @@ public async Task> SearchHybridAsync(string query, float[]? } // Fallback: FTS5 pre-filter + in-process cosine scoring - return await SearchHybridFallbackAsync(query, queryEmbedding, topK, context, ct); + return await SearchHybridFallbackAsync(query, queryEmbedding, topK, context, ct).ConfigureAwait(false); } /// Internal DTO for vec0 KNN query results. @@ -269,7 +267,7 @@ ORDER BY v.distance var vecResults = await context.Database .SqlQuery(sql) - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); if (vecResults.Count == 0) { @@ -280,7 +278,7 @@ ORDER BY v.distance var facts = await context.Facts .AsNoTracking() .Where(f => ids.Contains(f.Id)) - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); // Merge ANN distance with keyword score for hybrid re-rank var distanceMap = vecResults.ToDictionary(r => r.Id, r => r.Distance); @@ -308,7 +306,7 @@ ORDER BY v.distance var returnedIds = scored.Select(f => f.Id).ToList(); if (returnedIds.Count > 0) { - await UpdateAccessCountsAsync(returnedIds, ct); + await UpdateAccessCountsAsync(returnedIds, ct).ConfigureAwait(false); } return scored; @@ -339,7 +337,7 @@ ORDER BY rank rows = await context.Database .SqlQuery(ftsSql) - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); // FTS5 may return 0 results on no match — fall back to most-recent facts if (rows.Count == 0) @@ -354,7 +352,7 @@ FROM Facts ORDER BY "Id" DESC LIMIT {CandidateLimit} rows = await context.Database .SqlQuery(recentSql) - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); } } catch (Exception ex) @@ -371,7 +369,7 @@ FROM Facts ORDER BY "Id" DESC LIMIT {CandidateLimit} rows = await context.Database .SqlQuery(recentSql) - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); } var scored = rows.Select(row => @@ -413,7 +411,7 @@ FROM Facts ORDER BY "Id" DESC LIMIT {CandidateLimit} var returnedIds = scored.Select(f => f.Id).ToList(); if (returnedIds.Count > 0) { - await UpdateAccessCountsAsync(returnedIds, ct); + await UpdateAccessCountsAsync(returnedIds, ct).ConfigureAwait(false); } return scored; @@ -421,10 +419,10 @@ FROM Facts ORDER BY "Id" DESC LIMIT {CandidateLimit} public async Task> ListFactsAsync(CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); var facts = new List(); - await foreach (var fact in ListAllFactsQuery(context).WithCancellation(ct)) + await foreach (var fact in ListAllFactsQuery(context).WithCancellation(ct).ConfigureAwait(false)) { facts.Add(fact); } @@ -434,33 +432,45 @@ public async Task> ListFactsAsync(CancellationToken ct = def public async Task ClearAsync(CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); - // NOTE: Raw SQL bypasses EF WORM validation; database triggers enforce the constraint at DB level. - await context.Database.ExecuteSqlAsync($"DELETE FROM Facts_fts", ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); + await using var transaction = await context.Database.BeginTransactionAsync(ct).ConfigureAwait(false); - // Clear vec0 table if available - if (_vecTableReady && SqliteVecConnectionInterceptor.VecExtensionLoaded) + try { - try - { - await context.Database.ExecuteSqlAsync($"DELETE FROM Facts_vec", ct); - } - catch (Exception ex) + // NOTE: Raw SQL bypasses EF WORM validation; database triggers enforce the constraint at DB level. + await context.Database.ExecuteSqlAsync($"DELETE FROM Facts_fts", ct).ConfigureAwait(false); + + // Clear vec0 table if available + if (_vecTableReady && SqliteVecConnectionInterceptor.VecExtensionLoaded) { - LogVecClearFailed(logger, ex, ex.Message); + try + { + await context.Database.ExecuteSqlAsync($"DELETE FROM Facts_vec", ct).ConfigureAwait(false); + } + catch (Exception ex) + { + LogVecClearFailed(logger, ex, ex.Message); + } } - } - await context.Facts.ExecuteDeleteAsync(ct); - // History entries are WORM (write-once read-many) — never deleted. - // They represent immutable compaction snapshots and are preserved across clears. + await context.Facts.ExecuteDeleteAsync(ct).ConfigureAwait(false); + // History entries are WORM (write-once read-many) — never deleted. + // They represent immutable compaction snapshots and are preserved across clears. + + await transaction.CommitAsync(ct).ConfigureAwait(false); + } + catch + { + await transaction.RollbackAsync(ct).ConfigureAwait(false); + throw; + } } public async Task PruneExpiredFactsAsync(TimeSpan maxAge, CancellationToken ct = default) { - await EnsureInitializedAsync(ct); - await using var context = await contextFactory.CreateDbContextAsync(ct); + await _init.EnsureCompletedAsync(InitSchemaAsync, ct).ConfigureAwait(false); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); var cutoff = DateTimeOffset.UtcNow - maxAge; // ValueConverter stores DateTimeOffset as ISO 8601 text — lexicographic comparison works for same-offset values. @@ -468,7 +478,7 @@ public async Task PruneExpiredFactsAsync(TimeSpan maxAge, CancellationToken .AsNoTracking() .Where(f => f.CreatedAt < cutoff) .Select(f => f.Id) - .ToListAsync(ct); + .ToListAsync(ct).ConfigureAwait(false); if (expired.Count == 0) { @@ -478,18 +488,18 @@ public async Task PruneExpiredFactsAsync(TimeSpan maxAge, CancellationToken // Batch delete from FTS5 shadow table, vec0 table, and facts atomically to prevent index desync on crash. // IDs are long values so string joining is safe (no injection risk). var idList = string.Join(",", expired); - await using var transaction = await context.Database.BeginTransactionAsync(ct); + await using var transaction = await context.Database.BeginTransactionAsync(ct).ConfigureAwait(false); try { await context.Database.ExecuteSqlRawAsync( - $"DELETE FROM {FtsTable} WHERE rowid IN ({idList})", ct); + $"DELETE FROM {FtsTable} WHERE rowid IN ({idList})", ct).ConfigureAwait(false); if (_vecTableReady && SqliteVecConnectionInterceptor.VecExtensionLoaded) { try { await context.Database.ExecuteSqlRawAsync( - $"DELETE FROM {VecTable} WHERE rowid IN ({idList})", ct); + $"DELETE FROM {VecTable} WHERE rowid IN ({idList})", ct).ConfigureAwait(false); } catch (Exception ex) { @@ -497,13 +507,13 @@ await context.Database.ExecuteSqlRawAsync( } } - await context.Facts.Where(f => expired.Contains(f.Id)).ExecuteDeleteAsync(ct); + await context.Facts.Where(f => expired.Contains(f.Id)).ExecuteDeleteAsync(ct).ConfigureAwait(false); - await transaction.CommitAsync(ct); + await transaction.CommitAsync(ct).ConfigureAwait(false); } catch { - await transaction.RollbackAsync(ct); + await transaction.RollbackAsync(ct).ConfigureAwait(false); throw; } @@ -514,13 +524,13 @@ private async Task UpdateAccessCountsAsync(List ids, CancellationToken ct { try { - await using var ctx = await contextFactory.CreateDbContextAsync(ct); + await using var ctx = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); var now = DateTimeOffset.UtcNow; await ctx.Facts .Where(f => ids.Contains(f.Id)) .ExecuteUpdateAsync(s => s .SetProperty(f => f.AccessCount, f => f.AccessCount + 1) - .SetProperty(f => f.LastAccessedAt, now), ct); + .SetProperty(f => f.LastAccessedAt, now), ct).ConfigureAwait(false); } catch (Exception ex) { @@ -539,54 +549,24 @@ private static string SanitizeFtsQuery(string query) return string.Join(" ", words.Select(w => $"\"{w.Replace("\"", "\"\"")}\"")); } - private async Task EnsureInitializedAsync(CancellationToken ct) - { - if (_initTask is { IsCompletedSuccessfully: true }) - { - return; - } - - await _initLock.WaitAsync(ct); - try - { - if (_initTask is { IsCompletedSuccessfully: true }) - { - return; - } - - var task = InitSchemaAsync(ct); - _initTask = task; - await task; - } - catch - { - _initTask = null; // allow retry on next call - throw; - } - finally - { - _initLock.Release(); - } - } - [RequiresDynamicCode( "EF Core MigrateAsync builds the design-time model at runtime. Not compatible with NativeAOT; use migration bundles for AOT deployment.")] private async Task InitSchemaAsync(CancellationToken ct) { - await using var context = await contextFactory.CreateDbContextAsync(ct); + await using var context = await contextFactory.CreateDbContextAsync(ct).ConfigureAwait(false); // MigrateAsync applies pending migrations. For existing databases originally created via // EnsureCreated, the __EFMigrationsHistory table will be created and the initial migration // marked as applied if the schema already matches. using var migrationCts = CancellationTokenSource.CreateLinkedTokenSource(ct); migrationCts.CancelAfter(TimeSpan.FromSeconds(30)); - await context.Database.MigrateAsync(migrationCts.Token); + await context.Database.MigrateAsync(migrationCts.Token).ConfigureAwait(false); // FTS5 virtual table (content-table backed by facts) await context.Database.ExecuteSqlAsync( - $"CREATE VIRTUAL TABLE IF NOT EXISTS Facts_fts USING fts5(Content, content=Facts, content_rowid=id)"); + $"CREATE VIRTUAL TABLE IF NOT EXISTS Facts_fts USING fts5(Content, content=Facts, content_rowid=id)").ConfigureAwait(false); // Add embedding column if not present (graceful migration) try { - await context.Database.ExecuteSqlAsync($"ALTER TABLE Facts ADD COLUMN embedding TEXT"); + await context.Database.ExecuteSqlAsync($"ALTER TABLE Facts ADD COLUMN embedding TEXT").ConfigureAwait(false); } catch { @@ -596,7 +576,7 @@ await context.Database.ExecuteSqlAsync( // Add access tracking columns if not present (graceful migration) try { - await context.Database.ExecuteSqlAsync($"ALTER TABLE Facts ADD COLUMN AccessCount INTEGER NOT NULL DEFAULT 0"); + await context.Database.ExecuteSqlAsync($"ALTER TABLE Facts ADD COLUMN AccessCount INTEGER NOT NULL DEFAULT 0").ConfigureAwait(false); } catch { @@ -605,7 +585,7 @@ await context.Database.ExecuteSqlAsync( try { - await context.Database.ExecuteSqlAsync($"ALTER TABLE Facts ADD COLUMN LastAccessedAt TEXT"); + await context.Database.ExecuteSqlAsync($"ALTER TABLE Facts ADD COLUMN LastAccessedAt TEXT").ConfigureAwait(false); } catch { @@ -619,14 +599,14 @@ BEFORE UPDATE ON History BEGIN SELECT RAISE(ABORT, 'HistoryEntry is append-only (WORM). UPDATE operations are not allowed.'); END; - """); + """).ConfigureAwait(false); await context.Database.ExecuteSqlRawAsync(""" CREATE TRIGGER IF NOT EXISTS trg_prevent_history_delete BEFORE DELETE ON History BEGIN SELECT RAISE(ABORT, 'HistoryEntry is append-only (WORM). DELETE operations are not allowed.'); END; - """); + """).ConfigureAwait(false); // sqlite-vec: create vec0 virtual table for ANN search if extension is loaded if (SqliteVecConnectionInterceptor.VecExtensionLoaded) @@ -636,7 +616,7 @@ BEFORE DELETE ON History var dim = _embeddingDimension; // DDL: column type definition cannot be parameterized; dim is a trusted int from config var ddl = $"CREATE VIRTUAL TABLE IF NOT EXISTS {VecTable} USING vec0({EmbeddingJsonColumn} float[{dim}])"; - await context.Database.ExecuteSqlRawAsync(ddl); + await context.Database.ExecuteSqlRawAsync(ddl).ConfigureAwait(false); _vecTableReady = true; LogSqliteVecLoaded(logger, dim); } diff --git a/src/clawsharp/Organization/ApprovalQueue.cs b/src/clawsharp/Organization/ApprovalQueue.cs index 3b90437e..6086113d 100644 --- a/src/clawsharp/Organization/ApprovalQueue.cs +++ b/src/clawsharp/Organization/ApprovalQueue.cs @@ -27,6 +27,9 @@ public sealed partial class ApprovalQueue : IHostedService /// Dedup index: "userId\0toolName" -> requestId for pending requests. private readonly ConcurrentDictionary _dedupIndex = new(StringComparer.Ordinal); + /// Tracks pending fire-and-forget storage writes for test flushing. + private readonly ConcurrentBag _pendingWrites = []; + public ApprovalQueue(ApprovalStorage storage, ILogger logger, IOptions config) { _storage = storage; @@ -92,20 +95,34 @@ public async Task InitializeAsync(CancellationToken ct = default) public string Enqueue(OrgUser user, string toolName, ChannelName channel, string senderId) { var dedupKey = DedupKey(user.Name, toolName); + var newId = ApprovalRequest.NewId(); + + // Atomic dedup: only the first writer for this key proceeds + var winningId = _dedupIndex.GetOrAdd(dedupKey, _ => newId); - // Check for existing pending request (dedup) - if (_dedupIndex.TryGetValue(dedupKey, out var existingId) && - _requests.TryGetValue(existingId, out var existing) && - existing.State == ApprovalState.Pending) + if (!string.Equals(winningId, newId, StringComparison.Ordinal)) { - LogDeduplicated(_logger, user.Name, toolName, existingId); - return existingId; + // Lost the race or existing entry — check if existing request is still pending + if (_requests.TryGetValue(winningId, out var existing) && + existing.State == ApprovalState.Pending) + { + LogDeduplicated(_logger, user.Name, toolName, winningId); + return winningId; + } + + // Existing request was resolved; try to replace with our new ID + if (!_dedupIndex.TryUpdate(dedupKey, newId, winningId)) + { + // Another thread beat us; return whatever they set + return _dedupIndex[dedupKey]; + } } + // We own newId; build and persist the request var now = DateTimeOffset.UtcNow; var request = new ApprovalRequest { - Id = ApprovalRequest.NewId(), + Id = newId, UserId = user.Name, ToolName = toolName, Channel = channel.Value, @@ -116,12 +133,8 @@ public string Enqueue(OrgUser user, string toolName, ChannelName channel, string }; _requests[request.Id] = request; - _dedupIndex[dedupKey] = request.Id; - _storage.AppendAsync(request).ContinueWith(t => - { - if (t.IsFaulted) _logger.LogError(t.Exception, "Failed to persist enqueued request {RequestId}", request.Id); - }, TaskContinuationOptions.OnlyOnFaulted); + ScheduleAppend(request); LogEnqueued(_logger, request.Id, user.Name, toolName); return request.Id; @@ -168,10 +181,7 @@ public string Enqueue(OrgUser user, string toolName, ChannelName channel, string _grants[GrantKey(request.UserId, request.ToolName)] = grant; - _storage.AppendAsync(updated).ContinueWith(t => - { - if (t.IsFaulted) _logger.LogError(t.Exception, "Failed to persist approval state for {RequestId}", requestId); - }, TaskContinuationOptions.OnlyOnFaulted); + ScheduleAppend(updated); LogApproved(_logger, requestId, adminName, grantTtl); return grant; @@ -202,10 +212,7 @@ public bool Deny(string requestId, string adminName, string? reason = null) _dedupIndex.TryRemove(DedupKey(request.UserId, request.ToolName), out _); - _storage.AppendAsync(updated).ContinueWith(t => - { - if (t.IsFaulted) _logger.LogError(t.Exception, "Failed to persist denial state for {RequestId}", requestId); - }, TaskContinuationOptions.OnlyOnFaulted); + ScheduleAppend(updated); LogDenied(_logger, requestId, adminName, reason); return true; @@ -244,10 +251,7 @@ public bool Cancel(string requestId, string userId) _dedupIndex.TryRemove(DedupKey(request.UserId, request.ToolName), out _); - _storage.AppendAsync(updated).ContinueWith(t => - { - if (t.IsFaulted) _logger.LogError(t.Exception, "Failed to persist cancellation state for {RequestId}", requestId); - }, TaskContinuationOptions.OnlyOnFaulted); + ScheduleAppend(updated); LogCancelled(_logger, requestId, userId); return true; @@ -325,15 +329,38 @@ private void CleanExpiredRequests() _dedupIndex.TryRemove(DedupKey(request.UserId, request.ToolName), out _); - _storage.AppendAsync(expired).ContinueWith(t => - { - if (t.IsFaulted) _logger.LogError(t.Exception, "Failed to persist expired state for {RequestId}", id); - }, TaskContinuationOptions.OnlyOnFaulted); + ScheduleAppend(expired); LogExpired(_logger, id, request.UserId, request.ToolName); } } + /// + /// Awaits all pending fire-and-forget storage writes. Used by tests to avoid timing-dependent delays. + /// + internal async Task FlushPendingWritesAsync() + { + var tasks = _pendingWrites.ToArray(); + _pendingWrites.Clear(); + await Task.WhenAll(tasks).ConfigureAwait(false); + } + + /// + /// Schedules a fire-and-forget storage append, tracking the task for test flushing. + /// Tracks the original append task (not the error-logging continuation) so that + /// can await successful completion without + /// TaskCanceledException from OnlyOnFaulted continuations. + /// + private void ScheduleAppend(ApprovalRequest request) + { + var appendTask = _storage.AppendAsync(request); + appendTask.ContinueWith(t => + { + if (t.IsFaulted) LogPersistFailed(_logger, request.Id, t.Exception!); + }, TaskContinuationOptions.OnlyOnFaulted); + _pendingWrites.Add(appendTask); + } + private static string DedupKey(string userId, string toolName) => $"{userId}\0{toolName}"; private static string GrantKey(string userId, string toolName) => $"{userId}\0{toolName}"; @@ -364,4 +391,8 @@ private void CleanExpiredRequests() [LoggerMessage(EventId = 7, Level = LogLevel.Information, Message = "Approval request {RequestId} expired for {UserId}:{ToolName}")] private static partial void LogExpired(ILogger logger, string requestId, string userId, string toolName); + + [LoggerMessage(EventId = 8, Level = LogLevel.Error, + Message = "Failed to persist state for {RequestId}")] + private static partial void LogPersistFailed(ILogger logger, string requestId, Exception exception); } diff --git a/src/clawsharp/Organization/ApprovalStorage.cs b/src/clawsharp/Organization/ApprovalStorage.cs index 84ef32fa..26b8ba04 100644 --- a/src/clawsharp/Organization/ApprovalStorage.cs +++ b/src/clawsharp/Organization/ApprovalStorage.cs @@ -39,7 +39,7 @@ public async Task AppendAsync(ApprovalRequest request, CancellationToken ct = de await _writeLock.WaitAsync(ct).ConfigureAwait(false); try { - await File.AppendAllTextAsync(_filePath, json + "\n", ct).ConfigureAwait(false); + await File.AppendAllLinesAsync(_filePath, [json], ct).ConfigureAwait(false); } finally { diff --git a/src/clawsharp/Organization/IdentityResolver.cs b/src/clawsharp/Organization/IdentityResolver.cs index f3359fe6..a7a0541f 100644 --- a/src/clawsharp/Organization/IdentityResolver.cs +++ b/src/clawsharp/Organization/IdentityResolver.cs @@ -15,20 +15,22 @@ namespace Clawsharp.Organization; public sealed class IdentityResolver { /// - /// Immutable snapshot containing both identity indices, swapped atomically - /// to prevent torn reads between index and emailIndex. + /// Immutable snapshot containing the organization config and both identity indices, + /// swapped atomically as a single volatile field to prevent torn reads between + /// orgConfig, index, and emailIndex. /// private sealed record IdentitySnapshot( + OrganizationConfig? OrgConfig, FrozenDictionary Index, FrozenDictionary EmailIndex) { public static readonly IdentitySnapshot Empty = new( + null, FrozenDictionary.Empty, FrozenDictionary.Empty); } private volatile IdentitySnapshot _snapshot = IdentitySnapshot.Empty; - private volatile OrganizationConfig? _orgConfig; /// /// Initializes the resolver, building the inverted index from the current organization config. @@ -46,7 +48,6 @@ public IdentityResolver(IOptions config) private void RebuildIndex(OrganizationConfig? org) { - _orgConfig = org; if (org is null) { _snapshot = IdentitySnapshot.Empty; @@ -69,8 +70,9 @@ private void RebuildIndex(OrganizationConfig? org) } } - // Atomic swap of both indices as a single immutable snapshot + // Single atomic swap of OrgConfig + both indices to prevent torn reads _snapshot = new IdentitySnapshot( + org, builder.ToFrozenDictionary(StringComparer.Ordinal), emailBuilder.ToFrozenDictionary(StringComparer.OrdinalIgnoreCase)); } @@ -83,11 +85,12 @@ private void RebuildIndex(OrganizationConfig? org) /// An with status and optional user. public IdentityResolverResult Resolve(ChannelName channel, string senderId) { - var org = _orgConfig; + // Read snapshot once to ensure OrgConfig and indices are consistent + var snapshot = _snapshot; + var org = snapshot.OrgConfig; if (org is null) return IdentityResolverResult.NoOrg; - var snapshot = _snapshot; var key = $"{channel.Value}:{senderId}"; if (snapshot.Index.TryGetValue(key, out var user)) { @@ -100,7 +103,7 @@ public IdentityResolverResult Resolve(ChannelName channel, string senderId) if (defaults?.RequireEnrollment == true) return IdentityResolverResult.Denied; - var defaultRole = defaults?.DefaultRole ?? "user"; + var defaultRole = defaults?.DefaultRole ?? PolicyDefaults.DefaultRoleName; var defaultRolePolicy = org.Policies?.Roles?.GetValueOrDefault(defaultRole); var guest = OrgUser.Guest(senderId, defaultRole, defaultRolePolicy); return IdentityResolverResult.DefaultedToGuest(guest); @@ -117,7 +120,9 @@ public IdentityResolverResult Resolve(ChannelName channel, string senderId) /// An with resolved user or denial reason. public IdentityResolverResult ResolveFromClaims(IEnumerable claims, IdpConfig idpConfig) { - var org = _orgConfig; + // Read snapshot once to ensure OrgConfig and indices are consistent + var snapshot = _snapshot; + var org = snapshot.OrgConfig; if (org is null) return IdentityResolverResult.NoOrg; @@ -133,7 +138,6 @@ public IdentityResolverResult ResolveFromClaims(IEnumerable claims, IdpCo "Your identity token does not contain an email claim. Contact your administrator."); // Find matching OrgUser by email via O(1) index lookup (case-insensitive) - var snapshot = _snapshot; if (!snapshot.EmailIndex.TryGetValue(email, out var match)) return IdentityResolverResult.DeniedWithMessage( $"Your identity ({email}) is not provisioned in this system. Contact your administrator."); @@ -148,7 +152,7 @@ public IdentityResolverResult ResolveFromClaims(IEnumerable claims, IdpCo } // Map IdP groups to roles via OidcService.MapClaimsToRoles - var defaultRole = org.Policies?.Defaults?.DefaultRole ?? "user"; + var defaultRole = org.Policies?.Defaults?.DefaultRole ?? PolicyDefaults.DefaultRoleName; var mappedRoles = OidcService.MapClaimsToRoles(claimsList, idpConfig, defaultRole); // If MapClaimsToRoles returns null, it means deny behavior triggered (per D-14) @@ -156,31 +160,8 @@ public IdentityResolverResult ResolveFromClaims(IEnumerable claims, IdpCo return IdentityResolverResult.DeniedWithMessage( "Your IdP groups aren't mapped to any roles in this system."); - // Resolve policies for the mapped roles - var resolvedPolicies = new List(); - if (org.Policies?.Roles is { } roleDefs) - { - foreach (var role in mappedRoles) - { - if (roleDefs.TryGetValue(role, out var policy)) - resolvedPolicies.Add(policy); - } - } - - var orgUser = new OrgUser - { - Name = matchedName, - Roles = mappedRoles.ToList().AsReadOnly(), - Department = matchedConfig.Department, - Email = matchedConfig.Email, - Enabled = true, - IsGuest = false, - Metadata = matchedConfig.Metadata is not null - ? new Dictionary(matchedConfig.Metadata, StringComparer.Ordinal).AsReadOnly() - : null, - ResolvedPolicies = resolvedPolicies.AsReadOnly() - }; - + // Build OrgUser using centralized helper, with OIDC-mapped roles instead of config roles + var orgUser = OrgUser.FromConfigWithRoles(matchedName, matchedConfig, org.Policies, mappedRoles); return IdentityResolverResult.Resolved(orgUser); } } diff --git a/src/clawsharp/Organization/LinkTokenStore.cs b/src/clawsharp/Organization/LinkTokenStore.cs index d7053164..d3bcbcd9 100644 --- a/src/clawsharp/Organization/LinkTokenStore.cs +++ b/src/clawsharp/Organization/LinkTokenStore.cs @@ -11,6 +11,7 @@ namespace Clawsharp.Organization; /// public sealed class LinkTokenStore { + private const int CleanupThreshold = 100; private static readonly TimeSpan TokenTtl = TimeSpan.FromMinutes(10); private readonly ConcurrentDictionary _tokens = new(StringComparer.Ordinal); @@ -28,7 +29,7 @@ public LinkTokenStore() public (string Token, string Signature) Generate(string channel, string senderId) { // Lazy cleanup: purge expired tokens when count exceeds threshold - if (_tokens.Count > 100) + if (_tokens.Count > CleanupThreshold) { var now = DateTimeOffset.UtcNow; foreach (var (key, entry) in _tokens) @@ -47,6 +48,27 @@ public LinkTokenStore() return (token, signature); } + /// + /// Non-destructive signature and existence check. Returns true if the token exists, + /// the signature matches, and the token has not expired. Does NOT consume the token. + /// Use before initiating OIDC redirect to reject invalid link URLs early (MED-02). + /// + public bool Peek(string token, string signature) + { + var expectedSig = Sign(token); + if (!CryptographicOperations.FixedTimeEquals( + Encoding.UTF8.GetBytes(signature), + Encoding.UTF8.GetBytes(expectedSig))) + { + return false; + } + + if (!_tokens.TryGetValue(token, out var linkToken)) + return false; + + return linkToken.ExpiresAt > DateTimeOffset.UtcNow; + } + /// /// Validates a token and signature pair. Returns the if valid, null otherwise. /// Performs constant-time signature comparison (per D-24). diff --git a/src/clawsharp/Organization/OidcService.cs b/src/clawsharp/Organization/OidcService.cs index 0ae62c2a..a08a6afc 100644 --- a/src/clawsharp/Organization/OidcService.cs +++ b/src/clawsharp/Organization/OidcService.cs @@ -73,9 +73,9 @@ public async Task BuildAuthorizationUrlAsync( public static (string CodeVerifier, string CodeChallenge) GeneratePkce() { var verifierBytes = RandomNumberGenerator.GetBytes(64); - var codeVerifier = Base64UrlEncode(verifierBytes); + var codeVerifier = Base64UrlEncoder.Encode(verifierBytes); var challengeBytes = SHA256.HashData(Encoding.ASCII.GetBytes(codeVerifier)); - var codeChallenge = Base64UrlEncode(challengeBytes); + var codeChallenge = Base64UrlEncoder.Encode(challengeBytes); return (codeVerifier, codeChallenge); } @@ -161,48 +161,11 @@ public static (string State, string Nonce) GenerateStateAndNonce() public async Task?> ValidateIdTokenAsync( string idToken, string nonce, CancellationToken ct = default) { - var oidcConfig = await _configManager.GetConfigurationAsync(ct).ConfigureAwait(false); - - var validationParams = new TokenValidationParameters - { - ValidateIssuer = true, - ValidIssuer = oidcConfig.Issuer, - ValidateAudience = true, - ValidAudience = _config.ClientId, - ValidateLifetime = true, - IssuerSigningKeys = oidcConfig.SigningKeys, - ValidateIssuerSigningKey = true, - ClockSkew = TimeSpan.FromMinutes(2) - }; - - var result = await _tokenHandler.ValidateTokenAsync(idToken, validationParams).ConfigureAwait(false); - - if (!result.IsValid) - { - // If key not found, force JWKS refresh and retry once (per D-18/D-19) - if (result.Exception is SecurityTokenSignatureKeyNotFoundException) - { - LogJwksRefresh(_logger); - _configManager.RequestRefresh(); - oidcConfig = await _configManager.GetConfigurationAsync(ct).ConfigureAwait(false); - validationParams.IssuerSigningKeys = oidcConfig.SigningKeys; - result = await _tokenHandler.ValidateTokenAsync(idToken, validationParams).ConfigureAwait(false); - - if (!result.IsValid) - { - LogTokenValidationFailed(_logger, result.Exception?.Message ?? "unknown"); - return null; - } - } - else - { - LogTokenValidationFailed(_logger, result.Exception?.Message ?? "unknown"); - return null; - } - } + var jwt = await ValidateTokenCoreAsync(idToken, ct).ConfigureAwait(false); + if (jwt is null) + return null; // Validate nonce claim - var jwt = (JsonWebToken)result.SecurityToken; var tokenNonce = jwt.Claims.FirstOrDefault(c => string.Equals(c.Type, "nonce", StringComparison.Ordinal))?.Value; if (!string.Equals(tokenNonce, nonce, StringComparison.Ordinal)) @@ -226,6 +189,19 @@ public static (string State, string Nonce) GenerateStateAndNonce() /// Claims on success, null on validation failure. public async Task?> ValidateBearerTokenAsync( string jwt, CancellationToken ct = default) + { + // No nonce validation -- MCP Bearer tokens are pre-issued, not from OIDC auth code flow + var token = await ValidateTokenCoreAsync(jwt, ct).ConfigureAwait(false); + return token?.Claims; + } + + /// + /// Core JWT validation with JWKS key-rotation retry. Builds + /// from OIDC discovery, validates the token, and on + /// forces a JWKS refresh and retries once (per D-18/D-19). + /// + /// The validated on success, or null on validation failure. + private async Task ValidateTokenCoreAsync(string token, CancellationToken ct) { var oidcConfig = await _configManager.GetConfigurationAsync(ct).ConfigureAwait(false); @@ -241,7 +217,7 @@ public static (string State, string Nonce) GenerateStateAndNonce() ClockSkew = TimeSpan.FromMinutes(2) }; - var result = await _tokenHandler.ValidateTokenAsync(jwt, validationParams).ConfigureAwait(false); + var result = await _tokenHandler.ValidateTokenAsync(token, validationParams).ConfigureAwait(false); if (!result.IsValid) { @@ -252,24 +228,22 @@ public static (string State, string Nonce) GenerateStateAndNonce() _configManager.RequestRefresh(); oidcConfig = await _configManager.GetConfigurationAsync(ct).ConfigureAwait(false); validationParams.IssuerSigningKeys = oidcConfig.SigningKeys; - result = await _tokenHandler.ValidateTokenAsync(jwt, validationParams).ConfigureAwait(false); + result = await _tokenHandler.ValidateTokenAsync(token, validationParams).ConfigureAwait(false); if (!result.IsValid) { - LogBearerTokenValidationFailed(_logger, result.Exception?.Message ?? "unknown"); + LogTokenValidationFailed(_logger, result.Exception?.Message ?? "unknown"); return null; } } else { - LogBearerTokenValidationFailed(_logger, result.Exception?.Message ?? "unknown"); + LogTokenValidationFailed(_logger, result.Exception?.Message ?? "unknown"); return null; } } - // No nonce validation -- MCP Bearer tokens are pre-issued, not from OIDC auth code flow - var token = (JsonWebToken)result.SecurityToken; - return token.Claims; + return (JsonWebToken)result.SecurityToken; } /// @@ -329,14 +303,6 @@ public static (string State, string Nonce) GenerateStateAndNonce() return mappedRoles; } - private static string Base64UrlEncode(byte[] bytes) - { - return Convert.ToBase64String(bytes) - .TrimEnd('=') - .Replace('+', '-') - .Replace('/', '_'); - } - [LoggerMessage(EventId = 1, Level = LogLevel.Warning, Message = "OIDC nonce mismatch: id_token nonce does not match expected value")] private static partial void LogNonceMismatch(ILogger logger); @@ -356,8 +322,4 @@ private static string Base64UrlEncode(byte[] bytes) [LoggerMessage(EventId = 5, Level = LogLevel.Warning, Message = "Token exchange response did not contain id_token")] private static partial void LogNoIdTokenInResponse(ILogger logger); - - [LoggerMessage(EventId = 6, Level = LogLevel.Warning, - Message = "Bearer token validation failed: {Error}")] - private static partial void LogBearerTokenValidationFailed(ILogger logger, string error); } diff --git a/src/clawsharp/Organization/OrgUser.cs b/src/clawsharp/Organization/OrgUser.cs index 5a6ec8a5..41b95267 100644 --- a/src/clawsharp/Organization/OrgUser.cs +++ b/src/clawsharp/Organization/OrgUser.cs @@ -64,6 +64,40 @@ public static OrgUser FromConfig(string name, OrgUserConfig userConfig, Policies }; } + /// + /// Creates an from a config entry with overridden roles (e.g., OIDC-mapped roles). + /// Centralizes policy resolution so both channel resolution and OIDC paths use the same logic. + /// + public static OrgUser FromConfigWithRoles( + string name, OrgUserConfig userConfig, PoliciesConfig? policies, IReadOnlyCollection roles) + { + var resolvedPolicies = new List(); + if (policies?.Roles is { } roleDefs) + { + foreach (var role in roles) + { + if (roleDefs.TryGetValue(role, out var policy)) + { + resolvedPolicies.Add(policy); + } + } + } + + return new OrgUser + { + Name = name, + Roles = roles as IReadOnlyList ?? roles.ToList().AsReadOnly(), + Department = userConfig.Department, + Email = userConfig.Email, + Enabled = userConfig.Enabled, + IsGuest = false, + Metadata = userConfig.Metadata is not null + ? new Dictionary(userConfig.Metadata, StringComparer.Ordinal).AsReadOnly() + : null, + ResolvedPolicies = resolvedPolicies.AsReadOnly() + }; + } + /// /// Creates a guest for unknown senders assigned the default role. /// diff --git a/src/clawsharp/Organization/PolicyEvaluator.cs b/src/clawsharp/Organization/PolicyEvaluator.cs index 07f4e1b2..ab6cd5ca 100644 --- a/src/clawsharp/Organization/PolicyEvaluator.cs +++ b/src/clawsharp/Organization/PolicyEvaluator.cs @@ -1,6 +1,7 @@ using System.Collections.Concurrent; using Clawsharp.Config.Organization; using Clawsharp.Tools; +using Microsoft.Extensions.Logging; namespace Clawsharp.Organization; @@ -20,10 +21,28 @@ namespace Clawsharp.Organization; /// D-10: RequireApproval lists use union (most restrictive -- any role's approval requirement applies). /// /// -public sealed class PolicyEvaluator +public sealed partial class PolicyEvaluator { private readonly ConcurrentDictionary _denialCounts = new(StringComparer.Ordinal); + private readonly ILogger? _logger; private const int SuspiciousDenialThreshold = 3; + private const int MaxDenialEntries = 10_000; + + /// + /// Creates a PolicyEvaluator with logger support for denial cap warnings. + /// Preferred constructor for DI registration. + /// + public PolicyEvaluator(ILogger logger) + { + _logger = logger; + } + + /// + /// Creates a PolicyEvaluator without logging. Used by CLI commands and tests. + /// + public PolicyEvaluator() + { + } /// /// Merges all resolved role policies for a user into a single . @@ -114,10 +133,17 @@ public PolicyDecision MergeRoles(OrgUser? user) /// /// Records a denial for the given session. Returns true when the threshold /// is reached and an audit event should be logged (D-04). + /// Evicts excess entries when the cap is reached to prevent unbounded growth. /// public bool RecordDenial(string sessionId) { var count = _denialCounts.AddOrUpdate(sessionId, 1, (_, c) => c + 1); + + if (_denialCounts.Count > MaxDenialEntries) + { + EvictExcessDenialEntries(); + } + return count == SuspiciousDenialThreshold; } @@ -203,17 +229,7 @@ private static bool EvaluateConditions(AbacCondition? when, AbacContext context) // Role condition: user must have the specified role if (when.Role is { } requiredRole) { - var hasRole = false; - foreach (var userRole in context.User.Roles) - { - if (string.Equals(userRole, requiredRole, StringComparison.Ordinal)) - { - hasRole = true; - break; - } - } - - if (!hasRole) + if (!ContainsOrdinal(context.User.Roles, requiredRole)) return false; } @@ -221,21 +237,8 @@ private static bool EvaluateConditions(AbacCondition? when, AbacContext context) if (when.Channel is not null) { var channelNames = when.GetChannelNames(); - if (channelNames.Count > 0) - { - var channelMatch = false; - foreach (var ch in channelNames) - { - if (string.Equals(ch, context.Channel.Value, StringComparison.Ordinal)) - { - channelMatch = true; - break; - } - } - - if (!channelMatch) - return false; - } + if (channelNames.Count > 0 && !ContainsOrdinal(channelNames, context.Channel.Value)) + return false; } // TimeWindow condition: frozen timestamp must be within at least one window @@ -279,5 +282,49 @@ private static bool EvaluateConditions(AbacCondition? when, AbacContext context) _ => string.Empty }; + /// + /// Evicts entries with the lowest denial counts when the cap is exceeded. + /// This prevents unbounded memory growth from sessions that are never pruned. + /// + private void EvictExcessDenialEntries() + { + var excess = _denialCounts.Count - MaxDenialEntries; + if (excess <= 0) + return; + + LogDenialCountCapReached(_logger, _denialCounts.Count, MaxDenialEntries); + + // Evict entries with the lowest counts first (least suspicious sessions) + var toEvict = _denialCounts + .OrderBy(kv => kv.Value) + .Take(excess + MaxDenialEntries / 10) // Evict a 10% buffer to avoid frequent eviction + .Select(kv => kv.Key) + .ToList(); + + foreach (var key in toEvict) + { + _denialCounts.TryRemove(key, out _); + } + } + + [LoggerMessage(EventId = 1, Level = LogLevel.Warning, + Message = "Denial count tracker reached capacity ({Count}/{MaxEntries}); evicting lowest-count entries")] + private static partial void LogDenialCountCapReached(ILogger? logger, int count, int maxEntries); + private static ToolSensitivity ParseSensitivity(string? value) => ToolSensitivityParser.Parse(value); + + /// + /// Returns true if contains using ordinal comparison. + /// Used for role and channel matching in ABAC condition evaluation. + /// + private static bool ContainsOrdinal(IReadOnlyList list, string value) + { + for (var i = 0; i < list.Count; i++) + { + if (string.Equals(list[i], value, StringComparison.Ordinal)) + return true; + } + + return false; + } } diff --git a/src/clawsharp/Organization/PolicyExplainer.cs b/src/clawsharp/Organization/PolicyExplainer.cs index 2d75e7c6..eb956de4 100644 --- a/src/clawsharp/Organization/PolicyExplainer.cs +++ b/src/clawsharp/Organization/PolicyExplainer.cs @@ -92,7 +92,7 @@ private static string ExplainDefault( { var effect = rule.Effect.ToUpperInvariant(); var toolPattern = rule.When?.Tool ?? "*"; - sb.AppendLine($" [EXPIRED] {effect} {toolPattern} (rule: {ruleId}, expired {rule.ExpiresAt!.Value.ToString("O")})"); + sb.AppendLine($" [EXPIRED] {effect} {toolPattern} (rule: {ruleId}, expired {rule.ExpiresAt.Value.ToString("O")})"); } else { @@ -152,7 +152,7 @@ private static string ExplainVerbose( if (isExpired) { - sb.AppendLine($" Rule {ruleId}: [EXPIRED {rule.ExpiresAt!.Value.ToString("O")}] skipped"); + sb.AppendLine($" Rule {ruleId}: [EXPIRED {rule.ExpiresAt.Value.ToString("O")}] skipped"); } else { diff --git a/src/clawsharp/Organization/PolicySimulator.cs b/src/clawsharp/Organization/PolicySimulator.cs index 6ed39be6..639103ae 100644 --- a/src/clawsharp/Organization/PolicySimulator.cs +++ b/src/clawsharp/Organization/PolicySimulator.cs @@ -153,7 +153,70 @@ private static string SimulateToolVerbose( var rbacAllowed = decision.IsToolAllowed(toolName); sb.AppendLine($"RBAC: {(rbacAllowed ? "allowed" : "denied")} ({(rbacAllowed ? "matches pattern" : "no matching pattern")})"); - // ABAC check + AppendVerboseAbacSection(sb, decision, toolName); + + // Sensitivity check + sb.AppendLine($"Sensitivity: {ToolSensitivityName(toolSensitivity)} (max: {ToolSensitivityName(decision.MaxSensitivity)}) -> {(toolSensitivity <= decision.MaxSensitivity ? "OK" : "DENIED")}"); + + AppendVerboseBudgetSection(sb, decision, snap, user.Department); + + // Budget exceeded check + var budgetBlocked = IsBudgetExceeded(decision, snap); + + // Result + if (budgetBlocked is not null) + { + sb.Append($"=== Result: BLOCKED ({budgetBlocked}) ==="); + } + else + { + var resultStr = effect switch + { + PolicyEffect.Allowed => "ALLOWED", + PolicyEffect.DeniedByGlob => "DENIED (glob)", + PolicyEffect.DeniedBySensitivity => "DENIED (sensitivity)", + PolicyEffect.DeniedByAbac => "DENIED (ABAC)", + PolicyEffect.ApprovalRequired => "PENDING (approval)", + _ => "DENIED" + }; + sb.Append($"=== Result: {resultStr} ==="); + } + + return sb.ToString().TrimEnd(); + } + + private static string SimulateModelVerbose( + OrgUser user, + PolicyDecision decision, + string modelId, + bool isAllowed) + { + var sb = new StringBuilder(); + sb.AppendLine($"=== Simulation: @{user.Name} -> model {modelId} ==="); + + if (decision.IsUnrestrictedModels) + { + sb.AppendLine("Model access: unrestricted (*)"); + } + else + { + sb.AppendLine($"Model patterns: {string.Join(", ", decision.ModelPatterns)}"); + sb.AppendLine($"Match: {(isAllowed ? "yes" : "no")}"); + } + + sb.Append($"=== Result: {(isAllowed ? "ALLOWED" : "DENIED")} ==="); + + return sb.ToString().TrimEnd(); + } + + // ── Helpers ── + + /// Appends the ABAC evaluation trace for verbose tool simulation. + private static void AppendVerboseAbacSection( + StringBuilder sb, + PolicyDecision decision, + string toolName) + { if (decision.AbacDenyToolPatterns.Count > 0 || decision.AbacExceptionToolPatterns.Count > 0) { var abacDenied = false; @@ -187,29 +250,33 @@ private static string SimulateToolVerbose( { sb.AppendLine("ABAC: no rules configured"); } + } - // Sensitivity check - sb.AppendLine($"Sensitivity: {ToolSensitivityName(toolSensitivity)} (max: {ToolSensitivityName(decision.MaxSensitivity)}) -> {(toolSensitivity <= decision.MaxSensitivity ? "OK" : "DENIED")}"); - - // Budget + /// Appends the budget usage/limits trace for verbose tool simulation. + private static void AppendVerboseBudgetSection( + StringBuilder sb, + PolicyDecision decision, + BudgetSnapshot snap, + string? department) + { sb.Append("Budget: "); if (decision.Budget is { } ub) { var parts = new List(); if (ub.Daily > 0) { - parts.Add($"Personal ${snap.UserDailyUsed:F2}/${ub.Daily:F2} daily"); + parts.Add($"Personal ${snap.UserDailyUsed:F2}/{ub.Daily:F2} daily"); } if (ub.Monthly > 0) { - parts.Add($"Personal ${snap.UserMonthlyUsed:F2}/${ub.Monthly:F2} monthly"); + parts.Add($"Personal ${snap.UserMonthlyUsed:F2}/{ub.Monthly:F2} monthly"); } - if (snap.DeptBudget is not null && snap.DeptMonthlyUsed.HasValue && user.Department is not null) + if (snap.DeptBudget is not null && snap.DeptMonthlyUsed.HasValue && department is not null) { if (snap.DeptBudget.Monthly > 0) { - parts.Add($"Department ({user.Department}) ${snap.DeptMonthlyUsed.Value:F2}/${snap.DeptBudget.Monthly:F2} monthly"); + parts.Add($"Department ({department}) ${snap.DeptMonthlyUsed.Value:F2}/{snap.DeptBudget.Monthly:F2} monthly"); } } @@ -219,58 +286,8 @@ private static string SimulateToolVerbose( { sb.AppendLine("no limits configured"); } - - // Budget exceeded check - var budgetBlocked = IsBudgetExceeded(decision, snap); - - // Result - if (budgetBlocked is not null) - { - sb.Append($"=== Result: BLOCKED ({budgetBlocked}) ==="); - } - else - { - var resultStr = effect switch - { - PolicyEffect.Allowed => "ALLOWED", - PolicyEffect.DeniedByGlob => "DENIED (glob)", - PolicyEffect.DeniedBySensitivity => "DENIED (sensitivity)", - PolicyEffect.DeniedByAbac => "DENIED (ABAC)", - PolicyEffect.ApprovalRequired => "PENDING (approval)", - _ => "DENIED" - }; - sb.Append($"=== Result: {resultStr} ==="); - } - - return sb.ToString().TrimEnd(); } - private static string SimulateModelVerbose( - OrgUser user, - PolicyDecision decision, - string modelId, - bool isAllowed) - { - var sb = new StringBuilder(); - sb.AppendLine($"=== Simulation: @{user.Name} -> model {modelId} ==="); - - if (decision.IsUnrestrictedModels) - { - sb.AppendLine("Model access: unrestricted (*)"); - } - else - { - sb.AppendLine($"Model patterns: {string.Join(", ", decision.ModelPatterns)}"); - sb.AppendLine($"Match: {(isAllowed ? "yes" : "no")}"); - } - - sb.Append($"=== Result: {(isAllowed ? "ALLOWED" : "DENIED")} ==="); - - return sb.ToString().TrimEnd(); - } - - // ── Helpers ── - private static void AppendBudgetLine( StringBuilder sb, PolicyDecision decision, diff --git a/src/clawsharp/Program.cs b/src/clawsharp/Program.cs index 562ff5b0..5f4004cd 100644 --- a/src/clawsharp/Program.cs +++ b/src/clawsharp/Program.cs @@ -122,7 +122,7 @@ { channel.SetDescription("Channel management"); channel.AddCommand("status") - .WithDescription("Show enabled/disabled state for all 8 channels"); + .WithDescription("Show enabled/disabled state for all channels"); channel.AddCommand("pair-web") .WithDescription("Request a new web pairing code from the running gateway"); }); diff --git a/src/clawsharp/Providers/Anthropic/AnthropicJsonContext.cs b/src/clawsharp/Providers/Anthropic/AnthropicJsonContext.cs index 602d51d1..799712c2 100644 --- a/src/clawsharp/Providers/Anthropic/AnthropicJsonContext.cs +++ b/src/clawsharp/Providers/Anthropic/AnthropicJsonContext.cs @@ -29,4 +29,4 @@ namespace Clawsharp.Providers.Anthropic; [JsonSerializable(typeof(StreamContentBlock))] [JsonSerializable(typeof(StreamMessageStart))] [JsonSourceGenerationOptions(DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull)] -internal partial class AnthropicJsonContext : JsonSerializerContext; \ No newline at end of file +internal sealed partial class AnthropicJsonContext : JsonSerializerContext; \ No newline at end of file diff --git a/src/clawsharp/Providers/Anthropic/AnthropicProvider.cs b/src/clawsharp/Providers/Anthropic/AnthropicProvider.cs index ddb49f5a..2634f4fe 100644 --- a/src/clawsharp/Providers/Anthropic/AnthropicProvider.cs +++ b/src/clawsharp/Providers/Anthropic/AnthropicProvider.cs @@ -107,7 +107,7 @@ public async IAsyncEnumerable StreamAsync( var jsonBytes = JsonSerializer.SerializeToUtf8Bytes(anthReq, AnthropicJsonContext.Default.MessagesRequest); var (http, resp, body) = await ProviderRequestHandler.SendStreamingAsync( - httpClientFactory, anthReq.Url, jsonBytes, ConfigureHeaders, "Anthropic API stream", ct).ConfigureAwait(false); + httpClientFactory, anthReq.Url, jsonBytes, ConfigureHeaders, "Anthropic API stream", ct); using var _ = http; using var __ = resp; diff --git a/src/clawsharp/Providers/Bedrock/BedrockProvider.cs b/src/clawsharp/Providers/Bedrock/BedrockProvider.cs index 4c8e0b93..7c36b664 100644 --- a/src/clawsharp/Providers/Bedrock/BedrockProvider.cs +++ b/src/clawsharp/Providers/Bedrock/BedrockProvider.cs @@ -1,3 +1,4 @@ +using System.Net.Http.Headers; using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; @@ -36,12 +37,15 @@ public async Task ChatAsync(ChatRequest request, CancellationToken var endpoint = $"https://{Service}.{region}.amazonaws.com/model/{encodedModel}/converse"; var uri = new Uri(endpoint); - // Sign the request with SigV4 + // Sign the request with SigV4 (signer needs the JSON string for hash computation) var headers = AwsSigV4Signer.Sign("POST", uri, json, accessKeyId, secretAccessKey, region, Service, DateTimeOffset.UtcNow); + // Use ReadOnlyMemoryContent to avoid StringContent's internal UTF-16 → UTF-8 re-encoding. + var jsonBytes = Encoding.UTF8.GetBytes(json); using var http = httpClientFactory.CreateClient("llm"); using var httpReq = new HttpRequestMessage(HttpMethod.Post, uri); - httpReq.Content = new StringContent(json, Encoding.UTF8, "application/json"); + httpReq.Content = new ReadOnlyMemoryContent(jsonBytes); + httpReq.Content.Headers.ContentType = new MediaTypeHeaderValue("application/json") { CharSet = "utf-8" }; foreach (var (key, value) in headers) { @@ -62,7 +66,7 @@ public async Task ChatAsync(ChatRequest request, CancellationToken throw new HttpRequestException($"Bedrock Converse API error {resp.StatusCode}: {ProviderRequestHandler.SanitizeErrorBody(err)}"); } - await using var stream = await resp.Content.ReadAsStreamAsync(ct).ConfigureAwait(false); + await using var stream = await resp.Content.ReadAsStreamAsync(ct); var converseResponse = await JsonSerializer.DeserializeAsync(stream, BedrockJsonContext.Default.BedrockConverseResponse, ct) .ConfigureAwait(false) ?? throw new InvalidOperationException("Empty response from Bedrock Converse API."); @@ -84,11 +88,15 @@ public async IAsyncEnumerable StreamAsync( var endpoint = $"https://{Service}.{region}.amazonaws.com/model/{encodedModel}/converse-stream"; var uri = new Uri(endpoint); + // Sign the request with SigV4 (signer needs the JSON string for hash computation) var headers = AwsSigV4Signer.Sign("POST", uri, json, accessKeyId, secretAccessKey, region, Service, DateTimeOffset.UtcNow); + // Use ReadOnlyMemoryContent to avoid StringContent's internal UTF-16 → UTF-8 re-encoding. + var jsonBytes = Encoding.UTF8.GetBytes(json); using var http = httpClientFactory.CreateClient("llm"); using var httpReq = new HttpRequestMessage(HttpMethod.Post, uri); - httpReq.Content = new StringContent(json, Encoding.UTF8, "application/json"); + httpReq.Content = new ReadOnlyMemoryContent(jsonBytes); + httpReq.Content.Headers.ContentType = new MediaTypeHeaderValue("application/json") { CharSet = "utf-8" }; foreach (var (key, value) in headers) { @@ -110,10 +118,10 @@ public async IAsyncEnumerable StreamAsync( $"Bedrock ConverseStream API error {resp.StatusCode}: {ProviderRequestHandler.SanitizeErrorBody(err)}"); } - await using var stream = await resp.Content.ReadAsStreamAsync(ct).ConfigureAwait(false); + await using var stream = await resp.Content.ReadAsStreamAsync(ct); var doneEmitted = false; - await foreach (var (eventType, payload) in BedrockStreamParser.ParseAsync(stream, ct)) + await foreach (var (eventType, payload) in BedrockStreamParser.ParseAsync(stream, ct).ConfigureAwait(false)) { switch (eventType) { diff --git a/src/clawsharp/Providers/Bedrock/BedrockStreamParser.cs b/src/clawsharp/Providers/Bedrock/BedrockStreamParser.cs index e0490384..38024138 100644 --- a/src/clawsharp/Providers/Bedrock/BedrockStreamParser.cs +++ b/src/clawsharp/Providers/Bedrock/BedrockStreamParser.cs @@ -24,7 +24,7 @@ internal static class BedrockStreamParser while (!ct.IsCancellationRequested) { // Read 12-byte prelude: total_length(4) + headers_length(4) + prelude_crc(4) - if (!await ReadExactAsync(stream, prelude, 0, 12, ct)) + if (!await ReadExactAsync(stream, prelude, 0, 12, ct).ConfigureAwait(false)) { yield break; } @@ -43,7 +43,7 @@ internal static class BedrockStreamParser var messageBytes = ArrayPool.Shared.Rent(remaining); try { - if (!await ReadExactAsync(stream, messageBytes, 0, remaining, ct)) + if (!await ReadExactAsync(stream, messageBytes, 0, remaining, ct).ConfigureAwait(false)) { yield break; } @@ -158,7 +158,7 @@ private static async Task ReadExactAsync( var totalRead = 0; while (totalRead < count) { - var read = await stream.ReadAsync(buffer.AsMemory(offset + totalRead, count - totalRead), ct); + var read = await stream.ReadAsync(buffer.AsMemory(offset + totalRead, count - totalRead), ct).ConfigureAwait(false); if (read == 0) { return false; // Stream ended diff --git a/src/clawsharp/Providers/Copilot/CopilotProvider.cs b/src/clawsharp/Providers/Copilot/CopilotProvider.cs index bef0094d..3bcb2bf6 100644 --- a/src/clawsharp/Providers/Copilot/CopilotProvider.cs +++ b/src/clawsharp/Providers/Copilot/CopilotProvider.cs @@ -31,8 +31,8 @@ public sealed class CopilotProvider(IHttpClientFactory httpClientFactory, GitHub public async Task ChatAsync(ChatRequest request, CancellationToken ct = default) { - var inner = await CreateInnerProviderAsync(ct); - return await inner.ChatAsync(request, ct); + var inner = await CreateInnerProviderAsync(ct).ConfigureAwait(false); + return await inner.ChatAsync(request, ct).ConfigureAwait(false); } /// @@ -40,8 +40,8 @@ public async IAsyncEnumerable StreamAsync( ChatRequest request, [EnumeratorCancellation] CancellationToken ct = default) { - var inner = await CreateInnerProviderAsync(ct); - await foreach (var chunk in inner.StreamAsync(request, ct)) + var inner = await CreateInnerProviderAsync(ct).ConfigureAwait(false); + await foreach (var chunk in inner.StreamAsync(request, ct).ConfigureAwait(false)) { yield return chunk; } @@ -59,13 +59,13 @@ public async IAsyncEnumerable StreamAsync( /// private async Task CreateInnerProviderAsync(CancellationToken ct) { - var token = await GetAuthTokenAsync(ct); + var token = await GetAuthTokenAsync(ct).ConfigureAwait(false); return new OpenAiProvider(httpClientFactory, CopilotBaseUrl, token, "copilot"); } private async Task GetAuthTokenAsync(CancellationToken ct) { - var oauthToken = await AuthStore.LoadAsync("copilot", ct); + var oauthToken = await AuthStore.LoadAsync("copilot", ct).ConfigureAwait(false); if (oauthToken is null) { throw new InvalidOperationException( @@ -78,10 +78,10 @@ private async Task GetAuthTokenAsync(CancellationToken ct) } // Token expired -- refresh using stored GitHub OAuth token - await _refreshLock.WaitAsync(ct); + await _refreshLock.WaitAsync(ct).ConfigureAwait(false); try { - oauthToken = await AuthStore.LoadAsync("copilot", ct); + oauthToken = await AuthStore.LoadAsync("copilot", ct).ConfigureAwait(false); if (oauthToken is not null && !oauthToken.IsExpired) { return oauthToken.AccessToken; @@ -94,14 +94,14 @@ private async Task GetAuthTokenAsync(CancellationToken ct) } AnsiConsole.MarkupLine("[yellow][[copilot]][/] Token expired, refreshing..."); - var refreshed = await deviceFlow.RefreshCopilotTokenAsync(oauthToken.RefreshToken, ct); + var refreshed = await deviceFlow.RefreshCopilotTokenAsync(oauthToken.RefreshToken, ct).ConfigureAwait(false); if (refreshed is null) { throw new InvalidOperationException( "Failed to refresh Copilot token. Run: clawsharp auth login-copilot"); } - await AuthStore.SaveAsync("copilot", refreshed, ct); + await AuthStore.SaveAsync("copilot", refreshed, ct).ConfigureAwait(false); return refreshed.AccessToken; } finally diff --git a/src/clawsharp/Providers/Gemini/GeminiJsonContext.cs b/src/clawsharp/Providers/Gemini/GeminiJsonContext.cs index cf5ff7a7..d49b8f42 100644 --- a/src/clawsharp/Providers/Gemini/GeminiJsonContext.cs +++ b/src/clawsharp/Providers/Gemini/GeminiJsonContext.cs @@ -24,4 +24,4 @@ namespace Clawsharp.Providers.Gemini; [JsonSerializable(typeof(List))] [JsonSerializable(typeof(List))] [JsonSourceGenerationOptions(DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull)] -internal partial class GeminiJsonContext : JsonSerializerContext; \ No newline at end of file +internal sealed partial class GeminiJsonContext : JsonSerializerContext; \ No newline at end of file diff --git a/src/clawsharp/Providers/Gemini/GeminiProvider.cs b/src/clawsharp/Providers/Gemini/GeminiProvider.cs index 58843e50..ec8c6507 100644 --- a/src/clawsharp/Providers/Gemini/GeminiProvider.cs +++ b/src/clawsharp/Providers/Gemini/GeminiProvider.cs @@ -87,7 +87,7 @@ public async IAsyncEnumerable StreamAsync(ChatRequest request, [Enu var jsonBytes = JsonSerializer.SerializeToUtf8Bytes(gemReq, GeminiJsonContext.Default.GenerateContentRequest); var (http, resp, body) = await ProviderRequestHandler.SendStreamingAsync( - httpClientFactory, url, jsonBytes, ConfigureHeaders, "Gemini streaming API", ct).ConfigureAwait(false); + httpClientFactory, url, jsonBytes, ConfigureHeaders, "Gemini streaming API", ct); using var _ = http; using var __ = resp; @@ -113,11 +113,10 @@ public async IAsyncEnumerable StreamAsync(ChatRequest request, [Enu } // MED-57: Check for error field in streaming response chunks. + // Throw before emitting a done chunk — the post-loop guard will emit + // StreamDoneChunk when the exception causes the iterator to exit. if (gemResp.Error is { } streamErr) { - // Emit a done chunk before throwing so the stream is properly terminated. - doneEmitted = true; - yield return new StreamDoneChunk(); throw new HttpRequestException( $"Gemini streaming error {streamErr.Code}: {ProviderRequestHandler.SanitizeErrorBody(streamErr.Message)}"); } @@ -168,7 +167,8 @@ public async Task CheckHealthAsync(CancellationToken ct = def try { using var http = httpClientFactory.CreateClient("llm"); - using var req = new HttpRequestMessage(HttpMethod.Get, $"{BaseUrl}?key={apiKey}"); + using var req = new HttpRequestMessage(HttpMethod.Get, BaseUrl); + ConfigureHeaders(req); using var resp = await http.SendAsync(req, HttpCompletionOption.ResponseHeadersRead, ct).ConfigureAwait(false); sw.Stop(); diff --git a/src/clawsharp/Providers/OpenAi/OpenAiJsonContext.cs b/src/clawsharp/Providers/OpenAi/OpenAiJsonContext.cs index b572c692..df3117a1 100644 --- a/src/clawsharp/Providers/OpenAi/OpenAiJsonContext.cs +++ b/src/clawsharp/Providers/OpenAi/OpenAiJsonContext.cs @@ -37,4 +37,4 @@ namespace Clawsharp.Providers.OpenAi; [JsonSerializable(typeof(AudioContentData))] [JsonSerializable(typeof(StreamAudioDelta))] [JsonSourceGenerationOptions(DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull)] -internal partial class OpenAiJsonContext : JsonSerializerContext; \ No newline at end of file +internal sealed partial class OpenAiJsonContext : JsonSerializerContext; \ No newline at end of file diff --git a/src/clawsharp/Providers/OpenAi/OpenAiProvider.cs b/src/clawsharp/Providers/OpenAi/OpenAiProvider.cs index a2b5d4c9..807652ea 100644 --- a/src/clawsharp/Providers/OpenAi/OpenAiProvider.cs +++ b/src/clawsharp/Providers/OpenAi/OpenAiProvider.cs @@ -84,7 +84,7 @@ public async IAsyncEnumerable StreamAsync( var toolTagFilter = TagStripFilter.CreateStreamingFilter(); var (http, resp, body) = await ProviderRequestHandler.SendStreamingAsync( - httpClientFactory, url, jsonBytes, ConfigureHeaders, "OpenAI API stream", ct).ConfigureAwait(false); + httpClientFactory, url, jsonBytes, ConfigureHeaders, "OpenAI API stream", ct); using var _ = http; using var __ = resp; diff --git a/src/clawsharp/Providers/OpenRouter/OpenRouterProvider.cs b/src/clawsharp/Providers/OpenRouter/OpenRouterProvider.cs index 48ec1e85..a2dd6386 100644 --- a/src/clawsharp/Providers/OpenRouter/OpenRouterProvider.cs +++ b/src/clawsharp/Providers/OpenRouter/OpenRouterProvider.cs @@ -110,7 +110,7 @@ public async IAsyncEnumerable StreamAsync( var toolTagFilter = TagStripFilter.CreateStreamingFilter(); var (http, resp, body) = await ProviderRequestHandler.SendStreamingAsync( - httpClientFactory, ChatCompletionsUrl, jsonBytes, ConfigureHeaders, "OpenRouter API stream", ct).ConfigureAwait(false); + httpClientFactory, ChatCompletionsUrl, jsonBytes, ConfigureHeaders, "OpenRouter API stream", ct); using var _ = http; using var __ = resp; diff --git a/src/clawsharp/Providers/ProviderRequestHandler.cs b/src/clawsharp/Providers/ProviderRequestHandler.cs index 7d92eb10..11464742 100644 --- a/src/clawsharp/Providers/ProviderRequestHandler.cs +++ b/src/clawsharp/Providers/ProviderRequestHandler.cs @@ -202,12 +202,14 @@ internal static string SanitizeErrorBody(string raw) /// /// OpenAI-style keys: sk-[A-Za-z0-9]{20,} /// Anthropic-style keys: sk-ant-[A-Za-z0-9\-]{20,} + /// Gemini API keys: AIzaSy[A-Za-z0-9\-_]{33} (39 chars total) + /// AWS access key IDs: AKIA[A-Z0-9]{16} (20 chars total) /// Bearer tokens in echoed text: Bearer [^\s"]{20,} /// Generic long hex strings (40+ chars, likely keys): [0-9a-fA-F]{40,} /// /// [GeneratedRegex( - @"sk-ant-[A-Za-z0-9\-]{20,}|sk-[A-Za-z0-9]{20,}|key-[A-Za-z0-9]{20,}|Bearer\s+[^\s""]{20,}|[0-9a-fA-F]{40,}", + @"sk-ant-[A-Za-z0-9\-]{20,}|sk-[A-Za-z0-9]{20,}|key-[A-Za-z0-9]{20,}|AIzaSy[A-Za-z0-9\-_]{33}|AKIA[A-Z0-9]{16}|Bearer\s+[^\s""]{20,}|[0-9a-fA-F]{40,}", RegexOptions.None, 200)] private static partial Regex SecretPatternRegex(); } \ No newline at end of file diff --git a/src/clawsharp/Providers/TagStripFilter.cs b/src/clawsharp/Providers/TagStripFilter.cs index 6afbb13d..891f7c52 100644 --- a/src/clawsharp/Providers/TagStripFilter.cs +++ b/src/clawsharp/Providers/TagStripFilter.cs @@ -196,12 +196,11 @@ private void ProcessNormal(char ch, StringBuilder output) private void ProcessMaybeOpenTag(char ch, StringBuilder output) { _tagBuffer.Append(ch); - var buffered = _tagBuffer.ToString(); // Check if buffer fully matches an opening tag for (var i = 0; i < _openTags.Length; i++) { - if (string.Equals(buffered, _openTags[i], StringComparison.Ordinal)) + if (_tagBuffer.Equals(_openTags[i].AsSpan())) { // Full match -- enter the tag block _tagBuffer.Clear(); @@ -214,17 +213,35 @@ private void ProcessMaybeOpenTag(char ch, StringBuilder output) // Check if buffer is still a valid prefix of any opening tag for (var i = 0; i < _openTags.Length; i++) { - if (_openTags[i].StartsWith(buffered, StringComparison.Ordinal)) + if (IsPrefix(_tagBuffer, _openTags[i])) { // Still a valid prefix -- keep buffering return; } } - // Not a valid prefix of any tag -- flush buffer as normal text - output.Append(buffered); - _tagBuffer.Clear(); - _state = State.Normal; + // Not a valid prefix of any tag -- flush buffer as normal text. + // If the character that broke the prefix is '<', it could be the start + // of a new tag. Flush everything before it and re-enter MaybeOpenTag. + if (ch == '<') + { + // The buffer already has '<' appended (from line above), so flush + // everything except the trailing '<' and start a new tag match. + for (var j = 0; j < _tagBuffer.Length - 1; j++) + { + output.Append(_tagBuffer[j]); + } + + _tagBuffer.Clear(); + _tagBuffer.Append('<'); + _state = State.MaybeOpenTag; + } + else + { + output.Append(_tagBuffer); + _tagBuffer.Clear(); + _state = State.Normal; + } } private void ProcessInsideBlock(char ch) @@ -242,10 +259,9 @@ private void ProcessInsideBlock(char ch) private void ProcessMaybeCloseTag(char ch) { _tagBuffer.Append(ch); - var buffered = _tagBuffer.ToString(); var closeTag = _closeTags[_matchedTagIndex]; - if (string.Equals(buffered, closeTag, StringComparison.Ordinal)) + if (_tagBuffer.Equals(closeTag.AsSpan())) { // Full match on closing tag -- exit the block _tagBuffer.Clear(); @@ -254,7 +270,7 @@ private void ProcessMaybeCloseTag(char ch) return; } - if (closeTag.StartsWith(buffered, StringComparison.Ordinal)) + if (IsPrefix(_tagBuffer, closeTag)) { // Still a valid prefix of the closing tag -- keep buffering return; @@ -264,4 +280,20 @@ private void ProcessMaybeCloseTag(char ch) _tagBuffer.Clear(); _state = State.InsideBlock; } + + /// + /// Returns true if is a proper prefix of + /// (i.e. buffer is shorter than tag and all buffer characters match the corresponding tag characters). + /// Zero-allocation alternative to tag.StartsWith(buffer.ToString()). + /// + private static bool IsPrefix(StringBuilder buffer, string tag) + { + if (buffer.Length >= tag.Length) return false; + for (var i = 0; i < buffer.Length; i++) + { + if (buffer[i] != tag[i]) return false; + } + + return true; + } } \ No newline at end of file diff --git a/src/clawsharp/Security/AuditLogger.cs b/src/clawsharp/Security/AuditLogger.cs index 0494ddcc..6b934adc 100644 --- a/src/clawsharp/Security/AuditLogger.cs +++ b/src/clawsharp/Security/AuditLogger.cs @@ -2,6 +2,7 @@ using System.Text.Json; using Clawsharp.Config; using Clawsharp.Core.Events; +using Clawsharp.Core.Utilities; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Clawsharp.Config.Security; @@ -16,6 +17,8 @@ public sealed partial class AuditLogger : IDisposable { private readonly SemaphoreSlim _lock = new(1, 1); + private FileStream? _stream; + private readonly AuditConfig _config; private readonly ILogger _logger; @@ -33,7 +36,7 @@ public AuditLogger(IOptions options, ILogger logger, IEv var dir = Path.Combine( Environment.GetFolderPath(Environment.SpecialFolder.UserProfile), ".clawsharp"); - Directory.CreateDirectory(dir); + FilePermissions.EnsureRestrictedDirectory(dir); var defaultLogPath = Path.Combine(dir, "audit.log"); @@ -106,7 +109,7 @@ public async Task LogAsync(AuditEvent evt, CancellationToken ct = default) try { await RotateIfNeededAsync().ConfigureAwait(false); - await using var fs = new FileStream(_logPath, FileMode.Append, FileAccess.Write, FileShare.Read); + var fs = EnsureStreamOpen(); await fs.WriteAsync(jsonBytes, ct).ConfigureAwait(false); fs.WriteByte((byte)'\n'); await fs.FlushAsync(ct).ConfigureAwait(false); @@ -249,6 +252,19 @@ public Task LogFileAccessAsync( Result = new AuditResult { Success = success, Error = error }, }, ct); + private FileStream EnsureStreamOpen() + { + if (_stream is { CanWrite: true }) + { + return _stream; + } + + _stream = new FileStream( + _logPath, FileMode.Append, FileAccess.Write, FileShare.Read, + 4096, FileOptions.Asynchronous); + return _stream; + } + private async Task RotateIfNeededAsync() { if (!File.Exists(_logPath)) @@ -256,13 +272,22 @@ private async Task RotateIfNeededAsync() return; } - var info = new FileInfo(_logPath); + long currentLength = _stream is { CanWrite: true } + ? _stream.Length + : new FileInfo(_logPath).Length; long maxBytes = (long)_config.MaxSizeMb * 1024 * 1024; - if (info.Length < maxBytes) + if (currentLength < maxBytes) { return; } + // Close the held stream before rotation so the file handle is released. + if (_stream is not null) + { + await _stream.DisposeAsync().ConfigureAwait(false); + _stream = null; + } + // Rename audit.log.9 -> delete, .8 -> .9, ..., .1 -> .2, audit.log -> .1 for (var i = 9; i >= 1; i--) { @@ -309,7 +334,11 @@ private void PruneOldLogs() } } - public void Dispose() => _lock.Dispose(); + public void Dispose() + { + _stream?.Dispose(); + _lock.Dispose(); + } [LoggerMessage(EventId = 1, Level = LogLevel.Warning, Message = "Failed to write audit event")] private static partial void LogWriteAuditEventFailed(ILogger logger, Exception exception); diff --git a/src/clawsharp/Security/LeakDetector.cs b/src/clawsharp/Security/LeakDetector.cs index b11d8feb..d2036e84 100644 --- a/src/clawsharp/Security/LeakDetector.cs +++ b/src/clawsharp/Security/LeakDetector.cs @@ -63,6 +63,15 @@ public static LeakScanResult Scan(string content, double sensitivity = 0.7) [GeneratedRegex("""api[_\-]?key[=:]\s*['""]*[a-zA-Z0-9_\-]{20,}""", RegexOptions.IgnoreCase | RegexOptions.CultureInvariant, 200)] private static partial Regex GenericApiKeyRegex(); + [GeneratedRegex(@"xox[bpears]-[a-zA-Z0-9\-]{10,}", RegexOptions.CultureInvariant, 200)] + private static partial Regex SlackTokenRegex(); + + [GeneratedRegex(@"AccountKey=[a-zA-Z0-9+/=]{44,}", RegexOptions.CultureInvariant, 200)] + private static partial Regex AzureStorageKeyRegex(); + + [GeneratedRegex(@"\d{8,10}:[a-zA-Z0-9_\-]{35}", RegexOptions.CultureInvariant, 200)] + private static partial Regex TelegramBotTokenRegex(); + [GeneratedRegex(@"AKIA[A-Z0-9]{16}", RegexOptions.CultureInvariant, 200)] private static partial Regex AwsAccessKeyRegex(); @@ -107,6 +116,9 @@ private static void CheckApiKeys(ref string redacted, string original, List public static string GenerateSvg(string url) { - using var qrGenerator = new QRCoder.QRCodeGenerator(); - using var qrCodeData = qrGenerator.CreateQrCode(url, QRCoder.QRCodeGenerator.ECCLevel.M); + using var qrGenerator = new QRCodeGenerator(); + using var qrCodeData = qrGenerator.CreateQrCode(url, QRCodeGenerator.ECCLevel.M); var svgQrCode = new SvgQRCode(qrCodeData); return svgQrCode.GetGraphic(4); } diff --git a/src/clawsharp/Security/SecretStore.cs b/src/clawsharp/Security/SecretStore.cs index 124076da..07908dac 100644 --- a/src/clawsharp/Security/SecretStore.cs +++ b/src/clawsharp/Security/SecretStore.cs @@ -229,7 +229,15 @@ private static bool TryLoadFromFile(string keyPath, out byte[] key) } var hex = File.ReadAllText(keyPath).Trim(); - key = Convert.FromHexString(hex); + try + { + key = Convert.FromHexString(hex); + } + catch + { + throw new CryptographicException($"Secret key file at '{keyPath}' contains invalid hex data."); + } + if (key.Length != KeyLen) { throw new CryptographicException($"Secret key file at '{keyPath}' is invalid (expected {KeyLen * 2} hex chars)."); diff --git a/src/clawsharp/Security/ShellGuard.cs b/src/clawsharp/Security/ShellGuard.cs index 65b64c90..d3ed8f8d 100644 --- a/src/clawsharp/Security/ShellGuard.cs +++ b/src/clawsharp/Security/ShellGuard.cs @@ -214,7 +214,7 @@ private static string NormalizeCommand(string command) // Block on timeout — an attacker could craft ReDoS input to disable a deny rule. return $"Command blocked: custom deny pattern timed out (potential ReDoS): {pattern}"; } - catch + catch (ArgumentException) { // Invalid custom regex — skip silently } @@ -271,9 +271,9 @@ private static string NormalizeCommand(string command) return null; } } - catch + catch (RegexMatchTimeoutException) { - /* Timeout */ + // Timeout on auto-approve pattern — skip it (fail-closed: require approval) } } } @@ -288,9 +288,13 @@ private static string NormalizeCommand(string command) return null; } } - catch + catch (RegexMatchTimeoutException) + { + // Timeout on auto-approve pattern — skip it (fail-closed: require approval) + } + catch (ArgumentException) { - /* Invalid regex */ + // Invalid regex pattern — skip silently } } } @@ -323,9 +327,10 @@ private static string NormalizeCommand(string command) return regex.ToString(); } } - catch + catch (RegexMatchTimeoutException) { - /* Timeout */ + // Timeout on approval pattern — fail-closed: require approval + return regex.ToString(); } } } @@ -340,9 +345,14 @@ private static string NormalizeCommand(string command) return pattern; } } - catch + catch (RegexMatchTimeoutException) + { + // Timeout on approval pattern — fail-closed: require approval + return pattern; + } + catch (ArgumentException) { - /* Invalid regex */ + // Invalid regex pattern — skip silently } } } @@ -603,7 +613,7 @@ public static void SanitizeEnvironment(System.Diagnostics.ProcessStartInfo psi) [GeneratedRegex(@"\bsudo\b", RegexOptions.IgnoreCase, 200)] private static partial Regex DenySudo(); - [GeneratedRegex(@"\bchmod\s+[0-7]{3,4}\b", RegexOptions.IgnoreCase, 200)] + [GeneratedRegex(@"\bchmod\s+([0-7]{3,4}\b|[ugoaUGOA]*[+\-=][rwxXst]+)", RegexOptions.None, 200)] private static partial Regex DenyChmod(); [GeneratedRegex(@"\bchown\b", RegexOptions.IgnoreCase, 200)] @@ -666,7 +676,7 @@ public static void SanitizeEnvironment(System.Diagnostics.ProcessStartInfo psi) [GeneratedRegex(@"[<>]\([^)]*\)", RegexOptions.None, 200)] private static partial Regex DenyProcessSubstitution(); - [GeneratedRegex(@"\bln\b", RegexOptions.IgnoreCase, 200)] + [GeneratedRegex(@"\bln\s+", RegexOptions.IgnoreCase, 200)] private static partial Regex DenyLn(); [GeneratedRegex(@"\bmkfifo\b", RegexOptions.IgnoreCase, 200)] diff --git a/src/clawsharp/Security/SsrfGuard.cs b/src/clawsharp/Security/SsrfGuard.cs index f4841c9a..d02f3253 100644 --- a/src/clawsharp/Security/SsrfGuard.cs +++ b/src/clawsharp/Security/SsrfGuard.cs @@ -21,6 +21,20 @@ public static class SsrfGuard /// Configures the egress policy at startup. Call once with the egress section /// from . Pass null to keep the default open policy. /// + /// + /// + /// Threading contract: This method must be called exactly once during application + /// startup, before any async work begins (i.e., before hosted services start). The static + /// _egressConfig field uses a volatile write to ensure visibility across + /// threads, but no further synchronization is provided. Calling this method after + /// or have begun executing + /// on other threads may result in inconsistent policy enforcement. + /// + /// + /// The typical call site is GatewayHost.BuildHost(), which runs synchronously + /// before the generic host starts its hosted services. + /// + /// public static void Configure(EgressConfig? config) { _egressConfig = config; diff --git a/src/clawsharp/Security/WebPairingGuard.cs b/src/clawsharp/Security/WebPairingGuard.cs index 842fd17e..d382b5ed 100644 --- a/src/clawsharp/Security/WebPairingGuard.cs +++ b/src/clawsharp/Security/WebPairingGuard.cs @@ -17,6 +17,13 @@ internal sealed partial class WebPairingGuard { private const int MaxFailedAttempts = 5; + /// + /// Maximum number of global failed pairing attempts across all IPs before the current + /// pairing code is invalidated. Prevents distributed brute-force attacks where a botnet + /// makes N x attempts from different IPs before code expiry. + /// + private const int MaxGlobalAttempts = 50; + private const int MaxFailureTrackingEntries = 10_000; private static readonly TimeSpan LockoutDuration = TimeSpan.FromMinutes(5); @@ -31,6 +38,9 @@ internal sealed partial class WebPairingGuard private readonly string _persistPath; + /// Global failed attempt counter for the current pairing code (across all IPs). + private int _globalAttempts; + private string? _pairingCode; public WebPairingGuard(string persistPath, ILogger logger) @@ -111,10 +121,21 @@ public bool IsAuthenticated(string token) Encoding.UTF8.GetBytes(_pairingCode))) { RecordFailure(clientIp); + + // Global attempt counter: invalidate the pairing code after too many + // failed attempts across all IPs to defeat distributed brute-force. + _globalAttempts++; + if (_globalAttempts >= MaxGlobalAttempts && _pairingCode is not null) + { + LogPairingCodeInvalidated(_logger, MaxGlobalAttempts); + _pairingCode = null; + } + return null; } _pairingCode = null; // one-time use — consumed + _globalAttempts = 0; // reset on successful pairing token = GenerateToken(); _hashes.Add(HashToken(token)); } @@ -129,6 +150,7 @@ public string RegenerateCode() lock (_lock) { _pairingCode = NewCode(); + _globalAttempts = 0; return _pairingCode; } } @@ -267,7 +289,13 @@ private void SaveToDisk() [LoggerMessage(EventId = 2, Level = LogLevel.Warning, Message = "Corrupt or unreadable token file at '{FilePath}', starting fresh: {Reason}")] private static partial void LogCorruptTokenFile(ILogger logger, string filePath, string reason); + + [LoggerMessage(EventId = 3, Level = LogLevel.Warning, + Message = "Pairing code invalidated after {AttemptCount} global failed attempts (possible distributed brute-force). " + + "Use 'regenerateCode' to generate a new pairing code.")] + private static partial void LogPairingCodeInvalidated(ILogger logger, int attemptCount); } [JsonSerializable(typeof(List))] +[JsonSourceGenerationOptions(DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull)] internal sealed partial class WebPairingGuardJsonContext : JsonSerializerContext; \ No newline at end of file diff --git a/src/clawsharp/Telemetry/ClawsharpMetrics.cs b/src/clawsharp/Telemetry/ClawsharpMetrics.cs index a18db4c6..ad6e138c 100644 --- a/src/clawsharp/Telemetry/ClawsharpMetrics.cs +++ b/src/clawsharp/Telemetry/ClawsharpMetrics.cs @@ -19,6 +19,16 @@ public struct GenAiMetricTags public string TokenType { get; set; } } +/// Tag structure for GenAI operation duration (no token.type dimension per OTel semconv). +public struct DurationMetricTags +{ + [TagName("gen_ai.operation.name")] + public string OperationName { get; set; } + + [TagName("gen_ai.request.model")] + public string Model { get; set; } +} + /// Tag structure for pipeline-level metrics. public struct PipelineMetricTags { @@ -93,7 +103,7 @@ public static partial class ClawsharpMetrics /// Time-to-first-token histogram (reserved for Plan 02). public static readonly TtftHistogram Ttft = CreateTtftHistogram(GenAiMeter); - /// Tokens-per-output-token histogram (reserved for Plan 02). + /// Time-per-output-token histogram (reserved for Plan 02). public static readonly TpotHistogram Tpot = CreateTpotHistogram(GenAiMeter); // ── Active session gauge ───────────────────────────────────────────── @@ -118,7 +128,7 @@ public static void InitializeSessionGauge(Func sessionCountProvider) [Histogram(typeof(GenAiMetricTags), Name = "gen_ai.client.token.usage", Unit = "{token}")] public static partial TokenUsage CreateTokenUsage(Meter meter); - [Histogram(typeof(GenAiMetricTags), Name = "gen_ai.client.operation.duration", Unit = "s")] + [Histogram(typeof(DurationMetricTags), Name = "gen_ai.client.operation.duration", Unit = "s")] public static partial OperationDuration CreateOperationDuration(Meter meter); // ── Pipeline metrics ──────────────────────────────────────────────── @@ -145,7 +155,7 @@ public static void InitializeSessionGauge(Func sessionCountProvider) [Histogram(typeof(StreamingMetricTags), Name = "gen_ai.client.time_to_first_token", Unit = "s")] public static partial TtftHistogram CreateTtftHistogram(Meter meter); - [Histogram(typeof(StreamingMetricTags), Name = "gen_ai.client.tokens_per_output_token", Unit = "s")] + [Histogram(typeof(StreamingMetricTags), Name = "gen_ai.client.time_per_output_token", Unit = "s")] public static partial TpotHistogram CreateTpotHistogram(Meter meter); // ── ObservableGauge (MET-05: active session count) ─────────────────── diff --git a/src/clawsharp/Telemetry/TelemetryExtensions.cs b/src/clawsharp/Telemetry/TelemetryExtensions.cs index 02923658..50f46a73 100644 --- a/src/clawsharp/Telemetry/TelemetryExtensions.cs +++ b/src/clawsharp/Telemetry/TelemetryExtensions.cs @@ -1,4 +1,3 @@ -using System.Reflection; using Clawsharp.Config; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; @@ -34,13 +33,9 @@ internal static IHostBuilder AddClawsharpTelemetry(this IHostBuilder builder, Te var otel = services.AddOpenTelemetry() .ConfigureResource(r => { - var version = typeof(TelemetryExtensions).Assembly - .GetCustomAttribute() - ?.InformationalVersion ?? "0.0.0"; - r.AddService( serviceName: config.ServiceName ?? "clawsharp", - serviceVersion: version); + serviceVersion: TelemetryConstants.Version); if (config.Environment is not null) { diff --git a/src/clawsharp/Tools/Browser/BrowserSession.cs b/src/clawsharp/Tools/Browser/BrowserSession.cs index be4f4420..ff573f21 100644 --- a/src/clawsharp/Tools/Browser/BrowserSession.cs +++ b/src/clawsharp/Tools/Browser/BrowserSession.cs @@ -119,7 +119,7 @@ public async ValueTask DisposeAsync() { try { - await _page.CloseAsync(); + await _page.CloseAsync().ConfigureAwait(false); } catch { @@ -131,7 +131,7 @@ public async ValueTask DisposeAsync() { try { - await _context.CloseAsync(); + await _context.CloseAsync().ConfigureAwait(false); } catch { @@ -143,7 +143,7 @@ public async ValueTask DisposeAsync() { try { - await _browser.CloseAsync(); + await _browser.CloseAsync().ConfigureAwait(false); } catch { diff --git a/src/clawsharp/Tools/Browser/BrowserTool.cs b/src/clawsharp/Tools/Browser/BrowserTool.cs index 5211e541..5339201c 100644 --- a/src/clawsharp/Tools/Browser/BrowserTool.cs +++ b/src/clawsharp/Tools/Browser/BrowserTool.cs @@ -150,14 +150,14 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat { return action switch { - "navigate" => await NavigateAsync(arguments, sessionId, ct), - "snapshot" => await SnapshotAsync(arguments, sessionId, ct), - "click" => await ClickAsync(arguments, sessionId, ct), - "type" => await TypeAsync(arguments, sessionId, ct), - "select" => await SelectAsync(arguments, sessionId, ct), - "screenshot" => await ScreenshotAsync(arguments, sessionId, ct), - "evaluate" => await EvaluateAsync(arguments, sessionId, ct), - "close" => await CloseAsync(sessionId), + "navigate" => await NavigateAsync(arguments, sessionId, ct).ConfigureAwait(false), + "snapshot" => await SnapshotAsync(arguments, sessionId, ct).ConfigureAwait(false), + "click" => await ClickAsync(arguments, sessionId, ct).ConfigureAwait(false), + "type" => await TypeAsync(arguments, sessionId, ct).ConfigureAwait(false), + "select" => await SelectAsync(arguments, sessionId, ct).ConfigureAwait(false), + "screenshot" => await ScreenshotAsync(arguments, sessionId, ct).ConfigureAwait(false), + "evaluate" => await EvaluateAsync(arguments, sessionId, ct).ConfigureAwait(false), + "close" => await CloseAsync(sessionId).ConfigureAwait(false), _ => $"Error: unknown browser action '{action}'. " + "Valid actions: navigate, snapshot, click, type, select, screenshot, evaluate, close.", }; diff --git a/src/clawsharp/Tools/Browser/PinchTabTool.cs b/src/clawsharp/Tools/Browser/PinchTabTool.cs index 27cdfc60..cebc91d0 100644 --- a/src/clawsharp/Tools/Browser/PinchTabTool.cs +++ b/src/clawsharp/Tools/Browser/PinchTabTool.cs @@ -88,19 +88,19 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat { return action switch { - "navigate" => await NavigateAsync(arguments, sessionId, ct), - "snapshot" => await SnapshotAsync(arguments, sessionId, ct), - "text" => await TextAsync(sessionId, ct), - "click" => await ActionAsync("click", arguments, sessionId, ct), - "fill" => await ActionAsync("fill", arguments, sessionId, ct), - "type" => await ActionAsync("type", arguments, sessionId, ct), - "press" => await ActionAsync("press", arguments, sessionId, ct), - "select" => await ActionAsync("select", arguments, sessionId, ct), - "scroll" => await ActionAsync("scroll", arguments, sessionId, ct), - "screenshot" => await ScreenshotAsync(sessionId, ct), - "evaluate" => await EvaluateAsync(arguments, sessionId, ct), - "tabs" => await TabsAsync(ct), - "close" => await CloseAsync(sessionId, ct), + "navigate" => await NavigateAsync(arguments, sessionId, ct).ConfigureAwait(false), + "snapshot" => await SnapshotAsync(arguments, sessionId, ct).ConfigureAwait(false), + "text" => await TextAsync(sessionId, ct).ConfigureAwait(false), + "click" => await ActionAsync("click", arguments, sessionId, ct).ConfigureAwait(false), + "fill" => await ActionAsync("fill", arguments, sessionId, ct).ConfigureAwait(false), + "type" => await ActionAsync("type", arguments, sessionId, ct).ConfigureAwait(false), + "press" => await ActionAsync("press", arguments, sessionId, ct).ConfigureAwait(false), + "select" => await ActionAsync("select", arguments, sessionId, ct).ConfigureAwait(false), + "scroll" => await ActionAsync("scroll", arguments, sessionId, ct).ConfigureAwait(false), + "screenshot" => await ScreenshotAsync(sessionId, ct).ConfigureAwait(false), + "evaluate" => await EvaluateAsync(arguments, sessionId, ct).ConfigureAwait(false), + "tabs" => await TabsAsync(ct).ConfigureAwait(false), + "close" => await CloseAsync(sessionId, ct).ConfigureAwait(false), _ => $"Error: unknown pinchtab action '{action}'. " + "Valid actions: navigate, snapshot, text, click, fill, type, press, select, scroll, screenshot, evaluate, tabs, close.", }; diff --git a/src/clawsharp/Tools/Browser/ScreenshotTool.cs b/src/clawsharp/Tools/Browser/ScreenshotTool.cs index a698ba71..9aede823 100644 --- a/src/clawsharp/Tools/Browser/ScreenshotTool.cs +++ b/src/clawsharp/Tools/Browser/ScreenshotTool.cs @@ -82,11 +82,11 @@ public override async Task ExecuteAsync(JsonElement args, CancellationTo using var proc = Process.Start(psi) ?? throw new InvalidOperationException("Could not start capture process."); - await proc.WaitForExitAsync(cts.Token); + await proc.WaitForExitAsync(cts.Token).ConfigureAwait(false); if (proc.ExitCode != 0) { - var err = await proc.StandardError.ReadToEndAsync(cts.Token); + var err = await proc.StandardError.ReadToEndAsync(cts.Token).ConfigureAwait(false); return $"Error: capture failed (exit {proc.ExitCode}): {err.Trim()}"; } diff --git a/src/clawsharp/Tools/Files/FileEditTool.cs b/src/clawsharp/Tools/Files/FileEditTool.cs index e5c8cee6..cbd3de31 100644 --- a/src/clawsharp/Tools/Files/FileEditTool.cs +++ b/src/clawsharp/Tools/Files/FileEditTool.cs @@ -1,3 +1,4 @@ +using System.Text; using System.Text.Json; using Clawsharp.Security; @@ -74,7 +75,10 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat return "Error: path is outside the workspace."; } - var content = await File.ReadAllTextAsync(fullPath, ct); + // CRIT-02: Open the file handle, then verify the actual path via /proc/self/fd/ + // to close the TOCTOU race window between VerifyNotSymlinkEscape and file I/O. + // On non-Linux, the VerifyNotSymlinkEscape check above is the best we can do. + var content = await ReadVerifiedAsync(fullPath, ct).ConfigureAwait(false); var idx = content.IndexOf(oldText, StringComparison.Ordinal); if (idx < 0) @@ -95,7 +99,10 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat updated = string.Concat(content.AsSpan(0, idx), newText, content.AsSpan(idx + oldText.Length)); } - await File.WriteAllTextAsync(fullPath, updated, ct); + await using var writeFs = new FileStream(fullPath, FileMode.Create, FileAccess.Write, FileShare.None); + PathGuard.VerifyFileDescriptorPath(writeFs, _workspace); + await using var writer = new StreamWriter(writeFs, Encoding.UTF8); + await writer.WriteAsync(updated.AsMemory(), ct).ConfigureAwait(false); if (auditLogger is not null) { @@ -105,6 +112,14 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat return $"Replaced {count} occurrence(s) in {rel}"; } + private async Task ReadVerifiedAsync(string fullPath, CancellationToken ct) + { + await using var fs = new FileStream(fullPath, FileMode.Open, FileAccess.Read, FileShare.Read); + PathGuard.VerifyFileDescriptorPath(fs, _workspace); + using var reader = new StreamReader(fs, Encoding.UTF8); + return await reader.ReadToEndAsync(ct).ConfigureAwait(false); + } + private static int CountOccurrences(string text, string pattern) { var count = 0; diff --git a/src/clawsharp/Tools/Files/FileReadTool.cs b/src/clawsharp/Tools/Files/FileReadTool.cs index 8712b646..5f99ac01 100644 --- a/src/clawsharp/Tools/Files/FileReadTool.cs +++ b/src/clawsharp/Tools/Files/FileReadTool.cs @@ -81,7 +81,7 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat return "Error: path is outside the workspace."; } - var content = await File.ReadAllTextAsync(path, ct); + var content = await File.ReadAllTextAsync(path, ct).ConfigureAwait(false); if (content.Length > maxChars) { content = content[..maxChars] + $"\n... (truncated at {maxChars:N0} chars, file has {content.Length:N0} total)"; diff --git a/src/clawsharp/Tools/Files/FileSearchTool.cs b/src/clawsharp/Tools/Files/FileSearchTool.cs index 3a835d44..54efd880 100644 --- a/src/clawsharp/Tools/Files/FileSearchTool.cs +++ b/src/clawsharp/Tools/Files/FileSearchTool.cs @@ -84,7 +84,7 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat continue; } - var lines = await File.ReadAllLinesAsync(file, ct); + var lines = await File.ReadAllLinesAsync(file, ct).ConfigureAwait(false); for (var i = 0; i < lines.Length; i++) { if (lines[i].Contains(pattern, StringComparison.OrdinalIgnoreCase)) diff --git a/src/clawsharp/Tools/Files/FileWriteTool.cs b/src/clawsharp/Tools/Files/FileWriteTool.cs index e0a2c54d..b9d8a0ee 100644 --- a/src/clawsharp/Tools/Files/FileWriteTool.cs +++ b/src/clawsharp/Tools/Files/FileWriteTool.cs @@ -81,14 +81,14 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat await using var fs = new FileStream(fullPath, FileMode.Append, FileAccess.Write, FileShare.None); PathGuard.VerifyFileDescriptorPath(fs, _workspace); await using var writer = new StreamWriter(fs); - await writer.WriteAsync(content.AsMemory(), ct); + await writer.WriteAsync(content.AsMemory(), ct).ConfigureAwait(false); } else { await using var fs = new FileStream(fullPath, FileMode.Create, FileAccess.Write, FileShare.None); PathGuard.VerifyFileDescriptorPath(fs, _workspace); await using var writer = new StreamWriter(fs); - await writer.WriteAsync(content.AsMemory(), ct); + await writer.WriteAsync(content.AsMemory(), ct).ConfigureAwait(false); } if (auditLogger is not null) diff --git a/src/clawsharp/Tools/IToolRegistry.cs b/src/clawsharp/Tools/IToolRegistry.cs index 815a7a11..70fda84e 100644 --- a/src/clawsharp/Tools/IToolRegistry.cs +++ b/src/clawsharp/Tools/IToolRegistry.cs @@ -11,6 +11,9 @@ public interface IToolRegistry /// Registers a tool dynamically (e.g. from an MCP server). void Register(Tool tool); + /// Removes a previously registered tool by name. Returns true if the tool was found and removed. + bool Unregister(string toolName); + void SetChannelContext(ChannelName channelName, int spawnDepth = 0, string? sessionId = null, OrgUser? orgUser = null, PolicyDecision? policyDecision = null); /// Sets the spawn permission scope for audit trail tracking in sub-agent flows. diff --git a/src/clawsharp/Tools/Knowledge/KnowledgeSearchTool.cs b/src/clawsharp/Tools/Knowledge/KnowledgeSearchTool.cs index fc5d78dc..f9689ea3 100644 --- a/src/clawsharp/Tools/Knowledge/KnowledgeSearchTool.cs +++ b/src/clawsharp/Tools/Knowledge/KnowledgeSearchTool.cs @@ -1,3 +1,4 @@ +using System.Diagnostics; using System.Text; using System.Text.Json; using Clawsharp.Config.Organization; @@ -6,6 +7,7 @@ using Clawsharp.Memory; using Clawsharp.Memory.Entities; using Clawsharp.Organization; +using Clawsharp.Telemetry; using Microsoft.Extensions.Logging; namespace Clawsharp.Tools.Knowledge; @@ -101,7 +103,7 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat HashSet? validSourceIds = null; if (sourcesFilter is { Length: > 0 }) { - var allSources = await store.ListSourcesAsync(ct); + var allSources = await store.ListSourcesAsync(ct).ConfigureAwait(false); var sourceMap = allSources.ToDictionary( s => s.SourceTitle, s => s.Id, StringComparer.OrdinalIgnoreCase); var invalidNames = sourcesFilter.Where(n => !sourceMap.ContainsKey(n)).ToList(); @@ -118,19 +120,27 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat float[]? queryEmbedding = null; if (embeddingProvider is not null) { + using var embedSpan = ClawsharpActivitySources.Knowledge.StartActivity("knowledge.search.embed"); try { - queryEmbedding = await embeddingProvider.EmbedAsync(query, ct); + queryEmbedding = await embeddingProvider.EmbedAsync(query, ct).ConfigureAwait(false); } catch (Exception ex) { + embedSpan?.SetStatus(ActivityStatusCode.Error, ex.Message); LogEmbeddingFallback(logger, ex); } } // Step 5: Search with over-retrieval (D-27, D-38) var candidateCount = retrievalConfig.CandidateMultiplier * topK; - var results = await store.SearchAsync(queryEmbedding, query, acl, candidateCount, ct); + IReadOnlyList results; + using (var searchSpan = ClawsharpActivitySources.Knowledge.StartActivity("knowledge.search.query")) + { + results = await store.SearchAsync(queryEmbedding, query, acl, candidateCount, ct).ConfigureAwait(false); + searchSpan?.SetTag("knowledge.search.candidate_count", candidateCount); + searchSpan?.SetTag("knowledge.search.result_count", results.Count); + } // Step 6: Post-filter by sources (D-03) if (validSourceIds is not null) @@ -144,7 +154,13 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat return "No relevant knowledge found."; } - var reranked = await reranker.RerankAsync(query, results, topK, ct); + IReadOnlyList reranked; + using (var rerankSpan = ClawsharpActivitySources.Knowledge.StartActivity("knowledge.search.rerank")) + { + reranked = await reranker.RerankAsync(query, results, topK, ct).ConfigureAwait(false); + rerankSpan?.SetTag("knowledge.search.rerank_input", results.Count); + rerankSpan?.SetTag("knowledge.search.rerank_output", reranked.Count); + } // Step 8: Format results with source attribution (D-08, D-09, D-10) return FormatResults(reranked); diff --git a/src/clawsharp/Tools/Mcp/McpClient.cs b/src/clawsharp/Tools/Mcp/McpClient.cs index ba734f03..af0af2f7 100644 --- a/src/clawsharp/Tools/Mcp/McpClient.cs +++ b/src/clawsharp/Tools/Mcp/McpClient.cs @@ -168,7 +168,7 @@ public async Task CallToolAsync(string toolName, string argumentsJson, C } catch (Exception ex) { - _logger.LogWarning(ex, "MCP tool call failed"); + LogToolCallFailed(_logger, ex); return "Error: MCP tool call failed."; } @@ -228,4 +228,8 @@ public async ValueTask DisposeAsync() [LoggerMessage(EventId = 8, Level = LogLevel.Warning, Message = "MCP server '{ServerName}' tools/list error: [{Code}] {Message}")] private static partial void LogToolsListError(ILogger logger, string serverName, int code, string message); + + [LoggerMessage(EventId = 9, Level = LogLevel.Warning, + Message = "MCP tool call failed")] + private static partial void LogToolCallFailed(ILogger logger, Exception exception); } \ No newline at end of file diff --git a/src/clawsharp/Tools/Mcp/McpHostedService.cs b/src/clawsharp/Tools/Mcp/McpHostedService.cs index 416841e2..c28d6abb 100644 --- a/src/clawsharp/Tools/Mcp/McpHostedService.cs +++ b/src/clawsharp/Tools/Mcp/McpHostedService.cs @@ -1,3 +1,4 @@ +using System.Collections.Concurrent; using Clawsharp.Config; using Clawsharp.Core.Services; using Microsoft.Extensions.Logging; @@ -23,6 +24,8 @@ public sealed partial class McpHostedService( private readonly List _servers = []; + private readonly ConcurrentDictionary> _serverToolNames = new(StringComparer.OrdinalIgnoreCase); + /// Maximum backoff delay between restart attempts. private static readonly TimeSpan MaxBackoff = TimeSpan.FromSeconds(60); @@ -110,14 +113,27 @@ private async Task StartServerAsync(ManagedMcpServer managed, CancellationToken managed.Client = client; managed.RestartCount = 0; // Reset on successful start + // Deregister previously registered tools for this server (handles restart with changed tool list) + if (_serverToolNames.TryGetValue(managed.Name, out var previousTools)) + { + foreach (var previousToolName in previousTools) + { + toolRegistry.Unregister(previousToolName); + } + } + // Register each discovered tool in the tool registry var toolSensitivity = ParseMcpSensitivity(managed.Config.Sensitivity); + var registeredToolNames = new List(client.Tools.Count); foreach (var tool in client.Tools) { var adapter = new McpToolAdapter(client, tool, toolSensitivity); toolRegistry.Register(adapter); + registeredToolNames.Add(tool.Name); LogToolRegistered(logger, tool.Name, managed.Name); } + + _serverToolNames[managed.Name] = registeredToolNames; } /// @@ -248,6 +264,15 @@ public override async Task StopAsync(CancellationToken cancellationToken) foreach (var managed in _servers) { + // Deregister tools before disposing the server + if (_serverToolNames.TryRemove(managed.Name, out var toolNames)) + { + foreach (var toolName in toolNames) + { + toolRegistry.Unregister(toolName); + } + } + if (managed.Client is null) { continue; diff --git a/src/clawsharp/Tools/Mcp/McpInitializeResult.cs b/src/clawsharp/Tools/Mcp/McpInitializeResult.cs deleted file mode 100644 index 787fe8e0..00000000 --- a/src/clawsharp/Tools/Mcp/McpInitializeResult.cs +++ /dev/null @@ -1,17 +0,0 @@ -namespace Clawsharp.Tools.Mcp; - -/// Result payload for the MCP initialize response (MCP 2025-03-26). -public sealed class McpInitializeResult -{ - /// The protocol version the server supports. - public string ProtocolVersion { get; init; } = "2025-03-26"; - - /// Server capabilities declaration. - public McpServerCapabilities Capabilities { get; init; } = new(); - - /// Server identification. - public McpServerInfo ServerInfo { get; init; } = new(); - - /// Optional instructions for the client on how to use this server. - public string? Instructions { get; init; } -} diff --git a/src/clawsharp/Tools/Mcp/McpJsonContext.cs b/src/clawsharp/Tools/Mcp/McpJsonContext.cs index 25f448c3..e55a645d 100644 --- a/src/clawsharp/Tools/Mcp/McpJsonContext.cs +++ b/src/clawsharp/Tools/Mcp/McpJsonContext.cs @@ -15,13 +15,7 @@ namespace Clawsharp.Tools.Mcp; [JsonSerializable(typeof(McpClientInfo))] [JsonSerializable(typeof(McpCapabilities))] [JsonSerializable(typeof(McpCallToolParams))] -// Server-side DTOs (MCP 2025-03-26) -[JsonSerializable(typeof(McpInitializeResult))] -[JsonSerializable(typeof(McpServerInfo))] -[JsonSerializable(typeof(McpServerCapabilities))] -[JsonSerializable(typeof(McpToolsCapability))] -[JsonSerializable(typeof(McpToolAnnotations))] [JsonSourceGenerationOptions( PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull)] -internal partial class McpJsonContext : JsonSerializerContext; \ No newline at end of file +internal partial class McpJsonContext : JsonSerializerContext; diff --git a/src/clawsharp/Tools/Mcp/McpServerCapabilities.cs b/src/clawsharp/Tools/Mcp/McpServerCapabilities.cs deleted file mode 100644 index 0345c68d..00000000 --- a/src/clawsharp/Tools/Mcp/McpServerCapabilities.cs +++ /dev/null @@ -1,15 +0,0 @@ -namespace Clawsharp.Tools.Mcp; - -/// Server capabilities for the MCP initialize response. -public sealed class McpServerCapabilities -{ - /// Tools capability declaration. Null if no tools are offered. - public McpToolsCapability? Tools { get; init; } -} - -/// Tools capability declaration within server capabilities. -public sealed class McpToolsCapability -{ - /// Whether the server supports notifications/tools/list_changed. - public bool ListChanged { get; init; } -} diff --git a/src/clawsharp/Tools/Mcp/McpServerInfo.cs b/src/clawsharp/Tools/Mcp/McpServerInfo.cs deleted file mode 100644 index e3a429b9..00000000 --- a/src/clawsharp/Tools/Mcp/McpServerInfo.cs +++ /dev/null @@ -1,11 +0,0 @@ -namespace Clawsharp.Tools.Mcp; - -/// Server identification for the MCP initialize response. -public sealed class McpServerInfo -{ - /// The server name. Defaults to "clawsharp". - public string Name { get; init; } = "clawsharp"; - - /// Optional server version string. - public string? Version { get; init; } -} diff --git a/src/clawsharp/Tools/Mcp/McpToolAnnotations.cs b/src/clawsharp/Tools/Mcp/McpToolAnnotations.cs deleted file mode 100644 index 109bc031..00000000 --- a/src/clawsharp/Tools/Mcp/McpToolAnnotations.cs +++ /dev/null @@ -1,17 +0,0 @@ -namespace Clawsharp.Tools.Mcp; - -/// Tool annotations per MCP 2025-03-26 spec, providing hints about tool behavior. -public sealed class McpToolAnnotations -{ - /// Hint: tool does not modify state when true. - public bool? ReadOnlyHint { get; init; } - - /// Hint: tool may cause irreversible changes when true. - public bool? DestructiveHint { get; init; } - - /// Hint: calling the tool multiple times with the same arguments produces the same result. - public bool? IdempotentHint { get; init; } - - /// Hint: tool interacts with external entities outside the local environment. - public bool? OpenWorldHint { get; init; } -} diff --git a/src/clawsharp/Tools/Mcp/SseMcpTransport.cs b/src/clawsharp/Tools/Mcp/SseMcpTransport.cs index f683a86e..5b2b23b1 100644 --- a/src/clawsharp/Tools/Mcp/SseMcpTransport.cs +++ b/src/clawsharp/Tools/Mcp/SseMcpTransport.cs @@ -1,6 +1,7 @@ using System.Collections.Concurrent; using System.Text; using System.Text.Json; +using Clawsharp.Core.Utilities; using Microsoft.Extensions.Logging; namespace Clawsharp.Tools.Mcp; @@ -95,7 +96,7 @@ public async Task SendRequestAsync(string method, JsonElement? para var endpointUri = ResolveEndpointUri(_messageEndpoint); using var httpRequest = new HttpRequestMessage(HttpMethod.Post, endpointUri) { - Content = new StringContent(json, Encoding.UTF8, "application/json") + Content = Utf8JsonContent.FromString(json) }; foreach (var (key, value) in _headers) @@ -145,7 +146,7 @@ public async Task SendNotificationAsync(string method, JsonElement? parameters, var endpointUri = ResolveEndpointUri(_messageEndpoint); using var httpRequest = new HttpRequestMessage(HttpMethod.Post, endpointUri) { - Content = new StringContent(json, Encoding.UTF8, "application/json") + Content = Utf8JsonContent.FromString(json) }; foreach (var (key, value) in _headers) diff --git a/src/clawsharp/Tools/Mcp/StreamableHttpMcpTransport.cs b/src/clawsharp/Tools/Mcp/StreamableHttpMcpTransport.cs index 167132fd..5b86d009 100644 --- a/src/clawsharp/Tools/Mcp/StreamableHttpMcpTransport.cs +++ b/src/clawsharp/Tools/Mcp/StreamableHttpMcpTransport.cs @@ -1,6 +1,7 @@ using System.Net.Http.Headers; using System.Text; using System.Text.Json; +using Clawsharp.Core.Utilities; using Microsoft.Extensions.Logging; namespace Clawsharp.Tools.Mcp; @@ -45,7 +46,7 @@ public async Task SendRequestAsync(string method, JsonElement? para using var httpRequest = new HttpRequestMessage(HttpMethod.Post, endpointUri) { - Content = new StringContent(json, Encoding.UTF8, "application/json") + Content = Utf8JsonContent.FromString(json) }; httpRequest.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); @@ -109,7 +110,7 @@ public async Task SendNotificationAsync(string method, JsonElement? parameters, using var httpRequest = new HttpRequestMessage(HttpMethod.Post, endpointUri) { - Content = new StringContent(json, Encoding.UTF8, "application/json") + Content = Utf8JsonContent.FromString(json) }; foreach (var (key, value) in _headers) diff --git a/src/clawsharp/Tools/Memory/HistoryAppendTool.cs b/src/clawsharp/Tools/Memory/HistoryAppendTool.cs index e837f173..86592314 100644 --- a/src/clawsharp/Tools/Memory/HistoryAppendTool.cs +++ b/src/clawsharp/Tools/Memory/HistoryAppendTool.cs @@ -27,7 +27,7 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat return "Error: summary is required."; } - await memory.AppendHistoryAsync(summary, ct); + await memory.AppendHistoryAsync(summary, ct).ConfigureAwait(false); return "History updated."; } } \ No newline at end of file diff --git a/src/clawsharp/Tools/Memory/MemoryReadTool.cs b/src/clawsharp/Tools/Memory/MemoryReadTool.cs index 0f01ccf0..86c0aa2e 100644 --- a/src/clawsharp/Tools/Memory/MemoryReadTool.cs +++ b/src/clawsharp/Tools/Memory/MemoryReadTool.cs @@ -15,7 +15,7 @@ public sealed class MemoryReadTool(IMemory memory) : Tool public override async Task ExecuteAsync(JsonElement arguments, CancellationToken ct = default) { - var ctx = await memory.GetContextAsync(ct); + var ctx = await memory.GetContextAsync(ct).ConfigureAwait(false); return ctx ?? "(memory is empty)"; } } \ No newline at end of file diff --git a/src/clawsharp/Tools/Memory/MemorySearchTool.cs b/src/clawsharp/Tools/Memory/MemorySearchTool.cs index 5425f50c..228f5de9 100644 --- a/src/clawsharp/Tools/Memory/MemorySearchTool.cs +++ b/src/clawsharp/Tools/Memory/MemorySearchTool.cs @@ -16,7 +16,7 @@ public sealed class MemorySearchTool(IMemory memory) : Tool "type": "object", "properties": { "query": { "type": "string", "description": "Search query" }, - "n": { "type": "integer", "description": "Number of results (default 5)" } + "top_k": { "type": "integer", "description": "Number of results (default 5)" } }, "required": ["query"] } @@ -25,13 +25,13 @@ public sealed class MemorySearchTool(IMemory memory) : Tool public override async Task ExecuteAsync(JsonElement arguments, CancellationToken ct = default) { var query = arguments.TryGetProperty("query", out var q) ? q.GetString() ?? "" : ""; - var n = arguments.TryGetProperty("n", out var nProp) && nProp.TryGetInt32(out var nVal) ? nVal : 5; + var n = arguments.TryGetProperty("top_k", out var nProp) && nProp.TryGetInt32(out var nVal) ? nVal : 5; if (string.IsNullOrWhiteSpace(query)) { return "Error: query is required."; } - var results = await memory.SearchAsync(query, n, ct); + var results = await memory.SearchAsync(query, n, ct).ConfigureAwait(false); if (results.Count > 0) { return string.Join("\n", results); diff --git a/src/clawsharp/Tools/Memory/MemoryWriteTool.cs b/src/clawsharp/Tools/Memory/MemoryWriteTool.cs index 1cd467f2..3f853919 100644 --- a/src/clawsharp/Tools/Memory/MemoryWriteTool.cs +++ b/src/clawsharp/Tools/Memory/MemoryWriteTool.cs @@ -35,7 +35,7 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat fact = scrubResult.Redacted; } - await memory.AppendFactAsync(fact, ct); + await memory.AppendFactAsync(fact, ct).ConfigureAwait(false); return $"Saved: {fact}"; } } \ No newline at end of file diff --git a/src/clawsharp/Tools/Ops/CronTool.cs b/src/clawsharp/Tools/Ops/CronTool.cs index ceb35900..6fccce13 100644 --- a/src/clawsharp/Tools/Ops/CronTool.cs +++ b/src/clawsharp/Tools/Ops/CronTool.cs @@ -54,11 +54,11 @@ public override async Task ExecuteAsync(JsonElement args, CancellationTo var action = args.TryGetProperty("action", out var a) ? a.GetString() ?? "" : ""; return action.ToLowerInvariant() switch { - "add" => await AddAsync(args, ct), - "list" => await ListAsync(ct), - "remove" => await RemoveAsync(args, ct), - "run" => await RunAsync(args, ct), - "update" => await UpdateAsync(args, ct), + "add" => await AddAsync(args, ct).ConfigureAwait(false), + "list" => await ListAsync(ct).ConfigureAwait(false), + "remove" => await RemoveAsync(args, ct).ConfigureAwait(false), + "run" => await RunAsync(args, ct).ConfigureAwait(false), + "update" => await UpdateAsync(args, ct).ConfigureAwait(false), _ => $"Unknown action '{action}'. Valid: add, list, remove, run, update." }; } @@ -132,7 +132,7 @@ private async Task AddAsync(JsonElement args, CancellationToken ct) Provider = args.TryGetProperty("provider", out var pr) ? pr.GetString() : null }; - await cronService.AddJobAsync(job, ct); + await cronService.AddJobAsync(job, ct).ConfigureAwait(false); var preview = message; if (message.Length > AddPreviewLength) { @@ -144,7 +144,7 @@ private async Task AddAsync(JsonElement args, CancellationToken ct) private async Task ListAsync(CancellationToken ct) { - var jobs = await cronService.ListJobsAsync(ct); + var jobs = await cronService.ListJobsAsync(ct).ConfigureAwait(false); if (jobs.Count == 0) { return "No cron jobs scheduled."; @@ -212,7 +212,7 @@ private async Task RemoveAsync(JsonElement args, CancellationToken ct) return "Error: 'id' is required for remove."; } - var removed = await cronService.RemoveJobAsync(id, ct); + var removed = await cronService.RemoveJobAsync(id, ct).ConfigureAwait(false); return removed ? $"Removed job '{id}'." : $"No job found with id '{id}'."; } @@ -224,7 +224,7 @@ private async Task RunAsync(JsonElement args, CancellationToken ct) return "Error: 'id' is required for run."; } - return await cronService.RunJobNowAsync(id, ct); + return await cronService.RunJobNowAsync(id, ct).ConfigureAwait(false); } private async Task UpdateAsync(JsonElement args, CancellationToken ct) @@ -235,7 +235,7 @@ private async Task UpdateAsync(JsonElement args, CancellationToken ct) return "Error: 'id' is required for update."; } - var all = await cronService.ListJobsAsync(ct); + var all = await cronService.ListJobsAsync(ct).ConfigureAwait(false); var job = all.FirstOrDefault(j => j.Id == id || j.Id.StartsWith(id, StringComparison.OrdinalIgnoreCase)); if (job is null) @@ -297,7 +297,7 @@ private async Task UpdateAsync(JsonElement args, CancellationToken ct) Provider = args.TryGetProperty("provider", out var pr) ? pr.GetString() : job.Provider }; - var result = await cronService.UpdateJobAsync(updated, ct); + var result = await cronService.UpdateJobAsync(updated, ct).ConfigureAwait(false); return result is null ? $"No job found with id '{id}'." : $"Updated job '{result.Id}'."; } diff --git a/src/clawsharp/Tools/Ops/DocumentReadTool.cs b/src/clawsharp/Tools/Ops/DocumentReadTool.cs index 511fe475..dd44f64e 100644 --- a/src/clawsharp/Tools/Ops/DocumentReadTool.cs +++ b/src/clawsharp/Tools/Ops/DocumentReadTool.cs @@ -1,6 +1,7 @@ using System.IO.Compression; using System.Text; using System.Text.Json; +using System.Xml; using System.Xml.Linq; using Clawsharp.Security; using UglyToad.PdfPig; @@ -15,6 +16,12 @@ public sealed class DocumentReadTool(string workspace, AuditLogger? auditLogger private const int HardMaxChars = 200_000; + private static readonly XmlReaderSettings SafeXmlSettings = new() + { + DtdProcessing = DtdProcessing.Prohibit, + XmlResolver = null, + }; + private readonly string _workspace = Path.GetFullPath(workspace); public string? ChannelName => ToolRegistry.CurrentChannelName; @@ -67,7 +74,7 @@ public override async Task ExecuteAsync(JsonElement args, CancellationTo if (!File.Exists(resolvedPath)) { - return $"Error: file not found: {resolvedPath}"; + return $"Error: file not found: {inputPath}"; } var info = new FileInfo(resolvedPath); @@ -88,7 +95,7 @@ public override async Task ExecuteAsync(JsonElement args, CancellationTo { text = ext switch { - ".pdf" => await ExtractPdfAsync(resolvedPath, ct), + ".pdf" => await ExtractPdfAsync(resolvedPath, ct).ConfigureAwait(false), ".docx" => ExtractDocx(resolvedPath), ".xlsx" => ExtractXlsx(resolvedPath), ".pptx" => ExtractPptx(resolvedPath), @@ -151,7 +158,8 @@ private static string ExtractDocx(string path) } using var stream = entry.Open(); - var doc = XDocument.Load(stream); + using var reader = XmlReader.Create(stream, SafeXmlSettings); + var doc = XDocument.Load(reader); XNamespace w = "http://schemas.openxmlformats.org/wordprocessingml/2006/main"; var paragraphs = doc.Descendants(w + "p") @@ -168,7 +176,8 @@ private static string ExtractXlsx(string path) if (ssEntry is not null) { using var ss = ssEntry.Open(); - var ssDoc = XDocument.Load(ss); + using var ssReader = XmlReader.Create(ss, SafeXmlSettings); + var ssDoc = XDocument.Load(ssReader); XNamespace ns = "http://schemas.openxmlformats.org/spreadsheetml/2006/main"; sharedStrings.AddRange(ssDoc.Descendants(ns + "si") .Select(si => string.Concat(si.Descendants(ns + "t").Select(t => t.Value)))); @@ -184,7 +193,8 @@ private static string ExtractXlsx(string path) { sb.AppendLine($"[Sheet: {Path.GetFileNameWithoutExtension(sheetEntry.Name)}]"); using var s = sheetEntry.Open(); - var sheet = XDocument.Load(s); + using var sheetReader = XmlReader.Create(s, SafeXmlSettings); + var sheet = XDocument.Load(sheetReader); XNamespace ns = "http://schemas.openxmlformats.org/spreadsheetml/2006/main"; foreach (var row in sheet.Descendants(ns + "row")) { @@ -229,7 +239,8 @@ private static string ExtractPptx(string path) { sb.AppendLine($"[Slide {slideNum++}]"); using var s = slide.Open(); - var doc = XDocument.Load(s); + using var slideReader = XmlReader.Create(s, SafeXmlSettings); + var doc = XDocument.Load(slideReader); var texts = doc.Descendants(a + "t").Select(t => t.Value); sb.AppendLine(string.Join(" ", texts)); sb.AppendLine(); diff --git a/src/clawsharp/Tools/Ops/GitTool.cs b/src/clawsharp/Tools/Ops/GitTool.cs index 89780b90..2c73d1e7 100644 --- a/src/clawsharp/Tools/Ops/GitTool.cs +++ b/src/clawsharp/Tools/Ops/GitTool.cs @@ -165,9 +165,9 @@ public override async Task ExecuteAsync(JsonElement args, CancellationTo using var proc = Process.Start(psi) ?? throw new InvalidOperationException("Failed to start git process."); - var stdout = await proc.StandardOutput.ReadToEndAsync(cts.Token); - var stderr = await proc.StandardError.ReadToEndAsync(cts.Token); - await proc.WaitForExitAsync(cts.Token); + var stdout = await proc.StandardOutput.ReadToEndAsync(cts.Token).ConfigureAwait(false); + var stderr = await proc.StandardError.ReadToEndAsync(cts.Token).ConfigureAwait(false); + await proc.WaitForExitAsync(cts.Token).ConfigureAwait(false); var combined = (stdout + (stderr.Length > 0 ? $"\n{stderr}" : "")).Trim(); if (combined.Length > MaxOutputBytes) diff --git a/src/clawsharp/Tools/Ops/GoalTool.cs b/src/clawsharp/Tools/Ops/GoalTool.cs index 6fa53afc..bbe1ce86 100644 --- a/src/clawsharp/Tools/Ops/GoalTool.cs +++ b/src/clawsharp/Tools/Ops/GoalTool.cs @@ -64,13 +64,13 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat return action switch { - "create" => await CreateAsync(arguments, ct), - "list" => await ListAsync(arguments, ct), - "update_step" => await UpdateStepAsync(arguments, ct), - "complete" => await SetStatusAsync(arguments, GoalStatus.Completed, ct), - "pause" => await SetStatusAsync(arguments, GoalStatus.Paused, ct), - "resume" => await SetStatusAsync(arguments, GoalStatus.Active, ct), - "delete" => await SetStatusAsync(arguments, GoalStatus.Deleted, ct), + "create" => await CreateAsync(arguments, ct).ConfigureAwait(false), + "list" => await ListAsync(arguments, ct).ConfigureAwait(false), + "update_step" => await UpdateStepAsync(arguments, ct).ConfigureAwait(false), + "complete" => await SetStatusAsync(arguments, GoalStatus.Completed, ct).ConfigureAwait(false), + "pause" => await SetStatusAsync(arguments, GoalStatus.Paused, ct).ConfigureAwait(false), + "resume" => await SetStatusAsync(arguments, GoalStatus.Active, ct).ConfigureAwait(false), + "delete" => await SetStatusAsync(arguments, GoalStatus.Deleted, ct).ConfigureAwait(false), _ => "Error: action is required. Valid actions: create, list, update_step, complete, pause, resume, delete." }; } @@ -102,9 +102,9 @@ private async Task CreateAsync(JsonElement args, CancellationToken ct) } } - var goals = await storage.LoadAsync(ct); + var goals = await storage.LoadAsync(ct).ConfigureAwait(false); goals.Add(goal); - await storage.SaveAsync(goals, ct); + await storage.SaveAsync(goals, ct).ConfigureAwait(false); LogGoalCreated(logger, goal.Id, goal.Title); @@ -121,7 +121,7 @@ private async Task CreateAsync(JsonElement args, CancellationToken ct) private async Task ListAsync(JsonElement args, CancellationToken ct) { var statusFilter = args.TryGetProperty("status", out var s) ? s.GetString() : "active"; - var goals = await storage.LoadAsync(ct); + var goals = await storage.LoadAsync(ct).ConfigureAwait(false); var filtered = statusFilter switch { @@ -172,7 +172,7 @@ private async Task UpdateStepAsync(JsonElement args, CancellationToken c var done = doneEl.GetBoolean(); - var goals = await storage.LoadAsync(ct); + var goals = await storage.LoadAsync(ct).ConfigureAwait(false); var goal = goals.Find(g => string.Equals(g.Id, id, StringComparison.OrdinalIgnoreCase)); if (goal is null) { @@ -186,7 +186,7 @@ private async Task UpdateStepAsync(JsonElement args, CancellationToken c goal.Steps[stepIndex].Done = done; goal.UpdatedAt = DateTimeOffset.UtcNow; - await storage.SaveAsync(goals, ct); + await storage.SaveAsync(goals, ct).ConfigureAwait(false); var stepText = goal.Steps[stepIndex].Text; return $"Step {stepIndex} of goal '{goal.Title}' marked as {(done ? "done" : "not done")}: {stepText}"; @@ -200,7 +200,7 @@ private async Task SetStatusAsync(JsonElement args, GoalStatus newStatus return $"Error: id is required for {newStatus.ToString().ToLowerInvariant()}."; } - var goals = await storage.LoadAsync(ct); + var goals = await storage.LoadAsync(ct).ConfigureAwait(false); var goal = goals.Find(g => string.Equals(g.Id, id, StringComparison.OrdinalIgnoreCase)); if (goal is null) { @@ -225,7 +225,7 @@ private async Task SetStatusAsync(JsonElement args, GoalStatus newStatus goal.Status = newStatus; goal.UpdatedAt = DateTimeOffset.UtcNow; - await storage.SaveAsync(goals, ct); + await storage.SaveAsync(goals, ct).ConfigureAwait(false); return $"Goal {goal.Id} ({goal.Title}) marked as {newStatus.ToString().ToLowerInvariant()}."; } diff --git a/src/clawsharp/Tools/Ops/InteractionsTool.cs b/src/clawsharp/Tools/Ops/InteractionsTool.cs index 81a984fc..56f22654 100644 --- a/src/clawsharp/Tools/Ops/InteractionsTool.cs +++ b/src/clawsharp/Tools/Ops/InteractionsTool.cs @@ -27,7 +27,7 @@ public sealed class InteractionsTool(IInteractionStore store) : Tool public override async Task ExecuteAsync(JsonElement arguments, CancellationToken ct = default) { var query = arguments.GetProperty("query").GetString() ?? "summary"; - var records = await store.ReadAllAsync(ct); + var records = await store.ReadAllAsync(ct).ConfigureAwait(false); if (records.Count == 0) { @@ -52,7 +52,7 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat "recent" => FormatRecent(records), "savings" => FormatSavings(records), "daily" => FormatDaily(records), - _ => FormatSummary(records), + _ => $"Error: unknown query '{query}'. Valid: summary, recent, session:, model:, savings, daily.", }; } diff --git a/src/clawsharp/Tools/Ops/SendFileTool.cs b/src/clawsharp/Tools/Ops/SendFileTool.cs index 5cd74d7f..be69846f 100644 --- a/src/clawsharp/Tools/Ops/SendFileTool.cs +++ b/src/clawsharp/Tools/Ops/SendFileTool.cs @@ -90,7 +90,7 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat byte[] bytes; try { - bytes = await File.ReadAllBytesAsync(fullPath, ct); + bytes = await File.ReadAllBytesAsync(fullPath, ct).ConfigureAwait(false); } catch (Exception ex) { diff --git a/src/clawsharp/Tools/Ops/ShellTool.cs b/src/clawsharp/Tools/Ops/ShellTool.cs index 20119962..59a02d45 100644 --- a/src/clawsharp/Tools/Ops/ShellTool.cs +++ b/src/clawsharp/Tools/Ops/ShellTool.cs @@ -102,7 +102,7 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat _ = _auditLogger.LogPolicyViolationAsync($"ShellGuard denied: {blocked}", ChannelName, ct: ct); } - return $"[shell] {blocked}"; + return $"Error: {blocked}"; } } @@ -117,7 +117,7 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat error: $"Requires approval (pattern: {matchedPattern})", ct: ct); } - return $"[shell] Command requires approval: '{command}' matched approval pattern '{matchedPattern}'. " + + return $"Error: command requires approval: '{command}' matched approval pattern '{matchedPattern}'. " + "Use the CLI channel to execute this command interactively, or add the pattern to security.autoApprovePatterns to bypass."; } @@ -133,7 +133,7 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat } return - "[shell] Shell execution is disabled on non-interactive channels. Set requireShellApproval=false in config to allow."; + "Error: shell execution is disabled on non-interactive channels. Set requireShellApproval=false in config to allow."; } Console.Error.Write($"[shell] Allow command? {command}\n[y/N]: "); @@ -146,7 +146,7 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat error: "User rejected", ct: ct); } - return "[shell] Command rejected by user."; + return "Error: command rejected by user."; } } @@ -219,10 +219,10 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat try { - await proc.WaitForExitAsync(cts.Token); + await proc.WaitForExitAsync(cts.Token).ConfigureAwait(false); var elapsed = Stopwatch.GetElapsedTime(timestamp); - var stdout = await stdoutTask; - var stderr = await stderrTask; + var stdout = await stdoutTask.ConfigureAwait(false); + var stderr = await stderrTask.ConfigureAwait(false); var output = string.IsNullOrEmpty(stderr) ? stdout : $"{stdout}\n[stderr]\n{stderr}"; // Truncate to 100 KB (global cap in ToolRegistry is the final safety net) const int maxChars = 102_400; diff --git a/src/clawsharp/Tools/Ops/SpawnTool.cs b/src/clawsharp/Tools/Ops/SpawnTool.cs index b2a608f9..f98605fd 100644 --- a/src/clawsharp/Tools/Ops/SpawnTool.cs +++ b/src/clawsharp/Tools/Ops/SpawnTool.cs @@ -164,7 +164,7 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat try { - var result = await RunChildLoopAsync(task, restrictedTools, cts.Token); + var result = await RunChildLoopAsync(task, restrictedTools, cts.Token).ConfigureAwait(false); LogSpawnCompleted(logger, displayName, result.Response.Length); return result.Response; } @@ -214,7 +214,7 @@ private async Task RunChildLoopAsync( memoryCtx, workspaceContext: null, channelName: "spawn", - enabledTools: toolDefs.Select(t => t.Name).ToList()); + enabledTools: toolDefs.Select(t => t.Name)); string systemPrompt; if (string.IsNullOrEmpty(dynamicPrompt)) @@ -250,7 +250,7 @@ private async Task RunChildLoopAsync( CacheToolDefinitions: cacheToolDefs, BeforeToolExecution: _ => SetChildContext(CurrentSpawnDepth + 1)); - return await stepExecutor.ExecuteAsync(stepRequest, provider, tools, ct); + return await stepExecutor.ExecuteAsync(stepRequest, provider, tools, ct).ConfigureAwait(false); } /// diff --git a/src/clawsharp/Tools/ToolRegistry.cs b/src/clawsharp/Tools/ToolRegistry.cs index 5cdf1369..c1db570c 100644 --- a/src/clawsharp/Tools/ToolRegistry.cs +++ b/src/clawsharp/Tools/ToolRegistry.cs @@ -30,7 +30,7 @@ namespace Clawsharp.Tools; -public sealed class ToolRegistry : IToolRegistry +public sealed partial class ToolRegistry : IToolRegistry { private static readonly AsyncLocal _currentChannelName = new(); @@ -67,10 +67,17 @@ public sealed class ToolRegistry : IToolRegistry /// Current MCP execution context for the executing async flow. Set by McpServerToolBridge, read during tool.execute spans. public static McpExecutionContext? CurrentMcpExecutionContext => _currentMcpContext.Value; - private readonly ConcurrentDictionary _schemaCache = new(StringComparer.OrdinalIgnoreCase); + private readonly ConcurrentDictionary _schemaCache = new(StringComparer.OrdinalIgnoreCase); private readonly ConcurrentDictionary _tools; + /// + /// Cached unfiltered tool definitions. Invalidated when tools are registered. + /// Volatile ensures cross-thread visibility without locking; benign duplicate + /// computation is preferred over contention. + /// + private volatile IReadOnlyList? _cachedDefinitions; + private readonly int _maxToolOutputChars; private readonly Dictionary? _filterGroups; @@ -188,7 +195,24 @@ internal ToolRegistry(IEnumerable tools, ILoggerFactory loggerFactory, App } /// Registers a tool dynamically (e.g. from an MCP server). - public void Register(Tool tool) => _tools[tool.Name] = tool; + public void Register(Tool tool) + { + _tools[tool.Name] = tool; + _cachedDefinitions = null; + } + + /// + public bool Unregister(string toolName) + { + if (_tools.TryRemove(toolName, out _)) + { + _schemaCache.TryRemove(toolName, out _); + _cachedDefinitions = null; + return true; + } + + return false; + } /// Sets per-request channel context via AsyncLocal so each async call chain /// gets its own isolated value, preventing cross-channel corruption on shared singletons. @@ -215,17 +239,23 @@ public void SetMcpExecutionContext(McpExecutionContext? ctx) public IReadOnlyList GetDefinitions() { - return _tools.Values.Select(t => t.ToDefinition()).ToList(); + return _cachedDefinitions ??= _tools.Values.Select(t => t.ToDefinition()).ToList(); } /// public IReadOnlyList GetFilteredDefinitions(string? messageText) { + // Fast path: no RBAC policy and no filter groups — return the cached full set. + var policy = CurrentPolicyDecision; + if (policy is null && (_filterGroups is null || _filterGroups.Count == 0)) + { + return GetDefinitions(); + } + IEnumerable tools = _tools.Values; // RBAC filter (first) — per D-17, composes with existing filter groups. // When no policy is set (null), ALL tools pass (backward compatibility). - var policy = CurrentPolicyDecision; if (policy is not null) { tools = tools.Where(t => policy.EvaluateToolAccess(t.Name, t.Sensitivity) == PolicyEffect.Allowed); @@ -405,6 +435,10 @@ public async Task ExecuteAsync(string name, string argumentsJson, Cancel return $"[approval] Tool '{name}' requires admin approval. Request submitted (ID: {requestId}). An admin will review your request."; } + else + { + LogApprovalDeniedNoOrgUser(_logger, name); + } } if (effect != PolicyEffect.Allowed) @@ -421,8 +455,7 @@ public async Task ExecuteAsync(string name, string argumentsJson, Cancel var sid = CurrentSessionId; if (sid is not null && _policyEvaluator?.RecordDenial(sid) == true) { - _logger.LogWarning("Suspicious denial pattern: {Threshold}+ denials in session {SessionId}", - 3, sid); + LogSuspiciousDenialPattern(_logger, 3, sid); if (_auditLogger is not null) { _ = _auditLogger.LogAsync(new AuditEvent @@ -487,7 +520,7 @@ public async Task ExecuteAsync(string name, string argumentsJson, Cancel return validationError; } - var result = await tool.ExecuteAsync(doc.RootElement, ct); + var result = await tool.ExecuteAsync(doc.RootElement, ct).ConfigureAwait(false); toolSw.Stop(); // Global safety-net truncation — individual tools may have their own lower caps. @@ -512,7 +545,7 @@ public async Task ExecuteAsync(string name, string argumentsJson, Cancel catch (Exception ex) { toolSw.Stop(); - _logger.LogWarning(ex, "Tool '{ToolName}' execution failed", name); + LogToolExecutionFailed(_logger, name, ex); toolActivity?.SetStatus(ActivityStatusCode.Error, ex.Message); return "Error: operation failed."; } @@ -525,11 +558,12 @@ public async Task ExecuteAsync(string name, string argumentsJson, Cancel /// private string? ValidateArguments(Tool tool, JsonElement arguments) { - var schemaDoc = _schemaCache.GetOrAdd(tool.Name, _ => + var schemaElement = _schemaCache.GetOrAdd(tool.Name, _ => { try { - return JsonDocument.Parse(tool.ParametersSchemaJson); + using var doc = JsonDocument.Parse(tool.ParametersSchemaJson); + return doc.RootElement.Clone(); } catch { @@ -537,12 +571,12 @@ public async Task ExecuteAsync(string name, string argumentsJson, Cancel } }); - if (schemaDoc is null) + if (schemaElement is null) { return null; // Unparseable schema — skip validation rather than blocking execution. } - var error = ToolValidator.Validate(schemaDoc.RootElement, arguments); + var error = ToolValidator.Validate(schemaElement.Value, arguments); if (error is not null) { return $"Tool input validation error for '{tool.Name}': {error} " @@ -551,4 +585,16 @@ public async Task ExecuteAsync(string name, string argumentsJson, Cancel return null; } + + [LoggerMessage(Level = LogLevel.Warning, + Message = "Suspicious denial pattern: {Threshold}+ denials in session {SessionId}")] + private static partial void LogSuspiciousDenialPattern(ILogger logger, int threshold, string sessionId); + + [LoggerMessage(Level = LogLevel.Warning, + Message = "Tool '{ToolName}' execution failed")] + private static partial void LogToolExecutionFailed(ILogger logger, string toolName, Exception exception); + + [LoggerMessage(Level = LogLevel.Debug, + Message = "Tool '{ToolName}' requires approval but no OrgUser context is available; denying")] + private static partial void LogApprovalDeniedNoOrgUser(ILogger logger, string toolName); } \ No newline at end of file diff --git a/src/clawsharp/Tools/ToolValidator.cs b/src/clawsharp/Tools/ToolValidator.cs index 7cc563d9..90cf1c46 100644 --- a/src/clawsharp/Tools/ToolValidator.cs +++ b/src/clawsharp/Tools/ToolValidator.cs @@ -87,7 +87,8 @@ internal static class ToolValidator $"property '{name}' must be string, got {value.ValueKind}", "number" when value.ValueKind != JsonValueKind.Number => $"property '{name}' must be number, got {value.ValueKind}", - "integer" when value.ValueKind != JsonValueKind.Number => + "integer" when value.ValueKind != JsonValueKind.Number + || value.TryGetDecimal(out var d) && d != Math.Floor(d) => $"property '{name}' must be integer, got {value.ValueKind}", "boolean" when value.ValueKind is not (JsonValueKind.True or JsonValueKind.False) => $"property '{name}' must be boolean, got {value.ValueKind}", diff --git a/src/clawsharp/Tools/Web/WebFetchTool.cs b/src/clawsharp/Tools/Web/WebFetchTool.cs index 8fed8e59..e4f46e79 100644 --- a/src/clawsharp/Tools/Web/WebFetchTool.cs +++ b/src/clawsharp/Tools/Web/WebFetchTool.cs @@ -1,7 +1,7 @@ using System.Net; -using System.Text; using System.Text.Json; using System.Text.RegularExpressions; +using Clawsharp.Core.Utilities; using Clawsharp.Security; namespace Clawsharp.Tools.Web; @@ -78,16 +78,16 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat HttpResponseMessage resp; if (method.Equals("POST", StringComparison.OrdinalIgnoreCase) && body is not null) { - resp = await client.PostAsync(uri, new StringContent(body, Encoding.UTF8, "application/json"), ct); + resp = await client.PostAsync(uri, Utf8JsonContent.FromString(body), ct).ConfigureAwait(false); } else { - resp = await client.GetAsync(uri, ct); + resp = await client.GetAsync(uri, ct).ConfigureAwait(false); } using (resp) { - var text = await resp.Content.ReadAsStringAsync(ct); + var text = await resp.Content.ReadAsStringAsync(ct).ConfigureAwait(false); // Cap raw HTML before regex processing to prevent excessive scan time on huge responses. if (text.Length > maxChars * 2) { diff --git a/src/clawsharp/Tools/Web/WebSearchTool.cs b/src/clawsharp/Tools/Web/WebSearchTool.cs index b920d187..7a36bac6 100644 --- a/src/clawsharp/Tools/Web/WebSearchTool.cs +++ b/src/clawsharp/Tools/Web/WebSearchTool.cs @@ -5,6 +5,7 @@ using System.Text.Json; using System.Text.Json.Serialization; using System.Text.RegularExpressions; +using Clawsharp.Core.Utilities; using Clawsharp.Security; using Clawsharp.Config.Features; @@ -122,7 +123,7 @@ public WebSearchTool(IHttpClientFactory httpFactory, ToolsConfig config, AuditLo "type": "object", "properties": { "query": { "type": "string", "description": "Search query" }, - "count": { "type": "integer", "description": "Number of results (default 5, max 10)" } + "top_k": { "type": "integer", "description": "Number of results (default 5, max 10)" } }, "required": ["query"] } @@ -131,7 +132,7 @@ public WebSearchTool(IHttpClientFactory httpFactory, ToolsConfig config, AuditLo public override async Task ExecuteAsync(JsonElement arguments, CancellationToken ct = default) { var query = arguments.TryGetProperty("query", out var q) ? q.GetString() ?? "" : ""; - var count = arguments.TryGetProperty("count", out var c) && c.TryGetInt32(out var cv) ? Math.Min(cv, 10) : 5; + var count = arguments.TryGetProperty("top_k", out var c) && c.TryGetInt32(out var cv) ? Math.Min(cv, 10) : 5; if (string.IsNullOrWhiteSpace(query)) { @@ -147,15 +148,15 @@ public override async Task ExecuteAsync(JsonElement arguments, Cancellat { return _activeProvider switch { - SearchProvider.Brave => await BraveSearchAsync(query, count, ct), - SearchProvider.Exa => await SearchExaAsync(query, count, ct), - SearchProvider.Tavily => await SearchTavilyAsync(query, count, ct), - SearchProvider.Searxng => await SearchSearxngAsync(query, count, ct), - SearchProvider.Jina => await SearchJinaAsync(query, ct), - SearchProvider.Firecrawl => await SearchFirecrawlAsync(query, count, ct), - SearchProvider.Perplexity => await SearchPerplexityAsync(query, count, ct), - SearchProvider.Glm => await SearchGlmAsync(query, ct), - _ => await DdgSearchAsync(query, count, ct) + SearchProvider.Brave => await BraveSearchAsync(query, count, ct).ConfigureAwait(false), + SearchProvider.Exa => await SearchExaAsync(query, count, ct).ConfigureAwait(false), + SearchProvider.Tavily => await SearchTavilyAsync(query, count, ct).ConfigureAwait(false), + SearchProvider.Searxng => await SearchSearxngAsync(query, count, ct).ConfigureAwait(false), + SearchProvider.Jina => await SearchJinaAsync(query, ct).ConfigureAwait(false), + SearchProvider.Firecrawl => await SearchFirecrawlAsync(query, count, ct).ConfigureAwait(false), + SearchProvider.Perplexity => await SearchPerplexityAsync(query, count, ct).ConfigureAwait(false), + SearchProvider.Glm => await SearchGlmAsync(query, ct).ConfigureAwait(false), + _ => await DdgSearchAsync(query, count, ct).ConfigureAwait(false) }; } catch (Exception) @@ -177,9 +178,9 @@ private async Task BraveSearchAsync(string query, int count, Cancellatio req.Headers.Add("X-Subscription-Token", _braveApiKey); using var client = _httpFactory.CreateClient("tools"); - using var resp = await client.SendAsync(req, ct); - await using var stream = await resp.Content.ReadAsStreamAsync(ct); - using var doc = await JsonDocument.ParseAsync(stream, default, ct); + using var resp = await client.SendAsync(req, ct).ConfigureAwait(false); + await using var stream = await resp.Content.ReadAsStreamAsync(ct).ConfigureAwait(false); + using var doc = await JsonDocument.ParseAsync(stream, default, ct).ConfigureAwait(false); var results = new List(); if (doc.RootElement.TryGetProperty("web", out var web) && @@ -212,19 +213,17 @@ private async Task BraveSearchAsync(string query, int count, Cancellatio private async Task SearchExaAsync(string query, int count, CancellationToken ct) { - var body = JsonSerializer.Serialize( - new ExaSearchRequest(query, count, "auto"), - WebSearchJsonContext.Default.ExaSearchRequest); - using var req = new HttpRequestMessage(HttpMethod.Post, "https://api.exa.ai/search"); req.Headers.Add("x-api-key", _exaApiKey); - req.Content = new StringContent(body, Encoding.UTF8, "application/json"); + req.Content = Utf8JsonContent.Create( + new ExaSearchRequest(query, count, "auto"), + WebSearchJsonContext.Default.ExaSearchRequest); using var client = _httpFactory.CreateClient("tools"); - using var resp = await client.SendAsync(req, ct); + using var resp = await client.SendAsync(req, ct).ConfigureAwait(false); resp.EnsureSuccessStatusCode(); - var json = await resp.Content.ReadAsStringAsync(ct); + var json = await resp.Content.ReadAsStringAsync(ct).ConfigureAwait(false); using var parsed = JsonDocument.Parse(json); var sb = new StringBuilder(); @@ -256,18 +255,16 @@ private async Task SearchExaAsync(string query, int count, CancellationT private async Task SearchTavilyAsync(string query, int count, CancellationToken ct) { - var body = JsonSerializer.Serialize( + using var req = new HttpRequestMessage(HttpMethod.Post, "https://api.tavily.com/search"); + req.Content = Utf8JsonContent.Create( new TavilySearchRequest(_tavilyApiKey!, query, count, "basic"), WebSearchJsonContext.Default.TavilySearchRequest); - using var req = new HttpRequestMessage(HttpMethod.Post, "https://api.tavily.com/search"); - req.Content = new StringContent(body, Encoding.UTF8, "application/json"); - using var client = _httpFactory.CreateClient("tools"); - using var resp = await client.SendAsync(req, ct); + using var resp = await client.SendAsync(req, ct).ConfigureAwait(false); resp.EnsureSuccessStatusCode(); - var json = await resp.Content.ReadAsStringAsync(ct); + var json = await resp.Content.ReadAsStringAsync(ct).ConfigureAwait(false); using var parsed = JsonDocument.Parse(json); var sb = new StringBuilder(); @@ -301,14 +298,14 @@ private async Task SearchSearxngAsync(string query, int count, Cancellat { var url = $"{_searxngBaseUrl!.TrimEnd('/')}/search?q={Uri.EscapeDataString(query)}&format=json&categories=general&pageno=1"; - var ssrfError = await SsrfGuard.CheckAsync(new Uri(url), ct); + var ssrfError = await SsrfGuard.CheckAsync(new Uri(url), ct).ConfigureAwait(false); if (ssrfError is not null) { return $"Error: {ssrfError}"; } using var client = _httpFactory.CreateClient("tools"); - var resp = await client.GetStringAsync(url, ct); + var resp = await client.GetStringAsync(url, ct).ConfigureAwait(false); using var parsed = JsonDocument.Parse(resp); var sb = new StringBuilder(); var i = 0; @@ -351,10 +348,10 @@ private async Task SearchJinaAsync(string query, CancellationToken ct) req.Headers.Add("Accept", "text/plain"); using var client = _httpFactory.CreateClient("tools"); - using var resp = await client.SendAsync(req, ct); + using var resp = await client.SendAsync(req, ct).ConfigureAwait(false); resp.EnsureSuccessStatusCode(); - var text = await resp.Content.ReadAsStringAsync(ct); + var text = await resp.Content.ReadAsStringAsync(ct).ConfigureAwait(false); if (text.Length > MaxResponseTextLength) { return text[..MaxResponseTextLength] + "\n...[truncated]"; @@ -369,15 +366,11 @@ private async Task SearchJinaAsync(string query, CancellationToken ct) private async Task SearchFirecrawlAsync(string query, int count, CancellationToken ct) { - var body = JsonSerializer.Serialize( - new FirecrawlSearchRequest(query, count), - WebSearchJsonContext.Default.FirecrawlSearchRequest); - var baseUrl = (_firecrawlBaseUrl ?? "https://api.firecrawl.dev").TrimEnd('/'); var requestUrl = $"{baseUrl}/v1/search"; // HIGH-04: SSRF-check the Firecrawl base URL (user-configurable in config) - var ssrfError = await SsrfGuard.CheckAsync(new Uri(requestUrl), ct); + var ssrfError = await SsrfGuard.CheckAsync(new Uri(requestUrl), ct).ConfigureAwait(false); if (ssrfError is not null) { return $"Error: {ssrfError}"; @@ -385,14 +378,16 @@ private async Task SearchFirecrawlAsync(string query, int count, Cancell using var req = new HttpRequestMessage(HttpMethod.Post, requestUrl); req.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _firecrawlApiKey); - req.Content = new StringContent(body, Encoding.UTF8, "application/json"); + req.Content = Utf8JsonContent.Create( + new FirecrawlSearchRequest(query, count), + WebSearchJsonContext.Default.FirecrawlSearchRequest); using var client = _httpFactory.CreateClient("tools"); - using var resp = await client.SendAsync(req, ct); + using var resp = await client.SendAsync(req, ct).ConfigureAwait(false); resp.EnsureSuccessStatusCode(); - await using var stream = await resp.Content.ReadAsStreamAsync(ct); - using var doc = await JsonDocument.ParseAsync(stream, default, ct); + await using var stream = await resp.Content.ReadAsStreamAsync(ct).ConfigureAwait(false); + using var doc = await JsonDocument.ParseAsync(stream, default, ct).ConfigureAwait(false); var sb = new StringBuilder(); if (doc.RootElement.TryGetProperty("data", out var data)) @@ -426,22 +421,20 @@ private async Task SearchFirecrawlAsync(string query, int count, Cancell private async Task SearchPerplexityAsync(string query, int count, CancellationToken ct) { - var body = JsonSerializer.Serialize( + using var req = new HttpRequestMessage(HttpMethod.Post, "https://api.perplexity.ai/chat/completions"); + req.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _perplexityApiKey); + req.Content = Utf8JsonContent.Create( new PerplexitySearchRequest( _perplexityModel ?? "sonar-pro", [new PerplexityMessage("user", query)], 1024), WebSearchJsonContext.Default.PerplexitySearchRequest); - using var req = new HttpRequestMessage(HttpMethod.Post, "https://api.perplexity.ai/chat/completions"); - req.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _perplexityApiKey); - req.Content = new StringContent(body, Encoding.UTF8, "application/json"); - using var client = _httpFactory.CreateClient("tools"); - using var resp = await client.SendAsync(req, ct); + using var resp = await client.SendAsync(req, ct).ConfigureAwait(false); resp.EnsureSuccessStatusCode(); - var json = await resp.Content.ReadAsStringAsync(ct); + var json = await resp.Content.ReadAsStringAsync(ct).ConfigureAwait(false); using var doc = JsonDocument.Parse(json); if (doc.RootElement.TryGetProperty("choices", out var choices) && @@ -469,21 +462,19 @@ private async Task SearchGlmAsync(string query, CancellationToken ct) { var jwt = GetOrCreateGlmJwt(); - var body = JsonSerializer.Serialize( + using var req = new HttpRequestMessage(HttpMethod.Post, "https://open.bigmodel.cn/api/paas/v4/chat/completions"); + req.Headers.Authorization = new AuthenticationHeaderValue("Bearer", jwt); + req.Content = Utf8JsonContent.Create( new GlmSearchRequest( _glmModel ?? "web-search-pro", [new GlmMessage("user", query)]), WebSearchJsonContext.Default.GlmSearchRequest); - using var req = new HttpRequestMessage(HttpMethod.Post, "https://open.bigmodel.cn/api/paas/v4/chat/completions"); - req.Headers.Authorization = new AuthenticationHeaderValue("Bearer", jwt); - req.Content = new StringContent(body, Encoding.UTF8, "application/json"); - using var client = _httpFactory.CreateClient("tools"); - using var resp = await client.SendAsync(req, ct); + using var resp = await client.SendAsync(req, ct).ConfigureAwait(false); resp.EnsureSuccessStatusCode(); - var json = await resp.Content.ReadAsStringAsync(ct); + var json = await resp.Content.ReadAsStringAsync(ct).ConfigureAwait(false); using var doc = JsonDocument.Parse(json); if (doc.RootElement.TryGetProperty("choices", out var choices) && @@ -571,8 +562,8 @@ private async Task DdgSearchAsync(string query, int count, CancellationT using var req = new HttpRequestMessage(HttpMethod.Get, url); req.Headers.Add("User-Agent", "Mozilla/5.0 (compatible; clawsharp/1.0)"); using var client = _httpFactory.CreateClient("tools"); - using var resp = await client.SendAsync(req, ct); - var html = await resp.Content.ReadAsStringAsync(ct); + using var resp = await client.SendAsync(req, ct).ConfigureAwait(false); + var html = await resp.Content.ReadAsStringAsync(ct).ConfigureAwait(false); // Extract result links and snippets with compiled regex var results = new List(); @@ -623,7 +614,11 @@ private async Task DdgSearchAsync(string query, int count, CancellationT internal sealed record ExaSearchRequest(string Query, int NumResults, string Type); -internal sealed record TavilySearchRequest(string ApiKey, string Query, int MaxResults, string SearchDepth); +internal sealed record TavilySearchRequest( + [property: JsonPropertyName("api_key")] string ApiKey, + string Query, + [property: JsonPropertyName("max_results")] int MaxResults, + [property: JsonPropertyName("search_depth")] string SearchDepth); internal sealed record FirecrawlSearchRequest(string Query, int Limit); @@ -632,7 +627,7 @@ internal sealed record PerplexityMessage(string Role, string Content); internal sealed record PerplexitySearchRequest( string Model, IReadOnlyList Messages, - int MaxTokens); + [property: JsonPropertyName("max_tokens")] int MaxTokens); internal sealed record GlmMessage(string Role, string Content); @@ -649,4 +644,5 @@ internal sealed record GlmSearchRequest( [JsonSerializable(typeof(GlmMessage))] [JsonSerializable(typeof(GlmSearchRequest))] [JsonSerializable(typeof(IReadOnlyList))] +[JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] internal partial class WebSearchJsonContext : JsonSerializerContext; \ No newline at end of file diff --git a/src/clawsharp/Webhooks/DeliveryStatuses.cs b/src/clawsharp/Webhooks/DeliveryStatuses.cs new file mode 100644 index 00000000..d86c9c85 --- /dev/null +++ b/src/clawsharp/Webhooks/DeliveryStatuses.cs @@ -0,0 +1,26 @@ +namespace Clawsharp.Webhooks; + +/// +/// String constants for values. +/// Centralises the five status discriminators so they are never duplicated as magic strings. +/// +internal static class DeliveryStatuses +{ + public const string Pending = "pending"; + public const string Delivered = "delivered"; + public const string Failed = "failed"; + public const string Dlq = "dlq"; + public const string Replayed = "replayed"; +} + +/// +/// String constants for values. +/// Used by and to classify +/// delivery results for OTel instruments and SSE broadcast. +/// +internal static class DeliveryOutcomes +{ + public const string Success = "delivery.success"; + public const string Failed = "delivery.failed"; + public const string Dlq = "delivery.dlq"; +} diff --git a/src/clawsharp/Webhooks/DeliveryStorage.cs b/src/clawsharp/Webhooks/DeliveryStorage.cs index f288ff4e..474e6b3e 100644 --- a/src/clawsharp/Webhooks/DeliveryStorage.cs +++ b/src/clawsharp/Webhooks/DeliveryStorage.cs @@ -1,5 +1,8 @@ using System.Text.Json; using Clawsharp.Config; +using Clawsharp.Core.Utilities; +using Microsoft.Extensions.Logging; +using UglyToad.PdfPig.Core; namespace Clawsharp.Webhooks; @@ -14,13 +17,14 @@ namespace Clawsharp.Webhooks; /// Each file has its own to allow concurrent writes to different files /// while serializing writes to the same file. Per D-07 through D-09 of the v2.3 webhook design. /// -public sealed class DeliveryStorage +public sealed partial class DeliveryStorage { private readonly string _dir; private readonly string _outboxPath; private readonly string _historyPath; private readonly string _dlqPath; private readonly int _historyMaxEntries; + private readonly ILogger? _logger; private readonly SemaphoreSlim _outboxLock = new(1, 1); private readonly SemaphoreSlim _historyLock = new(1, 1); @@ -32,6 +36,14 @@ public sealed class DeliveryStorage /// private int _historyCount; + /// + /// DI constructor — stores files under ~/.clawsharp/webhooks/. + /// + public DeliveryStorage(ILogger logger) + : this(ConfigLoader.ExpandHome("~/.clawsharp/webhooks"), logger: logger) + { + } + /// /// Default constructor — stores files under ~/.clawsharp/webhooks/. /// @@ -44,11 +56,13 @@ public DeliveryStorage() : this(ConfigLoader.ExpandHome("~/.clawsharp/webhooks") /// /// Absolute path to the directory where JSONL files are stored. /// Number of history entries before rotating. Default 10 000. - internal DeliveryStorage(string directory, int historyMaxEntries = 10_000) + /// Optional logger for rotation and pruning warnings. + internal DeliveryStorage(string directory, int historyMaxEntries = 10_000, ILogger? logger = null) { _dir = directory; _historyMaxEntries = historyMaxEntries; - Directory.CreateDirectory(_dir); + _logger = logger; + FilePermissions.EnsureRestrictedDirectory(_dir); _outboxPath = Path.Combine(_dir, "outbox.jsonl"); _historyPath = Path.Combine(_dir, "history.jsonl"); @@ -70,7 +84,7 @@ public async Task AppendOutboxAsync(WebhookDeliveryRecord record, CancellationTo await _outboxLock.WaitAsync(ct).ConfigureAwait(false); try { - await File.AppendAllTextAsync(_outboxPath, json + "\n", ct).ConfigureAwait(false); + await File.AppendAllLinesAsync(_outboxPath, [json], ct).ConfigureAwait(false); } finally { @@ -91,7 +105,7 @@ public void AppendOutboxSync(WebhookDeliveryRecord record) _outboxLock.Wait(); try { - File.AppendAllText(_outboxPath, json + "\n"); + File.AppendAllLines(_outboxPath, [json]); } finally { @@ -102,7 +116,7 @@ public void AppendOutboxSync(WebhookDeliveryRecord record) /// /// Appends a delivery record to history.jsonl. /// When the entry count reaches , the file is atomically rotated - /// to history.{yyyyMMddHHmmss}.jsonl and a fresh file is started. + /// to history.{yyyyMMddHHmmssffff}.jsonl and a fresh file is started. /// Thread-safe via dedicated . /// public async Task AppendHistoryAsync(WebhookDeliveryRecord record, CancellationToken ct = default) @@ -111,12 +125,19 @@ public async Task AppendHistoryAsync(WebhookDeliveryRecord record, CancellationT await _historyLock.WaitAsync(ct).ConfigureAwait(false); try { - await File.AppendAllTextAsync(_historyPath, json + "\n", ct).ConfigureAwait(false); + await File.AppendAllLinesAsync(_historyPath, [json], ct).ConfigureAwait(false); _historyCount++; if (_historyCount >= _historyMaxEntries) { - RotateHistory(); + try + { + RotateHistory(); + } + catch (IOException ex) + { + LogHistoryRotationFailed(_logger, ex); + } } } finally @@ -135,7 +156,7 @@ public async Task AppendDlqAsync(WebhookDeliveryRecord record, CancellationToken await _dlqLock.WaitAsync(ct).ConfigureAwait(false); try { - await File.AppendAllTextAsync(_dlqPath, json + "\n", ct).ConfigureAwait(false); + await File.AppendAllLinesAsync(_dlqPath, [json], ct).ConfigureAwait(false); } finally { @@ -176,7 +197,7 @@ public async Task> ReadDlqAsync(Cancellatio return all .GroupBy(r => r.Id, StringComparer.Ordinal) .Select(g => g.OrderByDescending(r => r.CreatedAt).First()) - .Where(r => !string.Equals(r.Status, "replayed", StringComparison.Ordinal)) + .Where(r => !string.Equals(r.Status, DeliveryStatuses.Replayed, StringComparison.Ordinal)) .ToList() .AsReadOnly(); } @@ -186,7 +207,7 @@ public async Task> ReadDlqAsync(Cancellatio /// /// Removes delivered and dlq records from outbox.jsonl, keeping only /// pending and failed records that still need to be retried. - /// Uses an atomic so the outbox is never in a partial state. + /// Uses an atomic so the outbox is never in a partial state. /// Thread-safe via dedicated . /// public async Task CompactOutboxAsync(CancellationToken ct = default) @@ -213,8 +234,8 @@ public async Task CompactOutboxAsync(CancellationToken ct = default) { var record = JsonSerializer.Deserialize(line, WebhookJsonContext.Default.WebhookDeliveryRecord); if (record is not null - && !string.Equals(record.Status, "delivered", StringComparison.Ordinal) - && !string.Equals(record.Status, "dlq", StringComparison.Ordinal)) + && !string.Equals(record.Status, DeliveryStatuses.Delivered, StringComparison.Ordinal) + && !string.Equals(record.Status, DeliveryStatuses.Dlq, StringComparison.Ordinal)) { kept.Add(line); } @@ -235,16 +256,72 @@ public async Task CompactOutboxAsync(CancellationToken ct = default) } } + // ── Pruning ─────────────────────────────────────────────────────────────── + + /// + /// Compacts the outbox and prunes DLQ entries older than . + /// Designed to be called periodically from . + /// + public async Task PruneAsync(int dlqRetentionDays, CancellationToken ct = default) + { + await CompactOutboxAsync(ct).ConfigureAwait(false); + await PruneDlqAsync(dlqRetentionDays, ct).ConfigureAwait(false); + } + + /// + /// Removes DLQ entries older than days. + /// Uses an atomic so the DLQ is never in a partial state. + /// Thread-safe via dedicated . + /// + private async Task PruneDlqAsync(int retentionDays, CancellationToken ct) + { + await _dlqLock.WaitAsync(ct).ConfigureAwait(false); + try + { + if (!File.Exists(_dlqPath)) + return; + + var cutoff = DateTimeOffset.UtcNow.AddDays(-retentionDays); + var lines = await File.ReadAllLinesAsync(_dlqPath, ct).ConfigureAwait(false); + var kept = new List(lines.Length); + + foreach (var line in lines) + { + if (string.IsNullOrWhiteSpace(line)) + continue; + + try + { + var record = JsonSerializer.Deserialize(line, WebhookJsonContext.Default.WebhookDeliveryRecord); + if (record is not null && record.CreatedAt >= cutoff) + kept.Add(line); + } + catch (JsonException) + { + // Skip malformed lines — matches read behavior + } + } + + var tempPath = _dlqPath + ".tmp"; + await File.WriteAllLinesAsync(tempPath, kept, ct).ConfigureAwait(false); + File.Move(tempPath, _dlqPath, overwrite: true); + } + finally + { + _dlqLock.Release(); + } + } + // ── Private helpers ─────────────────────────────────────────────────────── /// /// Rotates the history file to a timestamped archive file. /// Must be called while is held. - /// Uses atomic so readers never observe a partial file. + /// Uses atomic so readers never observe a partial file. /// private void RotateHistory() { - var timestamp = DateTimeOffset.UtcNow.ToString("yyyyMMddHHmmss"); + var timestamp = DateTimeOffset.UtcNow.ToString("yyyyMMddHHmmssffff"); var archivePath = Path.Combine(_dir, $"history.{timestamp}.jsonl"); File.Move(_historyPath, archivePath, overwrite: false); _historyCount = 0; @@ -309,4 +386,10 @@ private static int CountLines(string filePath) return count; } + + // ── LoggerMessage methods ──────────────────────────────────────────────── + + [LoggerMessage(EventId = 1, Level = LogLevel.Warning, + Message = "History rotation failed; will retry on next write")] + private static partial void LogHistoryRotationFailed(ILogger? logger, Exception exception); } diff --git a/src/clawsharp/Webhooks/WebhookDashboardDtos.cs b/src/clawsharp/Webhooks/WebhookDashboardDtos.cs index fffdbd81..32a1d970 100644 --- a/src/clawsharp/Webhooks/WebhookDashboardDtos.cs +++ b/src/clawsharp/Webhooks/WebhookDashboardDtos.cs @@ -135,6 +135,19 @@ public sealed record ReplayResponse public required string Message { get; init; } } +/// Response for bulk DLQ replay. +public sealed record BulkReplayResponse +{ + [JsonPropertyName("replayed")] + public required int Replayed { get; init; } + + [JsonPropertyName("endpoint")] + public required string Endpoint { get; init; } + + [JsonPropertyName("message")] + public required string Message { get; init; } +} + /// /// A single delivery event broadcast to SSE clients via WebhookMetrics. /// Represents the outcome of a single delivery attempt. diff --git a/src/clawsharp/Webhooks/WebhookDeliveryWorker.cs b/src/clawsharp/Webhooks/WebhookDeliveryWorker.cs index 4d8c1836..b19c6f83 100644 --- a/src/clawsharp/Webhooks/WebhookDeliveryWorker.cs +++ b/src/clawsharp/Webhooks/WebhookDeliveryWorker.cs @@ -1,10 +1,9 @@ -using System.Collections.Frozen; using System.Diagnostics; using System.Net; using System.Net.Http; -using System.Text; using System.Threading.Channels; using Clawsharp.Config.Features; +using Clawsharp.Core.Utilities; using Clawsharp.Organization; using Clawsharp.Telemetry; using Clawsharp.Webhooks.Formatters; @@ -36,6 +35,10 @@ public sealed partial class WebhookDeliveryWorker : BackgroundService private const int ChannelDeliveryMaxAttempts = 3; // D-14 private const int CircuitBreakerPauseSeconds = 30; // D-09 break duration + /// Cached empty JSON object element — avoids allocating a per call. + private static readonly System.Text.Json.JsonElement EmptyJsonObject = + System.Text.Json.JsonDocument.Parse("{}").RootElement.Clone(); + private readonly WebhookConfig _webhookConfig; private readonly DeliveryStorage _storage; private readonly ChannelNotifier _channelNotifier; @@ -51,9 +54,6 @@ public sealed partial class WebhookDeliveryWorker : BackgroundService // Parsed channel:// targets for channel-routed endpoints. private readonly Dictionary _channelTargets; - // Formatter lookup (format name → IWebhookFormatter). - private readonly FrozenDictionary _formatters; - public WebhookDeliveryWorker( WebhookConfig webhookConfig, DeliveryStorage storage, @@ -73,14 +73,6 @@ public WebhookDeliveryWorker( _logger = logger; _webhookMetrics = webhookMetrics; - _formatters = new Dictionary(StringComparer.OrdinalIgnoreCase) - { - ["json"] = new JsonWebhookFormatter(), - ["slack"] = new SlackWebhookFormatter(), - ["discord"] = new DiscordWebhookFormatter(), - ["teams"] = new TeamsWebhookFormatter(), - }.ToFrozenDictionary(StringComparer.OrdinalIgnoreCase); - _pipelines = new Dictionary>(StringComparer.Ordinal); _channelTargets = new Dictionary(StringComparer.Ordinal); @@ -126,6 +118,9 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) tasks.Add(ConsumeHttpEndpointAsync(endpointId, reader, pipeline, stoppingToken)); } + // Periodic pruning: compact outbox and enforce DLQ retention (every 6 hours). + tasks.Add(RunPruneLoopAsync(stoppingToken)); + await Task.WhenAll(tasks).ConfigureAwait(false); } @@ -170,7 +165,7 @@ private async Task RecoverOutboxAsync(CancellationToken ct) var pendingRecords = allRecords .GroupBy(r => r.Id, StringComparer.Ordinal) .Select(g => g.OrderByDescending(r => r.CreatedAt).First()) - .Where(r => string.Equals(r.Status, "pending", StringComparison.Ordinal)) + .Where(r => string.Equals(r.Status, DeliveryStatuses.Pending, StringComparison.Ordinal)) .OrderBy(r => r.CreatedAt) .ToList(); @@ -179,17 +174,40 @@ private async Task RecoverOutboxAsync(CancellationToken ct) foreach (var record in pendingRecords) { if (_queueRegistry.GetReader(record.EndpointId) is not null - && _webhookConfig.Endpoints!.TryGetValue(record.EndpointId, out var endpointConfig)) + && _webhookConfig.Endpoints?.TryGetValue(record.EndpointId, out var endpointConfig) == true) { - var body = record.Payload ?? "{}"; - var formatter = ResolveFormatter(endpointConfig.Format); - var job = new WebhookJob(record, endpointConfig, record.EndpointId, body); + var body = record.Payload ?? "{}"; + var formatter = WebhookFormatterRegistry.ResolveFormatter(endpointConfig.Format); + + // Apply the formatter so the body matches the platform-specific format + // (Slack Block Kit, Discord embed, Teams card) and HMAC signs the correct content. + string formattedBody; + try + { + if (!string.IsNullOrEmpty(record.Payload)) + { + var payload = System.Text.Json.JsonSerializer.Deserialize( + record.Payload, WebhookJsonContext.Default.WebhookPayload); + formattedBody = payload is not null ? formatter.Format(payload) : body; + } + else + { + formattedBody = body; + } + } + catch (System.Text.Json.JsonException ex) + { + LogRecoveryFormatterFailed(_logger, record.Id, ex.Message); + formattedBody = body; + } + + var job = new WebhookJob(record, endpointConfig, record.EndpointId, formattedBody); await _queueRegistry.WriteAsync(record.EndpointId, job, ct).ConfigureAwait(false); } else { // Endpoint was removed from config — move to DLQ (D-04). - record.Status = "dlq"; + record.Status = DeliveryStatuses.Dlq; record.LastError = "Endpoint removed from config during recovery"; record.FailedAt = DateTimeOffset.UtcNow; await _storage.AppendDlqAsync(record, ct).ConfigureAwait(false); @@ -197,6 +215,25 @@ private async Task RecoverOutboxAsync(CancellationToken ct) } } + // ── Periodic Pruning ──────────────────────────────────────────────────── + + private async Task RunPruneLoopAsync(CancellationToken ct) + { + using var timer = new PeriodicTimer(TimeSpan.FromHours(6)); + while (await timer.WaitForNextTickAsync(ct).ConfigureAwait(false)) + { + try + { + await _storage.PruneAsync(_webhookConfig.DlqRetentionDays, ct).ConfigureAwait(false); + LogPruneCompleted(_logger); + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + LogPruneFailed(_logger, ex); + } + } + } + // ── HTTP Endpoint Consumer ───────────────────────────────────────────────── private async Task ConsumeHttpEndpointAsync( @@ -240,7 +277,7 @@ await HandleOutcomeAsync(job, outcome, (int)response.StatusCode, null, ct) deliverSpan?.SetTag(WebhookAttributes.FinalOutcome, finalOutcome); deliverSpan?.SetTag(WebhookAttributes.TotalAttempts, job.Record.AttemptCount); - if (string.Equals(finalOutcome, "dlq", StringComparison.Ordinal)) + if (string.Equals(finalOutcome, DeliveryStatuses.Dlq, StringComparison.Ordinal)) deliverSpan?.SetStatus(ActivityStatusCode.Error, "Delivery failed — moved to DLQ"); } catch (BrokenCircuitException) @@ -353,25 +390,27 @@ private async Task ConsumeChannelEndpointAsync( await Task.Delay(jitter, ct).ConfigureAwait(false); } + job.Record.AttemptCount = attemptCount; + if (result == ChannelDeliveryResult.Success) { - job.Record.Status = "delivered"; + job.Record.Status = DeliveryStatuses.Delivered; job.Record.DeliveredAt = DateTimeOffset.UtcNow; await _storage.AppendHistoryAsync(job.Record, ct).ConfigureAwait(false); LogDeliverySuccess(_logger, endpointId); - deliverSpan?.SetTag(WebhookAttributes.FinalOutcome, "delivered"); + deliverSpan?.SetTag(WebhookAttributes.FinalOutcome, DeliveryStatuses.Delivered); deliverSpan?.SetTag(WebhookAttributes.TotalAttempts, attemptCount); } else { - job.Record.Status = "dlq"; + job.Record.Status = DeliveryStatuses.Dlq; job.Record.LastError = $"Channel delivery failed: {result}"; job.Record.FailedAt = DateTimeOffset.UtcNow; await _storage.AppendDlqAsync(job.Record, ct).ConfigureAwait(false); LogDeliveryDlq(_logger, endpointId, job.Record.LastError); - deliverSpan?.SetTag(WebhookAttributes.FinalOutcome, "dlq"); + deliverSpan?.SetTag(WebhookAttributes.FinalOutcome, DeliveryStatuses.Dlq); deliverSpan?.SetTag(WebhookAttributes.TotalAttempts, attemptCount); deliverSpan?.SetStatus(ActivityStatusCode.Error, $"Channel delivery failed: {result}"); } @@ -388,7 +427,7 @@ private HttpRequestMessage BuildHttpRequest(WebhookJob job) var url = job.TargetUrl ?? job.EndpointConfig.Url; var request = new HttpRequestMessage(HttpMethod.Post, url) { - Content = new StringContent(job.FormattedBody, Encoding.UTF8, "application/json"), + Content = Utf8JsonContent.FromString(job.FormattedBody), }; // Idempotency header (WH-08). @@ -425,7 +464,7 @@ private async Task HandleOutcomeAsync( switch (outcome) { case DeliveryOutcome.Success: - job.Record.Status = "delivered"; + job.Record.Status = DeliveryStatuses.Delivered; job.Record.DeliveredAt = DateTimeOffset.UtcNow; await _storage.AppendHistoryAsync(job.Record, ct).ConfigureAwait(false); LogDeliverySuccess(_logger, job.EndpointId); @@ -434,7 +473,7 @@ private async Task HandleOutcomeAsync( Id = job.Record.Id, Endpoint = job.EndpointId, Type = job.Record.EventType, - Outcome = "delivery.success", + Outcome = DeliveryOutcomes.Success, Attempt = job.Record.AttemptCount, Status = statusCode, Error = error, @@ -445,7 +484,7 @@ private async Task HandleOutcomeAsync( case DeliveryOutcome.RateLimited: // only reaches here if 429 with Retry-After > 60s case DeliveryOutcome.TransientFailure: // After Polly exhausted all retries, move to DLQ. - job.Record.Status = "dlq"; + job.Record.Status = DeliveryStatuses.Dlq; job.Record.FailedAt = DateTimeOffset.UtcNow; await _storage.AppendDlqAsync(job.Record, ct).ConfigureAwait(false); LogDeliveryDlq(_logger, job.EndpointId, error ?? $"status={statusCode}"); @@ -454,7 +493,7 @@ private async Task HandleOutcomeAsync( Id = job.Record.Id, Endpoint = job.EndpointId, Type = job.Record.EventType, - Outcome = "delivery.dlq", + Outcome = DeliveryOutcomes.Dlq, Attempt = job.Record.AttemptCount, Status = statusCode, Error = error, @@ -535,7 +574,6 @@ private ResiliencePipeline BuildHttpPipeline( .Handle(), OnOpened = args => { - _ = NotifyCircuitOpenedAsync(endpointId, args.BreakDuration); LogCircuitOpened(_logger, endpointId, args.BreakDuration); _webhookMetrics?.RecordCircuitChanged(endpointId, "open"); return default; @@ -550,26 +588,6 @@ private ResiliencePipeline BuildHttpPipeline( .Build(); } - // ── Circuit Breaker Notification ─────────────────────────────────────────── - - private async Task NotifyCircuitOpenedAsync(string endpointId, TimeSpan breakDuration) - { - // AdminNotifier only exposes approval-specific methods; log at Warning level per plan note. - // Future: extend AdminNotifier with a general-purpose notification method. - try - { - await Task.CompletedTask.ConfigureAwait(false); // async context required for fire-and-catch - _logger.LogWarning( - "Circuit breaker opened for endpoint '{EndpointId}', break duration: {BreakDuration}. " + - "Admin notification via AdminNotifier is not available for circuit breaker events.", - endpointId, breakDuration); - } - catch - { - // Fire-and-catch — circuit notifications must never propagate. - } - } - // ── Helpers ──────────────────────────────────────────────────────────────── /// @@ -599,11 +617,6 @@ private async Task NotifyCircuitOpenedAsync(string endpointId, TimeSpan breakDur } } - private IWebhookFormatter ResolveFormatter(string? format) => - _formatters.TryGetValue(format ?? "json", out var formatter) - ? formatter - : _formatters["json"]; - /// /// Reconstructs a from the job's stored payload JSON. /// Used by channel delivery consumers which need the typed payload. @@ -620,7 +633,7 @@ private static WebhookPayload BuildPayloadFromJob(WebhookJob job) Category = ExtractCategory(job.Record.EventType), Timestamp = job.Record.CreatedAt, Source = new WebhookSource { Instance = job.Record.EndpointUrl }, - Data = System.Text.Json.JsonDocument.Parse("{}").RootElement, + Data = EmptyJsonObject, }; } @@ -669,4 +682,16 @@ private static partial void LogCircuitOpened( [LoggerMessage(EventId = 9, Level = LogLevel.Warning, Message = "Webhook moved to DLQ for endpoint '{EndpointId}': {Error}")] private static partial void LogDeliveryDlq(ILogger logger, string endpointId, string error); + + [LoggerMessage(EventId = 10, Level = LogLevel.Warning, + Message = "Recovery formatter failed for record '{RecordId}', delivering raw JSON: {Error}")] + private static partial void LogRecoveryFormatterFailed(ILogger logger, string recordId, string error); + + [LoggerMessage(EventId = 11, Level = LogLevel.Information, + Message = "Periodic prune completed — outbox compacted, DLQ retention enforced")] + private static partial void LogPruneCompleted(ILogger logger); + + [LoggerMessage(EventId = 12, Level = LogLevel.Warning, + Message = "Periodic prune failed")] + private static partial void LogPruneFailed(ILogger logger, Exception exception); } diff --git a/src/clawsharp/Webhooks/WebhookDispatchService.cs b/src/clawsharp/Webhooks/WebhookDispatchService.cs index 6f956b93..6b78dea8 100644 --- a/src/clawsharp/Webhooks/WebhookDispatchService.cs +++ b/src/clawsharp/Webhooks/WebhookDispatchService.cs @@ -6,7 +6,6 @@ using Clawsharp.Core.Events; using Clawsharp.Telemetry; using Clawsharp.Tools; -using Clawsharp.Webhooks.Formatters; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; @@ -44,9 +43,6 @@ public sealed partial class WebhookDispatchService : IHostedService /// Event type → EventTypeAttribute, for O(1) lookup in the hot-path handler. private readonly FrozenDictionary _wireNameLookup; - /// Formatter registry for building formatted webhook bodies. - private readonly FrozenDictionary _formatters; - public WebhookDispatchService( IEventBus eventBus, WebhookConfig webhookConfig, @@ -60,14 +56,6 @@ public WebhookDispatchService( _storage = storage; _logger = logger; - _formatters = new Dictionary(StringComparer.OrdinalIgnoreCase) - { - ["json"] = new JsonWebhookFormatter(), - ["slack"] = new SlackWebhookFormatter(), - ["discord"] = new DiscordWebhookFormatter(), - ["teams"] = new TeamsWebhookFormatter(), - }.ToFrozenDictionary(StringComparer.OrdinalIgnoreCase); - // Build wire name → matching endpoint ID set from config + SystemEventRegistry. // Per D-05: check Categories (null = all) and Filter glob (null = all). _dispatchMap = BuildDispatchMap(webhookConfig); @@ -79,6 +67,8 @@ public WebhookDispatchService( /// public Task StartAsync(CancellationToken cancellationToken) { + ValidateJsonContextCoverage(); + foreach (var (eventType, attr) in SystemEventRegistry.All) { var capturedAttr = attr; @@ -91,6 +81,21 @@ public Task StartAsync(CancellationToken cancellationToken) return Task.CompletedTask; } + private void ValidateJsonContextCoverage() + { + foreach (var (eventType, _) in SystemEventRegistry.All) + { + try + { + WebhookJsonContext.Default.GetTypeInfo(eventType); + } + catch (InvalidOperationException) + { + LogMissingJsonContext(eventType.FullName ?? eventType.Name); + } + } + } + /// public Task StopAsync(CancellationToken cancellationToken) { @@ -159,7 +164,7 @@ private void OnEventPublished(object evt, Type eventType, EventTypeAttribute att if (!_webhookConfig.Endpoints.TryGetValue(endpointId, out var endpointConfig)) continue; - var formatter = ResolveFormatter(endpointConfig.Format); + var formatter = WebhookFormatterRegistry.ResolveFormatter(endpointConfig.Format); string formattedBody; try { @@ -177,7 +182,7 @@ private void OnEventPublished(object evt, Type eventType, EventTypeAttribute att EndpointId = endpointId, EndpointUrl = endpointConfig.Url, EventType = payload.Type, - Status = "pending", + Status = DeliveryStatuses.Pending, Payload = payloadJson, CreatedAt = payload.Timestamp, }; @@ -272,11 +277,6 @@ private static bool FilterMatches(string? filter, string wireName) return FileSystemName.MatchesSimpleExpression(filter, wireName, ignoreCase: true); } - private IWebhookFormatter ResolveFormatter(string? format) => - _formatters.TryGetValue(format ?? "json", out var formatter) - ? formatter - : _formatters["json"]; - // ── Logging ─────────────────────────────────────────────────────────────── [LoggerMessage(EventId = 1, Level = LogLevel.Information, @@ -302,4 +302,8 @@ private IWebhookFormatter ResolveFormatter(string? format) => [LoggerMessage(EventId = 6, Level = LogLevel.Warning, Message = "Queue full for endpoint '{EndpointId}' — event '{EventId}' enqueue failed (record persisted in outbox)")] private partial void LogQueueFull(string endpointId, string eventId); + + [LoggerMessage(EventId = 7, Level = LogLevel.Error, + Message = "ISystemEvent type '{EventTypeName}' is not registered in WebhookJsonContext — webhook serialization will fail at runtime")] + private partial void LogMissingJsonContext(string eventTypeName); } diff --git a/src/clawsharp/Webhooks/WebhookFormatterRegistry.cs b/src/clawsharp/Webhooks/WebhookFormatterRegistry.cs new file mode 100644 index 00000000..f96af325 --- /dev/null +++ b/src/clawsharp/Webhooks/WebhookFormatterRegistry.cs @@ -0,0 +1,34 @@ +using System.Collections.Frozen; +using Clawsharp.Webhooks.Formatters; + +namespace Clawsharp.Webhooks; + +/// +/// Shared formatter lookup used by both and +/// . Owns a +/// of format name to and a helper. +/// +internal static class WebhookFormatterRegistry +{ + /// + /// Immutable mapping of format name (case-insensitive) to formatter implementation. + /// Default format is "json". + /// + public static readonly FrozenDictionary Formatters = + new Dictionary(StringComparer.OrdinalIgnoreCase) + { + ["json"] = new JsonWebhookFormatter(), + ["slack"] = new SlackWebhookFormatter(), + ["discord"] = new DiscordWebhookFormatter(), + ["teams"] = new TeamsWebhookFormatter(), + }.ToFrozenDictionary(StringComparer.OrdinalIgnoreCase); + + /// + /// Resolves a formatter by name. Falls back to the "json" formatter + /// when the name is null or unrecognised. + /// + public static IWebhookFormatter ResolveFormatter(string? format) => + Formatters.TryGetValue(format ?? "json", out var formatter) + ? formatter + : Formatters["json"]; +} diff --git a/src/clawsharp/Webhooks/WebhookJsonContext.cs b/src/clawsharp/Webhooks/WebhookJsonContext.cs index 35500e52..51236d0e 100644 --- a/src/clawsharp/Webhooks/WebhookJsonContext.cs +++ b/src/clawsharp/Webhooks/WebhookJsonContext.cs @@ -26,8 +26,11 @@ namespace Clawsharp.Webhooks; [JsonSerializable(typeof(DlqListResponse))] [JsonSerializable(typeof(DlqEntryResponse))] [JsonSerializable(typeof(ReplayResponse))] +[JsonSerializable(typeof(BulkReplayResponse))] [JsonSerializable(typeof(DeliveryEvent))] [JsonSerializable(typeof(EndpointSnapshot))] [JsonSerializable(typeof(Dictionary))] -[JsonSourceGenerationOptions(DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull)] +[JsonSourceGenerationOptions( + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] internal partial class WebhookJsonContext : JsonSerializerContext; diff --git a/src/clawsharp/Webhooks/WebhookMessageBuilder.cs b/src/clawsharp/Webhooks/WebhookMessageBuilder.cs index 5e820181..febba970 100644 --- a/src/clawsharp/Webhooks/WebhookMessageBuilder.cs +++ b/src/clawsharp/Webhooks/WebhookMessageBuilder.cs @@ -2,6 +2,7 @@ using System.Text; using Clawsharp.Core; using Clawsharp.Core.Utilities; +using Clawsharp.Webhooks.Formatters; namespace Clawsharp.Webhooks; @@ -49,7 +50,7 @@ public static OutboundMessage ToChannelMessage( sb.AppendLine($"User: {payload.Source.User ?? "system"}"); sb.AppendLine($"Time: {payload.Timestamp:O}"); - var dataSummary = BuildDataSummary(payload.Data); + var dataSummary = WebhookFormatterHelper.BuildDataSummary(payload.Data); if (dataSummary.Length > 0) { sb.AppendLine(); @@ -58,44 +59,4 @@ public static OutboundMessage ToChannelMessage( return new OutboundMessage(channel, recipientId, sb.ToString().TrimEnd()); } - - private const int MaxDataFields = 10; - - private static string BuildDataSummary(System.Text.Json.JsonElement data) - { - if (data.ValueKind != System.Text.Json.JsonValueKind.Object) - { - return string.Empty; - } - - var sb = new StringBuilder(); - var count = 0; - var totalCount = 0; - - foreach (var prop in data.EnumerateObject()) - { - totalCount++; - if (count < MaxDataFields) - { - var valueStr = prop.Value.ValueKind switch - { - System.Text.Json.JsonValueKind.String => prop.Value.GetString() ?? string.Empty, - System.Text.Json.JsonValueKind.Number => prop.Value.GetRawText(), - System.Text.Json.JsonValueKind.True => "true", - System.Text.Json.JsonValueKind.False => "false", - System.Text.Json.JsonValueKind.Null => "(null)", - _ => prop.Value.GetRawText() - }; - sb.AppendLine($"{prop.Name}: {valueStr}"); - count++; - } - } - - if (totalCount > MaxDataFields) - { - sb.AppendLine("..."); - } - - return sb.ToString().TrimEnd(); - } } diff --git a/src/clawsharp/Webhooks/WebhookMetrics.cs b/src/clawsharp/Webhooks/WebhookMetrics.cs index 3281ceea..6387499c 100644 --- a/src/clawsharp/Webhooks/WebhookMetrics.cs +++ b/src/clawsharp/Webhooks/WebhookMetrics.cs @@ -26,7 +26,7 @@ private sealed class EndpointMetrics } private readonly ConcurrentDictionary _endpoints = new(StringComparer.Ordinal); - private readonly ConcurrentDictionary Writer, string? TypeFilter, string? EndpointFilter)> _sseClients = new(StringComparer.Ordinal); + private readonly ConcurrentDictionary Writer, string? OutcomeFilter, string? EndpointFilter)> _sseClients = new(StringComparer.Ordinal); private readonly DateTimeOffset _startedAt = DateTimeOffset.UtcNow; // ── OTel instruments ────────────────────────────────────────────────────── @@ -101,13 +101,13 @@ public void RecordDelivery(string endpointId, DeliveryEvent evt) { switch (evt.Outcome) { - case "delivery.success": + case DeliveryOutcomes.Success: Interlocked.Increment(ref metrics.Delivered); break; - case "delivery.failed": + case DeliveryOutcomes.Failed: Interlocked.Increment(ref metrics.Failed); break; - case "delivery.dlq": + case DeliveryOutcomes.Dlq: Interlocked.Increment(ref metrics.Dlq); break; } @@ -122,20 +122,20 @@ public void RecordDelivery(string endpointId, DeliveryEvent evt) switch (evt.Outcome) { - case "delivery.success": + case DeliveryOutcomes.Success: _deliveredCounter.Add(1, tags); break; - case "delivery.failed": - case "delivery.dlq": + case DeliveryOutcomes.Failed: + case DeliveryOutcomes.Dlq: _failedCounter.Add(1, tags); break; } // Broadcast to SSE clients. var deadClients = new List(); - foreach (var (clientId, (writer, typeFilter, endpointFilter)) in _sseClients) + foreach (var (clientId, (writer, outcomeFilter, endpointFilter)) in _sseClients) { - if (typeFilter is not null && !string.Equals(evt.Outcome, typeFilter, StringComparison.Ordinal)) + if (outcomeFilter is not null && !string.Equals(evt.Outcome, outcomeFilter, StringComparison.Ordinal)) continue; if (endpointFilter is not null && !string.Equals(evt.Endpoint, endpointFilter, StringComparison.Ordinal)) continue; @@ -245,13 +245,21 @@ public string GetUptime() // ── SSE ─────────────────────────────────────────────────────────────────── + private const int MaxSseClients = 50; + /// - /// Registers a new SSE client with optional type and endpoint filters. - /// Returns a registration (IDisposable, removes client on Dispose) and the reader. + /// Registers a new SSE client with optional outcome and endpoint filters. + /// Returns a registration (IDisposable, removes client on Dispose) and the reader, + /// or null if the maximum number of concurrent SSE clients has been reached. /// - public (IDisposable Registration, ChannelReader Reader) RegisterSseClient( - string? typeFilter, string? endpointFilter) + public (IDisposable Registration, ChannelReader Reader)? RegisterSseClient( + string? outcomeFilter, string? endpointFilter) { + if (_sseClients.Count >= MaxSseClients) + { + return null; + } + var channel = Channel.CreateBounded(new BoundedChannelOptions(100) { FullMode = BoundedChannelFullMode.DropOldest, @@ -260,7 +268,7 @@ public string GetUptime() }); var clientId = Guid.NewGuid().ToString("N"); - _sseClients[clientId] = (channel.Writer, typeFilter, endpointFilter); + _sseClients[clientId] = (channel.Writer, outcomeFilter, endpointFilter); var registration = new SseClientRegistration(this, clientId, channel.Writer); return (registration, channel.Reader); diff --git a/src/clawsharp/Webhooks/WebhookQueueRegistry.cs b/src/clawsharp/Webhooks/WebhookQueueRegistry.cs index 72fcf24f..ed4f789f 100644 --- a/src/clawsharp/Webhooks/WebhookQueueRegistry.cs +++ b/src/clawsharp/Webhooks/WebhookQueueRegistry.cs @@ -63,26 +63,21 @@ public WebhookQueueRegistry(WebhookConfig webhookConfig) /// Creates a dynamic queue for runtime-registered endpoints (e.g., push notification targets). /// Returns true if the queue was created; false if it already exists in either /// config-defined or dynamic queues. - /// Thread-safe via . + /// Thread-safe via . /// public bool TryCreateQueue(string endpointId) { if (_queues.ContainsKey(endpointId)) return false; - var created = false; - _dynamicQueues.GetOrAdd(endpointId, _ => - { - created = true; - return Channel.CreateBounded( - new BoundedChannelOptions(QueueCapacity) - { - FullMode = BoundedChannelFullMode.DropOldest, // Push: drop old if slow consumer - SingleReader = true, - SingleWriter = false, - }); - }); - return created; + var channel = Channel.CreateBounded( + new BoundedChannelOptions(QueueCapacity) + { + FullMode = BoundedChannelFullMode.DropOldest, // Push: drop old if slow consumer + SingleReader = true, + SingleWriter = false, + }); + return _dynamicQueues.TryAdd(endpointId, channel); } /// Removes a dynamic queue and completes its writer. No-op for config-defined queues. diff --git a/src/clawsharp/Webhooks/WebhookRouteRegistrar.cs b/src/clawsharp/Webhooks/WebhookRouteRegistrar.cs index 764cba90..3cd51408 100644 --- a/src/clawsharp/Webhooks/WebhookRouteRegistrar.cs +++ b/src/clawsharp/Webhooks/WebhookRouteRegistrar.cs @@ -185,9 +185,13 @@ private async Task HandleBulkReplayAsync(string? endpoint, Cancellation if (result.StatusCode == 400) return Results.BadRequest(new { error = "Query parameter 'endpoint' is required for bulk replay" }); - return Results.Json( - new { replayed = result.Replayed, endpoint, message = $"Replayed {result.Replayed} entries for endpoint '{endpoint}'" }, - statusCode: 202); + var response = new BulkReplayResponse + { + Replayed = result.Replayed, + Endpoint = endpoint!, + Message = $"Replayed {result.Replayed} entries for endpoint '{endpoint}'" + }; + return Results.Json(response, WebhookJsonContext.Default.BulkReplayResponse, statusCode: 202); } /// @@ -217,12 +221,18 @@ internal async Task HandleBulkReplayCoreAsync(string? e /// /// Server-Sent Events stream of live webhook delivery outcomes. - /// Supports optional filtering by event type and endpoint ID. + /// Supports optional filtering by delivery outcome and endpoint ID. /// Per D-20: uses with per-client Channel fanout. /// - internal IResult HandleStreamAsync(string? type, string? endpoint, CancellationToken ct) + internal IResult HandleStreamAsync(string? outcome, string? endpoint, CancellationToken ct) { - var (registration, reader) = webhookMetrics.RegisterSseClient(type, endpoint); + var result = webhookMetrics.RegisterSseClient(outcome, endpoint); + if (result is null) + { + return TypedResults.StatusCode(503); + } + + var (registration, reader) = result.Value; async IAsyncEnumerable> Stream( [EnumeratorCancellation] CancellationToken cancellationToken) @@ -250,14 +260,14 @@ async IAsyncEnumerable> Stream( /// private async Task ReplayEntryAsync(WebhookDeliveryRecord entry, CancellationToken ct) { - // Append a "replayed" marker so the entry is excluded from future DLQ reads + // Append a replayed marker so the entry is excluded from future DLQ reads var replayedRecord = new WebhookDeliveryRecord { Id = entry.Id, EndpointId = entry.EndpointId, EndpointUrl = entry.EndpointUrl, EventType = entry.EventType, - Status = "replayed", + Status = DeliveryStatuses.Replayed, CreatedAt = entry.CreatedAt, Payload = entry.Payload, ReplayedAt = DateTimeOffset.UtcNow, @@ -274,10 +284,14 @@ private async Task ReplayEntryAsync(WebhookDeliveryRecord entry, CancellationTok EndpointId = entry.EndpointId, EndpointUrl = epConfig.Url, EventType = entry.EventType, - Status = "pending", + Status = DeliveryStatuses.Pending, CreatedAt = DateTimeOffset.UtcNow, Payload = entry.Payload, }; + + // Outbox-first: persist before enqueue so the job survives a crash. + await storage.AppendOutboxAsync(newRecord, ct).ConfigureAwait(false); + var job = new WebhookJob(newRecord, epConfig, entry.EndpointId, entry.Payload); queueRegistry.TryWrite(entry.EndpointId, job); } diff --git a/src/clawsharp/Webhooks/WebhookSlashCommandHandler.cs b/src/clawsharp/Webhooks/WebhookSlashCommandHandler.cs index b71dc663..a1cdc6f4 100644 --- a/src/clawsharp/Webhooks/WebhookSlashCommandHandler.cs +++ b/src/clawsharp/Webhooks/WebhookSlashCommandHandler.cs @@ -29,11 +29,14 @@ public WebhookSlashCommandHandler( _storage = storage; } - // ── Static disabled-state helpers (called when handler is null) ─────────── + // ── Static disabled-state helpers (used by tests) ───────────────────────── + // These pass null for session, which bypasses the admin check (null = single-operator mode). + // Production code uses AgentLoop.HandleWebhookStatusAsync/HandleWebhookDlqAsync which + // pass the real session. These exist for test convenience only. /// - /// Returns the response when the webhook system is not enabled. - /// Called by AgentLoop when _webhookSlashCommandHandler is null. + /// Returns the disabled message when is null, + /// or delegates to with a null session (admin bypass). /// public static Task HandleStatusAsync( WebhookSlashCommandHandler? handler, CancellationToken ct) @@ -44,8 +47,8 @@ public static Task HandleStatusAsync( } /// - /// Returns the response when the webhook system is not enabled. - /// Called by AgentLoop when _webhookSlashCommandHandler is null. + /// Returns the disabled message when is null, + /// or delegates to with a null session (admin bypass). /// public static Task HandleDlqAsync( WebhookSlashCommandHandler? handler, string? argument, CancellationToken ct) @@ -206,7 +209,7 @@ private async Task SingleReplayAsync(string id, CancellationToken ct) EndpointId = record.EndpointId, EndpointUrl = record.EndpointUrl, EventType = record.EventType, - Status = "replayed", + Status = DeliveryStatuses.Replayed, Payload = record.Payload, CreatedAt = record.CreatedAt, AttemptCount = record.AttemptCount, @@ -222,8 +225,22 @@ private async Task SingleReplayAsync(string id, CancellationToken ct) // For slash command replay, we create a minimal job using the stored payload if (record.Payload is { Length: > 0 }) { + var newRecord = new WebhookDeliveryRecord + { + Id = record.Id, + EndpointId = record.EndpointId, + EndpointUrl = record.EndpointUrl, + EventType = record.EventType, + Status = DeliveryStatuses.Pending, + CreatedAt = DateTimeOffset.UtcNow, + Payload = record.Payload, + }; + + // Outbox-first: persist before enqueue so the job survives a crash. + await _storage.AppendOutboxAsync(newRecord, ct).ConfigureAwait(false); + var job = new WebhookJob( - Record: replayed, + Record: newRecord, EndpointConfig: new Config.Features.WebhookEndpointConfig { Url = record.EndpointUrl }, EndpointId: record.EndpointId, FormattedBody: record.Payload); @@ -255,7 +272,7 @@ private async Task BulkReplayAsync(string endpoint, CancellationToken ct EndpointId = record.EndpointId, EndpointUrl = record.EndpointUrl, EventType = record.EventType, - Status = "replayed", + Status = DeliveryStatuses.Replayed, Payload = record.Payload, CreatedAt = record.CreatedAt, AttemptCount = record.AttemptCount, @@ -269,8 +286,22 @@ private async Task BulkReplayAsync(string endpoint, CancellationToken ct if (record.Payload is { Length: > 0 }) { + var newRecord = new WebhookDeliveryRecord + { + Id = record.Id, + EndpointId = record.EndpointId, + EndpointUrl = record.EndpointUrl, + EventType = record.EventType, + Status = DeliveryStatuses.Pending, + CreatedAt = DateTimeOffset.UtcNow, + Payload = record.Payload, + }; + + // Outbox-first: persist before enqueue so the job survives a crash. + await _storage.AppendOutboxAsync(newRecord, ct).ConfigureAwait(false); + var job = new WebhookJob( - Record: replayed, + Record: newRecord, EndpointConfig: new Config.Features.WebhookEndpointConfig { Url = record.EndpointUrl }, EndpointId: record.EndpointId, FormattedBody: record.Payload); diff --git a/src/clawsharp/clawsharp.csproj b/src/clawsharp/clawsharp.csproj index b0eeda40..c6186f93 100644 --- a/src/clawsharp/clawsharp.csproj +++ b/src/clawsharp/clawsharp.csproj @@ -61,6 +61,7 @@ + diff --git a/tests/clawsharp.Tests/Analytics/EfInteractionStoreTests.cs b/tests/clawsharp.Tests/Analytics/EfInteractionStoreTests.cs index b49e24b7..c9b60cfa 100644 --- a/tests/clawsharp.Tests/Analytics/EfInteractionStoreTests.cs +++ b/tests/clawsharp.Tests/Analytics/EfInteractionStoreTests.cs @@ -8,6 +8,7 @@ namespace Clawsharp.Tests.Analytics; /// /// Tests the EF Core-backed interaction store using temp SQLite files. /// +[TestFixture] public sealed class EfInteractionStoreTests { private static (EfInteractionStore Store, string DbPath) CreateStore() diff --git a/tests/clawsharp.Tests/Analytics/InteractionAnalyticsIntegrationTests.cs b/tests/clawsharp.Tests/Analytics/InteractionAnalyticsIntegrationTests.cs index 4e0b8eec..a38e7c5a 100644 --- a/tests/clawsharp.Tests/Analytics/InteractionAnalyticsIntegrationTests.cs +++ b/tests/clawsharp.Tests/Analytics/InteractionAnalyticsIntegrationTests.cs @@ -1,6 +1,7 @@ using Clawsharp.Analytics; using Clawsharp.Analytics.Sqlite; using Clawsharp.Config.Features; +using Clawsharp.Tests.Unit.Pipeline; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Logging.Abstractions; diff --git a/tests/clawsharp.Tests/Analytics/InteractionStorageTests.cs b/tests/clawsharp.Tests/Analytics/InteractionStorageTests.cs index 8059b405..f1795974 100644 --- a/tests/clawsharp.Tests/Analytics/InteractionStorageTests.cs +++ b/tests/clawsharp.Tests/Analytics/InteractionStorageTests.cs @@ -2,6 +2,7 @@ namespace Clawsharp.Tests.Analytics; +[TestFixture] public sealed class InteractionStorageTests : IDisposable { private readonly string _tempDir; diff --git a/tests/clawsharp.Tests/Analytics/InteractionTrackerTests.cs b/tests/clawsharp.Tests/Analytics/InteractionTrackerTests.cs index 6dbdc740..25963df8 100644 --- a/tests/clawsharp.Tests/Analytics/InteractionTrackerTests.cs +++ b/tests/clawsharp.Tests/Analytics/InteractionTrackerTests.cs @@ -6,6 +6,7 @@ namespace Clawsharp.Tests.Analytics; +[TestFixture] public sealed class InteractionTrackerTests : IDisposable { private readonly string _tempDir; diff --git a/tests/clawsharp.Tests/Channels/MessageChunkerTests.cs b/tests/clawsharp.Tests/Channels/MessageChunkerTests.cs index 5f5d878b..aa7712a2 100644 --- a/tests/clawsharp.Tests/Channels/MessageChunkerTests.cs +++ b/tests/clawsharp.Tests/Channels/MessageChunkerTests.cs @@ -2,6 +2,7 @@ namespace Clawsharp.Tests.Channels; +[TestFixture] public sealed class MessageChunkerTests { // ── Basic splitting ──────────────────────────────────────────────── diff --git a/tests/clawsharp.Tests/Fakes/TestFakes.cs b/tests/clawsharp.Tests/Fakes/TestFakes.cs index a3c20d77..0045aa5a 100644 --- a/tests/clawsharp.Tests/Fakes/TestFakes.cs +++ b/tests/clawsharp.Tests/Fakes/TestFakes.cs @@ -300,6 +300,11 @@ public void Register(Tool tool) _definitions.Add(tool.ToDefinition()); } + public bool Unregister(string toolName) + { + return _definitions.RemoveAll(d => string.Equals(d.Name, toolName, StringComparison.OrdinalIgnoreCase)) > 0; + } + public void SetChannelContext(ChannelName channelName, int spawnDepth = 0, string? sessionId = null, OrgUser? orgUser = null, PolicyDecision? policyDecision = null) { } diff --git a/tests/clawsharp.Tests/Fakes/TestLoggers.cs b/tests/clawsharp.Tests/Fakes/TestLoggers.cs new file mode 100644 index 00000000..7a6f44d2 --- /dev/null +++ b/tests/clawsharp.Tests/Fakes/TestLoggers.cs @@ -0,0 +1,24 @@ +using Microsoft.Extensions.Logging; + +namespace Clawsharp.Tests.Fakes; + +/// +/// Minimal ILogger implementation that captures log messages for assertion. +/// Source-generated [LoggerMessage] methods call the raw Log method with +/// a generated state type, making NSubstitute matching unreliable. +/// +public sealed class CapturingLogger(List<(LogLevel Level, string Message)> messages) : ILogger +{ + public IDisposable? BeginScope(TState state) where TState : notnull => null; + public bool IsEnabled(LogLevel logLevel) => true; + + public void Log( + LogLevel logLevel, + EventId eventId, + TState state, + Exception? exception, + Func formatter) + { + messages.Add((logLevel, formatter(state, exception))); + } +} diff --git a/tests/clawsharp.Tests/Integration/Cron/CronStoreContractTests.cs b/tests/clawsharp.Tests/Integration/Cron/CronStoreContractTests.cs index 3dd3f983..f5813fbd 100644 --- a/tests/clawsharp.Tests/Integration/Cron/CronStoreContractTests.cs +++ b/tests/clawsharp.Tests/Integration/Cron/CronStoreContractTests.cs @@ -3,6 +3,7 @@ namespace Clawsharp.Tests.Integration.Cron; /// Contract tests run against every ICronStore implementation. +[TestFixture] [Category("Integration")] public abstract class CronStoreContractTests { diff --git a/tests/clawsharp.Tests/Integration/E2E/WebChannelIntegrationTests.cs b/tests/clawsharp.Tests/Integration/E2E/WebChannelIntegrationTests.cs index 5fc39d0b..b5925071 100644 --- a/tests/clawsharp.Tests/Integration/E2E/WebChannelIntegrationTests.cs +++ b/tests/clawsharp.Tests/Integration/E2E/WebChannelIntegrationTests.cs @@ -69,8 +69,8 @@ public async Task OneTimeSetUp() pairingService, NullLogger.Instance, sp.GetRequiredService(), - new Clawsharp.Organization.IdentityResolver(appConfigOptions), - new Clawsharp.Organization.LinkTokenStore()); + new Organization.IdentityResolver(appConfigOptions), + new Organization.LinkTokenStore()); await _webChannel.StartAsync(_cts.Token); diff --git a/tests/clawsharp.Tests/Integration/Memory/RedisMemoryTests.cs b/tests/clawsharp.Tests/Integration/Memory/RedisMemoryTests.cs index b06604d9..bdf004a5 100644 --- a/tests/clawsharp.Tests/Integration/Memory/RedisMemoryTests.cs +++ b/tests/clawsharp.Tests/Integration/Memory/RedisMemoryTests.cs @@ -125,10 +125,15 @@ public async Task SearchAsync_FindsMatch() await memory.AppendFactAsync("user likes pizza"); await memory.AppendFactAsync("user dislikes broccoli"); - // Give RediSearch a moment to index - await Task.Delay(100); - - var results = await memory.SearchAsync("pizza"); + // Poll until RediSearch indexes the data or timeout + var deadline = DateTime.UtcNow.AddSeconds(5); + IReadOnlyList results = []; + while (DateTime.UtcNow < deadline) + { + results = await memory.SearchAsync("pizza"); + if (results.Count > 0) break; + await Task.Delay(50); + } results.ShouldNotBeEmpty(); results[0].ShouldContain("pizza"); @@ -140,7 +145,14 @@ public async Task SearchAsync_NoMatch_ReturnsEmpty() var memory = CreateMemory(); await memory.AppendFactAsync("user likes pizza"); - await Task.Delay(100); + // Poll until indexing completes (verify a known term is findable first) + var deadline = DateTime.UtcNow.AddSeconds(5); + while (DateTime.UtcNow < deadline) + { + var check = await memory.SearchAsync("pizza"); + if (check.Count > 0) break; + await Task.Delay(50); + } var results = await memory.SearchAsync("xyzzynotarealword"); @@ -156,7 +168,14 @@ public async Task SearchAsync_RespectsNLimit() await memory.AppendFactAsync($"fact about cats number {i}"); } - await Task.Delay(200); + // Poll until indexing completes (all 10 should be findable) + var deadline = DateTime.UtcNow.AddSeconds(5); + while (DateTime.UtcNow < deadline) + { + var check = await memory.SearchAsync("cats", n: 10); + if (check.Count >= 10) break; + await Task.Delay(50); + } var results = await memory.SearchAsync("cats", n: 3); @@ -347,9 +366,15 @@ public async Task SearchHybridAsync_NoEmbedding_FallsBackToText() await memory.AppendFactAsync("user likes chocolate"); await memory.AppendFactAsync("user hates vanilla"); - await Task.Delay(100); - - var results = await memory.SearchHybridAsync("chocolate"); + // Poll until RediSearch indexes the data + var deadline = DateTime.UtcNow.AddSeconds(5); + IReadOnlyList results = []; + while (DateTime.UtcNow < deadline) + { + results = await memory.SearchHybridAsync("chocolate"); + if (results.Count > 0) break; + await Task.Delay(50); + } results.ShouldNotBeEmpty(); results[0].Content.ShouldContain("chocolate"); @@ -361,9 +386,15 @@ public async Task SearchHybridAsync_EmptyEmbedding_FallsBackToText() var memory = CreateMemory(); await memory.AppendFactAsync("cats are great pets"); - await Task.Delay(100); - - var results = await memory.SearchHybridAsync("cats", queryEmbedding: []); + // Poll until RediSearch indexes the data + var deadline = DateTime.UtcNow.AddSeconds(5); + IReadOnlyList results = []; + while (DateTime.UtcNow < deadline) + { + results = await memory.SearchHybridAsync("cats", queryEmbedding: []); + if (results.Count > 0) break; + await Task.Delay(50); + } results.ShouldNotBeEmpty(); results[0].Content.ShouldContain("cats"); @@ -420,18 +451,35 @@ public async Task SearchAsync_AccessCountIncrementedOnHybridSearch() var memory = CreateMemory(); await memory.AppendFactAsync("test access tracking"); - await Task.Delay(100); + // Poll until RediSearch indexes the data + var deadline = DateTime.UtcNow.AddSeconds(5); + while (DateTime.UtcNow < deadline) + { + var check = await memory.SearchHybridAsync("access tracking"); + if (check.Count > 0) break; + await Task.Delay(50); + } // Search via hybrid (which updates access counts) await memory.SearchHybridAsync("access tracking"); - // Give time for access count update - await Task.Delay(50); + // Poll until access count is updated + deadline = DateTime.UtcNow.AddSeconds(5); + Clawsharp.Memory.Entities.Fact? fact = null; + while (DateTime.UtcNow < deadline) + { + var facts = await memory.ListFactsAsync(); + if (facts.Count == 1 && facts[0].AccessCount >= 1) + { + fact = facts[0]; + break; + } + await Task.Delay(50); + } - var facts = await memory.ListFactsAsync(); - facts.Count.ShouldBe(1); - facts[0].AccessCount.ShouldBeGreaterThanOrEqualTo(1); - facts[0].LastAccessedAt.ShouldNotBeNull(); + fact.ShouldNotBeNull(); + fact!.AccessCount.ShouldBeGreaterThanOrEqualTo(1); + fact.LastAccessedAt.ShouldNotBeNull(); } [Test] @@ -472,7 +520,15 @@ public async Task GetContextAsync_ReturnsAtMost50Facts() await memory.AppendFactAsync($"fact number {i}"); } - await Task.Delay(200); + // Poll until all facts are listed (indexing is not required for GetContextAsync, + // but we verify the data is stored) + var deadline = DateTime.UtcNow.AddSeconds(5); + while (DateTime.UtcNow < deadline) + { + var facts = await memory.ListFactsAsync(); + if (facts.Count >= 60) break; + await Task.Delay(50); + } var result = await memory.GetContextAsync(); result.ShouldNotBeNull(); diff --git a/tests/clawsharp.Tests/Knowledge/AzureBlobSourceLoaderTests.cs b/tests/clawsharp.Tests/Knowledge/AzureBlobSourceLoaderTests.cs index 940f6f7f..b35e721f 100644 --- a/tests/clawsharp.Tests/Knowledge/AzureBlobSourceLoaderTests.cs +++ b/tests/clawsharp.Tests/Knowledge/AzureBlobSourceLoaderTests.cs @@ -14,6 +14,7 @@ namespace Clawsharp.Tests.Knowledge; /// prefix filtering via GetBlobsAsync, azure:// URI format per D-22, and /// SsrfGuard transport injection per D-26. /// +[TestFixture] public sealed class AzureBlobSourceLoaderTests { private IDocumentLoaderRegistry _loaderRegistry = null!; diff --git a/tests/clawsharp.Tests/Knowledge/BatchEmbeddingProviderTests.cs b/tests/clawsharp.Tests/Knowledge/BatchEmbeddingProviderTests.cs index e4b08d3f..c916046a 100644 --- a/tests/clawsharp.Tests/Knowledge/BatchEmbeddingProviderTests.cs +++ b/tests/clawsharp.Tests/Knowledge/BatchEmbeddingProviderTests.cs @@ -13,6 +13,7 @@ namespace Clawsharp.Tests.Knowledge; /// on , bounded parallelism, empty input, and cancellation. /// Uses NSubstitute to mock and a zero-delay Polly pipeline for speed. /// +[TestFixture] public sealed class BatchEmbeddingProviderTests { private IEmbeddingProvider _inner = null!; diff --git a/tests/clawsharp.Tests/Knowledge/ClawsharpSignTests.cs b/tests/clawsharp.Tests/Knowledge/ClawsharpSignTests.cs index 2fc7dfe7..fec6f56b 100644 --- a/tests/clawsharp.Tests/Knowledge/ClawsharpSignTests.cs +++ b/tests/clawsharp.Tests/Knowledge/ClawsharpSignTests.cs @@ -7,10 +7,11 @@ namespace Clawsharp.Tests.Knowledge; /// Verifies Ed25519 keypair generation, plugin signing, verification, /// tampered DLL detection, and wrong-key rejection per D-40. /// +[TestFixture] public sealed class ClawsharpSignTests { private string _tempDir = null!; - private string _projectPath = null!; + private string _signToolDll = null!; [SetUp] public void SetUp() @@ -18,10 +19,14 @@ public void SetUp() _tempDir = Path.Combine(Path.GetTempPath(), $"clawsharp-sign-test-{Guid.NewGuid():N}"); Directory.CreateDirectory(_tempDir); - // Resolve project path relative to test assembly location - // tests/clawsharp.Tests/bin/Debug/net10.0/ -> navigate up to repo root -> src/clawsharp-sign/ + // Resolve the built clawsharp-sign.dll relative to test assembly location. + // tests/clawsharp.Tests/bin/{Config}/net10.0/ -> repo root -> src/clawsharp-sign/bin/{Config}/net10.0/ + // We use "dotnet exec " instead of "dotnet run --project" to avoid the + // .NET 10 glob expansion bug on GitHub Actions runners. var assemblyDir = Path.GetDirectoryName(typeof(ClawsharpSignTests).Assembly.Location)!; - _projectPath = Path.GetFullPath(Path.Combine(assemblyDir, "..", "..", "..", "..", "..", "src", "clawsharp-sign", "clawsharp-sign.csproj")); + var repoRoot = Path.GetFullPath(Path.Combine(assemblyDir, "..", "..", "..", "..", "..")); + var config = assemblyDir.Contains("Release") ? "Release" : "Debug"; + _signToolDll = Path.Combine(repoRoot, "src", "clawsharp-sign", "bin", config, "net10.0", "clawsharp-sign.dll"); } [TearDown] @@ -172,8 +177,8 @@ await RunSignToolAsync( private async Task<(int ExitCode, string Stdout, string Stderr)> RunSignToolAsync(params string[] args) { - var allArgs = new List { "run", "--project", _projectPath, "--no-build", "--" }; - allArgs.AddRange(args); + if (!File.Exists(_signToolDll)) + Assert.Ignore($"clawsharp-sign not built at {_signToolDll}"); var psi = new ProcessStartInfo { @@ -184,7 +189,9 @@ await RunSignToolAsync( CreateNoWindow = true, }; - foreach (var arg in allArgs) + psi.ArgumentList.Add("exec"); + psi.ArgumentList.Add(_signToolDll); + foreach (var arg in args) psi.ArgumentList.Add(arg); using var process = Process.Start(psi); diff --git a/tests/clawsharp.Tests/Knowledge/CloudStorageLoaderBaseTests.cs b/tests/clawsharp.Tests/Knowledge/CloudStorageLoaderBaseTests.cs index 1f5e1f49..e3a53d8e 100644 --- a/tests/clawsharp.Tests/Knowledge/CloudStorageLoaderBaseTests.cs +++ b/tests/clawsharp.Tests/Knowledge/CloudStorageLoaderBaseTests.cs @@ -13,6 +13,7 @@ namespace Clawsharp.Tests.Knowledge; /// shared logic: extension filtering before download (D-24), format dispatch via /// IDocumentLoaderRegistry (D-25), URI construction, and empty listing handling. /// +[TestFixture] public sealed class CloudStorageLoaderBaseTests { private IDocumentLoaderRegistry _loaderRegistry = null!; diff --git a/tests/clawsharp.Tests/Knowledge/ContentHasherTests.cs b/tests/clawsharp.Tests/Knowledge/ContentHasherTests.cs index 1190f015..27b5fb30 100644 --- a/tests/clawsharp.Tests/Knowledge/ContentHasherTests.cs +++ b/tests/clawsharp.Tests/Knowledge/ContentHasherTests.cs @@ -6,6 +6,7 @@ namespace Clawsharp.Tests.Knowledge; /// Tests for . Verifies per-document SHA-256 hashing with /// sourceUri inclusion to prevent empty-doc collision, determinism, and Merkle rollup. /// +[TestFixture] public sealed class ContentHasherTests { // ── ComputeDocumentHash ───────────────────────────────────────────── diff --git a/tests/clawsharp.Tests/Knowledge/DeleteByDocumentTests.cs b/tests/clawsharp.Tests/Knowledge/DeleteByDocumentTests.cs index 1239e407..6b60f0cc 100644 --- a/tests/clawsharp.Tests/Knowledge/DeleteByDocumentTests.cs +++ b/tests/clawsharp.Tests/Knowledge/DeleteByDocumentTests.cs @@ -10,6 +10,7 @@ namespace Clawsharp.Tests.Knowledge; /// backend (no database infrastructure needed). Verifies selective deletion by document /// within a source, no-op behavior on missing data. /// +[TestFixture] public sealed class DeleteByDocumentTests : IDisposable { private readonly string _tempDir; diff --git a/tests/clawsharp.Tests/Knowledge/GcsSourceLoaderTests.cs b/tests/clawsharp.Tests/Knowledge/GcsSourceLoaderTests.cs index 61e50dc7..030f218e 100644 --- a/tests/clawsharp.Tests/Knowledge/GcsSourceLoaderTests.cs +++ b/tests/clawsharp.Tests/Knowledge/GcsSourceLoaderTests.cs @@ -14,6 +14,7 @@ namespace Clawsharp.Tests.Knowledge; /// prefix filtering via ListObjectsAsync, gs:// URI format per D-23, /// and download stream handling. /// +[TestFixture] public sealed class GcsSourceLoaderTests { private StorageClient _storageClient = null!; diff --git a/tests/clawsharp.Tests/Knowledge/GitSourceLoaderTests.cs b/tests/clawsharp.Tests/Knowledge/GitSourceLoaderTests.cs index f370b34d..510da561 100644 --- a/tests/clawsharp.Tests/Knowledge/GitSourceLoaderTests.cs +++ b/tests/clawsharp.Tests/Knowledge/GitSourceLoaderTests.cs @@ -12,6 +12,7 @@ namespace Clawsharp.Tests.Knowledge; /// directories for high-fidelity testing of clone, pull, extension filtering, SourceUri /// format, and empty repo handling. /// +[TestFixture] public sealed class GitSourceLoaderTests { private string _tempDir = null!; diff --git a/tests/clawsharp.Tests/Knowledge/HeadingAwareChunkerTests.cs b/tests/clawsharp.Tests/Knowledge/HeadingAwareChunkerTests.cs index ba9103c9..ece362f5 100644 --- a/tests/clawsharp.Tests/Knowledge/HeadingAwareChunkerTests.cs +++ b/tests/clawsharp.Tests/Knowledge/HeadingAwareChunkerTests.cs @@ -15,7 +15,7 @@ public sealed class HeadingAwareChunkerTests private readonly HeadingAwareChunker _chunker = new(); private static ChunkingConfig Config(int chunkSize = 50, double overlap = 0.1) => - new() { ChunkSize = chunkSize, Overlap = overlap, Strategy = "paragraph" }; + new() { ChunkSize = chunkSize, Overlap = overlap, Strategy = "heading" }; private static async IAsyncEnumerable Pages(params DocumentPage[] pages) { @@ -34,9 +34,9 @@ private static async Task> CollectAsync( } [Test] - public void Name_ReturnsParagraph() + public void Name_ReturnsHeading() { - Assert.That(_chunker.Name, Is.EqualTo("paragraph")); + Assert.That(_chunker.Name, Is.EqualTo("heading")); } [Test] diff --git a/tests/clawsharp.Tests/Knowledge/IngestionPipelineTests.cs b/tests/clawsharp.Tests/Knowledge/IngestionPipelineTests.cs index 85fd6153..c814b17f 100644 --- a/tests/clawsharp.Tests/Knowledge/IngestionPipelineTests.cs +++ b/tests/clawsharp.Tests/Knowledge/IngestionPipelineTests.cs @@ -19,6 +19,7 @@ namespace Clawsharp.Tests.Knowledge; /// progress reporting, and error state tracking. All dependencies are mocked via NSubstitute. /// Uses real temp directories to exercise file enumeration. /// +[TestFixture] public sealed class IngestionPipelineTests { private IDocumentLoaderRegistry _loaderRegistry = null!; @@ -149,8 +150,11 @@ public async Task IngestSourceAsync_UnchangedDocument_IsSkipped() _loaderRegistry.GetSupportedExtensions().Returns(new List { ".md" }); - _loaderRegistry.LoadAsync(Arg.Is(p => p.EndsWith("file1.md")), Arg.Any()) + // Production code calls GetLoader(ext).LoadAsync() directly for local sources + var mdLoader = Substitute.For(); + mdLoader.LoadAsync(Arg.Is(p => p.EndsWith("file1.md")), Arg.Any()) .Returns(ToAsyncEnumerable(new DocumentPage("Same content", 1))); + _loaderRegistry.GetLoader(".md").Returns(mdLoader); // Pre-compute the hash that ContentHasher would produce var expectedHash = ContentHasher.ComputeDocumentHash(filePath, "Same content"); @@ -199,8 +203,11 @@ public async Task IngestSourceAsync_SourceMerkleHashMatches_SkipsEntireSource() _loaderRegistry.GetSupportedExtensions().Returns(new List { ".md" }); - _loaderRegistry.LoadAsync(Arg.Is(p => p.EndsWith("file1.md")), Arg.Any()) + // Production code calls GetLoader(ext).LoadAsync() directly for local sources + var mdLoader = Substitute.For(); + mdLoader.LoadAsync(Arg.Is(p => p.EndsWith("file1.md")), Arg.Any()) .Returns(ToAsyncEnumerable(new DocumentPage("Content A", 1))); + _loaderRegistry.GetLoader(".md").Returns(mdLoader); var docHash = ContentHasher.ComputeDocumentHash(filePath, "Content A"); diff --git a/tests/clawsharp.Tests/Knowledge/IngestionWorkerTests.cs b/tests/clawsharp.Tests/Knowledge/IngestionWorkerTests.cs index 8aece53b..e6ba7934 100644 --- a/tests/clawsharp.Tests/Knowledge/IngestionWorkerTests.cs +++ b/tests/clawsharp.Tests/Knowledge/IngestionWorkerTests.cs @@ -16,6 +16,7 @@ namespace Clawsharp.Tests.Knowledge; /// Tests for . Verifies sequential job processing /// via bounded channel, crash recovery on startup, and error resilience. /// +[TestFixture] public sealed class IngestionWorkerTests { private KnowledgeIngestionPipeline _pipeline = null!; diff --git a/tests/clawsharp.Tests/Knowledge/KnowledgeConfigTests.cs b/tests/clawsharp.Tests/Knowledge/KnowledgeConfigTests.cs index 8fa46dbc..ab1dea3e 100644 --- a/tests/clawsharp.Tests/Knowledge/KnowledgeConfigTests.cs +++ b/tests/clawsharp.Tests/Knowledge/KnowledgeConfigTests.cs @@ -4,6 +4,7 @@ namespace Clawsharp.Tests.Knowledge; +[TestFixture] public sealed class KnowledgeConfigTests { [Test] diff --git a/tests/clawsharp.Tests/Knowledge/KnowledgeEntityTests.cs b/tests/clawsharp.Tests/Knowledge/KnowledgeEntityTests.cs index 2d469ee7..27c37f78 100644 --- a/tests/clawsharp.Tests/Knowledge/KnowledgeEntityTests.cs +++ b/tests/clawsharp.Tests/Knowledge/KnowledgeEntityTests.cs @@ -10,6 +10,7 @@ namespace Clawsharp.Tests.Knowledge; /// Tests for KnowledgeSource, KnowledgeChunk entity configurations, AclFilter, and IKnowledgeStore interface shape. /// Uses an in-memory model builder to verify EF Core configuration without a database. /// +[TestFixture] public sealed class KnowledgeEntityTests { private static IModel BuildModel() diff --git a/tests/clawsharp.Tests/Knowledge/KnowledgeIngestCommandTests.cs b/tests/clawsharp.Tests/Knowledge/KnowledgeIngestCommandTests.cs index 47b81231..bd50f79c 100644 --- a/tests/clawsharp.Tests/Knowledge/KnowledgeIngestCommandTests.cs +++ b/tests/clawsharp.Tests/Knowledge/KnowledgeIngestCommandTests.cs @@ -8,8 +8,21 @@ namespace Clawsharp.Tests.Knowledge; /// Tests for source config resolution logic. /// [TestFixture] -public sealed class KnowledgeIngestCommandTests +public sealed class KnowledgeIngestCommandTests : IDisposable { + private readonly string _tempDir; + + public KnowledgeIngestCommandTests() + { + _tempDir = Path.Combine(Path.GetTempPath(), $"clawsharp-ingest-cmd-test-{Guid.NewGuid():N}"); + Directory.CreateDirectory(_tempDir); + } + + public void Dispose() + { + try { Directory.Delete(_tempDir, recursive: true); } catch { /* best-effort */ } + } + [Test] public void ResolveSourceConfig_ConfiguredName_ReturnsConfiguredSource() { @@ -67,6 +80,9 @@ public void ResolveSourceConfig_ConfiguredNameCaseInsensitive_ReturnsConfiguredS [Test] public void ResolveSourceConfig_LocalPath_CreatesAdHocConfig() { + var subDir = Path.Combine(_tempDir, "my-documents"); + Directory.CreateDirectory(subDir); + var config = new AppConfig { Knowledge = new KnowledgeConfig @@ -76,10 +92,10 @@ public void ResolveSourceConfig_LocalPath_CreatesAdHocConfig() }, }; - var result = KnowledgeIngestCommand.ResolveSourceConfig(config, "/tmp/my-documents"); + var result = KnowledgeIngestCommand.ResolveSourceConfig(config, subDir); result.Type.ShouldBe("local"); - result.Path.ShouldBe("/tmp/my-documents"); + result.Path.ShouldBe(subDir); result.Name.ShouldBe("my-documents"); } @@ -123,6 +139,9 @@ public void ResolveSourceConfig_HttpsUrl_CreatesUrlConfig() [Test] public void ResolveSourceConfig_NoConfiguredSources_CreatesAdHocLocal() { + var notesDir = Path.Combine(_tempDir, "notes"); + Directory.CreateDirectory(notesDir); + var config = new AppConfig { Knowledge = new KnowledgeConfig @@ -132,9 +151,9 @@ public void ResolveSourceConfig_NoConfiguredSources_CreatesAdHocLocal() }, }; - var result = KnowledgeIngestCommand.ResolveSourceConfig(config, "/home/user/notes"); + var result = KnowledgeIngestCommand.ResolveSourceConfig(config, notesDir); result.Type.ShouldBe("local"); - result.Path.ShouldBe("/home/user/notes"); + result.Path.ShouldBe(notesDir); } } diff --git a/tests/clawsharp.Tests/Knowledge/KnowledgeJsonContextTests.cs b/tests/clawsharp.Tests/Knowledge/KnowledgeJsonContextTests.cs index cf014996..fa7262e8 100644 --- a/tests/clawsharp.Tests/Knowledge/KnowledgeJsonContextTests.cs +++ b/tests/clawsharp.Tests/Knowledge/KnowledgeJsonContextTests.cs @@ -11,6 +11,7 @@ namespace Clawsharp.Tests.Knowledge; /// . Ensures no DTO is accidentally excluded /// from the source-generated context. /// +[TestFixture] public sealed class KnowledgeJsonContextTests { [Test] diff --git a/tests/clawsharp.Tests/Knowledge/KnowledgeSpanTests.cs b/tests/clawsharp.Tests/Knowledge/KnowledgeSpanTests.cs index a3932fc4..f99d6996 100644 --- a/tests/clawsharp.Tests/Knowledge/KnowledgeSpanTests.cs +++ b/tests/clawsharp.Tests/Knowledge/KnowledgeSpanTests.cs @@ -220,8 +220,11 @@ public async Task IngestSourceAsync_FailurePath_RecordsDocumentFailedMetric() { _loaderRegistry.GetSupportedExtensions().Returns(new List { ".md" }); File.WriteAllText(Path.Combine(_tempDir, "bad.md"), "bad content"); - _loaderRegistry.LoadAsync(Arg.Any(), Arg.Any()) + // Production code calls GetLoader(ext).LoadAsync() for local sources + var mdLoader = Substitute.For(); + mdLoader.LoadAsync(Arg.Any(), Arg.Any()) .Throws(new IOException("disk error")); + _loaderRegistry.GetLoader(".md").Returns(mdLoader); long recordedValue = 0; using var meterListener = new MeterListener(); diff --git a/tests/clawsharp.Tests/Knowledge/KnowledgeStoreDiTests.cs b/tests/clawsharp.Tests/Knowledge/KnowledgeStoreDiTests.cs index f2062443..51734a94 100644 --- a/tests/clawsharp.Tests/Knowledge/KnowledgeStoreDiTests.cs +++ b/tests/clawsharp.Tests/Knowledge/KnowledgeStoreDiTests.cs @@ -10,6 +10,7 @@ namespace Clawsharp.Tests.Knowledge; /// /// Structural tests verifying all 5 IKnowledgeStore implementations exist and implement the interface. /// +[TestFixture] public sealed class KnowledgeStoreDiTests { [Test] diff --git a/tests/clawsharp.Tests/Knowledge/KnowledgeStoreTests.cs b/tests/clawsharp.Tests/Knowledge/KnowledgeStoreTests.cs index 9cff779b..2d9d8b36 100644 --- a/tests/clawsharp.Tests/Knowledge/KnowledgeStoreTests.cs +++ b/tests/clawsharp.Tests/Knowledge/KnowledgeStoreTests.cs @@ -10,6 +10,7 @@ namespace Clawsharp.Tests.Knowledge; /// functional tests since it requires no database infrastructure. EF Core backends /// (SQLite, Postgres, MsSql) share the same patterns and are covered by integration tests. /// +[TestFixture] public sealed class KnowledgeStoreTests : IDisposable { private readonly string _tempDir; diff --git a/tests/clawsharp.Tests/Knowledge/PluginIntegrityVerifierTests.cs b/tests/clawsharp.Tests/Knowledge/PluginIntegrityVerifierTests.cs index 128595e1..51757d15 100644 --- a/tests/clawsharp.Tests/Knowledge/PluginIntegrityVerifierTests.cs +++ b/tests/clawsharp.Tests/Knowledge/PluginIntegrityVerifierTests.cs @@ -13,6 +13,7 @@ namespace Clawsharp.Tests.Knowledge; +[TestFixture] public sealed class PluginIntegrityVerifierTests : IDisposable { private readonly string _tempDir; @@ -20,7 +21,7 @@ public sealed class PluginIntegrityVerifierTests : IDisposable private readonly ILogger _logger; // Test key pair — generated fresh for each test class instance - private readonly NSec.Cryptography.Key _signingKey; + private readonly Key _signingKey; private readonly byte[] _publicKeyBytes; public PluginIntegrityVerifierTests() @@ -32,8 +33,8 @@ public PluginIntegrityVerifierTests() _auditLogger = new AuditLogger(options, NullLogger.Instance); _logger = NullLogger.Instance; - var algorithm = NSec.Cryptography.SignatureAlgorithm.Ed25519; - _signingKey = NSec.Cryptography.Key.Create(algorithm, + var algorithm = SignatureAlgorithm.Ed25519; + _signingKey = Key.Create(algorithm, new KeyCreationParameters { ExportPolicy = KeyExportPolicies.AllowPlaintextExport }); _publicKeyBytes = _signingKey.Export(KeyBlobFormat.RawPublicKey); } @@ -274,7 +275,7 @@ private string CreatePluginDirectory(string name, out PluginManifest manifest) // Create a fake DLL var dllName = "clawsharp.Plugin.Test.dll"; - var dllContent = System.Text.Encoding.UTF8.GetBytes($"fake-dll-content-{Guid.NewGuid():N}"); + var dllContent = Encoding.UTF8.GetBytes($"fake-dll-content-{Guid.NewGuid():N}"); File.WriteAllBytes(Path.Combine(pluginDir, dllName), dllContent); // Compute file hashes @@ -294,7 +295,7 @@ private string CreatePluginDirectory(string name, out PluginManifest manifest) // Build canonical payload and sign it var canonicalBytes = PluginIntegrityVerifier.BuildCanonicalPayload(unsignedManifest); - var algorithm = NSec.Cryptography.SignatureAlgorithm.Ed25519; + var algorithm = SignatureAlgorithm.Ed25519; var signatureBytes = algorithm.Sign(_signingKey, canonicalBytes); var signature = Convert.ToBase64String(signatureBytes); diff --git a/tests/clawsharp.Tests/Knowledge/PluginLoaderSubdirectoryTests.cs b/tests/clawsharp.Tests/Knowledge/PluginLoaderSubdirectoryTests.cs index 0bfc1874..b3a6b560 100644 --- a/tests/clawsharp.Tests/Knowledge/PluginLoaderSubdirectoryTests.cs +++ b/tests/clawsharp.Tests/Knowledge/PluginLoaderSubdirectoryTests.cs @@ -11,6 +11,7 @@ namespace Clawsharp.Tests.Knowledge; +[TestFixture] public sealed class PluginLoaderSubdirectoryTests : IDisposable { private readonly string _tempDir; @@ -18,7 +19,7 @@ public sealed class PluginLoaderSubdirectoryTests : IDisposable private readonly AuditLogger _auditLogger; // Test key pair for signing - private readonly NSec.Cryptography.Key _signingKey; + private readonly Key _signingKey; private readonly byte[] _publicKeyBytes; public PluginLoaderSubdirectoryTests() @@ -29,8 +30,8 @@ public PluginLoaderSubdirectoryTests() var options = Options.Create(new AppConfig { Audit = new Config.Security.AuditConfig { Enabled = false } }); _auditLogger = new AuditLogger(options, NullLogger.Instance); - var algorithm = NSec.Cryptography.SignatureAlgorithm.Ed25519; - _signingKey = NSec.Cryptography.Key.Create(algorithm, + var algorithm = SignatureAlgorithm.Ed25519; + _signingKey = Key.Create(algorithm, new KeyCreationParameters { ExportPolicy = KeyExportPolicies.AllowPlaintextExport }); _publicKeyBytes = _signingKey.Export(KeyBlobFormat.RawPublicKey); } @@ -150,17 +151,6 @@ public async Task LoadPluginsAsync_IgnoresSubdirsWithoutPluginDll() result.ShouldNotBeNull(); } - // ── Backward compatibility ───────────────────────────────────── - - [Test] - public void LoadPlugins_SyncWrapper_ReturnsEmptyForNonexistent() - { - var result = PluginLoader.LoadPlugins("/nonexistent/plugins", _logger); - - result.ShouldNotBeNull(); - result.Count.ShouldBe(0); - } - // ── Helper methods ───────────────────────────────────────────── /// diff --git a/tests/clawsharp.Tests/Knowledge/PluginLoaderTests.cs b/tests/clawsharp.Tests/Knowledge/PluginLoaderTests.cs index c7f04b21..b0e8bac2 100644 --- a/tests/clawsharp.Tests/Knowledge/PluginLoaderTests.cs +++ b/tests/clawsharp.Tests/Knowledge/PluginLoaderTests.cs @@ -8,30 +8,33 @@ namespace Clawsharp.Tests.Knowledge; +[TestFixture] public sealed class PluginLoaderTests { private readonly ILogger _logger = NullLogger.Instance; - // ── LoadPlugins ────────────────────────────────────────────────── + // ── LoadPluginsAsync ────────────────────────────────────────────── [Test] - public void LoadPlugins_NonExistentDirectory_ReturnsEmptyList() + public async Task LoadPluginsAsync_NonExistentDirectory_ReturnsEmptyList() { - var result = PluginLoader.LoadPlugins("/nonexistent/path/plugins", _logger); + var result = await PluginLoader.LoadPluginsAsync( + "/nonexistent/path/plugins", verifier: null, requireSigned: false, _logger); result.ShouldNotBeNull(); result.Count.ShouldBe(0); } [Test] - public void LoadPlugins_EmptyDirectory_ReturnsEmptyList() + public async Task LoadPluginsAsync_EmptyDirectory_ReturnsEmptyList() { var tempDir = Path.Combine(Path.GetTempPath(), $"clawsharp-test-{Guid.NewGuid():N}"); Directory.CreateDirectory(tempDir); try { - var result = PluginLoader.LoadPlugins(tempDir, _logger); + var result = await PluginLoader.LoadPluginsAsync( + tempDir, verifier: null, requireSigned: false, _logger); result.ShouldNotBeNull(); result.Count.ShouldBe(0); diff --git a/tests/clawsharp.Tests/Knowledge/RemoteIngestionPipelineTests.cs b/tests/clawsharp.Tests/Knowledge/RemoteIngestionPipelineTests.cs index 29bdc63d..14e2b4e9 100644 --- a/tests/clawsharp.Tests/Knowledge/RemoteIngestionPipelineTests.cs +++ b/tests/clawsharp.Tests/Knowledge/RemoteIngestionPipelineTests.cs @@ -18,6 +18,7 @@ namespace Clawsharp.Tests.Knowledge; /// Verifies that remote loaders are dispatched correctly, delta detection works for /// remote documents, and the existing local ingestion path remains unbroken. /// +[TestFixture] public sealed class RemoteIngestionPipelineTests { private IDocumentLoaderRegistry _loaderRegistry = null!; diff --git a/tests/clawsharp.Tests/Knowledge/RrfMergerTests.cs b/tests/clawsharp.Tests/Knowledge/RrfMergerTests.cs index a9734d39..41f16df1 100644 --- a/tests/clawsharp.Tests/Knowledge/RrfMergerTests.cs +++ b/tests/clawsharp.Tests/Knowledge/RrfMergerTests.cs @@ -7,6 +7,7 @@ namespace Clawsharp.Tests.Knowledge; /// /// Tests for RrfMerger reciprocal rank fusion utility. /// +[TestFixture] public sealed class RrfMergerTests { private static KnowledgeChunk MakeChunk(Guid id) => new() diff --git a/tests/clawsharp.Tests/Knowledge/S3SourceLoaderTests.cs b/tests/clawsharp.Tests/Knowledge/S3SourceLoaderTests.cs index 0f31775a..1f5cc32c 100644 --- a/tests/clawsharp.Tests/Knowledge/S3SourceLoaderTests.cs +++ b/tests/clawsharp.Tests/Knowledge/S3SourceLoaderTests.cs @@ -12,6 +12,7 @@ namespace Clawsharp.Tests.Knowledge; /// Tests for . Verifies S3-specific behavior: prefix filtering, /// pagination handling, and s3:// URI format per D-21. /// +[TestFixture] public sealed class S3SourceLoaderTests { private IAmazonS3 _s3Client = null!; diff --git a/tests/clawsharp.Tests/Knowledge/SyncStateTrackerTests.cs b/tests/clawsharp.Tests/Knowledge/SyncStateTrackerTests.cs index 62dc2f95..c47a901d 100644 --- a/tests/clawsharp.Tests/Knowledge/SyncStateTrackerTests.cs +++ b/tests/clawsharp.Tests/Knowledge/SyncStateTrackerTests.cs @@ -11,6 +11,7 @@ namespace Clawsharp.Tests.Knowledge; /// Tests for . Uses an in-memory SQLite database /// with a real EF Core context to validate CAS state transitions and crash recovery. /// +[TestFixture] public sealed class SyncStateTrackerTests : IDisposable { private readonly SqliteConnection _connection; diff --git a/tests/clawsharp.Tests/Security/CanaryGuardTests.cs b/tests/clawsharp.Tests/Security/CanaryGuardTests.cs index b27eca01..811c6f37 100644 --- a/tests/clawsharp.Tests/Security/CanaryGuardTests.cs +++ b/tests/clawsharp.Tests/Security/CanaryGuardTests.cs @@ -2,6 +2,7 @@ namespace Clawsharp.Tests.Security; +[TestFixture] public sealed class CanaryGuardTests { // ── GenerateCanary ─────────────────────────────────────────────────── diff --git a/tests/clawsharp.Tests/Security/EgressPolicyTests.cs b/tests/clawsharp.Tests/Security/EgressPolicyTests.cs index d7b0d9c1..aca9d942 100644 --- a/tests/clawsharp.Tests/Security/EgressPolicyTests.cs +++ b/tests/clawsharp.Tests/Security/EgressPolicyTests.cs @@ -3,6 +3,7 @@ namespace Clawsharp.Tests.Security; +[TestFixture] public sealed class EgressPolicyTests { [TearDown] diff --git a/tests/clawsharp.Tests/Security/LeakDetectorTests.cs b/tests/clawsharp.Tests/Security/LeakDetectorTests.cs index e7f1665d..72ac9e32 100644 --- a/tests/clawsharp.Tests/Security/LeakDetectorTests.cs +++ b/tests/clawsharp.Tests/Security/LeakDetectorTests.cs @@ -2,6 +2,7 @@ namespace Clawsharp.Tests.Security; +[TestFixture] public sealed class LeakDetectorTests { // ── API Key Detection ─────────────────────────────────────────────── diff --git a/tests/clawsharp.Tests/Security/PromptGuardTests.cs b/tests/clawsharp.Tests/Security/PromptGuardTests.cs index 1fe0b8a7..6ca9770c 100644 --- a/tests/clawsharp.Tests/Security/PromptGuardTests.cs +++ b/tests/clawsharp.Tests/Security/PromptGuardTests.cs @@ -3,6 +3,7 @@ namespace Clawsharp.Tests.Security; +[TestFixture] public sealed class PromptGuardTests { /// diff --git a/tests/clawsharp.Tests/Security/SecretStoreTests.cs b/tests/clawsharp.Tests/Security/SecretStoreTests.cs index 7b724d38..86109864 100644 --- a/tests/clawsharp.Tests/Security/SecretStoreTests.cs +++ b/tests/clawsharp.Tests/Security/SecretStoreTests.cs @@ -6,6 +6,7 @@ namespace Clawsharp.Tests.Security; +[TestFixture] public sealed class SecretStoreTests : IDisposable { private readonly string _tempDir; diff --git a/tests/clawsharp.Tests/Security/ShellGuardEdgeCaseTests.cs b/tests/clawsharp.Tests/Security/ShellGuardEdgeCaseTests.cs index a907ad5d..11ead36b 100644 --- a/tests/clawsharp.Tests/Security/ShellGuardEdgeCaseTests.cs +++ b/tests/clawsharp.Tests/Security/ShellGuardEdgeCaseTests.cs @@ -171,7 +171,7 @@ public void CheckCommand_PipeToTclsh_KnownLimitation_NotBlocked() // is skipped, so approval is still checked. [Test] - public void RequiresApproval_ApprovalPatternReDoS_KnownLimitation_SkipsTimedOutPattern() + public void RequiresApproval_ApprovalPatternReDoS_FailsClosed_RequiresApproval() { // Configure an approval pattern that will trigger ReDoS ShellGuard.ConfigureCustomPatterns( @@ -182,17 +182,15 @@ public void RequiresApproval_ApprovalPatternReDoS_KnownLimitation_SkipsTimedOutP var maliciousInput = new string('a', 30) + "!"; - // The approval pattern will time out. The catch block swallows the exception - // and returns null (no approval required). Built-in patterns are checked first - // and won't match this input. + // The approval pattern will time out. The catch block now returns the pattern + // (fail-closed: require approval on timeout) rather than skipping. var result = ShellGuard.RequiresApproval(maliciousInput, null, null); - // Known behavior: timeout on approval pattern = pattern is skipped = no approval required - // This is less critical than deny-pattern ReDoS because it fails in the - // "more permissive" direction (allowing without approval vs blocking outright). - result.ShouldBeNull( - "Known limitation: ReDoS timeout on approval pattern causes the pattern to be " + - "skipped, meaning the command is NOT flagged for approval."); + // Fail-closed: timeout on approval pattern = require approval + // This prevents ReDoS from bypassing approval requirements. + result.ShouldNotBeNull( + "Fail-closed: ReDoS timeout on approval pattern should require approval"); + result.ShouldBe("(a+)+$"); } [Test] diff --git a/tests/clawsharp.Tests/Security/ShellGuardTests.cs b/tests/clawsharp.Tests/Security/ShellGuardTests.cs index 0f709a1b..261e3580 100644 --- a/tests/clawsharp.Tests/Security/ShellGuardTests.cs +++ b/tests/clawsharp.Tests/Security/ShellGuardTests.cs @@ -4,6 +4,7 @@ namespace Clawsharp.Tests.Security; +[TestFixture] public sealed class ShellGuardTests { // ── Destructive commands ───────────────────────────────────────────── diff --git a/tests/clawsharp.Tests/Security/SsrfGuardTests.cs b/tests/clawsharp.Tests/Security/SsrfGuardTests.cs index 8e247418..f867427f 100644 --- a/tests/clawsharp.Tests/Security/SsrfGuardTests.cs +++ b/tests/clawsharp.Tests/Security/SsrfGuardTests.cs @@ -3,6 +3,7 @@ namespace Clawsharp.Tests.Security; +[TestFixture] public sealed class SsrfGuardTests { // ── IPv4 private/reserved (should be blocked) ──────────────────── diff --git a/tests/clawsharp.Tests/Unit/A2a/A2aAttributesTests.cs b/tests/clawsharp.Tests/Unit/A2a/A2aAttributesTests.cs index 0ef686ab..e17328d4 100644 --- a/tests/clawsharp.Tests/Unit/A2a/A2aAttributesTests.cs +++ b/tests/clawsharp.Tests/Unit/A2a/A2aAttributesTests.cs @@ -5,8 +5,8 @@ namespace Clawsharp.Tests.Unit.A2a; /// /// Structural exhaustiveness tests for . -/// Verifies exactly 12 constants exist, all prefixed with "a2a.", and each -/// expected attribute key is present. +/// Verifies exactly 16 constants exist (12 OTel attributes prefixed "a2a." + +/// 4 delegation metadata keys prefixed "clawsharp."), and each expected value is present. /// [TestFixture] public sealed class A2aAttributesTests @@ -19,23 +19,24 @@ public sealed class A2aAttributesTests // ── Exhaustiveness ────────────────────────────────────────────────────── [Test] - public void A2aAttributes_HasExactly12Constants() + public void A2aAttributes_HasExactly16Constants() { - Assert.That(ConstFields, Has.Length.EqualTo(12)); + Assert.That(ConstFields, Has.Length.EqualTo(16)); } [Test] - public void A2aAttributes_AllConstantsHaveA2aPrefix() + public void A2aAttributes_AllConstantsHaveKnownPrefix() { foreach (var field in ConstFields) { var value = (string)field.GetValue(null)!; - Assert.That(value, Does.StartWith("a2a."), - $"Field '{field.Name}' has value '{value}' which should start with 'a2a.'"); + Assert.That(value.StartsWith("a2a.", StringComparison.Ordinal) + || value.StartsWith("clawsharp.", StringComparison.Ordinal), + $"Field '{field.Name}' has value '{value}' which should start with 'a2a.' or 'clawsharp.'"); } } - // ── Individual value checks ───────────────────────────────────────────── + // ── Individual value checks (OTel attributes) ─────────────────────────── [Test] public void TaskId_HasCorrectValue() @@ -84,4 +85,22 @@ public void DelegationDepth_HasCorrectValue() [Test] public void DelegationChainId_HasCorrectValue() => Assert.That(A2aAttributes.DelegationChainId, Is.EqualTo("a2a.delegation.chain_id")); + + // ── Delegation metadata key value checks ──────────────────────────────── + + [Test] + public void MetaDepth_HasCorrectValue() + => Assert.That(A2aAttributes.MetaDepth, Is.EqualTo("clawsharp.delegation.depth")); + + [Test] + public void MetaMaxDepth_HasCorrectValue() + => Assert.That(A2aAttributes.MetaMaxDepth, Is.EqualTo("clawsharp.delegation.maxDepth")); + + [Test] + public void MetaOriginInstance_HasCorrectValue() + => Assert.That(A2aAttributes.MetaOriginInstance, Is.EqualTo("clawsharp.delegation.originInstance")); + + [Test] + public void MetaChainId_HasCorrectValue() + => Assert.That(A2aAttributes.MetaChainId, Is.EqualTo("clawsharp.delegation.chainId")); } diff --git a/tests/clawsharp.Tests/Unit/A2a/A2aClientServiceTests.cs b/tests/clawsharp.Tests/Unit/A2a/A2aClientServiceTests.cs index b78c960d..13fdea44 100644 --- a/tests/clawsharp.Tests/Unit/A2a/A2aClientServiceTests.cs +++ b/tests/clawsharp.Tests/Unit/A2a/A2aClientServiceTests.cs @@ -95,9 +95,9 @@ public void AgentRegistry_ExposesConfiguredAgents() // -- DelegateAsync with unknown agent returns error ------------------- [Test] - public async Task DelegateAsync_UnknownAgent_ReturnsErrorMessage() + public async Task DelegateAsync_UnknownAgent_ReturnsErrorTuple() { - // Without InitializeAsync, _clients is empty — any agent name is "unknown" + // Without InitializeAsync, _clients is empty -- any agent name is "unknown" var config = CreateConfigWithAgents( ("research-bot", "https://research.example.com/a2a", "Research", "bearer", "tok", null)); @@ -105,10 +105,11 @@ public async Task DelegateAsync_UnknownAgent_ReturnsErrorMessage() var logger = Substitute.For>(); var service = new A2aClientService(config, factory, logger); - var result = await service.DelegateAsync("nonexistent-agent", "do something"); + var (text, isError) = await service.DelegateAsync("nonexistent-agent", "do something"); - result.ShouldContain("Unknown agent"); - result.ShouldContain("nonexistent-agent"); + isError.ShouldBeTrue(); + text.ShouldContain("Unknown agent"); + text.ShouldContain("nonexistent-agent"); } // -- DelegateAsync with empty agent name returns error ---------------- @@ -123,17 +124,18 @@ public async Task DelegateAsync_EmptyAgentName_ReturnsError() var logger = Substitute.For>(); var service = new A2aClientService(config, factory, logger); - var result = await service.DelegateAsync("", "do something"); + var (text, isError) = await service.DelegateAsync("", "do something"); - result.ShouldContain("Unknown agent"); + isError.ShouldBeTrue(); + text.ShouldContain("Unknown agent"); } // -- DelegateAsync with cancelled token returns descriptive error ------ [Test] - public async Task DelegateAsync_CancelledToken_ReturnsErrorString() + public async Task DelegateAsync_CancelledToken_ReturnsErrorTuple() { - // Without InitializeAsync, _clients is empty — delegation returns "Unknown agent" + // Without InitializeAsync, _clients is empty -- delegation returns "Unknown agent" // even with cancellation. This tests the never-throw contract. var config = CreateConfigWithAgents( ("bot", "https://research.example.com/a2a", null, "bearer", "tok", null)); @@ -145,11 +147,12 @@ public async Task DelegateAsync_CancelledToken_ReturnsErrorString() using var cts = new CancellationTokenSource(); cts.Cancel(); - // Should return error string, not throw OperationCanceledException - var result = await service.DelegateAsync("bot", "do something", ct: cts.Token); + // Should return error tuple, not throw OperationCanceledException + var (text, isError) = await service.DelegateAsync("bot", "do something", ct: cts.Token); - result.ShouldNotBeNullOrEmpty(); - result.ShouldContain("Unknown agent"); + isError.ShouldBeTrue(); + text.ShouldNotBeNullOrEmpty(); + text.ShouldContain("Unknown agent"); } // -- IsTerminalState -------------------------------------------------- diff --git a/tests/clawsharp.Tests/Unit/A2a/A2aDelegateToolTests.cs b/tests/clawsharp.Tests/Unit/A2a/A2aDelegateToolTests.cs index 6a5c1cd8..c8b85c0f 100644 --- a/tests/clawsharp.Tests/Unit/A2a/A2aDelegateToolTests.cs +++ b/tests/clawsharp.Tests/Unit/A2a/A2aDelegateToolTests.cs @@ -153,7 +153,7 @@ public async Task ExecuteAsync_AboveDepthLimit_ReturnsDepthLimitError() [Test] public async Task ExecuteAsync_BelowDepthLimit_DoesNotReturnDepthError() { - // Without InitializeAsync, DelegateAsync returns "Unknown agent" — but NOT a depth error + // Without InitializeAsync, DelegateAsync returns "Unknown agent" -- but NOT a depth error var tool = CreateTool(depthLimit: 3); SetSpawnDepth(0); @@ -171,8 +171,8 @@ public void BuildDelegationMetadata_IncludesDepthPlusOne() { var metadata = A2aDelegateTool.BuildDelegationMetadata(currentDepth: 2, depthLimit: 5); - metadata.ShouldContainKey("clawsharp.delegation.depth"); - metadata["clawsharp.delegation.depth"].GetInt32().ShouldBe(3); + metadata.ShouldContainKey(A2aAttributes.MetaDepth); + metadata[A2aAttributes.MetaDepth].GetInt32().ShouldBe(3); } [Test] @@ -180,8 +180,8 @@ public void BuildDelegationMetadata_IncludesMaxDepth() { var metadata = A2aDelegateTool.BuildDelegationMetadata(currentDepth: 0, depthLimit: 5); - metadata.ShouldContainKey("clawsharp.delegation.maxDepth"); - metadata["clawsharp.delegation.maxDepth"].GetInt32().ShouldBe(5); + metadata.ShouldContainKey(A2aAttributes.MetaMaxDepth); + metadata[A2aAttributes.MetaMaxDepth].GetInt32().ShouldBe(5); } [Test] @@ -189,8 +189,8 @@ public void BuildDelegationMetadata_IncludesOriginInstance() { var metadata = A2aDelegateTool.BuildDelegationMetadata(currentDepth: 0, depthLimit: 3); - metadata.ShouldContainKey("clawsharp.delegation.originInstance"); - var origin = metadata["clawsharp.delegation.originInstance"].GetString(); + metadata.ShouldContainKey(A2aAttributes.MetaOriginInstance); + var origin = metadata[A2aAttributes.MetaOriginInstance].GetString(); origin.ShouldNotBeNullOrEmpty(); } @@ -199,8 +199,8 @@ public void BuildDelegationMetadata_IncludesChainId() { var metadata = A2aDelegateTool.BuildDelegationMetadata(currentDepth: 0, depthLimit: 3); - metadata.ShouldContainKey("clawsharp.delegation.chainId"); - var chainId = metadata["clawsharp.delegation.chainId"].GetString(); + metadata.ShouldContainKey(A2aAttributes.MetaChainId); + var chainId = metadata[A2aAttributes.MetaChainId].GetString(); chainId.ShouldNotBeNullOrEmpty(); } @@ -210,8 +210,8 @@ public void BuildDelegationMetadata_ChainIdIsUnique() var meta1 = A2aDelegateTool.BuildDelegationMetadata(currentDepth: 0, depthLimit: 3); var meta2 = A2aDelegateTool.BuildDelegationMetadata(currentDepth: 0, depthLimit: 3); - var id1 = meta1["clawsharp.delegation.chainId"].GetString(); - var id2 = meta2["clawsharp.delegation.chainId"].GetString(); + var id1 = meta1[A2aAttributes.MetaChainId].GetString(); + var id2 = meta2[A2aAttributes.MetaChainId].GetString(); id1.ShouldNotBe(id2); } diff --git a/tests/clawsharp.Tests/Unit/A2a/A2aServerWithPushTests.cs b/tests/clawsharp.Tests/Unit/A2a/A2aServerWithPushTests.cs index ee0d746e..6bf6b79f 100644 --- a/tests/clawsharp.Tests/Unit/A2a/A2aServerWithPushTests.cs +++ b/tests/clawsharp.Tests/Unit/A2a/A2aServerWithPushTests.cs @@ -1,6 +1,7 @@ using A2A; using Clawsharp.A2a; using Clawsharp.Config.Features; +using Clawsharp.Tests.Fakes; using Clawsharp.Webhooks; using Microsoft.Extensions.Logging; using NSubstitute; @@ -270,7 +271,7 @@ public async Task DeletePushConfig_RemovesConfig() } [Test] - public async Task DeletePushConfig_LastConfig_RemovesQueue() + public async Task DeletePushConfig_LastConfig_LeavesQueueForEviction() { var createReq = new CreateTaskPushNotificationConfigRequest { @@ -288,7 +289,9 @@ public async Task DeletePushConfig_LastConfig_RemovesQueue() }; await _sut.DeleteTaskPushNotificationConfigAsync(delReq, CancellationToken.None); - _queueRegistry.EndpointIds.ShouldNotContain("a2a-push:task-del-q"); + // Queue is NOT eagerly removed on last config delete — CleanupTask handles + // queue removal during task eviction to avoid TOCTOU races with concurrent Creates. + _queueRegistry.EndpointIds.ShouldContain("a2a-push:task-del-q"); } // ── Push Delivery Trigger ─────────────────────────────────────────────── @@ -463,19 +466,4 @@ private static WebhookJob CreateWebhookJob(string endpointId) return new WebhookJob(record, new WebhookEndpointConfig { Url = "https://example.com/test" }, endpointId, "{}"); } - private sealed class CapturingLogger(List<(LogLevel Level, string Message)> messages) : ILogger - { - public IDisposable? BeginScope(TState state) where TState : notnull => null; - public bool IsEnabled(LogLevel logLevel) => true; - - public void Log( - LogLevel logLevel, - EventId eventId, - TState state, - Exception? exception, - Func formatter) - { - messages.Add((logLevel, formatter(state, exception))); - } - } } diff --git a/tests/clawsharp.Tests/Unit/A2a/A2aTaskEvictionServiceTests.cs b/tests/clawsharp.Tests/Unit/A2a/A2aTaskEvictionServiceTests.cs index 5ef679f3..3413988b 100644 --- a/tests/clawsharp.Tests/Unit/A2a/A2aTaskEvictionServiceTests.cs +++ b/tests/clawsharp.Tests/Unit/A2a/A2aTaskEvictionServiceTests.cs @@ -1,5 +1,6 @@ using A2A; using Clawsharp.A2a; +using Clawsharp.Tests.Fakes; using Microsoft.Extensions.Logging; using TaskStatus = A2A.TaskStatus; @@ -229,21 +230,4 @@ public async Task EvictAsync_TtlEvictionRunsBeforeCapEviction() store.GetAllTasks().Count.ShouldBe(2); } - // ── Shared test infrastructure ─────────────────────────────────────────── - - private sealed class CapturingLogger(List<(LogLevel Level, string Message)> messages) : ILogger - { - public IDisposable? BeginScope(TState state) where TState : notnull => null; - public bool IsEnabled(LogLevel logLevel) => true; - - public void Log( - LogLevel logLevel, - EventId eventId, - TState state, - Exception? exception, - Func formatter) - { - messages.Add((logLevel, formatter(state, exception))); - } - } } diff --git a/tests/clawsharp.Tests/Unit/A2a/A2aTaskProcessorStreamingTests.cs b/tests/clawsharp.Tests/Unit/A2a/A2aTaskProcessorStreamingTests.cs index 77772a0e..de5803e1 100644 --- a/tests/clawsharp.Tests/Unit/A2a/A2aTaskProcessorStreamingTests.cs +++ b/tests/clawsharp.Tests/Unit/A2a/A2aTaskProcessorStreamingTests.cs @@ -11,6 +11,7 @@ using Clawsharp.McpServer; using Clawsharp.Organization; using Clawsharp.Providers; +using Clawsharp.Tests.Fakes; using Clawsharp.Tools; using Clawsharp.Webhooks; using Microsoft.AspNetCore.Http; @@ -796,23 +797,4 @@ public async Task SubscribeToTask_ReceivesLiveUpdatesAfterCatchUp() } } - // ═══════════════════════════════════════════════════════════════════════════ - // CapturingLogger for source-generated [LoggerMessage] testing - // ═══════════════════════════════════════════════════════════════════════════ - - private sealed class CapturingLogger(List<(LogLevel Level, string Message)> messages) : ILogger - { - public IDisposable? BeginScope(TState state) where TState : notnull => null; - public bool IsEnabled(LogLevel logLevel) => true; - - public void Log( - LogLevel logLevel, - EventId eventId, - TState state, - Exception? exception, - Func formatter) - { - messages.Add((logLevel, formatter(state, exception))); - } - } } diff --git a/tests/clawsharp.Tests/Unit/A2a/A2aTaskProcessorTests.cs b/tests/clawsharp.Tests/Unit/A2a/A2aTaskProcessorTests.cs index f4710584..2c8707b7 100644 --- a/tests/clawsharp.Tests/Unit/A2a/A2aTaskProcessorTests.cs +++ b/tests/clawsharp.Tests/Unit/A2a/A2aTaskProcessorTests.cs @@ -11,6 +11,7 @@ using Clawsharp.McpServer; using Clawsharp.Organization; using Clawsharp.Providers; +using Clawsharp.Tests.Fakes; using Clawsharp.Tools; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.DependencyInjection; @@ -770,23 +771,4 @@ public void Implements_IDisposable() } } - // ═══════════════════════════════════════════════════════════════════════════ - // CapturingLogger for source-generated [LoggerMessage] testing - // ═══════════════════════════════════════════════════════════════════════════ - - private sealed class CapturingLogger(List<(LogLevel Level, string Message)> messages) : ILogger - { - public IDisposable? BeginScope(TState state) where TState : notnull => null; - public bool IsEnabled(LogLevel logLevel) => true; - - public void Log( - LogLevel logLevel, - EventId eventId, - TState state, - Exception? exception, - Func formatter) - { - messages.Add((logLevel, formatter(state, exception))); - } - } } diff --git a/tests/clawsharp.Tests/Unit/A2a/A2aTaskStoreTests.cs b/tests/clawsharp.Tests/Unit/A2a/A2aTaskStoreTests.cs index 6a95314a..22da0897 100644 --- a/tests/clawsharp.Tests/Unit/A2a/A2aTaskStoreTests.cs +++ b/tests/clawsharp.Tests/Unit/A2a/A2aTaskStoreTests.cs @@ -1,6 +1,7 @@ using System.Text.Json; using A2A; using Clawsharp.A2a; +using Clawsharp.Tests.Fakes; using Microsoft.Extensions.Logging; using TaskStatus = A2A.TaskStatus; @@ -228,24 +229,22 @@ public async Task DeleteTaskAsync_RemovesFromInMemoryDictionary() // ── State transition validation ────────────────────────────────────────── [Test] - public async Task SaveTaskAsync_LogsWarning_OnInvalidTransition_CompletedToWorking() + public async Task SaveTaskAsync_ThrowsOnInvalidTransition_CompletedToWorking() { var logMessages = new List<(LogLevel Level, string Message)>(); var logger = new CapturingLogger(logMessages); var store = new A2aTaskStore(_tempDir, logger); await store.SaveTaskAsync("t1", CreateTask("t1", TaskState.Completed)); - // Attempt invalid transition: COMPLETED -> WORKING - await store.SaveTaskAsync("t1", CreateTask("t1", TaskState.Working)); + // Attempt invalid transition: COMPLETED -> WORKING — now throws (L-10 enforcement) + var ex = Assert.ThrowsAsync( + () => store.SaveTaskAsync("t1", CreateTask("t1", TaskState.Working))); + ex!.Message.ShouldContain("Invalid A2A task state transition"); - // Task is still saved (never rejects) + // Task should remain in its original state (not overwritten) var result = await store.GetTaskAsync("t1"); result.ShouldNotBeNull(); - result!.Status!.State.ShouldBe(TaskState.Working); - - // Logger should have received a warning about invalid transition - logMessages.ShouldContain(m => - m.Level == LogLevel.Warning && m.Message.Contains("Invalid A2A task state transition")); + result!.Status!.State.ShouldBe(TaskState.Completed); } [Test] @@ -409,16 +408,16 @@ public async Task CompactAsync_RewritesJsonlWithCurrentEntriesOnly() await store.SaveTaskAsync("t3", CreateTask("t3")); await store.DeleteTaskAsync("t2"); - // JSONL still has 3 lines (append-only) + // JSONL has 4 lines: 3 saves + 1 delete tombstone (append-only) var filePath = Path.Combine(_tempDir, "tasks.jsonl"); var linesBefore = (await File.ReadAllLinesAsync(filePath)) .Where(l => !string.IsNullOrWhiteSpace(l)).ToArray(); - linesBefore.Length.ShouldBe(3); + linesBefore.Length.ShouldBe(4); // Compact await store.CompactAsync(); - // Now JSONL should have 2 lines + // Now JSONL should have 2 lines (only surviving tasks) var linesAfter = (await File.ReadAllLinesAsync(filePath)) .Where(l => !string.IsNullOrWhiteSpace(l)).ToArray(); linesAfter.Length.ShouldBe(2); @@ -443,26 +442,4 @@ public async Task GetAllTasks_ReturnsSnapshotOfAllEntries() all.Count.ShouldBe(2); } - // ── Test infrastructure ────────────────────────────────────────────────── - - /// - /// Minimal ILogger implementation that captures log messages for assertion. - /// Source-generated [LoggerMessage] methods call the raw Log method with - /// a generated state type, making NSubstitute matching unreliable. - /// - private sealed class CapturingLogger(List<(LogLevel Level, string Message)> messages) : ILogger - { - public IDisposable? BeginScope(TState state) where TState : notnull => null; - public bool IsEnabled(LogLevel logLevel) => true; - - public void Log( - LogLevel logLevel, - EventId eventId, - TState state, - Exception? exception, - Func formatter) - { - messages.Add((logLevel, formatter(state, exception))); - } - } } diff --git a/tests/clawsharp.Tests/Unit/Channels/AllowListPolicyTests.cs b/tests/clawsharp.Tests/Unit/Channels/AllowListPolicyTests.cs index 67704a33..b8a8b655 100644 --- a/tests/clawsharp.Tests/Unit/Channels/AllowListPolicyTests.cs +++ b/tests/clawsharp.Tests/Unit/Channels/AllowListPolicyTests.cs @@ -2,6 +2,7 @@ namespace Clawsharp.Tests.Unit.Channels; +[TestFixture] public sealed class AllowListPolicyTests { [Test] diff --git a/tests/clawsharp.Tests/DiscordChannelOptionsTests.cs b/tests/clawsharp.Tests/Unit/Channels/DiscordChannelOptionsTests.cs similarity index 98% rename from tests/clawsharp.Tests/DiscordChannelOptionsTests.cs rename to tests/clawsharp.Tests/Unit/Channels/DiscordChannelOptionsTests.cs index 63eb953e..55304d9a 100644 --- a/tests/clawsharp.Tests/DiscordChannelOptionsTests.cs +++ b/tests/clawsharp.Tests/Unit/Channels/DiscordChannelOptionsTests.cs @@ -1,8 +1,9 @@ using Clawsharp.Channels.Discord; using Clawsharp.Config.Channels; -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Channels; +[TestFixture] public sealed class DiscordChannelOptionsTests { [Test] diff --git a/tests/clawsharp.Tests/EmailSecurityTests.cs b/tests/clawsharp.Tests/Unit/Channels/EmailSecurityTests.cs similarity index 99% rename from tests/clawsharp.Tests/EmailSecurityTests.cs rename to tests/clawsharp.Tests/Unit/Channels/EmailSecurityTests.cs index cb20b427..a8e07612 100644 --- a/tests/clawsharp.Tests/EmailSecurityTests.cs +++ b/tests/clawsharp.Tests/Unit/Channels/EmailSecurityTests.cs @@ -1,4 +1,4 @@ -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Channels; /// /// Tests for Email channel security logic: quoted reply stripping, command prefix, @@ -6,6 +6,7 @@ namespace Clawsharp.Tests; /// Uses pattern replication approach — the relevant logic is replicated as local /// static methods from EmailChannel. /// +[TestFixture] public sealed class EmailSecurityTests { // Replicates EmailChannel.PollImapAsync quoted reply stripping logic diff --git a/tests/clawsharp.Tests/IrcSecurityTests.cs b/tests/clawsharp.Tests/Unit/Channels/IrcSecurityTests.cs similarity index 99% rename from tests/clawsharp.Tests/IrcSecurityTests.cs rename to tests/clawsharp.Tests/Unit/Channels/IrcSecurityTests.cs index ed9423cd..a8a07254 100644 --- a/tests/clawsharp.Tests/IrcSecurityTests.cs +++ b/tests/clawsharp.Tests/Unit/Channels/IrcSecurityTests.cs @@ -1,10 +1,11 @@ -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Channels; /// /// Tests for IRC channel security logic: nick allowlist, channel allowlist, directed-at-bot /// detection, and mention cleanup. /// Uses pattern replication approach from IrcChannel. /// +[TestFixture] public sealed class IrcSecurityTests { // Replicates IrcChannel allowlist initialization logic diff --git a/tests/clawsharp.Tests/MatrixMentionDetectionTests.cs b/tests/clawsharp.Tests/Unit/Channels/MatrixMentionDetectionTests.cs similarity index 97% rename from tests/clawsharp.Tests/MatrixMentionDetectionTests.cs rename to tests/clawsharp.Tests/Unit/Channels/MatrixMentionDetectionTests.cs index b3a30303..0d7def51 100644 --- a/tests/clawsharp.Tests/MatrixMentionDetectionTests.cs +++ b/tests/clawsharp.Tests/Unit/Channels/MatrixMentionDetectionTests.cs @@ -1,10 +1,11 @@ -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Channels; /// /// Tests for Matrix channel mention detection logic: localpart extraction from MXID /// and message body mention checking. /// Uses pattern replication approach from MatrixChannel.SyncOnceAsync. /// +[TestFixture] public sealed class MatrixMentionDetectionTests { // Replicates MatrixChannel localpart extraction from _selfId diff --git a/tests/clawsharp.Tests/Unit/Channels/QqChannelTests.cs b/tests/clawsharp.Tests/Unit/Channels/QqChannelTests.cs index 85cd01aa..6c98cac9 100644 --- a/tests/clawsharp.Tests/Unit/Channels/QqChannelTests.cs +++ b/tests/clawsharp.Tests/Unit/Channels/QqChannelTests.cs @@ -13,6 +13,7 @@ namespace Clawsharp.Tests.Unit.Channels; [FixtureLifeCycle(LifeCycle.InstancePerTestCase)] +[TestFixture] public sealed class QqChannelTests : IDisposable { private readonly CapturingMessageBus _bus = new(); diff --git a/tests/clawsharp.Tests/SlackMrkdwnTests.cs b/tests/clawsharp.Tests/Unit/Channels/SlackMrkdwnTests.cs similarity index 99% rename from tests/clawsharp.Tests/SlackMrkdwnTests.cs rename to tests/clawsharp.Tests/Unit/Channels/SlackMrkdwnTests.cs index eba08e73..60af3fe1 100644 --- a/tests/clawsharp.Tests/SlackMrkdwnTests.cs +++ b/tests/clawsharp.Tests/Unit/Channels/SlackMrkdwnTests.cs @@ -1,11 +1,12 @@ using Clawsharp.Channels.Slack; -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Channels; /// /// Tests for Slack mrkdwn conversion and empty-text guard in SlackChannel. /// Calls the internal static ConvertToMrkdwn method directly. /// +[TestFixture] public sealed class SlackMrkdwnTests { // ── ConvertToMrkdwn: bold ──────────────────────────────────────────────── diff --git a/tests/clawsharp.Tests/SlackSecurityTests.cs b/tests/clawsharp.Tests/Unit/Channels/SlackSecurityTests.cs similarity index 99% rename from tests/clawsharp.Tests/SlackSecurityTests.cs rename to tests/clawsharp.Tests/Unit/Channels/SlackSecurityTests.cs index 92d157b3..c02a3903 100644 --- a/tests/clawsharp.Tests/SlackSecurityTests.cs +++ b/tests/clawsharp.Tests/Unit/Channels/SlackSecurityTests.cs @@ -1,4 +1,4 @@ -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Channels; /// /// Tests for Slack channel security logic (allowlist, channel filter, mention requirement). @@ -6,6 +6,7 @@ namespace Clawsharp.Tests; /// is replicated as local static methods, tested directly without needing the full SlackChannel /// constructor dependency tree. /// +[TestFixture] public sealed class SlackSecurityTests { // Replicates SlackChannel allowlist initialization logic diff --git a/tests/clawsharp.Tests/TelegramAllowlistTests.cs b/tests/clawsharp.Tests/Unit/Channels/TelegramAllowlistTests.cs similarity index 99% rename from tests/clawsharp.Tests/TelegramAllowlistTests.cs rename to tests/clawsharp.Tests/Unit/Channels/TelegramAllowlistTests.cs index 7e45f999..1392dfd5 100644 --- a/tests/clawsharp.Tests/TelegramAllowlistTests.cs +++ b/tests/clawsharp.Tests/Unit/Channels/TelegramAllowlistTests.cs @@ -1,7 +1,7 @@ using System.Reflection; using Clawsharp.Channels.Telegram; -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Channels; /// /// Tests for Telegram channel allowlist logic: Normalize and IsUserAllowed. @@ -10,6 +10,7 @@ namespace Clawsharp.Tests; /// (which requires redirecting an initonly static field in ApprovedSendersStore). /// The replicated logic mirrors TelegramChannel's constructor + IsUserAllowed exactly. /// +[TestFixture] public sealed class TelegramAllowlistTests { // ── Normalize tests (private static method via reflection) ────── diff --git a/tests/clawsharp.Tests/Unit/Cli/AuditFilterTests.cs b/tests/clawsharp.Tests/Unit/Cli/AuditFilterTests.cs index 07ba70d0..257c09c6 100644 --- a/tests/clawsharp.Tests/Unit/Cli/AuditFilterTests.cs +++ b/tests/clawsharp.Tests/Unit/Cli/AuditFilterTests.cs @@ -4,6 +4,7 @@ namespace Clawsharp.Tests.Unit.Cli; /// Tests for audit search filter predicate logic. +[TestFixture] public sealed class AuditFilterTests { private static readonly DateTimeOffset March15 = new(2026, 3, 15, 12, 0, 0, TimeSpan.Zero); diff --git a/tests/clawsharp.Tests/Unit/Cli/ConfigSetCommandTests.cs b/tests/clawsharp.Tests/Unit/Cli/ConfigSetCommandTests.cs index 0a13cc13..cd4ddddf 100644 --- a/tests/clawsharp.Tests/Unit/Cli/ConfigSetCommandTests.cs +++ b/tests/clawsharp.Tests/Unit/Cli/ConfigSetCommandTests.cs @@ -3,6 +3,7 @@ namespace Clawsharp.Tests.Unit.Cli; /// Tests for type detection logic. +[TestFixture] public sealed class ConfigSetCommandTests { // ── Bool detection ────────────────────────────────────────────────────── diff --git a/tests/clawsharp.Tests/Unit/Cli/ConfigShowCommandTests.cs b/tests/clawsharp.Tests/Unit/Cli/ConfigShowCommandTests.cs index 09cb844e..7d914296 100644 --- a/tests/clawsharp.Tests/Unit/Cli/ConfigShowCommandTests.cs +++ b/tests/clawsharp.Tests/Unit/Cli/ConfigShowCommandTests.cs @@ -3,6 +3,7 @@ namespace Clawsharp.Tests.Unit.Cli; /// Tests for and . +[TestFixture] public sealed class ConfigShowCommandTests { // ── Redact ─────────────────────────────────────────────────────────────── diff --git a/tests/clawsharp.Tests/Unit/Cli/CostByUserCommandTests.cs b/tests/clawsharp.Tests/Unit/Cli/CostByUserCommandTests.cs index 9b5abdb3..e42a5ee4 100644 --- a/tests/clawsharp.Tests/Unit/Cli/CostByUserCommandTests.cs +++ b/tests/clawsharp.Tests/Unit/Cli/CostByUserCommandTests.cs @@ -5,6 +5,7 @@ namespace Clawsharp.Tests.Unit.Cli; /// Tests for cost aggregation by user and department with date filtering and budget status. +[TestFixture] public sealed class CostByUserCommandTests { private static readonly DateTimeOffset Today = new(2026, 3, 21, 12, 0, 0, TimeSpan.Zero); diff --git a/tests/clawsharp.Tests/Unit/Cli/MigrateCommandTests.cs b/tests/clawsharp.Tests/Unit/Cli/MigrateCommandTests.cs index 3fa7f5c7..8e754b08 100644 --- a/tests/clawsharp.Tests/Unit/Cli/MigrateCommandTests.cs +++ b/tests/clawsharp.Tests/Unit/Cli/MigrateCommandTests.cs @@ -4,6 +4,7 @@ namespace Clawsharp.Tests.Unit.Cli; /// Tests for pure helper functions: TOML parsing, GetNode, SetNode. +[TestFixture] public sealed class MigrateCommandTests { // ── ParseToml ──────────────────────────────────────────────────────────── diff --git a/tests/clawsharp.Tests/Unit/Compatibility/Compat01_ZeroOverheadTests.cs b/tests/clawsharp.Tests/Unit/Compatibility/Compat01_ZeroOverheadTests.cs index bd662927..fa693d3e 100644 --- a/tests/clawsharp.Tests/Unit/Compatibility/Compat01_ZeroOverheadTests.cs +++ b/tests/clawsharp.Tests/Unit/Compatibility/Compat01_ZeroOverheadTests.cs @@ -144,7 +144,7 @@ public void WithoutKnowledgeConfig_NoKnowledgeServicesRegistered() var appConfig = new AppConfig(); // Knowledge is null GatewayHost.RegisterKnowledgeStore(services, appConfig); - GatewayHost.RegisterDocumentLoaders(services, appConfig, new ConfigurationBuilder().Build()); + GatewayHost.RegisterDocumentLoaders(services, appConfig, new ConfigurationBuilder().Build(), []); GatewayHost.RegisterIngestionPipeline(services, appConfig); GatewayHost.RegisterReranker(services, appConfig); @@ -166,7 +166,7 @@ public void WithoutKnowledgeConfig_NoKnowledgeHostedServices() var appConfig = new AppConfig(); // Knowledge is null GatewayHost.RegisterKnowledgeStore(services, appConfig); - GatewayHost.RegisterDocumentLoaders(services, appConfig, new ConfigurationBuilder().Build()); + GatewayHost.RegisterDocumentLoaders(services, appConfig, new ConfigurationBuilder().Build(), []); GatewayHost.RegisterIngestionPipeline(services, appConfig); GatewayHost.RegisterReranker(services, appConfig); @@ -187,7 +187,7 @@ public void WithKnowledgeDisabled_NoKnowledgeServicesRegistered() }; GatewayHost.RegisterKnowledgeStore(services, appConfig); - GatewayHost.RegisterDocumentLoaders(services, appConfig, new ConfigurationBuilder().Build()); + GatewayHost.RegisterDocumentLoaders(services, appConfig, new ConfigurationBuilder().Build(), []); GatewayHost.RegisterIngestionPipeline(services, appConfig); GatewayHost.RegisterReranker(services, appConfig); diff --git a/tests/clawsharp.Tests/Unit/Compatibility/Compat02_CoexistenceTests.cs b/tests/clawsharp.Tests/Unit/Compatibility/Compat02_CoexistenceTests.cs index 7c332b1d..70f727fa 100644 --- a/tests/clawsharp.Tests/Unit/Compatibility/Compat02_CoexistenceTests.cs +++ b/tests/clawsharp.Tests/Unit/Compatibility/Compat02_CoexistenceTests.cs @@ -113,30 +113,44 @@ public void FourSubsystemCoexistence_NoServiceTypeConflicts() { var services = new ServiceCollection(); - // MCP server stubs + // MCP server stubs — complex constructors, registered as singletons services.AddSingleton(sp => null!); services.AddSingleton(sp => null!); services.AddSingleton(sp => null!); - // Webhook stubs + // Webhook stubs — WebhookMetrics has a simple constructor services.AddSingleton(sp => null!); services.AddSingleton(sp => null!); - services.AddSingleton(sp => null!); + services.AddSingleton(new WebhookMetrics(new WebhookConfig())); - // Knowledge stubs - services.AddSingleton(sp => null!); + // Knowledge — parameterless constructor + services.AddSingleton(new KnowledgeMetrics()); - // A2A stubs + // A2A — A2aMetrics has a parameterless constructor services.AddSingleton(sp => null!); services.AddSingleton(sp => null!); - services.AddSingleton(sp => null!); - - // All four subsystems have service descriptors — verify no type conflicts - services.Any(d => d.ServiceType == typeof(McpServerRouteRegistrar)).ShouldBeTrue(); - services.Any(d => d.ServiceType == typeof(WebhookDispatchService)).ShouldBeTrue(); - services.Any(d => d.ServiceType == typeof(KnowledgeMetrics)).ShouldBeTrue(); - services.Any(d => d.ServiceType == typeof(A2aRouteRegistrar)).ShouldBeTrue(); - services.Any(d => d.ServiceType == typeof(A2aMetrics)).ShouldBeTrue(); + services.AddSingleton(new A2aMetrics()); + + // Build provider and resolve the types with real constructors + using var provider = services.BuildServiceProvider(); + + // Types with real instances should resolve successfully + var webhookMetrics = provider.GetService(); + var knowledgeMetrics = provider.GetService(); + var a2aMetrics = provider.GetService(); + + Assert.Multiple(() => + { + webhookMetrics.ShouldNotBeNull(); + knowledgeMetrics.ShouldNotBeNull(); + a2aMetrics.ShouldNotBeNull(); + + // All four subsystems have service descriptors — verify no type conflicts + provider.GetService(); // resolves (null factory, but no conflict) + provider.GetService(); + provider.GetService(); + provider.GetService(); + }); } // ── Test 8: All IHttpRouteRegistrar implementations coexist ────────────── diff --git a/tests/clawsharp.Tests/AllowListConverterTests.cs b/tests/clawsharp.Tests/Unit/Config/AllowListConverterTests.cs similarity index 98% rename from tests/clawsharp.Tests/AllowListConverterTests.cs rename to tests/clawsharp.Tests/Unit/Config/AllowListConverterTests.cs index 5ad9e246..8f163ddf 100644 --- a/tests/clawsharp.Tests/AllowListConverterTests.cs +++ b/tests/clawsharp.Tests/Unit/Config/AllowListConverterTests.cs @@ -1,8 +1,9 @@ using System.Text.Json; using Clawsharp.Config.Channels; -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Config; +[TestFixture] public sealed class AllowListConverterTests { private static readonly JsonSerializerOptions Options = new() diff --git a/tests/clawsharp.Tests/ApprovedSendersStoreTests.cs b/tests/clawsharp.Tests/Unit/Config/ApprovedSendersStoreTests.cs similarity index 98% rename from tests/clawsharp.Tests/ApprovedSendersStoreTests.cs rename to tests/clawsharp.Tests/Unit/Config/ApprovedSendersStoreTests.cs index 716a4a3e..372af840 100644 --- a/tests/clawsharp.Tests/ApprovedSendersStoreTests.cs +++ b/tests/clawsharp.Tests/Unit/Config/ApprovedSendersStoreTests.cs @@ -1,6 +1,6 @@ using System.Text.Json; -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Config; /// /// Tests for ApprovedSendersStore logic: approval tracking, channel isolation, persistence, @@ -9,6 +9,7 @@ namespace Clawsharp.Tests; /// that cannot be redirected via reflection in .NET 10 (initonly enforcement). /// The replicated logic mirrors ApprovedSendersStore.LoadAsync/SaveAsync/IsApprovedAsync/AddAsync exactly. /// +[TestFixture] public sealed class ApprovedSendersStoreTests { private string _tempDir = null!; diff --git a/tests/clawsharp.Tests/Unit/Config/CachingConfigTests.cs b/tests/clawsharp.Tests/Unit/Config/CachingConfigTests.cs index 77cd484f..278d4871 100644 --- a/tests/clawsharp.Tests/Unit/Config/CachingConfigTests.cs +++ b/tests/clawsharp.Tests/Unit/Config/CachingConfigTests.cs @@ -8,6 +8,7 @@ namespace Clawsharp.Tests.Unit.Config; /// Pure unit tests for CachingConfig defaults and the AgentLoop-equivalent logic /// that maps config to ChatRequest caching flags. No I/O. /// +[TestFixture] public sealed class CachingConfigTests { // ── Default values ──────────────────────────────────────────────────────── diff --git a/tests/clawsharp.Tests/Unit/Config/ConfigKeyValidatorTests.cs b/tests/clawsharp.Tests/Unit/Config/ConfigKeyValidatorTests.cs index f501a145..879bb513 100644 --- a/tests/clawsharp.Tests/Unit/Config/ConfigKeyValidatorTests.cs +++ b/tests/clawsharp.Tests/Unit/Config/ConfigKeyValidatorTests.cs @@ -3,6 +3,7 @@ namespace Clawsharp.Tests.Unit.Config; /// Tests for dot-path validation. +[TestFixture] public sealed class ConfigKeyValidatorTests { // ── Valid fixed leaf paths ──────────────────────────────────────────────── diff --git a/tests/clawsharp.Tests/ConfigValidatorTests.cs b/tests/clawsharp.Tests/Unit/Config/ConfigValidatorTests.cs similarity index 99% rename from tests/clawsharp.Tests/ConfigValidatorTests.cs rename to tests/clawsharp.Tests/Unit/Config/ConfigValidatorTests.cs index 9edde641..2192b3ed 100644 --- a/tests/clawsharp.Tests/ConfigValidatorTests.cs +++ b/tests/clawsharp.Tests/Unit/Config/ConfigValidatorTests.cs @@ -3,8 +3,9 @@ using Clawsharp.Config.Channels; using Clawsharp.Config.Memory; -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Config; +[TestFixture] public sealed class ConfigValidatorTests { // Known Intellenum runtime issue: MemoryBackend.TryFromName always returns false diff --git a/tests/clawsharp.Tests/Unit/Config/ResilienceConfigTests.cs b/tests/clawsharp.Tests/Unit/Config/ResilienceConfigTests.cs index 0326430d..0c1f7370 100644 --- a/tests/clawsharp.Tests/Unit/Config/ResilienceConfigTests.cs +++ b/tests/clawsharp.Tests/Unit/Config/ResilienceConfigTests.cs @@ -3,6 +3,7 @@ namespace Clawsharp.Tests.Unit.Config; +[TestFixture] public sealed class ResilienceConfigTests { private static List Validate(object obj) diff --git a/tests/clawsharp.Tests/Unit/Config/WebhookFormatOnChannelValidationTests.cs b/tests/clawsharp.Tests/Unit/Config/WebhookFormatOnChannelValidationTests.cs index db6168a1..47742557 100644 --- a/tests/clawsharp.Tests/Unit/Config/WebhookFormatOnChannelValidationTests.cs +++ b/tests/clawsharp.Tests/Unit/Config/WebhookFormatOnChannelValidationTests.cs @@ -2,6 +2,7 @@ using Clawsharp.Config.Agent; using Clawsharp.Config.Channels; using Clawsharp.Config.Features; +using Clawsharp.Webhooks; using NUnit.Framework; using Shouldly; diff --git a/tests/clawsharp.Tests/Unit/Config/WebhookValidatorTests.cs b/tests/clawsharp.Tests/Unit/Config/WebhookValidatorTests.cs index 4f0f4423..7ddbeed4 100644 --- a/tests/clawsharp.Tests/Unit/Config/WebhookValidatorTests.cs +++ b/tests/clawsharp.Tests/Unit/Config/WebhookValidatorTests.cs @@ -11,6 +11,7 @@ namespace Clawsharp.Tests.Unit.Config; /// unknown categories, channel:// target validation, blank filters, and D-15 /// (duplicate endpoint key last-wins via System.Text.Json deserialization). /// +[TestFixture] public sealed class WebhookValidatorTests { // Known Intellenum runtime issue: MemoryBackend.TryFromName always returns false diff --git a/tests/clawsharp.Tests/Unit/Core/AgentStepExecutorStreamTests.cs b/tests/clawsharp.Tests/Unit/Core/AgentStepExecutorStreamTests.cs index 5f3ab760..42bbef75 100644 --- a/tests/clawsharp.Tests/Unit/Core/AgentStepExecutorStreamTests.cs +++ b/tests/clawsharp.Tests/Unit/Core/AgentStepExecutorStreamTests.cs @@ -162,11 +162,12 @@ public async Task StreamAsync_MultipleToolCalls_AllExecuted() var request = CreateRequest(); var events = await CollectEventsAsync(request, provider, tools); - // ToolStart(a), ToolResult(a), ToolStart(b), ToolResult(b), TextChunk("done"), UsageReport + // Multiple tool calls use batch-parallel execution: + // ToolStart(a), ToolStart(b), [parallel exec], ToolResult(a), ToolResult(b), TextChunk("done"), UsageReport events.Count.ShouldBe(6); events[0].ShouldBeOfType().ToolName.ShouldBe("tool_a"); - events[1].ShouldBeOfType().Result.ShouldBe("result_a"); - events[2].ShouldBeOfType().ToolName.ShouldBe("tool_b"); + events[1].ShouldBeOfType().ToolName.ShouldBe("tool_b"); + events[2].ShouldBeOfType().Result.ShouldBe("result_a"); events[3].ShouldBeOfType().Result.ShouldBe("result_b"); events[4].ShouldBeOfType().Text.ShouldBe("done"); events[5].ShouldBeOfType(); diff --git a/tests/clawsharp.Tests/Unit/Core/ComplexityScorerTests.cs b/tests/clawsharp.Tests/Unit/Core/ComplexityScorerTests.cs index 4f0fc229..1f3b9ba7 100644 --- a/tests/clawsharp.Tests/Unit/Core/ComplexityScorerTests.cs +++ b/tests/clawsharp.Tests/Unit/Core/ComplexityScorerTests.cs @@ -2,6 +2,7 @@ namespace Clawsharp.Tests.Unit.Core; +[TestFixture] public sealed class ComplexityScorerTests { [Test] diff --git a/tests/clawsharp.Tests/CronDurationParserTests.cs b/tests/clawsharp.Tests/Unit/Core/CronDurationParserTests.cs similarity index 96% rename from tests/clawsharp.Tests/CronDurationParserTests.cs rename to tests/clawsharp.Tests/Unit/Core/CronDurationParserTests.cs index 2e229f7e..fb99deb1 100644 --- a/tests/clawsharp.Tests/CronDurationParserTests.cs +++ b/tests/clawsharp.Tests/Unit/Core/CronDurationParserTests.cs @@ -1,8 +1,9 @@ using System.Reflection; using Clawsharp.Cron; -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Core; +[TestFixture] public sealed class CronDurationParserTests { private static bool InvokeTryParseDuration(string input, out long ms) diff --git a/tests/clawsharp.Tests/CronParserTests.cs b/tests/clawsharp.Tests/Unit/Core/CronParserTests.cs similarity index 99% rename from tests/clawsharp.Tests/CronParserTests.cs rename to tests/clawsharp.Tests/Unit/Core/CronParserTests.cs index b990a539..7e75ecf9 100644 --- a/tests/clawsharp.Tests/CronParserTests.cs +++ b/tests/clawsharp.Tests/Unit/Core/CronParserTests.cs @@ -1,7 +1,8 @@ using Clawsharp.Core.Services; -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Core; +[TestFixture] public sealed class CronParserTests { // Tuesday, March 3, 2026, 09:30:00 UTC diff --git a/tests/clawsharp.Tests/GoalStorageTests.cs b/tests/clawsharp.Tests/Unit/Core/GoalStorageTests.cs similarity index 98% rename from tests/clawsharp.Tests/GoalStorageTests.cs rename to tests/clawsharp.Tests/Unit/Core/GoalStorageTests.cs index 262ef59e..e98ded0b 100644 --- a/tests/clawsharp.Tests/GoalStorageTests.cs +++ b/tests/clawsharp.Tests/Unit/Core/GoalStorageTests.cs @@ -1,9 +1,10 @@ using Clawsharp.Goals; using Microsoft.Extensions.Logging.Abstractions; -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Core; [FixtureLifeCycle(LifeCycle.InstancePerTestCase)] +[TestFixture] public sealed class GoalStorageTests : IDisposable { private readonly string _tempDir; diff --git a/tests/clawsharp.Tests/GoalToolTests.cs b/tests/clawsharp.Tests/Unit/Core/GoalToolTests.cs similarity index 99% rename from tests/clawsharp.Tests/GoalToolTests.cs rename to tests/clawsharp.Tests/Unit/Core/GoalToolTests.cs index fe169a5f..a0c3f5bf 100644 --- a/tests/clawsharp.Tests/GoalToolTests.cs +++ b/tests/clawsharp.Tests/Unit/Core/GoalToolTests.cs @@ -3,9 +3,10 @@ using Clawsharp.Tools.Ops; using Microsoft.Extensions.Logging.Abstractions; -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Core; [FixtureLifeCycle(LifeCycle.InstancePerTestCase)] +[TestFixture] public sealed class GoalToolTests : IDisposable { private readonly string _tempDir; diff --git a/tests/clawsharp.Tests/Unit/Core/HeartbeatServiceTests.cs b/tests/clawsharp.Tests/Unit/Core/HeartbeatServiceTests.cs index 6419b2e4..34778ea9 100644 --- a/tests/clawsharp.Tests/Unit/Core/HeartbeatServiceTests.cs +++ b/tests/clawsharp.Tests/Unit/Core/HeartbeatServiceTests.cs @@ -134,6 +134,7 @@ public async Task Constructor_InvalidChannel_FallsBackToCli() // ── 6. Heartbeat deduplicates within the same minute ── [Test] + [Category("Slow")] public async Task ExecuteAsync_SameMinute_PublishesAtMostOncePerMinute() { var bus = new CapturingMessageBus(); @@ -145,8 +146,10 @@ public async Task ExecuteAsync_SameMinute_PublishesAtMostOncePerMinute() using var cts = new CancellationTokenSource(); await service.StartAsync(cts.Token); - // Wait long enough for multiple poll cycles (10s each). 25s gives 2+ polls. - await Task.Delay(TimeSpan.FromSeconds(25), CancellationToken.None); + // Wait long enough for multiple poll cycles (10s each). 22s gives 2+ polls. + // HeartbeatService uses DateTimeOffset.Now directly with no TimeProvider abstraction, + // so wall-clock waiting is required to verify the dedup-per-minute invariant. + await Task.Delay(TimeSpan.FromSeconds(22), CancellationToken.None); await cts.CancelAsync(); await service.StopAsync(CancellationToken.None); diff --git a/tests/clawsharp.Tests/RateLimiterEdgeCaseTests.cs b/tests/clawsharp.Tests/Unit/Core/RateLimiterEdgeCaseTests.cs similarity index 99% rename from tests/clawsharp.Tests/RateLimiterEdgeCaseTests.cs rename to tests/clawsharp.Tests/Unit/Core/RateLimiterEdgeCaseTests.cs index 9f169cc2..0dae0fe2 100644 --- a/tests/clawsharp.Tests/RateLimiterEdgeCaseTests.cs +++ b/tests/clawsharp.Tests/Unit/Core/RateLimiterEdgeCaseTests.cs @@ -3,7 +3,7 @@ using Microsoft.Extensions.Options; using Clawsharp.Config.Agent; -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Core; /// /// Edge-case tests for : diff --git a/tests/clawsharp.Tests/RateLimiterTests.cs b/tests/clawsharp.Tests/Unit/Core/RateLimiterTests.cs similarity index 99% rename from tests/clawsharp.Tests/RateLimiterTests.cs rename to tests/clawsharp.Tests/Unit/Core/RateLimiterTests.cs index d2b35c3e..cdd6a599 100644 --- a/tests/clawsharp.Tests/RateLimiterTests.cs +++ b/tests/clawsharp.Tests/Unit/Core/RateLimiterTests.cs @@ -3,8 +3,9 @@ using Microsoft.Extensions.Options; using Clawsharp.Config.Agent; -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Core; +[TestFixture] public sealed class RateLimiterTests { private static RateLimiter CreateLimiter(int maxRequests = 3, int windowSeconds = 60) diff --git a/tests/clawsharp.Tests/SessionPruneTests.cs b/tests/clawsharp.Tests/Unit/Core/SessionPruneTests.cs similarity index 97% rename from tests/clawsharp.Tests/SessionPruneTests.cs rename to tests/clawsharp.Tests/Unit/Core/SessionPruneTests.cs index a5eae74d..eb185ed0 100644 --- a/tests/clawsharp.Tests/SessionPruneTests.cs +++ b/tests/clawsharp.Tests/Unit/Core/SessionPruneTests.cs @@ -1,8 +1,9 @@ -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Core; -using Core; -using Core.Sessions; +using Clawsharp.Core; +using Clawsharp.Core.Sessions; +[TestFixture] public sealed class SessionPruneTests { private static Session CreateSession(params ChatMessage[] messages) diff --git a/tests/clawsharp.Tests/Unit/Core/SystemEventAttributeTests.cs b/tests/clawsharp.Tests/Unit/Core/SystemEventAttributeTests.cs index 49924546..e1612eba 100644 --- a/tests/clawsharp.Tests/Unit/Core/SystemEventAttributeTests.cs +++ b/tests/clawsharp.Tests/Unit/Core/SystemEventAttributeTests.cs @@ -7,6 +7,7 @@ namespace Clawsharp.Tests.Unit.Core; /// Tests that all 7 ISystemEvent types have the correct [EventType] wire names and categories, /// and that security event records are instantiable with required properties (EVT-04). /// +[TestFixture] public sealed class SystemEventAttributeTests { // ── Parameterized attribute verification — all 7 event types ──────────── diff --git a/tests/clawsharp.Tests/ToolValidatorTests.cs b/tests/clawsharp.Tests/Unit/Core/ToolValidatorTests.cs similarity index 99% rename from tests/clawsharp.Tests/ToolValidatorTests.cs rename to tests/clawsharp.Tests/Unit/Core/ToolValidatorTests.cs index 25e7f73a..d02220ff 100644 --- a/tests/clawsharp.Tests/ToolValidatorTests.cs +++ b/tests/clawsharp.Tests/Unit/Core/ToolValidatorTests.cs @@ -1,7 +1,7 @@ using System.Text.Json; using Clawsharp.Tools; -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Core; [TestFixture] public sealed class ToolValidatorTests diff --git a/tests/clawsharp.Tests/Unit/Cost/BudgetScopeTests.cs b/tests/clawsharp.Tests/Unit/Cost/BudgetScopeTests.cs index 0b735f5c..4b5ee242 100644 --- a/tests/clawsharp.Tests/Unit/Cost/BudgetScopeTests.cs +++ b/tests/clawsharp.Tests/Unit/Cost/BudgetScopeTests.cs @@ -6,6 +6,7 @@ namespace Clawsharp.Tests.Unit.Cost; +[TestFixture] public sealed class BudgetScopeTests : IDisposable { private readonly string _tempDir; @@ -172,8 +173,8 @@ public async Task Stacking_AllWithinLimits_UserAtWarning_ReturnsWarning_WithBoth await tracker.CheckBudgetAsync(0m); - // gpt-4o: $5/1M input, so 18000 tokens = $0.09 - await tracker.RecordUsageAsync("s1", "gpt-4o", 18_000, 0, userId: "alice", departmentId: "eng"); + // gpt-4o: $2.50/1M input, so 36000 tokens = $0.09 + await tracker.RecordUsageAsync("s1", "gpt-4o", 36_000, 0, userId: "alice", departmentId: "eng"); // User daily budget $0.10 — at 90% after $0.09 usage (above 80% threshold) var userBudget = new BudgetLimits { Daily = 0.10m, Monthly = 100.0m }; @@ -197,8 +198,8 @@ public async Task Stacking_UserAndDeptBothAtWarning_WarningsContainBoth() await tracker.CheckBudgetAsync(0m); - // Record $0.09 for both user alice and dept marketing - await tracker.RecordUsageAsync("s1", "gpt-4o", 18_000, 0, userId: "alice", departmentId: "marketing"); + // gpt-4o: $2.50/1M input, 36000 tokens = $0.09 + await tracker.RecordUsageAsync("s1", "gpt-4o", 36_000, 0, userId: "alice", departmentId: "marketing"); // Both budgets at $0.10 — $0.09 = 90% (above 80% threshold) var userBudget = new BudgetLimits { Daily = 0.10m, Monthly = 100.0m }; @@ -243,8 +244,8 @@ public async Task WarnAtPercent_Zero_FallsBackToGlobalCostConfigWarnAtPercent() await tracker.CheckBudgetAsync(0m); - // gpt-4o: $5/1M input, 18000 tokens = $0.09 - await tracker.RecordUsageAsync("s1", "gpt-4o", 18_000, 0, userId: "alice"); + // gpt-4o: $2.50/1M input, 36000 tokens = $0.09 + await tracker.RecordUsageAsync("s1", "gpt-4o", 36_000, 0, userId: "alice"); // User daily budget $0.10 with WarnAtPercent=0 (should use global 80%) // $0.09 / $0.10 = 90% >= 80% → warning @@ -292,22 +293,22 @@ public async Task RecordUsage_MultipleUsers_TracksPerScopeTotalsCorrectly() await tracker.CheckBudgetAsync(0m); - // Record usage for alice and bob + // Record usage for alice and bob (gpt-4o: $2.50/1M input) await tracker.RecordUsageAsync("s1", "gpt-4o", 10_000, 0, userId: "alice"); await tracker.RecordUsageAsync("s2", "gpt-4o", 20_000, 0, userId: "bob"); - // Alice's scope should only include her usage ($0.05), not bob's ($0.10) - var userBudgetAlice = new BudgetLimits { Daily = 0.06m, Monthly = 100.0m }; + // Alice's scope should only include her usage ($0.025), not bob's ($0.05) + var userBudgetAlice = new BudgetLimits { Daily = 0.03m, Monthly = 100.0m }; var resultAlice = await tracker.CheckBudgetAsync(0m, userId: "alice", userBudget: userBudgetAlice); - // Alice: $0.05 / $0.06 = 83% → Warning (80% threshold) + // Alice: $0.025 / $0.03 = 83% → Warning (80% threshold) resultAlice.UserBudget.ShouldNotBeNull(); - resultAlice.UserBudget.DailyUsed.ShouldBe(0.05m); + resultAlice.UserBudget.DailyUsed.ShouldBe(0.0250m); - // Bob's scope should only include his usage ($0.10) - var userBudgetBob = new BudgetLimits { Daily = 0.06m, Monthly = 100.0m }; + // Bob's scope should only include his usage ($0.05) + var userBudgetBob = new BudgetLimits { Daily = 0.03m, Monthly = 100.0m }; var resultBob = await tracker.CheckBudgetAsync(0m, userId: "bob", userBudget: userBudgetBob); resultBob.UserBudget.ShouldNotBeNull(); - resultBob.UserBudget.DailyUsed.ShouldBe(0.10m); + resultBob.UserBudget.DailyUsed.ShouldBe(0.0500m); resultBob.UserBudget.Status.ShouldBe(BudgetStatus.Exceeded); } } diff --git a/tests/clawsharp.Tests/Unit/Cost/CostRecordBackwardCompatTests.cs b/tests/clawsharp.Tests/Unit/Cost/CostRecordBackwardCompatTests.cs index 0101be3d..f2a2a452 100644 --- a/tests/clawsharp.Tests/Unit/Cost/CostRecordBackwardCompatTests.cs +++ b/tests/clawsharp.Tests/Unit/Cost/CostRecordBackwardCompatTests.cs @@ -3,6 +3,7 @@ namespace Clawsharp.Tests.Unit.Cost; +[TestFixture] public sealed class CostRecordBackwardCompatTests { [Test] diff --git a/tests/clawsharp.Tests/Unit/Cost/CostSimulationTests.cs b/tests/clawsharp.Tests/Unit/Cost/CostSimulationTests.cs index 7dbc190e..713e046f 100644 --- a/tests/clawsharp.Tests/Unit/Cost/CostSimulationTests.cs +++ b/tests/clawsharp.Tests/Unit/Cost/CostSimulationTests.cs @@ -5,6 +5,7 @@ namespace Clawsharp.Tests.Unit.Cost; /// /// Pure math scenario tests using DefaultPricing.CalculateCost. No I/O. /// +[TestFixture] public sealed class CostSimulationTests { [Test] @@ -59,13 +60,13 @@ public void CalculateCost_CheapVsExpensiveModel_PriceDifferenceOver100x() public void CalculateCost_MonthlyProjection_WithinReasonableRange() { // 20 requests/day x 30 days = 600 requests - // gpt-4o: $5/1M input, $15/1M output + // gpt-4o: $2.50/1M input, $10.00/1M output // avg 1000 input + 500 output per request var costPerRequest = DefaultPricing.CalculateCost("gpt-4o", 1000, 500); var monthlyCost = costPerRequest * 600; - // Should be in a reasonable range: $5-$50 - monthlyCost.ShouldBeGreaterThan(5.0m); + // Should be in a reasonable range: $3-$50 + monthlyCost.ShouldBeGreaterThan(3.0m); monthlyCost.ShouldBeLessThan(50.0m); } @@ -74,7 +75,7 @@ public void CalculateCost_AllKnownModels_ReturnsNonNegativeCost() { string[] knownModels = [ - "gpt-4o", "gpt-4o-mini", "o1-preview", "o3-mini", "gpt-4.1", "gpt-4.1-mini", "gpt-5.2", + "gpt-4o", "gpt-4o-mini", "o1", "o3", "o3-mini", "o4-mini", "gpt-4.1", "gpt-4.1-mini", "gpt-5", "gpt-5.2", "gpt-5.4", "gpt-5.4-mini", "claude-sonnet-4-6", "claude-opus-4-6", "claude-3-haiku", "gemini-2.0-flash", "gemini-2.5-pro", "gemini-2.5-flash", "deepseek-chat", "deepseek-reasoner", @@ -136,15 +137,15 @@ public void CalculateCostWithCaching_AnthropicCacheWrite_ChargedAt125Percent() [Test] public void CalculateCostWithCaching_OpenAiCacheRead_DiscountedAt50Percent() { - // gpt-4o: $5/1M input + // gpt-4o: $2.50/1M input // 1000 total prompt tokens, 500 served from cache, no output - // Regular (uncached) input = 500 tokens at $5/1M - // Cached input = 500 tokens at $5 * 0.50 / 1M + // Regular (uncached) input = 500 tokens at $2.50/1M + // Cached input = 500 tokens at $2.50 * 0.50 / 1M var (cost, savings) = DefaultPricing.CalculateCostWithCaching("gpt-4o", 1000, 0, cacheReadTokens: 500, cacheWriteTokens: 0); - var expected = (500m * 5.0m + 500m * 5.0m * 0.50m) / 1_000_000m; + var expected = (500m * 2.50m + 500m * 2.50m * 0.50m) / 1_000_000m; cost.ShouldBe(expected); - savings.ShouldBe(500m * 5.0m * 0.50m / 1_000_000m); + savings.ShouldBe(500m * 2.50m * 0.50m / 1_000_000m); } [Test] diff --git a/tests/clawsharp.Tests/Unit/Cost/CostStorageTests.cs b/tests/clawsharp.Tests/Unit/Cost/CostStorageTests.cs index 16465d6b..8bbc96df 100644 --- a/tests/clawsharp.Tests/Unit/Cost/CostStorageTests.cs +++ b/tests/clawsharp.Tests/Unit/Cost/CostStorageTests.cs @@ -2,6 +2,7 @@ namespace Clawsharp.Tests.Unit.Cost; +[TestFixture] public sealed class CostStorageTests : IDisposable { private readonly string _tempDir; diff --git a/tests/clawsharp.Tests/Unit/Cost/CostTrackerConcurrencyTests.cs b/tests/clawsharp.Tests/Unit/Cost/CostTrackerConcurrencyTests.cs index 43b6c477..47326d9a 100644 --- a/tests/clawsharp.Tests/Unit/Cost/CostTrackerConcurrencyTests.cs +++ b/tests/clawsharp.Tests/Unit/Cost/CostTrackerConcurrencyTests.cs @@ -6,6 +6,7 @@ namespace Clawsharp.Tests.Unit.Cost; +[TestFixture] public sealed class CostTrackerConcurrencyTests : IDisposable { private readonly string _tempDir; @@ -74,14 +75,14 @@ public async Task ConcurrentRecordUsage_MultipleUsers_AggregatesCorrectly() for (var call = 0; call < 10; call++) { var userId = $"user-{user}"; - // Each call: gpt-4o, 1000 input tokens = $0.005 + // Each call: gpt-4o, 1000 input tokens at $2.50/1M = $0.0025 tasks.Add(tracker.RecordUsageAsync($"s-{user}-{call}", "gpt-4o", 1_000, 0, userId: userId)); } } await Task.WhenAll(tasks); - // Check each user's scope — each should have $0.05 (10 * $0.005) + // Check each user's scope — each should have $0.025 (10 * $0.0025) for (var user = 0; user < 5; user++) { var userId = $"user-{user}"; @@ -89,11 +90,11 @@ public async Task ConcurrentRecordUsage_MultipleUsers_AggregatesCorrectly() var result = await tracker.CheckBudgetAsync(0m, userId: userId, userBudget: userBudget); result.UserBudget.ShouldNotBeNull(); - result.UserBudget.DailyUsed.ShouldBe(0.05m); + result.UserBudget.DailyUsed.ShouldBe(0.0250m); } - // Global total should be $0.25 (5 users * 10 calls * $0.005) + // Global total should be $0.125 (5 users * 10 calls * $0.0025) var summary = await tracker.GetSummaryAsync(); - summary.Daily.ShouldBe(0.25m); + summary.Daily.ShouldBe(0.1250m); } } diff --git a/tests/clawsharp.Tests/Unit/Cost/CostTrackerEdgeCaseTests.cs b/tests/clawsharp.Tests/Unit/Cost/CostTrackerEdgeCaseTests.cs index dcd4e36c..eed57730 100644 --- a/tests/clawsharp.Tests/Unit/Cost/CostTrackerEdgeCaseTests.cs +++ b/tests/clawsharp.Tests/Unit/Cost/CostTrackerEdgeCaseTests.cs @@ -58,17 +58,17 @@ public async Task RecordUsageAsync_NegativeInputTokens_OpenAiPath_ClampedToZero( await tracker.CheckBudgetAsync(estimatedCost: 0m); - // Record legitimate usage: gpt-4o $5/1M input -> 200K = $1.00 + // Record legitimate usage: gpt-4o $2.50/1M input -> 200K = $0.50 await tracker.RecordUsageAsync("s1", "gpt-4o", inputTokens: 200_000, outputTokens: 0); var summaryBefore = await tracker.GetSummaryAsync(); - summaryBefore.Daily.ShouldBe(1.0m); + summaryBefore.Daily.ShouldBe(0.5000m); // Negative input tokens on OpenAI path => Math.Max(0, -100K - 0) = 0 => cost = $0 await tracker.RecordUsageAsync("s1", "gpt-4o", inputTokens: -100_000, outputTokens: 0); var summaryAfter = await tracker.GetSummaryAsync(); - summaryAfter.Daily.ShouldBe(1.0m, + summaryAfter.Daily.ShouldBe(0.5000m, "Negative input tokens on OpenAI path are clamped to 0 by Math.Max"); } @@ -118,17 +118,17 @@ public async Task RecordUsageAsync_NegativeOutputTokens_KnownLimitation_ReducesB await tracker.CheckBudgetAsync(estimatedCost: 0m); - // gpt-4o $15/1M output -> 100K = $1.50 + // gpt-4o $10.00/1M output -> 100K = $1.00 await tracker.RecordUsageAsync("s1", "gpt-4o", inputTokens: 0, outputTokens: 100_000); var summaryBefore = await tracker.GetSummaryAsync(); - summaryBefore.Daily.ShouldBe(1.50m); + summaryBefore.Daily.ShouldBe(1.0000m); - // Negative output tokens: -50K * $15/1M = -$0.75 + // Negative output tokens: -50K * $10.00/1M = -$0.50 await tracker.RecordUsageAsync("s1", "gpt-4o", inputTokens: 0, outputTokens: -50_000); var summaryAfter = await tracker.GetSummaryAsync(); - summaryAfter.Daily.ShouldBe(0.75m, + summaryAfter.Daily.ShouldBe(0.50m, "Known limitation: negative output tokens reduce budget total"); } @@ -141,34 +141,35 @@ public async Task RecordUsageAsync_NegativeOutputTokens_KnownLimitation_CanBypas var tracker = CreateTracker(new CostConfig { Enabled = true, - DailyLimitUsd = 1.0m, + DailyLimitUsd = 0.50m, MonthlyLimitUsd = 100.0m, WarnAtPercent = 80 }); await tracker.CheckBudgetAsync(estimatedCost: 0m); - // Fill daily budget: gpt-4o output 100K * $15/1M = $1.50 + // Fill daily budget: gpt-4o output 100K * $10.00/1M = $1.00, exceeds $0.50 await tracker.RecordUsageAsync("s1", "gpt-4o", inputTokens: 0, outputTokens: 100_000); // Budget should be exceeded var result1 = await tracker.CheckBudgetAsync(estimatedCost: 0.01m); result1.Status.ShouldBe(BudgetStatus.Exceeded); - // Negative output tokens bring us back under budget: -100K * $15/1M = -$1.50 + // Negative output tokens bring us back under budget: -100K * $10.00/1M = -$1.00 await tracker.RecordUsageAsync("s1", "gpt-4o", inputTokens: 0, outputTokens: -100_000); - // Budget is no longer exceeded (daily total: $1.50 - $1.50 = $0.00) + // Budget is no longer exceeded (daily total: $1.00 - $1.00 = $0.00) var result2 = await tracker.CheckBudgetAsync(estimatedCost: 0.01m); result2.Status.ShouldBe(BudgetStatus.Allowed, "Known limitation: negative output tokens bypass daily budget enforcement"); } [Test] - public async Task RecordUsageAsync_NegativeCacheTokens_DoesNotProduceNegativeSavings() + public async Task RecordUsageAsync_NegativeCacheTokens_AllowsNegativeSavings() { - // CostTracker clamps savings to >= 0 via `if (savings > 0)`, but the underlying - // cost calculation may still produce unexpected results with negative cache tokens. + // Cache savings can be negative when write premiums exceed read discounts (Anthropic) + // or when negative cache token counts are supplied. CostTracker passes savings through + // without clamping so persistence is consistent with in-memory accumulators. var tracker = CreateTracker(); await tracker.CheckBudgetAsync(estimatedCost: 0m); @@ -178,10 +179,9 @@ await tracker.RecordUsageAsync("s1", "gpt-4o", var summary = await tracker.GetSummaryAsync("s1"); - // The CostTracker record's CacheSavingsUsd is clamped to 0 when savings is negative. - // (savings = -500 * 5 * 0.5 / 1M < 0 → clamped to 0) - summary.SessionSavings.ShouldBeGreaterThanOrEqualTo(0m, - "CostTracker clamps negative savings to 0"); + // savings = -500 * 5 * 0.5 / 1M < 0 — negative savings pass through unclamped + summary.SessionSavings.ShouldBeLessThan(0m, + "Negative cache savings should pass through unclamped"); } // ── Disabled mode ────────────────────────────────────────────────── @@ -252,10 +252,10 @@ public async Task CheckBudgetAsync_NegativeEstimatedCost_ReducesProjectedTotal() }); await tracker.CheckBudgetAsync(estimatedCost: 0m); - // Fill to near budget - await tracker.RecordUsageAsync("s1", "gpt-4o", inputTokens: 190_000, outputTokens: 0); + // Fill to near budget: gpt-4o $2.50/1M input -> 360_000 = $0.90 (90% of $1.00) + await tracker.RecordUsageAsync("s1", "gpt-4o", inputTokens: 360_000, outputTokens: 0); - // Budget warning at 80% ($0.80) — current is $0.95 + // Budget warning at 80% ($0.80) — current is $0.90 var resultPositive = await tracker.CheckBudgetAsync(estimatedCost: 0.01m); resultPositive.Status.ShouldBe(BudgetStatus.Warning); diff --git a/tests/clawsharp.Tests/Unit/Cost/CostTrackerTests.cs b/tests/clawsharp.Tests/Unit/Cost/CostTrackerTests.cs index a1b4d60c..fa20e3a1 100644 --- a/tests/clawsharp.Tests/Unit/Cost/CostTrackerTests.cs +++ b/tests/clawsharp.Tests/Unit/Cost/CostTrackerTests.cs @@ -5,6 +5,7 @@ namespace Clawsharp.Tests.Unit.Cost; +[TestFixture] public sealed class CostTrackerTests : IDisposable { private readonly string _tempDir; @@ -106,8 +107,8 @@ public async Task CheckBudgetAsync_ApproachingLimit_ReturnsWarning() // Prime the tracker (triggers EnsureInitializedAsync on empty storage). await tracker.CheckBudgetAsync(estimatedCost: 0m); - // gpt-4o: $5/1M input -> 180_000 tokens = $0.90 (90% of $1.00, above 80% threshold but below 100%) - await tracker.RecordUsageAsync("test-session", "gpt-4o", 180_000, 0); + // gpt-4o: $2.50/1M input -> 360_000 tokens = $0.90 (90% of $1.00, above 80% threshold but below 100%) + await tracker.RecordUsageAsync("test-session", "gpt-4o", 360_000, 0); var result = await tracker.CheckBudgetAsync(estimatedCost: 0m); @@ -130,14 +131,14 @@ public async Task GetSummaryAsync_MultipleCalls_ReturnsAccurateAggregation() var summary = await tracker.GetSummaryAsync("session-a"); - // gpt-4o: $5/1M input, $15/1M output - // Call 1: 1000*5/1M + 500*15/1M = 0.005 + 0.0075 = 0.0125 - // Call 2: 2000*5/1M + 1000*15/1M = 0.01 + 0.015 = 0.025 - // Total: 0.0375 - summary.Daily.ShouldBe(0.0375m); - summary.Monthly.ShouldBe(0.0375m); - // Session A only: 0.0125 - summary.Session.ShouldBe(0.0125m); + // gpt-4o: $2.50/1M input, $10.00/1M output + // Call 1: 1000*2.50/1M + 500*10.00/1M = 0.0025 + 0.005 = 0.0075 + // Call 2: 2000*2.50/1M + 1000*10.00/1M = 0.005 + 0.010 = 0.015 + // Total: 0.0225 + summary.Daily.ShouldBe(0.0225m); + summary.Monthly.ShouldBe(0.0225m); + // Session A only: 0.0075 + summary.Session.ShouldBe(0.0075m); } [Test] @@ -193,15 +194,15 @@ public async Task RecordUsageAsync_OpenAiCacheTokens_CalculatesSavingsCorrectly( // Prime the tracker (triggers initialization on empty storage). await tracker.CheckBudgetAsync(estimatedCost: 0m); - // gpt-4o: $5.00/1M input, $15.00/1M output + // gpt-4o: $2.50/1M input, $10.00/1M output // OpenAI cache: inputTokens is total (including cached); read=0.50x // inputTokens=1000 (total), outputTokens=200, cacheRead=800, cacheWrite=0 // // regularInput = max(0, 1000 - 800) = 200 - // inputCost = (200*5.00 + 800*5.00*0.50) / 1M = (1000 + 2000) / 1M = 0.003000 - // outputCost = 200*15.00 / 1M = 0.003000 - // totalCost = 0.006000 - // savings = 800*5.00*0.50 / 1M = 0.002000 + // inputCost = (200*2.50 + 800*2.50*0.50) / 1M = (500 + 1000) / 1M = 0.001500 + // outputCost = 200*10.00 / 1M = 0.002000 + // totalCost = 0.003500 + // savings = 800*2.50*0.50 / 1M = 0.001000 await tracker.RecordUsageAsync( "session-openai", "gpt-4o", inputTokens: 1000, outputTokens: 200, @@ -209,10 +210,10 @@ await tracker.RecordUsageAsync( var summary = await tracker.GetSummaryAsync("session-openai"); - summary.Daily.ShouldBe(0.006000m); - summary.DailySavings.ShouldBe(0.002000m); - summary.Session.ShouldBe(0.006000m); - summary.SessionSavings.ShouldBe(0.002000m); + summary.Daily.ShouldBe(0.003500m); + summary.DailySavings.ShouldBe(0.001000m); + summary.Session.ShouldBe(0.003500m); + summary.SessionSavings.ShouldBe(0.001000m); } [Test] @@ -235,11 +236,11 @@ await tracker.RecordUsageAsync( cacheReadTokens: 400, cacheWriteTokens: 100); // Session B: OpenAI model with cache tokens - // gpt-4o: $5.00/1M input, $15.00/1M output + // gpt-4o: $2.50/1M input, $10.00/1M output // regularInput = max(0, 2000 - 1500) = 500 - // cost = (500*5.00 + 1500*5.00*0.50)/1M + 300*15.00/1M - // = (2500 + 3750)/1M + 4500/1M = 6250/1M + 4500/1M = 0.010750 - // savings = 1500*5.00*0.50/1M = 0.003750 + // cost = (500*2.50 + 1500*2.50*0.50)/1M + 300*10.00/1M + // = (1250 + 1875)/1M + 3000/1M = 3125/1M + 3000/1M = 0.006125 + // savings = 1500*2.50*0.50/1M = 0.001875 await tracker.RecordUsageAsync( "session-B", "gpt-4o", inputTokens: 2000, outputTokens: 300, @@ -252,12 +253,12 @@ await tracker.RecordUsageAsync( // Check session-scoped summary for B only var summaryB = await tracker.GetSummaryAsync("session-B"); - summaryB.Session.ShouldBe(0.010750m); - summaryB.SessionSavings.ShouldBe(0.003750m); + summaryB.Session.ShouldBe(0.006125m); + summaryB.SessionSavings.ShouldBe(0.001875m); // Daily and monthly totals should include both sessions - var expectedDailyTotal = 0.003495m + 0.010750m; - var expectedDailySavings = 0.001005m + 0.003750m; + var expectedDailyTotal = 0.003495m + 0.006125m; + var expectedDailySavings = 0.001005m + 0.001875m; summaryA.Daily.ShouldBe(expectedDailyTotal); summaryA.Monthly.ShouldBe(expectedDailyTotal); summaryA.DailySavings.ShouldBe(expectedDailySavings); diff --git a/tests/clawsharp.Tests/Unit/Cost/DefaultPricingCachingTests.cs b/tests/clawsharp.Tests/Unit/Cost/DefaultPricingCachingTests.cs index 194beb62..dc442f6b 100644 --- a/tests/clawsharp.Tests/Unit/Cost/DefaultPricingCachingTests.cs +++ b/tests/clawsharp.Tests/Unit/Cost/DefaultPricingCachingTests.cs @@ -24,18 +24,18 @@ public void CalculateCostWithCaching_OpenAi_CacheReadExceedsInput_ClampedToZeroR savings.ShouldBeGreaterThanOrEqualTo(0m, "savings must never be negative on OpenAI path"); // Expected: regularInput = max(0, 100 - 200) = 0 - // inputCost = (0 * 5 + 200 * 5 * 0.5) / 1M = 0.0005 - // (gpt-4o: $5/1M input) - cost.ShouldBe(200m * 5m * 0.50m / 1_000_000m); + // inputCost = (0 * 2.50 + 200 * 2.50 * 0.5) / 1M = 0.00025 + // (gpt-4o: $2.50/1M input) + cost.ShouldBe(200m * 2.50m * 0.50m / 1_000_000m); } [Test] public void CalculateCostWithCaching_OpenAi_NormalCacheRead_PartialDiscount() { // 1000 total input, 400 cached, 600 regular - // gpt-4o: $5/1M - // inputCost = (600 * 5 + 400 * 5 * 0.5) / 1M = (3000 + 1000) / 1M = 0.004 - // savings = 400 * 5 * 0.5 / 1M = 0.001 + // gpt-4o: $2.50/1M + // inputCost = (600 * 2.50 + 400 * 2.50 * 0.5) / 1M = (1500 + 500) / 1M = 0.002 + // savings = 400 * 2.50 * 0.5 / 1M = 0.0005 var (cost, savings) = DefaultPricing.CalculateCostWithCaching( "gpt-4o", inputTokens: 1000, @@ -43,8 +43,8 @@ public void CalculateCostWithCaching_OpenAi_NormalCacheRead_PartialDiscount() cacheReadTokens: 400, cacheWriteTokens: 0); - cost.ShouldBe(0.004m, 0.0001m); - savings.ShouldBe(0.001m, 0.0001m); + cost.ShouldBe(0.002m, 0.0001m); + savings.ShouldBe(0.0005m, 0.0001m); } // ── Anthropic caching path ───────────────────────────────────────── @@ -149,7 +149,7 @@ public void CalculateCostWithCaching_UnknownModel_NoOverride_ReturnsBothZero() [Test] public void CalculateCostWithCaching_OutputTokensOnly_CostFromOutputAlone() { - // gpt-4o: $15/1M output + // gpt-4o: $10.00/1M output var (cost, _) = DefaultPricing.CalculateCostWithCaching( "gpt-4o", inputTokens: 0, @@ -157,7 +157,7 @@ public void CalculateCostWithCaching_OutputTokensOnly_CostFromOutputAlone() cacheReadTokens: 0, cacheWriteTokens: 0); - cost.ShouldBe(15m, 0.001m); + cost.ShouldBe(10.00m, 0.001m); } // ── Anthropic dot-notation normalization ────────────────────────── diff --git a/tests/clawsharp.Tests/Unit/Cost/DefaultPricingEdgeCaseTests.cs b/tests/clawsharp.Tests/Unit/Cost/DefaultPricingEdgeCaseTests.cs index 8e33d61f..e0cd5682 100644 --- a/tests/clawsharp.Tests/Unit/Cost/DefaultPricingEdgeCaseTests.cs +++ b/tests/clawsharp.Tests/Unit/Cost/DefaultPricingEdgeCaseTests.cs @@ -69,20 +69,20 @@ public void CalculateCost_ZeroCustomPricing_ReturnsZero() public void CalculateCost_NegativeInputTokens_KnownLimitation_ProducesNegativeCost() { // Known limitation: negative token counts are not validated. - // gpt-4o: $5/1M input + // gpt-4o: $2.50/1M input var cost = DefaultPricing.CalculateCost("gpt-4o", -1_000_000, 0); - cost.ShouldBe(-5.0m, + cost.ShouldBe(-2.50m, "Known limitation: negative input tokens produce negative cost"); } [Test] public void CalculateCost_NegativeOutputTokens_KnownLimitation_ProducesNegativeCost() { - // gpt-4o: $15/1M output + // gpt-4o: $10.00/1M output var cost = DefaultPricing.CalculateCost("gpt-4o", 0, -1_000_000); - cost.ShouldBe(-15.0m, + cost.ShouldBe(-10.00m, "Known limitation: negative output tokens produce negative cost"); } @@ -135,7 +135,7 @@ public void CalculateCost_OverrideForOneModel_DoesNotAffectOtherModels() public void CalculateCost_NullOverrides_UsesBuiltInPricing() { var cost = DefaultPricing.CalculateCost("gpt-4o", 1_000_000, 0, null); - cost.ShouldBe(5.0m, "Null overrides should fall through to built-in pricing"); + cost.ShouldBe(2.50m, "Null overrides should fall through to built-in pricing"); } [Test] @@ -149,6 +149,6 @@ public void CalculateCostWithCaching_NullOverrides_UsesBuiltInPricing() cacheWriteTokens: 0, overrides: null); - cost.ShouldBe(5.0m); + cost.ShouldBe(2.50m); } } diff --git a/tests/clawsharp.Tests/Unit/Cost/DefaultPricingTests.cs b/tests/clawsharp.Tests/Unit/Cost/DefaultPricingTests.cs index 4907278f..4d8a2fa6 100644 --- a/tests/clawsharp.Tests/Unit/Cost/DefaultPricingTests.cs +++ b/tests/clawsharp.Tests/Unit/Cost/DefaultPricingTests.cs @@ -3,15 +3,16 @@ namespace Clawsharp.Tests.Unit.Cost; +[TestFixture] public sealed class DefaultPricingTests { [Test] public void CalculateCost_KnownModel_ReturnsCorrectCost() { - // gpt-4o: $5/1M input, $15/1M output - // 1000 input + 500 output = $0.005 + $0.0075 = $0.0125 + // gpt-4o: $2.50/1M input, $10.00/1M output + // 1000 input + 500 output = $0.0025 + $0.005 = $0.0075 var cost = DefaultPricing.CalculateCost("gpt-4o", 1000, 500); - cost.ShouldBe(0.0125m); + cost.ShouldBe(0.0075m); } [Test] @@ -90,7 +91,7 @@ public void GetPrice_AllKnownModels_ReturnsNonNegative() // Spot-check a representative set of known models string[] knownModels = [ - "gpt-4o", "gpt-4o-mini", "o1-preview", "o3-mini", + "gpt-4o", "gpt-4o-mini", "o3", "o3-mini", "o4-mini", "gpt-5.4", "gpt-5.4-mini", "claude-sonnet-4-6", "claude-opus-4-6", "claude-3-haiku", "gemini-2.0-flash", "gemini-2.5-pro", "deepseek-chat", "deepseek-reasoner", diff --git a/tests/clawsharp.Tests/Unit/Features/SiblingSyncTests.cs b/tests/clawsharp.Tests/Unit/Features/SiblingSyncTests.cs index e6fde943..c6589f5c 100644 --- a/tests/clawsharp.Tests/Unit/Features/SiblingSyncTests.cs +++ b/tests/clawsharp.Tests/Unit/Features/SiblingSyncTests.cs @@ -10,6 +10,7 @@ namespace Clawsharp.Tests.Unit.Features; /// Tests for the sibling sync features: /model slash command, ExtraHeaders, /// ApiKeys round-robin rotation, and SpawnTimeout configuration. /// +[TestFixture] public sealed class SiblingSyncTests { // ────────────────────────────────────────────────────────────────────── diff --git a/tests/clawsharp.Tests/Unit/Features/WebhookConfigTests.cs b/tests/clawsharp.Tests/Unit/Features/WebhookConfigTests.cs index 5dd3d590..b46e4774 100644 --- a/tests/clawsharp.Tests/Unit/Features/WebhookConfigTests.cs +++ b/tests/clawsharp.Tests/Unit/Features/WebhookConfigTests.cs @@ -9,6 +9,7 @@ namespace Clawsharp.Tests.Unit.Features; /// Serialization round-trip tests for WebhookConfig and WebhookEndpointConfig via /// source-generated ConfigJsonContext (no reflection). Covers EVT-01, EVT-02, EVT-03. /// +[TestFixture] public sealed class WebhookConfigTests { // ── Full shape deserialization ─────────────────────────────────────────── diff --git a/tests/clawsharp.Tests/Unit/McpServer/ChannelNameMcpTests.cs b/tests/clawsharp.Tests/Unit/McpServer/ChannelNameMcpTests.cs index 960f35c7..ac6f3cc5 100644 --- a/tests/clawsharp.Tests/Unit/McpServer/ChannelNameMcpTests.cs +++ b/tests/clawsharp.Tests/Unit/McpServer/ChannelNameMcpTests.cs @@ -5,6 +5,7 @@ namespace Clawsharp.Tests.Unit.McpServer; /// /// Unit tests for the ChannelName.Mcp Intellenum value (CHAN-01). /// +[TestFixture] public sealed class ChannelNameMcpTests { [Test] diff --git a/tests/clawsharp.Tests/Unit/McpServer/McpServerAuthenticatorTests.cs b/tests/clawsharp.Tests/Unit/McpServer/McpServerAuthenticatorTests.cs index c3f08687..030e78cc 100644 --- a/tests/clawsharp.Tests/Unit/McpServer/McpServerAuthenticatorTests.cs +++ b/tests/clawsharp.Tests/Unit/McpServer/McpServerAuthenticatorTests.cs @@ -64,8 +64,7 @@ private static McpServerAuthenticator CreateAuthenticator( return new McpServerAuthenticator( config, - apiKeyAuthenticator, - NullLogger.Instance); + apiKeyAuthenticator); } // ── Valid API key -> resolves to OrgUser ────────────────────────────── @@ -91,7 +90,6 @@ public async Task AuthenticateAsync_ValidApiKey_ResolvesToOrgUser() result.User.ShouldNotBeNull(); result.User!.Name.ShouldBe("alice"); result.KeyId.ShouldBe("cursor-key"); - result.IsOriginDenied.ShouldBeFalse(); }); } @@ -166,7 +164,6 @@ public async Task AuthenticateAsync_InvalidApiKey_ReturnsUnauthenticated() result.IsAuthenticated.ShouldBeFalse(); result.User.ShouldBeNull(); result.KeyId.ShouldBeNull(); - result.IsOriginDenied.ShouldBeFalse(); }); } @@ -350,7 +347,6 @@ public void McpServerAuthResult_Unauthenticated_HasCorrectDefaults() result.User.ShouldBeNull(); result.PolicyDecision.ShouldBe(PolicyDecision.Unrestricted); result.KeyId.ShouldBeNull(); - result.IsOriginDenied.ShouldBeFalse(); }); } @@ -368,21 +364,6 @@ public void McpServerAuthResult_Success_HasAllFieldsSet() result.User.ShouldBe(user); result.PolicyDecision.ShouldBe(policy); result.KeyId.ShouldBe("my-key"); - result.IsOriginDenied.ShouldBeFalse(); - }); - } - - [Test] - public void McpServerAuthResult_OriginDenied_HasCorrectValues() - { - var result = McpServerAuthResult.OriginDenied(); - - Assert.Multiple(() => - { - result.IsAuthenticated.ShouldBeFalse(); - result.IsOriginDenied.ShouldBeTrue(); - result.User.ShouldBeNull(); - result.KeyId.ShouldBeNull(); }); } } diff --git a/tests/clawsharp.Tests/Unit/McpServer/McpServerDtoTests.cs b/tests/clawsharp.Tests/Unit/McpServer/McpServerDtoTests.cs deleted file mode 100644 index 3c1ec672..00000000 --- a/tests/clawsharp.Tests/Unit/McpServer/McpServerDtoTests.cs +++ /dev/null @@ -1,197 +0,0 @@ -using System.Text.Json; -using Clawsharp.Tools.Mcp; - -namespace Clawsharp.Tests.Unit.McpServer; - -/// -/// Unit tests for server-side MCP DTOs and McpJsonContext serialization (SDK-03). -/// -public sealed class McpServerDtoTests -{ - // ── McpInitializeResult ───────────────────────────────────────────────── - - [Test] - public void McpInitializeResult_DefaultValues_AreCorrect() - { - var result = new McpInitializeResult(); - result.ProtocolVersion.ShouldBe("2025-03-26"); - result.Capabilities.ShouldNotBeNull(); - result.ServerInfo.ShouldNotBeNull(); - result.ServerInfo.Name.ShouldBe("clawsharp"); - result.Instructions.ShouldBeNull(); - } - - [Test] - public void McpInitializeResult_Serializes_ToCamelCase() - { - var result = new McpInitializeResult(); - var json = JsonSerializer.Serialize(result, McpJsonContext.Default.McpInitializeResult); - - using var doc = JsonDocument.Parse(json); - var root = doc.RootElement; - - root.TryGetProperty("protocolVersion", out _).ShouldBeTrue(); - root.TryGetProperty("capabilities", out _).ShouldBeTrue(); - root.TryGetProperty("serverInfo", out _).ShouldBeTrue(); - - // PascalCase properties should NOT be present - root.TryGetProperty("ProtocolVersion", out _).ShouldBeFalse(); - root.TryGetProperty("Capabilities", out _).ShouldBeFalse(); - root.TryGetProperty("ServerInfo", out _).ShouldBeFalse(); - } - - [Test] - public void McpInitializeResult_NullInstructions_OmittedFromJson() - { - var result = new McpInitializeResult(); - var json = JsonSerializer.Serialize(result, McpJsonContext.Default.McpInitializeResult); - - using var doc = JsonDocument.Parse(json); - doc.RootElement.TryGetProperty("instructions", out _).ShouldBeFalse(); - } - - [Test] - public void McpInitializeResult_WithInstructions_IncludedInJson() - { - var result = new McpInitializeResult { Instructions = "Use these tools carefully." }; - var json = JsonSerializer.Serialize(result, McpJsonContext.Default.McpInitializeResult); - - using var doc = JsonDocument.Parse(json); - doc.RootElement.TryGetProperty("instructions", out var instructions).ShouldBeTrue(); - instructions.GetString().ShouldBe("Use these tools carefully."); - } - - [Test] - public void McpInitializeResult_RoundTrip_PreservesValues() - { - var original = new McpInitializeResult - { - ProtocolVersion = "2025-03-26", - Capabilities = new McpServerCapabilities - { - Tools = new McpToolsCapability { ListChanged = true } - }, - ServerInfo = new McpServerInfo { Name = "clawsharp", Version = "2.2.0" }, - Instructions = "Test instructions" - }; - - var json = JsonSerializer.Serialize(original, McpJsonContext.Default.McpInitializeResult); - var deserialized = JsonSerializer.Deserialize(json, McpJsonContext.Default.McpInitializeResult); - - deserialized.ShouldNotBeNull(); - deserialized!.ProtocolVersion.ShouldBe("2025-03-26"); - deserialized.Capabilities.ShouldNotBeNull(); - deserialized.Capabilities.Tools.ShouldNotBeNull(); - deserialized.Capabilities.Tools!.ListChanged.ShouldBeTrue(); - deserialized.ServerInfo.Name.ShouldBe("clawsharp"); - deserialized.ServerInfo.Version.ShouldBe("2.2.0"); - deserialized.Instructions.ShouldBe("Test instructions"); - } - - // ── McpServerInfo ─────────────────────────────────────────────────────── - - [Test] - public void McpServerInfo_DefaultName_IsClawsharp() - { - var info = new McpServerInfo(); - info.Name.ShouldBe("clawsharp"); - } - - [Test] - public void McpServerInfo_NullVersion_OmittedFromJson() - { - var info = new McpServerInfo(); - var json = JsonSerializer.Serialize(info, McpJsonContext.Default.McpServerInfo); - - using var doc = JsonDocument.Parse(json); - doc.RootElement.TryGetProperty("version", out _).ShouldBeFalse(); - doc.RootElement.TryGetProperty("name", out var name).ShouldBeTrue(); - name.GetString().ShouldBe("clawsharp"); - } - - // ── McpServerCapabilities ─────────────────────────────────────────────── - - [Test] - public void McpServerCapabilities_NullTools_OmittedFromJson() - { - var caps = new McpServerCapabilities(); - var json = JsonSerializer.Serialize(caps, McpJsonContext.Default.McpServerCapabilities); - - using var doc = JsonDocument.Parse(json); - doc.RootElement.TryGetProperty("tools", out _).ShouldBeFalse(); - } - - [Test] - public void McpServerCapabilities_WithTools_IncludesListChanged() - { - var caps = new McpServerCapabilities - { - Tools = new McpToolsCapability { ListChanged = true } - }; - var json = JsonSerializer.Serialize(caps, McpJsonContext.Default.McpServerCapabilities); - - using var doc = JsonDocument.Parse(json); - doc.RootElement.TryGetProperty("tools", out var tools).ShouldBeTrue(); - tools.TryGetProperty("listChanged", out var listChanged).ShouldBeTrue(); - listChanged.GetBoolean().ShouldBeTrue(); - } - - // ── McpToolAnnotations ────────────────────────────────────────────────── - - [Test] - public void McpToolAnnotations_AllNull_EmptyJsonObject() - { - var annotations = new McpToolAnnotations(); - var json = JsonSerializer.Serialize(annotations, McpJsonContext.Default.McpToolAnnotations); - - using var doc = JsonDocument.Parse(json); - doc.RootElement.EnumerateObject().Count().ShouldBe(0); - } - - [Test] - public void McpToolAnnotations_WithValues_SerializesCorrectCamelCase() - { - var annotations = new McpToolAnnotations - { - ReadOnlyHint = true, - DestructiveHint = false, - IdempotentHint = true, - OpenWorldHint = false - }; - var json = JsonSerializer.Serialize(annotations, McpJsonContext.Default.McpToolAnnotations); - - using var doc = JsonDocument.Parse(json); - var root = doc.RootElement; - - root.TryGetProperty("readOnlyHint", out var readOnly).ShouldBeTrue(); - readOnly.GetBoolean().ShouldBeTrue(); - - root.TryGetProperty("destructiveHint", out var destructive).ShouldBeTrue(); - destructive.GetBoolean().ShouldBeFalse(); - - root.TryGetProperty("idempotentHint", out var idempotent).ShouldBeTrue(); - idempotent.GetBoolean().ShouldBeTrue(); - - root.TryGetProperty("openWorldHint", out var openWorld).ShouldBeTrue(); - openWorld.GetBoolean().ShouldBeFalse(); - } - - [Test] - public void McpToolAnnotations_PartialValues_OnlySetOnesPresent() - { - var annotations = new McpToolAnnotations - { - ReadOnlyHint = true - // others null - }; - var json = JsonSerializer.Serialize(annotations, McpJsonContext.Default.McpToolAnnotations); - - using var doc = JsonDocument.Parse(json); - var root = doc.RootElement; - - root.TryGetProperty("readOnlyHint", out _).ShouldBeTrue(); - root.TryGetProperty("destructiveHint", out _).ShouldBeFalse(); - root.TryGetProperty("idempotentHint", out _).ShouldBeFalse(); - root.TryGetProperty("openWorldHint", out _).ShouldBeFalse(); - } -} diff --git a/tests/clawsharp.Tests/Unit/McpServer/McpServerModeConfigTests.cs b/tests/clawsharp.Tests/Unit/McpServer/McpServerModeConfigTests.cs index a857fd3c..7517bb07 100644 --- a/tests/clawsharp.Tests/Unit/McpServer/McpServerModeConfigTests.cs +++ b/tests/clawsharp.Tests/Unit/McpServer/McpServerModeConfigTests.cs @@ -10,6 +10,7 @@ namespace Clawsharp.Tests.Unit.McpServer; /// Unit tests for McpServerModeConfig serialization/deserialization (AUTH-05) /// and ConfigValidator rules for the mcpServer config section. /// +[TestFixture] public sealed class McpServerModeConfigTests { /// Creates a minimal valid AppConfig for validation to pass non-MCP checks. diff --git a/tests/clawsharp.Tests/Unit/McpServer/McpServerRouteRegistrarTests.cs b/tests/clawsharp.Tests/Unit/McpServer/McpServerRouteRegistrarTests.cs index cecee9be..7cd449a6 100644 --- a/tests/clawsharp.Tests/Unit/McpServer/McpServerRouteRegistrarTests.cs +++ b/tests/clawsharp.Tests/Unit/McpServer/McpServerRouteRegistrarTests.cs @@ -48,8 +48,7 @@ private static McpServerAuthenticator CreateAuthenticator( NullLogger.Instance); return new McpServerAuthenticator( config, - apiKeyAuth, - NullLogger.Instance); + apiKeyAuth); } [SetUp] @@ -98,10 +97,10 @@ public async Task ConfigureSessionAsync_AuthenticatedRequest_PopulatesToolCollec mcpOptions.ToolCollection!.Count.ShouldBe(2); } - // ── ConfigureSessionAsync: Unauthenticated request throws ── + // ── ConfigureSessionAsync: Unauthenticated request throws with 401 ── [Test] - public async Task ConfigureSessionAsync_UnauthenticatedRequest_ThrowsUnauthorizedAccessException() + public async Task ConfigureSessionAsync_UnauthenticatedRequest_Returns401AndThrows() { // Arrange: authenticator that requires auth (empty API keys dict = all rejected) var authConfig = new McpServerModeConfig @@ -119,14 +118,15 @@ public async Task ConfigureSessionAsync_UnauthenticatedRequest_ThrowsUnauthorize var mcpOptions = new McpServerOptions(); // Act & Assert - await Should.ThrowAsync(async () => + await Should.ThrowAsync(async () => await registrar.ConfigureSessionAsync(httpContext, mcpOptions, CancellationToken.None)); + httpContext.Response.StatusCode.ShouldBe(StatusCodes.Status401Unauthorized); } - // ── ConfigureSessionAsync: Denied origin throws ── + // ── ConfigureSessionAsync: Denied origin throws with 403 ── [Test] - public async Task ConfigureSessionAsync_DeniedOrigin_ThrowsUnauthorizedAccessException() + public async Task ConfigureSessionAsync_DeniedOrigin_Returns403AndThrows() { // Arrange: authenticator with null allowedOrigins = deny all external origins var authConfig = new McpServerModeConfig @@ -144,8 +144,9 @@ public async Task ConfigureSessionAsync_DeniedOrigin_ThrowsUnauthorizedAccessExc var mcpOptions = new McpServerOptions(); // Act & Assert - await Should.ThrowAsync(async () => + await Should.ThrowAsync(async () => await registrar.ConfigureSessionAsync(httpContext, mcpOptions, CancellationToken.None)); + httpContext.Response.StatusCode.ShouldBe(StatusCodes.Status403Forbidden); } // ── ConfigureSessionAsync: ServerInfo ── diff --git a/tests/clawsharp.Tests/Unit/McpServer/McpSessionSpanTests.cs b/tests/clawsharp.Tests/Unit/McpServer/McpSessionSpanTests.cs index 3071010c..665a03f0 100644 --- a/tests/clawsharp.Tests/Unit/McpServer/McpSessionSpanTests.cs +++ b/tests/clawsharp.Tests/Unit/McpServer/McpSessionSpanTests.cs @@ -99,8 +99,7 @@ public async Task ConfigureSessionAsync_EmitsSessionInitSpanWithAttributes() NullLogger.Instance); var authenticator = new McpServerAuthenticator( config: null, - apiKeyAuth, - NullLogger.Instance); + apiKeyAuth); var toolRegistry = Substitute.For(); toolRegistry.GetFilteredDefinitions(null).Returns(new List()); diff --git a/tests/clawsharp.Tests/Unit/Organization/ApprovalQueueTests.cs b/tests/clawsharp.Tests/Unit/Organization/ApprovalQueueTests.cs index 1f601325..d8956461 100644 --- a/tests/clawsharp.Tests/Unit/Organization/ApprovalQueueTests.cs +++ b/tests/clawsharp.Tests/Unit/Organization/ApprovalQueueTests.cs @@ -254,7 +254,7 @@ public void HasActiveGrant_ActiveGrant_ReturnsTrue() } [Test] - public void HasActiveGrant_ExpiredGrant_ReturnsFalse() + public async Task HasActiveGrant_ExpiredGrant_ReturnsFalse() { var user = CreateUser(); var requestId = _queue.Enqueue(user, "shell", ChannelName.Telegram, "123"); @@ -262,7 +262,7 @@ public void HasActiveGrant_ExpiredGrant_ReturnsFalse() _queue.Approve(requestId, "admin", TimeSpan.FromMilliseconds(1)); // Wait for expiry - Thread.Sleep(10); + await Task.Delay(10); _queue.HasActiveGrant("alice", "shell").ShouldBeFalse(); } @@ -307,7 +307,7 @@ public void GetPendingForUser_ReturnsUserPendingOnly() // --- Expiry --- [Test] - public void RequestExpiry_ExpiredRequestsTransitionToExpired() + public async Task RequestExpiry_ExpiredRequestsTransitionToExpired() { // Create a queue with very short TTL var shortTtlConfig = new AppConfig @@ -323,7 +323,7 @@ public void RequestExpiry_ExpiredRequestsTransitionToExpired() var requestId = shortQueue.Enqueue(user, "shell", ChannelName.Telegram, "123"); // Wait for expiry - Thread.Sleep(10); + await Task.Delay(10); // GetPendingRequests triggers cleanup var pending = shortQueue.GetPendingRequests(); @@ -334,7 +334,7 @@ public void RequestExpiry_ExpiredRequestsTransitionToExpired() } [Test] - public void Enqueue_SameUserTool_AfterExpired_CreatesNewRequest() + public async Task Enqueue_SameUserTool_AfterExpired_CreatesNewRequest() { var shortTtlConfig = new AppConfig { @@ -348,7 +348,7 @@ public void Enqueue_SameUserTool_AfterExpired_CreatesNewRequest() var user = CreateUser(); var id1 = shortQueue.Enqueue(user, "shell", ChannelName.Telegram, "123"); - Thread.Sleep(10); + await Task.Delay(10); // Trigger cleanup shortQueue.GetPendingRequests(); @@ -365,8 +365,8 @@ public async Task InitializeAsync_RebuildsStateFromJSONL() var requestId = _queue.Enqueue(user, "shell", ChannelName.Telegram, "123"); _queue.Approve(requestId, "admin", TimeSpan.FromHours(1)); - // Wait for fire-and-forget persist - await Task.Delay(100); + // Flush pending fire-and-forget storage writes deterministically + await _queue.FlushPendingWritesAsync(); // Create a new queue from the same storage var config = new AppConfig diff --git a/tests/clawsharp.Tests/Unit/Organization/AuthorizationBehaviorTests.cs b/tests/clawsharp.Tests/Unit/Organization/AuthorizationBehaviorTests.cs index 7f0876a5..13818667 100644 --- a/tests/clawsharp.Tests/Unit/Organization/AuthorizationBehaviorTests.cs +++ b/tests/clawsharp.Tests/Unit/Organization/AuthorizationBehaviorTests.cs @@ -63,7 +63,7 @@ public async Task HandleAsync_NoOrgConfig_PassesThroughToNext() { var options = MakeOptions(orgConfig: null); var logger = NullLogger>.Instance; - var behavior = new AuthorizationBehavior(options, logger); + var behavior = new AuthorizationBehavior(options); var stub = WirePipeline(behavior, "expected-response"); var result = await behavior.HandleAsync("test-request", CancellationToken.None); @@ -82,7 +82,7 @@ public async Task HandleAsync_InternalSessionCommand_SkipsAuth_PassesThrough() { var options = MakeOptions(orgConfig: new OrganizationConfig()); var logger = NullLogger>.Instance; - var behavior = new AuthorizationBehavior(options, logger); + var behavior = new AuthorizationBehavior(options); var stub = WirePipeline(behavior, default(ValueTuple)); var command = new Clawsharp.Features.Session.Commands.SaveSession.Command( @@ -102,7 +102,7 @@ public async Task HandleAsync_InternalCostCommand_SkipsAuth_PassesThrough() { var options = MakeOptions(orgConfig: new OrganizationConfig()); var logger = NullLogger>.Instance; - var behavior = new AuthorizationBehavior(options, logger); + var behavior = new AuthorizationBehavior(options); var stub = WirePipeline(behavior, default(ValueTuple)); var command = new Clawsharp.Features.Cost.Commands.RecordUsage.Command("sess1", "gpt-4o", 100, 50); @@ -123,7 +123,7 @@ public async Task HandleAsync_AuthRequiredRequest_WithOrgConfig_PassesThroughToN { var options = MakeOptions(orgConfig: new OrganizationConfig()); var logger = NullLogger>.Instance; - var behavior = new AuthorizationBehavior(options, logger); + var behavior = new AuthorizationBehavior(options); var stub = WirePipeline(behavior, "handler-response"); var result = await behavior.HandleAsync("some-request", CancellationToken.None); @@ -142,7 +142,7 @@ public async Task HandleAsync_UnknownRequestType_DoesNotThrow_PassesThrough() { var options = MakeOptions(orgConfig: new OrganizationConfig()); var logger = NullLogger>.Instance; - var behavior = new AuthorizationBehavior(options, logger); + var behavior = new AuthorizationBehavior(options); var stub = WirePipeline(behavior, true); var result = await behavior.HandleAsync(42, CancellationToken.None); @@ -185,7 +185,7 @@ public async Task HandleAsync_LoadSessionQuery_SkipsAuth_PassesThrough() { var options = MakeOptions(orgConfig: new OrganizationConfig()); var logger = NullLogger>.Instance; - var behavior = new AuthorizationBehavior(options, logger); + var behavior = new AuthorizationBehavior(options); var expectedSession = new Clawsharp.Core.Sessions.Session { Id = "test:1" }; var stub = WirePipeline(behavior, expectedSession); @@ -207,7 +207,7 @@ public async Task HandleAsync_ClearSessionCommand_SkipsAuth_PassesThrough() { var options = MakeOptions(orgConfig: new OrganizationConfig()); var logger = NullLogger>.Instance; - var behavior = new AuthorizationBehavior(options, logger); + var behavior = new AuthorizationBehavior(options); var stub = WirePipeline(behavior, default(ValueTuple)); var command = new Clawsharp.Features.Session.Commands.ClearSession.Command( @@ -223,7 +223,7 @@ public async Task HandleAsync_PruneSessionCommand_SkipsAuth_PassesThrough() { var options = MakeOptions(orgConfig: new OrganizationConfig()); var logger = NullLogger>.Instance; - var behavior = new AuthorizationBehavior(options, logger); + var behavior = new AuthorizationBehavior(options); var stub = WirePipeline(behavior, false); var command = new Clawsharp.Features.Session.Commands.PruneSession.Command( diff --git a/tests/clawsharp.Tests/Unit/Organization/ConfigMutatorTests.cs b/tests/clawsharp.Tests/Unit/Organization/ConfigMutatorTests.cs index de20ad1d..6bca82f9 100644 --- a/tests/clawsharp.Tests/Unit/Organization/ConfigMutatorTests.cs +++ b/tests/clawsharp.Tests/Unit/Organization/ConfigMutatorTests.cs @@ -126,18 +126,34 @@ await ConfigMutator.MutateConfigAsync(_configPath, root => // ── Empty / missing file handling ──────────────────────────────────── [Test] - public async Task MutateConfigAsync_EmptyFile_ThrowsJsonException() + public async Task MutateConfigAsync_EmptyFile_TreatedAsMissing() { await File.WriteAllTextAsync(_configPath, ""); - // Empty string causes JsonNode.Parse to throw (not null return). - // This documents the real behavior -- ConfigMutator does not handle - // empty files; callers should ensure the file is absent or valid JSON. - await Should.ThrowAsync( - ConfigMutator.MutateConfigAsync(_configPath, root => - { - root["created"] = true; - })); + // Empty file is treated as missing: mutation creates a fresh JSON object + await ConfigMutator.MutateConfigAsync(_configPath, root => + { + root["created"] = true; + }); + + File.Exists(_configPath).ShouldBeTrue(); + var json = JsonNode.Parse(await File.ReadAllTextAsync(_configPath)); + json!["created"]!.GetValue().ShouldBeTrue(); + } + + [Test] + public async Task MutateConfigAsync_WhitespaceOnlyFile_TreatedAsMissing() + { + await File.WriteAllTextAsync(_configPath, " \n "); + + await ConfigMutator.MutateConfigAsync(_configPath, root => + { + root["recovered"] = true; + }); + + File.Exists(_configPath).ShouldBeTrue(); + var json = JsonNode.Parse(await File.ReadAllTextAsync(_configPath)); + json!["recovered"]!.GetValue().ShouldBeTrue(); } [Test] diff --git a/tests/clawsharp.Tests/Unit/Organization/IdpConfigSerializationTests.cs b/tests/clawsharp.Tests/Unit/Organization/IdpConfigSerializationTests.cs index 2d163016..f736585a 100644 --- a/tests/clawsharp.Tests/Unit/Organization/IdpConfigSerializationTests.cs +++ b/tests/clawsharp.Tests/Unit/Organization/IdpConfigSerializationTests.cs @@ -7,6 +7,7 @@ namespace Clawsharp.Tests.Unit.Organization; /// /// Round-trip serialization tests for IdpConfig and ClaimsConfig using source-generated JSON context. /// +[TestFixture] public sealed class IdpConfigSerializationTests { private const string FullIdpConfigJson = """ diff --git a/tests/clawsharp.Tests/Unit/Organization/LinkTokenStoreTests.cs b/tests/clawsharp.Tests/Unit/Organization/LinkTokenStoreTests.cs index 208e48f8..56efd99e 100644 --- a/tests/clawsharp.Tests/Unit/Organization/LinkTokenStoreTests.cs +++ b/tests/clawsharp.Tests/Unit/Organization/LinkTokenStoreTests.cs @@ -6,6 +6,7 @@ namespace Clawsharp.Tests.Unit.Organization; /// Tests for : HMAC-signed token generation, /// validation with constant-time comparison, TTL enforcement, and single-use atomicity. /// +[TestFixture] public sealed class LinkTokenStoreTests { [Test] diff --git a/tests/clawsharp.Tests/Unit/Organization/OidcBearerTokenTests.cs b/tests/clawsharp.Tests/Unit/Organization/OidcBearerTokenTests.cs index c60cb8ab..eb727c1b 100644 --- a/tests/clawsharp.Tests/Unit/Organization/OidcBearerTokenTests.cs +++ b/tests/clawsharp.Tests/Unit/Organization/OidcBearerTokenTests.cs @@ -216,9 +216,9 @@ public async Task McpServerAuthenticator_JwtFallback_WithNoOidcService_DoesNotTh } }; - var appConfig = new Clawsharp.Config.AppConfig + var appConfig = new AppConfig { - Organization = new Clawsharp.Config.Organization.OrganizationConfig + Organization = new OrganizationConfig { Name = "TestOrg", Users = new Dictionary @@ -239,9 +239,8 @@ public async Task McpServerAuthenticator_JwtFallback_WithNoOidcService_DoesNotTh config, identityResolver, policyEvaluator, oidcService: null, idpConfig: null, NullLogger.Instance); - var authenticator = new Clawsharp.McpServer.McpServerAuthenticator( - config, apiKeyAuth, - NullLogger.Instance); + var authenticator = new McpServerAuthenticator( + config, apiKeyAuth); // Passing a JWT-like string when no OIDC is configured should not throw var result = await authenticator.AuthenticateAsync("eyJhbGciOiJSUzI1NiJ9.invalid.token"); @@ -263,9 +262,9 @@ public async Task McpServerAuthenticator_JwtFallback_InvalidToken_ReturnsUnauthe } }; - var appConfig = new Clawsharp.Config.AppConfig + var appConfig = new AppConfig { - Organization = new Clawsharp.Config.Organization.OrganizationConfig + Organization = new OrganizationConfig { Name = "TestOrg", Users = new Dictionary @@ -286,9 +285,8 @@ public async Task McpServerAuthenticator_JwtFallback_InvalidToken_ReturnsUnauthe config, identityResolver, policyEvaluator, oidcService: null, idpConfig: null, NullLogger.Instance); - var authenticator = new Clawsharp.McpServer.McpServerAuthenticator( - config, apiKeyAuth, - NullLogger.Instance); + var authenticator = new McpServerAuthenticator( + config, apiKeyAuth); var result = await authenticator.AuthenticateAsync("not-a-valid-key"); diff --git a/tests/clawsharp.Tests/Unit/Organization/OidcServiceTests.cs b/tests/clawsharp.Tests/Unit/Organization/OidcServiceTests.cs index 5c0ad02b..929912d3 100644 --- a/tests/clawsharp.Tests/Unit/Organization/OidcServiceTests.cs +++ b/tests/clawsharp.Tests/Unit/Organization/OidcServiceTests.cs @@ -12,6 +12,7 @@ namespace Clawsharp.Tests.Unit.Organization; /// and . BuildAuthorizationUrl and ValidateIdTokenAsync require real OIDC /// infrastructure and are tested via integration tests. /// +[TestFixture] public sealed class OidcServiceTests { // ── GeneratePkce ────────────────────────────────────────────────────── diff --git a/tests/clawsharp.Tests/Unit/Organization/OrgConfigSerializationTests.cs b/tests/clawsharp.Tests/Unit/Organization/OrgConfigSerializationTests.cs index dfc3bad9..2395f01e 100644 --- a/tests/clawsharp.Tests/Unit/Organization/OrgConfigSerializationTests.cs +++ b/tests/clawsharp.Tests/Unit/Organization/OrgConfigSerializationTests.cs @@ -7,6 +7,7 @@ namespace Clawsharp.Tests.Unit.Organization; /// /// Round-trip serialization tests for organization config types using source-generated JSON context. /// +[TestFixture] public sealed class OrgConfigSerializationTests { private const string FullOrgConfigJson = """ diff --git a/tests/clawsharp.Tests/Unit/Organization/OrgConfigValidationTests.cs b/tests/clawsharp.Tests/Unit/Organization/OrgConfigValidationTests.cs index 17329828..c440a116 100644 --- a/tests/clawsharp.Tests/Unit/Organization/OrgConfigValidationTests.cs +++ b/tests/clawsharp.Tests/Unit/Organization/OrgConfigValidationTests.cs @@ -7,6 +7,7 @@ namespace Clawsharp.Tests.Unit.Organization; /// /// Unit tests for organization config validation in ConfigValidator. /// +[TestFixture] public sealed class OrgConfigValidationTests { /// Creates a minimal valid AppConfig with a single provider for validation to pass non-org checks. diff --git a/tests/clawsharp.Tests/Unit/Organization/OrgSetRoleTests.cs b/tests/clawsharp.Tests/Unit/Organization/OrgSetRoleTests.cs index 1a59207c..c8a02089 100644 --- a/tests/clawsharp.Tests/Unit/Organization/OrgSetRoleTests.cs +++ b/tests/clawsharp.Tests/Unit/Organization/OrgSetRoleTests.cs @@ -5,6 +5,7 @@ using Clawsharp.Core.Pipeline; using Clawsharp.Core.Sessions; using Clawsharp.Organization; +using Clawsharp.Tools; namespace Clawsharp.Tests.Unit.Organization; @@ -41,12 +42,13 @@ private static RolePolicy CreateRolePolicy(bool isAdmin = false) }; } - private static Session CreateSession(OrgUser? currentUser = null) + private static Session CreateSession(OrgUser? currentUser = null, PolicyDecision? currentPolicy = null) { return new Session { Id = "test:session-1", - CurrentUser = currentUser + CurrentUser = currentUser, + CurrentPolicy = currentPolicy }; } @@ -212,4 +214,207 @@ public void HandleOrgSetRole_NoPoliciesConfig_ReturnsRoleNotFound() result.ShouldContain("Role not found"); } + + // ── CVE-2026-33579: Scope escalation prevention ───────────────────── + + [Test] + public void HandleOrgSetRole_LimitedAdminAssigningUnrestrictedRole_ReturnsDenied() + { + // limited-admin has restricted tool access but IsAdmin=true + var limitedAdmin = CreateUser(name: "admin", resolvedPolicies: [new RolePolicy + { + IsAdmin = true, + ToolAccess = JsonSerializer.Deserialize("[\"memory_*\"]"), + MaxToolSensitivity = "low", + Models = JsonSerializer.Deserialize("\"*\""), + }]); + var callerPolicy = new PolicyDecision + { + IsUnrestrictedToolAccess = false, + ToolPatterns = ["memory_*"], + MaxSensitivity = ToolSensitivity.Low, + IsUnrestrictedModels = true, + }; + var session = CreateSession(currentUser: limitedAdmin, currentPolicy: callerPolicy); + var appConfig = CreateAppConfigWithRoles(); // "admin" role has unrestricted tool access + + var (success, result) = AgentLoop.HandleOrgSetRole(session, "set-role @alice admin", appConfig); + + Assert.Multiple(() => + { + success.ShouldBeFalse(); + result.ShouldContain("unrestricted tool access"); + result.ShouldContain("exceeds your own policy"); + }); + } + + [Test] + public void HandleOrgSetRole_LimitedAdminAssigningSensitivityExceedingRole_ReturnsDenied() + { + var limitedAdmin = CreateUser(name: "admin", resolvedPolicies: [new RolePolicy + { + IsAdmin = true, + ToolAccess = JsonSerializer.Deserialize("\"unrestricted\""), + MaxToolSensitivity = "low", + Models = JsonSerializer.Deserialize("\"*\""), + }]); + var callerPolicy = new PolicyDecision + { + IsUnrestrictedToolAccess = true, + MaxSensitivity = ToolSensitivity.Low, + IsUnrestrictedModels = true, + }; + var session = CreateSession(currentUser: limitedAdmin, currentPolicy: callerPolicy); + + // Create config with a role that has higher sensitivity + var appConfig = new AppConfig + { + Organization = new OrganizationConfig + { + Name = "TestOrg", + Users = new Dictionary + { + ["admin"] = new() { Ids = ["test:admin"], Roles = ["admin"] }, + ["alice"] = new() { Ids = ["test:alice"], Roles = ["user"] } + }, + Policies = new PoliciesConfig + { + Roles = new Dictionary + { + ["high-ops"] = new() + { + ToolAccess = JsonSerializer.Deserialize("\"unrestricted\""), + MaxToolSensitivity = "critical", + Models = JsonSerializer.Deserialize("\"*\""), + } + } + } + } + }; + + var (success, result) = AgentLoop.HandleOrgSetRole(session, "set-role @alice high-ops", appConfig); + + Assert.Multiple(() => + { + success.ShouldBeFalse(); + result.ShouldContain("sensitivity ceiling exceeds your own"); + }); + } + + [Test] + public void HandleOrgSetRole_LimitedAdminAssigningUnrestrictedModels_ReturnsDenied() + { + var limitedAdmin = CreateUser(name: "admin", resolvedPolicies: [new RolePolicy + { + IsAdmin = true, + ToolAccess = JsonSerializer.Deserialize("\"unrestricted\""), + Models = JsonSerializer.Deserialize("[\"gpt-4o\"]"), + }]); + var callerPolicy = new PolicyDecision + { + IsUnrestrictedToolAccess = true, + MaxSensitivity = ToolSensitivity.Critical, + IsUnrestrictedModels = false, + ModelPatterns = ["gpt-4o"], + }; + var session = CreateSession(currentUser: limitedAdmin, currentPolicy: callerPolicy); + + var appConfig = new AppConfig + { + Organization = new OrganizationConfig + { + Name = "TestOrg", + Users = new Dictionary + { + ["admin"] = new() { Ids = ["test:admin"], Roles = ["admin"] }, + ["alice"] = new() { Ids = ["test:alice"], Roles = ["user"] } + }, + Policies = new PoliciesConfig + { + Roles = new Dictionary + { + ["all-models"] = new() + { + ToolAccess = JsonSerializer.Deserialize("\"unrestricted\""), + Models = JsonSerializer.Deserialize("\"*\""), + } + } + } + } + }; + + var (success, result) = AgentLoop.HandleOrgSetRole(session, "set-role @alice all-models", appConfig); + + Assert.Multiple(() => + { + success.ShouldBeFalse(); + result.ShouldContain("unrestricted model access"); + result.ShouldContain("exceeds your own policy"); + }); + } + + [Test] + public void HandleOrgSetRole_FullAdminAssigningAnyRole_Succeeds() + { + // Full admin with unrestricted everything — no scope limitation + var fullAdmin = CreateUser(name: "admin", resolvedPolicies: [CreateRolePolicy(isAdmin: true)]); + var callerPolicy = PolicyDecision.Unrestricted; + var session = CreateSession(currentUser: fullAdmin, currentPolicy: callerPolicy); + var appConfig = CreateAppConfigWithRoles(); + + var (success, result) = AgentLoop.HandleOrgSetRole(session, "set-role @alice admin", appConfig); + + Assert.Multiple(() => + { + success.ShouldBeTrue(); + result.ShouldContain("Role updated"); + }); + } + + [Test] + public void HandleOrgSetRole_AdminWithNullPolicy_Succeeds() + { + // Null policy = single-operator mode, should skip scope check + var admin = CreateUser(name: "admin", resolvedPolicies: [CreateRolePolicy(isAdmin: true)]); + var session = CreateSession(currentUser: admin, currentPolicy: null); + var appConfig = CreateAppConfigWithRoles(); + + var (success, result) = AgentLoop.HandleOrgSetRole(session, "set-role @alice engineering", appConfig); + + Assert.Multiple(() => + { + success.ShouldBeTrue(); + result.ShouldContain("Role updated"); + }); + } + + [Test] + public void HandleOrgSetRole_LimitedAdminAssigningEqualRole_Succeeds() + { + // Admin assigning a role with equal or lesser privileges should succeed + var limitedAdmin = CreateUser(name: "admin", resolvedPolicies: [new RolePolicy + { + IsAdmin = true, + ToolAccess = JsonSerializer.Deserialize("[\"shell\",\"file_*\",\"memory_*\"]"), + MaxToolSensitivity = "high", + Models = JsonSerializer.Deserialize("\"*\""), + }]); + var callerPolicy = new PolicyDecision + { + IsUnrestrictedToolAccess = false, + ToolPatterns = ["shell", "file_*", "memory_*"], + MaxSensitivity = ToolSensitivity.High, + IsUnrestrictedModels = true, + }; + var session = CreateSession(currentUser: limitedAdmin, currentPolicy: callerPolicy); + var appConfig = CreateAppConfigWithRoles(); // "developer" role has ["shell","file_*"] + + var (success, result) = AgentLoop.HandleOrgSetRole(session, "set-role @alice developer", appConfig); + + Assert.Multiple(() => + { + success.ShouldBeTrue(); + result.ShouldContain("Role updated"); + }); + } } diff --git a/tests/clawsharp.Tests/AgentLoopTests.cs b/tests/clawsharp.Tests/Unit/Pipeline/AgentLoopTests.cs similarity index 99% rename from tests/clawsharp.Tests/AgentLoopTests.cs rename to tests/clawsharp.Tests/Unit/Pipeline/AgentLoopTests.cs index 700bf975..097dfd9b 100644 --- a/tests/clawsharp.Tests/AgentLoopTests.cs +++ b/tests/clawsharp.Tests/Unit/Pipeline/AgentLoopTests.cs @@ -21,7 +21,7 @@ using Clawsharp.Config.Agent; using Clawsharp.Config.Features; -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Pipeline; /// /// Test harness that wires up a complete AgentLoop with fakes for integration testing. diff --git a/tests/clawsharp.Tests/GoalSlashCommandTests.cs b/tests/clawsharp.Tests/Unit/Pipeline/GoalSlashCommandTests.cs similarity index 96% rename from tests/clawsharp.Tests/GoalSlashCommandTests.cs rename to tests/clawsharp.Tests/Unit/Pipeline/GoalSlashCommandTests.cs index 22c61157..8a7710f7 100644 --- a/tests/clawsharp.Tests/GoalSlashCommandTests.cs +++ b/tests/clawsharp.Tests/Unit/Pipeline/GoalSlashCommandTests.cs @@ -1,7 +1,8 @@ using Clawsharp.Core.Pipeline; -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Pipeline; +[TestFixture] public sealed class GoalSlashCommandTests { [Test] @@ -23,6 +24,7 @@ public void GoalsUnknownArg_ReturnsShowGoals() } } +[TestFixture] public sealed class GoalSystemPromptTests { [Test] diff --git a/tests/clawsharp.Tests/Unit/Pipeline/OrgApprovalCommandTests.cs b/tests/clawsharp.Tests/Unit/Pipeline/OrgApprovalCommandTests.cs index a5ca72e8..6f4fc745 100644 --- a/tests/clawsharp.Tests/Unit/Pipeline/OrgApprovalCommandTests.cs +++ b/tests/clawsharp.Tests/Unit/Pipeline/OrgApprovalCommandTests.cs @@ -5,6 +5,7 @@ using Clawsharp.Core.Sessions; using Clawsharp.Core.Utilities; using Clawsharp.Organization; +using Clawsharp.Tools; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Options; @@ -391,4 +392,131 @@ public void HandleOrgDeny_NoOrgConfig_ReturnsOrgNotEnabled() result.ShouldContain("Organization mode is not enabled"); } + + // ── CVE-2026-33579: Approver scope validation ─────────────────────── + + [Test] + public void HandleOrgApprove_LimitedAdminApprovingDeniedTool_ReturnsScopeDenied() + { + // Admin with low sensitivity ceiling tries to approve a critical-sensitivity tool + var limitedAdmin = CreateUser(name: "admin", resolvedPolicies: [new RolePolicy + { + IsAdmin = true, + ToolAccess = JsonSerializer.Deserialize("[\"memory_*\"]"), + MaxToolSensitivity = "low", + }]); + var callerPolicy = new PolicyDecision + { + IsUnrestrictedToolAccess = false, + ToolPatterns = ["memory_*"], + MaxSensitivity = ToolSensitivity.Low, + }; + var session = CreateSession(currentUser: limitedAdmin, currentPolicy: callerPolicy); + var appConfig = CreateOrgAppConfig(); + var queue = CreateApprovalQueue(appConfig); + + var user = CreateUser(name: "alice"); + var requestId = queue.Enqueue(user, "shell", ChannelName.Cli, "test:alice"); + + // shell has Critical sensitivity + var (success, result) = AgentLoop.HandleOrgApprove( + session, $"approve {requestId}", appConfig, queue, + _ => ToolSensitivity.Critical); + + Assert.Multiple(() => + { + success.ShouldBeFalse(); + result.ShouldContain("your own policy does not allow this tool"); + // Request should still be pending (not approved) + queue.GetRequest(requestId)!.State.ShouldBe(ApprovalState.Pending); + }); + } + + [Test] + public void HandleOrgApprove_LimitedAdminApprovingToolNotInGlobs_ReturnsScopeDenied() + { + // Admin with restricted tool patterns tries to approve a tool outside their patterns + var limitedAdmin = CreateUser(name: "admin", resolvedPolicies: [new RolePolicy + { + IsAdmin = true, + ToolAccess = JsonSerializer.Deserialize("[\"memory_*\"]"), + }]); + var callerPolicy = new PolicyDecision + { + IsUnrestrictedToolAccess = false, + ToolPatterns = ["memory_*"], + MaxSensitivity = ToolSensitivity.Critical, + }; + var session = CreateSession(currentUser: limitedAdmin, currentPolicy: callerPolicy); + var appConfig = CreateOrgAppConfig(); + var queue = CreateApprovalQueue(appConfig); + + var user = CreateUser(name: "alice"); + var requestId = queue.Enqueue(user, "shell", ChannelName.Cli, "test:alice"); + + var (success, result) = AgentLoop.HandleOrgApprove( + session, $"approve {requestId}", appConfig, queue, + _ => ToolSensitivity.High); + + Assert.Multiple(() => + { + success.ShouldBeFalse(); + result.ShouldContain("your own policy does not allow this tool"); + }); + } + + [Test] + public void HandleOrgApprove_FullAdminApprovingAnyTool_Succeeds() + { + // Unrestricted admin should still be able to approve any tool + var fullAdmin = CreateUser(name: "admin", resolvedPolicies: [CreateRolePolicy(isAdmin: true)]); + var session = CreateSession(currentUser: fullAdmin, currentPolicy: PolicyDecision.Unrestricted); + var appConfig = CreateOrgAppConfig(); + var queue = CreateApprovalQueue(appConfig); + + var user = CreateUser(name: "alice"); + var requestId = queue.Enqueue(user, "shell", ChannelName.Cli, "test:alice"); + + var (success, result) = AgentLoop.HandleOrgApprove( + session, $"approve {requestId}", appConfig, queue, + _ => ToolSensitivity.Critical); + + Assert.Multiple(() => + { + success.ShouldBeTrue(); + result.ShouldContain("Approved"); + }); + } + + [Test] + public void HandleOrgApprove_AdminApprovingToolWithinScope_Succeeds() + { + // Admin whose policy allows the tool should succeed + var admin = CreateUser(name: "admin", resolvedPolicies: [new RolePolicy + { + IsAdmin = true, + ToolAccess = JsonSerializer.Deserialize("\"unrestricted\""), + }]); + var callerPolicy = new PolicyDecision + { + IsUnrestrictedToolAccess = true, + MaxSensitivity = ToolSensitivity.Critical, + }; + var session = CreateSession(currentUser: admin, currentPolicy: callerPolicy); + var appConfig = CreateOrgAppConfig(); + var queue = CreateApprovalQueue(appConfig); + + var user = CreateUser(name: "alice"); + var requestId = queue.Enqueue(user, "shell", ChannelName.Cli, "test:alice"); + + var (success, result) = AgentLoop.HandleOrgApprove( + session, $"approve {requestId}", appConfig, queue, + _ => ToolSensitivity.Critical); + + Assert.Multiple(() => + { + success.ShouldBeTrue(); + result.ShouldContain("Approved"); + }); + } } diff --git a/tests/clawsharp.Tests/Unit/Providers/GeminiHealthCheckTests.cs b/tests/clawsharp.Tests/Unit/Providers/GeminiHealthCheckTests.cs index 6c144894..53af0165 100644 --- a/tests/clawsharp.Tests/Unit/Providers/GeminiHealthCheckTests.cs +++ b/tests/clawsharp.Tests/Unit/Providers/GeminiHealthCheckTests.cs @@ -126,10 +126,10 @@ public async Task CheckHealthAsync_Failure_ResponseTimePopulated() result.ResponseTime.ShouldNotBeNull(); } - // -- 9. Correct URL is called with API key as query parameter -- + // -- 9. Correct URL is called with API key in header (not query string) -- [Test] - public async Task CheckHealthAsync_CallsModelsEndpointWithApiKey() + public async Task CheckHealthAsync_CallsModelsEndpointWithApiKeyHeader() { var handler = new ConfigurableHttpHandler(HttpStatusCode.OK, """{"models":[]}"""); var provider = CreateProvider(handler, apiKey: "test-gemini-key"); @@ -139,7 +139,10 @@ public async Task CheckHealthAsync_CallsModelsEndpointWithApiKey() handler.LastRequestUri.ShouldNotBeNull(); var uri = handler.LastRequestUri!.ToString(); uri.ShouldContain("generativelanguage.googleapis.com/v1beta/models"); - uri.ShouldContain("key=test-gemini-key"); + uri.ShouldNotContain("key="); + + handler.LastCustomHeaders.ShouldContainKey("x-goog-api-key"); + handler.LastCustomHeaders["x-goog-api-key"].ShouldBe("test-gemini-key"); } // -- 10. Uses GET method -- diff --git a/tests/clawsharp.Tests/ProviderStreamingTests.cs b/tests/clawsharp.Tests/Unit/Providers/ProviderStreamingTests.cs similarity index 99% rename from tests/clawsharp.Tests/ProviderStreamingTests.cs rename to tests/clawsharp.Tests/Unit/Providers/ProviderStreamingTests.cs index f1de4f00..8f28f5f4 100644 --- a/tests/clawsharp.Tests/ProviderStreamingTests.cs +++ b/tests/clawsharp.Tests/Unit/Providers/ProviderStreamingTests.cs @@ -2,7 +2,7 @@ using Clawsharp.Core; using Clawsharp.Tests.Fakes; -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Providers; [TestFixture] public sealed class ProviderStreamingTests diff --git a/tests/clawsharp.Tests/Unit/Providers/SanitizeErrorBodyTests.cs b/tests/clawsharp.Tests/Unit/Providers/SanitizeErrorBodyTests.cs index 9b4513d2..a5bc414b 100644 --- a/tests/clawsharp.Tests/Unit/Providers/SanitizeErrorBodyTests.cs +++ b/tests/clawsharp.Tests/Unit/Providers/SanitizeErrorBodyTests.cs @@ -293,4 +293,70 @@ public void VariousKeyFormats_AreRedacted(string key) result.ShouldNotContain(key); result.ShouldContain("[REDACTED]"); } + + // ========================================================================= + // 10. Gemini API Keys (AIzaSy...) + // ========================================================================= + + [Test] + public void GeminiApiKey_IsRedacted() + { + // Gemini keys start with "AIzaSy" and are 39 chars total (6 prefix + 33 body) + const string key = "AIzaSyA1B2C3D4E5F6G7H8I9J0K1L2M3N4O5P6Q"; + var input = $"Invalid Gemini API key: {key}"; + var result = ProviderRequestHandler.SanitizeErrorBody(input); + + result.ShouldNotContain(key); + result.ShouldContain("[REDACTED]"); + result.ShouldContain("Invalid Gemini API key:"); + } + + [Test] + public void GeminiApiKeyWithDashesAndUnderscores_IsRedacted() + { + // Gemini keys can contain dashes and underscores + const string key = "AIzaSyA-B_C3D4E5F6G7H8I9J0K1L2M3N4O5P6Q"; + var input = $"Error: {key}"; + var result = ProviderRequestHandler.SanitizeErrorBody(input); + + result.ShouldNotContain(key); + result.ShouldContain("[REDACTED]"); + } + + [Test] + public void ShortAIzaPrefix_NotRedacted() + { + // "AIzaSy" followed by fewer than 33 chars should not match + const string input = "Error: AIzaSyShort"; + var result = ProviderRequestHandler.SanitizeErrorBody(input); + + result.ShouldContain("AIzaSyShort"); + } + + // ========================================================================= + // 11. AWS Access Key IDs (AKIA...) + // ========================================================================= + + [Test] + public void AwsAccessKeyId_IsRedacted() + { + // AWS access key IDs start with "AKIA" and are 20 chars total (4 prefix + 16 body) + const string key = "AKIAIOSFODNN7EXAMPLE"; + var input = $"AWS credential error: {key}"; + var result = ProviderRequestHandler.SanitizeErrorBody(input); + + result.ShouldNotContain(key); + result.ShouldContain("[REDACTED]"); + result.ShouldContain("AWS credential error:"); + } + + [Test] + public void ShortAkiaPrefix_NotRedacted() + { + // "AKIA" followed by fewer than 16 chars should not match + const string input = "Error: AKIASHORT"; + var result = ProviderRequestHandler.SanitizeErrorBody(input); + + result.ShouldContain("AKIASHORT"); + } } diff --git a/tests/clawsharp.Tests/Unit/Providers/TagStripFilterEdgeCaseTests.cs b/tests/clawsharp.Tests/Unit/Providers/TagStripFilterEdgeCaseTests.cs index 18844e2c..8faa4106 100644 --- a/tests/clawsharp.Tests/Unit/Providers/TagStripFilterEdgeCaseTests.cs +++ b/tests/clawsharp.Tests/Unit/Providers/TagStripFilterEdgeCaseTests.cs @@ -119,4 +119,27 @@ public void ProcessChunk_NullChunk_ReturnsEmpty() var result = filter.ProcessChunk(null!); result.ShouldBe(string.Empty); } + + // ── Streaming: re-entry on '<' during MaybeOpenTag flush ───────── + + [Test] + public void ProcessChunk_AngleBracketBreaksPrefixThenRealTag_StripsCorrectly() + { + // "" — the "/, + // then the second "<" breaks the match. The filter should flush "... + var filter = TagStripFilter.CreateStreamingFilter(); + var result = filter.ProcessChunk("hiddenvisible"); + result.ShouldBe("" — first "<" starts MaybeOpenTag, second "<" breaks + // prefix, should flush first "<" and start new match from second "<". + var filter = TagStripFilter.CreateStreamingFilter(); + var result = filter.ProcessChunk("<hiddenvisible"); + result.ShouldBe(" +[TestFixture] public sealed class HistoricalBugRegressionTests { // ══════════════════════════════════════════════════════════════════════ diff --git a/tests/clawsharp.Tests/Unit/Regression/ReviewFindingsRegressionTests.cs b/tests/clawsharp.Tests/Unit/Regression/ReviewFindingsRegressionTests.cs index d080466e..d58ad799 100644 --- a/tests/clawsharp.Tests/Unit/Regression/ReviewFindingsRegressionTests.cs +++ b/tests/clawsharp.Tests/Unit/Regression/ReviewFindingsRegressionTests.cs @@ -11,6 +11,7 @@ namespace Clawsharp.Tests.Unit.Regression; /// Regression tests for the 4 code review findings fixed on the analytics-schema-and-tests branch. /// Each test validates that the specific bug cannot silently reappear. /// +[TestFixture] public sealed class ReviewFindingsRegressionTests { // ────────────────────────────────────────────────────────────────────── diff --git a/tests/clawsharp.Tests/Unit/Security/AdminRoleFilterTests.cs b/tests/clawsharp.Tests/Unit/Security/AdminRoleFilterTests.cs index 174f513e..fda0fd27 100644 --- a/tests/clawsharp.Tests/Unit/Security/AdminRoleFilterTests.cs +++ b/tests/clawsharp.Tests/Unit/Security/AdminRoleFilterTests.cs @@ -119,16 +119,46 @@ await filter.InvokeAsync(invocationCtx, _ => nextCalled.ShouldBeTrue(); } - // ── IsUnrestrictedToolAccess -> passes through ──────────────────────── + // ── IsUnrestrictedToolAccess without IsAdmin -> 403 (CWE-863 fix) ──── [Test] - public async Task InvokeAsync_UnrestrictedToolAccess_PassesThrough() + public async Task InvokeAsync_UnrestrictedToolAccessWithoutAdminRole_Returns403() + { + var user = new OrgUser + { + Name = "alice", + Roles = ["power-user"], + ResolvedPolicies = [new RolePolicy { IsAdmin = false }] + }; + var policy = new PolicyDecision + { + IsUnrestrictedToolAccess = true, + }; + + var authResult = McpServerAuthResult.Success(user, policy, "power-key"); + var httpCtx = CreateContextWithAuthResult(authResult); + var invocationCtx = new FakeEndpointFilterInvocationContext(httpCtx); + + var filter = new AdminRoleFilter(); + var result = await filter.InvokeAsync(invocationCtx, _ => + ValueTask.FromResult(Results.Ok())); + + // Unrestricted tool access alone should NOT grant admin endpoint access + var typed = result as IStatusCodeHttpResult; + typed.ShouldNotBeNull(); + typed!.StatusCode.ShouldBe(403); + } + + // ── IsUnrestrictedToolAccess WITH IsAdmin -> passes through ────────── + + [Test] + public async Task InvokeAsync_UnrestrictedToolAccessWithAdminRole_PassesThrough() { var user = new OrgUser { Name = "alice", Roles = ["admin"], - ResolvedPolicies = [] + ResolvedPolicies = [new RolePolicy { IsAdmin = true }] }; var policy = new PolicyDecision { diff --git a/tests/clawsharp.Tests/PathGuardTests.cs b/tests/clawsharp.Tests/Unit/Security/PathGuardTests.cs similarity index 99% rename from tests/clawsharp.Tests/PathGuardTests.cs rename to tests/clawsharp.Tests/Unit/Security/PathGuardTests.cs index 5ee210fb..7d2c377e 100644 --- a/tests/clawsharp.Tests/PathGuardTests.cs +++ b/tests/clawsharp.Tests/Unit/Security/PathGuardTests.cs @@ -1,7 +1,8 @@ using Clawsharp.Tools; -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Security; +[TestFixture] public sealed class PathGuardTests { private string _workspace = null!; diff --git a/tests/clawsharp.Tests/SsrfCheckTests.cs b/tests/clawsharp.Tests/Unit/Security/SsrfCheckTests.cs similarity index 98% rename from tests/clawsharp.Tests/SsrfCheckTests.cs rename to tests/clawsharp.Tests/Unit/Security/SsrfCheckTests.cs index 2f685e1b..774f957b 100644 --- a/tests/clawsharp.Tests/SsrfCheckTests.cs +++ b/tests/clawsharp.Tests/Unit/Security/SsrfCheckTests.cs @@ -3,8 +3,9 @@ using Clawsharp.Tools.Web; using Microsoft.Extensions.DependencyInjection; -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Security; +[TestFixture] public sealed class SsrfCheckTests { // ── URL validation and scheme blocking (via ExecuteAsync) ───────── diff --git a/tests/clawsharp.Tests/WebPairingGuardEdgeCaseTests.cs b/tests/clawsharp.Tests/Unit/Security/WebPairingGuardEdgeCaseTests.cs similarity index 92% rename from tests/clawsharp.Tests/WebPairingGuardEdgeCaseTests.cs rename to tests/clawsharp.Tests/Unit/Security/WebPairingGuardEdgeCaseTests.cs index 44a4ff17..897e4ab3 100644 --- a/tests/clawsharp.Tests/WebPairingGuardEdgeCaseTests.cs +++ b/tests/clawsharp.Tests/Unit/Security/WebPairingGuardEdgeCaseTests.cs @@ -2,7 +2,7 @@ using Clawsharp.Security; using Microsoft.Extensions.Logging.Abstractions; -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Security; /// /// Edge-case tests for : @@ -101,21 +101,22 @@ public void TryPair_CodeWithLeadingTrailingWhitespace_MatchesTrimmed() [Test] public void TryPair_OverMaxFailureTrackingEntries_EvictsExpiredEntries() { - // The guard has a MaxFailureTrackingEntries = 10,000. - // When exceeded, it evicts entries with expired lockouts and count < MaxFailedAttempts. - // We can't easily test 10,001 IPs without being slow, so we test a smaller scenario - // to verify the eviction logic path is exercised. + // The guard has a MaxFailureTrackingEntries = 10,000 and MaxGlobalAttempts = 50. + // The global attempt counter invalidates the pairing code after 50 failed attempts + // across all IPs to defeat distributed brute-force. We test with fewer IPs than + // the global limit to verify the failure tracking and eviction logic without + // triggering code invalidation. var guard = new WebPairingGuard(_persistPath, NullLogger.Instance); - // Create failures from many different IPs (fewer than 10K but enough to test logic) - for (var i = 0; i < 100; i++) + // Create failures from 20 different IPs (below MaxGlobalAttempts of 50) + for (var i = 0; i < 20; i++) { // Each IP gets 1 failed attempt (below lockout threshold of 5) var ip = new IPAddress(BitConverter.GetBytes(i + 1).Reverse().ToArray()); guard.TryPair(ip, "wrong!"); } - // The guard should still function correctly + // The guard should still function correctly (pairing code not yet invalidated) var code = guard.PairingCode!; var result = guard.TryPair(IPAddress.Parse("192.168.1.1"), code); result.ShouldNotBeNull("Guard should still work after tracking many IPs"); diff --git a/tests/clawsharp.Tests/WebPairingGuardTests.cs b/tests/clawsharp.Tests/Unit/Security/WebPairingGuardTests.cs similarity index 99% rename from tests/clawsharp.Tests/WebPairingGuardTests.cs rename to tests/clawsharp.Tests/Unit/Security/WebPairingGuardTests.cs index 8da4b4f6..1ffc661c 100644 --- a/tests/clawsharp.Tests/WebPairingGuardTests.cs +++ b/tests/clawsharp.Tests/Unit/Security/WebPairingGuardTests.cs @@ -2,8 +2,9 @@ using Clawsharp.Security; using Microsoft.Extensions.Logging.Abstractions; -namespace Clawsharp.Tests; +namespace Clawsharp.Tests.Unit.Security; +[TestFixture] public sealed class WebPairingGuardTests { private string _persistPath = null!; diff --git a/tests/clawsharp.Tests/Unit/Telemetry/MetricsRegressionTests.cs b/tests/clawsharp.Tests/Unit/Telemetry/MetricsRegressionTests.cs index a6affca4..edd61911 100644 --- a/tests/clawsharp.Tests/Unit/Telemetry/MetricsRegressionTests.cs +++ b/tests/clawsharp.Tests/Unit/Telemetry/MetricsRegressionTests.cs @@ -46,7 +46,7 @@ public void OperationDuration_Record_DoesNotThrow() { Should.NotThrow(() => ClawsharpMetrics.OperationDuration.Record(1.5, - new GenAiMetricTags { OperationName = "chat", Model = "claude-3-5-sonnet", TokenType = "" })); + new DurationMetricTags { OperationName = "chat", Model = "claude-3-5-sonnet" })); } [Test] @@ -115,6 +115,13 @@ public void GenAiMetricTags_HasCorrectTagNames() AssertTagName("TokenType", "gen_ai.token.type"); } + [Test] + public void DurationMetricTags_HasCorrectTagNames() + { + AssertTagName("OperationName", "gen_ai.operation.name"); + AssertTagName("Model", "gen_ai.request.model"); + } + [Test] public void ToolMetricTags_HasCorrectTagNames() { diff --git a/tests/clawsharp.Tests/Unit/Telemetry/SpanIsolationTests.cs b/tests/clawsharp.Tests/Unit/Telemetry/SpanIsolationTests.cs index 09fc8d9b..0d1988c4 100644 --- a/tests/clawsharp.Tests/Unit/Telemetry/SpanIsolationTests.cs +++ b/tests/clawsharp.Tests/Unit/Telemetry/SpanIsolationTests.cs @@ -25,14 +25,17 @@ public async Task RunFireAndForget_NullsActivityCurrent_InsideTaskRun() using var parentActivity = TestSource.StartActivity("parent.op"); parentActivity.ShouldNotBeNull(); + var tcs = new TaskCompletionSource(); + // Act SpanIsolation.RunFireAndForget("test.isolated", TestSource, async () => { capturedCurrent = Activity.Current; await Task.CompletedTask; + tcs.SetResult(); }); - await Task.Delay(300); + await tcs.Task.WaitAsync(TimeSpan.FromSeconds(5)); // Assert: Activity.Current inside the work delegate should be the new span, not the parent // The parent activity should NOT be the current inside the delegate @@ -50,10 +53,16 @@ public async Task RunFireAndForget_CreatesSpanWithActivityLink_ToParent() parentActivity.ShouldNotBeNull(); var parentContext = parentActivity.Context; + var tcs = new TaskCompletionSource(); + // Act - SpanIsolation.RunFireAndForget("test.linked", TestSource, () => Task.CompletedTask); + SpanIsolation.RunFireAndForget("test.linked", TestSource, () => + { + tcs.SetResult(); + return Task.CompletedTask; + }); - await Task.Delay(300); + await tcs.Task.WaitAsync(TimeSpan.FromSeconds(5)); // Assert: should have created a new activity with a link back to the parent var isolatedActivity = activities.FirstOrDefault(a => a.OperationName == "test.linked"); @@ -72,10 +81,16 @@ public async Task RunFireAndForget_DoesNotCreateOrphanChildSpan_UnderOriginalPar using var parentActivity = TestSource.StartActivity("parent.op"); parentActivity.ShouldNotBeNull(); + var tcs = new TaskCompletionSource(); + // Act - SpanIsolation.RunFireAndForget("test.no-orphan", TestSource, () => Task.CompletedTask); + SpanIsolation.RunFireAndForget("test.no-orphan", TestSource, () => + { + tcs.SetResult(); + return Task.CompletedTask; + }); - await Task.Delay(300); + await tcs.Task.WaitAsync(TimeSpan.FromSeconds(5)); // Assert: the isolated activity should NOT have the parent as its parent var isolatedActivity = activities.FirstOrDefault(a => a.OperationName == "test.no-orphan"); @@ -90,11 +105,26 @@ public async Task RunFireAndForget_CatchesExceptions_WithoutPropagating() var activities = new List(); using var listener = CreateListener(activities); - // Act: should not throw even though the work delegate throws + var tcs = new TaskCompletionSource(); + + // Act: should not throw even though the work delegate throws. + // The TCS is set in ActivityStopped because the exception is swallowed by RunFireAndForget. + using var stopListener = new ActivityListener + { + ShouldListenTo = source => source.Name == TestSourceName, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded, + ActivityStopped = a => + { + if (a.OperationName == "test.throw") + tcs.TrySetResult(); + }, + }; + ActivitySource.AddActivityListener(stopListener); + SpanIsolation.RunFireAndForget("test.throw", TestSource, () => throw new InvalidOperationException("test error")); - await Task.Delay(300); + await tcs.Task.WaitAsync(TimeSpan.FromSeconds(5)); // Assert: no exception propagated, activity was still created var isolatedActivity = activities.FirstOrDefault(a => a.OperationName == "test.throw"); @@ -108,11 +138,26 @@ public async Task RunFireAndForget_SetsErrorStatus_WhenWorkThrows() var activities = new List(); using var listener = CreateListener(activities); + var tcs = new TaskCompletionSource(); + + // Use ActivityStopped to signal completion since the exception is swallowed. + using var stopListener = new ActivityListener + { + ShouldListenTo = source => source.Name == TestSourceName, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded, + ActivityStopped = a => + { + if (a.OperationName == "test.error-status") + tcs.TrySetResult(); + }, + }; + ActivitySource.AddActivityListener(stopListener); + // Act SpanIsolation.RunFireAndForget("test.error-status", TestSource, () => throw new InvalidOperationException("boom")); - await Task.Delay(300); + await tcs.Task.WaitAsync(TimeSpan.FromSeconds(5)); // Assert: span should have error status var isolatedActivity = activities.FirstOrDefault(a => a.OperationName == "test.error-status"); @@ -131,10 +176,16 @@ public async Task RunFireAndForget_WorksWithNoParentActivity() // Ensure no parent activity Activity.Current = null; + var tcs = new TaskCompletionSource(); + // Act: should not throw when there is no parent - SpanIsolation.RunFireAndForget("test.no-parent", TestSource, () => Task.CompletedTask); + SpanIsolation.RunFireAndForget("test.no-parent", TestSource, () => + { + tcs.SetResult(); + return Task.CompletedTask; + }); - await Task.Delay(300); + await tcs.Task.WaitAsync(TimeSpan.FromSeconds(5)); // Assert: activity was still created, with no links (no parent to link to) var isolatedActivity = activities.FirstOrDefault(a => a.OperationName == "test.no-parent"); diff --git a/tests/clawsharp.Tests/Unit/Webhooks/WebhookMetricsTests.cs b/tests/clawsharp.Tests/Unit/Webhooks/WebhookMetricsTests.cs index 641c7135..f377dff1 100644 --- a/tests/clawsharp.Tests/Unit/Webhooks/WebhookMetricsTests.cs +++ b/tests/clawsharp.Tests/Unit/Webhooks/WebhookMetricsTests.cs @@ -132,7 +132,7 @@ public async Task RegisterSseClient_ReturnsReaderThatReceivesBroadcastEvents() var config = MakeConfig("ep1"); var metrics = new WebhookMetrics(config); - var (registration, reader) = metrics.RegisterSseClient(null, null); + var (registration, reader) = metrics.RegisterSseClient(null, null)!.Value; using (registration) { var evt = MakeEvent("ep1", "delivery.success"); @@ -147,12 +147,12 @@ public async Task RegisterSseClient_ReturnsReaderThatReceivesBroadcastEvents() } [Test] - public void RegisterSseClient_WithTypeFilter_OnlyReceivesMatchingOutcome() + public void RegisterSseClient_WithOutcomeFilter_OnlyReceivesMatchingOutcome() { var config = MakeConfig("ep1"); var metrics = new WebhookMetrics(config); - var (registration, reader) = metrics.RegisterSseClient("delivery.success", null); + var (registration, reader) = metrics.RegisterSseClient("delivery.success", null)!.Value; using (registration) { metrics.RecordDelivery("ep1", MakeEvent("ep1", "delivery.failed", "e_fail")); @@ -160,7 +160,7 @@ public void RegisterSseClient_WithTypeFilter_OnlyReceivesMatchingOutcome() Assert.That(reader.TryRead(out var received), Is.True); Assert.That(received!.Id, Is.EqualTo("e_ok"), - "Type filter should only pass delivery.success events"); + "Outcome filter should only pass delivery.success events"); Assert.That(reader.TryRead(out _), Is.False, "No further events should be available"); @@ -173,7 +173,7 @@ public void RegisterSseClient_WithEndpointFilter_OnlyReceivesMatchingEndpoint() var config = MakeConfig("ep1", "ep2"); var metrics = new WebhookMetrics(config); - var (registration, reader) = metrics.RegisterSseClient(null, "ep1"); + var (registration, reader) = metrics.RegisterSseClient(null, "ep1")!.Value; using (registration) { metrics.RecordDelivery("ep2", MakeEvent("ep2", "delivery.success", "from_ep2")); @@ -193,7 +193,7 @@ public void RegisterSseClient_AfterChannelClose_DeadClientAutoCleanedUp() var config = MakeConfig("ep1"); var metrics = new WebhookMetrics(config); - var (registration, reader) = metrics.RegisterSseClient(null, null); + var (registration, reader) = metrics.RegisterSseClient(null, null)!.Value; registration.Dispose(); // closes channel writer // After disposing, broadcasting should not throw diff --git a/tests/clawsharp.Tests/Unit/Webhooks/WebhookPayloadBuilderTests.cs b/tests/clawsharp.Tests/Unit/Webhooks/WebhookPayloadBuilderTests.cs index 0e260c5d..b6a2a54e 100644 --- a/tests/clawsharp.Tests/Unit/Webhooks/WebhookPayloadBuilderTests.cs +++ b/tests/clawsharp.Tests/Unit/Webhooks/WebhookPayloadBuilderTests.cs @@ -84,11 +84,10 @@ public void Build_DataContainsToolNameField() var payload = WebhookPayloadBuilder.Build(evt, source, attr); - // The ToolExecuted record has ToolName property, so it should be in Data - Assert.That(payload.Data.TryGetProperty("tool_name", out var toolNameProp) - || payload.Data.TryGetProperty("ToolName", out toolNameProp), + // The ToolExecuted record has ToolName property; WebhookJsonContext uses camelCase naming policy + Assert.That(payload.Data.TryGetProperty("toolName", out var toolNameProp), Is.True, - "Data should contain tool name field"); + "Data should contain tool name field (camelCase per WebhookJsonContext naming policy)"); } // ── Build — Source propagation ─────────────────────────────────────────── diff --git a/tests/clawsharp.Tests/clawsharp.Tests.csproj b/tests/clawsharp.Tests/clawsharp.Tests.csproj index 4bb7d30e..fe2cd601 100644 --- a/tests/clawsharp.Tests/clawsharp.Tests.csproj +++ b/tests/clawsharp.Tests/clawsharp.Tests.csproj @@ -44,6 +44,8 @@ + +