diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index fb9d013eeff3..b4d1f41f3ec2 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -19,7 +19,6 @@ from typing import Callable, Optional, Union import torch -import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN @@ -727,6 +726,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[tuple, Qwen2AudioCausalLMOutputWithPast]: r""" feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): @@ -845,6 +845,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = outputs[0] @@ -878,5 +879,19 @@ def forward( attention_mask=attention_mask, ) + def prepare_inputs_for_generation(self, *args, **kwargs): + # Overwritten -- we should not pass input_features when we are in cached decoding stage + + input_features = kwargs.pop("input_features", None) + cache_position = kwargs.get("cache_position") + + model_inputs = super().prepare_inputs_for_generation(*args, **kwargs) + + if cache_position is not None and cache_position[0] == 0: + # input_features should only be passed when we are not in cached decoding stage + model_inputs["input_features"] = input_features + + return model_inputs + __all__ = ["Qwen2AudioForConditionalGeneration", "Qwen2AudioPreTrainedModel", "Qwen2AudioEncoder"] diff --git a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py index 571ac0737081..4533fbbf99d8 100644 --- a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py +++ b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py @@ -34,6 +34,7 @@ torch_device, ) +from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -132,14 +133,12 @@ def prepare_config_and_inputs_for_common(self): @require_torch -class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase): +class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): """ Model tester for `Qwen2AudioForConditionalGeneration`. """ all_model_classes = (Qwen2AudioForConditionalGeneration,) if is_torch_available() else () - # Doesn't run generation tests. TODO eustache/joao: some generation tests are broken, the errors seem cache-related - all_generative_model_classes = () test_pruning = False test_head_masking = False _is_composite = True diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 38c581992b31..754e267662b9 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3484,6 +3484,7 @@ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_sid model = model_class(config) model.to(torch_device) + model.to(torch.bfloat16) dummy_input = inputs_dict[model.main_input_name][:1] if dummy_input.dtype in [torch.float32, torch.float16]: dummy_input = dummy_input.to(torch.bfloat16)