Skip to content

Commit 7623aa3

Browse files
authored
Fix Qwen2AudioForConditionalGeneration.forward() and test_flash_attn_kernels_inference_equivalence (#39503)
* Add missing cache_position argument. * Pass cache_position to language model. * Overwrite prepare_inputs_for_generation. * Set model to half precision for Flash Attention test. * Cast model to bfloat16.
1 parent 28f2619 commit 7623aa3

File tree

3 files changed

+19
-4
lines changed

3 files changed

+19
-4
lines changed

src/transformers/models/qwen2_audio/modeling_qwen2_audio.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from typing import Callable, Optional, Union
2020

2121
import torch
22-
import torch.utils.checkpoint
2322
from torch import nn
2423

2524
from ...activations import ACT2FN
@@ -727,6 +726,7 @@ def forward(
727726
output_attentions: Optional[bool] = None,
728727
output_hidden_states: Optional[bool] = None,
729728
return_dict: Optional[bool] = None,
729+
cache_position: Optional[torch.LongTensor] = None,
730730
) -> Union[tuple, Qwen2AudioCausalLMOutputWithPast]:
731731
r"""
732732
feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
@@ -845,6 +845,7 @@ def forward(
845845
output_attentions=output_attentions,
846846
output_hidden_states=output_hidden_states,
847847
return_dict=return_dict,
848+
cache_position=cache_position,
848849
)
849850

850851
logits = outputs[0]
@@ -878,5 +879,19 @@ def forward(
878879
attention_mask=attention_mask,
879880
)
880881

882+
def prepare_inputs_for_generation(self, *args, **kwargs):
883+
# Overwritten -- we should not pass input_features when we are in cached decoding stage
884+
885+
input_features = kwargs.pop("input_features", None)
886+
cache_position = kwargs.get("cache_position")
887+
888+
model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
889+
890+
if cache_position is not None and cache_position[0] == 0:
891+
# input_features should only be passed when we are not in cached decoding stage
892+
model_inputs["input_features"] = input_features
893+
894+
return model_inputs
895+
881896

882897
__all__ = ["Qwen2AudioForConditionalGeneration", "Qwen2AudioPreTrainedModel", "Qwen2AudioEncoder"]

tests/models/qwen2_audio/test_modeling_qwen2_audio.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
torch_device,
3535
)
3636

37+
from ...generation.test_utils import GenerationTesterMixin
3738
from ...test_configuration_common import ConfigTester
3839
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
3940

@@ -132,14 +133,12 @@ def prepare_config_and_inputs_for_common(self):
132133

133134

134135
@require_torch
135-
class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase):
136+
class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
136137
"""
137138
Model tester for `Qwen2AudioForConditionalGeneration`.
138139
"""
139140

140141
all_model_classes = (Qwen2AudioForConditionalGeneration,) if is_torch_available() else ()
141-
# Doesn't run generation tests. TODO eustache/joao: some generation tests are broken, the errors seem cache-related
142-
all_generative_model_classes = ()
143142
test_pruning = False
144143
test_head_masking = False
145144
_is_composite = True

tests/test_modeling_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3484,6 +3484,7 @@ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_sid
34843484
model = model_class(config)
34853485

34863486
model.to(torch_device)
3487+
model.to(torch.bfloat16)
34873488
dummy_input = inputs_dict[model.main_input_name][:1]
34883489
if dummy_input.dtype in [torch.float32, torch.float16]:
34893490
dummy_input = dummy_input.to(torch.bfloat16)

0 commit comments

Comments
 (0)