Skip to content

Commit a1310e6

Browse files
jhlegarretaoesteban
andcommitted
ENH: Generalize iterators to accept bvals or uptake as kwargs
Generalize iterators to accept `bvals` or `uptake` as keyword arguments. The previous implementation of the linear, random and centralsym iterators was only accepting `bvals`. This patch set allows these iterators to work with PET data through the `uptake` argument. Transition to keyword argument-only style. Adapt the doctests accordingly. Document the functions by explicitly assigning the docstring to the `__doc__` property of each function so that the `SIZE_KEYS_DOC` can be reused and to allow the examples be run by the doctring tests. Co-authored-by: Oscar Esteban <[email protected]>
1 parent aa0fcc8 commit a1310e6

File tree

3 files changed

+140
-106
lines changed

3 files changed

+140
-106
lines changed

src/nifreeze/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
125125

126126
# Prepare iterator
127127
iterfunc = getattr(iterators, f"{self._strategy}_iterator")
128-
index_iter = iterfunc(len(dataset), seed=kwargs.get("seed", None))
128+
index_iter = iterfunc(size=len(dataset), seed=kwargs.get("seed", None))
129129

130130
# Initialize model
131131
if isinstance(self._model, str):

src/nifreeze/utils/iterators.py

Lines changed: 130 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,24 @@
2626
from itertools import chain, zip_longest
2727
from typing import Iterator
2828

29-
ITERATOR_SIZE_ERROR_MSG = "Cannot build iterator without size."
29+
SIZE_KEYS = ("size", "bvals", "uptake")
30+
"""Keys that may be used to infer the number of volumes in a dataset. When the
31+
size of the structure to iterate over is not given explicitly, these keys
32+
correspond to properties that distinguish one imaging modality from another, and
33+
are part of the 4th axis (e.g. diffusion gradients in DWI or update in PET)."""
34+
35+
SIZE_KEYS_DOC = """
36+
size : obj:`int`, optional
37+
Size of the structure to iterate over.
38+
bvals : :obj:`list`, optional
39+
List of b-values corresponding to all orientations of a DWI dataset.
40+
uptake : :obj:`list`, optional
41+
List of uptake values corresponding to all volumes of the dataset.
42+
"""
43+
44+
ITERATOR_SIZE_ERROR_MSG = (
45+
f"None of {SIZE_KEYS} were provided to infer size: cannot build iterator without size."
46+
)
3047
"""Iterator size argument error message."""
3148
KWARG_ERROR_MSG = "Keyword argument {kwarg} is required."
3249
"""Iterator keyword argument error message."""
@@ -36,92 +53,65 @@
3653
"""Uptake keyword argument name."""
3754

3855

39-
def linear_iterator(size: int | None = None, **kwargs) -> Iterator[int]:
40-
"""
41-
Traverse the dataset volumes in ascending order.
56+
def _get_size_from_kwargs(kwargs: dict) -> int:
57+
"""Extract the size from kwargs, ensuring only one key is used.
4258
4359
Parameters
4460
----------
45-
size : :obj:`int` or ``None``, optional
46-
Number of volumes in the dataset.
47-
If ``None``, ``size`` will be inferred from the ``bvals`` keyword argument.
48-
49-
Other Parameters
50-
----------------
51-
bvals : :obj:`list`
52-
List of b-values corresponding to all orientations of a DWI dataset.
53-
If ``size`` is provided, this argument will be ignored.
54-
Otherwise, ``size`` will be inferred from the length of ``bvals``.
61+
kwargs : :obj:`dict`
62+
The keyword arguments passed to the iterator function.
5563
56-
Yields
57-
------
64+
Returns
65+
-------
5866
:obj:`int`
59-
The next index.
60-
61-
Examples
62-
--------
63-
>>> list(linear_iterator(10))
64-
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
67+
The inferred size.
6568
69+
Raises
70+
------
71+
:exc:`ValueError`
72+
If size could not be extracted.
6673
"""
67-
if size is None and BVALS_KWARG in kwargs:
68-
size = len(kwargs[BVALS_KWARG])
69-
if size is None:
70-
raise TypeError(ITERATOR_SIZE_ERROR_MSG)
74+
candidates = [kwargs[k] for k in SIZE_KEYS if k in kwargs]
75+
if candidates:
76+
return candidates[0] if isinstance(candidates[0], int) else len(candidates[0])
77+
raise ValueError(ITERATOR_SIZE_ERROR_MSG)
78+
7179

80+
def linear_iterator(**kwargs) -> Iterator[int]:
81+
size = _get_size_from_kwargs(kwargs)
7282
return (s for s in range(size))
7383

7484

75-
def random_iterator(size: int | None = None, **kwargs) -> Iterator[int]:
76-
"""
77-
Traverse the dataset volumes randomly.
85+
linear_iterator.__doc__ = f"""
86+
Traverse the dataset volumes in ascending order.
7887
79-
If the ``seed`` key is present in the keyword arguments, initializes the seed
80-
of Python's ``random`` pseudo-random number generator library with the given
81-
value. Specifically, if ``False``, ``None`` is used as the seed; if ``True``, a
82-
default seed value is used.
88+
Other Parameters
89+
----------------
90+
{SIZE_KEYS_DOC}
8391
84-
Parameters
85-
----------
86-
size : :obj:`int` or ``None``, optional
87-
Number of volumes in the dataset.
88-
If ``None``, ``size`` will be inferred from the ``bvals`` keyword argument.
92+
Notes
93+
-----
94+
Only one of the above keyword arguments may be provided at a time. If ``size``
95+
is given, all other size-related keyword arguments will be ignored. If ``size``
96+
is not provided, the function will attempt to infer the number of volumes from
97+
the length or value of the provided keyword argument. If more than one such
98+
keyword is provided, a :exc:`ValueError` will be raised.
8999
90-
Other Parameters
91-
----------------
92-
bvals : :obj:`list`
93-
List of b-values corresponding to all orientations of a DWI dataset.
94-
If ``size`` is provided, this argument will be ignored.
95-
Otherwise, ``size`` will be inferred from the length of ``bvals``.
96-
seed : :obj:`int`, :obj:`bool`, :obj:`str`, or ``None``
97-
If :obj:`int` or :obj:`str` or ``None``, initializes the seed of Python's random generator
98-
with the given value.
99-
If ``False``, the random generator is passed ``None``.
100-
If ``True``, a default seed value is set.
100+
Yields
101+
------
102+
:obj:`int`
103+
The next index.
101104
102-
Yields
103-
------
104-
:obj:`int`
105-
The next index.
105+
Examples
106+
--------
107+
>>> list(linear_iterator(size=10))
108+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
106109
107-
Examples
108-
--------
109-
>>> list(random_iterator(15, seed=0)) # seed is 0
110-
[1, 10, 9, 5, 11, 2, 3, 7, 8, 4, 0, 14, 12, 6, 13]
111-
>>> # seed is True -> the default value 20210324 is set
112-
>>> list(random_iterator(15, seed=True))
113-
[1, 12, 14, 5, 0, 11, 10, 9, 7, 8, 3, 13, 2, 6, 4]
114-
>>> list(random_iterator(15, seed=20210324))
115-
[1, 12, 14, 5, 0, 11, 10, 9, 7, 8, 3, 13, 2, 6, 4]
116-
>>> list(random_iterator(15, seed=42)) # seed is 42
117-
[8, 13, 7, 6, 14, 12, 5, 2, 9, 3, 4, 11, 0, 1, 10]
110+
"""
118111

119-
"""
120112

121-
if size is None and BVALS_KWARG in kwargs:
122-
size = len(kwargs[BVALS_KWARG])
123-
if size is None:
124-
raise TypeError(ITERATOR_SIZE_ERROR_MSG)
113+
def random_iterator(**kwargs) -> Iterator[int]:
114+
size = _get_size_from_kwargs(kwargs)
125115

126116
_seed = kwargs.get("seed", None)
127117
_seed = 20210324 if _seed is True else _seed
@@ -133,6 +123,51 @@ def random_iterator(size: int | None = None, **kwargs) -> Iterator[int]:
133123
return (x for x in index_order)
134124

135125

126+
random_iterator.__doc__ = f"""
127+
Traverse the dataset volumes randomly.
128+
129+
If the ``seed`` key is present in the keyword arguments, initializes the seed
130+
of Python's ``random`` pseudo-random number generator library with the given
131+
value. Specifically, if ``False``, ``None`` is used as the seed; if ``True``, a
132+
default seed value is used.
133+
134+
Other Parameters
135+
----------------
136+
seed : :obj:`int`, :obj:`bool`, :obj:`str`, or ``None``
137+
If :obj:`int` or :obj:`str` or ``None``, initializes the seed of Python's random generator
138+
with the given value. If ``False``, the random generator is passed ``None``.
139+
If ``True``, a default seed value is set.
140+
141+
{SIZE_KEYS_DOC}
142+
143+
Notes
144+
-----
145+
Only one of the above keyword arguments may be provided at a time. If ``size``
146+
is given, all other size-related keyword arguments will be ignored. If ``size``
147+
is not provided, the function will attempt to infer the number of volumes from
148+
the length or value of the provided keyword argument. If more than one such
149+
keyword is provided, a :exc:`ValueError` will be raised.
150+
151+
Yields
152+
------
153+
:obj:`int`
154+
The next index.
155+
156+
Examples
157+
--------
158+
>>> list(random_iterator(size=15, seed=0)) # seed is 0
159+
[1, 10, 9, 5, 11, 2, 3, 7, 8, 4, 0, 14, 12, 6, 13]
160+
>>> # seed is True -> the default value 20210324 is set
161+
>>> list(random_iterator(size=15, seed=True))
162+
[1, 12, 14, 5, 0, 11, 10, 9, 7, 8, 3, 13, 2, 6, 4]
163+
>>> list(random_iterator(size=15, seed=20210324))
164+
[1, 12, 14, 5, 0, 11, 10, 9, 7, 8, 3, 13, 2, 6, 4]
165+
>>> list(random_iterator(size=15, seed=42)) # seed is 42
166+
[8, 13, 7, 6, 14, 12, 5, 2, 9, 3, 4, 11, 0, 1, 10]
167+
168+
"""
169+
170+
136171
def _value_iterator(values: list, ascending: bool, round_decimals: int = 2) -> Iterator[int]:
137172
"""
138173
Traverse the given values in ascending or descenting order.
@@ -231,40 +266,9 @@ def uptake_iterator(*_, **kwargs) -> Iterator[int]:
231266
return _value_iterator(uptake, ascending=False, **kwargs)
232267

233268

234-
def centralsym_iterator(size: int | None = None, **kwargs) -> Iterator[int]:
235-
"""
236-
Traverse the dataset starting from the center and alternatingly progressing to the sides.
237-
238-
Parameters
239-
----------
240-
size : :obj:`int` or ``None``, optional
241-
Number of volumes in the dataset.
242-
If ``None``, ``size`` will be inferred from the ``bvals`` keyword argument.
243-
244-
Other Parameters
245-
----------------
246-
bvals : :obj:`list`
247-
List of b-values corresponding to all orientations of the dataset.
248-
If ``size`` is provided, this argument will be ignored.
249-
Otherwise, ``size`` will be inferred from the length of ``bvals``.
250-
251-
Yields
252-
------
253-
:obj:`~int`
254-
The next index.
255-
256-
Examples
257-
--------
258-
>>> list(centralsym_iterator(10))
259-
[5, 4, 6, 3, 7, 2, 8, 1, 9, 0]
260-
>>> list(centralsym_iterator(11))
261-
[5, 4, 6, 3, 7, 2, 8, 1, 9, 0, 10]
269+
def centralsym_iterator(**kwargs) -> Iterator[int]:
270+
size = _get_size_from_kwargs(kwargs)
262271

263-
"""
264-
if size is None and BVALS_KWARG in kwargs:
265-
size = len(kwargs[BVALS_KWARG])
266-
if size is None:
267-
raise TypeError(ITERATOR_SIZE_ERROR_MSG)
268272
linear = list(range(size))
269273
return (
270274
x
@@ -276,3 +280,27 @@ def centralsym_iterator(size: int | None = None, **kwargs) -> Iterator[int]:
276280
)
277281
if x is not None
278282
)
283+
284+
285+
centralsym_iterator.__doc__ = f"""
286+
Traverse the dataset starting from the center and alternatingly progressing to the sides.
287+
288+
Other Parameters
289+
----------------
290+
{SIZE_KEYS_DOC}
291+
292+
Notes
293+
-----
294+
Only one of the above keyword arguments may be provided at a time. If ``size``
295+
is given, all other size-related keyword arguments will be ignored. If ``size``
296+
is not provided, the function will attempt to infer the number of volumes from
297+
the length or value of the provided keyword argument. If more than one such
298+
keyword is provided, a :exc:`ValueError` will be raised.
299+
300+
Examples
301+
--------
302+
>>> list(centralsym_iterator(size=10))
303+
[5, 4, 6, 3, 7, 2, 8, 1, 9, 0]
304+
>>> list(centralsym_iterator(size=11))
305+
[5, 4, 6, 3, 7, 2, 8, 1, 9, 0, 10]
306+
"""

test/test_iterators.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
# https://www.nipreps.org/community/licensing/
2222
#
2323

24+
import re
25+
2426
import pytest
2527

2628
from nifreeze.utils.iterators import (
@@ -85,7 +87,7 @@ def test_value_iterator(values, ascending, round_decimals, expected):
8587

8688

8789
def test_linear_iterator_error():
88-
with pytest.raises(TypeError, match=ITERATOR_SIZE_ERROR_MSG):
90+
with pytest.raises(ValueError, match=re.escape(ITERATOR_SIZE_ERROR_MSG)):
8991
list(linear_iterator())
9092

9193

@@ -94,14 +96,15 @@ def test_linear_iterator_error():
9496
[
9597
({"size": 4}, [0, 1, 2, 3]),
9698
({"bvals": [0, 1000, 2000, 3000]}, [0, 1, 2, 3]),
99+
({"uptake": [-1.02, -0.56, 0.43, 1.16]}, [0, 1, 2, 3]),
97100
],
98101
)
99102
def test_linear_iterator(kwargs, expected):
100103
assert list(linear_iterator(**kwargs)) == expected
101104

102105

103106
def test_random_iterator_error():
104-
with pytest.raises(TypeError, match=ITERATOR_SIZE_ERROR_MSG):
107+
with pytest.raises(ValueError, match=re.escape(ITERATOR_SIZE_ERROR_MSG)):
105108
list(random_iterator())
106109

107110

@@ -110,6 +113,7 @@ def test_random_iterator_error():
110113
[
111114
({"size": 5, "seed": 1234}, [1, 2, 4, 0, 3]),
112115
({"bvals": [0, 1000, 2000, 3000], "seed": 42}, [2, 1, 3, 0]),
116+
({"uptake": [-1.02, -0.56, 0.43, 1.16], "seed": True}, [3, 0, 1, 2]),
113117
],
114118
)
115119
def test_random_iterator(kwargs, expected):
@@ -120,7 +124,7 @@ def test_random_iterator(kwargs, expected):
120124

121125

122126
def test_centralsym_iterator_error():
123-
with pytest.raises(TypeError, match=ITERATOR_SIZE_ERROR_MSG):
127+
with pytest.raises(ValueError, match=re.escape(ITERATOR_SIZE_ERROR_MSG)):
124128
list(random_iterator())
125129

126130

@@ -131,6 +135,8 @@ def test_centralsym_iterator_error():
131135
({"bvals": [1000] * 6}, [3, 2, 4, 1, 5, 0]),
132136
({"bvals": [0, 700, 1000, 2000, 3000]}, [2, 1, 3, 0, 4]),
133137
({"bvals": [0, 1000, 700, 2000, 3000]}, [2, 1, 3, 0, 4]),
138+
({"uptake": [0.32, 0.27, -0.12]}, [1, 0, 2]),
139+
({"uptake": [-1.02, -0.56, 0.43, 0.89, 1.16]}, [2, 1, 3, 0, 4]),
134140
],
135141
)
136142
def test_centralsym_iterator(kwargs, expected):

0 commit comments

Comments
 (0)