diff --git a/auth/client.go b/auth/client.go new file mode 100644 index 0000000..51daf33 --- /dev/null +++ b/auth/client.go @@ -0,0 +1,483 @@ +package auth + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +// OAuthClient handles OAuth flow with third-party providers +type OAuthClient struct { + config *ThirdPartyOAuthConfig + mcpTokenGen *TokenGenerator + store Store + stateStore map[string]*OAuthState + httpClient *http.Client + encryptionKey []byte + stateMutex sync.RWMutex + encryptMutex sync.RWMutex +} + +// OAuthState stores temporary OAuth flow state +type OAuthState struct { + State string + CodeVerifier string + CodeChallenge string + RedirectURI string + OriginalRequest string + CreatedAt time.Time + ExpiresAt time.Time + + // MCP Client parameters + MCPClientID string + MCPRedirectURI string + MCPScopes []string + MCPResources []string + MCPState string + MCPCodeChallenge string + MCPCodeChallengeMethod string +} + +// NewOAuthClient creates a new OAuth client for third-party auth +func NewOAuthClient( + config *ThirdPartyOAuthConfig, + mcpTokenGen *TokenGenerator, + store Store, +) (*OAuthClient, error) { + // Generate encryption key for sensitive data + encryptionKey := make([]byte, 32) + if _, err := rand.Read(encryptionKey); err != nil { + return nil, fmt.Errorf("failed to generate encryption key: %w", err) + } + + return &OAuthClient{ + config: config, + mcpTokenGen: mcpTokenGen, + store: store, + stateStore: make(map[string]*OAuthState), + httpClient: &http.Client{Timeout: 30 * time.Second}, + encryptionKey: encryptionKey, + }, nil +} + +// InitiateOAuthFlow handles the initial OAuth request from MCP Client +func (c *OAuthClient) InitiateOAuthFlow(w http.ResponseWriter, r *http.Request) { + // Extract MCP client parameters + mcpClientID := r.URL.Query().Get("client_id") + mcpRedirectURI := r.URL.Query().Get("redirect_uri") + mcpState := r.URL.Query().Get("state") + mcpScope := r.URL.Query().Get("scope") + mcpCodeChallenge := r.URL.Query().Get("code_challenge") + mcpCodeChallengeMethod := r.URL.Query().Get("code_challenge_method") + mcpResources := r.URL.Query()["resource"] + + // Validate required MCP parameters + if mcpClientID == "" { + http.Error(w, "Missing client_id parameter", http.StatusBadRequest) + return + } + if mcpRedirectURI == "" { + http.Error(w, "Missing redirect_uri parameter", http.StatusBadRequest) + return + } + + // Validate MCP client exists + client, err := c.store.GetClient(r.Context(), mcpClientID) + if err != nil { + http.Error(w, "Invalid client_id", http.StatusBadRequest) + return + } + + // Validate redirect URI (exact match per MCP spec) + validRedirect := false + for _, allowed := range client.RedirectURIs { + if mcpRedirectURI == allowed { + validRedirect = true + break + } + } + if !validRedirect { + http.Error(w, "Invalid redirect_uri", http.StatusBadRequest) + return + } + + // Parse scopes + mcpScopes := parseScopes(mcpScope) + + // Generate state for CSRF protection + state, err := generateRandomState() + if err != nil { + http.Error(w, "Failed to generate state", http.StatusInternalServerError) + return + } + + // Generate PKCE challenge for third-party OAuth if enabled + var codeVerifier, codeChallenge string + if c.config.UsePKCE { + codeVerifier, codeChallenge, err = GenerateCodeVerifierAndChallenge() + if err != nil { + http.Error(w, "Failed to generate PKCE", http.StatusInternalServerError) + return + } + } + + // Store state with MCP client parameters + oauthState := &OAuthState{ + State: state, + CodeVerifier: codeVerifier, + CodeChallenge: codeChallenge, + RedirectURI: c.config.RedirectURI, + CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(10 * time.Minute), + + // MCP Client parameters + MCPClientID: mcpClientID, + MCPRedirectURI: mcpRedirectURI, + MCPScopes: mcpScopes, + MCPResources: mcpResources, + MCPState: mcpState, + MCPCodeChallenge: mcpCodeChallenge, + MCPCodeChallengeMethod: mcpCodeChallengeMethod, + } + + c.stateMutex.Lock() + c.stateStore[state] = oauthState + c.stateMutex.Unlock() + + // Build authorization URL for third-party provider + authURL := c.buildAuthorizationURL(state, codeChallenge) + + // Redirect user to third-party authorization page + http.Redirect(w, r, authURL, http.StatusFound) +} + +// HandleCallback handles the callback from third-party OAuth provider +func (c *OAuthClient) HandleCallback(w http.ResponseWriter, r *http.Request) { + // Extract authorization code and state + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + errorCode := r.URL.Query().Get("error") + + // Check for errors from third-party + if errorCode != "" { + errorDesc := r.URL.Query().Get("error_description") + http.Error(w, fmt.Sprintf("OAuth error: %s - %s", errorCode, errorDesc), http.StatusBadRequest) + return + } + + if code == "" || state == "" { + http.Error(w, "Missing code or state parameter", http.StatusBadRequest) + return + } + + // Validate state + c.stateMutex.RLock() + oauthState, ok := c.stateStore[state] + c.stateMutex.RUnlock() + + if !ok { + http.Error(w, "Invalid state parameter", http.StatusBadRequest) + return + } + + // Check state expiration + if time.Now().After(oauthState.ExpiresAt) { + c.stateMutex.Lock() + delete(c.stateStore, state) + c.stateMutex.Unlock() + http.Error(w, "State expired", http.StatusBadRequest) + return + } + + // Clean up used state + c.stateMutex.Lock() + delete(c.stateStore, state) + c.stateMutex.Unlock() + + // Exchange authorization code for access token from third-party + thirdPartyToken, err := c.exchangeCodeForToken(code, oauthState.CodeVerifier) + if err != nil { + http.Error(w, "Failed to exchange code", http.StatusInternalServerError) + return + } + + // Get user info from third-party + userInfo, err := c.getUserInfo(thirdPartyToken.AccessToken) + if err != nil { + http.Error(w, "Failed to get user info", http.StatusInternalServerError) + return + } + + // Generate MCP authorization code with proper parameters + mcpAuthCode, err := c.generateMCPAuthorizationCode(r.Context(), oauthState, userInfo, thirdPartyToken) + if err != nil { + http.Error(w, "Failed to complete authorization", http.StatusInternalServerError) + return + } + + // Redirect back to MCP client with MCP authorization code + redirectURL := c.buildMCPClientRedirect(oauthState.MCPRedirectURI, mcpAuthCode, oauthState.MCPState) + http.Redirect(w, r, redirectURL, http.StatusFound) +} + +// generateMCPAuthorizationCode generates a proper MCP authorization code +func (c *OAuthClient) generateMCPAuthorizationCode( + ctx context.Context, + oauthState *OAuthState, + userInfo *ThirdPartyUserInfo, + thirdPartyToken *ThirdPartyTokenResponse, +) (string, error) { + code, err := GenerateAuthorizationCode() + if err != nil { + return "", err + } + + // Set default PKCE method if challenge provided + challengeMethod := oauthState.MCPCodeChallengeMethod + if challengeMethod == "" && oauthState.MCPCodeChallenge != "" { + challengeMethod = CodeChallengeMethodPlain + } + + // Encrypt third-party tokens before storing + encryptedAccessToken, err := c.encryptToken(thirdPartyToken.AccessToken) + if err != nil { + return "", fmt.Errorf("failed to encrypt access token: %w", err) + } + + encryptedRefreshToken := "" + if thirdPartyToken.RefreshToken != "" { + encryptedRefreshToken, err = c.encryptToken(thirdPartyToken.RefreshToken) + if err != nil { + return "", fmt.Errorf("failed to encrypt refresh token: %w", err) + } + } + + // Create authorization code with proper MCP client binding + authCode := &AuthorizationCode{ + Code: code, + ClientID: oauthState.MCPClientID, + UserID: userInfo.Subject, + RedirectURI: oauthState.MCPRedirectURI, + Scopes: oauthState.MCPScopes, + Resources: oauthState.MCPResources, + CodeChallenge: oauthState.MCPCodeChallenge, + CodeChallengeMethod: challengeMethod, + ExpiresAt: time.Now().Add(5 * time.Minute), + CreatedAt: time.Now(), + Used: false, + Metadata: map[string]interface{}{ + "third_party_provider": c.config.ProviderName, + "third_party_token_encrypted": encryptedAccessToken, + "third_party_refresh_encrypted": encryptedRefreshToken, + "third_party_expires": time.Now().Add(time.Duration(thirdPartyToken.ExpiresIn) * time.Second).Unix(), + "user_email": userInfo.Email, + "user_name": userInfo.Name, + }, + } + + if err := c.store.SaveAuthorizationCode(ctx, authCode); err != nil { + return "", err + } + + return code, nil +} + +// encryptToken encrypts sensitive tokens using AES-GCM +func (c *OAuthClient) encryptToken(plaintext string) (string, error) { + if plaintext == "" { + return "", nil + } + + c.encryptMutex.RLock() + defer c.encryptMutex.RUnlock() + + block, err := aes.NewCipher(c.encryptionKey) + if err != nil { + return "", err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", err + } + + ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil) + return base64.StdEncoding.EncodeToString(ciphertext), nil +} + +// decryptToken decrypts sensitive tokens +func (c *OAuthClient) decryptToken(encrypted string) (string, error) { + if encrypted == "" { + return "", nil + } + + c.encryptMutex.RLock() + defer c.encryptMutex.RUnlock() + + ciphertext, err := base64.StdEncoding.DecodeString(encrypted) + if err != nil { + return "", err + } + + block, err := aes.NewCipher(c.encryptionKey) + if err != nil { + return "", err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + if len(ciphertext) < gcm.NonceSize() { + return "", errors.New("ciphertext too short") + } + + nonce, ciphertext := ciphertext[:gcm.NonceSize()], ciphertext[gcm.NonceSize():] + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return "", err + } + + return string(plaintext), nil +} + +// exchangeCodeForToken exchanges authorization code for access token +func (c *OAuthClient) exchangeCodeForToken(code, codeVerifier string) (*ThirdPartyTokenResponse, error) { + data := url.Values{} + data.Set("grant_type", "authorization_code") + data.Set("code", code) + data.Set("redirect_uri", c.config.RedirectURI) + data.Set("client_id", c.config.ClientID) + data.Set("client_secret", c.config.ClientSecret) + + if c.config.UsePKCE && codeVerifier != "" { + data.Set("code_verifier", codeVerifier) + } + + req, err := http.NewRequest(http.MethodPost, c.config.TokenURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("token request failed: %d - %s", resp.StatusCode, string(body)) + } + + var tokenResp ThirdPartyTokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return nil, err + } + + return &tokenResp, nil +} + +// getUserInfo fetches user information from third-party +func (c *OAuthClient) getUserInfo(accessToken string) (*ThirdPartyUserInfo, error) { + if c.config.UserInfoURL == "" { + return &ThirdPartyUserInfo{ + Subject: "unknown", + Claims: make(map[string]interface{}), + }, nil + } + + req, err := http.NewRequest(http.MethodGet, c.config.UserInfoURL, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("userinfo request failed: %d - %s", resp.StatusCode, string(body)) + } + + var userInfo ThirdPartyUserInfo + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if err := json.Unmarshal(body, &userInfo); err != nil { + return nil, err + } + + var claims map[string]interface{} + if err := json.Unmarshal(body, &claims); err == nil { + userInfo.Claims = claims + } + + return &userInfo, nil +} + +// buildAuthorizationURL builds the third-party authorization URL +func (c *OAuthClient) buildAuthorizationURL(state, codeChallenge string) string { + params := url.Values{} + params.Set("response_type", "code") + params.Set("client_id", c.config.ClientID) + params.Set("redirect_uri", c.config.RedirectURI) + params.Set("state", state) + params.Set("scope", strings.Join(c.config.Scopes, " ")) + + if c.config.UsePKCE && codeChallenge != "" { + params.Set("code_challenge", codeChallenge) + params.Set("code_challenge_method", "S256") + } + + return c.config.AuthorizationURL + "?" + params.Encode() +} + +// buildMCPClientRedirect builds redirect URL back to MCP client +func (c *OAuthClient) buildMCPClientRedirect(mcpClientCallback, authCode, mcpState string) string { + redirectURL, _ := url.Parse(mcpClientCallback) + query := redirectURL.Query() + query.Set("code", authCode) + if mcpState != "" { + query.Set("state", mcpState) + } + redirectURL.RawQuery = query.Encode() + return redirectURL.String() +} + +// generateRandomState generates a cryptographically secure random state +func generateRandomState() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} diff --git a/auth/dynamic_register.go b/auth/dynamic_register.go new file mode 100644 index 0000000..5f6e59a --- /dev/null +++ b/auth/dynamic_register.go @@ -0,0 +1,685 @@ +package auth + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "time" +) + +// RFC 7591: OAuth 2.0 Dynamic Client Registration Protocol + +// ClientRegistrationRequest represents a dynamic client registration request +type ClientRegistrationRequest struct { + // OPTIONAL: Array of redirection URIs + RedirectURIs []string `json:"redirect_uris,omitempty"` + + // OPTIONAL: Array of OAuth 2.0 grant types + GrantTypes []string `json:"grant_types,omitempty"` + + // OPTIONAL: Kind of application (e.g., "web", "native") + ApplicationType string `json:"application_type,omitempty"` + + // OPTIONAL: Array of email addresses + Contacts []string `json:"contacts,omitempty"` + + // OPTIONAL: Human-readable client name + ClientName string `json:"client_name,omitempty"` + + // OPTIONAL: URL of the client logo + LogoURI string `json:"logo_uri,omitempty"` + + // OPTIONAL: URL of the client's home page + ClientURI string `json:"client_uri,omitempty"` + + // OPTIONAL: URL of policy document + PolicyURI string `json:"policy_uri,omitempty"` + + // OPTIONAL: URL of terms of service + TosURI string `json:"tos_uri,omitempty"` + + // OPTIONAL: URL for the client's JSON Web Key Set + JWKSURI string `json:"jwks_uri,omitempty"` + + // OPTIONAL: Client's JSON Web Key Set + JWKS json.RawMessage `json:"jwks,omitempty"` + + // OPTIONAL: URL referencing the software statement + SoftwareID string `json:"software_id,omitempty"` + + // OPTIONAL: Version identifier for the software + SoftwareVersion string `json:"software_version,omitempty"` + + // OPTIONAL: Requested authentication method for the token endpoint + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"` + + // OPTIONAL: Requested scope values + Scope string `json:"scope,omitempty"` + + // RFC 8707: Resource indicators + ResourceIndicators []string `json:"resource_indicators,omitempty"` + + // MCP-specific extensions + MCPVersion string `json:"mcp_version,omitempty"` + MCPTransportsSupported []string `json:"mcp_transports_supported,omitempty"` +} + +// ClientRegistrationResponse represents a successful registration response +type ClientRegistrationResponse struct { + // REQUIRED: Unique client identifier + ClientID string `json:"client_id"` + + // OPTIONAL: Client secret (only for confidential clients) + ClientSecret string `json:"client_secret,omitempty"` + + // OPTIONAL: Time at which the client secret will expire (Unix timestamp) + ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"` + + // OPTIONAL: Registration access token for managing this client + RegistrationAccessToken string `json:"registration_access_token,omitempty"` + + // OPTIONAL: URL for managing this client registration + RegistrationClientURI string `json:"registration_client_uri,omitempty"` + + // Echo back the registered parameters + RedirectURIs []string `json:"redirect_uris,omitempty"` + GrantTypes []string `json:"grant_types,omitempty"` + ApplicationType string `json:"application_type,omitempty"` + Contacts []string `json:"contacts,omitempty"` + ClientName string `json:"client_name,omitempty"` + LogoURI string `json:"logo_uri,omitempty"` + ClientURI string `json:"client_uri,omitempty"` + PolicyURI string `json:"policy_uri,omitempty"` + TosURI string `json:"tos_uri,omitempty"` + JWKSURI string `json:"jwks_uri,omitempty"` + JWKS json.RawMessage `json:"jwks,omitempty"` + SoftwareID string `json:"software_id,omitempty"` + SoftwareVersion string `json:"software_version,omitempty"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"` + Scope string `json:"scope,omitempty"` + + // RFC 8707: Resource indicators + ResourceIndicators []string `json:"resource_indicators,omitempty"` + + // MCP-specific + MCPVersion string `json:"mcp_version,omitempty"` + MCPTransportsSupported []string `json:"mcp_transports_supported,omitempty"` + + // Metadata + ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"` +} + +// ClientUpdateRequest represents a client update request +type ClientUpdateRequest struct { + ClientRegistrationRequest + ClientID string `json:"client_id"` +} + +// DynamicRegistrationConfig configures dynamic client registration behavior +type DynamicRegistrationConfig struct { + // Enable or disable dynamic registration + Enabled bool + + // Require initial access token for registration + RequireInitialAccessToken bool + + // Initial access token (if required) + InitialAccessToken string + + // Allow public clients (no client_secret) + AllowPublicClients bool + + // Default grant types if not specified + DefaultGrantTypes []GrantType + + // Default token endpoint auth method + DefaultTokenEndpointAuthMethod string + + // Maximum number of redirect URIs + MaxRedirectURIs int + + // Client secret expiration time (0 = never expires) + ClientSecretTTL time.Duration + + // Generate registration access token for client management + EnableRegistrationAccessToken bool + + // Allowed application types + AllowedApplicationTypes []string + + // Validate redirect URIs + RedirectURIValidator func(string, string) error + + // Custom client ID generator + ClientIDGenerator func() (string, error) + + // Custom client secret generator + ClientSecretGenerator func() (string, error) +} + +// DefaultDynamicRegistrationConfig returns default configuration +func DefaultDynamicRegistrationConfig() *DynamicRegistrationConfig { + return &DynamicRegistrationConfig{ + Enabled: true, + RequireInitialAccessToken: false, + AllowPublicClients: true, + DefaultGrantTypes: []GrantType{GrantTypeAuthorizationCode, GrantTypeRefreshToken}, + DefaultTokenEndpointAuthMethod: "client_secret_basic", + MaxRedirectURIs: 10, + ClientSecretTTL: 0, // Never expires + EnableRegistrationAccessToken: true, + AllowedApplicationTypes: []string{"web", "native"}, + RedirectURIValidator: DefaultRedirectURIValidator, + ClientIDGenerator: GenerateClientID, + ClientSecretGenerator: GenerateClientSecret, + } +} + +// DynamicRegistrationHandler handles dynamic client registration +type DynamicRegistrationHandler struct { + server *Server + config *DynamicRegistrationConfig + store Store +} + +// NewDynamicRegistrationHandler creates a new registration handler +func NewDynamicRegistrationHandler(server *Server, config *DynamicRegistrationConfig) *DynamicRegistrationHandler { + if config == nil { + config = DefaultDynamicRegistrationConfig() + } + + return &DynamicRegistrationHandler{ + server: server, + config: config, + store: server.store, + } +} + +// HandleRegister handles POST /register +func (h *DynamicRegistrationHandler) HandleRegister(w http.ResponseWriter, r *http.Request) { + if !h.config.Enabled { + h.writeError(w, http.StatusNotFound, ErrInvalidRequest, "Dynamic registration is not enabled") + return + } + + if r.Method != http.MethodPost { + h.writeError(w, http.StatusMethodNotAllowed, ErrInvalidRequest, "Method not allowed") + return + } + + // Validate initial access token if required + if h.config.RequireInitialAccessToken { + token := extractBearerToken(r) + if token != h.config.InitialAccessToken { + h.writeError(w, http.StatusUnauthorized, ErrInvalidToken, "Invalid or missing initial access token") + return + } + } + + // Parse registration request + var req ClientRegistrationRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + h.writeError(w, http.StatusBadRequest, ErrInvalidRequest, fmt.Sprintf("Invalid JSON: %v", err)) + return + } + + // Register the client + resp, err := h.registerClient(r.Context(), &req) + if err != nil { + h.writeErrorFromErr(w, err) + return + } + + // Write response + h.writeJSON(w, http.StatusCreated, resp) +} + +// HandleClientConfiguration handles GET/PUT/DELETE /register/:client_id +func (h *DynamicRegistrationHandler) HandleClientConfiguration(w http.ResponseWriter, r *http.Request) { + if !h.config.Enabled { + h.writeError(w, http.StatusNotFound, ErrInvalidRequest, "Dynamic registration is not enabled") + return + } + + // Extract client ID from path + clientID := extractClientIDFromPath(r.URL.Path) + if clientID == "" { + h.writeError(w, http.StatusBadRequest, ErrInvalidRequest, "Missing client ID") + return + } + + // Validate registration access token + token := extractBearerToken(r) + if token == "" { + h.writeError(w, http.StatusUnauthorized, ErrInvalidToken, "Missing registration access token") + return + } + + // Verify token matches client + if err := h.validateRegistrationAccessToken(r.Context(), clientID, token); err != nil { + h.writeError(w, http.StatusUnauthorized, ErrInvalidToken, "Invalid registration access token") + return + } + + switch r.Method { + case http.MethodGet: + h.handleGetClient(w, r, clientID) + case http.MethodPut: + h.handleUpdateClient(w, r, clientID) + case http.MethodDelete: + h.handleDeleteClient(w, r, clientID) + default: + h.writeError(w, http.StatusMethodNotAllowed, ErrInvalidRequest, "Method not allowed") + } +} + +// registerClient registers a new client +func (h *DynamicRegistrationHandler) registerClient(ctx context.Context, req *ClientRegistrationRequest) (*ClientRegistrationResponse, error) { + // Validate redirect URIs + if len(req.RedirectURIs) == 0 { + return nil, fmt.Errorf("%s: at least one redirect_uri is required", ErrInvalidRequest) + } + + if len(req.RedirectURIs) > h.config.MaxRedirectURIs { + return nil, fmt.Errorf("%s: too many redirect URIs (max: %d)", ErrInvalidRequest, h.config.MaxRedirectURIs) + } + + // Validate application type + applicationType := req.ApplicationType + if applicationType == "" { + applicationType = "web" + } + + if !contains(h.config.AllowedApplicationTypes, applicationType) { + return nil, fmt.Errorf("%s: unsupported application type: %s", ErrInvalidRequest, applicationType) + } + + // Validate redirect URIs based on application type + for _, uri := range req.RedirectURIs { + if err := h.config.RedirectURIValidator(uri, applicationType); err != nil { + return nil, fmt.Errorf("%s: %w", ErrInvalidRequest, err) + } + } + + // Determine grant types + grantTypes := req.GrantTypes + if len(grantTypes) == 0 { + grantTypes = make([]string, len(h.config.DefaultGrantTypes)) + for i, gt := range h.config.DefaultGrantTypes { + grantTypes[i] = string(gt) + } + } + + // Validate grant types + for _, gt := range grantTypes { + if !h.server.isGrantTypeSupported(GrantType(gt)) { + return nil, fmt.Errorf("%s: unsupported grant type: %s", ErrInvalidRequest, gt) + } + } + + // Determine if client is public or confidential + isPublic := applicationType == "native" || req.TokenEndpointAuthMethod == "none" + + if isPublic && !h.config.AllowPublicClients { + return nil, fmt.Errorf("%s: public clients are not allowed", ErrInvalidRequest) + } + + // Generate client ID + clientID, err := h.config.ClientIDGenerator() + if err != nil { + return nil, fmt.Errorf("%s: failed to generate client ID", ErrServerError) + } + + // Generate client secret for confidential clients + var clientSecret string + var clientSecretExpiresAt int64 + if !isPublic { + clientSecret, err = h.config.ClientSecretGenerator() + if err != nil { + return nil, fmt.Errorf("%s: failed to generate client secret", ErrServerError) + } + + if h.config.ClientSecretTTL > 0 { + clientSecretExpiresAt = time.Now().Add(h.config.ClientSecretTTL).Unix() + } + } + + // Parse scopes + scopes := parseScopes(req.Scope) + if len(scopes) == 0 { + // Default scopes for MCP + scopes = []string{"read", "write"} + } + + // Parse resources (RFC 8707) + resources := req.ResourceIndicators + if len(resources) == 0 && len(h.server.config.SupportedResources) > 0 { + resources = h.server.config.SupportedResources + } + + // Create client + now := time.Now() + client := &Client{ + ID: clientID, + Secret: clientSecret, + Name: req.ClientName, + RedirectURIs: req.RedirectURIs, + GrantTypes: grantTypes, + Scopes: scopes, + Resources: resources, + IsPublic: isPublic, + CreatedAt: now, + UpdatedAt: now, + Metadata: map[string]interface{}{ + "application_type": applicationType, + "contacts": req.Contacts, + "logo_uri": req.LogoURI, + "client_uri": req.ClientURI, + "policy_uri": req.PolicyURI, + "tos_uri": req.TosURI, + "jwks_uri": req.JWKSURI, + "software_id": req.SoftwareID, + "software_version": req.SoftwareVersion, + "token_endpoint_auth_method": req.TokenEndpointAuthMethod, + "mcp_version": req.MCPVersion, + "mcp_transports_supported": req.MCPTransportsSupported, + "client_secret_expires_at": clientSecretExpiresAt, + }, + } + + // Store client + if err := h.store.CreateClient(ctx, client); err != nil { + return nil, fmt.Errorf("%s: %w", ErrServerError, err) + } + + // Generate registration access token if enabled + var registrationAccessToken string + var registrationClientURI string + if h.config.EnableRegistrationAccessToken { + registrationAccessToken, err = GenerateRegistrationAccessToken() + if err != nil { + return nil, fmt.Errorf("%s: failed to generate registration access token", ErrServerError) + } + + // Store the token (in metadata for now, in production use a separate store) + client.Metadata["registration_access_token"] = registrationAccessToken + if err := h.store.UpdateClient(ctx, client); err != nil { + return nil, fmt.Errorf("%s: %w", ErrServerError, err) + } + + registrationClientURI = h.server.config.GetBaseURL() + "/register/" + clientID + } + + // Build response + resp := &ClientRegistrationResponse{ + ClientID: clientID, + ClientSecret: clientSecret, + ClientSecretExpiresAt: clientSecretExpiresAt, + RegistrationAccessToken: registrationAccessToken, + RegistrationClientURI: registrationClientURI, + RedirectURIs: req.RedirectURIs, + GrantTypes: grantTypes, + ApplicationType: applicationType, + Contacts: req.Contacts, + ClientName: req.ClientName, + LogoURI: req.LogoURI, + ClientURI: req.ClientURI, + PolicyURI: req.PolicyURI, + TosURI: req.TosURI, + JWKSURI: req.JWKSURI, + JWKS: req.JWKS, + SoftwareID: req.SoftwareID, + SoftwareVersion: req.SoftwareVersion, + TokenEndpointAuthMethod: req.TokenEndpointAuthMethod, + Scope: strings.Join(scopes, " "), + ResourceIndicators: resources, + MCPVersion: req.MCPVersion, + MCPTransportsSupported: req.MCPTransportsSupported, + ClientIDIssuedAt: now.Unix(), + } + + return resp, nil +} + +// handleGetClient retrieves client configuration +func (h *DynamicRegistrationHandler) handleGetClient(w http.ResponseWriter, r *http.Request, clientID string) { + client, err := h.store.GetClient(r.Context(), clientID) + if err != nil { + if errors.Is(err, ErrNotFound) { + h.writeError(w, http.StatusNotFound, ErrInvalidRequest, "Client not found") + return + } + h.writeError(w, http.StatusInternalServerError, ErrServerError, "Failed to retrieve client") + return + } + + // Build response (don't include client_secret) + resp := h.buildClientInfoResponse(client) + h.writeJSON(w, http.StatusOK, resp) +} + +// handleUpdateClient updates client configuration +func (h *DynamicRegistrationHandler) handleUpdateClient(w http.ResponseWriter, r *http.Request, clientID string) { + var req ClientUpdateRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + h.writeError(w, http.StatusBadRequest, ErrInvalidRequest, fmt.Sprintf("Invalid JSON: %v", err)) + return + } + + req.ClientID = clientID + + // Retrieve existing client + client, err := h.store.GetClient(r.Context(), clientID) + if err != nil { + if errors.Is(err, ErrNotFound) { + h.writeError(w, http.StatusNotFound, ErrInvalidRequest, "Client not found") + return + } + h.writeError(w, http.StatusInternalServerError, ErrServerError, "Failed to retrieve client") + return + } + + // Update fields + if len(req.RedirectURIs) > 0 { + client.RedirectURIs = req.RedirectURIs + } + if len(req.GrantTypes) > 0 { + client.GrantTypes = req.GrantTypes + } + if req.ClientName != "" { + client.Name = req.ClientName + } + if req.Scope != "" { + client.Scopes = parseScopes(req.Scope) + } + + client.UpdatedAt = time.Now() + + if err := h.store.UpdateClient(r.Context(), client); err != nil { + h.writeError(w, http.StatusInternalServerError, ErrServerError, "Failed to update client") + return + } + + resp := h.buildClientInfoResponse(client) + h.writeJSON(w, http.StatusOK, resp) +} + +// handleDeleteClient deletes a client +func (h *DynamicRegistrationHandler) handleDeleteClient(w http.ResponseWriter, r *http.Request, clientID string) { + if err := h.store.DeleteClient(r.Context(), clientID); err != nil { + if errors.Is(err, ErrNotFound) { + h.writeError(w, http.StatusNotFound, ErrInvalidRequest, "Client not found") + return + } + h.writeError(w, http.StatusInternalServerError, ErrServerError, "Failed to delete client") + return + } + + w.WriteHeader(http.StatusNoContent) +} + +// buildClientInfoResponse builds a client info response +func (h *DynamicRegistrationHandler) buildClientInfoResponse(client *Client) *ClientRegistrationResponse { + applicationType, _ := client.Metadata["application_type"].(string) + tokenEndpointAuthMethod, _ := client.Metadata["token_endpoint_auth_method"].(string) + clientSecretExpiresAt, _ := client.Metadata["client_secret_expires_at"].(int64) + + return &ClientRegistrationResponse{ + ClientID: client.ID, + ClientSecretExpiresAt: clientSecretExpiresAt, + RedirectURIs: client.RedirectURIs, + GrantTypes: client.GrantTypes, + ApplicationType: applicationType, + ClientName: client.Name, + TokenEndpointAuthMethod: tokenEndpointAuthMethod, + Scope: strings.Join(client.Scopes, " "), + ResourceIndicators: client.Resources, + ClientIDIssuedAt: client.CreatedAt.Unix(), + } +} + +// validateRegistrationAccessToken validates the registration access token +func (h *DynamicRegistrationHandler) validateRegistrationAccessToken(ctx context.Context, clientID, token string) error { + client, err := h.store.GetClient(ctx, clientID) + if err != nil { + return err + } + + storedToken, ok := client.Metadata["registration_access_token"].(string) + if !ok || storedToken != token { + return errors.New("invalid registration access token") + } + + return nil +} + +// Helper functions + +// GenerateClientID generates a unique client ID +func GenerateClientID() (string, error) { + return GenerateRandomString(32) +} + +// GenerateClientSecret generates a secure client secret +func GenerateClientSecret() (string, error) { + return GenerateRandomString(64) +} + +// GenerateRegistrationAccessToken generates a registration access token +func GenerateRegistrationAccessToken() (string, error) { + return GenerateRandomString(48) +} + +// GenerateRandomString generates a cryptographically secure random string +func GenerateRandomString(length int) (string, error) { + b := make([]byte, length) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b)[:length], nil +} + +// DefaultRedirectURIValidator validates redirect URIs +func DefaultRedirectURIValidator(uri, applicationType string) error { + // Web clients must use HTTPS + if applicationType == "web" && !strings.HasPrefix(uri, "https://") { + // Allow localhost for development + if !strings.HasPrefix(uri, "http://localhost") && !strings.HasPrefix(uri, "http://127.0.0.1") { + return fmt.Errorf("web clients must use HTTPS redirect URIs") + } + } + + // Native clients can use custom schemes + if applicationType == "native" { + if strings.HasPrefix(uri, "http://") && !strings.HasPrefix(uri, "http://localhost") && !strings.HasPrefix(uri, "http://127.0.0.1") { + return fmt.Errorf("native clients cannot use http:// URIs except localhost") + } + } + + // Reject fragment components + if strings.Contains(uri, "#") { + return fmt.Errorf("redirect URIs must not contain fragment components") + } + + return nil +} + +// extractBearerToken extracts Bearer token from Authorization header +func extractBearerToken(r *http.Request) string { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return "" + } + + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { + return "" + } + + return parts[1] +} + +// extractClientIDFromPath extracts client ID from URL path +func extractClientIDFromPath(path string) string { + parts := strings.Split(strings.Trim(path, "/"), "/") + if len(parts) >= 2 && parts[0] == "register" { + return parts[1] + } + return "" +} + +// contains checks if a string slice contains a value +func contains(slice []string, value string) bool { + for _, item := range slice { + if item == value { + return true + } + } + return false +} + +// HTTP response helpers + +func (h *DynamicRegistrationHandler) writeJSON(w http.ResponseWriter, status int, data interface{}) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + w.WriteHeader(status) + json.NewEncoder(w).Encode(data) +} + +func (h *DynamicRegistrationHandler) writeError(w http.ResponseWriter, status int, errorCode, description string) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + w.WriteHeader(status) + + errorResp := &ErrorResponse{ + Error: errorCode, + ErrorDescription: description, + } + + json.NewEncoder(w).Encode(errorResp) +} + +func (h *DynamicRegistrationHandler) writeErrorFromErr(w http.ResponseWriter, err error) { + status := http.StatusBadRequest + errorCode := ErrInvalidRequest + description := err.Error() + + errStr := err.Error() + if strings.Contains(errStr, ErrInvalidClient) { + status = http.StatusUnauthorized + errorCode = ErrInvalidClient + } else if strings.Contains(errStr, ErrServerError) { + status = http.StatusInternalServerError + errorCode = ErrServerError + } + + h.writeError(w, status, errorCode, description) +} diff --git a/auth/handler.go b/auth/handler.go new file mode 100644 index 0000000..af5281c --- /dev/null +++ b/auth/handler.go @@ -0,0 +1,475 @@ +package auth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" +) + +// Handler provides HTTP handlers for OAuth 2.1 endpoints +type Handler struct { + server *Server + oauthClient *OAuthClient + dynamicRegistrationHandler *DynamicRegistrationHandler +} + +// NewHandler creates a new OAuth handler +func NewHandler(server *Server) *Handler { + return &Handler{ + server: server, + dynamicRegistrationHandler: NewDynamicRegistrationHandler(server, nil), + } +} + +func (h *Handler) SetDynamicRegistrationConfig(config *DynamicRegistrationConfig) { + h.dynamicRegistrationHandler = NewDynamicRegistrationHandler(h.server, config) +} + +// HandleAuthorization handles GET /oauth/authorize +func (h *Handler) HandleAuthorization(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + h.writeError(w, http.StatusMethodNotAllowed, ErrInvalidRequest, "method not allowed") + return + } + + // MCP spec: Check MCP-Protocol-Version header (optional but recommended) + mcpVersion := r.Header.Get(MCPProtocolVersionHeader) + if mcpVersion != "" && mcpVersion != CurrentMCPVersion { + h.server.config.Logger.Warnf("Client using MCP version %s, server supports %s", mcpVersion, CurrentMCPVersion) + // Continue - version mismatch is a warning, not an error + } + + // Parse authorization request + req := &AuthorizationRequest{ + ResponseType: r.URL.Query().Get("response_type"), + ClientID: r.URL.Query().Get("client_id"), + RedirectURI: r.URL.Query().Get("redirect_uri"), + Scope: r.URL.Query().Get("scope"), + State: r.URL.Query().Get("state"), + CodeChallenge: r.URL.Query().Get("code_challenge"), + CodeChallengeMethod: r.URL.Query().Get("code_challenge_method"), + Resource: r.URL.Query()["resource"], // RFC 8707 + } + + // State parameter is required for CSRF protection + if err := validateState(req.State); err != nil { + h.writeError(w, http.StatusBadRequest, ErrInvalidRequest, err.Error()) + return + } + + // Validate client ID + if req.ClientID == "" { + h.writeError(w, http.StatusBadRequest, ErrInvalidRequest, "client_id is required") + return + } + + // Get client + client, err := h.server.store.GetClient(r.Context(), req.ClientID) + if err != nil { + h.writeError(w, http.StatusBadRequest, ErrInvalidClient, "invalid client_id") + return + } + + // Validate redirect URI with exact matching + if err := validateRedirectURI(client.RedirectURIs, req.RedirectURI); err != nil { + h.writeError(w, http.StatusBadRequest, ErrInvalidRequest, err.Error()) + return + } + + // MCP spec: Enforce HTTPS or localhost for redirect URIs + if err := validateEndpointSecurity(req.RedirectURI); err != nil { + h.redirectWithError(w, r, req.RedirectURI, req.State, err) + return + } + + // Validate response type + if req.ResponseType != "code" { + h.redirectWithError(w, r, req.RedirectURI, req.State, + fmt.Errorf("%s: only 'code' response_type is supported", ErrUnsupportedResponseType)) + return + } + + // PKCE is required for public clients + if client.IsPublic && req.CodeChallenge == "" { + h.redirectWithError(w, r, req.RedirectURI, req.State, + fmt.Errorf("%s: PKCE is required for public clients", ErrInvalidRequest)) + return + } + + // Validate PKCE if provided + if req.CodeChallenge != "" { + pkceValidator := NewPKCEValidator() + if err := pkceValidator.ValidateCodeChallenge(req.CodeChallenge, req.CodeChallengeMethod, client.IsPublic); err != nil { + h.redirectWithError(w, r, req.RedirectURI, req.State, err) + return + } + } + + // Validate scopes + requestedScopes := parseScopes(req.Scope) + if err := validateScope(requestedScopes, client.Scopes); err != nil { + h.redirectWithError(w, r, req.RedirectURI, req.State, err) + return + } + + // Get user ID (in production, this would come from user authentication) + userID := r.Header.Get("X-User-ID") + if userID == "" { + h.redirectWithError(w, r, req.RedirectURI, req.State, + fmt.Errorf("%s: user not authenticated", ErrAccessDenied)) + return + } + + // Handle authorization request + redirectURL, err := h.server.HandleAuthorizationRequest(r.Context(), req, userID) + if err != nil { + h.redirectWithError(w, r, req.RedirectURI, req.State, err) + return + } + + // Redirect to client with authorization code + http.Redirect(w, r, redirectURL, http.StatusFound) +} + +// HandleToken handles POST /oauth/token +func (h *Handler) HandleToken(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + h.writeError(w, http.StatusMethodNotAllowed, ErrInvalidRequest, "method not allowed") + return + } + + // MCP spec: Check MCP-Protocol-Version header + mcpVersion := r.Header.Get(MCPProtocolVersionHeader) + if mcpVersion != "" && mcpVersion != CurrentMCPVersion { + h.server.config.Logger.Warnf("Client using MCP version %s for token request", mcpVersion) + } + + // Parse form data + if err := r.ParseForm(); err != nil { + h.writeError(w, http.StatusBadRequest, ErrInvalidRequest, "invalid form data") + return + } + + // Extract token request + req := &TokenRequest{ + GrantType: r.FormValue("grant_type"), + Code: r.FormValue("code"), + RedirectURI: r.FormValue("redirect_uri"), + ClientID: r.FormValue("client_id"), + ClientSecret: r.FormValue("client_secret"), + RefreshToken: r.FormValue("refresh_token"), + Scope: r.FormValue("scope"), + Resource: r.Form["resource"], // RFC 8707 + CodeVerifier: r.FormValue("code_verifier"), + } + + // Try Basic Auth for client credentials + if req.ClientID == "" || req.ClientSecret == "" { + clientID, clientSecret, ok := r.BasicAuth() + if ok { + req.ClientID = clientID + req.ClientSecret = clientSecret + } + } + + // Handle token request + tokenResp, err := h.server.HandleTokenRequest(r.Context(), req) + if err != nil { + h.writeErrorFromErr(w, err) + return + } + + // Write successful response + h.writeJSON(w, http.StatusOK, tokenResp) +} + +// HandleIntrospection handles POST /oauth/introspect +func (h *Handler) HandleIntrospection(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + h.writeError(w, http.StatusMethodNotAllowed, ErrInvalidRequest, "method not allowed") + return + } + + // Parse form data + if err := r.ParseForm(); err != nil { + h.writeError(w, http.StatusBadRequest, ErrInvalidRequest, "invalid form data") + return + } + + token := r.FormValue("token") + if token == "" { + h.writeError(w, http.StatusBadRequest, ErrInvalidRequest, "missing token parameter") + return + } + + // Authenticate client (introspection requires authentication) + clientID, clientSecret, ok := r.BasicAuth() + if !ok { + clientID = r.FormValue("client_id") + clientSecret = r.FormValue("client_secret") + } + + if clientID == "" { + h.writeError(w, http.StatusUnauthorized, ErrInvalidClient, "client authentication required") + return + } + + // Validate client + if err := h.server.store.ValidateClientSecret(r.Context(), clientID, clientSecret); err != nil { + h.writeError(w, http.StatusUnauthorized, ErrInvalidClient, "invalid client credentials") + return + } + + // Introspect token + resp := h.introspectToken(r.Context(), token) + h.writeJSON(w, http.StatusOK, resp) +} + +// HandleRevocation handles POST /oauth/revoke +func (h *Handler) HandleRevocation(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + h.writeError(w, http.StatusMethodNotAllowed, ErrInvalidRequest, "method not allowed") + return + } + + // Parse form data + if err := r.ParseForm(); err != nil { + h.writeError(w, http.StatusBadRequest, ErrInvalidRequest, "invalid form data") + return + } + + token := r.FormValue("token") + tokenTypeHint := r.FormValue("token_type_hint") + + if token == "" { + h.writeError(w, http.StatusBadRequest, ErrInvalidRequest, "missing token parameter") + return + } + + // Authenticate client + clientID, clientSecret, ok := r.BasicAuth() + if !ok { + clientID = r.FormValue("client_id") + clientSecret = r.FormValue("client_secret") + } + + if clientID == "" { + h.writeError(w, http.StatusUnauthorized, ErrInvalidClient, "client authentication required") + return + } + + // Validate client + if err := h.server.store.ValidateClientSecret(r.Context(), clientID, clientSecret); err != nil { + h.writeError(w, http.StatusUnauthorized, ErrInvalidClient, "invalid client credentials") + return + } + + // Revoke token + ctx := r.Context() + if tokenTypeHint == "refresh_token" { + _ = h.server.store.RevokeRefreshToken(ctx, token) + } else { + _ = h.server.store.RevokeAccessToken(ctx, token) + _ = h.server.store.RevokeRefreshToken(ctx, token) + } + + // RFC 7009: The authorization server responds with HTTP status code 200 + w.WriteHeader(http.StatusOK) +} + +// introspectToken introspects an access token +func (h *Handler) introspectToken(ctx context.Context, token string) *IntrospectionResponse { + // Try to validate JWT token + claims, err := h.server.tokenGenerator.ValidateAccessToken(token) + if err != nil { + return &IntrospectionResponse{Active: false} + } + + // Verify token exists in store + accessToken, err := h.server.store.GetAccessToken(ctx, token) + if err != nil { + return &IntrospectionResponse{Active: false} + } + + // Build introspection response with audience (RFC 8707) + return &IntrospectionResponse{ + Active: true, + Scope: strings.Join(accessToken.Scopes, " "), + ClientID: accessToken.ClientID, + Username: accessToken.UserID, + TokenType: string(accessToken.TokenType), + ExpiresAt: accessToken.ExpiresAt.Unix(), + IssuedAt: accessToken.CreatedAt.Unix(), + Subject: claims.Subject, + Audience: accessToken.Resources, // RFC 8707 + } +} + +// Helper methods for HTTP responses + +func (h *Handler) writeJSON(w http.ResponseWriter, status int, data interface{}) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + w.WriteHeader(status) + json.NewEncoder(w).Encode(data) +} + +func (h *Handler) writeError(w http.ResponseWriter, status int, errorCode, description string) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + w.WriteHeader(status) + + errorResp := &ErrorResponse{ + Error: errorCode, + ErrorDescription: description, + } + + json.NewEncoder(w).Encode(errorResp) +} + +func (h *Handler) writeErrorFromErr(w http.ResponseWriter, err error) { + status := http.StatusBadRequest + errorCode := ErrInvalidRequest + description := err.Error() + + // Parse error to determine status code + errStr := err.Error() + if strings.Contains(errStr, ErrInvalidClient) { + status = http.StatusUnauthorized + errorCode = ErrInvalidClient + } else if strings.Contains(errStr, ErrInvalidGrant) { + errorCode = ErrInvalidGrant + } else if strings.Contains(errStr, ErrUnauthorizedClient) { + errorCode = ErrUnauthorizedClient + } else if strings.Contains(errStr, ErrUnsupportedGrantType) { + errorCode = ErrUnsupportedGrantType + } else if strings.Contains(errStr, ErrInvalidScope) { + errorCode = ErrInvalidScope + } else if strings.Contains(errStr, ErrInvalidTarget) { + errorCode = ErrInvalidTarget + } else if strings.Contains(errStr, ErrServerError) { + status = http.StatusInternalServerError + errorCode = ErrServerError + } + + h.writeError(w, status, errorCode, description) +} + +func (h *Handler) redirectWithError(w http.ResponseWriter, r *http.Request, redirectURI, state string, err error) { + if redirectURI == "" { + h.writeErrorFromErr(w, err) + return + } + + errorCode := ErrServerError + description := err.Error() + + // Determine error code + errStr := err.Error() + if strings.Contains(errStr, ErrInvalidRequest) { + errorCode = ErrInvalidRequest + } else if strings.Contains(errStr, ErrUnauthorizedClient) { + errorCode = ErrUnauthorizedClient + } else if strings.Contains(errStr, ErrAccessDenied) { + errorCode = ErrAccessDenied + } else if strings.Contains(errStr, ErrUnsupportedResponseType) { + errorCode = ErrUnsupportedResponseType + } else if strings.Contains(errStr, ErrInvalidScope) { + errorCode = ErrInvalidScope + } else if strings.Contains(errStr, ErrInvalidTarget) { + errorCode = ErrInvalidTarget + } + + // Build error redirect URL + redirectURL, _ := url.Parse(redirectURI) + query := redirectURL.Query() + query.Set("error", errorCode) + query.Set("error_description", description) + if state != "" { + query.Set("state", state) + } + redirectURL.RawQuery = query.Encode() + + http.Redirect(w, r, redirectURL.String(), http.StatusFound) +} + +// RegisterRoutes registers OAuth endpoints on a mux +func (h *Handler) RegisterRoutes(mux *http.ServeMux) { + mux.HandleFunc("/oauth/authorize", h.HandleAuthorization) + mux.HandleFunc("/oauth/token", h.HandleToken) + mux.HandleFunc("/oauth/introspect", h.HandleIntrospection) + mux.HandleFunc("/oauth/revoke", h.HandleRevocation) + + // RFC 7591: Dynamic Client Registration + mux.HandleFunc("/register", h.dynamicRegistrationHandler.HandleRegister) + mux.HandleFunc("/register/", h.dynamicRegistrationHandler.HandleClientConfiguration) + + // RFC 8414: Authorization Server Metadata + mux.Handle("/.well-known/oauth-authorization-server", h.server.GetMetadataProvider()) + + // RFC 9728: Protected Resource Metadata (optional) + mux.HandleFunc("/.well-known/oauth-protected-resource", h.server.GetMetadataProvider().ServeProtectedResourceMetadata) + + // JWKS endpoint + mux.Handle("/.well-known/jwks.json", h.server.GetJWKSProvider()) +} + +// RegisterRoutesWithPrefix registers OAuth endpoints with a custom prefix +func (h *Handler) RegisterRoutesWithPrefix(mux *http.ServeMux, prefix string) { + if !strings.HasPrefix(prefix, "/") { + prefix = "/" + prefix + } + if strings.HasSuffix(prefix, "/") { + prefix = strings.TrimSuffix(prefix, "/") + } + + mux.HandleFunc(prefix+"/authorize", h.HandleAuthorization) + mux.HandleFunc(prefix+"/token", h.HandleToken) + mux.HandleFunc(prefix+"/introspect", h.HandleIntrospection) + mux.HandleFunc(prefix+"/revoke", h.HandleRevocation) + + // RFC 7591: Dynamic Client Registration + mux.HandleFunc("/register", h.dynamicRegistrationHandler.HandleRegister) + mux.HandleFunc("/register/", h.dynamicRegistrationHandler.HandleClientConfiguration) + + // Metadata and JWKS at well-known locations (not prefixed per spec) + mux.Handle("/.well-known/oauth-authorization-server", h.server.GetMetadataProvider()) + mux.HandleFunc("/.well-known/oauth-protected-resource", h.server.GetMetadataProvider().ServeProtectedResourceMetadata) + mux.Handle("/.well-known/jwks.json", h.server.GetJWKSProvider()) +} + +func (h *Handler) SetOAuthClient(client *OAuthClient) { + h.oauthClient = client +} + +func (h *Handler) RegisterRoutesWithOAuth(mux *http.ServeMux) { + // Standard OAuth Endpoints + mux.HandleFunc("/oauth/authorize", h.HandleAuthorization) + mux.HandleFunc("/oauth/token", h.HandleToken) + mux.HandleFunc("/oauth/introspect", h.HandleIntrospection) + mux.HandleFunc("/oauth/revoke", h.HandleRevocation) + + // RFC 7591: Dynamic Client Registration + mux.HandleFunc("/register", h.dynamicRegistrationHandler.HandleRegister) + mux.HandleFunc("/register/", h.dynamicRegistrationHandler.HandleClientConfiguration) + + // Metadata and JWKS + mux.Handle("/.well-known/oauth-authorization-server", h.server.GetMetadataProvider()) + mux.HandleFunc("/.well-known/oauth-protected-resource", h.server.GetMetadataProvider().ServeProtectedResourceMetadata) + mux.Handle("/.well-known/jwks.json", h.server.GetJWKSProvider()) + + // Third-party OAuth flow endpoints + if h.oauthClient != nil { + mux.HandleFunc("/oauth/login", h.oauthClient.InitiateOAuthFlow) + mux.HandleFunc("/oauth/callback", h.oauthClient.HandleCallback) + } +} + +func (h *Handler) GetDynamicRegistrationHandler() *DynamicRegistrationHandler { + return h.dynamicRegistrationHandler +} diff --git a/auth/jwks.go b/auth/jwks.go new file mode 100644 index 0000000..995b9c1 --- /dev/null +++ b/auth/jwks.go @@ -0,0 +1,149 @@ +package auth + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "math/big" + "net/http" +) + +// JWK represents a JSON Web Key +type JWK struct { + Kty string `json:"kty"` // Key Type + Use string `json:"use,omitempty"` // Public Key Use + Kid string `json:"kid,omitempty"` // Key ID + Alg string `json:"alg,omitempty"` // Algorithm + N string `json:"n,omitempty"` // RSA Modulus (for RSA keys) + E string `json:"e,omitempty"` // RSA Exponent (for RSA keys) + K string `json:"k,omitempty"` // Symmetric key (for oct keys) +} + +// JWKSet represents a JSON Web Key Set +type JWKSet struct { + Keys []JWK `json:"keys"` +} + +// JWKSProvider manages JWKS endpoint +type JWKSProvider struct { + privateKey *rsa.PrivateKey // For RS256 + publicKey *rsa.PublicKey // For RS256 + signingKey []byte // For HS256 (backward compatibility) + keyID string + jwks *JWKSet + algorithm string // "RS256" or "HS256" +} + +// NewJWKSProvider creates a new JWKS provider +// For backward compatibility, it uses HS256 by default +func NewJWKSProvider(signingKey []byte, keyID string) *JWKSProvider { + provider := &JWKSProvider{ + signingKey: signingKey, + keyID: keyID, + algorithm: "HS256", + } + provider.generateJWKS() + return provider +} + +// NewJWKSProviderWithRSA creates a JWKS provider with RSA support +func NewJWKSProviderWithRSA(keyID string) (*JWKSProvider, error) { + // Generate RSA key pair + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + + provider := &JWKSProvider{ + privateKey: privateKey, + publicKey: &privateKey.PublicKey, + keyID: keyID, + algorithm: "RS256", + } + provider.generateJWKS() + return provider, nil +} + +// NewJWKSProviderWithRSAKey creates a JWKS provider with existing RSA key +func NewJWKSProviderWithRSAKey(privateKey *rsa.PrivateKey, keyID string) *JWKSProvider { + provider := &JWKSProvider{ + privateKey: privateKey, + publicKey: &privateKey.PublicKey, + keyID: keyID, + algorithm: "RS256", + } + provider.generateJWKS() + return provider +} + +// generateJWKS creates the JWKS document +func (jp *JWKSProvider) generateJWKS() { + if jp.algorithm == "RS256" && jp.publicKey != nil { + jp.jwks = &JWKSet{ + Keys: []JWK{ + { + Kty: "RSA", + Use: "sig", + Kid: jp.keyID, + Alg: "RS256", + N: base64.RawURLEncoding.EncodeToString(jp.publicKey.N.Bytes()), + E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(jp.publicKey.E)).Bytes()), + }, + }, + } + } else { + jp.jwks = &JWKSet{ + Keys: []JWK{}, // empty for HS256 + } + } +} + +// GetJWKS returns the JWKS document +func (jp *JWKSProvider) GetJWKS() *JWKSet { + return jp.jwks +} + +// GetPrivateKey returns the RSA private key (for signing) +func (jp *JWKSProvider) GetPrivateKey() *rsa.PrivateKey { + return jp.privateKey +} + +// GetAlgorithm returns the signing algorithm +func (jp *JWKSProvider) GetAlgorithm() string { + return jp.algorithm +} + +// ServeHTTP handles the JWKS endpoint +func (jp *JWKSProvider) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + if jp.algorithm == "HS256" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotImplemented) + json.NewEncoder(w).Encode(map[string]string{ + "error": "jwks_not_available", + "error_description": "JWKS endpoint is not available for HMAC-based tokens. Use token introspection endpoint instead.", + }) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "public, max-age=3600") + w.Header().Set("Access-Control-Allow-Origin", "*") + + json.NewEncoder(w).Encode(jp.jwks) +} + +// GenerateKeyID creates a key ID from the signing key +func GenerateKeyID(signingKey []byte) string { + h := hmac.New(sha256.New, []byte("key-id-salt")) + h.Write(signingKey) + hash := h.Sum(nil) + return base64.RawURLEncoding.EncodeToString(hash[:16]) +} diff --git a/auth/memory/store.go b/auth/memory/store.go new file mode 100644 index 0000000..42a727c --- /dev/null +++ b/auth/memory/store.go @@ -0,0 +1,283 @@ +package memory + +import ( + "context" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "errors" + "sync" + "time" + + "github.com/ThinkInAIXYZ/go-mcp/auth" +) + +// Store implements an in-memory OAuth store for development and testing +type Store struct { + clients sync.Map // clientID -> *auth.Client + authCodes sync.Map // code -> *auth.AuthorizationCode + accessTokens sync.Map // token -> *auth.AccessToken + refreshTokens sync.Map // token -> *auth.RefreshToken + clientSecrets sync.Map // clientID -> hashedSecret +} + +// NewStore creates a new in-memory store +func NewStore() *Store { + return &Store{} +} + +// Client operations + +func (s *Store) GetClient(_ context.Context, clientID string) (*auth.Client, error) { + value, ok := s.clients.Load(clientID) + if !ok { + return nil, auth.ErrNotFound + } + + client := value.(*auth.Client) + return client, nil +} + +func (s *Store) CreateClient(_ context.Context, client *auth.Client) error { + if client.ID == "" { + return auth.ErrInvalidStore + } + + _, loaded := s.clients.LoadOrStore(client.ID, client) + if loaded { + return auth.ErrAlreadyExists + } + + // Hash and store client secret if present + if client.Secret != "" { + hashedSecret := hashSecret(client.Secret) + s.clientSecrets.Store(client.ID, hashedSecret) + client.Secret = "" + } + + return nil +} + +func (s *Store) UpdateClient(_ context.Context, client *auth.Client) error { + if client.ID == "" { + return auth.ErrInvalidStore + } + + _, ok := s.clients.Load(client.ID) + if !ok { + return auth.ErrNotFound + } + + // Update secret if provided + if client.Secret != "" { + hashedSecret := hashSecret(client.Secret) + s.clientSecrets.Store(client.ID, hashedSecret) + client.Secret = "" + } + + client.UpdatedAt = time.Now() + s.clients.Store(client.ID, client) + + return nil +} + +func (s *Store) DeleteClient(_ context.Context, clientID string) error { + s.clients.Delete(clientID) + s.clientSecrets.Delete(clientID) + return nil +} + +func (s *Store) ValidateClientSecret(_ context.Context, clientID, secret string) error { + value, ok := s.clientSecrets.Load(clientID) + if !ok { + return auth.ErrNotFound + } + + hashedSecret := value.(string) + providedHash := hashSecret(secret) + + // Use constant-time comparison to prevent timing attacks + if subtle.ConstantTimeCompare([]byte(hashedSecret), []byte(providedHash)) != 1 { + return errors.New(auth.ErrInvalidClient) + } + + return nil +} + +// Authorization code operations + +func (s *Store) SaveAuthorizationCode(_ context.Context, code *auth.AuthorizationCode) error { + if code.Code == "" { + return auth.ErrInvalidStore + } + + s.authCodes.Store(code.Code, code) + return nil +} + +func (s *Store) GetAuthorizationCode(_ context.Context, code string) (*auth.AuthorizationCode, error) { + value, ok := s.authCodes.Load(code) + if !ok { + return nil, auth.ErrNotFound + } + + authCode := value.(*auth.AuthorizationCode) + + // Check expiration + if time.Now().After(authCode.ExpiresAt) { + return nil, auth.ErrExpired + } + + // Check if already used + if authCode.Used { + return nil, auth.ErrRevoked + } + + return authCode, nil +} + +func (s *Store) InvalidateAuthorizationCode(_ context.Context, code string) error { + value, ok := s.authCodes.Load(code) + if !ok { + return auth.ErrNotFound + } + + authCode := value.(*auth.AuthorizationCode) + authCode.Used = true + s.authCodes.Store(code, authCode) + + return nil +} + +func (s *Store) CleanupExpiredAuthorizationCodes(_ context.Context) error { + now := time.Now() + + s.authCodes.Range(func(key, value interface{}) bool { + code := value.(*auth.AuthorizationCode) + if now.After(code.ExpiresAt) || code.Used { + s.authCodes.Delete(key) + } + return true + }) + + return nil +} + +// Access token operations + +func (s *Store) SaveAccessToken(_ context.Context, token *auth.AccessToken) error { + if token.Token == "" { + return auth.ErrInvalidStore + } + + s.accessTokens.Store(token.Token, token) + return nil +} + +func (s *Store) GetAccessToken(_ context.Context, token string) (*auth.AccessToken, error) { + value, ok := s.accessTokens.Load(token) + if !ok { + return nil, auth.ErrNotFound + } + + accessToken := value.(*auth.AccessToken) + + // Check expiration + if time.Now().After(accessToken.ExpiresAt) { + return nil, auth.ErrExpired + } + + return accessToken, nil +} + +func (s *Store) RevokeAccessToken(_ context.Context, token string) error { + s.accessTokens.Delete(token) + return nil +} + +// Refresh token operations + +func (s *Store) SaveRefreshToken(_ context.Context, token *auth.RefreshToken) error { + if token.Token == "" { + return auth.ErrInvalidStore + } + + s.refreshTokens.Store(token.Token, token) + return nil +} + +func (s *Store) GetRefreshToken(_ context.Context, token string) (*auth.RefreshToken, error) { + value, ok := s.refreshTokens.Load(token) + if !ok { + return nil, auth.ErrNotFound + } + + refreshToken := value.(*auth.RefreshToken) + + // Check if revoked + if refreshToken.Revoked { + return nil, auth.ErrRevoked + } + + // Check expiration + if time.Now().After(refreshToken.ExpiresAt) { + return nil, auth.ErrExpired + } + + return refreshToken, nil +} + +func (s *Store) RevokeRefreshToken(_ context.Context, token string) error { + value, ok := s.refreshTokens.Load(token) + if !ok { + return auth.ErrNotFound + } + + refreshToken := value.(*auth.RefreshToken) + refreshToken.Revoked = true + s.refreshTokens.Store(token, refreshToken) + + return nil +} + +func (s *Store) RotateRefreshToken(ctx context.Context, oldToken string, newToken *auth.RefreshToken) error { + // Revoke old token + if err := s.RevokeRefreshToken(ctx, oldToken); err != nil { + return err + } + + // Save new token + return s.SaveRefreshToken(ctx, newToken) +} + +// Cleanup operations + +func (s *Store) CleanupExpiredTokens(_ context.Context) error { + now := time.Now() + + // Cleanup access tokens + s.accessTokens.Range(func(key, value interface{}) bool { + token := value.(*auth.AccessToken) + if now.After(token.ExpiresAt) { + s.accessTokens.Delete(key) + } + return true + }) + + // Cleanup refresh tokens + s.refreshTokens.Range(func(key, value interface{}) bool { + token := value.(*auth.RefreshToken) + if now.After(token.ExpiresAt) || token.Revoked { + s.refreshTokens.Delete(key) + } + return true + }) + + return nil +} + +// hashSecret creates a SHA256 hash of the secret +func hashSecret(secret string) string { + hash := sha256.Sum256([]byte(secret)) + return base64.StdEncoding.EncodeToString(hash[:]) +} diff --git a/auth/metadata.go b/auth/metadata.go new file mode 100644 index 0000000..5874ee1 --- /dev/null +++ b/auth/metadata.go @@ -0,0 +1,387 @@ +package auth + +import ( + "encoding/json" + "errors" + "net/http" + "net/url" +) + +// AuthorizationServerMetadata represents OAuth 2.0 Authorization Server Metadata (RFC 8414) +type AuthorizationServerMetadata struct { + // REQUIRED: The authorization server's issuer identifier + Issuer string `json:"issuer"` + + // REQUIRED: URL of the authorization endpoint + AuthorizationEndpoint string `json:"authorization_endpoint"` + + // REQUIRED: URL of the token endpoint + TokenEndpoint string `json:"token_endpoint"` + + // OPTIONAL: URL of the JWK Set document + JWKSURI string `json:"jwks_uri,omitempty"` + + // OPTIONAL: URL of the Dynamic Client Registration endpoint + RegistrationEndpoint string `json:"registration_endpoint,omitempty"` + + // OPTIONAL: JSON array of scope values supported + ScopesSupported []string `json:"scopes_supported,omitempty"` + + // REQUIRED: JSON array of response_type values supported + ResponseTypesSupported []string `json:"response_types_supported"` + + // OPTIONAL: JSON array of response_mode values supported + ResponseModesSupported []string `json:"response_modes_supported,omitempty"` + + // OPTIONAL: JSON array of grant types supported + GrantTypesSupported []string `json:"grant_types_supported,omitempty"` + + // OPTIONAL: JSON array of client authentication methods supported + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"` + + // OPTIONAL: JSON array of signing algorithms supported + TokenEndpointAuthSigningAlgValuesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported,omitempty"` + + // OPTIONAL: URL of the revocation endpoint + RevocationEndpoint string `json:"revocation_endpoint,omitempty"` + + // OPTIONAL: Client authentication methods supported by revocation endpoint + RevocationEndpointAuthMethodsSupported []string `json:"revocation_endpoint_auth_methods_supported,omitempty"` + + // OPTIONAL: URL of the introspection endpoint + IntrospectionEndpoint string `json:"introspection_endpoint,omitempty"` + + // OPTIONAL: Client authentication methods supported by introspection endpoint + IntrospectionEndpointAuthMethodsSupported []string `json:"introspection_endpoint_auth_methods_supported,omitempty"` + + // OPTIONAL: JSON array of PKCE code challenge methods supported + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported,omitempty"` + + // RFC 8707: Resource Indicators + // OPTIONAL: Boolean indicating support for resource indicators + ResourceIndicatorsSupported bool `json:"resource_indicators_supported,omitempty"` + + // MCP-specific extensions + MCPVersion string `json:"mcp_version,omitempty"` + MCPTransportsSupported []string `json:"mcp_transports_supported,omitempty"` +} + +// ProtectedResourceMetadata represents OAuth 2.0 Protected Resource Metadata (RFC 9728) +// This metadata helps MCP clients discover how to access protected resources +// NOTE: RFC 9728 is NOT required by MCP spec but included for completeness +type ProtectedResourceMetadata struct { + // REQUIRED: The protected resource identifier (URI) + Resource string `json:"resource"` + + // REQUIRED: Array of authorization server identifiers that can issue tokens for this resource + AuthorizationServers []string `json:"authorization_servers"` + + // OPTIONAL: Array of bearer token usage methods supported (e.g., "header", "body", "query") + BearerMethodsSupported []string `json:"bearer_methods_supported,omitempty"` + + // OPTIONAL: Array of signing algorithms for JWS access tokens + ResourceSigningAlgValuesSupported []string `json:"resource_signing_alg_values_supported,omitempty"` + + // OPTIONAL: Array of encryption algorithms for JWE access tokens + ResourceEncryptionAlgValuesSupported []string `json:"resource_encryption_alg_values_supported,omitempty"` + + // OPTIONAL: Array of encryption encoding algorithms for JWE access tokens + ResourceEncryptionEncValuesSupported []string `json:"resource_encryption_enc_values_supported,omitempty"` + + // OPTIONAL: Scopes supported by this resource + ScopesSupported []string `json:"scopes_supported,omitempty"` + + // OPTIONAL: URL of the resource's documentation + ResourceDocumentation string `json:"resource_documentation,omitempty"` + + // MCP-specific extensions + MCPCapabilities *MCPResourceCapabilities `json:"mcp_capabilities,omitempty"` +} + +// MCPResourceCapabilities describes MCP-specific capabilities for a protected resource +type MCPResourceCapabilities struct { + Tools *MCPToolsCapability `json:"tools,omitempty"` + Prompts *MCPPromptsCapability `json:"prompts,omitempty"` + Resources *MCPResourcesCapability `json:"resources,omitempty"` +} + +type MCPToolsCapability struct { + ListChanged bool `json:"listChanged,omitempty"` +} + +type MCPPromptsCapability struct { + ListChanged bool `json:"listChanged,omitempty"` +} + +type MCPResourcesCapability struct { + Subscribe bool `json:"subscribe,omitempty"` + ListChanged bool `json:"listChanged,omitempty"` +} + +// MetadataProvider generates authorization server and protected resource metadata +type MetadataProvider struct { + config *ServerConfig + authorizationBaseURL string // Derived from MCP server URL per spec + authServerMetadata *AuthorizationServerMetadata + protectedResourceMetadata *ProtectedResourceMetadata + enableProtectedResourceMD bool + logger Logger +} + +// NewMetadataProvider creates a new metadata provider +// Automatically derives authorization base URL from MCP server URL per MCP spec +func NewMetadataProvider(config *ServerConfig) (*MetadataProvider, error) { + mp := &MetadataProvider{ + config: config, + enableProtectedResourceMD: true, // Enable by default for MCP compliance + logger: config.Logger, + } + + // Derive authorization base URL from MCP server URL + // MCP spec: "The authorization base URL MUST be determined from the MCP server URL + // by discarding any existing path component" + baseURL, err := deriveAuthorizationBaseURL(config.MCPServerURL) + if err != nil { + return nil, err + } + mp.authorizationBaseURL = baseURL + + mp.generateAuthServerMetadata() + mp.generateProtectedResourceMetadata() + return mp, nil +} + +// deriveAuthorizationBaseURL implements MCP spec requirement: +// "discarding any existing path component" +func deriveAuthorizationBaseURL(mcpServerURL string) (string, error) { + if mcpServerURL == "" { + return "", ErrInvalidServerURL + } + + parsedURL, err := url.Parse(mcpServerURL) + if err != nil { + return "", err + } + + // Discard path, query, and fragment per MCP spec + parsedURL.Path = "" + parsedURL.RawQuery = "" + parsedURL.Fragment = "" + + return parsedURL.String(), nil +} + +// SetProtectedResourceMetadataEnabled enables or disables protected resource metadata +func (mp *MetadataProvider) SetProtectedResourceMetadataEnabled(enabled bool) { + mp.enableProtectedResourceMD = enabled +} + +// generateAuthServerMetadata builds the authorization server metadata document +func (mp *MetadataProvider) generateAuthServerMetadata() { + grantTypes := make([]string, len(mp.config.SupportedGrantTypes)) + for i, gt := range mp.config.SupportedGrantTypes { + grantTypes[i] = string(gt) + } + + mp.authServerMetadata = &AuthorizationServerMetadata{ + Issuer: mp.config.Issuer, + AuthorizationEndpoint: mp.authorizationBaseURL + "/oauth/authorize", + TokenEndpoint: mp.authorizationBaseURL + "/oauth/token", + RevocationEndpoint: mp.authorizationBaseURL + "/oauth/revoke", + IntrospectionEndpoint: mp.authorizationBaseURL + "/oauth/introspect", + JWKSURI: mp.authorizationBaseURL + "/.well-known/jwks.json", + RegistrationEndpoint: mp.authorizationBaseURL + "/register", + + ResponseTypesSupported: []string{"code"}, + ResponseModesSupported: []string{"query"}, + GrantTypesSupported: grantTypes, + + TokenEndpointAuthMethodsSupported: []string{ + "client_secret_basic", + "none", // For public clients + }, + + RevocationEndpointAuthMethodsSupported: []string{ + "client_secret_basic", + }, + + IntrospectionEndpointAuthMethodsSupported: []string{ + "client_secret_basic", + }, + + CodeChallengeMethodsSupported: []string{"S256"}, + + // RFC 8707 support + ResourceIndicatorsSupported: true, + + // MCP-specific + MCPVersion: CurrentMCPVersion, + MCPTransportsSupported: []string{"sse", "stdio"}, + + ScopesSupported: []string{ + "read", + "write", + "tools:list", + "tools:execute", + "prompts:list", + "prompts:execute", + "resources:list", + "resources:read", + }, + } +} + +// generateProtectedResourceMetadata builds the protected resource metadata (RFC 9728) +func (mp *MetadataProvider) generateProtectedResourceMetadata() { + // Determine resource identifier (typically the MCP server's base URI) + resourceURI := mp.config.Issuer + if len(mp.config.SupportedResources) > 0 { + resourceURI = mp.config.SupportedResources[0] + } + + mp.protectedResourceMetadata = &ProtectedResourceMetadata{ + Resource: resourceURI, + AuthorizationServers: []string{ + mp.config.Issuer, // This MCP server acts as its own authorization server + }, + BearerMethodsSupported: []string{ + "header", // Authorization: Bearer (MCP spec requirement) + }, + ResourceSigningAlgValuesSupported: []string{ + "RS256", // RSA with SHA-256 + "HS256", // HMAC with SHA-256 + }, + ScopesSupported: []string{ + "read", + "write", + "tools:list", + "tools:execute", + "prompts:list", + "prompts:execute", + "resources:list", + "resources:read", + }, + ResourceDocumentation: mp.authorizationBaseURL + "/docs", + MCPCapabilities: &MCPResourceCapabilities{ + Tools: &MCPToolsCapability{ + ListChanged: true, + }, + Prompts: &MCPPromptsCapability{ + ListChanged: true, + }, + Resources: &MCPResourcesCapability{ + Subscribe: true, + ListChanged: true, + }, + }, + } +} + +// GetAuthServerMetadata returns the authorization server metadata document +func (mp *MetadataProvider) GetAuthServerMetadata() *AuthorizationServerMetadata { + return mp.authServerMetadata +} + +// GetProtectedResourceMetadata returns the protected resource metadata document +func (mp *MetadataProvider) GetProtectedResourceMetadata() *ProtectedResourceMetadata { + return mp.protectedResourceMetadata +} + +// GetAuthorizationBaseURL returns the derived authorization base URL +func (mp *MetadataProvider) GetAuthorizationBaseURL() string { + return mp.authorizationBaseURL +} + +// ServeHTTP handles the authorization server metadata endpoint +// Endpoint: /.well-known/oauth-authorization-server +// MCP spec: "MCP clients MUST follow the OAuth 2.0 Authorization Server Metadata protocol" +func (mp *MetadataProvider) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // MCP spec: Check MCP-Protocol-Version header + mcpVersion := r.Header.Get(MCPProtocolVersionHeader) + if mcpVersion != "" { + if !mp.isCompatibleMCPVersion(mcpVersion) { + mp.logger.Warnf("Client requested MCP version %s, server supports %s", mcpVersion, CurrentMCPVersion) + // Continue serving - version mismatch is a warning, not an error + // Clients should handle version differences gracefully + } + } else { + mp.logger.Infof("Client did not provide MCP-Protocol-Version header") + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "public, max-age=3600") // Cache for 1 hour + w.Header().Set("Access-Control-Allow-Origin", "*") + + json.NewEncoder(w).Encode(mp.authServerMetadata) +} + +// ServeProtectedResourceMetadata handles the protected resource metadata endpoint +// Endpoint: /.well-known/oauth-protected-resource +// NOTE: This is RFC 9728, not required by MCP spec but useful for advanced scenarios +func (mp *MetadataProvider) ServeProtectedResourceMetadata(w http.ResponseWriter, r *http.Request) { + if !mp.enableProtectedResourceMD { + http.Error(w, "Not Found", http.StatusNotFound) + return + } + + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Check MCP version + mcpVersion := r.Header.Get(MCPProtocolVersionHeader) + if mcpVersion != "" && !mp.isCompatibleMCPVersion(mcpVersion) { + mp.logger.Warnf("Client requested MCP version %s for protected resource metadata", mcpVersion) + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "public, max-age=3600") // Cache for 1 hour + w.Header().Set("Access-Control-Allow-Origin", "*") + + json.NewEncoder(w).Encode(mp.protectedResourceMetadata) +} + +// isCompatibleMCPVersion checks if the requested MCP version is compatible +func (mp *MetadataProvider) isCompatibleMCPVersion(requestedVersion string) bool { + // For now, exact match required + // In production, implement semantic versioning compatibility + return requestedVersion == CurrentMCPVersion +} + +// CustomizeProtectedResourceMetadata allows customization of protected resource metadata +func (mp *MetadataProvider) CustomizeProtectedResourceMetadata( + customize func(*ProtectedResourceMetadata), +) { + if mp.protectedResourceMetadata != nil { + customize(mp.protectedResourceMetadata) + } +} + +// AddAuthorizationServer adds an additional authorization server to the protected resource metadata +// This is useful when the MCP server accepts tokens from external authorization servers +func (mp *MetadataProvider) AddAuthorizationServer(authServerURI string) { + if mp.protectedResourceMetadata == nil { + return + } + + // Check if already exists + for _, existing := range mp.protectedResourceMetadata.AuthorizationServers { + if existing == authServerURI { + return + } + } + + mp.protectedResourceMetadata.AuthorizationServers = append( + mp.protectedResourceMetadata.AuthorizationServers, + authServerURI, + ) +} + +// Additional error for metadata provider +var ErrInvalidServerURL = errors.New("auth: invalid MCP server URL") diff --git a/auth/middleware.go b/auth/middleware.go new file mode 100644 index 0000000..73076a4 --- /dev/null +++ b/auth/middleware.go @@ -0,0 +1,307 @@ +package auth + +import ( + "context" + "fmt" + "net/http" + "strings" + + "github.com/ThinkInAIXYZ/go-mcp/protocol" +) + +// Middleware provides OAuth 2.1 authentication for MCP server +type Middleware struct { + server *Server + tokenGenerator *TokenGenerator + store Store + realm string // RFC 6750: realm for WWW-Authenticate + + // Optional custom validators + scopeValidator func(context.Context, []string, string) bool + userIDExtractor func(context.Context) (string, error) +} + +// MiddlewareConfig configures OAuth middleware +type MiddlewareConfig struct { + // Required scopes for all requests + RequiredScopes []string + // Optional scope validator + ScopeValidator func(context.Context, []string, string) bool + // Optional user ID extractor from context + UserIDExtractor func(context.Context) (string, error) + // Realm for WWW-Authenticate header + Realm string +} + +// NewMiddleware creates a new OAuth middleware +func NewMiddleware(server *Server) *Middleware { + return &Middleware{ + server: server, + tokenGenerator: server.tokenGenerator, + store: server.store, + realm: "MCP Server", + } +} + +// SetRealm sets the realm for WWW-Authenticate header +func (m *Middleware) SetRealm(realm string) { + m.realm = realm +} + +// HTTPMiddleware returns an HTTP middleware that validates OAuth tokens +func (m *Middleware) HTTPMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Extract token from Authorization header + token := m.extractTokenFromHeader(r) + if token == "" { + m.writeUnauthorizedError(w, "invalid_request", "Missing or invalid authorization header") + return + } + + // Validate token + claims, err := m.tokenGenerator.ValidateAccessToken(token) + if err != nil { + m.writeUnauthorizedError(w, "invalid_token", "The access token is invalid or expired") + return + } + + // Verify token exists in store (not revoked) + accessToken, err := m.store.GetAccessToken(ctx, token) + if err != nil { + m.writeUnauthorizedError(w, "invalid_token", "Token not found or has been revoked") + return + } + + // RFC 8707: Validate resource if specified in request + requestedResource := r.Header.Get("X-Resource-Indicator") + if requestedResource != "" { + if err := m.tokenGenerator.ValidateTokenAudience(claims, requestedResource); err != nil { + m.writeInsufficientScopeError(w, "The access token does not have permission for the requested resource") + return + } + } + + // Add OAuth context to request + ctx = WithOAuthContext(ctx, &OAuthContext{ + ClientID: claims.ClientID, + UserID: claims.Subject, + Scopes: claims.Scopes, + AccessToken: accessToken, + }) + + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// MCPToolMiddleware is a generic function that accepts any type with the specified underlying function signature. +func MCPToolMiddleware[H ~func(context.Context, *protocol.CallToolRequest) (*protocol.CallToolResult, error)]( + m *Middleware, config *MiddlewareConfig, +) func(H) H { + return func(next H) H { + return func(ctx context.Context, req *protocol.CallToolRequest) (*protocol.CallToolResult, error) { + oauthCtx := GetOAuthContext(ctx) + if oauthCtx == nil { + return nil, fmt.Errorf("unauthorized: no OAuth context") + } + + if config != nil && len(config.RequiredScopes) > 0 { + if !m.hasRequiredScopes(oauthCtx.Scopes, config.RequiredScopes) { + return nil, fmt.Errorf("insufficient_scope: required scopes %v", config.RequiredScopes) + } + } + + if config != nil && config.ScopeValidator != nil { + if !config.ScopeValidator(ctx, oauthCtx.Scopes, req.Name) { + return nil, fmt.Errorf("insufficient_scope: scope validation failed for tool: %s", req.Name) + } + } + + return next(ctx, req) + } + } +} + +// ScopeBasedToolMiddleware is a generic function that accepts any type with the specified underlying function signature. +func ScopeBasedToolMiddleware[H ~func(context.Context, *protocol.CallToolRequest) (*protocol.CallToolResult, error)]( + m *Middleware, toolScopes map[string][]string, +) func(H) H { + return func(next H) H { + return func(ctx context.Context, req *protocol.CallToolRequest) (*protocol.CallToolResult, error) { + oauthCtx := GetOAuthContext(ctx) + if oauthCtx == nil { + return nil, fmt.Errorf("unauthorized: no OAuth context") + } + + requiredScopes, ok := toolScopes[req.Name] + if ok && !m.hasRequiredScopes(oauthCtx.Scopes, requiredScopes) { + return nil, fmt.Errorf("insufficient_scope: tool '%s' requires scopes %v, have %v", + req.Name, requiredScopes, oauthCtx.Scopes) + } + + return next(ctx, req) + } + } +} + +// extractTokenFromHeader extracts Bearer token from Authorization header +func (m *Middleware) extractTokenFromHeader(r *http.Request) string { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return "" + } + + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { + return "" + } + + return parts[1] +} + +// hasRequiredScopes checks if user has all required scopes +func (m *Middleware) hasRequiredScopes(userScopes, requiredScopes []string) bool { + scopeMap := make(map[string]bool) + for _, scope := range userScopes { + scopeMap[scope] = true + } + + for _, required := range requiredScopes { + if !scopeMap[required] { + return false + } + } + + return true +} + +// writeUnauthorizedError writes a 401 Unauthorized response with WWW-Authenticate header (RFC 6750) +func (m *Middleware) writeUnauthorizedError(w http.ResponseWriter, errorCode, description string) { + // RFC 6750 Section 3: WWW-Authenticate header format + wwwAuth := fmt.Sprintf(`Bearer realm="%s"`, m.realm) + if errorCode != "" { + wwwAuth += fmt.Sprintf(`, error="%s"`, errorCode) + } + if description != "" { + wwwAuth += fmt.Sprintf(`, error_description="%s"`, escapeQuotes(description)) + } + + w.Header().Set("WWW-Authenticate", wwwAuth) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + + fmt.Fprintf(w, `{"error":"%s","error_description":"%s"}`, + errorCode, escapeQuotes(description)) +} + +// writeInsufficientScopeError writes a 403 Forbidden response with WWW-Authenticate header +func (m *Middleware) writeInsufficientScopeError(w http.ResponseWriter, description string) { + wwwAuth := fmt.Sprintf(`Bearer realm="%s", error="insufficient_scope"`, m.realm) + if description != "" { + wwwAuth += fmt.Sprintf(`, error_description="%s"`, escapeQuotes(description)) + } + + w.Header().Set("WWW-Authenticate", wwwAuth) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusForbidden) + + fmt.Fprintf(w, `{"error":"insufficient_scope","error_description":"%s"}`, + escapeQuotes(description)) +} + +// escapeQuotes escapes double quotes for use in header values +func escapeQuotes(s string) string { + return strings.ReplaceAll(s, `"`, `\"`) +} + +// OAuthContext holds OAuth information in context +type OAuthContext struct { + ClientID string + UserID string + Scopes []string + AccessToken *AccessToken +} + +type oauthContextKey struct{} + +// WithOAuthContext adds OAuth context to the context +func WithOAuthContext(ctx context.Context, oauthCtx *OAuthContext) context.Context { + return context.WithValue(ctx, oauthContextKey{}, oauthCtx) +} + +// GetOAuthContext retrieves OAuth context from context +func GetOAuthContext(ctx context.Context) *OAuthContext { + value := ctx.Value(oauthContextKey{}) + if value == nil { + return nil + } + return value.(*OAuthContext) +} + +// GetClientID is a helper to get client ID from context +func GetClientID(ctx context.Context) string { + if oauthCtx := GetOAuthContext(ctx); oauthCtx != nil { + return oauthCtx.ClientID + } + return "" +} + +// GetUserID is a helper to get user ID from context +func GetUserID(ctx context.Context) string { + if oauthCtx := GetOAuthContext(ctx); oauthCtx != nil { + return oauthCtx.UserID + } + return "" +} + +// GetScopes is a helper to get scopes from context +func GetScopes(ctx context.Context) []string { + if oauthCtx := GetOAuthContext(ctx); oauthCtx != nil { + return oauthCtx.Scopes + } + return nil +} + +// HasScope checks if context has a specific scope +func HasScope(ctx context.Context, scope string) bool { + scopes := GetScopes(ctx) + for _, s := range scopes { + if s == scope { + return true + } + } + return false +} + +// HasAnyScope checks if context has any of the specified scopes +func HasAnyScope(ctx context.Context, requiredScopes ...string) bool { + userScopes := GetScopes(ctx) + scopeMap := make(map[string]bool) + for _, scope := range userScopes { + scopeMap[scope] = true + } + + for _, required := range requiredScopes { + if scopeMap[required] { + return true + } + } + return false +} + +// HasAllScopes checks if context has all of the specified scopes +func HasAllScopes(ctx context.Context, requiredScopes ...string) bool { + userScopes := GetScopes(ctx) + scopeMap := make(map[string]bool) + for _, scope := range userScopes { + scopeMap[scope] = true + } + + for _, required := range requiredScopes { + if !scopeMap[required] { + return false + } + } + return true +} diff --git a/auth/pkce.go b/auth/pkce.go new file mode 100644 index 0000000..cdbba9e --- /dev/null +++ b/auth/pkce.go @@ -0,0 +1,170 @@ +package auth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "regexp" +) + +// PKCE (Proof Key for Code Exchange) support as required by OAuth 2.1 + +const ( + // CodeChallengeMethodPlain is not recommended but kept for compatibility + CodeChallengeMethodPlain = "plain" + // CodeChallengeMethodS256 is the recommended method + CodeChallengeMethodS256 = "S256" + + // Code verifier requirements (RFC 7636) + MinCodeVerifierLength = 43 + MaxCodeVerifierLength = 128 +) + +var ( + ErrInvalidCodeVerifier = errors.New("auth: invalid code verifier") + ErrInvalidCodeChallenge = errors.New("auth: invalid code challenge") + ErrInvalidCodeChallengeMethod = errors.New("auth: invalid code challenge method") + ErrCodeChallengeMismatch = errors.New("auth: code challenge verification failed") + ErrPKCERequired = errors.New("auth: PKCE is required for this client") +) + +// codeVerifierRegex validates code verifier format (RFC 7636) +var codeVerifierRegex = regexp.MustCompile(`^[A-Za-z0-9\-._~]{43,128}$`) + +// PKCEValidator handles PKCE validation +type PKCEValidator struct { + // RequirePKCE forces all clients to use PKCE + RequirePKCE bool + // RequireS256 forces S256 method (recommended for OAuth 2.1) + RequireS256 bool +} + +// NewPKCEValidator creates a new PKCE validator with OAuth 2.1 defaults +func NewPKCEValidator() *PKCEValidator { + return &PKCEValidator{ + RequirePKCE: true, // OAuth 2.1 requires PKCE + RequireS256: true, // OAuth 2.1 recommends S256 + } +} + +// GenerateCodeVerifier generates a cryptographically random code verifier +func GenerateCodeVerifier() (string, error) { + // Generate 32 random bytes (256 bits) + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("failed to generate code verifier: %w", err) + } + + // Base64 URL encode without padding + verifier := base64.RawURLEncoding.EncodeToString(b) + return verifier, nil +} + +// GenerateCodeChallenge generates a code challenge from a verifier +func GenerateCodeChallenge(verifier string, method string) (string, error) { + if err := ValidateCodeVerifier(verifier); err != nil { + return "", err + } + + switch method { + case CodeChallengeMethodPlain: + return verifier, nil + case CodeChallengeMethodS256: + h := sha256.New() + h.Write([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) + return challenge, nil + default: + return "", ErrInvalidCodeChallengeMethod + } +} + +// ValidateCodeVerifier validates the format of a code verifier +func ValidateCodeVerifier(verifier string) error { + if len(verifier) < MinCodeVerifierLength || len(verifier) > MaxCodeVerifierLength { + return fmt.Errorf("%w: length must be between %d and %d", + ErrInvalidCodeVerifier, MinCodeVerifierLength, MaxCodeVerifierLength) + } + + if !codeVerifierRegex.MatchString(verifier) { + return fmt.Errorf("%w: invalid characters", ErrInvalidCodeVerifier) + } + + return nil +} + +// ValidateCodeChallenge validates PKCE parameters during authorization +func (v *PKCEValidator) ValidateCodeChallenge(challenge, method string, isPublicClient bool) error { + // OAuth 2.1 requires PKCE for public clients + if isPublicClient && challenge == "" { + return ErrPKCERequired + } + + // If PKCE is required for all clients + if v.RequirePKCE && challenge == "" { + return ErrPKCERequired + } + + // If no challenge provided and it's not required, skip validation + if challenge == "" { + return nil + } + + // Validate challenge method + if method == "" { + method = CodeChallengeMethodPlain + } + + if method != CodeChallengeMethodPlain && method != CodeChallengeMethodS256 { + return ErrInvalidCodeChallengeMethod + } + + // OAuth 2.1 recommends S256 + if v.RequireS256 && method != CodeChallengeMethodS256 { + return fmt.Errorf("%w: S256 method required", ErrInvalidCodeChallengeMethod) + } + + return nil +} + +// VerifyCodeChallenge verifies the code verifier against the stored challenge +func VerifyCodeChallenge(verifier, challenge, method string) error { + if err := ValidateCodeVerifier(verifier); err != nil { + return err + } + + if challenge == "" { + // No challenge was stored, PKCE not used + return nil + } + + // Generate challenge from verifier + computedChallenge, err := GenerateCodeChallenge(verifier, method) + if err != nil { + return err + } + + // Compare challenges + if computedChallenge != challenge { + return ErrCodeChallengeMismatch + } + + return nil +} + +// GenerateCodeVerifierAndChallenge is a helper that generates both verifier and challenge +func GenerateCodeVerifierAndChallenge() (verifier, challenge string, err error) { + verifier, err = GenerateCodeVerifier() + if err != nil { + return "", "", err + } + + challenge, err = GenerateCodeChallenge(verifier, CodeChallengeMethodS256) + if err != nil { + return "", "", err + } + + return verifier, challenge, nil +} diff --git a/auth/server.go b/auth/server.go new file mode 100644 index 0000000..bdc56ce --- /dev/null +++ b/auth/server.go @@ -0,0 +1,560 @@ +package auth + +import ( + "context" + "errors" + "fmt" + "net/url" + "time" +) + +// Server implements OAuth 2.1 authorization server +type Server struct { + store Store + tokenGenerator *TokenGenerator + pkceValidator *PKCEValidator + config *ServerConfig + metadataProvider *MetadataProvider + jwksProvider *JWKSProvider +} + +// ServerConfig holds OAuth server configuration +type ServerConfig struct { + // Issuer identifier + Issuer string + + // MCP Server URL (NEW - required by MCP spec) + // The authorization base URL will be automatically derived from this + // by discarding the path component + MCPServerURL string + + // Supported grant types + SupportedGrantTypes []GrantType + + // Authorization code TTL (short-lived, OAuth 2.1 recommends < 10 minutes) + AuthorizationCodeTTL time.Duration + + // Require PKCE for all clients + RequirePKCE bool + + // Require S256 challenge method + RequireS256 bool + + // Enable refresh token rotation + RefreshTokenRotation bool + + // Maximum number of redirect URIs per client + MaxRedirectURIs int + + // RFC 8707: Resource indicators configuration + SupportedResources []string + + // Logger for server operations + Logger Logger + + // BaseURL is the derived authorization base URL + // This is computed from MCPServerURL by discarding the path component + baseURL string +} + +// Logger interface +type Logger interface { + Warnf(format string, args ...interface{}) + Infof(format string, args ...interface{}) + Errorf(format string, args ...interface{}) +} + +// DefaultLogger is a no-op logger +type DefaultLogger struct{} + +func (d DefaultLogger) Warnf(format string, args ...interface{}) {} +func (d DefaultLogger) Infof(format string, args ...interface{}) {} +func (d DefaultLogger) Errorf(format string, args ...interface{}) {} + +// DefaultServerConfig returns default OAuth 2.1 configuration +func DefaultServerConfig(issuer, mcpServerURL string) *ServerConfig { + return &ServerConfig{ + Issuer: issuer, + MCPServerURL: mcpServerURL, + SupportedGrantTypes: []GrantType{ + GrantTypeAuthorizationCode, + GrantTypeClientCredentials, + GrantTypeRefreshToken, + }, + AuthorizationCodeTTL: 5 * time.Minute, + RequirePKCE: true, // OAuth 2.1 requirement + RequireS256: true, // OAuth 2.1 best practice + RefreshTokenRotation: true, // OAuth 2.1 best practice + MaxRedirectURIs: 10, + SupportedResources: DefaultMCPResources, // MCP resources + Logger: DefaultLogger{}, + } +} + +// NewServer creates a new OAuth 2.1 server +func NewServer(store Store, signingKey []byte, config *ServerConfig) (*Server, error) { + if store == nil { + return nil, errors.New("store cannot be nil") + } + + if len(signingKey) < 32 { + return nil, errors.New("signing key must be at least 32 bytes") + } + + if config == nil { + return nil, errors.New("config cannot be nil") + } + + // Validate MCP Server URL (required for base URL derivation) + if config.MCPServerURL == "" { + return nil, errors.New("MCPServerURL is required for MCP compliance") + } + + // Validate MCP Server URL can be parsed + if _, err := url.Parse(config.MCPServerURL); err != nil { + return nil, fmt.Errorf("invalid MCPServerURL: %w", err) + } + + // Set default logger if not provided + if config.Logger == nil { + config.Logger = DefaultLogger{} + } + + tokenGenerator := NewTokenGenerator(signingKey, config.Issuer) + + pkceValidator := NewPKCEValidator() + pkceValidator.RequirePKCE = config.RequirePKCE + pkceValidator.RequireS256 = config.RequireS256 + + // Create metadata provider - it will automatically derive authorization base URL + metadataProvider, err := NewMetadataProvider(config) + if err != nil { + return nil, fmt.Errorf("failed to create metadata provider: %w", err) + } + + config.baseURL = metadataProvider.GetAuthorizationBaseURL() + + jwksProvider := NewJWKSProvider(signingKey, tokenGenerator.KeyID) + + config.Logger.Infof("OAuth server initialized with authorization base URL: %s (derived from MCP server URL: %s)", + metadataProvider.GetAuthorizationBaseURL(), config.MCPServerURL) + + return &Server{ + store: store, + tokenGenerator: tokenGenerator, + pkceValidator: pkceValidator, + config: config, + metadataProvider: metadataProvider, + jwksProvider: jwksProvider, + }, nil +} + +// GetMetadataProvider returns the metadata provider +func (s *Server) GetMetadataProvider() *MetadataProvider { + return s.metadataProvider +} + +// GetJWKSProvider returns the JWKS provider +func (s *Server) GetJWKSProvider() *JWKSProvider { + return s.jwksProvider +} + +// GetAuthorizationBaseURL returns the automatically derived authorization base URL +// This is derived from MCPServerURL per MCP spec requirements +func (s *Server) GetAuthorizationBaseURL() string { + return s.metadataProvider.GetAuthorizationBaseURL() +} + +// GetBaseURL returns the base URL from server config +func (c *ServerConfig) GetBaseURL() string { + return c.baseURL +} + +// HandleAuthorizationRequest handles OAuth 2.1 authorization requests +func (s *Server) HandleAuthorizationRequest(ctx context.Context, req *AuthorizationRequest, userID string) (string, error) { + // Validate response type (OAuth 2.1 only supports code) + if req.ResponseType != "code" { + return "", fmt.Errorf("%s: only 'code' response type is supported", ErrUnsupportedResponseType) + } + + // Validate client + client, err := s.store.GetClient(ctx, req.ClientID) + if err != nil { + return "", fmt.Errorf("%s: %w", ErrInvalidClient, err) + } + + // Validate redirect URI + if !s.validateRedirectURI(req.RedirectURI, client.RedirectURIs) { + return "", fmt.Errorf("%s: invalid redirect_uri", ErrInvalidRequest) + } + + // Parse and validate scopes + requestedScopes := parseScopes(req.Scope) + if !s.validateScopes(requestedScopes, client.Scopes) { + return "", fmt.Errorf("%s: invalid scopes", ErrInvalidScope) + } + + // RFC 8707: Validate resource indicators + if len(req.Resource) > 0 { + if err := s.validateResources(req.Resource, client.Resources); err != nil { + return "", fmt.Errorf("%s: %w", ErrInvalidTarget, err) + } + } + + // Validate PKCE (required for public clients in OAuth 2.1) + if err := s.pkceValidator.ValidateCodeChallenge(req.CodeChallenge, req.CodeChallengeMethod, client.IsPublic); err != nil { + return "", fmt.Errorf("%s: %w", ErrInvalidRequest, err) + } + + // Set default challenge method if not specified + challengeMethod := req.CodeChallengeMethod + if challengeMethod == "" && req.CodeChallenge != "" { + challengeMethod = CodeChallengeMethodPlain + } + + // Generate authorization code + code, err := GenerateAuthorizationCode() + if err != nil { + return "", fmt.Errorf("%s: failed to generate code", ErrServerError) + } + + // Store authorization code + authCode := &AuthorizationCode{ + Code: code, + ClientID: req.ClientID, + UserID: userID, + RedirectURI: req.RedirectURI, + Scopes: requestedScopes, + Resources: req.Resource, // RFC 8707 + CodeChallenge: req.CodeChallenge, + CodeChallengeMethod: challengeMethod, + ExpiresAt: time.Now().Add(s.config.AuthorizationCodeTTL), + CreatedAt: time.Now(), + Used: false, + } + + if err := s.store.SaveAuthorizationCode(ctx, authCode); err != nil { + return "", fmt.Errorf("%s: failed to save authorization code", ErrServerError) + } + + // Build redirect URI with code and state + redirectURL, err := url.Parse(req.RedirectURI) + if err != nil { + return "", fmt.Errorf("%s: invalid redirect URI", ErrInvalidRequest) + } + + query := redirectURL.Query() + query.Set("code", code) + if req.State != "" { + query.Set("state", req.State) + } + redirectURL.RawQuery = query.Encode() + + return redirectURL.String(), nil +} + +// HandleTokenRequest handles OAuth 2.1 token requests +func (s *Server) HandleTokenRequest(ctx context.Context, req *TokenRequest) (*TokenResponse, error) { + // Validate grant type + grantType := GrantType(req.GrantType) + if !s.isGrantTypeSupported(grantType) { + return nil, fmt.Errorf("%s: %s", ErrUnsupportedGrantType, req.GrantType) + } + + // Route to appropriate handler + switch grantType { + case GrantTypeAuthorizationCode: + return s.handleAuthorizationCodeGrant(ctx, req) + case GrantTypeClientCredentials: + return s.handleClientCredentialsGrant(ctx, req) + case GrantTypeRefreshToken: + return s.handleRefreshTokenGrant(ctx, req) + default: + return nil, fmt.Errorf("%s: %s", ErrUnsupportedGrantType, req.GrantType) + } +} + +// handleAuthorizationCodeGrant handles authorization code flow +func (s *Server) handleAuthorizationCodeGrant(ctx context.Context, req *TokenRequest) (*TokenResponse, error) { + // Validate required parameters + if req.Code == "" || req.RedirectURI == "" { + return nil, fmt.Errorf("%s: missing required parameters", ErrInvalidRequest) + } + + // Authenticate client + client, err := s.authenticateClient(ctx, req.ClientID, req.ClientSecret) + if err != nil { + return nil, err + } + + // Retrieve authorization code + authCode, err := s.store.GetAuthorizationCode(ctx, req.Code) + if err != nil { + if errors.Is(err, ErrNotFound) { + return nil, fmt.Errorf("%s: invalid authorization code", ErrInvalidGrant) + } + return nil, fmt.Errorf("%s: %w", ErrServerError, err) + } + + // Validate authorization code + if authCode.Used { + return nil, fmt.Errorf("%s: authorization code already used", ErrInvalidGrant) + } + + if time.Now().After(authCode.ExpiresAt) { + return nil, fmt.Errorf("%s: authorization code expired", ErrInvalidGrant) + } + + if authCode.ClientID != client.ID { + return nil, fmt.Errorf("%s: client mismatch", ErrInvalidGrant) + } + + if authCode.RedirectURI != req.RedirectURI { + return nil, fmt.Errorf("%s: redirect_uri mismatch", ErrInvalidGrant) + } + + // Verify PKCE if code challenge was used + if authCode.CodeChallenge != "" { + if req.CodeVerifier == "" { + return nil, fmt.Errorf("%s: code_verifier required", ErrInvalidRequest) + } + + if err := VerifyCodeChallenge(req.CodeVerifier, authCode.CodeChallenge, authCode.CodeChallengeMethod); err != nil { + return nil, fmt.Errorf("%s: %w", ErrInvalidGrant, err) + } + } + + // Invalidate authorization code (one-time use) + if err := s.store.InvalidateAuthorizationCode(ctx, req.Code); err != nil { + return nil, fmt.Errorf("%s: %w", ErrServerError, err) + } + + // Use resources from authorization code + resources := authCode.Resources + + // Generate tokens + return s.generateTokenPair(ctx, client.ID, authCode.UserID, authCode.Scopes, resources, true) +} + +// handleClientCredentialsGrant handles client credentials flow +func (s *Server) handleClientCredentialsGrant(ctx context.Context, req *TokenRequest) (*TokenResponse, error) { + // Authenticate client + client, err := s.authenticateClient(ctx, req.ClientID, req.ClientSecret) + if err != nil { + return nil, err + } + + // Parse requested scopes + requestedScopes := parseScopes(req.Scope) + if !s.validateScopes(requestedScopes, client.Scopes) { + return nil, fmt.Errorf("%s: invalid scopes", ErrInvalidScope) + } + + // RFC 8707: Validate resource indicators + resources := req.Resource + if len(resources) > 0 { + if err := s.validateResources(resources, client.Resources); err != nil { + return nil, fmt.Errorf("%s: %w", ErrInvalidTarget, err) + } + } + + // Generate access token only (no refresh token for client credentials) + return s.generateTokenPair(ctx, client.ID, "", requestedScopes, resources, false) +} + +// handleRefreshTokenGrant handles refresh token flow +func (s *Server) handleRefreshTokenGrant(ctx context.Context, req *TokenRequest) (*TokenResponse, error) { + if req.RefreshToken == "" { + return nil, fmt.Errorf("%s: missing refresh_token", ErrInvalidRequest) + } + + // Authenticate client + client, err := s.authenticateClient(ctx, req.ClientID, req.ClientSecret) + if err != nil { + return nil, err + } + + // Retrieve refresh token + refreshToken, err := s.store.GetRefreshToken(ctx, req.RefreshToken) + if err != nil { + if errors.Is(err, ErrNotFound) { + return nil, fmt.Errorf("%s: invalid refresh token", ErrInvalidGrant) + } + return nil, fmt.Errorf("%s: %w", ErrServerError, err) + } + + // Validate refresh token + if refreshToken.Revoked { + return nil, fmt.Errorf("%s: refresh token revoked", ErrInvalidGrant) + } + + if time.Now().After(refreshToken.ExpiresAt) { + return nil, fmt.Errorf("%s: refresh token expired", ErrInvalidGrant) + } + + if refreshToken.ClientID != client.ID { + return nil, fmt.Errorf("%s: client mismatch", ErrInvalidGrant) + } + + // Handle token rotation (OAuth 2.1 best practice) + if s.config.RefreshTokenRotation { + // Generate new refresh token + newRefreshToken, err := s.tokenGenerator.GenerateRefreshToken( + client.ID, + refreshToken.UserID, + refreshToken.Scopes, + refreshToken.Resources, + ) + if err != nil { + return nil, fmt.Errorf("%s: failed to generate refresh token", ErrServerError) + } + + newRefreshToken.RotationCount = refreshToken.RotationCount + 1 + + // Rotate tokens atomically + if err := s.store.RotateRefreshToken(ctx, req.RefreshToken, newRefreshToken); err != nil { + return nil, fmt.Errorf("%s: %w", ErrServerError, err) + } + + refreshToken = newRefreshToken + } + + // Generate new access token with same resources + accessToken, err := s.tokenGenerator.GenerateAccessToken( + client.ID, + refreshToken.UserID, + refreshToken.Scopes, + refreshToken.Resources, + ) + if err != nil { + return nil, fmt.Errorf("%s: failed to generate access token", ErrServerError) + } + + // Store access token + if err := s.store.SaveAccessToken(ctx, accessToken); err != nil { + return nil, fmt.Errorf("%s: %w", ErrServerError, err) + } + + return CreateTokenResponse(accessToken, refreshToken), nil +} + +// generateTokenPair generates access and optionally refresh token +func (s *Server) generateTokenPair( + ctx context.Context, + clientID, userID string, + scopes, resources []string, + includeRefreshToken bool, +) (*TokenResponse, error) { + accessToken, err := s.tokenGenerator.GenerateAccessToken(clientID, userID, scopes, resources) + if err != nil { + return nil, fmt.Errorf("%s: failed to generate access token", ErrServerError) + } + + // Store access token + if err := s.store.SaveAccessToken(ctx, accessToken); err != nil { + return nil, fmt.Errorf("%s: %w", ErrServerError, err) + } + + var refreshToken *RefreshToken + if includeRefreshToken { + refreshToken, err = s.tokenGenerator.GenerateRefreshToken(clientID, userID, scopes, resources) + if err != nil { + return nil, fmt.Errorf("%s: failed to generate refresh token", ErrServerError) + } + + // Store refresh token + if err := s.store.SaveRefreshToken(ctx, refreshToken); err != nil { + return nil, fmt.Errorf("%s: %w", ErrServerError, err) + } + } + + return CreateTokenResponse(accessToken, refreshToken), nil +} + +// authenticateClient authenticates the client +func (s *Server) authenticateClient(ctx context.Context, clientID, clientSecret string) (*Client, error) { + if clientID == "" { + return nil, fmt.Errorf("%s: missing client_id", ErrInvalidClient) + } + + client, err := s.store.GetClient(ctx, clientID) + if err != nil { + if errors.Is(err, ErrNotFound) { + return nil, fmt.Errorf("%s: client not found", ErrInvalidClient) + } + return nil, fmt.Errorf("%s: %w", ErrServerError, err) + } + + // Public clients don't require secret + if client.IsPublic { + return client, nil + } + + // Validate client secret for confidential clients + if err := s.store.ValidateClientSecret(ctx, clientID, clientSecret); err != nil { + return nil, fmt.Errorf("%s: invalid credentials", ErrInvalidClient) + } + + return client, nil +} + +// Helper functions + +func (s *Server) validateRedirectURI(redirectURI string, allowedURIs []string) bool { + for _, allowed := range allowedURIs { + if redirectURI == allowed { + return true + } + } + return false +} + +func (s *Server) validateScopes(requested, allowed []string) bool { + allowedMap := make(map[string]bool) + for _, scope := range allowed { + allowedMap[scope] = true + } + + for _, scope := range requested { + if !allowedMap[scope] { + return false + } + } + + return true +} + +// validateResources validates resource indicators (RFC 8707) +func (s *Server) validateResources(requested, allowed []string) error { + if len(allowed) == 0 { + // If client has no resource restrictions, check against server supported resources + allowed = s.config.SupportedResources + } + + if len(allowed) == 0 { + // No restrictions + return nil + } + + allowedMap := make(map[string]bool) + for _, res := range allowed { + allowedMap[res] = true + } + + for _, res := range requested { + if !allowedMap[res] { + return fmt.Errorf("resource not allowed: %s", res) + } + } + + return nil +} + +func (s *Server) isGrantTypeSupported(grantType GrantType) bool { + for _, supported := range s.config.SupportedGrantTypes { + if grantType == supported { + return true + } + } + return false +} diff --git a/auth/store.go b/auth/store.go new file mode 100644 index 0000000..65ef54a --- /dev/null +++ b/auth/store.go @@ -0,0 +1,81 @@ +package auth + +import ( + "context" + "errors" +) + +var ( + ErrNotFound = errors.New("auth: entity not found") + ErrAlreadyExists = errors.New("auth: entity already exists") + ErrExpired = errors.New("auth: entity expired") + ErrRevoked = errors.New("auth: entity revoked") + ErrInvalidStore = errors.New("auth: invalid store operation") +) + +// Store defines the interface for OAuth 2.1 data persistence +type Store interface { + ClientStore + AuthorizationCodeStore + TokenStore +} + +// ClientStore manages OAuth clients +type ClientStore interface { + // GetClient retrieves a client by ID + GetClient(ctx context.Context, clientID string) (*Client, error) + + // CreateClient creates a new client + CreateClient(ctx context.Context, client *Client) error + + // UpdateClient updates an existing client + UpdateClient(ctx context.Context, client *Client) error + + // DeleteClient deletes a client by ID + DeleteClient(ctx context.Context, clientID string) error + + // ValidateClientSecret validates client credentials + ValidateClientSecret(ctx context.Context, clientID, secret string) error +} + +// AuthorizationCodeStore manages authorization codes +type AuthorizationCodeStore interface { + // SaveAuthorizationCode saves an authorization code + SaveAuthorizationCode(ctx context.Context, code *AuthorizationCode) error + + // GetAuthorizationCode retrieves an authorization code + GetAuthorizationCode(ctx context.Context, code string) (*AuthorizationCode, error) + + // InvalidateAuthorizationCode marks a code as used + InvalidateAuthorizationCode(ctx context.Context, code string) error + + // CleanupExpiredAuthorizationCodes removes expired codes + CleanupExpiredAuthorizationCodes(ctx context.Context) error +} + +// TokenStore manages access and refresh tokens +type TokenStore interface { + // SaveAccessToken saves an access token + SaveAccessToken(ctx context.Context, token *AccessToken) error + + // GetAccessToken retrieves an access token + GetAccessToken(ctx context.Context, token string) (*AccessToken, error) + + // RevokeAccessToken revokes an access token + RevokeAccessToken(ctx context.Context, token string) error + + // SaveRefreshToken saves a refresh token + SaveRefreshToken(ctx context.Context, token *RefreshToken) error + + // GetRefreshToken retrieves a refresh token + GetRefreshToken(ctx context.Context, token string) (*RefreshToken, error) + + // RevokeRefreshToken revokes a refresh token + RevokeRefreshToken(ctx context.Context, token string) error + + // RotateRefreshToken creates a new refresh token and revokes the old one + RotateRefreshToken(ctx context.Context, oldToken string, newToken *RefreshToken) error + + // CleanupExpiredTokens removes expired tokens + CleanupExpiredTokens(ctx context.Context) error +} diff --git a/auth/third_party.go b/auth/third_party.go new file mode 100644 index 0000000..c6ccd45 --- /dev/null +++ b/auth/third_party.go @@ -0,0 +1,44 @@ +package auth + +// ThirdPartyOAuthConfig holds configuration for a third-party OAuth provider +type ThirdPartyOAuthConfig struct { + // Provider information + ProviderName string + + // OAuth 2.1 endpoints + AuthorizationURL string // Third-party /authorize endpoint + TokenURL string // Third-party /token endpoint + UserInfoURL string // Optional: to get user info + + // Client credentials (registered with third-party) + ClientID string + ClientSecret string + + // Redirect URI (your MCP server's callback) + RedirectURI string + + // Scopes to request from third-party + Scopes []string + + // Optional: PKCE support + UsePKCE bool +} + +// ThirdPartyTokenResponse represents token response from third-party +type ThirdPartyTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` + IDToken string `json:"id_token,omitempty"` // For OIDC +} + +// ThirdPartyUserInfo represents user info from third-party +type ThirdPartyUserInfo struct { + Subject string `json:"sub"` + Email string `json:"email,omitempty"` + Name string `json:"name,omitempty"` + Username string `json:"preferred_username,omitempty"` + Claims map[string]interface{} `json:"-"` // Additional claims +} diff --git a/auth/token.go b/auth/token.go new file mode 100644 index 0000000..42f1268 --- /dev/null +++ b/auth/token.go @@ -0,0 +1,256 @@ +package auth + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "fmt" + "math/big" + "strings" + "time" + + "github.com/golang-jwt/jwt/v4" +) + +// TokenGenerator handles token generation and validation +type TokenGenerator struct { + // JWT signing + SigningKey []byte + PrivateKey *rsa.PrivateKey // For RS256 + Algorithm string // "HS256" or "RS256" + + // Issuer identifier + Issuer string + // Key ID for JWKS + KeyID string + // Token expiration times + AccessTokenTTL time.Duration + RefreshTokenTTL time.Duration + // Token rotation settings + RotateRefreshTokens bool +} + +// TokenClaims represents JWT claims for access tokens +type TokenClaims struct { + jwt.RegisteredClaims + ClientID string `json:"client_id,omitempty"` + Scopes []string `json:"scopes,omitempty"` + TokenID string `json:"jti,omitempty"` +} + +// NewTokenGenerator creates a new token generator with default settings +func NewTokenGenerator(signingKey []byte, issuer string) *TokenGenerator { + return &TokenGenerator{ + SigningKey: signingKey, + Issuer: issuer, + KeyID: GenerateKeyID(signingKey), + Algorithm: "HS256", + AccessTokenTTL: 15 * time.Minute, + RefreshTokenTTL: 30 * 24 * time.Hour, + RotateRefreshTokens: true, + } +} + +// NewTokenGeneratorWithRSA creates a token generator with RS256 +func NewTokenGeneratorWithRSA(privateKey *rsa.PrivateKey, issuer string) *TokenGenerator { + return &TokenGenerator{ + PrivateKey: privateKey, + Issuer: issuer, + KeyID: GenerateRSAKeyID(&privateKey.PublicKey), + Algorithm: "RS256", + AccessTokenTTL: 15 * time.Minute, + RefreshTokenTTL: 30 * 24 * time.Hour, + RotateRefreshTokens: true, + } +} + +// GenerateRandomToken generates a cryptographically secure random token +func GenerateRandomToken(length int) (string, error) { + b := make([]byte, length) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("failed to generate random token: %w", err) + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// GenerateAccessToken generates a new JWT access token with resource indicators +func (g *TokenGenerator) GenerateAccessToken(clientID, userID string, scopes []string, resources ...[]string) (*AccessToken, error) { + now := time.Now() + expiresAt := now.Add(g.AccessTokenTTL) + + // Generate token ID for tracking + tokenID, err := GenerateRandomToken(16) + if err != nil { + return nil, fmt.Errorf("failed to generate token ID: %w", err) + } + + // Flatten resources parameter + var resourceList []string + if len(resources) > 0 { + resourceList = resources[0] + } + + // RFC 8707: Set audience to resource indicators + audience := jwt.ClaimStrings{clientID} + if len(resourceList) > 0 { + audience = make(jwt.ClaimStrings, len(resourceList)) + copy(audience, resourceList) + } + + // Create JWT claims + claims := TokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: g.Issuer, + Subject: userID, + Audience: audience, + ExpiresAt: jwt.NewNumericDate(expiresAt), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now), + ID: tokenID, + }, + ClientID: clientID, + Scopes: scopes, + TokenID: tokenID, + } + + // Create and sign token based on algorithm + var token *jwt.Token + var tokenString string + + if g.Algorithm == "RS256" { + token = jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = g.KeyID + tokenString, err = token.SignedString(g.PrivateKey) + } else { + token = jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token.Header["kid"] = g.KeyID + tokenString, err = token.SignedString(g.SigningKey) + } + + if err != nil { + return nil, fmt.Errorf("failed to sign token: %w", err) + } + + return &AccessToken{ + Token: tokenString, + TokenType: TokenTypeBearer, + ClientID: clientID, + UserID: userID, + Scopes: scopes, + Resources: resourceList, + ExpiresAt: expiresAt, + CreatedAt: now, + }, nil +} + +// GenerateRefreshToken generates a new refresh token +func (g *TokenGenerator) GenerateRefreshToken(clientID, userID string, scopes []string, resources ...[]string) (*RefreshToken, error) { + now := time.Now() + expiresAt := now.Add(g.RefreshTokenTTL) + + // Flatten resources parameter + var resourceList []string + if len(resources) > 0 { + resourceList = resources[0] + } + + // Generate opaque token + tokenString, err := GenerateRandomToken(32) + if err != nil { + return nil, fmt.Errorf("failed to generate refresh token: %w", err) + } + + return &RefreshToken{ + Token: tokenString, + ClientID: clientID, + UserID: userID, + Scopes: scopes, + Resources: resourceList, + ExpiresAt: expiresAt, + CreatedAt: now, + RotationCount: 0, + Revoked: false, + }, nil +} + +// GenerateAuthorizationCode generates a new authorization code +func GenerateAuthorizationCode() (string, error) { + return GenerateRandomToken(32) +} + +// ValidateAccessToken validates and parses a JWT access token +func (g *TokenGenerator) ValidateAccessToken(tokenString string) (*TokenClaims, error) { + token, err := jwt.ParseWithClaims(tokenString, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) { + // Verify signing method + if g.Algorithm == "RS256" { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + // Return public key for verification + return &g.PrivateKey.PublicKey, nil + } else { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return g.SigningKey, nil + } + }) + + if err != nil { + return nil, fmt.Errorf("failed to parse token: %w", err) + } + + claims, ok := token.Claims.(*TokenClaims) + if !ok || !token.Valid { + return nil, fmt.Errorf("invalid token claims") + } + + // Verify issuer + if claims.Issuer != g.Issuer { + return nil, fmt.Errorf("invalid issuer") + } + + return claims, nil +} + +// ValidateTokenAudience validates that the token's audience matches the requested resource +func (g *TokenGenerator) ValidateTokenAudience(claims *TokenClaims, resource string) error { + if resource == "" { + return nil // No specific resource requested + } + + for _, aud := range claims.Audience { + if aud == resource { + return nil + } + } + + return fmt.Errorf("token audience does not include requested resource: %s", resource) +} + +// CreateTokenResponse creates a token response for the client +func CreateTokenResponse(accessToken *AccessToken, refreshToken *RefreshToken) *TokenResponse { + response := &TokenResponse{ + AccessToken: accessToken.Token, + TokenType: string(accessToken.TokenType), + ExpiresIn: int64(time.Until(accessToken.ExpiresAt).Seconds()), + } + + if refreshToken != nil { + response.RefreshToken = refreshToken.Token + } + + if len(accessToken.Scopes) > 0 { + response.Scope = strings.Join(accessToken.Scopes, " ") + } + + return response +} + +// GenerateRSAKeyID creates a key ID from RSA public key +func GenerateRSAKeyID(publicKey *rsa.PublicKey) string { + n := publicKey.N.Bytes() + e := big.NewInt(int64(publicKey.E)).Bytes() + combined := append(n, e...) + return GenerateKeyID(combined) +} diff --git a/auth/types.go b/auth/types.go new file mode 100644 index 0000000..182146c --- /dev/null +++ b/auth/types.go @@ -0,0 +1,182 @@ +package auth + +import ( + "time" +) + +// GrantType represents OAuth 2.1 grant types +type GrantType string + +const ( + // OAuth 2.1 compliant grant types + GrantTypeAuthorizationCode GrantType = "authorization_code" + GrantTypeClientCredentials GrantType = "client_credentials" + GrantTypeRefreshToken GrantType = "refresh_token" + GrantTypeDeviceCode GrantType = "urn:ietf:params:oauth:grant-type:device_code" +) + +// TokenType represents the type of token +type TokenType string + +const ( + TokenTypeBearer TokenType = "Bearer" +) + +// Client represents an OAuth 2.1 client application +type Client struct { + ID string `json:"client_id"` + Secret string `json:"client_secret,omitempty"` // Hashed + Name string `json:"name"` + RedirectURIs []string `json:"redirect_uris"` + GrantTypes []string `json:"grant_types"` + Scopes []string `json:"scopes"` + Resources []string `json:"resources,omitempty"` // RFC 8707: Allowed resource indicators + IsPublic bool `json:"is_public"` // Public clients must use PKCE + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// AuthorizationCode represents an OAuth 2.1 authorization code +type AuthorizationCode struct { + Code string `json:"code"` + ClientID string `json:"client_id"` + UserID string `json:"user_id"` + RedirectURI string `json:"redirect_uri"` + Scopes []string `json:"scopes"` + Resources []string `json:"resources,omitempty"` // RFC 8707 + CodeChallenge string `json:"code_challenge"` + CodeChallengeMethod string `json:"code_challenge_method"` + ExpiresAt time.Time `json:"expires_at"` + CreatedAt time.Time `json:"created_at"` + Used bool `json:"used"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// AccessToken represents an OAuth 2.1 access token +type AccessToken struct { + Token string `json:"token"` + TokenType TokenType `json:"token_type"` + ClientID string `json:"client_id"` + UserID string `json:"user_id,omitempty"` + Scopes []string `json:"scopes"` + Resources []string `json:"resources,omitempty"` // RFC 8707: Token audience + ExpiresAt time.Time `json:"expires_at"` + CreatedAt time.Time `json:"created_at"` + RefreshToken string `json:"refresh_token,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// RefreshToken represents an OAuth 2.1 refresh token +type RefreshToken struct { + Token string `json:"token"` + ClientID string `json:"client_id"` + UserID string `json:"user_id"` + Scopes []string `json:"scopes"` + Resources []string `json:"resources,omitempty"` // RFC 8707 + ExpiresAt time.Time `json:"expires_at"` + CreatedAt time.Time `json:"created_at"` + LastUsedAt *time.Time `json:"last_used_at,omitempty"` + RotationCount int `json:"rotation_count"` + Revoked bool `json:"revoked"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// TokenResponse represents the OAuth 2.1 token endpoint response +type TokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` +} + +// TokenRequest represents the OAuth 2.1 token endpoint request +type TokenRequest struct { + GrantType string `json:"grant_type" form:"grant_type"` + Code string `json:"code,omitempty" form:"code"` + RedirectURI string `json:"redirect_uri,omitempty" form:"redirect_uri"` + ClientID string `json:"client_id,omitempty" form:"client_id"` + ClientSecret string `json:"client_secret,omitempty" form:"client_secret"` + RefreshToken string `json:"refresh_token,omitempty" form:"refresh_token"` + Scope string `json:"scope,omitempty" form:"scope"` + Resource []string `json:"resource,omitempty" form:"resource"` // RFC 8707: Can be repeated + CodeVerifier string `json:"code_verifier,omitempty" form:"code_verifier"` + Username string `json:"username,omitempty" form:"username"` // Not in OAuth 2.1, kept for compatibility + Password string `json:"password,omitempty" form:"password"` // Not in OAuth 2.1, kept for compatibility +} + +// AuthorizationRequest represents the OAuth 2.1 authorization endpoint request +type AuthorizationRequest struct { + ResponseType string `json:"response_type" form:"response_type"` + ClientID string `json:"client_id" form:"client_id"` + RedirectURI string `json:"redirect_uri" form:"redirect_uri"` + Scope string `json:"scope,omitempty" form:"scope"` + State string `json:"state,omitempty" form:"state"` + Resource []string `json:"resource,omitempty" form:"resource"` // RFC 8707 + CodeChallenge string `json:"code_challenge,omitempty" form:"code_challenge"` + CodeChallengeMethod string `json:"code_challenge_method,omitempty" form:"code_challenge_method"` +} + +// IntrospectionRequest represents the token introspection request +type IntrospectionRequest struct { + Token string `json:"token" form:"token"` + TokenTypeHint string `json:"token_type_hint,omitempty" form:"token_type_hint"` +} + +// IntrospectionResponse represents the token introspection response +type IntrospectionResponse struct { + Active bool `json:"active"` + Scope string `json:"scope,omitempty"` + ClientID string `json:"client_id,omitempty"` + Username string `json:"username,omitempty"` + TokenType string `json:"token_type,omitempty"` + ExpiresAt int64 `json:"exp,omitempty"` + IssuedAt int64 `json:"iat,omitempty"` + Subject string `json:"sub,omitempty"` + Audience []string `json:"aud,omitempty"` // RFC 8707: Resource indicators +} + +// ErrorResponse represents an OAuth 2.1 error response +type ErrorResponse struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` + ErrorURI string `json:"error_uri,omitempty"` +} + +// OAuth 2.1 error codes +const ( + ErrInvalidRequest = "invalid_request" + ErrInvalidClient = "invalid_client" + ErrInvalidGrant = "invalid_grant" + ErrUnauthorizedClient = "unauthorized_client" + ErrUnsupportedGrantType = "unsupported_grant_type" + ErrInvalidScope = "invalid_scope" + ErrInvalidTarget = "invalid_target" // RFC 8707 + ErrAccessDenied = "access_denied" + ErrUnsupportedResponseType = "unsupported_response_type" + ErrServerError = "server_error" + ErrTemporarilyUnavailable = "temporarily_unavailable" + ErrInvalidToken = "invalid_token" +) + +// MCP Protocol Constants +const ( + // CurrentMCPVersion is the MCP protocol version this implementation supports + CurrentMCPVersion = "2024-11-05" + + // MCPProtocolVersionHeader is the header name for MCP protocol version + MCPProtocolVersionHeader = "MCP-Protocol-Version" + + // MCP Resource URI schemes (RFC 8707) + MCPResourceSchemeTools = "mcp://tools" + MCPResourceSchemePrompts = "mcp://prompts" + MCPResourceSchemeResources = "mcp://resources" +) + +// DefaultMCPResources indicators +var DefaultMCPResources = []string{ + MCPResourceSchemeTools, + MCPResourceSchemePrompts, + MCPResourceSchemeResources, +} diff --git a/auth/validation.go b/auth/validation.go new file mode 100644 index 0000000..67da189 --- /dev/null +++ b/auth/validation.go @@ -0,0 +1,197 @@ +package auth + +import ( + "fmt" + "net/url" + "strings" +) + +// validateRedirectURI validates that the redirect URI exactly matches one of the registered URIs +// OAuth 2.1 requires exact matching, no wildcards allowed +func validateRedirectURI(allowedURIs []string, requestedURI string) error { + if requestedURI == "" { + return fmt.Errorf("redirect_uri is required") + } + + for _, allowed := range allowedURIs { + if allowed == requestedURI { + return nil + } + } + + return fmt.Errorf("%s: redirect_uri not registered for this client", ErrInvalidRequest) +} + +// validateEndpointSecurity ensures HTTPS is used (except for localhost) +// OAuth 2.1 requirement: All endpoints must use HTTPS except localhost for development +// MCP spec: "Redirect URIs MUST be either localhost URLs or HTTPS URLs" +func validateEndpointSecurity(uri string) error { + if uri == "" { + return fmt.Errorf("URI cannot be empty") + } + + parsedURL, err := url.Parse(uri) + if err != nil { + return fmt.Errorf("invalid URI: %w", err) + } + + // MCP spec: HTTPS or localhost required + if parsedURL.Scheme == "https" { + // HTTPS always allowed + return nil + } + + if parsedURL.Scheme == "http" { + // HTTP only allowed for localhost + if isLocalhost(parsedURL.Hostname()) { + return nil + } + return fmt.Errorf("%s: HTTPS required for non-localhost endpoints (MCP spec requirement)", ErrInvalidRequest) + } + + // Other schemes (e.g., custom schemes for native apps) may be allowed + // but HTTPS/localhost validation is MCP requirement for web endpoints + + // Fragment components are not allowed in redirect URIs + if parsedURL.Fragment != "" { + return fmt.Errorf("%s: redirect URIs must not contain fragment components", ErrInvalidRequest) + } + + return nil +} + +// isLocalhost checks if a host is localhost or 127.0.0.1 or ::1 (IPv6) +// Updated to properly handle IPv6 localhost per MCP spec +func isLocalhost(host string) bool { + if host == "" { + return false + } + + // Normalize host by removing brackets for IPv6 + normalizedHost := strings.Trim(host, "[]") + + // Check common localhost representations + switch normalizedHost { + case "localhost", + "127.0.0.1", // IPv4 loopback + "::1", // IPv6 loopback (short form) + "0:0:0:0:0:0:0:1": // IPv6 loopback (long form) + return true + } + + // Check IPv6 loopback variations + if strings.HasPrefix(normalizedHost, "::ffff:127.") { + // IPv4-mapped IPv6 addresses (::ffff:127.0.0.1) + return true + } + + return false +} + +// validateState ensures state parameter is present (CSRF protection) +// OAuth 2.1 strongly recommends state parameter for all authorization requests +func validateState(state string) error { + if state == "" { + return fmt.Errorf("%s: state parameter is required for CSRF protection", ErrInvalidRequest) + } + + // State should be at least 8 characters for adequate entropy + if len(state) < 8 { + return fmt.Errorf("%s: state parameter too short (minimum 8 characters)", ErrInvalidRequest) + } + + return nil +} + +// validateScope validates requested scopes +func validateScope(requestedScopes []string, allowedScopes []string) error { + if len(requestedScopes) == 0 { + return nil // Empty scope is allowed + } + + allowedMap := make(map[string]bool) + for _, scope := range allowedScopes { + allowedMap[scope] = true + } + + for _, requested := range requestedScopes { + if !allowedMap[requested] { + return fmt.Errorf("%s: scope '%s' is not allowed", ErrInvalidScope, requested) + } + } + + return nil +} + +// validateClientRedirectURIs validates all redirect URIs for a client during registration +func validateClientRedirectURIs(redirectURIs []string, applicationType string) error { + if len(redirectURIs) == 0 { + return fmt.Errorf("%s: at least one redirect_uri is required", ErrInvalidRequest) + } + + for _, uri := range redirectURIs { + if err := validateRedirectURIFormat(uri, applicationType); err != nil { + return err + } + } + + return nil +} + +// validateRedirectURIFormat validates the format of a redirect URI based on application type +// Updated to fully comply with MCP spec: "Redirect URIs MUST be either localhost URLs or HTTPS URLs" +func validateRedirectURIFormat(uri string, applicationType string) error { + parsedURL, err := url.Parse(uri) + if err != nil { + return fmt.Errorf("%s: invalid redirect URI format: %w", ErrInvalidRequest, err) + } + + // Fragment components are forbidden + if parsedURL.Fragment != "" { + return fmt.Errorf("%s: redirect URIs must not contain fragment components", ErrInvalidRequest) + } + + // MCP spec: "Redirect URIs MUST be either localhost URLs or HTTPS URLs" + switch parsedURL.Scheme { + case "https": + // HTTPS always allowed per MCP spec + return nil + + case "http": + // HTTP only allowed for localhost per MCP spec + hostname := parsedURL.Hostname() + if !isLocalhost(hostname) { + return fmt.Errorf("%s: HTTP redirect URIs only allowed for localhost (MCP spec requirement)", ErrInvalidRequest) + } + return nil + + default: + // For native apps, custom schemes may be allowed + if applicationType == "native" { + // Custom schemes like myapp:// are allowed for native apps + // but MCP spec primarily targets web/server scenarios + return nil + } + + return fmt.Errorf("%s: redirect URI must use HTTPS or localhost HTTP (MCP spec requirement)", ErrInvalidRequest) + } +} + +// parseScopes parses a space-separated scope string into a slice +func parseScopes(scopeString string) []string { + if scopeString == "" { + return []string{} + } + + scopes := strings.Split(scopeString, " ") + result := make([]string, 0, len(scopes)) + + for _, scope := range scopes { + trimmed := strings.TrimSpace(scope) + if trimmed != "" { + result = append(result, trimmed) + } + } + + return result +} diff --git a/client/client.go b/client/client.go index 7ab56e8..9f4affe 100644 --- a/client/client.go +++ b/client/client.go @@ -45,6 +45,19 @@ func WithLogger(logger pkg.Logger) Option { } } +func WithAuthAndRefresher( + accessToken, refreshToken string, + expiry time.Time, + refresher func(string) (string, string, time.Time, error), +) Option { + return func(c *Client) { + c.accessToken = accessToken + c.refreshToken = refreshToken + c.tokenExpiry = expiry + c.tokenRefresher = refresher + } +} + type Client struct { transport transport.ClientTransport @@ -74,6 +87,12 @@ type Client struct { closed chan struct{} logger pkg.Logger + + accessToken string + refreshToken string + tokenExpiry time.Time + tokenMutex sync.RWMutex + tokenRefresher func(refreshToken string) (accessToken, newRefreshToken string, expiry time.Time, err error) } func NewClient(t transport.ClientTransport, opts ...Option) (*Client, error) { @@ -94,6 +113,13 @@ func NewClient(t transport.ClientTransport, opts ...Option) (*Client, error) { opt(client) } + if tp, ok := t.(interface{ SetTokenProvider(func() string) }); ok { + tp.SetTokenProvider(func() string { + token, _ := client.GetAccessToken() + return token + }) + } + if client.notifyHandler == nil { h := NewBaseNotifyHandler() h.Logger = client.logger @@ -160,3 +186,38 @@ func (client *Client) sessionDetection() { client.logger.Warnf("mcp client ping server fail: %v", err) } } + +func (c *Client) GetAccessToken() (string, bool) { + c.tokenMutex.RLock() + + if time.Now().Before(c.tokenExpiry) { + token := c.accessToken + c.tokenMutex.RUnlock() + return token, true + } + + c.tokenMutex.RUnlock() + + if c.tokenRefresher == nil { + return "", false + } + + c.tokenMutex.Lock() + defer c.tokenMutex.Unlock() + + if time.Now().Before(c.tokenExpiry) { + return c.accessToken, true + } + + newAccessToken, newRefreshToken, newExpiry, err := c.tokenRefresher(c.refreshToken) + if err != nil { + c.logger.Errorf("Failed to refresh token: %v", err) + return "", false + } + + c.accessToken = newAccessToken + c.refreshToken = newRefreshToken + c.tokenExpiry = newExpiry + + return newAccessToken, true +} diff --git a/examples/oauth_example/client/main.go b/examples/oauth_example/client/main.go new file mode 100644 index 0000000..d63a2d5 --- /dev/null +++ b/examples/oauth_example/client/main.go @@ -0,0 +1,277 @@ +package main + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/url" + "strings" + "time" + + "github.com/ThinkInAIXYZ/go-mcp/client" + "github.com/ThinkInAIXYZ/go-mcp/protocol" + "github.com/ThinkInAIXYZ/go-mcp/transport" +) + +func main() { + log.Println("šŸš€ OAuth Client Demo") + + // Step 1: OAuth Authorization Flow + log.Println("\nšŸ“‹ Step 1: Getting OAuth token...") + codeVerifier, codeChallenge := generatePKCE() + state := randomString(16) + + // Start callback server + codeChan := make(chan string, 1) + go startCallback(codeChan) + time.Sleep(500 * time.Millisecond) + + // Build authorization URL + authURL := fmt.Sprintf( + "http://localhost:8080/oauth/authorize?response_type=code&client_id=demo-client&redirect_uri=http://localhost:9999/callback&scope=read&state=%s&code_challenge=%s&code_challenge_method=S256", + state, codeChallenge, + ) + + log.Printf("šŸ“‹ Authorization URL:\n %s\n", authURL) + log.Println("ā³ Waiting for authorization...") + + // Simulate a browser visit + go func() { + time.Sleep(100 * time.Millisecond) + resp, err := http.Get(authURL) + if err != nil { + log.Printf("āŒ Authorization request failed: %v", err) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusFound { + body, _ := io.ReadAll(resp.Body) + log.Printf("āŒ Authorization failed (%d): %s", resp.StatusCode, string(body)) + } + }() + + // Add timeout protection + select { + case code := <-codeChan: + if code == "" { + log.Fatal("āŒ Received empty authorization code") + } + log.Printf("āœ… Got authorization code: %s...\n", code[:min(10, len(code))]) + + // Exchange authorization code for tokens + accessToken, refreshToken := exchangeToken(code, codeVerifier) + if accessToken == "" { + log.Fatal("āŒ Failed to get access token") + } + log.Printf("āœ… Got access token: %s...\n", accessToken[:min(20, len(accessToken))]) + + // Step 2: Connect to MCP with token + log.Println("\nšŸ“‹ Step 2: Connecting to MCP with token...") + + sseTransport, err := transport.NewSSEClientTransport( + "http://localhost:8080/mcp", + transport.WithSSEClientTransportAuthToken(func() string { + return accessToken + }), + ) + if err != nil { + log.Fatalf("āŒ Failed to create transport: %v", err) + } + + mcpClient, err := client.NewClient(sseTransport, + client.WithAuthAndRefresher( + accessToken, + refreshToken, + time.Now().Add(15*time.Minute), + func(rt string) (string, string, time.Time, error) { + log.Println("šŸ”„ Refreshing token...") + newAccess, newRefresh := refreshToken_(rt) + return newAccess, newRefresh, time.Now().Add(15 * time.Minute), nil + }, + ), + ) + if err != nil { + log.Fatalf("āŒ Failed to create MCP client: %v", err) + } + defer mcpClient.Close() + + log.Println("āœ… Connected to MCP Server") + + // Step 3: Call a registered tool + log.Println("\nšŸ“‹ Step 3: Calling tool...") + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + result, err := mcpClient.CallTool(ctx, &protocol.CallToolRequest{ + Name: "echo", + Arguments: map[string]interface{}{ + "message": "Hello OAuth!", + }, + }) + + if err != nil { + log.Fatalf("āŒ Tool call failed: %v", err) + } + + if len(result.Content) > 0 { + if textContent, ok := result.Content[0].(*protocol.TextContent); ok { + log.Printf("āœ… Result: %s\n", textContent.Text) + } + } + + log.Println("\nāœ… Demo completed!") + log.Println(" Press Ctrl+C to exit") + select {} + + case <-time.After(10 * time.Second): + log.Fatal("āŒ Timeout waiting for authorization code") + } +} + +// ========== Helper Functions ========== + +// min returns the smaller of two integers. +func min(a, b int) int { + if a < b { + return a + } + return b +} + +// generatePKCE creates a PKCE code verifier and its corresponding challenge. +func generatePKCE() (verifier, challenge string) { + b := make([]byte, 32) + rand.Read(b) + verifier = base64.RawURLEncoding.EncodeToString(b) + + h := sha256.Sum256([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(h[:]) + return +} + +// randomString generates a random URL-safe string of length n. +func randomString(n int) string { + b := make([]byte, n) + rand.Read(b) + result := base64.RawURLEncoding.EncodeToString(b) + if len(result) > n { + return result[:n] + } + return result +} + +// startCallback starts a local HTTP server to handle OAuth redirect responses. +func startCallback(codeChan chan string) { + http.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + code := r.URL.Query().Get("code") + errorCode := r.URL.Query().Get("error") + + if errorCode != "" { + errorDesc := r.URL.Query().Get("error_description") + log.Printf("āŒ Authorization error: %s - %s", errorCode, errorDesc) + fmt.Fprintf(w, "āŒ Authorization failed: %s", errorDesc) + codeChan <- "" + return + } + + if code == "" { + log.Println("āŒ No authorization code received") + fmt.Fprintf(w, "āŒ No authorization code received") + codeChan <- "" + return + } + + log.Println("āœ… Callback received") + fmt.Fprintf(w, "āœ… Authorization successful! You can close this window.") + codeChan <- code + }) + + log.Println("🌐 Callback server started at http://localhost:9999") + if err := http.ListenAndServe(":9999", nil); err != nil { + log.Printf("āŒ Callback server error: %v", err) + } +} + +// exchangeToken exchanges the authorization code for an access and refresh token. +func exchangeToken(code, verifier string) (accessToken, refreshToken string) { + data := url.Values{} + data.Set("grant_type", "authorization_code") + data.Set("code", code) + data.Set("redirect_uri", "http://localhost:9999/callback") + data.Set("client_id", "demo-client") + data.Set("code_verifier", verifier) + + resp, err := http.Post( + "http://localhost:8080/oauth/token", + "application/x-www-form-urlencoded", + strings.NewReader(data.Encode()), + ) + if err != nil { + log.Printf("āŒ Token request failed: %v", err) + return "", "" + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + log.Printf("āŒ Token endpoint error (%d): %s", resp.StatusCode, string(body)) + return "", "" + } + + var tokenResp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + Error string `json:"error"` + ErrorDesc string `json:"error_description"` + } + + if err := json.Unmarshal(body, &tokenResp); err != nil { + log.Printf("āŒ Failed to parse token response: %v", err) + return "", "" + } + + if tokenResp.Error != "" { + log.Printf("āŒ Token error: %s - %s", tokenResp.Error, tokenResp.ErrorDesc) + return "", "" + } + + return tokenResp.AccessToken, tokenResp.RefreshToken +} + +// refreshToken_ refreshes the access token using a refresh token. +func refreshToken_(refreshToken string) (accessToken, newRefreshToken string) { + data := url.Values{} + data.Set("grant_type", "refresh_token") + data.Set("refresh_token", refreshToken) + data.Set("client_id", "demo-client") + + resp, err := http.Post( + "http://localhost:8080/oauth/token", + "application/x-www-form-urlencoded", + strings.NewReader(data.Encode()), + ) + if err != nil { + log.Printf("āŒ Token refresh failed: %v", err) + return "", "" + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + + var tokenResp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + } + json.Unmarshal(body, &tokenResp) + + return tokenResp.AccessToken, tokenResp.RefreshToken +} diff --git a/examples/oauth_example/server/main.go b/examples/oauth_example/server/main.go new file mode 100644 index 0000000..7ba7e7d --- /dev/null +++ b/examples/oauth_example/server/main.go @@ -0,0 +1,145 @@ +package main + +import ( + "context" + "fmt" + "log" + "net/http" + "time" + + "github.com/ThinkInAIXYZ/go-mcp/auth" + "github.com/ThinkInAIXYZ/go-mcp/auth/memory" + "github.com/ThinkInAIXYZ/go-mcp/protocol" + "github.com/ThinkInAIXYZ/go-mcp/server" + "github.com/ThinkInAIXYZ/go-mcp/transport" +) + +func main() { + // 1. Create the OAuth server + store := memory.NewStore() + signingKey := []byte("my-secret-key-must-be-32-bytes!!") + + authServer, _ := auth.NewServer( + store, + signingKey, + auth.DefaultServerConfig( + "http://localhost:8080", + "http://localhost:8080/mcp", + ), + ) + + // 2. Register a demo client + store.CreateClient(context.Background(), &auth.Client{ + ID: "demo-client", + Name: "Demo Client", + RedirectURIs: []string{"http://localhost:9999/callback"}, + GrantTypes: []string{"authorization_code", "refresh_token"}, + Scopes: []string{"read", "write"}, + IsPublic: true, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }) + + // 3. Create SSE transport using NewSSEServerTransportAndHandler + sseTransport, sseHandler, _ := transport.NewSSEServerTransportAndHandler( + "http://localhost:8080/mcp/message", + ) + + // 4. Create the MCP server (integrated with OAuth) + mcpServer, _ := server.NewServer(sseTransport, + server.WithAuth(authServer, map[string][]string{ + "echo": {"read"}, + }), + ) + + // 5. Register a simple test tool + echoTool, _ := protocol.NewTool("echo", "Echo back your message", struct { + Message string `json:"message" jsonschema:"required,description=Your message"` + }{}) + + mcpServer.RegisterTool(echoTool, func(ctx context.Context, req *protocol.CallToolRequest) (*protocol.CallToolResult, error) { + msg := req.Arguments["message"].(string) + userID := auth.GetUserID(ctx) + + return &protocol.CallToolResult{ + Content: []protocol.Content{ + &protocol.TextContent{ + Type: "text", + Text: fmt.Sprintf("Echo: %s (user: %s)", msg, userID), + }, + }, + }, nil + }) + + // 6. Manually create HTTP routes (including auto-approval for OAuth authorization) + mux := http.NewServeMux() + + // āœ… MCP SSE endpoints + mux.Handle("/mcp", sseHandler.HandleSSE()) + mux.Handle("/mcp/message", sseHandler.HandleMessage()) + + // āœ… OAuth endpoints (custom authorization handler - auto-approve mode) + oauthHandler := mcpServer.GetOAuthHandler() + + // Register the auto-approval authorization endpoint + mux.HandleFunc("/oauth/authorize", func(w http.ResponseWriter, r *http.Request) { + handleAutoApproveAuthorize(w, r, authServer) + }) + + // Other standard OAuth endpoints + mux.HandleFunc("/oauth/token", oauthHandler.HandleToken) + mux.HandleFunc("/oauth/introspect", oauthHandler.HandleIntrospection) + mux.HandleFunc("/oauth/revoke", oauthHandler.HandleRevocation) + + // Metadata endpoints + mux.Handle("/.well-known/oauth-authorization-server", authServer.GetMetadataProvider()) + mux.Handle("/.well-known/jwks.json", authServer.GetJWKSProvider()) + + // 7. Start the HTTP server and MCP server + go mcpServer.Run() + + log.Println("šŸš€ Server started at http://localhost:8080") + log.Println(" MCP SSE: http://localhost:8080/mcp") + log.Println(" OAuth: http://localhost:8080/oauth/authorize") + log.Println("") + log.Println("šŸ‘‰ Run client: go run client/main.go") + + if err := http.ListenAndServe(":8080", mux); err != nil { + log.Fatalf("Server failed: %v", err) + } +} + +// āœ… Auto-approve authorization requests (for demo/testing purposes) +func handleAutoApproveAuthorize(w http.ResponseWriter, r *http.Request, authServer *auth.Server) { + // Parse authorization request + req := &auth.AuthorizationRequest{ + ResponseType: r.URL.Query().Get("response_type"), + ClientID: r.URL.Query().Get("client_id"), + RedirectURI: r.URL.Query().Get("redirect_uri"), + Scope: r.URL.Query().Get("scope"), + State: r.URL.Query().Get("state"), + CodeChallenge: r.URL.Query().Get("code_challenge"), + CodeChallengeMethod: r.URL.Query().Get("code_challenge_method"), + Resource: r.URL.Query()["resource"], + } + + log.Printf("šŸ“ Authorization request: client=%s, scope=%s", req.ClientID, req.Scope) + + // āœ… Auto-approve: use client ID as the user ID + // In production, you should: + // 1. Authenticate the user (check session/cookie) + // 2. Show a consent screen to ask for user approval + // 3. Generate an authorization code only after user consent + userID := req.ClientID + "-user" + + // Generate authorization code and redirect + redirectURL, err := authServer.HandleAuthorizationRequest(r.Context(), req, userID) + if err != nil { + log.Printf("āŒ Authorization failed: %v", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + log.Printf("āœ… Auto-approved for user: %s", userID) + http.Redirect(w, r, redirectURL, http.StatusFound) +} diff --git a/go.mod b/go.mod index 8df79ba..38679ec 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/ThinkInAIXYZ/go-mcp go 1.18 require ( + github.com/golang-jwt/jwt/v4 v4.5.2 github.com/google/uuid v1.6.0 github.com/orcaman/concurrent-map/v2 v2.0.1 github.com/tidwall/gjson v1.18.0 diff --git a/go.sum b/go.sum index e3a61aa..bfb01e4 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= +github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/orcaman/concurrent-map/v2 v2.0.1 h1:jOJ5Pg2w1oeB6PeDurIYf6k9PQ+aTITr/6lP/L/zp6c= @@ -6,7 +8,11 @@ github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM= +github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= diff --git a/server/server.go b/server/server.go index e56aa69..054806b 100644 --- a/server/server.go +++ b/server/server.go @@ -3,11 +3,13 @@ package server import ( "context" "fmt" + "net/http" "sync" "time" "github.com/google/uuid" + "github.com/ThinkInAIXYZ/go-mcp/auth" "github.com/ThinkInAIXYZ/go-mcp/pkg" "github.com/ThinkInAIXYZ/go-mcp/protocol" "github.com/ThinkInAIXYZ/go-mcp/server/session" @@ -46,8 +48,8 @@ func WithLogger(logger pkg.Logger) Option { } } -// ToolMiddleware defines the middleware type of the tool handler -// Allow ToolHandlerFunc to be wrapped like a chain call +// ToolHandlerFunc and ToolMiddleware are defined and stay in the server package. +type ToolHandlerFunc func(context.Context, *protocol.CallToolRequest) (*protocol.CallToolResult, error) type ToolMiddleware func(ToolHandlerFunc) ToolHandlerFunc // RateLimitMiddleware Return a rate-limiting middleware @@ -74,6 +76,32 @@ func WithGenSessionIDFunc(genSessionID func(context.Context) string) Option { } } +func WithAuth(authServer *auth.Server, toolScopes map[string][]string) Option { + return func(s *Server) { + if authServer == nil { + return + } + + middleware := auth.NewMiddleware(authServer) + s.authMiddleware = middleware + + s.authHandler = auth.NewHandler(authServer) + s.authServer = authServer + + if len(toolScopes) > 0 { + // Call the generic function with the specific type from this package. + mw := auth.ScopeBasedToolMiddleware[ToolHandlerFunc](middleware, toolScopes) + s.Use(mw) + } else { + // Call the generic function with the specific type from this package. + mw := auth.MCPToolMiddleware[ToolHandlerFunc](middleware, &auth.MiddlewareConfig{ + RequiredScopes: []string{"read"}, + }) + s.Use(mw) + } + } +} + type ToolFilter func(context.Context, []*protocol.Tool) []*protocol.Tool type Server struct { @@ -102,6 +130,10 @@ type Server struct { globalMiddlewares []ToolMiddleware toolFilters ToolFilter + + authMiddleware *auth.Middleware + authServer *auth.Server + authHandler *auth.Handler } func NewServer(t transport.ServerTransport, opts ...Option) (*Server, error) { @@ -126,6 +158,10 @@ func NewServer(t transport.ServerTransport, opts ...Option) (*Server, error) { opt(server) } + if server.authMiddleware != nil { + t.ApplyAuthMiddleware(server.authMiddleware.HTTPMiddleware) + } + server.sessionManager.SetLogger(server.logger) t.SetSessionManager(server.sessionManager) @@ -155,8 +191,6 @@ type toolEntry struct { handler ToolHandlerFunc } -type ToolHandlerFunc func(context.Context, *protocol.CallToolRequest) (*protocol.CallToolResult, error) - func (server *Server) RegisterTool(tool *protocol.Tool, toolHandler ToolHandlerFunc, middlewares ...ToolMiddleware) { for i := len(middlewares) - 1; i >= 0; i-- { toolHandler = middlewares[i](toolHandler) @@ -313,3 +347,51 @@ func (server *Server) sessionDetection(ctx context.Context, sessionID string) er } return nil } + +func (s *Server) WrapWithAuth(handler http.Handler) http.Handler { + if s.authMiddleware == nil { + return handler + } + return s.authMiddleware.HTTPMiddleware(handler) +} + +func (server *Server) GetOAuthHandler() *auth.Handler { + return server.authHandler +} + +// RegisterOAuthRoutes automatically registers OAuth routes with transport +// Only effective if transport implements the HTTPRouteRegistrar interface +func (server *Server) RegisterOAuthRoutes() error { + if server.authHandler == nil { + return fmt.Errorf("OAuth not configured, use WithAuth option") + } + + // å°čÆ•čŽ·å– HTTPRouteRegistrar ęŽ„å£ + registrar, ok := server.transport.(interface { + RegisterHandler(pattern string, handler http.Handler) error + }) + + if !ok { + return fmt.Errorf("transport does not support HTTP route registration") + } + + // ę³Øå†Œ OAuth č·Æē”± + routes := map[string]http.Handler{ + "/oauth/authorize": http.HandlerFunc(server.authHandler.HandleAuthorization), + "/oauth/token": http.HandlerFunc(server.authHandler.HandleToken), + "/oauth/introspect": http.HandlerFunc(server.authHandler.HandleIntrospection), + "/oauth/revoke": http.HandlerFunc(server.authHandler.HandleRevocation), + "/register": http.HandlerFunc(server.authHandler.GetDynamicRegistrationHandler().HandleRegister), + "/.well-known/oauth-authorization-server": server.authServer.GetMetadataProvider(), + "/.well-known/oauth-protected-resource": http.HandlerFunc(server.authServer.GetMetadataProvider().ServeProtectedResourceMetadata), + "/.well-known/jwks.json": server.authServer.GetJWKSProvider(), + } + + for pattern, handler := range routes { + if err := registrar.RegisterHandler(pattern, handler); err != nil { + return fmt.Errorf("failed to register %s: %w", pattern, err) + } + } + + return nil +} diff --git a/transport/mock_server.go b/transport/mock_server.go index a637797..ce15fb4 100644 --- a/transport/mock_server.go +++ b/transport/mock_server.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "net/http" "github.com/ThinkInAIXYZ/go-mcp/pkg" ) @@ -63,6 +64,11 @@ func (t *mockServerTransport) SetSessionManager(m sessionManager) { t.sessionManager = m } +// ApplyAuthMiddleware is a no-op for mock server (non-HTTP) +func (t *mockServerTransport) ApplyAuthMiddleware(middleware func(http.Handler) http.Handler) { + // no-op +} + func (t *mockServerTransport) Shutdown(userCtx context.Context, serverCtx context.Context) error { t.cancel() diff --git a/transport/sse_client.go b/transport/sse_client.go index dd9a1a1..94e3d31 100644 --- a/transport/sse_client.go +++ b/transport/sse_client.go @@ -41,6 +41,12 @@ func WithRetryFunc(retry func(func() error)) SSEClientTransportOption { } } +func WithSSEClientTransportAuthToken(tokenProvider func() string) SSEClientTransportOption { + return func(t *sseClientTransport) { + t.tokenProvider = tokenProvider + } +} + type sseClientTransport struct { ctx context.Context cancel context.CancelFunc @@ -59,6 +65,8 @@ type sseClientTransport struct { retry func(func() error) sseConnectClose chan struct{} + + tokenProvider func() string } func NewSSEClientTransport(serverURL string, opts ...SSEClientTransportOption) (ClientTransport, error) { @@ -145,6 +153,12 @@ func (t *sseClientTransport) startSSE() error { req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Connection", "keep-alive") + if t.tokenProvider != nil { + if token := t.tokenProvider(); token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + } + resp, err := t.client.Do(req) //nolint:bodyclose if err != nil { return fmt.Errorf("failed to connect to SSE stream: %w", err) @@ -247,6 +261,12 @@ func (t *sseClientTransport) Send(ctx context.Context, msg Message) error { req.Header.Set("Content-Type", "application/json") + if t.tokenProvider != nil { + if token := t.tokenProvider(); token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + } + if resp, err = t.client.Do(req); err != nil { return fmt.Errorf("failed to send message: %w", err) } diff --git a/transport/sse_server.go b/transport/sse_server.go index 9f5ee48..205ea17 100644 --- a/transport/sse_server.go +++ b/transport/sse_server.go @@ -65,6 +65,8 @@ type sseServerTransport struct { httpSvr *http.Server + mux *http.ServeMux + messageEndpointURL string // Auto-generated inFlySend sync.WaitGroup @@ -73,6 +75,8 @@ type sseServerTransport struct { sessionManager sessionManager + authMiddleware func(http.Handler) http.Handler + // options logger pkg.Logger ssePath string @@ -100,6 +104,10 @@ func (h *SSEHandler) HandleMessage() http.Handler { }) } +func (t *sseServerTransport) ApplyAuthMiddleware(middleware func(http.Handler) http.Handler) { + t.authMiddleware = middleware +} + // NewSSEServerTransport returns transport that will start an HTTP server func NewSSEServerTransport(addr string, opts ...SSEServerTransportOption) (ServerTransport, error) { ctx, cancel := context.WithCancel(context.Background()) @@ -129,6 +137,7 @@ func NewSSEServerTransport(addr string, opts ...SSEServerTransportOption) (Serve mux := http.NewServeMux() mux.HandleFunc(t.ssePath, t.handleSSE) mux.HandleFunc(t.messagePath, t.handleMessage) + t.mux = mux t.httpSvr = &http.Server{ Addr: addr, @@ -205,8 +214,33 @@ func (t *sseServerTransport) SetSessionManager(manager sessionManager) { t.sessionManager = manager } -// handleSSE handles incoming SSE connections from clients and sends messages to them. +// RegisterHandler allows registering custom routes on the HTTP server of the SSE Server +// Only valid when creating a transport using NewSSEServerTransport +func (t *sseServerTransport) RegisterHandler(pattern string, handler http.Handler) error { + if t.mux == nil { + return fmt.Errorf("mux is not available, use NewSSEServerTransport to create transport") + } + t.mux.Handle(pattern, handler) + return nil +} + +// RegisterHandlerFunc is a convenience method of RegisterHandler that accepts HandlerFunc +func (t *sseServerTransport) RegisterHandlerFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) error { + return t.RegisterHandler(pattern, http.HandlerFunc(handler)) +} + func (t *sseServerTransport) handleSSE(w http.ResponseWriter, r *http.Request) { + if t.authMiddleware != nil { + t.authMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.handleSSECore(w, r) + })).ServeHTTP(w, r) + return + } + t.handleSSECore(w, r) +} + +// handleSSECore handles incoming SSE connections from clients and sends messages to them. +func (t *sseServerTransport) handleSSECore(w http.ResponseWriter, r *http.Request) { defer pkg.RecoverWithFunc(func(_ any) { t.writeError(w, http.StatusInternalServerError, "Internal server error") }) @@ -269,6 +303,17 @@ func (t *sseServerTransport) handleSSE(w http.ResponseWriter, r *http.Request) { // handleMessage processes incoming JSON-RPC messages from clients and sends responses // back through both the SSE connection and HTTP response. func (t *sseServerTransport) handleMessage(w http.ResponseWriter, r *http.Request) { + if t.authMiddleware != nil { + t.authMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.handleMessageCore(w, r) + })).ServeHTTP(w, r) + return + } + + t.handleMessageCore(w, r) +} + +func (t *sseServerTransport) handleMessageCore(w http.ResponseWriter, r *http.Request) { defer pkg.RecoverWithFunc(func(_ any) { t.writeError(w, http.StatusInternalServerError, "Internal server error") }) diff --git a/transport/stdio_server.go b/transport/stdio_server.go index ef0c7d6..8a179d7 100644 --- a/transport/stdio_server.go +++ b/transport/stdio_server.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "net/http" "os" "github.com/ThinkInAIXYZ/go-mcp/pkg" @@ -76,6 +77,11 @@ func (t *stdioServerTransport) SetSessionManager(m sessionManager) { t.sessionManager = m } +// ApplyAuthMiddleware is a no-op for stdio transport (non-HTTP) +func (t *stdioServerTransport) ApplyAuthMiddleware(middleware func(http.Handler) http.Handler) { + // no-op +} + func (t *stdioServerTransport) Shutdown(userCtx context.Context, serverCtx context.Context) error { t.cancel() diff --git a/transport/streamable_http_server.go b/transport/streamable_http_server.go index 5f50a9f..d04a7b3 100644 --- a/transport/streamable_http_server.go +++ b/transport/streamable_http_server.go @@ -77,6 +77,8 @@ type streamableHTTPServerTransport struct { sessionManager sessionManager + authMiddleware func(http.Handler) http.Handler + // options logger pkg.Logger mcpEndpoint string // The single MCP endpoint path @@ -146,6 +148,10 @@ func NewStreamableHTTPServerTransport(addr string, opts ...StreamableHTTPServerT return t } +func (t *streamableHTTPServerTransport) ApplyAuthMiddleware(middleware func(http.Handler) http.Handler) { + t.authMiddleware = middleware +} + func (t *streamableHTTPServerTransport) Run() error { if t.httpSvr == nil { <-t.ctx.Done() @@ -181,6 +187,16 @@ func (t *streamableHTTPServerTransport) SetSessionManager(manager sessionManager } func (t *streamableHTTPServerTransport) handleMCPEndpoint(w http.ResponseWriter, r *http.Request) { + if t.authMiddleware != nil { + t.authMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.handleMCPEndpointCore(w, r) + })).ServeHTTP(w, r) + return + } + t.handleMCPEndpointCore(w, r) +} + +func (t *streamableHTTPServerTransport) handleMCPEndpointCore(w http.ResponseWriter, r *http.Request) { defer pkg.RecoverWithFunc(func(_ any) { t.writeError(w, http.StatusInternalServerError, "Internal server error") }) diff --git a/transport/transport.go b/transport/transport.go index 63c7aa2..44c14b5 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -2,6 +2,7 @@ package transport import ( "context" + "net/http" "github.com/ThinkInAIXYZ/go-mcp/pkg" ) @@ -70,6 +71,14 @@ type ServerTransport interface { SetSessionManager(manager sessionManager) + // ApplyAuthMiddleware applies an HTTP authentication middleware to the transport layer. + // This method allows the server to inject authentication logic at the HTTP level, + // ensuring that all incoming requests are authenticated before reaching the business logic. + // The middleware parameter is a function that wraps an http.Handler with authentication logic. + // Not all transports may support this (e.g., stdio transport), so implementations + // should handle this gracefully (e.g., no-op for non-HTTP transports). + ApplyAuthMiddleware(middleware func(http.Handler) http.Handler) + // Shutdown gracefully closes, the internal implementation needs to stop receiving messages first, // then wait for serverCtx to be canceled, while using userCtx to control timeout. // userCtx is used to control the timeout of the server shutdown.