4040
4141
4242class TFAugmentationExecutor :
43- def __init__ (self , augmentations : list ):
43+ def __init__ (self , augmentations : list , prob : float = 0.5 ):
4444 self .augmentations = augmentations
45+ self .prob = prob
4546
4647 @tf .function
4748 def augment (self , inputs ):
4849 outputs = inputs
4950 for au in self .augmentations :
50- outputs = au .augment (outputs )
51+ if tf .random .uniform ([]) < self .prob :
52+ outputs = au .augment (outputs )
5153 return outputs
5254
5355
5456class Augmentation :
5557 def __init__ (self , config : dict = None , use_tf : bool = False ):
5658 if not config : config = {}
59+ prob = float (config .pop ("prob" , 0.5 ))
5760 if use_tf :
58- self .before = self .tf_parse (config .pop ("before" , {}))
59- self .after = self .tf_parse (config .pop ("after" , {}))
61+ self .before = self .tf_parse (config .pop ("before" , {}), prob = prob )
62+ self .after = self .tf_parse (config .pop ("after" , {}), prob = prob )
6063 else :
61- self .before = self .parse (config .pop ("before" , {}))
62- self .after = self .parse (config .pop ("after" , {}))
64+ self .before = self .parse (config .pop ("before" , {}), prob = prob )
65+ self .after = self .parse (config .pop ("after" , {}), prob = prob )
6366
6467 @staticmethod
65- def parse (config : dict ) -> list :
68+ def parse (config : dict , prob : float = 0.5 ) -> naf . Sometimes :
6669 augmentations = []
6770 for key , value in config .items ():
6871 au = AUGMENTATIONS .get (key , None )
@@ -71,10 +74,10 @@ def parse(config: dict) -> list:
7174 f"Available augmentations: { AUGMENTATIONS .keys ()} " )
7275 aug = au (** value ) if value is not None else au ()
7376 augmentations .append (aug )
74- return naf .Sometimes (augmentations )
77+ return naf .Sometimes (augmentations , pipeline_p = prob )
7578
7679 @staticmethod
77- def tf_parse (config : dict ) -> list :
80+ def tf_parse (config : dict , prob : float = 0.5 ) -> TFAugmentationExecutor :
7881 augmentations = []
7982 for key , value in config .items ():
8083 au = TFAUGMENTATIONS .get (key , None )
@@ -83,4 +86,4 @@ def tf_parse(config: dict) -> list:
8386 f"Available tf augmentations: { TFAUGMENTATIONS .keys ()} " )
8487 aug = au (** value ) if value is not None else au ()
8588 augmentations .append (aug )
86- return TFAugmentationExecutor (augmentations )
89+ return TFAugmentationExecutor (augmentations , prob = prob )
0 commit comments