Skip to content

Commit f9aa617

Browse files
authored
Fix for deadlock-y condition on error + close call (#20)
The root problem is that in our `eof` call on an `SSLStream`, we're currently grabbing the `SSLStream` lock for basically the entire `eof` call, _including_ when we do a potentially long-blocking call to `eof` on the underlying socket! This is bad because it basically prevents force-closing the `SSLStream` until the socket `eof` call returns, and we can finally get the `SSLStream` lock in our `close` call. The fix proposed here is introducing a 2nd `eoflock` and renaming `lock` to `closelock`. `closelock` is used (hopefully obviously) just for protecting the close operation. `eoflock` is used to bundle the calling of `eof` on the underlying socket and calling `SSL_peek` to process bytes, so they're observed by concurrent tasks as a single operation.
1 parent 0e5a1e2 commit f9aa617

File tree

2 files changed

+115
-71
lines changed

2 files changed

+115
-71
lines changed

src/ssl.jl

Lines changed: 101 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,17 @@ macro atomicset(ex)
387387
end
388388
end
389389

390+
macro atomiccas(ex, cmp, val)
391+
@static if VERSION < v"1.7"
392+
return esc(quote
393+
_ret = Threads.atomic_cas!($ex, $cmp, $val)
394+
(; success=(_ret === $cmp))
395+
end)
396+
else
397+
return esc(:(@atomicreplace $ex $cmp => $val))
398+
end
399+
end
400+
390401
"""
391402
SSLStream.
392403
"""
@@ -396,9 +407,14 @@ mutable struct SSLStream <: IO
396407
rbio::BIO
397408
wbio::BIO
398409
io::TCPSocket
399-
lock::ReentrantLock
410+
# used in `eof` where we want the call to `eof` on the underlying
411+
# socket and the SSL_peek call that processes bytes to be seen
412+
# as one "operation"
413+
eoflock::ReentrantLock
400414
readbytes::Base.RefValue{Csize_t}
401415
writebytes::Base.RefValue{Csize_t}
416+
peekbuf::Base.RefValue{UInt8}
417+
peekbytes::Base.RefValue{Csize_t}
402418
@static if VERSION < v"1.7"
403419
close_notify_received::Threads.Atomic{Bool}
404420
closed::Threads.Atomic{Bool}
@@ -414,9 +430,9 @@ end
414430
ssl = SSL(ssl_context, bio_read, bio_write)
415431

416432
@static if VERSION < v"1.7"
417-
return new(ssl, ssl_context, bio_read, bio_write, io, ReentrantLock(), Ref{Csize_t}(0), Ref{Csize_t}(0), Threads.Atomic{Bool}(false), Threads.Atomic{Bool}(false))
433+
return new(ssl, ssl_context, bio_read, bio_write, io, ReentrantLock(), Ref{Csize_t}(0), Ref{Csize_t}(0), Ref{UInt8}(0x00), Ref{Csize_t}(0), Threads.Atomic{Bool}(false), Threads.Atomic{Bool}(false))
418434
else
419-
return new(ssl, ssl_context, bio_read, bio_write, io, ReentrantLock(), Ref{Csize_t}(0), Ref{Csize_t}(0), false, false)
435+
return new(ssl, ssl_context, bio_read, bio_write, io, ReentrantLock(), Ref{Csize_t}(0), Ref{Csize_t}(0), Ref{UInt8}(0x00), Ref{Csize_t}(0), false, false)
420436
end
421437
end
422438
end
@@ -430,39 +446,54 @@ SSLStream(tcp::TCPSocket) = SSLStream(SSLContext(OpenSSL.TLSClientMethod()), tcp
430446
Base.isreadable(ssl::SSLStream)::Bool = !(@atomicget(ssl.close_notify_received))
431447
Base.isopen(ssl::SSLStream)::Bool = !@atomicget(ssl.closed)
432448
Base.iswritable(ssl::SSLStream)::Bool = isopen(ssl) && isopen(ssl.io)
433-
check_isopen(ssl::SSLStream, op) = isopen(ssl) || throw(Base.IOError("$op requires ssl to be open", 0))
449+
@noinline throwio(op) = throw(Base.IOError("$op requires ssl to be open", 0))
434450

435-
macro geterror(ssl, expr)
436-
quote
451+
# this is a macro, but should be a function, but closures are stupid slow
452+
# we use this to standardize the error handling for all of the SSL_*_ex functions
453+
macro geterror(ssl, op, expr)
454+
esc(quote
455+
# clear the current error queue before openssl ccall
437456
clear_errors!()
438-
ret = $(esc(expr))
457+
# last check that SSL is still open before ccall
458+
isopen($ssl) || throwio($op)
459+
# do the ccall
460+
ret = $expr
461+
# we want to return one of our SSL return codes, regardless of error
462+
# SSL_peek_ex, SSL_write_ex, SSL_connect, and SSL_read_ex all return 1 on success
439463
if ret == 1
440464
ret = SSL_ERROR_NONE
441465
else
442-
ssl = $(esc(ssl))
443-
err = get_error(ssl.ssl, ret)
466+
err = get_error($ssl.ssl, ret)
444467
if err == SSL_ERROR_ZERO_RETURN
445-
@atomicset ssl.close_notify_received = true
468+
# the peer sent a close_notify, so no more reading is possible
469+
@atomicset $ssl.close_notify_received = true
446470
elseif err == SSL_ERROR_NONE
447-
# pass
471+
ret = SSL_ERROR_NONE
448472
elseif err == SSL_ERROR_WANT_READ
473+
# we need to read more data from the underlying socket
449474
ret = SSL_ERROR_WANT_READ
450475
elseif err == SSL_ERROR_WANT_WRITE
476+
# we need to write more data to the underlying socket
477+
# we don't expect to ever see this since we set up our SSL
478+
# to do auto TLS (re)negotiation
451479
ret = SSL_ERROR_WANT_WRITE
452480
else
453-
close(ssl, false)
481+
# this is usually some other kind of error, like a protocol error
482+
# or OS-level IO error, just close the SSL connection and throw
483+
# notably, the openssl docs say we should *not* call ssl_disconnect
484+
# in this case, hence the `false` arg to close
485+
close($ssl, false)
454486
throw(Base.IOError(OpenSSLError(err).msg, 0))
455487
end
456488
end
457489
ret
458-
end
490+
end)
459491
end
460492

461493
function Base.unsafe_write(ssl::SSLStream, in_buffer::Ptr{UInt8}, in_length::UInt)
462-
check_isopen(ssl, "unsafe_write")
463494
nwritten = 0
464495
while nwritten < in_length
465-
ret = @geterror ssl ccall(
496+
ret = @geterror ssl :unsafe_write ccall(
466497
(:SSL_write_ex, libssl),
467498
Cint,
468499
(SSL, Ptr{Cvoid}, Cint, Ptr{Csize_t}),
@@ -480,14 +511,13 @@ end
480511

481512
function Sockets.connect(ssl::SSLStream; require_ssl_verification::Bool=true)
482513
while true
483-
check_isopen(ssl, "connect")
484-
ret = @geterror ssl ssl_connect(ssl.ssl)
514+
ret = @geterror ssl :connect ssl_connect(ssl.ssl)
485515
if ret == SSL_ERROR_NONE
486516
break
487517
elseif ret == SSL_ERROR_WANT_READ
488-
if eof(ssl.io)
489-
throw(EOFError())
490-
end
518+
# this means connect is waiting for more data from the underlying socket
519+
# so call eof on the socket to wait for more bytes to come in
520+
eof(ssl.io) && throw(EOFError())
491521
else
492522
throw(Base.IOError(OpenSSLError(ret).msg, 0))
493523
end
@@ -511,7 +541,9 @@ function Sockets.connect(ssl::SSLStream; require_ssl_verification::Bool=true)
511541
cert === nothing && throw(OpenSSLError("No peer certificate"))
512542
end
513543

514-
# set read ahead
544+
# set read ahead; this is a recommended optimization when we can guarantee
545+
# that an SSL connection will only ever be read from sequentially, which we do
546+
# by not doing any internal buffering
515547
ccall(
516548
(:SSL_set_read_ahead, libssl),
517549
Cvoid,
@@ -525,7 +557,6 @@ const SSL_CTRL_SET_TLSEXT_HOSTNAME = 55
525557
const TLSEXT_NAMETYPE_host_name = 0
526558

527559
function hostname!(ssl::SSLStream, host)
528-
# SSL_set_tlsext_host_name
529560
if (ret = ccall(
530561
(:SSL_ctrl, libssl),
531562
Cint,
@@ -545,12 +576,9 @@ end
545576
"""
546577
function Base.unsafe_read(ssl::SSLStream, buf::Ptr{UInt8}, nbytes::UInt)
547578
nread = 0
579+
readbytes = ssl.readbytes
548580
while nread < nbytes
549-
# If open, optimistically call `SSL_read_ex` to try to save an `eof` call;
550-
# if that returns `SSL_WANT_READ` we will call `eof` anyway afterwards.
551-
!isopen(ssl) && throw(EOFError())
552-
readbytes = ssl.readbytes
553-
ret = @geterror ssl ccall(
581+
ret = @geterror ssl :unsafe_read ccall(
554582
(:SSL_read_ex, libssl),
555583
Cint,
556584
(SSL, Ptr{UInt8}, Csize_t, Ptr{Csize_t}),
@@ -571,37 +599,52 @@ end
571599
function Base.readavailable(ssl::SSLStream)
572600
N = bytesavailable(ssl)
573601
buf = Vector{UInt8}(undef, N)
574-
n = unsafe_read(ssl, pointer(buf), N)
602+
n = GC.@preserve buf unsafe_read(ssl, pointer(buf), N)
575603
return resize!(buf, n)
576604
end
577605

606+
# returns the # of bytes that can be read immediately via unsafe_read
607+
# i.e. # of processes, decrypted bytes available
578608
function Base.bytesavailable(ssl::SSLStream)::Cint
579609
isopen(ssl) || return 0
580-
pending_count = ccall(
610+
return Int(ccall(
581611
(:SSL_pending, libssl),
582612
Cint,
583613
(SSL,),
584-
ssl.ssl)
585-
return pending_count
614+
ssl.ssl))
586615
end
587616

588-
function haspending(s::SSLStream)
589-
isopen(s) || return false
590-
has_pending = ccall(
617+
# returns whether there are _any_ bytes buffered, processed
618+
# or unprocessed, in the SSL stream
619+
function haspending(ssl::SSLStream)
620+
isopen(ssl) || return false
621+
return 1 == ccall(
591622
(:SSL_has_pending, libssl),
592623
Cint,
593624
(SSL,),
594-
s.ssl)
595-
return has_pending == 1
625+
ssl.ssl)
596626
end
597627

598628
function Base.eof(ssl::SSLStream)::Bool
629+
isopen(ssl) || return true
599630
bytesavailable(ssl) > 0 && return false
600-
Base.@lock ssl.lock begin
601-
# check if we're open inside the lock in case ssl got closed
602-
# in `close` while we were waiting for the lock
603-
isopen(ssl) || return true
604-
while isreadable(ssl) && bytesavailable(ssl) <= 0
631+
peekbuf = ssl.peekbuf
632+
peekbytes = ssl.peekbytes
633+
while isreadable(ssl)
634+
# note that care needs to be taken here to avoid a potential bad
635+
# race condition; for SSLStream, we have to manage the state of
636+
# the underlying socket having available bytes *and* whether they've
637+
# been processed in the ssl layer, so we want to treat the receiving and processing
638+
# of bytes as a single operation; in other words, bytesavailable returns
639+
# > 0 when bytes have been received *and* processed and we don't want
640+
# racing tasks to get stuck in between. We also don't really care whether
641+
# tasks are blocked calling eof on the socket or waiting on eoflock, so
642+
# we avoid the races and keep things orderly by only allowing one task
643+
# to make the eof call and kick off byte processing at a time.
644+
Base.@lock ssl.eoflock begin
645+
# check condition now that we have eoflock since another task may have
646+
# succeeded in getting bytes processed
647+
bytesavailable(ssl) > 0 && return false
605648
# no processed bytes available, check if there are unprocessed bytes
606649
if !haspending(ssl)
607650
# no unprocessed bytes, call eof to get more unprocessed
@@ -610,20 +653,19 @@ function Base.eof(ssl::SSLStream)::Bool
610653
return true
611654
end
612655
end
613-
# if we're here, we know there are unprocessed bytes,
614-
# so we call peek to force processing
615-
byte = Ref{UInt8}(0x00)
616-
ptr = Base.unsafe_convert(Ptr{UInt8}, byte)
617-
GC.@preserve byte begin
618-
ret = @geterror ssl ccall(
619-
(:SSL_peek, libssl),
620-
Cint,
621-
(SSL, Ptr{UInt8}, Cint),
622-
ssl.ssl,
623-
ptr,
624-
1
625-
)
626-
end
656+
# at this point, we know there are at least unprocessed bytes
657+
# buffered, so we call SSL_peek to get the next record processed,
658+
# which still might not result in bytesavailable > 0
659+
ret = @geterror ssl :peek ccall(
660+
(:SSL_peek_ex, libssl),
661+
Cint,
662+
(SSL, Ptr{UInt8}, Cint, Ptr{Csize_t}),
663+
ssl.ssl,
664+
peekbuf,
665+
1,
666+
peekbytes
667+
)
668+
ret == SSL_ERROR_NONE && return false
627669
# if we get WANT_READ back, that means there were pending bytes
628670
# to be processed, but not a full record, so we need to wait
629671
# for additional bytes to come in before we can process
@@ -638,23 +680,16 @@ end
638680
Close SSL stream.
639681
"""
640682
function Base.close(ssl::SSLStream, shutdown::Bool=true)
641-
# eager unconditional closed set so other concurrent operations see it immediately
642-
@atomicset ssl.closed = true
643-
# if we've already finalized, no further action needed
644-
ssl.ssl.ssl == C_NULL && return
645-
# close operations
646-
Base.@lock ssl.lock begin
647-
# we do an additional check once inside the lock in case
648-
# it was closed while we were waiting on the lock
649-
isopen(ssl) || return
650-
# Ignore the disconnect result.
651-
shutdown && ssl_disconnect(ssl.ssl)
683+
if @atomiccas(ssl.closed, false, true).success
684+
# we won the race to close the ssl
652685
# close underlying io
653686
try
654687
Base.close(ssl.io)
655688
catch e
656689
e isa Base.IOError || rethrow()
657690
end
691+
# Ignore the disconnect result.
692+
shutdown && ssl_disconnect(ssl.ssl)
658693
free(ssl.ssl)
659694
end
660695
return

test/runtests.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,23 +175,22 @@ end
175175
end
176176

177177
@testset "HttpsConnect" begin
178-
tcp_stream = connect("www.nghttp2.org", 443)
178+
tcp_stream = connect("httpbingo.julialang.org", 443)
179179

180180
ssl_ctx = OpenSSL.SSLContext(OpenSSL.TLSClientMethod())
181181
result = OpenSSL.ssl_set_options(ssl_ctx, OpenSSL.SSL_OP_NO_COMPRESSION)
182182

183183
# Create SSL stream.
184184
ssl = SSLStream(ssl_ctx, tcp_stream, tcp_stream)
185185

186-
#TODO expose connect
187186
OpenSSL.connect(ssl)
188187

189188
x509_server_cert = OpenSSL.get_peer_certificate(ssl)
190189

191190
@test String(x509_server_cert.issuer_name) == "/C=US/O=Let's Encrypt/CN=R3"
192-
@test String(x509_server_cert.subject_name) == "/CN=nghttp2.org"
191+
@test String(x509_server_cert.subject_name) == "/CN=httpbingo.julialang.org"
193192

194-
request_str = "GET / HTTP/1.1\r\nHost: www.nghttp2.org\r\nUser-Agent: curl\r\nAccept: */*\r\n\r\n"
193+
request_str = "GET /status/200 HTTP/1.1\r\nHost: httpbingo.julialang.org\r\nUser-Agent: curl\r\nAccept: */*\r\n\r\n"
195194

196195
written = write(ssl, request_str)
197196

@@ -201,8 +200,18 @@ end
201200
write(io, readavailable(ssl))
202201
response = String(take!(io))
203202
@test startswith(response, "HTTP/1.1 200 OK\r\n")
204-
close(ssl)
203+
sleep(2)
204+
@test isempty(readavailable(ssl))
205+
# start a bunch of tasks all racing to call eof
206+
tasks = [@async(eof(ssl)) for _ = 1:100]
207+
yield()
208+
@test all(t -> !istaskdone(t), tasks)
209+
closetasks = [@async(close(ssl)) for _ = 1:100]
210+
yield()
211+
sleep(2)
205212
finalize(ssl_ctx)
213+
@test all(t -> istaskdone(t), tasks)
214+
@test all(t -> istaskdone(t), closetasks)
206215
end
207216

208217
@testset "ClosedStream" begin

0 commit comments

Comments
 (0)