Skip to content

Fix Qwen2AudioForConditionalGeneration.forward() #39503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

ebezzam
Copy link
Contributor

@ebezzam ebezzam commented Jul 18, 2025

What does this PR do?

Fixes breaking Qwen2Audio tests: https://github.com/huggingface/transformers/actions/runs/16361063842/job/46229139255

Errors don't display in above Model CI tests, but this is what I got:

FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationIntegrationTest::test_small_model_integration_test_batch - TypeError: Qwen2AudioForConditionalGeneration.forward() got an unexpected keyword argument 'cache_position'
FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationIntegrationTest::test_small_model_integration_test_multiturn - TypeError: Qwen2AudioForConditionalGeneration.forward() got an unexpected keyword argument 'cache_position'
FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationIntegrationTest::test_small_model_integration_test_single - TypeError: Qwen2AudioForConditionalGeneration.forward() got an unexpected keyword argument 'cache_position'
== 3 failed, 77 passed, 43 skipped, 4 warnings in 82.32s (0:01:22)

In short, cache_position was missing in forward of Qwen2AudioForConditionalGeneration and adding it resolves the tests 👍

However, it probably needs to be processed in some way? Like in Qwen2VLForConditionalGeneration?

If it's similar as in Qwen2VLForConditionalGeneration happy to do it!

cc @eustlb, @gante as it seems you were aware of cache related issues here

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for opening the PR 🙏

We'll need to pass cache_position to self.language_model as well. The non-legacy processing branch in Qwen2AudioForConditionalGeneration.forward only prepares inputs_embeds (and not position-based tensors, like position_ids or attention_mask), so further processing of cache_position probably isn't needed. I'm assuming that before this fix, we were relying on the language_model's default values for cache_position, so in practice this change won't have an impact on the outputs from interfaces like generate.

Let's make sure the slow tests are passing after the changes!

@gante
Copy link
Member

gante commented Jul 18, 2025

@ebezzam after your changes, remove this line as well. If this was the only cache-related issue, then CI will be happy :D

(the line I linked prevents generate-related quick tests from running)

@lovenya
Copy link

lovenya commented Jul 20, 2025

This is such an important patch. Thankyou so much, was having so many issues in just inference.

@ebezzam
Copy link
Contributor Author

ebezzam commented Jul 21, 2025

Thank you @gante for the pointers!

I can confirm that slow tests are passing:

# RUN_SLOW=1 pytest tests/models/qwen2_audio/test_modeling_qwen2_audio.py
...
============= 15 passed, 33 skipped in 45.03s =============

UPDATE after seeing failing CI

I tried different combinations of:

  • keeping all_generative_model_classes = () inside Qwen2AudioForConditionalGenerationModelTest
  • adding GenerationTesterMixin to inheritance list of Qwen2AudioForConditionalGenerationModelTest
# COMMAND: RUN_SLOW=1 pytest tests/models/qwen2_audio

# -- without `all_generative_model_classes = ()` and without `GenerationTesterMixin`
# -- similar to https://app.circleci.com/pipelines/github/huggingface/transformers/138686/workflows/39334258-3551-4727-9629-adcf04e95a67/jobs/1838253
FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationModelTest::test_generation_tester_mixin_inheritance - AssertionError: False is not true : This model can call `generate` from `GenerationMixin`, so one of two things must happen: 1) the tester must inherit from `GenerationTesterMixi...
==== 1 failed, 93 passed, 77 skipped, 4 warnings in 164.27s (0:02:44) =========

# -- without `all_generative_model_classes = ()` and with `GenerationTesterMixin`
FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationModelTest::test_assisted_decoding_matches_greedy_search_0_random - RuntimeError: The size of tensor a (2) must match the size of tensor b (26) at non-singleton dimension 1
FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationModelTest::test_assisted_decoding_matches_greedy_search_1_same - RuntimeError: The size of tensor a (2) must match the size of tensor b (27) at non-singleton dimension 1
FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationModelTest::test_assisted_decoding_sample - RuntimeError: The size of tensor a (2) must match the size of tensor b (27) at non-singleton dimension 1
==== 3 failed, 124 passed, 87 skipped, 4 warnings in 179.77s (0:02:59) ======

# with `all_generative_model_classes = ()` and without `GenerationTesterMixin`
==== 95 passed, 76 skipped, 4 warnings in 165.05s (0:02:45) =======

@ebezzam ebezzam requested a review from gante July 21, 2025 08:25
@eustlb
Copy link
Contributor

eustlb commented Jul 22, 2025

remaining failing tests should be fixed adding this method to ensure we don't trigger merging input features when not in prefil stage

Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: qwen2_audio

@ebezzam ebezzam added the Audio label Jul 22, 2025
@ebezzam ebezzam changed the title Add missing cache_position to Qwen2AudioForConditionalGeneration.forward() Fix Qwen2AudioForConditionalGeneration.forward() Jul 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants