|  | 
|  | 1 | +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- | 
|  | 2 | +# vi: set ft=python sts=4 ts=4 sw=4 et: | 
|  | 3 | +# | 
|  | 4 | +# Copyright The NiPreps Developers <[email protected]> | 
|  | 5 | +# | 
|  | 6 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 7 | +# you may not use this file except in compliance with the License. | 
|  | 8 | +# You may obtain a copy of the License at | 
|  | 9 | +# | 
|  | 10 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 11 | +# | 
|  | 12 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 13 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 14 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 15 | +# See the License for the specific language governing permissions and | 
|  | 16 | +# limitations under the License. | 
|  | 17 | +# | 
|  | 18 | +# We support and encourage derived works from this project, please read | 
|  | 19 | +# about our expectations at | 
|  | 20 | +# | 
|  | 21 | +#     https://www.nipreps.org/community/licensing/ | 
|  | 22 | +# | 
|  | 23 | + | 
|  | 24 | +import pytest | 
|  | 25 | + | 
|  | 26 | +from nifreeze.utils.iterators import ( | 
|  | 27 | +    BVALS_KWARG, | 
|  | 28 | +    ITERATOR_SIZE_ERROR_MSG, | 
|  | 29 | +    KWARG_ERROR_MSG, | 
|  | 30 | +    UPTAKE_KWARG, | 
|  | 31 | +    _value_iterator, | 
|  | 32 | +    bvalue_iterator, | 
|  | 33 | +    centralsym_iterator, | 
|  | 34 | +    linear_iterator, | 
|  | 35 | +    random_iterator, | 
|  | 36 | +    uptake_iterator, | 
|  | 37 | +) | 
|  | 38 | + | 
|  | 39 | + | 
|  | 40 | +@pytest.mark.parametrize( | 
|  | 41 | +    "values, ascending, expected", | 
|  | 42 | +    [ | 
|  | 43 | +        # Simple integers | 
|  | 44 | +        ([1, 2, 3], True, [0, 1, 2]), | 
|  | 45 | +        ([1, 2, 3], False, [2, 1, 0]), | 
|  | 46 | +        # Repeated values | 
|  | 47 | +        ([2, 1, 2, 1], True, [1, 3, 0, 2]), | 
|  | 48 | +        ([2, 1, 2, 1], False, [2, 0, 3, 1]),  # Ties are reversed due to reverse=True | 
|  | 49 | +        # Floats | 
|  | 50 | +        ([1.01, 1.02, 0.99], True, [2, 0, 1]), | 
|  | 51 | +        ([1.01, 1.02, 0.99], False, [1, 0, 2]), | 
|  | 52 | +        # Floats with rounding | 
|  | 53 | +        ( | 
|  | 54 | +            [1.001, 1.002, 0.999], | 
|  | 55 | +            True, | 
|  | 56 | +            [0, 1, 2], | 
|  | 57 | +        ),  # All round to 1.00 (round_decimals=2), so original order | 
|  | 58 | +        ( | 
|  | 59 | +            [1.001, 1.002, 0.999], | 
|  | 60 | +            False, | 
|  | 61 | +            [2, 1, 0], | 
|  | 62 | +        ),  # All round to 1.00 (round_decimals=2), ties are reversed due to reverse=True | 
|  | 63 | +        # Negative and positive | 
|  | 64 | +        ([-1.2, 0.0, 3.4, -1.2], True, [0, 3, 1, 2]), | 
|  | 65 | +        ([-1.2, 0.0, 3.4, -1.2], False, [2, 1, 3, 0]),  # Ties are reversed due to reverse=True | 
|  | 66 | +    ], | 
|  | 67 | +) | 
|  | 68 | +def test_value_iterator(values, ascending, expected): | 
|  | 69 | +    result = list(_value_iterator(values, ascending=ascending)) | 
|  | 70 | +    assert result == expected | 
|  | 71 | + | 
|  | 72 | + | 
|  | 73 | +def test_linear_iterator_error(): | 
|  | 74 | +    with pytest.raises(TypeError, match=ITERATOR_SIZE_ERROR_MSG): | 
|  | 75 | +        list(linear_iterator()) | 
|  | 76 | + | 
|  | 77 | + | 
|  | 78 | +@pytest.mark.parametrize( | 
|  | 79 | +    "kwargs, expected", | 
|  | 80 | +    [ | 
|  | 81 | +        ({"size": 4}, [0, 1, 2, 3]), | 
|  | 82 | +        ({"bvals": [0, 1000, 2000, 3000]}, [0, 1, 2, 3]), | 
|  | 83 | +    ], | 
|  | 84 | +) | 
|  | 85 | +def test_linear_iterator(kwargs, expected): | 
|  | 86 | +    assert list(linear_iterator(**kwargs)) == expected | 
|  | 87 | + | 
|  | 88 | + | 
|  | 89 | +def test_random_iterator_error(): | 
|  | 90 | +    with pytest.raises(TypeError, match=ITERATOR_SIZE_ERROR_MSG): | 
|  | 91 | +        list(random_iterator()) | 
|  | 92 | + | 
|  | 93 | + | 
|  | 94 | +@pytest.mark.parametrize( | 
|  | 95 | +    "kwargs, expected", | 
|  | 96 | +    [ | 
|  | 97 | +        ({"size": 5, "seed": 1234}, [1, 2, 4, 0, 3]), | 
|  | 98 | +        ({"bvals": [0, 1000, 2000, 3000], "seed": 42}, [2, 1, 3, 0]), | 
|  | 99 | +    ], | 
|  | 100 | +) | 
|  | 101 | +def test_random_iterator(kwargs, expected): | 
|  | 102 | +    obtained = list(random_iterator(**kwargs)) | 
|  | 103 | +    assert obtained == expected | 
|  | 104 | +    # Determinism check | 
|  | 105 | +    assert obtained == list(random_iterator(**kwargs)) | 
|  | 106 | + | 
|  | 107 | + | 
|  | 108 | +def test_centralsym_iterator_error(): | 
|  | 109 | +    with pytest.raises(TypeError, match=ITERATOR_SIZE_ERROR_MSG): | 
|  | 110 | +        list(random_iterator()) | 
|  | 111 | + | 
|  | 112 | + | 
|  | 113 | +@pytest.mark.parametrize( | 
|  | 114 | +    "kwargs, expected", | 
|  | 115 | +    [ | 
|  | 116 | +        ({"size": 6}, [3, 2, 4, 1, 5, 0]), | 
|  | 117 | +        ({"bvals": [1000] * 6}, [3, 2, 4, 1, 5, 0]), | 
|  | 118 | +        ({"bvals": [0, 700, 1000, 2000, 3000]}, [2, 1, 3, 0, 4]), | 
|  | 119 | +        ({"bvals": [0, 1000, 700, 2000, 3000]}, [2, 1, 3, 0, 4]), | 
|  | 120 | +    ], | 
|  | 121 | +) | 
|  | 122 | +def test_centralsym_iterator(kwargs, expected): | 
|  | 123 | +    # The centralsym_iterator's output order depends only on the length | 
|  | 124 | +    assert list(centralsym_iterator(**kwargs)) == expected | 
|  | 125 | + | 
|  | 126 | + | 
|  | 127 | +def test_bvalue_iterator_error(): | 
|  | 128 | +    with pytest.raises(TypeError, match=KWARG_ERROR_MSG.format(kwarg=BVALS_KWARG)): | 
|  | 129 | +        list(bvalue_iterator()) | 
|  | 130 | + | 
|  | 131 | + | 
|  | 132 | +@pytest.mark.parametrize( | 
|  | 133 | +    "bvals, expected", | 
|  | 134 | +    [ | 
|  | 135 | +        ([0, 700, 1200], [0, 1, 2]), | 
|  | 136 | +        ([0, 0, 1000, 700], [0, 1, 3, 2]), | 
|  | 137 | +        ([0, 1000, 1500, 700, 2000], [0, 3, 1, 2, 4]), | 
|  | 138 | +    ], | 
|  | 139 | +) | 
|  | 140 | +def test_bvalue_iterator(bvals, expected): | 
|  | 141 | +    obtained = list(bvalue_iterator(bvals=bvals)) | 
|  | 142 | +    assert set(obtained) == set(range(len(bvals))) | 
|  | 143 | +    # Should be ordered by increasing bvalue | 
|  | 144 | +    sorted_bvals = [bvals[i] for i in obtained] | 
|  | 145 | +    assert sorted_bvals == sorted(bvals) | 
|  | 146 | + | 
|  | 147 | + | 
|  | 148 | +def test_uptake_iterator_error(): | 
|  | 149 | +    with pytest.raises(TypeError, match=KWARG_ERROR_MSG.format(kwarg=UPTAKE_KWARG)): | 
|  | 150 | +        list(uptake_iterator()) | 
|  | 151 | + | 
|  | 152 | + | 
|  | 153 | +@pytest.mark.parametrize( | 
|  | 154 | +    "uptake, expected", | 
|  | 155 | +    [ | 
|  | 156 | +        ([0.3, 0.2, 0.1], [0, 1, 2]), | 
|  | 157 | +        ([0.2, 0.1, 0.3], [2, 1, 0]), | 
|  | 158 | +        ([-1.02, 1.16, -0.56, 0.43], [1, 3, 2, 0]), | 
|  | 159 | +    ], | 
|  | 160 | +) | 
|  | 161 | +def test_uptake_iterator_valid(uptake, expected): | 
|  | 162 | +    obtained = list(uptake_iterator(uptake=uptake)) | 
|  | 163 | +    assert set(obtained) == set(range(len(uptake))) | 
|  | 164 | +    # Should be ordered by decreasing uptake | 
|  | 165 | +    sorted_uptake = [uptake[i] for i in obtained] | 
|  | 166 | +    assert sorted_uptake == sorted(uptake, reverse=True) | 
0 commit comments