Skip to content

Commit 412c6c9

Browse files
committed
fix: extended api host to have a oauth url
1 parent 43e207d commit 412c6c9

File tree

4 files changed

+171
-6
lines changed

4 files changed

+171
-6
lines changed

pkg/http/server.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"io"
77
"log/slog"
88
"net/http"
9-
"net/url"
109
"os"
1110
"os/signal"
1211
"slices"
@@ -129,12 +128,11 @@ func RunHTTPServer(cfg ServerConfig) error {
129128
ResourcePath: cfg.ResourcePath,
130129
}
131130
if cfg.Host != "" {
132-
u := &url.URL{
133-
Scheme: "https",
134-
Host: cfg.Host,
135-
Path: "/login/oauth",
131+
oauthURL, err := apiHost.OAuthURL(ctx)
132+
if err != nil {
133+
return fmt.Errorf("failed to get OAuth URL: %w", err)
136134
}
137-
oauthCfg.AuthorizationServer = u.String()
135+
oauthCfg.AuthorizationServer = oauthURL.String()
138136
}
139137

140138
serverOptions := []HandlerOption{}

pkg/scopes/fetcher_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ func (t testAPIHostResolver) UploadURL(_ context.Context) (*url.URL, error) {
2828
func (t testAPIHostResolver) RawURL(_ context.Context) (*url.URL, error) {
2929
return nil, nil
3030
}
31+
func (t testAPIHostResolver) OAuthURL(_ context.Context) (*url.URL, error) {
32+
return nil, nil
33+
}
3134

3235
func TestParseScopeHeader(t *testing.T) {
3336
tests := []struct {

pkg/utils/api.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ type APIHostResolver interface {
1414
GraphqlURL(ctx context.Context) (*url.URL, error)
1515
UploadURL(ctx context.Context) (*url.URL, error)
1616
RawURL(ctx context.Context) (*url.URL, error)
17+
OAuthURL(ctx context.Context) (*url.URL, error)
1718
}
1819

1920
type APIHost struct {
2021
restURL *url.URL
2122
gqlURL *url.URL
2223
uploadURL *url.URL
2324
rawURL *url.URL
25+
oauthURL *url.URL
2426
}
2527

2628
var _ APIHostResolver = APIHost{}
@@ -52,6 +54,10 @@ func (a APIHost) RawURL(_ context.Context) (*url.URL, error) {
5254
return a.rawURL, nil
5355
}
5456

57+
func (a APIHost) OAuthURL(_ context.Context) (*url.URL, error) {
58+
return a.oauthURL, nil
59+
}
60+
5561
func newDotcomHost() (APIHost, error) {
5662
baseRestURL, err := url.Parse("https://api.github.com/")
5763
if err != nil {
@@ -73,11 +79,17 @@ func newDotcomHost() (APIHost, error) {
7379
return APIHost{}, fmt.Errorf("failed to parse dotcom Raw URL: %w", err)
7480
}
7581

82+
oauthURL, err := url.Parse("https://github.com/login/oauth")
83+
if err != nil {
84+
return APIHost{}, fmt.Errorf("failed to parse dotcom OAuth URL: %w", err)
85+
}
86+
7687
return APIHost{
7788
restURL: baseRestURL,
7889
gqlURL: gqlURL,
7990
uploadURL: uploadURL,
8091
rawURL: rawURL,
92+
oauthURL: oauthURL,
8193
}, nil
8294
}
8395

@@ -112,11 +124,17 @@ func newGHECHost(hostname string) (APIHost, error) {
112124
return APIHost{}, fmt.Errorf("failed to parse GHEC Raw URL: %w", err)
113125
}
114126

127+
oauthURL, err := url.Parse(fmt.Sprintf("https://%s/login/oauth", u.Hostname()))
128+
if err != nil {
129+
return APIHost{}, fmt.Errorf("failed to parse GHEC OAuth URL: %w", err)
130+
}
131+
115132
return APIHost{
116133
restURL: restURL,
117134
gqlURL: gqlURL,
118135
uploadURL: uploadURL,
119136
rawURL: rawURL,
137+
oauthURL: oauthURL,
120138
}, nil
121139
}
122140

@@ -164,11 +182,17 @@ func newGHESHost(hostname string) (APIHost, error) {
164182
return APIHost{}, fmt.Errorf("failed to parse GHES Raw URL: %w", err)
165183
}
166184

185+
oauthURL, err := url.Parse(fmt.Sprintf("%s://%s/login/oauth", u.Scheme, u.Hostname()))
186+
if err != nil {
187+
return APIHost{}, fmt.Errorf("failed to parse GHES OAuth URL: %w", err)
188+
}
189+
167190
return APIHost{
168191
restURL: restURL,
169192
gqlURL: gqlURL,
170193
uploadURL: uploadURL,
171194
rawURL: rawURL,
195+
oauthURL: oauthURL,
172196
}, nil
173197
}
174198

pkg/utils/api_test.go

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
package utils //nolint:revive //TODO: figure out a better name for this package
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestOAuthURL(t *testing.T) {
12+
ctx := context.Background()
13+
14+
tests := []struct {
15+
name string
16+
host string
17+
expectedOAuth string
18+
expectError bool
19+
errorSubstring string
20+
}{
21+
{
22+
name: "dotcom (empty host)",
23+
host: "",
24+
expectedOAuth: "https://github.com/login/oauth",
25+
},
26+
{
27+
name: "dotcom (explicit github.com)",
28+
host: "https://github.com",
29+
expectedOAuth: "https://github.com/login/oauth",
30+
},
31+
{
32+
name: "GHEC with HTTPS",
33+
host: "https://acme.ghe.com",
34+
expectedOAuth: "https://acme.ghe.com/login/oauth",
35+
},
36+
{
37+
name: "GHEC with HTTP (should error)",
38+
host: "http://acme.ghe.com",
39+
expectError: true,
40+
errorSubstring: "GHEC URL must be HTTPS",
41+
},
42+
{
43+
name: "GHES with HTTPS",
44+
host: "https://ghes.example.com",
45+
expectedOAuth: "https://ghes.example.com/login/oauth",
46+
},
47+
{
48+
name: "GHES with HTTP",
49+
host: "http://ghes.example.com",
50+
expectedOAuth: "http://ghes.example.com/login/oauth",
51+
},
52+
{
53+
name: "GHES with HTTP and custom port (port stripped - not supported yet)",
54+
host: "http://ghes.local:8080",
55+
expectedOAuth: "http://ghes.local/login/oauth", // Port is stripped ref: ln222 api.go comment
56+
},
57+
{
58+
name: "GHES with HTTPS and custom port (port stripped - not supported yet)",
59+
host: "https://ghes.local:8443",
60+
expectedOAuth: "https://ghes.local/login/oauth", // Port is stripped ref: ln222 api.go comment
61+
},
62+
{
63+
name: "host without scheme (should error)",
64+
host: "ghes.example.com",
65+
expectError: true,
66+
errorSubstring: "host must have a scheme",
67+
},
68+
}
69+
70+
for _, tt := range tests {
71+
t.Run(tt.name, func(t *testing.T) {
72+
apiHost, err := NewAPIHost(tt.host)
73+
74+
if tt.expectError {
75+
require.Error(t, err)
76+
if tt.errorSubstring != "" {
77+
assert.Contains(t, err.Error(), tt.errorSubstring)
78+
}
79+
return
80+
}
81+
82+
require.NoError(t, err)
83+
require.NotNil(t, apiHost)
84+
85+
oauthURL, err := apiHost.OAuthURL(ctx)
86+
require.NoError(t, err)
87+
require.NotNil(t, oauthURL)
88+
89+
assert.Equal(t, tt.expectedOAuth, oauthURL.String())
90+
})
91+
}
92+
}
93+
94+
func TestAPIHost_AllURLsHaveConsistentScheme(t *testing.T) {
95+
ctx := context.Background()
96+
97+
tests := []struct {
98+
name string
99+
host string
100+
expectedScheme string
101+
}{
102+
{
103+
name: "GHES with HTTPS",
104+
host: "https://ghes.example.com",
105+
expectedScheme: "https",
106+
},
107+
{
108+
name: "GHES with HTTP",
109+
host: "http://ghes.example.com",
110+
expectedScheme: "http",
111+
},
112+
}
113+
114+
for _, tt := range tests {
115+
t.Run(tt.name, func(t *testing.T) {
116+
apiHost, err := NewAPIHost(tt.host)
117+
require.NoError(t, err)
118+
119+
restURL, err := apiHost.BaseRESTURL(ctx)
120+
require.NoError(t, err)
121+
assert.Equal(t, tt.expectedScheme, restURL.Scheme, "REST URL scheme should match")
122+
123+
gqlURL, err := apiHost.GraphqlURL(ctx)
124+
require.NoError(t, err)
125+
assert.Equal(t, tt.expectedScheme, gqlURL.Scheme, "GraphQL URL scheme should match")
126+
127+
uploadURL, err := apiHost.UploadURL(ctx)
128+
require.NoError(t, err)
129+
assert.Equal(t, tt.expectedScheme, uploadURL.Scheme, "Upload URL scheme should match")
130+
131+
rawURL, err := apiHost.RawURL(ctx)
132+
require.NoError(t, err)
133+
assert.Equal(t, tt.expectedScheme, rawURL.Scheme, "Raw URL scheme should match")
134+
135+
oauthURL, err := apiHost.OAuthURL(ctx)
136+
require.NoError(t, err)
137+
assert.Equal(t, tt.expectedScheme, oauthURL.Scheme, "OAuth URL scheme should match")
138+
})
139+
}
140+
}

0 commit comments

Comments
 (0)