Skip to content

Avoid race that can cause Kestrel's RequestAborted to not fire #62385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs
Original file line number Diff line number Diff line change
Expand Up @@ -407,21 +407,20 @@ public void Reset()

_manuallySetRequestAbortToken = null;

// Lock to prevent CancelRequestAbortedToken from attempting to cancel a disposed CTS.
CancellationTokenSource? localAbortCts = null;

lock (_abortLock)
{
_preventRequestAbortedCancellation = false;
if (_abortedCts?.TryReset() == false)

// If the connection has already been aborted, allow that to be observed during the next request.
if (!_connectionAborted && _abortedCts is not null)
{
localAbortCts = _abortedCts;
_abortedCts = null;
// _connectionAborted is terminal and only set inside the _abortLock, so if it isn't set here,
// _abortedCts has not been canceled yet.
var resetSuccess = _abortedCts.TryReset();
Debug.Assert(resetSuccess);
}
}

localAbortCts?.Dispose();

Output?.Reset();

_requestHeadersParsed = 0;
Expand Down Expand Up @@ -760,7 +759,7 @@ private async Task ProcessRequests<TContext>(IHttpApplication<TContext> applicat
}
else if (!HasResponseStarted)
{
// If the request was aborted and no response was sent, we use status code 499 for logging
// If the request was aborted and no response was sent, we use status code 499 for logging
StatusCode = StatusCodes.Status499ClientClosedRequest;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ public bool ReceivedEmptyRequestBody
protected override void OnReset()
{
_keepAlive = true;
_connectionAborted = false;
_userTrailers = null;

// Reset Http2 Features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -602,15 +602,16 @@ private async Task CreateHttp3Stream<TContext>(ConnectionContext streamContext,

// Check whether there is an existing HTTP/3 stream on the transport stream.
// A stream will only be cached if the transport stream itself is reused.
if (!persistentStateFeature.State.TryGetValue(StreamPersistentStateKey, out var s))
if (!persistentStateFeature.State.TryGetValue(StreamPersistentStateKey, out var s) ||
s is not Http3Stream<TContext> { CanReuse: true } reusableStream)
{
stream = new Http3Stream<TContext>(application, CreateHttpStreamContext(streamContext));
persistentStateFeature.State.Add(StreamPersistentStateKey, stream);
persistentStateFeature.State[StreamPersistentStateKey] = stream;
}
else
{
stream = (Http3Stream<TContext>)s!;
stream.InitializeWithExistingContext(streamContext.Transport);
stream = reusableStream;
reusableStream.InitializeWithExistingContext(streamContext.Transport);
}

_streamLifetimeHandler.OnStreamCreated(stream);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ internal abstract partial class Http3Stream : HttpProtocol, IHttp3Stream, IHttpS
private bool IsAbortedRead => (_completionState & StreamCompletionFlags.AbortedRead) == StreamCompletionFlags.AbortedRead;
public bool IsCompleted => (_completionState & StreamCompletionFlags.Completed) == StreamCompletionFlags.Completed;

public bool CanReuse => !_connectionAborted && HasResponseCompleted;

public bool ReceivedEmptyRequestBody
{
get
Expand Down Expand Up @@ -957,7 +959,6 @@ private Task ProcessDataFrameAsync(in ReadOnlySequence<byte> payload)
protected override void OnReset()
{
_keepAlive = true;
_connectionAborted = false;
_userTrailers = null;
_isWebTransportSessionAccepted = false;
_isMethodConnect = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,160 @@ public async Task GET_MultipleRequestsInSequence_ReusedState()
}
}

[ConditionalFact]
[MsQuicSupported]
public async Task GET_RequestAbortedByClient_StateNotReused()
{
// Arrange
object persistedState = null;
var requestCount = 0;
var abortedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
var requestStartedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);

var builder = CreateHostBuilder(async context =>
{
requestCount++;
var persistentStateCollection = context.Features.Get<IPersistentStateFeature>().State;
if (persistentStateCollection.TryGetValue("Counter", out var value))
{
persistedState = value;
}
persistentStateCollection["Counter"] = requestCount;

if (requestCount == 1)
{
// For the first request, wait for RequestAborted to fire before returning
context.RequestAborted.Register(() =>
{
Logger.LogInformation("Server received cancellation");
abortedTcs.SetResult();
});

// Signal that the request has started and is ready to be cancelled
requestStartedTcs.SetResult();

// Wait for the request to be aborted
await abortedTcs.Task;
}
});

using (var host = builder.Build())
using (var client = HttpHelpers.CreateClient())
{
await host.StartAsync();

// Act - Send first request and cancel it
var cts1 = new CancellationTokenSource();
var request1 = new HttpRequestMessage(HttpMethod.Get, $"https://127.0.0.1:{host.GetPort()}/");
request1.Version = HttpVersion.Version30;
request1.VersionPolicy = HttpVersionPolicy.RequestVersionExact;

var responseTask1 = client.SendAsync(request1, cts1.Token);

// Wait for the server to start processing the request
await requestStartedTcs.Task.DefaultTimeout();

// Cancel the first request
cts1.Cancel();
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => responseTask1).DefaultTimeout();

// Wait for the server to process the abort
await abortedTcs.Task.DefaultTimeout();

// Store the state from the first (aborted) request
var firstRequestState = persistedState;

// Delay to ensure the stream has enough time to return to pool
await Task.Delay(100);

// Send second request (should not reuse state from aborted request)
var request2 = new HttpRequestMessage(HttpMethod.Get, $"https://127.0.0.1:{host.GetPort()}/");
request2.Version = HttpVersion.Version30;
request2.VersionPolicy = HttpVersionPolicy.RequestVersionExact;

var response2 = await client.SendAsync(request2, CancellationToken.None);
response2.EnsureSuccessStatusCode();
var secondRequestState = persistedState;

// Assert
// First request has no persisted state (it was aborted)
Assert.Null(firstRequestState);

// Second request should also have no persisted state since the first request was aborted
// and state should not be reused from aborted requests
Assert.Null(secondRequestState);

await host.StopAsync();
}
}

[ConditionalFact]
[MsQuicSupported]
public async Task GET_RequestAbortedByServer_StateNotReused()
{
// Arrange
object persistedState = null;
var requestCount = 0;

var builder = CreateHostBuilder(context =>
{
requestCount++;
var persistentStateCollection = context.Features.Get<IPersistentStateFeature>().State;
if (persistentStateCollection.TryGetValue("Counter", out var value))
{
persistedState = value;
}
persistentStateCollection["Counter"] = requestCount;

if (requestCount == 1)
{
context.Abort();
}

return Task.CompletedTask;
});

using (var host = builder.Build())
using (var client = HttpHelpers.CreateClient())
{
await host.StartAsync();

var request1 = new HttpRequestMessage(HttpMethod.Get, $"https://127.0.0.1:{host.GetPort()}/");
request1.Version = HttpVersion.Version30;
request1.VersionPolicy = HttpVersionPolicy.RequestVersionExact;

var responseTask1 = client.SendAsync(request1, CancellationToken.None);
var ex = await Assert.ThrowsAnyAsync<HttpRequestException>(() => responseTask1).DefaultTimeout();
var innerEx = Assert.IsType<HttpProtocolException>(ex.InnerException);
Assert.Equal(Http3ErrorCode.InternalError, (Http3ErrorCode)innerEx.ErrorCode);

// Store the state from the first (aborted) request
var firstRequestState = persistedState;

// Delay to ensure the stream has enough time to return to pool
await Task.Delay(100);

// Send second request (should not reuse state from aborted request)
var request2 = new HttpRequestMessage(HttpMethod.Get, $"https://127.0.0.1:{host.GetPort()}/");
request2.Version = HttpVersion.Version30;
request2.VersionPolicy = HttpVersionPolicy.RequestVersionExact;

var response2 = await client.SendAsync(request2, CancellationToken.None);
response2.EnsureSuccessStatusCode();
var secondRequestState = persistedState;

// Assert
// First request has no persisted state (it was aborted)
Assert.Null(firstRequestState);

// Second request should also have no persisted state since the first request was aborted
// and state should not be reused from aborted requests
Assert.Null(secondRequestState);

await host.StopAsync();
}
}

[ConditionalFact]
[MsQuicSupported]
public async Task GET_MultipleRequests_RequestVersionOrHigher_UpgradeToHttp3()
Expand Down
Loading