@@ -343,11 +343,23 @@ type Backoff func(min, max time.Duration, attemptNum int, resp *http.Response) t
343343// attempted. If overriding this, be sure to close the body if needed.
344344type ErrorHandler func (resp * http.Response , err error , numTries int ) (* http.Response , error )
345345
346+ type HTTPClient interface {
347+ // Do performs an HTTP request and returns an HTTP response.
348+ Do (* http.Request ) (* http.Response , error )
349+ // Done is called when the client is no longer needed.
350+ Done ()
351+ }
352+
353+ type HTTPClientFactory interface {
354+ // New returns an HTTP client to use for a request, including retries.
355+ New () HTTPClient
356+ }
357+
346358// Client is used to make HTTP requests. It adds additional functionality
347359// like automatic retries to tolerate minor outages.
348360type Client struct {
349- HTTPClient * http. Client // Internal HTTP client.
350- Logger interface {} // Customer logger instance. Can be either Logger or LeveledLogger
361+ HTTPClientFactory HTTPClientFactory
362+ Logger interface {} // Customer logger instance. Can be either Logger or LeveledLogger
351363
352364 RetryWaitMin time.Duration // Minimum time to wait
353365 RetryWaitMax time.Duration // Maximum time to wait
@@ -372,19 +384,18 @@ type Client struct {
372384 ErrorHandler ErrorHandler
373385
374386 loggerInit sync.Once
375- clientInit sync.Once
376387}
377388
378389// NewClient creates a new Client with default settings.
379390func NewClient () * Client {
380391 return & Client {
381- HTTPClient : cleanhttp . DefaultPooledClient () ,
382- Logger : defaultLogger ,
383- RetryWaitMin : defaultRetryWaitMin ,
384- RetryWaitMax : defaultRetryWaitMax ,
385- RetryMax : defaultRetryMax ,
386- CheckRetry : DefaultRetryPolicy ,
387- Backoff : DefaultBackoff ,
392+ HTTPClientFactory : & CleanPooledClientFactory {} ,
393+ Logger : defaultLogger ,
394+ RetryWaitMin : defaultRetryWaitMin ,
395+ RetryWaitMax : defaultRetryWaitMax ,
396+ RetryMax : defaultRetryMax ,
397+ CheckRetry : DefaultRetryPolicy ,
398+ Backoff : DefaultBackoff ,
388399 }
389400}
390401
@@ -545,12 +556,6 @@ func PassthroughErrorHandler(resp *http.Response, err error, _ int) (*http.Respo
545556
546557// Do wraps calling an HTTP method with retries.
547558func (c * Client ) Do (req * Request ) (* http.Response , error ) {
548- c .clientInit .Do (func () {
549- if c .HTTPClient == nil {
550- c .HTTPClient = cleanhttp .DefaultPooledClient ()
551- }
552- })
553-
554559 logger := c .logger ()
555560
556561 if logger != nil {
@@ -562,6 +567,13 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
562567 }
563568 }
564569
570+ clientFactory := c .HTTPClientFactory
571+ if clientFactory == nil {
572+ clientFactory = & CleanPooledClientFactory {}
573+ }
574+ httpClient := clientFactory .New ()
575+ defer httpClient .Done ()
576+
565577 var resp * http.Response
566578 var attempt int
567579 var shouldRetry bool
@@ -574,7 +586,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
574586 if req .body != nil {
575587 body , err := req .body ()
576588 if err != nil {
577- c .HTTPClient .CloseIdleConnections ()
578589 return resp , err
579590 }
580591 if c , ok := body .(io.ReadCloser ); ok {
@@ -596,7 +607,8 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
596607 }
597608
598609 // Attempt the request
599- resp , doErr = c .HTTPClient .Do (req .Request )
610+
611+ resp , doErr = httpClient .Do (req .Request )
600612
601613 // Check if we should continue with retries.
602614 shouldRetry , checkErr = c .CheckRetry (req .Context (), resp , doErr )
@@ -657,7 +669,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
657669 select {
658670 case <- req .Context ().Done ():
659671 timer .Stop ()
660- c .HTTPClient .CloseIdleConnections ()
661672 return nil , req .Context ().Err ()
662673 case <- timer .C :
663674 }
@@ -673,8 +684,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
673684 return resp , nil
674685 }
675686
676- defer c .HTTPClient .CloseIdleConnections ()
677-
678687 err := doErr
679688 if checkErr != nil {
680689 err = checkErr
@@ -779,3 +788,29 @@ func (c *Client) StandardClient() *http.Client {
779788 Transport : & RoundTripper {Client : c },
780789 }
781790}
791+
792+ var (
793+ _ HTTPClientFactory = & CleanPooledClientFactory {}
794+ _ HTTPClient = & cleanClient {}
795+ )
796+
797+ type CleanPooledClientFactory struct {
798+ }
799+
800+ func (f * CleanPooledClientFactory ) New () HTTPClient {
801+ return & cleanClient {
802+ httpClient : cleanhttp .DefaultPooledClient (),
803+ }
804+ }
805+
806+ type cleanClient struct {
807+ httpClient * http.Client
808+ }
809+
810+ func (c * cleanClient ) Do (req * http.Request ) (* http.Response , error ) {
811+ return c .httpClient .Do (req )
812+ }
813+
814+ func (c * cleanClient ) Done () {
815+ c .httpClient .CloseIdleConnections ()
816+ }
0 commit comments