|  | 
| 29 | 29 | from dipy.sims.voxel import single_tensor | 
| 30 | 30 | 
 | 
| 31 | 31 | from nifreeze import model | 
|  | 32 | +from nifreeze.data.base import BaseDataset | 
| 32 | 33 | from nifreeze.data.dmri import DEFAULT_MAX_S0, DEFAULT_MIN_S0, DWI | 
| 33 | 34 | from nifreeze.model._dipy import GaussianProcessModel | 
| 34 | 35 | from nifreeze.model.base import mask_absence_warn_msg | 
| @@ -192,7 +193,86 @@ def test_dti_model(setup_random_dwi_data): | 
| 192 | 193 |     assert predicted.shape == dwi_dataobj.shape[:-1] | 
| 193 | 194 | 
 | 
| 194 | 195 | 
 | 
| 195 |  | -def test_factory(datadir): | 
|  | 196 | +def test_factory_none_raises(setup_random_base_data): | 
|  | 197 | +    dataobj, affine, brainmask, motion_affines, datahdr = setup_random_base_data | 
|  | 198 | +    dataset = BaseDataset( | 
|  | 199 | +        dataobj=dataobj, | 
|  | 200 | +        affine=affine, | 
|  | 201 | +        brainmask=brainmask, | 
|  | 202 | +        motion_affines=motion_affines, | 
|  | 203 | +        datahdr=datahdr, | 
|  | 204 | +    ) | 
|  | 205 | +    with pytest.raises(RuntimeError, match="No model identifier provided."): | 
|  | 206 | +        model.ModelFactory.init(None, dataset=dataset) | 
|  | 207 | + | 
|  | 208 | + | 
|  | 209 | +@pytest.mark.parametrize( | 
|  | 210 | +    "name, expected_cls", | 
|  | 211 | +    [ | 
|  | 212 | +        ("avg", model.ExpectationModel), | 
|  | 213 | +        ("average", model.ExpectationModel), | 
|  | 214 | +        ("mean", model.ExpectationModel), | 
|  | 215 | +    ], | 
|  | 216 | +) | 
|  | 217 | +def test_factory_variants(name, expected_cls, setup_random_base_data): | 
|  | 218 | +    dataobj, affine, brainmask, motion_affines, datahdr = setup_random_base_data | 
|  | 219 | +    dataset = BaseDataset( | 
|  | 220 | +        dataobj=dataobj, | 
|  | 221 | +        affine=affine, | 
|  | 222 | +        brainmask=brainmask, | 
|  | 223 | +        motion_affines=motion_affines, | 
|  | 224 | +        datahdr=datahdr, | 
|  | 225 | +    ) | 
|  | 226 | +    model_instance = model.ModelFactory.init(name, dataset=dataset) | 
|  | 227 | +    assert isinstance(model_instance, expected_cls) | 
|  | 228 | + | 
|  | 229 | + | 
|  | 230 | +@pytest.mark.parametrize("name", ["avgdwi", "averagedwi", "meandwi"]) | 
|  | 231 | +def test_factory_avgdwi_variants(monkeypatch, name, setup_random_dwi_data): | 
|  | 232 | +    ( | 
|  | 233 | +        dwi_dataobj, | 
|  | 234 | +        affine, | 
|  | 235 | +        brainmask_dataobj, | 
|  | 236 | +        b0_dataobj, | 
|  | 237 | +        gradients, | 
|  | 238 | +        _, | 
|  | 239 | +    ) = setup_random_dwi_data | 
|  | 240 | + | 
|  | 241 | +    dataset = DWI( | 
|  | 242 | +        dataobj=dwi_dataobj, | 
|  | 243 | +        affine=affine, | 
|  | 244 | +        brainmask=brainmask_dataobj, | 
|  | 245 | +        bzero=b0_dataobj, | 
|  | 246 | +        gradients=gradients, | 
|  | 247 | +    ) | 
|  | 248 | + | 
|  | 249 | +    # Dummy class to simulate AverageDWIModel | 
|  | 250 | +    class DummyAvgDWI: | 
|  | 251 | +        def __init__(self, _dataset, **kwargs): | 
|  | 252 | +            self._dataset = _dataset | 
|  | 253 | +            self._kwargs = kwargs | 
|  | 254 | + | 
|  | 255 | +    # Patch import for AverageDWIModel | 
|  | 256 | +    import sys | 
|  | 257 | +    import types as _types | 
|  | 258 | + | 
|  | 259 | +    old_module = sys.modules.get("nifreeze.model.dmri") | 
|  | 260 | +    dmri_module = _types.ModuleType("nifreeze.model.dmri") | 
|  | 261 | +    dmri_module.AverageDWIModel = DummyAvgDWI | 
|  | 262 | +    sys.modules["nifreeze.model.dmri"] = dmri_module | 
|  | 263 | + | 
|  | 264 | +    try: | 
|  | 265 | +        model_instance = model.ModelFactory.init(name, dataset=dataset) | 
|  | 266 | +        assert isinstance(model_instance, DummyAvgDWI) | 
|  | 267 | +    finally: | 
|  | 268 | +        # Restore previous state | 
|  | 269 | +        if old_module is not None: | 
|  | 270 | +            sys.modules["nifreeze.model.dmri"] = old_module | 
|  | 271 | +        else: | 
|  | 272 | +            del sys.modules["nifreeze.model.dmri"] | 
|  | 273 | + | 
|  | 274 | + | 
|  | 275 | +def test_factory_initializations(datadir): | 
| 196 | 276 |     """Check that the two different initialisations result in the same models""" | 
| 197 | 277 | 
 | 
| 198 | 278 |     # Load test data | 
|  | 
0 commit comments