Skip to content

Commit bec4d37

Browse files
authored
fix: raise stopIteration when get_next is exhausted (#70)
Raise `StopAsyncIteration` when the async iterator is exhausted and no default value is provided. Introduce a sentinel class for missing values to improve clarity in the code.
1 parent 63825d0 commit bec4d37

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

sqlspec/utils/sync_tools.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@
2929
T = TypeVar("T")
3030

3131

32+
class NoValue:
33+
"""Sentinel class for missing values."""
34+
35+
36+
NO_VALUE = NoValue()
37+
38+
3239
class CapacityLimiter:
3340
"""Limits the number of concurrent operations using a semaphore."""
3441

@@ -240,11 +247,7 @@ def with_ensure_async_(
240247
return obj
241248

242249

243-
class NoValue:
244-
"""Sentinel class for missing values."""
245-
246-
247-
async def get_next(iterable: Any, default: Any = NoValue, *args: Any) -> Any: # pragma: no cover
250+
async def get_next(iterable: Any, default: Any = NO_VALUE, *args: Any) -> Any: # pragma: no cover
248251
"""Return the next item from an async iterator.
249252
250253
Args:

tests/unit/test_utils/test_sync_tools.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,9 @@ async def test_get_next_with_default() -> None:
320320
"""Test get_next with default value when iterator is exhausted."""
321321

322322
class EmptyAsyncIterator:
323+
def __aiter__(self) -> "EmptyAsyncIterator":
324+
return self
325+
323326
async def __anext__(self) -> int:
324327
raise StopAsyncIteration
325328

@@ -334,17 +337,17 @@ async def test_get_next_no_default_behavior() -> None:
334337
"""Test get_next behavior when iterator is exhausted without default."""
335338

336339
class EmptyAsyncIterator:
340+
def __aiter__(self) -> "EmptyAsyncIterator":
341+
return self
342+
337343
async def __anext__(self) -> int:
338344
raise StopAsyncIteration
339345

340346
iterator = EmptyAsyncIterator()
341347

342-
try:
343-
result = await get_next(iterator)
344-
345-
assert isinstance(result, type(NoValue))
346-
except StopAsyncIteration:
347-
pass
348+
# Should raise StopAsyncIteration when no default is provided
349+
with pytest.raises(StopAsyncIteration):
350+
await get_next(iterator)
348351

349352

350353
def test_no_value_class() -> None:

0 commit comments

Comments
 (0)