Skip to content

Commit 56e7391

Browse files
committed
HTTP client factory for per-request clients
1 parent 493aa4c commit 56e7391

File tree

1 file changed

+66
-21
lines changed

1 file changed

+66
-21
lines changed

client.go

Lines changed: 66 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -368,11 +368,24 @@ type Backoff func(min, max time.Duration, attemptNum int, resp *http.Response) t
368368
// attempted. If overriding this, be sure to close the body if needed.
369369
type ErrorHandler func(resp *http.Response, err error, numTries int) (*http.Response, error)
370370

371+
type HTTPClient interface {
372+
// Do performs an HTTP request and returns an HTTP response.
373+
Do(*http.Request) (*http.Response, error)
374+
// Done is called when the client is no longer needed.
375+
Done()
376+
}
377+
378+
type HTTPClientFactory interface {
379+
// New returns an HTTP client to use for a request, including retries.
380+
New() HTTPClient
381+
}
382+
371383
// Client is used to make HTTP requests. It adds additional functionality
372384
// like automatic retries to tolerate minor outages.
373385
type Client struct {
374-
HTTPClient *http.Client // Internal HTTP client.
375-
Logger interface{} // Customer logger instance. Can be either Logger or LeveledLogger
386+
HTTPClient *http.Client // Internal HTTP client. This field is used if set, otherwise HTTPClientFactory is used.
387+
HTTPClientFactory HTTPClientFactory
388+
Logger interface{} // Customer logger instance. Can be either Logger or LeveledLogger
376389

377390
RetryWaitMin time.Duration // Minimum time to wait
378391
RetryWaitMax time.Duration // Maximum time to wait
@@ -397,19 +410,18 @@ type Client struct {
397410
ErrorHandler ErrorHandler
398411

399412
loggerInit sync.Once
400-
clientInit sync.Once
401413
}
402414

403415
// NewClient creates a new Client with default settings.
404416
func NewClient() *Client {
405417
return &Client{
406-
HTTPClient: cleanhttp.DefaultPooledClient(),
407-
Logger: defaultLogger,
408-
RetryWaitMin: defaultRetryWaitMin,
409-
RetryWaitMax: defaultRetryWaitMax,
410-
RetryMax: defaultRetryMax,
411-
CheckRetry: DefaultRetryPolicy,
412-
Backoff: DefaultBackoff,
418+
HTTPClientFactory: &CleanPooledClientFactory{},
419+
Logger: defaultLogger,
420+
RetryWaitMin: defaultRetryWaitMin,
421+
RetryWaitMax: defaultRetryWaitMax,
422+
RetryMax: defaultRetryMax,
423+
CheckRetry: DefaultRetryPolicy,
424+
Backoff: DefaultBackoff,
413425
}
414426
}
415427

@@ -573,12 +585,6 @@ func PassthroughErrorHandler(resp *http.Response, err error, _ int) (*http.Respo
573585

574586
// Do wraps calling an HTTP method with retries.
575587
func (c *Client) Do(req *Request) (*http.Response, error) {
576-
c.clientInit.Do(func() {
577-
if c.HTTPClient == nil {
578-
c.HTTPClient = cleanhttp.DefaultPooledClient()
579-
}
580-
})
581-
582588
logger := c.logger()
583589

584590
if logger != nil {
@@ -590,6 +596,9 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
590596
}
591597
}
592598

599+
httpClient := c.getHTTPClient()
600+
defer httpClient.Done()
601+
593602
var resp *http.Response
594603
var attempt int
595604
var shouldRetry bool
@@ -603,7 +612,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
603612
if req.body != nil {
604613
body, err := req.body()
605614
if err != nil {
606-
c.HTTPClient.CloseIdleConnections()
607615
return resp, err
608616
}
609617
if c, ok := body.(io.ReadCloser); ok {
@@ -625,7 +633,8 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
625633
}
626634

627635
// Attempt the request
628-
resp, doErr = c.HTTPClient.Do(req.Request)
636+
637+
resp, doErr = httpClient.Do(req.Request)
629638

630639
// Check if we should continue with retries.
631640
shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr)
@@ -694,7 +703,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
694703
select {
695704
case <-req.Context().Done():
696705
timer.Stop()
697-
c.HTTPClient.CloseIdleConnections()
698706
return nil, req.Context().Err()
699707
case <-timer.C:
700708
}
@@ -710,8 +718,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
710718
return resp, nil
711719
}
712720

713-
defer c.HTTPClient.CloseIdleConnections()
714-
715721
var err error
716722
if checkErr != nil {
717723
err = checkErr
@@ -758,6 +764,19 @@ func (c *Client) drainBody(body io.ReadCloser) {
758764
}
759765
}
760766

767+
func (c *Client) getHTTPClient() HTTPClient {
768+
if c.HTTPClient != nil {
769+
return &idleConnectionsClosingClient{
770+
httpClient: c.HTTPClient,
771+
}
772+
}
773+
clientFactory := c.HTTPClientFactory
774+
if clientFactory == nil {
775+
clientFactory = &CleanPooledClientFactory{}
776+
}
777+
return clientFactory.New()
778+
}
779+
761780
// Get is a shortcut for doing a GET request without making a new client.
762781
func Get(url string) (*http.Response, error) {
763782
return defaultClient.Get(url)
@@ -820,3 +839,29 @@ func (c *Client) StandardClient() *http.Client {
820839
Transport: &RoundTripper{Client: c},
821840
}
822841
}
842+
843+
var (
844+
_ HTTPClientFactory = &CleanPooledClientFactory{}
845+
_ HTTPClient = &idleConnectionsClosingClient{}
846+
)
847+
848+
type CleanPooledClientFactory struct {
849+
}
850+
851+
func (f *CleanPooledClientFactory) New() HTTPClient {
852+
return &idleConnectionsClosingClient{
853+
httpClient: cleanhttp.DefaultPooledClient(),
854+
}
855+
}
856+
857+
type idleConnectionsClosingClient struct {
858+
httpClient *http.Client
859+
}
860+
861+
func (c *idleConnectionsClosingClient) Do(req *http.Request) (*http.Response, error) {
862+
return c.httpClient.Do(req)
863+
}
864+
865+
func (c *idleConnectionsClosingClient) Done() {
866+
c.httpClient.CloseIdleConnections()
867+
}

0 commit comments

Comments
 (0)