44import tarfile
55import traceback
66from abc import ABC , abstractmethod
7- from typing import Dict
7+ from functools import lru_cache
8+ from typing import Dict , List
89
910import fsspec
1011import nbconvert
1112import nbformat
1213from nbconvert .preprocessors import CellExecutionError , ExecutePreprocessor
14+ from prefect import flow , task
15+ from prefect .futures import as_completed
16+ from prefect_dask .task_runners import DaskTaskRunner
1317
1418from jupyter_scheduler .models import DescribeJob , JobFeature , Status
15- from jupyter_scheduler .orm import Job , create_session
19+ from jupyter_scheduler .orm import Job , Workflow , create_session
1620from jupyter_scheduler .parameterize import add_parameters
1721from jupyter_scheduler .utils import get_utc_timestamp
22+ from jupyter_scheduler .workflows import DescribeWorkflow
1823
1924
2025class ExecutionManager (ABC ):
@@ -29,14 +34,29 @@ class ExecutionManager(ABC):
2934 _model = None
3035 _db_session = None
3136
32- def __init__ (self , job_id : str , root_dir : str , db_url : str , staging_paths : Dict [str , str ]):
37+ def __init__ (
38+ self ,
39+ job_id : str ,
40+ workflow_id : str ,
41+ root_dir : str ,
42+ db_url : str ,
43+ staging_paths : Dict [str , str ],
44+ ):
3345 self .job_id = job_id
46+ self .workflow_id = workflow_id
3447 self .staging_paths = staging_paths
3548 self .root_dir = root_dir
3649 self .db_url = db_url
3750
3851 @property
3952 def model (self ):
53+ if self .workflow_id :
54+ with self .db_session () as session :
55+ workflow = (
56+ session .query (Workflow ).filter (Workflow .workflow_id == self .workflow_id ).first ()
57+ )
58+ self ._model = DescribeWorkflow .from_orm (workflow )
59+ return self ._model
4060 if self ._model is None :
4161 with self .db_session () as session :
4262 job = session .query (Job ).filter (Job .job_id == self .job_id ).first ()
@@ -65,6 +85,18 @@ def process(self):
6585 else :
6686 self .on_complete ()
6787
88+ def process_workflow (self ):
89+
90+ self .before_start_workflow ()
91+ try :
92+ self .execute_workflow ()
93+ except CellExecutionError as e :
94+ self .on_failure_workflow (e )
95+ except Exception as e :
96+ self .on_failure_workflow (e )
97+ else :
98+ self .on_complete_workflow ()
99+
68100 @abstractmethod
69101 def execute (self ):
70102 """Performs notebook execution,
@@ -74,6 +106,11 @@ def execute(self):
74106 """
75107 pass
76108
109+ @abstractmethod
110+ def execute_workflow (self ):
111+ """Performs workflow execution"""
112+ pass
113+
77114 @classmethod
78115 @abstractmethod
79116 def supported_features (cls ) -> Dict [JobFeature , bool ]:
@@ -98,6 +135,15 @@ def before_start(self):
98135 )
99136 session .commit ()
100137
138+ def before_start_workflow (self ):
139+ """Called before start of execute"""
140+ workflow = self .model
141+ with self .db_session () as session :
142+ session .query (Workflow ).filter (Workflow .workflow_id == workflow .workflow_id ).update (
143+ {"status" : Status .IN_PROGRESS }
144+ )
145+ session .commit ()
146+
101147 def on_failure (self , e : Exception ):
102148 """Called after failure of execute"""
103149 job = self .model
@@ -109,6 +155,17 @@ def on_failure(self, e: Exception):
109155
110156 traceback .print_exc ()
111157
158+ def on_failure_workflow (self , e : Exception ):
159+ """Called after failure of execute"""
160+ workflow = self .model
161+ with self .db_session () as session :
162+ session .query (Workflow ).filter (Workflow .workflow_id == workflow .workflow_id ).update (
163+ {"status" : Status .FAILED , "status_message" : str (e )}
164+ )
165+ session .commit ()
166+
167+ traceback .print_exc ()
168+
112169 def on_complete (self ):
113170 """Called after job is completed"""
114171 job = self .model
@@ -118,10 +175,40 @@ def on_complete(self):
118175 )
119176 session .commit ()
120177
178+ def on_complete_workflow (self ):
179+ workflow = self .model
180+ with self .db_session () as session :
181+ session .query (Workflow ).filter (Workflow .workflow_id == workflow .workflow_id ).update (
182+ {"status" : Status .COMPLETED }
183+ )
184+ session .commit ()
185+
121186
122187class DefaultExecutionManager (ExecutionManager ):
123188 """Default execution manager that executes notebooks"""
124189
190+ @task (task_run_name = "{task_id}" )
191+ def execute_task (task_id : str ):
192+ print (f"Task { task_id } executed" )
193+ return task_id
194+
195+ @flow (task_runner = DaskTaskRunner ())
196+ def execute_workflow (self ):
197+ workflow : DescribeWorkflow = self .model
198+ tasks = {task ["id" ]: task for task in workflow .tasks }
199+
200+ # create Prefect tasks, use caching to ensure Prefect tasks are created before wait_for is called on them
201+ @lru_cache (maxsize = None )
202+ def make_task (task_id , execute_task ):
203+ deps = tasks [task_id ]["dependsOn" ]
204+ return execute_task .submit (
205+ task_id , wait_for = [make_task (dep_id , execute_task ) for dep_id in deps ]
206+ )
207+
208+ final_tasks = [make_task (task_id , self .execute_task ) for task_id in tasks ]
209+ for future in as_completed (final_tasks ):
210+ print (future .result ())
211+
125212 def execute (self ):
126213 job = self .model
127214
0 commit comments