diff --git a/client/transport/stdio.go b/client/transport/stdio.go index 74ce81046..0fcd8f69c 100644 --- a/client/transport/stdio.go +++ b/client/transport/stdio.go @@ -123,6 +123,8 @@ func (c *Stdio) Start(ctx context.Context) error { // If an (optional) cmdFunc custom command factory function was configured, it will be used to construct the subprocess; // otherwise, the default behavior uses exec.CommandContext with the merged environment. // Initializes stdin, stdout, and stderr pipes for JSON-RPC communication. +// A background goroutine is also started to wait for the subprocess to exit, +// ensuring that the done channel is closed automatically if the process terminates unexpectedly. func (c *Stdio) spawnCommand(ctx context.Context) error { if c.command == "" { return nil @@ -163,6 +165,16 @@ func (c *Stdio) spawnCommand(ctx context.Context) error { return fmt.Errorf("failed to start command: %w", err) } + go func() { + _ = cmd.Wait() + select { + case <-c.done: + // Already closed explicitly (via Close), do nothing + default: + close(c.done) // Automatically signal subprocess exit + } + }() + return nil } @@ -191,6 +203,16 @@ func (c *Stdio) Close() error { return nil } +// IsClosed reports whether the subprocess has exited and the transport is no longer usable. +func (c *Stdio) IsClosed() bool { + select { + case <-c.done: + return true + default: + return false + } +} + // GetSessionId returns the session ID of the transport. // Since stdio does not maintain a session ID, it returns an empty string. func (c *Stdio) GetSessionId() string { @@ -293,6 +315,11 @@ func (c *Stdio) SendRequest( c.mu.Unlock() } + if c.IsClosed() { + deleteResponseChan() + return nil, fmt.Errorf("cannot send request: subprocess is closed") + } + // Send request if _, err := c.stdin.Write(requestBytes); err != nil { deleteResponseChan() @@ -303,6 +330,9 @@ func (c *Stdio) SendRequest( case <-ctx.Done(): deleteResponseChan() return nil, ctx.Err() + case <-c.done: + deleteResponseChan() + return nil, fmt.Errorf("subprocess exited while waiting for response") case response := <-responseChan: return response, nil } diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go index 0b312ace3..fded435ac 100644 --- a/client/transport/stdio_test.go +++ b/client/transport/stdio_test.go @@ -384,6 +384,31 @@ func TestStdio(t *testing.T) { t.Errorf("Expected array with 3 items, got %v", result.Params["array"]) } }) + + t.Run("SendRequestFailsIfSubprocessExited", func(t *testing.T) { + // Start a subprocess that exits immediately + ctx := context.Background() + stdio := NewStdio("sh", nil, "-c", "exit 0") + + err := stdio.Start(ctx) + require.NoError(t, err) + + // Wait for subprocess to exit + require.Eventually(t, func() bool { + return stdio.IsClosed() + }, time.Second, 10*time.Millisecond) + + // Try to send a request + _, err = stdio.SendRequest(ctx, JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId("dead"), + Method: "noop", + }) + + require.Error(t, err) + require.Contains(t, err.Error(), "subprocess") + }) + } func TestStdioErrors(t *testing.T) { @@ -609,6 +634,32 @@ func TestStdio_SpawnCommand_UsesCommandFunc_Error(t *testing.T) { require.EqualError(t, err, "test error") } +func TestStdio_DoneClosedWhenSubcommandExits(t *testing.T) { + ctx := context.Background() + + stdio := NewStdioWithOptions( + "sh", + nil, + []string{"-c", "exit 0"}, + ) + + require.NotNil(t, stdio) + + err := stdio.spawnCommand(ctx) + require.NoError(t, err) + + t.Cleanup(func() { + if stdio.cmd.Process != nil { + _ = stdio.cmd.Process.Kill() + } + }) + + // Wait up to 200ms for the done channel to close + require.Eventually(t, func() bool { + return stdio.IsClosed() + }, 200*time.Millisecond, 10*time.Millisecond, "expected done to be closed after subprocess exited") +} + func TestStdio_NewStdioWithOptions_AppliesOptions(t *testing.T) { configured := false @@ -620,3 +671,28 @@ func TestStdio_NewStdioWithOptions_AppliesOptions(t *testing.T) { require.NotNil(t, stdio) require.True(t, configured, "option was not applied") } + +func TestStdio_IsClosed(t *testing.T) { + t.Run("returns false before Start", func(t *testing.T) { + stdio := NewStdio("sh", nil, "-c", "sleep 1") + require.False(t, stdio.IsClosed(), "expected IsClosed to be false before Start") + }) + + t.Run("returns false after Start", func(t *testing.T) { + stdio := NewStdio("sh", nil, "-c", "sleep 1") + err := stdio.Start(context.Background()) + require.NoError(t, err) + defer stdio.Close() + require.False(t, stdio.IsClosed(), "expected IsClosed to be false right after Start") + }) + + t.Run("returns true after subprocess exits", func(t *testing.T) { + stdio := NewStdio("sh", nil, "-c", "exit 0") + err := stdio.Start(context.Background()) + require.NoError(t, err) + + require.Eventually(t, func() bool { + return stdio.IsClosed() + }, 200*time.Millisecond, 10*time.Millisecond, "expected IsClosed to return true after subprocess exits") + }) +}