1010import pandas as pd
1111import requests
1212import shapely .wkt
13+ from requests .adapters import HTTPAdapter , Retry
1314
1415from openeo import BatchJob , Connection
1516from openeo .rest import OpenEoApiError
2324_Backend = collections .namedtuple ("_Backend" , ["get_connection" , "parallel_jobs" ])
2425
2526
27+ MAX_RETRIES = 5
28+
2629class MultiBackendJobManager :
2730 """
2831 Tracker for multiple jobs on multiple backends.
@@ -67,6 +70,7 @@ def __init__(
6770 """
6871 self .backends : Dict [str , _Backend ] = {}
6972 self .poll_sleep = poll_sleep
73+ self ._connections : Dict [str , _Backend ] = {}
7074
7175 # An explicit None or "" should also default to "."
7276 self ._root_dir = Path (root_dir or "." )
@@ -87,6 +91,10 @@ def add_backend(
8791 :param parallel_jobs:
8892 Maximum number of jobs to allow in parallel on a backend.
8993 """
94+
95+ # TODO: Code might become simpler if we turn _Backend into class move this logic there.
96+ # We would need to keep add_backend here as part of the public API though.
97+ # But the amount of unrelated "stuff to manage" would be less (better cohesion)
9098 if isinstance (connection , Connection ):
9199 c = connection
92100 connection = lambda : c
@@ -95,6 +103,53 @@ def add_backend(
95103 get_connection = connection , parallel_jobs = parallel_jobs
96104 )
97105
106+ def _get_connection (self , backend_name : str , resilient : bool = True ) -> Connection :
107+ """Get a connection for the backend and optionally make it resilient (adds retry behavior)
108+
109+ The default is to get a resilient connection, but if necessary you can turn it off with
110+ resilient=False
111+ """
112+
113+ # TODO: Code could be simplified if _Backend is a class and this method is moved there.
114+ # TODO: Is it better to make this a public method?
115+
116+ # Reuse the connection if we can, in order to avoid modifying the same connection several times.
117+ # This is to avoid adding the retry HTTPAdapter multiple times.
118+ # Remember that the get_connection attribute on _Backend can be a Connection object instead
119+ # of a callable, so we don't want to assume it is a fresh connection that doesn't have the
120+ # retry adapter yet.
121+ if backend_name in self ._connections :
122+ return self ._connections [backend_name ]
123+
124+ connection = self .backends [backend_name ].get_connection ()
125+ # If we really need it we can skip making it resilient, but by default it should be resilient.
126+ if resilient :
127+ self ._make_resilient (connection )
128+
129+ self ._connections [backend_name ] = connection
130+ return connection
131+
132+ def _make_resilient (self , connection ):
133+ """Add an HTTPAdapter that retries the request if it fails.
134+
135+ Retry for the following HTTP 50x statuses:
136+ 502 Bad Gateway
137+ 503 Service Unavailable
138+ 504 Gateway Timeout
139+ """
140+ status_forcelist = [502 , 503 , 504 ]
141+ retries = Retry (
142+ total = MAX_RETRIES ,
143+ read = MAX_RETRIES ,
144+ other = MAX_RETRIES ,
145+ status = MAX_RETRIES ,
146+ backoff_factor = 0.1 ,
147+ status_forcelist = status_forcelist ,
148+ allowed_methods = ["HEAD" , "GET" , "OPTIONS" , "POST" ],
149+ )
150+ connection .session .mount ("https://" , HTTPAdapter (max_retries = retries ))
151+ connection .session .mount ("http://" , HTTPAdapter (max_retries = retries ))
152+
98153 def _normalize_df (self , df : pd .DataFrame ) -> pd .DataFrame :
99154 """Ensure we have the required columns and the expected type for the geometry column.
100155
@@ -121,20 +176,50 @@ def _normalize_df(self, df: pd.DataFrame) -> pd.DataFrame:
121176 df ["geometry" ] = df ["geometry" ].apply (shapely .wkt .loads )
122177 return df
123178
124- # TODO: long method with deep nesting. Refactor it to make it more readable.
179+ def _persists (self , df , output_file ):
180+ df .to_csv (output_file , index = False )
181+ _log .info (f"Wrote job metadata to { output_file .absolute ()} " )
182+
125183 def run_jobs (
126184 self , df : pd .DataFrame , start_job : Callable [[], BatchJob ], output_file : Path
127185 ):
128186 """Runs jobs, specified in a dataframe, and tracks parameters.
129187
130188 :param df:
131189 DataFrame that specifies the jobs, and tracks the jobs' statuses.
190+
132191 :param start_job:
133192 A callback which will be invoked with the row of the dataframe for which a job should be started.
134193 This callable should return a :py:class:`openeo.rest.job.BatchJob` object.
194+
195+ The run_jobs method passes the following parameters to the start_job callback.
196+ You do not have to define all of the parameters described below, but if you leave
197+ any of them out, then remember to include the *args and **kwargs parameters.
198+ Otherwise you will have an exception because run_jobs passes unknown parameters to start_job.
199+
200+ row:
201+ The row in the pandas dataframe that stores the jobs state and other tracked data.
202+
203+ connection_provider:
204+ Like connection in add_backend:
205+ - either a Connection to the backend,
206+ - or a callable to create a backend connection.
207+ Typically you would need either the parameter `connection_provider`,
208+ or the parameter `connection`, but likely you will not need both.
209+
210+ connection:
211+ The Connection itself, that has already been created.
212+ Typically you would need either the parameter `connection_provider`,
213+ or the parameter `connection`, but likely you will not need both.
214+
215+ provider:
216+ The name of the backend that will run the job.
217+
135218 :param output_file:
136219 Path to output file (CSV) containing the status and metadata of the jobs.
137220 """
221+ # TODO: Defining start_jobs as a Protocol might make its usage more clear, and avoid complicated doctrings,
222+ # but Protocols are only supported in Python 3.8 and higher.
138223 # TODO: this resume functionality better fits outside of this function
139224 # (e.g. if `output_file` exists: `df` is fully discarded)
140225
@@ -147,10 +232,6 @@ def run_jobs(
147232
148233 df = self ._normalize_df (df )
149234
150- def persists (df , output_file ):
151- df .to_csv (output_file , index = False )
152- _log .info (f"Wrote job metadata to { output_file .absolute ()} " )
153-
154235 while (
155236 df [
156237 (df .status != "finished" )
@@ -163,7 +244,7 @@ def persists(df, output_file):
163244 self ._update_statuses (df )
164245 status_histogram = df .groupby ("status" ).size ().to_dict ()
165246 _log .info (f"Status histogram: { status_histogram } " )
166- persists (df , output_file )
247+ self . _persists (df , output_file )
167248
168249 if len (df [df .status == "not_started" ]) > 0 :
169250 # Check number of jobs running at each backend
@@ -182,52 +263,68 @@ def persists(df, output_file):
182263 )
183264 to_launch = df [df .status == "not_started" ].iloc [0 :to_add ]
184265 for i in to_launch .index :
185- df .loc [i , "backend_name" ] = backend_name
186- row = df .loc [i ]
187- try :
188- _log .info (
189- f"Starting job on backend { backend_name } for { row .to_dict ()} "
190- )
191- job = start_job (
192- row = row ,
193- connection_provider = self .backends [
194- backend_name
195- ].get_connection ,
196- connection = self .backends [
197- backend_name
198- ].get_connection (),
199- provider = backend_name ,
200- )
201- except requests .exceptions .ConnectionError as e :
202- _log .warning (
203- f"Failed to start job for { row .to_dict ()} " ,
204- exc_info = True ,
205- )
206- df .loc [i , "status" ] = "start_failed"
207- else :
208- df .loc [
209- i , "start_time"
210- ] = datetime .datetime .now ().isoformat ()
211- if job :
212- df .loc [i , "id" ] = job .job_id
213- with ignore_connection_errors (context = "get status" ):
214- status = job .status ()
215- df .loc [i , "status" ] = status
216- if status == "created" :
217- # start job if not yet done by callback
218- try :
219- job .start_job ()
220- df .loc [i , "status" ] = job .status ()
221- except OpenEoApiError as e :
222- _log .error (e )
223- df .loc [i , "status" ] = "start_failed"
224- else :
225- df .loc [i , "status" ] = "skipped"
226-
227- persists (df , output_file )
266+ self ._launch_job (start_job , df , i , backend_name )
267+ self ._persists (df , output_file )
228268
229269 time .sleep (self .poll_sleep )
230270
271+ def _launch_job (self , start_job , df , i , backend_name ):
272+ """Helper method for launching jobs
273+
274+ :param start_job:
275+ A callback which will be invoked with the row of the dataframe for which a job should be started.
276+ This callable should return a :py:class:`openeo.rest.job.BatchJob` object.
277+
278+ See also:
279+ `MultiBackendJobManager.run_jobs` for the parameters and return type of this callable
280+
281+ Even though it is called here in `_launch_job` and that is where the constraints
282+ really come from, the public method `run_jobs` needs to document `start_job` anyway,
283+ so let's avoid duplication in the docstrings.
284+
285+ :param df:
286+ DataFrame that specifies the jobs, and tracks the jobs' statuses.
287+
288+ :param i:
289+ index of the job's row in dataframe df
290+
291+ :param backend_name:
292+ name of the backend that will execute the job.
293+ """
294+
295+ df .loc [i , "backend_name" ] = backend_name
296+ row = df .loc [i ]
297+ try :
298+ _log .info (f"Starting job on backend { backend_name } for { row .to_dict ()} " )
299+ connection = self ._get_connection (backend_name , resilient = True )
300+
301+ job = start_job (
302+ row = row ,
303+ connection_provider = self ._get_connection ,
304+ connection = connection ,
305+ provider = backend_name ,
306+ )
307+ except requests .exceptions .ConnectionError as e :
308+ _log .warning (f"Failed to start job for { row .to_dict ()} " , exc_info = True )
309+ df .loc [i , "status" ] = "start_failed"
310+ else :
311+ df .loc [i , "start_time" ] = datetime .datetime .now ().isoformat ()
312+ if job :
313+ df .loc [i , "id" ] = job .job_id
314+ with ignore_connection_errors (context = "get status" ):
315+ status = job .status ()
316+ df .loc [i , "status" ] = status
317+ if status == "created" :
318+ # start job if not yet done by callback
319+ try :
320+ job .start_job ()
321+ df .loc [i , "status" ] = job .status ()
322+ except OpenEoApiError as e :
323+ _log .error (e )
324+ df .loc [i , "status" ] = "start_failed"
325+ else :
326+ df .loc [i , "status" ] = "skipped"
327+
231328 def on_job_done (self , job : BatchJob , row ):
232329 """
233330 Handles jobs that have finished. Can be overridden to provide custom behaviour.
@@ -251,7 +348,7 @@ def on_job_done(self, job: BatchJob, row):
251348
252349 def on_job_error (self , job : BatchJob , row ):
253350 """
254- Handles jobs that stopped with errors. Can be overridden to provide custom behaviour.
351+ Handles jobs that stopped with errors. Can be overridden to provide custom behaviour.
255352
256353 Default implementation writes the error logs to a JSON file.
257354
@@ -298,7 +395,7 @@ def _update_statuses(self, df: pd.DataFrame):
298395 backend_name = df .loc [i , "backend_name" ]
299396
300397 try :
301- con = self .backends [ backend_name ]. get_connection ( )
398+ con = self ._get_connection ( backend_name )
302399 the_job = con .job (job_id )
303400 job_metadata = the_job .describe_job ()
304401 _log .info (
0 commit comments