Skip to content

Commit 4fa7460

Browse files
committed
memory improvement for ReadMessageAsync compression
1 parent d170b24 commit 4fa7460

File tree

1 file changed

+110
-37
lines changed

1 file changed

+110
-37
lines changed

src/Grpc.Net.Client/Internal/StreamExtensions.cs

Lines changed: 110 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
using System.Buffers;
2020
using System.Buffers.Binary;
21+
using System.Collections;
2122
using System.Diagnostics;
2223
using System.Diagnostics.CodeAnalysis;
2324
using System.Runtime.InteropServices;
@@ -99,27 +100,15 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
99100
var compressed = ReadCompressedFlag(buffer[0]);
100101
var length = ReadMessageLength(buffer.AsSpan(1, 4));
101102

102-
if (length > 0)
103+
if (length > call.Channel.ReceiveMaxMessageSize)
103104
{
104-
if (length > call.Channel.ReceiveMaxMessageSize)
105-
{
106-
throw call.CreateRpcException(ReceivedMessageExceedsLimitStatus);
107-
}
108-
109-
// Replace buffer if the message doesn't fit
110-
if (buffer.Length < length)
111-
{
112-
ArrayPool<byte>.Shared.Return(buffer);
113-
buffer = ArrayPool<byte>.Shared.Rent(length);
114-
}
115-
116-
await ReadMessageContentAsync(responseStream, buffer, length, cancellationToken).ConfigureAwait(false);
105+
throw call.CreateRpcException(ReceivedMessageExceedsLimitStatus);
117106
}
118107

119108
cancellationToken.ThrowIfCancellationRequested();
120109

121-
ReadOnlySequence<byte> payload;
122-
if (compressed)
110+
TResponse message;
111+
if (compressed && length > 0)
123112
{
124113
if (grpcEncoding == null)
125114
{
@@ -130,28 +119,57 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
130119
throw call.CreateRpcException(IdentityMessageEncodingMessageStatus);
131120
}
132121

133-
// Performance improvement would be to decompress without converting to an intermediary byte array
134-
if (!TryDecompressMessage(call.Logger, grpcEncoding, call.Channel.CompressionProviders, buffer, length, out var decompressedMessage))
122+
if (call.Channel.CompressionProviders.TryGetValue(grpcEncoding, out var compressionProvider))
123+
{
124+
GrpcCallLog.DecompressingMessage(call.Logger, compressionProvider.EncodingName);
125+
int lastLength;
126+
var moreBuffers = new List<byte[]>();
127+
try
128+
{
129+
using (var compressionStream = compressionProvider.CreateDecompressionStream(new LimitedLengthStream(responseStream, length)))
130+
{
131+
lastLength = await ReadStreamToBuffers(compressionStream, buffer, moreBuffers, length, cancellationToken).ConfigureAwait(false);
132+
}
133+
call.DeserializationContext.SetPayload(BuffersToReadOnlySequence(buffer, moreBuffers, lastLength));
134+
message = deserializer(call.DeserializationContext);
135+
}
136+
finally
137+
{
138+
ArrayPool<byte>.Shared.Return(buffer);
139+
buffer = null;
140+
foreach (var byteArray in moreBuffers)
141+
{
142+
ArrayPool<byte>.Shared.Return(byteArray);
143+
}
144+
}
145+
}
146+
else
135147
{
136-
var supportedEncodings = new List<string>();
137-
supportedEncodings.Add(GrpcProtocolConstants.IdentityGrpcEncoding);
148+
var supportedEncodings = new List<string>(call.Channel.CompressionProviders.Count + 1) { GrpcProtocolConstants.IdentityGrpcEncoding };
138149
supportedEncodings.AddRange(call.Channel.CompressionProviders.Select(c => c.Key));
139150
throw call.CreateRpcException(CreateUnknownMessageEncodingMessageStatus(grpcEncoding, supportedEncodings));
140151
}
141-
142-
payload = decompressedMessage;
143152
}
144153
else
145154
{
146-
payload = new ReadOnlySequence<byte>(buffer, 0, length);
155+
if (length > 0)
156+
{
157+
// Replace buffer if the message doesn't fit
158+
if (buffer.Length < length)
159+
{
160+
ArrayPool<byte>.Shared.Return(buffer);
161+
buffer = ArrayPool<byte>.Shared.Rent(length);
162+
}
163+
await ReadMessageContentAsync(responseStream, buffer, length, cancellationToken).ConfigureAwait(false);
164+
}
165+
var payload = new ReadOnlySequence<byte>(buffer, 0, length);
166+
call.DeserializationContext.SetPayload(payload);
167+
message = deserializer(call.DeserializationContext);
147168
}
169+
call.DeserializationContext.SetPayload(null);
148170

149171
GrpcCallLog.DeserializingMessage(call.Logger, length, typeof(TResponse));
150172

151-
call.DeserializationContext.SetPayload(payload);
152-
var message = deserializer(call.DeserializationContext);
153-
call.DeserializationContext.SetPayload(null);
154-
155173
if (singleMessage)
156174
{
157175
// Check that there is no additional content in the stream for a single message
@@ -251,24 +269,79 @@ private static async Task ReadMessageContentAsync(Stream responseStream, Memory<
251269
}
252270
}
253271

254-
private static bool TryDecompressMessage(ILogger logger, string compressionEncoding, Dictionary<string, ICompressionProvider> compressionProviders, byte[] messageData, int length, out ReadOnlySequence<byte> result)
272+
private sealed class LimitedLengthStream(Stream stream, int length) : Stream
255273
{
256-
if (compressionProviders.TryGetValue(compressionEncoding, out var compressionProvider))
274+
private int _bytesRead;
275+
public override int Read(byte[] buffer, int offset, int count)
257276
{
258-
GrpcCallLog.DecompressingMessage(logger, compressionProvider.EncodingName);
277+
var bytesToRead = Math.Min(count, length - _bytesRead);
278+
if (bytesToRead <= 0)
279+
{
280+
return 0;
281+
}
282+
var bytesRead = stream.Read(buffer, offset, bytesToRead);
283+
_bytesRead += bytesRead;
284+
return bytesRead;
285+
}
286+
public override bool CanRead => true;
287+
public override bool CanSeek => false;
288+
public override bool CanWrite => false;
289+
public override long Length => length;
290+
public override long Position { get => _bytesRead; set => throw new NotSupportedException(); }
291+
public override void Flush() => throw new NotSupportedException();
292+
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
293+
public override void SetLength(long value) => throw new NotSupportedException();
294+
public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException();
295+
}
259296

260-
var output = new MemoryStream();
261-
using (var compressionStream = compressionProvider.CreateDecompressionStream(new MemoryStream(messageData, 0, length, writable: true, publiclyVisible: true)))
297+
private static async Task<int> ReadStreamToBuffers(Stream stream, byte[] buffer, List<byte[]> moreBuffers, int moreLength, CancellationToken cancellationToken)
298+
{
299+
while (true)
300+
{
301+
var offset = 0;
302+
while (offset < buffer.Length)
262303
{
263-
compressionStream.CopyTo(output);
304+
var read = await stream.ReadAsync(buffer.AsMemory(offset), cancellationToken).ConfigureAwait(false);
305+
if (read == 0)
306+
{
307+
return offset;
308+
}
309+
offset += read;
264310
}
311+
moreBuffers.Add(buffer = ArrayPool<byte>.Shared.Rent(moreLength));
312+
}
313+
}
265314

266-
result = new ReadOnlySequence<byte>(output.GetBuffer(), 0, (int)output.Length);
267-
return true;
315+
private static ReadOnlySequence<byte> BuffersToReadOnlySequence(byte[] buffer, List<byte[]> moreBuffers, int lastLength)
316+
{
317+
if (moreBuffers.Count == 0)
318+
{
319+
return new ReadOnlySequence<byte>(buffer, 0, lastLength);
268320
}
321+
var runningIndex = buffer.Length;
322+
for (var i = moreBuffers.Count - 2; i >= 0; i--)
323+
{
324+
runningIndex += moreBuffers[i].Length;
325+
}
326+
var endSegment = new ReadOnlySequenceSegmentByte(moreBuffers[moreBuffers.Count - 1].AsMemory(0, lastLength), null, runningIndex);
327+
var startSegment = endSegment;
328+
for (var i = moreBuffers.Count - 2; i >= 0; i--)
329+
{
330+
var bytes = moreBuffers[i];
331+
startSegment = new ReadOnlySequenceSegmentByte(bytes, startSegment, runningIndex -= bytes.Length);
332+
}
333+
startSegment = new ReadOnlySequenceSegmentByte(buffer, startSegment, 0);
334+
return new ReadOnlySequence<byte>(startSegment, 0, endSegment, lastLength);
335+
}
269336

270-
result = default;
271-
return false;
337+
private sealed class ReadOnlySequenceSegmentByte : ReadOnlySequenceSegment<byte>
338+
{
339+
public ReadOnlySequenceSegmentByte(ReadOnlyMemory<byte> memory, ReadOnlySequenceSegmentByte? next, int runningIndex)
340+
{
341+
Memory = memory;
342+
Next = next;
343+
RunningIndex = runningIndex;
344+
}
272345
}
273346

274347
private static bool ReadCompressedFlag(byte flag)

0 commit comments

Comments
 (0)