Skip to content

Commit 6da5cd1

Browse files
authored
feat: allow to set a custom logger in the SSE and STDIO clients (#525)
* feat: allow to set a custom logger in the SSE client So it logs to a logger instead of stdout. Signed-off-by: Carlos Alexandro Becker <[email protected]> * fix: tests, docs, stdio logger * chore: lint --------- Signed-off-by: Carlos Alexandro Becker <[email protected]>
1 parent 9259d32 commit 6da5cd1

File tree

8 files changed

+314
-98
lines changed

8 files changed

+314
-98
lines changed

client/sse.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,10 @@ func WithHTTPClient(httpClient *http.Client) transport.ClientOption {
2323
// NewSSEMCPClient creates a new SSE-based MCP client with the given base URL.
2424
// Returns an error if the URL is invalid.
2525
func NewSSEMCPClient(baseURL string, options ...transport.ClientOption) (*Client, error) {
26-
2726
sseTransport, err := transport.NewSSE(baseURL, options...)
2827
if err != nil {
2928
return nil, fmt.Errorf("failed to create SSE transport: %w", err)
3029
}
31-
3230
return NewClient(sseTransport), nil
3331
}
3432

client/transport/sse.go

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"time"
1717

1818
"github.com/mark3labs/mcp-go/mcp"
19+
"github.com/mark3labs/mcp-go/util"
1920
)
2021

2122
// SSE implements the transport layer of the MCP protocol using Server-Sent Events (SSE).
@@ -33,6 +34,7 @@ type SSE struct {
3334
endpointChan chan struct{}
3435
headers map[string]string
3536
headerFunc HTTPHeaderFunc
37+
logger util.Logger
3638

3739
started atomic.Bool
3840
closed atomic.Bool
@@ -47,6 +49,13 @@ type SSE struct {
4749

4850
type ClientOption func(*SSE)
4951

52+
// WithSSELogger sets a custom logger for the SSE client.
53+
func WithSSELogger(logger util.Logger) ClientOption {
54+
return func(sc *SSE) {
55+
sc.logger = logger
56+
}
57+
}
58+
5059
func WithHeaders(headers map[string]string) ClientOption {
5160
return func(sc *SSE) {
5261
sc.headers = headers
@@ -85,6 +94,7 @@ func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) {
8594
responses: make(map[string]chan *JSONRPCResponse),
8695
endpointChan: make(chan struct{}),
8796
headers: make(map[string]string),
97+
logger: util.DefaultLogger(),
8898
}
8999

90100
for _, opt := range options {
@@ -104,7 +114,6 @@ func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) {
104114
// Start initiates the SSE connection to the server and waits for the endpoint information.
105115
// Returns an error if the connection fails or times out waiting for the endpoint.
106116
func (c *SSE) Start(ctx context.Context) error {
107-
108117
if c.started.Load() {
109118
return fmt.Errorf("has already started")
110119
}
@@ -113,7 +122,6 @@ func (c *SSE) Start(ctx context.Context) error {
113122
c.cancelSSEStream = cancel
114123

115124
req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil)
116-
117125
if err != nil {
118126
return fmt.Errorf("failed to create request: %w", err)
119127
}
@@ -220,7 +228,7 @@ func (c *SSE) readSSE(reader io.ReadCloser) {
220228
}
221229
}
222230
if !c.closed.Load() {
223-
fmt.Printf("SSE stream error: %v\n", err)
231+
c.logger.Errorf("SSE stream error: %v", err)
224232
}
225233
return
226234
}
@@ -256,11 +264,11 @@ func (c *SSE) handleSSEEvent(event, data string) {
256264
case "endpoint":
257265
endpoint, err := c.baseURL.Parse(data)
258266
if err != nil {
259-
fmt.Printf("Error parsing endpoint URL: %v\n", err)
267+
c.logger.Errorf("Error parsing endpoint URL: %v", err)
260268
return
261269
}
262270
if endpoint.Host != c.baseURL.Host {
263-
fmt.Printf("Endpoint origin does not match connection origin\n")
271+
c.logger.Errorf("Endpoint origin does not match connection origin")
264272
return
265273
}
266274
c.endpoint = endpoint
@@ -269,7 +277,7 @@ func (c *SSE) handleSSEEvent(event, data string) {
269277
case "message":
270278
var baseMessage JSONRPCResponse
271279
if err := json.Unmarshal([]byte(data), &baseMessage); err != nil {
272-
fmt.Printf("Error unmarshaling message: %v\n", err)
280+
c.logger.Errorf("Error unmarshaling message: %v", err)
273281
return
274282
}
275283

@@ -321,7 +329,6 @@ func (c *SSE) SendRequest(
321329
ctx context.Context,
322330
request JSONRPCRequest,
323331
) (*JSONRPCResponse, error) {
324-
325332
if !c.started.Load() {
326333
return nil, fmt.Errorf("transport not started yet")
327334
}

0 commit comments

Comments
 (0)