1616runtime users need to take care to not assume a static rank or world size.
1717"""
1818
19+ import atexit
1920import logging
2021import threading
2122from contextlib import contextmanager , nullcontext
7576logger : logging .Logger = logging .getLogger (__name__ )
7677
7778# TODO: use non strings which are cheaper
78- _QUEUE_CLOSE = "queue_close "
79+ _PIPE_CLOSE = "pipe_close "
7980_FUTURE_RESULT = "fut_result"
8081_FUTURE_EXCEPTION = "fut_exception"
8182
@@ -940,36 +941,67 @@ def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
940941
941942 self ._timeout : float = timeout
942943
944+ # Register the shutdown method to be called at exit
945+ atexit .register (self .shutdown )
946+
943947 def shutdown (self ) -> None :
944948 """
945949 Shutdown the process group. This will kill the underlying process and
946- close all queues .
950+ close all pipes .
947951
948952 This is a no-op if the process group is already shutdown.
949953
950954 ProcessGroup can be reconfigured after shutdown.
951955 """
952-
956+ # Close the future pipe first
957+ if self ._future_pipe is not None :
958+ # close future thread
959+ self ._future_pipe .send ((- 1 , _PIPE_CLOSE , None , None ))
960+ assert self ._future_pipe is not None
961+ self ._future_pipe .close ()
962+ self ._future_pipe = None
963+ # Join the future thread after closing its pipe
964+ if self ._future_thread is not None :
965+ self ._future_thread .join (timeout = 10.0 )
966+ assert self ._future_thread is not None
967+ if self ._future_thread .is_alive ():
968+ raise RuntimeError ("Future thread did not exit" )
969+ self ._future_thread = None
970+ # Close the request pipe to signal the worker process to exit
953971 if self ._pipe is not None :
972+ self ._pipe .send ((_PIPE_CLOSE ,))
973+ assert self ._pipe is not None
954974 self ._pipe .close ()
955-
956- future_pipe = self ._future_pipe
957- if future_pipe is not None :
958- # wait for the future thread to exit and then close the queue
959- future_pipe .close ()
960-
961- future_thread = self ._future_thread
962- assert future_thread is not None
963-
964- future_thread .join (timeout = 10.0 )
965- if future_thread .is_alive ():
966- raise RuntimeError ("future thread did not exit" )
967-
968- # Kill after closing queues to avoid log spam.
975+ self ._pipe = None
976+ # Terminate the worker process after closing its pipe
969977 if self ._p is not None :
970- self ._p .kill ()
978+ self ._p .join (timeout = 10.0 )
979+ assert self ._p is not None
980+ if self ._p .is_alive ():
981+ raise RuntimeError ("Worker process did not exit" )
982+ self ._p = None
971983
972984 def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
985+ """
986+ Structure
987+ +-------------------+
988+ | |
989+ | Main Process | (updates futures)
990+ | | <---------------
991+ +-------------------+ |
992+ | Pipe 1 |
993+ v |
994+ +-------------------+ +-------------------+
995+ | | | |
996+ | Worker Process | -> | Future Thread |
997+ | | Pipe 2 | |
998+ +-------------------+ +-------------------+
999+
1000+ Main Process: Maintains self._futures
1001+ Worker Process: Handles tasks, communicates with Future Thread.
1002+ Future Thread: Manages asynchronous tasks, updates self._futures.
1003+ """
1004+
9731005 self ._world_size = world_size
9741006
9751007 self .shutdown ()
@@ -990,7 +1022,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
9901022 rank ,
9911023 world_size ,
9921024 req_remote ,
993- future_remote ,
1025+ future_local ,
9941026 curr_device ,
9951027 ),
9961028 daemon = True ,
@@ -1003,7 +1035,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
10031035 self ._futures = {}
10041036 self ._future_thread = threading .Thread (
10051037 target = self ._future_handler ,
1006- args = (future_local ,),
1038+ args = (_MonitoredPipe ( future_remote ) ,),
10071039 daemon = True ,
10081040 )
10091041 self ._future_thread .start ()
@@ -1049,6 +1081,8 @@ def _worker(
10491081 while True :
10501082 op = cast (list [object ], req_pipe .recv ())
10511083 cmd = op [0 ]
1084+ if cmd == _PIPE_CLOSE :
1085+ break
10521086 if cmd == "func" :
10531087 op_id : int
10541088 op_id , func_name , args , kwargs , stream_device , stream_id , event = (
@@ -1172,6 +1206,8 @@ def _future_handler(self, future_pipe: _MonitoredPipe) -> None:
11721206 op_id , mode , data , event = cast (
11731207 Tuple [int , str , object , Optional [torch .cuda .Event ]], cmd
11741208 )
1209+ if mode == _PIPE_CLOSE :
1210+ break
11751211 with self ._futures_lock :
11761212 fut = self ._futures [op_id ]
11771213 del self ._futures [op_id ]
0 commit comments