|  | 
| 1 | 1 | """Contains code relevant to the execution.""" | 
| 2 | 2 | import sys | 
| 3 | 3 | import time | 
|  | 4 | +from typing import Any | 
|  | 5 | +from typing import Tuple | 
| 4 | 6 | 
 | 
| 5 | 7 | import cloudpickle | 
| 6 | 8 | from _pytask.config import hookimpl | 
|  | 9 | +from _pytask.console import console | 
| 7 | 10 | from _pytask.report import ExecutionReport | 
|  | 11 | +from _pytask.traceback import remove_internal_traceback_frames_from_exc_info | 
| 8 | 12 | from pytask_parallel.backends import PARALLEL_BACKENDS | 
|  | 13 | +from rich.console import ConsoleOptions | 
|  | 14 | +from rich.traceback import Traceback | 
| 9 | 15 | 
 | 
| 10 | 16 | 
 | 
| 11 | 17 | @hookimpl | 
| 12 | 18 | def pytask_post_parse(config): | 
| 13 | 19 |     """Register the parallel backend.""" | 
| 14 |  | -    if config["parallel_backend"] == "processes": | 
|  | 20 | +    if config["parallel_backend"] in ["loky", "processes"]: | 
| 15 | 21 |         config["pm"].register(ProcessesNameSpace) | 
| 16 |  | -    elif config["parallel_backend"] in ["threads", "loky"]: | 
|  | 22 | +    elif config["parallel_backend"] in ["threads"]: | 
| 17 | 23 |         config["pm"].register(DefaultBackendNameSpace) | 
| 18 | 24 | 
 | 
| 19 | 25 | 
 | 
| @@ -72,13 +78,23 @@ def pytask_execute_build(session): | 
| 72 | 78 | 
 | 
| 73 | 79 |                     for task_name in list(running_tasks): | 
| 74 | 80 |                         future = running_tasks[task_name] | 
| 75 |  | -                        if future.done() and future.exception() is not None: | 
|  | 81 | +                        if future.done() and ( | 
|  | 82 | +                            future.exception() is not None | 
|  | 83 | +                            or future.result() is not None | 
|  | 84 | +                        ): | 
| 76 | 85 |                             task = session.dag.nodes[task_name]["task"] | 
| 77 |  | -                            exception = future.exception() | 
| 78 |  | -                            newly_collected_reports.append( | 
| 79 |  | -                                ExecutionReport.from_task_and_exception( | 
| 80 |  | -                                    task, (type(exception), exception, None) | 
|  | 86 | +                            if future.exception() is not None: | 
|  | 87 | +                                exception = future.exception() | 
|  | 88 | +                                exc_info = ( | 
|  | 89 | +                                    type(exception), | 
|  | 90 | +                                    exception, | 
|  | 91 | +                                    exception.__traceback__, | 
| 81 | 92 |                                 ) | 
|  | 93 | +                            else: | 
|  | 94 | +                                exc_info = future.result() | 
|  | 95 | + | 
|  | 96 | +                            newly_collected_reports.append( | 
|  | 97 | +                                ExecutionReport.from_task_and_exception(task, exc_info) | 
| 82 | 98 |                             ) | 
| 83 | 99 |                             running_tasks.pop(task_name) | 
| 84 | 100 |                             session.scheduler.done(task_name) | 
| @@ -132,18 +148,41 @@ def pytask_execute_task(session, task):  # noqa: N805 | 
| 132 | 148 |         """ | 
| 133 | 149 |         if session.config["n_workers"] > 1: | 
| 134 | 150 |             bytes_ = cloudpickle.dumps(task) | 
| 135 |  | -            return session.executor.submit(unserialize_and_execute_task, bytes_) | 
|  | 151 | +            return session.executor.submit( | 
|  | 152 | +                _unserialize_and_execute_task, | 
|  | 153 | +                bytes_=bytes_, | 
|  | 154 | +                show_locals=session.config["show_locals"], | 
|  | 155 | +                console_options=console.options, | 
|  | 156 | +            ) | 
| 136 | 157 | 
 | 
| 137 | 158 | 
 | 
| 138 |  | -def unserialize_and_execute_task(bytes_): | 
|  | 159 | +def _unserialize_and_execute_task(bytes_, show_locals, console_options): | 
| 139 | 160 |     """Unserialize and execute task. | 
| 140 | 161 | 
 | 
| 141 | 162 |     This function receives bytes and unpickles them to a task which is them execute | 
| 142 | 163 |     in a spawned process or thread. | 
| 143 | 164 | 
 | 
| 144 | 165 |     """ | 
|  | 166 | +    __tracebackhide__ = True | 
|  | 167 | + | 
| 145 | 168 |     task = cloudpickle.loads(bytes_) | 
| 146 |  | -    task.execute() | 
|  | 169 | + | 
|  | 170 | +    try: | 
|  | 171 | +        task.execute() | 
|  | 172 | +    except Exception: | 
|  | 173 | +        exc_info = sys.exc_info() | 
|  | 174 | +        processed_exc_info = _process_exception(exc_info, show_locals, console_options) | 
|  | 175 | +        return processed_exc_info | 
|  | 176 | + | 
|  | 177 | + | 
|  | 178 | +def _process_exception( | 
|  | 179 | +    exc_info: Tuple[Any], show_locals: bool, console_options: ConsoleOptions | 
|  | 180 | +) -> Tuple[Any]: | 
|  | 181 | +    exc_info = remove_internal_traceback_frames_from_exc_info(exc_info) | 
|  | 182 | +    traceback = Traceback.from_exception(*exc_info, show_locals=show_locals) | 
|  | 183 | +    segments = console.render(traceback, options=console_options) | 
|  | 184 | +    text = "".join(segment.text for segment in segments) | 
|  | 185 | +    return (*exc_info[:2], text) | 
| 147 | 186 | 
 | 
| 148 | 187 | 
 | 
| 149 | 188 | class DefaultBackendNameSpace: | 
|  | 
0 commit comments