diff --git a/invoke/tasks.py b/invoke/tasks.py index cd3075e9..03888f19 100644 --- a/invoke/tasks.py +++ b/invoke/tasks.py @@ -32,6 +32,14 @@ T = TypeVar("T", bound=Callable) +class PreTaskDesc: + def __get__(self, obj, _type): + return obj._pre + + def __set__(self, obj, value): + for p in value: + p._is_pre_of = obj + obj._pre = value class Task(Generic[T]): """ @@ -48,6 +56,7 @@ class Task(Generic[T]): .. versionadded:: 1.0 """ + pre = PreTaskDesc() # TODO: store these kwarg defaults central, refer to those values both here # and in @task. @@ -395,6 +404,11 @@ def __init__( Keyword arguments to call with, if any. Default: ``None``. """ self.task = task + if hasattr(task, "_is_pre_of"): + self.make_context = _copy_attrs_to_return_val( + task, + "_is_pre_of" + )(self.make_context) self.called_as = called_as self.args = args or tuple() self.kwargs = kwargs or dict() @@ -517,3 +531,18 @@ def clean_build(c): .. versionadded:: 1.0 """ return Call(task, args=args, kwargs=kwargs) + +def _copy_attrs_to_return_val(source, *attrs): + """ + Copy attributes from a source to the return value of the decorated func + """ + def _wrapper(func): + def _inner(*args, **kwargs): + target = func(*args, **kwargs) + for name in attrs: + value = getattr(source, name) + if value: + setattr(target, name, value) + return target + return _inner + return _wrapper diff --git a/tests/task.py b/tests/task.py index d60d9123..8bff9bce 100644 --- a/tests/task.py +++ b/tests/task.py @@ -92,6 +92,33 @@ def func(c): assert func.pre == [whatever] + def task_and_pre_tasks_binding(self): + + @task + def pre_task(c): + pass + + @task(pre=[pre_task]) + def my_task(c): + pass + + assert all([hasattr(p, "_is_pre_of") for p in my_task.pre]) + assert pre_task._is_pre_of == my_task + + def create_call_and_context_form_pre_task_has_access_to_parent_task(self): + + @task + def pre_task(c): + pass + + @task(pre=[pre_task]) + def my_task(c): + pass + + call = Call(pre_task) + c = call.make_context(Config(defaults={})) + assert getattr(c, "_is_pre_of") == my_task + def allows_star_args_as_shortcut_for_pre(self): @task def pre1(c):