Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 2a07e8f

Browse files
T2T TeamRyan Sepassi
authored andcommitted
Factor out common audio feature extraction and apply it to Librispeech dataset.
PiperOrigin-RevId: 179931584
1 parent ee947c9 commit 2a07e8f

File tree

3 files changed

+367
-169
lines changed

3 files changed

+367
-169
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
'gym',
3131
'numpy',
3232
'requests',
33+
'scipy',
3334
'sympy',
3435
'six',
3536
],

tensor2tensor/data_generators/librispeech.py

Lines changed: 34 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,14 @@
1616
"""Librispeech dataset."""
1717

1818
import os
19-
from subprocess import call
2019
import tarfile
21-
import wave
2220

2321
# Dependency imports
2422

25-
import numpy as np
26-
2723
from tensor2tensor.data_generators import generator_utils
28-
from tensor2tensor.data_generators import problem
29-
from tensor2tensor.data_generators import text_encoder
30-
from tensor2tensor.layers import common_layers
31-
from tensor2tensor.utils import modality
24+
from tensor2tensor.data_generators import speech_recognition
3225
from tensor2tensor.utils import registry
3326

34-
import tensorflow as tf
35-
3627

3728
_LIBRISPEECH_TRAIN_DATASETS = [
3829
[
@@ -86,130 +77,13 @@ def _collect_data(directory, input_ext, transcription_ext):
8677
return data_files
8778

8879

89-
def _get_audio_data(filepath):
90-
# Construct a true .wav file.
91-
out_filepath = filepath.strip(".flac") + ".wav"
92-
# Assumes sox is installed on system. Sox converts from FLAC to WAV.
93-
call(["sox", filepath, out_filepath])
94-
wav_file = wave.open(open(out_filepath))
95-
frame_count = wav_file.getnframes()
96-
byte_array = wav_file.readframes(frame_count)
97-
98-
data = np.fromstring(byte_array, np.uint8).tolist()
99-
return data, frame_count, wav_file.getsampwidth(), wav_file.getnchannels()
100-
101-
102-
class LibrispeechTextEncoder(text_encoder.TextEncoder):
103-
104-
def encode(self, s):
105-
return [self._num_reserved_ids + ord(c) for c in s]
106-
107-
def decode(self, ids):
108-
"""Transform a sequence of int ids into a human-readable string.
109-
110-
EOS is not expected in ids.
111-
112-
Args:
113-
ids: list of integers to be converted.
114-
Returns:
115-
s: human-readable string.
116-
"""
117-
decoded_ids = []
118-
for id_ in ids:
119-
if 0 <= id_ < self._num_reserved_ids:
120-
decoded_ids.append(text_encoder.RESERVED_TOKENS[int(id_)])
121-
else:
122-
decoded_ids.append(id_ - self._num_reserved_ids)
123-
return "".join([chr(d) for d in decoded_ids])
124-
125-
126-
@registry.register_audio_modality
127-
class LibrispeechModality(modality.Modality):
128-
"""Performs strided conv compressions for audio spectral data."""
129-
130-
def bottom(self, inputs):
131-
"""Transform input from data space to model space.
132-
133-
Args:
134-
inputs: A Tensor with shape [batch, ...]
135-
Returns:
136-
body_input: A Tensor with shape [batch, ?, ?, body_input_depth].
137-
"""
138-
with tf.variable_scope(self.name):
139-
# TODO(aidangomez): Will need to sort out a better audio pipeline
140-
def xnet_resblock(x, filters, res_relu, name):
141-
with tf.variable_scope(name):
142-
# We only stride along the length dimension to preserve the spectral
143-
# bins (which are tiny in dimensionality relative to length)
144-
y = common_layers.separable_conv_block(
145-
x,
146-
filters, [((1, 1), (3, 3)), ((1, 1), (3, 3))],
147-
first_relu=True,
148-
padding="SAME",
149-
force2d=True,
150-
name="sep_conv_block")
151-
y = common_layers.pool(y, (3, 3), "MAX", "SAME", strides=(2, 1))
152-
return y + common_layers.conv_block(
153-
x,
154-
filters, [((1, 1), (1, 1))],
155-
padding="SAME",
156-
strides=(2, 1),
157-
first_relu=res_relu,
158-
force2d=True,
159-
name="res_conv0")
160-
161-
# Rescale from UINT8 to floats in [-1,-1]
162-
signals = (tf.to_float(inputs)-127)/128.
163-
signals = tf.squeeze(signals, [2, 3])
164-
165-
# `stfts` is a complex64 Tensor representing the short-time Fourier
166-
# Transform of each signal in `signals`. Its shape is
167-
# [batch_size, ?, fft_unique_bins]
168-
# where fft_unique_bins = fft_length // 2 + 1 = 513.
169-
stfts = tf.contrib.signal.stft(signals, frame_length=1024, frame_step=512,
170-
fft_length=1024)
171-
172-
# An energy spectrogram is the magnitude of the complex-valued STFT.
173-
# A float32 Tensor of shape [batch_size, ?, 513].
174-
magnitude_spectrograms = tf.abs(stfts)
175-
176-
# Warp the linear-scale, magnitude spectrograms into the mel-scale.
177-
num_spectrogram_bins = magnitude_spectrograms.shape[-1].value
178-
lower_edge_hertz, upper_edge_hertz, num_mel_bins = 80.0, 7600.0, 64
179-
sample_rate = 16000
180-
linear_to_mel_weight_matrix = (
181-
tf.contrib.signal.linear_to_mel_weight_matrix(
182-
num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz,
183-
upper_edge_hertz))
184-
mel_spectrograms = tf.tensordot(
185-
magnitude_spectrograms, linear_to_mel_weight_matrix, 1)
186-
# Note: Shape inference for tensordot does not currently handle this case.
187-
mel_spectrograms.set_shape(magnitude_spectrograms.shape[:-1].concatenate(
188-
linear_to_mel_weight_matrix.shape[-1:]))
189-
190-
x = tf.expand_dims(mel_spectrograms, 2)
191-
x.set_shape([None, None, None, num_mel_bins])
192-
for i in xrange(self._model_hparams.audio_compression):
193-
x = xnet_resblock(x, 2**(i + 1), True, "compress_block_%d" % i)
194-
return xnet_resblock(x, self._body_input_depth, False,
195-
"compress_block_final")
196-
197-
19880
@registry.register_problem()
199-
class Librispeech(problem.Problem):
200-
"""Problem spec for English word to dictionary definition."""
81+
class Librispeech(speech_recognition.SpeechRecognitionProblem):
82+
"""Problem spec for Librispeech using clean and noisy data."""
20183

202-
@property
203-
def is_character_level(self):
204-
return True
205-
206-
@property
207-
def input_space_id(self):
208-
return problem.SpaceID.AUDIO_SPECTRAL
209-
210-
@property
211-
def target_space_id(self):
212-
return problem.SpaceID.EN_CHR
84+
# Select only the clean data
85+
TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS
86+
DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS
21387

21488
@property
21589
def num_shards(self):
@@ -228,26 +102,8 @@ def use_train_shards_for_dev(self):
228102
"""If true, we only generate training data and hold out shards for dev."""
229103
return False
230104

231-
def feature_encoders(self, _):
232-
return {
233-
"inputs": text_encoder.TextEncoder(),
234-
"targets": LibrispeechTextEncoder(),
235-
}
236-
237-
def example_reading_spec(self):
238-
data_fields = {
239-
"inputs": tf.VarLenFeature(tf.int64),
240-
"targets": tf.VarLenFeature(tf.int64),
241-
}
242-
data_items_to_decoders = None
243-
return (data_fields, data_items_to_decoders)
244-
245-
def generator(self, data_dir, tmp_dir, training,
105+
def generator(self, data_dir, tmp_dir, datasets,
246106
eos_list=None, start_from=0, how_many=0):
247-
eos_list = [1] if eos_list is None else eos_list
248-
datasets = (_LIBRISPEECH_TRAIN_DATASETS if training
249-
else _LIBRISPEECH_TEST_DATASETS)
250-
num_reserved_ids = self.feature_encoders(None)["targets"].num_reserved_ids
251107
i = 0
252108
for url, subdir in datasets:
253109
filename = os.path.basename(url)
@@ -267,44 +123,53 @@ def generator(self, data_dir, tmp_dir, training,
267123
data_dir = os.path.join(tmp_dir, "LibriSpeech", subdir)
268124
data_files = _collect_data(data_dir, "flac", "txt")
269125
data_pairs = data_files.values()
126+
127+
encoders = self.feature_encoders(None)
128+
audio_encoder = encoders["waveforms"]
129+
text_encoder = encoders["targets"]
130+
270131
for media_file, text_data in sorted(data_pairs)[start_from:]:
271132
if how_many > 0 and i == how_many:
272133
return
273134
i += 1
274-
audio_data, sample_count, sample_width, num_channels = _get_audio_data(
275-
media_file)
276-
label = [num_reserved_ids + ord(c) for c in text_data] + eos_list
277135
yield {
278-
"inputs": audio_data,
279-
"audio/channel_count": [num_channels],
280-
"audio/sample_count": [sample_count],
281-
"audio/sample_width": [sample_width],
282-
"targets": label
136+
"waveforms": audio_encoder.encode(media_file),
137+
"targets": text_encoder.encode(text_data)
283138
}
284139

285140
def generate_data(self, data_dir, tmp_dir, task_id=-1):
286141
train_paths = self.training_filepaths(
287142
data_dir, self.num_shards, shuffled=False)
288143
dev_paths = self.dev_filepaths(
289144
data_dir, self.num_dev_shards, shuffled=False)
145+
290146
if self.use_train_shards_for_dev:
291147
all_paths = train_paths + dev_paths
292148
generator_utils.generate_files(
293-
self.generator(data_dir, tmp_dir, True), all_paths)
149+
self.generator(data_dir, tmp_dir, self.TRAIN_DATASETS), all_paths)
294150
generator_utils.shuffle_dataset(all_paths)
295151
else:
296152
generator_utils.generate_dataset_and_shuffle(
297-
self.generator(data_dir, tmp_dir, True), train_paths,
298-
self.generator(data_dir, tmp_dir, False), dev_paths)
153+
self.generator(data_dir, tmp_dir, self.TRAIN_DATASETS), train_paths,
154+
self.generator(data_dir, tmp_dir, self.DEV_DATASETS), dev_paths)
299155

300-
def hparams(self, defaults, unused_model_hparams):
301-
p = defaults
302-
p.stop_at_eos = int(False)
303-
p.input_modality = {"inputs": ("audio:librispeech_modality", None)}
304-
p.target_modality = (registry.Modalities.SYMBOL, 256)
305156

306-
def preprocess_example(self, example, mode, hparams):
307-
return example
157+
@registry.register_problem()
158+
class LibrispeechCleanSmall(Librispeech):
159+
"""Problem spec for Librispeech using 100h clean train data."""
160+
161+
# Select only the clean data
162+
TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS[:1]
163+
DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS[:1]
164+
165+
166+
@registry.register_problem()
167+
class LibrispeechClean(Librispeech):
168+
"""Problem spec for Librispeech using 460h clean train data."""
169+
170+
# Select only the clean data
171+
TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS[:2]
172+
DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS[:1]
308173

309174

310175
# TODO(lukaszkaiser): clean up hparams or remove from here.

0 commit comments

Comments
 (0)