Skip to content

Commit 5e22e27

Browse files
authored
Merge pull request #243 from jhlegarreta/tst/test-iterators
TST: Test iterators
2 parents 6357f73 + 45e4410 commit 5e22e27

File tree

2 files changed

+188
-13
lines changed

2 files changed

+188
-13
lines changed

src/nifreeze/utils/iterators.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@
2626
from itertools import chain, zip_longest
2727
from typing import Iterator
2828

29+
ITERATOR_SIZE_ERROR_MSG = "Cannot build iterator without size."
30+
"""Iterator size argument error message."""
31+
KWARG_ERROR_MSG = "Keyword argument {kwarg} is required."
32+
"""Iterator keyword argument error message."""
33+
BVALS_KWARG = "bvals"
34+
"""b-vals keyword argument name."""
35+
UPTAKE_KWARG = "uptake"
36+
"""Uptake keyword argument name."""
37+
2938

3039
def linear_iterator(size: int | None = None, **kwargs) -> Iterator[int]:
3140
"""
@@ -55,10 +64,10 @@ def linear_iterator(size: int | None = None, **kwargs) -> Iterator[int]:
5564
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
5665
5766
"""
58-
if size is None and "bvals" in kwargs:
59-
size = len(kwargs["bvals"])
67+
if size is None and BVALS_KWARG in kwargs:
68+
size = len(kwargs[BVALS_KWARG])
6069
if size is None:
61-
raise TypeError("Cannot build iterator without size")
70+
raise TypeError(ITERATOR_SIZE_ERROR_MSG)
6271

6372
return (s for s in range(size))
6473

@@ -109,10 +118,10 @@ def random_iterator(size: int | None = None, **kwargs) -> Iterator[int]:
109118
110119
"""
111120

112-
if size is None and "bvals" in kwargs:
113-
size = len(kwargs["bvals"])
121+
if size is None and BVALS_KWARG in kwargs:
122+
size = len(kwargs[BVALS_KWARG])
114123
if size is None:
115-
raise TypeError("Cannot build iterator without size")
124+
raise TypeError(ITERATOR_SIZE_ERROR_MSG)
116125

117126
_seed = kwargs.get("seed", None)
118127
_seed = 20210324 if _seed is True else _seed
@@ -184,9 +193,9 @@ def bvalue_iterator(*_, **kwargs) -> Iterator[int]:
184193
[0, 1, 8, 4, 5, 2, 3, 6, 7]
185194
186195
"""
187-
bvals = kwargs.get("bvals", None)
196+
bvals = kwargs.get(BVALS_KWARG, None)
188197
if bvals is None:
189-
raise TypeError("Keyword argument bvals is required")
198+
raise TypeError(KWARG_ERROR_MSG.format(kwarg=BVALS_KWARG))
190199
return _value_iterator(bvals, round_decimals=2, ascending=True)
191200

192201

@@ -216,9 +225,9 @@ def uptake_iterator(*_, **kwargs) -> Iterator[int]:
216225
[3, 7, 1, 8, 2, 5, 6, 0, 4]
217226
218227
"""
219-
uptake = kwargs.get("uptake", None)
228+
uptake = kwargs.get(UPTAKE_KWARG, None)
220229
if uptake is None:
221-
raise TypeError("Keyword argument uptake is required")
230+
raise TypeError(KWARG_ERROR_MSG.format(kwarg=UPTAKE_KWARG))
222231
return _value_iterator(uptake, round_decimals=2, ascending=False)
223232

224233

@@ -252,10 +261,10 @@ def centralsym_iterator(size: int | None = None, **kwargs) -> Iterator[int]:
252261
[5, 4, 6, 3, 7, 2, 8, 1, 9, 0, 10]
253262
254263
"""
255-
if size is None and "bvals" in kwargs:
256-
size = len(kwargs["bvals"])
264+
if size is None and BVALS_KWARG in kwargs:
265+
size = len(kwargs[BVALS_KWARG])
257266
if size is None:
258-
raise TypeError("Cannot build iterator without size")
267+
raise TypeError(ITERATOR_SIZE_ERROR_MSG)
259268
linear = list(range(size))
260269
return (
261270
x

test/test_iterators.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
#
4+
# Copyright The NiPreps Developers <[email protected]>
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
# We support and encourage derived works from this project, please read
19+
# about our expectations at
20+
#
21+
# https://www.nipreps.org/community/licensing/
22+
#
23+
24+
import pytest
25+
26+
from nifreeze.utils.iterators import (
27+
BVALS_KWARG,
28+
ITERATOR_SIZE_ERROR_MSG,
29+
KWARG_ERROR_MSG,
30+
UPTAKE_KWARG,
31+
_value_iterator,
32+
bvalue_iterator,
33+
centralsym_iterator,
34+
linear_iterator,
35+
random_iterator,
36+
uptake_iterator,
37+
)
38+
39+
40+
@pytest.mark.parametrize(
41+
"values, ascending, expected",
42+
[
43+
# Simple integers
44+
([1, 2, 3], True, [0, 1, 2]),
45+
([1, 2, 3], False, [2, 1, 0]),
46+
# Repeated values
47+
([2, 1, 2, 1], True, [1, 3, 0, 2]),
48+
([2, 1, 2, 1], False, [2, 0, 3, 1]), # Ties are reversed due to reverse=True
49+
# Floats
50+
([1.01, 1.02, 0.99], True, [2, 0, 1]),
51+
([1.01, 1.02, 0.99], False, [1, 0, 2]),
52+
# Floats with rounding
53+
(
54+
[1.001, 1.002, 0.999],
55+
True,
56+
[0, 1, 2],
57+
), # All round to 1.00 (round_decimals=2), so original order
58+
(
59+
[1.001, 1.002, 0.999],
60+
False,
61+
[2, 1, 0],
62+
), # All round to 1.00 (round_decimals=2), ties are reversed due to reverse=True
63+
# Negative and positive
64+
([-1.2, 0.0, 3.4, -1.2], True, [0, 3, 1, 2]),
65+
([-1.2, 0.0, 3.4, -1.2], False, [2, 1, 3, 0]), # Ties are reversed due to reverse=True
66+
],
67+
)
68+
def test_value_iterator(values, ascending, expected):
69+
result = list(_value_iterator(values, ascending=ascending))
70+
assert result == expected
71+
72+
73+
def test_linear_iterator_error():
74+
with pytest.raises(TypeError, match=ITERATOR_SIZE_ERROR_MSG):
75+
list(linear_iterator())
76+
77+
78+
@pytest.mark.parametrize(
79+
"kwargs, expected",
80+
[
81+
({"size": 4}, [0, 1, 2, 3]),
82+
({"bvals": [0, 1000, 2000, 3000]}, [0, 1, 2, 3]),
83+
],
84+
)
85+
def test_linear_iterator(kwargs, expected):
86+
assert list(linear_iterator(**kwargs)) == expected
87+
88+
89+
def test_random_iterator_error():
90+
with pytest.raises(TypeError, match=ITERATOR_SIZE_ERROR_MSG):
91+
list(random_iterator())
92+
93+
94+
@pytest.mark.parametrize(
95+
"kwargs, expected",
96+
[
97+
({"size": 5, "seed": 1234}, [1, 2, 4, 0, 3]),
98+
({"bvals": [0, 1000, 2000, 3000], "seed": 42}, [2, 1, 3, 0]),
99+
],
100+
)
101+
def test_random_iterator(kwargs, expected):
102+
obtained = list(random_iterator(**kwargs))
103+
assert obtained == expected
104+
# Determinism check
105+
assert obtained == list(random_iterator(**kwargs))
106+
107+
108+
def test_centralsym_iterator_error():
109+
with pytest.raises(TypeError, match=ITERATOR_SIZE_ERROR_MSG):
110+
list(random_iterator())
111+
112+
113+
@pytest.mark.parametrize(
114+
"kwargs, expected",
115+
[
116+
({"size": 6}, [3, 2, 4, 1, 5, 0]),
117+
({"bvals": [1000] * 6}, [3, 2, 4, 1, 5, 0]),
118+
({"bvals": [0, 700, 1000, 2000, 3000]}, [2, 1, 3, 0, 4]),
119+
({"bvals": [0, 1000, 700, 2000, 3000]}, [2, 1, 3, 0, 4]),
120+
],
121+
)
122+
def test_centralsym_iterator(kwargs, expected):
123+
# The centralsym_iterator's output order depends only on the length
124+
assert list(centralsym_iterator(**kwargs)) == expected
125+
126+
127+
def test_bvalue_iterator_error():
128+
with pytest.raises(TypeError, match=KWARG_ERROR_MSG.format(kwarg=BVALS_KWARG)):
129+
list(bvalue_iterator())
130+
131+
132+
@pytest.mark.parametrize(
133+
"bvals, expected",
134+
[
135+
([0, 700, 1200], [0, 1, 2]),
136+
([0, 0, 1000, 700], [0, 1, 3, 2]),
137+
([0, 1000, 1500, 700, 2000], [0, 3, 1, 2, 4]),
138+
],
139+
)
140+
def test_bvalue_iterator(bvals, expected):
141+
obtained = list(bvalue_iterator(bvals=bvals))
142+
assert set(obtained) == set(range(len(bvals)))
143+
# Should be ordered by increasing bvalue
144+
sorted_bvals = [bvals[i] for i in obtained]
145+
assert sorted_bvals == sorted(bvals)
146+
147+
148+
def test_uptake_iterator_error():
149+
with pytest.raises(TypeError, match=KWARG_ERROR_MSG.format(kwarg=UPTAKE_KWARG)):
150+
list(uptake_iterator())
151+
152+
153+
@pytest.mark.parametrize(
154+
"uptake, expected",
155+
[
156+
([0.3, 0.2, 0.1], [0, 1, 2]),
157+
([0.2, 0.1, 0.3], [2, 1, 0]),
158+
([-1.02, 1.16, -0.56, 0.43], [1, 3, 2, 0]),
159+
],
160+
)
161+
def test_uptake_iterator_valid(uptake, expected):
162+
obtained = list(uptake_iterator(uptake=uptake))
163+
assert set(obtained) == set(range(len(uptake)))
164+
# Should be ordered by decreasing uptake
165+
sorted_uptake = [uptake[i] for i in obtained]
166+
assert sorted_uptake == sorted(uptake, reverse=True)

0 commit comments

Comments
 (0)