@@ -16,6 +16,7 @@ import (
16
16
"time"
17
17
18
18
"github.com/mark3labs/mcp-go/mcp"
19
+ "github.com/mark3labs/mcp-go/util"
19
20
)
20
21
21
22
// SSE implements the transport layer of the MCP protocol using Server-Sent Events (SSE).
@@ -33,6 +34,7 @@ type SSE struct {
33
34
endpointChan chan struct {}
34
35
headers map [string ]string
35
36
headerFunc HTTPHeaderFunc
37
+ logger util.Logger
36
38
37
39
started atomic.Bool
38
40
closed atomic.Bool
@@ -47,6 +49,13 @@ type SSE struct {
47
49
48
50
type ClientOption func (* SSE )
49
51
52
+ // WithSSELogger sets a custom logger for the SSE client.
53
+ func WithSSELogger (logger util.Logger ) ClientOption {
54
+ return func (sc * SSE ) {
55
+ sc .logger = logger
56
+ }
57
+ }
58
+
50
59
func WithHeaders (headers map [string ]string ) ClientOption {
51
60
return func (sc * SSE ) {
52
61
sc .headers = headers
@@ -85,6 +94,7 @@ func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) {
85
94
responses : make (map [string ]chan * JSONRPCResponse ),
86
95
endpointChan : make (chan struct {}),
87
96
headers : make (map [string ]string ),
97
+ logger : util .DefaultLogger (),
88
98
}
89
99
90
100
for _ , opt := range options {
@@ -104,7 +114,6 @@ func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) {
104
114
// Start initiates the SSE connection to the server and waits for the endpoint information.
105
115
// Returns an error if the connection fails or times out waiting for the endpoint.
106
116
func (c * SSE ) Start (ctx context.Context ) error {
107
-
108
117
if c .started .Load () {
109
118
return fmt .Errorf ("has already started" )
110
119
}
@@ -113,7 +122,6 @@ func (c *SSE) Start(ctx context.Context) error {
113
122
c .cancelSSEStream = cancel
114
123
115
124
req , err := http .NewRequestWithContext (ctx , "GET" , c .baseURL .String (), nil )
116
-
117
125
if err != nil {
118
126
return fmt .Errorf ("failed to create request: %w" , err )
119
127
}
@@ -220,7 +228,7 @@ func (c *SSE) readSSE(reader io.ReadCloser) {
220
228
}
221
229
}
222
230
if ! c .closed .Load () {
223
- fmt . Printf ("SSE stream error: %v\n " , err )
231
+ c . logger . Errorf ("SSE stream error: %v" , err )
224
232
}
225
233
return
226
234
}
@@ -256,11 +264,11 @@ func (c *SSE) handleSSEEvent(event, data string) {
256
264
case "endpoint" :
257
265
endpoint , err := c .baseURL .Parse (data )
258
266
if err != nil {
259
- fmt . Printf ("Error parsing endpoint URL: %v\n " , err )
267
+ c . logger . Errorf ("Error parsing endpoint URL: %v" , err )
260
268
return
261
269
}
262
270
if endpoint .Host != c .baseURL .Host {
263
- fmt . Printf ("Endpoint origin does not match connection origin\n " )
271
+ c . logger . Errorf ("Endpoint origin does not match connection origin" )
264
272
return
265
273
}
266
274
c .endpoint = endpoint
@@ -269,7 +277,7 @@ func (c *SSE) handleSSEEvent(event, data string) {
269
277
case "message" :
270
278
var baseMessage JSONRPCResponse
271
279
if err := json .Unmarshal ([]byte (data ), & baseMessage ); err != nil {
272
- fmt . Printf ("Error unmarshaling message: %v\n " , err )
280
+ c . logger . Errorf ("Error unmarshaling message: %v" , err )
273
281
return
274
282
}
275
283
@@ -321,7 +329,6 @@ func (c *SSE) SendRequest(
321
329
ctx context.Context ,
322
330
request JSONRPCRequest ,
323
331
) (* JSONRPCResponse , error ) {
324
-
325
332
if ! c .started .Load () {
326
333
return nil , fmt .Errorf ("transport not started yet" )
327
334
}
0 commit comments