Skip to content

Commit ee60032

Browse files
Merge pull request #381 from supertokens/tests-fix
refactor: Check for status in github validate access token
2 parents a8829da + e2318f3 commit ee60032

File tree

4 files changed

+14
-12
lines changed

4 files changed

+14
-12
lines changed

recipe/thirdparty/providers/github.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@ func Github(input tpmodels.ProviderInput) *tpmodels.TypeProvider {
4242
basicAuthToken := base64.StdEncoding.EncodeToString([]byte(clientConfig.ClientID + ":" + clientConfig.ClientSecret))
4343
wrongClientIdError := errors.New("Access token does not belong to your application")
4444

45-
resp, err := doPostRequest("https://api.github.com/applications/"+clientConfig.ClientID+"/token", map[string]interface{}{
45+
resp, status, err := doPostRequest("https://api.github.com/applications/"+clientConfig.ClientID+"/token", map[string]interface{}{
4646
"access_token": accessToken,
4747
}, map[string]interface{}{
4848
"Authorization": "Basic " + basicAuthToken,
4949
"Content-Type": "application/json",
5050
})
5151

52-
if err != nil {
52+
if err != nil || status != 200 {
5353
return errors.New("Invalid access token")
5454
}
5555

recipe/thirdparty/providers/oauth2_impl.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func oauth2_ExchangeAuthCodeForOAuthTokens(config tpmodels.ProviderConfigForClie
106106
}
107107
/* Transformation needed for dev keys END */
108108

109-
oAuthTokens, err := doPostRequest(tokenAPIURL, accessTokenAPIParams, nil)
109+
oAuthTokens, _, err := doPostRequest(tokenAPIURL, accessTokenAPIParams, nil)
110110
if err != nil {
111111
return nil, err
112112
}

recipe/thirdparty/providers/twitter.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,11 @@ func Twitter(input tpmodels.ProviderInput) *tpmodels.TypeProvider {
8787
twitterOauthParams["redirect_uri"] = redirectUri
8888
twitterOauthParams["code"] = redirectURIInfo.RedirectURIQueryParams["code"]
8989

90-
return doPostRequest(originalImplementation.Config.TokenEndpoint, twitterOauthParams, map[string]interface{}{
90+
resp, _, err := doPostRequest(originalImplementation.Config.TokenEndpoint, twitterOauthParams, map[string]interface{}{
9191
"Authorization": "Basic " + basicAuthToken,
9292
})
93+
94+
return resp, err
9395
}
9496

9597
if oOverride != nil {

recipe/thirdparty/providers/utils.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,16 @@ func doGetRequest(url string, queryParams map[string]interface{}, headers map[st
9090
return result, nil
9191
}
9292

93-
func doPostRequest(url string, params map[string]interface{}, headers map[string]interface{}) (map[string]interface{}, error) {
93+
func doPostRequest(url string, params map[string]interface{}, headers map[string]interface{}) (map[string]interface{}, int, error) {
9494
supertokens.LogDebugMessage(fmt.Sprintf("POST request to %s, with form fields %v and headers %v", url, params, headers))
9595

9696
postBody, err := qs.Marshal(params)
9797
if err != nil {
98-
return nil, err
98+
return nil, -1, err
9999
}
100100
req, err := http.NewRequest("POST", url, bytes.NewBuffer([]byte(postBody)))
101101
if err != nil {
102-
return nil, err
102+
return nil, -1, err
103103
}
104104
for key, value := range headers {
105105
req.Header.Set(key, value.(string))
@@ -110,28 +110,28 @@ func doPostRequest(url string, params map[string]interface{}, headers map[string
110110
client := &http.Client{}
111111
resp, err := client.Do(req)
112112
if err != nil {
113-
return nil, err
113+
return nil, resp.StatusCode, err
114114
}
115115
defer resp.Body.Close()
116116

117117
body, err := ioutil.ReadAll(resp.Body)
118118
if err != nil {
119-
return nil, err
119+
return nil, resp.StatusCode, err
120120
}
121121

122122
supertokens.LogDebugMessage(fmt.Sprintf("Received response with status %d and body %s", resp.StatusCode, string(body)))
123123

124124
var result map[string]interface{}
125125
err = json.Unmarshal(body, &result)
126126
if err != nil {
127-
return nil, err
127+
return nil, resp.StatusCode, err
128128
}
129129

130130
if resp.StatusCode >= 300 {
131-
return nil, fmt.Errorf("POST request to %s resulted in %d status with body %s", url, resp.StatusCode, string(body))
131+
return nil, resp.StatusCode, fmt.Errorf("POST request to %s resulted in %d status with body %s", url, resp.StatusCode, string(body))
132132
}
133133

134-
return result, nil
134+
return result, resp.StatusCode, nil
135135
}
136136

137137
// JWKS utils

0 commit comments

Comments
 (0)