@@ -31,6 +31,7 @@ import (
31
31
const (
32
32
messageHeaderLength = 10
33
33
messageLengthMax = 4 << 20
34
+ messageLengthMin = 4 << 10
34
35
)
35
36
36
37
type messageType uint8
@@ -96,18 +97,23 @@ func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error {
96
97
var buffers sync.Pool
97
98
98
99
type channel struct {
99
- conn net.Conn
100
- bw * bufio.Writer
101
- br * bufio.Reader
102
- hrbuf [messageHeaderLength ]byte // avoid alloc when reading header
103
- hwbuf [messageHeaderLength ]byte
100
+ conn net.Conn
101
+ bw * bufio.Writer
102
+ br * bufio.Reader
103
+ hrbuf [messageHeaderLength ]byte // avoid alloc when reading header
104
+ hwbuf [messageHeaderLength ]byte
105
+ maxMsgLen int
104
106
}
105
107
106
- func newChannel (conn net.Conn ) * channel {
108
+ func newChannel (conn net.Conn , maxMsgLen int ) * channel {
109
+ if maxMsgLen == 0 {
110
+ maxMsgLen = messageLengthMax
111
+ }
107
112
return & channel {
108
- conn : conn ,
109
- bw : bufio .NewWriter (conn ),
110
- br : bufio .NewReader (conn ),
113
+ conn : conn ,
114
+ bw : bufio .NewWriter (conn ),
115
+ br : bufio .NewReader (conn ),
116
+ maxMsgLen : maxMsgLen ,
111
117
}
112
118
}
113
119
@@ -123,12 +129,12 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
123
129
return messageHeader {}, nil , err
124
130
}
125
131
126
- if mh .Length > uint32 (messageLengthMax ) {
132
+ if mh .Length > uint32 (ch . maxMsgLen ) {
127
133
if _ , err := ch .br .Discard (int (mh .Length )); err != nil {
128
134
return mh , nil , fmt .Errorf ("failed to discard after receiving oversized message: %w" , err )
129
135
}
130
136
131
- return mh , nil , status .Errorf (codes .ResourceExhausted , "message length %v exceed maximum message size of %v" , mh .Length , messageLengthMax )
137
+ return mh , nil , status .Errorf (codes .ResourceExhausted , "message length %v exceed maximum message size of %v" , mh .Length , ch . maxMsgLen )
132
138
}
133
139
134
140
var p []byte
@@ -147,6 +153,7 @@ func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) e
147
153
//if len(p) > messageLengthMax {
148
154
// return status.Errorf(codes.InvalidArgument, "refusing to send, message length %v exceed maximum message size of %v", len(p), messageLengthMax)
149
155
//}
156
+
150
157
if err := writeMessageHeader (ch .bw , ch .hwbuf [:], messageHeader {Length : uint32 (len (p )), StreamID : streamID , Type : t , Flags : flags }); err != nil {
151
158
return err
152
159
}
0 commit comments