Skip to content

Commit 74ec271

Browse files
committed
⚡ Add gradients accumulation for CTC models
1 parent 3393571 commit 74ec271

File tree

6 files changed

+283
-5
lines changed

6 files changed

+283
-5
lines changed

examples/deepspeech2/config.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,9 @@ learning_config:
6767
learning_rate: 0.0001
6868

6969
running_config:
70-
batch_size: 8
70+
batch_size: 4
7171
num_epochs: 20
72+
accumulation_steps: 8
7273
outdir: /mnt/d/SpeechProcessing/Trained/local/deepspeech2
7374
log_interval_steps: 400
7475
save_interval_steps: 400
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2020 Huy Le Nguyen (@usimarit)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import argparse
17+
from tensorflow_asr.utils import setup_environment, setup_strategy
18+
19+
setup_environment()
20+
import tensorflow as tf
21+
22+
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
23+
24+
tf.keras.backend.clear_session()
25+
26+
parser = argparse.ArgumentParser(prog="Deep Speech 2 Training")
27+
28+
parser.add_argument("--config", "-c", type=str, default=DEFAULT_YAML,
29+
help="The file path of model configuration file")
30+
31+
parser.add_argument("--max_ckpts", type=int, default=10,
32+
help="Max number of checkpoints to keep")
33+
34+
parser.add_argument("--tbs", type=int, default=None,
35+
help="Train batch size per replicas")
36+
37+
parser.add_argument("--ebs", type=int, default=None,
38+
help="Evaluation batch size per replicas")
39+
40+
parser.add_argument("--acs", type=int, default=None,
41+
help="Train accumulation steps")
42+
43+
parser.add_argument("--tfrecords", default=False, action="store_true",
44+
help="Whether to use tfrecords dataset")
45+
46+
parser.add_argument("--devices", type=int, nargs="*", default=[0],
47+
help="Devices' ids to apply distributed training")
48+
49+
parser.add_argument("--mxp", default=False, action="store_true",
50+
help="Enable mixed precision")
51+
52+
parser.add_argument("--cache", default=False, action="store_true",
53+
help="Enable caching for dataset")
54+
55+
args = parser.parse_args()
56+
57+
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
58+
59+
strategy = setup_strategy(args.devices)
60+
61+
from tensorflow_asr.configs.config import Config
62+
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
63+
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
64+
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
65+
from tensorflow_asr.runners.ctc_runners import CTCTrainerGA
66+
from tensorflow_asr.models.deepspeech2 import DeepSpeech2
67+
68+
config = Config(args.config, learning=True)
69+
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
70+
text_featurizer = CharFeaturizer(config.decoder_config)
71+
72+
if args.tfrecords:
73+
train_dataset = ASRTFRecordDataset(
74+
data_paths=config.learning_config.dataset_config.train_paths,
75+
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
76+
speech_featurizer=speech_featurizer,
77+
text_featurizer=text_featurizer,
78+
augmentations=config.learning_config.augmentations,
79+
stage="train", cache=args.cache, shuffle=True
80+
)
81+
eval_dataset = ASRTFRecordDataset(
82+
data_paths=config.learning_config.dataset_config.eval_paths,
83+
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
84+
speech_featurizer=speech_featurizer,
85+
text_featurizer=text_featurizer,
86+
stage="eval", cache=args.cache, shuffle=True
87+
)
88+
else:
89+
train_dataset = ASRSliceDataset(
90+
speech_featurizer=speech_featurizer,
91+
text_featurizer=text_featurizer,
92+
data_paths=config.learning_config.dataset_config.train_paths,
93+
augmentations=config.learning_config.augmentations,
94+
stage="train", cache=args.cache, shuffle=True
95+
)
96+
eval_dataset = ASRSliceDataset(
97+
speech_featurizer=speech_featurizer,
98+
text_featurizer=text_featurizer,
99+
data_paths=config.learning_config.dataset_config.eval_paths,
100+
stage="eval", cache=args.cache, shuffle=True
101+
)
102+
103+
ctc_trainer = CTCTrainerGA(text_featurizer, config.learning_config.running_config)
104+
# Build DS2 model
105+
with ctc_trainer.strategy.scope():
106+
ds2_model = DeepSpeech2(**config.model_config, vocabulary_size=text_featurizer.num_classes)
107+
ds2_model._build(speech_featurizer.shape)
108+
ds2_model.summary(line_length=120)
109+
# Compile
110+
ctc_trainer.compile(ds2_model, config.learning_config.optimizer_config,
111+
max_to_keep=args.max_ckpts)
112+
113+
ctc_trainer.fit(train_dataset, eval_dataset,
114+
train_bs=args.tbs, eval_bs=args.ebs, train_acs=args.acs)

examples/jasper/config.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,9 @@ learning_config:
7474
learning_rate: 0.0001
7575

7676
running_config:
77-
batch_size: 8
77+
batch_size: 4
7878
num_epochs: 20
79+
accumulation_steps: 8
7980
outdir: /mnt/d/SpeechProcessing/Trained/local/jasper
8081
log_interval_steps: 400
8182
save_interval_steps: 400

examples/jasper/train_ga_jasper.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2020 Huy Le Nguyen (@usimarit)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import argparse
17+
from tensorflow_asr.utils import setup_environment, setup_strategy
18+
19+
setup_environment()
20+
import tensorflow as tf
21+
22+
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
23+
24+
tf.keras.backend.clear_session()
25+
26+
parser = argparse.ArgumentParser(prog="Jasper Training")
27+
28+
parser.add_argument("--config", "-c", type=str, default=DEFAULT_YAML,
29+
help="The file path of model configuration file")
30+
31+
parser.add_argument("--max_ckpts", type=int, default=10,
32+
help="Max number of checkpoints to keep")
33+
34+
parser.add_argument("--tbs", type=int, default=None,
35+
help="Train batch size per replicas")
36+
37+
parser.add_argument("--ebs", type=int, default=None,
38+
help="Evaluation batch size per replicas")
39+
40+
parser.add_argument("--acs", type=int, default=None,
41+
help="Train accumulation steps")
42+
43+
parser.add_argument("--tfrecords", default=False, action="store_true",
44+
help="Whether to use tfrecords dataset")
45+
46+
parser.add_argument("--devices", type=int, nargs="*", default=[0],
47+
help="Devices' ids to apply distributed training")
48+
49+
parser.add_argument("--mxp", default=False, action="store_true",
50+
help="Enable mixed precision")
51+
52+
parser.add_argument("--cache", default=False, action="store_true",
53+
help="Enable caching for dataset")
54+
55+
args = parser.parse_args()
56+
57+
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
58+
59+
strategy = setup_strategy(args.devices)
60+
61+
from tensorflow_asr.configs.config import Config
62+
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
63+
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
64+
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
65+
from tensorflow_asr.runners.ctc_runners import CTCTrainerGA
66+
from tensorflow_asr.models.jasper import Jasper
67+
68+
config = Config(args.config, learning=True)
69+
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
70+
text_featurizer = CharFeaturizer(config.decoder_config)
71+
72+
if args.tfrecords:
73+
train_dataset = ASRTFRecordDataset(
74+
data_paths=config.learning_config.dataset_config.train_paths,
75+
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
76+
speech_featurizer=speech_featurizer,
77+
text_featurizer=text_featurizer,
78+
augmentations=config.learning_config.augmentations,
79+
stage="train", cache=args.cache, shuffle=True
80+
)
81+
eval_dataset = ASRTFRecordDataset(
82+
data_paths=config.learning_config.dataset_config.eval_paths,
83+
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
84+
speech_featurizer=speech_featurizer,
85+
text_featurizer=text_featurizer,
86+
stage="eval", cache=args.cache, shuffle=True
87+
)
88+
else:
89+
train_dataset = ASRSliceDataset(
90+
speech_featurizer=speech_featurizer,
91+
text_featurizer=text_featurizer,
92+
data_paths=config.learning_config.dataset_config.train_paths,
93+
augmentations=config.learning_config.augmentations,
94+
stage="train", cache=args.cache, shuffle=True
95+
)
96+
eval_dataset = ASRSliceDataset(
97+
speech_featurizer=speech_featurizer,
98+
text_featurizer=text_featurizer,
99+
data_paths=config.learning_config.dataset_config.eval_paths,
100+
stage="eval", cache=args.cache, shuffle=True
101+
)
102+
103+
ctc_trainer = CTCTrainerGA(text_featurizer, config.learning_config.running_config)
104+
# Build DS2 model
105+
with ctc_trainer.strategy.scope():
106+
jasper = Jasper(**config.model_config, vocabulary_size=text_featurizer.num_classes)
107+
jasper._build(speech_featurizer.shape)
108+
jasper.summary(line_length=120)
109+
# Compile
110+
ctc_trainer.compile(jasper, config.learning_config.optimizer_config,
111+
max_to_keep=args.max_ckpts)
112+
113+
ctc_trainer.fit(train_dataset, eval_dataset,
114+
train_bs=args.tbs, eval_bs=args.ebs, train_acs=args.acs)

tensorflow_asr/optimizers/accumulation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,16 @@ def __init__(self, trainable_variables):
2121
tf.Variable(
2222
tf.zeros_like(g),
2323
trainable=False,
24-
synchronization=tf.VariableSynchronization.ON_READ
24+
synchronization=tf.VariableSynchronization.ON_READ,
25+
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA
2526
) for g in trainable_variables
2627
]
2728

2829
def reset(self):
2930
for i, g in enumerate(self.gradients):
30-
self.gradients[i].assign(tf.zeros_like(g))
31+
self.gradients[i].assign(tf.zeros_like(g), read_value=False)
3132

3233
def accumulate(self, step_gradients):
3334
for i, g in enumerate(step_gradients):
3435
if g is None: continue
35-
self.gradients[i].assign_add(g)
36+
self.gradients[i].assign_add(g, read_value=False)

tensorflow_asr/runners/ctc_runners.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ..featurizers.text_featurizers import TextFeaturizer
2020
from ..losses.ctc_losses import ctc_loss
2121
from .base_runners import BaseTrainer
22+
from ..optimizers.accumulation import GradientAccumulation
2223

2324

2425
class CTCTrainer(BaseTrainer):
@@ -89,3 +90,49 @@ def compile(self, model: tf.keras.Model,
8990
self.model = model
9091
self.optimizer = tf.keras.optimizers.get(optimizer)
9192
self.create_checkpoint_manager(max_to_keep, model=self.model, optimizer=self.optimizer)
93+
94+
95+
class CTCTrainerGA(CTCTrainer):
96+
""" Trainer for CTC Models """
97+
98+
@tf.function
99+
def _train_function(self, iterator):
100+
for _ in range(self.config.accumulation_steps):
101+
batch = next(iterator)
102+
self.strategy.run(self._train_step, args=(batch,))
103+
self.strategy.run(self._apply_gradients, args=())
104+
105+
@tf.function
106+
def _apply_gradients(self):
107+
self.optimizer.apply_gradients(
108+
zip(self.accumulation.gradients, self.model.trainable_variables))
109+
self.accumulation.reset()
110+
111+
@tf.function(experimental_relax_shapes=True)
112+
def _train_step(self, batch):
113+
_, features, input_length, labels, label_length, _ = batch
114+
115+
with tf.GradientTape() as tape:
116+
y_pred = self.model(features, training=True)
117+
tape.watch(y_pred)
118+
per_train_loss = ctc_loss(
119+
y_true=labels, y_pred=y_pred,
120+
input_length=(input_length // self.model.time_reduction_factor),
121+
label_length=label_length,
122+
blank=self.text_featurizer.blank
123+
)
124+
train_loss = tf.nn.compute_average_loss(per_train_loss,
125+
global_batch_size=self.global_batch_size)
126+
127+
gradients = tape.gradient(train_loss, self.model.trainable_variables)
128+
self.accumulation.accumulate(gradients)
129+
self.train_metrics["ctc_loss"].update_state(per_train_loss)
130+
131+
def compile(self, model: tf.keras.Model,
132+
optimizer: any,
133+
max_to_keep: int = 10):
134+
with self.strategy.scope():
135+
self.model = model
136+
self.optimizer = tf.keras.optimizers.get(optimizer)
137+
self.create_checkpoint_manager(max_to_keep, model=self.model, optimizer=self.optimizer)
138+
self.accumulation = GradientAccumulation(self.model.trainable_variables)

0 commit comments

Comments
 (0)