Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions src/nifreeze/utils/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@
from itertools import chain, zip_longest
from typing import Iterator

ITERATOR_SIZE_ERROR_MSG = "Cannot build iterator without size."
"""Iterator size argument error message."""
KWARG_ERROR_MSG = "Keyword argument {kwarg} is required."
"""Iterator keyword argument error message."""
BVALS_KWARG = "bvals"
"""b-vals keyword argument name."""
UPTAKE_KWARG = "uptake"
"""Uptake keyword argument name."""


def linear_iterator(size: int | None = None, **kwargs) -> Iterator[int]:
"""
Expand Down Expand Up @@ -55,10 +64,10 @@ def linear_iterator(size: int | None = None, **kwargs) -> Iterator[int]:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

"""
if size is None and "bvals" in kwargs:
size = len(kwargs["bvals"])
if size is None and BVALS_KWARG in kwargs:
size = len(kwargs[BVALS_KWARG])
if size is None:
raise TypeError("Cannot build iterator without size")
raise TypeError(ITERATOR_SIZE_ERROR_MSG)

return (s for s in range(size))

Expand Down Expand Up @@ -109,10 +118,10 @@ def random_iterator(size: int | None = None, **kwargs) -> Iterator[int]:

"""

if size is None and "bvals" in kwargs:
size = len(kwargs["bvals"])
if size is None and BVALS_KWARG in kwargs:
size = len(kwargs[BVALS_KWARG])
if size is None:
raise TypeError("Cannot build iterator without size")
raise TypeError(ITERATOR_SIZE_ERROR_MSG)

_seed = kwargs.get("seed", None)
_seed = 20210324 if _seed is True else _seed
Expand Down Expand Up @@ -184,9 +193,9 @@ def bvalue_iterator(*_, **kwargs) -> Iterator[int]:
[0, 1, 8, 4, 5, 2, 3, 6, 7]

"""
bvals = kwargs.get("bvals", None)
bvals = kwargs.get(BVALS_KWARG, None)
if bvals is None:
raise TypeError("Keyword argument bvals is required")
raise TypeError(KWARG_ERROR_MSG.format(kwarg=BVALS_KWARG))
return _value_iterator(bvals, round_decimals=2, ascending=True)


Expand Down Expand Up @@ -216,9 +225,9 @@ def uptake_iterator(*_, **kwargs) -> Iterator[int]:
[3, 7, 1, 8, 2, 5, 6, 0, 4]

"""
uptake = kwargs.get("uptake", None)
uptake = kwargs.get(UPTAKE_KWARG, None)
if uptake is None:
raise TypeError("Keyword argument uptake is required")
raise TypeError(KWARG_ERROR_MSG.format(kwarg=UPTAKE_KWARG))
return _value_iterator(uptake, round_decimals=2, ascending=False)


Expand Down Expand Up @@ -252,10 +261,10 @@ def centralsym_iterator(size: int | None = None, **kwargs) -> Iterator[int]:
[5, 4, 6, 3, 7, 2, 8, 1, 9, 0, 10]

"""
if size is None and "bvals" in kwargs:
size = len(kwargs["bvals"])
if size is None and BVALS_KWARG in kwargs:
size = len(kwargs[BVALS_KWARG])
if size is None:
raise TypeError("Cannot build iterator without size")
raise TypeError(ITERATOR_SIZE_ERROR_MSG)
linear = list(range(size))
return (
x
Expand Down
166 changes: 166 additions & 0 deletions test/test_iterators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright The NiPreps Developers <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#

import pytest

from nifreeze.utils.iterators import (
BVALS_KWARG,
ITERATOR_SIZE_ERROR_MSG,
KWARG_ERROR_MSG,
UPTAKE_KWARG,
_value_iterator,
bvalue_iterator,
centralsym_iterator,
linear_iterator,
random_iterator,
uptake_iterator,
)


@pytest.mark.parametrize(
"values, ascending, expected",
[
# Simple integers
([1, 2, 3], True, [0, 1, 2]),
([1, 2, 3], False, [2, 1, 0]),
# Repeated values
([2, 1, 2, 1], True, [1, 3, 0, 2]),
([2, 1, 2, 1], False, [2, 0, 3, 1]), # Ties are reversed due to reverse=True
# Floats
([1.01, 1.02, 0.99], True, [2, 0, 1]),
([1.01, 1.02, 0.99], False, [1, 0, 2]),
# Floats with rounding
(
[1.001, 1.002, 0.999],
True,
[0, 1, 2],
), # All round to 1.00 (round_decimals=2), so original order
(
[1.001, 1.002, 0.999],
False,
[2, 1, 0],
), # All round to 1.00 (round_decimals=2), ties are reversed due to reverse=True
# Negative and positive
([-1.2, 0.0, 3.4, -1.2], True, [0, 3, 1, 2]),
([-1.2, 0.0, 3.4, -1.2], False, [2, 1, 3, 0]), # Ties are reversed due to reverse=True
],
)
def test_value_iterator(values, ascending, expected):
result = list(_value_iterator(values, ascending=ascending))
assert result == expected


def test_linear_iterator_error():
with pytest.raises(TypeError, match=ITERATOR_SIZE_ERROR_MSG):
list(linear_iterator())


@pytest.mark.parametrize(
"kwargs, expected",
[
({"size": 4}, [0, 1, 2, 3]),
({"bvals": [0, 1000, 2000, 3000]}, [0, 1, 2, 3]),
],
)
def test_linear_iterator(kwargs, expected):
assert list(linear_iterator(**kwargs)) == expected


def test_random_iterator_error():
with pytest.raises(TypeError, match=ITERATOR_SIZE_ERROR_MSG):
list(random_iterator())


@pytest.mark.parametrize(
"kwargs, expected",
[
({"size": 5, "seed": 1234}, [1, 2, 4, 0, 3]),
({"bvals": [0, 1000, 2000, 3000], "seed": 42}, [2, 1, 3, 0]),
],
)
def test_random_iterator(kwargs, expected):
obtained = list(random_iterator(**kwargs))
assert obtained == expected
# Determinism check
assert obtained == list(random_iterator(**kwargs))


def test_centralsym_iterator_error():
with pytest.raises(TypeError, match=ITERATOR_SIZE_ERROR_MSG):
list(random_iterator())


@pytest.mark.parametrize(
"kwargs, expected",
[
({"size": 6}, [3, 2, 4, 1, 5, 0]),
({"bvals": [1000] * 6}, [3, 2, 4, 1, 5, 0]),
({"bvals": [0, 700, 1000, 2000, 3000]}, [2, 1, 3, 0, 4]),
({"bvals": [0, 1000, 700, 2000, 3000]}, [2, 1, 3, 0, 4]),
],
)
def test_centralsym_iterator(kwargs, expected):
# The centralsym_iterator's output order depends only on the length
assert list(centralsym_iterator(**kwargs)) == expected


def test_bvalue_iterator_error():
with pytest.raises(TypeError, match=KWARG_ERROR_MSG.format(kwarg=BVALS_KWARG)):
list(bvalue_iterator())


@pytest.mark.parametrize(
"bvals, expected",
[
([0, 700, 1200], [0, 1, 2]),
([0, 0, 1000, 700], [0, 1, 3, 2]),
([0, 1000, 1500, 700, 2000], [0, 3, 1, 2, 4]),
],
)
def test_bvalue_iterator(bvals, expected):
obtained = list(bvalue_iterator(bvals=bvals))
assert set(obtained) == set(range(len(bvals)))
# Should be ordered by increasing bvalue
sorted_bvals = [bvals[i] for i in obtained]
assert sorted_bvals == sorted(bvals)


def test_uptake_iterator_error():
with pytest.raises(TypeError, match=KWARG_ERROR_MSG.format(kwarg=UPTAKE_KWARG)):
list(uptake_iterator())


@pytest.mark.parametrize(
"uptake, expected",
[
([0.3, 0.2, 0.1], [0, 1, 2]),
([0.2, 0.1, 0.3], [2, 1, 0]),
([-1.02, 1.16, -0.56, 0.43], [1, 3, 2, 0]),
],
)
def test_uptake_iterator_valid(uptake, expected):
obtained = list(uptake_iterator(uptake=uptake))
assert set(obtained) == set(range(len(uptake)))
# Should be ordered by decreasing uptake
sorted_uptake = [uptake[i] for i in obtained]
assert sorted_uptake == sorted(uptake, reverse=True)