Skip to content

Commit 7e29a70

Browse files
committed
Added handling of cache_position for stateless Whisper pipeline
1 parent 6b728f5 commit 7e29a70

File tree

8 files changed

+28
-20
lines changed

8 files changed

+28
-20
lines changed

src/cpp/src/utils.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,14 @@ void export_model(ov::CompiledModel& compiled_model, const std::filesystem::path
840840
out.close();
841841
}
842842

843+
bool has_input(const std::shared_ptr<ov::Model>& model, const std::string& name) {
844+
auto inputs = model->inputs();
845+
auto it = std::find_if(inputs.begin(), inputs.end(), [&](const auto& port) {
846+
return port.get_names().count(name) != 0;
847+
});
848+
return it != inputs.end();
849+
}
850+
843851
} // namespace utils
844852
} // namespace genai
845853
} // namespace ov

src/cpp/src/utils.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,11 @@ ov::CompiledModel import_model(const std::filesystem::path& blob_path,
309309
*/
310310
void export_model(ov::CompiledModel& compiled_model, const std::filesystem::path& blob_path);
311311

312+
/**
313+
* @brief Checks if the model has an input with the specified name.
314+
*/
315+
bool has_input(const std::shared_ptr<Model>& model, const std::string& name);
316+
312317
} // namespace utils
313318
} // namespace genai
314319
} // namespace ov

src/cpp/src/whisper/models/statefull_decoder.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include "statefull_decoder.hpp"
55

66
#include "utils.hpp"
7-
#include "whisper/whisper_utils.hpp"
87

98
namespace {
109
void reshape_hidden_states_to_static(std::shared_ptr<ov::Model> model, const ov::PartialShape& lhstates_shape) {
@@ -26,7 +25,7 @@ WhisperStatefullDecoder::WhisperStatefullDecoder(const std::filesystem::path& mo
2625

2726
auto model = core.read_model(models_path / "openvino_decoder_model.xml", {}, properties);
2827

29-
m_has_cache_position = ov::genai::utils::input_exists(model, "cache_position");
28+
m_has_cache_position = utils::has_input(model, "cache_position");
3029

3130
ov::CompiledModel compiled_model;
3231
if (device == "NPU") {

src/cpp/src/whisper/models/with_past_decoder.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,11 @@ WhisperWithPastDecoder::WhisperWithPastDecoder(const std::filesystem::path& mode
8686
"To obtain stateful decoder model use latest `optimum-intel` package:\n"
8787
"pip install optimum-intel@git+https://github.com/huggingface/optimum-intel.git@main\n"
8888
"optimum-cli export openvino --trust-remote-code --model openai/whisper-tiny whisper-tiny");
89+
8990
ov::Core core = utils::singleton_core();
9091

92+
m_has_cache_position = utils::has_input(core.read_model(models_path / "openvino_decoder_model.xml"), "cache_position");
93+
9194
auto compiled_model = core.compile_model(models_path / "openvino_decoder_model.xml", device, properties);
9295
utils::print_compiled_model_properties(compiled_model, "whisper decoder model");
9396
m_request_decoder = compiled_model.create_infer_request();
@@ -110,9 +113,11 @@ void WhisperWithPastDecoder::start_async(const Tensor& encoder_hidden_state,
110113
request.set_tensor("input_ids", input_ids);
111114

112115
if (!is_initial_step) {
113-
ov::Tensor cache_position_tensor = request.get_tensor("cache_position");
114-
cache_position_tensor.set_shape({1});
115-
cache_position_tensor.data<int64_t>()[0] = m_cache_position;
116+
if (m_has_cache_position) {
117+
ov::Tensor cache_position_tensor = request.get_tensor("cache_position");
118+
cache_position_tensor.set_shape({1});
119+
cache_position_tensor.data<int64_t>()[0] = m_cache_position;
120+
}
116121
}
117122

118123
_set_past_key_value(beam_idx);

src/cpp/src/whisper/models/with_past_decoder.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class WhisperWithPastDecoder : public WhisperDecoder {
2626
size_t m_cache_position = 0;
2727
bool m_initial_past_key_value_set = false;
2828
bool m_past_key_value_linked = false;
29+
bool m_has_cache_position = true;
2930

3031
void _set_past_key_value(const Tensor& beam_idx);
3132
};

src/cpp/src/whisper/pipeline_static.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,7 +1006,7 @@ std::shared_ptr<ov::Model> prepare_decoder_model(std::shared_ptr<ov::Model>& mod
10061006
// 3) Expose all states that requires initialization on the first run as outputs
10071007
expose_runtime_states_as_outputs(decoder_model);
10081008
// 4) Remove cache_position input if it exists
1009-
if (ov::genai::utils::input_exists(decoder_model, "cache_position")) {
1009+
if (ov::genai::utils::has_input(decoder_model, "cache_position")) {
10101010
remove_cache_position(decoder_model);
10111011
}
10121012
// 5) Normalize output names - should be done in stateful_to_stateless_transformation
@@ -1023,10 +1023,6 @@ std::shared_ptr<ov::Model> prepare_decoder_with_past_model(std::shared_ptr<ov::M
10231023
normalize_output_key_value_names(decoder_with_past_model);
10241024
expose_runtime_states_as_inputs(decoder_with_past_model);
10251025

1026-
if (!ov::genai::utils::input_exists(decoder_with_past_model, "cache_position")) {
1027-
add_cache_position_input(decoder_with_past_model);
1028-
}
1029-
10301026
decoder_with_past_model->reshape({{"input_ids", ov::PartialShape({-1, 1})}});
10311027
decoder_with_past_model->set_friendly_name("Model6");
10321028

@@ -1066,6 +1062,10 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
10661062
if (!decoder_model || !decoder_with_past_model)
10671063
OPENVINO_THROW("Decoder/decoder_with_past model is not valid !");
10681064

1065+
if (!ov::genai::utils::has_input(decoder_with_past_model, "cache_position")) {
1066+
add_cache_position_input(decoder_with_past_model);
1067+
}
1068+
10691069
add_attention_mask_input(decoder_model, true /* transform_cross_attn */, last_hidden_state_shape[1].get_length());
10701070
// NB: Note, there is no need to transform cross attention for decoder_with_past_model
10711071
// as it accepts only single token and there can't be any padding.

src/cpp/src/whisper/whisper_utils.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,6 @@ int64_t argmax(const ov::Tensor& logits, const size_t batch_idx) {
5757
return out_token;
5858
}
5959

60-
bool input_exists(const std::shared_ptr<ov::Model>& model, const std::string& name) {
61-
auto inputs = model->inputs();
62-
auto it = std::find_if(inputs.begin(), inputs.end(), [&](const auto& port) {
63-
return port.get_names().count(name) != 0;
64-
});
65-
return it != inputs.end();
66-
}
67-
6860
} // namespace utils
6961
} // namespace genai
7062
} // namespace ov

src/cpp/src/whisper/whisper_utils.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ void filter_non_segment_metrics(ov::genai::RawPerfMetrics& raw_metrics,
1919

2020
int64_t argmax(const ov::Tensor& logits, const size_t batch_idx);
2121

22-
bool input_exists(const std::shared_ptr<ov::Model>& model, const std::string& name);
23-
2422
} // namespace utils
2523
} // namespace genai
2624
} // namespace ov

0 commit comments

Comments
 (0)