Skip to content

Commit 2677101

Browse files
committed
Added CheckRedirect support to Client.
1 parent 991b9d0 commit 2677101

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

client.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,7 @@ func (c *Client) PostForm(url string, data url.Values) (*http.Response, error) {
769769
// shims in a *retryablehttp.Client for added retries.
770770
func (c *Client) StandardClient() *http.Client {
771771
return &http.Client{
772-
Transport: &RoundTripper{Client: c},
772+
Transport: &RoundTripper{Client: c},
773+
CheckRedirect: c.HTTPClient.CheckRedirect,
773774
}
774775
}

client_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,68 @@ func TestClient_CheckRetryStop(t *testing.T) {
648648
}
649649
}
650650

651+
func TestClient_CheckRedirects(t *testing.T) {
652+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
653+
w.Header().Add("Location", "/new/path")
654+
655+
switch r.URL.Path {
656+
case "/301":
657+
w.WriteHeader(301)
658+
case "/302":
659+
w.WriteHeader(302)
660+
default:
661+
w.WriteHeader(500)
662+
}
663+
}))
664+
defer ts.Close()
665+
666+
client := NewClient()
667+
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
668+
return http.ErrUseLastResponse
669+
}
670+
stdClient := client.StandardClient()
671+
672+
tests := []int{301, 302}
673+
674+
// Check that we get 301 and 302 responses.
675+
for _, test := range tests {
676+
resp, err := client.Get(fmt.Sprintf("%s/%d", ts.URL, test))
677+
if err != nil {
678+
t.Fatalf("unexpected error testing check redirect. %s", err.Error())
679+
}
680+
if resp.StatusCode != test {
681+
t.Fatalf("expected status code %d but got %d", test, resp.StatusCode)
682+
}
683+
684+
// Check with standard client as well.
685+
resp, err = stdClient.Get(fmt.Sprintf("%s/%d", ts.URL, test))
686+
if err != nil {
687+
t.Fatalf("unexpected error testing check redirect. %s", err.Error())
688+
}
689+
if resp.StatusCode != test {
690+
t.Fatalf("expected status code %d but got %d", test, resp.StatusCode)
691+
}
692+
}
693+
694+
// Check that we get errors when using default check redirect policy.
695+
client = NewClient()
696+
client.RetryMax = 0
697+
stdClient = client.StandardClient()
698+
699+
for _, test := range tests {
700+
_, err := client.Get(fmt.Sprintf("%s/%d", ts.URL, test))
701+
if err == nil {
702+
t.Fatalf("expected none nil error when testing default redirect behavior")
703+
}
704+
705+
// Check with standard client as well.
706+
_, err = stdClient.Get(fmt.Sprintf("%s/%d", ts.URL, test))
707+
if err == nil {
708+
t.Fatalf("expected none nil error when testing default redirect behavior")
709+
}
710+
}
711+
}
712+
651713
func TestClient_Head(t *testing.T) {
652714
// Mock server which always responds 200.
653715
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

0 commit comments

Comments
 (0)