Skip to content

Commit 11b94a7

Browse files
Merge pull request #5 from Lordfirespeed/refactor-origin-handler
Refactor origin handler
2 parents 57cf95f + 78645fc commit 11b94a7

File tree

3 files changed

+99
-27
lines changed

3 files changed

+99
-27
lines changed

.husky/pre-commit

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#!/bin/sh
1+
#!/bin/sh -l
22
. "$(dirname "$0")/_/husky.sh"
33

44
pnpm format && pnpm lint && pnpm build && pnpm test

src/index.ts

Lines changed: 66 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { IncomingMessage as Request, ServerResponse as Response } from 'http'
22
import { vary } from 'es-vary'
33

44
export interface AccessControlOptions {
5-
origin?: string | boolean | ((req: Request, res: Response) => string) | Array<string> | RegExp
5+
origin?: string | boolean | ((req: Request, res: Response) => string) | Iterable<string> | RegExp
66
methods?: string[]
77
allowedHeaders?: string[]
88
exposedHeaders?: string[]
@@ -12,6 +12,68 @@ export interface AccessControlOptions {
1212
preflightContinue?: boolean
1313
}
1414

15+
const isIterable = (obj: unknown): obj is Iterable<unknown> => typeof obj[Symbol.iterator] === 'function'
16+
17+
const failOriginParam = () => {
18+
throw new TypeError('No other objects allowed. Allowed types is array of strings or RegExp')
19+
}
20+
21+
const getOriginHeaderHandler = (origin: unknown): ((req: Request, res: Response) => void) => {
22+
if (typeof origin === 'boolean') {
23+
return origin
24+
? (_, res) => {
25+
res.setHeader('Access-Control-Allow-Origin', '*')
26+
}
27+
: () => undefined
28+
}
29+
30+
if (typeof origin === 'string') {
31+
return (_, res) => {
32+
res.setHeader('Access-Control-Allow-Origin', origin)
33+
}
34+
}
35+
36+
if (typeof origin === 'function') {
37+
return (req, res) => {
38+
vary(res, 'Origin')
39+
res.setHeader('Access-Control-Allow-Origin', origin(req, res))
40+
}
41+
}
42+
43+
if (typeof origin !== 'object') failOriginParam()
44+
45+
if (isIterable(origin)) {
46+
const originArray = Array.from(origin)
47+
if (originArray.some((element) => typeof element !== 'string')) failOriginParam()
48+
49+
const originSet = new Set(origin)
50+
51+
if (originSet.has('*')) {
52+
return (_, res) => {
53+
res.setHeader('Access-Control-Allow-Origin', '*')
54+
}
55+
}
56+
57+
return (req, res) => {
58+
vary(res, 'Origin')
59+
if (req.headers.origin === undefined) return
60+
if (!originSet.has(req.headers.origin)) return
61+
res.setHeader('Access-Control-Allow-Origin', req.headers.origin)
62+
}
63+
}
64+
65+
if (origin instanceof RegExp) {
66+
return (req, res) => {
67+
vary(res, 'Origin')
68+
if (req.headers.origin === undefined) return
69+
if (!origin.test(req.headers.origin)) return
70+
res.setHeader('Access-Control-Allow-Origin', req.headers.origin)
71+
}
72+
}
73+
74+
failOriginParam()
75+
}
76+
1577
/**
1678
* CORS Middleware
1779
*/
@@ -26,24 +88,10 @@ export const cors = (opts: AccessControlOptions = {}) => {
2688
optionsSuccessStatus = 204,
2789
preflightContinue = false
2890
} = opts
91+
const originHeaderHandler = getOriginHeaderHandler(origin)
92+
2993
return (req: Request, res: Response, next?: () => void) => {
30-
// Checking the type of the origin property
31-
if (typeof origin === 'boolean' && origin === true) {
32-
res.setHeader('Access-Control-Allow-Origin', '*')
33-
} else if (typeof origin === 'string') {
34-
res.setHeader('Access-Control-Allow-Origin', origin)
35-
} else if (typeof origin === 'function') {
36-
res.setHeader('Access-Control-Allow-Origin', origin(req, res))
37-
} else if (typeof origin === 'object') {
38-
if (Array.isArray(origin) && (origin.indexOf(req.headers.origin) !== -1 || origin.indexOf('*') !== -1)) {
39-
res.setHeader('Access-Control-Allow-Origin', req.headers.origin)
40-
} else if (origin instanceof RegExp && origin.test(req.headers.origin)) {
41-
res.setHeader('Access-Control-Allow-Origin', req.headers.origin)
42-
} else {
43-
throw new TypeError('No other objects allowed. Allowed types is array of strings or RegExp')
44-
}
45-
}
46-
if ((typeof origin === 'string' && origin !== '*') || typeof origin === 'function') vary(res, 'Origin')
94+
originHeaderHandler(req, res)
4795

4896
// Setting the Access-Control-Allow-Methods header from the methods array
4997
res.setHeader('Access-Control-Allow-Methods', methods.join(', ').toUpperCase())

tests/index.test.ts

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,44 @@ describe('CORS headers tests', (it) => {
5050
'http://example.com'
5151
)
5252
})
53-
it('should set origin if it is an array', async () => {
54-
const app = createServer(cors({ origin: ['http://example.com', 'example.com', 'https://example.com'] }))
53+
describe('when origin is an array of strings', (it) => {
54+
it('should set origin when origin header is included in request and whitelisted', async () => {
55+
const app = createServer(cors({ origin: ['http://example.com', 'example.com', 'https://example.com'] }))
5556

56-
const fetch = makeFetch(app)
57+
const fetch = makeFetch(app)
5758

58-
await fetch('/', { headers: { Origin: 'http://example.com' } }).expect(
59-
'Access-Control-Allow-Origin',
60-
'http://example.com'
61-
)
59+
await fetch('/', { headers: { Origin: 'http://example.com' } }).expect(
60+
'Access-Control-Allow-Origin',
61+
'http://example.com'
62+
)
63+
})
64+
it('should not set origin when origin header is included in request but not whitelisted', async () => {
65+
const app = createServer(cors({ origin: ['http://example.com', 'example.com', 'https://example.com'] }))
66+
67+
const fetch = makeFetch(app)
68+
69+
await fetch('/', { headers: { Origin: 'http://not-example.com' } }).expect('Access-Control-Allow-Origin', null)
70+
})
71+
it('should not set origin when origin header is excluded from request', async () => {
72+
const app = createServer(cors({ origin: ['http://example.com', 'example.com', 'https://example.com'] }))
73+
74+
const fetch = makeFetch(app)
75+
76+
await fetch('/').expect('Access-Control-Allow-Origin', null)
77+
})
78+
})
79+
it('should send an error if origin is an iterable containing a non-string', async () => {
80+
try {
81+
// @ts-ignore
82+
const middleware = cors({ origin: [{}, 3, 'abc'] })
83+
} catch (e) {
84+
assert.strictEqual(e.message, 'No other objects allowed. Allowed types is array of strings or RegExp')
85+
}
6286
})
6387
it('should send an error if it is other object types', () => {
6488
try {
6589
// @ts-ignore
66-
const app = createServer(cors({ origin: { site: 'http://example.com' } }))
90+
const middleware = cors({ origin: { site: 'http://example.com' } })
6791
} catch (e) {
6892
assert.strictEqual(e.message, 'No other objects allowed. Allowed types is array of strings or RegExp')
6993
}

0 commit comments

Comments
 (0)