@@ -98,23 +98,12 @@ def _default_staging_path(self):
9898 )
9999
100100 def __init__ (
101- self ,
102- root_dir : str ,
103- environments_manager : Type [EnvironmentManager ],
104- config = None ,
105- ** kwargs ,
101+ self , root_dir : str , environments_manager : Type [EnvironmentManager ], config = None , ** kwargs
106102 ):
107103 super ().__init__ (config = config , ** kwargs )
108104 self .root_dir = root_dir
109105 self .environments_manager = environments_manager
110106
111- loop = asyncio .get_event_loop ()
112- self .dask_client_future : Awaitable [DaskClient ] = loop .create_task (self ._get_dask_client ())
113-
114- async def _get_dask_client (self ):
115- """Creates and configures a Dask client."""
116- return DaskClient (processes = False , asynchronous = True )
117-
118107 def create_job (self , model : CreateJob ) -> str :
119108 """Creates a new job record, may trigger execution of the job.
120109 In case a task runner is actually handling execution of the jobs,
@@ -394,6 +383,12 @@ def get_local_output_path(
394383 else :
395384 return os .path .join (self .root_dir , self .output_directory , output_dir_name )
396385
386+ async def stop_extension (self ):
387+ """
388+ Placeholder method for a cleanup code to run when the server is stopping.
389+ """
390+ pass
391+
397392
398393class Scheduler (BaseScheduler ):
399394 _db_session = None
@@ -427,6 +422,13 @@ def __init__(
427422 if self .task_runner_class :
428423 self .task_runner = self .task_runner_class (scheduler = self , config = config )
429424
425+ loop = asyncio .get_event_loop ()
426+ self .dask_client_future : Awaitable [DaskClient ] = loop .create_task (self ._get_dask_client ())
427+
428+ async def _get_dask_client (self ):
429+ """Creates and configures a Dask client."""
430+ return DaskClient (processes = False , asynchronous = True )
431+
430432 @property
431433 def db_session (self ):
432434 if not self ._db_session :
@@ -775,6 +777,14 @@ def get_staging_paths(self, model: Union[DescribeJob, DescribeJobDefinition]) ->
775777
776778 return staging_paths
777779
780+ async def stop_extension (self ):
781+ """
782+ Cleanup code to run when the server is stopping.
783+ """
784+ if self .dask_client_future :
785+ dask_client : DaskClient = await self .dask_client_future
786+ await dask_client .close ()
787+
778788
779789class ArchivingScheduler (Scheduler ):
780790 """Scheduler that captures all files in output directory in an archive."""
0 commit comments