Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
9027754
Redo
Jun 30, 2025
bf18abf
Tests
Jun 30, 2025
a40af0c
Remove print
Jun 30, 2025
38f5be6
Simplify
Jun 30, 2025
4763170
Feedback
Jul 1, 2025
6a95dc5
Feedback
Jul 1, 2025
ee8fafd
Feedback
Jul 1, 2025
9af9940
Cleanup
Jul 1, 2025
1cd1311
Update context.go
smaye81 Jul 1, 2025
93cacf8
Feedback
Jul 1, 2025
adc81b1
Interceptors
Jul 1, 2025
8397865
Interceptor tests
Jul 1, 2025
bc250b0
Feedback
Jul 1, 2025
0a44db9
Feedback
Jul 1, 2025
94dbb48
Update header setting
Jul 1, 2025
57e8698
Fix responseWrapper docs
Jul 1, 2025
e422ba2
Fix again
Jul 1, 2025
3cdb5e1
Update tests
Jul 1, 2025
6a3ed80
Style
Jul 1, 2025
e14c0d7
Move func
Jul 2, 2025
a2e3f4e
Fix server stream tests
Jul 3, 2025
b83ca03
Rename context methods and always create a new call info when using e…
Jul 16, 2025
7d37a59
Interceptor tests
Jul 21, 2025
3cadcc1
Side quest tests
Jul 21, 2025
b2d9bce
Extensive testing for simple and generic APIs using callinfo
Jul 21, 2025
153acb7
Implement simple for client streaming on handler
Jul 17, 2025
021a0cb
Implement simple for client streaming on client
Jul 17, 2025
7995c00
Implement simple for bidi streaming on client
Jul 17, 2025
d12c6c9
Make client/bidi stream fallible for simple
Jul 17, 2025
e4026f2
Fix benchmark/example test
Jul 17, 2025
e34c8c8
Redo
Jun 30, 2025
194fb35
Tests
Jun 30, 2025
39380da
Remove print
Jun 30, 2025
36bc4bc
Simplify
Jun 30, 2025
c4b878f
Feedback
Jul 1, 2025
6c5253b
Feedback
Jul 1, 2025
3f509d8
Feedback
Jul 1, 2025
c78eb94
Cleanup
Jul 1, 2025
279b452
Update context.go
smaye81 Jul 1, 2025
7baf7be
Feedback
Jul 1, 2025
8d92bae
Interceptors
Jul 1, 2025
329c9e5
Interceptor tests
Jul 1, 2025
c674148
Feedback
Jul 1, 2025
3935be9
Feedback
Jul 1, 2025
a295d10
Update header setting
Jul 1, 2025
1d30ca6
Fix responseWrapper docs
Jul 1, 2025
a2e8814
Fix again
Jul 1, 2025
d8102c0
Update tests
Jul 1, 2025
8de8390
Style
Jul 1, 2025
32520e7
Move func
Jul 2, 2025
d5ccf16
Fix server stream tests
Jul 3, 2025
cd7dc92
Rename context methods and always create a new call info when using e…
Jul 16, 2025
45ee1ce
Interceptor tests
Jul 21, 2025
1011799
Side quest tests
Jul 21, 2025
9143ba0
Extensive testing for simple and generic APIs using callinfo
Jul 21, 2025
fa4e176
Feedback
Jul 23, 2025
8e8dc71
Merge branch 'simple' into simple
smaye81 Jul 23, 2025
0b48c6a
Merge pull request #1 from jchadwick-buf/simple
smaye81 Jul 23, 2025
f131a46
Add full host of tests for all RPC types and simple vs. generics API.
Jul 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ func BenchmarkConnect(b *testing.B) {
upTo = 1
expect = 1
)
stream := client.Sum(ctx)
stream, err := client.Sum(ctx)
if err != nil {
b.Error(err)
}
for number := int64(1); number <= upTo; number++ {
if err := stream.Send(&pingv1.SumRequest{Number: number}); err != nil {
b.Error(err)
Expand All @@ -121,7 +124,7 @@ func BenchmarkConnect(b *testing.B) {
response, err := stream.CloseAndReceive()
if err != nil {
b.Error(err)
} else if got := response.Msg.GetSum(); got != expect {
} else if got := response.GetSum(); got != expect {
b.Errorf("expected %d, got %d", expect, got)
}
}
Expand Down Expand Up @@ -159,7 +162,10 @@ func BenchmarkConnect(b *testing.B) {
const (
upTo = 1
)
stream := client.CumSum(ctx)
stream, err := client.CumSum(ctx)
if err != nil {
b.Error(err)
}
number := int64(1)
for ; number <= upTo; number++ {
if err := stream.Send(&pingv1.CumSumRequest{Number: number}); err != nil {
Expand Down
105 changes: 81 additions & 24 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien
conn := client.protocolClient.NewConn(ctx, unarySpec, request.Header())
conn.onRequestSend(func(r *http.Request) {
request.setRequestMethod(r.Method)
callInfo, ok := clientCallInfoFromContext(ctx)
if ok {
callInfo.method = r.Method
}
})
// Send always returns an io.EOF unless the error is from the client-side.
// We want the user to continue to call Receive in those cases to get the
Expand All @@ -100,6 +104,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien
return response, conn.CloseResponse()
})
if interceptor := config.Interceptor; interceptor != nil {
// interceptor is the full chain of all interceptors provided
unaryFunc = interceptor.WrapUnary(unaryFunc)
}
client.callUnary = func(ctx context.Context, request *Request[Req]) (*Response[Res], error) {
Expand All @@ -109,6 +114,23 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien
request.spec = unarySpec
request.peer = client.protocolClient.Peer()
protocolClient.WriteRequestHeader(StreamTypeUnary, request.Header())

// Also set them in the context if there's a call info present
callInfo, callInfoOk := clientCallInfoFromContext(ctx)
if callInfoOk {
callInfo.peer = request.Peer()
callInfo.spec = request.Spec()
// A client could have set request headers in the call info OR the request wrapper
// So if a callInfo exists in context, merge any headers from there into the request wrapper
// so that all headers are sent in the request
mergeHeaders(request.Header(), callInfo.requestHeader)

// Copy the call info into a sentinel value. This is so we can compare
// the sentinel value against the call info in context. If they're different,
// we can stop the request. This protects against changing the context in interceptors.
ctx = context.WithValue(ctx, sentinelContextKey{}, callInfo)
}

response, err := unaryFunc(ctx, request)
if err != nil {
return nil, err
Expand All @@ -117,6 +139,12 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien
if !ok {
return nil, errorf(CodeInternal, "unexpected client response type %T", response)
}
if callInfoOk {
// Wrap the response and set it into the context callinfo
callInfo.responseSource = &responseWrapper[Res]{
response: typed,
}
}
return typed, nil
}
return client
Expand All @@ -130,19 +158,6 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req])
return c.callUnary(ctx, request)
}

// CallUnarySimple calls a request-response procedure using the function signature
// associated with the "simple" generation option.
//
// This option eliminates the [Request] and [Response] wrappers, and instead uses the
// context.Context to propagate information such as headers.
func (c *Client[Req, Res]) CallUnarySimple(ctx context.Context, requestMsg *Req) (*Res, error) {
response, err := c.CallUnary(ctx, requestFromContext(ctx, requestMsg))
if response != nil {
return response.Msg, err
}
return nil, err
}

// CallClientStream calls a client streaming procedure.
func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamForClient[Req, Res] {
if c.err != nil {
Expand All @@ -154,6 +169,22 @@ func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamFo
}
}

// CallClientStream calls a client streaming procedure in simple mode.
func (c *Client[Req, Res]) CallClientStreamSimple(ctx context.Context) (*ClientStreamForClientSimple[Req, Res], error) {
if c.err != nil {
return &ClientStreamForClientSimple[Req, Res]{err: c.err}, c.err
}

stream := &ClientStreamForClientSimple[Req, Res]{
conn: c.newConn(ctx, StreamTypeClient, nil),
initializer: c.config.Initializer,
}
if err := stream.Send(nil); err != nil {
return nil, err
}
return stream, nil
}

// CallServerStream calls a server streaming procedure.
func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Request[Req]) (*ServerStreamForClient[Res], error) {
if c.err != nil {
Expand All @@ -162,9 +193,11 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques
conn := c.newConn(ctx, StreamTypeServer, func(r *http.Request) {
request.method = r.Method
})
request.spec = conn.Spec()
request.peer = conn.Peer()
request.spec = conn.Spec()

mergeHeaders(conn.RequestHeader(), request.header)

// Send always returns an io.EOF unless the error is from the client-side.
// We want the user to continue to call Receive in those cases to get the
// full error from the server-side.
Expand All @@ -182,15 +215,6 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques
}, nil
}

// CallServerStreamSimple calls a server streaming procedure using the function signature
// associated with the "simple" generation option.
//
// This option eliminates the [Request] wrapper, and instead uses the context.Context to
// propagate information such as headers.
func (c *Client[Req, Res]) CallServerStreamSimple(ctx context.Context, requestMsg *Req) (*ServerStreamForClient[Res], error) {
return c.CallServerStream(ctx, requestFromContext(ctx, requestMsg))
}

// CallBidiStream calls a bidirectional streaming procedure.
func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForClient[Req, Res] {
if c.err != nil {
Expand All @@ -202,7 +226,27 @@ func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForCli
}
}

// CallBidiStreamSimple calls a bidirectional streaming procedure in simple mode.
func (c *Client[Req, Res]) CallBidiStreamSimple(ctx context.Context) (*BidiStreamForClient[Req, Res], error) {
stream := c.CallBidiStream(ctx)
if stream.err != nil {
return nil, stream.err
}
if err := stream.Send(nil); err != nil {
return nil, err
}
return stream, nil
}

func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType, onRequestSend func(r *http.Request)) StreamingClientConn {
callInfo, callInfoOk := clientCallInfoFromContext(ctx)
// Set values in the context if there's a call info present
if callInfoOk {
// Copy the call info into a sentinel value. This is so we can compare
// the sentinel value against the call info in context. If they're different,
// we can stop the request. This protects against changing the context in interceptors.
ctx = context.WithValue(ctx, sentinelContextKey{}, callInfo)
}
newConn := func(ctx context.Context, spec Spec) StreamingClientConn {
header := make(http.Header, 8) // arbitrary power of two, prevent immediate resizing
c.protocolClient.WriteRequestHeader(streamType, header)
Expand All @@ -213,7 +257,20 @@ func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType, o
if interceptor := c.config.Interceptor; interceptor != nil {
newConn = interceptor.WrapStreamingClient(newConn)
}
return newConn(ctx, c.config.newSpec(streamType))
conn := newConn(ctx, c.config.newSpec(streamType))

// Set values in the context if there's a call info present
if callInfoOk {
callInfo.peer = conn.Peer()
callInfo.spec = conn.Spec()
callInfo.responseSource = conn

// Merge any callInfo request headers first, then do the request.
// so that context headers show first in the list of headers
mergeHeaders(conn.RequestHeader(), callInfo.RequestHeader())
}

return conn
}

type clientConfig struct {
Expand Down
87 changes: 87 additions & 0 deletions client_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ import (
"bytes"
"context"
"crypto/rand"
"crypto/tls"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"net/http/httptest"
"runtime"
Expand All @@ -36,6 +38,7 @@ import (
pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1"
"connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect"
"connectrpc.com/connect/internal/memhttp/memhttptest"
"golang.org/x/net/http2"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
Expand Down Expand Up @@ -389,6 +392,48 @@ func TestDynamicClient(t *testing.T) {
got := rsp.Msg.Get(methodDesc.Output().Fields().ByName("sum")).Int()
assert.Equal(t, got, 42*2)
})
t.Run("clientStreamSimple", func(t *testing.T) {
t.Parallel()
desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Sum")
assert.Nil(t, err)
methodDesc, ok := desc.(protoreflect.MethodDescriptor)
assert.True(t, ok)
connected := make(chan struct{})
transport := &http2.Transport{
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
close(connected)
return server.Transport().DialTLSContext(ctx, network, addr, cfg)
},
AllowHTTP: true,
}
client := connect.NewClient[dynamicpb.Message, dynamicpb.Message](
&http.Client{Transport: transport},
server.URL()+"/connect.ping.v1.PingService/Sum",
connect.WithSchema(methodDesc),
connect.WithResponseInitializer(initializer),
)
stream, err := client.CallClientStreamSimple(ctx)
assert.Nil(t, err)
select {
case <-connected:
break
case <-time.After(time.Second):
t.Error("CallClientStreamSimple did not eagerly send headers")
}
msg := dynamicpb.NewMessage(methodDesc.Input())
msg.Set(
methodDesc.Input().Fields().ByName("number"),
protoreflect.ValueOfInt64(42),
)
assert.Nil(t, stream.Send(msg))
assert.Nil(t, stream.Send(msg))
rsp, err := stream.CloseAndReceive()
if !assert.Nil(t, err) {
return
}
got := rsp.Get(methodDesc.Output().Fields().ByName("sum")).Int()
assert.Equal(t, got, 42*2)
})
t.Run("serverStream", func(t *testing.T) {
t.Parallel()
desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.CountUp")
Expand Down Expand Up @@ -445,6 +490,48 @@ func TestDynamicClient(t *testing.T) {
got := out.Get(methodDesc.Output().Fields().ByName("number")).Int()
assert.Equal(t, got, 42)
})
t.Run("bidiSimple", func(t *testing.T) {
t.Parallel()
desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.CumSum")
assert.Nil(t, err)
methodDesc, ok := desc.(protoreflect.MethodDescriptor)
assert.True(t, ok)
connected := make(chan struct{})
transport := &http2.Transport{
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
close(connected)
return server.Transport().DialTLSContext(ctx, network, addr, cfg)
},
AllowHTTP: true,
}
client := connect.NewClient[dynamicpb.Message, dynamicpb.Message](
&http.Client{Transport: transport},
server.URL()+"/connect.ping.v1.PingService/CumSum",
connect.WithSchema(methodDesc),
connect.WithResponseInitializer(initializer),
)
stream, err := client.CallBidiStreamSimple(ctx)
assert.Nil(t, err)
select {
case <-connected:
break
case <-time.After(time.Second):
t.Error("CallBidiStreamSimple did not eagerly send headers")
}
msg := dynamicpb.NewMessage(methodDesc.Input())
msg.Set(
methodDesc.Input().Fields().ByName("number"),
protoreflect.ValueOfInt64(42),
)
assert.Nil(t, stream.Send(msg))
assert.Nil(t, stream.CloseRequest())
out, err := stream.Receive()
if assert.Nil(t, err) {
return
}
got := out.Get(methodDesc.Output().Fields().ByName("number")).Int()
assert.Equal(t, got, 42)
})
t.Run("option", func(t *testing.T) {
t.Parallel()
desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Ping")
Expand Down
62 changes: 62 additions & 0 deletions client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,68 @@ import (
"net/http"
)

// ClientStreamForClientsimple is the client's view of a client streaming RPC.
// for the simple API.
//
// It's returned from [Client].CallClientStreamSimple, but doesn't currently have an
// exported constructor function.
type ClientStreamForClientSimple[Req, Res any] struct {
conn StreamingClientConn
initializer maybeInitializer
// Error from client construction. If non-nil, return for all calls.
err error
}

// Spec returns the specification for the RPC.
func (c *ClientStreamForClientSimple[_, _]) Spec() Spec {
return c.conn.Spec()
}

// Peer describes the server for the RPC.
func (c *ClientStreamForClientSimple[_, _]) Peer() Peer {
return c.conn.Peer()
}

// Send a message to the server. The first call to Send also sends the request
// headers.
//
// If the server returns an error, Send returns an error that wraps [io.EOF].
// Clients should check for case using the standard library's [errors.Is] and
// unmarshal the error using CloseAndReceive.
func (c *ClientStreamForClientSimple[Req, Res]) Send(request *Req) error {
if c.err != nil {
return c.err
}
if request == nil {
return c.conn.Send(nil)
}
return c.conn.Send(request)
}

// CloseAndReceive closes the send side of the stream and waits for the
// response.
func (c *ClientStreamForClientSimple[Req, Res]) CloseAndReceive() (*Res, error) {
if c.err != nil {
return nil, c.err
}
if err := c.conn.CloseRequest(); err != nil {
_ = c.conn.CloseResponse()
return nil, err
}
response, err := receiveUnaryResponse[Res](c.conn, c.initializer)
if err != nil {
_ = c.conn.CloseResponse()
return nil, err
}
return response.Msg, c.conn.CloseResponse()
}

// Conn exposes the underlying StreamingClientConn. This may be useful if
// you'd prefer to wrap the connection in a different high-level API.
func (c *ClientStreamForClientSimple[Req, Res]) Conn() (StreamingClientConn, error) {
return c.conn, c.err
}

// ClientStreamForClient is the client's view of a client streaming RPC.
//
// It's returned from [Client].CallClientStream, but doesn't currently have an
Expand Down
Loading