@@ -10,13 +10,24 @@ import (
1010 "context"
1111 "errors"
1212 "fmt"
13+ "io"
1314 "net/http"
1415 "net/http/httptest"
16+ "strings"
1517 "testing"
1618
1719 "golang.org/x/oauth2"
1820)
1921
22+ // A basicReader is an io.Reader to be used as a non-rereadable request body.
23+ //
24+ // net/http has special handling for strings.Reader that we want to avoid.
25+ type basicReader struct {
26+ r * strings.Reader
27+ }
28+
29+ func (r * basicReader ) Read (p []byte ) (n int , err error ) { return r .r .Read (p ) }
30+
2031// TestHTTPTransport validates the OAuth HTTPTransport.
2132func TestHTTPTransport (t * testing.T ) {
2233 const testToken = "test-token-123"
@@ -27,6 +38,20 @@ func TestHTTPTransport(t *testing.T) {
2738
2839 // authServer simulates a resource that requires OAuth.
2940 authServer := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
41+ if r .Method == http .MethodPost {
42+ // Ensure that the body was properly cloned, by reading it completely.
43+ // If the body is not cloned, reading it the second time may yield no
44+ // bytes.
45+ body , err := io .ReadAll (r .Body )
46+ if err != nil {
47+ http .Error (w , err .Error (), http .StatusInternalServerError )
48+ return
49+ }
50+ if len (body ) == 0 {
51+ http .Error (w , "empty body" , http .StatusBadRequest )
52+ return
53+ }
54+ }
3055 authHeader := r .Header .Get ("Authorization" )
3156 if authHeader == fmt .Sprintf ("Bearer %s" , testToken ) {
3257 w .WriteHeader (http .StatusOK )
@@ -82,6 +107,32 @@ func TestHTTPTransport(t *testing.T) {
82107 }
83108 })
84109
110+ t .Run ("request body is cloned" , func (t * testing.T ) {
111+ handler := func (ctx context.Context , args OAuthHandlerArgs ) (oauth2.TokenSource , error ) {
112+ if args .ResourceMetadataURL != "http://metadata.example.com" {
113+ t .Errorf ("handler got metadata URL %q, want %q" , args .ResourceMetadataURL , "http://metadata.example.com" )
114+ }
115+ return fakeTokenSource , nil
116+ }
117+
118+ transport , err := NewHTTPTransport (handler , nil )
119+ if err != nil {
120+ t .Fatalf ("NewHTTPTransport() failed: %v" , err )
121+ }
122+ client := & http.Client {Transport : transport }
123+
124+ resp , err := client .Post (authServer .URL , "application/json" , & basicReader {strings .NewReader ("{}" )})
125+ // resp, err := client.Post(authServer.URL, "application/json", strings.NewReader("{}"))
126+ if err != nil {
127+ t .Fatalf ("client.Post() failed: %v" , err )
128+ }
129+ defer resp .Body .Close ()
130+
131+ if resp .StatusCode != http .StatusOK {
132+ t .Errorf ("got status %d, want %d" , resp .StatusCode , http .StatusOK )
133+ }
134+ })
135+
85136 t .Run ("handler returns error" , func (t * testing.T ) {
86137 handlerErr := errors .New ("user rejected auth" )
87138 handler := func (ctx context.Context , args OAuthHandlerArgs ) (oauth2.TokenSource , error ) {
0 commit comments