Skip to content

Commit fbda9d8

Browse files
committed
Fixed review comments
1 parent 04c89a2 commit fbda9d8

File tree

3 files changed

+46
-17
lines changed

3 files changed

+46
-17
lines changed

src/cpp/src/whisper/models/with_past_decoder.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,10 @@ void WhisperWithPastDecoder::start_async(const Tensor& encoder_hidden_state,
112112
_set_encoder_hidden_states_tensor(encoder_hidden_state, batch_size, request);
113113
request.set_tensor("input_ids", input_ids);
114114

115-
if (!is_initial_step) {
116-
if (m_past_decoder_has_cache_position) {
117-
ov::Tensor cache_position_tensor = request.get_tensor("cache_position");
118-
cache_position_tensor.set_shape({1});
119-
cache_position_tensor.data<int64_t>()[0] = m_cache_position;
120-
}
115+
if (!is_initial_step && m_past_decoder_has_cache_position) {
116+
ov::Tensor cache_position_tensor = request.get_tensor("cache_position");
117+
cache_position_tensor.set_shape({1});
118+
cache_position_tensor.data<int64_t>()[0] = m_cache_position;
121119
}
122120

123121
_set_past_key_value(beam_idx);

src/cpp/src/whisper/pipeline_static.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ void add_cache_position_input(std::shared_ptr<ov::Model> model) {
580580
cache_position->set_friendly_name("cache_position");
581581
model->add_parameters({cache_position});
582582
std::shared_ptr<ov::Node> cache_pos_unsqueeze_arg;
583-
if (unsqueeze_node->input(0).get_element_type() == ov::element::f32) {
583+
if (matched_unsqueeze->input(0).get_element_type() == ov::element::f32) {
584584
cache_pos_unsqueeze_arg = std::make_shared<v0::Convert>(cache_position, ov::element::f32);
585585
} else {
586586
cache_pos_unsqueeze_arg = cache_position;

tests/python_tests/test_whisper_pipeline.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import openvino_genai as ov_genai
55
import functools
66
import pytest
7-
import sys
87
import openvino_tokenizers
98
import openvino
109
import datasets
@@ -55,11 +54,13 @@ def get_whisper_models_list(tiny_only=False):
5554
# used whisper models are relatively small
5655
# cache them in memory to speedup tests
5756
@functools.lru_cache()
58-
def read_whisper_model(params):
57+
def read_whisper_model(params, stateful=True):
5958
model_id, path = params
59+
if not stateful:
60+
path = pathlib.Path(f"{path}_with_past")
6061

6162
if not (path / "openvino_encoder_model.xml").exists():
62-
save_model(model_id=model_id, tmp_path=path)
63+
save_model(model_id=model_id, tmp_path=path, stateful=stateful)
6364

6465
opt_model = retry_request(lambda: OVModelForSpeechSeq2Seq.from_pretrained(
6566
path,
@@ -91,7 +92,7 @@ def read_whisper_model(params):
9192
)
9293

9394

94-
def save_model(model_id: str, tmp_path: pathlib.Path):
95+
def save_model(model_id: str, tmp_path: pathlib.Path, stateful=True):
9596
tokenizer = retry_request(lambda: AutoTokenizer.from_pretrained(model_id, trust_remote_code=True))
9697
ov_tokenizer, ov_detokenizer = openvino_tokenizers.convert_tokenizer(
9798
tokenizer,
@@ -109,6 +110,7 @@ def save_model(model_id: str, tmp_path: pathlib.Path):
109110
model_id,
110111
export=True,
111112
trust_remote_code=True,
113+
stateful=stateful,
112114
compile=False,
113115
device="CPU",
114116
load_in_8bit=False,
@@ -223,6 +225,9 @@ def run_pipeline_with_ref(
223225
streamer: typing.Callable[[str], bool] | None = None,
224226
):
225227
_, _, hf_pipe, genai_pipe = read_whisper_model((model_id, tmp_path))
228+
_, _, _, genai_with_past_pipe = read_whisper_model(
229+
(model_id, tmp_path), stateful=False
230+
)
226231

227232
if type(sample) is np.ndarray and len(sample.shape) == 1:
228233
sample = np.expand_dims(sample, 0)
@@ -233,6 +238,12 @@ def run_pipeline_with_ref(
233238

234239
compare_results(hf_result, genai_result)
235240

241+
genai_with_past_result = run_genai(
242+
genai_with_past_pipe, _sample, generation_config, streamer
243+
)
244+
245+
compare_results(hf_result, genai_with_past_result)
246+
236247

237248
def compare_results(hf_result, genai_result):
238249
assert genai_result.texts[0] == hf_result["text"]
@@ -446,7 +457,6 @@ def test_language_autodetect(model_descr, sample_from_dataset):
446457
@pytest.mark.parametrize("model_descr", get_whisper_models_list(tiny_only=True))
447458
@pytest.mark.parametrize("sample_from_dataset", [*get_fixture_params_for_n_whisper_dataset_samples(n=1)], indirect=True)
448459
@pytest.mark.precommit
449-
@pytest.mark.xfail(condition=(sys.platform == "darwin"), reason="Ticket - 173169")
450460
def test_return_timestamps_short_form(model_descr, sample_from_dataset):
451461
run_pipeline_with_ref(
452462
model_id=model_descr[0],
@@ -459,7 +469,6 @@ def test_return_timestamps_short_form(model_descr, sample_from_dataset):
459469
@pytest.mark.parametrize("model_descr", get_whisper_models_list(tiny_only=True))
460470
@pytest.mark.parametrize("sample_from_dataset", [*get_fixture_params_for_n_whisper_dataset_samples(n=1)], indirect=True)
461471
@pytest.mark.precommit
462-
@pytest.mark.xfail(condition=(sys.platform == "darwin"), reason="Ticket - 173169")
463472
def test_return_timestamps_max_new_tokens_short_form(model_descr, sample_from_dataset):
464473
run_pipeline_with_ref(
465474
model_id=model_descr[0],
@@ -474,7 +483,6 @@ def test_return_timestamps_max_new_tokens_short_form(model_descr, sample_from_da
474483
@pytest.mark.parametrize("model_descr", get_whisper_models_list())
475484
@pytest.mark.parametrize("sample_from_dataset", [*get_fixture_params_for_n_whisper_dataset_samples(n=10, long_form=True)], indirect=True)
476485
@pytest.mark.precommit
477-
@pytest.mark.xfail(condition=(sys.platform == "darwin"), reason="Ticket - 173169")
478486
def test_longform_audio(model_descr, sample_from_dataset):
479487
_, _, hf_pipe, genai_pipe = read_whisper_model(model_descr)
480488

@@ -498,9 +506,34 @@ def test_longform_audio(model_descr, sample_from_dataset):
498506
assert "".join(streamer_result) == hf_result["text"]
499507

500508

509+
@pytest.mark.parametrize("model_descr", get_whisper_models_list())
510+
@pytest.mark.parametrize("sample_from_dataset", [*get_fixture_params_for_n_whisper_dataset_samples(n=2, long_form=True)], indirect=True)
511+
@pytest.mark.precommit
512+
def test_longform_audio_with_past(model_descr, sample_from_dataset):
513+
_, _, hf_pipe, genai_pipe = read_whisper_model(model_descr, stateful=True)
514+
515+
streamer_result = []
516+
517+
genai_result = run_genai(
518+
genai_pipe,
519+
sample_from_dataset,
520+
config=ov_genai.WhisperGenerationConfig(return_timestamps=True),
521+
streamer=lambda x: streamer_result.append(x),
522+
)
523+
524+
hf_result = run_huggingface(
525+
hf_pipe,
526+
sample_from_dataset,
527+
config=ov_genai.WhisperGenerationConfig(return_timestamps=True),
528+
)
529+
530+
compare_results(hf_result, genai_result)
531+
532+
assert "".join(streamer_result) == hf_result["text"]
533+
534+
501535
@pytest.mark.parametrize("model_descr", get_whisper_models_list())
502536
@pytest.mark.precommit
503-
@pytest.mark.xfail(condition=(sys.platform == "darwin"), reason="Ticket - 173169")
504537
def test_shortform(model_descr):
505538
samples = []
506539
ds = datasets.load_dataset(
@@ -520,7 +553,6 @@ def test_shortform(model_descr):
520553
@pytest.mark.parametrize("model_descr", get_whisper_models_list(tiny_only=True))
521554
@pytest.mark.parametrize("sample_from_dataset", [*get_fixture_params_for_n_whisper_dataset_samples(n=2, long_form=True)], indirect=True)
522555
@pytest.mark.precommit
523-
@pytest.mark.xfail(condition=(sys.platform == "darwin"), reason="Ticket - 173169")
524556
def test_beam_search(model_descr, sample_from_dataset):
525557
# use only 30 seconds of audio due to beam search results wrong with enabled timestamps
526558
# ticket: 167239
@@ -599,7 +631,6 @@ def test_random_sampling(model_descr, sample_from_dataset):
599631
@pytest.mark.parametrize("model_descr", get_whisper_models_list(tiny_only=True))
600632
@pytest.mark.parametrize("sample_from_dataset", [{"language" : "en", "sample_id": 0}], indirect=True)
601633
@pytest.mark.precommit
602-
@pytest.mark.xfail(condition=(sys.platform == "darwin"), reason="Ticket - 173169")
603634
def test_perf_metrics(model_descr, sample_from_dataset):
604635
model_id, path, hf_pipe, genai_pipe = read_whisper_model(model_descr)
605636

0 commit comments

Comments
 (0)