Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,19 +1049,34 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform):
which include TC (Tumor core), WT (Whole tumor) and ET (Enhancing tumor):
label 1 is the necrotic and non-enhancing tumor core, which should be counted under TC and WT subregion,
label 2 is the peritumoral edema, which is counted only under WT subregion,
label 4 is the GD-enhancing tumor, which should be counted under ET, TC, WT subregions.
the specified `et_label` (default 4) is the GD-enhancing tumor, which should be counted under ET, TC, WT subregions.

Args:
et_label: the label used for the GD-enhancing tumor (ET).
- Use 4 for BraTS 2018-2022.
- Use 3 for BraTS 2023.
Defaults to 4.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, et_label: int = 4) -> None:
if et_label in (1, 2):
raise ValueError(f"et_label cannot be 1 or 2, as these are reserved. Got {et_label}.")
self.et_label = et_label

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
# if img has channel dim, squeeze it
if img.ndim == 4 and img.shape[0] == 1:
img = img.squeeze(0)

result = [(img == 1) | (img == 4), (img == 1) | (img == 4) | (img == 2), img == 4]
# merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT
# label 4 is ET
result = [
(img == 1) | (img == self.et_label),
(img == 1) | (img == self.et_label) | (img == 2),
img == self.et_label,
]
# merge labels 1 (tumor non-enh) and self.et_label (tumor enh) and 2 (large edema) to WT
# self.et_label is ET (4 or 3)
return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0)


Expand Down
16 changes: 12 additions & 4 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,19 +1297,27 @@ def __call__(self, data: Mapping[Hashable, Any]):
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ConvertToMultiChannelBasedOnBratsClasses`.
Convert labels to multi channels based on brats18 classes:
Convert labels to multi channels based on brats classes:
label 1 is the necrotic and non-enhancing tumor core
label 2 is the peritumoral edema
label 4 is the GD-enhancing tumor
the specified `et_label` (default 4) is the GD-enhancing tumor
The possible classes are TC (Tumor core), WT (Whole tumor)
and ET (Enhancing tumor).

Args:
keys: keys of the corresponding items to be transformed.
et_label: the label used for the GD-enhancing tumor (ET).
- Use 4 for BraTS 2018-2022.
- Use 3 for BraTS 2023.
Defaults to 4.
allow_missing_keys: don't raise exception if key is missing.
"""

backend = ConvertToMultiChannelBasedOnBratsClasses.backend

def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False, et_label: int = 4):
super().__init__(keys, allow_missing_keys)
self.converter = ConvertToMultiChannelBasedOnBratsClasses()
self.converter = ConvertToMultiChannelBasedOnBratsClasses(et_label=et_label)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
Expand Down
32 changes: 32 additions & 0 deletions tests/transforms/test_convert_to_multi_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from tests.test_utils import TEST_NDARRAYS, assert_allclose

TESTS = []
TESTS_ET_LABEL_3 = []

# Tests for default et_label = 4
for p in TEST_NDARRAYS:
TESTS.extend(
[
Expand All @@ -46,6 +49,23 @@
]
)

# Tests for et_label = 3
for p in TEST_NDARRAYS:
TESTS_ET_LABEL_3.extend(
[
[
p([[0, 1, 2], [1, 2, 3], [0, 1, 3]]),
p(
[
[[0, 1, 0], [1, 0, 1], [0, 1, 1]],
[[0, 1, 1], [1, 1, 1], [0, 1, 1]],
[[0, 0, 0], [0, 0, 1], [0, 0, 1]],
]
),
]
]
)


class TestConvertToMultiChannel(unittest.TestCase):
@parameterized.expand(TESTS)
Expand All @@ -54,6 +74,18 @@ def test_type_shape(self, data, expected_result):
assert_allclose(result, expected_result)
self.assertTrue(result.dtype in (bool, torch.bool))

@parameterized.expand(TESTS_ET_LABEL_3)
def test_type_shape_et_label_3(self, data, expected_result):
result = ConvertToMultiChannelBasedOnBratsClasses(et_label=3)(data)
assert_allclose(result, expected_result)
self.assertTrue(result.dtype in (bool, torch.bool))

def test_invalid_et_label(self):
with self.assertRaises(ValueError):
ConvertToMultiChannelBasedOnBratsClasses(et_label=1)
with self.assertRaises(ValueError):
ConvertToMultiChannelBasedOnBratsClasses(et_label=2)


if __name__ == "__main__":
unittest.main()
11 changes: 11 additions & 0 deletions tests/transforms/test_convert_to_multi_channeld.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]),
]

TEST_CASE_ET_LABEL_3 = [
{"keys": "label", "et_label": 3},
{"label": np.array([[0, 1, 2], [1, 2, 3], [0, 1, 3]])},
np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]),
]


class TestConvertToMultiChanneld(unittest.TestCase):

Expand All @@ -32,6 +38,11 @@ def test_type_shape(self, keys, data, expected_result):
result = ConvertToMultiChannelBasedOnBratsClassesd(**keys)(data)
np.testing.assert_equal(result["label"], expected_result)

@parameterized.expand([TEST_CASE_ET_LABEL_3])
def test_et_label_3(self, keys, data, expected_result):
result = ConvertToMultiChannelBasedOnBratsClassesd(**keys)(data)
np.testing.assert_equal(result["label"], expected_result)


if __name__ == "__main__":
unittest.main()
Loading