diff --git a/runtime/gpu/model_repo/feature_extractor/1/model.py b/runtime/gpu/model_repo/feature_extractor/1/model.py index 0623335c0..c459b9187 100644 --- a/runtime/gpu/model_repo/feature_extractor/1/model.py +++ b/runtime/gpu/model_repo/feature_extractor/1/model.py @@ -6,6 +6,7 @@ import _kaldifeat from typing import List import json +import kaldi_native_fbank as knf class Fbank(torch.nn.Module): @@ -53,10 +54,13 @@ def initialize(self, args): # Convert Triton types to numpy types output0_dtype = pb_utils.triton_string_to_numpy( output0_config['data_type']) - if output0_dtype == np.float32: - self.output0_dtype = torch.float32 + if self.device == "cuda": + if output0_dtype == np.float32: + self.output0_dtype = torch.float32 + else: + self.output0_dtype = torch.float16 else: - self.output0_dtype = torch.float16 + self.output0_dtype = output0_dtype # Get OUTPUT1 configuration output1_config = pb_utils.get_output_config_by_name( @@ -66,7 +70,10 @@ def initialize(self, args): output1_config['data_type']) params = self.model_config['parameters'] - opts = kaldifeat.FbankOptions() + if self.device == "cuda": + opts = kaldifeat.FbankOptions() + else: + opts = knf.FbankOptions() opts.frame_opts.dither = 0 for li in params.items(): @@ -80,9 +87,11 @@ def initialize(self, args): opts.frame_opts.frame_length_ms = float(value) elif key == "sample_rate": opts.frame_opts.samp_freq = int(value) - opts.device = torch.device(self.device) + if self.device == "cuda": + opts.device = torch.device(self.device) self.opts = opts - self.feature_extractor = Fbank(self.opts) + if self.device == "cuda": + self.feature_extractor = Fbank(self.opts) self.feature_size = opts.mel_opts.num_bins def execute(self, requests): @@ -102,10 +111,12 @@ def execute(self, requests): A list of pb_utils.InferenceResponse. The length of this list must be the same as `requests` """ + batch_count = [] total_waves = [] batch_len = [] responses = [] + features = [] for request in requests: input0 = pb_utils.get_input_tensor_by_name(request, "wav") input1 = pb_utils.get_input_tensor_by_name(request, "wav_lens") @@ -116,38 +127,66 @@ def execute(self, requests): cur_batch = cur_b_wav.shape[0] cur_len = cur_b_wav.shape[1] batch_count.append(cur_batch) - batch_len.append(cur_len) - for wav, wav_len in zip(cur_b_wav, cur_b_wav_lens): - wav_len = wav_len[0] - wav = torch.tensor(wav[0:wav_len], - dtype=torch.float32, - device=self.device) - total_waves.append(wav) - - features = self.feature_extractor(total_waves) + if self.device == "cuda": + batch_len.append(cur_len) + for wav, wav_len in zip(cur_b_wav, cur_b_wav_lens): + wav_len = wav_len[0] + wav = torch.tensor(wav[0:wav_len], + dtype=torch.float32, + device=self.device) + total_waves.append(wav) + else: + fea_len = -1 + for wav, wav_len in zip(cur_b_wav, cur_b_wav_lens): + feature_extractor_cpu = knf.OnlineFbank(self.opts) + feature_extractor_cpu.accept_waveform(self.opts.frame_opts.samp_freq, + wav[0:wav_len[0]].tolist()) + frame_num = feature_extractor_cpu.num_frames_ready + if frame_num > fea_len: + fea_len = frame_num + feature = np.zeros((frame_num, self.feature_size)) + for i in range(frame_num): + feature[i] = feature_extractor_cpu.get_frame(i) + features.append(feature) + batch_len.append(fea_len) + if self.device == "cuda": + features = self.feature_extractor(total_waves) idx = 0 for b, l in zip(batch_count, batch_len): - expect_feat_len = _kaldifeat.num_frames(l, self.opts.frame_opts) - speech = torch.zeros((b, expect_feat_len, self.feature_size), - dtype=self.output0_dtype, - device=self.device) - speech_lengths = torch.zeros((b, 1), - dtype=torch.int32, - device=self.device) + if self.device == "cuda": + expect_feat_len = _kaldifeat.num_frames(l, self.opts.frame_opts) + speech = torch.zeros((b, expect_feat_len, self.feature_size), + dtype=self.output0_dtype, + device=self.device) + speech_lengths = torch.zeros((b, 1), + dtype=torch.int32, + device=self.device) + else: + speech = np.zeros((b, l, self.feature_size), dtype=self.output0_dtype) + speech_lengths = np.zeros((b, 1), dtype=np.int32) for i in range(b): f = features[idx] f_l = f.shape[0] - speech[i, 0:f_l, :] = f.to(self.output0_dtype) + if self.device == "cuda": + speech[i, 0:f_l, :] = f.to(self.output0_dtype) + else: + speech[i, 0:f_l, :] = f.astype(self.output0_dtype) speech_lengths[i][0] = f_l idx += 1 # put speech feature on device will cause empty output # we will follow this issue and now temporarily put it on cpu - speech = speech.cpu() - speech_lengths = speech_lengths.cpu() - out0 = pb_utils.Tensor.from_dlpack("speech", to_dlpack(speech)) - out1 = pb_utils.Tensor.from_dlpack("speech_lengths", - to_dlpack(speech_lengths)) + if self.device == "cuda": + speech = speech.cpu() + speech_lengths = speech_lengths.cpu() + + out0 = pb_utils.Tensor.from_dlpack("speech", to_dlpack(speech)) + out1 = pb_utils.Tensor.from_dlpack("speech_lengths", + to_dlpack(speech_lengths)) + else: + out0 = pb_utils.Tensor("speech", speech) + out1 = pb_utils.Tensor("speech_lengths", speech_lengths) + inference_response = pb_utils.InferenceResponse( output_tensors=[out0, out1]) responses.append(inference_response) - return responses + return responses \ No newline at end of file