Skip to content

Commit c859ef6

Browse files
committed
✍️ add prob in tf augmentations
1 parent bc825a3 commit c859ef6

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

tensorflow_asr/augmentations/augments.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,29 +40,32 @@
4040

4141

4242
class TFAugmentationExecutor:
43-
def __init__(self, augmentations: list):
43+
def __init__(self, augmentations: list, prob: float = 0.5):
4444
self.augmentations = augmentations
45+
self.prob = prob
4546

4647
@tf.function
4748
def augment(self, inputs):
4849
outputs = inputs
4950
for au in self.augmentations:
50-
outputs = au.augment(outputs)
51+
if tf.random.uniform([]) < self.prob:
52+
outputs = au.augment(outputs)
5153
return outputs
5254

5355

5456
class Augmentation:
5557
def __init__(self, config: dict = None, use_tf: bool = False):
5658
if not config: config = {}
59+
prob = float(config.pop("prob", 0.5))
5760
if use_tf:
58-
self.before = self.tf_parse(config.pop("before", {}))
59-
self.after = self.tf_parse(config.pop("after", {}))
61+
self.before = self.tf_parse(config.pop("before", {}), prob=prob)
62+
self.after = self.tf_parse(config.pop("after", {}), prob=prob)
6063
else:
61-
self.before = self.parse(config.pop("before", {}))
62-
self.after = self.parse(config.pop("after", {}))
64+
self.before = self.parse(config.pop("before", {}), prob=prob)
65+
self.after = self.parse(config.pop("after", {}), prob=prob)
6366

6467
@staticmethod
65-
def parse(config: dict) -> list:
68+
def parse(config: dict, prob: float = 0.5) -> naf.Sometimes:
6669
augmentations = []
6770
for key, value in config.items():
6871
au = AUGMENTATIONS.get(key, None)
@@ -71,10 +74,10 @@ def parse(config: dict) -> list:
7174
f"Available augmentations: {AUGMENTATIONS.keys()}")
7275
aug = au(**value) if value is not None else au()
7376
augmentations.append(aug)
74-
return naf.Sometimes(augmentations)
77+
return naf.Sometimes(augmentations, pipeline_p=prob)
7578

7679
@staticmethod
77-
def tf_parse(config: dict) -> list:
80+
def tf_parse(config: dict, prob: float = 0.5) -> TFAugmentationExecutor:
7881
augmentations = []
7982
for key, value in config.items():
8083
au = TFAUGMENTATIONS.get(key, None)
@@ -83,4 +86,4 @@ def tf_parse(config: dict) -> list:
8386
f"Available tf augmentations: {TFAUGMENTATIONS.keys()}")
8487
aug = au(**value) if value is not None else au()
8588
augmentations.append(aug)
86-
return TFAugmentationExecutor(augmentations)
89+
return TFAugmentationExecutor(augmentations, prob=prob)

0 commit comments

Comments
 (0)