Skip to content

Commit dc72c78

Browse files
authored
Merge pull request #250 from jhlegarreta/fix/remove-pet-thru-factory
FIX: Allow instantiating the PET model through the factory
2 parents 7ffe43e + 2706a76 commit dc72c78

File tree

2 files changed

+52
-13
lines changed

2 files changed

+52
-13
lines changed

src/nifreeze/model/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,10 @@ def init(model: str | None = None, **kwargs):
7171
if model.lower() in ("gqi", "dti", "dki", "pet"):
7272
from importlib import import_module
7373

74-
dmrimod = import_module("nifreeze.model.dmri")
75-
Model = getattr(dmrimod, f"{model.upper()}Model")
74+
thismod = import_module(
75+
f"nifreeze.model.{'pet' if model.lower() == 'pet' else 'dmri'}"
76+
)
77+
Model = getattr(thismod, f"{model.upper()}Model")
7678
return Model(kwargs.pop("dataset"), **kwargs)
7779

7880
raise NotImplementedError(UNSUPPORTED_MODEL_ERROR_MSG.format(model=model))

test/test_model.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"""Unit tests exercising models."""
2424

2525
import contextlib
26+
from typing import List
2627

2728
import numpy as np
2829
import pytest
@@ -43,7 +44,13 @@
4344

4445

4546
# Dummy classes to simulate model factory essential features
46-
class DummyModel:
47+
class DummyDMRIModel:
48+
def __init__(self, dataset, **kwargs):
49+
self._dataset = dataset
50+
self._kwargs = kwargs
51+
52+
53+
class DummyPETModel:
4754
def __init__(self, dataset, **kwargs):
4855
self._dataset = dataset
4956
self._kwargs = kwargs
@@ -302,29 +309,59 @@ def __init__(self, _dataset, **kwargs):
302309
@pytest.mark.parametrize(
303310
"model_name, expected_cls",
304311
[
305-
("gqi", DummyModel),
306-
("dti", DummyModel),
307-
("DTI", DummyModel),
308-
("dki", DummyModel),
312+
("gqi", DummyDMRIModel),
313+
("dti", DummyDMRIModel),
314+
("DTI", DummyDMRIModel),
315+
("dki", DummyDMRIModel),
316+
("pet", DummyPETModel),
317+
("PET", DummyPETModel),
309318
],
310319
)
311320
def test_model_factory_valid_models(monkeypatch, model_name, expected_cls):
321+
# Track which module names were requested by the factory
322+
imported_modules: List[str] = []
323+
312324
# Monkeypatch import_module to return a dummy module with DTIModel, DKIModel, etc.
313325
class DummyDMRI:
314-
DTIModel = DummyModel
315-
DKIModel = DummyModel
316-
GQIModel = DummyModel
326+
DTIModel = DummyDMRIModel
327+
DKIModel = DummyDMRIModel
328+
GQIModel = DummyDMRIModel
329+
330+
class DummyPET:
331+
# Use a distinct DummyPETModel so we can explicitly verify the factory
332+
# resolves to nifreeze.model.pet:PETModel (not to a dMRI model).
333+
PETModel = DummyPETModel
317334

318335
def dummy_import_module(name):
319-
assert name == "nifreeze.model.dmri"
320-
return DummyDMRI
336+
imported_modules.append(name)
337+
if name == "nifreeze.model.dmri":
338+
return DummyDMRI
339+
if name == "nifreeze.model.pet":
340+
return DummyPET
341+
raise ImportError(f"Unexpected import: {name}")
321342

322343
monkeypatch.setattr("importlib.import_module", dummy_import_module)
323344
model_instance = model.ModelFactory.init(model_name, dataset=DummyDataset(), extra="value")
324-
assert isinstance(model_instance, expected_cls)
345+
assert model_instance.__class__ is expected_cls
325346
assert isinstance(model_instance._dataset, DummyDataset)
326347
assert model_instance._kwargs.get("extra") == "value"
327348

349+
# Check the imported modules
350+
if model_name.lower() == "pet":
351+
assert "nifreeze.model.pet" in imported_modules, (
352+
"Factory should import 'nifreeze.model.pet' when model_name is 'pet'"
353+
)
354+
assert "nifreeze.model.dmri" not in imported_modules, (
355+
"Factory should not import 'nifreeze.model.dmri' when resolving PET models"
356+
)
357+
else:
358+
assert "nifreeze.model.dmri" in imported_modules, (
359+
"Factory should import 'nifreeze.model.dmri' for dMRI model names"
360+
)
361+
assert "nifreeze.model.pet" not in imported_modules, (
362+
"Factory should not import 'nifreeze.model.pet' when resolving dMRI models"
363+
)
364+
328365

329366
def test_factory_initializations(datadir):
330367
"""Check that the two different initialisations result in the same models"""

0 commit comments

Comments
 (0)