Skip to content

Commit b59de98

Browse files
committed
add new deterministic functions, non-retryable errors, and shutdown helpers
Signed-off-by: Filinto Duran <[email protected]>
1 parent 3854b18 commit b59de98

File tree

11 files changed

+849
-124
lines changed

11 files changed

+849
-124
lines changed

dev-requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-
grpcio-tools==1.62.3 # 1.62.X is the latest version before protobuf 1.26.X is used which has breaking changes for Python # supports protobuf 6.x and aligns with generated code

durabletask/client.py

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,28 @@ def __init__(
127127
interceptors=interceptors,
128128
options=channel_options,
129129
)
130+
self._channel = channel
130131
self._stub = stubs.TaskHubSidecarServiceStub(channel)
131132
self._logger = shared.get_logger("client", log_handler, log_formatter)
132133

134+
def __enter__(self):
135+
return self
136+
137+
def __exit__(self, exc_type, exc, tb):
138+
try:
139+
self.close()
140+
finally:
141+
return False
142+
143+
def close(self) -> None:
144+
"""Close the underlying gRPC channel."""
145+
try:
146+
# grpc.Channel.close() is idempotent
147+
self._channel.close()
148+
except Exception:
149+
# Best-effort cleanup
150+
pass
151+
133152
def schedule_new_orchestration(
134153
self,
135154
orchestrator: Union[task.Orchestrator[TInput, TOutput], str],
@@ -188,10 +207,59 @@ def wait_for_orchestration_completion(
188207
) -> Optional[OrchestrationState]:
189208
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
190209
try:
191-
grpc_timeout = None if timeout == 0 else timeout
192-
self._logger.info(
193-
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete."
194-
)
210+
# gRPC timeout mapping (pytest unit tests may pass None explicitly)
211+
grpc_timeout = None if (timeout is None or timeout == 0) else timeout
212+
213+
# If timeout is None or 0, skip pre-checks/polling and call server-side wait directly
214+
if timeout is None or timeout == 0:
215+
self._logger.info(
216+
f"Waiting {'indefinitely' if not timeout else f'up to {timeout}s'} for instance '{instance_id}' to complete."
217+
)
218+
res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(
219+
req, timeout=grpc_timeout
220+
)
221+
state = new_orchestration_state(req.instanceId, res)
222+
return state
223+
224+
# For positive timeout, best-effort pre-check and short polling to avoid long server waits
225+
try:
226+
# First check if the orchestration is already completed
227+
current_state = self.get_orchestration_state(
228+
instance_id, fetch_payloads=fetch_payloads
229+
)
230+
if current_state and current_state.runtime_status in [
231+
OrchestrationStatus.COMPLETED,
232+
OrchestrationStatus.FAILED,
233+
OrchestrationStatus.TERMINATED,
234+
]:
235+
return current_state
236+
237+
# Poll for completion with exponential backoff to handle eventual consistency
238+
import time
239+
240+
poll_timeout = min(timeout, 10)
241+
poll_start = time.time()
242+
poll_interval = 0.1
243+
244+
while time.time() - poll_start < poll_timeout:
245+
current_state = self.get_orchestration_state(
246+
instance_id, fetch_payloads=fetch_payloads
247+
)
248+
249+
if current_state and current_state.runtime_status in [
250+
OrchestrationStatus.COMPLETED,
251+
OrchestrationStatus.FAILED,
252+
OrchestrationStatus.TERMINATED,
253+
]:
254+
return current_state
255+
256+
time.sleep(poll_interval)
257+
poll_interval = min(poll_interval * 1.5, 1.0) # Exponential backoff, max 1s
258+
except Exception:
259+
# Ignore pre-check/poll issues (e.g., mocked stubs in unit tests) and fall back
260+
pass
261+
262+
self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to complete.")
195263
res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(
196264
req, timeout=grpc_timeout
197265
)

durabletask/deterministic.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
"""
5+
Deterministic utilities for Durable Task workflows (async and generator).
6+
7+
This module provides deterministic alternatives to non-deterministic Python
8+
functions, ensuring workflow replay consistency across different executions.
9+
It is shared by both the asyncio authoring model and the generator-based model.
10+
"""
11+
12+
from __future__ import annotations
13+
14+
import hashlib
15+
import random
16+
import string as _string
17+
import uuid
18+
from collections.abc import Sequence
19+
from dataclasses import dataclass
20+
from datetime import datetime
21+
from typing import Optional, Protocol, TypeVar, runtime_checkable
22+
23+
24+
@dataclass
25+
class DeterminismSeed:
26+
"""Seed data for deterministic operations."""
27+
28+
instance_id: str
29+
orchestration_unix_ts: int
30+
31+
def to_int(self) -> int:
32+
"""Convert seed to integer for PRNG initialization."""
33+
combined = f"{self.instance_id}:{self.orchestration_unix_ts}"
34+
hash_bytes = hashlib.sha256(combined.encode("utf-8")).digest()
35+
return int.from_bytes(hash_bytes[:8], byteorder="big")
36+
37+
38+
def derive_seed(instance_id: str, orchestration_time: datetime) -> int:
39+
"""
40+
Derive a deterministic seed from instance ID and orchestration time.
41+
"""
42+
ts = int(orchestration_time.timestamp())
43+
return DeterminismSeed(instance_id=instance_id, orchestration_unix_ts=ts).to_int()
44+
45+
46+
def deterministic_random(instance_id: str, orchestration_time: datetime) -> random.Random:
47+
"""
48+
Create a deterministic random number generator.
49+
"""
50+
seed = derive_seed(instance_id, orchestration_time)
51+
return random.Random(seed)
52+
53+
54+
def deterministic_uuid4(rnd: random.Random) -> uuid.UUID:
55+
"""Generate a deterministic UUID4 using the provided random generator."""
56+
bytes_ = bytes(rnd.randrange(0, 256) for _ in range(16))
57+
bytes_list = list(bytes_)
58+
bytes_list[6] = (bytes_list[6] & 0x0F) | 0x40 # Version 4
59+
bytes_list[8] = (bytes_list[8] & 0x3F) | 0x80 # Variant bits
60+
return uuid.UUID(bytes=bytes(bytes_list))
61+
62+
63+
@runtime_checkable
64+
class DeterministicContextProtocol(Protocol):
65+
"""Protocol for contexts that provide deterministic operations."""
66+
67+
@property
68+
def instance_id(self) -> str: ...
69+
70+
@property
71+
def current_utc_datetime(self) -> datetime: ...
72+
73+
74+
class DeterministicContextMixin:
75+
"""
76+
Mixin providing deterministic helpers for workflow contexts.
77+
78+
Assumes the inheriting class exposes `instance_id` and `current_utc_datetime` attributes.
79+
"""
80+
81+
def now(self) -> datetime:
82+
"""Return orchestration time (deterministic UTC)."""
83+
value = self.current_utc_datetime # type: ignore[attr-defined]
84+
assert isinstance(value, datetime)
85+
return value
86+
87+
def random(self) -> random.Random:
88+
"""Return a PRNG seeded deterministically from instance id and orchestration time."""
89+
rnd = deterministic_random(
90+
self.instance_id, # type: ignore[attr-defined]
91+
self.current_utc_datetime, # type: ignore[attr-defined]
92+
)
93+
# Mark as deterministic for sandbox detector whitelisting of bound methods
94+
try:
95+
rnd._dt_deterministic = True
96+
except Exception:
97+
pass
98+
return rnd
99+
100+
def uuid4(self) -> uuid.UUID:
101+
"""Return a deterministically generated UUID using the deterministic PRNG."""
102+
rnd = self.random()
103+
return deterministic_uuid4(rnd)
104+
105+
def new_guid(self) -> uuid.UUID:
106+
"""Alias for uuid4 for API parity with other SDKs."""
107+
return self.uuid4()
108+
109+
def random_string(self, length: int, *, alphabet: Optional[str] = None) -> str:
110+
"""Return a deterministically generated random string of the given length."""
111+
if length < 0:
112+
raise ValueError("length must be non-negative")
113+
chars = alphabet if alphabet is not None else (_string.ascii_letters + _string.digits)
114+
if not chars:
115+
raise ValueError("alphabet must not be empty")
116+
rnd = self.random()
117+
size = len(chars)
118+
return "".join(chars[rnd.randrange(0, size)] for _ in range(length))
119+
120+
def random_int(self, min_value: int = 0, max_value: int = 2**31 - 1) -> int:
121+
"""Return a deterministic random integer in the specified range."""
122+
if min_value > max_value:
123+
raise ValueError("min_value must be <= max_value")
124+
rnd = self.random()
125+
return rnd.randint(min_value, max_value)
126+
127+
T = TypeVar("T")
128+
129+
def random_choice(self, sequence: Sequence[T]) -> T:
130+
"""Return a deterministic random element from a non-empty sequence."""
131+
if not sequence:
132+
raise IndexError("Cannot choose from empty sequence")
133+
rnd = self.random()
134+
return rnd.choice(sequence)

durabletask/internal/shared.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def get_logger(
102102
# Add a default log handler if none is provided
103103
if log_handler is None:
104104
log_handler = logging.StreamHandler()
105-
log_handler.setLevel(logging.INFO)
105+
log_handler.setLevel(logging.DEBUG)
106106
logger.handlers.append(log_handler)
107107

108108
# Set a default log formatter to our handler if none is provided

durabletask/task.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import math
88
from abc import ABC, abstractmethod
99
from datetime import datetime, timedelta
10-
from typing import Any, Callable, Generator, Generic, Optional, TypeVar, Union
10+
from typing import Any, Callable, Generic, Generator, Optional, TypeVar, Union, cast
1111

1212
import durabletask.internal.helpers as pbh
1313
import durabletask.internal.orchestrator_service_pb2 as pb
@@ -233,6 +233,16 @@ class OrchestrationStateError(Exception):
233233
pass
234234

235235

236+
class NonRetryableError(Exception):
237+
"""Exception indicating the operation should not be retried.
238+
239+
If an activity or sub-orchestration raises this exception, retry logic will be
240+
bypassed and the failure will be returned immediately to the orchestrator.
241+
"""
242+
243+
pass
244+
245+
236246
class Task(ABC, Generic[T]):
237247
"""Abstract base class for asynchronous tasks in a durable orchestration."""
238248

@@ -395,7 +405,7 @@ def compute_next_delay(self) -> Optional[timedelta]:
395405
next_delay_f = min(
396406
next_delay_f, self._retry_policy.max_retry_interval.total_seconds()
397407
)
398-
return timedelta(seconds=next_delay_f)
408+
return timedelta(seconds=next_delay_f)
399409

400410
return None
401411

@@ -486,6 +496,7 @@ def __init__(
486496
backoff_coefficient: Optional[float] = 1.0,
487497
max_retry_interval: Optional[timedelta] = None,
488498
retry_timeout: Optional[timedelta] = None,
499+
non_retryable_error_types: Optional[list[Union[str, type]]] = None,
489500
):
490501
"""Creates a new RetryPolicy instance.
491502
@@ -501,6 +512,11 @@ def __init__(
501512
The maximum retry interval to use for any retry attempt.
502513
retry_timeout : Optional[timedelta]
503514
The maximum amount of time to spend retrying the operation.
515+
non_retryable_error_types : Optional[list[Union[str, type]]]
516+
A list of exception type names or classes that should not be retried.
517+
If a failure's error type matches any of these, the task fails immediately.
518+
The built-in NonRetryableError is always treated as non-retryable regardless
519+
of this setting.
504520
"""
505521
# validate inputs
506522
if first_retry_interval < timedelta(seconds=0):
@@ -519,6 +535,17 @@ def __init__(
519535
self._backoff_coefficient = backoff_coefficient
520536
self._max_retry_interval = max_retry_interval
521537
self._retry_timeout = retry_timeout
538+
# Normalize non-retryable error type names to a set of strings
539+
names: Optional[set[str]] = None
540+
if non_retryable_error_types:
541+
names = set()
542+
for t in non_retryable_error_types:
543+
if isinstance(t, str):
544+
if t:
545+
names.add(t)
546+
elif isinstance(t, type):
547+
names.add(t.__name__)
548+
self._non_retryable_error_types = names
522549

523550
@property
524551
def first_retry_interval(self) -> timedelta:
@@ -545,6 +572,14 @@ def retry_timeout(self) -> Optional[timedelta]:
545572
"""The maximum amount of time to spend retrying the operation."""
546573
return self._retry_timeout
547574

575+
@property
576+
def non_retryable_error_types(self) -> Optional[set[str]]:
577+
"""Set of error type names that should not be retried.
578+
579+
Comparison is performed against the errorType string provided by the
580+
backend (typically the exception class name).
581+
"""
582+
return self._non_retryable_error_types
548583

549584
def get_name(fn: Callable) -> str:
550585
"""Returns the name of the provided function"""

0 commit comments

Comments
 (0)