Skip to content

Fix Qwen2AudioForConditionalGeneration.forward() and test_flash_attn_kernels_inference_equivalence #39503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 28, 2025
Merged
17 changes: 16 additions & 1 deletion src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from typing import Callable, Optional, Union

import torch
import torch.utils.checkpoint
from torch import nn

from ...activations import ACT2FN
Expand Down Expand Up @@ -727,6 +726,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[tuple, Qwen2AudioCausalLMOutputWithPast]:
r"""
feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
Expand Down Expand Up @@ -845,6 +845,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)

logits = outputs[0]
Expand Down Expand Up @@ -878,5 +879,19 @@ def forward(
attention_mask=attention_mask,
)

def prepare_inputs_for_generation(self, *args, **kwargs):
# Overwritten -- we should not pass input_features when we are in cached decoding stage

input_features = kwargs.pop("input_features", None)
cache_position = kwargs.get("cache_position")

model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)

if cache_position is not None and cache_position[0] == 0:
# input_features should only be passed when we are not in cached decoding stage
model_inputs["input_features"] = input_features

return model_inputs


__all__ = ["Qwen2AudioForConditionalGeneration", "Qwen2AudioPreTrainedModel", "Qwen2AudioEncoder"]
5 changes: 2 additions & 3 deletions tests/models/qwen2_audio/test_modeling_qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
torch_device,
)

from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor

Expand Down Expand Up @@ -132,14 +133,12 @@ def prepare_config_and_inputs_for_common(self):


@require_torch
class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase):
class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
"""
Model tester for `Qwen2AudioForConditionalGeneration`.
"""

all_model_classes = (Qwen2AudioForConditionalGeneration,) if is_torch_available() else ()
# Doesn't run generation tests. TODO eustache/joao: some generation tests are broken, the errors seem cache-related
all_generative_model_classes = ()
test_pruning = False
test_head_masking = False
_is_composite = True
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3484,6 +3484,7 @@ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_sid
model = model_class(config)

model.to(torch_device)
model.to(torch.bfloat16)
dummy_input = inputs_dict[model.main_input_name][:1]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
Expand Down