22from concurrent .futures import ThreadPoolExecutor
33from collections .abc import Iterator
44
5+ import grpc
56from google .protobuf import empty_pb2 as _empty_pb2
6- from pynumaflow .shared .server import exit_on_error
7+ from pynumaflow .shared .server import update_context_err
78from pynumaflow ._metadata import _user_and_system_metadata_from_proto
89
9- from pynumaflow ._constants import NUM_THREADS_DEFAULT , STREAM_EOF , _LOGGER , ERR_UDF_EXCEPTION_STRING
10+ from pynumaflow ._constants import NUM_THREADS_DEFAULT , _LOGGER , ERR_UDF_EXCEPTION_STRING
1011from pynumaflow .mapper ._dtypes import MapSyncCallable , Datum , MapError
1112from pynumaflow .proto .mapper import map_pb2 , map_pb2_grpc
1213from pynumaflow .shared .synciter import SyncIterator
@@ -26,6 +27,8 @@ def __init__(self, handler: MapSyncCallable, multiproc: bool = False):
2627 self .multiproc = multiproc
2728 # create a thread pool for executing UDF code
2829 self .executor = ThreadPoolExecutor (max_workers = NUM_THREADS_DEFAULT )
30+ self .shutdown_event : threading .Event = threading .Event ()
31+ self .error : BaseException | None = None
2932
3033 def MapFn (
3134 self ,
@@ -36,6 +39,7 @@ def MapFn(
3639 Applies a function to each datum element.
3740 The pascal case function name comes from the proto map_pb2_grpc.py file.
3841 """
42+ result_queue = None
3943 try :
4044 # The first message to be received should be a valid handshake
4145 req = next (request_iterator )
@@ -57,10 +61,13 @@ def MapFn(
5761 for res in result_queue .read_iterator ():
5862 # if error handler accordingly
5963 if isinstance (res , BaseException ):
60- # Terminate the current server process due to exception
61- exit_on_error (
62- context , f"{ ERR_UDF_EXCEPTION_STRING } : { repr (res )} " , parent = self .multiproc
63- )
64+ err_msg = f"{ ERR_UDF_EXCEPTION_STRING } : { repr (res )} "
65+ _LOGGER .critical (err_msg , exc_info = True )
66+ update_context_err (context , res , err_msg )
67+ # Unblock the reader thread if it is waiting on queue.put()
68+ result_queue .close ()
69+ self .error = res
70+ self .shutdown_event .set ()
6471 return
6572 # return the result
6673 yield res
@@ -69,12 +76,22 @@ def MapFn(
6976 reader_thread .join ()
7077 self .executor .shutdown (cancel_futures = True )
7178
79+ except grpc .RpcError :
80+ _LOGGER .warning ("gRPC stream closed, shutting down the server." )
81+ if result_queue is not None :
82+ result_queue .close ()
83+ self .shutdown_event .set ()
84+ return
85+
7286 except BaseException as err :
73- _LOGGER .critical ("UDFError, re-raising the error" , exc_info = True )
74- # Terminate the current server process due to exception
75- exit_on_error (
76- context , f"{ ERR_UDF_EXCEPTION_STRING } : { repr (err )} " , parent = self .multiproc
77- )
87+ err_msg = f"UDFError, { ERR_UDF_EXCEPTION_STRING } : { repr (err )} "
88+ _LOGGER .critical (err_msg , exc_info = True )
89+ update_context_err (context , err , err_msg )
90+ # Unblock the reader thread if it is waiting on queue.put()
91+ if result_queue is not None :
92+ result_queue .close ()
93+ self .error = err
94+ self .shutdown_event .set ()
7895 return
7996
8097 def _process_requests (
@@ -91,9 +108,20 @@ def _process_requests(
91108 # wait for all tasks to finish after all requests exhausted
92109 self .executor .shutdown (wait = True )
93110 # Indicate to the result queue that no more messages left to process
94- result_queue .put (STREAM_EOF )
111+ result_queue .close ()
112+ except grpc .RpcError :
113+ # The only error that can occur here is the gRPC stream closing
114+ # (e.g. client disconnected). UDF exceptions are caught inside _invoke_map
115+ # and never propagate here.
116+ _LOGGER .warning ("gRPC stream closed in reader thread, shutting down the server." )
117+ # Let already-submitted UDF tasks finish within the graceful shutdown period
118+ self .executor .shutdown (wait = True )
119+ # Signal MapFn's read_iterator() loop to exit cleanly
120+ result_queue .close ()
121+ # Trigger server shutdown (not a UDF error, so self.error is not set)
122+ self .shutdown_event .set ()
95123 except BaseException as e :
96- _LOGGER .critical ("MapFn Error, re-raising the error " , exc_info = True )
124+ _LOGGER .critical ("MapFn Error while reading requests from gRPC stream " , exc_info = True )
97125 # Surface the error to the consumer; MapFn will handle and exit
98126 result_queue .put (e )
99127
0 commit comments