Skip to content

Commit 52b3ed2

Browse files
authored
Added (Rand)ScaleScaleIntensityFixedMean(d) and modified (Rand)AdjustContrast(d) (#6542)
### Description This PR adds the intensity transform **ScaleIntensityFixedMean** (including its random and random dictionary version) and modifies the intensity transform **AdjustContrast** and its dictionary version. It adds functionality available in the corresponding nnU-Net transforms [**ContrastAugmentationTransform**](https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/transforms/color_transforms.py#L25) and [**GammaTransform**](https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/transforms/color_transforms.py#L132). Specifically, **ScaleIntensityFixedMean** scales the intensity of the input image by a factor _v = v * (1 + factor)_ (same as the existing **ScaleIntensity** transform when used with the factor argument). The added functionality is provided by two arguments: 1. _fixed_mean_: subtract the mean intensity before scaling with _factor_, then add the same value after scaling to ensure that the output has the same mean intensity as the input. 2. _preserve_range_: clips the output array/tensor to the range of the input array/tensor AdjustContrast was modified by adding two arguments: 1. _invert_image_: multiplies all intensity values by -1 before gamma transform and again after gamma transform 2. _retain_stats_: applies a scaling factor and an offset to all intensity values after the gamma transform to ensure that the output intensity distribution has the same mean and standard deviation as the intensity distribution of the input ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Aaron Kujawa <[email protected]>
1 parent c33f1ba commit 52b3ed2

9 files changed

+593
-37
lines changed

docs/source/transforms.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,18 @@ Intensity
317317
:members:
318318
:special-members: __call__
319319

320+
`ScaleIntensityFixedMean`
321+
"""""""""""""""""""""""""
322+
.. autoclass:: ScaleIntensityFixedMean
323+
:members:
324+
:special-members: __call__
325+
326+
`RandScaleIntensityFixedMean`
327+
"""""""""""""""""""""""""""""
328+
.. autoclass:: RandScaleIntensityFixedMean
329+
:members:
330+
:special-members: __call__
331+
320332
`NormalizeIntensity`
321333
""""""""""""""""""""
322334
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/NormalizeIntensity.png
@@ -1386,6 +1398,12 @@ Intensity (Dict)
13861398
:members:
13871399
:special-members: __call__
13881400

1401+
`RandScaleIntensityFixedMeand`
1402+
"""""""""""""""""""""""""""""""
1403+
.. autoclass:: RandScaleIntensityFixedMeand
1404+
:members:
1405+
:special-members: __call__
1406+
13891407
`NormalizeIntensityd`
13901408
"""""""""""""""""""""
13911409
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/NormalizeIntensityd.png

monai/transforms/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,12 @@
118118
RandKSpaceSpikeNoise,
119119
RandRicianNoise,
120120
RandScaleIntensity,
121+
RandScaleIntensityFixedMean,
121122
RandShiftIntensity,
122123
RandStdShiftIntensity,
123124
SavitzkyGolaySmooth,
124125
ScaleIntensity,
126+
ScaleIntensityFixedMean,
125127
ScaleIntensityRange,
126128
ScaleIntensityRangePercentiles,
127129
ShiftIntensity,
@@ -198,6 +200,9 @@
198200
RandScaleIntensityd,
199201
RandScaleIntensityD,
200202
RandScaleIntensityDict,
203+
RandScaleIntensityFixedMeand,
204+
RandScaleIntensityFixedMeanD,
205+
RandScaleIntensityFixedMeanDict,
201206
RandShiftIntensityd,
202207
RandShiftIntensityD,
203208
RandShiftIntensityDict,

monai/transforms/intensity/array.py

Lines changed: 223 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
"RandBiasField",
4949
"ScaleIntensity",
5050
"RandScaleIntensity",
51+
"ScaleIntensityFixedMean",
52+
"RandScaleIntensityFixedMean",
5153
"NormalizeIntensity",
5254
"ThresholdIntensity",
5355
"ScaleIntensityRange",
@@ -466,6 +468,161 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
466468
return ret
467469

468470

471+
class ScaleIntensityFixedMean(Transform):
472+
"""
473+
Scale the intensity of input image by ``v = v * (1 + factor)``, then shift the output so that the output image has the
474+
same mean as the input.
475+
"""
476+
477+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
478+
479+
def __init__(
480+
self,
481+
factor: float = 0,
482+
preserve_range: bool = False,
483+
fixed_mean: bool = True,
484+
channel_wise: bool = False,
485+
dtype: DtypeLike = np.float32,
486+
) -> None:
487+
"""
488+
Args:
489+
factor: factor scale by ``v = v * (1 + factor)``.
490+
preserve_range: clips the output array/tensor to the range of the input array/tensor
491+
fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling
492+
to ensure that the output has the same mean as the input.
493+
channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied
494+
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
495+
channel of the image if True.
496+
dtype: output data type, if None, same as input image. defaults to float32.
497+
"""
498+
self.factor = factor
499+
self.preserve_range = preserve_range
500+
self.fixed_mean = fixed_mean
501+
self.channel_wise = channel_wise
502+
self.dtype = dtype
503+
504+
def __call__(self, img: NdarrayOrTensor, factor=None) -> NdarrayOrTensor:
505+
"""
506+
Apply the transform to `img`.
507+
Args:
508+
img: the input tensor/array
509+
factor: factor scale by ``v = v * (1 + factor)``
510+
511+
"""
512+
513+
factor = factor if factor is not None else self.factor
514+
515+
img = convert_to_tensor(img, track_meta=get_track_meta())
516+
img_t = convert_to_tensor(img, track_meta=False)
517+
ret: NdarrayOrTensor
518+
if self.channel_wise:
519+
out = []
520+
for d in img_t:
521+
if self.preserve_range:
522+
clip_min = d.min()
523+
clip_max = d.max()
524+
525+
if self.fixed_mean:
526+
mn = d.mean()
527+
d = d - mn
528+
529+
out_channel = d * (1 + factor)
530+
531+
if self.fixed_mean:
532+
out_channel = out_channel + mn
533+
534+
if self.preserve_range:
535+
out_channel = clip(out_channel, clip_min, clip_max)
536+
537+
out.append(out_channel)
538+
ret = torch.stack(out) # type: ignore
539+
else:
540+
if self.preserve_range:
541+
clip_min = img_t.min()
542+
clip_max = img_t.max()
543+
544+
if self.fixed_mean:
545+
mn = img_t.mean()
546+
img_t = img_t - mn
547+
548+
ret = img_t * (1 + factor)
549+
550+
if self.fixed_mean:
551+
ret = ret + mn
552+
553+
if self.preserve_range:
554+
ret = clip(ret, clip_min, clip_max)
555+
556+
ret = convert_to_dst_type(ret, dst=img, dtype=self.dtype or img_t.dtype)[0]
557+
return ret
558+
559+
560+
class RandScaleIntensityFixedMean(RandomizableTransform):
561+
"""
562+
Randomly scale the intensity of input image by ``v = v * (1 + factor)`` where the `factor`
563+
is randomly picked. Subtract the mean intensity before scaling with `factor`, then add the same value after scaling
564+
to ensure that the output has the same mean as the input.
565+
"""
566+
567+
backend = ScaleIntensityFixedMean.backend
568+
569+
def __init__(
570+
self,
571+
prob: float = 0.1,
572+
factors: Sequence[float] | float = 0,
573+
fixed_mean: bool = True,
574+
preserve_range: bool = False,
575+
dtype: DtypeLike = np.float32,
576+
) -> None:
577+
"""
578+
Args:
579+
factors: factor range to randomly scale by ``v = v * (1 + factor)``.
580+
if single number, factor value is picked from (-factors, factors).
581+
preserve_range: clips the output array/tensor to the range of the input array/tensor
582+
fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling
583+
to ensure that the output has the same mean as the input.
584+
channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied
585+
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
586+
channel of the image if True.
587+
dtype: output data type, if None, same as input image. defaults to float32.
588+
589+
"""
590+
RandomizableTransform.__init__(self, prob)
591+
if isinstance(factors, (int, float)):
592+
self.factors = (min(-factors, factors), max(-factors, factors))
593+
elif len(factors) != 2:
594+
raise ValueError("factors should be a number or pair of numbers.")
595+
else:
596+
self.factors = (min(factors), max(factors))
597+
self.factor = self.factors[0]
598+
self.fixed_mean = fixed_mean
599+
self.preserve_range = preserve_range
600+
self.dtype = dtype
601+
602+
self.scaler = ScaleIntensityFixedMean(
603+
factor=self.factor, fixed_mean=self.fixed_mean, preserve_range=self.preserve_range, dtype=self.dtype
604+
)
605+
606+
def randomize(self, data: Any | None = None) -> None:
607+
super().randomize(None)
608+
if not self._do_transform:
609+
return None
610+
self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])
611+
612+
def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:
613+
"""
614+
Apply the transform to `img`.
615+
"""
616+
img = convert_to_tensor(img, track_meta=get_track_meta())
617+
if randomize:
618+
self.randomize()
619+
620+
if not self._do_transform:
621+
return convert_data_type(img, dtype=self.dtype)[0]
622+
623+
return self.scaler(img, self.factor)
624+
625+
469626
class RandScaleIntensity(RandomizableTransform):
470627
"""
471628
Randomly scale the intensity of input image by ``v = v * (1 + factor)`` where the `factor`
@@ -799,48 +956,99 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
799956

800957
class AdjustContrast(Transform):
801958
"""
802-
Changes image intensity by gamma. Each pixel/voxel intensity is updated as::
959+
Changes image intensity with gamma transform. Each pixel/voxel intensity is updated as::
803960
804961
x = ((x - min) / intensity_range) ^ gamma * intensity_range + min
805962
806963
Args:
807964
gamma: gamma value to adjust the contrast as function.
965+
invert_image: whether to invert the image before applying gamma augmentation. If True, multiply all intensity
966+
values with -1 before the gamma transform and again after the gamma transform. This behaviour is mimicked
967+
from `nnU-Net <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this
968+
<https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_
969+
function.
970+
retain_stats: if True, applies a scaling factor and an offset to all intensity values after gamma transform to
971+
ensure that the output intensity distribution has the same mean and standard deviation as the intensity
972+
distribution of the input. This behaviour is mimicked from `nnU-Net
973+
<https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this
974+
<https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_
975+
function.
808976
"""
809977

810978
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
811979

812-
def __init__(self, gamma: float) -> None:
980+
def __init__(self, gamma: float, invert_image: bool = False, retain_stats: bool = False) -> None:
813981
if not isinstance(gamma, (int, float)):
814982
raise ValueError(f"gamma must be a float or int number, got {type(gamma)} {gamma}.")
815983
self.gamma = gamma
984+
self.invert_image = invert_image
985+
self.retain_stats = retain_stats
816986

817-
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
987+
def __call__(self, img: NdarrayOrTensor, gamma=None) -> NdarrayOrTensor:
818988
"""
819989
Apply the transform to `img`.
990+
gamma: gamma value to adjust the contrast as function.
820991
"""
821992
img = convert_to_tensor(img, track_meta=get_track_meta())
993+
gamma = gamma if gamma is not None else self.gamma
994+
995+
if self.invert_image:
996+
img = -img
997+
998+
if self.retain_stats:
999+
mn = img.mean()
1000+
sd = img.std()
1001+
8221002
epsilon = 1e-7
8231003
img_min = img.min()
8241004
img_range = img.max() - img_min
825-
ret: NdarrayOrTensor = ((img - img_min) / float(img_range + epsilon)) ** self.gamma * img_range + img_min
1005+
ret: NdarrayOrTensor = ((img - img_min) / float(img_range + epsilon)) ** gamma * img_range + img_min
1006+
1007+
if self.retain_stats:
1008+
# zero mean and normalize
1009+
ret = ret - ret.mean()
1010+
ret = ret / (ret.std() + 1e-8)
1011+
# restore old mean and standard deviation
1012+
ret = sd * ret + mn
1013+
1014+
if self.invert_image:
1015+
ret = -ret
1016+
8261017
return ret
8271018

8281019

8291020
class RandAdjustContrast(RandomizableTransform):
8301021
"""
831-
Randomly changes image intensity by gamma. Each pixel/voxel intensity is updated as::
1022+
Randomly changes image intensity with gamma transform. Each pixel/voxel intensity is updated as:
8321023
8331024
x = ((x - min) / intensity_range) ^ gamma * intensity_range + min
8341025
8351026
Args:
8361027
prob: Probability of adjustment.
8371028
gamma: Range of gamma values.
8381029
If single number, value is picked from (0.5, gamma), default is (0.5, 4.5).
1030+
invert_image: whether to invert the image before applying gamma augmentation. If True, multiply all intensity
1031+
values with -1 before the gamma transform and again after the gamma transform. This behaviour is mimicked
1032+
from `nnU-Net <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this
1033+
<https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_
1034+
function.
1035+
retain_stats: if True, applies a scaling factor and an offset to all intensity values after gamma transform to
1036+
ensure that the output intensity distribution has the same mean and standard deviation as the intensity
1037+
distribution of the input. This behaviour is mimicked from `nnU-Net
1038+
<https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this
1039+
<https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_
1040+
function.
8391041
"""
8401042

8411043
backend = AdjustContrast.backend
8421044

843-
def __init__(self, prob: float = 0.1, gamma: Sequence[float] | float = (0.5, 4.5)) -> None:
1045+
def __init__(
1046+
self,
1047+
prob: float = 0.1,
1048+
gamma: Sequence[float] | float = (0.5, 4.5),
1049+
invert_image: bool = False,
1050+
retain_stats: bool = False,
1051+
) -> None:
8441052
RandomizableTransform.__init__(self, prob)
8451053

8461054
if isinstance(gamma, (int, float)):
@@ -854,7 +1062,13 @@ def __init__(self, prob: float = 0.1, gamma: Sequence[float] | float = (0.5, 4.5
8541062
else:
8551063
self.gamma = (min(gamma), max(gamma))
8561064

857-
self.gamma_value: float | None = None
1065+
self.gamma_value: float = 1.0
1066+
self.invert_image: bool = invert_image
1067+
self.retain_stats: bool = retain_stats
1068+
1069+
self.adjust_contrast = AdjustContrast(
1070+
self.gamma_value, invert_image=self.invert_image, retain_stats=self.retain_stats
1071+
)
8581072

8591073
def randomize(self, data: Any | None = None) -> None:
8601074
super().randomize(None)
@@ -875,7 +1089,8 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen
8751089

8761090
if self.gamma_value is None:
8771091
raise RuntimeError("gamma_value is not set, please call `randomize` function first.")
878-
return AdjustContrast(self.gamma_value)(img)
1092+
1093+
return self.adjust_contrast(img, self.gamma_value)
8791094

8801095

8811096
class ScaleIntensityRangePercentiles(Transform):

0 commit comments

Comments
 (0)