|
23 | 23 | """Unit tests exercising models.""" |
24 | 24 |
|
25 | 25 | import contextlib |
| 26 | +from typing import List |
26 | 27 |
|
27 | 28 | import numpy as np |
28 | 29 | import pytest |
|
43 | 44 |
|
44 | 45 |
|
45 | 46 | # Dummy classes to simulate model factory essential features |
46 | | -class DummyModel: |
| 47 | +class DummyDMRIModel: |
| 48 | + def __init__(self, dataset, **kwargs): |
| 49 | + self._dataset = dataset |
| 50 | + self._kwargs = kwargs |
| 51 | + |
| 52 | + |
| 53 | +class DummyPETModel: |
47 | 54 | def __init__(self, dataset, **kwargs): |
48 | 55 | self._dataset = dataset |
49 | 56 | self._kwargs = kwargs |
@@ -302,29 +309,59 @@ def __init__(self, _dataset, **kwargs): |
302 | 309 | @pytest.mark.parametrize( |
303 | 310 | "model_name, expected_cls", |
304 | 311 | [ |
305 | | - ("gqi", DummyModel), |
306 | | - ("dti", DummyModel), |
307 | | - ("DTI", DummyModel), |
308 | | - ("dki", DummyModel), |
| 312 | + ("gqi", DummyDMRIModel), |
| 313 | + ("dti", DummyDMRIModel), |
| 314 | + ("DTI", DummyDMRIModel), |
| 315 | + ("dki", DummyDMRIModel), |
| 316 | + ("pet", DummyPETModel), |
| 317 | + ("PET", DummyPETModel), |
309 | 318 | ], |
310 | 319 | ) |
311 | 320 | def test_model_factory_valid_models(monkeypatch, model_name, expected_cls): |
| 321 | + # Track which module names were requested by the factory |
| 322 | + imported_modules: List[str] = [] |
| 323 | + |
312 | 324 | # Monkeypatch import_module to return a dummy module with DTIModel, DKIModel, etc. |
313 | 325 | class DummyDMRI: |
314 | | - DTIModel = DummyModel |
315 | | - DKIModel = DummyModel |
316 | | - GQIModel = DummyModel |
| 326 | + DTIModel = DummyDMRIModel |
| 327 | + DKIModel = DummyDMRIModel |
| 328 | + GQIModel = DummyDMRIModel |
| 329 | + |
| 330 | + class DummyPET: |
| 331 | + # Use a distinct DummyPETModel so we can explicitly verify the factory |
| 332 | + # resolves to nifreeze.model.pet:PETModel (not to a dMRI model). |
| 333 | + PETModel = DummyPETModel |
317 | 334 |
|
318 | 335 | def dummy_import_module(name): |
319 | | - assert name == "nifreeze.model.dmri" |
320 | | - return DummyDMRI |
| 336 | + imported_modules.append(name) |
| 337 | + if name == "nifreeze.model.dmri": |
| 338 | + return DummyDMRI |
| 339 | + if name == "nifreeze.model.pet": |
| 340 | + return DummyPET |
| 341 | + raise ImportError(f"Unexpected import: {name}") |
321 | 342 |
|
322 | 343 | monkeypatch.setattr("importlib.import_module", dummy_import_module) |
323 | 344 | model_instance = model.ModelFactory.init(model_name, dataset=DummyDataset(), extra="value") |
324 | | - assert isinstance(model_instance, expected_cls) |
| 345 | + assert model_instance.__class__ is expected_cls |
325 | 346 | assert isinstance(model_instance._dataset, DummyDataset) |
326 | 347 | assert model_instance._kwargs.get("extra") == "value" |
327 | 348 |
|
| 349 | + # Check the imported modules |
| 350 | + if model_name.lower() == "pet": |
| 351 | + assert "nifreeze.model.pet" in imported_modules, ( |
| 352 | + "Factory should import 'nifreeze.model.pet' when model_name is 'pet'" |
| 353 | + ) |
| 354 | + assert "nifreeze.model.dmri" not in imported_modules, ( |
| 355 | + "Factory should not import 'nifreeze.model.dmri' when resolving PET models" |
| 356 | + ) |
| 357 | + else: |
| 358 | + assert "nifreeze.model.dmri" in imported_modules, ( |
| 359 | + "Factory should import 'nifreeze.model.dmri' for dMRI model names" |
| 360 | + ) |
| 361 | + assert "nifreeze.model.pet" not in imported_modules, ( |
| 362 | + "Factory should not import 'nifreeze.model.pet' when resolving dMRI models" |
| 363 | + ) |
| 364 | + |
328 | 365 |
|
329 | 366 | def test_factory_initializations(datadir): |
330 | 367 | """Check that the two different initialisations result in the same models""" |
|
0 commit comments