18
18
19
19
using System . Buffers ;
20
20
using System . Buffers . Binary ;
21
+ using System . Collections ;
21
22
using System . Diagnostics ;
22
23
using System . Diagnostics . CodeAnalysis ;
23
24
using System . Runtime . InteropServices ;
@@ -99,27 +100,15 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
99
100
var compressed = ReadCompressedFlag ( buffer [ 0 ] ) ;
100
101
var length = ReadMessageLength ( buffer . AsSpan ( 1 , 4 ) ) ;
101
102
102
- if ( length > 0 )
103
+ if ( length > call . Channel . ReceiveMaxMessageSize )
103
104
{
104
- if ( length > call . Channel . ReceiveMaxMessageSize )
105
- {
106
- throw call . CreateRpcException ( ReceivedMessageExceedsLimitStatus ) ;
107
- }
108
-
109
- // Replace buffer if the message doesn't fit
110
- if ( buffer . Length < length )
111
- {
112
- ArrayPool < byte > . Shared . Return ( buffer ) ;
113
- buffer = ArrayPool < byte > . Shared . Rent ( length ) ;
114
- }
115
-
116
- await ReadMessageContentAsync ( responseStream , buffer , length , cancellationToken ) . ConfigureAwait ( false ) ;
105
+ throw call . CreateRpcException ( ReceivedMessageExceedsLimitStatus ) ;
117
106
}
118
107
119
108
cancellationToken . ThrowIfCancellationRequested ( ) ;
120
109
121
- ReadOnlySequence < byte > payload ;
122
- if ( compressed )
110
+ TResponse message ;
111
+ if ( compressed && length > 0 )
123
112
{
124
113
if ( grpcEncoding == null )
125
114
{
@@ -130,28 +119,57 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
130
119
throw call . CreateRpcException ( IdentityMessageEncodingMessageStatus ) ;
131
120
}
132
121
133
- // Performance improvement would be to decompress without converting to an intermediary byte array
134
- if ( ! TryDecompressMessage ( call . Logger , grpcEncoding , call . Channel . CompressionProviders , buffer , length , out var decompressedMessage ) )
122
+ if ( call . Channel . CompressionProviders . TryGetValue ( grpcEncoding , out var compressionProvider ) )
123
+ {
124
+ GrpcCallLog . DecompressingMessage ( call . Logger , compressionProvider . EncodingName ) ;
125
+ int lastLength ;
126
+ var moreBuffers = new List < byte [ ] > ( ) ;
127
+ try
128
+ {
129
+ using ( var compressionStream = compressionProvider . CreateDecompressionStream ( new LimitedLengthStream ( responseStream , length ) ) )
130
+ {
131
+ lastLength = await ReadStreamToBuffers ( compressionStream , buffer , moreBuffers , length , cancellationToken ) . ConfigureAwait ( false ) ;
132
+ }
133
+ call . DeserializationContext . SetPayload ( BuffersToReadOnlySequence ( buffer , moreBuffers , lastLength ) ) ;
134
+ message = deserializer ( call . DeserializationContext ) ;
135
+ }
136
+ finally
137
+ {
138
+ ArrayPool < byte > . Shared . Return ( buffer ) ;
139
+ buffer = null ;
140
+ foreach ( var byteArray in moreBuffers )
141
+ {
142
+ ArrayPool < byte > . Shared . Return ( byteArray ) ;
143
+ }
144
+ }
145
+ }
146
+ else
135
147
{
136
- var supportedEncodings = new List < string > ( ) ;
137
- supportedEncodings . Add ( GrpcProtocolConstants . IdentityGrpcEncoding ) ;
148
+ var supportedEncodings = new List < string > ( call . Channel . CompressionProviders . Count + 1 ) { GrpcProtocolConstants . IdentityGrpcEncoding } ;
138
149
supportedEncodings . AddRange ( call . Channel . CompressionProviders . Select ( c => c . Key ) ) ;
139
150
throw call . CreateRpcException ( CreateUnknownMessageEncodingMessageStatus ( grpcEncoding , supportedEncodings ) ) ;
140
151
}
141
-
142
- payload = decompressedMessage ;
143
152
}
144
153
else
145
154
{
146
- payload = new ReadOnlySequence < byte > ( buffer , 0 , length ) ;
155
+ if ( length > 0 )
156
+ {
157
+ // Replace buffer if the message doesn't fit
158
+ if ( buffer . Length < length )
159
+ {
160
+ ArrayPool < byte > . Shared . Return ( buffer ) ;
161
+ buffer = ArrayPool < byte > . Shared . Rent ( length ) ;
162
+ }
163
+ await ReadMessageContentAsync ( responseStream , buffer , length , cancellationToken ) . ConfigureAwait ( false ) ;
164
+ }
165
+ var payload = new ReadOnlySequence < byte > ( buffer , 0 , length ) ;
166
+ call . DeserializationContext . SetPayload ( payload ) ;
167
+ message = deserializer ( call . DeserializationContext ) ;
147
168
}
169
+ call . DeserializationContext . SetPayload ( null ) ;
148
170
149
171
GrpcCallLog . DeserializingMessage ( call . Logger , length , typeof ( TResponse ) ) ;
150
172
151
- call . DeserializationContext . SetPayload ( payload ) ;
152
- var message = deserializer ( call . DeserializationContext ) ;
153
- call . DeserializationContext . SetPayload ( null ) ;
154
-
155
173
if ( singleMessage )
156
174
{
157
175
// Check that there is no additional content in the stream for a single message
@@ -251,24 +269,79 @@ private static async Task ReadMessageContentAsync(Stream responseStream, Memory<
251
269
}
252
270
}
253
271
254
- private static bool TryDecompressMessage ( ILogger logger , string compressionEncoding , Dictionary < string , ICompressionProvider > compressionProviders , byte [ ] messageData , int length , out ReadOnlySequence < byte > result )
272
+ private sealed class LimitedLengthStream ( Stream stream , int length ) : Stream
255
273
{
256
- if ( compressionProviders . TryGetValue ( compressionEncoding , out var compressionProvider ) )
274
+ private int _bytesRead ;
275
+ public override int Read ( byte [ ] buffer , int offset , int count )
257
276
{
258
- GrpcCallLog . DecompressingMessage ( logger , compressionProvider . EncodingName ) ;
277
+ var bytesToRead = Math . Min ( count , length - _bytesRead ) ;
278
+ if ( bytesToRead <= 0 )
279
+ {
280
+ return 0 ;
281
+ }
282
+ var bytesRead = stream . Read ( buffer , offset , bytesToRead ) ;
283
+ _bytesRead += bytesRead ;
284
+ return bytesRead ;
285
+ }
286
+ public override bool CanRead => true ;
287
+ public override bool CanSeek => false ;
288
+ public override bool CanWrite => false ;
289
+ public override long Length => length ;
290
+ public override long Position { get => _bytesRead ; set => throw new NotSupportedException ( ) ; }
291
+ public override void Flush ( ) => throw new NotSupportedException ( ) ;
292
+ public override long Seek ( long offset , SeekOrigin origin ) => throw new NotSupportedException ( ) ;
293
+ public override void SetLength ( long value ) => throw new NotSupportedException ( ) ;
294
+ public override void Write ( byte [ ] buffer , int offset , int count ) => throw new NotSupportedException ( ) ;
295
+ }
259
296
260
- var output = new MemoryStream ( ) ;
261
- using ( var compressionStream = compressionProvider . CreateDecompressionStream ( new MemoryStream ( messageData , 0 , length , writable : true , publiclyVisible : true ) ) )
297
+ private static async Task < int > ReadStreamToBuffers ( Stream stream , byte [ ] buffer , List < byte [ ] > moreBuffers , int moreLength , CancellationToken cancellationToken )
298
+ {
299
+ while ( true )
300
+ {
301
+ var offset = 0 ;
302
+ while ( offset < buffer . Length )
262
303
{
263
- compressionStream . CopyTo ( output ) ;
304
+ var read = await stream . ReadAsync ( buffer . AsMemory ( offset ) , cancellationToken ) . ConfigureAwait ( false ) ;
305
+ if ( read == 0 )
306
+ {
307
+ return offset ;
308
+ }
309
+ offset += read ;
264
310
}
311
+ moreBuffers . Add ( buffer = ArrayPool < byte > . Shared . Rent ( moreLength ) ) ;
312
+ }
313
+ }
265
314
266
- result = new ReadOnlySequence < byte > ( output . GetBuffer ( ) , 0 , ( int ) output . Length ) ;
267
- return true ;
315
+ private static ReadOnlySequence < byte > BuffersToReadOnlySequence ( byte [ ] buffer , List < byte [ ] > moreBuffers , int lastLength )
316
+ {
317
+ if ( moreBuffers . Count == 0 )
318
+ {
319
+ return new ReadOnlySequence < byte > ( buffer , 0 , lastLength ) ;
268
320
}
321
+ var runningIndex = buffer . Length ;
322
+ for ( var i = moreBuffers . Count - 2 ; i >= 0 ; i -- )
323
+ {
324
+ runningIndex += moreBuffers [ i ] . Length ;
325
+ }
326
+ var endSegment = new ReadOnlySequenceSegmentByte ( moreBuffers [ moreBuffers . Count - 1 ] . AsMemory ( 0 , lastLength ) , null , runningIndex ) ;
327
+ var startSegment = endSegment ;
328
+ for ( var i = moreBuffers . Count - 2 ; i >= 0 ; i -- )
329
+ {
330
+ var bytes = moreBuffers [ i ] ;
331
+ startSegment = new ReadOnlySequenceSegmentByte ( bytes , startSegment , runningIndex -= bytes . Length ) ;
332
+ }
333
+ startSegment = new ReadOnlySequenceSegmentByte ( buffer , startSegment , 0 ) ;
334
+ return new ReadOnlySequence < byte > ( startSegment , 0 , endSegment , lastLength ) ;
335
+ }
269
336
270
- result = default ;
271
- return false ;
337
+ private sealed class ReadOnlySequenceSegmentByte : ReadOnlySequenceSegment < byte >
338
+ {
339
+ public ReadOnlySequenceSegmentByte ( ReadOnlyMemory < byte > memory , ReadOnlySequenceSegmentByte ? next , int runningIndex )
340
+ {
341
+ Memory = memory ;
342
+ Next = next ;
343
+ RunningIndex = runningIndex ;
344
+ }
272
345
}
273
346
274
347
private static bool ReadCompressedFlag ( byte flag )
0 commit comments