Skip to content

Commit d3e1040

Browse files
committed
fixing security comments and copilot issues
1 parent 2fd4db5 commit d3e1040

12 files changed

+759
-37
lines changed

cmd/terraform-mcp-server/init.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
stdlog "log"
1111
"net/http"
1212
"os"
13+
"path"
1314
"strings"
1415
"time"
1516

e2e/cors_e2e_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ func runCORSTests(t *testing.T, mcpURL, mode, configuredOrigins string) {
150150

151151
// Define base test cases that apply to all modes
152152
baseTestCases := []testCase{
153-
{"GET with allowed origin", "GET", "https://example.com", 202, true},
154-
{"GET with no origin", "GET", "", 202, false},
153+
{"GET with allowed origin", "GET", "https://example.com", 200, true},
154+
{"GET with no origin", "GET", "", 200, false},
155155
{"OPTIONS preflight with allowed origin", "OPTIONS", "https://example.com", 200, true},
156156
}
157157

@@ -163,15 +163,15 @@ func runCORSTests(t *testing.T, mcpURL, mode, configuredOrigins string) {
163163
}
164164

165165
developmentModeTests := []testCase{
166-
{"GET with localhost origin", "GET", "http://localhost:3000", 202, true},
167-
{"GET with IPv4 localhost", "GET", "http://127.0.0.1:3000", 202, true},
168-
{"GET with IPv6 localhost", "GET", "http://[::1]:3000", 202, true},
166+
{"GET with localhost origin", "GET", "http://localhost:3000", 200, true},
167+
{"GET with IPv4 localhost", "GET", "http://127.0.0.1:3000", 200, true},
168+
{"GET with IPv6 localhost", "GET", "http://[::1]:3000", 200, true},
169169
{"GET with disallowed origin", "GET", "https://evil.com", 403, false},
170170
{"OPTIONS with localhost origin", "OPTIONS", "http://localhost:3000", 200, true},
171171
}
172172

173173
disabledModeTests := []testCase{
174-
{"GET with any origin", "GET", "https://any-site.com", 202, true},
174+
{"GET with any origin", "GET", "https://any-site.com", 200, true},
175175
{"OPTIONS with any origin", "OPTIONS", "https://any-site.com", 200, true},
176176
}
177177

@@ -195,7 +195,7 @@ func runCORSTests(t *testing.T, mcpURL, mode, configuredOrigins string) {
195195
var sessionID string
196196
if tc.method != "OPTIONS" {
197197
// Only try to initialize if we expect it to succeed
198-
if tc.expectedStatus == 202 {
198+
if tc.expectedStatus == 200 {
199199
sessionID = initializeMCPSession(t, mcpURL, tc.origin)
200200
require.NotEmpty(t, sessionID, "Expected to get a session ID for allowed origin")
201201
} else {

pkg/resources/resource_templates.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"context"
88
"fmt"
99
"net/http"
10+
"path"
1011

1112
"github.com/hashicorp/terraform-mcp-server/pkg/client"
1213
"github.com/hashicorp/terraform-mcp-server/pkg/utils"
@@ -16,10 +17,16 @@ import (
1617
)
1718

1819
func RegisterResourceTemplates(hcServer *server.MCPServer, logger *log.Logger) {
19-
hcServer.AddResourceTemplate(ProviderResourceTemplate(path.Join(utils.PROVIDER_BASE_PATH, "{namespace}", "name", "{name}", "version", "{version}"), utils.PROVIDER_BASE_PATH), "Provider details", logger))
20+
hcServer.AddResourceTemplate(
21+
providerResourceTemplate(
22+
path.Join(utils.PROVIDER_BASE_PATH, "{namespace}", "name", "{name}", "version", "{version}"),
23+
"Provider details",
24+
logger,
25+
),
26+
)
2027
}
2128

22-
func ProviderResourceTemplate(resourceURI string, description string, logger *log.Logger) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) {
29+
func providerResourceTemplate(resourceURI string, description string, logger *log.Logger) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) {
2330
return mcp.NewResourceTemplate(
2431
resourceURI,
2532
description,

pkg/tools/dynamic_tool.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,11 @@ func (r *DynamicToolRegistry) createDynamicTFETool(toolName string, toolFactory
123123
terraformClient := client.GetTerraformClient(sessionID)
124124
if terraformClient == nil || terraformClient.TfeClient == nil {
125125
r.logger.WithFields(log.Fields{
126-
"tool": toolName,
126+
"tool": toolName,
127127
}).Warn("TFE tool called but session has no valid TFE client")
128128

129129
return mcp.NewToolResultError("This tool is not available. This tool requires a valid Terraform Cloud/Enterprise token and configuration. Please ensure TFE_TOKEN and TFE_ADDRESS environment variables are properly set."), nil
130130
}
131-
132131
// If we found a valid client that wasn't registered, register it now
133132
r.RegisterSessionWithTFE(sessionID)
134133
}

pkg/tools/get_provider_docs.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"context"
88
"encoding/json"
99
"fmt"
10+
"path"
1011

1112
"github.com/hashicorp/terraform-mcp-server/pkg/client"
1213
"github.com/hashicorp/terraform-mcp-server/pkg/utils"

pkg/tools/list_terraform_orgs.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,15 @@ func ListTerraformOrgs(logger *log.Logger) server.ServerTool {
2222
mcp.WithDescription(`Fetches a list of all Terraform organizations.`),
2323
mcp.WithTitleAnnotation("List all Terraform organizations"),
2424
mcp.WithReadOnlyHintAnnotation(true),
25+
utils.WithPagination(),
2526
),
2627
Handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
27-
return listTerraformOrgsHandler(ctx, logger)
28+
return listTerraformOrgsHandler(ctx, req, logger)
2829
},
2930
}
3031
}
3132

32-
func listTerraformOrgsHandler(ctx context.Context, logger *log.Logger) (*mcp.CallToolResult, error) {
33+
func listTerraformOrgsHandler(ctx context.Context, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) {
3334
// Get a Terraform client from context
3435
terraformClients, err := client.GetTerraformClientFromContext(ctx, logger)
3536
if err != nil {
@@ -41,9 +42,15 @@ func listTerraformOrgsHandler(ctx context.Context, logger *log.Logger) (*mcp.Cal
4142
return nil, utils.LogAndReturnError(logger, "TFE client is not available - please ensure TFE_TOKEN and TFE_ADDRESS are properly configured", nil)
4243
}
4344

45+
pagination, err := utils.OptionalPaginationParams(request)
46+
if err != nil {
47+
return mcp.NewToolResultError(err.Error()), nil
48+
}
49+
4450
orgs, err := tfeClient.Organizations.List(ctx, &tfe.OrganizationListOptions{
4551
ListOptions: tfe.ListOptions{
46-
PageSize: 100,
52+
PageNumber: pagination.Page,
53+
PageSize: pagination.PageSize,
4754
},
4855
})
4956

pkg/tools/list_terraform_projects.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ func ListTerraformProjects(logger *log.Logger) server.ServerTool {
2626
mcp.Required(),
2727
mcp.Description("The name of the Terraform organization to list projects for."),
2828
),
29+
utils.WithPagination(),
2930
),
3031
Handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
3132
return listTerraformProjectsHandler(ctx, req, logger)
@@ -42,6 +43,11 @@ func listTerraformProjectsHandler(ctx context.Context, request mcp.CallToolReque
4243
return nil, utils.LogAndReturnError(logger, "terraform_org_name cannot be empty", nil)
4344
}
4445

46+
pagination, err := utils.OptionalPaginationParams(request)
47+
if err != nil {
48+
return mcp.NewToolResultError(err.Error()), nil
49+
}
50+
4551
// Get a Terraform client from context
4652
terraformClients, err := client.GetTerraformClientFromContext(ctx, logger)
4753
if err != nil {
@@ -55,7 +61,8 @@ func listTerraformProjectsHandler(ctx context.Context, request mcp.CallToolReque
5561
// Fetch the list of projects
5662
projects, err := tfeClient.Projects.List(ctx, terraformOrgName, &tfe.ProjectListOptions{
5763
ListOptions: tfe.ListOptions{
58-
PageSize: 100,
64+
PageNumber: pagination.Page,
65+
PageSize: pagination.PageSize,
5966
},
6067
})
6168

pkg/tools/policy_details.go

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ import (
77
"context"
88
"encoding/json"
99
"fmt"
10+
"net/url"
1011
"strings"
12+
"text/template"
1113

1214
"github.com/hashicorp/terraform-mcp-server/pkg/client"
1315
"github.com/hashicorp/terraform-mcp-server/pkg/utils"
@@ -69,11 +71,28 @@ func getPolicyDetailsHandler(ctx context.Context, request mcp.CallToolRequest, l
6971
moduleList := ""
7072
for _, policy := range policyDetails.Included {
7173
if policy.Type == "policy-modules" {
72-
moduleList += fmt.Sprintf(`
73-
module "%s" {
74-
source = "https://registry.terraform.io/v2%s/policy-module/%s.sentinel?checksum=sha256:%s"
74+
// Use text/template to safely build the module block
75+
var moduleBuilder strings.Builder
76+
tmpl := `
77+
module "{{.Name}}" {
78+
source = "https://registry.terraform.io/v2{{.PolicyID}}/policy-module/{{.Name}}.sentinel?checksum=sha256:{{.Shasum}}"
7579
}
76-
`, policy.Attributes.Name, terraformPolicyID, policy.Attributes.Name, policy.Attributes.Shasum)
80+
`
81+
type moduleData struct {
82+
Name string
83+
PolicyID string
84+
Shasum string
85+
}
86+
t := template.Must(template.New("module").Parse(tmpl))
87+
err := t.Execute(&moduleBuilder, moduleData{
88+
Name: policy.Attributes.Name,
89+
PolicyID: terraformPolicyID,
90+
Shasum: policy.Attributes.Shasum,
91+
})
92+
if err != nil {
93+
logger.WithError(err).Error("failed to render module template")
94+
}
95+
moduleList += moduleBuilder.String()
7796
}
7897

7998
if policy.Type == "policies" {
@@ -85,13 +104,30 @@ module "%s" {
85104
builder.WriteString("## Usage\n\n")
86105
builder.WriteString("Generate the content for a HashiCorp Configuration Language (HCL) file named policies.hcl. This file should define a set of policies. For each policy provided, create a distinct policy block using the following template.\n")
87106
builder.WriteString("\n```hcl\n")
88-
hclTemplate := fmt.Sprintf(`
89-
%s
107+
// Use text/template to safely build the HCL template for policies
108+
hclTmpl := `
109+
{{- if .ModuleList }}
110+
{{ .ModuleList }}
111+
{{- end }}
90112
policy "<<POLICY_NAME>>" {
91-
source = "https://registry.terraform.io/v2%s/policy/<<POLICY_NAME>>.sentinel?checksum=<<POLICY_CHECKSUM>>"
92-
enforcement_level = "advisory"
113+
source = "https://registry.terraform.io/v2{{ .TerraformPolicyID }}/policy/<<POLICY_NAME>>.sentinel?checksum=<<POLICY_CHECKSUM>>"
114+
enforcement_level = "advisory"
93115
}
94-
`, moduleList, terraformPolicyID)
116+
`
117+
type hclTemplateData struct {
118+
ModuleList string
119+
TerraformPolicyID string
120+
}
121+
var hclBuilder strings.Builder
122+
t := template.Must(template.New("hclPolicy").Parse(hclTmpl))
123+
err = t.Execute(&hclBuilder, hclTemplateData{
124+
ModuleList: moduleList,
125+
TerraformPolicyID: terraformPolicyID,
126+
})
127+
if err != nil {
128+
logger.WithError(err).Error("failed to render HCL policy template")
129+
}
130+
hclTemplate := hclBuilder.String()
95131
builder.WriteString(hclTemplate)
96132
builder.WriteString("\n```\n")
97133
builder.WriteString(fmt.Sprintf("Available policies with SHA for %s are: \n\n", terraformPolicyID))

pkg/tools/resolve_provider_doc_id.go

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"encoding/json"
99
"fmt"
1010
"net/http"
11+
"path"
1112
"strings"
1213

1314
"github.com/hashicorp/terraform-mcp-server/pkg/client"
@@ -126,7 +127,7 @@ func resolveProviderDocIDHandler(ctx context.Context, request mcp.CallToolReques
126127
cs_pn, err_pn := utils.ContainsSlug(fmt.Sprintf("%s_%s", providerDetail.ProviderName, doc.Slug), serviceSlug)
127128
if (cs || cs_pn) && err == nil && err_pn == nil {
128129
contentAvailable = true
129-
descriptionSnippet, err := getContentSnippet(registryClient, doc.ID, logger)
130+
descriptionSnippet, err := getContentSnippet(httpClient, doc.ID, logger)
130131
if err != nil {
131132
logger.Warnf("Error fetching content snippet for provider doc ID: %s: %v", doc.ID, err)
132133
}
@@ -143,7 +144,7 @@ func resolveProviderDocIDHandler(ctx context.Context, request mcp.CallToolReques
143144
return mcp.NewToolResultText(builder.String()), nil
144145
}
145146

146-
func resolveProviderDetails(request mcp.CallToolRequest, registryClient *http.Client, defaultErrorGuide string, logger *log.Logger) (client.ProviderDetail, error) {
147+
func resolveProviderDetails(request mcp.CallToolRequest, httpClient *http.Client, defaultErrorGuide string, logger *log.Logger) (client.ProviderDetail, error) {
147148
providerDetail := client.ProviderDetail{}
148149
providerName := request.GetString("provider_name", "")
149150
if providerName == "" {
@@ -164,7 +165,7 @@ func resolveProviderDetails(request mcp.CallToolRequest, registryClient *http.Cl
164165
if utils.IsValidProviderVersionFormat(providerVersion) {
165166
providerVersionValue = providerVersion
166167
} else {
167-
providerVersionValue, err = client.GetLatestProviderVersion(registryClient, providerNamespace, providerName, logger)
168+
providerVersionValue, err = client.GetLatestProviderVersion(httpClient, providerNamespace, providerName, logger)
168169
if err != nil {
169170
providerVersionValue = ""
170171
logger.Debugf("Error getting latest provider version in %s namespace: %v", providerNamespace, err)
@@ -174,7 +175,7 @@ func resolveProviderDetails(request mcp.CallToolRequest, registryClient *http.Cl
174175
// If the provider version doesn't exist, try the hashicorp namespace
175176
if providerVersionValue == "" {
176177
tryProviderNamespace := "hashicorp"
177-
providerVersionValue, err = client.GetLatestProviderVersion(registryClient, tryProviderNamespace, providerName, logger)
178+
providerVersionValue, err = client.GetLatestProviderVersion(httpClient, tryProviderNamespace, providerName, logger)
178179
if err != nil {
179180
// Just so we don't print the same namespace twice if they are the same
180181
if providerNamespace != tryProviderNamespace {
@@ -198,20 +199,20 @@ func resolveProviderDetails(request mcp.CallToolRequest, registryClient *http.Cl
198199
}
199200

200201
// get_provider_docsV2 retrieves a list of documentation items for a specific provider category using v2 API with support for pagination using page numbers
201-
func get_provider_docsV2(registryClient *http.Client, providerDetail client.ProviderDetail, logger *log.Logger) (string, error) {
202-
providerVersionID, err := client.GetProviderVersionID(registryClient, providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion, logger)
202+
func get_provider_docsV2(httpClient *http.Client, providerDetail client.ProviderDetail, logger *log.Logger) (string, error) {
203+
providerVersionID, err := client.GetProviderVersionID(httpClient, providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion, logger)
203204
if err != nil {
204205
return "", utils.LogAndReturnError(logger, "getting provider version ID", err)
205206
}
206207
category := providerDetail.ProviderDataType
207208
if category == "overview" {
208-
return client.GetProviderOverviewDocs(registryClient, providerVersionID, logger)
209+
return client.GetProviderOverviewDocs(httpClient, providerVersionID, logger)
209210
}
210211

211212
uriPrefix := fmt.Sprintf("provider-docs?filter[provider-version]=%s&filter[category]=%s&filter[language]=hcl",
212213
providerVersionID, category)
213214

214-
docs, err := client.SendPaginatedRegistryCall(registryClient, uriPrefix, logger)
215+
docs, err := client.SendPaginatedRegistryCall(httpClient, uriPrefix, logger)
215216
if err != nil {
216217
return "", utils.LogAndReturnError(logger, "getting provider documentation", err)
217218
}
@@ -225,7 +226,7 @@ func get_provider_docsV2(registryClient *http.Client, providerDetail client.Prov
225226
builder.WriteString("Each result includes:\n- providerDocID: tfprovider-compatible identifier\n- Title: Service or resource name\n- Category: Type of document\n- Description: Brief summary of the document\n")
226227
builder.WriteString("For best results, select libraries based on the service_slug match and category of information requested.\n\n---\n\n")
227228
for _, doc := range docs {
228-
descriptionSnippet, err := getContentSnippet(registryClient, doc.ID, logger)
229+
descriptionSnippet, err := getContentSnippet(httpClient, doc.ID, logger)
229230
if err != nil {
230231
logger.Warnf("Error fetching content snippet for provider doc ID: %s: %v", doc.ID, err)
231232
}
@@ -235,8 +236,8 @@ func get_provider_docsV2(registryClient *http.Client, providerDetail client.Prov
235236
return builder.String(), nil
236237
}
237238

238-
func getContentSnippet(registryClient *http.Client, docID string, logger *log.Logger) (string, error) {
239-
docContent, err := client.SendRegistryCall(registryClient, "GET", fmt.Sprintf("provider-docs/%s", docID), logger, "v2")
239+
func getContentSnippet(httpClient *http.Client, docID string, logger *log.Logger) (string, error) {
240+
docContent, err := client.SendRegistryCall(httpClient, "GET", fmt.Sprintf("provider-docs/%s", docID), logger, "v2")
240241
if err != nil {
241242
return "", utils.LogAndReturnError(logger, fmt.Sprintf("error fetching provider-docs/%s within getContentSnippet", docID), err)
242243
}

pkg/tools/search_policies.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"context"
88
"encoding/json"
99
"fmt"
10+
"net/url"
1011
"strings"
1112

1213
"github.com/hashicorp/terraform-mcp-server/pkg/client"
@@ -62,8 +63,14 @@ func getSearchPoliciesHandler(ctx context.Context, request mcp.CallToolRequest,
6263
}
6364

6465
httpClient := terraformClients.HttpClient
65-
// static list of 100 is fine for now
66-
policyResp, err := client.SendRegistryCall(httpClient, "GET", (&url.URL{Path: "policies", RawQuery: url.Values{"page[size]": {"100"}, "include": {"latest-version"}}.Encode()}).String(), logger, "v2")
66+
uri := (&url.URL{
67+
Path: "policies",
68+
RawQuery: url.Values{
69+
"page[size]": {"100"}, // static list of 100 is fine for now
70+
"include": {"latest-version"},
71+
}.Encode(),
72+
}).String()
73+
policyResp, err := client.SendRegistryCall(httpClient, "GET", uri, logger, "v2")
6774
if err != nil {
6875
return nil, utils.LogAndReturnError(logger, "Failed to fetch policies: registry API did not return a successful response", err)
6976
}

0 commit comments

Comments
 (0)