@@ -23,14 +23,13 @@ import (
23
23
"io"
24
24
"net"
25
25
"sync"
26
-
27
- "google.golang.org/grpc/codes"
28
- "google.golang.org/grpc/status"
29
26
)
30
27
31
28
const (
32
- messageHeaderLength = 10
33
- messageLengthMax = 4 << 20
29
+ messageHeaderLength = 10
30
+ MinMessageLengthLimit = 4 << 10
31
+ MaxMessageLengthLimit = 4 << 22
32
+ DefaultMessageLengthLimit = 4 << 20
34
33
)
35
34
36
35
type messageType uint8
@@ -96,18 +95,23 @@ func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error {
96
95
var buffers sync.Pool
97
96
98
97
type channel struct {
99
- conn net.Conn
100
- bw * bufio.Writer
101
- br * bufio.Reader
102
- hrbuf [messageHeaderLength ]byte // avoid alloc when reading header
103
- hwbuf [messageHeaderLength ]byte
98
+ conn net.Conn
99
+ bw * bufio.Writer
100
+ br * bufio.Reader
101
+ hrbuf [messageHeaderLength ]byte // avoid alloc when reading header
102
+ hwbuf [messageHeaderLength ]byte
103
+ maxMsgLen int
104
104
}
105
105
106
- func newChannel (conn net.Conn ) * channel {
106
+ func newChannel (conn net.Conn , maxMsgLen int ) * channel {
107
+ if maxMsgLen == 0 {
108
+ maxMsgLen = DefaultMessageLengthLimit
109
+ }
107
110
return & channel {
108
- conn : conn ,
109
- bw : bufio .NewWriter (conn ),
110
- br : bufio .NewReader (conn ),
111
+ conn : conn ,
112
+ bw : bufio .NewWriter (conn ),
113
+ br : bufio .NewReader (conn ),
114
+ maxMsgLen : maxMsgLen ,
111
115
}
112
116
}
113
117
@@ -123,12 +127,12 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
123
127
return messageHeader {}, nil , err
124
128
}
125
129
126
- if mh .Length > uint32 (messageLengthMax ) {
130
+ if maxMsgLen := ch . maxMsgLimit ( true ); mh .Length > uint32 (maxMsgLen ) {
127
131
if _ , err := ch .br .Discard (int (mh .Length )); err != nil {
128
132
return mh , nil , fmt .Errorf ("failed to discard after receiving oversized message: %w" , err )
129
133
}
130
134
131
- return mh , nil , status . Errorf ( codes . ResourceExhausted , "message length %v exceed maximum message size of %v" , mh .Length , messageLengthMax )
135
+ return mh , nil , OversizedMessageError ( int ( mh .Length ), maxMsgLen )
132
136
}
133
137
134
138
var p []byte
@@ -143,8 +147,10 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
143
147
}
144
148
145
149
func (ch * channel ) send (streamID uint32 , t messageType , flags uint8 , p []byte ) error {
146
- if len (p ) > messageLengthMax {
147
- return OversizedMessageError (len (p ))
150
+ if maxMsgLen := ch .maxMsgLimit (false ); maxMsgLen != 0 {
151
+ if len (p ) > maxMsgLen {
152
+ return OversizedMessageError (len (p ), maxMsgLen )
153
+ }
148
154
}
149
155
150
156
if err := writeMessageHeader (ch .bw , ch .hwbuf [:], messageHeader {Length : uint32 (len (p )), StreamID : streamID , Type : t , Flags : flags }); err != nil {
@@ -180,3 +186,22 @@ func (ch *channel) getmbuf(size int) []byte {
180
186
func (ch * channel ) putmbuf (p []byte ) {
181
187
buffers .Put (& p )
182
188
}
189
+
190
+ func (ch * channel ) maxMsgLimit (recv bool ) int {
191
+ if ch .maxMsgLen == 0 && recv {
192
+ return DefaultMessageLengthLimit
193
+ }
194
+ return ch .maxMsgLen
195
+ }
196
+
197
+ func clampWireMessageLimit (maxMsgLen int ) int {
198
+ switch {
199
+ case maxMsgLen == 0 :
200
+ return 0
201
+ case maxMsgLen < MinMessageLengthLimit :
202
+ return MinMessageLengthLimit
203
+ case maxMsgLen > MaxMessageLengthLimit :
204
+ return MaxMessageLengthLimit
205
+ }
206
+ return maxMsgLen
207
+ }
0 commit comments