From dc60a1391d8fa9a25c7e5c272e771d92f4e71947 Mon Sep 17 00:00:00 2001 From: Anthony Lloyd Date: Tue, 22 Apr 2025 21:45:19 +0100 Subject: [PATCH] memory improvement for ReadMessageAsync compression --- .../Internal/StreamExtensions.cs | 149 +++++++++++++----- 1 file changed, 112 insertions(+), 37 deletions(-) diff --git a/src/Grpc.Net.Client/Internal/StreamExtensions.cs b/src/Grpc.Net.Client/Internal/StreamExtensions.cs index 7048a6fba..19e3cc6c4 100644 --- a/src/Grpc.Net.Client/Internal/StreamExtensions.cs +++ b/src/Grpc.Net.Client/Internal/StreamExtensions.cs @@ -18,6 +18,7 @@ using System.Buffers; using System.Buffers.Binary; +using System.Collections; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Runtime.InteropServices; @@ -99,27 +100,15 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport var compressed = ReadCompressedFlag(buffer[0]); var length = ReadMessageLength(buffer.AsSpan(1, 4)); - if (length > 0) + if (length > call.Channel.ReceiveMaxMessageSize) { - if (length > call.Channel.ReceiveMaxMessageSize) - { - throw call.CreateRpcException(ReceivedMessageExceedsLimitStatus); - } - - // Replace buffer if the message doesn't fit - if (buffer.Length < length) - { - ArrayPool.Shared.Return(buffer); - buffer = ArrayPool.Shared.Rent(length); - } - - await ReadMessageContentAsync(responseStream, buffer, length, cancellationToken).ConfigureAwait(false); + throw call.CreateRpcException(ReceivedMessageExceedsLimitStatus); } cancellationToken.ThrowIfCancellationRequested(); - ReadOnlySequence payload; - if (compressed) + TResponse message; + if (compressed && length > 0) { if (grpcEncoding == null) { @@ -130,28 +119,55 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport throw call.CreateRpcException(IdentityMessageEncodingMessageStatus); } - // Performance improvement would be to decompress without converting to an intermediary byte array - if (!TryDecompressMessage(call.Logger, grpcEncoding, call.Channel.CompressionProviders, buffer, length, out var decompressedMessage)) + if (call.Channel.CompressionProviders.TryGetValue(grpcEncoding, out var compressionProvider)) + { + GrpcCallLog.DecompressingMessage(call.Logger, compressionProvider.EncodingName); + var moreBuffers = new List(); + try + { + int lastLength; + using (var compressionStream = compressionProvider.CreateDecompressionStream(new FixedLengthStream(responseStream, length))) + { + var underLohLength = Math.Min(Math.Max(4096, length), 65536); + lastLength = await ReadStreamToBuffers(compressionStream, buffer, moreBuffers, underLohLength, cancellationToken).ConfigureAwait(false); + } + call.DeserializationContext.SetPayload(BuffersToReadOnlySequence(buffer, moreBuffers, lastLength)); + message = deserializer(call.DeserializationContext); + } + finally + { + foreach (var byteArray in moreBuffers) + { + ArrayPool.Shared.Return(byteArray); + } + } + } + else { - var supportedEncodings = new List(); - supportedEncodings.Add(GrpcProtocolConstants.IdentityGrpcEncoding); + var supportedEncodings = new List(call.Channel.CompressionProviders.Count + 1) { GrpcProtocolConstants.IdentityGrpcEncoding }; supportedEncodings.AddRange(call.Channel.CompressionProviders.Select(c => c.Key)); throw call.CreateRpcException(CreateUnknownMessageEncodingMessageStatus(grpcEncoding, supportedEncodings)); } - - payload = decompressedMessage; } else { - payload = new ReadOnlySequence(buffer, 0, length); + if (length > 0) + { + // Replace buffer if the message doesn't fit + if (buffer.Length < length) + { + ArrayPool.Shared.Return(buffer); + buffer = ArrayPool.Shared.Rent(length); + } + await ReadMessageContentAsync(responseStream, buffer, length, cancellationToken).ConfigureAwait(false); + } + call.DeserializationContext.SetPayload(new ReadOnlySequence(buffer, 0, length)); + message = deserializer(call.DeserializationContext); } + call.DeserializationContext.SetPayload(null); GrpcCallLog.DeserializingMessage(call.Logger, length, typeof(TResponse)); - call.DeserializationContext.SetPayload(payload); - var message = deserializer(call.DeserializationContext); - call.DeserializationContext.SetPayload(null); - if (singleMessage) { // Check that there is no additional content in the stream for a single message @@ -251,24 +267,83 @@ private static async Task ReadMessageContentAsync(Stream responseStream, Memory< } } - private static bool TryDecompressMessage(ILogger logger, string compressionEncoding, Dictionary compressionProviders, byte[] messageData, int length, out ReadOnlySequence result) + private sealed class FixedLengthStream(Stream stream, int length) : Stream { - if (compressionProviders.TryGetValue(compressionEncoding, out var compressionProvider)) + private int _bytesRead; + public override int Read(byte[] buffer, int offset, int count) { - GrpcCallLog.DecompressingMessage(logger, compressionProvider.EncodingName); + var bytesToRead = Math.Min(count, length - _bytesRead); + if (bytesToRead <= 0) + { + return 0; + } + var bytesRead = stream.Read(buffer, offset, bytesToRead); + if (bytesRead == 0) + { + throw new InvalidDataException("Unexpected end of content while reading the message content."); + } + _bytesRead += bytesRead; + return bytesRead; + } + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => false; + public override long Length => length; + public override long Position { get => _bytesRead; set => throw new NotSupportedException(); } + public override void Flush() => throw new NotSupportedException(); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + } - var output = new MemoryStream(); - using (var compressionStream = compressionProvider.CreateDecompressionStream(new MemoryStream(messageData, 0, length, writable: true, publiclyVisible: true))) + private static async Task ReadStreamToBuffers(Stream stream, byte[] buffer, List moreBuffers, int moreLength, CancellationToken cancellationToken) + { + while (true) + { + var offset = 0; + while (offset < buffer.Length) { - compressionStream.CopyTo(output); + var read = await stream.ReadAsync(buffer.AsMemory(offset), cancellationToken).ConfigureAwait(false); + if (read == 0) + { + return offset; + } + offset += read; } + moreBuffers.Add(buffer = ArrayPool.Shared.Rent(moreLength)); + } + } - result = new ReadOnlySequence(output.GetBuffer(), 0, (int)output.Length); - return true; + private static ReadOnlySequence BuffersToReadOnlySequence(byte[] buffer, List moreBuffers, int lastLength) + { + if (moreBuffers.Count == 0) + { + return new ReadOnlySequence(buffer, 0, lastLength); } + var runningIndex = buffer.Length; + for (var i = moreBuffers.Count - 2; i >= 0; i--) + { + runningIndex += moreBuffers[i].Length; + } + var endSegment = new ReadOnlySequenceSegmentByte(moreBuffers[moreBuffers.Count - 1].AsMemory(0, lastLength), null, runningIndex); + var startSegment = endSegment; + for (var i = moreBuffers.Count - 2; i >= 0; i--) + { + var bytes = moreBuffers[i]; + startSegment = new ReadOnlySequenceSegmentByte(bytes, startSegment, runningIndex -= bytes.Length); + } + startSegment = new ReadOnlySequenceSegmentByte(buffer, startSegment, 0); + return new ReadOnlySequence(startSegment, 0, endSegment, lastLength); + } - result = default; - return false; + private sealed class ReadOnlySequenceSegmentByte : ReadOnlySequenceSegment + { + public ReadOnlySequenceSegmentByte(ReadOnlyMemory memory, ReadOnlySequenceSegmentByte? next, int runningIndex) + { + Memory = memory; + Next = next; + RunningIndex = runningIndex; + } } private static bool ReadCompressedFlag(byte flag)