Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions ca/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,13 @@ type RetryFunc func(code int) bool
// ClientOption is the type of options passed to the Client constructor.
type ClientOption func(o *clientOptions) error

// TransportDecorator is the type used to support customization of the HTTP
// transport.
type TransportDecorator func(http.RoundTripper) http.RoundTripper

type clientOptions struct {
transport http.RoundTripper
transportDecorator TransportDecorator
timeout time.Duration
rootSHA256 string
rootFilename string
Expand Down Expand Up @@ -272,7 +277,8 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err
}
}

return tr, nil
// Wrap the transport using the decorator function if necessary
return decorateRoundTripper(tr, o.transportDecorator), nil
}

// WithTransport adds a custom transport to the Client. It will fail if a
Expand All @@ -287,6 +293,16 @@ func WithTransport(tr http.RoundTripper) ClientOption {
}
}

// WithTransportDecorator allows customization of the HTTP transport used by the
// client. The provided function receives the configured [http.RoundTripper] and
// can wrap it with additional functionality.
func WithTransportDecorator(fn TransportDecorator) ClientOption {
return func(o *clientOptions) error {
o.transportDecorator = fn
return nil
}
}

// WithInsecure adds a insecure transport that bypasses TLS verification.
func WithInsecure() ClientOption {
return func(o *clientOptions) error {
Expand Down Expand Up @@ -562,11 +578,12 @@ func WithProvisionerName(name string) ProvisionerOption {

// Client implements an HTTP client for the CA server.
type Client struct {
client *uaClient
endpoint *url.URL
retryFunc RetryFunc
timeout time.Duration
opts []ClientOption
client *uaClient
endpoint *url.URL
retryFunc RetryFunc
timeout time.Duration
opts []ClientOption
transportDecorator TransportDecorator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps the client would benefit from having a reference to *clientOptions and not []ClientOption. This way you wouldn't need to copy transportDecorator around. Not sure how big of a change this is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use those options later, perhaps we could refactor that part, but there might be some unexpected implications.

}

// NewClient creates a new Client with the given endpoint and options.
Expand All @@ -587,11 +604,12 @@ func NewClient(endpoint string, opts ...ClientOption) (*Client, error) {
}

return &Client{
client: newClient(tr, o.timeout),
endpoint: u,
retryFunc: o.retryFunc,
timeout: o.timeout,
opts: opts,
client: newClient(tr, o.timeout),
endpoint: u,
retryFunc: o.retryFunc,
timeout: o.timeout,
opts: opts,
transportDecorator: o.transportDecorator,
}, nil
}

Expand Down Expand Up @@ -1583,3 +1601,10 @@ func clientError(err error) error {
}
return fmt.Errorf("client request failed: %w", err)
}

func decorateRoundTripper(tr http.RoundTripper, td TransportDecorator) http.RoundTripper {
if td != nil {
return td(tr)
}
return tr
}
37 changes: 37 additions & 0 deletions ca/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,43 @@ func TestClient_WithTimeout(t *testing.T) {
}
}

type decoratedRoundTripper func(*http.Request) (*http.Response, error)

func (rt decoratedRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return rt(req)
}

func TestClient_WithTransportDecorator(t *testing.T) {
var srv *httptest.Server
srv = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.RequestURI, "/root") {
render.JSONStatus(w, r, api.RootResponse{
RootPEM: api.NewCertificate(srv.Certificate()),
}, 200)
return
}

if s := r.Header.Get("X-Test-Header"); s != "" {
render.JSONStatus(w, r, api.HealthResponse{Status: s}, 200)
} else {
render.JSONStatus(w, r, api.HealthResponse{Status: "ok"}, 200)
}
}))
defer srv.Close()

fp := x509util.Fingerprint(srv.Certificate())
c, err := NewClient(srv.URL, WithRootSHA256(fp), WithTransportDecorator(func(tr http.RoundTripper) http.RoundTripper {
return decoratedRoundTripper(func(r *http.Request) (*http.Response, error) {
r.Header.Add("X-Test-Header", "some-data")
return tr.RoundTrip(r)
})
}))
require.NoError(t, err)
resp, err := c.Health()
require.NoError(t, err)
assert.Equal(t, "some-data", resp.Status)
}

func Test_enforceRequestID(t *testing.T) {
set := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
set.Header.Set("X-Request-Id", "already-set")
Expand Down
27 changes: 18 additions & 9 deletions ca/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
return tlsConfig, nil
}

func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options []TLSOption) (*tls.Config, *http.Transport, error) {
func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options []TLSOption) (*tls.Config, http.RoundTripper, error) {
cert, err := TLSCertificate(sign, pk)
if err != nil {
return nil, nil, err
Expand All @@ -133,14 +133,18 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse,

tr := getDefaultTransport(tlsConfig)
tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) //nolint:contextcheck // deeply nested context

// Add decorator if available, and use the resulting [http.RoundTripper]
// going forward
rt := decorateRoundTripper(tr, c.transportDecorator)
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, rt, pk) //nolint:contextcheck // deeply nested context

// Update client transport
c.SetTransport(tr)
c.SetTransport(rt)

// Start renewer
renewer.RunContext(ctx)
return tlsConfig, tr, nil
return tlsConfig, rt, nil
}

// GetServerTLSConfig returns a tls.Config for server use configured with the
Expand Down Expand Up @@ -179,18 +183,23 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
// Update renew function with transport
tr := getDefaultTransport(tlsConfig)
tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) //nolint:contextcheck // deeply nested context

// Add decorator if available, and use the resulting [http.RoundTripper]
// going forward
rt := decorateRoundTripper(tr, c.transportDecorator)
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, rt, pk) //nolint:contextcheck // deeply nested context

// Update client transport
c.SetTransport(tr)
c.SetTransport(rt)

// Start renewer
renewer.RunContext(ctx)
return tlsConfig, nil
}

// Transport returns an http.Transport configured to use the client certificate from the sign response.
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*http.Transport, error) {
// Transport returns an [http.RoundTripper] configured to use the client
// certificate from the sign response.
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (http.RoundTripper, error) {
_, tr, err := c.getClientTLSConfig(ctx, sign, pk, options)
if err != nil {
return nil, err
Expand Down Expand Up @@ -365,7 +374,7 @@ func getPEM(i interface{}) ([]byte, error) {
return pem.EncodeToMemory(block), nil
}

func getRenewFunc(ctx *TLSOptionCtx, client *Client, tr *http.Transport, pk crypto.PrivateKey) RenewFunc {
func getRenewFunc(ctx *TLSOptionCtx, client *Client, tr http.RoundTripper, pk crypto.PrivateKey) RenewFunc {
return func() (*tls.Certificate, error) {
// Close connections in keep-alive state
defer client.CloseIdleConnections()
Expand Down
4 changes: 3 additions & 1 deletion ca/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,10 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {

// Transport
client, sr, pk = signDuration(t, ca, clientDomain, 5*time.Second)
tr1, err := client.Transport(context.Background(), sr, pk)
tr, err := client.Transport(context.Background(), sr, pk)
require.NoError(t, err)
tr1, ok := tr.(*http.Transport)
require.True(t, ok)

// Transport with tlsConfig
client, sr, pk = signDuration(t, ca, clientDomain, 5*time.Second)
Expand Down