Skip to content

Commit 0753662

Browse files
committed
client,server: configurable wire message size limits.
Implement configurable limits for the maximum accepted message size of the wire protocol. The default limit can be overridden using the WithClientWireMessageLimit() option for clients and using the WithServerWireMessageLimit() option for servers. Signed-off-by: Krisztian Litkey <[email protected]>
1 parent 3f02183 commit 0753662

File tree

4 files changed

+53
-17
lines changed

4 files changed

+53
-17
lines changed

channel.go

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
const (
3232
messageHeaderLength = 10
3333
messageLengthMax = 4 << 20
34+
messageLengthMin = 4 << 10
3435
)
3536

3637
type messageType uint8
@@ -96,18 +97,23 @@ func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error {
9697
var buffers sync.Pool
9798

9899
type channel struct {
99-
conn net.Conn
100-
bw *bufio.Writer
101-
br *bufio.Reader
102-
hrbuf [messageHeaderLength]byte // avoid alloc when reading header
103-
hwbuf [messageHeaderLength]byte
100+
conn net.Conn
101+
bw *bufio.Writer
102+
br *bufio.Reader
103+
hrbuf [messageHeaderLength]byte // avoid alloc when reading header
104+
hwbuf [messageHeaderLength]byte
105+
maxMsgLen int
104106
}
105107

106-
func newChannel(conn net.Conn) *channel {
108+
func newChannel(conn net.Conn, maxMsgLen int) *channel {
109+
if maxMsgLen == 0 {
110+
maxMsgLen = messageLengthMax
111+
}
107112
return &channel{
108-
conn: conn,
109-
bw: bufio.NewWriter(conn),
110-
br: bufio.NewReader(conn),
113+
conn: conn,
114+
bw: bufio.NewWriter(conn),
115+
br: bufio.NewReader(conn),
116+
maxMsgLen: maxMsgLen,
111117
}
112118
}
113119

@@ -123,12 +129,12 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
123129
return messageHeader{}, nil, err
124130
}
125131

126-
if mh.Length > uint32(messageLengthMax) {
132+
if mh.Length > uint32(ch.maxMsgLen) {
127133
if _, err := ch.br.Discard(int(mh.Length)); err != nil {
128134
return mh, nil, fmt.Errorf("failed to discard after receiving oversized message: %w", err)
129135
}
130136

131-
return mh, nil, status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", mh.Length, messageLengthMax)
137+
return mh, nil, status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", mh.Length, ch.maxMsgLen)
132138
}
133139

134140
var p []byte
@@ -147,6 +153,7 @@ func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) e
147153
//if len(p) > messageLengthMax {
148154
// return status.Errorf(codes.InvalidArgument, "refusing to send, message length %v exceed maximum message size of %v", len(p), messageLengthMax)
149155
//}
156+
150157
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil {
151158
return err
152159
}

client.go

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@ import (
3535

3636
// Client for a ttrpc server
3737
type Client struct {
38-
codec codec
39-
conn net.Conn
40-
channel *channel
38+
codec codec
39+
conn net.Conn
40+
channel *channel
41+
maxMsgLen int
4142

4243
streamLock sync.RWMutex
4344
streams map[streamID]*stream
@@ -107,14 +108,25 @@ func chainUnaryInterceptors(interceptors []UnaryClientInterceptor, final Invoker
107108
}
108109
}
109110

111+
// WithClientWireMessageLimit sets the maximum allowed message length on the wire for the client.
112+
func WithClientWireMessageLimit(maxMsgLen int) ClientOpts {
113+
switch {
114+
case maxMsgLen == 0:
115+
maxMsgLen = messageLengthMax
116+
case maxMsgLen < messageLengthMin:
117+
maxMsgLen = messageLengthMin
118+
}
119+
return func(c *Client) {
120+
c.maxMsgLen = maxMsgLen
121+
}
122+
}
123+
110124
// NewClient creates a new ttrpc client using the given connection
111125
func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
112126
ctx, cancel := context.WithCancel(context.Background())
113-
channel := newChannel(conn)
114127
c := &Client{
115128
codec: codec{},
116129
conn: conn,
117-
channel: channel,
118130
streams: make(map[streamID]*stream),
119131
nextStreamID: 1,
120132
closed: cancel,
@@ -127,6 +139,8 @@ func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
127139
o(c)
128140
}
129141

142+
c.channel = newChannel(conn, c.maxMsgLen)
143+
130144
if c.interceptor == nil {
131145
c.interceptor = defaultClientInterceptor
132146
}

config.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
type serverConfig struct {
2525
handshaker Handshaker
2626
interceptor UnaryServerInterceptor
27+
maxMsgLen int
2728
}
2829

2930
// ServerOpt for configuring a ttrpc server
@@ -84,3 +85,17 @@ func chainUnaryServerInterceptors(info *UnaryServerInfo, method Method, intercep
8485
chainUnaryServerInterceptors(info, method, interceptors[1:]))
8586
}
8687
}
88+
89+
// WithServerWireMessageLimit sets the maximum allowed message length on the wire for the server.
90+
func WithServerWireMessageLimit(maxMsgLen int) ServerOpt {
91+
switch {
92+
case maxMsgLen == 0:
93+
maxMsgLen = messageLengthMax
94+
case maxMsgLen < messageLengthMin:
95+
maxMsgLen = messageLengthMin
96+
}
97+
return func(c *serverConfig) error {
98+
c.maxMsgLen = maxMsgLen
99+
return nil
100+
}
101+
}

server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ func (c *serverConn) run(sctx context.Context) {
332332
)
333333

334334
var (
335-
ch = newChannel(c.conn)
335+
ch = newChannel(c.conn, c.server.config.maxMsgLen)
336336
ctx, cancel = context.WithCancel(sctx)
337337
state connState = connStateIdle
338338
responses = make(chan response)

0 commit comments

Comments
 (0)