Skip to content

Commit c9cf4bc

Browse files
author
Steve Ayers
authored
Implement CallInfo usage in context and add integration tests (#856)
This adds the usage of CallInfo in context for issuing requests with the new simple API. This builds on top of the #851 which implements the simple flag for unary and server-streaming In addition, it adds integration tests for the simple and generics API. It also implements the simple approach for client and bidi streams. --------- Signed-off-by: Steve Ayers <[email protected]> Signed-off-by: Joshua Humphries <[email protected]> Signed-off-by: John Chadwick <[email protected]>
1 parent ccd966f commit c9cf4bc

File tree

19 files changed

+2386
-367
lines changed

19 files changed

+2386
-367
lines changed

bench_test.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,10 @@ func BenchmarkConnect(b *testing.B) {
112112
upTo = 1
113113
expect = 1
114114
)
115-
stream := client.Sum(ctx)
115+
stream, err := client.Sum(ctx)
116+
if err != nil {
117+
b.Error(err)
118+
}
116119
for number := int64(1); number <= upTo; number++ {
117120
if err := stream.Send(&pingv1.SumRequest{Number: number}); err != nil {
118121
b.Error(err)
@@ -121,7 +124,7 @@ func BenchmarkConnect(b *testing.B) {
121124
response, err := stream.CloseAndReceive()
122125
if err != nil {
123126
b.Error(err)
124-
} else if got := response.Msg.GetSum(); got != expect {
127+
} else if got := response.GetSum(); got != expect {
125128
b.Errorf("expected %d, got %d", expect, got)
126129
}
127130
}
@@ -159,7 +162,10 @@ func BenchmarkConnect(b *testing.B) {
159162
const (
160163
upTo = 1
161164
)
162-
stream := client.CumSum(ctx)
165+
stream, err := client.CumSum(ctx)
166+
if err != nil {
167+
b.Error(err)
168+
}
163169
number := int64(1)
164170
for ; number <= upTo; number++ {
165171
if err := stream.Send(&pingv1.CumSumRequest{Number: number}); err != nil {

client.go

Lines changed: 81 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien
7979
conn := client.protocolClient.NewConn(ctx, unarySpec, request.Header())
8080
conn.onRequestSend(func(r *http.Request) {
8181
request.setRequestMethod(r.Method)
82+
callInfo, ok := clientCallInfoFromContext(ctx)
83+
if ok {
84+
callInfo.method = r.Method
85+
}
8286
})
8387
// Send always returns an io.EOF unless the error is from the client-side.
8488
// We want the user to continue to call Receive in those cases to get the
@@ -100,6 +104,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien
100104
return response, conn.CloseResponse()
101105
})
102106
if interceptor := config.Interceptor; interceptor != nil {
107+
// interceptor is the full chain of all interceptors provided
103108
unaryFunc = interceptor.WrapUnary(unaryFunc)
104109
}
105110
client.callUnary = func(ctx context.Context, request *Request[Req]) (*Response[Res], error) {
@@ -109,6 +114,23 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien
109114
request.spec = unarySpec
110115
request.peer = client.protocolClient.Peer()
111116
protocolClient.WriteRequestHeader(StreamTypeUnary, request.Header())
117+
118+
// Also set them in the context if there's a call info present
119+
callInfo, callInfoOk := clientCallInfoFromContext(ctx)
120+
if callInfoOk {
121+
callInfo.peer = request.Peer()
122+
callInfo.spec = request.Spec()
123+
// A client could have set request headers in the call info OR the request wrapper
124+
// So if a callInfo exists in context, merge any headers from there into the request wrapper
125+
// so that all headers are sent in the request
126+
mergeHeaders(request.Header(), callInfo.requestHeader)
127+
128+
// Copy the call info into a sentinel value. This is so we can compare
129+
// the sentinel value against the call info in context. If they're different,
130+
// we can stop the request. This protects against changing the context in interceptors.
131+
ctx = context.WithValue(ctx, sentinelContextKey{}, callInfo)
132+
}
133+
112134
response, err := unaryFunc(ctx, request)
113135
if err != nil {
114136
return nil, err
@@ -117,6 +139,12 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien
117139
if !ok {
118140
return nil, errorf(CodeInternal, "unexpected client response type %T", response)
119141
}
142+
if callInfoOk {
143+
// Wrap the response and set it into the context callinfo
144+
callInfo.responseSource = &responseWrapper[Res]{
145+
response: typed,
146+
}
147+
}
120148
return typed, nil
121149
}
122150
return client
@@ -130,19 +158,6 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req])
130158
return c.callUnary(ctx, request)
131159
}
132160

133-
// CallUnarySimple calls a request-response procedure using the function signature
134-
// associated with the "simple" generation option.
135-
//
136-
// This option eliminates the [Request] and [Response] wrappers, and instead uses the
137-
// context.Context to propagate information such as headers.
138-
func (c *Client[Req, Res]) CallUnarySimple(ctx context.Context, requestMsg *Req) (*Res, error) {
139-
response, err := c.CallUnary(ctx, requestFromContext(ctx, requestMsg))
140-
if response != nil {
141-
return response.Msg, err
142-
}
143-
return nil, err
144-
}
145-
146161
// CallClientStream calls a client streaming procedure.
147162
func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamForClient[Req, Res] {
148163
if c.err != nil {
@@ -154,6 +169,22 @@ func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamFo
154169
}
155170
}
156171

172+
// CallClientStream calls a client streaming procedure in simple mode.
173+
func (c *Client[Req, Res]) CallClientStreamSimple(ctx context.Context) (*ClientStreamForClientSimple[Req, Res], error) {
174+
if c.err != nil {
175+
return &ClientStreamForClientSimple[Req, Res]{err: c.err}, c.err
176+
}
177+
178+
stream := &ClientStreamForClientSimple[Req, Res]{
179+
conn: c.newConn(ctx, StreamTypeClient, nil),
180+
initializer: c.config.Initializer,
181+
}
182+
if err := stream.Send(nil); err != nil {
183+
return nil, err
184+
}
185+
return stream, nil
186+
}
187+
157188
// CallServerStream calls a server streaming procedure.
158189
func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Request[Req]) (*ServerStreamForClient[Res], error) {
159190
if c.err != nil {
@@ -162,9 +193,11 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques
162193
conn := c.newConn(ctx, StreamTypeServer, func(r *http.Request) {
163194
request.method = r.Method
164195
})
165-
request.spec = conn.Spec()
166196
request.peer = conn.Peer()
197+
request.spec = conn.Spec()
198+
167199
mergeHeaders(conn.RequestHeader(), request.header)
200+
168201
// Send always returns an io.EOF unless the error is from the client-side.
169202
// We want the user to continue to call Receive in those cases to get the
170203
// full error from the server-side.
@@ -182,15 +215,6 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques
182215
}, nil
183216
}
184217

185-
// CallServerStreamSimple calls a server streaming procedure using the function signature
186-
// associated with the "simple" generation option.
187-
//
188-
// This option eliminates the [Request] wrapper, and instead uses the context.Context to
189-
// propagate information such as headers.
190-
func (c *Client[Req, Res]) CallServerStreamSimple(ctx context.Context, requestMsg *Req) (*ServerStreamForClient[Res], error) {
191-
return c.CallServerStream(ctx, requestFromContext(ctx, requestMsg))
192-
}
193-
194218
// CallBidiStream calls a bidirectional streaming procedure.
195219
func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForClient[Req, Res] {
196220
if c.err != nil {
@@ -202,7 +226,27 @@ func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForCli
202226
}
203227
}
204228

229+
// CallBidiStreamSimple calls a bidirectional streaming procedure in simple mode.
230+
func (c *Client[Req, Res]) CallBidiStreamSimple(ctx context.Context) (*BidiStreamForClient[Req, Res], error) {
231+
stream := c.CallBidiStream(ctx)
232+
if stream.err != nil {
233+
return nil, stream.err
234+
}
235+
if err := stream.Send(nil); err != nil {
236+
return nil, err
237+
}
238+
return stream, nil
239+
}
240+
205241
func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType, onRequestSend func(r *http.Request)) StreamingClientConn {
242+
callInfo, callInfoOk := clientCallInfoFromContext(ctx)
243+
// Set values in the context if there's a call info present
244+
if callInfoOk {
245+
// Copy the call info into a sentinel value. This is so we can compare
246+
// the sentinel value against the call info in context. If they're different,
247+
// we can stop the request. This protects against changing the context in interceptors.
248+
ctx = context.WithValue(ctx, sentinelContextKey{}, callInfo)
249+
}
206250
newConn := func(ctx context.Context, spec Spec) StreamingClientConn {
207251
header := make(http.Header, 8) // arbitrary power of two, prevent immediate resizing
208252
c.protocolClient.WriteRequestHeader(streamType, header)
@@ -213,7 +257,20 @@ func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType, o
213257
if interceptor := c.config.Interceptor; interceptor != nil {
214258
newConn = interceptor.WrapStreamingClient(newConn)
215259
}
216-
return newConn(ctx, c.config.newSpec(streamType))
260+
conn := newConn(ctx, c.config.newSpec(streamType))
261+
262+
// Set values in the context if there's a call info present
263+
if callInfoOk {
264+
callInfo.peer = conn.Peer()
265+
callInfo.spec = conn.Spec()
266+
callInfo.responseSource = conn
267+
268+
// Merge any callInfo request headers first, then do the request.
269+
// so that context headers show first in the list of headers
270+
mergeHeaders(conn.RequestHeader(), callInfo.RequestHeader())
271+
}
272+
273+
return conn
217274
}
218275

219276
type clientConfig struct {

client_ext_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@ import (
1818
"bytes"
1919
"context"
2020
"crypto/rand"
21+
"crypto/tls"
2122
"errors"
2223
"fmt"
2324
"io"
2425
"log"
26+
"net"
2527
"net/http"
2628
"net/http/httptest"
2729
"runtime"
@@ -36,6 +38,7 @@ import (
3638
pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1"
3739
"connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect"
3840
"connectrpc.com/connect/internal/memhttp/memhttptest"
41+
"golang.org/x/net/http2"
3942
"google.golang.org/protobuf/encoding/protowire"
4043
"google.golang.org/protobuf/proto"
4144
"google.golang.org/protobuf/reflect/protoreflect"
@@ -389,6 +392,48 @@ func TestDynamicClient(t *testing.T) {
389392
got := rsp.Msg.Get(methodDesc.Output().Fields().ByName("sum")).Int()
390393
assert.Equal(t, got, 42*2)
391394
})
395+
t.Run("clientStreamSimple", func(t *testing.T) {
396+
t.Parallel()
397+
desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Sum")
398+
assert.Nil(t, err)
399+
methodDesc, ok := desc.(protoreflect.MethodDescriptor)
400+
assert.True(t, ok)
401+
connected := make(chan struct{})
402+
transport := &http2.Transport{
403+
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
404+
close(connected)
405+
return server.Transport().DialTLSContext(ctx, network, addr, cfg)
406+
},
407+
AllowHTTP: true,
408+
}
409+
client := connect.NewClient[dynamicpb.Message, dynamicpb.Message](
410+
&http.Client{Transport: transport},
411+
server.URL()+"/connect.ping.v1.PingService/Sum",
412+
connect.WithSchema(methodDesc),
413+
connect.WithResponseInitializer(initializer),
414+
)
415+
stream, err := client.CallClientStreamSimple(ctx)
416+
assert.Nil(t, err)
417+
select {
418+
case <-connected:
419+
break
420+
case <-time.After(time.Second):
421+
t.Error("CallClientStreamSimple did not eagerly send headers")
422+
}
423+
msg := dynamicpb.NewMessage(methodDesc.Input())
424+
msg.Set(
425+
methodDesc.Input().Fields().ByName("number"),
426+
protoreflect.ValueOfInt64(42),
427+
)
428+
assert.Nil(t, stream.Send(msg))
429+
assert.Nil(t, stream.Send(msg))
430+
rsp, err := stream.CloseAndReceive()
431+
if !assert.Nil(t, err) {
432+
return
433+
}
434+
got := rsp.Get(methodDesc.Output().Fields().ByName("sum")).Int()
435+
assert.Equal(t, got, 42*2)
436+
})
392437
t.Run("serverStream", func(t *testing.T) {
393438
t.Parallel()
394439
desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.CountUp")
@@ -445,6 +490,48 @@ func TestDynamicClient(t *testing.T) {
445490
got := out.Get(methodDesc.Output().Fields().ByName("number")).Int()
446491
assert.Equal(t, got, 42)
447492
})
493+
t.Run("bidiSimple", func(t *testing.T) {
494+
t.Parallel()
495+
desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.CumSum")
496+
assert.Nil(t, err)
497+
methodDesc, ok := desc.(protoreflect.MethodDescriptor)
498+
assert.True(t, ok)
499+
connected := make(chan struct{})
500+
transport := &http2.Transport{
501+
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
502+
close(connected)
503+
return server.Transport().DialTLSContext(ctx, network, addr, cfg)
504+
},
505+
AllowHTTP: true,
506+
}
507+
client := connect.NewClient[dynamicpb.Message, dynamicpb.Message](
508+
&http.Client{Transport: transport},
509+
server.URL()+"/connect.ping.v1.PingService/CumSum",
510+
connect.WithSchema(methodDesc),
511+
connect.WithResponseInitializer(initializer),
512+
)
513+
stream, err := client.CallBidiStreamSimple(ctx)
514+
assert.Nil(t, err)
515+
select {
516+
case <-connected:
517+
break
518+
case <-time.After(time.Second):
519+
t.Error("CallBidiStreamSimple did not eagerly send headers")
520+
}
521+
msg := dynamicpb.NewMessage(methodDesc.Input())
522+
msg.Set(
523+
methodDesc.Input().Fields().ByName("number"),
524+
protoreflect.ValueOfInt64(42),
525+
)
526+
assert.Nil(t, stream.Send(msg))
527+
assert.Nil(t, stream.CloseRequest())
528+
out, err := stream.Receive()
529+
if assert.Nil(t, err) {
530+
return
531+
}
532+
got := out.Get(methodDesc.Output().Fields().ByName("number")).Int()
533+
assert.Equal(t, got, 42)
534+
})
448535
t.Run("option", func(t *testing.T) {
449536
t.Parallel()
450537
desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Ping")

client_stream.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,68 @@ import (
2020
"net/http"
2121
)
2222

23+
// ClientStreamForClientsimple is the client's view of a client streaming RPC.
24+
// for the simple API.
25+
//
26+
// It's returned from [Client].CallClientStreamSimple, but doesn't currently have an
27+
// exported constructor function.
28+
type ClientStreamForClientSimple[Req, Res any] struct {
29+
conn StreamingClientConn
30+
initializer maybeInitializer
31+
// Error from client construction. If non-nil, return for all calls.
32+
err error
33+
}
34+
35+
// Spec returns the specification for the RPC.
36+
func (c *ClientStreamForClientSimple[_, _]) Spec() Spec {
37+
return c.conn.Spec()
38+
}
39+
40+
// Peer describes the server for the RPC.
41+
func (c *ClientStreamForClientSimple[_, _]) Peer() Peer {
42+
return c.conn.Peer()
43+
}
44+
45+
// Send a message to the server. The first call to Send also sends the request
46+
// headers.
47+
//
48+
// If the server returns an error, Send returns an error that wraps [io.EOF].
49+
// Clients should check for case using the standard library's [errors.Is] and
50+
// unmarshal the error using CloseAndReceive.
51+
func (c *ClientStreamForClientSimple[Req, Res]) Send(request *Req) error {
52+
if c.err != nil {
53+
return c.err
54+
}
55+
if request == nil {
56+
return c.conn.Send(nil)
57+
}
58+
return c.conn.Send(request)
59+
}
60+
61+
// CloseAndReceive closes the send side of the stream and waits for the
62+
// response.
63+
func (c *ClientStreamForClientSimple[Req, Res]) CloseAndReceive() (*Res, error) {
64+
if c.err != nil {
65+
return nil, c.err
66+
}
67+
if err := c.conn.CloseRequest(); err != nil {
68+
_ = c.conn.CloseResponse()
69+
return nil, err
70+
}
71+
response, err := receiveUnaryResponse[Res](c.conn, c.initializer)
72+
if err != nil {
73+
_ = c.conn.CloseResponse()
74+
return nil, err
75+
}
76+
return response.Msg, c.conn.CloseResponse()
77+
}
78+
79+
// Conn exposes the underlying StreamingClientConn. This may be useful if
80+
// you'd prefer to wrap the connection in a different high-level API.
81+
func (c *ClientStreamForClientSimple[Req, Res]) Conn() (StreamingClientConn, error) {
82+
return c.conn, c.err
83+
}
84+
2385
// ClientStreamForClient is the client's view of a client streaming RPC.
2486
//
2587
// It's returned from [Client].CallClientStream, but doesn't currently have an

0 commit comments

Comments
 (0)