diff --git a/WORKSPACE b/WORKSPACE index a511e61d2e..4f5e23a9ee 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -636,3 +636,19 @@ cc_library( ) """, ) + +new_git_repository( + name = "dr_libs", + remote = "https://github.com/mackron/dr_libs", + commit = "24d738be2349fd4b6fe50eeaa81f5bd586267fd0", + build_file_content = """ +cc_library( + name = "dr", + hdrs = ["dr_flac.h", "dr_mp3.h", "dr_wav.h"], + visibility = ["//visibility:public"], + local_defines = [ + ], +) +""", +) + diff --git a/demos/audio/README.md b/demos/audio/README.md new file mode 100644 index 0000000000..d91482f0cb --- /dev/null +++ b/demos/audio/README.md @@ -0,0 +1,25 @@ +# Audio endpoints + + +## Audio synthesis + +python export_model.py text2speech --source_model microsoft/speecht5_tts --vocoder microsoft/speecht5_hifigan --weight-format fp16 + +docker run -p 8000:8000 -d -v $(pwd)/models/:/models openvino/model_server --model_name speecht5_tts --model_path /models/microsoft/speecht5_tts --rest_port 8000 + +curl http://localhost/v3/audio/speech -H "Content-Type: application/json" -d "{\"model\": \"speecht5_tts\", \"input\": \"The quick brown fox jumped over the lazy dog.\"}" -o audio.wav + + +## Audio transcription + +python export_model.py speech2text --source_model openai/whisper-large-v2 --weight-format fp16 --target_device GPU + + +docker run -p 8000:8000 -it --device /dev/dri -u 0 -v $(pwd)/models/:/models openvino/model_server --model_name whisper --model_path /models/openai/whisper-large-v2 --rest_port 8000 + + +curl http://localhost/v3/audio/transcriptions -H "Content-Type: multipart/form-data" -F file="@audio.wav" -F model="whisper" + + + + diff --git a/demos/audio/openai_speech2text.py b/demos/audio/openai_speech2text.py new file mode 100644 index 0000000000..cb8b958508 --- /dev/null +++ b/demos/audio/openai_speech2text.py @@ -0,0 +1,33 @@ +# +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pathlib import Path +from openai import OpenAI + +filename = "speech.wav" +url="http://localhost:8125/v3" + + +speech_file_path = Path(__file__).parent / filename +client = OpenAI(base_url=url, api_key="not_used") + +audio_file = open(filename, "rb") +transcript = client.audio.transcriptions.create( + model="openai/whisper-large-v2", + file=audio_file +) + +print(transcript) \ No newline at end of file diff --git a/demos/audio/openai_text2speech.py b/demos/audio/openai_text2speech.py new file mode 100644 index 0000000000..89f0ae8e23 --- /dev/null +++ b/demos/audio/openai_text2speech.py @@ -0,0 +1,36 @@ +# +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pathlib import Path +from openai import OpenAI + +prompt = "Intel Corporation is an American multinational technology company headquartered in Santa Clara, California.[3] Intel designs, manufactures, and sells computer components such as central processing units (CPUs) and related products for business and consumer markets. It was the world's third-largest semiconductor chip manufacturer by revenue in 2024[4] and has been included in the Fortune 500 list of the largest United States corporations by revenue since 2007. It was one of the first companies listed on Nasdaq. Since 2025, it is partially owned by the United States government." +filename = "speech.wav" +url="http://localhost:8125/v3" + + +speech_file_path = Path(__file__).parent / "speech.wav" +client = OpenAI(base_url=url, api_key="not_used") + +with client.audio.speech.with_streaming_response.create( + model="microsoft/speecht5_tts", + voice="unused", + input=prompt +) as response: + response.stream_to_file(speech_file_path) + + +print("Generation finished") \ No newline at end of file diff --git a/demos/common/export_models/export_model.py b/demos/common/export_models/export_model.py index 2f5af4e64d..b5eb85b12e 100644 --- a/demos/common/export_models/export_model.py +++ b/demos/common/export_models/export_model.py @@ -85,8 +85,55 @@ def add_common_arguments(parser): parser_image_generation.add_argument('--max_num_images_per_prompt', type=int, default=0, help='Max allowed number of images client is allowed to request for a given prompt', dest='max_num_images_per_prompt') parser_image_generation.add_argument('--default_num_inference_steps', type=int, default=0, help='Default number of inference steps when not specified by client', dest='default_num_inference_steps') parser_image_generation.add_argument('--max_num_inference_steps', type=int, default=0, help='Max allowed number of inference steps client is allowed to request for a given prompt', dest='max_num_inference_steps') + +parser_text2speech = subparsers.add_parser('text2speech', help='export model for text2speech endpoint') +add_common_arguments(parser_text2speech) +parser_text2speech.add_argument('--num_streams', default=0, type=int, help='The number of parallel execution streams to use for the models in the pipeline.', dest='num_streams') +parser_text2speech.add_argument('--vocoder', type=str, help='The vocoder model to use for text2speech. For example microsoft/speecht5_hifigan', dest='vocoder') + +parser_speech2text = subparsers.add_parser('speech2text', help='export model for speech2text endpoint') +add_common_arguments(parser_speech2text) +parser_speech2text.add_argument('--num_streams', default=0, type=int, help='The number of parallel execution streams to use for the models in the pipeline.', dest='num_streams') args = vars(parser.parse_args()) +tts_graph_template = """ +input_stream: "HTTP_REQUEST_PAYLOAD:input" +output_stream: "HTTP_RESPONSE_PAYLOAD:output" +node { + name: "TtsExecutor" + input_side_packet: "TTS_NODE_RESOURCES:tts_servable" + calculator: "TtsCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.TtsCalculatorOptions]: { + models_path: "{{model_path}}", + plugin_config: '{ "NUM_STREAMS": "{{num_streams|default(1, true)}}" }', + device: "{{target_device|default("CPU", true)}}" + } + } +} +""" + +stt_graph_template = """ +input_stream: "HTTP_REQUEST_PAYLOAD:input" +output_stream: "HTTP_RESPONSE_PAYLOAD:output" +node { + name: "SttExecutor" + input_side_packet: "STT_NODE_RESOURCES:stt_servable" + calculator: "SttCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.SttCalculatorOptions]: { + models_path: "{{model_path}}", + plugin_config: '{ "NUM_STREAMS": "{{num_streams|default(1, true)}}" }', + device: "{{target_device|default("CPU", true)}}" + } + } +} +""" + embedding_graph_ov_template = """ input_stream: "REQUEST_PAYLOAD:input" output_stream: "RESPONSE_PAYLOAD:output" @@ -457,7 +504,34 @@ def export_embeddings_model_ov(model_repository_path, source_model, model_name, with open(os.path.join(model_repository_path, model_name, 'graph.pbtxt'), 'w') as f: f.write(graph_content) print("Created graph {}".format(os.path.join(model_repository_path, model_name, 'graph.pbtxt'))) - add_servable_to_config(config_file_path, model_name, os.path.relpath(os.path.join(model_repository_path, model_name), os.path.dirname(config_file_path))) + +def export_text2speech_model(model_repository_path, source_model, model_name, precision, task_parameters, config_file_path): + destination_path = os.path.join(model_repository_path, model_name) + print("Exporting text2speech model to ",destination_path) + if not os.path.isdir(destination_path) or args['overwrite_models']: + optimum_command = "optimum-cli export openvino --model {} --weight-format {} --trust-remote-code --model-kwargs \"{{\\\"vocoder\\\": \\\"{}\\\"}}\" {}".format(source_model, precision, task_parameters['vocoder'], destination_path) + if os.system(optimum_command): + raise ValueError("Failed to export text2speech model", source_model) + gtemplate = jinja2.Environment(loader=jinja2.BaseLoader).from_string(tts_graph_template) + graph_content = gtemplate.render(model_path="./", **task_parameters) + with open(os.path.join(model_repository_path, model_name, 'graph.pbtxt'), 'w') as f: + f.write(graph_content) + print("Created graph {}".format(os.path.join(model_repository_path, model_name, 'graph.pbtxt'))) + add_servable_to_config(config_file_path, model_name, os.path.relpath( os.path.join(model_repository_path, model_name), os.path.dirname(config_file_path))) + +def export_speech2text_model(model_repository_path, source_model, model_name, precision, task_parameters, config_file_path): + destination_path = os.path.join(model_repository_path, model_name) + print("Exporting speech2text model to ",destination_path) + if not os.path.isdir(destination_path) or args['overwrite_models']: + optimum_command = "optimum-cli export openvino --model {} --weight-format {} --trust-remote-code {}".format(source_model, precision, destination_path) + if os.system(optimum_command): + raise ValueError("Failed to export speech2text model", source_model) + gtemplate = jinja2.Environment(loader=jinja2.BaseLoader).from_string(stt_graph_template) + graph_content = gtemplate.render(model_path="./", **task_parameters) + with open(os.path.join(model_repository_path, model_name, 'graph.pbtxt'), 'w') as f: + f.write(graph_content) + print("Created graph {}".format(os.path.join(model_repository_path, model_name, 'graph.pbtxt'))) + add_servable_to_config(config_file_path, model_name, os.path.relpath( os.path.join(model_repository_path, model_name), os.path.dirname(config_file_path))) def export_rerank_model_ov(model_repository_path, source_model, model_name, precision, task_parameters, config_file_path, max_doc_length): destination_path = os.path.join(model_repository_path, model_name) @@ -585,7 +659,7 @@ def export_image_generation_model(model_repository_path, source_model, model_nam export_text_generation_model(args['model_repository_path'], args['source_model'], args['model_name'], args['precision'], template_parameters, args['config_file_path']) elif args['task'] == 'embeddings_ov': - export_embeddings_model_ov(args['model_repository_path'], args['source_model'], args['model_name'], args['precision'], template_parameters, args['config_file_path'], args['truncate']) + export_embeddings_model_ov(args['model_repository_path'], args['source_model'], args['model_name'], args['precision'], template_parameters) elif args['task'] == 'rerank': export_rerank_model(args['model_repository_path'], args['source_model'], args['model_name'] ,args['precision'], template_parameters, str(args['version']), args['config_file_path'], args['max_doc_length']) @@ -593,6 +667,11 @@ def export_image_generation_model(model_repository_path, source_model, model_nam elif args['task'] == 'rerank_ov': export_rerank_model_ov(args['model_repository_path'], args['source_model'], args['model_name'] ,args['precision'], template_parameters, args['config_file_path'], args['max_doc_length']) +elif args['task'] == 'text2speech': + export_text2speech_model(args['model_repository_path'], args['source_model'], args['model_name'], args['precision'], template_parameters, args['config_file_path']) + +elif args['task'] == 'speech2text': + export_speech2text_model(args['model_repository_path'], args['source_model'], args['model_name'] ,args['precision'], template_parameters, args['config_file_path']) elif args['task'] == 'image_generation': template_parameters = {k: v for k, v in args.items() if k in [ 'ov_cache_dir', diff --git a/src/BUILD b/src/BUILD index f5f04361f9..57539ee51a 100644 --- a/src/BUILD +++ b/src/BUILD @@ -558,6 +558,9 @@ ovms_cc_library( "//conditions:default": [], "//:not_disable_mediapipe" : [ "//src/image_gen:image_gen_calculator", + "//src/audio/speech_to_text:stt_calculator", + "//src/audio/text_to_speech:tts_calculator", + "//src/audio:audio_utils", "//src/image_gen:imagegen_init", "//src/llm:openai_completions_api_handler", "//src/embeddings:embeddingscalculator_ov", diff --git a/src/audio/BUILD b/src/audio/BUILD new file mode 100644 index 0000000000..1956d79a58 --- /dev/null +++ b/src/audio/BUILD @@ -0,0 +1,31 @@ +# +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +load("@mediapipe//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library", "mediapipe_proto_library") +load("//:common_settings.bzl", "ovms_cc_library") + +ovms_cc_library( + name = "audio_utils", + hdrs = ["audio_utils.hpp"], + srcs = ["audio_utils.cpp"], + visibility = ["//visibility:public"], + deps = [ + "//src:libovmslogging", + "//src/port:dr_audio", + "//src:libovmstimer", + ], + alwayslink = 1, +) diff --git a/src/audio/audio_utils.cpp b/src/audio/audio_utils.cpp new file mode 100644 index 0000000000..b3bc6c317d --- /dev/null +++ b/src/audio/audio_utils.cpp @@ -0,0 +1,187 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#define DR_WAV_IMPLEMENTATION +#define DR_MP3_IMPLEMENTATION +#include "src/port/dr_audio.hpp" +#include "audio_utils.hpp" +#include "src/timer.hpp" +#include "src/logging.hpp" +#include +#include +#define PIPELINE_SUPPORTED_SAMPLE_RATE 16000 + +using namespace ovms; + +bool isWavBuffer(const std::string buf) { + // RIFF ref: https://en.wikipedia.org/wiki/Resource_Interchange_File_Format + // WAV ref: https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html + SPDLOG_TRACE("isWavBuffer: buf {}", buf.substr(0, 12)); + if (buf.size() < 12 || buf.substr(0, 4) != "RIFF" || buf.substr(8, 4) != "WAVE") { + return false; + } + + uint32_t chunk_size = *reinterpret_cast(buf.data() + 4); + SPDLOG_TRACE("isWavBuffer: chunk_size {}", chunk_size); + if (chunk_size + 8 != buf.size()) { + return false; + } + + return true; +} +// https://github.com/openvinotoolkit/openvino.genai/blob/8698683535fe32b5e3cb6953000c4e0175841bd3/samples/c/whisper_speech_recognition/whisper_utils.c#L105 +void resample_audio(const float* input, + size_t inputLength, + float inputRate, + float targetRate, + std::vector& output) { + SPDLOG_LOGGER_DEBUG(stt_calculator_logger, "Input file sample rate: {}. Resampling to {} required", inputRate, targetRate); + float ratio = inputRate / targetRate; + + for (size_t i = 0; i < output.size(); i++) { + float src_idx = i * ratio; + size_t idx0 = (size_t)src_idx; + size_t idx1 = idx0 + 1; + + if (idx1 >= inputLength) { + output.data()[i] = input[inputLength - 1]; + } else { + float frac = src_idx - idx0; + output.data()[i] = input[idx0] * (1.0f - frac) + input[idx1] * frac; + } + } +} + +enum : unsigned int { + TENSOR_PREPARATION, + RESAMPLING, + TIMER_END +}; + +std::vector readWav(const std::string_view& wavData) { + Timer timer; + timer.start(TENSOR_PREPARATION); + drwav wav; + auto result = drwav_init_memory(&wav, wavData.data(), wavData.size(), nullptr); + if (result == false) { + throw std::runtime_error("WAV file parsing failed"); + } + if (wav.channels != 1 && wav.channels != 2) { + drwav_uninit(&wav); + throw std::runtime_error("WAV file must be mono or stereo"); + } + + const uint64_t n = + wavData.empty() ? wav.totalPCMFrameCount : wavData.size() / (wav.channels * wav.bitsPerSample / 8ul); + + std::vector pcm16; + pcm16.resize(n * wav.channels); + drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); + drwav_uninit(&wav); + + // convert to mono, float + std::vector pcmf32; + pcmf32.resize(n); + if (wav.channels == 1) { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[i]) / 32768.0f; + } + } else { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[2 * i] + pcm16[2 * i + 1]) / 65536.0f; + } + } + timer.stop(TENSOR_PREPARATION); + auto tensorPreparationTime = (timer.elapsed(TENSOR_PREPARATION)) / 1000; + SPDLOG_LOGGER_DEBUG(stt_calculator_logger, "Tensor preparation time: {} ms size: {}", tensorPreparationTime, pcmf32.size()); + if (wav.sampleRate == PIPELINE_SUPPORTED_SAMPLE_RATE) { + return pcmf32; + } + + timer.start(RESAMPLING); + size_t outputLength = (size_t)(pcmf32.size() * PIPELINE_SUPPORTED_SAMPLE_RATE / wav.sampleRate); + std::vector output(outputLength); + resample_audio(reinterpret_cast(pcmf32.data()), pcmf32.size(), wav.sampleRate, PIPELINE_SUPPORTED_SAMPLE_RATE, output); + timer.stop(RESAMPLING); + auto resamplingTime = (timer.elapsed(RESAMPLING)) / 1000; + SPDLOG_LOGGER_DEBUG(stt_calculator_logger, "Resampling time: {} ms", resamplingTime); + return output; +} +#pragma warning(push) +#pragma warning(disable : 6262) +std::vector readMp3(const std::string_view& mp3Data) { + Timer timer; + timer.start(TENSOR_PREPARATION); + drmp3 mp3; + auto result = drmp3_init_memory(&mp3, mp3Data.data(), mp3Data.size(), nullptr); + if (result == 0) { + throw std::runtime_error("MP3 file parsing failed"); + } + + if (mp3.channels != 1 && mp3.channels != 2) { + drmp3_uninit(&mp3); + throw std::runtime_error("MP3 file must be mono or stereo"); + } + const uint64_t n = mp3.totalPCMFrameCount; + std::vector pcmf32; + pcmf32.resize(n * mp3.channels); + drmp3_read_pcm_frames_f32(&mp3, n, pcmf32.data()); + drmp3_uninit(&mp3); + timer.stop(TENSOR_PREPARATION); + auto tensorPreparationTime = (timer.elapsed(TENSOR_PREPARATION)) / 1000; + SPDLOG_LOGGER_DEBUG(stt_calculator_logger, "Tensor preparation time: {} ms size: {}", tensorPreparationTime, pcmf32.size()); + if (mp3.sampleRate == PIPELINE_SUPPORTED_SAMPLE_RATE) { + return pcmf32; + } + timer.start(RESAMPLING); + size_t outputLength = (size_t)(pcmf32.size() * PIPELINE_SUPPORTED_SAMPLE_RATE / mp3.sampleRate); + std::vector output(outputLength); + resample_audio(reinterpret_cast(pcmf32.data()), pcmf32.size(), mp3.sampleRate, PIPELINE_SUPPORTED_SAMPLE_RATE, output); + timer.stop(RESAMPLING); + auto resamplingTime = (timer.elapsed(RESAMPLING)) / 1000; + SPDLOG_LOGGER_DEBUG(stt_calculator_logger, "Resampling time: {} ms", resamplingTime); + return output; +} + +void prepareAudioOutput(void** ppData, size_t& pDataSize, uint16_t bitsPerSample, size_t speechSize, const float* waveformPtr) { + enum : unsigned int { + OUTPUT_PREPARATION, + TIMER_END + }; + Timer timer; + timer.start(OUTPUT_PREPARATION); + drwav_data_format format; + format.container = drwav_container_riff; + format.format = DR_WAVE_FORMAT_IEEE_FLOAT; + format.channels = 1; + format.sampleRate = 16000; // assume it is always 16 KHz + format.bitsPerSample = bitsPerSample; + drwav wav; + auto waveformSize = speechSize; + size_t totalSamples = waveformSize * format.channels; + auto status = drwav_init_memory_write_sequential_pcm_frames(&wav, ppData, &pDataSize, &format, totalSamples, nullptr); + if (status == DRWAV_FALSE) { + throw std::runtime_error("Failed to write all frames"); + } + drwav_uint64 framesWritten = drwav_write_pcm_frames(&wav, totalSamples, waveformPtr); + if (framesWritten != totalSamples) { + throw std::runtime_error("Failed to write all frames"); + } + drwav_uninit(&wav); + timer.stop(OUTPUT_PREPARATION); + auto outputPreparationTime = (timer.elapsed(OUTPUT_PREPARATION)) / 1000; + SPDLOG_LOGGER_DEBUG(tts_calculator_logger, "Output preparation time: {} ms", outputPreparationTime); +} diff --git a/src/audio/audio_utils.hpp b/src/audio/audio_utils.hpp new file mode 100644 index 0000000000..cbeea8b457 --- /dev/null +++ b/src/audio/audio_utils.hpp @@ -0,0 +1,27 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include +#include +#include + +bool isWavBuffer(const std::string buf); + +std::vector readWav(const std::string_view& wavData); +std::vector readMp3(const std::string_view& mp3Data); +void prepareAudioOutput(void** ppData, size_t& pDataSize, uint16_t bitsPerSample, size_t speechSize, const float* waveformPtr); diff --git a/src/audio/speech_to_text/BUILD b/src/audio/speech_to_text/BUILD new file mode 100644 index 0000000000..664e80dab4 --- /dev/null +++ b/src/audio/speech_to_text/BUILD @@ -0,0 +1,52 @@ +# +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +load("@mediapipe//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library", "mediapipe_proto_library") +load("//:common_settings.bzl", "ovms_cc_library") + +ovms_cc_library( + name = "stt_servable", + hdrs = ["stt_servable.hpp"], + visibility = ["//visibility:public"], + alwayslink = 1, +) + +ovms_cc_library( + name = "stt_calculator", + srcs = ["stt_calculator.cc"], + deps = [ + "@mediapipe//mediapipe/framework:calculator_framework", + "//src:httppayload", + "//src:libovmslogging", + "stt_calculator_cc_proto", + "//src/port:dr_audio", + ":stt_servable", + "//third_party:genai", + "//src/audio:audio_utils", + ], + visibility = ["//visibility:public"], + alwayslink = 1, +) + +mediapipe_proto_library( + name = "stt_calculator_proto", + srcs = ["stt_calculator.proto"], + visibility = ["//visibility:private"], + deps = [ + "@mediapipe//mediapipe/framework:calculator_options_proto", + "@mediapipe//mediapipe/framework:calculator_proto", + ], +) diff --git a/src/audio/speech_to_text/stt_calculator.cc b/src/audio/speech_to_text/stt_calculator.cc new file mode 100644 index 0000000000..5a766d18b1 --- /dev/null +++ b/src/audio/speech_to_text/stt_calculator.cc @@ -0,0 +1,133 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#include + +#pragma warning(push) +#pragma warning(disable : 4005 4309 6001 6385 6386 6326 6011 6246 4456 6246) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/canonical_errors.h" +#pragma GCC diagnostic pop +#pragma warning(pop) + +#include "src/audio/audio_utils.hpp" +#include "src/http_payload.hpp" +#include "src/logging.hpp" +#include +#include + +#pragma warning(push) +#pragma warning(disable : 6001 4324 6385 6386) +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#pragma warning(pop) + +#include "stt_servable.hpp" + +#ifdef _WIN32 +#include +#include +#endif + +using namespace ovms; + +namespace mediapipe { + +const std::string STT_SESSION_SIDE_PACKET_TAG = "STT_NODE_RESOURCES"; + +class SttCalculator : public CalculatorBase { + static const std::string INPUT_TAG_NAME; + static const std::string OUTPUT_TAG_NAME; + +public: + static absl::Status GetContract(CalculatorContract* cc) { + RET_CHECK(!cc->Inputs().GetTags().empty()); + RET_CHECK(!cc->Outputs().GetTags().empty()); + cc->Inputs().Tag(INPUT_TAG_NAME).Set(); + cc->InputSidePackets().Tag(STT_SESSION_SIDE_PACKET_TAG).Set(); // TODO: template? + cc->Outputs().Tag(OUTPUT_TAG_NAME).Set(); + return absl::OkStatus(); + } + + absl::Status Close(CalculatorContext* cc) final { + SPDLOG_LOGGER_DEBUG(stt_calculator_logger, "SpeechToTextCalculator [Node: {} ] Close", cc->NodeName()); + return absl::OkStatus(); + } + + absl::Status Open(CalculatorContext* cc) final { + SPDLOG_LOGGER_DEBUG(stt_calculator_logger, "SpeechToTextCalculator [Node: {}] Open start", cc->NodeName()); + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) final { + SPDLOG_LOGGER_DEBUG(stt_calculator_logger, "SpeechToTextCalculator [Node: {}] Process start", cc->NodeName()); + + SttServableMap pipelinesMap = cc->InputSidePackets().Tag(STT_SESSION_SIDE_PACKET_TAG).Get(); + auto it = pipelinesMap.find(cc->NodeName()); + RET_CHECK(it != pipelinesMap.end()) << "Could not find initialized STT node named: " << cc->NodeName(); + auto pipe = it->second; + + auto payload = cc->Inputs().Tag(INPUT_TAG_NAME).Get(); + + std::unique_ptr output; + if (absl::StartsWith(payload.uri, "/v3/audio/transcriptions")) { + if (payload.multipartParser->hasParseError()) + return absl::InvalidArgumentError("Failed to parse multipart data"); + + std::string_view stream = payload.multipartParser->getFileContentByFieldName("stream"); + if (!stream.empty()) { + return absl::InvalidArgumentError("streaming is not supported"); + } + std::string_view file = payload.multipartParser->getFileContentByFieldName("file"); + if (file.empty()) { + return absl::InvalidArgumentError(absl::StrCat("File parsing fails")); + } + + std::vector rawSpeech; + try { + if (isWavBuffer(std::string(file))) { + SPDLOG_DEBUG("Received file format: wav"); + rawSpeech = readWav(file); + } else { + rawSpeech = readMp3(file); + SPDLOG_DEBUG("Received file format: mp3"); + } + } catch (std::exception&) { + return absl::InvalidArgumentError("Received input file is not valid wav nor mp3 audio file"); + } + std::string result = "{\"text\": \""; + std::unique_lock lock(pipe->sttPipelineMutex); + result += pipe->sttPipeline->generate(rawSpeech); + result.append("\"}"); + output = std::make_unique(result); + } else { + return absl::InvalidArgumentError(absl::StrCat("Unsupported URI: ", payload.uri)); + } + + cc->Outputs().Tag(OUTPUT_TAG_NAME).Add(output.release(), cc->InputTimestamp()); + SPDLOG_LOGGER_DEBUG(stt_calculator_logger, "SpeechToTextCalculator [Node: {}] Process end", cc->NodeName()); + + return absl::OkStatus(); + } +}; + +const std::string SttCalculator::INPUT_TAG_NAME{"HTTP_REQUEST_PAYLOAD"}; +const std::string SttCalculator::OUTPUT_TAG_NAME{"HTTP_RESPONSE_PAYLOAD"}; + +REGISTER_CALCULATOR(SttCalculator); + +} // namespace mediapipe diff --git a/src/audio/speech_to_text/stt_calculator.proto b/src/audio/speech_to_text/stt_calculator.proto new file mode 100644 index 0000000000..90b48d5736 --- /dev/null +++ b/src/audio/speech_to_text/stt_calculator.proto @@ -0,0 +1,34 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +syntax = "proto2"; +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + + +message SttCalculatorOptions { + extend mediapipe.CalculatorOptions { + // https://github.com/google/mediapipe/issues/634 have to be unique in app + // no rule to obtain this + optional SttCalculatorOptions ext = 116423757; + } + + // fields required for GenAI pipeline initialization + required string models_path = 1; + optional string device = 2; + optional string plugin_config = 3; +} diff --git a/src/audio/speech_to_text/stt_servable.hpp b/src/audio/speech_to_text/stt_servable.hpp new file mode 100644 index 0000000000..1bcb4b9d67 --- /dev/null +++ b/src/audio/speech_to_text/stt_servable.hpp @@ -0,0 +1,54 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#pragma once + +#include +#include +#include +#include + +#pragma warning(push) +#pragma warning(disable : 4005 4309 6001 6385 6386 6326 6011 4005 4456 6246) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#include "mediapipe/framework/calculator_graph.h" +#pragma GCC diagnostic pop +#pragma warning(pop) + +#include "openvino/genai/whisper_pipeline.hpp" +#include "openvino/genai/speech_generation/text2speech_pipeline.hpp" +#include "src/audio/speech_to_text/stt_calculator.pb.h" + +namespace ovms { + +struct SttServable { + std::filesystem::path parsedModelsPath; + std::shared_ptr sttPipeline; + std::mutex sttPipelineMutex; + + SttServable(const std::string& modelDir, const std::string& targetDevice, const std::string& graphPath) { + auto fsModelsPath = std::filesystem::path(modelDir); + if (fsModelsPath.is_relative()) { + parsedModelsPath = (std::filesystem::path(graphPath) / fsModelsPath); + } else { + parsedModelsPath = fsModelsPath.string(); + } + sttPipeline = std::make_shared(parsedModelsPath.string(), targetDevice); + } +}; + +using SttServableMap = std::unordered_map>; +} // namespace ovms diff --git a/src/audio/text_to_speech/BUILD b/src/audio/text_to_speech/BUILD new file mode 100644 index 0000000000..36a8239e68 --- /dev/null +++ b/src/audio/text_to_speech/BUILD @@ -0,0 +1,52 @@ +# +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +load("@mediapipe//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library", "mediapipe_proto_library") +load("//:common_settings.bzl", "ovms_cc_library") + +ovms_cc_library( + name = "tts_servable", + hdrs = ["tts_servable.hpp"], + visibility = ["//visibility:public"], + alwayslink = 1, +) + +ovms_cc_library( + name = "tts_calculator", + srcs = ["tts_calculator.cc"], + deps = [ + "@mediapipe//mediapipe/framework:calculator_framework", + "//src:httppayload", + "//src:libovmslogging", + "tts_calculator_cc_proto", + "//src/port:dr_audio", + ":tts_servable", + "//third_party:genai", + "//src/audio:audio_utils", + ], + visibility = ["//visibility:public"], + alwayslink = 1, +) + +mediapipe_proto_library( + name = "tts_calculator_proto", + srcs = ["tts_calculator.proto"], + visibility = ["//visibility:private"], + deps = [ + "@mediapipe//mediapipe/framework:calculator_options_proto", + "@mediapipe//mediapipe/framework:calculator_proto", + ], +) diff --git a/src/audio/text_to_speech/tts_calculator.cc b/src/audio/text_to_speech/tts_calculator.cc new file mode 100644 index 0000000000..b17dcfcce1 --- /dev/null +++ b/src/audio/text_to_speech/tts_calculator.cc @@ -0,0 +1,136 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#include + +#pragma warning(push) +#pragma warning(disable : 4005 4309 6001 6385 6386 6326 6011 6246 4456 6246) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/canonical_errors.h" +#pragma GCC diagnostic pop +#pragma warning(pop) + +#include "src/audio/audio_utils.hpp" +#include "src/http_payload.hpp" +#include "src/logging.hpp" +#include +#include + +#pragma warning(push) +#pragma warning(disable : 6001 4324 6385 6386) +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#pragma warning(pop) + +#include "src/port/dr_audio.hpp" + +#include "tts_servable.hpp" + +#ifdef _WIN32 +#include +#include +#endif + +using namespace ovms; + +namespace mediapipe { + +const std::string TTS_SESSION_SIDE_PACKET_TAG = "TTS_NODE_RESOURCES"; + +class TtsCalculator : public CalculatorBase { + static const std::string INPUT_TAG_NAME; + static const std::string OUTPUT_TAG_NAME; + +public: + static absl::Status GetContract(CalculatorContract* cc) { + RET_CHECK(!cc->Inputs().GetTags().empty()); + RET_CHECK(!cc->Outputs().GetTags().empty()); + cc->Inputs().Tag(INPUT_TAG_NAME).Set(); + cc->InputSidePackets().Tag(TTS_SESSION_SIDE_PACKET_TAG).Set(); // TODO: template? + cc->Outputs().Tag(OUTPUT_TAG_NAME).Set(); + return absl::OkStatus(); + } + + absl::Status Close(CalculatorContext* cc) final { + SPDLOG_LOGGER_DEBUG(tts_calculator_logger, "TtsCalculator [Node: {} ] Close", cc->NodeName()); + return absl::OkStatus(); + } + + absl::Status Open(CalculatorContext* cc) final { + SPDLOG_LOGGER_DEBUG(tts_calculator_logger, "TtsCalculator [Node: {}] Open start", cc->NodeName()); + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) final { + SPDLOG_LOGGER_DEBUG(tts_calculator_logger, "TtsCalculator [Node: {}] Process start", cc->NodeName()); + + TtsServableMap pipelinesMap = cc->InputSidePackets().Tag(TTS_SESSION_SIDE_PACKET_TAG).Get(); + auto it = pipelinesMap.find(cc->NodeName()); + RET_CHECK(it != pipelinesMap.end()) << "Could not find initialized TTS node named: " << cc->NodeName(); + auto pipe = it->second; + + auto payload = cc->Inputs().Tag(INPUT_TAG_NAME).Get(); + + std::unique_ptr output; + if (absl::StartsWith(payload.uri, "/v3/audio/speech")) { + if (payload.parsedJson->HasParseError()) + return absl::InvalidArgumentError("Failed to parse JSON"); + + if (!payload.parsedJson->IsObject()) { + return absl::InvalidArgumentError("JSON body must be an object"); + } + auto inputIt = payload.parsedJson->FindMember("input"); + if (inputIt == payload.parsedJson->MemberEnd()) { + return absl::InvalidArgumentError("input field is missing in JSON body"); + } + if (!inputIt->value.IsString()) { + return absl::InvalidArgumentError("input field is not a string"); + } + auto streamIt = payload.parsedJson->FindMember("stream_format"); + if (streamIt != payload.parsedJson->MemberEnd()) { + return absl::InvalidArgumentError("streaming is not supported"); + } + std::unique_lock lock(pipe->ttsPipelineMutex); + auto generatedSpeech = pipe->ttsPipeline->generate(inputIt->value.GetString()); + auto bitsPerSample = generatedSpeech.speeches[0].get_element_type().bitwidth(); + auto speechSize = generatedSpeech.speeches[0].get_size(); + ov::Tensor cpuTensor(generatedSpeech.speeches[0].get_element_type(), generatedSpeech.speeches[0].get_shape()); + // copy results to release inference request + generatedSpeech.speeches[0].copy_to(cpuTensor); + lock.unlock(); + void* ppData; + size_t pDataSize; + prepareAudioOutput(&ppData, pDataSize, bitsPerSample, speechSize, cpuTensor.data()); + output = std::make_unique(reinterpret_cast(ppData), pDataSize); + // drwav_free(ppData, NULL); TODO: is needed? + } else { + return absl::InvalidArgumentError(absl::StrCat("Unsupported URI: ", payload.uri)); + } + + cc->Outputs().Tag(OUTPUT_TAG_NAME).Add(output.release(), cc->InputTimestamp()); + SPDLOG_LOGGER_DEBUG(tts_calculator_logger, "TtsCalculator [Node: {}] Process end", cc->NodeName()); + + return absl::OkStatus(); + } +}; + +const std::string TtsCalculator::INPUT_TAG_NAME{"HTTP_REQUEST_PAYLOAD"}; +const std::string TtsCalculator::OUTPUT_TAG_NAME{"HTTP_RESPONSE_PAYLOAD"}; + +REGISTER_CALCULATOR(TtsCalculator); + +} // namespace mediapipe diff --git a/src/audio/text_to_speech/tts_calculator.proto b/src/audio/text_to_speech/tts_calculator.proto new file mode 100644 index 0000000000..a75b681cce --- /dev/null +++ b/src/audio/text_to_speech/tts_calculator.proto @@ -0,0 +1,34 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +syntax = "proto2"; +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + + +message TtsCalculatorOptions { + extend mediapipe.CalculatorOptions { + // https://github.com/google/mediapipe/issues/634 have to be unique in app + // no rule to obtain this + optional TtsCalculatorOptions ext = 116423755; + } + + // fields required for GenAI pipeline initialization + required string models_path = 1; + optional string device = 2; + optional string plugin_config = 3; +} diff --git a/src/audio/text_to_speech/tts_servable.hpp b/src/audio/text_to_speech/tts_servable.hpp new file mode 100644 index 0000000000..89fa30e0aa --- /dev/null +++ b/src/audio/text_to_speech/tts_servable.hpp @@ -0,0 +1,54 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#pragma once + +#include +#include +#include +#include + +#pragma warning(push) +#pragma warning(disable : 4005 4309 6001 6385 6386 6326 6011 4005 4456 6246) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#include "mediapipe/framework/calculator_graph.h" +#pragma GCC diagnostic pop +#pragma warning(pop) + +#include "openvino/genai/whisper_pipeline.hpp" +#include "openvino/genai/speech_generation/text2speech_pipeline.hpp" +#include "src/audio/text_to_speech/tts_calculator.pb.h" + +namespace ovms { + +struct TtsServable { + std::filesystem::path parsedModelsPath; + std::shared_ptr ttsPipeline; + std::mutex ttsPipelineMutex; + + TtsServable(const std::string& modelDir, const std::string& targetDevice, const std::string& graphPath) { + auto fsModelsPath = std::filesystem::path(modelDir); + if (fsModelsPath.is_relative()) { + parsedModelsPath = (std::filesystem::path(graphPath) / fsModelsPath); + } else { + parsedModelsPath = fsModelsPath.string(); + } + ttsPipeline = std::make_shared(parsedModelsPath.string(), targetDevice); + } +}; + +using TtsServableMap = std::unordered_map>; +} // namespace ovms diff --git a/src/logging.cpp b/src/logging.cpp index dbf24ab35a..d34c565b84 100644 --- a/src/logging.cpp +++ b/src/logging.cpp @@ -33,6 +33,8 @@ std::shared_ptr capi_logger = std::make_shared(" std::shared_ptr mediapipe_logger = std::make_shared("mediapipe"); std::shared_ptr llm_executor_logger = std::make_shared("llm_executor"); std::shared_ptr llm_calculator_logger = std::make_shared("llm_calculator"); +std::shared_ptr stt_calculator_logger = std::make_shared("stt_calculator"); +std::shared_ptr tts_calculator_logger = std::make_shared("tts_calculator"); std::shared_ptr embeddings_calculator_logger = std::make_shared("embeddings_calculator"); std::shared_ptr rerank_calculator_logger = std::make_shared("rerank_calculator"); #endif @@ -74,6 +76,8 @@ static void register_loggers(const std::string& log_level, std::vectorset_pattern(default_pattern); llm_executor_logger->set_pattern(default_pattern); llm_calculator_logger->set_pattern(default_pattern); + stt_calculator_logger->set_pattern(default_pattern); + tts_calculator_logger->set_pattern(default_pattern); rerank_calculator_logger->set_pattern(default_pattern); embeddings_calculator_logger->set_pattern(default_pattern); #endif @@ -92,6 +96,8 @@ static void register_loggers(const std::string& log_level, std::vectorsinks().push_back(sink); llm_executor_logger->sinks().push_back(sink); llm_calculator_logger->sinks().push_back(sink); + stt_calculator_logger->sinks().push_back(sink); + tts_calculator_logger->sinks().push_back(sink); rerank_calculator_logger->sinks().push_back(sink); embeddings_calculator_logger->sinks().push_back(sink); #endif @@ -111,6 +117,8 @@ static void register_loggers(const std::string& log_level, std::vector capi_logger; extern std::shared_ptr mediapipe_logger; extern std::shared_ptr llm_executor_logger; extern std::shared_ptr llm_calculator_logger; +extern std::shared_ptr stt_calculator_logger; +extern std::shared_ptr tts_calculator_logger; extern std::shared_ptr embeddings_calculator_logger; extern std::shared_ptr rerank_calculator_logger; #endif diff --git a/src/mediapipe_internal/mediapipegraphdefinition.cpp b/src/mediapipe_internal/mediapipegraphdefinition.cpp index 441752784e..c40a8d1087 100644 --- a/src/mediapipe_internal/mediapipegraphdefinition.cpp +++ b/src/mediapipe_internal/mediapipegraphdefinition.cpp @@ -61,6 +61,8 @@ const std::string MediapipeGraphDefinition::SCHEDULER_CLASS_NAME{"Mediapipe"}; const std::string MediapipeGraphDefinition::PYTHON_NODE_CALCULATOR_NAME{"PythonExecutorCalculator"}; const std::string MediapipeGraphDefinition::LLM_NODE_CALCULATOR_NAME{"LLMCalculator"}; const std::string MediapipeGraphDefinition::IMAGE_GEN_CALCULATOR_NAME{"ImageGenCalculator"}; +const std::string MediapipeGraphDefinition::STT_NODE_CALCULATOR_NAME{"SttCalculator"}; +const std::string MediapipeGraphDefinition::TTS_NODE_CALCULATOR_NAME{"TtsCalculator"}; const std::string MediapipeGraphDefinition::EMBEDDINGS_NODE_CALCULATOR_NAME{"EmbeddingsCalculatorOV"}; const std::string MediapipeGraphDefinition::RERANK_NODE_CALCULATOR_NAME{"RerankCalculatorOV"}; @@ -566,6 +568,50 @@ Status MediapipeGraphDefinition::initializeNodes() { rerankServableMap.insert(std::pair>(nodeName, std::move(servable))); rerankServablesCleaningGuard.disableCleaning(); } + if (endsWith(config.node(i).calculator(), STT_NODE_CALCULATOR_NAME)) { + auto& sttServableMap = this->sidePacketMaps.sttServableMap; + ResourcesCleaningGuard sttServablesCleaningGuard(sttServableMap); + if (!config.node(i).node_options().size()) { + SPDLOG_LOGGER_ERROR(modelmanager_logger, "SpeechToText node missing options in graph: {}. ", this->name); + return StatusCode::LLM_NODE_MISSING_OPTIONS; + } + if (config.node(i).name().empty()) { + SPDLOG_LOGGER_ERROR(modelmanager_logger, "SpeechToText node name is missing in graph: {}. ", this->name); + return StatusCode::LLM_NODE_MISSING_NAME; + } + std::string nodeName = config.node(i).name(); + if (sttServableMap.find(nodeName) != sttServableMap.end()) { + SPDLOG_LOGGER_ERROR(modelmanager_logger, "SpeechToText node name: {} already used in graph: {}. ", nodeName, this->name); + return StatusCode::LLM_NODE_NAME_ALREADY_EXISTS; + } + mediapipe::SttCalculatorOptions nodeOptions; + config.node(i).node_options(0).UnpackTo(&nodeOptions); + std::shared_ptr servable = std::make_shared(nodeOptions.models_path(), nodeOptions.device(), mgconfig.getBasePath()); + sttServableMap.insert(std::pair>(nodeName, std::move(servable))); + sttServablesCleaningGuard.disableCleaning(); + } + if (endsWith(config.node(i).calculator(), TTS_NODE_CALCULATOR_NAME)) { + auto& ttsServableMap = this->sidePacketMaps.ttsServableMap; + ResourcesCleaningGuard ttsServablesCleaningGuard(ttsServableMap); + if (!config.node(i).node_options().size()) { + SPDLOG_LOGGER_ERROR(modelmanager_logger, "TextToSpeech node missing options in graph: {}. ", this->name); + return StatusCode::LLM_NODE_MISSING_OPTIONS; + } + if (config.node(i).name().empty()) { + SPDLOG_LOGGER_ERROR(modelmanager_logger, "TextToSpeech node name is missing in graph: {}. ", this->name); + return StatusCode::LLM_NODE_MISSING_NAME; + } + std::string nodeName = config.node(i).name(); + if (ttsServableMap.find(nodeName) != ttsServableMap.end()) { + SPDLOG_LOGGER_ERROR(modelmanager_logger, "TextToSpeech node name: {} already used in graph: {}. ", nodeName, this->name); + return StatusCode::LLM_NODE_NAME_ALREADY_EXISTS; + } + mediapipe::TtsCalculatorOptions nodeOptions; + config.node(i).node_options(0).UnpackTo(&nodeOptions); + std::shared_ptr servable = std::make_shared(nodeOptions.models_path(), nodeOptions.device(), mgconfig.getBasePath()); + ttsServableMap.insert(std::pair>(nodeName, std::move(servable))); + ttsServablesCleaningGuard.disableCleaning(); + } } return StatusCode::OK; } diff --git a/src/mediapipe_internal/mediapipegraphdefinition.hpp b/src/mediapipe_internal/mediapipegraphdefinition.hpp index 1a6e98bfcf..6515faacd4 100644 --- a/src/mediapipe_internal/mediapipegraphdefinition.hpp +++ b/src/mediapipe_internal/mediapipegraphdefinition.hpp @@ -46,6 +46,8 @@ #include "../sidepacket_servable.hpp" #include "../embeddings/embeddings_servable.hpp" #include "../rerank/rerank_servable.hpp" +#include "../audio/speech_to_text/stt_servable.hpp" +#include "../audio/text_to_speech/tts_servable.hpp" namespace ovms { class MediapipeGraphDefinitionUnloadGuard; @@ -62,6 +64,8 @@ struct ImageGenerationPipelines; using PythonNodeResourcesMap = std::unordered_map>; using GenAiServableMap = std::unordered_map>; using RerankServableMap = std::unordered_map>; +using SttServableMap = std::unordered_map>; +using TtsServableMap = std::unordered_map>; using EmbeddingsServableMap = std::unordered_map>; using ImageGenerationPipelinesMap = std::unordered_map>; @@ -71,19 +75,25 @@ struct GraphSidePackets { ImageGenerationPipelinesMap imageGenPipelinesMap; EmbeddingsServableMap embeddingsServableMap; RerankServableMap rerankServableMap; + SttServableMap sttServableMap; + TtsServableMap ttsServableMap; void clear() { pythonNodeResourcesMap.clear(); genAiServableMap.clear(); imageGenPipelinesMap.clear(); embeddingsServableMap.clear(); rerankServableMap.clear(); + sttServableMap.clear(); + ttsServableMap.clear(); } bool empty() { return (pythonNodeResourcesMap.empty() && genAiServableMap.empty() && imageGenPipelinesMap.empty() && embeddingsServableMap.empty() && - rerankServableMap.empty()); + rerankServableMap.empty() && + sttServableMap.empty() && + ttsServableMap.empty()); } }; @@ -124,6 +134,8 @@ class MediapipeGraphDefinition { static const std::string IMAGE_GEN_CALCULATOR_NAME; static const std::string EMBEDDINGS_NODE_CALCULATOR_NAME; static const std::string RERANK_NODE_CALCULATOR_NAME; + static const std::string STT_NODE_CALCULATOR_NAME; + static const std::string TTS_NODE_CALCULATOR_NAME; Status waitForLoaded(std::unique_ptr& unloadGuard, const uint32_t waitForLoadedTimeoutMicroseconds = WAIT_FOR_LOADED_DEFAULT_TIMEOUT_MICROSECONDS); // Pipelines are not versioned and any available definition has constant version equal 1. diff --git a/src/mediapipe_internal/mediapipegraphexecutor.cpp b/src/mediapipe_internal/mediapipegraphexecutor.cpp index aa95bf88ec..80f5b01fa6 100644 --- a/src/mediapipe_internal/mediapipegraphexecutor.cpp +++ b/src/mediapipe_internal/mediapipegraphexecutor.cpp @@ -47,6 +47,8 @@ MediapipeGraphExecutor::MediapipeGraphExecutor( const GenAiServableMap& llmNodeResourcesMap, const EmbeddingsServableMap& embeddingsServableMap, const RerankServableMap& rerankServableMap, + const SttServableMap& sttServableMap, + const TtsServableMap& ttsServableMap, PythonBackend* pythonBackend, MediapipeServableMetricReporter* mediapipeServableMetricReporter) : name(name), @@ -56,7 +58,7 @@ MediapipeGraphExecutor::MediapipeGraphExecutor( outputTypes(std::move(outputTypes)), inputNames(std::move(inputNames)), outputNames(std::move(outputNames)), - sidePacketMaps({pythonNodeResourcesMap, llmNodeResourcesMap, {}, embeddingsServableMap, rerankServableMap}), + sidePacketMaps({pythonNodeResourcesMap, llmNodeResourcesMap, {}, embeddingsServableMap, rerankServableMap, sttServableMap, ttsServableMap}), pythonBackend(pythonBackend), currentStreamTimestamp(STARTING_TIMESTAMP), mediapipeServableMetricReporter(mediapipeServableMetricReporter) {} @@ -88,6 +90,8 @@ const std::string MediapipeGraphExecutor::LLM_SESSION_SIDE_PACKET_TAG = "llm"; const std::string MediapipeGraphExecutor::IMAGE_GEN_SESSION_SIDE_PACKET_TAG = "pipes"; const std::string MediapipeGraphExecutor::EMBEDDINGS_SESSION_SIDE_PACKET_TAG = "embeddings_servable"; const std::string MediapipeGraphExecutor::RERANK_SESSION_SIDE_PACKET_TAG = "rerank_servable"; +const std::string MediapipeGraphExecutor::STT_SESSION_SIDE_PACKET_TAG = "stt_servable"; +const std::string MediapipeGraphExecutor::TTS_SESSION_SIDE_PACKET_TAG = "tts_servable"; const ::mediapipe::Timestamp MediapipeGraphExecutor::STARTING_TIMESTAMP = ::mediapipe::Timestamp(0); } // namespace ovms diff --git a/src/mediapipe_internal/mediapipegraphexecutor.hpp b/src/mediapipe_internal/mediapipegraphexecutor.hpp index b2468f5540..52f56fdf53 100644 --- a/src/mediapipe_internal/mediapipegraphexecutor.hpp +++ b/src/mediapipe_internal/mediapipegraphexecutor.hpp @@ -93,6 +93,8 @@ class MediapipeGraphExecutor { static const std::string IMAGE_GEN_SESSION_SIDE_PACKET_TAG; static const std::string EMBEDDINGS_SESSION_SIDE_PACKET_TAG; static const std::string RERANK_SESSION_SIDE_PACKET_TAG; + static const std::string STT_SESSION_SIDE_PACKET_TAG; + static const std::string TTS_SESSION_SIDE_PACKET_TAG; static const ::mediapipe::Timestamp STARTING_TIMESTAMP; MediapipeGraphExecutor(const std::string& name, const std::string& version, const ::mediapipe::CalculatorGraphConfig& config, @@ -103,6 +105,8 @@ class MediapipeGraphExecutor { const GenAiServableMap& llmNodeResourcesMap, const EmbeddingsServableMap& embeddingsServableMap, const RerankServableMap& rerankServableMap, + const SttServableMap& sttServableMap, + const TtsServableMap& ttsServableMap, PythonBackend* pythonBackend, MediapipeServableMetricReporter* mediapipeServableMetricReporter); MediapipeGraphExecutor(const std::string& name, const std::string& version, const ::mediapipe::CalculatorGraphConfig& config, @@ -151,6 +155,9 @@ class MediapipeGraphExecutor { inputSidePackets[EMBEDDINGS_SESSION_SIDE_PACKET_TAG] = mediapipe::MakePacket(this->sidePacketMaps.embeddingsServableMap).At(STARTING_TIMESTAMP); inputSidePackets[RERANK_SESSION_SIDE_PACKET_TAG] = mediapipe::MakePacket(this->sidePacketMaps.rerankServableMap).At(STARTING_TIMESTAMP); + inputSidePackets[STT_SESSION_SIDE_PACKET_TAG] = mediapipe::MakePacket(this->sidePacketMaps.sttServableMap).At(STARTING_TIMESTAMP); + inputSidePackets[TTS_SESSION_SIDE_PACKET_TAG] = mediapipe::MakePacket(this->sidePacketMaps.ttsServableMap).At(STARTING_TIMESTAMP); + MP_RETURN_ON_FAIL(graph.StartRun(inputSidePackets), std::string("start MediaPipe graph: ") + this->name, StatusCode::MEDIAPIPE_GRAPH_START_ERROR); ::mediapipe::Packet packet; diff --git a/src/port/BUILD b/src/port/BUILD index 3e8f670289..2a64d583a2 100644 --- a/src/port/BUILD +++ b/src/port/BUILD @@ -21,3 +21,10 @@ ovms_cc_library( deps = ["@com_github_tencent_rapidjson//:rapidjson"], visibility = ["//visibility:public",], ) + +ovms_cc_library( + name = "dr_audio", + hdrs = ["dr_audio.hpp"], + deps = ["@dr_libs//:dr"], + visibility = ["//visibility:public",], +) diff --git a/src/port/dr_audio.hpp b/src/port/dr_audio.hpp new file mode 100644 index 0000000000..0152f1e48d --- /dev/null +++ b/src/port/dr_audio.hpp @@ -0,0 +1,25 @@ +#pragma once +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma warning(push) +#pragma warning(disable : 4245 4220) +#include "dr_wav.h" // NOLINT +#pragma warning(pop) +#pragma warning(push) +#pragma warning(disable : 6386 6262) +#include "dr_mp3.h" // NOLINT +#pragma warning(pop) diff --git a/src/port/rapidjson_document.hpp b/src/port/rapidjson_document.hpp index 9b2a64049f..690e60d878 100644 --- a/src/port/rapidjson_document.hpp +++ b/src/port/rapidjson_document.hpp @@ -15,8 +15,6 @@ // limitations under the License. //***************************************************************************** -// Type that holds vector of pairs where first element is chat turn index and second is image tensor -// this way we store information about which image is associated with which chat turn #pragma warning(push) #pragma warning(disable : 6313) #include diff --git a/src/test/mediapipeflow_test.cpp b/src/test/mediapipeflow_test.cpp index 0fe5488560..7952afd3fe 100644 --- a/src/test/mediapipeflow_test.cpp +++ b/src/test/mediapipeflow_test.cpp @@ -2683,7 +2683,7 @@ class MediapipeSerialization : public ::testing::Test { std::vector inputNames, std::vector outputNames, const PythonNodeResourcesMap& pythonNodeResourcesMap, MediapipeServableMetricReporter* mediapipeServableMetricReporter) : - MediapipeGraphExecutor(name, version, config, inputTypes, outputTypes, inputNames, outputNames, pythonNodeResourcesMap, {}, {}, {}, nullptr, mediapipeServableMetricReporter) {} + MediapipeGraphExecutor(name, version, config, inputTypes, outputTypes, inputNames, outputNames, pythonNodeResourcesMap, {}, {}, {}, {}, {}, nullptr, mediapipeServableMetricReporter) {} }; protected: @@ -3931,6 +3931,8 @@ TEST(WhitelistRegistered, MediapipeCalculatorsList) { "SerializationCalculator", "SetLandmarkVisibilityCalculator", "SidePacketToStreamCalculator", + "SttCalculator", + "TtsCalculator", "SplitAffineMatrixVectorCalculator", "SplitClassificationListVectorCalculator", "SplitDetectionVectorCalculator", diff --git a/src/test/pythonnode_test.cpp b/src/test/pythonnode_test.cpp index 5b1c91cbca..b90c5aa027 100644 --- a/src/test/pythonnode_test.cpp +++ b/src/test/pythonnode_test.cpp @@ -1005,7 +1005,7 @@ class MockedMediapipeGraphExecutorPy : public ovms::MediapipeGraphExecutor { const PythonNodeResourcesMap& pythonNodeResourcesMap, PythonBackend* pythonBackend, MediapipeServableMetricReporter* mediapipeServableMetricReporter) : - MediapipeGraphExecutor(name, version, config, inputTypes, outputTypes, inputNames, outputNames, pythonNodeResourcesMap, {}, {}, {}, pythonBackend, mediapipeServableMetricReporter) {} + MediapipeGraphExecutor(name, version, config, inputTypes, outputTypes, inputNames, outputNames, pythonNodeResourcesMap, {}, {}, {}, {}, {}, pythonBackend, mediapipeServableMetricReporter) {} }; TEST_F(PythonFlowTest, SerializePyObjectWrapperToKServeResponse) { diff --git a/src/test/streaming_test.cpp b/src/test/streaming_test.cpp index d7c40b2ce6..02c6feb229 100644 --- a/src/test/streaming_test.cpp +++ b/src/test/streaming_test.cpp @@ -359,7 +359,7 @@ node { this->name, this->version, config, {{"in", mediapipe_packet_type_enum::KFS_REQUEST}}, {{"out", mediapipe_packet_type_enum::KFS_RESPONSE}}, - {"in"}, {"out"}, {}, {}, {}, {}, nullptr, this->reporter.get()}; + {"in"}, {"out"}, {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; // Mock receiving 3 requests and disconnection prepareRequest(this->firstRequest, {{"in", 3.5f}}); @@ -416,7 +416,7 @@ node { this->name, this->version, config, {{"in", mediapipe_packet_type_enum::OVTENSOR}}, {{"out", mediapipe_packet_type_enum::OVTENSOR}}, - {"in"}, {"out"}, {}, {}, {}, {}, nullptr, this->reporter.get()}; + {"in"}, {"out"}, {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; // Mock receiving 3 requests and disconnection prepareRequest(this->firstRequest, {{"in", 3.5f}}); // no timestamp specified, server will assign one @@ -559,7 +559,7 @@ node { this->name, this->version, config, {{"in", mediapipe_packet_type_enum::OVTENSOR}}, {{"out", mediapipe_packet_type_enum::OVTENSOR}}, - {"in"}, {"out"}, {}, {}, {}, {}, nullptr, this->reporter.get()}; + {"in"}, {"out"}, {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; // Mock receiving 3 requests with manually (client) assigned ascending order of timestamp and disconnection prepareRequest(this->firstRequest, {{"in", 3.5f}}, 3); // first request with timestamp 3 @@ -604,7 +604,7 @@ node { this->name, this->version, config, {{"in", mediapipe_packet_type_enum::OVTENSOR}}, {{"out", mediapipe_packet_type_enum::OVTENSOR}}, - {"in"}, {"out"}, {}, {}, {}, {}, nullptr, this->reporter.get()}; + {"in"}, {"out"}, {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; // Mock only 1 request and disconnect immediately prepareRequest(this->firstRequest, {{"in", 3.5f}}); @@ -1243,7 +1243,7 @@ node { {"out3", mediapipe_packet_type_enum::OVTENSOR}}, {"in1", "in2", "in3"}, {"out1", "out2", "out3"}, - {}, {}, {}, {}, nullptr, this->reporter.get()}; + {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; std::promise signalPromise; std::future signalFuture = signalPromise.get_future(); @@ -1295,7 +1295,7 @@ node { {"out3", mediapipe_packet_type_enum::OVTENSOR}}, {"in1", "in2", "in3"}, {"out1", "out2", "out3"}, - {}, {}, {}, {}, nullptr, this->reporter.get()}; + {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; std::promise signalPromise; std::future signalFuture = signalPromise.get_future(); @@ -1330,7 +1330,7 @@ node { this->name, this->version, config, {{"in", mediapipe_packet_type_enum::OVTENSOR}}, {{"out", mediapipe_packet_type_enum::OVTENSOR}}, - {"in"}, {"out"}, {}, {}, {}, {}, nullptr, this->reporter.get()}; + {"in"}, {"out"}, {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; std::promise signalPromise; std::future signalFuture = signalPromise.get_future(); @@ -1364,7 +1364,7 @@ node { this->name, this->version, config, {{"in", mediapipe_packet_type_enum::OVTENSOR}}, {{"out", mediapipe_packet_type_enum::OVTENSOR}}, - {"in"}, {"wrong_name"}, {}, {}, {}, {}, nullptr, this->reporter.get()}; // cannot install observer due to wrong output name (should never happen due to validation) + {"in"}, {"wrong_name"}, {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; // cannot install observer due to wrong output name (should never happen due to validation) EXPECT_CALL(this->stream, Read(_)).Times(0); EXPECT_CALL(this->stream, Write(_, _)).Times(0); @@ -1389,7 +1389,7 @@ node { this->name, this->version, config, {{"in", mediapipe_packet_type_enum::OVTENSOR}}, {{"out", mediapipe_packet_type_enum::OVTENSOR}}, - {"in"}, {"out"}, {}, {}, {}, {}, nullptr, this->reporter.get()}; + {"in"}, {"out"}, {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; prepareRequest(this->firstRequest, {}); EXPECT_CALL(this->stream, Read(_)) @@ -1417,7 +1417,7 @@ node { this->name, this->version, config, {{"in", mediapipe_packet_type_enum::OVTENSOR}}, {{"out", mediapipe_packet_type_enum::OVTENSOR}}, - {"in"}, {"out"}, {}, {}, {}, {}, nullptr, this->reporter.get()}; + {"in"}, {"out"}, {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; std::promise signalPromise; std::future signalFuture = signalPromise.get_future(); @@ -1453,7 +1453,7 @@ node { this->name, this->version, config, {{"in", mediapipe_packet_type_enum::OVTENSOR}}, {{"out", mediapipe_packet_type_enum::OVTENSOR}}, - {"in"}, {"out"}, {}, {}, {}, {}, nullptr, this->reporter.get()}; + {"in"}, {"out"}, {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; prepareRequest(this->firstRequest, {{"in", 3.5f}}); ASSERT_EQ(executor.inferStream(this->firstRequest, this->stream, this->executionContext), StatusCode::MEDIAPIPE_GRAPH_INITIALIZATION_ERROR); @@ -1476,7 +1476,7 @@ node { this->name, this->version, config, {{"in", mediapipe_packet_type_enum::OVTENSOR}}, {{"out", mediapipe_packet_type_enum::OVTENSOR}}, - {"in"}, {"out"}, {}, {}, {}, {}, nullptr, this->reporter.get()}; + {"in"}, {"out"}, {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; // Invalid request - missing data in buffer prepareInvalidRequest(this->firstRequest, {"in"}); // no timestamp specified, server will assign one @@ -1511,7 +1511,7 @@ node { this->name, this->version, config, {{"in", mediapipe_packet_type_enum::OVTENSOR}}, {{"out", mediapipe_packet_type_enum::OVTENSOR}}, - {"in"}, {"out"}, {}, {}, {}, {}, nullptr, this->reporter.get()}; + {"in"}, {"out"}, {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; std::promise signalPromise[3]; std::future signalFuture[3] = { @@ -1558,7 +1558,7 @@ node { this->name, this->version, config, {{"in", mediapipe_packet_type_enum::OVTENSOR}}, {{"out", mediapipe_packet_type_enum::OVTENSOR}}, - {"in"}, {"out"}, {}, {}, {}, {}, nullptr, this->reporter.get()}; + {"in"}, {"out"}, {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; prepareRequest(this->firstRequest, {{"in", 3.5f}}, 0); EXPECT_CALL(this->stream, Read(_)) @@ -1586,7 +1586,7 @@ node { this->name, this->version, config, {{"in", mediapipe_packet_type_enum::OVTENSOR}}, {{"out", mediapipe_packet_type_enum::OVTENSOR}}, - {"in"}, {"out"}, {}, {}, {}, {}, nullptr, this->reporter.get()}; + {"in"}, {"out"}, {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; prepareRequest(this->firstRequest, {{"in", 3.5f}}); setRequestTimestamp(this->firstRequest, std::string("not an int")); @@ -1621,7 +1621,7 @@ node { this->name, this->version, config, {{"in", mediapipe_packet_type_enum::OVTENSOR}}, {{"out", mediapipe_packet_type_enum::OVTENSOR}}, - {"in"}, {"out"}, {}, {}, {}, {}, nullptr, this->reporter.get()}; + {"in"}, {"out"}, {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; // Timestamps not allowed in stream // Expect continuity of operation and response with error message @@ -1663,7 +1663,7 @@ node { this->name, this->version, config, {{"in", mediapipe_packet_type_enum::OVTENSOR}}, {{"out", mediapipe_packet_type_enum::OVTENSOR}}, - {"in"}, {"out"}, {}, {}, {}, {}, nullptr, this->reporter.get()}; + {"in"}, {"out"}, {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; // Allowed in stream for (auto timestamp : std::vector<::mediapipe::Timestamp>{ @@ -1699,7 +1699,7 @@ node { this->name, this->version, config, {{"in", mediapipe_packet_type_enum::OVTENSOR}}, {{"out", mediapipe_packet_type_enum::OVTENSOR}}, - {"in"}, {"out"}, {}, {}, {}, {}, nullptr, this->reporter.get()}; + {"in"}, {"out"}, {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; // Mock receiving 3 requests and disconnection prepareRequestWithParam(this->firstRequest, {{"in", 3.5f}}, {"val", 65}); // request with parameter val @@ -1736,7 +1736,7 @@ node { this->name, this->version, config, {{"in", mediapipe_packet_type_enum::OVTENSOR}}, {{"out", mediapipe_packet_type_enum::OVTENSOR}}, - {"in"}, {"out"}, {}, {}, {}, {}, nullptr, this->reporter.get()}; + {"in"}, {"out"}, {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; // Mock receiving the invalid request and disconnection // Request with invalid param py (special pythons session side packet) @@ -1765,7 +1765,7 @@ node { this->name, this->version, config, {{"in", mediapipe_packet_type_enum::OVTENSOR}}, {{"out", mediapipe_packet_type_enum::OVTENSOR}}, - {"in"}, {"out"}, {}, {}, {}, {}, nullptr, this->reporter.get()}; + {"in"}, {"out"}, {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; prepareRequest(this->firstRequest, {{"in", 3.5f}}); // missing required request param EXPECT_CALL(this->stream, Read(_)).Times(0); @@ -1791,7 +1791,7 @@ node { this->name, this->version, config, {{"in", mediapipe_packet_type_enum::OVTENSOR}}, {{"out", mediapipe_packet_type_enum::OVTENSOR}}, - {"in"}, {"out"}, {}, {}, {}, {}, nullptr, this->reporter.get()}; + {"in"}, {"out"}, {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; // Mock receiving 2 requests and disconnection prepareRequest(this->firstRequest, {{"in", 3.5f}}, std::nullopt, this->name, this->version); // no timestamp specified, server will assign one @@ -1825,7 +1825,7 @@ node { this->name, this->version, config, {{"in", mediapipe_packet_type_enum::OVTENSOR}}, {{"out", mediapipe_packet_type_enum::OVTENSOR}}, - {"in"}, {"out"}, {}, {}, {}, {}, nullptr, this->reporter.get()}; + {"in"}, {"out"}, {}, {}, {}, {}, {}, {}, nullptr, this->reporter.get()}; std::promise signalPromise; std::future signalFuture = signalPromise.get_future();