Skip to content
Open
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 62 additions & 13 deletions mmdb/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,25 @@ local geodb_mt = {
local data_types = {}
local getters = {}

local function open_db(filename)
local fd = assert(io.open(filename, "rb"))
local function fail(safe, ...)
if safe then
return nil, ...
end
error(..., 2)
end

local function open_db(safe, filename)
local fd, err = io.open(filename, "rb")
if not fd then
return fail(safe, err)
end

local contents, err = fd:read("*a")
fd:close()
assert(contents, err)

if not contents then
return fail(safe, err)
end

local start_metadata do
-- Find data section seperator; at most it's 128kb from the end
Expand All @@ -29,11 +43,12 @@ local function open_db(filename)
start_metadata = e + 1
end
if start_metadata == nil then
error("Invalid MaxMind Database")
return fail(safe, "Invalid MaxMind Database")
end
end

local self = setmetatable({
safe = not not safe;
contents = contents;
start_metadata = start_metadata;
data = nil;
Expand All @@ -47,7 +62,7 @@ local function open_db(filename)

local getter = getters[data.record_size]
if getter == nil then
error("Unsupported record size: " .. data.record_size)
return fail(self.safe, "Unsupported record size: " .. tostring(data.record_size))
end
self.left, self.right, self.record_length = getter.left, getter.right, getter.record_length

Expand Down Expand Up @@ -368,12 +383,17 @@ end

local function ipv4_to_bit_array(str)
local o1, o2, o3, o4 = str:match("(%d%d?%d?)%.(%d%d?%d?)%.(%d%d?%d?)%.(%d%d?%d?)")
assert(o1, "invalid IPv4 address")
if not o1 then
return nil, "invalid IPv4 address"
end
o1 = tonumber(o1, 10)
o2 = tonumber(o2, 10)
o3 = tonumber(o3, 10)
o4 = tonumber(o4, 10)
assert(o1 <= 255 and o2 <= 255 and o3 <= 255 and o4 <= 255, "invalid IPv4 address")
if not (o1 <= 255 and o2 <= 255 and o3 <= 255 and o4 <= 255) then
return nil, "invalid IPv4 address"
end

return {
math.floor(o1 / 128) % 2 == 1;
math.floor(o1 / 64) % 2 == 1;
Expand Down Expand Up @@ -411,7 +431,11 @@ local function ipv4_to_bit_array(str)
end

function geodb_methods:search_ipv4(str)
return select(2, self:search(ipv4_to_bit_array(str), self.ipv4_start))
local bits, err = ipv4_to_bit_array(str)
if not bits then
return fail(self.safe, err)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to call fail here: can return bits, err to the caller.

end
return select(2, self:search(bits, self.ipv4_start))
end

local function ipv6_split(str)
Expand All @@ -420,7 +444,9 @@ local function ipv6_split(str)
for u16 in str:gmatch("(%x%x?%x?%x?):?") do
n = n + 1
u16 = tonumber(u16, 16)
assert(u16, "invalid IPv6 address")
if not u16 then
return nil, "invalid IPv6 address"
end
components[n] = u16
end
return components, n
Expand All @@ -429,17 +455,27 @@ end
local function ipv6_to_bit_array(str)
local a, b = str:match("^([%x:]-)::([%x:]*)$")
local components, n = ipv6_split(a or str)
if not components then
return nil, n
end
if a ~= nil then
local end_components, m = ipv6_split(b)
assert(m+n <= 7, "invalid IPv6 address")
if not end_components then
return nil, m
end
if m+n > 7 then
return nil, "invalid IPv6 address"
end
for i = n+1, 8-m do
components[i] = 0
end
for i = 8-m+1, 8 do
components[i] = end_components[i-8+m]
end
else
assert(n == 8, "invalid IPv6 address")
if n ~= 8 then
return nil, "invalid IPv6 address"
end
end
-- Now components is an array of 16bit components
local bits = {}
Expand All @@ -453,9 +489,22 @@ local function ipv6_to_bit_array(str)
end

function geodb_methods:search_ipv6(str)
return select(2, self:search(ipv6_to_bit_array(str)))
local bits, err = ipv6_to_bit_array(str)
if not bits then
return fail(self.safe, err)
end
return select(2, self:search(bits))
end

local function open(...)
return open_db(false, ...)
end

local function open_safe(...)
return open_db(true, ...)
end

return {
open = open_db;
open = open;
open_safe = open_safe;
}