From 3e270ba447297340072b28b7ec80202a844cee36 Mon Sep 17 00:00:00 2001 From: Gautam Date: Thu, 31 Jul 2025 02:38:22 -0700 Subject: [PATCH 01/19] implementing auth, dynamic tool registry, updating mcp-go library --- cmd/terraform-mcp-server/init.go | 208 +++++++++-- cmd/terraform-mcp-server/main.go | 194 +--------- go.mod | 24 +- go.sum | 51 ++- pkg/client/client.go | 180 ++++++++++ pkg/client/common.go | 18 +- pkg/client/{security.go => middleware.go} | 53 +++ .../{security_test.go => middleware_test.go} | 333 ++++++++++++++++++ pkg/client/registry.go | 45 +++ pkg/resources/resource_templates.go | 25 +- pkg/resources/resources.go | 26 +- pkg/tools/dynamic_tool.go | 145 ++++++++ pkg/tools/dynamic_tool_test.go | 164 +++++++++ pkg/tools/get_provider_docs.go | 19 +- pkg/tools/list_terraform_orgs.go | 66 ++++ pkg/tools/list_terraform_projects.go | 77 ++++ pkg/tools/module_details.go | 22 +- pkg/tools/policy_details.go | 24 +- pkg/tools/resolve_provider_doc_id.go | 28 +- pkg/tools/search_modules.go | 18 +- pkg/tools/search_policies.go | 18 +- pkg/tools/tools.go | 24 +- 22 files changed, 1460 insertions(+), 302 deletions(-) create mode 100644 pkg/client/client.go rename pkg/client/{security.go => middleware.go} (61%) rename pkg/client/{security_test.go => middleware_test.go} (50%) create mode 100644 pkg/tools/dynamic_tool.go create mode 100644 pkg/tools/dynamic_tool_test.go create mode 100644 pkg/tools/list_terraform_orgs.go create mode 100644 pkg/tools/list_terraform_projects.go diff --git a/cmd/terraform-mcp-server/init.go b/cmd/terraform-mcp-server/init.go index 7c15769..4da9a06 100644 --- a/cmd/terraform-mcp-server/init.go +++ b/cmd/terraform-mcp-server/init.go @@ -10,55 +10,94 @@ import ( stdlog "log" "net/http" "os" - "strconv" + "strings" "time" - "github.com/hashicorp/go-cleanhttp" - "github.com/hashicorp/go-retryablehttp" + "github.com/hashicorp/terraform-mcp-server/pkg/client" "github.com/hashicorp/terraform-mcp-server/pkg/resources" "github.com/hashicorp/terraform-mcp-server/pkg/tools" - + "github.com/hashicorp/terraform-mcp-server/version" "github.com/mark3labs/mcp-go/server" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/spf13/viper" ) -func InitRegistryClient(logger *log.Logger) *http.Client { - retryClient := retryablehttp.NewClient() - retryClient.Logger = logger +var ( + rootCmd = &cobra.Command{ + Use: "terraform-mcp-server", + Short: "Terraform MCP Server", + Long: `A Terraform MCP server that handles various tools and resources.`, + Version: fmt.Sprintf("Version: %s\nCommit: %s\nBuild Date: %s", version.GetHumanVersion(), version.GitCommit, version.BuildDate), + Run: runDefaultCommand, + } - transport := cleanhttp.DefaultPooledTransport() - transport.Proxy = http.ProxyFromEnvironment + stdioCmd = &cobra.Command{ + Use: "stdio", + Short: "Start stdio server", + Long: `Start a server that communicates via standard input/output streams using JSON-RPC messages.`, + Run: func(_ *cobra.Command, _ []string) { + logFile, err := rootCmd.PersistentFlags().GetString("log-file") + if err != nil { + stdlog.Fatal("Failed to get log file:", err) + } + logger, err := initLogger(logFile) + if err != nil { + stdlog.Fatal("Failed to initialize logger:", err) + } - retryClient.HTTPClient = cleanhttp.DefaultClient() - retryClient.HTTPClient.Timeout = 10 * time.Second - retryClient.HTTPClient.Transport = transport - retryClient.RetryMax = 3 + if err := runStdioServer(logger); err != nil { + stdlog.Fatal("failed to run stdio server:", err) + } + }, + } - retryClient.Backoff = func(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration { - if resp != nil && resp.StatusCode == http.StatusTooManyRequests { - resetAfter := resp.Header.Get("x-ratelimit-reset") - resetAfterInt, err := strconv.ParseInt(resetAfter, 10, 64) + streamableHTTPCmd = &cobra.Command{ + Use: "streamable-http", + Short: "Start StreamableHTTP server", + Long: `Start a server that communicates via StreamableHTTP transport on port 8080 at /mcp endpoint.`, + Run: func(cmd *cobra.Command, _ []string) { + logFile, err := rootCmd.PersistentFlags().GetString("log-file") if err != nil { - return 0 + stdlog.Fatal("Failed to get log file:", err) + } + logger, err := initLogger(logFile) + if err != nil { + stdlog.Fatal("Failed to initialize logger:", err) } - resetAfterTime := time.Unix(resetAfterInt, 0) - return time.Until(resetAfterTime) - } - return 0 - } - retryClient.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) { - if resp != nil && resp.StatusCode == http.StatusTooManyRequests { - resetAfter := resp.Header.Get("x-ratelimit-reset") - return resetAfter != "", nil - } - return false, nil + port, err := cmd.Flags().GetString("transport-port") + if err != nil { + stdlog.Fatal("Failed to get streamableHTTP port:", err) + } + host, err := cmd.Flags().GetString("transport-host") + if err != nil { + stdlog.Fatal("Failed to get streamableHTTP host:", err) + } + + endpointPath, err := cmd.Flags().GetString("mcp-endpoint") + if err != nil { + stdlog.Fatal("Failed to get endpoint path:", err) + } + + if err := runHTTPServer(logger, host, port, endpointPath); err != nil { + stdlog.Fatal("failed to run streamableHTTP server:", err) + } + }, } - return retryClient.StandardClient() -} + // Create an alias for backward compatibility + httpCmdAlias = &cobra.Command{ + Use: "http", + Short: "Start StreamableHTTP server (deprecated, use 'streamable-http' instead)", + Long: `This command is deprecated. Please use 'streamable-http' instead.`, + Deprecated: "Use 'streamable-http' instead", + Run: func(cmd *cobra.Command, args []string) { + // Forward to the new command + streamableHTTPCmd.Run(cmd, args) + }, + } +) func init() { cobra.OnInitialize(initConfig) @@ -101,11 +140,11 @@ func initLogger(outPath string) (*log.Logger, error) { return logger, nil } -func registryInit(hcServer *server.MCPServer, logger *log.Logger) { - registryClient := InitRegistryClient(logger) - tools.InitTools(hcServer, registryClient, logger) - resources.RegisterResources(hcServer, registryClient, logger) - resources.RegisterResourceTemplates(hcServer, registryClient, logger) +// registerToolsAndResources registers tools and resources with the MCP server +func registerToolsAndResources(hcServer *server.MCPServer, logger *log.Logger) { + tools.RegisterTools(hcServer, logger) + resources.RegisterResources(hcServer, logger) + resources.RegisterResourceTemplates(hcServer, logger) } func serverInit(ctx context.Context, hcServer *server.MCPServer, logger *log.Logger) error { @@ -134,3 +173,100 @@ func serverInit(ctx context.Context, hcServer *server.MCPServer, logger *log.Log return nil } + +func streamableHTTPServerInit(ctx context.Context, hcServer *server.MCPServer, logger *log.Logger, host string, port string, endpointPath string) error { + // Check if stateless mode is enabled + isStateless := shouldUseStatelessMode() + + // Ensure endpoint path starts with / + if !strings.HasPrefix(endpointPath, "/") { + endpointPath = "/" + endpointPath + } + // Create StreamableHTTP server which implements the new streamable-http transport + // This is the modern MCP transport that supports both direct HTTP responses and SSE streams + opts := []server.StreamableHTTPOption{ + server.WithEndpointPath(endpointPath), // Default MCP endpoint path + server.WithLogger(logger), + } + + // Log the endpoint path being used + logger.Infof("Using endpoint path: %s", endpointPath) + + // Only add the WithStateLess option if stateless mode is enabled + // TODO: fix this in mcp-go ver 0.33.0 or higher + if isStateless { + opts = append(opts, server.WithStateLess(true)) + logger.Infof("Running in stateless mode") + } else { + logger.Infof("Running in stateful mode (default)") + } + + baseStreamableServer := server.NewStreamableHTTPServer(hcServer, opts...) + + // Load CORS configuration + corsConfig := client.LoadCORSConfigFromEnv() + + // Log CORS configuration + logger.Infof("CORS Mode: %s", corsConfig.Mode) + if len(corsConfig.AllowedOrigins) > 0 { + logger.Infof("Allowed Origins: %s", strings.Join(corsConfig.AllowedOrigins, ", ")) + } else if corsConfig.Mode == "strict" { + logger.Warnf("No allowed origins configured in strict mode. All cross-origin requests will be rejected.") + } else if corsConfig.Mode == "development" { + logger.Infof("Development mode: localhost origins are automatically allowed") + } else if corsConfig.Mode == "disabled" { + logger.Warnf("CORS validation is disabled. This is not recommended for production.") + } + + // Create a security wrapper around the streamable server + streamableServer := client.NewSecurityHandler(baseStreamableServer, corsConfig.AllowedOrigins, corsConfig.Mode, logger) + + mux := http.NewServeMux() + + // Apply middleware + streamableServer = client.TerraformContextMiddleware(logger)(streamableServer) + + // Handle the /mcp endpoint with the streamable server (with security wrapper) + mux.Handle(endpointPath, streamableServer) + mux.Handle(endpointPath+"/", streamableServer) + + // Add health check endpoint + mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + response := fmt.Sprintf(`{"status":"ok","service":"terraform-mcp-server","transport":"streamable-http","endpoint":"%s"}`, endpointPath) + w.Write([]byte(response)) + }) + + addr := fmt.Sprintf("%s:%s", host, port) + httpServer := &http.Server{ + Addr: addr, + Handler: mux, + ReadTimeout: 30 * time.Second, + ReadHeaderTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 60 * time.Second, + } + + // Start server in goroutine + errC := make(chan error, 1) + go func() { + logger.Infof("Starting StreamableHTTP server on %s%s", addr, endpointPath) + errC <- httpServer.ListenAndServe() + }() + + // Wait for shutdown signal + select { + case <-ctx.Done(): + logger.Infof("Shutting down StreamableHTTP server...") + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return httpServer.Shutdown(shutdownCtx) + case err := <-errC: + if err != nil && err != http.ErrServerClosed { + return fmt.Errorf("StreamableHTTP server error: %w", err) + } + } + + return nil +} diff --git a/cmd/terraform-mcp-server/main.go b/cmd/terraform-mcp-server/main.go index 18b1f9d..eec25a9 100644 --- a/cmd/terraform-mcp-server/main.go +++ b/cmd/terraform-mcp-server/main.go @@ -7,12 +7,10 @@ import ( "context" "fmt" stdlog "log" - "net/http" "os" "os/signal" "strings" "syscall" - "time" "github.com/hashicorp/terraform-mcp-server/pkg/client" "github.com/hashicorp/terraform-mcp-server/version" @@ -22,197 +20,27 @@ import ( "github.com/spf13/cobra" ) -var ( - rootCmd = &cobra.Command{ - Use: "terraform-mcp-server", - Short: "Terraform MCP Server", - Long: `A Terraform MCP server that handles various tools and resources.`, - Version: fmt.Sprintf("Version: %s\nCommit: %s\nBuild Date: %s", version.GetHumanVersion(), version.GitCommit, version.BuildDate), - Run: runDefaultCommand, - } - - stdioCmd = &cobra.Command{ - Use: "stdio", - Short: "Start stdio server", - Long: `Start a server that communicates via standard input/output streams using JSON-RPC messages.`, - Run: func(_ *cobra.Command, _ []string) { - logFile, err := rootCmd.PersistentFlags().GetString("log-file") - if err != nil { - stdlog.Fatal("Failed to get log file:", err) - } - logger, err := initLogger(logFile) - if err != nil { - stdlog.Fatal("Failed to initialize logger:", err) - } - - if err := runStdioServer(logger); err != nil { - stdlog.Fatal("failed to run stdio server:", err) - } - }, - } - - streamableHTTPCmd = &cobra.Command{ - Use: "streamable-http", - Short: "Start StreamableHTTP server", - Long: `Start a server that communicates via StreamableHTTP transport on port 8080 at /mcp endpoint.`, - Run: func(cmd *cobra.Command, _ []string) { - logFile, err := rootCmd.PersistentFlags().GetString("log-file") - if err != nil { - stdlog.Fatal("Failed to get log file:", err) - } - logger, err := initLogger(logFile) - if err != nil { - stdlog.Fatal("Failed to initialize logger:", err) - } - - port, err := cmd.Flags().GetString("transport-port") - if err != nil { - stdlog.Fatal("Failed to get streamableHTTP port:", err) - } - host, err := cmd.Flags().GetString("transport-host") - if err != nil { - stdlog.Fatal("Failed to get streamableHTTP host:", err) - } - - endpointPath, err := cmd.Flags().GetString("mcp-endpoint") - if err != nil { - stdlog.Fatal("Failed to get endpoint path:", err) - } - - if err := runHTTPServer(logger, host, port, endpointPath); err != nil { - stdlog.Fatal("failed to run streamableHTTP server:", err) - } - }, - } - - // Create an alias for backward compatibility - httpCmdAlias = &cobra.Command{ - Use: "http", - Short: "Start StreamableHTTP server (deprecated, use 'streamable-http' instead)", - Long: `This command is deprecated. Please use 'streamable-http' instead.`, - Deprecated: "Use 'streamable-http' instead", - Run: func(cmd *cobra.Command, args []string) { - // Forward to the new command - streamableHTTPCmd.Run(cmd, args) - }, - } -) - func runHTTPServer(logger *log.Logger, host string, port string, endpointPath string) error { ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() - hcServer := NewServer(version.Version) - registryInit(hcServer, logger) + hcServer := NewServer(version.Version, logger) + registerToolsAndResources(hcServer, logger) return streamableHTTPServerInit(ctx, hcServer, logger, host, port, endpointPath) } -func streamableHTTPServerInit(ctx context.Context, hcServer *server.MCPServer, logger *log.Logger, host string, port string, endpointPath string) error { - // Check if stateless mode is enabled - isStateless := shouldUseStatelessMode() - - // Ensure endpoint path starts with / - if !strings.HasPrefix(endpointPath, "/") { - endpointPath = "/" + endpointPath - } - // Create StreamableHTTP server which implements the new streamable-http transport - // This is the modern MCP transport that supports both direct HTTP responses and SSE streams - opts := []server.StreamableHTTPOption{ - server.WithEndpointPath(endpointPath), // Default MCP endpoint path - server.WithLogger(logger), - } - - // Log the endpoint path being used - logger.Infof("Using endpoint path: %s", endpointPath) - - // Only add the WithStateLess option if stateless mode is enabled - // TODO: fix this in mcp-go ver 0.33.0 or higher - if isStateless { - opts = append(opts, server.WithStateLess(true)) - logger.Infof("Running in stateless mode") - } else { - logger.Infof("Running in stateful mode (default)") - } - - baseStreamableServer := server.NewStreamableHTTPServer(hcServer, opts...) - - // Load CORS configuration - corsConfig := client.LoadCORSConfigFromEnv() - - // Log CORS configuration - logger.Infof("CORS Mode: %s", corsConfig.Mode) - if len(corsConfig.AllowedOrigins) > 0 { - logger.Infof("Allowed Origins: %s", strings.Join(corsConfig.AllowedOrigins, ", ")) - } else if corsConfig.Mode == "strict" { - logger.Warnf("No allowed origins configured in strict mode. All cross-origin requests will be rejected.") - } else if corsConfig.Mode == "development" { - logger.Infof("Development mode: localhost origins are automatically allowed") - } else if corsConfig.Mode == "disabled" { - logger.Warnf("CORS validation is disabled. This is not recommended for production.") - } - - // Create a security wrapper around the streamable server - streamableServer := client.NewSecurityHandler(baseStreamableServer, corsConfig.AllowedOrigins, corsConfig.Mode, logger) - - mux := http.NewServeMux() - - // Handle the /mcp endpoint with the streamable server (with security wrapper) - mux.Handle(endpointPath, streamableServer) - mux.Handle(endpointPath+"/", streamableServer) - - // Add health check endpoint - mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - response := fmt.Sprintf(`{"status":"ok","service":"terraform-mcp-server","transport":"streamable-http","endpoint":"%s"}`, endpointPath) - w.Write([]byte(response)) - }) - - addr := fmt.Sprintf("%s:%s", host, port) - httpServer := &http.Server{ - Addr: addr, - Handler: mux, - ReadTimeout: 30 * time.Second, - ReadHeaderTimeout: 30 * time.Second, - WriteTimeout: 30 * time.Second, - IdleTimeout: 60 * time.Second, - } - - // Start server in goroutine - errC := make(chan error, 1) - go func() { - logger.Infof("Starting StreamableHTTP server on %s%s", addr, endpointPath) - errC <- httpServer.ListenAndServe() - }() - - // Wait for shutdown signal - select { - case <-ctx.Done(): - logger.Infof("Shutting down StreamableHTTP server...") - shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - return httpServer.Shutdown(shutdownCtx) - case err := <-errC: - if err != nil && err != http.ErrServerClosed { - return fmt.Errorf("StreamableHTTP server error: %w", err) - } - } - - return nil -} - func runStdioServer(logger *log.Logger) error { ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() - hcServer := NewServer(version.Version) - registryInit(hcServer, logger) + hcServer := NewServer(version.Version, logger) + registerToolsAndResources(hcServer, logger) return serverInit(ctx, hcServer, logger) } -func NewServer(version string, opts ...server.ServerOption) *server.MCPServer { +func NewServer(version string, logger *log.Logger, opts ...server.ServerOption) *server.MCPServer { // Add default options defaultOpts := []server.ServerOption{ server.WithToolCapabilities(true), @@ -220,6 +48,18 @@ func NewServer(version string, opts ...server.ServerOption) *server.MCPServer { } opts = append(defaultOpts, opts...) + // Create hooks for session management + hooks := &server.Hooks{} + hooks.AddOnRegisterSession(func(ctx context.Context, session server.ClientSession) { + client.NewSessionHandler(ctx, session, logger) + }) + hooks.AddOnUnregisterSession(func(ctx context.Context, session server.ClientSession) { + client.EndSessionHandler(ctx, session, logger) + }) + + // Add hooks to options + opts = append(opts, server.WithHooks(hooks)) + // Create a new MCP server s := server.NewMCPServer( "terraform-mcp-server", diff --git a/go.mod b/go.mod index 2dec5a5..9e55ed5 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,8 @@ go 1.24 require ( github.com/hashicorp/go-cleanhttp v0.5.2 github.com/hashicorp/go-retryablehttp v0.7.8 - github.com/mark3labs/mcp-go v0.32.0 + github.com/hashicorp/go-tfe v1.87.0 + github.com/mark3labs/mcp-go v0.36.0 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.9.1 github.com/spf13/viper v1.20.1 @@ -13,23 +14,34 @@ require ( ) require ( + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect - github.com/go-viper/mapstructure/v2 v2.3.0 // indirect + github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/google/go-cmp v0.7.0 // indirect + github.com/google/go-querystring v1.1.0 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/hashicorp/go-slug v0.16.7 // indirect + github.com/hashicorp/go-version v1.7.0 // indirect + github.com/hashicorp/jsonapi v1.5.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/mailru/easyjson v0.9.0 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/sagikazarmark/locafero v0.9.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.14.0 // indirect - github.com/spf13/cast v1.8.0 // indirect - github.com/spf13/pflag v1.0.6 // indirect + github.com/spf13/cast v1.9.2 // indirect + github.com/spf13/pflag v1.0.7 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/sys v0.33.0 // indirect - golang.org/x/text v0.25.0 // indirect + golang.org/x/sync v0.16.0 // indirect + golang.org/x/sys v0.34.0 // indirect + golang.org/x/text v0.27.0 // indirect + golang.org/x/time v0.12.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 0a3c6b5..f41f95a 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -9,10 +13,13 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= -github.com/go-viper/mapstructure/v2 v2.3.0 h1:27XbWsHIqhbdR5TIC911OfYvgSaW93HM+dX7970Q7jk= -github.com/go-viper/mapstructure/v2 v2.3.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= +github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= +github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= @@ -21,14 +28,28 @@ github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB1 github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= github.com/hashicorp/go-retryablehttp v0.7.8 h1:ylXZWnqa7Lhqpk0L1P1LzDtGcCR0rPVUrx/c8Unxc48= github.com/hashicorp/go-retryablehttp v0.7.8/go.mod h1:rjiScheydd+CxvumBsIrFKlx3iS0jrZ7LvzFGFmuKbw= +github.com/hashicorp/go-slug v0.16.7 h1:sBW8y1sX+JKOZKu9a+DQZuWDVaX+U9KFnk6+VDQvKcw= +github.com/hashicorp/go-slug v0.16.7/go.mod h1:X5fm++dL59cDOX8j48CqHr4KARTQau7isGh0ZVxJB5I= +github.com/hashicorp/go-tfe v1.87.0 h1:0ejo3SegLoQ/Uj/2U0ECGppm3E/VZfSu+KscvzxvRNs= +github.com/hashicorp/go-tfe v1.87.0/go.mod h1:6dUFMBKh0jkxlRsrw7bYD2mby0efdwE4dtlAuTogIzA= +github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= +github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY= +github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= +github.com/hashicorp/jsonapi v1.5.0 h1:toO1EpzVl1b3xTjC/Tw4XMIlHgJreeTnyb1a1sHnlPk= +github.com/hashicorp/jsonapi v1.5.0/go.mod h1:kWfdn49yCjQvbpnvY1dxxAuAFzISwrrMDQOcu6NsFoM= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8= -github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= +github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.36.0 h1:rIZaijrRYPeSbJG8/qNDe0hWlGrCJ7FWHNMz2SQpTis= +github.com/mark3labs/mcp-go v0.36.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= @@ -49,12 +70,13 @@ github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9yS github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.14.0 h1:9tH6MapGnn/j0eb0yIXiLjERO8RB6xIVZRDCX7PtqWA= github.com/spf13/afero v1.14.0/go.mod h1:acJQ8t0ohCGuMN3O+Pv0V0hgMxNYDlvdk+VTfyZmbYo= -github.com/spf13/cast v1.8.0 h1:gEN9K4b8Xws4EX0+a0reLmhq8moKn7ntRlQYgjPeCDk= -github.com/spf13/cast v1.8.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cast v1.9.2 h1:SsGfm7M8QOFtEzumm7UZrZdLLquNdzFYfIbEXntcFbE= +github.com/spf13/cast v1.9.2/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= -github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.7 h1:vN6T9TfwStFPFM5XzjsvmzZkLuaLX+HS+0SeFLRgU6M= +github.com/spf13/pflag v1.0.7/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4= github.com/spf13/viper v1.20.1/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -63,15 +85,22 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= -golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= -golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= +golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= +golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= +golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= +golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= +golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/pkg/client/client.go b/pkg/client/client.go new file mode 100644 index 0000000..e6767ce --- /dev/null +++ b/pkg/client/client.go @@ -0,0 +1,180 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package client + +import ( + "context" + "fmt" + "net/http" + "os" + "strconv" + "sync" + + "github.com/hashicorp/go-tfe" + "github.com/mark3labs/mcp-go/server" + log "github.com/sirupsen/logrus" +) + +var ( + activeClients sync.Map +) + +const ( + TerraformAddress = "TFE_ADDRESS" + TerraformToken = "TFE_TOKEN" + TerraformSkipTLSVerify = "TFE_SKIP_VERIFY" +) + +const DefaultTerraformAddress = "https://app.terraform.io" + +type terraformClients struct { + TfeClient *tfe.Client + HttpClient *http.Client +} + +// contextKey is a type alias to avoid lint warnings while maintaining compatibility +type contextKey string + +// getEnv retrieves the value of an environment variable or returns a fallback value if not set +func getEnv(key, fallback string) string { + if value, ok := os.LookupEnv(key); ok { + return value + } + return fallback +} + +// NewTerraformClient creates a new Terraform client for the given session +func NewTerraformClient(sessionId string, terraformAddress string, terraformSkipTLSVerify bool, terraformToken string, logger *log.Logger) *terraformClients { + // Initialize Terraform client + config := &tfe.Config{ + Address: terraformAddress, + Token: terraformToken, + RetryServerErrors: true, + } + + config.HTTPClient = createHTTPClient(terraformSkipTLSVerify, logger) + terraformClients := &terraformClients{ + TfeClient: nil, + HttpClient: config.HTTPClient, + } + + client, err := tfe.NewClient(config) + if err != nil { + logger.Warnf("Failed to create a Terraform Cloud/Enterprise client: %s, %v", sessionId, err) + return terraformClients + } + + terraformClients.TfeClient = client + activeClients.Store(sessionId, terraformClients) + return terraformClients +} + +// GetTerraformClient retrieves the Terraform client for the given session +func GetTerraformClient(sessionId string) *terraformClients { + if value, ok := activeClients.Load(sessionId); ok { + return value.(*terraformClients) + } + return nil +} + +// DeleteTerraformClient removes the Terraform client for the given session +func DeleteTerraformClient(sessionId string) { + activeClients.Delete(sessionId) +} + +// GetTerraformClientFromContext extracts Terraform client from the MCP context +func GetTerraformClientFromContext(ctx context.Context, logger *log.Logger) (*terraformClients, error) { + session := server.ClientSessionFromContext(ctx) + if session == nil { + return nil, fmt.Errorf("no active session") + } + + // Log the session ID for debugging + logger.WithField("session_id", session.SessionID()).Debug("Retrieving Terraform client for session") + + // Try to get existing client + client := GetTerraformClient(session.SessionID()) + if client != nil { + return client, nil + } + + logger.WithField("session_id", session.SessionID()).Warn("Terraform client not found, creating a new one") + return CreateTerraformClientForSession(ctx, session, logger) +} + +func CreateTerraformClientForSession(ctx context.Context, session server.ClientSession, logger *log.Logger) (*terraformClients, error) { + // Initialize a new Terraform client for this session + terraformAddress, ok := ctx.Value(contextKey(TerraformAddress)).(string) + if !ok || terraformAddress == "" { + terraformAddress = getEnv(TerraformAddress, DefaultTerraformAddress) + } + + terraformToken, ok := ctx.Value(contextKey(TerraformToken)).(string) + if !ok || terraformToken == "" { + terraformToken = getEnv(TerraformToken, "") + } + + terraformSkipTLSVerifyStr, ok := ctx.Value(contextKey(TerraformSkipTLSVerify)).(string) + terraformSkipTLSVerify := false + if ok && terraformSkipTLSVerifyStr != "" { + var err error + terraformSkipTLSVerify, err = strconv.ParseBool(terraformSkipTLSVerifyStr) + if err != nil { + terraformSkipTLSVerify = false + } + } + + newClient := NewTerraformClient(session.SessionID(), terraformAddress, terraformSkipTLSVerify, terraformToken, logger) + return newClient, nil +} + +// NewSessionHandler initializes a new Terraform client for the session +func NewSessionHandler(ctx context.Context, session server.ClientSession, logger *log.Logger) { + terraformClient, err := CreateTerraformClientForSession(ctx, session, logger) + if err != nil { + logger.WithError(err).Error("NewSessionHandler failed to create Terraform client") + return + } + + // Check if the session has a valid TFE client and register with dynamic tool registry + if terraformClient.TfeClient != nil { + // Import the tools package to access the registry + // We need to avoid circular imports, so we'll use a callback approach + if registryCallback := getToolRegistryCallback(); registryCallback != nil { + registryCallback.RegisterSessionWithTFE(session.SessionID()) + } + logger.WithField("session_id", session.SessionID()).Info("Session has valid TFE client - registered with tool registry") + } else { + logger.WithField("session_id", session.SessionID()).Info("Session has no valid TFE client - TFE tools will not be available") + } +} + +// EndSessionHandler cleans up the Terraform client when the session ends +func EndSessionHandler(_ context.Context, session server.ClientSession, logger *log.Logger) { + // Unregister from tool registry if it was registered + if registryCallback := getToolRegistryCallback(); registryCallback != nil { + registryCallback.UnregisterSessionWithTFE(session.SessionID()) + } + + DeleteTerraformClient(session.SessionID()) + logger.WithField("session_id", session.SessionID()).Info("Cleaned up Terraform client for session") +} + +// ToolRegistryCallback defines the interface for interacting with the tool registry +type ToolRegistryCallback interface { + RegisterSessionWithTFE(sessionID string) + UnregisterSessionWithTFE(sessionID string) +} + +var toolRegistryCallback ToolRegistryCallback + +// SetToolRegistryCallback sets the callback for tool registry operations +func SetToolRegistryCallback(callback ToolRegistryCallback) { + toolRegistryCallback = callback +} + +// getToolRegistryCallback returns the current tool registry callback +func getToolRegistryCallback() ToolRegistryCallback { + return toolRegistryCallback +} diff --git a/pkg/client/common.go b/pkg/client/common.go index 3b34953..5da842a 100644 --- a/pkg/client/common.go +++ b/pkg/client/common.go @@ -12,9 +12,9 @@ import ( log "github.com/sirupsen/logrus" ) -func GetLatestProviderVersion(providerClient *http.Client, providerNamespace, providerName interface{}, logger *log.Logger) (string, error) { +func GetLatestProviderVersion(httpClient *http.Client, providerNamespace, providerName interface{}, logger *log.Logger) (string, error) { uri := fmt.Sprintf("providers/%s/%s", providerNamespace, providerName) - jsonData, err := SendRegistryCall(providerClient, "GET", uri, logger, "v1") + jsonData, err := SendRegistryCall(httpClient, "GET", uri, logger, "v1") if err != nil { return "", utils.LogAndReturnError(logger, "latest provider version API request", err) } @@ -30,9 +30,9 @@ func GetLatestProviderVersion(providerClient *http.Client, providerNamespace, pr // Every provider version has a unique ID, which is used to identify the provider version in the registry and its specific documentation // https://registry.terraform.io/v2/providers/hashicorp/aws?include=provider-versions -func GetProviderVersionID(registryClient *http.Client, namespace string, name string, version string, logger *log.Logger) (string, error) { +func GetProviderVersionID(httpClient *http.Client, namespace string, name string, version string, logger *log.Logger) (string, error) { uri := fmt.Sprintf("providers/%s/%s?include=provider-versions", namespace, name) - response, err := SendRegistryCall(registryClient, "GET", uri, logger, "v2") + response, err := SendRegistryCall(httpClient, "GET", uri, logger, "v2") if err != nil { return "", utils.LogAndReturnError(logger, "provider version ID request", err) } @@ -48,10 +48,10 @@ func GetProviderVersionID(registryClient *http.Client, namespace string, name st return "", fmt.Errorf("provider version %s not found", version) } -func GetProviderOverviewDocs(registryClient *http.Client, providerVersionID string, logger *log.Logger) (string, error) { +func GetProviderOverviewDocs(httpClient *http.Client, providerVersionID string, logger *log.Logger) (string, error) { // https://registry.terraform.io/v2/provider-docs?filter[provider-version]=21818&filter[category]=overview&filter[slug]=index uri := fmt.Sprintf("provider-docs?filter[provider-version]=%s&filter[category]=overview&filter[slug]=index", providerVersionID) - response, err := SendRegistryCall(registryClient, "GET", uri, logger, "v2") + response, err := SendRegistryCall(httpClient, "GET", uri, logger, "v2") if err != nil { return "", utils.LogAndReturnError(logger, "getting provider docs overview", err) } @@ -62,7 +62,7 @@ func GetProviderOverviewDocs(registryClient *http.Client, providerVersionID stri resourceContent := "" for _, providerOverviewPage := range providerOverview.Data { - resourceContentNew, err := GetProviderResourceDocs(registryClient, providerOverviewPage.ID, logger) + resourceContentNew, err := GetProviderResourceDocs(httpClient, providerOverviewPage.ID, logger) resourceContent += resourceContentNew if err != nil { return "", utils.LogAndReturnError(logger, "getting provider resource docs looping", err) @@ -72,10 +72,10 @@ func GetProviderOverviewDocs(registryClient *http.Client, providerVersionID stri return resourceContent, nil } -func GetProviderResourceDocs(registryClient *http.Client, providerDocsID string, logger *log.Logger) (string, error) { +func GetProviderResourceDocs(httpClient *http.Client, providerDocsID string, logger *log.Logger) (string, error) { // https://registry.terraform.io/v2/provider-docs/8862001 uri := fmt.Sprintf("provider-docs/%s", providerDocsID) - response, err := SendRegistryCall(registryClient, "GET", uri, logger, "v2") + response, err := SendRegistryCall(httpClient, "GET", uri, logger, "v2") if err != nil { return "", utils.LogAndReturnError(logger, "Error getting provider resource docs ", err) } diff --git a/pkg/client/security.go b/pkg/client/middleware.go similarity index 61% rename from pkg/client/security.go rename to pkg/client/middleware.go index 9e57896..baa8d94 100644 --- a/pkg/client/security.go +++ b/pkg/client/middleware.go @@ -4,7 +4,10 @@ package client import ( + "context" + "fmt" "net/http" + "net/textproto" "os" "strings" @@ -120,3 +123,53 @@ func NewSecurityHandler(handler http.Handler, allowedOrigins []string, corsMode logger: logger, } } + +// TerraformContextMiddleware adds Terraform-related header values to the request context +// This middleware extracts Terraform configuration from HTTP headers, query parameters, +// or environment variables and adds them to the request context for use by MCP tools +func TerraformContextMiddleware(logger *log.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requiredHeaders := []string{TerraformAddress, TerraformToken, TerraformSkipTLSVerify} + ctx := r.Context() + /* + if !r.URL.Query().Has("Authorization") || r.Header.Get("Authorization") == "" { + http.Error(w, "Unauthorized: Please provide valid credentials", http.StatusUnauthorized) + return + } + */ + for _, header := range requiredHeaders { + // Priority order: HTTP header -> Query parameter -> Environment variable + headerValue := r.Header.Get(textproto.CanonicalMIMEHeaderKey(header)) + + if headerValue == "" { + headerValue = r.URL.Query().Get(header) + + // Explicitly disallow TerraformToken in query parameters for security reasons + if header == TerraformToken && headerValue != "" { + logger.Info(fmt.Sprintf("Terraform token was provided in query parameters by client %v, termiating request", r.RemoteAddr)) + http.Error(w, "Terraform token should not be provided in query parameters for security reasons, use the terraform_token header", http.StatusBadRequest) + return + } + } + + if headerValue == "" { + headerValue = getEnv(header, "") + } + + // Add to context using the header name as key + ctx = context.WithValue(ctx, contextKey(header), headerValue) + + // Log the source of the configuration (without exposing sensitive values) + if header == TerraformToken && headerValue != "" { + logger.Debug("Terraform token provided via request context") + } else if header == TerraformAddress && headerValue != "" { + logger.Debug("Terraform address configured via request context") + } + } + + // Call the next handler with the enriched context + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} diff --git a/pkg/client/security_test.go b/pkg/client/middleware_test.go similarity index 50% rename from pkg/client/security_test.go rename to pkg/client/middleware_test.go index f995e28..41a2df4 100644 --- a/pkg/client/security_test.go +++ b/pkg/client/middleware_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "os" + "strings" "testing" log "github.com/sirupsen/logrus" @@ -313,3 +314,335 @@ func TestOptionsRequest(t *testing.T) { assert.Equal(t, "https://example.com", rr.Header().Get("Access-Control-Allow-Origin")) assert.NotEmpty(t, rr.Header().Get("Access-Control-Allow-Methods")) } + +// TestTerraformContextMiddleware tests the middleware that extracts Terraform configuration +// from HTTP headers, query parameters, and environment variables and adds them to the request context +func TestTerraformContextMiddleware(t *testing.T) { + logger := log.New() + logger.SetLevel(log.ErrorLevel) // Reduce noise in tests + + // Save original env vars to restore later + origAddress := os.Getenv(TerraformAddress) + origToken := os.Getenv(TerraformToken) + origSkipTLS := os.Getenv(TerraformSkipTLSVerify) + defer func() { + os.Setenv(TerraformAddress, origAddress) + os.Setenv(TerraformToken, origToken) + os.Setenv(TerraformSkipTLSVerify, origSkipTLS) + }() + + // Clear environment variables for clean test state + os.Unsetenv(TerraformAddress) + os.Unsetenv(TerraformToken) + os.Unsetenv(TerraformSkipTLSVerify) + + tests := []struct { + name string + headers map[string]string + queryParams map[string]string + envVars map[string]string + expectedStatus int + expectedContextVals map[string]string + expectError bool + errorMessage string + }{ + { + name: "headers take priority over query params and env vars", + headers: map[string]string{ + TerraformAddress: "https://header.terraform.io", + TerraformToken: "header-token", + TerraformSkipTLSVerify: "true", + }, + queryParams: map[string]string{ + TerraformAddress: "https://query.terraform.io", + TerraformSkipTLSVerify: "false", + }, + envVars: map[string]string{ + TerraformAddress: "https://env.terraform.io", + TerraformToken: "env-token", + TerraformSkipTLSVerify: "false", + }, + expectedStatus: http.StatusOK, + expectedContextVals: map[string]string{ + TerraformAddress: "https://header.terraform.io", + TerraformToken: "header-token", + TerraformSkipTLSVerify: "true", + }, + }, + { + name: "query params take priority over env vars (except token)", + headers: map[string]string{}, + queryParams: map[string]string{ + TerraformAddress: "https://query.terraform.io", + TerraformSkipTLSVerify: "true", + }, + envVars: map[string]string{ + TerraformAddress: "https://env.terraform.io", + TerraformToken: "env-token", + TerraformSkipTLSVerify: "false", + }, + expectedStatus: http.StatusOK, + expectedContextVals: map[string]string{ + TerraformAddress: "https://query.terraform.io", + TerraformToken: "env-token", // From env since not in query + TerraformSkipTLSVerify: "true", + }, + }, + { + name: "env vars used as fallback", + headers: map[string]string{}, + queryParams: map[string]string{}, + envVars: map[string]string{ + TerraformAddress: "https://env.terraform.io", + TerraformToken: "env-token", + TerraformSkipTLSVerify: "true", + }, + expectedStatus: http.StatusOK, + expectedContextVals: map[string]string{ + TerraformAddress: "https://env.terraform.io", + TerraformToken: "env-token", + TerraformSkipTLSVerify: "true", + }, + }, + { + name: "empty values result in empty context values", + headers: map[string]string{}, + queryParams: map[string]string{ + TerraformAddress: "", // Empty value + }, + envVars: map[string]string{}, + expectedStatus: http.StatusOK, + expectedContextVals: map[string]string{ + TerraformAddress: "", + TerraformToken: "", + TerraformSkipTLSVerify: "", + }, + }, + { + name: "token in query params is rejected for security", + headers: map[string]string{}, + queryParams: map[string]string{ + TerraformAddress: "https://query.terraform.io", + TerraformToken: "query-token", // This should cause an error + }, + envVars: map[string]string{}, + expectedStatus: http.StatusBadRequest, + expectError: true, + errorMessage: "Terraform token should not be provided in query parameters for security reasons, use the terraform_token header", + }, + { + name: "canonical header names are handled correctly", + headers: map[string]string{ + "tfe_address": "https://canonical.terraform.io", // lowercase + "TFE_TOKEN": "canonical-token", // uppercase + "Tfe_Skip_Verify": "true", // mixed case + }, + queryParams: map[string]string{}, + envVars: map[string]string{}, + expectedStatus: http.StatusOK, + expectedContextVals: map[string]string{ + TerraformAddress: "https://canonical.terraform.io", + TerraformToken: "canonical-token", + TerraformSkipTLSVerify: "true", + }, + }, + { + name: "mixed sources - headers override query params, query params override env", + headers: map[string]string{ + TerraformAddress: "https://header.terraform.io", // Header wins + }, + queryParams: map[string]string{ + TerraformSkipTLSVerify: "true", // Query param wins over env + }, + envVars: map[string]string{ + TerraformAddress: "https://env.terraform.io", // Overridden by header + TerraformToken: "env-token", // Used since not in header/query + TerraformSkipTLSVerify: "false", // Overridden by query param + }, + expectedStatus: http.StatusOK, + expectedContextVals: map[string]string{ + TerraformAddress: "https://header.terraform.io", + TerraformToken: "env-token", + TerraformSkipTLSVerify: "true", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up environment variables for this test + for key, value := range tt.envVars { + os.Setenv(key, value) + } + defer func() { + for key := range tt.envVars { + os.Unsetenv(key) + } + }() + + // Create a mock handler that captures the context values + var capturedContext map[string]string + mockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedContext = make(map[string]string) + ctx := r.Context() + + // Extract all terraform-related context values + for _, key := range []string{TerraformAddress, TerraformToken, TerraformSkipTLSVerify} { + if val := ctx.Value(contextKey(key)); val != nil { + if strVal, ok := val.(string); ok { + capturedContext[key] = strVal + } + } else { + capturedContext[key] = "" // Explicitly track nil/missing values as empty strings + } + } + + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + }) + + // Create the middleware + middleware := TerraformContextMiddleware(logger) + handler := middleware(mockHandler) + + // Create request with headers and query parameters + req := httptest.NewRequest("GET", "/mcp", nil) + + // Set headers + for key, value := range tt.headers { + req.Header.Set(key, value) + } + + // Set query parameters + q := req.URL.Query() + for key, value := range tt.queryParams { + q.Set(key, value) + } + req.URL.RawQuery = q.Encode() + + // Execute request + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + // Verify response status + assert.Equal(t, tt.expectedStatus, rr.Code) + + if tt.expectError { + // Verify error message is in response body + assert.Contains(t, rr.Body.String(), tt.errorMessage) + } else { + // Verify context values were set correctly + assert.NotNil(t, capturedContext, "Context should have been captured") + for key, expectedValue := range tt.expectedContextVals { + actualValue, exists := capturedContext[key] + assert.True(t, exists, "Context should contain key %s", key) + assert.Equal(t, expectedValue, actualValue, "Context value for %s should match", key) + } + } + }) + } +} + +// TestTerraformContextMiddleware_SecurityLogging tests that the middleware properly logs +// security-related events without exposing sensitive information +func TestTerraformContextMiddleware_SecurityLogging(t *testing.T) { + // Create a custom logger that captures log output + logger := log.New() + logger.SetLevel(log.DebugLevel) + + // Create a mock handler + mockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + middleware := TerraformContextMiddleware(logger) + handler := middleware(mockHandler) + + t.Run("token provided via header is logged without exposing value", func(t *testing.T) { + req := httptest.NewRequest("GET", "/mcp", nil) + req.Header.Set(TerraformToken, "secret-token") + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + // Note: In a real test, you'd capture the log output and verify it contains + // "Terraform token provided via request context" but doesn't contain "secret-token" + }) + + t.Run("address provided via header is logged", func(t *testing.T) { + req := httptest.NewRequest("GET", "/mcp", nil) + req.Header.Set(TerraformAddress, "https://custom.terraform.io") + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + // Note: In a real test, you'd capture the log output and verify it contains + // "Terraform address configured via request context" + }) +} + +// TestTerraformContextMiddleware_EdgeCases tests edge cases and error conditions +func TestTerraformContextMiddleware_EdgeCases(t *testing.T) { + logger := log.New() + logger.SetLevel(log.ErrorLevel) + + t.Run("nil logger should not panic", func(t *testing.T) { + // This tests that the middleware handles a nil logger gracefully + mockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Create middleware with nil logger - this should not panic + assert.NotPanics(t, func() { + middleware := TerraformContextMiddleware(nil) + handler := middleware(mockHandler) + + req := httptest.NewRequest("GET", "/mcp", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + }) + }) + + t.Run("malformed query parameters are handled gracefully", func(t *testing.T) { + mockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + middleware := TerraformContextMiddleware(logger) + handler := middleware(mockHandler) + + // Create request with malformed query string + req := httptest.NewRequest("GET", "/mcp?%invalid", nil) + + rr := httptest.NewRecorder() + // This should not panic even with malformed query parameters + assert.NotPanics(t, func() { + handler.ServeHTTP(rr, req) + }) + }) + + t.Run("very long header values are handled", func(t *testing.T) { + mockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + val := ctx.Value(contextKey(TerraformAddress)) + assert.NotNil(t, val) + w.WriteHeader(http.StatusOK) + }) + + middleware := TerraformContextMiddleware(logger) + handler := middleware(mockHandler) + + // Create a very long address value + longAddress := "https://" + strings.Repeat("a", 1000) + ".terraform.io" + + req := httptest.NewRequest("GET", "/mcp", nil) + req.Header.Set(TerraformAddress, longAddress) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + }) +} diff --git a/pkg/client/registry.go b/pkg/client/registry.go index 5dd2736..02dff35 100644 --- a/pkg/client/registry.go +++ b/pkg/client/registry.go @@ -4,16 +4,61 @@ package client import ( + "context" + "crypto/tls" "encoding/json" "fmt" "io" "net/http" "net/url" + "strconv" + "time" + "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/go-retryablehttp" "github.com/hashicorp/terraform-mcp-server/pkg/utils" log "github.com/sirupsen/logrus" ) +// createHTTPClient initializes a retryable HTTP client +func createHTTPClient(insecureSkipVerify bool, logger *log.Logger) *http.Client { + retryClient := retryablehttp.NewClient() + retryClient.Logger = logger + + transport := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: insecureSkipVerify}, + } + transport.Proxy = http.ProxyFromEnvironment + + retryClient.HTTPClient = cleanhttp.DefaultClient() + retryClient.HTTPClient.Timeout = 10 * time.Second + retryClient.HTTPClient.Transport = transport + retryClient.RetryMax = 3 + + retryClient.Backoff = func(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration { + if resp != nil && resp.StatusCode == http.StatusTooManyRequests { + resetAfter := resp.Header.Get("x-ratelimit-reset") + resetAfterInt, err := strconv.ParseInt(resetAfter, 10, 64) + if err != nil { + return 0 + } + resetAfterTime := time.Unix(resetAfterInt, 0) + return time.Until(resetAfterTime) + } + return 0 + } + + retryClient.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) { + if resp != nil && resp.StatusCode == http.StatusTooManyRequests { + resetAfter := resp.Header.Get("x-ratelimit-reset") + return resetAfter != "", nil + } + return false, nil + } + + return retryClient.StandardClient() +} + func SendRegistryCall(client *http.Client, method string, uri string, logger *log.Logger, callOptions ...string) ([]byte, error) { version := "v1" if len(callOptions) > 0 { diff --git a/pkg/resources/resource_templates.go b/pkg/resources/resource_templates.go index 75ea783..663068f 100644 --- a/pkg/resources/resource_templates.go +++ b/pkg/resources/resource_templates.go @@ -15,11 +15,11 @@ import ( log "github.com/sirupsen/logrus" ) -func RegisterResourceTemplates(hcServer *server.MCPServer, registryClient *http.Client, logger *log.Logger) { - hcServer.AddResourceTemplate(ProviderResourceTemplate(registryClient, fmt.Sprintf("%s/{namespace}/name/{name}/version/{version}", utils.PROVIDER_BASE_PATH), "Provider details", logger)) +func RegisterResourceTemplates(hcServer *server.MCPServer, logger *log.Logger) { + hcServer.AddResourceTemplate(ProviderResourceTemplate(fmt.Sprintf("%s/{namespace}/name/{name}/version/{version}", utils.PROVIDER_BASE_PATH), "Provider details", logger)) } -func ProviderResourceTemplate(registryClient *http.Client, resourceURI string, description string, logger *log.Logger) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { +func ProviderResourceTemplate(resourceURI string, description string, logger *log.Logger) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { return mcp.NewResourceTemplate( resourceURI, description, @@ -32,7 +32,15 @@ func ProviderResourceTemplate(registryClient *http.Client, resourceURI string, d ), func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { logger.Debugf("Provider resource template - resourceURI: %s", request.Params.URI) - providerDocs, err := ProviderResourceTemplateHandler(registryClient, request.Params.URI, logger) + + // Get a simple http client to access the public Terraform registry from context + terraformClients, err := client.GetTerraformClientFromContext(ctx, logger) + if err != nil { + return nil, utils.LogAndReturnError(logger, "failed to get http client for public Terraform registry", err) + } + + httpClient := terraformClients.HttpClient + providerDocs, err := providerResourceTemplateHelper(httpClient, request.Params.URI, logger) if err != nil { return nil, utils.LogAndReturnError(logger, "Provider Resource: error getting provider details", err) } @@ -46,13 +54,14 @@ func ProviderResourceTemplate(registryClient *http.Client, resourceURI string, d } } -func ProviderResourceTemplateHandler(registryClient *http.Client, resourceURI string, logger *log.Logger) (string, error) { +// providerResourceTemplateHelper fetches the provider details based on the resource URI +func providerResourceTemplateHelper(httpClient *http.Client, resourceURI string, logger *log.Logger) (string, error) { namespace, name, version := utils.ExtractProviderNameAndVersion(resourceURI) logger.Debugf("Extracted namespace: %s, name: %s, version: %s", namespace, name, version) var err error if version == "" || version == "latest" || !utils.IsValidProviderVersionFormat(version) { - version, err = client.GetLatestProviderVersion(registryClient, namespace, name, logger) + version, err = client.GetLatestProviderVersion(httpClient, namespace, name, logger) if err != nil { return "", utils.LogAndReturnError(logger, fmt.Sprintf("Provider Resource: error getting %s/%s latest provider version", namespace, name), err) } @@ -64,14 +73,14 @@ func ProviderResourceTemplateHandler(registryClient *http.Client, resourceURI st } // Get the provider-version-id for the specified provider version - providerVersionID, err := client.GetProviderVersionID(registryClient, namespace, name, version, logger) + providerVersionID, err := client.GetProviderVersionID(httpClient, namespace, name, version, logger) logger.Debugf("Provider resource template - Provider version id providerVersionID: %s, providerVersionUri: %s", providerVersionID, providerVersionUri) if err != nil { return "", utils.LogAndReturnError(logger, "getting provider details", err) } // Get all the docs based on provider version id - providerDocs, err := client.GetProviderOverviewDocs(registryClient, providerVersionID, logger) + providerDocs, err := client.GetProviderOverviewDocs(httpClient, providerVersionID, logger) logger.Debugf("Provider resource template - Provider docs providerVersionID: %s", providerVersionID) if err != nil { return "", utils.LogAndReturnError(logger, "getting provider details", err) diff --git a/pkg/resources/resources.go b/pkg/resources/resources.go index 95229b8..4bb47c1 100644 --- a/pkg/resources/resources.go +++ b/pkg/resources/resources.go @@ -9,6 +9,7 @@ import ( "io" "net/http" + "github.com/hashicorp/terraform-mcp-server/pkg/client" "github.com/hashicorp/terraform-mcp-server/pkg/utils" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" @@ -19,13 +20,13 @@ import ( const terraformGuideRawURL = "https://raw.githubusercontent.com/hashicorp/web-unified-docs/main/content/terraform/v1.12.x/docs/language" // RegisterResources adds the new resource -func RegisterResources(hcServer *server.MCPServer, registryClient *http.Client, logger *log.Logger) { - hcServer.AddResource(TerraformStyleGuideResource(registryClient, logger)) - hcServer.AddResource(TerraformModuleDevGuideResource(registryClient, logger)) +func RegisterResources(hcServer *server.MCPServer, logger *log.Logger) { + hcServer.AddResource(TerraformStyleGuideResource(logger)) + hcServer.AddResource(TerraformModuleDevGuideResource(logger)) } // TerraformStyleGuideResource returns the resource and handler for the style guide -func TerraformStyleGuideResource(httpClient *http.Client, logger *log.Logger) (mcp.Resource, server.ResourceHandlerFunc) { +func TerraformStyleGuideResource(logger *log.Logger) (mcp.Resource, server.ResourceHandlerFunc) { resourceURI := "/terraform/style-guide" description := "Terraform Style Guide" @@ -36,6 +37,14 @@ func TerraformStyleGuideResource(httpClient *http.Client, logger *log.Logger) (m mcp.WithResourceDescription(description), ), func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + + // Get a simple http client to access the public Terraform registry from context + terraformClients, err := client.GetTerraformClientFromContext(ctx, logger) + if err != nil { + return nil, utils.LogAndReturnError(logger, "failed to get http client for public Terraform registry", err) + } + + httpClient := terraformClients.HttpClient resp, err := httpClient.Get(fmt.Sprintf("%s/style.mdx", terraformGuideRawURL)) if err != nil { return nil, utils.LogAndReturnError(logger, "Error fetching Terraform Style Guide markdown", err) @@ -59,7 +68,7 @@ func TerraformStyleGuideResource(httpClient *http.Client, logger *log.Logger) (m } // TerraformModuleDevGuideResource returns a resource and handler for the Terraform Module Development Guide markdown files -func TerraformModuleDevGuideResource(httpClient *http.Client, logger *log.Logger) (mcp.Resource, server.ResourceHandlerFunc) { +func TerraformModuleDevGuideResource(logger *log.Logger) (mcp.Resource, server.ResourceHandlerFunc) { resourceURI := "/terraform/module-development" description := "Terraform Module Development Guide" @@ -82,6 +91,13 @@ func TerraformModuleDevGuideResource(httpClient *http.Client, logger *log.Logger mcp.WithResourceDescription(description), ), func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + // Get a simple http client to access the public Terraform registry from context + terraformClients, err := client.GetTerraformClientFromContext(ctx, logger) + if err != nil { + return nil, utils.LogAndReturnError(logger, "failed to get http client for public Terraform registry", err) + } + httpClient := terraformClients.HttpClient + var contents []mcp.ResourceContents for _, u := range urls { resp, err := httpClient.Get(u.URL) diff --git a/pkg/tools/dynamic_tool.go b/pkg/tools/dynamic_tool.go new file mode 100644 index 0000000..8a9886a --- /dev/null +++ b/pkg/tools/dynamic_tool.go @@ -0,0 +1,145 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package tools + +import ( + "context" + "sync" + + "github.com/hashicorp/terraform-mcp-server/pkg/client" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + log "github.com/sirupsen/logrus" +) + +// DynamicToolRegistry manages the availability of tools based on session state +type DynamicToolRegistry struct { + mu sync.RWMutex + sessionsWithTFE map[string]bool // sessionID -> hasTFEClient + tfeToolsRegistered bool + mcpServer *server.MCPServer + logger *log.Logger +} + +var globalToolRegistry *DynamicToolRegistry + +// registerDynamicTools registers the global tool registry +func registerDynamicTools(mcpServer *server.MCPServer, logger *log.Logger) { + globalToolRegistry = &DynamicToolRegistry{ + sessionsWithTFE: make(map[string]bool), + tfeToolsRegistered: false, + mcpServer: mcpServer, + logger: logger, + } + + // Set the callback in the client package to avoid circular imports + client.SetToolRegistryCallback(globalToolRegistry) +} + +// GetDynamicToolRegistry returns the global tool registry instance +func GetDynamicToolRegistry() *DynamicToolRegistry { + return globalToolRegistry +} + +// RegisterSessionWithTFE marks a session as having a valid TFE client +func (r *DynamicToolRegistry) RegisterSessionWithTFE(sessionID string) { + r.mu.Lock() + defer r.mu.Unlock() + + r.sessionsWithTFE[sessionID] = true + r.logger.WithField("session_id", sessionID).Info("Session registered with TFE client") + + // If this is the first session with TFE, register the tools + if !r.tfeToolsRegistered { + r.registerTFETools() + } +} + +// UnregisterSessionWithTFE removes a session from the TFE registry +func (r *DynamicToolRegistry) UnregisterSessionWithTFE(sessionID string) { + r.mu.Lock() + defer r.mu.Unlock() + + delete(r.sessionsWithTFE, sessionID) + r.logger.WithField("session_id", sessionID).Info("Session unregistered from TFE client") + + // If no sessions have TFE clients, we could unregister tools + // but since MCP doesn't support tool removal, we keep them registered + // and rely on runtime checks +} + +// HasSessionWithTFE checks if a specific session has a TFE client +func (r *DynamicToolRegistry) HasSessionWithTFE(sessionID string) bool { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.sessionsWithTFE[sessionID] +} + +// HasAnySessionWithTFE checks if any session has a TFE client +func (r *DynamicToolRegistry) HasAnySessionWithTFE() bool { + r.mu.RLock() + defer r.mu.RUnlock() + + return len(r.sessionsWithTFE) > 0 +} + +// registerTFETools registers TFE tools with the MCP server +func (r *DynamicToolRegistry) registerTFETools() { + if r.tfeToolsRegistered { + return + } + + r.logger.Info("Registering TFE tools - first session with valid TFE client detected") + + // Create TFE tools with dynamic availability checking + listTerraformOrgsTool := r.createDynamicTFETool("list_terraform_orgs", ListTerraformOrgs) + r.mcpServer.AddTool(listTerraformOrgsTool.Tool, listTerraformOrgsTool.Handler) + + listTerraformProjectsTool := r.createDynamicTFETool("list_terraform_projects", ListTerraformProjects) + r.mcpServer.AddTool(listTerraformProjectsTool.Tool, listTerraformProjectsTool.Handler) + + r.tfeToolsRegistered = true +} + +// createDynamicTFETool creates a TFE tool with dynamic availability checking +func (r *DynamicToolRegistry) createDynamicTFETool(toolName string, toolFactory func(*log.Logger) server.ServerTool) server.ServerTool { + originalTool := toolFactory(r.logger) + + // Wrap the handler with dynamic availability checking + wrappedHandler := func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Get session from context + session := server.ClientSessionFromContext(ctx) + if session == nil { + r.logger.WithField("tool", toolName).Warn("TFE tool called without session context") + return mcp.NewToolResultError("This tool requires an active session with valid Terraform Cloud/Enterprise configuration."), nil + } + + // Check if this session has a valid TFE client + sessionID := session.SessionID() + if !r.HasSessionWithTFE(sessionID) { + // Double-check by looking at the actual client state + terraformClient := client.GetTerraformClient(sessionID) + if terraformClient == nil || terraformClient.TfeClient == nil { + r.logger.WithFields(log.Fields{ + "tool": toolName, + "session_id": sessionID, + }).Warn("TFE tool called but session has no valid TFE client") + + return mcp.NewToolResultError("This tool is not available. This tool requires a valid Terraform Cloud/Enterprise token and configuration. Please ensure TFE_TOKEN and TFE_ADDRESS environment variables are properly set."), nil + } + + // If we found a valid client that wasn't registered, register it now + r.RegisterSessionWithTFE(sessionID) + } + + // Tool is available, proceed with original handler + return originalTool.Handler(ctx, req) + } + + return server.ServerTool{ + Tool: originalTool.Tool, + Handler: wrappedHandler, + } +} diff --git a/pkg/tools/dynamic_tool_test.go b/pkg/tools/dynamic_tool_test.go new file mode 100644 index 0000000..da0c4a2 --- /dev/null +++ b/pkg/tools/dynamic_tool_test.go @@ -0,0 +1,164 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package tools + +import ( + "fmt" + "testing" + + log "github.com/sirupsen/logrus" +) + +func TestDynamicToolRegistry_SessionManagement(t *testing.T) { + logger := log.New() + logger.SetLevel(log.ErrorLevel) // Reduce noise in tests + + // Create a registry without initializing the MCP server + registry := &DynamicToolRegistry{ + sessionsWithTFE: make(map[string]bool), + tfeToolsRegistered: false, + mcpServer: nil, // We'll skip actual tool registration + logger: logger, + } + + // Initially no sessions should have TFE + if registry.HasAnySessionWithTFE() { + t.Error("Expected no sessions with TFE initially") + } + + sessionID1 := "test-session-1" + sessionID2 := "test-session-2" + + // Check specific sessions + if registry.HasSessionWithTFE(sessionID1) { + t.Error("Expected session1 to not have TFE initially") + } + + // Manually register sessions (without triggering tool registration) + registry.mu.Lock() + registry.sessionsWithTFE[sessionID1] = true + registry.mu.Unlock() + + if !registry.HasSessionWithTFE(sessionID1) { + t.Error("Expected session1 to have TFE after registration") + } + + if !registry.HasAnySessionWithTFE() { + t.Error("Expected at least one session with TFE") + } + + if registry.HasSessionWithTFE(sessionID2) { + t.Error("Expected session2 to not have TFE") + } + + // Register second session + registry.mu.Lock() + registry.sessionsWithTFE[sessionID2] = true + registry.mu.Unlock() + + if !registry.HasSessionWithTFE(sessionID2) { + t.Error("Expected session2 to have TFE after registration") + } + + // Unregister first session + registry.UnregisterSessionWithTFE(sessionID1) + + if registry.HasSessionWithTFE(sessionID1) { + t.Error("Expected session1 to not have TFE after unregistration") + } + + if !registry.HasSessionWithTFE(sessionID2) { + t.Error("Expected session2 to still have TFE") + } + + if !registry.HasAnySessionWithTFE() { + t.Error("Expected session2 to still provide TFE availability") + } + + // Unregister second session + registry.UnregisterSessionWithTFE(sessionID2) + + if registry.HasSessionWithTFE(sessionID2) { + t.Error("Expected session2 to not have TFE after unregistration") + } + + if registry.HasAnySessionWithTFE() { + t.Error("Expected no sessions with TFE after all unregistered") + } +} + +func TestDynamicToolRegistry_ToolRegistrationState(t *testing.T) { + logger := log.New() + logger.SetLevel(log.ErrorLevel) // Reduce noise in tests + + // Create a registry without MCP server to test state management + registry := &DynamicToolRegistry{ + sessionsWithTFE: make(map[string]bool), + tfeToolsRegistered: false, + mcpServer: nil, + logger: logger, + } + + // Initially tools should not be registered + if registry.tfeToolsRegistered { + t.Error("Expected TFE tools to not be registered initially") + } + + // Manually set tools as registered (simulating what would happen) + registry.mu.Lock() + registry.tfeToolsRegistered = true + registry.mu.Unlock() + + // Now tools should be registered + if !registry.tfeToolsRegistered { + t.Error("Expected TFE tools to be registered") + } +} + +func TestDynamicToolRegistry_ConcurrentAccess(t *testing.T) { + logger := log.New() + logger.SetLevel(log.ErrorLevel) // Reduce noise in tests + + // Create a registry for concurrent testing + registry := &DynamicToolRegistry{ + sessionsWithTFE: make(map[string]bool), + tfeToolsRegistered: false, + mcpServer: nil, + logger: logger, + } + + // Test concurrent registration and unregistration + done := make(chan bool, 10) + + // Start multiple goroutines registering sessions + for i := 0; i < 5; i++ { + go func(id int) { + sessionID := fmt.Sprintf("session-%d", id) + // Manually register/unregister to avoid MCP server calls + registry.mu.Lock() + registry.sessionsWithTFE[sessionID] = true + registry.mu.Unlock() + + registry.UnregisterSessionWithTFE(sessionID) + done <- true + }(i) + } + + // Start multiple goroutines checking state + for i := 0; i < 5; i++ { + go func(id int) { + sessionID := fmt.Sprintf("session-%d", id) + registry.HasSessionWithTFE(sessionID) + registry.HasAnySessionWithTFE() + done <- true + }(i) + } + + // Wait for all goroutines to complete + for i := 0; i < 10; i++ { + <-done + } + + // Test should complete without deadlocks or panics +} diff --git a/pkg/tools/get_provider_docs.go b/pkg/tools/get_provider_docs.go index 031d029..c4df71e 100644 --- a/pkg/tools/get_provider_docs.go +++ b/pkg/tools/get_provider_docs.go @@ -7,7 +7,6 @@ import ( "context" "encoding/json" "fmt" - "net/http" "github.com/hashicorp/terraform-mcp-server/pkg/client" "github.com/hashicorp/terraform-mcp-server/pkg/utils" @@ -17,24 +16,25 @@ import ( ) // GetProviderDocs creates a tool to get provider docs for a specific service from registry. -func GetProviderDocs(registryClient *http.Client, logger *log.Logger) server.ServerTool { +func GetProviderDocs(logger *log.Logger) server.ServerTool { return server.ServerTool{ Tool: mcp.NewTool("get_provider_docs", mcp.WithDescription(`Fetches up-to-date documentation for a specific service from a Terraform provider. You must call 'resolve_provider_doc_id' tool first to obtain the exact tfprovider-compatible provider_doc_id required to use this tool.`), mcp.WithTitleAnnotation("Fetch detailed Terraform provider documentation using a document ID"), mcp.WithOpenWorldHintAnnotation(true), + mcp.WithReadOnlyHintAnnotation(true), mcp.WithString("provider_doc_id", mcp.Required(), mcp.Description("Exact tfprovider-compatible provider_doc_id, (e.g., '8894603', '8906901') retrieved from 'resolve_provider_doc_id'")), ), Handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return getProviderDocsHandler(registryClient, req, logger) + return getProviderDocsHandler(ctx, req, logger) }, } } -func getProviderDocsHandler(registryClient *http.Client, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { +func getProviderDocsHandler(ctx context.Context, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { providerDocID, err := request.RequireString("provider_doc_id") if err != nil { return nil, utils.LogAndReturnError(logger, "provider_doc_id is required", err) @@ -43,7 +43,16 @@ func getProviderDocsHandler(registryClient *http.Client, request mcp.CallToolReq return nil, utils.LogAndReturnError(logger, "provider_doc_id cannot be empty", nil) } - detailResp, err := client.SendRegistryCall(registryClient, "GET", fmt.Sprintf("provider-docs/%s", providerDocID), logger, "v2") + // Get a simple http client to access the public Terraform registry from context + terraformClients, err := client.GetTerraformClientFromContext(ctx, logger) + if err != nil { + logger.WithError(err).Error("failed to get http client for public Terraform registry") + return mcp.NewToolResultError(fmt.Sprintf("failed to get http client for public Terraform registry: %v", err)), nil + } + + httpClient := terraformClients.HttpClient + + detailResp, err := client.SendRegistryCall(httpClient, "GET", fmt.Sprintf("provider-docs/%s", providerDocID), logger, "v2") if err != nil { return nil, utils.LogAndReturnError(logger, fmt.Sprintf("Error fetching provider-docs/%s, please make sure provider_doc_id is valid and the resolve_provider_doc_id tool has run prior", providerDocID), err) } diff --git a/pkg/tools/list_terraform_orgs.go b/pkg/tools/list_terraform_orgs.go new file mode 100644 index 0000000..416a728 --- /dev/null +++ b/pkg/tools/list_terraform_orgs.go @@ -0,0 +1,66 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package tools + +import ( + "context" + "encoding/json" + + "github.com/hashicorp/go-tfe" + "github.com/hashicorp/terraform-mcp-server/pkg/client" + "github.com/hashicorp/terraform-mcp-server/pkg/utils" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + log "github.com/sirupsen/logrus" +) + +// ListTerraformOrgs creates a tool to get terraform organizations. +func ListTerraformOrgs(logger *log.Logger) server.ServerTool { + return server.ServerTool{ + Tool: mcp.NewTool("list_terraform_orgs", + mcp.WithDescription(`Fetches a list of all Terraform organizations.`), + mcp.WithTitleAnnotation("List all Terraform organizations"), + mcp.WithReadOnlyHintAnnotation(true), + ), + Handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return listTerraformOrgsHandler(ctx, logger) + }, + } +} + +func listTerraformOrgsHandler(ctx context.Context, logger *log.Logger) (*mcp.CallToolResult, error) { + // Get a Terraform client from context + terraformClients, err := client.GetTerraformClientFromContext(ctx, logger) + if err != nil { + return nil, utils.LogAndReturnError(logger, "failed to get Terraform client", err) + } + + tfeClient := terraformClients.TfeClient + if tfeClient == nil { + return nil, utils.LogAndReturnError(logger, "TFE client is not available - please ensure TFE_TOKEN and TFE_ADDRESS are properly configured", nil) + } + + orgs, err := tfeClient.Organizations.List(ctx, &tfe.OrganizationListOptions{ + ListOptions: tfe.ListOptions{ + PageSize: 100, + }, + }) + + if err != nil { + logger.WithError(err).Error("failed to list Terraform organizations") + return nil, utils.LogAndReturnError(logger, "failed to list Terraform organizations", err) + } + + orgNames := make([]string, 0, len(orgs.Items)) + for _, org := range orgs.Items { + orgNames = append(orgNames, org.Name) + } + + orgsJSON, err := json.Marshal(orgNames) + if err != nil { + return nil, utils.LogAndReturnError(logger, "failed to marshal organization names", err) + } + + return mcp.NewToolResultText(string(orgsJSON)), nil +} diff --git a/pkg/tools/list_terraform_projects.go b/pkg/tools/list_terraform_projects.go new file mode 100644 index 0000000..b78ef66 --- /dev/null +++ b/pkg/tools/list_terraform_projects.go @@ -0,0 +1,77 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package tools + +import ( + "context" + "encoding/json" + + "github.com/hashicorp/go-tfe" + "github.com/hashicorp/terraform-mcp-server/pkg/client" + "github.com/hashicorp/terraform-mcp-server/pkg/utils" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + log "github.com/sirupsen/logrus" +) + +// ListTerraformProjects creates a tool to get terraform projects. +func ListTerraformProjects(logger *log.Logger) server.ServerTool { + return server.ServerTool{ + Tool: mcp.NewTool("list_terraform_projects", + mcp.WithDescription(`Fetches a list of all Terraform projects.`), + mcp.WithTitleAnnotation("List all Terraform projects"), + mcp.WithReadOnlyHintAnnotation(true), + mcp.WithString("terraform_org_name", + mcp.Required(), + mcp.Description("The name of the Terraform organization to list projects for."), + ), + ), + Handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return listTerraformProjectsHandler(ctx, req, logger) + }, + } +} + +func listTerraformProjectsHandler(ctx context.Context, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { + terraformOrgName, err := request.RequireString("terraform_org_name") + if err != nil { + return nil, utils.LogAndReturnError(logger, "terraform_org_name is required", err) + } + if terraformOrgName == "" { + return nil, utils.LogAndReturnError(logger, "terraform_org_name cannot be empty", nil) + } + + // Get a Terraform client from context + terraformClients, err := client.GetTerraformClientFromContext(ctx, logger) + if err != nil { + return nil, utils.LogAndReturnError(logger, "failed to get Terraform client", err) + } + tfeClient := terraformClients.TfeClient + if tfeClient == nil { + return nil, utils.LogAndReturnError(logger, "TFE client is not available - please ensure TFE_TOKEN and TFE_ADDRESS are properly configured", nil) + } + + // Fetch the list of projects + projects, err := tfeClient.Projects.List(ctx, terraformOrgName, &tfe.ProjectListOptions{ + ListOptions: tfe.ListOptions{ + PageSize: 100, + }, + }) + + if err != nil { + return nil, utils.LogAndReturnError(logger, "failed to list Terraform projects, check if the organization exists and you have access", err) + } + + projectNames := make([]string, 0, len(projects.Items)) + for _, project := range projects.Items { + projectNames = append(projectNames, project.Name) + } + + projectJSON, err := json.Marshal(projectNames) + if err != nil { + return nil, utils.LogAndReturnError(logger, "failed to marshal project names", err) + } + + return mcp.NewToolResultText(string(projectJSON)), nil +} diff --git a/pkg/tools/module_details.go b/pkg/tools/module_details.go index d7c8236..778b259 100644 --- a/pkg/tools/module_details.go +++ b/pkg/tools/module_details.go @@ -20,24 +20,25 @@ import ( const MODULE_BASE_PATH = "registry://modules" -func ModuleDetails(registryClient *http.Client, logger *log.Logger) server.ServerTool { +func ModuleDetails(logger *log.Logger) server.ServerTool { return server.ServerTool{ Tool: mcp.NewTool("module_details", mcp.WithDescription(`Fetches up-to-date documentation on how to use a Terraform module. You must call 'search_modules' first to obtain the exact valid and compatible module_id required to use this tool.`), mcp.WithTitleAnnotation("Retrieve documentation for a specific Terraform module"), mcp.WithOpenWorldHintAnnotation(true), + mcp.WithReadOnlyHintAnnotation(true), mcp.WithString("module_id", mcp.Required(), mcp.Description("Exact valid and compatible module_id retrieved from search_modules (e.g., 'squareops/terraform-kubernetes-mongodb/mongodb/2.1.1', 'GoogleCloudPlatform/vertex-ai/google/0.2.0')"), ), ), Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return getModuleDetailsHandler(registryClient, request, logger) + return getModuleDetailsHandler(ctx, request, logger) }, } } -func getModuleDetailsHandler(registryClient *http.Client, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { +func getModuleDetailsHandler(ctx context.Context, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { moduleID, err := request.RequireString("module_id") if err != nil { return nil, utils.LogAndReturnError(logger, "module_id is required", err) @@ -46,8 +47,17 @@ func getModuleDetailsHandler(registryClient *http.Client, request mcp.CallToolRe return nil, utils.LogAndReturnError(logger, "module_id cannot be empty", nil) } + // Get a simple http client to access the public Terraform registry from context + terraformClients, err := client.GetTerraformClientFromContext(ctx, logger) + if err != nil { + logger.WithError(err).Error("failed to get http client for public Terraform registry") + return mcp.NewToolResultError(fmt.Sprintf("failed to get http client for public Terraform registry: %v", err)), nil + } + + httpClient := terraformClients.HttpClient + var errMsg string - response, err := getModuleDetails(registryClient, moduleID, 0, logger) + response, err := getModuleDetails(httpClient, moduleID, 0, logger) if err != nil { errMsg = fmt.Sprintf("no module(s) found for %v,", moduleID) return nil, utils.LogAndReturnError(logger, errMsg, nil) @@ -63,14 +73,14 @@ func getModuleDetailsHandler(registryClient *http.Client, request mcp.CallToolRe return mcp.NewToolResultText(moduleData), nil } -func getModuleDetails(providerClient *http.Client, moduleID string, currentOffset int, logger *log.Logger) ([]byte, error) { +func getModuleDetails(httpClient *http.Client, moduleID string, currentOffset int, logger *log.Logger) ([]byte, error) { uri := "modules" if moduleID != "" { uri = fmt.Sprintf("modules/%s", moduleID) } uri = fmt.Sprintf("%s?offset=%v", uri, currentOffset) - response, err := client.SendRegistryCall(providerClient, "GET", uri, logger) + response, err := client.SendRegistryCall(httpClient, "GET", uri, logger) if err != nil { // We shouldn't log the error here because we might hit a namespace that doesn't exist, it's better to let the caller handle it. return nil, fmt.Errorf("getting module(s) for: %v, please provide a different provider name like aws, azurerm or google etc", moduleID) diff --git a/pkg/tools/policy_details.go b/pkg/tools/policy_details.go index 07b85df..93eead2 100644 --- a/pkg/tools/policy_details.go +++ b/pkg/tools/policy_details.go @@ -7,7 +7,6 @@ import ( "context" "encoding/json" "fmt" - "net/http" "strings" "github.com/hashicorp/terraform-mcp-server/pkg/client" @@ -18,24 +17,25 @@ import ( "github.com/mark3labs/mcp-go/server" ) -func PolicyDetails(registryClient *http.Client, logger *log.Logger) server.ServerTool { +func PolicyDetails(logger *log.Logger) server.ServerTool { return server.ServerTool{ Tool: mcp.NewTool("policy_details", mcp.WithDescription(`Fetches up-to-date documentation for a specific policy from the Terraform registry. You must call 'search_policies' first to obtain the exact terraform_policy_id required to use this tool.`), mcp.WithTitleAnnotation("Fetch detailed Terraform policy documentation using a terraform_policy_id"), mcp.WithOpenWorldHintAnnotation(true), + mcp.WithReadOnlyHintAnnotation(true), mcp.WithString("terraform_policy_id", mcp.Required(), mcp.Description("Matching terraform_policy_id retrieved from the 'search_policies' tool (e.g., 'policies/hashicorp/CIS-Policy-Set-for-AWS-Terraform/1.0.1')"), ), ), Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return getPolicyDetailsHandler(registryClient, request, logger) + return getPolicyDetailsHandler(ctx, request, logger) }, } } -func getPolicyDetailsHandler(registryClient *http.Client, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { +func getPolicyDetailsHandler(ctx context.Context, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { terraformPolicyID, err := request.RequireString("terraform_policy_id") if err != nil { return nil, utils.LogAndReturnError(logger, "terraform_policy_id is required and must be a string, it is fetched by running the search_policies tool", err) @@ -44,7 +44,15 @@ func getPolicyDetailsHandler(registryClient *http.Client, request mcp.CallToolRe return nil, utils.LogAndReturnError(logger, "terraform_policy_id cannot be empty, it is fetched by running the search_policies tool", nil) } - policyResp, err := client.SendRegistryCall(registryClient, "GET", fmt.Sprintf("%s?include=policies,policy-modules,policy-library", terraformPolicyID), logger, "v2") + // Get a simple http client to access the public Terraform registry from context + terraformClients, err := client.GetTerraformClientFromContext(ctx, logger) + if err != nil { + logger.WithError(err).Error("failed to get http client for public Terraform registry") + return mcp.NewToolResultError(fmt.Sprintf("failed to get http client for public Terraform registry: %v", err)), nil + } + + httpClient := terraformClients.HttpClient + policyResp, err := client.SendRegistryCall(httpClient, "GET", fmt.Sprintf("%s?include=policies,policy-modules,policy-library", terraformPolicyID), logger, "v2") if err != nil { return nil, utils.LogAndReturnError(logger, "Failed to fetch policy details: registry API did not return a successful response", err) } @@ -63,7 +71,7 @@ func getPolicyDetailsHandler(registryClient *http.Client, request mcp.CallToolRe if policy.Type == "policy-modules" { moduleList += fmt.Sprintf(` module "%s" { -source = "https://registry.terraform.io/v2%s/policy-module/%s.sentinel?checksum=sha256:%s" + source = "https://registry.terraform.io/v2%s/policy-module/%s.sentinel?checksum=sha256:%s" } `, policy.Attributes.Name, terraformPolicyID, policy.Attributes.Name, policy.Attributes.Shasum) } @@ -80,8 +88,8 @@ source = "https://registry.terraform.io/v2%s/policy-module/%s.sentinel?checksum= hclTemplate := fmt.Sprintf(` %s policy "<>" { -source = "https://registry.terraform.io/v2%s/policy/<>.sentinel?checksum=<>" -enforcement_level = "advisory" + source = "https://registry.terraform.io/v2%s/policy/<>.sentinel?checksum=<>" + enforcement_level = "advisory" } `, moduleList, terraformPolicyID) builder.WriteString(hclTemplate) diff --git a/pkg/tools/resolve_provider_doc_id.go b/pkg/tools/resolve_provider_doc_id.go index b66aaaa..06a6901 100644 --- a/pkg/tools/resolve_provider_doc_id.go +++ b/pkg/tools/resolve_provider_doc_id.go @@ -19,7 +19,7 @@ import ( ) // ResolveProviderDocID creates a tool to get provider details from registry. -func ResolveProviderDocID(registryClient *http.Client, logger *log.Logger) server.ServerTool { +func ResolveProviderDocID(logger *log.Logger) server.ServerTool { return server.ServerTool{ Tool: mcp.NewTool("resolve_provider_doc_id", mcp.WithDescription(`This tool retrieves a list of potential documents based on the service_slug and provider_data_type provided. @@ -31,6 +31,7 @@ When selecting the best match, consider the following: Return the selected provider_doc_id and explain your choice. If there are multiple good matches, mention this but proceed with the most relevant one.`), mcp.WithTitleAnnotation("Identify the most relevant provider document ID for a Terraform service"), + mcp.WithOpenWorldHintAnnotation(true), mcp.WithReadOnlyHintAnnotation(true), mcp.WithString("provider_name", mcp.Required(), @@ -53,15 +54,24 @@ If there are multiple good matches, mention this but proceed with the most relev mcp.Description("The version of the Terraform provider to retrieve in the format 'x.y.z', or 'latest' to get the latest version")), ), Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return resolveProviderDocIDHandler(registryClient, request, logger) + return resolveProviderDocIDHandler(ctx, request, logger) }, } } -func resolveProviderDocIDHandler(registryClient *http.Client, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { +func resolveProviderDocIDHandler(ctx context.Context, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { // For typical provider and namespace hallucinations defaultErrorGuide := "please check the provider name, provider namespace or the provider version you're looking for, perhaps the provider is published under a different namespace or company name" - providerDetail, err := resolveProviderDetails(request, registryClient, defaultErrorGuide, logger) + + // Get a simple http client to access the public Terraform registry from context + terraformClients, err := client.GetTerraformClientFromContext(ctx, logger) + if err != nil { + logger.WithError(err).Error("failed to get http client for public Terraform registry") + return mcp.NewToolResultError(fmt.Sprintf("failed to get http client for public Terraform registry: %v", err)), nil + } + + httpClient := terraformClients.HttpClient + providerDetail, err := resolveProviderDetails(request, httpClient, defaultErrorGuide, logger) if err != nil { return nil, err } @@ -79,7 +89,7 @@ func resolveProviderDocIDHandler(registryClient *http.Client, request mcp.CallTo // Check if we need to use v2 API for guides, functions, or overview if utils.IsV2ProviderDataType(providerDetail.ProviderDataType) { - content, err := get_provider_docsV2(registryClient, providerDetail, logger) + content, err := get_provider_docsV2(httpClient, providerDetail, logger) if err != nil { errMessage := fmt.Sprintf(`No %s documentation found for provider '%s' in the '%s' namespace, %s`, providerDetail.ProviderDataType, providerDetail.ProviderName, providerDetail.ProviderNamespace, defaultErrorGuide) @@ -94,10 +104,9 @@ func resolveProviderDocIDHandler(registryClient *http.Client, request mcp.CallTo // For resources/data-sources, use the v1 API for better performance (single response) uri := fmt.Sprintf("providers/%s/%s/%s", providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion) - response, err := client.SendRegistryCall(registryClient, "GET", uri, logger) + response, err := client.SendRegistryCall(httpClient, "GET", uri, logger) if err != nil { - return nil, utils.LogAndReturnError(logger, fmt.Sprintf(`Error getting the "%s" provider, - with version "%s" in the %s namespace, %s`, providerDetail.ProviderName, providerDetail.ProviderVersion, providerDetail.ProviderNamespace, defaultErrorGuide), nil) + return nil, utils.LogAndReturnError(logger, fmt.Sprintf(`Error getting the "%s" provider, with version "%s" in the %s namespace, %s`, providerDetail.ProviderName, providerDetail.ProviderVersion, providerDetail.ProviderNamespace, defaultErrorGuide), nil) } var providerDocs client.ProviderDocs @@ -167,8 +176,7 @@ func resolveProviderDetails(request mcp.CallToolRequest, registryClient *http.Cl if providerNamespace != tryProviderNamespace { tryProviderNamespace = fmt.Sprintf(`"%s" or the "%s"`, providerNamespace, tryProviderNamespace) } - return providerDetail, utils.LogAndReturnError(logger, fmt.Sprintf(`Error getting the "%s" provider, - with version "%s" in the %s namespace, %s`, providerName, providerVersion, tryProviderNamespace, defaultErrorGuide), nil) + return providerDetail, utils.LogAndReturnError(logger, fmt.Sprintf(`Error getting the "%s" provider, with version "%s" in the %s namespace, %s`, providerName, providerVersion, tryProviderNamespace, defaultErrorGuide), nil) } providerNamespace = tryProviderNamespace // Update the namespace to hashicorp, if successful } diff --git a/pkg/tools/search_modules.go b/pkg/tools/search_modules.go index 0912c19..b491cf1 100644 --- a/pkg/tools/search_modules.go +++ b/pkg/tools/search_modules.go @@ -20,7 +20,7 @@ import ( "github.com/mark3labs/mcp-go/server" ) -func SearchModules(registryClient *http.Client, logger *log.Logger) server.ServerTool { +func SearchModules(logger *log.Logger) server.ServerTool { return server.ServerTool{ Tool: mcp.NewTool("search_modules", mcp.WithDescription(`Resolves a Terraform module name to obtain a compatible module_id for the module_details tool and returns a list of matching Terraform modules. @@ -34,6 +34,7 @@ Return the selected module_id and explain your choice. If there are multiple goo If no modules were found, reattempt the search with a new moduleName query.`), mcp.WithTitleAnnotation("Search and match Terraform modules based on name and relevance"), mcp.WithOpenWorldHintAnnotation(true), + mcp.WithReadOnlyHintAnnotation(true), mcp.WithString("module_query", mcp.Required(), mcp.Description("The query to search for Terraform modules."), @@ -45,20 +46,29 @@ If no modules were found, reattempt the search with a new moduleName query.`), ), ), Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return getSearchModulesHandler(registryClient, request, logger) + return getSearchModulesHandler(ctx, request, logger) }, } } -func getSearchModulesHandler(registryClient *http.Client, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { +func getSearchModulesHandler(ctx context.Context, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { moduleQuery, err := request.RequireString("module_query") if err != nil { return nil, utils.LogAndReturnError(logger, "module_query is required", err) } currentOffsetValue := request.GetInt("current_offset", 0) + // Get a simple http client to access the public Terraform registry from context + terraformClients, err := client.GetTerraformClientFromContext(ctx, logger) + if err != nil { + logger.WithError(err).Error("failed to get http client for public Terraform registry") + return mcp.NewToolResultError(fmt.Sprintf("failed to get http client for public Terraform registry: %v", err)), nil + } + + httpClient := terraformClients.HttpClient + var modulesData, errMsg string - response, err := sendSearchModulesCall(registryClient, moduleQuery, currentOffsetValue, logger) + response, err := sendSearchModulesCall(httpClient, moduleQuery, currentOffsetValue, logger) if err != nil { return nil, utils.LogAndReturnError(logger, fmt.Sprintf("no module(s) found for moduleName: %s", moduleQuery), err) } else { diff --git a/pkg/tools/search_policies.go b/pkg/tools/search_policies.go index 0c45624..1b6a666 100644 --- a/pkg/tools/search_policies.go +++ b/pkg/tools/search_policies.go @@ -7,7 +7,6 @@ import ( "context" "encoding/json" "fmt" - "net/http" "strings" "github.com/hashicorp/terraform-mcp-server/pkg/client" @@ -18,7 +17,7 @@ import ( "github.com/mark3labs/mcp-go/server" ) -func SearchPolicies(registryClient *http.Client, logger *log.Logger) server.ServerTool { +func SearchPolicies(logger *log.Logger) server.ServerTool { return server.ServerTool{ Tool: mcp.NewTool("search_policies", mcp.WithDescription(`Searches for Terraform policies based on a query string. @@ -33,18 +32,19 @@ Return the selected policyID and explain your choice. If there are multiple good If no policies were found, reattempt the search with a new policy_query.`), mcp.WithTitleAnnotation("Search and match Terraform policies based on name and relevance"), mcp.WithOpenWorldHintAnnotation(true), + mcp.WithReadOnlyHintAnnotation(true), mcp.WithString("policy_query", mcp.Required(), mcp.Description("The query to search for Terraform modules."), ), ), Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return getSearchPoliciesHandler(registryClient, request, logger) + return getSearchPoliciesHandler(ctx, request, logger) }, } } -func getSearchPoliciesHandler(registryClient *http.Client, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { +func getSearchPoliciesHandler(ctx context.Context, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { var terraformPolicies client.TerraformPolicyList pq, err := request.RequireString("policy_query") if err != nil { @@ -54,8 +54,16 @@ func getSearchPoliciesHandler(registryClient *http.Client, request mcp.CallToolR return nil, utils.LogAndReturnError(logger, "policy_query cannot be empty", nil) } + // Get a simple http client to access the public Terraform registry from context + terraformClients, err := client.GetTerraformClientFromContext(ctx, logger) + if err != nil { + logger.WithError(err).Error("failed to get http client for public Terraform registry") + return mcp.NewToolResultError(fmt.Sprintf("failed to get http client for public Terraform registry: %v", err)), nil + } + + httpClient := terraformClients.HttpClient // static list of 100 is fine for now - policyResp, err := client.SendRegistryCall(registryClient, "GET", "policies?page%5Bsize%5D=100&include=latest-version", logger, "v2") + policyResp, err := client.SendRegistryCall(httpClient, "GET", "policies?page%5Bsize%5D=100&include=latest-version", logger, "v2") if err != nil { return nil, utils.LogAndReturnError(logger, "Failed to fetch policies: registry API did not return a successful response", err) } diff --git a/pkg/tools/tools.go b/pkg/tools/tools.go index 653a5a4..d465ea3 100644 --- a/pkg/tools/tools.go +++ b/pkg/tools/tools.go @@ -4,32 +4,32 @@ package tools import ( - "net/http" - "github.com/mark3labs/mcp-go/server" log "github.com/sirupsen/logrus" ) -func InitTools(hcServer *server.MCPServer, registryClient *http.Client, logger *log.Logger) { +func RegisterTools(hcServer *server.MCPServer, logger *log.Logger) { + // Register the dynamic tool + registerDynamicTools(hcServer, logger) - // Provider tools - getResolveProviderDocIDTool := ResolveProviderDocID(registryClient, logger) + // Provider tools (always available) + getResolveProviderDocIDTool := ResolveProviderDocID(logger) hcServer.AddTool(getResolveProviderDocIDTool.Tool, getResolveProviderDocIDTool.Handler) - getProviderDocsTool := GetProviderDocs(registryClient, logger) + getProviderDocsTool := GetProviderDocs(logger) hcServer.AddTool(getProviderDocsTool.Tool, getProviderDocsTool.Handler) - // Module tools - getSearchModulesTool := SearchModules(registryClient, logger) + // Module tools (always available) + getSearchModulesTool := SearchModules(logger) hcServer.AddTool(getSearchModulesTool.Tool, getSearchModulesTool.Handler) - getModuleDetailsTool := ModuleDetails(registryClient, logger) + getModuleDetailsTool := ModuleDetails(logger) hcServer.AddTool(getModuleDetailsTool.Tool, getModuleDetailsTool.Handler) - // Policy tools - getSearchPoliciesTool := SearchPolicies(registryClient, logger) + // Policy tools (always available) + getSearchPoliciesTool := SearchPolicies(logger) hcServer.AddTool(getSearchPoliciesTool.Tool, getSearchPoliciesTool.Handler) - getPolicyDetailsTool := PolicyDetails(registryClient, logger) + getPolicyDetailsTool := PolicyDetails(logger) hcServer.AddTool(getPolicyDetailsTool.Tool, getPolicyDetailsTool.Handler) } From 4c8a7757f4547dea7746db7e4edfa8a2dff2c4e2 Mon Sep 17 00:00:00 2001 From: Gautam Date: Thu, 31 Jul 2025 13:57:21 -0700 Subject: [PATCH 02/19] Update pkg/client/middleware.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- pkg/client/middleware.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/client/middleware.go b/pkg/client/middleware.go index baa8d94..27505da 100644 --- a/pkg/client/middleware.go +++ b/pkg/client/middleware.go @@ -147,7 +147,7 @@ func TerraformContextMiddleware(logger *log.Logger) func(http.Handler) http.Hand // Explicitly disallow TerraformToken in query parameters for security reasons if header == TerraformToken && headerValue != "" { - logger.Info(fmt.Sprintf("Terraform token was provided in query parameters by client %v, termiating request", r.RemoteAddr)) + logger.Info(fmt.Sprintf("Terraform token was provided in query parameters by client %v, terminating request", r.RemoteAddr)) http.Error(w, "Terraform token should not be provided in query parameters for security reasons, use the terraform_token header", http.StatusBadRequest) return } From e82de50cff05a893df3b9a8b5f198c6ad623e7df Mon Sep 17 00:00:00 2001 From: Gautam Date: Thu, 31 Jul 2025 14:01:27 -0700 Subject: [PATCH 03/19] removing unnecessary comment --- pkg/client/middleware.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pkg/client/middleware.go b/pkg/client/middleware.go index 27505da..ba86c53 100644 --- a/pkg/client/middleware.go +++ b/pkg/client/middleware.go @@ -132,12 +132,6 @@ func TerraformContextMiddleware(logger *log.Logger) func(http.Handler) http.Hand return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requiredHeaders := []string{TerraformAddress, TerraformToken, TerraformSkipTLSVerify} ctx := r.Context() - /* - if !r.URL.Query().Has("Authorization") || r.Header.Get("Authorization") == "" { - http.Error(w, "Unauthorized: Please provide valid credentials", http.StatusUnauthorized) - return - } - */ for _, header := range requiredHeaders { // Priority order: HTTP header -> Query parameter -> Environment variable headerValue := r.Header.Get(textproto.CanonicalMIMEHeaderKey(header)) From 39f57edf537328c3dc0146cec90700f50edbc986 Mon Sep 17 00:00:00 2001 From: Gautam Date: Thu, 31 Jul 2025 14:02:10 -0700 Subject: [PATCH 04/19] Update cmd/terraform-mcp-server/init.go Co-authored-by: Deniz Onur Duzgun <59659739+dduzgun-security@users.noreply.github.com> --- cmd/terraform-mcp-server/init.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cmd/terraform-mcp-server/init.go b/cmd/terraform-mcp-server/init.go index 4da9a06..9f63f2c 100644 --- a/cmd/terraform-mcp-server/init.go +++ b/cmd/terraform-mcp-server/init.go @@ -179,9 +179,7 @@ func streamableHTTPServerInit(ctx context.Context, hcServer *server.MCPServer, l isStateless := shouldUseStatelessMode() // Ensure endpoint path starts with / - if !strings.HasPrefix(endpointPath, "/") { - endpointPath = "/" + endpointPath - } + endpointPath = path.Join("/", endpointPath) // Create StreamableHTTP server which implements the new streamable-http transport // This is the modern MCP transport that supports both direct HTTP responses and SSE streams opts := []server.StreamableHTTPOption{ From 9586b7e21044804ef55061dbcfc4d8878d287fe7 Mon Sep 17 00:00:00 2001 From: Gautam Date: Thu, 31 Jul 2025 14:02:58 -0700 Subject: [PATCH 05/19] Update pkg/client/client.go Co-authored-by: Deniz Onur Duzgun <59659739+dduzgun-security@users.noreply.github.com> --- pkg/client/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/client/client.go b/pkg/client/client.go index e6767ce..5640341 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -61,7 +61,7 @@ func NewTerraformClient(sessionId string, terraformAddress string, terraformSkip client, err := tfe.NewClient(config) if err != nil { - logger.Warnf("Failed to create a Terraform Cloud/Enterprise client: %s, %v", sessionId, err) + logger.Warnf("Failed to create a Terraform Cloud/Enterprise client: %v", err) return terraformClients } From 2ac2692b82e196e2666b5f51ead9ba6593e498a6 Mon Sep 17 00:00:00 2001 From: Gautam Date: Thu, 31 Jul 2025 14:03:20 -0700 Subject: [PATCH 06/19] Update pkg/client/client.go Co-authored-by: Deniz Onur Duzgun <59659739+dduzgun-security@users.noreply.github.com> --- pkg/client/client.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/pkg/client/client.go b/pkg/client/client.go index 5640341..ff417d3 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -90,8 +90,6 @@ func GetTerraformClientFromContext(ctx context.Context, logger *log.Logger) (*te return nil, fmt.Errorf("no active session") } - // Log the session ID for debugging - logger.WithField("session_id", session.SessionID()).Debug("Retrieving Terraform client for session") // Try to get existing client client := GetTerraformClient(session.SessionID()) From 2d95e85bac072544600a9d756aa6ac2478b11a27 Mon Sep 17 00:00:00 2001 From: Gautam Date: Thu, 31 Jul 2025 14:03:58 -0700 Subject: [PATCH 07/19] Update pkg/client/client.go Co-authored-by: Deniz Onur Duzgun <59659739+dduzgun-security@users.noreply.github.com> --- pkg/client/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/client/client.go b/pkg/client/client.go index ff417d3..3f01993 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -97,7 +97,7 @@ func GetTerraformClientFromContext(ctx context.Context, logger *log.Logger) (*te return client, nil } - logger.WithField("session_id", session.SessionID()).Warn("Terraform client not found, creating a new one") + logger.Warnf("Terraform client not found, creating a new one") return CreateTerraformClientForSession(ctx, session, logger) } From ce4ecc07f084cd213e40308c4acca0bead5fd236 Mon Sep 17 00:00:00 2001 From: Gautam Date: Thu, 31 Jul 2025 14:04:29 -0700 Subject: [PATCH 08/19] Update pkg/tools/search_policies.go Co-authored-by: Deniz Onur Duzgun <59659739+dduzgun-security@users.noreply.github.com> --- pkg/tools/search_policies.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/tools/search_policies.go b/pkg/tools/search_policies.go index 1b6a666..cac9bab 100644 --- a/pkg/tools/search_policies.go +++ b/pkg/tools/search_policies.go @@ -63,7 +63,7 @@ func getSearchPoliciesHandler(ctx context.Context, request mcp.CallToolRequest, httpClient := terraformClients.HttpClient // static list of 100 is fine for now - policyResp, err := client.SendRegistryCall(httpClient, "GET", "policies?page%5Bsize%5D=100&include=latest-version", logger, "v2") + policyResp, err := client.SendRegistryCall(httpClient, "GET", (&url.URL{Path: "policies", RawQuery: url.Values{"page[size]": {"100"}, "include": {"latest-version"}}.Encode()}).String(), logger, "v2") if err != nil { return nil, utils.LogAndReturnError(logger, "Failed to fetch policies: registry API did not return a successful response", err) } From f6df99c98827b72c77f11a11ca3c61c01aeef014 Mon Sep 17 00:00:00 2001 From: Gautam Date: Thu, 31 Jul 2025 14:04:41 -0700 Subject: [PATCH 09/19] Update pkg/tools/resolve_provider_doc_id.go Co-authored-by: Deniz Onur Duzgun <59659739+dduzgun-security@users.noreply.github.com> --- pkg/tools/resolve_provider_doc_id.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/tools/resolve_provider_doc_id.go b/pkg/tools/resolve_provider_doc_id.go index 06a6901..e8e00b0 100644 --- a/pkg/tools/resolve_provider_doc_id.go +++ b/pkg/tools/resolve_provider_doc_id.go @@ -103,7 +103,7 @@ func resolveProviderDocIDHandler(ctx context.Context, request mcp.CallToolReques } // For resources/data-sources, use the v1 API for better performance (single response) - uri := fmt.Sprintf("providers/%s/%s/%s", providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion) + uri := path.Join("providers", providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion) response, err := client.SendRegistryCall(httpClient, "GET", uri, logger) if err != nil { return nil, utils.LogAndReturnError(logger, fmt.Sprintf(`Error getting the "%s" provider, with version "%s" in the %s namespace, %s`, providerDetail.ProviderName, providerDetail.ProviderVersion, providerDetail.ProviderNamespace, defaultErrorGuide), nil) From ee6d810632792196ddc45a7b1a0d4663f62e8ca6 Mon Sep 17 00:00:00 2001 From: Gautam Date: Thu, 31 Jul 2025 14:04:55 -0700 Subject: [PATCH 10/19] Update pkg/tools/policy_details.go Co-authored-by: Deniz Onur Duzgun <59659739+dduzgun-security@users.noreply.github.com> --- pkg/tools/policy_details.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/tools/policy_details.go b/pkg/tools/policy_details.go index 93eead2..d716d5f 100644 --- a/pkg/tools/policy_details.go +++ b/pkg/tools/policy_details.go @@ -52,7 +52,7 @@ func getPolicyDetailsHandler(ctx context.Context, request mcp.CallToolRequest, l } httpClient := terraformClients.HttpClient - policyResp, err := client.SendRegistryCall(httpClient, "GET", fmt.Sprintf("%s?include=policies,policy-modules,policy-library", terraformPolicyID), logger, "v2") + policyResp, err := client.SendRegistryCall(httpClient, "GET", (&url.URL{Path: terraformPolicyID, RawQuery: url.Values{"include": {"policies,policy-modules,policy-library"}}.Encode()}).String(), logger, "v2") if err != nil { return nil, utils.LogAndReturnError(logger, "Failed to fetch policy details: registry API did not return a successful response", err) } From 144673d6b9c91782102a728304e52c46589de39d Mon Sep 17 00:00:00 2001 From: Gautam Date: Thu, 31 Jul 2025 14:05:36 -0700 Subject: [PATCH 11/19] Update pkg/tools/get_provider_docs.go Co-authored-by: Deniz Onur Duzgun <59659739+dduzgun-security@users.noreply.github.com> --- pkg/tools/get_provider_docs.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/tools/get_provider_docs.go b/pkg/tools/get_provider_docs.go index c4df71e..c91ab99 100644 --- a/pkg/tools/get_provider_docs.go +++ b/pkg/tools/get_provider_docs.go @@ -52,7 +52,7 @@ func getProviderDocsHandler(ctx context.Context, request mcp.CallToolRequest, lo httpClient := terraformClients.HttpClient - detailResp, err := client.SendRegistryCall(httpClient, "GET", fmt.Sprintf("provider-docs/%s", providerDocID), logger, "v2") + detailResp, err := client.SendRegistryCall(httpClient, "GET", path.Join("provider-docs", providerDocID), logger, "v2") if err != nil { return nil, utils.LogAndReturnError(logger, fmt.Sprintf("Error fetching provider-docs/%s, please make sure provider_doc_id is valid and the resolve_provider_doc_id tool has run prior", providerDocID), err) } From b25173fcbf84f51fc0fc7e5bcdc1be0bfc8562df Mon Sep 17 00:00:00 2001 From: Gautam Date: Thu, 31 Jul 2025 14:05:53 -0700 Subject: [PATCH 12/19] Update pkg/client/client.go Co-authored-by: Deniz Onur Duzgun <59659739+dduzgun-security@users.noreply.github.com> --- pkg/client/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/client/client.go b/pkg/client/client.go index 3f01993..1a09463 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -142,7 +142,7 @@ func NewSessionHandler(ctx context.Context, session server.ClientSession, logger if registryCallback := getToolRegistryCallback(); registryCallback != nil { registryCallback.RegisterSessionWithTFE(session.SessionID()) } - logger.WithField("session_id", session.SessionID()).Info("Session has valid TFE client - registered with tool registry") + logger.Info("Session has valid TFE client - registered with tool registry") } else { logger.WithField("session_id", session.SessionID()).Info("Session has no valid TFE client - TFE tools will not be available") } From a6b3dbbacadc7525ae4f26c9c6195e02cbe4b09d Mon Sep 17 00:00:00 2001 From: Gautam Date: Thu, 31 Jul 2025 14:06:08 -0700 Subject: [PATCH 13/19] Update pkg/client/client.go Co-authored-by: Deniz Onur Duzgun <59659739+dduzgun-security@users.noreply.github.com> --- pkg/client/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/client/client.go b/pkg/client/client.go index 1a09463..3b4ed8c 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -144,7 +144,7 @@ func NewSessionHandler(ctx context.Context, session server.ClientSession, logger } logger.Info("Session has valid TFE client - registered with tool registry") } else { - logger.WithField("session_id", session.SessionID()).Info("Session has no valid TFE client - TFE tools will not be available") + logger.Info("Session has no valid TFE client - TFE tools will not be available") } } From d3073d2d27aaa2135400abb348fe2acf6473119a Mon Sep 17 00:00:00 2001 From: Gautam Date: Thu, 31 Jul 2025 14:06:27 -0700 Subject: [PATCH 14/19] Update pkg/resources/resource_templates.go Co-authored-by: Deniz Onur Duzgun <59659739+dduzgun-security@users.noreply.github.com> --- pkg/resources/resource_templates.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/resources/resource_templates.go b/pkg/resources/resource_templates.go index 663068f..efc3566 100644 --- a/pkg/resources/resource_templates.go +++ b/pkg/resources/resource_templates.go @@ -16,7 +16,7 @@ import ( ) func RegisterResourceTemplates(hcServer *server.MCPServer, logger *log.Logger) { - hcServer.AddResourceTemplate(ProviderResourceTemplate(fmt.Sprintf("%s/{namespace}/name/{name}/version/{version}", utils.PROVIDER_BASE_PATH), "Provider details", logger)) + hcServer.AddResourceTemplate(ProviderResourceTemplate(path.Join(utils.PROVIDER_BASE_PATH, "{namespace}", "name", "{name}", "version", "{version}"), utils.PROVIDER_BASE_PATH), "Provider details", logger)) } func ProviderResourceTemplate(resourceURI string, description string, logger *log.Logger) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { From cd5150d8310438e5b8e8e4a1b81d3174e4a23400 Mon Sep 17 00:00:00 2001 From: Gautam Date: Thu, 31 Jul 2025 14:06:51 -0700 Subject: [PATCH 15/19] Update pkg/tools/dynamic_tool.go Co-authored-by: Deniz Onur Duzgun <59659739+dduzgun-security@users.noreply.github.com> --- pkg/tools/dynamic_tool.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/tools/dynamic_tool.go b/pkg/tools/dynamic_tool.go index 8a9886a..8827161 100644 --- a/pkg/tools/dynamic_tool.go +++ b/pkg/tools/dynamic_tool.go @@ -124,7 +124,6 @@ func (r *DynamicToolRegistry) createDynamicTFETool(toolName string, toolFactory if terraformClient == nil || terraformClient.TfeClient == nil { r.logger.WithFields(log.Fields{ "tool": toolName, - "session_id": sessionID, }).Warn("TFE tool called but session has no valid TFE client") return mcp.NewToolResultError("This tool is not available. This tool requires a valid Terraform Cloud/Enterprise token and configuration. Please ensure TFE_TOKEN and TFE_ADDRESS environment variables are properly set."), nil From b375622a1f225941fff5c225ec6f69542073e6ae Mon Sep 17 00:00:00 2001 From: Gautam Date: Thu, 31 Jul 2025 14:07:08 -0700 Subject: [PATCH 16/19] Update pkg/tools/dynamic_tool.go Co-authored-by: Deniz Onur Duzgun <59659739+dduzgun-security@users.noreply.github.com> --- pkg/tools/dynamic_tool.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/tools/dynamic_tool.go b/pkg/tools/dynamic_tool.go index 8827161..b970c92 100644 --- a/pkg/tools/dynamic_tool.go +++ b/pkg/tools/dynamic_tool.go @@ -48,7 +48,7 @@ func (r *DynamicToolRegistry) RegisterSessionWithTFE(sessionID string) { defer r.mu.Unlock() r.sessionsWithTFE[sessionID] = true - r.logger.WithField("session_id", sessionID).Info("Session registered with TFE client") + r.logger.Info("Session registered with TFE client") // If this is the first session with TFE, register the tools if !r.tfeToolsRegistered { From 36eb20df455eea7d91a013cf9669844e25cddcd4 Mon Sep 17 00:00:00 2001 From: Gautam Date: Thu, 31 Jul 2025 14:07:25 -0700 Subject: [PATCH 17/19] Update pkg/tools/dynamic_tool.go Co-authored-by: Deniz Onur Duzgun <59659739+dduzgun-security@users.noreply.github.com> --- pkg/tools/dynamic_tool.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/tools/dynamic_tool.go b/pkg/tools/dynamic_tool.go index b970c92..6e3e060 100644 --- a/pkg/tools/dynamic_tool.go +++ b/pkg/tools/dynamic_tool.go @@ -62,7 +62,7 @@ func (r *DynamicToolRegistry) UnregisterSessionWithTFE(sessionID string) { defer r.mu.Unlock() delete(r.sessionsWithTFE, sessionID) - r.logger.WithField("session_id", sessionID).Info("Session unregistered from TFE client") + r.logger.Info("Session unregistered from TFE client") // If no sessions have TFE clients, we could unregister tools // but since MCP doesn't support tool removal, we keep them registered From d3e1040cf769b1d11be9532e2d3a40fbcc3eefb1 Mon Sep 17 00:00:00 2001 From: Gautam Date: Thu, 31 Jul 2025 17:39:50 -0700 Subject: [PATCH 18/19] fixing security comments and copilot issues --- cmd/terraform-mcp-server/init.go | 1 + e2e/cors_e2e_test.go | 14 +- pkg/resources/resource_templates.go | 11 +- pkg/tools/dynamic_tool.go | 3 +- pkg/tools/get_provider_docs.go | 1 + pkg/tools/list_terraform_orgs.go | 13 +- pkg/tools/list_terraform_projects.go | 9 +- pkg/tools/policy_details.go | 54 ++- pkg/tools/resolve_provider_doc_id.go | 23 +- pkg/tools/search_policies.go | 11 +- pkg/utils/pagination.go | 96 +++++ pkg/utils/pagination_test.go | 560 +++++++++++++++++++++++++++ 12 files changed, 759 insertions(+), 37 deletions(-) create mode 100644 pkg/utils/pagination.go create mode 100644 pkg/utils/pagination_test.go diff --git a/cmd/terraform-mcp-server/init.go b/cmd/terraform-mcp-server/init.go index 9f63f2c..317d6df 100644 --- a/cmd/terraform-mcp-server/init.go +++ b/cmd/terraform-mcp-server/init.go @@ -10,6 +10,7 @@ import ( stdlog "log" "net/http" "os" + "path" "strings" "time" diff --git a/e2e/cors_e2e_test.go b/e2e/cors_e2e_test.go index dc06bce..771b37a 100644 --- a/e2e/cors_e2e_test.go +++ b/e2e/cors_e2e_test.go @@ -150,8 +150,8 @@ func runCORSTests(t *testing.T, mcpURL, mode, configuredOrigins string) { // Define base test cases that apply to all modes baseTestCases := []testCase{ - {"GET with allowed origin", "GET", "https://example.com", 202, true}, - {"GET with no origin", "GET", "", 202, false}, + {"GET with allowed origin", "GET", "https://example.com", 200, true}, + {"GET with no origin", "GET", "", 200, false}, {"OPTIONS preflight with allowed origin", "OPTIONS", "https://example.com", 200, true}, } @@ -163,15 +163,15 @@ func runCORSTests(t *testing.T, mcpURL, mode, configuredOrigins string) { } developmentModeTests := []testCase{ - {"GET with localhost origin", "GET", "http://localhost:3000", 202, true}, - {"GET with IPv4 localhost", "GET", "http://127.0.0.1:3000", 202, true}, - {"GET with IPv6 localhost", "GET", "http://[::1]:3000", 202, true}, + {"GET with localhost origin", "GET", "http://localhost:3000", 200, true}, + {"GET with IPv4 localhost", "GET", "http://127.0.0.1:3000", 200, true}, + {"GET with IPv6 localhost", "GET", "http://[::1]:3000", 200, true}, {"GET with disallowed origin", "GET", "https://evil.com", 403, false}, {"OPTIONS with localhost origin", "OPTIONS", "http://localhost:3000", 200, true}, } disabledModeTests := []testCase{ - {"GET with any origin", "GET", "https://any-site.com", 202, true}, + {"GET with any origin", "GET", "https://any-site.com", 200, true}, {"OPTIONS with any origin", "OPTIONS", "https://any-site.com", 200, true}, } @@ -195,7 +195,7 @@ func runCORSTests(t *testing.T, mcpURL, mode, configuredOrigins string) { var sessionID string if tc.method != "OPTIONS" { // Only try to initialize if we expect it to succeed - if tc.expectedStatus == 202 { + if tc.expectedStatus == 200 { sessionID = initializeMCPSession(t, mcpURL, tc.origin) require.NotEmpty(t, sessionID, "Expected to get a session ID for allowed origin") } else { diff --git a/pkg/resources/resource_templates.go b/pkg/resources/resource_templates.go index efc3566..0ca9b63 100644 --- a/pkg/resources/resource_templates.go +++ b/pkg/resources/resource_templates.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "net/http" + "path" "github.com/hashicorp/terraform-mcp-server/pkg/client" "github.com/hashicorp/terraform-mcp-server/pkg/utils" @@ -16,10 +17,16 @@ import ( ) func RegisterResourceTemplates(hcServer *server.MCPServer, logger *log.Logger) { - hcServer.AddResourceTemplate(ProviderResourceTemplate(path.Join(utils.PROVIDER_BASE_PATH, "{namespace}", "name", "{name}", "version", "{version}"), utils.PROVIDER_BASE_PATH), "Provider details", logger)) + hcServer.AddResourceTemplate( + providerResourceTemplate( + path.Join(utils.PROVIDER_BASE_PATH, "{namespace}", "name", "{name}", "version", "{version}"), + "Provider details", + logger, + ), + ) } -func ProviderResourceTemplate(resourceURI string, description string, logger *log.Logger) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { +func providerResourceTemplate(resourceURI string, description string, logger *log.Logger) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { return mcp.NewResourceTemplate( resourceURI, description, diff --git a/pkg/tools/dynamic_tool.go b/pkg/tools/dynamic_tool.go index 6e3e060..5fb8bbb 100644 --- a/pkg/tools/dynamic_tool.go +++ b/pkg/tools/dynamic_tool.go @@ -123,12 +123,11 @@ func (r *DynamicToolRegistry) createDynamicTFETool(toolName string, toolFactory terraformClient := client.GetTerraformClient(sessionID) if terraformClient == nil || terraformClient.TfeClient == nil { r.logger.WithFields(log.Fields{ - "tool": toolName, + "tool": toolName, }).Warn("TFE tool called but session has no valid TFE client") return mcp.NewToolResultError("This tool is not available. This tool requires a valid Terraform Cloud/Enterprise token and configuration. Please ensure TFE_TOKEN and TFE_ADDRESS environment variables are properly set."), nil } - // If we found a valid client that wasn't registered, register it now r.RegisterSessionWithTFE(sessionID) } diff --git a/pkg/tools/get_provider_docs.go b/pkg/tools/get_provider_docs.go index c91ab99..a751709 100644 --- a/pkg/tools/get_provider_docs.go +++ b/pkg/tools/get_provider_docs.go @@ -7,6 +7,7 @@ import ( "context" "encoding/json" "fmt" + "path" "github.com/hashicorp/terraform-mcp-server/pkg/client" "github.com/hashicorp/terraform-mcp-server/pkg/utils" diff --git a/pkg/tools/list_terraform_orgs.go b/pkg/tools/list_terraform_orgs.go index 416a728..aad2562 100644 --- a/pkg/tools/list_terraform_orgs.go +++ b/pkg/tools/list_terraform_orgs.go @@ -22,14 +22,15 @@ func ListTerraformOrgs(logger *log.Logger) server.ServerTool { mcp.WithDescription(`Fetches a list of all Terraform organizations.`), mcp.WithTitleAnnotation("List all Terraform organizations"), mcp.WithReadOnlyHintAnnotation(true), + utils.WithPagination(), ), Handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return listTerraformOrgsHandler(ctx, logger) + return listTerraformOrgsHandler(ctx, req, logger) }, } } -func listTerraformOrgsHandler(ctx context.Context, logger *log.Logger) (*mcp.CallToolResult, error) { +func listTerraformOrgsHandler(ctx context.Context, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { // Get a Terraform client from context terraformClients, err := client.GetTerraformClientFromContext(ctx, logger) if err != nil { @@ -41,9 +42,15 @@ func listTerraformOrgsHandler(ctx context.Context, logger *log.Logger) (*mcp.Cal return nil, utils.LogAndReturnError(logger, "TFE client is not available - please ensure TFE_TOKEN and TFE_ADDRESS are properly configured", nil) } + pagination, err := utils.OptionalPaginationParams(request) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + orgs, err := tfeClient.Organizations.List(ctx, &tfe.OrganizationListOptions{ ListOptions: tfe.ListOptions{ - PageSize: 100, + PageNumber: pagination.Page, + PageSize: pagination.PageSize, }, }) diff --git a/pkg/tools/list_terraform_projects.go b/pkg/tools/list_terraform_projects.go index b78ef66..7f86bc3 100644 --- a/pkg/tools/list_terraform_projects.go +++ b/pkg/tools/list_terraform_projects.go @@ -26,6 +26,7 @@ func ListTerraformProjects(logger *log.Logger) server.ServerTool { mcp.Required(), mcp.Description("The name of the Terraform organization to list projects for."), ), + utils.WithPagination(), ), Handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { return listTerraformProjectsHandler(ctx, req, logger) @@ -42,6 +43,11 @@ func listTerraformProjectsHandler(ctx context.Context, request mcp.CallToolReque return nil, utils.LogAndReturnError(logger, "terraform_org_name cannot be empty", nil) } + pagination, err := utils.OptionalPaginationParams(request) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + // Get a Terraform client from context terraformClients, err := client.GetTerraformClientFromContext(ctx, logger) if err != nil { @@ -55,7 +61,8 @@ func listTerraformProjectsHandler(ctx context.Context, request mcp.CallToolReque // Fetch the list of projects projects, err := tfeClient.Projects.List(ctx, terraformOrgName, &tfe.ProjectListOptions{ ListOptions: tfe.ListOptions{ - PageSize: 100, + PageNumber: pagination.Page, + PageSize: pagination.PageSize, }, }) diff --git a/pkg/tools/policy_details.go b/pkg/tools/policy_details.go index d716d5f..8dd40ec 100644 --- a/pkg/tools/policy_details.go +++ b/pkg/tools/policy_details.go @@ -7,7 +7,9 @@ import ( "context" "encoding/json" "fmt" + "net/url" "strings" + "text/template" "github.com/hashicorp/terraform-mcp-server/pkg/client" "github.com/hashicorp/terraform-mcp-server/pkg/utils" @@ -69,11 +71,28 @@ func getPolicyDetailsHandler(ctx context.Context, request mcp.CallToolRequest, l moduleList := "" for _, policy := range policyDetails.Included { if policy.Type == "policy-modules" { - moduleList += fmt.Sprintf(` -module "%s" { - source = "https://registry.terraform.io/v2%s/policy-module/%s.sentinel?checksum=sha256:%s" + // Use text/template to safely build the module block + var moduleBuilder strings.Builder + tmpl := ` +module "{{.Name}}" { + source = "https://registry.terraform.io/v2{{.PolicyID}}/policy-module/{{.Name}}.sentinel?checksum=sha256:{{.Shasum}}" } -`, policy.Attributes.Name, terraformPolicyID, policy.Attributes.Name, policy.Attributes.Shasum) +` + type moduleData struct { + Name string + PolicyID string + Shasum string + } + t := template.Must(template.New("module").Parse(tmpl)) + err := t.Execute(&moduleBuilder, moduleData{ + Name: policy.Attributes.Name, + PolicyID: terraformPolicyID, + Shasum: policy.Attributes.Shasum, + }) + if err != nil { + logger.WithError(err).Error("failed to render module template") + } + moduleList += moduleBuilder.String() } if policy.Type == "policies" { @@ -85,13 +104,30 @@ module "%s" { builder.WriteString("## Usage\n\n") builder.WriteString("Generate the content for a HashiCorp Configuration Language (HCL) file named policies.hcl. This file should define a set of policies. For each policy provided, create a distinct policy block using the following template.\n") builder.WriteString("\n```hcl\n") - hclTemplate := fmt.Sprintf(` -%s + // Use text/template to safely build the HCL template for policies + hclTmpl := ` +{{- if .ModuleList }} +{{ .ModuleList }} +{{- end }} policy "<>" { - source = "https://registry.terraform.io/v2%s/policy/<>.sentinel?checksum=<>" - enforcement_level = "advisory" + source = "https://registry.terraform.io/v2{{ .TerraformPolicyID }}/policy/<>.sentinel?checksum=<>" + enforcement_level = "advisory" } -`, moduleList, terraformPolicyID) +` + type hclTemplateData struct { + ModuleList string + TerraformPolicyID string + } + var hclBuilder strings.Builder + t := template.Must(template.New("hclPolicy").Parse(hclTmpl)) + err = t.Execute(&hclBuilder, hclTemplateData{ + ModuleList: moduleList, + TerraformPolicyID: terraformPolicyID, + }) + if err != nil { + logger.WithError(err).Error("failed to render HCL policy template") + } + hclTemplate := hclBuilder.String() builder.WriteString(hclTemplate) builder.WriteString("\n```\n") builder.WriteString(fmt.Sprintf("Available policies with SHA for %s are: \n\n", terraformPolicyID)) diff --git a/pkg/tools/resolve_provider_doc_id.go b/pkg/tools/resolve_provider_doc_id.go index a9ac409..813025a 100644 --- a/pkg/tools/resolve_provider_doc_id.go +++ b/pkg/tools/resolve_provider_doc_id.go @@ -8,6 +8,7 @@ import ( "encoding/json" "fmt" "net/http" + "path" "strings" "github.com/hashicorp/terraform-mcp-server/pkg/client" @@ -126,7 +127,7 @@ func resolveProviderDocIDHandler(ctx context.Context, request mcp.CallToolReques cs_pn, err_pn := utils.ContainsSlug(fmt.Sprintf("%s_%s", providerDetail.ProviderName, doc.Slug), serviceSlug) if (cs || cs_pn) && err == nil && err_pn == nil { contentAvailable = true - descriptionSnippet, err := getContentSnippet(registryClient, doc.ID, logger) + descriptionSnippet, err := getContentSnippet(httpClient, doc.ID, logger) if err != nil { logger.Warnf("Error fetching content snippet for provider doc ID: %s: %v", doc.ID, err) } @@ -143,7 +144,7 @@ func resolveProviderDocIDHandler(ctx context.Context, request mcp.CallToolReques return mcp.NewToolResultText(builder.String()), nil } -func resolveProviderDetails(request mcp.CallToolRequest, registryClient *http.Client, defaultErrorGuide string, logger *log.Logger) (client.ProviderDetail, error) { +func resolveProviderDetails(request mcp.CallToolRequest, httpClient *http.Client, defaultErrorGuide string, logger *log.Logger) (client.ProviderDetail, error) { providerDetail := client.ProviderDetail{} providerName := request.GetString("provider_name", "") if providerName == "" { @@ -164,7 +165,7 @@ func resolveProviderDetails(request mcp.CallToolRequest, registryClient *http.Cl if utils.IsValidProviderVersionFormat(providerVersion) { providerVersionValue = providerVersion } else { - providerVersionValue, err = client.GetLatestProviderVersion(registryClient, providerNamespace, providerName, logger) + providerVersionValue, err = client.GetLatestProviderVersion(httpClient, providerNamespace, providerName, logger) if err != nil { providerVersionValue = "" logger.Debugf("Error getting latest provider version in %s namespace: %v", providerNamespace, err) @@ -174,7 +175,7 @@ func resolveProviderDetails(request mcp.CallToolRequest, registryClient *http.Cl // If the provider version doesn't exist, try the hashicorp namespace if providerVersionValue == "" { tryProviderNamespace := "hashicorp" - providerVersionValue, err = client.GetLatestProviderVersion(registryClient, tryProviderNamespace, providerName, logger) + providerVersionValue, err = client.GetLatestProviderVersion(httpClient, tryProviderNamespace, providerName, logger) if err != nil { // Just so we don't print the same namespace twice if they are the same if providerNamespace != tryProviderNamespace { @@ -198,20 +199,20 @@ func resolveProviderDetails(request mcp.CallToolRequest, registryClient *http.Cl } // get_provider_docsV2 retrieves a list of documentation items for a specific provider category using v2 API with support for pagination using page numbers -func get_provider_docsV2(registryClient *http.Client, providerDetail client.ProviderDetail, logger *log.Logger) (string, error) { - providerVersionID, err := client.GetProviderVersionID(registryClient, providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion, logger) +func get_provider_docsV2(httpClient *http.Client, providerDetail client.ProviderDetail, logger *log.Logger) (string, error) { + providerVersionID, err := client.GetProviderVersionID(httpClient, providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion, logger) if err != nil { return "", utils.LogAndReturnError(logger, "getting provider version ID", err) } category := providerDetail.ProviderDataType if category == "overview" { - return client.GetProviderOverviewDocs(registryClient, providerVersionID, logger) + return client.GetProviderOverviewDocs(httpClient, providerVersionID, logger) } uriPrefix := fmt.Sprintf("provider-docs?filter[provider-version]=%s&filter[category]=%s&filter[language]=hcl", providerVersionID, category) - docs, err := client.SendPaginatedRegistryCall(registryClient, uriPrefix, logger) + docs, err := client.SendPaginatedRegistryCall(httpClient, uriPrefix, logger) if err != nil { return "", utils.LogAndReturnError(logger, "getting provider documentation", err) } @@ -225,7 +226,7 @@ func get_provider_docsV2(registryClient *http.Client, providerDetail client.Prov builder.WriteString("Each result includes:\n- providerDocID: tfprovider-compatible identifier\n- Title: Service or resource name\n- Category: Type of document\n- Description: Brief summary of the document\n") builder.WriteString("For best results, select libraries based on the service_slug match and category of information requested.\n\n---\n\n") for _, doc := range docs { - descriptionSnippet, err := getContentSnippet(registryClient, doc.ID, logger) + descriptionSnippet, err := getContentSnippet(httpClient, doc.ID, logger) if err != nil { logger.Warnf("Error fetching content snippet for provider doc ID: %s: %v", doc.ID, err) } @@ -235,8 +236,8 @@ func get_provider_docsV2(registryClient *http.Client, providerDetail client.Prov return builder.String(), nil } -func getContentSnippet(registryClient *http.Client, docID string, logger *log.Logger) (string, error) { - docContent, err := client.SendRegistryCall(registryClient, "GET", fmt.Sprintf("provider-docs/%s", docID), logger, "v2") +func getContentSnippet(httpClient *http.Client, docID string, logger *log.Logger) (string, error) { + docContent, err := client.SendRegistryCall(httpClient, "GET", fmt.Sprintf("provider-docs/%s", docID), logger, "v2") if err != nil { return "", utils.LogAndReturnError(logger, fmt.Sprintf("error fetching provider-docs/%s within getContentSnippet", docID), err) } diff --git a/pkg/tools/search_policies.go b/pkg/tools/search_policies.go index cac9bab..29dc159 100644 --- a/pkg/tools/search_policies.go +++ b/pkg/tools/search_policies.go @@ -7,6 +7,7 @@ import ( "context" "encoding/json" "fmt" + "net/url" "strings" "github.com/hashicorp/terraform-mcp-server/pkg/client" @@ -62,8 +63,14 @@ func getSearchPoliciesHandler(ctx context.Context, request mcp.CallToolRequest, } httpClient := terraformClients.HttpClient - // static list of 100 is fine for now - policyResp, err := client.SendRegistryCall(httpClient, "GET", (&url.URL{Path: "policies", RawQuery: url.Values{"page[size]": {"100"}, "include": {"latest-version"}}.Encode()}).String(), logger, "v2") + uri := (&url.URL{ + Path: "policies", + RawQuery: url.Values{ + "page[size]": {"100"}, // static list of 100 is fine for now + "include": {"latest-version"}, + }.Encode(), + }).String() + policyResp, err := client.SendRegistryCall(httpClient, "GET", uri, logger, "v2") if err != nil { return nil, utils.LogAndReturnError(logger, "Failed to fetch policies: registry API did not return a successful response", err) } diff --git a/pkg/utils/pagination.go b/pkg/utils/pagination.go new file mode 100644 index 0000000..e62e070 --- /dev/null +++ b/pkg/utils/pagination.go @@ -0,0 +1,96 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package utils + +import ( + "fmt" + + "github.com/mark3labs/mcp-go/mcp" +) + +type PaginationParams struct { + Page int + PageSize int + After string +} + +// OptionalParam is a helper function to retrieve an optional parameter from the request. +// It returns the value as type T and an error if the parameter is not present or cannot be converted. +func OptionalParam[T any](r mcp.CallToolRequest, p string) (T, error) { + var zero T + + // Check if the parameter exists in the request + if _, ok := r.GetArguments()[p]; !ok { + return zero, nil + } + + // Check if the parameter can be converted to type T + if _, ok := r.GetArguments()[p].(T); !ok { + return zero, fmt.Errorf("parameter %s is not of type %T, is %T", p, zero, r.GetArguments()[p]) + } + + return r.GetArguments()[p].(T), nil +} + +// OptionalIntParam is a helper function to retrieve an optional integer parameter from the request. +// It returns the value as an int and an error if the parameter is not present or cannot be converted. +func OptionalIntParam(r mcp.CallToolRequest, p string) (int, error) { + v, err := OptionalParam[float64](r, p) + if err != nil { + return 0, err + } + return int(v), nil +} + +// OptionalIntParamWithDefault retrieves an optional integer parameter from the request. +// If the parameter is not present or is zero, it returns the default value. +func OptionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) { + v, err := OptionalIntParam(r, p) + if err != nil { + return 0, err + } + if v == 0 { + return d, nil + } + return v, nil +} + +// OptionalPaginationParams returns pagination parameters from the request. +// It retrieves "page", "pageSize", and "after" parameters, providing defaults where necessary. +func OptionalPaginationParams(r mcp.CallToolRequest) (PaginationParams, error) { + page, err := OptionalIntParamWithDefault(r, "page", 1) + if err != nil { + return PaginationParams{}, err + } + pageSize, err := OptionalIntParamWithDefault(r, "pageSize", 30) + if err != nil { + return PaginationParams{}, err + } + after, err := OptionalParam[string](r, "after") + if err != nil { + return PaginationParams{}, err + } + return PaginationParams{ + Page: page, + PageSize: pageSize, + After: after, + }, nil +} + +// WithPagination adds pagination parameters to a tool. +// It adds "page", "pageSize", and "after" parameters with appropriate descriptions and defaults. +func WithPagination() mcp.ToolOption { + return func(tool *mcp.Tool) { + mcp.WithNumber("page", + mcp.Description("Page number for pagination (min 1)"), + mcp.Min(1), + )(tool) + + mcp.WithNumber("pageSize", + mcp.Description("Results per page for pagination (min 1, max 100)"), + mcp.Min(1), + mcp.Max(100), + )(tool) + } +} diff --git a/pkg/utils/pagination_test.go b/pkg/utils/pagination_test.go new file mode 100644 index 0000000..00d9e63 --- /dev/null +++ b/pkg/utils/pagination_test.go @@ -0,0 +1,560 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package utils + +import ( + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockCallToolRequest creates a mock CallToolRequest with the given arguments +func mockCallToolRequest(args map[string]interface{}) mcp.CallToolRequest { + req := mcp.CallToolRequest{} + req.Params.Arguments = args + return req +} + +func TestOptionalParam(t *testing.T) { + tests := []struct { + name string + args map[string]interface{} + param string + expectValue interface{} + expectError bool + errorMsg string + }{ + { + name: "string parameter exists", + args: map[string]interface{}{"test": "value"}, + param: "test", + expectValue: "value", + expectError: false, + }, + { + name: "parameter does not exist", + args: map[string]interface{}{"other": "value"}, + param: "missing", + expectValue: "", + expectError: false, + }, + { + name: "parameter exists but wrong type", + args: map[string]interface{}{"test": 123}, + param: "test", + expectValue: "", + expectError: true, + errorMsg: "parameter test is not of type string, is int", + }, + { + name: "empty args map", + args: map[string]interface{}{}, + param: "test", + expectValue: "", + expectError: false, + }, + { + name: "nil value in args", + args: map[string]interface{}{"test": nil}, + param: "test", + expectValue: "", + expectError: true, + errorMsg: "parameter test is not of type string, is ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := mockCallToolRequest(tt.args) + + // Test with string type + result, err := OptionalParam[string](req, tt.param) + + if tt.expectError { + require.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + require.NoError(t, err) + if tt.expectValue != nil { + assert.Equal(t, tt.expectValue, result) + } else { + assert.Equal(t, "", result) // zero value for string + } + } + }) + } +} + +func TestOptionalParam_DifferentTypes(t *testing.T) { + tests := []struct { + name string + args map[string]interface{} + param string + testType string + expectValue interface{} + expectError bool + }{ + { + name: "int type", + args: map[string]interface{}{"value": 42}, + param: "value", + testType: "int", + expectValue: 42, + expectError: false, + }, + { + name: "float64 type", + args: map[string]interface{}{"value": 3.14}, + param: "value", + testType: "float64", + expectValue: 3.14, + expectError: false, + }, + { + name: "bool type", + args: map[string]interface{}{"value": true}, + param: "value", + testType: "bool", + expectValue: true, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := mockCallToolRequest(tt.args) + + switch tt.testType { + case "int": + result, err := OptionalParam[int](req, tt.param) + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectValue, result) + } + case "float64": + result, err := OptionalParam[float64](req, tt.param) + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectValue, result) + } + case "bool": + result, err := OptionalParam[bool](req, tt.param) + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectValue, result) + } + } + }) + } +} + +func TestOptionalIntParam(t *testing.T) { + tests := []struct { + name string + args map[string]interface{} + param string + expectValue int + expectError bool + errorMsg string + }{ + { + name: "valid float64 converts to int", + args: map[string]interface{}{"count": 42.0}, + param: "count", + expectValue: 42, + expectError: false, + }, + { + name: "valid float64 with decimal converts to int", + args: map[string]interface{}{"count": 42.7}, + param: "count", + expectValue: 42, + expectError: false, + }, + { + name: "parameter does not exist", + args: map[string]interface{}{"other": 123.0}, + param: "missing", + expectValue: 0, + expectError: false, + }, + { + name: "parameter exists but wrong type", + args: map[string]interface{}{"count": "not a number"}, + param: "count", + expectValue: 0, + expectError: true, + errorMsg: "parameter count is not of type float64, is string", + }, + { + name: "zero value", + args: map[string]interface{}{"count": 0.0}, + param: "count", + expectValue: 0, + expectError: false, + }, + { + name: "negative value", + args: map[string]interface{}{"count": -5.0}, + param: "count", + expectValue: -5, + expectError: false, + }, + { + name: "large value", + args: map[string]interface{}{"count": 999999.0}, + param: "count", + expectValue: 999999, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := mockCallToolRequest(tt.args) + + result, err := OptionalIntParam(req, tt.param) + + if tt.expectError { + require.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectValue, result) + } + }) + } +} + +func TestOptionalIntParamWithDefault(t *testing.T) { + tests := []struct { + name string + args map[string]interface{} + param string + defaultValue int + expectValue int + expectError bool + errorMsg string + }{ + { + name: "valid value returns value", + args: map[string]interface{}{"count": 42.0}, + param: "count", + defaultValue: 10, + expectValue: 42, + expectError: false, + }, + { + name: "parameter does not exist returns default", + args: map[string]interface{}{"other": 123.0}, + param: "missing", + defaultValue: 10, + expectValue: 10, + expectError: false, + }, + { + name: "zero value returns default", + args: map[string]interface{}{"count": 0.0}, + param: "count", + defaultValue: 10, + expectValue: 10, + expectError: false, + }, + { + name: "parameter exists but wrong type", + args: map[string]interface{}{"count": "not a number"}, + param: "count", + defaultValue: 10, + expectValue: 0, + expectError: true, + errorMsg: "parameter count is not of type float64, is string", + }, + { + name: "negative value returns value (not default)", + args: map[string]interface{}{"count": -5.0}, + param: "count", + defaultValue: 10, + expectValue: -5, + expectError: false, + }, + { + name: "default value is zero", + args: map[string]interface{}{"other": 123.0}, + param: "missing", + defaultValue: 0, + expectValue: 0, + expectError: false, + }, + { + name: "default value is negative", + args: map[string]interface{}{"other": 123.0}, + param: "missing", + defaultValue: -1, + expectValue: -1, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := mockCallToolRequest(tt.args) + + result, err := OptionalIntParamWithDefault(req, tt.param, tt.defaultValue) + + if tt.expectError { + require.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectValue, result) + } + }) + } +} + +func TestOptionalPaginationParams(t *testing.T) { + tests := []struct { + name string + args map[string]interface{} + expectParams PaginationParams + expectError bool + errorMsg string + }{ + { + name: "all parameters provided", + args: map[string]interface{}{ + "page": 5.0, + "pageSize": 20.0, + "after": "cursor123", + }, + expectParams: PaginationParams{ + Page: 5, + PageSize: 20, + After: "cursor123", + }, + expectError: false, + }, + { + name: "no parameters provided - uses defaults", + args: map[string]interface{}{}, + expectParams: PaginationParams{ + Page: 1, + PageSize: 30, + After: "", + }, + expectError: false, + }, + { + name: "only page provided", + args: map[string]interface{}{ + "page": 3.0, + }, + expectParams: PaginationParams{ + Page: 3, + PageSize: 30, + After: "", + }, + expectError: false, + }, + { + name: "only pageSize provided", + args: map[string]interface{}{ + "pageSize": 50.0, + }, + expectParams: PaginationParams{ + Page: 1, + PageSize: 50, + After: "", + }, + expectError: false, + }, + { + name: "only after provided", + args: map[string]interface{}{ + "after": "token456", + }, + expectParams: PaginationParams{ + Page: 1, + PageSize: 30, + After: "token456", + }, + expectError: false, + }, + { + name: "zero values use defaults", + args: map[string]interface{}{ + "page": 0.0, + "pageSize": 0.0, + "after": "", + }, + expectParams: PaginationParams{ + Page: 1, + PageSize: 30, + After: "", + }, + expectError: false, + }, + { + name: "invalid page type", + args: map[string]interface{}{ + "page": "not a number", + }, + expectParams: PaginationParams{}, + expectError: true, + errorMsg: "parameter page is not of type float64, is string", + }, + { + name: "invalid pageSize type", + args: map[string]interface{}{ + "pageSize": "not a number", + }, + expectParams: PaginationParams{}, + expectError: true, + errorMsg: "parameter pageSize is not of type float64, is string", + }, + { + name: "invalid after type", + args: map[string]interface{}{ + "after": 123, + }, + expectParams: PaginationParams{}, + expectError: true, + errorMsg: "parameter after is not of type string, is int", + }, + { + name: "negative page value", + args: map[string]interface{}{ + "page": -1.0, + }, + expectParams: PaginationParams{ + Page: -1, + PageSize: 30, + After: "", + }, + expectError: false, + }, + { + name: "large values", + args: map[string]interface{}{ + "page": 999.0, + "pageSize": 100.0, + "after": "very-long-cursor-token-with-special-chars-!@#$%^&*()", + }, + expectParams: PaginationParams{ + Page: 999, + PageSize: 100, + After: "very-long-cursor-token-with-special-chars-!@#$%^&*()", + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := mockCallToolRequest(tt.args) + + result, err := OptionalPaginationParams(req) + + if tt.expectError { + require.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectParams, result) + } + }) + } +} + +func TestWithPagination(t *testing.T) { + // Test that WithPagination returns a valid ToolOption + option := WithPagination() + assert.NotNil(t, option) + + // Create a properly initialized tool to test the option + tool := mcp.NewTool("test-tool", mcp.WithDescription("Test tool")) + + // Apply the pagination option + option(&tool) + + // Verify that the tool has been modified and doesn't panic + assert.NotNil(t, tool) + + // The function should not panic when applied to a valid tool + // Since we can't easily inspect the internal structure of mcp.Tool, + // we verify that the option can be applied without errors +} + +func TestPaginationParams_Struct(t *testing.T) { + // Test that PaginationParams struct can be created and accessed + params := PaginationParams{ + Page: 5, + PageSize: 25, + After: "cursor123", + } + + assert.Equal(t, 5, params.Page) + assert.Equal(t, 25, params.PageSize) + assert.Equal(t, "cursor123", params.After) + + // Test zero values + zeroParams := PaginationParams{} + assert.Equal(t, 0, zeroParams.Page) + assert.Equal(t, 0, zeroParams.PageSize) + assert.Equal(t, "", zeroParams.After) +} + +// Benchmark tests for performance +func BenchmarkOptionalParam(b *testing.B) { + req := mockCallToolRequest(map[string]interface{}{ + "test": "value", + "count": 42.0, + "enabled": true, + }) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = OptionalParam[string](req, "test") + } +} + +func BenchmarkOptionalIntParam(b *testing.B) { + req := mockCallToolRequest(map[string]interface{}{ + "count": 42.0, + }) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = OptionalIntParam(req, "count") + } +} + +func BenchmarkOptionalPaginationParams(b *testing.B) { + req := mockCallToolRequest(map[string]interface{}{ + "page": 5.0, + "pageSize": 20.0, + "after": "cursor123", + }) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = OptionalPaginationParams(req) + } +} From 8352060b10f595a4c6a726ba44560b46650f586c Mon Sep 17 00:00:00 2001 From: Gautam Date: Tue, 12 Aug 2025 13:45:47 -0700 Subject: [PATCH 19/19] updating changelog, go libs and minor fixes --- CHANGELOG.md | 7 ++++ go.mod | 13 +++---- go.sum | 26 ++++++------- pkg/client/client.go | 3 +- pkg/tools/dynamic_tool_test.go | 68 +++++++++++++++++----------------- 5 files changed, 60 insertions(+), 57 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5994c5a..9dbac27 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ ## 0.3.0 +* Authentication for HCP Terraform & TFE and restructure the repo. See [#121](https://github.com/hashicorp/terraform-mcp-server/pull/121) +* Implement dynamic tool registration. See [#121](https://github.com/hashicorp/terraform-mcp-server/pull/121) +* Adding 2 new HCP TF/TFE tools. List Terraform organizations & projects. See [#121](https://github.com/hashicorp/terraform-mcp-server/pull/121) +* Changes to tool names to be more consistent. See [#121](https://github.com/hashicorp/terraform-mcp-server/pull/121) +* Implement pagination utility. See [#121](https://github.com/hashicorp/terraform-mcp-server/pull/121) +* Updating `mark3labs/mcp-go` and `hashicorp/tfe-go` versions. See [#121](https://github.com/hashicorp/terraform-mcp-server/pull/121) + ## 0.2.2 (Aug 5, 2025) FEATURES diff --git a/go.mod b/go.mod index 9e55ed5..699533d 100644 --- a/go.mod +++ b/go.mod @@ -5,8 +5,8 @@ go 1.24 require ( github.com/hashicorp/go-cleanhttp v0.5.2 github.com/hashicorp/go-retryablehttp v0.7.8 - github.com/hashicorp/go-tfe v1.87.0 - github.com/mark3labs/mcp-go v0.36.0 + github.com/hashicorp/go-tfe v1.89.0 + github.com/mark3labs/mcp-go v0.37.0 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.9.1 github.com/spf13/viper v1.20.1 @@ -30,18 +30,17 @@ require ( github.com/mailru/easyjson v0.9.0 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - github.com/sagikazarmark/locafero v0.9.0 // indirect - github.com/sourcegraph/conc v0.3.0 // indirect + github.com/sagikazarmark/locafero v0.10.0 // indirect + github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect github.com/spf13/afero v1.14.0 // indirect github.com/spf13/cast v1.9.2 // indirect github.com/spf13/pflag v1.0.7 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect - go.uber.org/multierr v1.11.0 // indirect golang.org/x/sync v0.16.0 // indirect - golang.org/x/sys v0.34.0 // indirect - golang.org/x/text v0.27.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.28.0 // indirect golang.org/x/time v0.12.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index f41f95a..089ef38 100644 --- a/go.sum +++ b/go.sum @@ -30,8 +30,8 @@ github.com/hashicorp/go-retryablehttp v0.7.8 h1:ylXZWnqa7Lhqpk0L1P1LzDtGcCR0rPVU github.com/hashicorp/go-retryablehttp v0.7.8/go.mod h1:rjiScheydd+CxvumBsIrFKlx3iS0jrZ7LvzFGFmuKbw= github.com/hashicorp/go-slug v0.16.7 h1:sBW8y1sX+JKOZKu9a+DQZuWDVaX+U9KFnk6+VDQvKcw= github.com/hashicorp/go-slug v0.16.7/go.mod h1:X5fm++dL59cDOX8j48CqHr4KARTQau7isGh0ZVxJB5I= -github.com/hashicorp/go-tfe v1.87.0 h1:0ejo3SegLoQ/Uj/2U0ECGppm3E/VZfSu+KscvzxvRNs= -github.com/hashicorp/go-tfe v1.87.0/go.mod h1:6dUFMBKh0jkxlRsrw7bYD2mby0efdwE4dtlAuTogIzA= +github.com/hashicorp/go-tfe v1.89.0 h1:2eNxW6LnNzqvFgPzWuyVa0ieUZnb6+2GBJ2KYdwzTx8= +github.com/hashicorp/go-tfe v1.89.0/go.mod h1:6dUFMBKh0jkxlRsrw7bYD2mby0efdwE4dtlAuTogIzA= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY= @@ -48,8 +48,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= -github.com/mark3labs/mcp-go v0.36.0 h1:rIZaijrRYPeSbJG8/qNDe0hWlGrCJ7FWHNMz2SQpTis= -github.com/mark3labs/mcp-go v0.36.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/mark3labs/mcp-go v0.37.0 h1:BywvZLPRT6Zx6mMG/MJfxLSZQkTGIcJSEGKsvr4DsoQ= +github.com/mark3labs/mcp-go v0.37.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= @@ -62,12 +62,12 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/sagikazarmark/locafero v0.9.0 h1:GbgQGNtTrEmddYDSAH9QLRyfAHY12md+8YFTqyMTC9k= -github.com/sagikazarmark/locafero v0.9.0/go.mod h1:UBUyz37V+EdMS3hDF3QWIiVr/2dPrx49OMO0Bn0hJqk= +github.com/sagikazarmark/locafero v0.10.0 h1:FM8Cv6j2KqIhM2ZK7HZjm4mpj9NBktLgowT1aN9q5Cc= +github.com/sagikazarmark/locafero v0.10.0/go.mod h1:Ieo3EUsjifvQu4NZwV5sPd4dwvu0OCgEQV7vjc9yDjw= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= -github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= -github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U= github.com/spf13/afero v1.14.0 h1:9tH6MapGnn/j0eb0yIXiLjERO8RB6xIVZRDCX7PtqWA= github.com/spf13/afero v1.14.0/go.mod h1:acJQ8t0ohCGuMN3O+Pv0V0hgMxNYDlvdk+VTfyZmbYo= github.com/spf13/cast v1.9.2 h1:SsGfm7M8QOFtEzumm7UZrZdLLquNdzFYfIbEXntcFbE= @@ -89,15 +89,13 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/ github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= -go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= -go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= -golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= -golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/pkg/client/client.go b/pkg/client/client.go index 3b4ed8c..9500d03 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -90,7 +90,6 @@ func GetTerraformClientFromContext(ctx context.Context, logger *log.Logger) (*te return nil, fmt.Errorf("no active session") } - // Try to get existing client client := GetTerraformClient(session.SessionID()) if client != nil { @@ -144,7 +143,7 @@ func NewSessionHandler(ctx context.Context, session server.ClientSession, logger } logger.Info("Session has valid TFE client - registered with tool registry") } else { - logger.Info("Session has no valid TFE client - TFE tools will not be available") + logger.Warn("Session has no valid TFE client - TFE tools will not be available") } } diff --git a/pkg/tools/dynamic_tool_test.go b/pkg/tools/dynamic_tool_test.go index da0c4a2..b317a59 100644 --- a/pkg/tools/dynamic_tool_test.go +++ b/pkg/tools/dynamic_tool_test.go @@ -13,76 +13,76 @@ import ( func TestDynamicToolRegistry_SessionManagement(t *testing.T) { logger := log.New() logger.SetLevel(log.ErrorLevel) // Reduce noise in tests - + // Create a registry without initializing the MCP server registry := &DynamicToolRegistry{ sessionsWithTFE: make(map[string]bool), tfeToolsRegistered: false, - mcpServer: nil, // We'll skip actual tool registration - logger: logger, + mcpServer: nil, // We'll skip actual tool registration + logger: logger, } - + // Initially no sessions should have TFE if registry.HasAnySessionWithTFE() { t.Error("Expected no sessions with TFE initially") } - + sessionID1 := "test-session-1" sessionID2 := "test-session-2" - + // Check specific sessions if registry.HasSessionWithTFE(sessionID1) { t.Error("Expected session1 to not have TFE initially") } - + // Manually register sessions (without triggering tool registration) registry.mu.Lock() registry.sessionsWithTFE[sessionID1] = true registry.mu.Unlock() - + if !registry.HasSessionWithTFE(sessionID1) { t.Error("Expected session1 to have TFE after registration") } - + if !registry.HasAnySessionWithTFE() { t.Error("Expected at least one session with TFE") } - + if registry.HasSessionWithTFE(sessionID2) { t.Error("Expected session2 to not have TFE") } - + // Register second session registry.mu.Lock() registry.sessionsWithTFE[sessionID2] = true registry.mu.Unlock() - + if !registry.HasSessionWithTFE(sessionID2) { t.Error("Expected session2 to have TFE after registration") } - + // Unregister first session registry.UnregisterSessionWithTFE(sessionID1) - + if registry.HasSessionWithTFE(sessionID1) { t.Error("Expected session1 to not have TFE after unregistration") } - + if !registry.HasSessionWithTFE(sessionID2) { t.Error("Expected session2 to still have TFE") } - + if !registry.HasAnySessionWithTFE() { t.Error("Expected session2 to still provide TFE availability") } - + // Unregister second session registry.UnregisterSessionWithTFE(sessionID2) - + if registry.HasSessionWithTFE(sessionID2) { t.Error("Expected session2 to not have TFE after unregistration") } - + if registry.HasAnySessionWithTFE() { t.Error("Expected no sessions with TFE after all unregistered") } @@ -91,25 +91,25 @@ func TestDynamicToolRegistry_SessionManagement(t *testing.T) { func TestDynamicToolRegistry_ToolRegistrationState(t *testing.T) { logger := log.New() logger.SetLevel(log.ErrorLevel) // Reduce noise in tests - + // Create a registry without MCP server to test state management registry := &DynamicToolRegistry{ sessionsWithTFE: make(map[string]bool), tfeToolsRegistered: false, - mcpServer: nil, - logger: logger, + mcpServer: nil, + logger: logger, } - + // Initially tools should not be registered if registry.tfeToolsRegistered { t.Error("Expected TFE tools to not be registered initially") } - + // Manually set tools as registered (simulating what would happen) registry.mu.Lock() registry.tfeToolsRegistered = true registry.mu.Unlock() - + // Now tools should be registered if !registry.tfeToolsRegistered { t.Error("Expected TFE tools to be registered") @@ -119,18 +119,18 @@ func TestDynamicToolRegistry_ToolRegistrationState(t *testing.T) { func TestDynamicToolRegistry_ConcurrentAccess(t *testing.T) { logger := log.New() logger.SetLevel(log.ErrorLevel) // Reduce noise in tests - + // Create a registry for concurrent testing registry := &DynamicToolRegistry{ sessionsWithTFE: make(map[string]bool), tfeToolsRegistered: false, - mcpServer: nil, - logger: logger, + mcpServer: nil, + logger: logger, } - + // Test concurrent registration and unregistration done := make(chan bool, 10) - + // Start multiple goroutines registering sessions for i := 0; i < 5; i++ { go func(id int) { @@ -139,12 +139,12 @@ func TestDynamicToolRegistry_ConcurrentAccess(t *testing.T) { registry.mu.Lock() registry.sessionsWithTFE[sessionID] = true registry.mu.Unlock() - + registry.UnregisterSessionWithTFE(sessionID) done <- true }(i) } - + // Start multiple goroutines checking state for i := 0; i < 5; i++ { go func(id int) { @@ -154,11 +154,11 @@ func TestDynamicToolRegistry_ConcurrentAccess(t *testing.T) { done <- true }(i) } - + // Wait for all goroutines to complete for i := 0; i < 10; i++ { <-done } - + // Test should complete without deadlocks or panics }