Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 68 additions & 29 deletions runtime/gpu/model_repo/feature_extractor/1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import _kaldifeat
from typing import List
import json
import kaldi_native_fbank as knf


class Fbank(torch.nn.Module):
Expand Down Expand Up @@ -53,10 +54,13 @@ def initialize(self, args):
# Convert Triton types to numpy types
output0_dtype = pb_utils.triton_string_to_numpy(
output0_config['data_type'])
if output0_dtype == np.float32:
self.output0_dtype = torch.float32
if self.device == "cuda":
if output0_dtype == np.float32:
self.output0_dtype = torch.float32
else:
self.output0_dtype = torch.float16
else:
self.output0_dtype = torch.float16
self.output0_dtype = output0_dtype

# Get OUTPUT1 configuration
output1_config = pb_utils.get_output_config_by_name(
Expand All @@ -66,7 +70,10 @@ def initialize(self, args):
output1_config['data_type'])

params = self.model_config['parameters']
opts = kaldifeat.FbankOptions()
if self.device == "cuda":
opts = kaldifeat.FbankOptions()
else:
opts = knf.FbankOptions()
opts.frame_opts.dither = 0

for li in params.items():
Expand All @@ -80,9 +87,11 @@ def initialize(self, args):
opts.frame_opts.frame_length_ms = float(value)
elif key == "sample_rate":
opts.frame_opts.samp_freq = int(value)
opts.device = torch.device(self.device)
if self.device == "cuda":
opts.device = torch.device(self.device)
self.opts = opts
self.feature_extractor = Fbank(self.opts)
if self.device == "cuda":
self.feature_extractor = Fbank(self.opts)
self.feature_size = opts.mel_opts.num_bins

def execute(self, requests):
Expand All @@ -102,10 +111,12 @@ def execute(self, requests):
A list of pb_utils.InferenceResponse. The length of this list must
be the same as `requests`
"""

batch_count = []
total_waves = []
batch_len = []
responses = []
features = []
for request in requests:
input0 = pb_utils.get_input_tensor_by_name(request, "wav")
input1 = pb_utils.get_input_tensor_by_name(request, "wav_lens")
Expand All @@ -116,38 +127,66 @@ def execute(self, requests):
cur_batch = cur_b_wav.shape[0]
cur_len = cur_b_wav.shape[1]
batch_count.append(cur_batch)
batch_len.append(cur_len)
for wav, wav_len in zip(cur_b_wav, cur_b_wav_lens):
wav_len = wav_len[0]
wav = torch.tensor(wav[0:wav_len],
dtype=torch.float32,
device=self.device)
total_waves.append(wav)

features = self.feature_extractor(total_waves)
if self.device == "cuda":
batch_len.append(cur_len)
for wav, wav_len in zip(cur_b_wav, cur_b_wav_lens):
wav_len = wav_len[0]
wav = torch.tensor(wav[0:wav_len],
dtype=torch.float32,
device=self.device)
total_waves.append(wav)
else:
fea_len = -1
for wav, wav_len in zip(cur_b_wav, cur_b_wav_lens):
feature_extractor_cpu = knf.OnlineFbank(self.opts)
feature_extractor_cpu.accept_waveform(self.opts.frame_opts.samp_freq,
wav[0:wav_len[0]].tolist())
frame_num = feature_extractor_cpu.num_frames_ready
if frame_num > fea_len:
fea_len = frame_num
feature = np.zeros((frame_num, self.feature_size))
for i in range(frame_num):
feature[i] = feature_extractor_cpu.get_frame(i)
features.append(feature)
batch_len.append(fea_len)
if self.device == "cuda":
features = self.feature_extractor(total_waves)
idx = 0
for b, l in zip(batch_count, batch_len):
expect_feat_len = _kaldifeat.num_frames(l, self.opts.frame_opts)
speech = torch.zeros((b, expect_feat_len, self.feature_size),
dtype=self.output0_dtype,
device=self.device)
speech_lengths = torch.zeros((b, 1),
dtype=torch.int32,
device=self.device)
if self.device == "cuda":
expect_feat_len = _kaldifeat.num_frames(l, self.opts.frame_opts)
speech = torch.zeros((b, expect_feat_len, self.feature_size),
dtype=self.output0_dtype,
device=self.device)
speech_lengths = torch.zeros((b, 1),
dtype=torch.int32,
device=self.device)
else:
speech = np.zeros((b, l, self.feature_size), dtype=self.output0_dtype)
speech_lengths = np.zeros((b, 1), dtype=np.int32)
for i in range(b):
f = features[idx]
f_l = f.shape[0]
speech[i, 0:f_l, :] = f.to(self.output0_dtype)
if self.device == "cuda":
speech[i, 0:f_l, :] = f.to(self.output0_dtype)
else:
speech[i, 0:f_l, :] = f.astype(self.output0_dtype)
speech_lengths[i][0] = f_l
idx += 1
# put speech feature on device will cause empty output
# we will follow this issue and now temporarily put it on cpu
speech = speech.cpu()
speech_lengths = speech_lengths.cpu()
out0 = pb_utils.Tensor.from_dlpack("speech", to_dlpack(speech))
out1 = pb_utils.Tensor.from_dlpack("speech_lengths",
to_dlpack(speech_lengths))
if self.device == "cuda":
speech = speech.cpu()
speech_lengths = speech_lengths.cpu()

out0 = pb_utils.Tensor.from_dlpack("speech", to_dlpack(speech))
out1 = pb_utils.Tensor.from_dlpack("speech_lengths",
to_dlpack(speech_lengths))
else:
out0 = pb_utils.Tensor("speech", speech)
out1 = pb_utils.Tensor("speech_lengths", speech_lengths)

inference_response = pb_utils.InferenceResponse(
output_tensors=[out0, out1])
responses.append(inference_response)
return responses
return responses