|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import os |
15 | 16 | import argparse |
16 | 17 | from tensorflow_asr.utils import setup_environment, setup_devices |
17 | 18 |
|
|
32 | 33 | parser.add_argument("--blank", type=int, default=0, |
33 | 34 | help="Path to conformer tflite") |
34 | 35 |
|
| 36 | +parser.add_argument("--beam_width", type=int, default=0, help="Beam width") |
| 37 | + |
35 | 38 | parser.add_argument("--num_rnns", type=int, default=1, |
36 | 39 | help="Number of RNN layers in prediction network") |
37 | 40 |
|
|
47 | 50 | parser.add_argument("--cpu", default=False, action="store_true", |
48 | 51 | help="Whether to only use cpu") |
49 | 52 |
|
| 53 | +parser.add_argument("--subwords", type=str, default=None, |
| 54 | + help="Path to file that stores generated subwords") |
| 55 | + |
| 56 | +parser.add_argument("--output_name", type=str, default="test", |
| 57 | + help="Result filename name prefix") |
| 58 | + |
50 | 59 | args = parser.parse_args() |
51 | 60 |
|
52 | 61 | setup_devices([args.device], cpu=args.cpu) |
53 | 62 |
|
54 | 63 | from tensorflow_asr.configs.config import Config |
55 | 64 | from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio |
56 | 65 | from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer |
57 | | -from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer |
| 66 | +from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer, SubwordFeaturizer |
58 | 67 | from tensorflow_asr.models.conformer import Conformer |
59 | 68 |
|
60 | 69 | config = Config(args.config, learning=False) |
61 | 70 | speech_featurizer = TFSpeechFeaturizer(config.speech_config) |
62 | | -text_featurizer = CharFeaturizer(config.decoder_config) |
| 71 | +if args.subwords and os.path.exists(args.subwords): |
| 72 | + print("Loading subwords ...") |
| 73 | + text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) |
| 74 | +else: |
| 75 | + text_featurizer = CharFeaturizer(config.decoder_config) |
| 76 | +text_featurizer.decoder_config.beam_width = args.beam_width |
63 | 77 |
|
64 | 78 | # build model |
65 | 79 | conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) |
|
69 | 83 | conformer.add_featurizers(speech_featurizer, text_featurizer) |
70 | 84 |
|
71 | 85 | signal = read_raw_audio(args.filename) |
72 | | -predicted = tf.constant(args.blank, dtype=tf.int32) |
73 | | -states = tf.zeros([args.num_rnns, args.nstates, 1, args.statesize], dtype=tf.float32) |
74 | 86 |
|
75 | | -hyp, _, _ = conformer.recognize_tflite(signal, predicted, states) |
| 87 | +if (args.beam_width): |
| 88 | + transcript = conformer.recognize_beam(signal[None, ...]) |
| 89 | +else: |
| 90 | + transcript = conformer.recognize(signal[None, ...]) |
76 | 91 |
|
77 | | -print("".join([chr(u) for u in hyp])) |
| 92 | +tf.print("Transcript:", transcript[0]) |
0 commit comments