Skip to content

Commit 205f527

Browse files
committed
TST: Test iterators
Test iterators. Even if docstring tests cover the core fetures, exceptions remain untested. This patch set adds unit tests that do test the exceptions and the different use cases with the keyword arguments that iterators accept. Take advantage of the commit to define the keyword arguments and exception messages as global variables so that the exceptions raised by the iterators get checked.
1 parent 431ce51 commit 205f527

File tree

2 files changed

+188
-13
lines changed

2 files changed

+188
-13
lines changed

src/nifreeze/utils/iterators.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@
2626
from itertools import chain, zip_longest
2727
from typing import Iterator
2828

29+
ITERATOR_SIZE_ERROR_MSG = "Cannot build iterator without size."
30+
"""Iterator size argument error message."""
31+
KWARG_ERROR_MSG = "Keyword argument {kwarg} is required."
32+
"""Iterator keyword argument error message."""
33+
BVALS_KWARG = "bvals"
34+
"""b-vals keyword argument name."""
35+
UPTAKE_KWARG = "uptake"
36+
"""Uptake keyword argument name."""
37+
2938

3039
def linear_iterator(size: int | None = None, **kwargs) -> Iterator[int]:
3140
"""
@@ -55,10 +64,10 @@ def linear_iterator(size: int | None = None, **kwargs) -> Iterator[int]:
5564
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
5665
5766
"""
58-
if size is None and "bvals" in kwargs:
59-
size = len(kwargs["bvals"])
67+
if size is None and BVALS_KWARG in kwargs:
68+
size = len(kwargs[BVALS_KWARG])
6069
if size is None:
61-
raise TypeError("Cannot build iterator without size")
70+
raise TypeError(ITERATOR_SIZE_ERROR_MSG)
6271

6372
return (s for s in range(size))
6473

@@ -109,10 +118,10 @@ def random_iterator(size: int | None = None, **kwargs) -> Iterator[int]:
109118
110119
"""
111120

112-
if size is None and "bvals" in kwargs:
113-
size = len(kwargs["bvals"])
121+
if size is None and BVALS_KWARG in kwargs:
122+
size = len(kwargs[BVALS_KWARG])
114123
if size is None:
115-
raise TypeError("Cannot build iterator without size")
124+
raise TypeError(ITERATOR_SIZE_ERROR_MSG)
116125

117126
_seed = kwargs.get("seed", None)
118127
_seed = 20210324 if _seed is True else _seed
@@ -184,9 +193,9 @@ def bvalue_iterator(*_, **kwargs) -> Iterator[int]:
184193
[0, 1, 8, 4, 5, 2, 3, 6, 7]
185194
186195
"""
187-
bvals = kwargs.get("bvals", None)
196+
bvals = kwargs.get(BVALS_KWARG, None)
188197
if bvals is None:
189-
raise TypeError("Keyword argument bvals is required")
198+
raise TypeError(KWARG_ERROR_MSG.format(kwarg=BVALS_KWARG))
190199
return _value_iterator(bvals, round_decimals=2, ascending=True)
191200

192201

@@ -216,9 +225,9 @@ def uptake_iterator(*_, **kwargs) -> Iterator[int]:
216225
[3, 7, 1, 8, 2, 5, 6, 0, 4]
217226
218227
"""
219-
uptake = kwargs.get("uptake", None)
228+
uptake = kwargs.get(UPTAKE_KWARG, None)
220229
if uptake is None:
221-
raise TypeError("Keyword argument uptake is required")
230+
raise TypeError(KWARG_ERROR_MSG.format(kwarg=UPTAKE_KWARG))
222231
return _value_iterator(uptake, round_decimals=2, ascending=False)
223232

224233

@@ -252,10 +261,10 @@ def centralsym_iterator(size: int | None = None, **kwargs) -> Iterator[int]:
252261
[5, 4, 6, 3, 7, 2, 8, 1, 9, 0, 10]
253262
254263
"""
255-
if size is None and "bvals" in kwargs:
256-
size = len(kwargs["bvals"])
264+
if size is None and BVALS_KWARG in kwargs:
265+
size = len(kwargs[BVALS_KWARG])
257266
if size is None:
258-
raise TypeError("Cannot build iterator without size")
267+
raise TypeError(ITERATOR_SIZE_ERROR_MSG)
259268
linear = list(range(size))
260269
return (
261270
x

test/test_iterators.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
#
4+
# Copyright The NiPreps Developers <[email protected]>
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
# We support and encourage derived works from this project, please read
19+
# about our expectations at
20+
#
21+
# https://www.nipreps.org/community/licensing/
22+
#
23+
24+
import pytest
25+
26+
from nifreeze.utils.iterators import (
27+
BVALS_KWARG,
28+
ITERATOR_SIZE_ERROR_MSG,
29+
KWARG_ERROR_MSG,
30+
UPTAKE_KWARG,
31+
_value_iterator,
32+
bvalue_iterator,
33+
centralsym_iterator,
34+
linear_iterator,
35+
random_iterator,
36+
uptake_iterator,
37+
)
38+
39+
40+
@pytest.mark.parametrize(
41+
"values, ascending, expected",
42+
[
43+
# Simple integers
44+
([1, 2, 3], True, [0, 1, 2]),
45+
([1, 2, 3], False, [2, 1, 0]),
46+
# Repeated values
47+
([2, 1, 2, 1], True, [1, 3, 0, 2]),
48+
([2, 1, 2, 1], False, [2, 0, 3, 1]), # Ties are reversed due to reverse=True
49+
# Floats
50+
([1.01, 1.02, 0.99], True, [2, 0, 1]),
51+
([1.01, 1.02, 0.99], False, [1, 0, 2]),
52+
# Floats with rounding
53+
(
54+
[1.001, 1.002, 0.999],
55+
True,
56+
[0, 1, 2],
57+
), # All round to 1.00 (round_decimals=2), so original order
58+
(
59+
[1.001, 1.002, 0.999],
60+
False,
61+
[2, 1, 0],
62+
), # All round to 1.00 (round_decimals=2), ties are reversed due to reverse=True
63+
# Negative and positive
64+
([-1.2, 0.0, 3.4, -1.2], True, [0, 3, 1, 2]),
65+
([-1.2, 0.0, 3.4, -1.2], False, [2, 1, 3, 0]), # Ties are reversed due to reverse=True
66+
],
67+
)
68+
def test_value_iterator(values, ascending, expected):
69+
result = list(_value_iterator(values, ascending=ascending))
70+
assert result == expected
71+
72+
73+
def test_linear_iterator_error():
74+
with pytest.raises(TypeError, match=ITERATOR_SIZE_ERROR_MSG):
75+
list(linear_iterator())
76+
77+
78+
@pytest.mark.parametrize(
79+
"kwargs, expected",
80+
[
81+
({"size": 4}, [0, 1, 2, 3]),
82+
({"bvals": [0, 1000, 2000, 3000]}, [0, 1, 2, 3]),
83+
],
84+
)
85+
def test_linear_iterator(kwargs, expected):
86+
assert list(linear_iterator(**kwargs)) == expected
87+
88+
89+
def test_random_iterator_error():
90+
with pytest.raises(TypeError, match=ITERATOR_SIZE_ERROR_MSG):
91+
list(random_iterator())
92+
93+
94+
@pytest.mark.parametrize(
95+
"kwargs, expected",
96+
[
97+
({"size": 5, "seed": 1234}, [1, 2, 4, 0, 3]),
98+
({"bvals": [0, 1000, 2000, 3000], "seed": 42}, [2, 1, 3, 0]),
99+
],
100+
)
101+
def test_random_iterator(kwargs, expected):
102+
obtained = list(random_iterator(**kwargs))
103+
assert obtained == expected
104+
# Determinism check
105+
assert obtained == list(random_iterator(**kwargs))
106+
107+
108+
def test_centralsym_iterator_error():
109+
with pytest.raises(TypeError, match=ITERATOR_SIZE_ERROR_MSG):
110+
list(random_iterator())
111+
112+
113+
@pytest.mark.parametrize(
114+
"kwargs, expected",
115+
[
116+
({"size": 6}, [3, 2, 4, 1, 5, 0]),
117+
({"bvals": [1000] * 6}, [3, 2, 4, 1, 5, 0]),
118+
({"bvals": [0, 700, 1000, 2000, 3000]}, [2, 1, 3, 0, 4]),
119+
({"bvals": [0, 1000, 700, 2000, 3000]}, [2, 1, 3, 0, 4]),
120+
],
121+
)
122+
def test_centralsym_iterator(kwargs, expected):
123+
# The centralsym_iterator's output order depends only on the length
124+
assert list(centralsym_iterator(**kwargs)) == expected
125+
126+
127+
def test_bvalue_iterator_error():
128+
with pytest.raises(TypeError, match=KWARG_ERROR_MSG.format(kwarg=BVALS_KWARG)):
129+
list(bvalue_iterator())
130+
131+
132+
@pytest.mark.parametrize(
133+
"bvals, expected",
134+
[
135+
([0, 700, 1200], [0, 1, 2]),
136+
([0, 0, 1000, 700], [0, 1, 3, 2]),
137+
([0, 1000, 1500, 700, 2000], [0, 3, 1, 2, 4]),
138+
],
139+
)
140+
def test_bvalue_iterator(bvals, expected):
141+
obtained = list(bvalue_iterator(bvals=bvals))
142+
assert set(obtained) == set(range(len(bvals)))
143+
# Should be ordered by increasing bvalue
144+
sorted_bvals = [bvals[i] for i in obtained]
145+
assert sorted_bvals == sorted(bvals)
146+
147+
148+
def test_uptake_iterator_error():
149+
with pytest.raises(TypeError, match=KWARG_ERROR_MSG.format(kwarg=UPTAKE_KWARG)):
150+
list(uptake_iterator())
151+
152+
153+
@pytest.mark.parametrize(
154+
"uptake, expected",
155+
[
156+
([0.3, 0.2, 0.1], [0, 1, 2]),
157+
([0.2, 0.1, 0.3], [2, 1, 0]),
158+
([-1.02, 1.16, -0.56, 0.43], [1, 3, 2, 0]),
159+
],
160+
)
161+
def test_uptake_iterator_valid(uptake, expected):
162+
obtained = list(uptake_iterator(uptake=uptake))
163+
assert set(obtained) == set(range(len(uptake)))
164+
# Should be ordered by decreasing uptake
165+
sorted_uptake = [uptake[i] for i in obtained]
166+
assert sorted_uptake == sorted(uptake, reverse=True)

0 commit comments

Comments
 (0)