Skip to content

Commit c5deaa3

Browse files
committed
feat: add ActionsClient.wait_for function
This function allows the users to wait for multiple actions in an efficient way. All actions are queried using a single call, which reduce the potential for running into rate limits.
1 parent bcbdd7d commit c5deaa3

File tree

6 files changed

+306
-7
lines changed

6 files changed

+306
-7
lines changed

hcloud/_utils.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from __future__ import annotations
2+
3+
import time
4+
from collections.abc import Iterable, Iterator
5+
from itertools import islice
6+
from typing import Callable, TypeVar
7+
8+
T = TypeVar("T")
9+
10+
11+
def batched(iterable: Iterable[T], size: int) -> Iterator[tuple[T, ...]]:
12+
"""
13+
Returns a batch of the provided size from the provided iterable.
14+
"""
15+
iterator = iter(iterable)
16+
while True:
17+
batch = tuple(islice(iterator, size))
18+
if not batch:
19+
break
20+
yield batch
21+
22+
23+
def waiter(timeout: float | None = None) -> Callable[[float], bool]:
24+
"""
25+
Waiter returns a wait function that sleeps the specified amount of seconds, and
26+
handles timeouts.
27+
28+
The wait function returns True if the timeout was reached, False otherwise.
29+
30+
:param timeout: Timeout in seconds, defaults to None.
31+
:return: Wait function.
32+
"""
33+
34+
if timeout:
35+
deadline = time.time() + timeout
36+
37+
def wait(seconds: float) -> bool:
38+
now = time.time()
39+
40+
# Timeout if the deadline exceeded.
41+
if deadline < now:
42+
return True
43+
44+
# The deadline is not exceeded after the sleep time.
45+
if now + seconds < deadline:
46+
time.sleep(seconds)
47+
return False
48+
49+
# The deadline is exceeded after the sleep time, clamp sleep time to
50+
# deadline, and allow one last attempt until next wait call.
51+
time.sleep(deadline - now)
52+
return False
53+
54+
else:
55+
56+
def wait(seconds: float) -> bool:
57+
time.sleep(seconds)
58+
return False
59+
60+
return wait

hcloud/actions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Action,
1111
ActionException,
1212
ActionFailedException,
13+
ActionGroupException,
1314
ActionTimeoutException,
1415
)
1516

@@ -18,6 +19,7 @@
1819
"ActionException",
1920
"ActionFailedException",
2021
"ActionTimeoutException",
22+
"ActionGroupException",
2123
"ActionsClient",
2224
"ActionsPageResult",
2325
"BoundAction",

hcloud/actions/client.py

Lines changed: 127 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from __future__ import annotations
22

3-
import time
43
import warnings
5-
from typing import TYPE_CHECKING, Any, NamedTuple
4+
from typing import TYPE_CHECKING, Any, Callable, NamedTuple
65

6+
from .._utils import batched, waiter
77
from ..core import BoundModelBase, ClientEntityBase, Meta
8-
from .domain import Action, ActionFailedException, ActionTimeoutException
8+
from .domain import (
9+
Action,
10+
ActionFailedException,
11+
ActionGroupException,
12+
ActionTimeoutException,
13+
)
914

1015
if TYPE_CHECKING:
1116
from .._client import Client
@@ -16,18 +21,24 @@ class BoundAction(BoundModelBase, Action):
1621

1722
model = Action
1823

19-
def wait_until_finished(self, max_retries: int | None = None) -> None:
24+
def wait_until_finished(
25+
self,
26+
max_retries: int | None = None,
27+
*,
28+
timeout: float | None = None,
29+
) -> None:
2030
"""Wait until the specific action has status=finished.
2131
2232
:param max_retries: int Specify how many retries will be performed before an ActionTimeoutException will be raised.
2333
:raises: ActionFailedException when action is finished with status==error
24-
:raises: ActionTimeoutException when Action is still in status==running after max_retries is reached.
34+
:raises: ActionTimeoutException when Action is still in status==running after max_retries or timeout is reached.
2535
"""
2636
if max_retries is None:
2737
# pylint: disable=protected-access
2838
max_retries = self._client._client._poll_max_retries
2939

3040
retries = 0
41+
wait = waiter(timeout)
3142
while True:
3243
self.reload()
3344
if self.status != Action.STATUS_RUNNING:
@@ -36,8 +47,8 @@ def wait_until_finished(self, max_retries: int | None = None) -> None:
3647
retries += 1
3748
if retries < max_retries:
3849
# pylint: disable=protected-access
39-
time.sleep(self._client._client._poll_interval_func(retries))
40-
continue
50+
if not wait(self._client._client._poll_interval_func(retries)):
51+
continue
4152

4253
raise ActionTimeoutException(action=self)
4354

@@ -129,6 +140,115 @@ class ActionsClient(ResourceActionsClient):
129140
def __init__(self, client: Client):
130141
super().__init__(client, None)
131142

143+
# TODO: Consider making public?
144+
def _get_list_by_ids(self, ids: list[int]) -> list[BoundAction]:
145+
"""
146+
Get a list of Actions by their IDs.
147+
148+
:param ids: List of Action IDs to get.
149+
:raises ValueError: Raise when Action IDs were not found.
150+
:return: List of Actions.
151+
"""
152+
actions: list[BoundAction] = []
153+
154+
for ids_batch in batched(ids, 25):
155+
params: dict[str, Any] = {
156+
"id": ids_batch,
157+
}
158+
159+
response = self._client.request(
160+
method="GET",
161+
url="/actions",
162+
params=params,
163+
)
164+
165+
actions.extend(
166+
BoundAction(self._client.actions, action_data)
167+
for action_data in response["actions"]
168+
)
169+
170+
# TODO: Should this be moved to the the wait function?
171+
if len(ids) != len(actions):
172+
found_ids = [a.id for a in actions]
173+
not_found_ids = list(set(ids) - set(found_ids))
174+
175+
raise ValueError(
176+
f"actions not found: {', '.join(str(o) for o in not_found_ids)}"
177+
)
178+
179+
return actions
180+
181+
def wait_for_function(
182+
self,
183+
handle_update: Callable[[BoundAction], None],
184+
actions: list[Action | BoundAction],
185+
*,
186+
timeout: float | None = None,
187+
) -> list[BoundAction]:
188+
"""
189+
Waits until all Actions succeed by polling the API at the interval defined by
190+
the client's poll interval and function. An Action is considered as complete
191+
when its status is either "success" or "error".
192+
193+
The handle_update callback is called every time an Action is updated.
194+
195+
:param handle_update: Function called every time an Action is updated.
196+
:param actions: List of Actions to wait for.
197+
:param timeout: Timeout in seconds.
198+
:raises: ActionFailedException when an Action failed.
199+
:return: List of succeeded Actions.
200+
"""
201+
running: list[BoundAction] = list(actions)
202+
completed: list[BoundAction] = []
203+
204+
retries = 0
205+
wait = waiter(timeout)
206+
while len(running) > 0:
207+
# pylint: disable=protected-access
208+
if wait(self._client._poll_interval_func(retries)):
209+
raise ActionGroupException(
210+
[ActionTimeoutException(action=action) for action in running]
211+
)
212+
213+
retries += 1
214+
215+
running = self._get_list_by_ids([a.id for a in running])
216+
217+
for update in running:
218+
if update.status != Action.STATUS_RUNNING:
219+
running.remove(update)
220+
completed.append(update)
221+
222+
handle_update(update)
223+
224+
return completed
225+
226+
def wait_for(
227+
self,
228+
actions: list[Action | BoundAction],
229+
*,
230+
timeout: float | None = None,
231+
) -> list[BoundAction]:
232+
"""
233+
Waits until all Actions succeed by polling the API at the interval defined by
234+
the client's poll interval and function. An Action is considered as complete
235+
when its status is either "success" or "error".
236+
237+
If a single Action fails, the function will stop waiting and raise ActionFailedException.
238+
239+
:param actions: List of Actions to wait for.
240+
:param timeout: Timeout in seconds.
241+
:raises: ActionFailedException when an Action failed.
242+
:raises: TimeoutError when the Actions did not succeed before timeout.
243+
:return: List of succeeded Actions.
244+
"""
245+
246+
def handle_update(update: BoundAction) -> None:
247+
if update.status == Action.STATUS_ERROR:
248+
raise ActionFailedException(action=update)
249+
250+
return self.wait_for_function(handle_update, actions, timeout=timeout)
251+
132252
def get_list(
133253
self,
134254
status: list[str] | None = None,

hcloud/actions/domain.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,11 @@ class ActionFailedException(ActionException):
9898

9999
class ActionTimeoutException(ActionException):
100100
"""The pending action timed out"""
101+
102+
103+
class ActionGroupException(HCloudException):
104+
"""An exception for a group of actions"""
105+
106+
def __init__(self, exceptions: list[ActionException]):
107+
super().__init__("Multiple pending actions failed")
108+
self.exceptions = exceptions

tests/unit/actions/test_client.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from hcloud.actions import (
88
Action,
99
ActionFailedException,
10+
ActionGroupException,
1011
ActionsClient,
1112
ActionTimeoutException,
1213
BoundAction,
@@ -197,3 +198,90 @@ def test_get_all(self, actions_client, generic_action_list, params):
197198
assert action2._client == actions_client._client.actions
198199
assert action2.id == 2
199200
assert action2.command == "stop_server"
201+
202+
def test_wait_for(self, actions_client: ActionsClient):
203+
actions = [Action(id=1), Action(id=2)]
204+
205+
# Speed up test by not really waiting
206+
actions_client._client._poll_interval_func = mock.MagicMock()
207+
actions_client._client._poll_interval_func.return_value = 0.1
208+
209+
actions_client._client.request.side_effect = [
210+
{
211+
"actions": [
212+
{"id": 1, "status": "running"},
213+
{"id": 2, "status": "success"},
214+
]
215+
},
216+
{
217+
"actions": [
218+
{"id": 1, "status": "success"},
219+
]
220+
},
221+
]
222+
223+
actions = actions_client.wait_for(actions)
224+
225+
actions_client._client.request.assert_has_calls(
226+
[
227+
mock.call(method="GET", url="/actions", params={"id": (1, 2)}),
228+
mock.call(method="GET", url="/actions", params={"id": (1,)}),
229+
]
230+
)
231+
232+
assert len(actions) == 2
233+
234+
def test_wait_for_error(self, actions_client: ActionsClient):
235+
actions = [Action(id=1), Action(id=2)]
236+
237+
# Speed up test by not really waiting
238+
actions_client._client._poll_interval_func = mock.MagicMock()
239+
actions_client._client._poll_interval_func.return_value = 0.1
240+
241+
actions_client._client.request.side_effect = [
242+
{
243+
"actions": [
244+
{"id": 1, "status": "running"},
245+
{
246+
"id": 2,
247+
"status": "error",
248+
"error": {"code": "failed", "message": "Action failed"},
249+
},
250+
]
251+
},
252+
]
253+
254+
with pytest.raises(ActionFailedException):
255+
actions_client.wait_for(actions)
256+
257+
actions_client._client.request.assert_has_calls(
258+
[
259+
mock.call(method="GET", url="/actions", params={"id": (1, 2)}),
260+
]
261+
)
262+
263+
def test_wait_for_timeout(self, actions_client: ActionsClient):
264+
actions = [
265+
Action(id=1, status="running", command="create_server"),
266+
Action(id=2, status="running", command="start_server"),
267+
]
268+
269+
# Speed up test by not really waiting
270+
actions_client._client._poll_interval_func = mock.MagicMock()
271+
actions_client._client._poll_interval_func.return_value = 0.1
272+
273+
actions_client._client.request.return_value = {
274+
"actions": [
275+
{"id": 1, "status": "running", "command": "create_server"},
276+
{"id": 2, "status": "running", "command": "start_server"},
277+
]
278+
}
279+
280+
with pytest.raises(ActionGroupException):
281+
actions_client.wait_for(actions, timeout=0.2)
282+
283+
actions_client._client.request.assert_has_calls(
284+
[
285+
mock.call(method="GET", url="/actions", params={"id": (1, 2)}),
286+
]
287+
)

tests/unit/test_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from __future__ import annotations
2+
3+
import time
4+
5+
from hcloud._utils import batched, waiter
6+
7+
8+
def test_batched():
9+
assert list(o for o in batched([1, 2, 3, 4, 5], 2)) == [(1, 2), (3, 4), (5,)]
10+
11+
12+
def test_waiter():
13+
wait = waiter(timeout=0.2)
14+
assert wait(0.1) is False
15+
time.sleep(0.2)
16+
assert wait(1) is True
17+
18+
# Clamp sleep to deadline
19+
wait = waiter(timeout=0.2)
20+
assert wait(0.3) is False
21+
assert wait(1) is True

0 commit comments

Comments
 (0)