Skip to content

Commit 98169fe

Browse files
authored
Merge pull request #161 from hashicorp/retry-extra-handling
Retry extra handling
2 parents ff6d014 + 5bd1a6f commit 98169fe

File tree

4 files changed

+171
-18
lines changed

4 files changed

+171
-18
lines changed

README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,25 @@ The returned response object is an `*http.Response`, the same thing you would
4545
usually get from `net/http`. Had the request failed one or more times, the above
4646
call would block and retry with exponential backoff.
4747

48+
## Retrying cases that fail after a seeming success
49+
50+
It's possible for a request to succeed in the sense that the expected response headers are received, but then to encounter network-level errors while reading the response body. In go-retryablehttp's most basic usage, this error would not be retryable, due to the out-of-band handling of the response body. In some cases it may be desirable to handle the response body as part of the retryable operation.
51+
52+
A toy example (which will retry the full request and succeed on the second attempt) is shown below:
53+
54+
```go
55+
c := retryablehttp.NewClient()
56+
r := retryablehttp.NewRequest("GET", "://foo", nil)
57+
handlerShouldRetry := true
58+
r.SetResponseHandler(func(*http.Response) error {
59+
if !handlerShouldRetry {
60+
return nil
61+
}
62+
handlerShouldRetry = false
63+
return errors.New("retryable error")
64+
})
65+
```
66+
4867
## Getting a stdlib `*http.Client` with retries
4968

5069
It's possible to convert a `*retryablehttp.Client` directly to a `*http.Client`.

client.go

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,21 @@ var (
6969
// scheme specified in the URL is invalid. This error isn't typed
7070
// specifically so we resort to matching on the error string.
7171
schemeErrorRe = regexp.MustCompile(`unsupported protocol scheme`)
72+
73+
// A regular expression to match the error returned by net/http when the
74+
// TLS certificate is not trusted. This error isn't typed
75+
// specifically so we resort to matching on the error string.
76+
notTrustedErrorRe = regexp.MustCompile(`certificate is not trusted`)
7277
)
7378

7479
// ReaderFunc is the type of function that can be given natively to NewRequest
7580
type ReaderFunc func() (io.Reader, error)
7681

82+
// ResponseHandlerFunc is a type of function that takes in a Response, and does something with it.
83+
// It only runs if the initial part of the request was successful.
84+
// If an error is returned, the client's retry policy will be used to determine whether to retry the whole request.
85+
type ResponseHandlerFunc func(*http.Response) error
86+
7787
// LenReader is an interface implemented by many in-memory io.Reader's. Used
7888
// for automatically sending the right Content-Length header when possible.
7989
type LenReader interface {
@@ -86,6 +96,8 @@ type Request struct {
8696
// used to rewind the request data in between retries.
8797
body ReaderFunc
8898

99+
responseHandler ResponseHandlerFunc
100+
89101
// Embed an HTTP request directly. This makes a *Request act exactly
90102
// like an *http.Request so that all meta methods are supported.
91103
*http.Request
@@ -95,11 +107,17 @@ type Request struct {
95107
// with its context changed to ctx. The provided ctx must be non-nil.
96108
func (r *Request) WithContext(ctx context.Context) *Request {
97109
return &Request{
98-
body: r.body,
99-
Request: r.Request.WithContext(ctx),
110+
body: r.body,
111+
responseHandler: r.responseHandler,
112+
Request: r.Request.WithContext(ctx),
100113
}
101114
}
102115

116+
// SetResponseHandler allows setting the response handler.
117+
func (r *Request) SetResponseHandler(fn ResponseHandlerFunc) {
118+
r.responseHandler = fn
119+
}
120+
103121
// BodyBytes allows accessing the request body. It is an analogue to
104122
// http.Request's Body variable, but it returns a copy of the underlying data
105123
// rather than consuming it.
@@ -254,7 +272,7 @@ func FromRequest(r *http.Request) (*Request, error) {
254272
return nil, err
255273
}
256274
// Could assert contentLength == r.ContentLength
257-
return &Request{bodyReader, r}, nil
275+
return &Request{body: bodyReader, Request: r}, nil
258276
}
259277

260278
// NewRequest creates a new wrapped request.
@@ -278,7 +296,7 @@ func NewRequestWithContext(ctx context.Context, method, url string, rawBody inte
278296
}
279297
httpReq.ContentLength = contentLength
280298

281-
return &Request{bodyReader, httpReq}, nil
299+
return &Request{body: bodyReader, Request: httpReq}, nil
282300
}
283301

284302
// Logger interface allows to use other loggers than
@@ -445,6 +463,9 @@ func baseRetryPolicy(resp *http.Response, err error) (bool, error) {
445463
}
446464

447465
// Don't retry if the error was due to TLS cert verification failure.
466+
if notTrustedErrorRe.MatchString(v.Error()) {
467+
return false, v
468+
}
448469
if _, ok := v.Err.(x509.UnknownAuthorityError); ok {
449470
return false, v
450471
}
@@ -565,9 +586,10 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
565586
var resp *http.Response
566587
var attempt int
567588
var shouldRetry bool
568-
var doErr, checkErr error
589+
var doErr, respErr, checkErr error
569590

570591
for i := 0; ; i++ {
592+
doErr, respErr = nil, nil
571593
attempt++
572594

573595
// Always rewind the request body when non-nil.
@@ -600,13 +622,21 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
600622

601623
// Check if we should continue with retries.
602624
shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr)
625+
if !shouldRetry && doErr == nil && req.responseHandler != nil {
626+
respErr = req.responseHandler(resp)
627+
shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, respErr)
628+
}
603629

604-
if doErr != nil {
630+
err := doErr
631+
if respErr != nil {
632+
err = respErr
633+
}
634+
if err != nil {
605635
switch v := logger.(type) {
606636
case LeveledLogger:
607-
v.Error("request failed", "error", doErr, "method", req.Method, "url", req.URL)
637+
v.Error("request failed", "error", err, "method", req.Method, "url", req.URL)
608638
case Logger:
609-
v.Printf("[ERR] %s %s request failed: %v", req.Method, req.URL, doErr)
639+
v.Printf("[ERR] %s %s request failed: %v", req.Method, req.URL, err)
610640
}
611641
} else {
612642
// Call this here to maintain the behavior of logging all requests,
@@ -669,15 +699,19 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
669699
}
670700

671701
// this is the closest we have to success criteria
672-
if doErr == nil && checkErr == nil && !shouldRetry {
702+
if doErr == nil && respErr == nil && checkErr == nil && !shouldRetry {
673703
return resp, nil
674704
}
675705

676706
defer c.HTTPClient.CloseIdleConnections()
677707

678-
err := doErr
708+
var err error
679709
if checkErr != nil {
680710
err = checkErr
711+
} else if respErr != nil {
712+
err = respErr
713+
} else {
714+
err = doErr
681715
}
682716

683717
if c.ErrorHandler != nil {

client_test.go

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,13 @@ func testClientDo(t *testing.T, body interface{}) {
167167
// Send the request
168168
var resp *http.Response
169169
doneCh := make(chan struct{})
170+
errCh := make(chan error, 1)
170171
go func() {
171172
defer close(doneCh)
173+
defer close(errCh)
172174
var err error
173175
resp, err = client.Do(req)
174-
if err != nil {
175-
t.Fatalf("err: %v", err)
176-
}
176+
errCh <- err
177177
}()
178178

179179
select {
@@ -247,6 +247,106 @@ func testClientDo(t *testing.T, body interface{}) {
247247
if retryCount < 0 {
248248
t.Fatal("request log hook was not called")
249249
}
250+
251+
err = <-errCh
252+
if err != nil {
253+
t.Fatalf("err: %v", err)
254+
}
255+
}
256+
257+
func TestClient_Do_WithResponseHandler(t *testing.T) {
258+
// Create the client. Use short retry windows so we fail faster.
259+
client := NewClient()
260+
client.RetryWaitMin = 10 * time.Millisecond
261+
client.RetryWaitMax = 10 * time.Millisecond
262+
client.RetryMax = 2
263+
264+
var checks int
265+
client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) {
266+
checks++
267+
if err != nil && strings.Contains(err.Error(), "nonretryable") {
268+
return false, nil
269+
}
270+
return DefaultRetryPolicy(context.TODO(), resp, err)
271+
}
272+
273+
// Mock server which always responds 200.
274+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
275+
w.WriteHeader(200)
276+
}))
277+
defer ts.Close()
278+
279+
var shouldSucceed bool
280+
tests := []struct {
281+
name string
282+
handler ResponseHandlerFunc
283+
expectedChecks int // often 2x number of attempts since we check twice
284+
err string
285+
}{
286+
{
287+
name: "nil handler",
288+
handler: nil,
289+
expectedChecks: 1,
290+
},
291+
{
292+
name: "handler always succeeds",
293+
handler: func(*http.Response) error {
294+
return nil
295+
},
296+
expectedChecks: 2,
297+
},
298+
{
299+
name: "handler always fails in a retryable way",
300+
handler: func(*http.Response) error {
301+
return errors.New("retryable failure")
302+
},
303+
expectedChecks: 6,
304+
},
305+
{
306+
name: "handler always fails in a nonretryable way",
307+
handler: func(*http.Response) error {
308+
return errors.New("nonretryable failure")
309+
},
310+
expectedChecks: 2,
311+
},
312+
{
313+
name: "handler succeeds on second attempt",
314+
handler: func(*http.Response) error {
315+
if shouldSucceed {
316+
return nil
317+
}
318+
shouldSucceed = true
319+
return errors.New("retryable failure")
320+
},
321+
expectedChecks: 4,
322+
},
323+
}
324+
325+
for _, tt := range tests {
326+
t.Run(tt.name, func(t *testing.T) {
327+
checks = 0
328+
shouldSucceed = false
329+
// Create the request
330+
req, err := NewRequest("GET", ts.URL, nil)
331+
if err != nil {
332+
t.Fatalf("err: %v", err)
333+
}
334+
req.SetResponseHandler(tt.handler)
335+
336+
// Send the request.
337+
_, err = client.Do(req)
338+
if err != nil && !strings.Contains(err.Error(), tt.err) {
339+
t.Fatalf("error does not match expectation, expected: %s, got: %s", tt.err, err.Error())
340+
}
341+
if err == nil && tt.err != "" {
342+
t.Fatalf("no error, expected: %s", tt.err)
343+
}
344+
345+
if checks != tt.expectedChecks {
346+
t.Fatalf("expected %d attempts, got %d attempts", tt.expectedChecks, checks)
347+
}
348+
})
349+
}
250350
}
251351

252352
func TestClient_Do_fails(t *testing.T) {
@@ -598,7 +698,7 @@ func TestClient_DefaultRetryPolicy_TLS(t *testing.T) {
598698

599699
func TestClient_DefaultRetryPolicy_redirects(t *testing.T) {
600700
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
601-
http.Redirect(w, r, "/", 302)
701+
http.Redirect(w, r, "/", http.StatusFound)
602702
}))
603703
defer ts.Close()
604704

roundtripper_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,12 @@ func TestRoundTripper_TransportFailureErrorHandling(t *testing.T) {
107107

108108
expectedError := &url.Error{
109109
Op: "Get",
110-
URL: "http://this-url-does-not-exist-ed2fb.com/",
110+
URL: "http://999.999.999.999:999/",
111111
Err: &net.OpError{
112112
Op: "dial",
113113
Net: "tcp",
114114
Err: &net.DNSError{
115-
Name: "this-url-does-not-exist-ed2fb.com",
115+
Name: "999.999.999.999",
116116
Err: "no such host",
117117
IsNotFound: true,
118118
},
@@ -121,10 +121,10 @@ func TestRoundTripper_TransportFailureErrorHandling(t *testing.T) {
121121

122122
// Get the standard client and execute the request.
123123
client := retryClient.StandardClient()
124-
_, err := client.Get("http://this-url-does-not-exist-ed2fb.com/")
124+
_, err := client.Get("http://999.999.999.999:999/")
125125

126126
// assert expectations
127-
if !reflect.DeepEqual(normalizeError(err), expectedError) {
127+
if !reflect.DeepEqual(expectedError, normalizeError(err)) {
128128
t.Fatalf("expected %q, got %q", expectedError, err)
129129
}
130130
}

0 commit comments

Comments
 (0)