Skip to content

Commit 977f787

Browse files
committed
fix: Clean shutdown for multithreaded unary Map
Signed-off-by: Sreekanth <prsreekanth920@gmail.com>
1 parent a66cdda commit 977f787

File tree

2 files changed

+48
-13
lines changed

2 files changed

+48
-13
lines changed

packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
from concurrent.futures import ThreadPoolExecutor
33
from collections.abc import Iterator
44

5+
import grpc
56
from google.protobuf import empty_pb2 as _empty_pb2
6-
from pynumaflow.shared.server import exit_on_error
7+
from pynumaflow.shared.server import update_context_err
78
from pynumaflow._metadata import _user_and_system_metadata_from_proto
89

9-
from pynumaflow._constants import NUM_THREADS_DEFAULT, STREAM_EOF, _LOGGER, ERR_UDF_EXCEPTION_STRING
10+
from pynumaflow._constants import NUM_THREADS_DEFAULT, _LOGGER, ERR_UDF_EXCEPTION_STRING
1011
from pynumaflow.mapper._dtypes import MapSyncCallable, Datum, MapError
1112
from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc
1213
from pynumaflow.shared.synciter import SyncIterator
@@ -26,6 +27,8 @@ def __init__(self, handler: MapSyncCallable, multiproc: bool = False):
2627
self.multiproc = multiproc
2728
# create a thread pool for executing UDF code
2829
self.executor = ThreadPoolExecutor(max_workers=NUM_THREADS_DEFAULT)
30+
self.shutdown_event: threading.Event = threading.Event()
31+
self.error: BaseException | None = None
2932

3033
def MapFn(
3134
self,
@@ -36,6 +39,7 @@ def MapFn(
3639
Applies a function to each datum element.
3740
The pascal case function name comes from the proto map_pb2_grpc.py file.
3841
"""
42+
result_queue = None
3943
try:
4044
# The first message to be received should be a valid handshake
4145
req = next(request_iterator)
@@ -57,10 +61,13 @@ def MapFn(
5761
for res in result_queue.read_iterator():
5862
# if error handler accordingly
5963
if isinstance(res, BaseException):
60-
# Terminate the current server process due to exception
61-
exit_on_error(
62-
context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}", parent=self.multiproc
63-
)
64+
err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}"
65+
_LOGGER.critical(err_msg, exc_info=True)
66+
update_context_err(context, res, err_msg)
67+
# Unblock the reader thread if it is waiting on queue.put()
68+
result_queue.close()
69+
self.error = res
70+
self.shutdown_event.set()
6471
return
6572
# return the result
6673
yield res
@@ -69,12 +76,22 @@ def MapFn(
6976
reader_thread.join()
7077
self.executor.shutdown(cancel_futures=True)
7178

79+
except grpc.RpcError:
80+
_LOGGER.warning("gRPC stream closed, shutting down the server.")
81+
if result_queue is not None:
82+
result_queue.close()
83+
self.shutdown_event.set()
84+
return
85+
7286
except BaseException as err:
73-
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
74-
# Terminate the current server process due to exception
75-
exit_on_error(
76-
context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}", parent=self.multiproc
77-
)
87+
err_msg = f"UDFError, {ERR_UDF_EXCEPTION_STRING}: {repr(err)}"
88+
_LOGGER.critical(err_msg, exc_info=True)
89+
update_context_err(context, err, err_msg)
90+
# Unblock the reader thread if it is waiting on queue.put()
91+
if result_queue is not None:
92+
result_queue.close()
93+
self.error = err
94+
self.shutdown_event.set()
7895
return
7996

8097
def _process_requests(
@@ -91,9 +108,20 @@ def _process_requests(
91108
# wait for all tasks to finish after all requests exhausted
92109
self.executor.shutdown(wait=True)
93110
# Indicate to the result queue that no more messages left to process
94-
result_queue.put(STREAM_EOF)
111+
result_queue.close()
112+
except grpc.RpcError:
113+
# The only error that can occur here is the gRPC stream closing
114+
# (e.g. client disconnected). UDF exceptions are caught inside _invoke_map
115+
# and never propagate here.
116+
_LOGGER.warning("gRPC stream closed in reader thread, shutting down the server.")
117+
# Let already-submitted UDF tasks finish within the graceful shutdown period
118+
self.executor.shutdown(wait=True)
119+
# Signal MapFn's read_iterator() loop to exit cleanly
120+
result_queue.close()
121+
# Trigger server shutdown (not a UDF error, so self.error is not set)
122+
self.shutdown_event.set()
95123
except BaseException as e:
96-
_LOGGER.critical("MapFn Error, re-raising the error", exc_info=True)
124+
_LOGGER.critical("MapFn Error while reading requests from gRPC stream", exc_info=True)
97125
# Surface the error to the consumer; MapFn will handle and exit
98126
result_queue.put(e)
99127

packages/pynumaflow/pynumaflow/mapper/sync_server.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import sys
2+
13
from pynumaflow.info.types import (
24
ServerInfo,
35
MAP_MODE_KEY,
@@ -112,4 +114,9 @@ def start(self) -> None:
112114
server_options=self._server_options,
113115
udf_type=UDFType.Map,
114116
server_info=serv_info,
117+
shutdown_event=self.servicer.shutdown_event,
115118
)
119+
120+
if self.servicer.error:
121+
_LOGGER.critical("Server exiting due to UDF error: %s", self.servicer.error)
122+
sys.exit(1)

0 commit comments

Comments
 (0)