Skip to content

Commit 9393526

Browse files
ezynda3opencode
andauthored
fix: resolve stdio transport race condition for concurrent tool calls (#529)
* fix: resolve stdio transport race condition for concurrent tool calls Fixes #528 by replacing goroutine-per-request pattern with a worker pool to prevent multiple goroutines from accessing the non-thread-safe bufio.Reader. - Add worker pool pattern with configurable size (default: 5 workers) - Add buffered queue for tool calls (default: 100 capacity) - Protect concurrent writes with mutex - Maintain backward compatibility while fixing the race condition - Add comprehensive test for concurrent tool calls 🤖 Generated with [opencode](https://opencode.ai) Co-Authored-By: opencode <[email protected]> * fix: address golangci-lint errcheck warnings - Add error checking for Write operations in concurrent tool calls test - Ensures all error return values are properly handled 🤖 Generated with [opencode](https://opencode.ai) Co-Authored-By: opencode <[email protected]> * refactor: remove redundant mutex around thread-safe sync.Map sync.Map is already thread-safe and doesn't require additional mutex protection. This simplifies the test code while maintaining the same functionality. 🤖 Generated with [opencode](https://opencode.ai) Co-Authored-By: opencode <[email protected]> * refactor: improve test control flow and add configuration bounds validation - Replace goto statements with labeled break for more idiomatic Go - Reduce test timeout from 5s to 2s for faster test execution - Add upper bounds validation for worker pool size (max 100) and queue size (max 10000) - Add tests to verify configuration bounds are respected - Log warnings when configuration values exceed maximum limits Addresses remaining CodeRabbit review feedback. 🤖 Generated with [opencode](https://opencode.ai) Co-Authored-By: opencode <[email protected]> * fix: use maximum values instead of defaults when limits exceeded When WithWorkerPoolSize or WithQueueSize receive values exceeding their maximum limits, use the maximum allowed value instead of falling back to the default. This better aligns with user expectations when they request large pool/queue sizes. 🤖 Generated with [opencode](https://opencode.ai) Co-Authored-By: opencode <[email protected]> * test: update stdio tests to expect maximum values when limits exceeded Update test expectations to match the new behavior where WithWorkerPoolSize and WithQueueSize use maximum values instead of defaults when the requested size exceeds limits. 🤖 Generated with [opencode](https://opencode.ai) Co-Authored-By: opencode <[email protected]> --------- Co-authored-by: opencode <[email protected]>
1 parent 6da5cd1 commit 9393526

File tree

2 files changed

+320
-8
lines changed

2 files changed

+320
-8
lines changed

server/stdio.go

Lines changed: 104 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,20 @@ type StdioServer struct {
2929
server *MCPServer
3030
errLogger *log.Logger
3131
contextFunc StdioContextFunc
32+
33+
// Thread-safe tool call processing
34+
toolCallQueue chan *toolCallWork
35+
workerWg sync.WaitGroup
36+
workerPoolSize int
37+
queueSize int
38+
writeMu sync.Mutex // Protects concurrent writes
39+
}
40+
41+
// toolCallWork represents a queued tool call request
42+
type toolCallWork struct {
43+
ctx context.Context
44+
message json.RawMessage
45+
writer io.Writer
3246
}
3347

3448
// StdioOption defines a function type for configuring StdioServer
@@ -50,6 +64,32 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption {
5064
}
5165
}
5266

67+
// WithWorkerPoolSize sets the number of workers for processing tool calls
68+
func WithWorkerPoolSize(size int) StdioOption {
69+
return func(s *StdioServer) {
70+
const maxWorkerPoolSize = 100
71+
if size > 0 && size <= maxWorkerPoolSize {
72+
s.workerPoolSize = size
73+
} else if size > maxWorkerPoolSize {
74+
s.errLogger.Printf("Worker pool size %d exceeds maximum (%d), using maximum", size, maxWorkerPoolSize)
75+
s.workerPoolSize = maxWorkerPoolSize
76+
}
77+
}
78+
}
79+
80+
// WithQueueSize sets the size of the tool call queue
81+
func WithQueueSize(size int) StdioOption {
82+
return func(s *StdioServer) {
83+
const maxQueueSize = 10000
84+
if size > 0 && size <= maxQueueSize {
85+
s.queueSize = size
86+
} else if size > maxQueueSize {
87+
s.errLogger.Printf("Queue size %d exceeds maximum (%d), using maximum", size, maxQueueSize)
88+
s.queueSize = maxQueueSize
89+
}
90+
}
91+
}
92+
5393
// stdioSession is a static client session, since stdio has only one client.
5494
type stdioSession struct {
5595
notifications chan mcp.JSONRPCNotification
@@ -218,6 +258,8 @@ func NewStdioServer(server *MCPServer) *StdioServer {
218258
"",
219259
log.LstdFlags,
220260
), // Default to discarding logs
261+
workerPoolSize: 5, // Default worker pool size
262+
queueSize: 100, // Default queue size
221263
}
222264
}
223265

@@ -281,6 +323,30 @@ func (s *StdioServer) processInputStream(ctx context.Context, reader *bufio.Read
281323
}
282324
}
283325

326+
// toolCallWorker processes tool calls from the queue
327+
func (s *StdioServer) toolCallWorker(ctx context.Context) {
328+
defer s.workerWg.Done()
329+
330+
for {
331+
select {
332+
case work, ok := <-s.toolCallQueue:
333+
if !ok {
334+
// Channel closed, exit worker
335+
return
336+
}
337+
// Process the tool call
338+
response := s.server.HandleMessage(work.ctx, work.message)
339+
if response != nil {
340+
if err := s.writeResponse(response, work.writer); err != nil {
341+
s.errLogger.Printf("Error writing tool response: %v", err)
342+
}
343+
}
344+
case <-ctx.Done():
345+
return
346+
}
347+
}
348+
}
349+
284350
// readNextLine reads a single line from the input reader in a context-aware manner.
285351
// It uses channels to make the read operation cancellable via context.
286352
// Returns the read line and any error encountered. If the context is cancelled,
@@ -315,6 +381,9 @@ func (s *StdioServer) Listen(
315381
stdin io.Reader,
316382
stdout io.Writer,
317383
) error {
384+
// Initialize the tool call queue
385+
s.toolCallQueue = make(chan *toolCallWork, s.queueSize)
386+
318387
// Set a static client context since stdio only has one client
319388
if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil {
320389
return fmt.Errorf("register session: %w", err)
@@ -332,9 +401,23 @@ func (s *StdioServer) Listen(
332401

333402
reader := bufio.NewReader(stdin)
334403

404+
// Start worker pool for tool calls
405+
for i := 0; i < s.workerPoolSize; i++ {
406+
s.workerWg.Add(1)
407+
go s.toolCallWorker(ctx)
408+
}
409+
335410
// Start notification handler
336411
go s.handleNotifications(ctx, stdout)
337-
return s.processInputStream(ctx, reader, stdout)
412+
413+
// Process input stream
414+
err := s.processInputStream(ctx, reader, stdout)
415+
416+
// Shutdown workers gracefully
417+
close(s.toolCallQueue)
418+
s.workerWg.Wait()
419+
420+
return err
338421
}
339422

340423
// processMessage handles a single JSON-RPC message and writes the response.
@@ -367,16 +450,25 @@ func (s *StdioServer) processMessage(
367450
Method string `json:"method"`
368451
}
369452
if json.Unmarshal(rawMessage, &baseMessage) == nil && baseMessage.Method == "tools/call" {
370-
// Process tool calls concurrently to avoid blocking on sampling requests
371-
go func() {
453+
// Queue tool calls for processing by workers
454+
select {
455+
case s.toolCallQueue <- &toolCallWork{
456+
ctx: ctx,
457+
message: rawMessage,
458+
writer: writer,
459+
}:
460+
return nil
461+
case <-ctx.Done():
462+
return ctx.Err()
463+
default:
464+
// Queue is full, process synchronously as fallback
465+
s.errLogger.Printf("Tool call queue full, processing synchronously")
372466
response := s.server.HandleMessage(ctx, rawMessage)
373467
if response != nil {
374-
if err := s.writeResponse(response, writer); err != nil {
375-
s.errLogger.Printf("Error writing tool response: %v", err)
376-
}
468+
return s.writeResponse(response, writer)
377469
}
378-
}()
379-
return nil
470+
return nil
471+
}
380472
}
381473

382474
// Handle other messages synchronously
@@ -462,6 +554,10 @@ func (s *StdioServer) writeResponse(
462554
return err
463555
}
464556

557+
// Protect concurrent writes
558+
s.writeMu.Lock()
559+
defer s.writeMu.Unlock()
560+
465561
// Write response followed by newline
466562
if _, err := fmt.Fprintf(writer, "%s\n", responseBytes); err != nil {
467563
return err

0 commit comments

Comments
 (0)