18
18
19
19
using System . Buffers ;
20
20
using System . Buffers . Binary ;
21
+ using System . Collections ;
21
22
using System . Diagnostics ;
22
23
using System . Diagnostics . CodeAnalysis ;
23
24
using System . Runtime . InteropServices ;
@@ -99,27 +100,15 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
99
100
var compressed = ReadCompressedFlag ( buffer [ 0 ] ) ;
100
101
var length = ReadMessageLength ( buffer . AsSpan ( 1 , 4 ) ) ;
101
102
102
- if ( length > 0 )
103
+ if ( length > call . Channel . ReceiveMaxMessageSize )
103
104
{
104
- if ( length > call . Channel . ReceiveMaxMessageSize )
105
- {
106
- throw call . CreateRpcException ( ReceivedMessageExceedsLimitStatus ) ;
107
- }
108
-
109
- // Replace buffer if the message doesn't fit
110
- if ( buffer . Length < length )
111
- {
112
- ArrayPool < byte > . Shared . Return ( buffer ) ;
113
- buffer = ArrayPool < byte > . Shared . Rent ( length ) ;
114
- }
115
-
116
- await ReadMessageContentAsync ( responseStream , buffer , length , cancellationToken ) . ConfigureAwait ( false ) ;
105
+ throw call . CreateRpcException ( ReceivedMessageExceedsLimitStatus ) ;
117
106
}
118
107
119
108
cancellationToken . ThrowIfCancellationRequested ( ) ;
120
109
121
- ReadOnlySequence < byte > payload ;
122
- if ( compressed )
110
+ TResponse message ;
111
+ if ( compressed && length > 0 )
123
112
{
124
113
if ( grpcEncoding == null )
125
114
{
@@ -130,28 +119,55 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
130
119
throw call . CreateRpcException ( IdentityMessageEncodingMessageStatus ) ;
131
120
}
132
121
133
- // Performance improvement would be to decompress without converting to an intermediary byte array
134
- if ( ! TryDecompressMessage ( call . Logger , grpcEncoding , call . Channel . CompressionProviders , buffer , length , out var decompressedMessage ) )
122
+ if ( call . Channel . CompressionProviders . TryGetValue ( grpcEncoding , out var compressionProvider ) )
123
+ {
124
+ GrpcCallLog . DecompressingMessage ( call . Logger , compressionProvider . EncodingName ) ;
125
+ var moreBuffers = new List < byte [ ] > ( ) ;
126
+ try
127
+ {
128
+ int lastLength ;
129
+ using ( var compressionStream = compressionProvider . CreateDecompressionStream ( new FixedLengthStream ( responseStream , length ) ) )
130
+ {
131
+ var underLohLength = Math . Min ( Math . Max ( 4096 , length ) , 65536 ) ;
132
+ lastLength = await ReadStreamToBuffers ( compressionStream , buffer , moreBuffers , underLohLength , cancellationToken ) . ConfigureAwait ( false ) ;
133
+ }
134
+ call . DeserializationContext . SetPayload ( BuffersToReadOnlySequence ( buffer , moreBuffers , lastLength ) ) ;
135
+ message = deserializer ( call . DeserializationContext ) ;
136
+ }
137
+ finally
138
+ {
139
+ foreach ( var byteArray in moreBuffers )
140
+ {
141
+ ArrayPool < byte > . Shared . Return ( byteArray ) ;
142
+ }
143
+ }
144
+ }
145
+ else
135
146
{
136
- var supportedEncodings = new List < string > ( ) ;
137
- supportedEncodings . Add ( GrpcProtocolConstants . IdentityGrpcEncoding ) ;
147
+ var supportedEncodings = new List < string > ( call . Channel . CompressionProviders . Count + 1 ) { GrpcProtocolConstants . IdentityGrpcEncoding } ;
138
148
supportedEncodings . AddRange ( call . Channel . CompressionProviders . Select ( c => c . Key ) ) ;
139
149
throw call . CreateRpcException ( CreateUnknownMessageEncodingMessageStatus ( grpcEncoding , supportedEncodings ) ) ;
140
150
}
141
-
142
- payload = decompressedMessage ;
143
151
}
144
152
else
145
153
{
146
- payload = new ReadOnlySequence < byte > ( buffer , 0 , length ) ;
154
+ if ( length > 0 )
155
+ {
156
+ // Replace buffer if the message doesn't fit
157
+ if ( buffer . Length < length )
158
+ {
159
+ ArrayPool < byte > . Shared . Return ( buffer ) ;
160
+ buffer = ArrayPool < byte > . Shared . Rent ( length ) ;
161
+ }
162
+ await ReadMessageContentAsync ( responseStream , buffer , length , cancellationToken ) . ConfigureAwait ( false ) ;
163
+ }
164
+ call . DeserializationContext . SetPayload ( new ReadOnlySequence < byte > ( buffer , 0 , length ) ) ;
165
+ message = deserializer ( call . DeserializationContext ) ;
147
166
}
167
+ call . DeserializationContext . SetPayload ( null ) ;
148
168
149
169
GrpcCallLog . DeserializingMessage ( call . Logger , length , typeof ( TResponse ) ) ;
150
170
151
- call . DeserializationContext . SetPayload ( payload ) ;
152
- var message = deserializer ( call . DeserializationContext ) ;
153
- call . DeserializationContext . SetPayload ( null ) ;
154
-
155
171
if ( singleMessage )
156
172
{
157
173
// Check that there is no additional content in the stream for a single message
@@ -251,24 +267,83 @@ private static async Task ReadMessageContentAsync(Stream responseStream, Memory<
251
267
}
252
268
}
253
269
254
- private static bool TryDecompressMessage ( ILogger logger , string compressionEncoding , Dictionary < string , ICompressionProvider > compressionProviders , byte [ ] messageData , int length , out ReadOnlySequence < byte > result )
270
+ private sealed class FixedLengthStream ( Stream stream , int length ) : Stream
255
271
{
256
- if ( compressionProviders . TryGetValue ( compressionEncoding , out var compressionProvider ) )
272
+ private int _bytesRead ;
273
+ public override int Read ( byte [ ] buffer , int offset , int count )
257
274
{
258
- GrpcCallLog . DecompressingMessage ( logger , compressionProvider . EncodingName ) ;
275
+ var bytesToRead = Math . Min ( count , length - _bytesRead ) ;
276
+ if ( bytesToRead <= 0 )
277
+ {
278
+ return 0 ;
279
+ }
280
+ var bytesRead = stream . Read ( buffer , offset , bytesToRead ) ;
281
+ if ( bytesRead == 0 )
282
+ {
283
+ throw new InvalidDataException ( "Unexpected end of content while reading the message content." ) ;
284
+ }
285
+ _bytesRead += bytesRead ;
286
+ return bytesRead ;
287
+ }
288
+ public override bool CanRead => true ;
289
+ public override bool CanSeek => false ;
290
+ public override bool CanWrite => false ;
291
+ public override long Length => length ;
292
+ public override long Position { get => _bytesRead ; set => throw new NotSupportedException ( ) ; }
293
+ public override void Flush ( ) => throw new NotSupportedException ( ) ;
294
+ public override long Seek ( long offset , SeekOrigin origin ) => throw new NotSupportedException ( ) ;
295
+ public override void SetLength ( long value ) => throw new NotSupportedException ( ) ;
296
+ public override void Write ( byte [ ] buffer , int offset , int count ) => throw new NotSupportedException ( ) ;
297
+ }
259
298
260
- var output = new MemoryStream ( ) ;
261
- using ( var compressionStream = compressionProvider . CreateDecompressionStream ( new MemoryStream ( messageData , 0 , length , writable : true , publiclyVisible : true ) ) )
299
+ private static async Task < int > ReadStreamToBuffers ( Stream stream , byte [ ] buffer , List < byte [ ] > moreBuffers , int moreLength , CancellationToken cancellationToken )
300
+ {
301
+ while ( true )
302
+ {
303
+ var offset = 0 ;
304
+ while ( offset < buffer . Length )
262
305
{
263
- compressionStream . CopyTo ( output ) ;
306
+ var read = await stream . ReadAsync ( buffer . AsMemory ( offset ) , cancellationToken ) . ConfigureAwait ( false ) ;
307
+ if ( read == 0 )
308
+ {
309
+ return offset ;
310
+ }
311
+ offset += read ;
264
312
}
313
+ moreBuffers . Add ( buffer = ArrayPool < byte > . Shared . Rent ( moreLength ) ) ;
314
+ }
315
+ }
265
316
266
- result = new ReadOnlySequence < byte > ( output . GetBuffer ( ) , 0 , ( int ) output . Length ) ;
267
- return true ;
317
+ private static ReadOnlySequence < byte > BuffersToReadOnlySequence ( byte [ ] buffer , List < byte [ ] > moreBuffers , int lastLength )
318
+ {
319
+ if ( moreBuffers . Count == 0 )
320
+ {
321
+ return new ReadOnlySequence < byte > ( buffer , 0 , lastLength ) ;
268
322
}
323
+ var runningIndex = buffer . Length ;
324
+ for ( var i = moreBuffers . Count - 2 ; i >= 0 ; i -- )
325
+ {
326
+ runningIndex += moreBuffers [ i ] . Length ;
327
+ }
328
+ var endSegment = new ReadOnlySequenceSegmentByte ( moreBuffers [ moreBuffers . Count - 1 ] . AsMemory ( 0 , lastLength ) , null , runningIndex ) ;
329
+ var startSegment = endSegment ;
330
+ for ( var i = moreBuffers . Count - 2 ; i >= 0 ; i -- )
331
+ {
332
+ var bytes = moreBuffers [ i ] ;
333
+ startSegment = new ReadOnlySequenceSegmentByte ( bytes , startSegment , runningIndex -= bytes . Length ) ;
334
+ }
335
+ startSegment = new ReadOnlySequenceSegmentByte ( buffer , startSegment , 0 ) ;
336
+ return new ReadOnlySequence < byte > ( startSegment , 0 , endSegment , lastLength ) ;
337
+ }
269
338
270
- result = default ;
271
- return false ;
339
+ private sealed class ReadOnlySequenceSegmentByte : ReadOnlySequenceSegment < byte >
340
+ {
341
+ public ReadOnlySequenceSegmentByte ( ReadOnlyMemory < byte > memory , ReadOnlySequenceSegmentByte ? next , int runningIndex )
342
+ {
343
+ Memory = memory ;
344
+ Next = next ;
345
+ RunningIndex = runningIndex ;
346
+ }
272
347
}
273
348
274
349
private static bool ReadCompressedFlag ( byte flag )
0 commit comments