Skip to content

Commit dc60a13

Browse files
committed
memory improvement for ReadMessageAsync compression
1 parent d170b24 commit dc60a13

File tree

1 file changed

+112
-37
lines changed

1 file changed

+112
-37
lines changed

src/Grpc.Net.Client/Internal/StreamExtensions.cs

Lines changed: 112 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
using System.Buffers;
2020
using System.Buffers.Binary;
21+
using System.Collections;
2122
using System.Diagnostics;
2223
using System.Diagnostics.CodeAnalysis;
2324
using System.Runtime.InteropServices;
@@ -99,27 +100,15 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
99100
var compressed = ReadCompressedFlag(buffer[0]);
100101
var length = ReadMessageLength(buffer.AsSpan(1, 4));
101102

102-
if (length > 0)
103+
if (length > call.Channel.ReceiveMaxMessageSize)
103104
{
104-
if (length > call.Channel.ReceiveMaxMessageSize)
105-
{
106-
throw call.CreateRpcException(ReceivedMessageExceedsLimitStatus);
107-
}
108-
109-
// Replace buffer if the message doesn't fit
110-
if (buffer.Length < length)
111-
{
112-
ArrayPool<byte>.Shared.Return(buffer);
113-
buffer = ArrayPool<byte>.Shared.Rent(length);
114-
}
115-
116-
await ReadMessageContentAsync(responseStream, buffer, length, cancellationToken).ConfigureAwait(false);
105+
throw call.CreateRpcException(ReceivedMessageExceedsLimitStatus);
117106
}
118107

119108
cancellationToken.ThrowIfCancellationRequested();
120109

121-
ReadOnlySequence<byte> payload;
122-
if (compressed)
110+
TResponse message;
111+
if (compressed && length > 0)
123112
{
124113
if (grpcEncoding == null)
125114
{
@@ -130,28 +119,55 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
130119
throw call.CreateRpcException(IdentityMessageEncodingMessageStatus);
131120
}
132121

133-
// Performance improvement would be to decompress without converting to an intermediary byte array
134-
if (!TryDecompressMessage(call.Logger, grpcEncoding, call.Channel.CompressionProviders, buffer, length, out var decompressedMessage))
122+
if (call.Channel.CompressionProviders.TryGetValue(grpcEncoding, out var compressionProvider))
123+
{
124+
GrpcCallLog.DecompressingMessage(call.Logger, compressionProvider.EncodingName);
125+
var moreBuffers = new List<byte[]>();
126+
try
127+
{
128+
int lastLength;
129+
using (var compressionStream = compressionProvider.CreateDecompressionStream(new FixedLengthStream(responseStream, length)))
130+
{
131+
var underLohLength = Math.Min(Math.Max(4096, length), 65536);
132+
lastLength = await ReadStreamToBuffers(compressionStream, buffer, moreBuffers, underLohLength, cancellationToken).ConfigureAwait(false);
133+
}
134+
call.DeserializationContext.SetPayload(BuffersToReadOnlySequence(buffer, moreBuffers, lastLength));
135+
message = deserializer(call.DeserializationContext);
136+
}
137+
finally
138+
{
139+
foreach (var byteArray in moreBuffers)
140+
{
141+
ArrayPool<byte>.Shared.Return(byteArray);
142+
}
143+
}
144+
}
145+
else
135146
{
136-
var supportedEncodings = new List<string>();
137-
supportedEncodings.Add(GrpcProtocolConstants.IdentityGrpcEncoding);
147+
var supportedEncodings = new List<string>(call.Channel.CompressionProviders.Count + 1) { GrpcProtocolConstants.IdentityGrpcEncoding };
138148
supportedEncodings.AddRange(call.Channel.CompressionProviders.Select(c => c.Key));
139149
throw call.CreateRpcException(CreateUnknownMessageEncodingMessageStatus(grpcEncoding, supportedEncodings));
140150
}
141-
142-
payload = decompressedMessage;
143151
}
144152
else
145153
{
146-
payload = new ReadOnlySequence<byte>(buffer, 0, length);
154+
if (length > 0)
155+
{
156+
// Replace buffer if the message doesn't fit
157+
if (buffer.Length < length)
158+
{
159+
ArrayPool<byte>.Shared.Return(buffer);
160+
buffer = ArrayPool<byte>.Shared.Rent(length);
161+
}
162+
await ReadMessageContentAsync(responseStream, buffer, length, cancellationToken).ConfigureAwait(false);
163+
}
164+
call.DeserializationContext.SetPayload(new ReadOnlySequence<byte>(buffer, 0, length));
165+
message = deserializer(call.DeserializationContext);
147166
}
167+
call.DeserializationContext.SetPayload(null);
148168

149169
GrpcCallLog.DeserializingMessage(call.Logger, length, typeof(TResponse));
150170

151-
call.DeserializationContext.SetPayload(payload);
152-
var message = deserializer(call.DeserializationContext);
153-
call.DeserializationContext.SetPayload(null);
154-
155171
if (singleMessage)
156172
{
157173
// Check that there is no additional content in the stream for a single message
@@ -251,24 +267,83 @@ private static async Task ReadMessageContentAsync(Stream responseStream, Memory<
251267
}
252268
}
253269

254-
private static bool TryDecompressMessage(ILogger logger, string compressionEncoding, Dictionary<string, ICompressionProvider> compressionProviders, byte[] messageData, int length, out ReadOnlySequence<byte> result)
270+
private sealed class FixedLengthStream(Stream stream, int length) : Stream
255271
{
256-
if (compressionProviders.TryGetValue(compressionEncoding, out var compressionProvider))
272+
private int _bytesRead;
273+
public override int Read(byte[] buffer, int offset, int count)
257274
{
258-
GrpcCallLog.DecompressingMessage(logger, compressionProvider.EncodingName);
275+
var bytesToRead = Math.Min(count, length - _bytesRead);
276+
if (bytesToRead <= 0)
277+
{
278+
return 0;
279+
}
280+
var bytesRead = stream.Read(buffer, offset, bytesToRead);
281+
if (bytesRead == 0)
282+
{
283+
throw new InvalidDataException("Unexpected end of content while reading the message content.");
284+
}
285+
_bytesRead += bytesRead;
286+
return bytesRead;
287+
}
288+
public override bool CanRead => true;
289+
public override bool CanSeek => false;
290+
public override bool CanWrite => false;
291+
public override long Length => length;
292+
public override long Position { get => _bytesRead; set => throw new NotSupportedException(); }
293+
public override void Flush() => throw new NotSupportedException();
294+
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
295+
public override void SetLength(long value) => throw new NotSupportedException();
296+
public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException();
297+
}
259298

260-
var output = new MemoryStream();
261-
using (var compressionStream = compressionProvider.CreateDecompressionStream(new MemoryStream(messageData, 0, length, writable: true, publiclyVisible: true)))
299+
private static async Task<int> ReadStreamToBuffers(Stream stream, byte[] buffer, List<byte[]> moreBuffers, int moreLength, CancellationToken cancellationToken)
300+
{
301+
while (true)
302+
{
303+
var offset = 0;
304+
while (offset < buffer.Length)
262305
{
263-
compressionStream.CopyTo(output);
306+
var read = await stream.ReadAsync(buffer.AsMemory(offset), cancellationToken).ConfigureAwait(false);
307+
if (read == 0)
308+
{
309+
return offset;
310+
}
311+
offset += read;
264312
}
313+
moreBuffers.Add(buffer = ArrayPool<byte>.Shared.Rent(moreLength));
314+
}
315+
}
265316

266-
result = new ReadOnlySequence<byte>(output.GetBuffer(), 0, (int)output.Length);
267-
return true;
317+
private static ReadOnlySequence<byte> BuffersToReadOnlySequence(byte[] buffer, List<byte[]> moreBuffers, int lastLength)
318+
{
319+
if (moreBuffers.Count == 0)
320+
{
321+
return new ReadOnlySequence<byte>(buffer, 0, lastLength);
268322
}
323+
var runningIndex = buffer.Length;
324+
for (var i = moreBuffers.Count - 2; i >= 0; i--)
325+
{
326+
runningIndex += moreBuffers[i].Length;
327+
}
328+
var endSegment = new ReadOnlySequenceSegmentByte(moreBuffers[moreBuffers.Count - 1].AsMemory(0, lastLength), null, runningIndex);
329+
var startSegment = endSegment;
330+
for (var i = moreBuffers.Count - 2; i >= 0; i--)
331+
{
332+
var bytes = moreBuffers[i];
333+
startSegment = new ReadOnlySequenceSegmentByte(bytes, startSegment, runningIndex -= bytes.Length);
334+
}
335+
startSegment = new ReadOnlySequenceSegmentByte(buffer, startSegment, 0);
336+
return new ReadOnlySequence<byte>(startSegment, 0, endSegment, lastLength);
337+
}
269338

270-
result = default;
271-
return false;
339+
private sealed class ReadOnlySequenceSegmentByte : ReadOnlySequenceSegment<byte>
340+
{
341+
public ReadOnlySequenceSegmentByte(ReadOnlyMemory<byte> memory, ReadOnlySequenceSegmentByte? next, int runningIndex)
342+
{
343+
Memory = memory;
344+
Next = next;
345+
RunningIndex = runningIndex;
346+
}
272347
}
273348

274349
private static bool ReadCompressedFlag(byte flag)

0 commit comments

Comments
 (0)