Skip to content

Commit d882695

Browse files
committed
feat(core): Add RunFunctionWaitStrategy
This is usefull for converting old wait_container_is_ready which use container.exec() to check for a condition to become True.
1 parent 5c1504c commit d882695

File tree

2 files changed

+94
-0
lines changed

2 files changed

+94
-0
lines changed

core/testcontainers/core/wait_strategies.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,44 @@ def _get_status_compose_container(container: DockerCompose) -> str:
703703
raise NotImplementedError
704704

705705

706+
class RunFunctionWaitStrategy(WaitStrategy):
707+
"""Runs a functions and waits until it succeeds.
708+
709+
The function must take a single argument, the WaitStrategyTarget (= DockerContainer)
710+
(use a lambda to capture other arguments) and must return a Boolean or raise an Exception.
711+
712+
Args:
713+
func: The function to run. It must return True when the wait is over.
714+
"""
715+
716+
def __init__(
717+
self,
718+
func: Callable[[WaitStrategyTarget], bool],
719+
):
720+
super().__init__()
721+
self.func = func
722+
723+
def wait_until_ready(self, container: WaitStrategyTarget) -> Any:
724+
start_time = time.time()
725+
last_exception = None
726+
while True:
727+
try:
728+
result = self.func(container)
729+
if result:
730+
return result
731+
except tuple(self._transient_exceptions) as e:
732+
logger.debug(f"Check attempt failed: {e!s}")
733+
last_exception = str(e)
734+
if time.time() - start_time > self._startup_timeout:
735+
raise TimeoutError(
736+
f"Wait time ({self._startup_timeout}s) exceeded for {self.func.__name__}"
737+
f"Exception: {last_exception}. "
738+
f"Hint: Check if the container is ready, "
739+
f"and the expected conditions are met for the function to succeed."
740+
)
741+
time.sleep(self._poll_interval)
742+
743+
706744
class CompositeWaitStrategy(WaitStrategy):
707745
"""
708746
Wait for multiple conditions to be satisfied in sequence.
@@ -787,6 +825,7 @@ def wait_until_ready(self, container: WaitStrategyTarget) -> None:
787825
"HttpWaitStrategy",
788826
"LogMessageWaitStrategy",
789827
"PortWaitStrategy",
828+
"RunFunctionWaitStrategy",
790829
"WaitStrategy",
791830
"WaitStrategyTarget",
792831
]

core/tests/test_wait_strategies.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
LogMessageWaitStrategy,
1616
PortWaitStrategy,
1717
WaitStrategy,
18+
RunFunctionWaitStrategy,
1819
)
1920

2021

@@ -550,6 +551,60 @@ def test_wait_until_ready(self, mock_sleep, mock_time, mock_is_file, file_exists
550551
strategy.wait_until_ready(mock_container)
551552

552553

554+
class TestRunFunctionWaitStrategy:
555+
"""Test the RunFunctionWaitStrategy class."""
556+
557+
def test_run_function_wait_strategy_initialization(self):
558+
func = lambda x: True
559+
strategy = RunFunctionWaitStrategy(func)
560+
assert strategy.func == func
561+
562+
def test_run_function_wait_strategy_wait_until_ready(self):
563+
returns = [False, False, True]
564+
mock_container = object()
565+
566+
def func(target) -> bool:
567+
assert target is mock_container
568+
return returns.pop(0)
569+
570+
strategy = RunFunctionWaitStrategy(func).with_poll_interval(0)
571+
strategy.wait_until_ready(mock_container) # type: ignore[arg-type]
572+
573+
def test_run_function_wait_strategy_wait_until_ready_with_unknown_exception(self):
574+
mock_container = object()
575+
576+
def func(target) -> bool:
577+
assert target is mock_container
578+
raise RuntimeError("Unknown error, abort!")
579+
580+
strategy = RunFunctionWaitStrategy(func).with_poll_interval(0)
581+
with pytest.raises(RuntimeError, match="Unknown error, abort!"):
582+
strategy.wait_until_ready(mock_container) # type: ignore[arg-type]
583+
584+
@pytest.mark.parametrize("transient_exception", [ConnectionError, NotImplementedError])
585+
def test_run_function_wait_strategy_wait_until_ready_with_transient_exception(self, transient_exception):
586+
mock_container = object()
587+
returns = [False, False, True]
588+
589+
def func(target) -> bool:
590+
assert target is mock_container
591+
if returns.pop(0):
592+
return True
593+
raise transient_exception("Go on")
594+
595+
# ConnectionError should be in the default transient exceptions, but NotImplementedError ist not
596+
strategy = (
597+
RunFunctionWaitStrategy(func).with_poll_interval(0.001).with_transient_exceptions(NotImplementedError)
598+
)
599+
strategy.wait_until_ready(mock_container) # type: ignore[arg-type]
600+
601+
def test_run_function_wait_strategy_wait_until_ready_with_timeout(self):
602+
mock_container = object()
603+
strategy = RunFunctionWaitStrategy(lambda x: False).with_poll_interval(0).with_startup_timeout(0)
604+
with pytest.raises(TimeoutError, match=r"Wait time (.*) exceeded for"):
605+
strategy.wait_until_ready(mock_container) # type: ignore[arg-type]
606+
607+
553608
class TestCompositeWaitStrategy:
554609
"""Test the CompositeWaitStrategy class."""
555610

0 commit comments

Comments
 (0)