Skip to content

Conversation

ebezzam
Copy link
Contributor

@ebezzam ebezzam commented Jul 18, 2025

What does this PR do?

Fixes breaking Qwen2Audio tests: https://github.com/huggingface/transformers/actions/runs/16361063842/job/46229139255

Errors don't display in above Model CI tests, but this is what I got:

FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationIntegrationTest::test_small_model_integration_test_batch - TypeError: Qwen2AudioForConditionalGeneration.forward() got an unexpected keyword argument 'cache_position'
FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationIntegrationTest::test_small_model_integration_test_multiturn - TypeError: Qwen2AudioForConditionalGeneration.forward() got an unexpected keyword argument 'cache_position'
FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationIntegrationTest::test_small_model_integration_test_single - TypeError: Qwen2AudioForConditionalGeneration.forward() got an unexpected keyword argument 'cache_position'
== 3 failed, 77 passed, 43 skipped, 4 warnings in 82.32s (0:01:22)

In short, cache_position was missing in forward of Qwen2AudioForConditionalGeneration and adding it resolves the tests 👍

However, it probably needs to be processed in some way? Like in Qwen2VLForConditionalGeneration?

If it's similar as in Qwen2VLForConditionalGeneration happy to do it!

cc @eustlb, @gante as it seems you were aware of cache related issues here

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for opening the PR 🙏

We'll need to pass cache_position to self.language_model as well. The non-legacy processing branch in Qwen2AudioForConditionalGeneration.forward only prepares inputs_embeds (and not position-based tensors, like position_ids or attention_mask), so further processing of cache_position probably isn't needed. I'm assuming that before this fix, we were relying on the language_model's default values for cache_position, so in practice this change won't have an impact on the outputs from interfaces like generate.

Let's make sure the slow tests are passing after the changes!

@gante
Copy link
Member

gante commented Jul 18, 2025

@ebezzam after your changes, remove this line as well. If this was the only cache-related issue, then CI will be happy :D

(the line I linked prevents generate-related quick tests from running)

@lovenya
Copy link

lovenya commented Jul 20, 2025

This is such an important patch. Thankyou so much, was having so many issues in just inference.

@ebezzam
Copy link
Contributor Author

ebezzam commented Jul 21, 2025

Thank you @gante for the pointers!

I can confirm that slow tests are passing:

# RUN_SLOW=1 pytest tests/models/qwen2_audio/test_modeling_qwen2_audio.py
...
============= 15 passed, 33 skipped in 45.03s =============

UPDATE after seeing failing CI

I tried different combinations of:

  • keeping all_generative_model_classes = () inside Qwen2AudioForConditionalGenerationModelTest
  • adding GenerationTesterMixin to inheritance list of Qwen2AudioForConditionalGenerationModelTest
# COMMAND: RUN_SLOW=1 pytest tests/models/qwen2_audio

# -- without `all_generative_model_classes = ()` and without `GenerationTesterMixin`
# -- similar to https://app.circleci.com/pipelines/github/huggingface/transformers/138686/workflows/39334258-3551-4727-9629-adcf04e95a67/jobs/1838253
FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationModelTest::test_generation_tester_mixin_inheritance - AssertionError: False is not true : This model can call `generate` from `GenerationMixin`, so one of two things must happen: 1) the tester must inherit from `GenerationTesterMixi...
==== 1 failed, 93 passed, 77 skipped, 4 warnings in 164.27s (0:02:44) =========

# -- without `all_generative_model_classes = ()` and with `GenerationTesterMixin`
FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationModelTest::test_assisted_decoding_matches_greedy_search_0_random - RuntimeError: The size of tensor a (2) must match the size of tensor b (26) at non-singleton dimension 1
FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationModelTest::test_assisted_decoding_matches_greedy_search_1_same - RuntimeError: The size of tensor a (2) must match the size of tensor b (27) at non-singleton dimension 1
FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationModelTest::test_assisted_decoding_sample - RuntimeError: The size of tensor a (2) must match the size of tensor b (27) at non-singleton dimension 1
==== 3 failed, 124 passed, 87 skipped, 4 warnings in 179.77s (0:02:59) ======

# with `all_generative_model_classes = ()` and without `GenerationTesterMixin`
==== 95 passed, 76 skipped, 4 warnings in 165.05s (0:02:45) =======

@ebezzam ebezzam requested a review from gante July 21, 2025 08:25
@eustlb
Copy link
Contributor

eustlb commented Jul 22, 2025

remaining failing tests should be fixed adding this method to ensure we don't trigger merging input features when not in prefil stage

@ebezzam ebezzam added the Audio label Jul 22, 2025
@ebezzam ebezzam changed the title Add missing cache_position to Qwen2AudioForConditionalGeneration.forward() Fix Qwen2AudioForConditionalGeneration.forward() Jul 22, 2025
@eustlb
Copy link
Contributor

eustlb commented Jul 24, 2025

run-slow: qwen2_audio

Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/qwen2_audio']
quantizations: [] ...

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for iterating, LGTM 🙏

@gante
Copy link
Member

gante commented Jul 24, 2025

(having a look at the CI failures, which seem unrelated to these changes)

@gante
Copy link
Member

gante commented Jul 24, 2025

The root issue in CI is solved in timm, and already released. However, this release happened after we last built our CI images.

I've manually triggered a rebuild of our CI images (here), it should be a matter of waiting for this job -> rerun this PR's CI

@gante
Copy link
Member

gante commented Jul 24, 2025

ok, needs more work: our setup file has "timm<=1.0.11", but timm==1.0.18 is installed on the CI machine

@gante
Copy link
Member

gante commented Jul 24, 2025

run-slow: qwen2_audio

Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/qwen2_audio']
quantizations: [] ...

@gante
Copy link
Member

gante commented Jul 24, 2025

the CI job failed before running the slow tests, but I'm getting the following failures locally:

FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationModelTest::test_eager_matches_fa2_generate - RuntimeError: cu_seqlens_q must have shape (batch_size + 1)
FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationModelTest::test_eager_padding_matches_padding_free_with_position_ids - TypeError: can't assign a NoneType to a torch.cuda.LongTensor
FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationModelTest::test_flash_attention_2_padding_matches_padding_free_with_position_ids - TypeError: can't assign a NoneType to a torch.cuda.LongTensor
FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationModelTest::test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs - TypeError: can't assign a NoneType to a torch.cuda.LongTensor
FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationModelTest::test_flash_attn_2_inference_equivalence - RuntimeError: FlashAttention only support fp16 and bf16 data type
FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationModelTest::test_flash_attn_kernels_inference_equivalence - RuntimeError: FlashAttention only support fp16 and bf16 data type
FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationModelTest::test_sdpa_padding_matches_padding_free_with_position_ids - TypeError: can't assign a NoneType to a torch.cuda.LongTensor

cc @ebezzam

@ebezzam
Copy link
Contributor Author

ebezzam commented Jul 24, 2025

thanks for the update! I'll take a look into that

Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: qwen2_audio

@@ -3484,6 +3484,7 @@ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_sid
model = model_class(config)

model.to(torch_device)
model.half()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolves test that is expecting half-precision inputs

# RUN_SLOW=1 pytest tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationModelTest::test_flash_attn_kernels_inference_equivalence

# -- previously
FAILED tests/models/qwen2_audio/test_modeling_qwen2_audio.py::Qwen2AudioForConditionalGenerationModelTest::test_flash_attn_kernels_inference_equivalence - RuntimeError: FlashAttention only supports fp16, bf16, and fp8_e4m3 data type

Also fixes the simliar test for Qwen2_VL (below) and maybe other models!

RUN_SLOW=1 pytest tests/models/qwen2_vl/test_modeling_qwen2_vl.py::Qwen2VLModelTest::test_flash_attn_kernels_inference_equivalence

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was broken recently (~1 week ago, we could see the models being loaded in BF16)

Can we set the original dtype, i.e. model.to(torch.bfloat16), instead?

@ebezzam
Copy link
Contributor Author

ebezzam commented Jul 28, 2025

run-slow: qwen2_audio

Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/qwen2_audio']
quantizations: [] ...

@ebezzam ebezzam changed the title Fix Qwen2AudioForConditionalGeneration.forward() Fix Qwen2AudioForConditionalGeneration.forward() and test_flash_attn_kernels_inference_equivalence Jul 28, 2025
@ebezzam ebezzam merged commit 7623aa3 into huggingface:main Jul 28, 2025
26 checks passed
@ebezzam ebezzam deleted the qwen2audio branch July 28, 2025 14:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants