diff --git a/Sources/MCP/Base/Transports/HTTPClientTransport.swift b/Sources/MCP/Base/Transports/HTTPClientTransport.swift index 50f2f8c..11a4455 100644 --- a/Sources/MCP/Base/Transports/HTTPClientTransport.swift +++ b/Sources/MCP/Base/Transports/HTTPClientTransport.swift @@ -32,9 +32,14 @@ import Logging /// ```swift /// import MCP /// -/// // Create a streaming HTTP transport +/// // Create a streaming HTTP transport with bearer token authentication /// let transport = HTTPClientTransport( -/// endpoint: URL(string: "http://localhost:8080")!, +/// endpoint: URL(string: "https://api.example.com/mcp")!, +/// requestModifier: { request in +/// var modifiedRequest = request +/// modifiedRequest.addValue("Bearer your-token-here", forHTTPHeaderField: "Authorization") +/// return modifiedRequest +/// } /// ) /// /// // Initialize the client with streaming transport @@ -60,6 +65,9 @@ public actor HTTPClientTransport: Transport { /// Maximum time to wait for a session ID before proceeding with SSE connection public let sseInitializationTimeout: TimeInterval + /// Closure to modify requests before they are sent + private let requestModifier: (URLRequest) -> URLRequest + private var isConnected = false private let messageStream: AsyncThrowingStream private let messageContinuation: AsyncThrowingStream.Continuation @@ -74,12 +82,14 @@ public actor HTTPClientTransport: Transport { /// - configuration: URLSession configuration to use for HTTP requests /// - streaming: Whether to enable SSE streaming mode (default: true) /// - sseInitializationTimeout: Maximum time to wait for session ID before proceeding with SSE (default: 10 seconds) + /// - requestModifier: Optional closure to customize requests before they are sent (default: no modification) /// - logger: Optional logger instance for transport events public init( endpoint: URL, configuration: URLSessionConfiguration = .default, streaming: Bool = true, sseInitializationTimeout: TimeInterval = 10, + requestModifier: @escaping (URLRequest) -> URLRequest = { $0 }, logger: Logger? = nil ) { self.init( @@ -87,6 +97,7 @@ public actor HTTPClientTransport: Transport { session: URLSession(configuration: configuration), streaming: streaming, sseInitializationTimeout: sseInitializationTimeout, + requestModifier: requestModifier, logger: logger ) } @@ -96,12 +107,14 @@ public actor HTTPClientTransport: Transport { session: URLSession, streaming: Bool = false, sseInitializationTimeout: TimeInterval = 10, + requestModifier: @escaping (URLRequest) -> URLRequest = { $0 }, logger: Logger? = nil ) { self.endpoint = endpoint self.session = session self.streaming = streaming self.sseInitializationTimeout = sseInitializationTimeout + self.requestModifier = requestModifier // Create message stream var continuation: AsyncThrowingStream.Continuation! @@ -211,6 +224,9 @@ public actor HTTPClientTransport: Transport { request.addValue(sessionID, forHTTPHeaderField: "Mcp-Session-Id") } + // Apply request modifier + request = requestModifier(request) + #if os(Linux) // Linux implementation using data(for:) instead of bytes(for:) let (responseData, response) = try await session.data(for: request) @@ -480,6 +496,9 @@ public actor HTTPClientTransport: Transport { request.addValue(sessionID, forHTTPHeaderField: "Mcp-Session-Id") } + // Apply request modifier + request = requestModifier(request) + logger.debug("Starting SSE connection") // Create URLSession task for SSE diff --git a/Tests/MCPTests/HTTPClientTransportTests.swift b/Tests/MCPTests/HTTPClientTransportTests.swift index 9056a1b..cf2a25d 100644 --- a/Tests/MCPTests/HTTPClientTransportTests.swift +++ b/Tests/MCPTests/HTTPClientTransportTests.swift @@ -279,16 +279,9 @@ import Testing let configuration = URLSessionConfiguration.ephemeral configuration.protocolClasses = [MockURLProtocol.self] - let transport = HTTPClientTransport( - endpoint: testEndpoint, - configuration: configuration, - streaming: false, - logger: nil - ) - try await transport.connect() - let messageData = #"{"jsonrpc":"2.0","method":"test","id":3}"#.data(using: .utf8)! + // Set up the handler BEFORE creating the transport await MockURLProtocol.requestHandlerStorage.setHandler { [testEndpoint] (request: URLRequest) in let response = HTTPURLResponse( @@ -296,6 +289,14 @@ import Testing return (response, Data("Not Found".utf8)) } + let transport = HTTPClientTransport( + endpoint: testEndpoint, + configuration: configuration, + streaming: false, + logger: nil + ) + try await transport.connect() + do { try await transport.send(messageData) Issue.record("Expected send to throw an error for 404") @@ -316,16 +317,9 @@ import Testing let configuration = URLSessionConfiguration.ephemeral configuration.protocolClasses = [MockURLProtocol.self] - let transport = HTTPClientTransport( - endpoint: testEndpoint, - configuration: configuration, - streaming: false, - logger: nil - ) - try await transport.connect() - let messageData = #"{"jsonrpc":"2.0","method":"test","id":4}"#.data(using: .utf8)! + // Set up the handler BEFORE creating the transport await MockURLProtocol.requestHandlerStorage.setHandler { [testEndpoint] (request: URLRequest) in let response = HTTPURLResponse( @@ -333,6 +327,14 @@ import Testing return (response, Data("Server Error".utf8)) } + let transport = HTTPClientTransport( + endpoint: testEndpoint, + configuration: configuration, + streaming: false, + logger: nil + ) + try await transport.connect() + do { try await transport.send(messageData) Issue.record("Expected send to throw an error for 500") @@ -353,22 +355,15 @@ import Testing let configuration = URLSessionConfiguration.ephemeral configuration.protocolClasses = [MockURLProtocol.self] - let transport = HTTPClientTransport( - endpoint: testEndpoint, - configuration: configuration, - streaming: false, - logger: nil - ) - try await transport.connect() - let initialSessionID = "expired-session-xyz" let firstMessageData = #"{"jsonrpc":"2.0","method":"initialize","id":1}"#.data( using: .utf8)! let secondMessageData = #"{"jsonrpc":"2.0","method":"ping","id":2}"#.data( using: .utf8)! + // Set up the first handler BEFORE creating the transport await MockURLProtocol.requestHandlerStorage.setHandler { - [testEndpoint] (request: URLRequest) in + [testEndpoint, initialSessionID] (request: URLRequest) in let response = HTTPURLResponse( url: testEndpoint, statusCode: 200, httpVersion: "HTTP/1.1", headerFields: [ @@ -377,11 +372,21 @@ import Testing ])! return (response, Data()) } + + let transport = HTTPClientTransport( + endpoint: testEndpoint, + configuration: configuration, + streaming: false, + logger: nil + ) + try await transport.connect() + try await transport.send(firstMessageData) #expect(await transport.sessionID == initialSessionID) + // Set up the second handler for the 404 response await MockURLProtocol.requestHandlerStorage.setHandler { - [testEndpoint] (request: URLRequest) in + [testEndpoint, initialSessionID] (request: URLRequest) in #expect(request.value(forHTTPHeaderField: "Mcp-Session-Id") == initialSessionID) let response = HTTPURLResponse( url: testEndpoint, statusCode: 404, httpVersion: "HTTP/1.1", headerFields: nil)! @@ -528,150 +533,197 @@ import Testing await transport.disconnect() } - #endif // !canImport(FoundationNetworking) - @Test( - "Client with HTTP Transport complete flow", .httpClientTransportSetup, - .timeLimit(.minutes(1))) - func testClientFlow() async throws { - let configuration = URLSessionConfiguration.ephemeral - configuration.protocolClasses = [MockURLProtocol.self] + @Test( + "Client with HTTP Transport complete flow", .httpClientTransportSetup, + .timeLimit(.minutes(1))) + func testClientFlow() async throws { + let configuration = URLSessionConfiguration.ephemeral + configuration.protocolClasses = [MockURLProtocol.self] - let transport = HTTPClientTransport( - endpoint: testEndpoint, - configuration: configuration, - streaming: false, - logger: nil - ) + let transport = HTTPClientTransport( + endpoint: testEndpoint, + configuration: configuration, + streaming: false, + logger: nil + ) - let client = Client(name: "TestClient", version: "1.0.0") + let client = Client(name: "TestClient", version: "1.0.0") - // Use an actor to track request sequence - actor RequestTracker { - enum RequestType { - case initialize - case callTool - } + // Use an actor to track request sequence + actor RequestTracker { + enum RequestType { + case initialize + case callTool + } - private(set) var lastRequest: RequestType? + private(set) var lastRequest: RequestType? - func setRequest(_ type: RequestType) { - lastRequest = type - } + func setRequest(_ type: RequestType) { + lastRequest = type + } - func getLastRequest() -> RequestType? { - return lastRequest + func getLastRequest() -> RequestType? { + return lastRequest + } } - } - let tracker = RequestTracker() + let tracker = RequestTracker() - // Setup mock responses - await MockURLProtocol.requestHandlerStorage.setHandler { - [testEndpoint, tracker] (request: URLRequest) in - switch request.httpMethod { - case "GET": - #expect( - request.allHTTPHeaderFields?["Accept"]?.contains("text/event-stream") - == true) - case "POST": - #expect( - request.allHTTPHeaderFields?["Accept"]?.contains("application/json") == true - ) - default: - Issue.record( - "Unsupported HTTP method \(String(describing: request.httpMethod))") - } + // Setup mock responses + await MockURLProtocol.requestHandlerStorage.setHandler { + [testEndpoint, tracker] (request: URLRequest) in + switch request.httpMethod { + case "GET": + #expect( + request.allHTTPHeaderFields?["Accept"]?.contains("text/event-stream") + == true) + case "POST": + #expect( + request.allHTTPHeaderFields?["Accept"]?.contains("application/json") + == true + ) + default: + Issue.record( + "Unsupported HTTP method \(String(describing: request.httpMethod))") + } - #expect(request.url == testEndpoint) + #expect(request.url == testEndpoint) + + let bodyData = request.readBody() + + guard let bodyData = bodyData, + let json = try JSONSerialization.jsonObject(with: bodyData) + as? [String: Any], + let method = json["method"] as? String + else { + throw NSError( + domain: "MockURLProtocolError", code: 0, + userInfo: [ + NSLocalizedDescriptionKey: + "Invalid JSON-RPC message \(#file):\(#line)" + ]) + } - let bodyData = request.readBody() - - guard let bodyData = bodyData, - let json = try JSONSerialization.jsonObject(with: bodyData) as? [String: Any], - let method = json["method"] as? String - else { - throw NSError( - domain: "MockURLProtocolError", code: 0, - userInfo: [ - NSLocalizedDescriptionKey: "Invalid JSON-RPC message \(#file):\(#line)" - ]) + if method == "initialize" { + await tracker.setRequest(.initialize) + + let requestID = json["id"] as! String + let result = Initialize.Result( + protocolVersion: Version.latest, + capabilities: .init(tools: .init()), + serverInfo: .init(name: "Mock Server", version: "0.0.1"), + instructions: nil + ) + let response = Initialize.response(id: .string(requestID), result: result) + let responseData = try JSONEncoder().encode(response) + + let httpResponse = HTTPURLResponse( + url: testEndpoint, statusCode: 200, httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (httpResponse, responseData) + } else if method == "tools/call" { + // Verify initialize was called first + if let lastRequest = await tracker.getLastRequest(), + lastRequest != .initialize + { + #expect(Bool(false), "Initialize should be called before callTool") + } + + await tracker.setRequest(.callTool) + + let params = json["params"] as? [String: Any] + let toolName = params?["name"] as? String + #expect(toolName == "calculator") + + let requestID = json["id"] as! String + let result = CallTool.Result(content: [.text("42")]) + let response = CallTool.response(id: .string(requestID), result: result) + let responseData = try JSONEncoder().encode(response) + + let httpResponse = HTTPURLResponse( + url: testEndpoint, statusCode: 200, httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (httpResponse, responseData) + } else if method == "notifications/initialized" { + // Ignore initialized notifications + let httpResponse = HTTPURLResponse( + url: testEndpoint, statusCode: 200, httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (httpResponse, Data()) + } else { + throw NSError( + domain: "MockURLProtocolError", code: 0, + userInfo: [ + NSLocalizedDescriptionKey: + "Unexpected request method: \(method) \(#file):\(#line)" + ]) + } } - if method == "initialize" { - await tracker.setRequest(.initialize) + // Step 1: Initialize client + let initResult = try await client.connect(transport: transport) + #expect(initResult.protocolVersion == Version.latest) + #expect(initResult.capabilities.tools != nil) - let requestID = json["id"] as! String - let result = Initialize.Result( - protocolVersion: Version.latest, - capabilities: .init(tools: .init()), - serverInfo: .init(name: "Mock Server", version: "0.0.1"), - instructions: nil - ) - let response = Initialize.response(id: .string(requestID), result: result) - let responseData = try JSONEncoder().encode(response) + // Step 2: Call a tool + let toolResult = try await client.callTool(name: "calculator") + #expect(toolResult.content.count == 1) + if case let .text(text) = toolResult.content[0] { + #expect(text == "42") + } else { + #expect(Bool(false), "Expected text content") + } - let httpResponse = HTTPURLResponse( - url: testEndpoint, statusCode: 200, httpVersion: "HTTP/1.1", - headerFields: ["Content-Type": "application/json"])! - return (httpResponse, responseData) - } else if method == "tools/call" { - // Verify initialize was called first - if let lastRequest = await tracker.getLastRequest(), lastRequest != .initialize - { - #expect(Bool(false), "Initialize should be called before callTool") - } + // Step 3: Verify request sequence + #expect(await tracker.getLastRequest() == .callTool) - await tracker.setRequest(.callTool) + // Step 4: Disconnect + await client.disconnect() + } - let params = json["params"] as? [String: Any] - let toolName = params?["name"] as? String - #expect(toolName == "calculator") + @Test("Request modifier functionality", .httpClientTransportSetup) + func testRequestModifier() async throws { + let testEndpoint = URL(string: "https://api.example.com/mcp")! + let testToken = "test-bearer-token-12345" - let requestID = json["id"] as! String - let result = CallTool.Result(content: [.text("42")]) - let response = CallTool.response(id: .string(requestID), result: result) - let responseData = try JSONEncoder().encode(response) + let configuration = URLSessionConfiguration.ephemeral + configuration.protocolClasses = [MockURLProtocol.self] - let httpResponse = HTTPURLResponse( - url: testEndpoint, statusCode: 200, httpVersion: "HTTP/1.1", - headerFields: ["Content-Type": "application/json"])! - return (httpResponse, responseData) - } else if method == "notifications/initialized" { - // Ignore initialized notifications - let httpResponse = HTTPURLResponse( + await MockURLProtocol.requestHandlerStorage.setHandler { + [testEndpoint, testToken] (request: URLRequest) in + // Verify the Authorization header was added by the requestModifier + #expect( + request.value(forHTTPHeaderField: "Authorization") == "Bearer \(testToken)") + + // Return a successful response + let response = HTTPURLResponse( url: testEndpoint, statusCode: 200, httpVersion: "HTTP/1.1", headerFields: ["Content-Type": "application/json"])! - return (httpResponse, Data()) - } else { - throw NSError( - domain: "MockURLProtocolError", code: 0, - userInfo: [ - NSLocalizedDescriptionKey: - "Unexpected request method: \(method) \(#file):\(#line)" - ]) + return (response, Data()) } - } - // Step 1: Initialize client - let initResult = try await client.connect(transport: transport) - #expect(initResult.protocolVersion == Version.latest) - #expect(initResult.capabilities.tools != nil) - - // Step 2: Call a tool - let toolResult = try await client.callTool(name: "calculator") - #expect(toolResult.content.count == 1) - if case let .text(text) = toolResult.content[0] { - #expect(text == "42") - } else { - #expect(Bool(false), "Expected text content") - } + // Create transport with requestModifier that adds Authorization header + let transport = HTTPClientTransport( + endpoint: testEndpoint, + configuration: configuration, + streaming: false, + requestModifier: { request in + var modifiedRequest = request + modifiedRequest.addValue( + "Bearer \(testToken)", forHTTPHeaderField: "Authorization") + return modifiedRequest + }, + logger: nil + ) - // Step 3: Verify request sequence - #expect(await tracker.getLastRequest() == .callTool) + try await transport.connect() - // Step 4: Disconnect - await client.disconnect() - } + let messageData = #"{"jsonrpc":"2.0","method":"test","id":5}"#.data(using: .utf8)! + + try await transport.send(messageData) + await transport.disconnect() + } + #endif // !canImport(FoundationNetworking) } #endif // swift(>=6.1)