Skip to content

Commit f7c0746

Browse files
committed
Address PR #176 review comments
- Set ClusterAware to false for probe_dns_local - Improve robust IPv6-aware address port handling - Use api.NewToolCallResultStructured for structured tool output
1 parent 804593c commit f7c0746

File tree

2 files changed

+61
-23
lines changed

2 files changed

+61
-23
lines changed

pkg/toolsets/netedge/probe_dns_local.go

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ func initProbeDNSLocal() []api.ServerTool {
4545
OpenWorldHint: ptr.To(true),
4646
},
4747
},
48-
Handler: probeDNSLocalHandler,
48+
ClusterAware: ptr.To(false),
49+
Handler: probeDNSLocalHandler,
4950
},
5051
}
5152
}
@@ -77,12 +78,12 @@ type DNSResult struct {
7778
func probeDNSLocalHandler(params api.ToolHandlerParams) (*api.ToolCallResult, error) {
7879
serverParam, ok := params.GetArguments()["server"].(string)
7980
if !ok || serverParam == "" {
80-
return api.NewToolCallResult("", fmt.Errorf("server parameter is required")), nil
81+
return api.NewToolCallResultStructured(nil, fmt.Errorf("server parameter is required")), nil
8182
}
8283

8384
nameParam, ok := params.GetArguments()["name"].(string)
8485
if !ok || nameParam == "" {
85-
return api.NewToolCallResult("", fmt.Errorf("name parameter is required")), nil
86+
return api.NewToolCallResultStructured(nil, fmt.Errorf("name parameter is required")), nil
8687
}
8788

8889
typeParam, ok := params.GetArguments()["type"].(string)
@@ -95,16 +96,18 @@ func probeDNSLocalHandler(params api.ToolHandlerParams) (*api.ToolCallResult, er
9596

9697
// Ensure server parameter has a port
9798
if _, _, err := net.SplitHostPort(serverParam); err != nil {
98-
if strings.Contains(err.Error(), "missing port in address") {
99-
serverParam = net.JoinHostPort(serverParam, "53")
99+
// If port is missing, try adding default DNS port 53
100+
appended := net.JoinHostPort(serverParam, "53")
101+
if _, _, err2 := net.SplitHostPort(appended); err2 == nil {
102+
serverParam = appended
100103
} else {
101-
return api.NewToolCallResult("", fmt.Errorf("invalid server address format: %w", err)), nil
104+
return api.NewToolCallResultStructured(nil, fmt.Errorf("invalid server address format: %w", err)), nil
102105
}
103106
}
104107

105108
recordType, ok := dns.StringToType[strings.ToUpper(typeParam)]
106109
if !ok {
107-
return api.NewToolCallResult("", fmt.Errorf("invalid or unsupported DNS record type: %s", typeParam)), nil
110+
return api.NewToolCallResultStructured(nil, fmt.Errorf("invalid or unsupported DNS record type: %s", typeParam)), nil
108111
}
109112

110113
msg := new(dns.Msg)
@@ -115,7 +118,7 @@ func probeDNSLocalHandler(params api.ToolHandlerParams) (*api.ToolCallResult, er
115118

116119
if err != nil {
117120
// Log network level errors directly to the tool output so agent can interpret it
118-
return api.NewToolCallResult("", fmt.Errorf("DNS query failed: %w", err)), nil
121+
return api.NewToolCallResultStructured(nil, fmt.Errorf("DNS query failed: %w", err)), nil
119122
}
120123

121124
result := DNSResult{
@@ -130,10 +133,5 @@ func probeDNSLocalHandler(params api.ToolHandlerParams) (*api.ToolCallResult, er
130133
result.Answers = append(result.Answers, ans)
131134
}
132135

133-
jsonData, err := json.MarshalIndent(result, "", " ")
134-
if err != nil {
135-
return api.NewToolCallResult("", fmt.Errorf("failed to marshal DNS result: %w", err)), nil
136-
}
137-
138-
return api.NewToolCallResult(string(jsonData), nil), nil
136+
return api.NewToolCallResultStructured(result, nil), nil
139137
}

pkg/toolsets/netedge/probe_dns_local_test.go

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@ import (
1414
)
1515

1616
type mockDNSClient struct {
17-
msg *dns.Msg
18-
rtt time.Duration
19-
err error
17+
msg *dns.Msg
18+
rtt time.Duration
19+
err error
20+
lastServer string
2021
}
2122

2223
func (m *mockDNSClient) Exchange(msg *dns.Msg, server string) (*dns.Msg, time.Duration, error) {
24+
m.lastServer = server
2325
return m.msg, m.rtt, m.err
2426
}
2527

@@ -36,7 +38,7 @@ func TestProbeDNSLocalHandler(t *testing.T) {
3638
args map[string]interface{}
3739
mockClient *mockDNSClient
3840
expectedError string
39-
validate func(t *testing.T, result string)
41+
validate func(t *testing.T, result *api.ToolCallResult)
4042
}{
4143
{
4244
name: "success query A record",
@@ -50,14 +52,19 @@ func TestProbeDNSLocalHandler(t *testing.T) {
5052
rtt: 10 * time.Millisecond,
5153
err: nil,
5254
},
53-
validate: func(t *testing.T, content string) {
55+
validate: func(t *testing.T, result *api.ToolCallResult) {
5456
var res DNSResult
55-
err := json.Unmarshal([]byte(content), &res)
57+
err := json.Unmarshal([]byte(result.Content), &res)
5658
require.NoError(t, err)
5759
assert.Equal(t, "NOERROR", res.Rcode)
5860
assert.Equal(t, int64(10), res.LatencyMS)
5961
assert.Len(t, res.Answers, 1)
6062
assert.Contains(t, res.Answers[0], "93.184.216.34")
63+
64+
// Also check structured content
65+
structured, ok := result.StructuredContent.(DNSResult)
66+
require.True(t, ok)
67+
assert.Equal(t, "NOERROR", structured.Rcode)
6168
},
6269
},
6370
{
@@ -108,11 +115,44 @@ func TestProbeDNSLocalHandler(t *testing.T) {
108115
rtt: 5 * time.Millisecond,
109116
err: nil,
110117
},
111-
validate: func(t *testing.T, content string) {
118+
validate: func(t *testing.T, result *api.ToolCallResult) {
112119
var res DNSResult
113-
err := json.Unmarshal([]byte(content), &res)
120+
err := json.Unmarshal([]byte(result.Content), &res)
114121
require.NoError(t, err)
115122
assert.Equal(t, "NOERROR", res.Rcode)
123+
124+
// Also check structured content
125+
structured, ok := result.StructuredContent.(DNSResult)
126+
require.True(t, ok)
127+
assert.Equal(t, "NOERROR", structured.Rcode)
128+
},
129+
},
130+
{
131+
name: "ipv4 address appends default port",
132+
args: map[string]interface{}{
133+
"server": "1.1.1.1",
134+
"name": "example.com",
135+
},
136+
mockClient: &mockDNSClient{
137+
msg: successMsg,
138+
rtt: 5 * time.Millisecond,
139+
},
140+
validate: func(t *testing.T, result *api.ToolCallResult) {
141+
assert.Equal(t, "1.1.1.1:53", activeDNSClient.(*mockDNSClient).lastServer)
142+
},
143+
},
144+
{
145+
name: "ipv6 address appends default port",
146+
args: map[string]interface{}{
147+
"server": "2001:4860:4860::8888",
148+
"name": "example.com",
149+
},
150+
mockClient: &mockDNSClient{
151+
msg: successMsg,
152+
rtt: 5 * time.Millisecond,
153+
},
154+
validate: func(t *testing.T, result *api.ToolCallResult) {
155+
assert.Equal(t, "[2001:4860:4860::8888]:53", activeDNSClient.(*mockDNSClient).lastServer)
116156
},
117157
},
118158
}
@@ -149,7 +189,7 @@ func TestProbeDNSLocalHandler(t *testing.T) {
149189
require.NotNil(t, result)
150190
require.NoError(t, result.Error)
151191
if tt.validate != nil {
152-
tt.validate(t, result.Content)
192+
tt.validate(t, result)
153193
}
154194
}
155195
})

0 commit comments

Comments
 (0)