Skip to content
136 changes: 123 additions & 13 deletions src/client/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,22 @@ describe("OAuth Authorization", () => {
});

describe("exchangeAuthorization", () => {
const mockProvider: OAuthClientProvider = {
get redirectUrl() { return "http://localhost:3000/callback"; },
get clientMetadata() {
return {
redirect_uris: ["http://localhost:3000/callback"],
client_name: "Test Client",
};
},
clientInformation: jest.fn(),
tokens: jest.fn(),
saveTokens: jest.fn(),
redirectToAuthorization: jest.fn(),
saveCodeVerifier: jest.fn(),
codeVerifier: jest.fn(),
};

const validTokens = {
access_token: "access123",
token_type: "Bearer",
Expand All @@ -435,7 +451,7 @@ describe("OAuth Authorization", () => {
json: async () => validTokens,
});

const tokens = await exchangeAuthorization("https://auth.example.com", {
const tokens = await exchangeAuthorization("https://auth.example.com", mockProvider, {
clientInformation: validClientInfo,
authorizationCode: "code123",
codeVerifier: "verifier123",
Expand All @@ -449,12 +465,11 @@ describe("OAuth Authorization", () => {
}),
expect.objectContaining({
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
})
);

const headers = mockFetch.mock.calls[0][1].headers as Headers;
expect(headers.get("Content-Type")).toBe("application/x-www-form-urlencoded");
const body = mockFetch.mock.calls[0][1].body as URLSearchParams;
expect(body.get("grant_type")).toBe("authorization_code");
expect(body.get("code")).toBe("code123");
Expand All @@ -464,6 +479,48 @@ describe("OAuth Authorization", () => {
expect(body.get("redirect_uri")).toBe("http://localhost:3000/callback");
});

it("exchanges code for tokens with auth", async () => {
mockProvider.authToTokenEndpoint = function(headers: Headers, params: URLSearchParams) {
headers.set("Authorization", "Basic " + btoa(validClientInfo.client_id + ":" + validClientInfo.client_secret));
params.set("example_param", "example_value")
};

mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
json: async () => validTokens,
});

const tokens = await exchangeAuthorization("https://auth.example.com", mockProvider, {
clientInformation: validClientInfo,
authorizationCode: "code123",
codeVerifier: "verifier123",
redirectUri: "http://localhost:3000/callback",
});

expect(tokens).toEqual(validTokens);
expect(mockFetch).toHaveBeenCalledWith(
expect.objectContaining({
href: "https://auth.example.com/token",
}),
expect.objectContaining({
method: "POST",
})
);

const headers = mockFetch.mock.calls[0][1].headers as Headers;
expect(headers.get("Content-Type")).toBe("application/x-www-form-urlencoded");
expect(headers.get("Authorization")).toBe("Basic Y2xpZW50MTIzOnNlY3JldDEyMw==");
const body = mockFetch.mock.calls[0][1].body as URLSearchParams;
expect(body.get("grant_type")).toBe("authorization_code");
expect(body.get("code")).toBe("code123");
expect(body.get("code_verifier")).toBe("verifier123");
expect(body.get("client_id")).toBe("client123");
expect(body.get("redirect_uri")).toBe("http://localhost:3000/callback");
expect(body.get("example_param")).toBe("example_value");
expect(body.get("client_secret")).toBeUndefined;
});

it("validates token response schema", async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
Expand All @@ -475,7 +532,7 @@ describe("OAuth Authorization", () => {
});

await expect(
exchangeAuthorization("https://auth.example.com", {
exchangeAuthorization("https://auth.example.com", mockProvider, {
clientInformation: validClientInfo,
authorizationCode: "code123",
codeVerifier: "verifier123",
Expand All @@ -491,7 +548,7 @@ describe("OAuth Authorization", () => {
});

await expect(
exchangeAuthorization("https://auth.example.com", {
exchangeAuthorization("https://auth.example.com", mockProvider, {
clientInformation: validClientInfo,
authorizationCode: "code123",
codeVerifier: "verifier123",
Expand All @@ -502,6 +559,22 @@ describe("OAuth Authorization", () => {
});

describe("refreshAuthorization", () => {
const mockProvider: OAuthClientProvider = {
get redirectUrl() { return "http://localhost:3000/callback"; },
get clientMetadata() {
return {
redirect_uris: ["http://localhost:3000/callback"],
client_name: "Test Client",
};
},
clientInformation: jest.fn(),
tokens: jest.fn(),
saveTokens: jest.fn(),
redirectToAuthorization: jest.fn(),
saveCodeVerifier: jest.fn(),
codeVerifier: jest.fn(),
};

const validTokens = {
access_token: "newaccess123",
token_type: "Bearer",
Expand All @@ -526,7 +599,7 @@ describe("OAuth Authorization", () => {
json: async () => validTokensWithNewRefreshToken,
});

const tokens = await refreshAuthorization("https://auth.example.com", {
const tokens = await refreshAuthorization("https://auth.example.com", mockProvider, {
clientInformation: validClientInfo,
refreshToken: "refresh123",
});
Expand All @@ -538,19 +611,56 @@ describe("OAuth Authorization", () => {
}),
expect.objectContaining({
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
})
);

const headers = mockFetch.mock.calls[0][1].headers as Headers;
expect(headers.get("Content-Type")).toBe("application/x-www-form-urlencoded");
const body = mockFetch.mock.calls[0][1].body as URLSearchParams;
expect(body.get("grant_type")).toBe("refresh_token");
expect(body.get("refresh_token")).toBe("refresh123");
expect(body.get("client_id")).toBe("client123");
expect(body.get("client_secret")).toBe("secret123");
});

it("exchanges refresh token for new tokens with auth", async () => {
mockProvider.authToTokenEndpoint = function(headers: Headers, params: URLSearchParams) {
headers.set("Authorization", "Basic " + btoa(validClientInfo.client_id + ":" + validClientInfo.client_secret));
params.set("example_param", "example_value")
};

mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
json: async () => validTokensWithNewRefreshToken,
});

const tokens = await refreshAuthorization("https://auth.example.com", mockProvider, {
clientInformation: validClientInfo,
refreshToken: "refresh123",
});

expect(tokens).toEqual(validTokensWithNewRefreshToken);
expect(mockFetch).toHaveBeenCalledWith(
expect.objectContaining({
href: "https://auth.example.com/token",
}),
expect.objectContaining({
method: "POST",
})
);

const headers = mockFetch.mock.calls[0][1].headers as Headers;
expect(headers.get("Content-Type")).toBe("application/x-www-form-urlencoded");
expect(headers.get("Authorization")).toBe("Basic Y2xpZW50MTIzOnNlY3JldDEyMw==");
const body = mockFetch.mock.calls[0][1].body as URLSearchParams;
expect(body.get("grant_type")).toBe("refresh_token");
expect(body.get("refresh_token")).toBe("refresh123");
expect(body.get("client_id")).toBe("client123");
expect(body.get("example_param")).toBe("example_value");
expect(body.get("client_secret")).toBeUndefined;
});

it("exchanges refresh token for new tokens and keep existing refresh token if none is returned", async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
Expand All @@ -559,7 +669,7 @@ describe("OAuth Authorization", () => {
});

const refreshToken = "refresh123";
const tokens = await refreshAuthorization("https://auth.example.com", {
const tokens = await refreshAuthorization("https://auth.example.com", mockProvider, {
clientInformation: validClientInfo,
refreshToken,
});
Expand All @@ -578,7 +688,7 @@ describe("OAuth Authorization", () => {
});

await expect(
refreshAuthorization("https://auth.example.com", {
refreshAuthorization("https://auth.example.com", mockProvider, {
clientInformation: validClientInfo,
refreshToken: "refresh123",
})
Expand All @@ -592,7 +702,7 @@ describe("OAuth Authorization", () => {
});

await expect(
refreshAuthorization("https://auth.example.com", {
refreshAuthorization("https://auth.example.com", mockProvider, {
clientInformation: validClientInfo,
refreshToken: "refresh123",
})
Expand Down
30 changes: 20 additions & 10 deletions src/client/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ export interface OAuthClientProvider {
* the authorization result.
*/
codeVerifier(): string | Promise<string>;

authToTokenEndpoint?(headers: Headers, params: URLSearchParams): void | Promise<void>;
}

export type AuthResult = "AUTHORIZED" | "REDIRECT";
Expand Down Expand Up @@ -131,7 +133,7 @@ export async function auth(
// Exchange authorization code for tokens
if (authorizationCode !== undefined) {
const codeVerifier = await provider.codeVerifier();
const tokens = await exchangeAuthorization(authorizationServerUrl, {
const tokens = await exchangeAuthorization(authorizationServerUrl, provider, {
metadata,
clientInformation,
authorizationCode,
Expand All @@ -149,7 +151,7 @@ export async function auth(
if (tokens?.refresh_token) {
try {
// Attempt to refresh the token
const newTokens = await refreshAuthorization(authorizationServerUrl, {
const newTokens = await refreshAuthorization(authorizationServerUrl, provider, {
metadata,
clientInformation,
refreshToken: tokens.refresh_token,
Expand Down Expand Up @@ -359,6 +361,7 @@ export async function startAuthorization(
*/
export async function exchangeAuthorization(
authorizationServerUrl: string | URL,
provider: OAuthClientProvider,
{
metadata,
clientInformation,
Expand Down Expand Up @@ -392,6 +395,9 @@ export async function exchangeAuthorization(
}

// Exchange code for tokens
const headers = new Headers({
"Content-Type": "application/x-www-form-urlencoded",
});
const params = new URLSearchParams({
grant_type: grantType,
client_id: clientInformation.client_id,
Expand All @@ -400,15 +406,15 @@ export async function exchangeAuthorization(
redirect_uri: String(redirectUri),
});

if (clientInformation.client_secret) {
if (provider.authToTokenEndpoint) {
provider.authToTokenEndpoint(headers, params);
} else if (clientInformation.client_secret) {
params.set("client_secret", clientInformation.client_secret);
}

const response = await fetch(tokenUrl, {
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
headers: headers,
body: params,
});

Expand All @@ -424,6 +430,7 @@ export async function exchangeAuthorization(
*/
export async function refreshAuthorization(
authorizationServerUrl: string | URL,
provider: OAuthClientProvider,
{
metadata,
clientInformation,
Expand Down Expand Up @@ -453,21 +460,24 @@ export async function refreshAuthorization(
}

// Exchange refresh token
const headers = new Headers({
"Content-Type": "application/x-www-form-urlencoded",
});
const params = new URLSearchParams({
grant_type: grantType,
client_id: clientInformation.client_id,
refresh_token: refreshToken,
});

if (clientInformation.client_secret) {
if (provider.authToTokenEndpoint) {
provider.authToTokenEndpoint(headers, params);
} else if (clientInformation.client_secret) {
params.set("client_secret", clientInformation.client_secret);
}

const response = await fetch(tokenUrl, {
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
headers: headers,
body: params,
});
if (!response.ok) {
Expand Down