Skip to content

Commit 5a3805d

Browse files
authored
Merge pull request #28 from JuliaComputing/tan/misc
fix race condition in basic_consume
2 parents beb6f34 + f11a599 commit 5a3805d

File tree

2 files changed

+77
-56
lines changed

2 files changed

+77
-56
lines changed

src/protocol.jl

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,11 @@ mutable struct MessageConsumer
210210
callback::Function
211211
receiver::Task
212212

213-
function MessageConsumer(chan_id::TAMQPChannel, consumer_tag::String, callback::Function; buffer_size::Int=typemax(Int))
214-
c = new(chan_id, consumer_tag, Channel{Message}(buffer_size), callback)
213+
function MessageConsumer(chan_id::TAMQPChannel, consumer_tag::String, callback::Function;
214+
buffer_size::Int=typemax(Int),
215+
buffer::Channel{Message}=Channel{Message}(buffer_size))
216+
217+
c = new(chan_id, consumer_tag, buffer, callback)
215218
c.receiver = @async connection_processor(c, "Consumer $consumer_tag", channel_message_consumer)
216219
c
217220
end
@@ -232,14 +235,16 @@ mutable struct MessageChannel <: AbstractChannel
232235
partial_msgs::Vector{Message} # holds partial messages while they are getting read (message bodies arrive in sequence)
233236
chan_get::Channel{Union{Message, Nothing}} # channel used for received messages, in sync get call (TODO: maybe type more strongly?)
234237
consumers::Dict{String,MessageConsumer}
238+
pending_msgs::Dict{String,Channel{Message}} # holds messages received that do not have a consumer registered
239+
lck::ReentrantLock
235240

236241
closereason::Union{CloseReason, Nothing}
237242

238243
function MessageChannel(id, conn)
239244
new(id, conn, CONN_STATE_CLOSED, true,
240245
Channel{TAMQPGenericFrame}(CONN_MAX_QUEUED), nothing, Dict{Tuple,Tuple{Function,Any}}(),
241246
Message[], Channel{Union{Message, Nothing}}(1), Dict{String,MessageConsumer}(),
242-
nothing)
247+
Dict{String,Channel{Message}}(), ReentrantLock(), nothing)
243248
end
244249
end
245250

@@ -422,7 +427,7 @@ end
422427
clear_handlers(c::MessageChannel) = (empty!(c.callbacks); nothing)
423428
function handle(c::MessageChannel, classname::Symbol, methodname::Symbol, cb=nothing, ctx=nothing)
424429
cbkey = method_key(classname, methodname)
425-
if cb == nothing
430+
if cb === nothing
426431
delete!(c.callbacks, cbkey)
427432
else
428433
c.callbacks[cbkey] = (cb, ctx)
@@ -431,7 +436,7 @@ function handle(c::MessageChannel, classname::Symbol, methodname::Symbol, cb=not
431436
end
432437
function handle(c::MessageChannel, frame_type::Integer, cb=nothing, ctx=nothing)
433438
cbkey = frame_key(frame_type)
434-
if cb == nothing
439+
if cb === nothing
435440
delete!(c.callbacks, cbkey)
436441
else
437442
c.callbacks[cbkey] = (cb, ctx)
@@ -779,15 +784,28 @@ nowait: do not send a reply method
779784
"""
780785
function basic_consume(chan::MessageChannel, queue::String, consumer_fn::Function; consumer_tag::String="", no_local::Bool=false, no_ack::Bool=false,
781786
exclusive::Bool=false, nowait::Bool=false, arguments::Dict{String,Any}=Dict{String,Any}(), timeout::Int=DEFAULT_TIMEOUT, buffer_sz::Int=typemax(Int))
787+
788+
# register the consumer and get the consumer_tag
782789
result = _wait_resp(chan, (true, ""), nowait, on_basic_consume_ok, :Basic, :ConsumeOk, (false, ""), timeout) do
783790
send_basic_consume(chan, queue, consumer_tag, no_local, no_ack, exclusive, nowait, arguments)
784791
end
785792

786-
# setup a message consumer
793+
# start the message consumer
787794
if result[1]
788795
consumer_tag = result[2]
789-
chan.consumers[consumer_tag] = MessageConsumer(chan.id, consumer_tag, consumer_fn; buffer_size=buffer_sz)
796+
797+
# set up message buffer beforehand to store messages that the consumer may receive while we are still setting things up,
798+
# or get the buffer that was set up already because we received messages
799+
lock(chan.lck) do
800+
consumer_buffer = get!(chan.pending_msgs, consumer_tag) do
801+
Channel{Message}(buffer_sz)
802+
end
803+
consumer_buffer.sz_max = buffer_sz
804+
chan.consumers[consumer_tag] = MessageConsumer(chan.id, consumer_tag, consumer_fn; buffer=consumer_buffer)
805+
delete!(chan.pending_msgs, consumer_tag)
806+
end
790807
end
808+
791809
result
792810
end
793811

@@ -1177,15 +1195,15 @@ function on_basic_get_empty_or_ok(chan::MessageChannel, m::TAMQPMethodFrame, ctx
11771195
end
11781196

11791197
function on_channel_message_in(chan::MessageChannel, m::TAMQPContentHeaderFrame, ctx)
1180-
msg = chan.partial_msgs[1]
1198+
msg = last(chan.partial_msgs)
11811199
msg.properties = m.hdrpayload.proplist
11821200
msg.data = Vector{UInt8}(undef, m.hdrpayload.bodysize)
11831201
msg.filled = 0
11841202
nothing
11851203
end
11861204

11871205
function on_channel_message_in(chan::MessageChannel, m::TAMQPContentBodyFrame, ctx)
1188-
msg = chan.partial_msgs[1]
1206+
msg = last(chan.partial_msgs)
11891207
data = m.payload.data
11901208
startpos = msg.filled + 1
11911209
endpos = min(length(msg.data), msg.filled + length(data))
@@ -1195,11 +1213,16 @@ function on_channel_message_in(chan::MessageChannel, m::TAMQPContentBodyFrame, c
11951213
if msg.filled >= length(msg.data)
11961214
# got all data for msg
11971215
if isempty(msg.consumer_tag)
1198-
put!(chan.chan_get, popfirst!(chan.partial_msgs))
1199-
elseif msg.consumer_tag in keys(chan.consumers)
1200-
put!(chan.consumers[msg.consumer_tag].recvq, popfirst!(chan.partial_msgs))
1216+
put!(chan.chan_get, pop!(chan.partial_msgs))
12011217
else
1202-
@debug("discarding message, no consumer with tag", tag=msg.consumer_tag)
1218+
lock(chan.lck) do
1219+
if msg.consumer_tag in keys(chan.consumers)
1220+
put!(chan.consumers[msg.consumer_tag].recvq, pop!(chan.partial_msgs))
1221+
else
1222+
put!(get!(()->Channel{Message}(typemax(Int)), chan.pending_msgs, msg.consumer_tag), msg)
1223+
@debug("holding message, no consumer yet with tag", tag=msg.consumer_tag)
1224+
end
1225+
end
12031226
end
12041227
end
12051228

@@ -1210,6 +1233,4 @@ on_confirm_select_ok(chan::MessageChannel, m::TAMQPMethodFrame, ctx) = _on_ack(c
12101233

12111234
# ----------------------------------------
12121235
# send and recv for methods end
1213-
# ----------------------------------------
1214-
1215-
1236+
# ----------------------------------------

test/test_rpc.jl

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,17 @@ module AMQPTestRPC
33
using AMQPClient, Test, Random
44

55
const JULIA_HOME = Sys.BINDIR
6-
76
const QUEUE_RPC = "queue_rpc"
8-
9-
testlog(msg) = println(msg)
10-
reply_queue_id = 0
11-
server_id = 0
12-
137
const NRPC_MSGS = 100
148
const NRPC_CLNTS = 4
159
const NRPC_SRVRS = 4
10+
const server_lck = Ref(ReentrantLock())
11+
const queue_declared = Ref(false)
12+
const servers_done = Channel{Int}(NRPC_SRVRS)
13+
14+
testlog(msg) = println(msg)
1615

17-
function test_rpc_client(;virtualhost="/", host="localhost", port=AMQPClient.AMQP_DEFAULT_PORT, auth_params=AMQPClient.DEFAULT_AUTH_PARAMS)
16+
function test_rpc_client(reply_queue_id; virtualhost="/", host="localhost", port=AMQPClient.AMQP_DEFAULT_PORT, auth_params=AMQPClient.DEFAULT_AUTH_PARAMS)
1817
# open a connection
1918
testlog("client opening connection...")
2019
conn = connection(;virtualhost=virtualhost, host=host, port=port, auth_params=auth_params)
@@ -24,8 +23,6 @@ function test_rpc_client(;virtualhost="/", host="localhost", port=AMQPClient.AMQ
2423
chan1 = channel(conn, AMQPClient.UNUSED_CHANNEL, true)
2524

2625
# create a reply queue for a client
27-
global reply_queue_id
28-
reply_queue_id += 1
2926
queue_name = QUEUE_RPC * "_" * string(reply_queue_id) * "_" * string(getpid())
3027
testlog("client creating queue " * queue_name * "...")
3128
success, queue_name, message_count, consumer_count = queue_declare(chan1, queue_name; exclusive=true)
@@ -83,11 +80,7 @@ function test_rpc_client(;virtualhost="/", host="localhost", port=AMQPClient.AMQ
8380
testlog("client done.")
8481
end
8582

86-
function test_rpc_server(;virtualhost="/", host="localhost", port=AMQPClient.AMQP_DEFAULT_PORT, auth_params=AMQPClient.DEFAULT_AUTH_PARAMS)
87-
global server_id
88-
server_id += 1
89-
my_server_id = server_id
90-
83+
function test_rpc_server(my_server_id; virtualhost="/", host="localhost", port=AMQPClient.AMQP_DEFAULT_PORT, auth_params=AMQPClient.DEFAULT_AUTH_PARAMS)
9184
# open a connection
9285
testlog("server $my_server_id opening connection...")
9386
conn = connection(;virtualhost=virtualhost, host=host, port=port, auth_params=auth_params)
@@ -97,10 +90,15 @@ function test_rpc_server(;virtualhost="/", host="localhost", port=AMQPClient.AMQ
9790
chan1 = channel(conn, AMQPClient.UNUSED_CHANNEL, true)
9891

9992
# create queues (no need to bind if we are using the default exchange)
100-
testlog("server $my_server_id creating queues...")
101-
# this is the callback queue
102-
success, message_count, consumer_count = queue_declare(chan1, QUEUE_RPC)
103-
@test success
93+
lock(server_lck[]) do
94+
if !(queue_declared[])
95+
testlog("server $my_server_id creating queues...")
96+
# this is the callback queue
97+
success, message_count, consumer_count = queue_declare(chan1, QUEUE_RPC)
98+
@test success
99+
queue_declared[] = true
100+
end
101+
end
104102

105103
# test RPC
106104
testlog("server $my_server_id testing rpc...")
@@ -131,16 +129,24 @@ function test_rpc_server(;virtualhost="/", host="localhost", port=AMQPClient.AMQ
131129
end
132130

133131
testlog("server $my_server_id closing down...")
134-
success, message_count = queue_purge(chan1, QUEUE_RPC)
135-
@test success
136-
@test message_count == 0
137-
138132
@test basic_cancel(chan1, consumer_tag)
139-
140-
success, message_count = queue_delete(chan1, QUEUE_RPC)
141-
@test success
142-
@test message_count == 0
143-
testlog("server $my_server_id deleted rpc queue")
133+
testlog("server $my_server_id cancelled consumer...")
134+
135+
lock(server_lck[]) do
136+
take!(servers_done)
137+
# the last server to finish will purge and delete the queue
138+
if length(servers_done.data) == 0
139+
success, message_count = queue_purge(chan1, QUEUE_RPC)
140+
@test success
141+
@test message_count == 0
142+
testlog("server $my_server_id purged queue...")
143+
144+
success, message_count = queue_delete(chan1, QUEUE_RPC)
145+
@test success
146+
@test message_count == 0
147+
testlog("server $my_server_id deleted rpc queue")
148+
end
149+
end
144150

145151
# close channels and connection
146152
close(chan1)
@@ -157,29 +163,23 @@ end
157163

158164
function runtests()
159165
testlog("testing multiple client server rpc")
160-
clients = Vector{Task}()
161-
servers = Vector{Task}()
162166

163167
for idx in 1:NRPC_SRVRS
164-
push!(servers, @async test_rpc_server())
168+
put!(servers_done, idx)
165169
end
166170

167-
for idx in 1:NRPC_CLNTS
168-
push!(clients, @async test_rpc_client())
169-
end
170-
171-
tasks_active = NRPC_CLNTS + NRPC_SRVRS
172-
while tasks_active > 0
173-
tasks_active = 0
171+
@sync begin
174172
for idx in 1:NRPC_SRVRS
175-
istaskdone(servers[idx]) || (tasks_active += 1)
173+
@async test_rpc_server(idx)
176174
end
175+
177176
for idx in 1:NRPC_CLNTS
178-
istaskdone(clients[idx]) || (tasks_active += 1)
177+
@async test_rpc_client(idx)
179178
end
180-
sleep(5)
181179
end
180+
182181
testlog("done")
183182
end
183+
184184
end # module AMQPTestRPC
185185

0 commit comments

Comments
 (0)