Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/nifreeze/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:

# Prepare iterator
iterfunc = getattr(iterators, f"{self._strategy}_iterator")
index_iter = iterfunc(len(dataset), seed=kwargs.get("seed", None))
index_iter = iterfunc(size=len(dataset), seed=kwargs.get("seed", None))

# Initialize model
if isinstance(self._model, str):
Expand Down
232 changes: 130 additions & 102 deletions src/nifreeze/utils/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,24 @@
from itertools import chain, zip_longest
from typing import Iterator

ITERATOR_SIZE_ERROR_MSG = "Cannot build iterator without size."
SIZE_KEYS = ("size", "bvals", "uptake")
"""Keys that may be used to infer the number of volumes in a dataset. When the
size of the structure to iterate over is not given explicitly, these keys
correspond to properties that distinguish one imaging modality from another, and
are part of the 4th axis (e.g. diffusion gradients in DWI or update in PET)."""

SIZE_KEYS_DOC = """
size : obj:`int`, optional
Size of the structure to iterate over.
bvals : :obj:`list`, optional
List of b-values corresponding to all orientations of a DWI dataset.
uptake : :obj:`list`, optional
List of uptake values corresponding to all volumes of the dataset.
"""

ITERATOR_SIZE_ERROR_MSG = (
f"None of {SIZE_KEYS} were provided to infer size: cannot build iterator without size."
)
"""Iterator size argument error message."""
KWARG_ERROR_MSG = "Keyword argument {kwarg} is required."
"""Iterator keyword argument error message."""
Expand All @@ -36,92 +53,65 @@
"""Uptake keyword argument name."""


def linear_iterator(size: int | None = None, **kwargs) -> Iterator[int]:
"""
Traverse the dataset volumes in ascending order.
def _get_size_from_kwargs(kwargs: dict) -> int:
"""Extract the size from kwargs, ensuring only one key is used.

Parameters
----------
size : :obj:`int` or ``None``, optional
Number of volumes in the dataset.
If ``None``, ``size`` will be inferred from the ``bvals`` keyword argument.

Other Parameters
----------------
bvals : :obj:`list`
List of b-values corresponding to all orientations of a DWI dataset.
If ``size`` is provided, this argument will be ignored.
Otherwise, ``size`` will be inferred from the length of ``bvals``.
kwargs : :obj:`dict`
The keyword arguments passed to the iterator function.

Yields
------
Returns
-------
:obj:`int`
The next index.

Examples
--------
>>> list(linear_iterator(10))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
The inferred size.

Raises
------
:exc:`ValueError`
If size could not be extracted.
"""
if size is None and BVALS_KWARG in kwargs:
size = len(kwargs[BVALS_KWARG])
if size is None:
raise TypeError(ITERATOR_SIZE_ERROR_MSG)
candidates = [kwargs[k] for k in SIZE_KEYS if k in kwargs]
if candidates:
return candidates[0] if isinstance(candidates[0], int) else len(candidates[0])
raise ValueError(ITERATOR_SIZE_ERROR_MSG)


def linear_iterator(**kwargs) -> Iterator[int]:
size = _get_size_from_kwargs(kwargs)
return (s for s in range(size))


def random_iterator(size: int | None = None, **kwargs) -> Iterator[int]:
"""
Traverse the dataset volumes randomly.
linear_iterator.__doc__ = f"""
Traverse the dataset volumes in ascending order.

If the ``seed`` key is present in the keyword arguments, initializes the seed
of Python's ``random`` pseudo-random number generator library with the given
value. Specifically, if ``False``, ``None`` is used as the seed; if ``True``, a
default seed value is used.
Other Parameters
----------------
{SIZE_KEYS_DOC}

Parameters
----------
size : :obj:`int` or ``None``, optional
Number of volumes in the dataset.
If ``None``, ``size`` will be inferred from the ``bvals`` keyword argument.
Notes
-----
Only one of the above keyword arguments may be provided at a time. If ``size``
is given, all other size-related keyword arguments will be ignored. If ``size``
is not provided, the function will attempt to infer the number of volumes from
the length or value of the provided keyword argument. If more than one such
keyword is provided, a :exc:`ValueError` will be raised.

Other Parameters
----------------
bvals : :obj:`list`
List of b-values corresponding to all orientations of a DWI dataset.
If ``size`` is provided, this argument will be ignored.
Otherwise, ``size`` will be inferred from the length of ``bvals``.
seed : :obj:`int`, :obj:`bool`, :obj:`str`, or ``None``
If :obj:`int` or :obj:`str` or ``None``, initializes the seed of Python's random generator
with the given value.
If ``False``, the random generator is passed ``None``.
If ``True``, a default seed value is set.
Yields
------
:obj:`int`
The next index.

Yields
------
:obj:`int`
The next index.
Examples
--------
>>> list(linear_iterator(size=10))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

Examples
--------
>>> list(random_iterator(15, seed=0)) # seed is 0
[1, 10, 9, 5, 11, 2, 3, 7, 8, 4, 0, 14, 12, 6, 13]
>>> # seed is True -> the default value 20210324 is set
>>> list(random_iterator(15, seed=True))
[1, 12, 14, 5, 0, 11, 10, 9, 7, 8, 3, 13, 2, 6, 4]
>>> list(random_iterator(15, seed=20210324))
[1, 12, 14, 5, 0, 11, 10, 9, 7, 8, 3, 13, 2, 6, 4]
>>> list(random_iterator(15, seed=42)) # seed is 42
[8, 13, 7, 6, 14, 12, 5, 2, 9, 3, 4, 11, 0, 1, 10]
"""

"""

if size is None and BVALS_KWARG in kwargs:
size = len(kwargs[BVALS_KWARG])
if size is None:
raise TypeError(ITERATOR_SIZE_ERROR_MSG)
def random_iterator(**kwargs) -> Iterator[int]:
size = _get_size_from_kwargs(kwargs)

_seed = kwargs.get("seed", None)
_seed = 20210324 if _seed is True else _seed
Expand All @@ -133,6 +123,51 @@ def random_iterator(size: int | None = None, **kwargs) -> Iterator[int]:
return (x for x in index_order)


random_iterator.__doc__ = f"""
Traverse the dataset volumes randomly.

If the ``seed`` key is present in the keyword arguments, initializes the seed
of Python's ``random`` pseudo-random number generator library with the given
value. Specifically, if ``False``, ``None`` is used as the seed; if ``True``, a
default seed value is used.

Other Parameters
----------------
seed : :obj:`int`, :obj:`bool`, :obj:`str`, or ``None``
If :obj:`int` or :obj:`str` or ``None``, initializes the seed of Python's random generator
with the given value. If ``False``, the random generator is passed ``None``.
If ``True``, a default seed value is set.

{SIZE_KEYS_DOC}

Notes
-----
Only one of the above keyword arguments may be provided at a time. If ``size``
is given, all other size-related keyword arguments will be ignored. If ``size``
is not provided, the function will attempt to infer the number of volumes from
the length or value of the provided keyword argument. If more than one such
keyword is provided, a :exc:`ValueError` will be raised.

Yields
------
:obj:`int`
The next index.

Examples
--------
>>> list(random_iterator(size=15, seed=0)) # seed is 0
[1, 10, 9, 5, 11, 2, 3, 7, 8, 4, 0, 14, 12, 6, 13]
>>> # seed is True -> the default value 20210324 is set
>>> list(random_iterator(size=15, seed=True))
[1, 12, 14, 5, 0, 11, 10, 9, 7, 8, 3, 13, 2, 6, 4]
>>> list(random_iterator(size=15, seed=20210324))
[1, 12, 14, 5, 0, 11, 10, 9, 7, 8, 3, 13, 2, 6, 4]
>>> list(random_iterator(size=15, seed=42)) # seed is 42
[8, 13, 7, 6, 14, 12, 5, 2, 9, 3, 4, 11, 0, 1, 10]

"""


def _value_iterator(values: list, ascending: bool, round_decimals: int = 2) -> Iterator[int]:
"""
Traverse the given values in ascending or descenting order.
Expand Down Expand Up @@ -231,40 +266,9 @@ def uptake_iterator(*_, **kwargs) -> Iterator[int]:
return _value_iterator(uptake, ascending=False, **kwargs)


def centralsym_iterator(size: int | None = None, **kwargs) -> Iterator[int]:
"""
Traverse the dataset starting from the center and alternatingly progressing to the sides.

Parameters
----------
size : :obj:`int` or ``None``, optional
Number of volumes in the dataset.
If ``None``, ``size`` will be inferred from the ``bvals`` keyword argument.

Other Parameters
----------------
bvals : :obj:`list`
List of b-values corresponding to all orientations of the dataset.
If ``size`` is provided, this argument will be ignored.
Otherwise, ``size`` will be inferred from the length of ``bvals``.

Yields
------
:obj:`~int`
The next index.

Examples
--------
>>> list(centralsym_iterator(10))
[5, 4, 6, 3, 7, 2, 8, 1, 9, 0]
>>> list(centralsym_iterator(11))
[5, 4, 6, 3, 7, 2, 8, 1, 9, 0, 10]
def centralsym_iterator(**kwargs) -> Iterator[int]:
size = _get_size_from_kwargs(kwargs)

"""
if size is None and BVALS_KWARG in kwargs:
size = len(kwargs[BVALS_KWARG])
if size is None:
raise TypeError(ITERATOR_SIZE_ERROR_MSG)
linear = list(range(size))
return (
x
Expand All @@ -276,3 +280,27 @@ def centralsym_iterator(size: int | None = None, **kwargs) -> Iterator[int]:
)
if x is not None
)


centralsym_iterator.__doc__ = f"""
Traverse the dataset starting from the center and alternatingly progressing to the sides.

Other Parameters
----------------
{SIZE_KEYS_DOC}

Notes
-----
Only one of the above keyword arguments may be provided at a time. If ``size``
is given, all other size-related keyword arguments will be ignored. If ``size``
is not provided, the function will attempt to infer the number of volumes from
the length or value of the provided keyword argument. If more than one such
keyword is provided, a :exc:`ValueError` will be raised.

Examples
--------
>>> list(centralsym_iterator(size=10))
[5, 4, 6, 3, 7, 2, 8, 1, 9, 0]
>>> list(centralsym_iterator(size=11))
[5, 4, 6, 3, 7, 2, 8, 1, 9, 0, 10]
"""
12 changes: 9 additions & 3 deletions test/test_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
# https://www.nipreps.org/community/licensing/
#

import re

import pytest

from nifreeze.utils.iterators import (
Expand Down Expand Up @@ -85,7 +87,7 @@ def test_value_iterator(values, ascending, round_decimals, expected):


def test_linear_iterator_error():
with pytest.raises(TypeError, match=ITERATOR_SIZE_ERROR_MSG):
with pytest.raises(ValueError, match=re.escape(ITERATOR_SIZE_ERROR_MSG)):
list(linear_iterator())


Expand All @@ -94,14 +96,15 @@ def test_linear_iterator_error():
[
({"size": 4}, [0, 1, 2, 3]),
({"bvals": [0, 1000, 2000, 3000]}, [0, 1, 2, 3]),
({"uptake": [-1.02, -0.56, 0.43, 1.16]}, [0, 1, 2, 3]),
],
)
def test_linear_iterator(kwargs, expected):
assert list(linear_iterator(**kwargs)) == expected


def test_random_iterator_error():
with pytest.raises(TypeError, match=ITERATOR_SIZE_ERROR_MSG):
with pytest.raises(ValueError, match=re.escape(ITERATOR_SIZE_ERROR_MSG)):
list(random_iterator())


Expand All @@ -110,6 +113,7 @@ def test_random_iterator_error():
[
({"size": 5, "seed": 1234}, [1, 2, 4, 0, 3]),
({"bvals": [0, 1000, 2000, 3000], "seed": 42}, [2, 1, 3, 0]),
({"uptake": [-1.02, -0.56, 0.43, 1.16], "seed": True}, [3, 0, 1, 2]),
],
)
def test_random_iterator(kwargs, expected):
Expand All @@ -120,7 +124,7 @@ def test_random_iterator(kwargs, expected):


def test_centralsym_iterator_error():
with pytest.raises(TypeError, match=ITERATOR_SIZE_ERROR_MSG):
with pytest.raises(ValueError, match=re.escape(ITERATOR_SIZE_ERROR_MSG)):
list(random_iterator())


Expand All @@ -131,6 +135,8 @@ def test_centralsym_iterator_error():
({"bvals": [1000] * 6}, [3, 2, 4, 1, 5, 0]),
({"bvals": [0, 700, 1000, 2000, 3000]}, [2, 1, 3, 0, 4]),
({"bvals": [0, 1000, 700, 2000, 3000]}, [2, 1, 3, 0, 4]),
({"uptake": [0.32, 0.27, -0.12]}, [1, 0, 2]),
({"uptake": [-1.02, -0.56, 0.43, 0.89, 1.16]}, [2, 1, 3, 0, 4]),
],
)
def test_centralsym_iterator(kwargs, expected):
Expand Down