diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs index 664e466ccefb..094ba0058b87 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs @@ -407,21 +407,20 @@ public void Reset() _manuallySetRequestAbortToken = null; - // Lock to prevent CancelRequestAbortedToken from attempting to cancel a disposed CTS. - CancellationTokenSource? localAbortCts = null; - lock (_abortLock) { _preventRequestAbortedCancellation = false; - if (_abortedCts?.TryReset() == false) + + // If the connection has already been aborted, allow that to be observed during the next request. + if (!_connectionAborted && _abortedCts is not null) { - localAbortCts = _abortedCts; - _abortedCts = null; + // _connectionAborted is terminal and only set inside the _abortLock, so if it isn't set here, + // _abortedCts has not been canceled yet. + var resetSuccess = _abortedCts.TryReset(); + Debug.Assert(resetSuccess); } } - localAbortCts?.Dispose(); - Output?.Reset(); _requestHeadersParsed = 0; @@ -760,7 +759,7 @@ private async Task ProcessRequests(IHttpApplication applicat } else if (!HasResponseStarted) { - // If the request was aborted and no response was sent, we use status code 499 for logging + // If the request was aborted and no response was sent, we use status code 499 for logging StatusCode = StatusCodes.Status499ClientClosedRequest; } } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs index da2b547c03a7..8524b35f3947 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs @@ -129,7 +129,6 @@ public bool ReceivedEmptyRequestBody protected override void OnReset() { _keepAlive = true; - _connectionAborted = false; _userTrailers = null; // Reset Http2 Features diff --git a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Connection.cs b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Connection.cs index 20ec37eb3cc7..c2db22acd8bd 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Connection.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Connection.cs @@ -602,15 +602,16 @@ private async Task CreateHttp3Stream(ConnectionContext streamContext, // Check whether there is an existing HTTP/3 stream on the transport stream. // A stream will only be cached if the transport stream itself is reused. - if (!persistentStateFeature.State.TryGetValue(StreamPersistentStateKey, out var s)) + if (!persistentStateFeature.State.TryGetValue(StreamPersistentStateKey, out var s) || + s is not Http3Stream { CanReuse: true } reusableStream) { stream = new Http3Stream(application, CreateHttpStreamContext(streamContext)); - persistentStateFeature.State.Add(StreamPersistentStateKey, stream); + persistentStateFeature.State[StreamPersistentStateKey] = stream; } else { - stream = (Http3Stream)s!; - stream.InitializeWithExistingContext(streamContext.Transport); + stream = reusableStream; + reusableStream.InitializeWithExistingContext(streamContext.Transport); } _streamLifetimeHandler.OnStreamCreated(stream); diff --git a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs index 672c566b7538..7ada6890f867 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs @@ -66,6 +66,8 @@ internal abstract partial class Http3Stream : HttpProtocol, IHttp3Stream, IHttpS private bool IsAbortedRead => (_completionState & StreamCompletionFlags.AbortedRead) == StreamCompletionFlags.AbortedRead; public bool IsCompleted => (_completionState & StreamCompletionFlags.Completed) == StreamCompletionFlags.Completed; + public bool CanReuse => !_connectionAborted && HasResponseCompleted; + public bool ReceivedEmptyRequestBody { get @@ -957,7 +959,6 @@ private Task ProcessDataFrameAsync(in ReadOnlySequence payload) protected override void OnReset() { _keepAlive = true; - _connectionAborted = false; _userTrailers = null; _isWebTransportSessionAccepted = false; _isMethodConnect = false; diff --git a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs index f76cd141bb89..e0840b525f6d 100644 --- a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs +++ b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs @@ -879,6 +879,160 @@ public async Task GET_MultipleRequestsInSequence_ReusedState() } } + [ConditionalFact] + [MsQuicSupported] + public async Task GET_RequestAbortedByClient_StateNotReused() + { + // Arrange + object persistedState = null; + var requestCount = 0; + var abortedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var requestStartedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var builder = CreateHostBuilder(async context => + { + requestCount++; + var persistentStateCollection = context.Features.Get().State; + if (persistentStateCollection.TryGetValue("Counter", out var value)) + { + persistedState = value; + } + persistentStateCollection["Counter"] = requestCount; + + if (requestCount == 1) + { + // For the first request, wait for RequestAborted to fire before returning + context.RequestAborted.Register(() => + { + Logger.LogInformation("Server received cancellation"); + abortedTcs.SetResult(); + }); + + // Signal that the request has started and is ready to be cancelled + requestStartedTcs.SetResult(); + + // Wait for the request to be aborted + await abortedTcs.Task; + } + }); + + using (var host = builder.Build()) + using (var client = HttpHelpers.CreateClient()) + { + await host.StartAsync(); + + // Act - Send first request and cancel it + var cts1 = new CancellationTokenSource(); + var request1 = new HttpRequestMessage(HttpMethod.Get, $"https://127.0.0.1:{host.GetPort()}/"); + request1.Version = HttpVersion.Version30; + request1.VersionPolicy = HttpVersionPolicy.RequestVersionExact; + + var responseTask1 = client.SendAsync(request1, cts1.Token); + + // Wait for the server to start processing the request + await requestStartedTcs.Task.DefaultTimeout(); + + // Cancel the first request + cts1.Cancel(); + await Assert.ThrowsAnyAsync(() => responseTask1).DefaultTimeout(); + + // Wait for the server to process the abort + await abortedTcs.Task.DefaultTimeout(); + + // Store the state from the first (aborted) request + var firstRequestState = persistedState; + + // Delay to ensure the stream has enough time to return to pool + await Task.Delay(100); + + // Send second request (should not reuse state from aborted request) + var request2 = new HttpRequestMessage(HttpMethod.Get, $"https://127.0.0.1:{host.GetPort()}/"); + request2.Version = HttpVersion.Version30; + request2.VersionPolicy = HttpVersionPolicy.RequestVersionExact; + + var response2 = await client.SendAsync(request2, CancellationToken.None); + response2.EnsureSuccessStatusCode(); + var secondRequestState = persistedState; + + // Assert + // First request has no persisted state (it was aborted) + Assert.Null(firstRequestState); + + // Second request should also have no persisted state since the first request was aborted + // and state should not be reused from aborted requests + Assert.Null(secondRequestState); + + await host.StopAsync(); + } + } + + [ConditionalFact] + [MsQuicSupported] + public async Task GET_RequestAbortedByServer_StateNotReused() + { + // Arrange + object persistedState = null; + var requestCount = 0; + + var builder = CreateHostBuilder(context => + { + requestCount++; + var persistentStateCollection = context.Features.Get().State; + if (persistentStateCollection.TryGetValue("Counter", out var value)) + { + persistedState = value; + } + persistentStateCollection["Counter"] = requestCount; + + if (requestCount == 1) + { + context.Abort(); + } + + return Task.CompletedTask; + }); + + using (var host = builder.Build()) + using (var client = HttpHelpers.CreateClient()) + { + await host.StartAsync(); + + var request1 = new HttpRequestMessage(HttpMethod.Get, $"https://127.0.0.1:{host.GetPort()}/"); + request1.Version = HttpVersion.Version30; + request1.VersionPolicy = HttpVersionPolicy.RequestVersionExact; + + var responseTask1 = client.SendAsync(request1, CancellationToken.None); + var ex = await Assert.ThrowsAnyAsync(() => responseTask1).DefaultTimeout(); + var innerEx = Assert.IsType(ex.InnerException); + Assert.Equal(Http3ErrorCode.InternalError, (Http3ErrorCode)innerEx.ErrorCode); + + // Store the state from the first (aborted) request + var firstRequestState = persistedState; + + // Delay to ensure the stream has enough time to return to pool + await Task.Delay(100); + + // Send second request (should not reuse state from aborted request) + var request2 = new HttpRequestMessage(HttpMethod.Get, $"https://127.0.0.1:{host.GetPort()}/"); + request2.Version = HttpVersion.Version30; + request2.VersionPolicy = HttpVersionPolicy.RequestVersionExact; + + var response2 = await client.SendAsync(request2, CancellationToken.None); + response2.EnsureSuccessStatusCode(); + var secondRequestState = persistedState; + + // Assert + // First request has no persisted state (it was aborted) + Assert.Null(firstRequestState); + + // Second request should also have no persisted state since the first request was aborted + // and state should not be reused from aborted requests + Assert.Null(secondRequestState); + + await host.StopAsync(); + } + } + [ConditionalFact] [MsQuicSupported] public async Task GET_MultipleRequests_RequestVersionOrHigher_UpgradeToHttp3()