Skip to content

Commit 5b05799

Browse files
eshiryaeAsyaProninaas-suvorovWovchena
authored
Enable WhisperStatefulImpl for NPU, fix Whisper pipelines for transformers 4.53.3 & 4.55 (#2126)
Ticket: 174805 --------- Co-authored-by: Anastasiya(Asya) Pronina <[email protected]> Co-authored-by: Alexander Suvorov <[email protected]> Co-authored-by: Vladimir Zlobin <[email protected]>
1 parent 1ed328b commit 5b05799

File tree

11 files changed

+205
-35
lines changed

11 files changed

+205
-35
lines changed

src/cpp/src/utils.cpp

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ std::optional<uint32_t> pop_int_and_cast(ov::AnyMap& config, const std::string&
7474
}
7575

7676
void update_npu_config(ov::AnyMap& config,
77-
const std::shared_ptr<ov::Model>& model,
7877
const ov::genai::utils::KVAxesPosition& kv_pos,
7978
const ov::genai::utils::KVDesc& kv_desc) {
8079
update_config(config, {"NPU_USE_NPUW", "YES"});
@@ -97,6 +96,26 @@ void update_npu_config(ov::AnyMap& config,
9796
rename_key(config, "++SHARED_HEAD_CONFIG", "++NPUW_LLM_SHARED_HEAD_CONFIG");
9897
}
9998

99+
void update_npu_config_whisper(ov::AnyMap& config,
100+
const ov::genai::utils::KVAxesPosition& kv_pos,
101+
const ov::genai::utils::KVDesc& kv_desc) {
102+
update_config(config, {"NPU_USE_NPUW", "YES"});
103+
update_config(config, {"NPUW_ONLINE_PIPELINE", "NONE"});
104+
update_config(config, {"NPUW_FUNCALL_FOR_ALL", "NO"});
105+
update_config(config, {"NPUW_FOLD", "NO"});
106+
update_config(config, {"NPUW_LLM", "YES"});
107+
update_config(config, {"NPUW_WHISPER", "YES"});
108+
109+
update_config(config, {"NPUW_LLM_BATCH_DIM", kv_pos.batch});
110+
update_config(config, {"NPUW_LLM_SEQ_LEN_DIM", kv_pos.seq_len});
111+
112+
update_config(config, {"NPUW_LLM_MAX_PROMPT_LEN", kv_desc.max_prompt_len});
113+
update_config(config, {"NPUW_LLM_MIN_RESPONSE_LEN", kv_desc.min_response_len});
114+
115+
// To disable chunking
116+
update_config(config, {"NPUW_LLM_PREFILL_HINT", "STATIC"});
117+
}
118+
100119
inline bool is_paged_attention_available() {
101120
#if defined(OPENVINO_ARCH_X86_64) || defined(OPENVINO_ARCH_ARM64)
102121
return true;
@@ -554,7 +573,8 @@ void print_scheduler_config_info(const SchedulerConfig &scheduler_config) {
554573
std::pair<ov::CompiledModel, KVDesc>
555574
compile_decoder_for_npu(const std::shared_ptr<ov::Model>& model,
556575
const ov::AnyMap& config,
557-
const KVAxesPosition& kv_pos) {
576+
const KVAxesPosition& kv_pos,
577+
const bool is_whisper) {
558578
ov::CompiledModel compiled;
559579
ov::AnyMap properties = config;
560580
KVDesc kv_desc;
@@ -575,9 +595,16 @@ compile_decoder_for_npu(const std::shared_ptr<ov::Model>& model,
575595
kv_desc.max_prompt_len = compiled.get_property("NPUW_LLM_MAX_PROMPT_LEN").as<uint32_t>();
576596
kv_desc.min_response_len = compiled.get_property("NPUW_LLM_MIN_RESPONSE_LEN").as<uint32_t>();
577597
} else {
578-
kv_desc.max_prompt_len = pop_int_and_cast(properties, "MAX_PROMPT_LEN").value_or(1024u);
579-
kv_desc.min_response_len = pop_int_and_cast(properties, "MIN_RESPONSE_LEN").value_or(128u);
580-
update_npu_config(properties, model, kv_pos, kv_desc);
598+
if (is_whisper) {
599+
kv_desc.max_prompt_len = pop_int_and_cast(properties, "MAX_PROMPT_LEN").value_or(4u);
600+
// kvcache size for Whisper = 448u (MAX_PROMPT_LEN + MIN_RESPONSE_LEN)
601+
kv_desc.min_response_len = pop_int_and_cast(properties, "MIN_RESPONSE_LEN").value_or(444u);
602+
update_npu_config_whisper(properties, kv_pos, kv_desc);
603+
} else {
604+
kv_desc.max_prompt_len = pop_int_and_cast(properties, "MAX_PROMPT_LEN").value_or(1024u);
605+
kv_desc.min_response_len = pop_int_and_cast(properties, "MIN_RESPONSE_LEN").value_or(128u);
606+
update_npu_config(properties, kv_pos, kv_desc);
607+
}
581608
compiled = ov::genai::utils::singleton_core().compile_model(model, "NPU", properties);
582609
// Also export compiled model if required
583610
if (export_blob) {
@@ -813,6 +840,14 @@ void export_model(ov::CompiledModel& compiled_model, const std::filesystem::path
813840
out.close();
814841
}
815842

843+
bool has_input(const std::shared_ptr<ov::Model>& model, const std::string& name) {
844+
auto inputs = model->inputs();
845+
auto it = std::find_if(inputs.begin(), inputs.end(), [&](const auto& port) {
846+
return port.get_names().count(name) != 0;
847+
});
848+
return it != inputs.end();
849+
}
850+
816851
} // namespace utils
817852
} // namespace genai
818853
} // namespace ov

src/cpp/src/utils.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,8 @@ struct KVDesc {
193193

194194
std::pair<ov::CompiledModel, KVDesc> compile_decoder_for_npu(const std::shared_ptr<ov::Model>& model,
195195
const ov::AnyMap& config,
196-
const KVAxesPosition& kv_pos);
196+
const KVAxesPosition& kv_pos,
197+
const bool is_whisper = false);
197198

198199
/// @brief SharedOptional is a wrapper around a reference to an existing object and an optional shared alternative value.
199200
/// The difference from std::optional is that the default state is not empty and contains a reference to an existing object outside the class.
@@ -308,6 +309,11 @@ ov::CompiledModel import_model(const std::filesystem::path& blob_path,
308309
*/
309310
void export_model(ov::CompiledModel& compiled_model, const std::filesystem::path& blob_path);
310311

312+
/**
313+
* @brief Checks if the model has an input with the specified name.
314+
*/
315+
bool has_input(const std::shared_ptr<Model>& model, const std::string& name);
316+
311317
} // namespace utils
312318
} // namespace genai
313319
} // namespace ov

src/cpp/src/whisper/models/decoder.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,19 @@
1212
namespace ov::genai {
1313
std::shared_ptr<WhisperDecoder> WhisperDecoder::from_path(const std::filesystem::path& models_path,
1414
const std::string& device,
15-
const ov::AnyMap& properties) {
15+
const ov::AnyMap& properties,
16+
const ov::PartialShape& lhs_shape) {
1617
bool has_decoder_with_past = std::filesystem::exists(models_path / "openvino_decoder_with_past_model.xml");
1718

1819
if (has_decoder_with_past) {
20+
if (device == "NPU") {
21+
OPENVINO_THROW("For NPU, 3-model whisper pipeline works only with STATIC_PIPELINE : YES configuration "
22+
"(which is default for NPU).");
23+
}
1924
return std::make_shared<WhisperWithPastDecoder>(models_path, device, properties);
2025
}
2126

22-
return std::make_shared<WhisperStatefullDecoder>(models_path, device, properties);
27+
return std::make_shared<WhisperStatefullDecoder>(models_path, device, properties, lhs_shape);
2328
}
2429

2530
std::pair<int64_t, float> WhisperDecoder::detect_language(const ov::Tensor& encoder_hidden_state,

src/cpp/src/whisper/models/decoder.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ class WhisperDecoder {
1313
public:
1414
static std::shared_ptr<WhisperDecoder> from_path(const std::filesystem::path& models_path,
1515
const std::string& device,
16-
const ov::AnyMap& properties);
16+
const ov::AnyMap& properties,
17+
const ov::PartialShape& lhs_shape);
1718

1819
std::pair<int64_t, float> detect_language(const Tensor& encoder_hidden_state, const int64_t decoder_start_token_id);
1920

src/cpp/src/whisper/models/statefull_decoder.cpp

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,45 @@
1-
// Copyright (C) 2024 Intel Corporation
1+
// Copyright (C) 2024-2025 Intel Corporation
22
// SPDX-License-Identifier: Apache-2.0
33

44
#include "statefull_decoder.hpp"
55

66
#include "utils.hpp"
77

8+
namespace {
9+
void reshape_hidden_states_to_static(std::shared_ptr<ov::Model> model, const ov::PartialShape& lhstates_shape) {
10+
ov::PartialShape new_shape = model->input("encoder_hidden_states").get_partial_shape();
11+
OPENVINO_ASSERT(new_shape.size() > 1 && lhstates_shape.size() > 1);
12+
new_shape[1] = lhstates_shape[1];
13+
std::map<std::string, ov::PartialShape> name_to_shape{{"encoder_hidden_states", new_shape}};
14+
model->reshape(name_to_shape);
15+
}
16+
17+
} // anonymous
18+
819
namespace ov::genai {
920
WhisperStatefullDecoder::WhisperStatefullDecoder(const std::filesystem::path& models_path,
1021
const std::string& device,
11-
const ov::AnyMap& properties) {
22+
const ov::AnyMap& properties,
23+
const ov::PartialShape& lhs_shape) {
1224
ov::Core core = utils::singleton_core();
1325

1426
auto model = core.read_model(models_path / "openvino_decoder_model.xml", {}, properties);
1527

16-
utils::apply_slice_before_matmul_transformation(model);
28+
m_has_cache_position = utils::has_input(model, "cache_position");
29+
30+
ov::CompiledModel compiled_model;
31+
if (device == "NPU") {
32+
auto kv_pos = ov::genai::utils::get_kv_axes_pos(model);
33+
34+
reshape_hidden_states_to_static(model, lhs_shape);
1735

18-
auto compiled_model = core.compile_model(model, device, properties);
36+
utils::KVDesc kv_desc;
37+
std::tie(compiled_model, kv_desc) = utils::compile_decoder_for_npu(model, properties, kv_pos, true);
38+
} else {
39+
utils::apply_slice_before_matmul_transformation(model);
40+
41+
compiled_model = core.compile_model(model, device, properties);
42+
}
1943

2044
utils::print_compiled_model_properties(compiled_model, "whisper decoder model");
2145
m_request = compiled_model.create_infer_request();
@@ -29,7 +53,9 @@ void WhisperStatefullDecoder::start_async(const Tensor& encoder_hidden_state,
2953

3054
_set_encoder_hidden_states_tensor(encoder_hidden_state, batch_size, m_request);
3155

32-
_set_cache_position_tensor(seq_len);
56+
if (m_has_cache_position) {
57+
_set_cache_position_tensor(seq_len);
58+
}
3359
m_request.set_tensor("input_ids", input_ids);
3460
m_request.set_tensor("beam_idx", beam_idx);
3561

@@ -58,7 +84,9 @@ Tensor WhisperStatefullDecoder::wait() {
5884

5985
void WhisperStatefullDecoder::reset_state() {
6086
m_request.reset_state();
61-
m_request.set_tensor("cache_position", create_host_tensor(ov::element::i64, {0}));
87+
if (m_has_cache_position) {
88+
m_request.set_tensor("cache_position", create_host_tensor(ov::element::i64, {0}));
89+
}
6290

6391
Shape encoder_hidden_states_shape{m_request.get_tensor("encoder_hidden_states").get_shape()};
6492
encoder_hidden_states_shape[0] = 0;

src/cpp/src/whisper/models/statefull_decoder.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (C) 2024 Intel Corporation
1+
// Copyright (C) 2024-2025 Intel Corporation
22
// SPDX-License-Identifier: Apache-2.0
33

44
#pragma once
@@ -12,7 +12,8 @@ class WhisperStatefullDecoder : public WhisperDecoder {
1212
public:
1313
WhisperStatefullDecoder(const std::filesystem::path& models_path,
1414
const std::string& device,
15-
const ov::AnyMap& properties);
15+
const ov::AnyMap& properties,
16+
const ov::PartialShape& lhs_shape);
1617

1718
void start_async(const Tensor& encoder_hidden_state, const Tensor& input_ids, const Tensor& beam_idx) override;
1819

@@ -27,5 +28,6 @@ class WhisperStatefullDecoder : public WhisperDecoder {
2728

2829
private:
2930
ov::InferRequest m_request;
31+
bool m_has_cache_position = true;
3032
};
3133
} // namespace ov::genai

src/cpp/src/whisper/models/with_past_decoder.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,15 @@ WhisperWithPastDecoder::WhisperWithPastDecoder(const std::filesystem::path& mode
8686
"To obtain stateful decoder model use latest `optimum-intel` package:\n"
8787
"pip install optimum-intel@git+https://github.com/huggingface/optimum-intel.git@main\n"
8888
"optimum-cli export openvino --trust-remote-code --model openai/whisper-tiny whisper-tiny");
89+
8990
ov::Core core = utils::singleton_core();
9091

9192
auto compiled_model = core.compile_model(models_path / "openvino_decoder_model.xml", device, properties);
9293
utils::print_compiled_model_properties(compiled_model, "whisper decoder model");
9394
m_request_decoder = compiled_model.create_infer_request();
9495

96+
m_past_decoder_has_cache_position =
97+
utils::has_input(core.read_model(models_path / "openvino_decoder_with_past_model.xml"), "cache_position");
9598
compiled_model = core.compile_model(models_path / "openvino_decoder_with_past_model.xml", device, properties);
9699
utils::print_compiled_model_properties(compiled_model, "whisper decoder with past model");
97100
m_request_decoder_with_past = compiled_model.create_infer_request();
@@ -109,7 +112,7 @@ void WhisperWithPastDecoder::start_async(const Tensor& encoder_hidden_state,
109112
_set_encoder_hidden_states_tensor(encoder_hidden_state, batch_size, request);
110113
request.set_tensor("input_ids", input_ids);
111114

112-
if (!is_initial_step) {
115+
if (!is_initial_step && m_past_decoder_has_cache_position) {
113116
ov::Tensor cache_position_tensor = request.get_tensor("cache_position");
114117
cache_position_tensor.set_shape({1});
115118
cache_position_tensor.data<int64_t>()[0] = m_cache_position;

src/cpp/src/whisper/models/with_past_decoder.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class WhisperWithPastDecoder : public WhisperDecoder {
2626
size_t m_cache_position = 0;
2727
bool m_initial_past_key_value_set = false;
2828
bool m_past_key_value_linked = false;
29+
bool m_past_decoder_has_cache_position = true;
2930

3031
void _set_past_key_value(const Tensor& beam_idx);
3132
};

src/cpp/src/whisper/pipeline.cpp

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,25 @@ ov::InferRequest init_model(ov::CompiledModel& compiled) {
4242
}
4343
}
4444

45+
void reshape_to_static_encoder(std::shared_ptr<ov::Model> model,
46+
const size_t batch_size,
47+
const size_t feature_size) {
48+
std::map<std::string, ov::PartialShape> new_shapes;
49+
for (auto input : model->inputs()) {
50+
const auto& input_name = input.get_any_name();
51+
ov::PartialShape new_shape;
52+
if (input_name.find("input_features") != std::string::npos) {
53+
const auto& partial_shape = input.get_partial_shape();
54+
OPENVINO_ASSERT(partial_shape.size() >= 3);
55+
new_shape = partial_shape;
56+
new_shape[0] = batch_size; // batch_dim
57+
new_shape[1] = feature_size;
58+
new_shapes.emplace(input_name, new_shape);
59+
}
60+
}
61+
model->reshape(new_shapes);
62+
}
63+
4564
} // namespace
4665

4766
namespace ov {
@@ -55,13 +74,20 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi
5574
: WhisperPipelineImplBase{models_path},
5675
m_sampler(m_tokenizer) {
5776
ov::Core core = utils::singleton_core();
77+
ov::CompiledModel compiled_model;
78+
if (device == "NPU") {
79+
auto encoder_model = core.read_model(models_path / "openvino_encoder_model.xml", {}, properties);
80+
// NB: only batch_size == 1 is supported now for NPU
81+
reshape_to_static_encoder(encoder_model, 1, m_feature_extractor.feature_size);
82+
compiled_model = core.compile_model(encoder_model, "NPU", properties);
83+
} else {
84+
compiled_model = core.compile_model(models_path / "openvino_encoder_model.xml", device, properties);
85+
}
5886

59-
ov::CompiledModel compiled_model =
60-
core.compile_model(models_path / "openvino_encoder_model.xml", device, properties);
6187
ov::genai::utils::print_compiled_model_properties(compiled_model, "whisper encoder model");
6288
m_encoder = init_model(compiled_model);
6389

64-
m_decoder = WhisperDecoder::from_path(models_path, device, properties);
90+
m_decoder = WhisperDecoder::from_path(models_path, device, properties, m_encoder.get_compiled_model().output("last_hidden_state").get_partial_shape());
6591

6692
// If eos_token_id was not provided, take value
6793
if (m_generation_config.eos_token_id == -1) {
@@ -155,7 +181,13 @@ ov::genai::WhisperPipeline::WhisperPipeline(const std::filesystem::path& models_
155181
const ov::AnyMap& properties) {
156182
auto start_time = std::chrono::steady_clock::now();
157183
if (device == "NPU") {
158-
m_impl = std::make_unique<StaticWhisperPipeline>(models_path, properties);
184+
auto properties_copy = properties;
185+
const bool use_static_pipeline = utils::pop_or_default(properties_copy, "STATIC_PIPELINE", true);
186+
if (!use_static_pipeline) {
187+
m_impl = std::make_unique<WhisperPipelineStatefulImpl>(models_path, device, properties_copy);
188+
} else {
189+
m_impl = std::make_unique<StaticWhisperPipeline>(models_path, properties_copy);
190+
}
159191
} else {
160192
m_impl = std::make_unique<WhisperPipelineStatefulImpl>(models_path, device, properties);
161193
}

0 commit comments

Comments
 (0)