Skip to content

Commit 2f70b4a

Browse files
committed
http/websocket.lua: Fix behaviour of read_frame on error
- Fixes incorrect return values on error (missing error message #107) - Now retry safe
1 parent 072db57 commit 2f70b4a

File tree

2 files changed

+97
-35
lines changed

2 files changed

+97
-35
lines changed

http/websocket.lua

Lines changed: 73 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ local http_patts = require "lpeg_patterns.http"
3535
local rand = require "openssl.rand"
3636
local digest = require "openssl.digest"
3737
local bit = require "http.bit"
38+
local onerror = require "http.connection_common".onerror
3839
local new_headers = require "http.headers".new
3940
local http_request = require "http.request"
4041

@@ -177,13 +178,18 @@ local function build_close(code, message, mask)
177178
end
178179

179180
local function read_frame(sock, deadline)
180-
local frame do
181-
local first_2, err, errno = sock:xread(2, "b", deadline and (deadline-monotime()))
181+
local frame, first_2 do
182+
local err, errno
183+
first_2, err, errno = sock:xread(2, "b", deadline and (deadline-monotime()))
182184
if not first_2 then
183185
return nil, err, errno
184186
elseif #first_2 ~= 2 then
185187
sock:seterror("r", ce.EILSEQ)
186-
return nil, ce.strerror(ce.EILSEQ), ce.EILSEQ
188+
local ok, errno2 = sock:unget(first_2)
189+
if not ok then
190+
return nil, onerror(sock, "unget", errno2)
191+
end
192+
return nil, onerror(sock, "read_frame", ce.EILSEQ)
187193
end
188194
local byte1, byte2 = first_2:byte(1, 2)
189195
frame = {
@@ -200,50 +206,80 @@ local function read_frame(sock, deadline)
200206
}
201207
end
202208

203-
if frame.length == 126 then
204-
local length, err, errno = sock:xread(2, "b", deadline and (deadline-monotime()))
205-
if not length or #length ~= 2 then
206-
if err == nil then
209+
local fill_length = frame.length
210+
if fill_length == 126 then
211+
fill_length = 2
212+
elseif fill_length == 127 then
213+
fill_length = 8
214+
end
215+
if frame.MASK then
216+
fill_length = fill_length + 4
217+
end
218+
do
219+
local ok, err, errno = sock:fill(fill_length, 0)
220+
if not ok then
221+
local unget_ok1, unget_errno1 = sock:unget(first_2)
222+
if not unget_ok1 then
223+
return nil, onerror(sock, "unget", unget_errno1)
224+
end
225+
if errno == ce.ETIMEDOUT then
226+
local timeout = deadline and deadline-monotime()
227+
if cqueues.poll(sock, timeout) ~= timeout then
228+
-- retry
229+
return read_frame(sock, deadline)
230+
end
231+
elseif err == nil then
207232
sock:seterror("r", ce.EILSEQ)
208-
return nil, ce.strerror(ce.EILSEQ), ce.EILSEQ
233+
return nil, onerror(sock, "read_frame", ce.EILSEQ)
209234
end
210235
return nil, err, errno
211236
end
212-
frame.length = sunpack(">I2", length)
237+
end
238+
239+
-- if `fill` succeeded these shouldn't be able to fail
240+
local extra_fill_unget
241+
if frame.length == 126 then
242+
extra_fill_unget = assert(sock:xread(2, "b", 0))
243+
frame.length = sunpack(">I2", extra_fill_unget)
244+
fill_length = fill_length - 2
213245
elseif frame.length == 127 then
214-
local length, err, errno = sock:xread(8, "b", deadline and (deadline-monotime()))
215-
if not length or #length ~= 8 then
216-
if err == nil then
246+
extra_fill_unget = assert(sock:xread(8, "b", 0))
247+
frame.length = sunpack(">I8", extra_fill_unget)
248+
fill_length = fill_length - 8 + frame.length
249+
end
250+
251+
if extra_fill_unget then
252+
local ok, err, errno = sock:fill(fill_length, 0)
253+
if not ok then
254+
local unget_ok1, unget_errno1 = sock:unget(extra_fill_unget)
255+
if not unget_ok1 then
256+
return nil, onerror(sock, "unget", unget_errno1)
257+
end
258+
local unget_ok2, unget_errno2 = sock:unget(first_2)
259+
if not unget_ok2 then
260+
return nil, onerror(sock, "unget", unget_errno2)
261+
end
262+
if errno == ce.ETIMEDOUT then
263+
local timeout = deadline and deadline-monotime()
264+
if cqueues.poll(sock, timeout) ~= timeout then
265+
-- retry
266+
return read_frame(sock, deadline)
267+
end
268+
elseif err == nil then
217269
sock:seterror("r", ce.EILSEQ)
218-
return nil, ce.strerror(ce.EILSEQ), ce.EILSEQ
270+
return nil, onerror(sock, "read_frame", ce.EILSEQ)
219271
end
220272
return nil, err, errno
221273
end
222-
frame.length = sunpack(">I8", length)
223274
end
224275

225276
if frame.MASK then
226-
local key, err, errno = sock:xread(4, "b", deadline and (deadline-monotime()))
227-
if not key or #key ~= 4 then
228-
if err == nil then
229-
sock:seterror("r", ce.EILSEQ)
230-
return nil, ce.strerror(ce.EILSEQ), ce.EILSEQ
231-
end
232-
return nil, err, errno
233-
end
277+
local key = assert(sock:xread(4, "b", 0))
234278
frame.key = { key:byte(1, 4) }
235279
end
236280

237281
do
238-
local data, err, errno = sock:xread(frame.length, "b", deadline and (deadline-monotime()))
239-
if data == nil or #data ~= frame.length then
240-
if err == nil then
241-
sock:seterror("r", ce.EILSEQ)
242-
return nil, ce.strerror(ce.EILSEQ), ce.EILSEQ
243-
end
244-
return nil, err, errno
245-
end
246-
282+
local data = assert(sock:xread(frame.length, "b", 0))
247283
if frame.MASK then
248284
frame.data = apply_mask(data, frame.key)
249285
else
@@ -267,9 +303,9 @@ end
267303

268304
function websocket_methods:send_frame(frame, timeout)
269305
if self.readyState < 1 then
270-
return nil, ce.strerror(ce.ENOTCONN), ce.ENOTCONN
306+
return nil, onerror(self.socket, "send_frame", ce.ENOTCONN)
271307
elseif self.readyState > 2 then
272-
return nil, ce.strerror(ce.EPIPE), ce.EPIPE
308+
return nil, onerror(self.socket, "send_frame", ce.EPIPE)
273309
end
274310
local ok, err, errno = self.socket:xwrite(build_frame(frame), "bn", timeout)
275311
if not ok then
@@ -349,9 +385,9 @@ end
349385

350386
function websocket_methods:receive(timeout)
351387
if self.readyState < 1 then
352-
return nil, ce.strerror(ce.ENOTCONN), ce.ENOTCONN
388+
return nil, onerror(self.socket, "receive", ce.ENOTCONN)
353389
elseif self.readyState > 2 then
354-
return nil, ce.strerror(ce.EPIPE), ce.EPIPE
390+
return nil, onerror(self.socket, "receive", ce.EPIPE)
355391
end
356392
local deadline = timeout and (monotime()+timeout)
357393
while true do
@@ -638,6 +674,7 @@ local function handle_websocket_response(self, headers, stream)
638674
-- Success!
639675
assert(self.socket == nil, "websocket:connect called twice")
640676
self.socket = assert(stream.connection:take_socket())
677+
self.socket:onerror(onerror)
641678
self.request = nil
642679
self.headers = headers
643680
self.readyState = 1
@@ -776,6 +813,7 @@ function websocket_methods:accept(options, timeout)
776813
end
777814

778815
self.socket = assert(self.stream.connection:take_socket())
816+
self.socket:onerror(onerror)
779817
self.stream = nil
780818
self.readyState = 1
781819
self.protocol = chosen_protocol

spec/websocket_spec.lua

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,14 +169,18 @@ describe("http.websocket", function()
169169
end)
170170
end)
171171
describe("http.websocket module two sided tests", function()
172+
local onerror = require "http.connection_common".onerror
172173
local server = require "http.server"
173174
local util = require "http.util"
174175
local websocket = require "http.websocket"
175176
local cqueues = require "cqueues"
176177
local ca = require "cqueues.auxlib"
178+
local ce = require "cqueues.errno"
177179
local cs = require "cqueues.socket"
178180
local function new_pair()
179181
local s, c = ca.assert(cs.pair())
182+
s:onerror(onerror)
183+
c:onerror(onerror)
180184
local ws_server = websocket.new("server")
181185
ws_server.socket = s
182186
ws_server.readyState = 1
@@ -201,6 +205,26 @@ describe("http.websocket module two sided tests", function()
201205
assert_loop(cq, TEST_TIMEOUT)
202206
assert.truthy(cq:empty())
203207
end)
208+
it("timeouts return nil, err, errno", function()
209+
local cq = cqueues.new()
210+
local c, s = new_pair()
211+
local ok, _, errno = c:receive(0)
212+
assert.same(nil, ok)
213+
assert.same(ce.ETIMEDOUT, errno)
214+
-- Check it still works afterwards
215+
cq:wrap(function()
216+
assert(c:send("hello"))
217+
assert.same("world", c:receive())
218+
assert(c:close())
219+
end)
220+
cq:wrap(function()
221+
assert.same("hello", s:receive())
222+
assert(s:send("world"))
223+
assert(s:close())
224+
end)
225+
assert_loop(cq, TEST_TIMEOUT)
226+
assert.truthy(cq:empty())
227+
end)
204228
it("doesn't fail when data contains a \\r\\n", function()
205229
local cq = cqueues.new()
206230
local c, s = new_pair()

0 commit comments

Comments
 (0)