Skip to content

Commit e1f03b3

Browse files
committed
client,server: configurable wire message size limits.
Implement configurable limits for the maximum accepted message size of the wire protocol. The default limit can be overridden using the WithClientWireMessageLimit() option for clients and using the WithServerWireMessageLimit() option for servers. Add exported constants for the minimum, maximum and default limits. Signed-off-by: Krisztian Litkey <[email protected]>
1 parent 525ddce commit e1f03b3

File tree

5 files changed

+132
-29
lines changed

5 files changed

+132
-29
lines changed

channel.go

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,13 @@ import (
2323
"io"
2424
"net"
2525
"sync"
26-
27-
"google.golang.org/grpc/codes"
28-
"google.golang.org/grpc/status"
2926
)
3027

3128
const (
32-
messageHeaderLength = 10
33-
messageLengthMax = 4 << 20
29+
messageHeaderLength = 10
30+
MinMessageLengthLimit = 4 << 10
31+
MaxMessageLengthLimit = 4 << 22
32+
DefaultMessageLengthLimit = 4 << 20
3433
)
3534

3635
type messageType uint8
@@ -96,18 +95,23 @@ func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error {
9695
var buffers sync.Pool
9796

9897
type channel struct {
99-
conn net.Conn
100-
bw *bufio.Writer
101-
br *bufio.Reader
102-
hrbuf [messageHeaderLength]byte // avoid alloc when reading header
103-
hwbuf [messageHeaderLength]byte
98+
conn net.Conn
99+
bw *bufio.Writer
100+
br *bufio.Reader
101+
hrbuf [messageHeaderLength]byte // avoid alloc when reading header
102+
hwbuf [messageHeaderLength]byte
103+
maxMsgLen int
104104
}
105105

106-
func newChannel(conn net.Conn) *channel {
106+
func newChannel(conn net.Conn, maxMsgLen int) *channel {
107+
if maxMsgLen == 0 {
108+
maxMsgLen = DefaultMessageLengthLimit
109+
}
107110
return &channel{
108-
conn: conn,
109-
bw: bufio.NewWriter(conn),
110-
br: bufio.NewReader(conn),
111+
conn: conn,
112+
bw: bufio.NewWriter(conn),
113+
br: bufio.NewReader(conn),
114+
maxMsgLen: maxMsgLen,
111115
}
112116
}
113117

@@ -123,12 +127,12 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
123127
return messageHeader{}, nil, err
124128
}
125129

126-
if mh.Length > uint32(messageLengthMax) {
130+
if maxMsgLen := ch.maxMsgLimit(true); mh.Length > uint32(maxMsgLen) {
127131
if _, err := ch.br.Discard(int(mh.Length)); err != nil {
128132
return mh, nil, fmt.Errorf("failed to discard after receiving oversized message: %w", err)
129133
}
130134

131-
return mh, nil, status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", mh.Length, messageLengthMax)
135+
return mh, nil, OversizedMessageError(int(mh.Length), maxMsgLen)
132136
}
133137

134138
var p []byte
@@ -143,8 +147,10 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
143147
}
144148

145149
func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) error {
146-
if len(p) > messageLengthMax {
147-
return OversizedMessageError(len(p))
150+
if maxMsgLen := ch.maxMsgLimit(false); maxMsgLen != 0 {
151+
if len(p) > maxMsgLen {
152+
return OversizedMessageError(len(p), maxMsgLen)
153+
}
148154
}
149155

150156
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil {
@@ -180,3 +186,22 @@ func (ch *channel) getmbuf(size int) []byte {
180186
func (ch *channel) putmbuf(p []byte) {
181187
buffers.Put(&p)
182188
}
189+
190+
func (ch *channel) maxMsgLimit(recv bool) int {
191+
if ch.maxMsgLen == 0 && recv {
192+
return DefaultMessageLengthLimit
193+
}
194+
return ch.maxMsgLen
195+
}
196+
197+
func clampWireMessageLimit(maxMsgLen int) int {
198+
switch {
199+
case maxMsgLen == 0:
200+
return 0
201+
case maxMsgLen < MinMessageLengthLimit:
202+
return MinMessageLengthLimit
203+
case maxMsgLen > MaxMessageLengthLimit:
204+
return MaxMessageLengthLimit
205+
}
206+
return maxMsgLen
207+
}

client.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@ import (
3535

3636
// Client for a ttrpc server
3737
type Client struct {
38-
codec codec
39-
conn net.Conn
40-
channel *channel
38+
codec codec
39+
conn net.Conn
40+
channel *channel
41+
maxMsgLen int
4142

4243
streamLock sync.RWMutex
4344
streams map[streamID]*stream
@@ -107,14 +108,20 @@ func chainUnaryInterceptors(interceptors []UnaryClientInterceptor, final Invoker
107108
}
108109
}
109110

111+
// WithClientWireMessageLimit sets the maximum allowed message length on the wire for the client.
112+
func WithClientWireMessageLimit(maxMsgLen int) ClientOpts {
113+
maxMsgLen = clampWireMessageLimit(maxMsgLen)
114+
return func(c *Client) {
115+
c.maxMsgLen = maxMsgLen
116+
}
117+
}
118+
110119
// NewClient creates a new ttrpc client using the given connection
111120
func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
112121
ctx, cancel := context.WithCancel(context.Background())
113-
channel := newChannel(conn)
114122
c := &Client{
115123
codec: codec{},
116124
conn: conn,
117-
channel: channel,
118125
streams: make(map[streamID]*stream),
119126
nextStreamID: 1,
120127
closed: cancel,
@@ -127,6 +134,8 @@ func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
127134
o(c)
128135
}
129136

137+
c.channel = newChannel(conn, c.maxMsgLen)
138+
130139
if c.interceptor == nil {
131140
c.interceptor = defaultClientInterceptor
132141
}

config.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
type serverConfig struct {
2525
handshaker Handshaker
2626
interceptor UnaryServerInterceptor
27+
maxMsgLen int
2728
}
2829

2930
// ServerOpt for configuring a ttrpc server
@@ -84,3 +85,12 @@ func chainUnaryServerInterceptors(info *UnaryServerInfo, method Method, intercep
8485
chainUnaryServerInterceptors(info, method, interceptors[1:]))
8586
}
8687
}
88+
89+
// WithServerWireMessageLimit sets the maximum allowed message length on the wire for the server.
90+
func WithServerWireMessageLimit(maxMsgLen int) ServerOpt {
91+
maxMsgLen = clampWireMessageLimit(maxMsgLen)
92+
return func(c *serverConfig) error {
93+
c.maxMsgLen = maxMsgLen
94+
return nil
95+
}
96+
}

errors.go

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package ttrpc
1818

1919
import (
2020
"errors"
21+
"fmt"
2122

2223
"google.golang.org/grpc/codes"
2324
"google.golang.org/grpc/status"
@@ -43,20 +44,59 @@ var (
4344
// length.
4445
type OversizedMessageErr struct {
4546
messageLength int
47+
maxLength int
4648
err error
4749
}
4850

51+
var (
52+
oversizedMsgFmt = "message length %d exceeds maximum message size of %d"
53+
oversizedMsgScanFmt = fmt.Sprintf("%v", status.New(codes.ResourceExhausted, oversizedMsgFmt))
54+
)
55+
4956
// OversizedMessageError returns an OversizedMessageErr error for the given message
5057
// length if it exceeds the allowed maximum. Otherwise a nil error is returned.
51-
func OversizedMessageError(messageLength int) error {
52-
if messageLength <= messageLengthMax {
58+
func OversizedMessageError(messageLength, maxLength int) error {
59+
if messageLength <= maxLength {
5360
return nil
5461
}
5562

5663
return &OversizedMessageErr{
5764
messageLength: messageLength,
58-
err: status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", messageLength, messageLengthMax),
65+
maxLength: maxLength,
66+
err: OversizedMessageStatus(messageLength, maxLength).Err(),
67+
}
68+
}
69+
70+
// OversizedMessageStatus returns a Status for an oversized message error.
71+
func OversizedMessageStatus(messageLength, maxLength int) *status.Status {
72+
return status.Newf(codes.ResourceExhausted, oversizedMsgFmt, messageLength, maxLength)
73+
}
74+
75+
// OversizedMessageFromError reconstructs an OversizedMessageErr from a Status.
76+
func OversizedMessageFromError(err error) (*OversizedMessageErr, bool) {
77+
var (
78+
messageLength int
79+
maxLength int
80+
)
81+
82+
st, ok := status.FromError(err)
83+
if !ok || st.Code() != codes.ResourceExhausted {
84+
return nil, false
5985
}
86+
87+
// TODO(klihub): might be too ugly to recover an error this way... An
88+
// alternative would be to define our custom status detail proto type,
89+
// then use status.WithDetails() and status.Details().
90+
91+
n, _ := fmt.Sscanf(st.Message(), oversizedMsgScanFmt, &messageLength, &maxLength)
92+
if n != 2 {
93+
n, _ = fmt.Sscanf(st.Message(), oversizedMsgFmt, &messageLength, &maxLength)
94+
}
95+
if n != 2 {
96+
return nil, false
97+
}
98+
99+
return OversizedMessageError(messageLength, maxLength).(*OversizedMessageErr), true
60100
}
61101

62102
// Error returns the error message for the corresponding grpc Status for the error.
@@ -75,6 +115,6 @@ func (e *OversizedMessageErr) RejectedLength() int {
75115
}
76116

77117
// MaximumLength retrieves the maximum allowed message length that triggered the error.
78-
func (*OversizedMessageErr) MaximumLength() int {
79-
return messageLengthMax
118+
func (e *OversizedMessageErr) MaximumLength() int {
119+
return e.maxLength
80120
}

server.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ func (c *serverConn) run(sctx context.Context) {
339339
)
340340

341341
var (
342-
ch = newChannel(c.conn)
342+
ch = newChannel(c.conn, c.server.config.maxMsgLen)
343343
ctx, cancel = context.WithCancel(sctx)
344344
state connState = connStateIdle
345345
responses = make(chan response)
@@ -373,6 +373,14 @@ func (c *serverConn) run(sctx context.Context) {
373373
}
374374
}
375375

376+
isResourceExhaustedError := func(err error) (*status.Status, bool) {
377+
st, ok := status.FromError(err)
378+
if !ok || st.Code() != codes.ResourceExhausted {
379+
return nil, false
380+
}
381+
return st, true
382+
}
383+
376384
go func(recvErr chan error) {
377385
defer close(recvErr)
378386
for {
@@ -525,6 +533,17 @@ func (c *serverConn) run(sctx context.Context) {
525533
}
526534

527535
if err := ch.send(response.id, messageTypeResponse, 0, p); err != nil {
536+
if st, ok := isResourceExhaustedError(err); ok {
537+
p, err = c.server.codec.Marshal(&Response{
538+
Status: st.Proto(),
539+
})
540+
if err != nil {
541+
log.G(ctx).WithError(err).Error("failed marshaling error response")
542+
return
543+
}
544+
ch.send(response.id, messageTypeResponse, 0, p)
545+
return
546+
}
528547
log.G(ctx).WithError(err).Error("failed sending message on channel")
529548
return
530549
}

0 commit comments

Comments
 (0)