diff --git a/octue/cloud/events/handler.py b/octue/cloud/events/handler.py index 581461f3b..e412c7346 100644 --- a/octue/cloud/events/handler.py +++ b/octue/cloud/events/handler.py @@ -10,6 +10,7 @@ from octue.definitions import GOOGLE_COMPUTE_PROVIDERS from octue.log_handlers import COLOUR_PALETTE from octue.resources.manifest import Manifest +from octue.utils.exceptions import convert_exception_event_to_exception logger = logging.getLogger(__name__) @@ -37,6 +38,7 @@ class AbstractEventHandler: :param str|None exclude_logs_containing: if provided, skip handling log messages containing this string :param bool only_handle_result: if `True`, skip handling non-result events and only handle the "result" event when received (turning this on speeds up event handling) :param bool validate_events: if `True`, validate events before attempting to handle them (turning this off speeds up event handling) + :param bool raise_errors: if `True`, raise any exceptions received; otherwise, just log them (just logging them allows a partial result event to be received afterwards and handled) :return None: """ @@ -50,6 +52,7 @@ def __init__( exclude_logs_containing=None, only_handle_result=False, validate_events=True, + raise_errors=True, ): self.handle_monitor_message = handle_monitor_message self.record_events = record_events @@ -58,6 +61,7 @@ def __init__( self.exclude_logs_containing = exclude_logs_containing self.only_handle_result = only_handle_result self.validate_events = validate_events + self.raise_errors = raise_errors self.handled_events = [] self._start_time = None @@ -221,29 +225,19 @@ def _handle_log_message(self, event, attributes): logger.handle(record) def _handle_exception(self, event, attributes): - """Raise the exception from the child. + """Raise or log the exception from the child. :param dict event: - :param dict attributes: the event's attributes + :param octue.cloud.events.attributes.ResponseAttributes attributes: the event's attributes :raise Exception: :return None: """ - exception_message = "\n\n".join( - ( - event["exception_message"], - f"The following traceback was captured from the remote service {attributes.sender!r}:", - "".join(event["exception_traceback"]), - ) - ) - - try: - exception_type = EXCEPTIONS_MAPPING[event["exception_type"]] + error = convert_exception_event_to_exception(event, attributes.sender, EXCEPTIONS_MAPPING) - # Allow unknown exception types to still be raised. - except KeyError: - exception_type = type(event["exception_type"], (Exception,), {}) + if self.raise_errors: + raise error - raise exception_type(exception_message) + logger.error("", exc_info=error) def _handle_result(self, event, attributes): """Extract any output values and output manifest from the result, deserialising the manifest if present. @@ -259,4 +253,13 @@ def _handle_result(self, event, attributes): else: output_manifest = None - return {"output_values": event.get("output_values"), "output_manifest": output_manifest} + result = { + "output_values": event.get("output_values"), + "output_manifest": output_manifest, + "success": event["success"], + } + + if event.get("exception"): + result["exception"] = event["exception"] + + return result diff --git a/octue/cloud/events/replayer.py b/octue/cloud/events/replayer.py index ab57f6003..9e5c606a0 100644 --- a/octue/cloud/events/replayer.py +++ b/octue/cloud/events/replayer.py @@ -19,6 +19,7 @@ class EventReplayer(AbstractEventHandler): :param str|None exclude_logs_containing: if provided, skip handling log messages containing this string :param bool only_handle_result: if `True`, skip non-result events and only handle the "result" event if present (turning this on speeds up event handling) :param bool validate_events: if `True`, validate events before attempting to handle them (this is off by default to speed up event handling) + :param bool raise_errors: if `True`, raise any exceptions received; otherwise, just log them (just logging them allows a partial result event to be received afterwards and handled) :return None: """ @@ -32,6 +33,7 @@ def __init__( exclude_logs_containing=None, only_handle_result=False, validate_events=False, + raise_errors=True, ): event_handlers = event_handlers or { "question": self._handle_question, @@ -52,6 +54,7 @@ def __init__( exclude_logs_containing=exclude_logs_containing, only_handle_result=only_handle_result, validate_events=validate_events, + raise_errors=raise_errors, ) def handle_events(self, events): diff --git a/octue/cloud/pub_sub/events.py b/octue/cloud/pub_sub/events.py index e33c8a3b1..0369a9b92 100644 --- a/octue/cloud/pub_sub/events.py +++ b/octue/cloud/pub_sub/events.py @@ -46,6 +46,7 @@ class GoogleCloudPubSubEventHandler(AbstractEventHandler): :param str|None exclude_logs_containing: if provided, skip handling log messages containing this string :param bool only_handle_result: if `True`, skip non-result events and only handle the "result" event if present (turning this on speeds up event handling) :param bool validate_events: if `True`, validate events before attempting to handle them (turn this off to speed up event handling at risk of failure if an invalid event is received) + :param bool raise_errors: if `True`, raise any exceptions received; otherwise, just log them (just logging them allows a partial result event to be received afterwards and handled) :return None: """ @@ -60,6 +61,7 @@ def __init__( exclude_logs_containing=None, only_handle_result=False, validate_events=True, + raise_errors=True, ): self.subscription = subscription @@ -72,6 +74,7 @@ def __init__( exclude_logs_containing=exclude_logs_containing, only_handle_result=only_handle_result, validate_events=validate_events, + raise_errors=raise_errors, ) self._heartbeat_checker = None diff --git a/octue/cloud/pub_sub/service.py b/octue/cloud/pub_sub/service.py index 8643fbf1d..242f04da8 100644 --- a/octue/cloud/pub_sub/service.py +++ b/octue/cloud/pub_sub/service.py @@ -26,6 +26,7 @@ from octue.compatibility import warn_if_incompatible from octue.definitions import DEFAULT_MAXIMUM_HEARTBEAT_INTERVAL, LOCAL_SDK_VERSION import octue.exceptions +from octue.resources import Analysis from octue.utils.dictionaries import make_minimal_dictionary from octue.utils.encoders import OctueJSONEncoder from octue.utils.exceptions import convert_exception_to_primitives @@ -195,12 +196,16 @@ def answer(self, question, heartbeat_interval=120, timeout=30): :raise Exception: if any exception arises during running analysis and sending its results :return dict: the result event """ + heartbeater = None + + # Instantiate analysis here so outputs can be accessed even in the event of an exception in the run function. + analysis = Analysis() + try: question, question_attributes = self._parse_question(question) except jsonschema.ValidationError: return - heartbeater = None response_attributes = ResponseAttributes.from_question_attributes(question_attributes) try: @@ -225,7 +230,8 @@ def answer(self, question, heartbeat_interval=120, timeout=30): handle_monitor_message = functools.partial(self._send_monitor_message, attributes=response_attributes) - analysis = self.run_function( + self.run_function( + analysis=analysis, analysis_id=question_attributes.question_uuid, input_values=question.get("input_values"), input_manifest=question.get("input_manifest"), @@ -237,7 +243,7 @@ def answer(self, question, heartbeat_interval=120, timeout=30): originator=question_attributes.originator, ) - result = self._send_result(analysis, response_attributes) + result = self._send_result(analysis, response_attributes, success=True) heartbeater.cancel() logger.info("%r answered question %r.", self, question_attributes.question_uuid) return result @@ -251,7 +257,8 @@ def answer(self, question, heartbeat_interval=120, timeout=30): sender_sdk_version=question_attributes.sender_sdk_version, ) - self.send_exception(attributes=response_attributes, timeout=timeout) + self._send_exception(attributes=response_attributes, timeout=timeout) + self._send_result(analysis, response_attributes, success=False, exception=self._serialise_exception()) raise error def ask( @@ -381,6 +388,7 @@ def wait_for_answer( record_events=True, timeout=60, maximum_heartbeat_interval=DEFAULT_MAXIMUM_HEARTBEAT_INTERVAL, + raise_errors=True, ): """Wait for an answer to a question on the given subscription, deleting the subscription and its topic once the answer is received. @@ -390,8 +398,9 @@ def wait_for_answer( :param bool record_events: if `True`, record messages received from the child in the `received_events` attribute :param float|None timeout: how long in seconds to wait for an answer before raising a `TimeoutError` :param float|int maximum_heartbeat_interval: the maximum amount of time (in seconds) allowed between child heartbeats before an error is raised + :param bool raise_errors: :raise TimeoutError: if the timeout is exceeded - :return dict: dictionary containing the keys "output_values" and "output_manifest" + :return dict: dictionary containing the keys "output_values", "output_manifest", "success", and for a failed analysis, "exception" """ if subscription.is_push_subscription: raise octue.exceptions.NotAPullSubscription( @@ -403,6 +412,7 @@ def wait_for_answer( subscription=subscription, handle_monitor_message=handle_monitor_message, record_events=record_events, + raise_errors=raise_errors, ) try: @@ -444,26 +454,15 @@ def wait_for_answer( # self._emit_event({"kind": "cancellation"}, attributes=question_attributes, timeout=timeout) # logger.info("Cancellation of question %r requested.", question_uuid) - def send_exception(self, attributes, timeout=30): + def _send_exception(self, attributes, timeout=30): """Serialise and send the exception being handled to the parent. :param octue.cloud.events.attributes.ResponseAttributes attributes: the attributes to use for the exception event :param float|None timeout: time in seconds to keep retrying sending of the exception :return None: """ - exception = convert_exception_to_primitives() - exception_message = f"Error in {self!r}: {exception['message']}" - - self._emit_event( - { - "kind": "exception", - "exception_type": exception["type"], - "exception_message": exception_message, - "exception_traceback": exception["traceback"], - }, - attributes=attributes, - timeout=timeout, - ) + event = self._serialise_exception() + self._emit_event(event, attributes=attributes, timeout=timeout) def _emit_event(self, event, attributes, wait=True, timeout=30): """Emit a JSON-serialised event as a Pub/Sub message to the services topic with optional message attributes. @@ -592,15 +591,22 @@ def _send_monitor_message(self, data, attributes, timeout=30): self._emit_event({"kind": "monitor_message", "data": data}, attributes=attributes, timeout=timeout, wait=False) logger.debug("Monitor message sent by %r.", self) - def _send_result(self, analysis, attributes, timeout=30): + def _send_result(self, analysis, attributes, success, exception=None, timeout=30): """Send the result to the parent. :param octue.resources.analysis.Analysis analysis: the analysis object containing the output values and/or output manifest :param octue.cloud.events.attributes.ResponseAttributes attributes: the attributes to use for the result event + :param bool success: + :param dict|None exception: :param float timeout: time in seconds to retry sending the message :return dict: the result """ - result = make_minimal_dictionary(kind="result", output_values=analysis.output_values) + result = make_minimal_dictionary( + kind="result", + output_values=analysis.output_values, + success=success, + exception=exception, + ) if analysis.output_manifest is not None: result["output_manifest"] = analysis.output_manifest.to_primitive() @@ -608,6 +614,17 @@ def _send_result(self, analysis, attributes, timeout=30): self._emit_event(event=result, attributes=attributes, timeout=timeout) return result + def _serialise_exception(self): + exception = convert_exception_to_primitives() + exception_message = f"Error in {self!r}: {exception['message']}" + + return { + "kind": "exception", + "exception_type": exception["type"], + "exception_message": exception_message, + "exception_traceback": exception["traceback"], + } + def _parse_question(self, question): """Parse a question in dictionary format or direct Google Pub/Sub format. diff --git a/octue/resources/analysis.py b/octue/resources/analysis.py index d79f281af..7b0a7c789 100644 --- a/octue/resources/analysis.py +++ b/octue/resources/analysis.py @@ -3,7 +3,6 @@ import coolname -import twined.exceptions from octue.cloud import storage from octue.exceptions import InvalidMonitorMessage from octue.mixins import Hashable, Identifiable, Labelable, Serialisable, Taggable @@ -11,7 +10,7 @@ from octue.utils.encoders import OctueJSONEncoder from octue.utils.threads import RepeatingTimer from twined import ALL_STRANDS, Twine - +import twined.exceptions logger = logging.getLogger(__name__) @@ -48,7 +47,7 @@ class Analysis(Identifiable, Serialisable, Labelable, Taggable): If a strand is ``None``, so will its corresponding hash attribute be. The hash of a datafile is the hash of its file, while the hash of a manifest or dataset is the cumulative hash of the files it refers to. - :param twined.Twine|dict|str twine: the twine, dictionary defining a twine, or path to "twine.json" file defining the service's data interface + :param twined.Twine|dict|str|None twine: the twine, dictionary defining a twine, or path to "twine.json" file defining the service's data interface :param callable|None handle_monitor_message: an optional function for sending monitor messages to the parent that requested the analysis :param any configuration_values: the configuration values for the analysis - this can be expressed as a python primitive (e.g. dict), a path to a JSON file, or a JSON string. :param octue.resources.manifest.Manifest configuration_manifest: a manifest of configuration datasets for the analysis if required @@ -61,16 +60,45 @@ class Analysis(Identifiable, Serialisable, Labelable, Taggable): :return None: """ - def __init__(self, twine, handle_monitor_message=None, **kwargs): - if isinstance(twine, Twine): + def __init__(self, twine=None, handle_monitor_message=None, **kwargs): + strand_kwargs = {name: kwargs.pop(name, None) for name in ALL_STRANDS} + output_location = kwargs.pop("output_location", None) + use_signed_urls_for_output_datasets = kwargs.pop("use_signed_urls_for_output_datasets", False) + + self.prepare( + twine=twine, + handle_monitor_message=handle_monitor_message, + output_location=output_location, + use_signed_urls_for_output_datasets=use_signed_urls_for_output_datasets, + **strand_kwargs, + ) + + super().__init__(**kwargs) + + @property + def finalised(self): + """Check whether the analysis has been finalised (i.e. whether its outputs have been validated and, if an output + manifest is produced, its datasets uploaded). + + :return bool: + """ + return self._finalised + + def prepare( + self, + twine=None, + handle_monitor_message=None, + output_location=None, + use_signed_urls_for_output_datasets=None, + **strand_kwargs, + ): + if twine is None or isinstance(twine, Twine): self.twine = twine else: self.twine = Twine(source=twine) self._handle_monitor_message = handle_monitor_message - strand_kwargs = {name: kwargs.pop(name, None) for name in ALL_STRANDS} - # Values strands. self.configuration_values = strand_kwargs.get("configuration_values", None) self.input_values = strand_kwargs.get("input_values", None) @@ -85,22 +113,12 @@ def __init__(self, twine, handle_monitor_message=None, **kwargs): self.children = strand_kwargs.get("children", None) # Non-strands. - self.output_location = kwargs.pop("output_location", None) - self.use_signed_urls_for_output_datasets = kwargs.pop("use_signed_urls_for_output_datasets", False) + self.output_location = output_location + self.use_signed_urls_for_output_datasets = use_signed_urls_for_output_datasets self._calculate_strand_hashes(strands=strand_kwargs) self._periodic_monitor_message_sender_threads = [] self._finalised = False - super().__init__(**kwargs) - - @property - def finalised(self): - """Check whether the analysis has been finalised (i.e. whether its outputs have been validated and, if an output - manifest is produced, its datasets uploaded). - - :return bool: - """ - return self._finalised def send_monitor_message(self, data): """Send a monitor message to the parent that requested the analysis. diff --git a/octue/resources/child.py b/octue/resources/child.py index 6e2234ed1..9e551a849 100644 --- a/octue/resources/child.py +++ b/octue/resources/child.py @@ -3,9 +3,11 @@ import logging import os +from octue.cloud import EXCEPTIONS_MAPPING from octue.cloud.pub_sub.service import Service from octue.definitions import DEFAULT_MAXIMUM_HEARTBEAT_INTERVAL from octue.resources import service_backends +from octue.utils.exceptions import convert_exception_event_to_exception logger = logging.getLogger(__name__) @@ -111,7 +113,7 @@ def ask( :param float|int maximum_heartbeat_interval: the maximum amount of time (in seconds) allowed between child heartbeats before an error is raised :raise TimeoutError: if the timeout is exceeded while waiting for an answer :raise Exception: if the question raises an error and `raise_errors=True` - :return dict|octue.cloud.pub_sub.subscription.Subscription|Exception|None, str: for a synchronous question, a dictionary containing the keys "output_values" and "output_manifest" from the result (or just an exception if the question fails), and the question UUID; for a question with a push endpoint, the push subscription and the question UUID; for an asynchronous question, `None` and the question UUID + :return dict|octue.cloud.pub_sub.subscription.Subscription|Exception|None, str: for a synchronous question, a dictionary containing the keys "output_values", "output_manifest", and "success" from the result (or just an exception if the question fails), and the question UUID; for a question with a push endpoint, the push subscription and the question UUID; for an asynchronous question, `None` and the question UUID """ prevent_retries_when = prevent_retries_when or [] @@ -142,52 +144,57 @@ def ask( logger.info("Waiting for question to be accepted...") - try: - answer = self._service.wait_for_answer( - subscription=subscription, - handle_monitor_message=handle_monitor_message, - record_events=record_events, - timeout=timeout, - maximum_heartbeat_interval=maximum_heartbeat_interval, - ) + answer = self._service.wait_for_answer( + subscription=subscription, + handle_monitor_message=handle_monitor_message, + record_events=record_events, + timeout=timeout, + maximum_heartbeat_interval=maximum_heartbeat_interval, + raise_errors=False, + ) + if answer["success"]: return answer, question_uuid - except Exception as e: - logger.error( - "Question %r failed. Run 'octue question diagnostics gs:///%s " - "--download-datasets' to get the crash diagnostics.", - question_uuid, - question_uuid, - ) + logger.error( + "Question %r failed. Run 'octue question diagnostics gs:///%s " + "--download-datasets' to get the crash diagnostics.", + question_uuid, + question_uuid, + ) - if raise_errors: - raise e + e = convert_exception_event_to_exception(answer["exception"], self.id, EXCEPTIONS_MAPPING) - if type(e) in prevent_retries_when: - logger.info("Skipping retries for exceptions of type %r.", type(e)) - return e, question_uuid + if type(e) in prevent_retries_when: + logger.info("Skipping retries for exceptions of type %r.", type(e)) + return e, question_uuid - for retry in range(max_retries): - logger.info("Retrying question %r %d of %d times.", question_uuid, retry + 1, max_retries) + for retry in range(max_retries): + logger.info("Retrying question %r %d of %d times.", question_uuid, retry + 1, max_retries) - inputs["retry_count"] += 1 - answer, question_uuid = self.ask(**inputs, raise_errors=False, log_errors=False) + inputs["retry_count"] += 1 + answer, question_uuid = self.ask(**inputs, raise_errors=False, log_errors=False) - if not isinstance(answer, Exception) or type(answer) in prevent_retries_when: - return answer, question_uuid + if answer["success"]: + return answer, question_uuid - e = answer + e = convert_exception_event_to_exception(answer["exception"], self.id, EXCEPTIONS_MAPPING) - if log_errors: - logger.error( - "Question %r failed after %d retries (see below for error).", - question_uuid, - max_retries, - exc_info=e, - ) + if type(e) in prevent_retries_when: + return e, question_uuid - return e, question_uuid + if raise_errors: + raise e + + if log_errors: + logger.error( + "Question %r failed after %d retries (see below for error).", + question_uuid, + max_retries, + exc_info=e, + ) + + return e, question_uuid def ask_multiple( self, diff --git a/octue/runner.py b/octue/runner.py index 64589ae18..d1c577000 100644 --- a/octue/runner.py +++ b/octue/runner.py @@ -6,12 +6,12 @@ import re import uuid -import google.api_core.exceptions from google import auth +import google.api_core.exceptions from google.cloud import secretmanager -from jsonschema import ValidationError, validate as jsonschema_validate +from jsonschema import ValidationError +from jsonschema import validate as jsonschema_validate -import twined.exceptions from octue import exceptions from octue.app_loading import AppFrom from octue.diagnostics import Diagnostics @@ -21,7 +21,7 @@ from octue.resources.datafile import downloaded_files from octue.utils.files import registered_temporary_directories from twined import Twine - +import twined.exceptions SAVE_DIAGNOSTICS_OFF = "SAVE_DIAGNOSTICS_OFF" SAVE_DIAGNOSTICS_ON_CRASH = "SAVE_DIAGNOSTICS_ON_CRASH" @@ -145,6 +145,7 @@ def __repr__(self): def run( self, + analysis=None, analysis_id=None, input_values=None, input_manifest=None, @@ -158,6 +159,7 @@ def run( ): """Run an analysis. + :param octue.resources.analysis.Analysis|None analysis: :param str|None analysis_id: UUID of analysis :param str|dict|None input_values: the input_values strand data. Can be expressed as a string path of a *.json file (relative or absolute), as an open file-like object (containing json data), as a string of json data or as an already-parsed dict. :param str|dict|octue.resources.manifest.Manifest|None input_manifest: The input_manifest strand data. Can be expressed as a string path of a *.json file (relative or absolute), as an open file-like object (containing json data), as a string of json data or as an already-parsed dict. @@ -241,16 +243,30 @@ def run( analysis_log_level=analysis_log_level, extra_log_handlers=extra_log_handlers, ): - analysis = Analysis( - id=analysis_id, - twine=self.twine, - handle_monitor_message=handle_monitor_message, - output_location=self.output_location, - use_signed_urls_for_output_datasets=self.use_signed_urls_for_output_datasets, - **self.configuration, - **inputs, - **outputs_and_monitors, - ) + if analysis: + analysis._set_id(analysis_id) + + analysis.prepare( + twine=self.twine, + handle_monitor_message=handle_monitor_message, + output_location=self.output_location, + use_signed_urls_for_output_datasets=self.use_signed_urls_for_output_datasets, + **self.configuration, + **inputs, + **outputs_and_monitors, + ) + + else: + analysis = Analysis( + id=analysis_id, + twine=self.twine, + handle_monitor_message=handle_monitor_message, + output_location=self.output_location, + use_signed_urls_for_output_datasets=self.use_signed_urls_for_output_datasets, + **self.configuration, + **inputs, + **outputs_and_monitors, + ) try: self._load_and_run_app(analysis) diff --git a/octue/utils/exceptions.py b/octue/utils/exceptions.py index e1de2f878..3f9257838 100644 --- a/octue/utils/exceptions.py +++ b/octue/utils/exceptions.py @@ -45,3 +45,22 @@ def convert_exception_to_primitives(exception=None): "message": f"{exception_info[1]}", "traceback": tb.format_list(tb.extract_tb(exception_info[2])), } + + +def convert_exception_event_to_exception(event, sender, exceptions_mapping): + exception_message = "\n\n".join( + ( + event["exception_message"], + f"The following traceback was captured from the remote service {sender!r}:", + "".join(event["exception_traceback"]), + ) + ) + + try: + exception_type = exceptions_mapping[event["exception_type"]] + + # Allow unknown exception types to still be raised. + except KeyError: + exception_type = type(event["exception_type"], (Exception,), {}) + + return exception_type(exception_message) diff --git a/tests/cloud/pub_sub/test_events.py b/tests/cloud/pub_sub/test_events.py index 76a0054ae..d9b935968 100644 --- a/tests/cloud/pub_sub/test_events.py +++ b/tests/cloud/pub_sub/test_events.py @@ -136,18 +136,15 @@ def test_delivery_acknowledgement(self): child = MockService(backend=GCPPubSubBackend(project_id=TEST_PROJECT_ID)) events = [ - {"event": {"kind": "delivery_acknowledgement", "order": 0}}, - {"event": {"kind": "result", "order": 1}}, + {"event": {"kind": "delivery_acknowledgement"}}, + {"event": {"kind": "result", "success": True}}, ] for event in events: - child._emit_event( - event=event["event"], - attributes=self.attributes, - ) + child._emit_event(event=event["event"], attributes=self.attributes) result = event_handler.handle_events() - self.assertEqual(result, {"output_values": None, "output_manifest": None}) + self.assertEqual(result, {"output_values": None, "output_manifest": None, "success": True}) def test_error_raised_if_heartbeat_not_received_before_checked(self): """Test that an error is raised if a heartbeat isn't received before a heartbeat is first checked for.""" diff --git a/tests/cloud/pub_sub/test_service.py b/tests/cloud/pub_sub/test_service.py index 7fbc9f22c..68b46b321 100644 --- a/tests/cloud/pub_sub/test_service.py +++ b/tests/cloud/pub_sub/test_service.py @@ -23,10 +23,11 @@ from octue.cloud.emulators.service import ServicePatcher from octue.cloud.pub_sub.service import Service from octue.exceptions import InvalidMonitorMessage -from octue.resources import Analysis, Datafile, Dataset, Manifest +from octue.resources import Datafile, Dataset, Manifest from octue.resources.service_backends import GCPPubSubBackend from tests import MOCK_SERVICE_REVISION_TAG, TEST_BUCKET_NAME, TEST_PROJECT_ID from tests.base import BaseTestCase +from twined import Twine import twined.exceptions logger = logging.getLogger(__name__) @@ -257,7 +258,11 @@ def test_ask_with_real_run_function_with_no_log_message_forwarding(self): self.assertEqual( answer, - {"output_values": MockAnalysis().output_values, "output_manifest": MockAnalysis().output_manifest}, + { + "output_values": MockAnalysis().output_values, + "output_manifest": MockAnalysis().output_manifest, + "success": True, + }, ) self.assertTrue(all("[truly/madly:deeply" not in message for message in logging_context.output)) @@ -282,7 +287,11 @@ def test_ask_with_real_run_function_with_log_message_forwarding(self): self.assertEqual( answer, - {"output_values": MockAnalysis().output_values, "output_manifest": MockAnalysis().output_manifest}, + { + "output_values": MockAnalysis().output_values, + "output_manifest": MockAnalysis().output_manifest, + "success": True, + }, ) # Check that the two expected remote log messages were logged consecutively in the right order with the service @@ -415,10 +424,11 @@ def test_ask_with_non_json_python_primitive_input_values(self): """ input_values = {"my_set": {1, 2, 3}, "my_datetime": datetime.datetime.now()} - def run_function(analysis_id, input_values, *args, **kwargs): - return MockAnalysis(output_values=input_values) + child = MockService( + backend=BACKEND, + run_function=self.make_run_function(run_function_returnee=MockAnalysis(output_values=input_values)), + ) - child = MockService(backend=BACKEND, run_function=lambda *args, **kwargs: run_function(*args, **kwargs)) parent = MockService(backend=BACKEND, children={child.id: child}) child.serve() @@ -457,7 +467,11 @@ def test_ask_with_input_manifest(self): self.assertEqual( answer, - {"output_values": MockAnalysis().output_values, "output_manifest": MockAnalysis().output_manifest}, + { + "output_values": MockAnalysis().output_values, + "output_manifest": MockAnalysis().output_manifest, + "success": True, + }, ) def test_ask_with_input_manifest_and_no_input_values(self): @@ -485,7 +499,11 @@ def test_ask_with_input_manifest_and_no_input_values(self): self.assertEqual( answer, - {"output_values": MockAnalysis().output_values, "output_manifest": MockAnalysis().output_manifest}, + { + "output_values": MockAnalysis().output_values, + "output_manifest": MockAnalysis().output_manifest, + "success": True, + }, ) def test_ask_with_input_manifest_with_local_paths_raises_error(self): @@ -518,11 +536,12 @@ def test_ask_with_input_manifest_with_local_paths_works_if_allowed_and_child_has manifest = Manifest(datasets={"my-local-dataset": Dataset(name="my-local-dataset", files={local_file})}) # Get the child to open the local file itself and return the contents as output. - def run_function(*args, **kwargs): - with open(temporary_local_path) as f: - return MockAnalysis(output_values=f.read()) + with open(temporary_local_path) as f: + child = MockService( + backend=BACKEND, + run_function=self.make_run_function(run_function_returnee=MockAnalysis(output_values=f.read())), + ) - child = MockService(backend=BACKEND, run_function=run_function) parent = MockService(backend=BACKEND, children={child.id: child}) child.serve() @@ -563,7 +582,11 @@ def test_service_can_ask_multiple_questions_to_child(self): for answer in answers: self.assertEqual( answer, - {"output_values": MockAnalysis().output_values, "output_manifest": MockAnalysis().output_manifest}, + { + "output_values": MockAnalysis().output_values, + "output_manifest": MockAnalysis().output_manifest, + "success": True, + }, ) def test_service_can_ask_questions_to_multiple_children(self): @@ -583,7 +606,11 @@ def test_service_can_ask_questions_to_multiple_children(self): self.assertEqual( answer_1, - {"output_values": MockAnalysis().output_values, "output_manifest": MockAnalysis().output_manifest}, + { + "output_values": MockAnalysis().output_values, + "output_manifest": MockAnalysis().output_manifest, + "success": True, + }, ) self.assertEqual( @@ -591,15 +618,18 @@ def test_service_can_ask_questions_to_multiple_children(self): { "output_values": DifferentMockAnalysis.output_values, "output_manifest": DifferentMockAnalysis.output_manifest, + "success": True, }, ) def test_child_can_ask_its_own_child_questions(self): """Test that a child can contact its own child while answering a question from a parent.""" - def child_run_function(analysis_id, input_values, *args, **kwargs): + def child_run_function(analysis, analysis_id, input_values, *args, **kwargs): subscription, _ = child.ask(service_id=child_of_child.id, input_values=input_values) - return MockAnalysis(output_values={input_values["question"]: child.wait_for_answer(subscription)}) + + analysis.output_values = {input_values["question"]: child.wait_for_answer(subscription)} + return MockAnalysis(output_values=analysis.output_values) child_of_child = self.make_new_child(BACKEND, run_function_returnee=DifferentMockAnalysis()) @@ -628,25 +658,27 @@ def child_run_function(analysis_id, input_values, *args, **kwargs): "What does the child of the child say?": { "output_values": DifferentMockAnalysis.output_values, "output_manifest": DifferentMockAnalysis.output_manifest, + "success": True, } }, "output_manifest": None, + "success": True, }, ) def test_child_can_ask_its_own_children_questions(self): """Test that a child can contact more than one of its own children while answering a question from a parent.""" - def child_run_function(analysis_id, input_values, *args, **kwargs): + def child_run_function(analysis, analysis_id, input_values, *args, **kwargs): subscription_1, _ = child.ask(service_id=first_child_of_child.id, input_values=input_values) subscription_2, _ = child.ask(service_id=second_child_of_child.id, input_values=input_values) - return MockAnalysis( - output_values={ - "first_child_of_child": child.wait_for_answer(subscription_1), - "second_child_of_child": child.wait_for_answer(subscription_2), - } - ) + analysis.output_values = { + "first_child_of_child": child.wait_for_answer(subscription_1), + "second_child_of_child": child.wait_for_answer(subscription_2), + } + + return MockAnalysis(output_values=analysis.output_values) first_child_of_child = self.make_new_child(BACKEND, run_function_returnee=DifferentMockAnalysis()) second_child_of_child = self.make_new_child(BACKEND, run_function_returnee=MockAnalysis()) @@ -680,13 +712,16 @@ def child_run_function(analysis_id, input_values, *args, **kwargs): "first_child_of_child": { "output_values": DifferentMockAnalysis.output_values, "output_manifest": DifferentMockAnalysis.output_manifest, + "success": True, }, "second_child_of_child": { "output_values": MockAnalysis().output_values, "output_manifest": MockAnalysis().output_manifest, + "success": True, }, }, "output_manifest": None, + "success": True, }, ) @@ -705,7 +740,10 @@ def test_child_messages_can_be_recorded_by_parent(self): for i in range(1, 6): self.assertEqual(parent.received_events[i]["event"]["kind"], "log_record") - self.assertEqual(parent.received_events[6]["event"], {"kind": "result", "output_values": "Hello! It worked!"}) + self.assertEqual( + parent.received_events[6]["event"], + {"kind": "result", "output_values": "Hello! It worked!", "success": True}, + ) def test_child_exception_message_can_be_recorded_by_parent(self): """Test that the parent can record exceptions raised by the child.""" @@ -726,9 +764,15 @@ def test_child_sends_heartbeat_messages_at_expected_regular_intervals(self): """Test that children send heartbeat messages at the expected regular intervals.""" expected_interval = 0.05 - def run_function(*args, **kwargs): + def run_function(analysis=None, *args, **kwargs): time.sleep(0.3) - return MockAnalysis() + mock_analysis = MockAnalysis() + + if analysis: + analysis.output_values = mock_analysis.output_values + analysis.output_manifest = mock_analysis.output_manifest + + return mock_analysis child = MockService(backend=BACKEND, run_function=lambda *args, **kwargs: run_function()) parent = MockService(backend=BACKEND, children={child.id: child}) @@ -764,15 +808,12 @@ def test_send_monitor_messages_periodically(self): message thread doesn't stop the result from being received (i.e. message sending is thread-safe). """ - def run_function(*args, **kwargs): - analysis = Analysis( - twine={"monitor_message_schema": {"type": "number"}}, - handle_monitor_message=kwargs["handle_monitor_message"], - ) - + def run_function(analysis, *args, **kwargs): + analysis.twine = Twine(source={"monitor_message_schema": {"type": "number"}}) + analysis._handle_monitor_message = kwargs["handle_monitor_message"] analysis.set_up_periodic_monitor_message(create_monitor_message=random.random, period=0.05) - time.sleep(1) analysis.output_values = {"tada": True} + time.sleep(1) return analysis child = MockService(backend=BACKEND, run_function=run_function) @@ -875,7 +916,15 @@ def mock_child_app(analysis): self.assertEqual(answer["output_values"], "I am the dynamic child.") @staticmethod - def make_new_child(backend, run_function_returnee, service_id=None): + def make_run_function(run_function_returnee): + def _run_function(analysis, *args, **kwargs): + analysis.output_values = run_function_returnee.output_values + analysis.output_manifest = run_function_returnee.output_manifest + return run_function_returnee + + return _run_function + + def make_new_child(self, backend, run_function_returnee, service_id=None): """Make and return a new child service that returns the given run function returnee when its run function is executed. @@ -887,7 +936,7 @@ def make_new_child(backend, run_function_returnee, service_id=None): return MockService( backend=backend, service_id=service_id, - run_function=lambda *args, **kwargs: run_function_returnee, + run_function=self.make_run_function(run_function_returnee), ) def make_new_child_with_error(self, exception_to_raise): diff --git a/tests/resources/test_child.py b/tests/resources/test_child.py index 06549ac94..048267960 100644 --- a/tests/resources/test_child.py +++ b/tests/resources/test_child.py @@ -1,20 +1,19 @@ import functools import logging +from multiprocessing import Value import os import random import threading import time -from multiprocessing import Value from unittest.mock import patch -from octue.cloud.emulators._pub_sub import MockAnalysis, MockService +from octue.cloud.emulators._pub_sub import MockService from octue.cloud.emulators.service import ServicePatcher from octue.resources.child import Child from octue.resources.service_backends import GCPPubSubBackend from tests import MOCK_SERVICE_REVISION_TAG from tests.base import BaseTestCase - lock = threading.Lock() @@ -23,7 +22,7 @@ def mock_run_function_that_fails(analysis_id, input_values, *args, **kwargs): raise ValueError("Deliberately raised for `Child.ask` test.") -def mock_run_function_that_fails_every_other_time(analysis_id, input_values, *args, **kwargs): +def mock_run_function_that_fails_every_other_time(analysis, analysis_id, input_values, *args, **kwargs): """A run function that always fails every other time, starting with the first time.""" with lock: # Every other question will fail. @@ -32,7 +31,8 @@ def mock_run_function_that_fails_every_other_time(analysis_id, input_values, *ar raise ValueError("Deliberately raised for `Child.ask` test.") time.sleep(random.random() * 0.1) - return MockAnalysis(output_values=input_values) + analysis.output_values = input_values + return analysis class TestChild(BaseTestCase): @@ -77,8 +77,9 @@ def test_instantiating_child_without_credentials(self): def test_child_can_be_asked_multiple_questions_in_serial(self): """Test that a child can be asked multiple questions in serial.""" - def mock_run_function(analysis_id, input_values, *args, **kwargs): - return MockAnalysis(output_values=input_values) + def mock_run_function(analysis, analysis_id, input_values, *args, **kwargs): + analysis.output_values = input_values + return analysis responding_service = MockService(backend=GCPPubSubBackend(project_id="blah"), run_function=mock_run_function) @@ -144,7 +145,7 @@ def test_with_failed_question_retry(self): answer, _ = child.ask(input_values=[1, 2, 3, 4], raise_errors=False, max_retries=1) # Check that the question succeeds. - self.assertEqual(answer, {"output_manifest": None, "output_values": [1, 2, 3, 4]}) + self.assertEqual(answer, {"output_manifest": None, "output_values": [1, 2, 3, 4], "success": True}) def test_errors_logged_when_not_raised(self): """Test that errors from a question still failing after retries are exhausted are logged by default.""" @@ -163,8 +164,8 @@ def test_errors_logged_when_not_raised(self): with self.assertLogs(level=logging.ERROR) as logging_context: child.ask(input_values=[1, 2, 3, 4], raise_errors=False, max_retries=0) - self.assertIn("failed after 0 retries (see below for error).", logging_context.output[2]) - self.assertIn('raise ValueError("Deliberately raised for `Child.ask` test.")', logging_context.output[2]) + self.assertIn("failed after 0 retries (see below for error).", logging_context.output[3]) + self.assertIn('raise ValueError("Deliberately raised for `Child.ask` test.")', logging_context.output[3]) def test_with_prevented_retries(self): """Test that retries can be prevented for specified exception types.""" @@ -216,9 +217,10 @@ def tearDownClass(cls): def test_ask_multiple(self): """Test that a child can be asked multiple questions in parallel and return the answers in the correct order.""" - def mock_run_function(analysis_id, input_values, *args, **kwargs): + def mock_run_function(analysis, analysis_id, input_values, *args, **kwargs): time.sleep(random.randint(0, 2)) - return MockAnalysis(output_values=input_values) + analysis.output_values = input_values + return analysis responding_service = MockService(backend=GCPPubSubBackend(project_id="blah"), run_function=mock_run_function) @@ -238,8 +240,8 @@ def mock_run_function(analysis_id, input_values, *args, **kwargs): self.assertEqual( [answer[0] for answer in answers], [ - {"output_values": [1, 2, 3, 4], "output_manifest": None}, - {"output_values": [5, 6, 7, 8], "output_manifest": None}, + {"output_values": [1, 2, 3, 4], "output_manifest": None, "success": True}, + {"output_values": [5, 6, 7, 8], "output_manifest": None, "success": True}, ], ) @@ -280,9 +282,9 @@ def test_with_multiple_failed_question_retries(self): self.assertEqual( [answer[0] for answer in answers], [ - {"output_manifest": None, "output_values": [1, 2, 3, 4]}, - {"output_manifest": None, "output_values": [5, 6, 7, 8]}, - {"output_manifest": None, "output_values": [9, 10, 11, 12]}, - {"output_manifest": None, "output_values": [13, 14, 15, 16]}, + {"output_manifest": None, "output_values": [1, 2, 3, 4], "success": True}, + {"output_manifest": None, "output_values": [5, 6, 7, 8], "success": True}, + {"output_manifest": None, "output_values": [9, 10, 11, 12], "success": True}, + {"output_manifest": None, "output_values": [13, 14, 15, 16], "success": True}, ], )