Skip to content

Commit a303604

Browse files
committed
Fixed review comments
1 parent 04c89a2 commit a303604

File tree

3 files changed

+47
-10
lines changed

3 files changed

+47
-10
lines changed

src/cpp/src/whisper/models/with_past_decoder.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,10 @@ void WhisperWithPastDecoder::start_async(const Tensor& encoder_hidden_state,
112112
_set_encoder_hidden_states_tensor(encoder_hidden_state, batch_size, request);
113113
request.set_tensor("input_ids", input_ids);
114114

115-
if (!is_initial_step) {
116-
if (m_past_decoder_has_cache_position) {
117-
ov::Tensor cache_position_tensor = request.get_tensor("cache_position");
118-
cache_position_tensor.set_shape({1});
119-
cache_position_tensor.data<int64_t>()[0] = m_cache_position;
120-
}
115+
if (!is_initial_step && m_past_decoder_has_cache_position) {
116+
ov::Tensor cache_position_tensor = request.get_tensor("cache_position");
117+
cache_position_tensor.set_shape({1});
118+
cache_position_tensor.data<int64_t>()[0] = m_cache_position;
121119
}
122120

123121
_set_past_key_value(beam_idx);

src/cpp/src/whisper/pipeline_static.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ void add_cache_position_input(std::shared_ptr<ov::Model> model) {
580580
cache_position->set_friendly_name("cache_position");
581581
model->add_parameters({cache_position});
582582
std::shared_ptr<ov::Node> cache_pos_unsqueeze_arg;
583-
if (unsqueeze_node->input(0).get_element_type() == ov::element::f32) {
583+
if (matched_unsqueeze->input(0).get_element_type() == ov::element::f32) {
584584
cache_pos_unsqueeze_arg = std::make_shared<v0::Convert>(cache_position, ov::element::f32);
585585
} else {
586586
cache_pos_unsqueeze_arg = cache_position;

tests/python_tests/test_whisper_pipeline.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,13 @@ def get_whisper_models_list(tiny_only=False):
5555
# used whisper models are relatively small
5656
# cache them in memory to speedup tests
5757
@functools.lru_cache()
58-
def read_whisper_model(params):
58+
def read_whisper_model(params, stateful=True):
5959
model_id, path = params
60+
if not stateful:
61+
path = pathlib.Path(f"{path}_with_past")
6062

6163
if not (path / "openvino_encoder_model.xml").exists():
62-
save_model(model_id=model_id, tmp_path=path)
64+
save_model(model_id=model_id, tmp_path=path, stateful=stateful)
6365

6466
opt_model = retry_request(lambda: OVModelForSpeechSeq2Seq.from_pretrained(
6567
path,
@@ -91,7 +93,7 @@ def read_whisper_model(params):
9193
)
9294

9395

94-
def save_model(model_id: str, tmp_path: pathlib.Path):
96+
def save_model(model_id: str, tmp_path: pathlib.Path, stateful=True):
9597
tokenizer = retry_request(lambda: AutoTokenizer.from_pretrained(model_id, trust_remote_code=True))
9698
ov_tokenizer, ov_detokenizer = openvino_tokenizers.convert_tokenizer(
9799
tokenizer,
@@ -109,6 +111,7 @@ def save_model(model_id: str, tmp_path: pathlib.Path):
109111
model_id,
110112
export=True,
111113
trust_remote_code=True,
114+
stateful=stateful,
112115
compile=False,
113116
device="CPU",
114117
load_in_8bit=False,
@@ -223,6 +226,9 @@ def run_pipeline_with_ref(
223226
streamer: typing.Callable[[str], bool] | None = None,
224227
):
225228
_, _, hf_pipe, genai_pipe = read_whisper_model((model_id, tmp_path))
229+
_, _, _, genai_with_past_pipe = read_whisper_model(
230+
(model_id, tmp_path), stateful=False
231+
)
226232

227233
if type(sample) is np.ndarray and len(sample.shape) == 1:
228234
sample = np.expand_dims(sample, 0)
@@ -233,6 +239,12 @@ def run_pipeline_with_ref(
233239

234240
compare_results(hf_result, genai_result)
235241

242+
genai_with_past_result = run_genai(
243+
genai_with_past_pipe, _sample, generation_config, streamer
244+
)
245+
246+
compare_results(hf_result, genai_with_past_result)
247+
236248

237249
def compare_results(hf_result, genai_result):
238250
assert genai_result.texts[0] == hf_result["text"]
@@ -498,6 +510,33 @@ def test_longform_audio(model_descr, sample_from_dataset):
498510
assert "".join(streamer_result) == hf_result["text"]
499511

500512

513+
@pytest.mark.parametrize("model_descr", get_whisper_models_list())
514+
@pytest.mark.parametrize("sample_from_dataset", [*get_fixture_params_for_n_whisper_dataset_samples(n=2, long_form=True)], indirect=True)
515+
@pytest.mark.precommit
516+
@pytest.mark.xfail(condition=(sys.platform == "darwin"), reason="Ticket - 173169")
517+
def test_longform_audio_with_past(model_descr, sample_from_dataset):
518+
_, _, hf_pipe, genai_pipe = read_whisper_model(model_descr, stateful=True)
519+
520+
streamer_result = []
521+
522+
genai_result = run_genai(
523+
genai_pipe,
524+
sample_from_dataset,
525+
config=ov_genai.WhisperGenerationConfig(return_timestamps=True),
526+
streamer=lambda x: streamer_result.append(x),
527+
)
528+
529+
hf_result = run_huggingface(
530+
hf_pipe,
531+
sample_from_dataset,
532+
config=ov_genai.WhisperGenerationConfig(return_timestamps=True),
533+
)
534+
535+
compare_results(hf_result, genai_result)
536+
537+
assert "".join(streamer_result) == hf_result["text"]
538+
539+
501540
@pytest.mark.parametrize("model_descr", get_whisper_models_list())
502541
@pytest.mark.precommit
503542
@pytest.mark.xfail(condition=(sys.platform == "darwin"), reason="Ticket - 173169")

0 commit comments

Comments
 (0)