Skip to content

Commit 4a96c3c

Browse files
authored
Merge pull request #237 from jhlegarreta/enh/testing-model-factory
TST: Improve `ModelFactory` coverage
2 parents ac37fed + e1d529d commit 4a96c3c

File tree

3 files changed

+110
-1
lines changed

3 files changed

+110
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ addopts = "-v --doctest-modules"
211211
doctest_optionflags = "ALLOW_UNICODE NORMALIZE_WHITESPACE ELLIPSIS"
212212
env = "PYTHONHASHSEED=0"
213213
markers = [
214+
"random_base_data: Custom marker for random base data tests",
214215
"random_bval_data: Custom marker for random b-val data tests",
215216
"random_bvec_data: Custom marker for random b-vec data tests",
216217
"random_gtab_data: Custom marker for random gtab data tests",

test/conftest.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,34 @@ def setup_random_gtab_data(request):
292292
return bvals, bvecs
293293

294294

295+
@pytest.fixture(autouse=True)
296+
def setup_random_base_data(request):
297+
"""Automatically generate random BaseDataset data for tests."""
298+
marker = request.node.get_closest_marker("random_base_data")
299+
300+
vol_size = (4, 4, 4)
301+
volumes = 5
302+
if marker:
303+
vol_size, volumes = marker.args
304+
305+
rng = request.node.rng
306+
307+
base_dataobj, affine = _generate_random_uniform_spatial_data(
308+
request, (*vol_size, volumes), 0.0, 1.0
309+
)
310+
brainmask_dataobj = rng.choice([True, False], size=vol_size).astype(np.uint8)
311+
motion_affines = rng.random((volumes, 4, 4))
312+
datahdr = None
313+
314+
return (
315+
base_dataobj,
316+
affine,
317+
brainmask_dataobj,
318+
motion_affines,
319+
datahdr,
320+
)
321+
322+
295323
@pytest.fixture(autouse=True)
296324
def setup_random_dwi_data(request, setup_random_gtab_data):
297325
"""Automatically generate random DWI data for tests."""

test/test_model.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from dipy.sims.voxel import single_tensor
3030

3131
from nifreeze import model
32+
from nifreeze.data.base import BaseDataset
3233
from nifreeze.data.dmri import DEFAULT_MAX_S0, DEFAULT_MIN_S0, DWI
3334
from nifreeze.model._dipy import GaussianProcessModel
3435
from nifreeze.model.base import mask_absence_warn_msg
@@ -192,7 +193,86 @@ def test_dti_model(setup_random_dwi_data):
192193
assert predicted.shape == dwi_dataobj.shape[:-1]
193194

194195

195-
def test_factory(datadir):
196+
def test_factory_none_raises(setup_random_base_data):
197+
dataobj, affine, brainmask, motion_affines, datahdr = setup_random_base_data
198+
dataset = BaseDataset(
199+
dataobj=dataobj,
200+
affine=affine,
201+
brainmask=brainmask,
202+
motion_affines=motion_affines,
203+
datahdr=datahdr,
204+
)
205+
with pytest.raises(RuntimeError, match="No model identifier provided."):
206+
model.ModelFactory.init(None, dataset=dataset)
207+
208+
209+
@pytest.mark.parametrize(
210+
"name, expected_cls",
211+
[
212+
("avg", model.ExpectationModel),
213+
("average", model.ExpectationModel),
214+
("mean", model.ExpectationModel),
215+
],
216+
)
217+
def test_factory_variants(name, expected_cls, setup_random_base_data):
218+
dataobj, affine, brainmask, motion_affines, datahdr = setup_random_base_data
219+
dataset = BaseDataset(
220+
dataobj=dataobj,
221+
affine=affine,
222+
brainmask=brainmask,
223+
motion_affines=motion_affines,
224+
datahdr=datahdr,
225+
)
226+
model_instance = model.ModelFactory.init(name, dataset=dataset)
227+
assert isinstance(model_instance, expected_cls)
228+
229+
230+
@pytest.mark.parametrize("name", ["avgdwi", "averagedwi", "meandwi"])
231+
def test_factory_avgdwi_variants(monkeypatch, name, setup_random_dwi_data):
232+
(
233+
dwi_dataobj,
234+
affine,
235+
brainmask_dataobj,
236+
b0_dataobj,
237+
gradients,
238+
_,
239+
) = setup_random_dwi_data
240+
241+
dataset = DWI(
242+
dataobj=dwi_dataobj,
243+
affine=affine,
244+
brainmask=brainmask_dataobj,
245+
bzero=b0_dataobj,
246+
gradients=gradients,
247+
)
248+
249+
# Dummy class to simulate AverageDWIModel
250+
class DummyAvgDWI:
251+
def __init__(self, _dataset, **kwargs):
252+
self._dataset = _dataset
253+
self._kwargs = kwargs
254+
255+
# Patch import for AverageDWIModel
256+
import sys
257+
import types as _types
258+
259+
old_module = sys.modules.get("nifreeze.model.dmri")
260+
dmri_module = _types.ModuleType("nifreeze.model.dmri")
261+
dmri_module.AverageDWIModel = DummyAvgDWI
262+
sys.modules["nifreeze.model.dmri"] = dmri_module
263+
264+
try:
265+
model_instance = model.ModelFactory.init(name, dataset=dataset)
266+
assert isinstance(model_instance, DummyAvgDWI)
267+
finally:
268+
# Restore previous state
269+
if old_module is not None:
270+
sys.modules["nifreeze.model.dmri"] = old_module
271+
else:
272+
del sys.modules["nifreeze.model.dmri"]
273+
274+
275+
def test_factory_initializations(datadir):
196276
"""Check that the two different initialisations result in the same models"""
197277

198278
# Load test data

0 commit comments

Comments
 (0)