Skip to content

Commit bc825a3

Browse files
committed
✍️ fix subword featurizer iextract
1 parent 4c9255c commit bc825a3

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

tensorflow_asr/datasets/base_dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,13 @@ def __init__(self,
3737
stage: str = "train",
3838
**kwargs):
3939
self.data_paths = data_paths or []
40+
if not isinstance(self.data_paths, list):
41+
raise ValueError('data_paths must be a list of string paths')
4042
self.augmentations = augmentations # apply augmentation
4143
self.cache = cache # whether to cache transformed dataset to memory
4244
self.shuffle = shuffle # whether to shuffle tf.data.Dataset
43-
if buffer_size <= 0 and shuffle: raise ValueError("buffer_size must be positive when shuffle is on")
45+
if buffer_size <= 0 and shuffle:
46+
raise ValueError("buffer_size must be positive when shuffle is on")
4447
self.buffer_size = buffer_size # shuffle buffer size
4548
self.stage = stage # for defining tfrecords files
4649
self.use_tf = use_tf

tensorflow_asr/featurizers/text_featurizers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,10 @@ def iextract(self, indices: tf.Tensor) -> tf.Tensor:
295295
def cond(batch, total, _): return tf.less(batch, total)
296296

297297
def body(batch, total, transcripts):
298-
upoints = self.indices2upoints(indices[batch])
299-
transcripts = transcripts.write(batch, tf.strings.unicode_encode(upoints, "UTF-8"))
298+
norm_indices = self.normalize_indices(indices[batch])
299+
norm_indices = tf.gather_nd(norm_indices, tf.where(tf.not_equal(norm_indices, 0)))
300+
decoded = tf.numpy_function(self.subwords.decode, inp=[norm_indices], Tout=tf.string)
301+
transcripts = transcripts.write(batch, decoded)
300302
return batch + 1, total, transcripts
301303

302304
_, _, transcripts = tf.while_loop(cond, body, loop_vars=[batch, total, transcripts])

tensorflow_asr/models/transducer.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -456,14 +456,11 @@ def _perform_greedy_batch(self,
456456
encoded: tf.Tensor,
457457
encoded_length: tf.Tensor,
458458
parallel_iterations: int = 10,
459-
swap_memory: bool = False,
460-
version: str = 'v1'):
459+
swap_memory: bool = False):
461460
with tf.name_scope(f"{self.name}_perform_greedy_batch"):
462461
total_batch = tf.shape(encoded)[0]
463462
batch = tf.constant(0, dtype=tf.int32)
464463

465-
greedy_fn = self._perform_greedy if version == 'v1' else self._perform_greedy_v2
466-
467464
decoded = tf.TensorArray(
468465
dtype=tf.int32, size=total_batch, dynamic_size=False,
469466
clear_after_read=False, element_shape=tf.TensorShape([None])
@@ -472,7 +469,7 @@ def _perform_greedy_batch(self,
472469
def condition(batch, _): return tf.less(batch, total_batch)
473470

474471
def body(batch, decoded):
475-
hypothesis = greedy_fn(
472+
hypothesis = self._perform_greedy(
476473
encoded=encoded[batch],
477474
encoded_length=encoded_length[batch],
478475
predicted=tf.constant(self.text_featurizer.blank, dtype=tf.int32),

0 commit comments

Comments
 (0)