Skip to content

Commit ac37fed

Browse files
authored
Merge pull request #241 from jhlegarreta/ref/refactor-gtab-fixtures
REF: Allow bvals and bvecs to be separate fixtures
2 parents 6c00959 + 30c26ca commit ac37fed

File tree

3 files changed

+58
-16
lines changed

3 files changed

+58
-16
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ addopts = "-v --doctest-modules"
211211
doctest_optionflags = "ALLOW_UNICODE NORMALIZE_WHITESPACE ELLIPSIS"
212212
env = "PYTHONHASHSEED=0"
213213
markers = [
214+
"random_bval_data: Custom marker for random b-val data tests",
215+
"random_bvec_data: Custom marker for random b-vec data tests",
214216
"random_gtab_data: Custom marker for random gtab data tests",
215217
"random_dwi_data: Custom marker for random dwi data tests",
216218
"random_pet_data: Custom marker for random pet data tests",

test/conftest.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,7 @@ def setup_random_uniform_ndim_data(request):
209209
return _generate_random_uniform_nd_data(request, size, a, b)
210210

211211

212-
def _generate_random_choices(request, values, count):
213-
rng = request.node.rng
214-
212+
def _generate_random_choices(rng, values, count):
215213
values = set(values)
216214

217215
num_elements = len(values)
@@ -235,6 +233,46 @@ def _generate_random_choices(request, values, count):
235233
return sorted(selected_values)
236234

237235

236+
def _generate_random_bvals(rng, b0s, shells, n_gradients):
237+
# Generate a random number of elements for each shell
238+
bvals_shells = _generate_random_choices(rng, shells, n_gradients)
239+
240+
bvals = np.hstack([b0s * [0], bvals_shells])
241+
242+
return bvals
243+
244+
245+
def _generate_random_bvecs(rng, b0s, n_gradients):
246+
return np.hstack([np.zeros((3, b0s)), normalized_vector(rng.random((3, n_gradients)), axis=0)])
247+
248+
249+
@pytest.fixture(autouse=True)
250+
def setup_random_bval_data(request):
251+
"""Automatically generate random b-val data for tests."""
252+
marker = request.node.get_closest_marker("random_bval_data")
253+
254+
n_gradients = 10
255+
shells = (1000, 2000, 3000)
256+
b0s = 1
257+
if marker:
258+
n_gradients, shells, b0s = marker.args
259+
260+
rng = request.node.rng
261+
return _generate_random_bvals(rng, b0s, shells, n_gradients)
262+
263+
264+
@pytest.fixture
265+
def setup_random_bvec_data(request, bvals, bval_tolerance):
266+
"""Automatically generate random b-vec data for tests."""
267+
rng = request.node.rng
268+
269+
is_b0 = np.abs(bvals) <= bval_tolerance
270+
b0s = np.sum(is_b0)
271+
n_gradients = np.sum(~is_b0)
272+
273+
return _generate_random_bvecs(rng, b0s, n_gradients)
274+
275+
238276
@pytest.fixture(autouse=True)
239277
def setup_random_gtab_data(request):
240278
"""Automatically generate random gtab data for tests."""
@@ -248,13 +286,8 @@ def setup_random_gtab_data(request):
248286

249287
rng = request.node.rng
250288

251-
# Generate a random number of elements for each shell
252-
bvals_shells = _generate_random_choices(request, shells, n_gradients)
253-
254-
bvals = np.hstack([b0s * [0], bvals_shells])
255-
bvecs = np.hstack(
256-
[np.zeros((3, b0s)), normalized_vector(rng.random((3, n_gradients)), axis=0)]
257-
)
289+
bvals = _generate_random_bvals(rng, b0s, shells, n_gradients)
290+
bvecs = _generate_random_bvecs(rng, b0s, n_gradients)
258291

259292
return bvals, bvecs
260293

test/test_filtering.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,23 @@ def test_advanced_clip(
9393
)
9494

9595

96-
@pytest.mark.random_gtab_data(5, (1000, 2000, 3000), 1)
9796
@pytest.mark.parametrize(
98-
"index, expect_exception, expected_output",
97+
"bvals, index, expect_exception, expected_output, bval_tolerance",
9998
[
100-
(3, False, np.asarray([False, True, True, False, False, False])),
101-
(0, True, np.asarray([])),
99+
(
100+
np.asarray([0, 1000, 1000, 1000, 2000, 3000]),
101+
3,
102+
False,
103+
np.asarray([False, True, True, False, False, False]),
104+
50,
105+
),
106+
(np.asarray([0, 1000, 1000, 1000, 2000, 3000]), 0, True, np.asarray([]), 50),
102107
],
103108
)
104-
def test_dwi_select_shells(setup_random_gtab_data, index, expect_exception, expected_output):
105-
bvals, bvecs = setup_random_gtab_data
109+
def test_dwi_select_shells(
110+
setup_random_bvec_data, bvals, index, expect_exception, expected_output, bval_tolerance
111+
):
112+
bvecs = setup_random_bvec_data
106113

107114
gradients = np.vstack([bvecs, bvals[np.newaxis, :]], dtype="float32")
108115

0 commit comments

Comments
 (0)