@@ -307,12 +307,17 @@ def __init__(self,
307307
308308 if positional_encoding == "sinusoid" :
309309 self .pe = PositionalEncoding (name = f"{ name } _pe" )
310+ elif positional_encoding == "sinusoid_v2" :
311+ self .pe = PositionalEncoding (alpha = 2 , beta = 0 , name = f"{ name } _pe" )
310312 elif positional_encoding == "sinusoid_concat" :
311313 self .pe = PositionalEncodingConcat (name = f"{ name } _pe" )
314+ elif positional_encoding == "sinusoid_concat_v2" :
315+ self .pe = PositionalEncodingConcat (alpha = 2 , beta = - 1 , name = f"{ name } _pe" )
312316 elif positional_encoding == "subsampling" :
313317 self .pe = tf .keras .layers .Activation ("linear" , name = f"{ name } _pe" )
314318 else :
315- raise ValueError ("positional_encoding must be either 'sinusoid' or 'subsampling'" )
319+ raise ValueError ("positional_encoding must be either 'sinusoid', \
320+ 'sinusoid_concat', 'sinusoid_v2', 'sinusoid_concat_v2' or 'subsampling'" )
316321
317322 self .linear = tf .keras .layers .Dense (
318323 dmodel , name = f"{ name } _linear" ,
@@ -373,6 +378,7 @@ def __init__(self,
373378 encoder_depth_multiplier : int = 1 ,
374379 encoder_fc_factor : float = 0.5 ,
375380 encoder_dropout : float = 0 ,
381+ encoder_trainable : bool = True ,
376382 prediction_embed_dim : int = 512 ,
377383 prediction_embed_dropout : int = 0 ,
378384 prediction_num_rnns : int = 1 ,
@@ -381,12 +387,16 @@ def __init__(self,
381387 prediction_rnn_implementation : int = 2 ,
382388 prediction_layer_norm : bool = True ,
383389 prediction_projection_units : int = 0 ,
390+ prediction_trainable : bool = True ,
384391 joint_dim : int = 1024 ,
385392 joint_activation : str = "tanh" ,
386393 prejoint_linear : bool = True ,
394+ postjoint_linear : bool = False ,
395+ joint_mode : str = "add" ,
396+ joint_trainable : bool = True ,
387397 kernel_regularizer = L2 ,
388398 bias_regularizer = L2 ,
389- name : str = "conformer_transducer " ,
399+ name : str = "conformer " ,
390400 ** kwargs ):
391401 super (Conformer , self ).__init__ (
392402 encoder = ConformerEncoder (
@@ -402,7 +412,9 @@ def __init__(self,
402412 fc_factor = encoder_fc_factor ,
403413 dropout = encoder_dropout ,
404414 kernel_regularizer = kernel_regularizer ,
405- bias_regularizer = bias_regularizer
415+ bias_regularizer = bias_regularizer ,
416+ trainable = encoder_trainable ,
417+ name = f"{ name } _encoder"
406418 ),
407419 vocabulary_size = vocabulary_size ,
408420 embed_dim = prediction_embed_dim ,
@@ -413,12 +425,17 @@ def __init__(self,
413425 rnn_implementation = prediction_rnn_implementation ,
414426 layer_norm = prediction_layer_norm ,
415427 projection_units = prediction_projection_units ,
428+ prediction_trainable = prediction_trainable ,
416429 joint_dim = joint_dim ,
417430 joint_activation = joint_activation ,
418431 prejoint_linear = prejoint_linear ,
432+ postjoint_linear = postjoint_linear ,
433+ joint_mode = joint_mode ,
434+ joint_trainable = joint_trainable ,
419435 kernel_regularizer = kernel_regularizer ,
420436 bias_regularizer = bias_regularizer ,
421- name = name , ** kwargs
437+ name = name ,
438+ ** kwargs
422439 )
423440 self .dmodel = encoder_dmodel
424441 self .time_reduction_factor = self .encoder .conv_subsampling .time_reduction_factor
0 commit comments