Skip to content

Commit 38e321a

Browse files
Haibo Yanbillyean
authored andcommitted
Bugfix: Closing PostgreSQL connection deadlocks listen(_:)
1 parent db1eae1 commit 38e321a

File tree

5 files changed

+116
-2
lines changed

5 files changed

+116
-2
lines changed

Sources/PostgresNIO/Connection/PostgresConnection.swift

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,14 @@ extension PostgresConnection {
455455

456456
let task = HandlerTask.startListening(listener)
457457

458-
self.channel.write(task, promise: nil)
458+
let promise = self.channel.eventLoop.makePromise(of: Void.self)
459+
promise.futureResult.whenFailure { error in
460+
self.logger.debug("Channel error in listen()",
461+
metadata: [.error: "\(error)"])
462+
listener.failed(PSQLError(code: .listenFailed))
463+
}
464+
465+
self.channel.write(task, promise: promise)
459466
}
460467
} onCancel: {
461468
let task = HandlerTask.cancelListening(channel, id)

Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ extension ListenStateMachine {
224224
mutating func fail(_ error: Error) -> FailAction {
225225
switch self.state {
226226
case .initialized:
227-
fatalError("Invalid state: \(self.state)")
227+
return .none
228228

229229
case .starting(let listeners), .listening(let listeners), .stopping(let listeners):
230230
self.state = .failed(error)

Sources/PostgresNIO/New/NotificationListener.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,17 @@ final class NotificationListener: @unchecked Sendable {
1717
case done
1818
}
1919

20+
deinit {
21+
switch self.state {
22+
case .streamInitialized:
23+
preconditionFailure("Notification continuation had not been used")
24+
case .closure:
25+
preconditionFailure("Notification closure had not been used")
26+
case .streamListening, .done:
27+
break
28+
}
29+
}
30+
2031
init(
2132
channel: String,
2233
id: Int,

Sources/PostgresNIO/New/PostgresChannelHandler.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
209209
psqlTask = .extendedQuery(query)
210210

211211
case .startListening(let listener):
212+
defer { promise?.succeed(()) }
212213
switch self.listenState.startListening(listener) {
213214
case .startListening(let channel):
214215
psqlTask = self.makeStartListeningQuery(channel: channel, context: context)

Tests/IntegrationTests/AsyncTests.swift

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,101 @@ final class AsyncPostgresConnectionTests: XCTestCase {
284284
}
285285
}
286286

287+
func testListenTwiceChannel() async throws {
288+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
289+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
290+
let eventLoop = eventLoopGroup.next()
291+
292+
try await self.withTestConnection(on: eventLoop) { connection in
293+
// Concurrently listen on a channel that is initially closed
294+
async let stream1later = connection.listen("same-channel")
295+
async let stream2later = connection.listen("same-channel")
296+
let (stream1, stream2) = try await (stream1later, stream2later)
297+
298+
try await self.withTestConnection(on: eventLoop) { other in
299+
try await other.query(#"NOTIFY "\#(unescaped: "same-channel")";"#, logger: .psqlTest)
300+
}
301+
302+
var stream1EventReceived = false
303+
var stream2EventReceived = false
304+
305+
for try await _ in stream1 {
306+
stream1EventReceived = true
307+
break
308+
}
309+
310+
for try await _ in stream2 {
311+
stream2EventReceived = true
312+
break
313+
}
314+
315+
XCTAssertTrue(stream1EventReceived)
316+
XCTAssertTrue(stream2EventReceived)
317+
}
318+
}
319+
320+
func testListenOnClosedChannel() async throws {
321+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
322+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
323+
let eventLoop = eventLoopGroup.next()
324+
325+
try await self.withTestConnection(on: eventLoop) { connection in
326+
try await connection.close()
327+
do {
328+
_ = try await connection.listen("futile")
329+
XCTFail("Expected not to get any events")
330+
} catch let error as PSQLError where error.code == .listenFailed {
331+
// Expected
332+
}
333+
}
334+
}
335+
336+
func testListenThenCloseChannel() async throws {
337+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
338+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
339+
let eventLoop = eventLoopGroup.next()
340+
341+
try await self.withTestConnection(on: eventLoop) { connection in
342+
let stream = try await connection.listen("hopeful")
343+
try await connection.close()
344+
do {
345+
for try await _ in stream {
346+
XCTFail("Expected not to get any events")
347+
}
348+
XCTFail("Expected not to have reached the end of stream")
349+
} catch is PSQLError {
350+
// Expected
351+
}
352+
}
353+
}
354+
355+
func testListenThenClosingChannel() async throws {
356+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
357+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
358+
let eventLoop = eventLoopGroup.next()
359+
360+
try await self.withTestConnection(on: eventLoop) { connection in
361+
_ = try await connection.listen("initial")
362+
async let asyncClose: () = connection.close()
363+
let stream: PostgresNotificationSequence
364+
do {
365+
stream = try await connection.listen("hopeful")
366+
} catch let error as PSQLError where error.code == .listenFailed {
367+
// Expected
368+
return
369+
}
370+
try await asyncClose
371+
do {
372+
for try await _ in stream {
373+
XCTFail("Expected not to get any events")
374+
}
375+
XCTFail("Expected not to have reached the end of stream")
376+
} catch is PSQLError {
377+
// Expected
378+
}
379+
}
380+
}
381+
287382
#if canImport(Network)
288383
func testSelect10kRowsNetworkFramework() async throws {
289384
let eventLoopGroup = NIOTSEventLoopGroup()

0 commit comments

Comments
 (0)