diff --git a/docs/pages/setup/manual.mdx b/docs/pages/setup/manual.mdx index 58e4251..840eef5 100644 --- a/docs/pages/setup/manual.mdx +++ b/docs/pages/setup/manual.mdx @@ -61,7 +61,8 @@ const keys = await generateKeyPair("RS256", { }); const privateKey = await exportPKCS8(keys.privateKey); const publicKey = await exportJWK(keys.publicKey); -const jwks = JSON.stringify({ keys: [{ use: "sig", ...publicKey }] }); +const kid = crypto.randomUUID(); +const jwks = JSON.stringify({ keys: [{ use: "sig", kid, ...publicKey }] }); process.stdout.write( `JWT_PRIVATE_KEY="${privateKey.trimEnd().replace(/\n/g, " ")}"`, @@ -127,5 +128,5 @@ export default http; -Continue to [Step 3](/setup#add-authentication-tables-to-your-schema) in -the Setup guide. +Continue to [Step 3](/setup#add-authentication-tables-to-your-schema) in the +Setup guide. diff --git a/src/cli/generateKeys.ts b/src/cli/generateKeys.ts index bad4358..559a6a7 100644 --- a/src/cli/generateKeys.ts +++ b/src/cli/generateKeys.ts @@ -5,7 +5,8 @@ export async function generateKeys() { const keys = await generateKeyPair("RS256"); const privateKey = await exportPKCS8(keys.privateKey); const publicKey = await exportJWK(keys.publicKey); - const jwks = JSON.stringify({ keys: [{ use: "sig", ...publicKey }] }); + const kid = crypto.randomUUID(); + const jwks = JSON.stringify({ keys: [{ use: "sig", kid, ...publicKey }] }); return { JWT_PRIVATE_KEY: `${privateKey.trimEnd().replace(/\n/g, " ")}`, JWKS: jwks, diff --git a/src/server/implementation/tokens.ts b/src/server/implementation/tokens.ts index 3701f24..847f3d0 100644 --- a/src/server/implementation/tokens.ts +++ b/src/server/implementation/tokens.ts @@ -1,6 +1,6 @@ import { GenericId } from "convex/values"; import { ConvexAuthConfig } from "../index.js"; -import { SignJWT, importPKCS8 } from "jose"; +import { SignJWT, importPKCS8, type JWK } from "jose"; import { requireEnv } from "../utils.js"; import { TOKEN_SUB_CLAIM_DIVIDER } from "./utils.js"; @@ -17,13 +17,29 @@ export async function generateToken( const expirationTime = new Date( Date.now() + (config.jwt?.durationMs ?? DEFAULT_JWT_DURATION_MS), ); + const latestJwk = getLatestJwk(); return await new SignJWT({ sub: args.userId + TOKEN_SUB_CLAIM_DIVIDER + args.sessionId, }) - .setProtectedHeader({ alg: "RS256" }) + .setProtectedHeader({ alg: "RS256", kid: latestJwk.kid, typ: "JWT" }) .setIssuedAt() .setIssuer(requireEnv("CONVEX_SITE_URL")) .setAudience("convex") .setExpirationTime(expirationTime) .sign(privateKey); } + +function getLatestJwk() { + try { + const jwksString = requireEnv("JWKS"); + const jwks = JSON.parse(jwksString); + // assume the latest JWK is the first one + const latestJwk = jwks["keys"][0]; + if (!latestJwk) { + throw new Error("No JWK found"); + } + return latestJwk as JWK; + } catch (error) { + throw new Error("Error getting latest JWK", { cause: error }); + } +}