From 69a6e4b9163f9a91c70743d42a5728d81e412d97 Mon Sep 17 00:00:00 2001 From: Eric B Date: Fri, 18 Jul 2025 16:34:43 +0200 Subject: [PATCH 1/5] Add missing cache_position argument. --- src/transformers/models/qwen2_audio/modeling_qwen2_audio.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index fb9d013eeff3..7fbb68e521da 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -727,6 +727,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[tuple, Qwen2AudioCausalLMOutputWithPast]: r""" feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): From 73af10772def2195d4aa97e586dd2223670a5974 Mon Sep 17 00:00:00 2001 From: Eric B Date: Mon, 21 Jul 2025 10:22:33 +0200 Subject: [PATCH 2/5] Pass cache_position to language model. --- src/transformers/models/qwen2_audio/modeling_qwen2_audio.py | 1 + tests/models/qwen2_audio/test_modeling_qwen2_audio.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 7fbb68e521da..8cb5fde9c23f 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -846,6 +846,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = outputs[0] diff --git a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py index 571ac0737081..ad7339799274 100644 --- a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py +++ b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py @@ -138,8 +138,6 @@ class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, unittest.Tes """ all_model_classes = (Qwen2AudioForConditionalGeneration,) if is_torch_available() else () - # Doesn't run generation tests. TODO eustache/joao: some generation tests are broken, the errors seem cache-related - all_generative_model_classes = () test_pruning = False test_head_masking = False _is_composite = True From d46f20c64bfa2a0a41d90c9a1107ade6094752e9 Mon Sep 17 00:00:00 2001 From: Eric B Date: Tue, 22 Jul 2025 19:02:25 +0200 Subject: [PATCH 3/5] Overwrite prepare_inputs_for_generation. --- .../models/qwen2_audio/modeling_qwen2_audio.py | 15 ++++++++++++++- .../qwen2_audio/test_modeling_qwen2_audio.py | 3 ++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 8cb5fde9c23f..b4d1f41f3ec2 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -19,7 +19,6 @@ from typing import Callable, Optional, Union import torch -import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN @@ -880,5 +879,19 @@ def forward( attention_mask=attention_mask, ) + def prepare_inputs_for_generation(self, *args, **kwargs): + # Overwritten -- we should not pass input_features when we are in cached decoding stage + + input_features = kwargs.pop("input_features", None) + cache_position = kwargs.get("cache_position") + + model_inputs = super().prepare_inputs_for_generation(*args, **kwargs) + + if cache_position is not None and cache_position[0] == 0: + # input_features should only be passed when we are not in cached decoding stage + model_inputs["input_features"] = input_features + + return model_inputs + __all__ = ["Qwen2AudioForConditionalGeneration", "Qwen2AudioPreTrainedModel", "Qwen2AudioEncoder"] diff --git a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py index ad7339799274..4533fbbf99d8 100644 --- a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py +++ b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py @@ -34,6 +34,7 @@ torch_device, ) +from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -132,7 +133,7 @@ def prepare_config_and_inputs_for_common(self): @require_torch -class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase): +class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): """ Model tester for `Qwen2AudioForConditionalGeneration`. """ From 3a3a470e613c80d90cad46010a8650960d1b78f5 Mon Sep 17 00:00:00 2001 From: Eric B Date: Mon, 28 Jul 2025 15:05:31 +0200 Subject: [PATCH 4/5] Set model to half precision for Flash Attention test. --- tests/test_modeling_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 38c581992b31..51ab326d3940 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3484,6 +3484,7 @@ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_sid model = model_class(config) model.to(torch_device) + model.half() dummy_input = inputs_dict[model.main_input_name][:1] if dummy_input.dtype in [torch.float32, torch.float16]: dummy_input = dummy_input.to(torch.bfloat16) From e06ba17a97b634d8fc5f4cace151e2a35049abcf Mon Sep 17 00:00:00 2001 From: Eric B Date: Mon, 28 Jul 2025 15:30:01 +0200 Subject: [PATCH 5/5] Cast model to bfloat16. --- tests/test_modeling_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 51ab326d3940..754e267662b9 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3484,7 +3484,7 @@ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_sid model = model_class(config) model.to(torch_device) - model.half() + model.to(torch.bfloat16) dummy_input = inputs_dict[model.main_input_name][:1] if dummy_input.dtype in [torch.float32, torch.float16]: dummy_input = dummy_input.to(torch.bfloat16)