Skip to content

Commit 230e352

Browse files
committed
auth: clone the client request body before roundtripping
RoundTrippers may read and close the body, so be careful to clone before roundtripping during client oauth, as the request may be issued multiple times. Fixes #590
1 parent 1b00937 commit 230e352

File tree

2 files changed

+89
-2
lines changed

2 files changed

+89
-2
lines changed

auth/client.go

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
package auth
88

99
import (
10+
"bytes"
1011
"context"
1112
"errors"
13+
"io"
1214
"net/http"
1315
"sync"
1416

@@ -67,7 +69,32 @@ func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
6769
base := t.opts.Base
6870
t.mu.Unlock()
6971

70-
resp, err := base.RoundTrip(req)
72+
// req1 is our first request in the authorization flow.
73+
//
74+
// If we mutate its body, we must clone it first.
75+
req1 := req
76+
var (
77+
// If haveBody is set, the request has a nontrivial body, and need avoid
78+
// reading (or closing) it multiple times. In that case, bodyBytes is its
79+
// content.
80+
haveBody bool
81+
bodyBytes []byte
82+
)
83+
if req.Body != nil && req.Body != http.NoBody {
84+
req1 = req.Clone(req.Context())
85+
haveBody = true
86+
var err error
87+
bodyBytes, err = io.ReadAll(req.Body)
88+
if err != nil {
89+
return nil, err
90+
}
91+
// Now that we've read the request body, http.RoundTripper requires that we
92+
// close it.
93+
req.Body.Close() // ignore error
94+
req1.Body = io.NopCloser(bytes.NewReader(bodyBytes))
95+
}
96+
97+
resp, err := base.RoundTrip(req1)
7198
if err != nil {
7299
return nil, err
73100
}
@@ -97,7 +124,16 @@ func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
97124
}
98125
t.opts.Base = &oauth2.Transport{Base: t.opts.Base, Source: ts}
99126
}
100-
return t.opts.Base.RoundTrip(req.Clone(req.Context()))
127+
128+
// If we don't have a body, the request is reusable, though it will be cloned
129+
// by the base. However, if we've had to read the body, we must clone.
130+
req2 := req
131+
if haveBody {
132+
req2 = req.Clone(req.Context())
133+
req2.Body = io.NopCloser(bytes.NewReader(bodyBytes))
134+
}
135+
136+
return t.opts.Base.RoundTrip(req2)
101137
}
102138

103139
func extractResourceMetadataURL(authHeaders []string) string {

auth/client_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,24 @@ import (
1010
"context"
1111
"errors"
1212
"fmt"
13+
"io"
1314
"net/http"
1415
"net/http/httptest"
16+
"strings"
1517
"testing"
1618

1719
"golang.org/x/oauth2"
1820
)
1921

22+
// A basicReader is an io.Reader to be used as a non-rereadable request body.
23+
//
24+
// net/http has special handling for strings.Reader that we want to avoid.
25+
type basicReader struct {
26+
r *strings.Reader
27+
}
28+
29+
func (r *basicReader) Read(p []byte) (n int, err error) { return r.r.Read(p) }
30+
2031
// TestHTTPTransport validates the OAuth HTTPTransport.
2132
func TestHTTPTransport(t *testing.T) {
2233
const testToken = "test-token-123"
@@ -27,6 +38,20 @@ func TestHTTPTransport(t *testing.T) {
2738

2839
// authServer simulates a resource that requires OAuth.
2940
authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
41+
if r.Method == http.MethodPost {
42+
// Ensure that the body was properly cloned, by reading it completely.
43+
// If the body is not cloned, reading it the second time may yield no
44+
// bytes.
45+
body, err := io.ReadAll(r.Body)
46+
if err != nil {
47+
http.Error(w, err.Error(), http.StatusInternalServerError)
48+
return
49+
}
50+
if len(body) == 0 {
51+
http.Error(w, "empty body", http.StatusBadRequest)
52+
return
53+
}
54+
}
3055
authHeader := r.Header.Get("Authorization")
3156
if authHeader == fmt.Sprintf("Bearer %s", testToken) {
3257
w.WriteHeader(http.StatusOK)
@@ -82,6 +107,32 @@ func TestHTTPTransport(t *testing.T) {
82107
}
83108
})
84109

110+
t.Run("request body is cloned", func(t *testing.T) {
111+
handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) {
112+
if args.ResourceMetadataURL != "http://metadata.example.com" {
113+
t.Errorf("handler got metadata URL %q, want %q", args.ResourceMetadataURL, "http://metadata.example.com")
114+
}
115+
return fakeTokenSource, nil
116+
}
117+
118+
transport, err := NewHTTPTransport(handler, nil)
119+
if err != nil {
120+
t.Fatalf("NewHTTPTransport() failed: %v", err)
121+
}
122+
client := &http.Client{Transport: transport}
123+
124+
resp, err := client.Post(authServer.URL, "application/json", &basicReader{strings.NewReader("{}")})
125+
// resp, err := client.Post(authServer.URL, "application/json", strings.NewReader("{}"))
126+
if err != nil {
127+
t.Fatalf("client.Post() failed: %v", err)
128+
}
129+
defer resp.Body.Close()
130+
131+
if resp.StatusCode != http.StatusOK {
132+
t.Errorf("got status %d, want %d", resp.StatusCode, http.StatusOK)
133+
}
134+
})
135+
85136
t.Run("handler returns error", func(t *testing.T) {
86137
handlerErr := errors.New("user rejected auth")
87138
handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) {

0 commit comments

Comments
 (0)