diff --git a/.husky/pre-commit b/.husky/pre-commit index 367adc4..35257ce 100755 --- a/.husky/pre-commit +++ b/.husky/pre-commit @@ -1,4 +1,4 @@ -#!/bin/sh +#!/bin/sh -l . "$(dirname "$0")/_/husky.sh" pnpm format && pnpm lint && pnpm build && pnpm test \ No newline at end of file diff --git a/src/index.ts b/src/index.ts index a13e8fe..4511206 100644 --- a/src/index.ts +++ b/src/index.ts @@ -2,7 +2,7 @@ import { IncomingMessage as Request, ServerResponse as Response } from 'http' import { vary } from 'es-vary' export interface AccessControlOptions { - origin?: string | boolean | ((req: Request, res: Response) => string) | Array | RegExp + origin?: string | boolean | ((req: Request, res: Response) => string) | Iterable | RegExp methods?: string[] allowedHeaders?: string[] exposedHeaders?: string[] @@ -12,6 +12,68 @@ export interface AccessControlOptions { preflightContinue?: boolean } +const isIterable = (obj: unknown): obj is Iterable => typeof obj[Symbol.iterator] === 'function' + +const failOriginParam = () => { + throw new TypeError('No other objects allowed. Allowed types is array of strings or RegExp') +} + +const getOriginHeaderHandler = (origin: unknown): ((req: Request, res: Response) => void) => { + if (typeof origin === 'boolean') { + return origin + ? (_, res) => { + res.setHeader('Access-Control-Allow-Origin', '*') + } + : () => undefined + } + + if (typeof origin === 'string') { + return (_, res) => { + res.setHeader('Access-Control-Allow-Origin', origin) + } + } + + if (typeof origin === 'function') { + return (req, res) => { + vary(res, 'Origin') + res.setHeader('Access-Control-Allow-Origin', origin(req, res)) + } + } + + if (typeof origin !== 'object') failOriginParam() + + if (isIterable(origin)) { + const originArray = Array.from(origin) + if (originArray.some((element) => typeof element !== 'string')) failOriginParam() + + const originSet = new Set(origin) + + if (originSet.has('*')) { + return (_, res) => { + res.setHeader('Access-Control-Allow-Origin', '*') + } + } + + return (req, res) => { + vary(res, 'Origin') + if (req.headers.origin === undefined) return + if (!originSet.has(req.headers.origin)) return + res.setHeader('Access-Control-Allow-Origin', req.headers.origin) + } + } + + if (origin instanceof RegExp) { + return (req, res) => { + vary(res, 'Origin') + if (req.headers.origin === undefined) return + if (!origin.test(req.headers.origin)) return + res.setHeader('Access-Control-Allow-Origin', req.headers.origin) + } + } + + failOriginParam() +} + /** * CORS Middleware */ @@ -26,24 +88,10 @@ export const cors = (opts: AccessControlOptions = {}) => { optionsSuccessStatus = 204, preflightContinue = false } = opts + const originHeaderHandler = getOriginHeaderHandler(origin) + return (req: Request, res: Response, next?: () => void) => { - // Checking the type of the origin property - if (typeof origin === 'boolean' && origin === true) { - res.setHeader('Access-Control-Allow-Origin', '*') - } else if (typeof origin === 'string') { - res.setHeader('Access-Control-Allow-Origin', origin) - } else if (typeof origin === 'function') { - res.setHeader('Access-Control-Allow-Origin', origin(req, res)) - } else if (typeof origin === 'object') { - if (Array.isArray(origin) && (origin.indexOf(req.headers.origin) !== -1 || origin.indexOf('*') !== -1)) { - res.setHeader('Access-Control-Allow-Origin', req.headers.origin) - } else if (origin instanceof RegExp && origin.test(req.headers.origin)) { - res.setHeader('Access-Control-Allow-Origin', req.headers.origin) - } else { - throw new TypeError('No other objects allowed. Allowed types is array of strings or RegExp') - } - } - if ((typeof origin === 'string' && origin !== '*') || typeof origin === 'function') vary(res, 'Origin') + originHeaderHandler(req, res) // Setting the Access-Control-Allow-Methods header from the methods array res.setHeader('Access-Control-Allow-Methods', methods.join(', ').toUpperCase()) diff --git a/tests/index.test.ts b/tests/index.test.ts index 0a888c1..aa1f671 100644 --- a/tests/index.test.ts +++ b/tests/index.test.ts @@ -50,20 +50,44 @@ describe('CORS headers tests', (it) => { 'http://example.com' ) }) - it('should set origin if it is an array', async () => { - const app = createServer(cors({ origin: ['http://example.com', 'example.com', 'https://example.com'] })) + describe('when origin is an array of strings', (it) => { + it('should set origin when origin header is included in request and whitelisted', async () => { + const app = createServer(cors({ origin: ['http://example.com', 'example.com', 'https://example.com'] })) - const fetch = makeFetch(app) + const fetch = makeFetch(app) - await fetch('/', { headers: { Origin: 'http://example.com' } }).expect( - 'Access-Control-Allow-Origin', - 'http://example.com' - ) + await fetch('/', { headers: { Origin: 'http://example.com' } }).expect( + 'Access-Control-Allow-Origin', + 'http://example.com' + ) + }) + it('should not set origin when origin header is included in request but not whitelisted', async () => { + const app = createServer(cors({ origin: ['http://example.com', 'example.com', 'https://example.com'] })) + + const fetch = makeFetch(app) + + await fetch('/', { headers: { Origin: 'http://not-example.com' } }).expect('Access-Control-Allow-Origin', null) + }) + it('should not set origin when origin header is excluded from request', async () => { + const app = createServer(cors({ origin: ['http://example.com', 'example.com', 'https://example.com'] })) + + const fetch = makeFetch(app) + + await fetch('/').expect('Access-Control-Allow-Origin', null) + }) + }) + it('should send an error if origin is an iterable containing a non-string', async () => { + try { + // @ts-ignore + const middleware = cors({ origin: [{}, 3, 'abc'] }) + } catch (e) { + assert.strictEqual(e.message, 'No other objects allowed. Allowed types is array of strings or RegExp') + } }) it('should send an error if it is other object types', () => { try { // @ts-ignore - const app = createServer(cors({ origin: { site: 'http://example.com' } })) + const middleware = cors({ origin: { site: 'http://example.com' } }) } catch (e) { assert.strictEqual(e.message, 'No other objects allowed. Allowed types is array of strings or RegExp') }