Skip to content

Commit 301a137

Browse files
committed
Make when_all and when_any tasks composable #29
Signed-off-by: evhen14 <[email protected]>
1 parent 3854b18 commit 301a137

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

durabletask/task.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,8 @@ def on_child_completed(self, task: Task[T]):
322322
# The order of the result MUST match the order of the tasks provided to the constructor.
323323
self._result = [task.get_result() for task in self._tasks]
324324
self._is_complete = True
325+
if self._parent is not None:
326+
self._parent.on_child_completed(self)
325327

326328
def get_completed_tasks(self) -> int:
327329
return self._completed_tasks
@@ -423,6 +425,8 @@ def on_child_completed(self, task: Task):
423425
if not self.is_complete:
424426
self._is_complete = True
425427
self._result = task
428+
if self._parent is not None:
429+
self._parent.on_child_completed(self)
426430

427431

428432
def when_all(tasks: list[Task[T]]) -> WhenAllTask[T]:

tests/durabletask/test_task.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,49 @@ def test_when_all_happy_path_returns_ordered_results_and_completes_last():
4646
assert all_task.get_result() == ["one", "two", "three"]
4747

4848

49+
def test_when_all_is_composable_with_when_any():
50+
c1 = task.CompletableTask()
51+
c2 = task.CompletableTask()
52+
53+
any_task = task.when_any([c1, c2])
54+
all_task = task.when_all([any_task])
55+
56+
assert not any_task.is_complete
57+
assert not all_task.is_complete
58+
59+
c2.complete("two")
60+
61+
assert any_task.is_complete
62+
assert all_task.is_complete
63+
64+
assert all_task.is_complete
65+
66+
assert all_task.get_result() == [c2]
67+
68+
69+
def test_when_any_is_composable_with_when_all():
70+
c1 = task.CompletableTask()
71+
c2 = task.CompletableTask()
72+
c3 = task.CompletableTask()
73+
74+
all_task1 = task.when_all([c1, c2])
75+
all_task2 = task.when_all([c3])
76+
any_task = task.when_any([all_task1, all_task2])
77+
78+
assert not any_task.is_complete
79+
assert not all_task1.is_complete
80+
assert not all_task2.is_complete
81+
82+
c1.complete("one")
83+
c2.complete("two")
84+
85+
assert any_task.is_complete
86+
assert all_task1.is_complete
87+
assert not all_task2.is_complete
88+
89+
assert any_task.get_result() == all_task1
90+
91+
4992
def test_when_any_happy_path_returns_winner_task_and_completes_on_first():
5093
a = task.CompletableTask()
5194
b = task.CompletableTask()

0 commit comments

Comments
 (0)