@@ -8,32 +8,28 @@ import (
88 "testing"
99
1010 "github.com/github/github-mcp-server/pkg/http/headers"
11+ "github.com/github/github-mcp-server/pkg/utils"
1112 "github.com/go-chi/chi/v5"
1213 "github.com/stretchr/testify/assert"
1314 "github.com/stretchr/testify/require"
1415)
1516
17+ var (
18+ defaultAuthorizationServer = "https://github.com/login/oauth"
19+ )
20+
1621func TestNewAuthHandler (t * testing.T ) {
1722 t .Parallel ()
1823
24+ dotcomHost , err := utils .NewAPIHost ("https://api.github.com" )
25+ require .NoError (t , err )
26+
1927 tests := []struct {
2028 name string
2129 cfg * Config
2230 expectedAuthServer string
2331 expectedResourcePath string
2432 }{
25- {
26- name : "nil config uses defaults" ,
27- cfg : nil ,
28- expectedAuthServer : DefaultAuthorizationServer ,
29- expectedResourcePath : "" ,
30- },
31- {
32- name : "empty config uses defaults" ,
33- cfg : & Config {},
34- expectedAuthServer : DefaultAuthorizationServer ,
35- expectedResourcePath : "" ,
36- },
3733 {
3834 name : "custom authorization server" ,
3935 cfg : & Config {
@@ -48,7 +44,7 @@ func TestNewAuthHandler(t *testing.T) {
4844 BaseURL : "https://example.com" ,
4945 ResourcePath : "/mcp" ,
5046 },
51- expectedAuthServer : DefaultAuthorizationServer ,
47+ expectedAuthServer : "" ,
5248 expectedResourcePath : "/mcp" ,
5349 },
5450 }
@@ -57,11 +53,12 @@ func TestNewAuthHandler(t *testing.T) {
5753 t .Run (tc .name , func (t * testing.T ) {
5854 t .Parallel ()
5955
60- handler , err := NewAuthHandler (tc .cfg )
56+ handler , err := NewAuthHandler (tc .cfg , dotcomHost )
6157 require .NoError (t , err )
6258 require .NotNil (t , handler )
6359
6460 assert .Equal (t , tc .expectedAuthServer , handler .cfg .AuthorizationServer )
61+ assert .Equal (t , tc .expectedResourcePath , handler .cfg .ResourcePath )
6562 })
6663 }
6764}
@@ -372,7 +369,7 @@ func TestHandleProtectedResource(t *testing.T) {
372369 authServers , ok := body ["authorization_servers" ].([]any )
373370 require .True (t , ok )
374371 require .Len (t , authServers , 1 )
375- assert .Equal (t , DefaultAuthorizationServer , authServers [0 ])
372+ assert .Equal (t , defaultAuthorizationServer , authServers [0 ])
376373 },
377374 },
378375 {
@@ -451,7 +448,10 @@ func TestHandleProtectedResource(t *testing.T) {
451448 t .Run (tc .name , func (t * testing.T ) {
452449 t .Parallel ()
453450
454- handler , err := NewAuthHandler (tc .cfg )
451+ dotcomHost , err := utils .NewAPIHost ("https://api.github.com" )
452+ require .NoError (t , err )
453+
454+ handler , err := NewAuthHandler (tc .cfg , dotcomHost )
455455 require .NoError (t , err )
456456
457457 router := chi .NewRouter ()
@@ -493,9 +493,12 @@ func TestHandleProtectedResource(t *testing.T) {
493493func TestRegisterRoutes (t * testing.T ) {
494494 t .Parallel ()
495495
496+ dotcomHost , err := utils .NewAPIHost ("https://api.github.com" )
497+ require .NoError (t , err )
498+
496499 handler , err := NewAuthHandler (& Config {
497500 BaseURL : "https://api.example.com" ,
498- })
501+ }, dotcomHost )
499502 require .NoError (t , err )
500503
501504 router := chi .NewRouter ()
@@ -559,9 +562,12 @@ func TestSupportedScopes(t *testing.T) {
559562func TestProtectedResourceResponseFormat (t * testing.T ) {
560563 t .Parallel ()
561564
565+ dotcomHost , err := utils .NewAPIHost ("https://api.github.com" )
566+ require .NoError (t , err )
567+
562568 handler , err := NewAuthHandler (& Config {
563569 BaseURL : "https://api.example.com" ,
564- })
570+ }, dotcomHost )
565571 require .NoError (t , err )
566572
567573 router := chi .NewRouter ()
@@ -598,7 +604,7 @@ func TestProtectedResourceResponseFormat(t *testing.T) {
598604 authServers , ok := response ["authorization_servers" ].([]any )
599605 require .True (t , ok )
600606 assert .Len (t , authServers , 1 )
601- assert .Equal (t , DefaultAuthorizationServer , authServers [0 ])
607+ assert .Equal (t , defaultAuthorizationServer , authServers [0 ])
602608}
603609
604610func TestOAuthProtectedResourcePrefix (t * testing.T ) {
@@ -611,5 +617,121 @@ func TestOAuthProtectedResourcePrefix(t *testing.T) {
611617func TestDefaultAuthorizationServer (t * testing.T ) {
612618 t .Parallel ()
613619
614- assert .Equal (t , "https://github.com/login/oauth" , DefaultAuthorizationServer )
620+ assert .Equal (t , "https://github.com/login/oauth" , defaultAuthorizationServer )
621+ }
622+
623+ func TestAPIHostResolver_AuthorizationServerURL (t * testing.T ) {
624+ t .Parallel ()
625+
626+ tests := []struct {
627+ name string
628+ host string
629+ oauthConfig * Config
630+ expectedURL string
631+ expectedError bool
632+ expectedStatusCode int
633+ errorContains string
634+ }{
635+ {
636+ name : "valid host returns authorization server URL" ,
637+ host : "https://github.com" ,
638+ expectedURL : "https://github.com/login/oauth" ,
639+ expectedStatusCode : http .StatusOK ,
640+ },
641+ {
642+ name : "invalid host returns error" ,
643+ host : "://invalid-url" ,
644+ expectedURL : "" ,
645+ expectedError : true ,
646+ errorContains : "could not parse host as URL" ,
647+ },
648+ {
649+ name : "host without scheme returns error" ,
650+ host : "github.com" ,
651+ expectedURL : "" ,
652+ expectedError : true ,
653+ errorContains : "host must have a scheme" ,
654+ },
655+ {
656+ name : "GHEC host returns correct authorization server URL" ,
657+ host : "https://test.ghe.com" ,
658+ expectedURL : "https://test.ghe.com/login/oauth" ,
659+ expectedStatusCode : http .StatusOK ,
660+ },
661+ {
662+ name : "GHES host returns correct authorization server URL" ,
663+ host : "https://ghe.example.com" ,
664+ expectedURL : "https://ghe.example.com/login/oauth" ,
665+ expectedStatusCode : http .StatusOK ,
666+ },
667+ {
668+ name : "GHES with http scheme returns the correct authorization server URL" ,
669+ host : "http://ghe.example.com" ,
670+ expectedURL : "http://ghe.example.com/login/oauth" ,
671+ expectedStatusCode : http .StatusOK ,
672+ },
673+ {
674+ name : "custom authorization server in config takes precedence" ,
675+ host : "https://github.com" ,
676+ oauthConfig : & Config {
677+ AuthorizationServer : "https://custom.auth.example.com/oauth" ,
678+ },
679+ expectedURL : "https://custom.auth.example.com/oauth" ,
680+ expectedStatusCode : http .StatusOK ,
681+ },
682+ }
683+
684+ for _ , tc := range tests {
685+ t .Run (tc .name , func (t * testing.T ) {
686+ t .Parallel ()
687+
688+ apiHost , err := utils .NewAPIHost (tc .host )
689+ if tc .expectedError {
690+ require .Error (t , err )
691+ if tc .errorContains != "" {
692+ assert .Contains (t , err .Error (), tc .errorContains )
693+ }
694+ return
695+ }
696+ require .NoError (t , err )
697+
698+ config := tc .oauthConfig
699+ if config == nil {
700+ config = & Config {}
701+ }
702+ config .BaseURL = tc .host
703+
704+ handler , err := NewAuthHandler (config , apiHost )
705+ require .NoError (t , err )
706+
707+ router := chi .NewRouter ()
708+ handler .RegisterRoutes (router )
709+
710+ req := httptest .NewRequest (http .MethodGet , OAuthProtectedResourcePrefix , nil )
711+ req .Host = "api.example.com"
712+
713+ rec := httptest .NewRecorder ()
714+ router .ServeHTTP (rec , req )
715+
716+ require .Equal (t , http .StatusOK , rec .Code )
717+
718+ var response map [string ]any
719+ err = json .Unmarshal (rec .Body .Bytes (), & response )
720+ require .NoError (t , err )
721+
722+ assert .Contains (t , response , "authorization_servers" )
723+ if tc .expectedStatusCode != http .StatusOK {
724+ require .Equal (t , tc .expectedStatusCode , rec .Code )
725+ if tc .errorContains != "" {
726+ assert .Contains (t , rec .Body .String (), tc .errorContains )
727+ }
728+ return
729+ }
730+
731+ responseAuthServers , ok := response ["authorization_servers" ].([]any )
732+ require .True (t , ok )
733+ require .Len (t , responseAuthServers , 1 )
734+ assert .Equal (t , tc .expectedURL , responseAuthServers [0 ])
735+ })
736+ }
615737}
0 commit comments