Skip to content

Remote server support #1423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/thv/app/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ func performOAuthFlow(ctx context.Context, issuer, clientID, clientSecret string
scopes,
true, // Enable PKCE by default for security
remoteAuthCallbackPort,
nil, // No OAuth params for proxy command
)
} else {
// Fall back to OIDC discovery
Expand Down
9 changes: 7 additions & 2 deletions cmd/thv/app/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,17 @@ func printTextServers(servers []registry.ServerMetadata) {
}
}

const (
serverTypeRemote = "remote"
serverTypeContainer = "container"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps we should make these public

)

// getServerType returns the type of server (container or remote)
func getServerType(server registry.ServerMetadata) string {
if server.IsRemote() {
return "remote"
return serverTypeRemote
}
return "container"
return serverTypeContainer
}

// printTextServerInfo prints detailed information about a server in text format
Expand Down
32 changes: 30 additions & 2 deletions cmd/thv/app/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net"
"net/url"
"os"
"os/signal"
"syscall"
Expand All @@ -26,7 +27,7 @@ var runCmd = &cobra.Command{
Short: "Run an MCP server",
Long: `Run an MCP server with the specified name, image, or protocol scheme.

ToolHive supports four ways to run an MCP server:
ToolHive supports five ways to run an MCP server:

1. From the registry:

Expand Down Expand Up @@ -59,13 +60,25 @@ ToolHive supports four ways to run an MCP server:

Runs an MCP server using a previously exported configuration file.

5. Remote MCP server:

$ thv run --remote <URL> [--name <name>]

Runs a remote MCP server as a workload, proxying requests to the specified URL.
This allows remote MCP servers to be managed like local workloads with full
support for client configuration, tool filtering, import/export, etc.

The container will be started with the specified transport mode and
permission profile. Additional configuration can be provided via flags.`,
Args: func(cmd *cobra.Command, args []string) error {
// If --from-config is provided, no args are required
if runFlags.FromConfig != "" {
return nil
}
// If --remote is provided, no args are required
if runFlags.RemoteURL != "" {
return nil
}
// Otherwise, require at least 1 argument
return cobra.MinimumNArgs(1)(cmd, args)
},
Expand Down Expand Up @@ -124,12 +137,27 @@ func runCmdFunc(cmd *cobra.Command, args []string) error {

// Get the name of the MCP server to run.
// This may be a server name from the registry, a container image, or a protocol scheme.
// When using --from-config, no args are required
// When using --from-config or --remote, no args are required
var serverOrImage string
if len(args) > 0 {
serverOrImage = args[0]
}

// If --remote is provided but no name is given, generate a name from the URL
if runFlags.RemoteURL != "" && runFlags.Name == "" {
// Extract a name from the remote URL
parsedURL, err := url.Parse(runFlags.RemoteURL)
if err != nil {
return fmt.Errorf("invalid remote URL: %v", err)
}
// Use the hostname as the base name
hostname := parsedURL.Hostname()
if hostname == "" {
hostname = "remote"
}
runFlags.Name = fmt.Sprintf("%s-remote", hostname)
}

// Process command arguments using os.Args to find everything after --
cmdArgs := parseCommandArguments(os.Args)

Expand Down
191 changes: 177 additions & 14 deletions cmd/thv/app/run_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package app
import (
"context"
"fmt"
"strings"
"time"

"github.com/spf13/cobra"

Expand All @@ -20,6 +22,10 @@ import (
"github.com/stacklok/toolhive/pkg/transport/types"
)

const (
defaultTransportType = "streamable-http"
)

// RunFlags holds the configuration for running MCP servers
type RunFlags struct {
// Transport and proxy settings
Expand All @@ -38,6 +44,9 @@ type RunFlags struct {
Volumes []string
Secrets []string

// Remote MCP server support
RemoteURL string

// Security and audit
AuthzConfig string
AuditConfig string
Expand Down Expand Up @@ -83,11 +92,33 @@ type RunFlags struct {
// Ignore functionality
IgnoreGlobally bool
PrintOverlays bool

// Remote authentication
EnableRemoteAuth bool
RemoteAuthClientID string
RemoteAuthClientSecret string
RemoteAuthClientSecretFile string
RemoteAuthScopes []string
RemoteAuthSkipBrowser bool
RemoteAuthTimeout time.Duration
RemoteAuthCallbackPort int
RemoteAuthBearerToken string

// Additional remote auth fields for registry-based servers
RemoteAuthIssuer string
RemoteAuthAuthorizeURL string
RemoteAuthTokenURL string

// Environment variables for remote servers
EnvVars map[string]string

// OAuth parameters for remote servers
OAuthParams map[string]string
}

// AddRunFlags adds all the run flags to a command
func AddRunFlags(cmd *cobra.Command, config *RunFlags) {
cmd.Flags().StringVar(&config.Transport, "transport", "", "Transport mode (sse, streamable-http or stdio)")
cmd.Flags().StringVar(&config.Transport, "transport", "", "Transport type to use (stdio, sse, streamable-http)")
cmd.Flags().StringVar(&config.ProxyMode, "proxy-mode", "sse", "Proxy mode for stdio transport (sse or streamable-http)")
cmd.Flags().StringVar(&config.Name, "name", "", "Name of the MCP server (auto-generated from image if not provided)")
// TODO: Re-enable when group functionality is complete
Expand Down Expand Up @@ -128,6 +159,7 @@ func AddRunFlags(cmd *cobra.Command, config *RunFlags) {
[]string{},
"Specify a secret to be fetched from the secrets manager and set as an environment variable (format: NAME,target=TARGET)",
)
cmd.Flags().StringVar(&config.RemoteURL, "remote", "", "URL of remote MCP server to run as a workload")
cmd.Flags().StringVar(&config.AuthzConfig, "authz-config", "", "Path to the authorization configuration file")
cmd.Flags().StringVar(&config.AuditConfig, "audit-config", "", "Path to the audit configuration file")
cmd.Flags().BoolVar(&config.EnableAudit, "enable-audit", false, "Enable audit logging with default configuration")
Expand All @@ -144,6 +176,26 @@ func AddRunFlags(cmd *cobra.Command, config *RunFlags) {
cmd.Flags().BoolVar(&config.JWKSAllowPrivateIP, "jwks-allow-private-ip", false,
"Allow JWKS/OIDC endpoints on private IP addresses (use with caution)")

// Remote authentication flags
cmd.Flags().BoolVar(&config.EnableRemoteAuth, "remote-auth", false,
"Enable automatic OAuth authentication for remote MCP servers")
cmd.Flags().StringVar(&config.RemoteAuthClientID, "remote-auth-client-id", "",
"OAuth client ID for remote server authentication")
cmd.Flags().StringVar(&config.RemoteAuthClientSecret, "remote-auth-client-secret", "",
"OAuth client secret for remote server authentication")
cmd.Flags().StringVar(&config.RemoteAuthClientSecretFile, "remote-auth-client-secret-file", "",
"Path to file containing OAuth client secret")
cmd.Flags().StringSliceVar(&config.RemoteAuthScopes, "remote-auth-scopes", []string{},
"OAuth scopes to request for remote server authentication")
cmd.Flags().BoolVar(&config.RemoteAuthSkipBrowser, "remote-auth-skip-browser", false,
"Skip opening browser for remote server OAuth flow")
cmd.Flags().DurationVar(&config.RemoteAuthTimeout, "remote-auth-timeout", 5*time.Minute,
"Timeout for OAuth authentication flow")
cmd.Flags().IntVar(&config.RemoteAuthCallbackPort, "remote-auth-callback-port", 8666,
"Port for OAuth callback server during remote authentication")
cmd.Flags().StringVar(&config.RemoteAuthBearerToken, "remote-auth-bearer-token", "",
"Bearer token for remote server authentication (alternative to OAuth)")

// OAuth discovery configuration
cmd.Flags().StringVar(&config.ResourceURL, "resource-url", "",
"Explicit resource URL for OAuth discovery endpoint (RFC 9728)")
Expand Down Expand Up @@ -183,6 +235,7 @@ func AddRunFlags(cmd *cobra.Command, config *RunFlags) {
}

// BuildRunnerConfig creates a runner.RunConfig from the configuration
// nolint:gocyclo // This function handles multiple configuration scenarios and is complex by design
func BuildRunnerConfig(
ctx context.Context,
runFlags *RunFlags,
Expand Down Expand Up @@ -232,19 +285,122 @@ func BuildRunnerConfig(
envVarValidator = &runner.CLIEnvVarValidator{}
}

// Image retrieval
// Handle remote MCP server
var imageMetadata *registry.ImageMetadata
imageURL := serverOrImage
// Only pull image if we are not running in Kubernetes mode.
// This split will go away if we implement a separate command or binary
// for running MCP servers in Kubernetes.
if !runtime.IsKubernetesRuntime() {
// Take the MCP server we were supplied and either fetch the image, or
// build it from a protocol scheme. If the server URI refers to an image
// in our trusted registry, we will also fetch the image metadata.
imageURL, imageMetadata, err = retriever.GetMCPServer(ctx, serverOrImage, runFlags.CACertPath, runFlags.VerifyImage)
if err != nil {
return nil, fmt.Errorf("failed to find or create the MCP server %s: %v", serverOrImage, err)
var remoteServerMetadata *registry.RemoteServerMetadata
transportType := runFlags.Transport

// If --remote flag is provided, use it as the serverOrImage
if runFlags.RemoteURL != "" {
serverOrImage = runFlags.RemoteURL
}

// Try to get server from registry (container or remote) or direct URL
imageURL, imageMetadata, remoteServerMetadata, err := retriever.GetMCPServerOrRemote(
ctx, serverOrImage, runFlags.CACertPath, runFlags.VerifyImage)
if err != nil {
return nil, fmt.Errorf("failed to find or create the MCP server %s: %v", serverOrImage, err)
}

if remoteServerMetadata != nil {
// Handle registry-based remote server
runFlags.RemoteURL = remoteServerMetadata.URL
if transportType == "" {
transportType = remoteServerMetadata.Transport
}

// Set up OAuth config if provided
if remoteServerMetadata.OAuthConfig != nil {
runFlags.EnableRemoteAuth = true
// Only set ClientID from registry if not provided via command line
if runFlags.RemoteAuthClientID == "" {
runFlags.RemoteAuthClientID = remoteServerMetadata.OAuthConfig.ClientID
}
runFlags.RemoteAuthIssuer = remoteServerMetadata.OAuthConfig.Issuer
runFlags.RemoteAuthAuthorizeURL = remoteServerMetadata.OAuthConfig.AuthorizeURL
runFlags.RemoteAuthTokenURL = remoteServerMetadata.OAuthConfig.TokenURL
runFlags.RemoteAuthScopes = remoteServerMetadata.OAuthConfig.Scopes

// Set OAuth parameters and callback port from registry
if remoteServerMetadata.OAuthConfig.OAuthParams != nil {
runFlags.OAuthParams = remoteServerMetadata.OAuthConfig.OAuthParams
}
if remoteServerMetadata.OAuthConfig.CallbackPort != 0 {
runFlags.RemoteAuthCallbackPort = remoteServerMetadata.OAuthConfig.CallbackPort
}
}

// Set up headers if provided
for _, header := range remoteServerMetadata.Headers {
if header.Secret {
runFlags.Secrets = append(runFlags.Secrets, fmt.Sprintf("%s,target=%s", header.Name, header.Name))
} else {
if runFlags.EnvVars == nil {
runFlags.EnvVars = make(map[string]string)
}
runFlags.EnvVars[header.Name] = header.Default
}
}

// Set up environment variables if provided
for _, envVar := range remoteServerMetadata.EnvVars {
if envVar.Secret {
// Only add secrets if no authentication method is provided
hasAuth := runFlags.RemoteAuthBearerToken != "" ||
runFlags.RemoteAuthClientID != "" ||
remoteServerMetadata.OAuthConfig != nil
if !hasAuth {
runFlags.Secrets = append(runFlags.Secrets, fmt.Sprintf("%s,target=%s", envVar.Name, envVar.Name))
}
} else {
if runFlags.EnvVars == nil {
runFlags.EnvVars = make(map[string]string)
}
runFlags.EnvVars[envVar.Name] = envVar.Default
}
}
} else if isURL(imageURL) {
// Handle direct URL approach
runFlags.RemoteURL = imageURL
if transportType == "" {
transportType = defaultTransportType // Default for direct URLs
}
} else {
// Handle container server (existing logic)
if transportType == "" {
transportType = defaultTransportType // Default for remote servers
}
// Only pull image if we are not running in Kubernetes mode.
// This split will go away if we implement a separate command or binary
// for running MCP servers in Kubernetes.
if !runtime.IsKubernetesRuntime() {
// Take the MCP server we were supplied and either fetch the image, or
// build it from a protocol scheme. If the server URI refers to an image
// in our trusted registry, we will also fetch the image metadata.
imageURL, imageMetadata, err = retriever.GetMCPServer(ctx, serverOrImage, runFlags.CACertPath, runFlags.VerifyImage)
if err != nil {
return nil, fmt.Errorf("failed to find or create the MCP server %s: %v", serverOrImage, err)
}
}
}

// Build remote auth config if enabled
var remoteAuthConfig *runner.RemoteAuthConfig
if runFlags.EnableRemoteAuth || runFlags.RemoteAuthClientID != "" || runFlags.RemoteAuthBearerToken != "" {
remoteAuthConfig = &runner.RemoteAuthConfig{
EnableRemoteAuth: runFlags.EnableRemoteAuth,
ClientID: runFlags.RemoteAuthClientID,
ClientSecret: runFlags.RemoteAuthClientSecret,
ClientSecretFile: runFlags.RemoteAuthClientSecretFile,
Scopes: runFlags.RemoteAuthScopes,
SkipBrowser: runFlags.RemoteAuthSkipBrowser,
Timeout: runFlags.RemoteAuthTimeout,
CallbackPort: runFlags.RemoteAuthCallbackPort,
BearerToken: runFlags.RemoteAuthBearerToken,
Issuer: runFlags.RemoteAuthIssuer,
AuthorizeURL: runFlags.RemoteAuthAuthorizeURL,
TokenURL: runFlags.RemoteAuthTokenURL,
OAuthParams: runFlags.OAuthParams,
}
}

Expand All @@ -269,6 +425,7 @@ func BuildRunnerConfig(
WithCmdArgs(cmdArgs).
WithName(runFlags.Name).
WithImage(imageURL).
WithRemoteURL(runFlags.RemoteURL).
WithHost(validatedHost).
WithTargetHost(runFlags.TargetHost).
WithDebug(debugMode).
Expand All @@ -280,7 +437,7 @@ func BuildRunnerConfig(
WithNetworkIsolation(runFlags.IsolateNetwork).
WithK8sPodPatch(runFlags.K8sPodPatch).
WithProxyMode(types.ProxyMode(runFlags.ProxyMode)).
WithTransportAndPorts(runFlags.Transport, runFlags.ProxyPort, runFlags.TargetPort).
WithTransportAndPorts(transportType, runFlags.ProxyPort, runFlags.TargetPort).
WithAuditEnabled(runFlags.EnableAudit, runFlags.AuditConfig).
WithLabels(runFlags.Labels).
WithGroup(runFlags.Group).
Expand All @@ -293,6 +450,7 @@ func BuildRunnerConfig(
LoadGlobal: runFlags.IgnoreGlobally,
PrintOverlays: runFlags.PrintOverlays,
}).
WithRemoteAuth(remoteAuthConfig).
Build(ctx, imageMetadata, envVars, envVarValidator)
}

Expand Down Expand Up @@ -329,3 +487,8 @@ func getTelemetryFromFlags(cmd *cobra.Command, config *cfg.Config, otelEndpoint

return finalOtelEndpoint, finalOtelSamplingRate, finalOtelEnvironmentVariables
}

// isURL checks if the input is a URL
func isURL(input string) bool {
return strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}
Loading
Loading