Skip to content

Commit c6ce4e0

Browse files
author
Sebastien Stormacq
committed
Merge branch 'sebsto/fix_584' into sebsto/multiple_continuations
2 parents 7979b4c + 0cd73da commit c6ce4e0

File tree

2 files changed

+75
-33
lines changed

2 files changed

+75
-33
lines changed

Sources/AWSLambdaRuntime/Lambda+LocalServer.swift

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -393,35 +393,50 @@ internal struct LambdaHTTPServer {
393393
self.invocationPool.push(LocalServerInvocation(requestId: requestId, request: body))
394394

395395
// wait for the lambda function to process the request
396-
for try await response in self.responsePool {
397-
logger[metadataKey: "response requestId"] = "\(response.requestId ?? "nil")"
398-
logger.trace("Received response to return to client")
399-
if response.requestId == requestId {
400-
logger.trace("/invoke requestId is valid, sending the response")
401-
// send the response to the client
402-
// if the response is final, we can send it and return
403-
// if the response is not final, we can send it and wait for the next response
404-
try await self.sendResponse(response, outbound: outbound, logger: logger)
405-
if response.final == true {
406-
logger.trace("/invoke returning")
407-
return // if the response is final, we can return and close the connection
396+
// when POST /invoke is called multiple times before a response is processed,
397+
// the `for try await ... in` loop will throw an error and we will return a 400 error to the client
398+
do {
399+
for try await response in self.responsePool {
400+
logger[metadataKey: "response requestId"] = "\(response.requestId ?? "nil")"
401+
logger.trace("Received response to return to client")
402+
if response.requestId == requestId {
403+
logger.trace("/invoke requestId is valid, sending the response")
404+
// send the response to the client
405+
// if the response is final, we can send it and return
406+
// if the response is not final, we can send it and wait for the next response
407+
try await self.sendResponse(response, outbound: outbound, logger: logger)
408+
if response.final == true {
409+
logger.trace("/invoke returning")
410+
return // if the response is final, we can return and close the connection
411+
}
412+
} else {
413+
logger.error(
414+
"Received response for a different requestId",
415+
metadata: ["response requestId": "\(response.requestId ?? "")"]
416+
)
417+
let response = LocalServerResponse(
418+
id: requestId,
419+
status: .badRequest,
420+
body: ByteBuffer(string: "The responseId is not equal to the requestId.")
421+
)
422+
try await self.sendResponse(response, outbound: outbound, logger: logger)
408423
}
409-
} else {
410-
logger.error(
411-
"Received response for a different requestId",
412-
metadata: ["response requestId": "\(response.requestId ?? "")"]
413-
)
414-
let response = LocalServerResponse(
415-
id: requestId,
416-
status: .badRequest,
417-
body: ByteBuffer(string: "The responseId is not equal to the requestId.")
418-
)
419-
try await self.sendResponse(response, outbound: outbound, logger: logger)
420424
}
425+
// What todo when there is no more responses to process?
426+
// This should not happen as the async iterator blocks until there is a response to process
427+
fatalError("No more responses to process - the async for loop should not return")
428+
} catch is LambdaHTTPServer.Pool<LambdaHTTPServer.LocalServerResponse>.PoolError {
429+
// detect concurrent invocations of POST and gently decline the requests while we're processing one.
430+
let response = LocalServerResponse(
431+
id: requestId,
432+
status: .badRequest,
433+
body: ByteBuffer(
434+
string:
435+
"It is not allowed to invoke multiple Lambda function executions in parallel. (The Lambda runtime environment on AWS will never do that)"
436+
)
437+
)
438+
try await self.sendResponse(response, outbound: outbound, logger: logger)
421439
}
422-
// What todo when there is no more responses to process?
423-
// This should not happen as the async iterator blocks until there is a response to process
424-
fatalError("No more responses to process - the async for loop should not return")
425440

426441
// client uses incorrect HTTP method
427442
case (_, let url) where url.hasSuffix(self.invocationEndpoint):
@@ -606,19 +621,18 @@ internal struct LambdaHTTPServer {
606621
if let nextAction = state.actionQueue.popFirst() {
607622
return (nextAction, continuation)
608623
} else {
609-
// there is no continuation and no action waiting,
610-
// enqueue the continuation for later usage
611-
state.continuationQueue.append(continuation)
612-
return (nil, continuation)
624+
state = .continuation(continuation)
625+
return nil
613626
}
627+
628+
case .continuation(_):
629+
fatalError("\(self.poolName) : Concurrent invocations to next(). This is not allowed.")
614630
}
615631
}
616632

617-
// there is no next action, ignore
618633
guard let nextAction else { return }
619634

620-
// we have a next action and a next continuation, resume it
621-
nextContinuation.resume(returning: nextAction)
635+
continuation.resume(returning: nextAction)
622636
}
623637
} onCancel: {
624638
self.lock.withLock { state in

Tests/AWSLambdaRuntimeTests/PoolTests.swift

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,32 @@ struct PoolTests {
158158
#expect(receivedValues.count == producerCount * messagesPerProducer)
159159
#expect(Set(receivedValues).count == producerCount * messagesPerProducer)
160160
}
161+
162+
@Test
163+
@available(LambdaSwift 2.0, *)
164+
func testConcurrentNext() async throws {
165+
let pool = LambdaHTTPServer.Pool<String>()
166+
167+
// Create two tasks that will both wait for elements to be available
168+
await #expect(throws: LambdaHTTPServer.Pool<Swift.String>.PoolError.self) {
169+
try await withThrowingTaskGroup(of: Void.self) { group in
170+
171+
// one of the two task will throw a PoolError
172+
173+
group.addTask {
174+
for try await _ in pool {
175+
}
176+
Issue.record("Loop 1 should not complete")
177+
}
178+
179+
group.addTask {
180+
for try await _ in pool {
181+
}
182+
Issue.record("Loop 2 should not complete")
183+
}
184+
try await group.waitForAll()
185+
}
186+
}
187+
}
188+
161189
}

0 commit comments

Comments
 (0)