Skip to content

Commit fbc5ba1

Browse files
committed
chore: oauth accepts APIHostResolver as argument
1 parent 412c6c9 commit fbc5ba1

File tree

3 files changed

+73
-9
lines changed

3 files changed

+73
-9
lines changed

pkg/http/oauth/oauth.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
package oauth
44

55
import (
6+
"context"
67
"fmt"
78
"net/http"
89
"strings"
910

1011
"github.com/github/github-mcp-server/pkg/http/headers"
12+
"github.com/github/github-mcp-server/pkg/utils"
1113
"github.com/go-chi/chi/v5"
1214
"github.com/modelcontextprotocol/go-sdk/auth"
1315
"github.com/modelcontextprotocol/go-sdk/oauthex"
@@ -43,8 +45,13 @@ type Config struct {
4345
// This is used to construct the OAuth resource URL.
4446
BaseURL string
4547

48+
// APIHost is the GitHub API host resolver that provides OAuth URL.
49+
// If set, this takes precedence over AuthorizationServer.
50+
APIHost utils.APIHostResolver
51+
4652
// AuthorizationServer is the OAuth authorization server URL.
4753
// Defaults to GitHub's OAuth server if not specified.
54+
// This field is ignored if APIHost is set.
4855
AuthorizationServer string
4956

5057
// ResourcePath is the externally visible base path for the MCP server (e.g., "/mcp").
@@ -64,8 +71,15 @@ func NewAuthHandler(cfg *Config) (*AuthHandler, error) {
6471
cfg = &Config{}
6572
}
6673

67-
// Default authorization server to GitHub
68-
if cfg.AuthorizationServer == "" {
74+
// Resolve authorization server from APIHost if provided
75+
if cfg.APIHost != nil {
76+
oauthURL, err := cfg.APIHost.OAuthURL(context.Background())
77+
if err != nil {
78+
return nil, fmt.Errorf("failed to get OAuth URL from API host: %w", err)
79+
}
80+
cfg.AuthorizationServer = oauthURL.String()
81+
} else if cfg.AuthorizationServer == "" {
82+
// Default authorization server to GitHub if not provided
6983
cfg.AuthorizationServer = DefaultAuthorizationServer
7084
}
7185

pkg/http/oauth/oauth_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,49 @@
11
package oauth
22

33
import (
4+
"context"
45
"crypto/tls"
56
"encoding/json"
67
"net/http"
78
"net/http/httptest"
9+
"net/url"
810
"testing"
911

1012
"github.com/github/github-mcp-server/pkg/http/headers"
13+
"github.com/github/github-mcp-server/pkg/utils"
1114
"github.com/go-chi/chi/v5"
1215
"github.com/stretchr/testify/assert"
1316
"github.com/stretchr/testify/require"
1417
)
1518

19+
// mockAPIHostResolver is a test implementation of utils.APIHostResolver
20+
type mockAPIHostResolver struct {
21+
oauthURL string
22+
}
23+
24+
func (m mockAPIHostResolver) BaseRESTURL(_ context.Context) (*url.URL, error) {
25+
return nil, nil
26+
}
27+
28+
func (m mockAPIHostResolver) GraphqlURL(_ context.Context) (*url.URL, error) {
29+
return nil, nil
30+
}
31+
32+
func (m mockAPIHostResolver) UploadURL(_ context.Context) (*url.URL, error) {
33+
return nil, nil
34+
}
35+
36+
func (m mockAPIHostResolver) RawURL(_ context.Context) (*url.URL, error) {
37+
return nil, nil
38+
}
39+
40+
func (m mockAPIHostResolver) OAuthURL(_ context.Context) (*url.URL, error) {
41+
return url.Parse(m.oauthURL)
42+
}
43+
44+
// Ensure mockAPIHostResolver implements utils.APIHostResolver
45+
var _ utils.APIHostResolver = mockAPIHostResolver{}
46+
1647
func TestNewAuthHandler(t *testing.T) {
1748
t.Parallel()
1849

@@ -51,6 +82,31 @@ func TestNewAuthHandler(t *testing.T) {
5182
expectedAuthServer: DefaultAuthorizationServer,
5283
expectedResourcePath: "/mcp",
5384
},
85+
{
86+
name: "APIHost with HTTPS GHES",
87+
cfg: &Config{
88+
APIHost: mockAPIHostResolver{oauthURL: "https://ghes.example.com/login/oauth"},
89+
},
90+
expectedAuthServer: "https://ghes.example.com/login/oauth",
91+
expectedResourcePath: "",
92+
},
93+
{
94+
name: "APIHost with HTTP GHES",
95+
cfg: &Config{
96+
APIHost: mockAPIHostResolver{oauthURL: "http://ghes.local/login/oauth"},
97+
},
98+
expectedAuthServer: "http://ghes.local/login/oauth",
99+
expectedResourcePath: "",
100+
},
101+
{
102+
name: "APIHost takes precedence over AuthorizationServer",
103+
cfg: &Config{
104+
APIHost: mockAPIHostResolver{oauthURL: "https://ghes.example.com/login/oauth"},
105+
AuthorizationServer: "https://should-be-ignored.example.com/oauth",
106+
},
107+
expectedAuthServer: "https://ghes.example.com/login/oauth",
108+
expectedResourcePath: "",
109+
},
54110
}
55111

56112
for _, tc := range tests {

pkg/http/server.go

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,7 @@ func RunHTTPServer(cfg ServerConfig) error {
126126
oauthCfg := &oauth.Config{
127127
BaseURL: cfg.BaseURL,
128128
ResourcePath: cfg.ResourcePath,
129-
}
130-
if cfg.Host != "" {
131-
oauthURL, err := apiHost.OAuthURL(ctx)
132-
if err != nil {
133-
return fmt.Errorf("failed to get OAuth URL: %w", err)
134-
}
135-
oauthCfg.AuthorizationServer = oauthURL.String()
129+
APIHost: apiHost,
136130
}
137131

138132
serverOptions := []HandlerOption{}

0 commit comments

Comments
 (0)