|
| 1 | +import asyncio |
1 | 2 | import json |
2 | 3 | import re |
3 | 4 | from typing import Any, Optional |
@@ -55,32 +56,66 @@ async def create_mcp_clients( |
55 | 56 | mnemonic: Optional[str] = None, |
56 | 57 | ) -> list[MCPClient]: |
57 | 58 | mcp_clients: list[MCPClient] = [] |
58 | | - # Initialize SSE connections |
| 59 | + |
| 60 | + # Create connection tasks for parallel execution |
| 61 | + connection_tasks = [] |
| 62 | + clients = [] |
| 63 | + |
59 | 64 | for name, mcp_config in dict_mcp_config.items(): |
60 | 65 | logger.info( |
61 | 66 | f'Initializing MCP {name} agent for {mcp_config.url} with {mcp_config.mode} connection...' |
62 | 67 | ) |
63 | 68 | # check if the name in a search engine config |
64 | 69 | if f'search_engine_{name}' in dict_mcp_config: |
65 | 70 | continue |
| 71 | + |
66 | 72 | client = MCPClient(name=name) |
67 | | - try: |
68 | | - await client.connect_sse(mcp_config.url, sid, mnemonic) |
69 | | - # Only add the client to the list after a successful connection |
70 | | - mcp_clients.append(client) |
71 | | - logger.info(f'Connected to MCP server {mcp_config.url} via SSE') |
72 | | - except Exception as e: |
73 | | - logger.error(f'Failed to connect to {mcp_config.url}: {str(e)}') |
| 73 | + clients.append((client, name, mcp_config)) |
| 74 | + |
| 75 | + # Create connection task |
| 76 | + async def connect_client(client, mcp_config): |
74 | 77 | try: |
75 | | - await client.disconnect() |
76 | | - except Exception as disconnect_error: |
77 | | - logger.error( |
78 | | - f'Error during disconnect after failed connection: {str(disconnect_error)}' |
79 | | - ) |
| 78 | + await client.connect_sse(mcp_config.url, sid, mnemonic) |
| 79 | + return client, None # Success: return client and no error |
| 80 | + except Exception as e: |
| 81 | + return None, e # Failure: return no client and the error |
| 82 | + |
| 83 | + connection_tasks.append(connect_client(client, mcp_config)) |
| 84 | + |
| 85 | + if connection_tasks: |
| 86 | + results = await asyncio.gather(*connection_tasks, return_exceptions=True) |
| 87 | + |
| 88 | + for i, result in enumerate(results): |
| 89 | + client, name, mcp_config = clients[i] |
| 90 | + |
| 91 | + if isinstance(result, Exception): |
| 92 | + logger.error(f'Failed to connect to {mcp_config.url}: {str(result)}') |
| 93 | + await _safe_disconnect(client) |
| 94 | + elif isinstance(result, tuple) and len(result) == 2: |
| 95 | + client_result, error = result |
| 96 | + if error is None: # Success case |
| 97 | + mcp_clients.append(client_result) |
| 98 | + logger.info(f'Connected to MCP server {mcp_config.url} via SSE') |
| 99 | + else: # Error case |
| 100 | + logger.error(f'Failed to connect to {mcp_config.url}: {str(error)}') |
| 101 | + await _safe_disconnect(client) |
| 102 | + else: |
| 103 | + logger.error(f'Unexpected result format for {mcp_config.url}: {result}') |
| 104 | + await _safe_disconnect(client) |
80 | 105 |
|
81 | 106 | return mcp_clients |
82 | 107 |
|
83 | 108 |
|
| 109 | +async def _safe_disconnect(client: MCPClient): |
| 110 | + """Safely disconnect a client with error handling.""" |
| 111 | + try: |
| 112 | + await client.disconnect() |
| 113 | + except Exception as disconnect_error: |
| 114 | + logger.error( |
| 115 | + f'Error during disconnect after failed connection: {str(disconnect_error)}' |
| 116 | + ) |
| 117 | + |
| 118 | + |
84 | 119 | async def fetch_mcp_tools_from_config( |
85 | 120 | dict_mcp_config: dict[str, MCPConfig], |
86 | 121 | sid: Optional[str] = None, |
|
0 commit comments