Skip to content

Commit 48a2a05

Browse files
atharva1051omgitsadsCopilot
authored
Use configured --gh-host as oauth authorization server (#2046)
When configured with a `--gh-host` argument, construct the OAuth Authorization Server URL from this host, rather than defaulting to `https://github.com/login/oauth` Co-authored-by: Adam Holt < 4619+omgitsads@users.noreply.github.com> Co-authored-by: atharva1051 <53966412+atharva1051@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
1 parent cdf84bc commit 48a2a05

File tree

5 files changed

+213
-47
lines changed

5 files changed

+213
-47
lines changed

pkg/http/oauth/oauth.go

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"strings"
99

1010
"github.com/github/github-mcp-server/pkg/http/headers"
11+
"github.com/github/github-mcp-server/pkg/utils"
1112
"github.com/go-chi/chi/v5"
1213
"github.com/modelcontextprotocol/go-sdk/auth"
1314
"github.com/modelcontextprotocol/go-sdk/oauthex"
@@ -16,9 +17,6 @@ import (
1617
const (
1718
// OAuthProtectedResourcePrefix is the well-known path prefix for OAuth protected resource metadata.
1819
OAuthProtectedResourcePrefix = "/.well-known/oauth-protected-resource"
19-
20-
// DefaultAuthorizationServer is GitHub's OAuth authorization server.
21-
DefaultAuthorizationServer = "https://github.com/login/oauth"
2220
)
2321

2422
// SupportedScopes lists all OAuth scopes that may be required by MCP tools.
@@ -55,22 +53,27 @@ type Config struct {
5553

5654
// AuthHandler handles OAuth-related HTTP endpoints.
5755
type AuthHandler struct {
58-
cfg *Config
56+
cfg *Config
57+
apiHost utils.APIHostResolver
5958
}
6059

6160
// NewAuthHandler creates a new OAuth auth handler.
62-
func NewAuthHandler(cfg *Config) (*AuthHandler, error) {
61+
func NewAuthHandler(cfg *Config, apiHost utils.APIHostResolver) (*AuthHandler, error) {
6362
if cfg == nil {
6463
cfg = &Config{}
6564
}
6665

67-
// Default authorization server to GitHub
68-
if cfg.AuthorizationServer == "" {
69-
cfg.AuthorizationServer = DefaultAuthorizationServer
66+
if apiHost == nil {
67+
var err error
68+
apiHost, err = utils.NewAPIHost("https://api.github.com")
69+
if err != nil {
70+
return nil, fmt.Errorf("failed to create default API host: %w", err)
71+
}
7072
}
7173

7274
return &AuthHandler{
73-
cfg: cfg,
75+
cfg: cfg,
76+
apiHost: apiHost,
7477
}, nil
7578
}
7679

@@ -95,15 +98,28 @@ func (h *AuthHandler) RegisterRoutes(r chi.Router) {
9598

9699
func (h *AuthHandler) metadataHandler() http.Handler {
97100
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
101+
ctx := r.Context()
98102
resourcePath := resolveResourcePath(
99103
strings.TrimPrefix(r.URL.Path, OAuthProtectedResourcePrefix),
100104
h.cfg.ResourcePath,
101105
)
102106
resourceURL := h.buildResourceURL(r, resourcePath)
103107

108+
var authorizationServerURL string
109+
if h.cfg.AuthorizationServer != "" {
110+
authorizationServerURL = h.cfg.AuthorizationServer
111+
} else {
112+
authURL, err := h.apiHost.AuthorizationServerURL(ctx)
113+
if err != nil {
114+
http.Error(w, fmt.Sprintf("failed to resolve authorization server URL: %v", err), http.StatusInternalServerError)
115+
return
116+
}
117+
authorizationServerURL = authURL.String()
118+
}
119+
104120
metadata := &oauthex.ProtectedResourceMetadata{
105121
Resource: resourceURL,
106-
AuthorizationServers: []string{h.cfg.AuthorizationServer},
122+
AuthorizationServers: []string{authorizationServerURL},
107123
ResourceName: "GitHub MCP Server",
108124
ScopesSupported: SupportedScopes,
109125
BearerMethodsSupported: []string{"header"},

pkg/http/oauth/oauth_test.go

Lines changed: 142 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,28 @@ import (
88
"testing"
99

1010
"github.com/github/github-mcp-server/pkg/http/headers"
11+
"github.com/github/github-mcp-server/pkg/utils"
1112
"github.com/go-chi/chi/v5"
1213
"github.com/stretchr/testify/assert"
1314
"github.com/stretchr/testify/require"
1415
)
1516

17+
var (
18+
defaultAuthorizationServer = "https://github.com/login/oauth"
19+
)
20+
1621
func TestNewAuthHandler(t *testing.T) {
1722
t.Parallel()
1823

24+
dotcomHost, err := utils.NewAPIHost("https://api.github.com")
25+
require.NoError(t, err)
26+
1927
tests := []struct {
2028
name string
2129
cfg *Config
2230
expectedAuthServer string
2331
expectedResourcePath string
2432
}{
25-
{
26-
name: "nil config uses defaults",
27-
cfg: nil,
28-
expectedAuthServer: DefaultAuthorizationServer,
29-
expectedResourcePath: "",
30-
},
31-
{
32-
name: "empty config uses defaults",
33-
cfg: &Config{},
34-
expectedAuthServer: DefaultAuthorizationServer,
35-
expectedResourcePath: "",
36-
},
3733
{
3834
name: "custom authorization server",
3935
cfg: &Config{
@@ -48,7 +44,7 @@ func TestNewAuthHandler(t *testing.T) {
4844
BaseURL: "https://example.com",
4945
ResourcePath: "/mcp",
5046
},
51-
expectedAuthServer: DefaultAuthorizationServer,
47+
expectedAuthServer: "",
5248
expectedResourcePath: "/mcp",
5349
},
5450
}
@@ -57,11 +53,12 @@ func TestNewAuthHandler(t *testing.T) {
5753
t.Run(tc.name, func(t *testing.T) {
5854
t.Parallel()
5955

60-
handler, err := NewAuthHandler(tc.cfg)
56+
handler, err := NewAuthHandler(tc.cfg, dotcomHost)
6157
require.NoError(t, err)
6258
require.NotNil(t, handler)
6359

6460
assert.Equal(t, tc.expectedAuthServer, handler.cfg.AuthorizationServer)
61+
assert.Equal(t, tc.expectedResourcePath, handler.cfg.ResourcePath)
6562
})
6663
}
6764
}
@@ -372,7 +369,7 @@ func TestHandleProtectedResource(t *testing.T) {
372369
authServers, ok := body["authorization_servers"].([]any)
373370
require.True(t, ok)
374371
require.Len(t, authServers, 1)
375-
assert.Equal(t, DefaultAuthorizationServer, authServers[0])
372+
assert.Equal(t, defaultAuthorizationServer, authServers[0])
376373
},
377374
},
378375
{
@@ -451,7 +448,10 @@ func TestHandleProtectedResource(t *testing.T) {
451448
t.Run(tc.name, func(t *testing.T) {
452449
t.Parallel()
453450

454-
handler, err := NewAuthHandler(tc.cfg)
451+
dotcomHost, err := utils.NewAPIHost("https://api.github.com")
452+
require.NoError(t, err)
453+
454+
handler, err := NewAuthHandler(tc.cfg, dotcomHost)
455455
require.NoError(t, err)
456456

457457
router := chi.NewRouter()
@@ -493,9 +493,12 @@ func TestHandleProtectedResource(t *testing.T) {
493493
func TestRegisterRoutes(t *testing.T) {
494494
t.Parallel()
495495

496+
dotcomHost, err := utils.NewAPIHost("https://api.github.com")
497+
require.NoError(t, err)
498+
496499
handler, err := NewAuthHandler(&Config{
497500
BaseURL: "https://api.example.com",
498-
})
501+
}, dotcomHost)
499502
require.NoError(t, err)
500503

501504
router := chi.NewRouter()
@@ -559,9 +562,12 @@ func TestSupportedScopes(t *testing.T) {
559562
func TestProtectedResourceResponseFormat(t *testing.T) {
560563
t.Parallel()
561564

565+
dotcomHost, err := utils.NewAPIHost("https://api.github.com")
566+
require.NoError(t, err)
567+
562568
handler, err := NewAuthHandler(&Config{
563569
BaseURL: "https://api.example.com",
564-
})
570+
}, dotcomHost)
565571
require.NoError(t, err)
566572

567573
router := chi.NewRouter()
@@ -598,7 +604,7 @@ func TestProtectedResourceResponseFormat(t *testing.T) {
598604
authServers, ok := response["authorization_servers"].([]any)
599605
require.True(t, ok)
600606
assert.Len(t, authServers, 1)
601-
assert.Equal(t, DefaultAuthorizationServer, authServers[0])
607+
assert.Equal(t, defaultAuthorizationServer, authServers[0])
602608
}
603609

604610
func TestOAuthProtectedResourcePrefix(t *testing.T) {
@@ -611,5 +617,121 @@ func TestOAuthProtectedResourcePrefix(t *testing.T) {
611617
func TestDefaultAuthorizationServer(t *testing.T) {
612618
t.Parallel()
613619

614-
assert.Equal(t, "https://github.com/login/oauth", DefaultAuthorizationServer)
620+
assert.Equal(t, "https://github.com/login/oauth", defaultAuthorizationServer)
621+
}
622+
623+
func TestAPIHostResolver_AuthorizationServerURL(t *testing.T) {
624+
t.Parallel()
625+
626+
tests := []struct {
627+
name string
628+
host string
629+
oauthConfig *Config
630+
expectedURL string
631+
expectedError bool
632+
expectedStatusCode int
633+
errorContains string
634+
}{
635+
{
636+
name: "valid host returns authorization server URL",
637+
host: "https://github.com",
638+
expectedURL: "https://github.com/login/oauth",
639+
expectedStatusCode: http.StatusOK,
640+
},
641+
{
642+
name: "invalid host returns error",
643+
host: "://invalid-url",
644+
expectedURL: "",
645+
expectedError: true,
646+
errorContains: "could not parse host as URL",
647+
},
648+
{
649+
name: "host without scheme returns error",
650+
host: "github.com",
651+
expectedURL: "",
652+
expectedError: true,
653+
errorContains: "host must have a scheme",
654+
},
655+
{
656+
name: "GHEC host returns correct authorization server URL",
657+
host: "https://test.ghe.com",
658+
expectedURL: "https://test.ghe.com/login/oauth",
659+
expectedStatusCode: http.StatusOK,
660+
},
661+
{
662+
name: "GHES host returns correct authorization server URL",
663+
host: "https://ghe.example.com",
664+
expectedURL: "https://ghe.example.com/login/oauth",
665+
expectedStatusCode: http.StatusOK,
666+
},
667+
{
668+
name: "GHES with http scheme returns the correct authorization server URL",
669+
host: "http://ghe.example.com",
670+
expectedURL: "http://ghe.example.com/login/oauth",
671+
expectedStatusCode: http.StatusOK,
672+
},
673+
{
674+
name: "custom authorization server in config takes precedence",
675+
host: "https://github.com",
676+
oauthConfig: &Config{
677+
AuthorizationServer: "https://custom.auth.example.com/oauth",
678+
},
679+
expectedURL: "https://custom.auth.example.com/oauth",
680+
expectedStatusCode: http.StatusOK,
681+
},
682+
}
683+
684+
for _, tc := range tests {
685+
t.Run(tc.name, func(t *testing.T) {
686+
t.Parallel()
687+
688+
apiHost, err := utils.NewAPIHost(tc.host)
689+
if tc.expectedError {
690+
require.Error(t, err)
691+
if tc.errorContains != "" {
692+
assert.Contains(t, err.Error(), tc.errorContains)
693+
}
694+
return
695+
}
696+
require.NoError(t, err)
697+
698+
config := tc.oauthConfig
699+
if config == nil {
700+
config = &Config{}
701+
}
702+
config.BaseURL = tc.host
703+
704+
handler, err := NewAuthHandler(config, apiHost)
705+
require.NoError(t, err)
706+
707+
router := chi.NewRouter()
708+
handler.RegisterRoutes(router)
709+
710+
req := httptest.NewRequest(http.MethodGet, OAuthProtectedResourcePrefix, nil)
711+
req.Host = "api.example.com"
712+
713+
rec := httptest.NewRecorder()
714+
router.ServeHTTP(rec, req)
715+
716+
require.Equal(t, http.StatusOK, rec.Code)
717+
718+
var response map[string]any
719+
err = json.Unmarshal(rec.Body.Bytes(), &response)
720+
require.NoError(t, err)
721+
722+
assert.Contains(t, response, "authorization_servers")
723+
if tc.expectedStatusCode != http.StatusOK {
724+
require.Equal(t, tc.expectedStatusCode, rec.Code)
725+
if tc.errorContains != "" {
726+
assert.Contains(t, rec.Body.String(), tc.errorContains)
727+
}
728+
return
729+
}
730+
731+
responseAuthServers, ok := response["authorization_servers"].([]any)
732+
require.True(t, ok)
733+
require.Len(t, responseAuthServers, 1)
734+
assert.Equal(t, tc.expectedURL, responseAuthServers[0])
735+
})
736+
}
615737
}

pkg/http/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ func RunHTTPServer(cfg ServerConfig) error {
136136

137137
r := chi.NewRouter()
138138
handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, apiHost, append(serverOptions, WithFeatureChecker(featureChecker), WithOAuthConfig(oauthCfg))...)
139-
oauthHandler, err := oauth.NewAuthHandler(oauthCfg)
139+
oauthHandler, err := oauth.NewAuthHandler(oauthCfg, apiHost)
140140
if err != nil {
141141
return fmt.Errorf("failed to create OAuth handler: %w", err)
142142
}

pkg/scopes/fetcher_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ func (t testAPIHostResolver) UploadURL(_ context.Context) (*url.URL, error) {
2828
func (t testAPIHostResolver) RawURL(_ context.Context) (*url.URL, error) {
2929
return nil, nil
3030
}
31+
func (t testAPIHostResolver) AuthorizationServerURL(_ context.Context) (*url.URL, error) {
32+
return nil, nil
33+
}
3134

3235
func TestParseScopeHeader(t *testing.T) {
3336
tests := []struct {

0 commit comments

Comments
 (0)