diff --git a/src/nifreeze/utils/iterators.py b/src/nifreeze/utils/iterators.py index e3fdb696..5429324a 100644 --- a/src/nifreeze/utils/iterators.py +++ b/src/nifreeze/utils/iterators.py @@ -193,10 +193,10 @@ def bvalue_iterator(*_, **kwargs) -> Iterator[int]: [0, 1, 8, 4, 5, 2, 3, 6, 7] """ - bvals = kwargs.get(BVALS_KWARG, None) + bvals = kwargs.pop(BVALS_KWARG, None) if bvals is None: raise TypeError(KWARG_ERROR_MSG.format(kwarg=BVALS_KWARG)) - return _value_iterator(bvals, round_decimals=2, ascending=True) + return _value_iterator(bvals, ascending=True, **kwargs) def uptake_iterator(*_, **kwargs) -> Iterator[int]: @@ -225,10 +225,10 @@ def uptake_iterator(*_, **kwargs) -> Iterator[int]: [3, 7, 1, 8, 2, 5, 6, 0, 4] """ - uptake = kwargs.get(UPTAKE_KWARG, None) + uptake = kwargs.pop(UPTAKE_KWARG, None) if uptake is None: raise TypeError(KWARG_ERROR_MSG.format(kwarg=UPTAKE_KWARG)) - return _value_iterator(uptake, round_decimals=2, ascending=False) + return _value_iterator(uptake, ascending=False, **kwargs) def centralsym_iterator(size: int | None = None, **kwargs) -> Iterator[int]: diff --git a/test/test_iterators.py b/test/test_iterators.py index b774c17b..be2b63a3 100644 --- a/test/test_iterators.py +++ b/test/test_iterators.py @@ -38,35 +38,49 @@ @pytest.mark.parametrize( - "values, ascending, expected", + "values, ascending, round_decimals, expected", [ # Simple integers - ([1, 2, 3], True, [0, 1, 2]), - ([1, 2, 3], False, [2, 1, 0]), + ([1, 2, 3], True, 2, [0, 1, 2]), + ([1, 2, 3], False, 2, [2, 1, 0]), # Repeated values - ([2, 1, 2, 1], True, [1, 3, 0, 2]), - ([2, 1, 2, 1], False, [2, 0, 3, 1]), # Ties are reversed due to reverse=True + ([2, 1, 2, 1], True, 2, [1, 3, 0, 2]), + ([2, 1, 2, 1], False, 2, [2, 0, 3, 1]), # Ties are reversed due to reverse=True # Floats - ([1.01, 1.02, 0.99], True, [2, 0, 1]), - ([1.01, 1.02, 0.99], False, [1, 0, 2]), + ([1.01, 1.02, 0.99], True, 2, [2, 0, 1]), + ([1.01, 1.02, 0.99], False, 2, [1, 0, 2]), # Floats with rounding ( [1.001, 1.002, 0.999], True, + 2, [0, 1, 2], ), # All round to 1.00 (round_decimals=2), so original order + ( + [1.001, 1.002, 0.999], + True, + 4, + [2, 0, 1], + ), ( [1.001, 1.002, 0.999], False, + 2, [2, 1, 0], ), # All round to 1.00 (round_decimals=2), ties are reversed due to reverse=True + ( + [1.001, 1.002, 0.999], + False, + 4, + [1, 0, 2], + ), # Negative and positive - ([-1.2, 0.0, 3.4, -1.2], True, [0, 3, 1, 2]), - ([-1.2, 0.0, 3.4, -1.2], False, [2, 1, 3, 0]), # Ties are reversed due to reverse=True + ([-1.2, 0.0, 3.4, -1.2], True, 2, [0, 3, 1, 2]), + ([-1.2, 0.0, 3.4, -1.2], False, 2, [2, 1, 3, 0]), # Ties are reversed due to reverse=True ], ) -def test_value_iterator(values, ascending, expected): - result = list(_value_iterator(values, ascending=ascending)) +def test_value_iterator(values, ascending, round_decimals, expected): + result = list(_value_iterator(values, ascending=ascending, round_decimals=round_decimals)) assert result == expected