Skip to content
23 changes: 23 additions & 0 deletions src/client/sse.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,29 @@ describe("SSEClientTransport", () => {
expect(mockAuthProvider.tokens).toHaveBeenCalled();
});

it("attaches custom header from provider on initial SSE connection", async () => {
mockAuthProvider.tokens.mockResolvedValue({
access_token: "test-token",
token_type: "Bearer"
});
const customHeaders = {
"X-Custom-Header": "custom-value",
};

transport = new SSEClientTransport(resourceBaseUrl, {
authProvider: mockAuthProvider,
requestInit: {
headers: customHeaders,
},
});

await transport.start();

expect(lastServerRequest.headers.authorization).toBe("Bearer test-token");
expect(lastServerRequest.headers["x-custom-header"]).toBe("custom-value");
expect(mockAuthProvider.tokens).toHaveBeenCalled();
});

it("attaches auth header from provider on POST requests", async () => {
mockAuthProvider.tokens.mockResolvedValue({
access_token: "test-token",
Expand Down
23 changes: 10 additions & 13 deletions src/client/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,8 @@ export class SSEClientTransport implements Transport {
return await this._startOrAuth();
}

private async _commonHeaders(): Promise<HeadersInit> {
const headers = {
...this._requestInit?.headers,
} as HeadersInit & Record<string, string>;
private async _commonHeaders(): Promise<Headers> {
const headers: HeadersInit = {};
if (this._authProvider) {
const tokens = await this._authProvider.tokens();
if (tokens) {
Expand All @@ -120,24 +118,24 @@ export class SSEClientTransport implements Transport {
headers["mcp-protocol-version"] = this._protocolVersion;
}

return headers;
return new Headers(
{ ...headers, ...this._requestInit?.headers }
);
}

private _startOrAuth(): Promise<void> {
const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typeof fetch
const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typeof fetch
return new Promise((resolve, reject) => {
this._eventSource = new EventSource(
this._url.href,
{
...this._eventSourceInit,
fetch: async (url, init) => {
const headers = await this._commonHeaders()
const headers = await this._commonHeaders();
headers.set("Accept", "text/event-stream");
const response = await fetchImpl(url, {
...init,
headers: new Headers({
...headers,
Accept: "text/event-stream"
})
headers,
})

if (response.status === 401 && response.headers.has('www-authenticate')) {
Expand Down Expand Up @@ -238,8 +236,7 @@ const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typ
}

try {
const commonHeaders = await this._commonHeaders();
const headers = new Headers(commonHeaders);
const headers = await this._commonHeaders();
headers.set("content-type", "application/json");
const init = {
...this._requestInit,
Expand Down