48
48
"RandBiasField" ,
49
49
"ScaleIntensity" ,
50
50
"RandScaleIntensity" ,
51
+ "ScaleIntensityFixedMean" ,
52
+ "RandScaleIntensityFixedMean" ,
51
53
"NormalizeIntensity" ,
52
54
"ThresholdIntensity" ,
53
55
"ScaleIntensityRange" ,
@@ -466,6 +468,161 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
466
468
return ret
467
469
468
470
471
+ class ScaleIntensityFixedMean (Transform ):
472
+ """
473
+ Scale the intensity of input image by ``v = v * (1 + factor)``, then shift the output so that the output image has the
474
+ same mean as the input.
475
+ """
476
+
477
+ backend = [TransformBackends .TORCH , TransformBackends .NUMPY ]
478
+
479
+ def __init__ (
480
+ self ,
481
+ factor : float = 0 ,
482
+ preserve_range : bool = False ,
483
+ fixed_mean : bool = True ,
484
+ channel_wise : bool = False ,
485
+ dtype : DtypeLike = np .float32 ,
486
+ ) -> None :
487
+ """
488
+ Args:
489
+ factor: factor scale by ``v = v * (1 + factor)``.
490
+ preserve_range: clips the output array/tensor to the range of the input array/tensor
491
+ fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling
492
+ to ensure that the output has the same mean as the input.
493
+ channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied
494
+ on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
495
+ channel of the image if True.
496
+ dtype: output data type, if None, same as input image. defaults to float32.
497
+ """
498
+ self .factor = factor
499
+ self .preserve_range = preserve_range
500
+ self .fixed_mean = fixed_mean
501
+ self .channel_wise = channel_wise
502
+ self .dtype = dtype
503
+
504
+ def __call__ (self , img : NdarrayOrTensor , factor = None ) -> NdarrayOrTensor :
505
+ """
506
+ Apply the transform to `img`.
507
+ Args:
508
+ img: the input tensor/array
509
+ factor: factor scale by ``v = v * (1 + factor)``
510
+
511
+ """
512
+
513
+ factor = factor if factor is not None else self .factor
514
+
515
+ img = convert_to_tensor (img , track_meta = get_track_meta ())
516
+ img_t = convert_to_tensor (img , track_meta = False )
517
+ ret : NdarrayOrTensor
518
+ if self .channel_wise :
519
+ out = []
520
+ for d in img_t :
521
+ if self .preserve_range :
522
+ clip_min = d .min ()
523
+ clip_max = d .max ()
524
+
525
+ if self .fixed_mean :
526
+ mn = d .mean ()
527
+ d = d - mn
528
+
529
+ out_channel = d * (1 + factor )
530
+
531
+ if self .fixed_mean :
532
+ out_channel = out_channel + mn
533
+
534
+ if self .preserve_range :
535
+ out_channel = clip (out_channel , clip_min , clip_max )
536
+
537
+ out .append (out_channel )
538
+ ret = torch .stack (out ) # type: ignore
539
+ else :
540
+ if self .preserve_range :
541
+ clip_min = img_t .min ()
542
+ clip_max = img_t .max ()
543
+
544
+ if self .fixed_mean :
545
+ mn = img_t .mean ()
546
+ img_t = img_t - mn
547
+
548
+ ret = img_t * (1 + factor )
549
+
550
+ if self .fixed_mean :
551
+ ret = ret + mn
552
+
553
+ if self .preserve_range :
554
+ ret = clip (ret , clip_min , clip_max )
555
+
556
+ ret = convert_to_dst_type (ret , dst = img , dtype = self .dtype or img_t .dtype )[0 ]
557
+ return ret
558
+
559
+
560
+ class RandScaleIntensityFixedMean (RandomizableTransform ):
561
+ """
562
+ Randomly scale the intensity of input image by ``v = v * (1 + factor)`` where the `factor`
563
+ is randomly picked. Subtract the mean intensity before scaling with `factor`, then add the same value after scaling
564
+ to ensure that the output has the same mean as the input.
565
+ """
566
+
567
+ backend = ScaleIntensityFixedMean .backend
568
+
569
+ def __init__ (
570
+ self ,
571
+ prob : float = 0.1 ,
572
+ factors : Sequence [float ] | float = 0 ,
573
+ fixed_mean : bool = True ,
574
+ preserve_range : bool = False ,
575
+ dtype : DtypeLike = np .float32 ,
576
+ ) -> None :
577
+ """
578
+ Args:
579
+ factors: factor range to randomly scale by ``v = v * (1 + factor)``.
580
+ if single number, factor value is picked from (-factors, factors).
581
+ preserve_range: clips the output array/tensor to the range of the input array/tensor
582
+ fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling
583
+ to ensure that the output has the same mean as the input.
584
+ channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied
585
+ on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
586
+ channel of the image if True.
587
+ dtype: output data type, if None, same as input image. defaults to float32.
588
+
589
+ """
590
+ RandomizableTransform .__init__ (self , prob )
591
+ if isinstance (factors , (int , float )):
592
+ self .factors = (min (- factors , factors ), max (- factors , factors ))
593
+ elif len (factors ) != 2 :
594
+ raise ValueError ("factors should be a number or pair of numbers." )
595
+ else :
596
+ self .factors = (min (factors ), max (factors ))
597
+ self .factor = self .factors [0 ]
598
+ self .fixed_mean = fixed_mean
599
+ self .preserve_range = preserve_range
600
+ self .dtype = dtype
601
+
602
+ self .scaler = ScaleIntensityFixedMean (
603
+ factor = self .factor , fixed_mean = self .fixed_mean , preserve_range = self .preserve_range , dtype = self .dtype
604
+ )
605
+
606
+ def randomize (self , data : Any | None = None ) -> None :
607
+ super ().randomize (None )
608
+ if not self ._do_transform :
609
+ return None
610
+ self .factor = self .R .uniform (low = self .factors [0 ], high = self .factors [1 ])
611
+
612
+ def __call__ (self , img : NdarrayOrTensor , randomize : bool = True ) -> NdarrayOrTensor :
613
+ """
614
+ Apply the transform to `img`.
615
+ """
616
+ img = convert_to_tensor (img , track_meta = get_track_meta ())
617
+ if randomize :
618
+ self .randomize ()
619
+
620
+ if not self ._do_transform :
621
+ return convert_data_type (img , dtype = self .dtype )[0 ]
622
+
623
+ return self .scaler (img , self .factor )
624
+
625
+
469
626
class RandScaleIntensity (RandomizableTransform ):
470
627
"""
471
628
Randomly scale the intensity of input image by ``v = v * (1 + factor)`` where the `factor`
@@ -799,48 +956,99 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
799
956
800
957
class AdjustContrast (Transform ):
801
958
"""
802
- Changes image intensity by gamma. Each pixel/voxel intensity is updated as::
959
+ Changes image intensity with gamma transform . Each pixel/voxel intensity is updated as::
803
960
804
961
x = ((x - min) / intensity_range) ^ gamma * intensity_range + min
805
962
806
963
Args:
807
964
gamma: gamma value to adjust the contrast as function.
965
+ invert_image: whether to invert the image before applying gamma augmentation. If True, multiply all intensity
966
+ values with -1 before the gamma transform and again after the gamma transform. This behaviour is mimicked
967
+ from `nnU-Net <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this
968
+ <https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_
969
+ function.
970
+ retain_stats: if True, applies a scaling factor and an offset to all intensity values after gamma transform to
971
+ ensure that the output intensity distribution has the same mean and standard deviation as the intensity
972
+ distribution of the input. This behaviour is mimicked from `nnU-Net
973
+ <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this
974
+ <https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_
975
+ function.
808
976
"""
809
977
810
978
backend = [TransformBackends .TORCH , TransformBackends .NUMPY ]
811
979
812
- def __init__ (self , gamma : float ) -> None :
980
+ def __init__ (self , gamma : float , invert_image : bool = False , retain_stats : bool = False ) -> None :
813
981
if not isinstance (gamma , (int , float )):
814
982
raise ValueError (f"gamma must be a float or int number, got { type (gamma )} { gamma } ." )
815
983
self .gamma = gamma
984
+ self .invert_image = invert_image
985
+ self .retain_stats = retain_stats
816
986
817
- def __call__ (self , img : NdarrayOrTensor ) -> NdarrayOrTensor :
987
+ def __call__ (self , img : NdarrayOrTensor , gamma = None ) -> NdarrayOrTensor :
818
988
"""
819
989
Apply the transform to `img`.
990
+ gamma: gamma value to adjust the contrast as function.
820
991
"""
821
992
img = convert_to_tensor (img , track_meta = get_track_meta ())
993
+ gamma = gamma if gamma is not None else self .gamma
994
+
995
+ if self .invert_image :
996
+ img = - img
997
+
998
+ if self .retain_stats :
999
+ mn = img .mean ()
1000
+ sd = img .std ()
1001
+
822
1002
epsilon = 1e-7
823
1003
img_min = img .min ()
824
1004
img_range = img .max () - img_min
825
- ret : NdarrayOrTensor = ((img - img_min ) / float (img_range + epsilon )) ** self .gamma * img_range + img_min
1005
+ ret : NdarrayOrTensor = ((img - img_min ) / float (img_range + epsilon )) ** gamma * img_range + img_min
1006
+
1007
+ if self .retain_stats :
1008
+ # zero mean and normalize
1009
+ ret = ret - ret .mean ()
1010
+ ret = ret / (ret .std () + 1e-8 )
1011
+ # restore old mean and standard deviation
1012
+ ret = sd * ret + mn
1013
+
1014
+ if self .invert_image :
1015
+ ret = - ret
1016
+
826
1017
return ret
827
1018
828
1019
829
1020
class RandAdjustContrast (RandomizableTransform ):
830
1021
"""
831
- Randomly changes image intensity by gamma. Each pixel/voxel intensity is updated as: :
1022
+ Randomly changes image intensity with gamma transform . Each pixel/voxel intensity is updated as:
832
1023
833
1024
x = ((x - min) / intensity_range) ^ gamma * intensity_range + min
834
1025
835
1026
Args:
836
1027
prob: Probability of adjustment.
837
1028
gamma: Range of gamma values.
838
1029
If single number, value is picked from (0.5, gamma), default is (0.5, 4.5).
1030
+ invert_image: whether to invert the image before applying gamma augmentation. If True, multiply all intensity
1031
+ values with -1 before the gamma transform and again after the gamma transform. This behaviour is mimicked
1032
+ from `nnU-Net <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this
1033
+ <https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_
1034
+ function.
1035
+ retain_stats: if True, applies a scaling factor and an offset to all intensity values after gamma transform to
1036
+ ensure that the output intensity distribution has the same mean and standard deviation as the intensity
1037
+ distribution of the input. This behaviour is mimicked from `nnU-Net
1038
+ <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this
1039
+ <https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_
1040
+ function.
839
1041
"""
840
1042
841
1043
backend = AdjustContrast .backend
842
1044
843
- def __init__ (self , prob : float = 0.1 , gamma : Sequence [float ] | float = (0.5 , 4.5 )) -> None :
1045
+ def __init__ (
1046
+ self ,
1047
+ prob : float = 0.1 ,
1048
+ gamma : Sequence [float ] | float = (0.5 , 4.5 ),
1049
+ invert_image : bool = False ,
1050
+ retain_stats : bool = False ,
1051
+ ) -> None :
844
1052
RandomizableTransform .__init__ (self , prob )
845
1053
846
1054
if isinstance (gamma , (int , float )):
@@ -854,7 +1062,13 @@ def __init__(self, prob: float = 0.1, gamma: Sequence[float] | float = (0.5, 4.5
854
1062
else :
855
1063
self .gamma = (min (gamma ), max (gamma ))
856
1064
857
- self .gamma_value : float | None = None
1065
+ self .gamma_value : float = 1.0
1066
+ self .invert_image : bool = invert_image
1067
+ self .retain_stats : bool = retain_stats
1068
+
1069
+ self .adjust_contrast = AdjustContrast (
1070
+ self .gamma_value , invert_image = self .invert_image , retain_stats = self .retain_stats
1071
+ )
858
1072
859
1073
def randomize (self , data : Any | None = None ) -> None :
860
1074
super ().randomize (None )
@@ -875,7 +1089,8 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen
875
1089
876
1090
if self .gamma_value is None :
877
1091
raise RuntimeError ("gamma_value is not set, please call `randomize` function first." )
878
- return AdjustContrast (self .gamma_value )(img )
1092
+
1093
+ return self .adjust_contrast (img , self .gamma_value )
879
1094
880
1095
881
1096
class ScaleIntensityRangePercentiles (Transform ):
0 commit comments