Skip to content

Commit b63fc3a

Browse files
committed
test: cover oauth authorize flow
1 parent 4015bc4 commit b63fc3a

File tree

4 files changed

+331
-11
lines changed

4 files changed

+331
-11
lines changed

internal/googleauth/oauth_flow.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,26 @@ type AuthorizeOptions struct {
2828
Timeout time.Duration
2929
}
3030

31+
var (
32+
readClientCredentials = config.ReadClientCredentials
33+
openBrowserFn = openBrowser
34+
oauthEndpoint = google.Endpoint
35+
randomStateFn = randomState
36+
)
37+
3138
func Authorize(ctx context.Context, opts AuthorizeOptions) (string, error) {
3239
if opts.Timeout <= 0 {
3340
opts.Timeout = 2 * time.Minute
3441
}
3542
if len(opts.Scopes) == 0 {
3643
return "", errors.New("missing scopes")
3744
}
38-
creds, err := config.ReadClientCredentials()
45+
creds, err := readClientCredentials()
3946
if err != nil {
4047
return "", err
4148
}
4249

43-
state, err := randomState()
50+
state, err := randomStateFn()
4451
if err != nil {
4552
return "", err
4653
}
@@ -53,7 +60,7 @@ func Authorize(ctx context.Context, opts AuthorizeOptions) (string, error) {
5360
cfg := oauth2.Config{
5461
ClientID: creds.ClientID,
5562
ClientSecret: creds.ClientSecret,
56-
Endpoint: google.Endpoint,
63+
Endpoint: oauthEndpoint,
5764
RedirectURL: redirectURI,
5865
Scopes: opts.Scopes,
5966
}
@@ -100,7 +107,7 @@ func Authorize(ctx context.Context, opts AuthorizeOptions) (string, error) {
100107
cfg := oauth2.Config{
101108
ClientID: creds.ClientID,
102109
ClientSecret: creds.ClientSecret,
103-
Endpoint: google.Endpoint,
110+
Endpoint: oauthEndpoint,
104111
RedirectURL: redirectURI,
105112
Scopes: opts.Scopes,
106113
}
@@ -170,7 +177,7 @@ func Authorize(ctx context.Context, opts AuthorizeOptions) (string, error) {
170177
fmt.Fprintln(os.Stderr, "Opening browser for authorization…")
171178
fmt.Fprintln(os.Stderr, "If the browser doesn't open, visit this URL:")
172179
fmt.Fprintln(os.Stderr, authURL)
173-
_ = openBrowser(authURL)
180+
_ = openBrowserFn(authURL)
174181

175182
select {
176183
case code := <-codeCh:
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
package googleauth
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"errors"
7+
"net/http"
8+
"net/http/httptest"
9+
"net/url"
10+
"os"
11+
"strings"
12+
"testing"
13+
"time"
14+
15+
"github.com/steipete/gogcli/internal/config"
16+
"golang.org/x/oauth2"
17+
)
18+
19+
func newTokenServer(t *testing.T, refreshToken string) *httptest.Server {
20+
t.Helper()
21+
22+
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
23+
if r.URL.Path != "/token" {
24+
http.NotFound(w, r)
25+
return
26+
}
27+
if err := r.ParseForm(); err != nil {
28+
http.Error(w, "bad form", http.StatusBadRequest)
29+
return
30+
}
31+
if r.Form.Get("grant_type") != "authorization_code" {
32+
http.Error(w, "bad grant_type", http.StatusBadRequest)
33+
return
34+
}
35+
if r.Form.Get("code") == "" {
36+
http.Error(w, "missing code", http.StatusBadRequest)
37+
return
38+
}
39+
w.Header().Set("Content-Type", "application/json")
40+
_ = json.NewEncoder(w).Encode(map[string]any{
41+
"access_token": "at",
42+
"refresh_token": refreshToken,
43+
"token_type": "Bearer",
44+
"expires_in": 3600,
45+
})
46+
}))
47+
}
48+
49+
func TestAuthorize_MissingScopes(t *testing.T) {
50+
_, err := Authorize(context.Background(), AuthorizeOptions{})
51+
if err == nil || !strings.Contains(err.Error(), "missing scopes") {
52+
t.Fatalf("expected missing scopes error, got: %v", err)
53+
}
54+
}
55+
56+
func TestAuthorize_Manual_Success(t *testing.T) {
57+
origRead := readClientCredentials
58+
origEndpoint := oauthEndpoint
59+
origState := randomStateFn
60+
t.Cleanup(func() {
61+
readClientCredentials = origRead
62+
oauthEndpoint = origEndpoint
63+
randomStateFn = origState
64+
})
65+
66+
readClientCredentials = func() (config.ClientCredentials, error) {
67+
return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil
68+
}
69+
randomStateFn = func() (string, error) { return "state123", nil }
70+
71+
tokenSrv := newTokenServer(t, "rt")
72+
defer tokenSrv.Close()
73+
oauthEndpoint = oauth2EndpointForTest(tokenSrv.URL)
74+
75+
origStdin := os.Stdin
76+
t.Cleanup(func() { os.Stdin = origStdin })
77+
r, w, err := os.Pipe()
78+
if err != nil {
79+
t.Fatalf("pipe: %v", err)
80+
}
81+
os.Stdin = r
82+
_, _ = w.WriteString("http://localhost:1/?code=abc&state=state123\n")
83+
_ = w.Close()
84+
85+
rt, err := Authorize(context.Background(), AuthorizeOptions{
86+
Scopes: []string{"s1"},
87+
Manual: true,
88+
Timeout: 2 * time.Second,
89+
})
90+
if err != nil {
91+
t.Fatalf("Authorize: %v", err)
92+
}
93+
if rt != "rt" {
94+
t.Fatalf("unexpected refresh token: %q", rt)
95+
}
96+
}
97+
98+
func TestAuthorize_Manual_StateMismatch(t *testing.T) {
99+
origRead := readClientCredentials
100+
origEndpoint := oauthEndpoint
101+
origState := randomStateFn
102+
t.Cleanup(func() {
103+
readClientCredentials = origRead
104+
oauthEndpoint = origEndpoint
105+
randomStateFn = origState
106+
})
107+
108+
readClientCredentials = func() (config.ClientCredentials, error) {
109+
return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil
110+
}
111+
randomStateFn = func() (string, error) { return "state123", nil }
112+
113+
tokenSrv := newTokenServer(t, "rt")
114+
defer tokenSrv.Close()
115+
oauthEndpoint = oauth2EndpointForTest(tokenSrv.URL)
116+
117+
origStdin := os.Stdin
118+
t.Cleanup(func() { os.Stdin = origStdin })
119+
r, w, err := os.Pipe()
120+
if err != nil {
121+
t.Fatalf("pipe: %v", err)
122+
}
123+
os.Stdin = r
124+
_, _ = w.WriteString("http://localhost:1/?code=abc&state=DIFFERENT\n")
125+
_ = w.Close()
126+
127+
_, err = Authorize(context.Background(), AuthorizeOptions{
128+
Scopes: []string{"s1"},
129+
Manual: true,
130+
Timeout: 2 * time.Second,
131+
})
132+
if err == nil || !strings.Contains(err.Error(), "state mismatch") {
133+
t.Fatalf("expected state mismatch, got: %v", err)
134+
}
135+
}
136+
137+
func TestAuthorize_ServerFlow_Success(t *testing.T) {
138+
origRead := readClientCredentials
139+
origEndpoint := oauthEndpoint
140+
origOpen := openBrowserFn
141+
t.Cleanup(func() {
142+
readClientCredentials = origRead
143+
oauthEndpoint = origEndpoint
144+
openBrowserFn = origOpen
145+
})
146+
147+
readClientCredentials = func() (config.ClientCredentials, error) {
148+
return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil
149+
}
150+
151+
tokenSrv := newTokenServer(t, "rt")
152+
defer tokenSrv.Close()
153+
oauthEndpoint = oauth2EndpointForTest(tokenSrv.URL)
154+
155+
openBrowserFn = func(authURL string) error {
156+
u, err := url.Parse(authURL)
157+
if err != nil {
158+
return err
159+
}
160+
q := u.Query()
161+
redirect := q.Get("redirect_uri")
162+
state := q.Get("state")
163+
if redirect == "" || state == "" {
164+
return errors.New("missing redirect/state")
165+
}
166+
cb := redirect + "?code=abc&state=" + url.QueryEscape(state)
167+
resp, err := http.Get(cb)
168+
if err != nil {
169+
return err
170+
}
171+
_ = resp.Body.Close()
172+
return nil
173+
}
174+
175+
rt, err := Authorize(context.Background(), AuthorizeOptions{
176+
Scopes: []string{"s1"},
177+
Timeout: 2 * time.Second,
178+
})
179+
if err != nil {
180+
t.Fatalf("Authorize: %v", err)
181+
}
182+
if rt != "rt" {
183+
t.Fatalf("unexpected refresh token: %q", rt)
184+
}
185+
}
186+
187+
func TestAuthorize_ServerFlow_CallbackErrors(t *testing.T) {
188+
tests := []struct {
189+
name string
190+
query string
191+
wantText string
192+
}{
193+
{name: "missing_code", query: "state=%s", wantText: "missing code"},
194+
{name: "state_mismatch", query: "code=abc&state=WRONG", wantText: "state mismatch"},
195+
{name: "oauth_error", query: "error=access_denied&state=%s", wantText: "authorization error"},
196+
}
197+
198+
for _, tt := range tests {
199+
t.Run(tt.name, func(t *testing.T) {
200+
origRead := readClientCredentials
201+
origEndpoint := oauthEndpoint
202+
origOpen := openBrowserFn
203+
t.Cleanup(func() {
204+
readClientCredentials = origRead
205+
oauthEndpoint = origEndpoint
206+
openBrowserFn = origOpen
207+
})
208+
209+
readClientCredentials = func() (config.ClientCredentials, error) {
210+
return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil
211+
}
212+
213+
tokenSrv := newTokenServer(t, "rt")
214+
defer tokenSrv.Close()
215+
oauthEndpoint = oauth2EndpointForTest(tokenSrv.URL)
216+
217+
openBrowserFn = func(authURL string) error {
218+
u, err := url.Parse(authURL)
219+
if err != nil {
220+
return err
221+
}
222+
q := u.Query()
223+
redirect := q.Get("redirect_uri")
224+
state := q.Get("state")
225+
if redirect == "" || state == "" {
226+
return errors.New("missing redirect/state")
227+
}
228+
query := tt.query
229+
if strings.Contains(query, "%s") {
230+
query = fmtSprintf(query, url.QueryEscape(state))
231+
}
232+
cb := redirect + "?" + query
233+
resp, err := http.Get(cb)
234+
if err != nil {
235+
return err
236+
}
237+
_ = resp.Body.Close()
238+
return nil
239+
}
240+
241+
_, err := Authorize(context.Background(), AuthorizeOptions{
242+
Scopes: []string{"s1"},
243+
Timeout: 2 * time.Second,
244+
})
245+
if err == nil || !strings.Contains(err.Error(), tt.wantText) {
246+
t.Fatalf("expected %q error, got: %v", tt.wantText, err)
247+
}
248+
})
249+
}
250+
}
251+
252+
// oauth2.Endpoint is a plain struct; keep construction centralized.
253+
func oauth2EndpointForTest(base string) oauth2.Endpoint {
254+
return oauth2.Endpoint{
255+
AuthURL: base + "/auth",
256+
TokenURL: base + "/token",
257+
}
258+
}
259+
260+
// Minimal sprintf to avoid importing fmt just for one small helper in tests.
261+
func fmtSprintf(format string, v string) string {
262+
return strings.ReplaceAll(format, "%s", v)
263+
}

internal/googleauth/open_browser.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,22 @@ import (
55
"runtime"
66
)
77

8+
var startCommand = func(name string, args ...string) error {
9+
return exec.Command(name, args...).Start()
10+
}
11+
812
func openBrowser(u string) error {
9-
var cmd *exec.Cmd
10-
switch runtime.GOOS {
13+
name, args := openBrowserCommand(u, runtime.GOOS)
14+
return startCommand(name, args...)
15+
}
16+
17+
func openBrowserCommand(u string, goos string) (name string, args []string) {
18+
switch goos {
1119
case "darwin":
12-
cmd = exec.Command("open", u)
20+
return "open", []string{u}
1321
case "windows":
14-
cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", u)
22+
return "rundll32", []string{"url.dll,FileProtocolHandler", u}
1523
default:
16-
cmd = exec.Command("xdg-open", u)
24+
return "xdg-open", []string{u}
1725
}
18-
return cmd.Start()
1926
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package googleauth
2+
3+
import "testing"
4+
5+
func TestOpenBrowserCommand(t *testing.T) {
6+
name, args := openBrowserCommand("https://example.com", "darwin")
7+
if name != "open" || len(args) != 1 || args[0] != "https://example.com" {
8+
t.Fatalf("darwin: %q %#v", name, args)
9+
}
10+
11+
name, args = openBrowserCommand("https://example.com", "windows")
12+
if name != "rundll32" || len(args) != 2 || args[1] != "https://example.com" {
13+
t.Fatalf("windows: %q %#v", name, args)
14+
}
15+
16+
name, args = openBrowserCommand("https://example.com", "linux")
17+
if name != "xdg-open" || len(args) != 1 || args[0] != "https://example.com" {
18+
t.Fatalf("linux: %q %#v", name, args)
19+
}
20+
}
21+
22+
func TestOpenBrowser_UsesStartCommand(t *testing.T) {
23+
orig := startCommand
24+
t.Cleanup(func() { startCommand = orig })
25+
26+
var gotName string
27+
var gotArgs []string
28+
startCommand = func(name string, args ...string) error {
29+
gotName = name
30+
gotArgs = append([]string(nil), args...)
31+
return nil
32+
}
33+
34+
if err := openBrowser("https://example.com"); err != nil {
35+
t.Fatalf("openBrowser: %v", err)
36+
}
37+
if gotName == "" || len(gotArgs) == 0 {
38+
t.Fatalf("expected startCommand to be called")
39+
}
40+
if gotArgs[len(gotArgs)-1] != "https://example.com" {
41+
t.Fatalf("unexpected args: %q %#v", gotName, gotArgs)
42+
}
43+
}

0 commit comments

Comments
 (0)