Skip to content

Commit e44d560

Browse files
committed
🚀 Add ContextNet
1 parent a826084 commit e44d560

File tree

5 files changed

+281
-13
lines changed

5 files changed

+281
-13
lines changed

examples/conformer/masking/trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ def _train_step(self, batch):
1313
mask = create_padding_mask(features, input_length, self.model.time_reduction_factor)
1414

1515
with tf.GradientTape() as tape:
16-
logits = self.model([features, pred_inp], training=True, mask=mask)
16+
logits = self.model([features, input_length, pred_inp, label_length + 1], training=True, mask=mask)
1717
tape.watch(logits)
1818
per_train_loss = rnnt_loss(
1919
logits=logits, labels=labels, label_length=label_length,
20-
logit_length=(input_length // self.model.time_reduction_factor),
20+
logit_length=tf.cast(tf.math.ceil(input_length / self.model.time_reduction_factor), dtype=tf.int32),
2121
blank=self.text_featurizer.blank
2222
)
2323
train_loss = tf.nn.compute_average_loss(per_train_loss,
@@ -37,11 +37,11 @@ def _train_step(self, batch):
3737
mask = create_padding_mask(features, input_length, self.model.time_reduction_factor)
3838

3939
with tf.GradientTape() as tape:
40-
logits = self.model([features, pred_inp], training=True, mask=mask)
40+
logits = self.model([features, input_length, pred_inp, label_length + 1], training=True, mask=mask)
4141
tape.watch(logits)
4242
per_train_loss = rnnt_loss(
4343
logits=logits, labels=labels, label_length=label_length,
44-
logit_length=(input_length // self.model.time_reduction_factor),
44+
logit_length=tf.cast(tf.math.ceil(input_length / self.model.time_reduction_factor), dtype=tf.int32),
4545
blank=self.text_featurizer.blank
4646
)
4747
train_loss = tf.nn.compute_average_loss(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
setuptools.setup(
3939
name="TensorFlowASR",
40-
version="0.4.5",
40+
version="0.5.0",
4141
author="Huy Le Nguyen",
4242
author_email="[email protected]",
4343
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
from typing import List
2+
import tensorflow as tf
3+
from .transducer import Transducer
4+
from ..utils.utils import merge_two_last_dims
5+
6+
L2 = tf.keras.regularizers.l2(1e-6)
7+
8+
9+
def get_activation(activation: str = "silu"):
10+
activation = activation.lower()
11+
if activation in ["silu", "swish"]: return tf.nn.silu
12+
elif activation == "relu": return tf.nn.relu
13+
elif activation == "linear": return tf.keras.activations.linear
14+
else: raise ValueError("activation must be either 'silu', 'swish', 'relu' or 'linear'")
15+
16+
17+
class Reshape(tf.keras.layers.Layer):
18+
def call(self, inputs): return merge_two_last_dims(inputs)
19+
20+
21+
class ResConvModule(tf.keras.layers.Layer):
22+
def __init__(self,
23+
filters: int = 256,
24+
kernel_regularizer = None,
25+
bias_regularizer = None,
26+
**kwargs):
27+
super(ResConvModule, self).__init__(**kwargs)
28+
self.conv = tf.keras.layers.Conv1D(
29+
filters=filters, kernel_size=1, strides=1, padding="same",
30+
kernel_regularizer=kernel_regularizer,
31+
bias_regularizer=bias_regularizer, name=f"{self.name}_conv"
32+
)
33+
self.bn = tf.keras.layers.BatchNormalization(name=f"{self.name}_bn")
34+
35+
def call(self, inputs, training=False, **kwargs):
36+
outputs = self.conv(inputs, training=training)
37+
outputs = self.bn(outputs, training=training)
38+
return outputs
39+
40+
41+
class ConvModule(tf.keras.layers.Layer):
42+
def __init__(self,
43+
kernel_size: int = 3,
44+
strides: int = 1,
45+
filters: int = 256,
46+
activation: str = "silu",
47+
kernel_regularizer = None,
48+
bias_regularizer = None,
49+
**kwargs):
50+
super(ConvModule, self).__init__(**kwargs)
51+
self.conv = tf.keras.layers.SeparableConv1D(
52+
filters=filters, kernel_size=kernel_size, strides=strides, padding="same",
53+
depthwise_regularizer=kernel_regularizer, pointwise_regularizer=kernel_regularizer,
54+
bias_regularizer=bias_regularizer, name=f"{self.name}_conv"
55+
)
56+
self.bn = tf.keras.layers.BatchNormalization(name=f"{self.name}_bn")
57+
self.activation = get_activation(activation)
58+
59+
def call(self, inputs, training=False, **kwargs):
60+
outputs = self.conv(inputs, training=training)
61+
outputs = self.bn(outputs, training=training)
62+
outputs = self.activation(outputs)
63+
return outputs
64+
65+
66+
class SEModule(tf.keras.layers.Layer):
67+
def __init__(self,
68+
kernel_size: int = 3,
69+
strides: int = 1,
70+
filters: int = 256,
71+
activation: str = "silu",
72+
kernel_regularizer = None,
73+
bias_regularizer = None,
74+
**kwargs):
75+
super(SEModule, self).__init__(**kwargs)
76+
self.conv = ConvModule(
77+
kernel_size=kernel_size, strides=strides,
78+
filters=filters, activation=activation,
79+
kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer,
80+
name=f"{self.name}_conv_module"
81+
)
82+
self.activation = get_activation(activation)
83+
self.fc1 = tf.keras.layers.Dense(filters // 8, name=f"{self.name}_fc1")
84+
self.fc2 = tf.keras.layers.Dense(filters, name=f"{self.name}_fc2")
85+
86+
def call(self, inputs, training=False, **kwargs):
87+
features, input_length = inputs
88+
outputs = self.conv(features, training=training)
89+
90+
se = tf.reduce_sum(outputs, axis=1) / tf.expand_dims(tf.cast(input_length, dtype=outputs.dtype), axis=1)
91+
se = self.fc1(se, training=training)
92+
se = self.activation(se)
93+
se = self.fc2(se, training=training)
94+
se = self.activation(se)
95+
se = tf.nn.sigmoid(se)
96+
se = tf.expand_dims(se, axis=1)
97+
98+
outputs = tf.multiply(outputs, se)
99+
return outputs
100+
101+
102+
class ConvBlock(tf.keras.layers.Layer):
103+
def __init__(self,
104+
nlayers: int = 3,
105+
kernel_size: int = 3,
106+
filters: int = 256,
107+
strides: int = 1,
108+
residual: bool = True,
109+
activation: str = 'silu',
110+
kernel_regularizer = None,
111+
bias_regularizer = None,
112+
**kwargs):
113+
super(ConvBlock, self).__init__(**kwargs)
114+
115+
self.dmodel = filters
116+
self.time_reduction_factor = strides
117+
118+
self.convs = []
119+
for i in range(nlayers - 1):
120+
self.convs.append(
121+
ConvModule(
122+
kernel_size=kernel_size, strides=1,
123+
filters=filters, activation=activation,
124+
kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer,
125+
name=f"{self.name}_conv_module_{i}"
126+
)
127+
)
128+
129+
self.last_conv = ConvModule(
130+
kernel_size=kernel_size, strides=strides,
131+
filters=filters, activation=activation,
132+
kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer,
133+
name=f"{self.name}_conv_module_{nlayers - 1}"
134+
)
135+
136+
self.se = SEModule(
137+
kernel_size=kernel_size, strides=1, filters=filters, activation=activation,
138+
kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer,
139+
name=f"{self.name}_se"
140+
)
141+
142+
self.residual = None
143+
if residual:
144+
self.residual = ResConvModule(
145+
filters=filters, kernel_regularizer=kernel_regularizer,
146+
bias_regularizer=bias_regularizer, name=f"{self.name}_residual"
147+
)
148+
149+
self.activation = get_activation(activation)
150+
151+
def call(self, inputs, training=False, **kwargs):
152+
features, input_length = inputs
153+
outputs = features
154+
for conv in self.convs:
155+
outputs = conv(outputs, training=training)
156+
outputs = self.last_conv(outputs, training=training)
157+
input_length = tf.math.ceil(input_length / self.last_conv.strides[0])
158+
outputs = self.se([outputs, input_length], training=training)
159+
if self.residual is not None:
160+
res = self.residual(features, training=training)
161+
outputs = tf.add(outputs, res)
162+
outputs = self.activation(outputs)
163+
return outputs, input_length
164+
165+
166+
class ContextNetEncoder(tf.keras.Model):
167+
def __init__(self,
168+
blocks: List[dict] = [],
169+
kernel_regularizer = None,
170+
bias_regularizer = None,
171+
**kwargs):
172+
super(ContextNetEncoder, self).__init__(**kwargs)
173+
174+
self.reshape = Reshape(name=f"{self.name}_reshape")
175+
176+
self.blocks = []
177+
for config, i in enumerate(blocks):
178+
self.blocks.append(
179+
ConvBlock(**config, kernel_regularizer=kernel_regularizer,
180+
bias_regularizer=bias_regularizer, name=f"{self.name}_block_{i}")
181+
)
182+
183+
def call(self, inputs, training=False, **kwargs):
184+
outputs, input_length = inputs
185+
outputs = self.reshape(outputs)
186+
for block in self.blocks:
187+
outputs, input_length = block([outputs, input_length], training=training)
188+
return outputs
189+
190+
191+
class ContextNet(Transducer):
192+
def __init__(self,
193+
vocabulary_size: int,
194+
encoder_blocks: List[dict],
195+
prediction_embed_dim: int = 512,
196+
prediction_embed_dropout: int = 0,
197+
prediction_num_rnns: int = 1,
198+
prediction_rnn_units: int = 320,
199+
prediction_rnn_type: str = "lstm",
200+
prediction_rnn_implementation: int = 2,
201+
prediction_layer_norm: bool = True,
202+
prediction_projection_units: int = 0,
203+
joint_dim: int = 1024,
204+
kernel_regularizer=L2,
205+
bias_regularizer=L2,
206+
name: str = "contextnet",
207+
**kwargs):
208+
super(ContextNet, self).__init__(
209+
encoder=ContextNetEncoder(
210+
blocks=encoder_blocks,
211+
kernel_regularizer=kernel_regularizer,
212+
bias_regularizer=bias_regularizer,
213+
name=f"{name}_encoder"
214+
),
215+
vocabulary_size=vocabulary_size,
216+
embed_dim=prediction_embed_dim,
217+
embed_dropout=prediction_embed_dropout,
218+
num_rnns=prediction_num_rnns,
219+
rnn_units=prediction_rnn_units,
220+
rnn_type=prediction_rnn_type,
221+
rnn_implementation=prediction_rnn_implementation,
222+
layer_norm=prediction_layer_norm,
223+
projection_units=prediction_projection_units,
224+
joint_dim=joint_dim,
225+
kernel_regularizer=kernel_regularizer,
226+
bias_regularizer=bias_regularizer,
227+
name=name, **kwargs
228+
)
229+
self.dmodel = self.encoder.blocks[-1].dmodel
230+
self.time_reduction_factor = 1
231+
for block in self.encoder.blocks:
232+
self.time_reduction_factor += block.time_reduction_factor
233+
234+
def call(self, inputs, training=False, **kwargs):
235+
"""
236+
Transducer Model call function
237+
Args:
238+
features: audio features in shape [B, T, F, C]
239+
input_length: shape [B]
240+
predicted: predicted sequence of character ids, in shape [B, U]
241+
training: python boolean
242+
**kwargs: sth else
243+
244+
Returns:
245+
`logits` with shape [B, T, U, vocab]
246+
"""
247+
features, input_length, predicted, label_length = inputs
248+
enc = self.encoder([features, input_length], training=training, **kwargs)
249+
pred = self.predict_net([predicted, label_length], training=training, **kwargs)
250+
outputs = self.joint_net([enc, pred], training=training, **kwargs)
251+
return outputs
252+
253+
def encoder_inference(self, features):
254+
"""Infer function for encoder (or encoders)
255+
256+
Args:
257+
features (tf.Tensor): features with shape [T, F, C]
258+
259+
Returns:
260+
tf.Tensor: output of encoders with shape [T, E]
261+
"""
262+
with tf.name_scope(f"{self.name}_encoder"):
263+
input_length = tf.expand_dims(tf.shape(features)[0], axis=0)
264+
outputs = tf.expand_dims(features, axis=0)
265+
outputs = self.encoder([outputs, input_length], training=False)
266+
return tf.squeeze(outputs, axis=0)

tensorflow_asr/models/transducer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,12 @@ def get_initial_state(self):
9898
def call(self, inputs, training=False, **kwargs):
9999
# inputs has shape [B, U]
100100
# use tf.gather_nd instead of tf.gather for tflite conversion
101-
outputs = self.embed(inputs, training=training)
101+
outputs, label_length = inputs
102+
outputs = self.embed(outputs, training=training)
102103
outputs = self.do(outputs, training=training)
103104
for rnn in self.rnns:
104-
outputs = rnn["rnn"](outputs, training=training)
105+
mask = tf.sequence_mask(label_length)
106+
outputs = rnn["rnn"](outputs, training=training, mask=mask)
105107
outputs = outputs[0]
106108
if rnn["ln"] is not None:
107109
outputs = rnn["ln"](outputs, training=training)
@@ -268,9 +270,9 @@ def call(self, inputs, training=False, **kwargs):
268270
Returns:
269271
`logits` with shape [B, T, U, vocab]
270272
"""
271-
features, predicted = inputs
273+
features, _, predicted, label_length = inputs
272274
enc = self.encoder(features, training=training, **kwargs)
273-
pred = self.predict_net(predicted, training=training, **kwargs)
275+
pred = self.predict_net([predicted, label_length], training=training, **kwargs)
274276
outputs = self.joint_net([enc, pred], training=training, **kwargs)
275277
return outputs
276278

tensorflow_asr/runners/transducer_runners.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ def _train_step(self, batch):
4949
_, features, input_length, labels, label_length, pred_inp = batch
5050

5151
with tf.GradientTape() as tape:
52-
logits = self.model([features, pred_inp], training=True)
52+
logits = self.model([features, input_length, pred_inp, label_length + 1], training=True)
5353
tape.watch(logits)
5454
per_train_loss = rnnt_loss(
5555
logits=logits, labels=labels, label_length=label_length,
56-
logit_length=(input_length // self.model.time_reduction_factor),
56+
logit_length=tf.cast(tf.math.ceil(input_length / self.model.time_reduction_factor), dtype=tf.int32),
5757
blank=self.text_featurizer.blank
5858
)
5959
train_loss = tf.nn.compute_average_loss(per_train_loss,
@@ -108,11 +108,11 @@ def _train_step(self, batch):
108108
_, features, input_length, labels, label_length, pred_inp = batch
109109

110110
with tf.GradientTape() as tape:
111-
logits = self.model([features, pred_inp], training=True)
111+
logits = self.model([features, input_length, pred_inp, label_length + 1], training=True)
112112
tape.watch(logits)
113113
per_train_loss = rnnt_loss(
114114
logits=logits, labels=labels, label_length=label_length,
115-
logit_length=(input_length // self.model.time_reduction_factor),
115+
logit_length=tf.cast(tf.math.ceil(input_length / self.model.time_reduction_factor), dtype=tf.int32),
116116
blank=self.text_featurizer.blank
117117
)
118118
train_loss = tf.nn.compute_average_loss(

0 commit comments

Comments
 (0)