Skip to content

Commit 5ad52b0

Browse files
committed
Fixes for Whisper stateful/static pipelines
- fix encode reset based on used device in stateful pipeline - fix STATIC_PIPELINE property option usage for Whisper static
1 parent 9d8b239 commit 5ad52b0

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

src/cpp/src/whisper/pipeline_static.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -990,18 +990,22 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
990990
, m_sampler(m_tokenizer) {
991991
ov::Core core = utils::singleton_core();
992992

993-
auto encoder_model = core.read_model(models_path / "openvino_encoder_model.xml", {}, properties);
993+
// Remove "STATIC_PIPELINE" as we don't need to pass it further
994+
auto model_properties = properties;
995+
utils::pop_option(model_properties, "STATIC_PIPELINE");
996+
997+
auto encoder_model = core.read_model(models_path / "openvino_encoder_model.xml", {}, model_properties);
994998
reshape_to_static_encoder(encoder_model, m_feature_extractor.feature_size);
995999
auto last_hidden_state_shape = get_encoder_hidden_state_shape(encoder_model);
9961000

9971001
std::shared_ptr<ov::Model> decoder_model;
9981002
std::shared_ptr<ov::Model> decoder_with_past_model;
9991003

10001004
if (std::filesystem::exists(models_path / "openvino_decoder_with_past_model.xml") ) {
1001-
decoder_model = core.read_model(models_path / "openvino_decoder_model.xml", {}, properties);
1002-
decoder_with_past_model = core.read_model(models_path / "openvino_decoder_with_past_model.xml", {}, properties);
1005+
decoder_model = core.read_model(models_path / "openvino_decoder_model.xml", {}, model_properties);
1006+
decoder_with_past_model = core.read_model(models_path / "openvino_decoder_with_past_model.xml", {}, model_properties);
10031007
} else {
1004-
auto model = core.read_model(models_path / "openvino_decoder_model.xml", {}, properties);
1008+
auto model = core.read_model(models_path / "openvino_decoder_model.xml", {}, model_properties);
10051009
ov::pass::StatefulToStateless().run_on_model(model);
10061010

10071011
decoder_model = prepare_decoder_model(model);
@@ -1030,15 +1034,15 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
10301034
preprocess_decoder(decoder_with_past_model);
10311035

10321036
ov::CompiledModel compiled_model;
1033-
compiled_model = core.compile_model(encoder_model, "NPU", properties);
1037+
compiled_model = core.compile_model(encoder_model, "NPU", model_properties);
10341038
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper encoder model");
10351039
m_models.encoder = compiled_model.create_infer_request();
10361040

1037-
compiled_model = core.compile_model(decoder_with_past_model, "NPU", properties);
1041+
compiled_model = core.compile_model(decoder_with_past_model, "NPU", model_properties);
10381042
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder with past model");
10391043
m_models.decoder_with_past = compiled_model.create_infer_request();
10401044

1041-
compiled_model = core.compile_model(decoder_model, "NPU", properties);
1045+
compiled_model = core.compile_model(decoder_model, "NPU", model_properties);
10421046
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model");
10431047
m_models.decoder = compiled_model.create_infer_request();
10441048

src/cpp/src/whisper/whisper.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,8 @@ ov::Tensor encode(ov::InferRequest& request,
212212
raw_metrics.m_inference_durations[0] += MicroSeconds(infer_ms);
213213

214214
// reset input tensor
215-
auto m_is_npu = true;
216-
uint8_t batch_size = m_is_npu ? 1 : 0;
215+
auto devices = request.get_compiled_model().get_property(ov::execution_devices);
216+
uint8_t batch_size = (devices[0] == "NPU") ? 1 : 0;
217217
request.set_tensor("input_features", ov::Tensor(ov::element::f32, {batch_size, feature_size, nb_max_frames}));
218218

219219
return request.get_tensor("last_hidden_state");

0 commit comments

Comments
 (0)