Skip to content

Conversation

zhaozhenyu-newsbreak
Copy link

@zhaozhenyu-newsbreak zhaozhenyu-newsbreak commented Mar 19, 2025

fix unexpected kws of input_ids when setup no speech detection of whisper

What does this PR do?

To fix the unexpected keyword arguments error of input_ids when use WhisperForConditionalGeneration or the pipeline of " pipeline" to do ASR job with setting up the no_speech_threshold, which is similar with https://discuss.huggingface.co/t/unexpected-keywork-argument/91356 but not the same

Fixes # (issue)

The root reason is that the function of _setup_no_speech_detection will add a "input_ids" kw argument, but the forward function of WhisperForConditionalGeneration does not support this argument. I think change the input_ids to decoder_input_ids will fix this issue

Before submitting

@github-actions github-actions bot marked this pull request as draft March 19, 2025 03:00
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the Ready for review button (at the bottom of the PR page).

@zhaozhenyu-newsbreak zhaozhenyu-newsbreak marked this pull request as ready for review March 19, 2025 03:03
@gante
Copy link
Member

gante commented Mar 19, 2025

Hey @zhaozhenyu-newsbreak 👋 Thank you for opening the PR!

Could you share a short reproducible script this PR is meant to fix? I confess I'm not seeing the issue :)

@zhaozhenyu-newsbreak
Copy link
Author

zhaozhenyu-newsbreak commented Mar 19, 2025

Of course! I follow the instruction of whisper_large_v3 with version of tranformers (4.49.0)
and here is my code

import os
import torch
import logging
from transformers import (
                          WhisperProcessor, WhisperForConditionalGeneration, WhisperFeatureExtractor,
                          AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline)


class WhisperAsrPipeline:
    def __init__(self, model_name='openai/whisper-large-v3', max_length=400):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
        self.processor = AutoProcessor.from_pretrained(model_name)
        model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name).to(self.device)
        model.eval()
        self.model = model
        self.pipeline = pipeline( 
                                "automatic-speech-recognition",
                                model=model,
                                tokenizer=self.processor.tokenizer,
                                feature_extractor=self.processor.feature_extractor,
                                torch_dtype=self.torch_dtype,
                                device=self.device)
        self.generate_kwargs = {
            "max_new_tokens": max_length,
            "num_beams": 1,
            "condition_on_prev_tokens": False,
            "compression_ratio_threshold": 1.35,  # zlib compression ratio threshold (in token space)
            "temperature": 0.0, #(0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
            "logprob_threshold": -1.0,
            "no_speech_threshold": 0.6,
            "return_timestamps": True,
        }
        
    def asr_one_video(self, video_path):
        with torch.no_grad():
            transcription = self.pipeline(video_path, generate_kwargs=self.generate_kwargs)['text']
        return transcription

if __name__ == '__main__':
    # whisper_asr = WhisperAsr(model_name='openai/whisper-large-v3', max_length=400)
    video_path = '/path/to/video.mp4'
    transcription = whisper_asr.asr_one_video(video_path)
    print(transcription)

but I encountered an issue

 File "/home/zhenyu/LlmTraining/llm_training/feature_extract/audio_feature/whisper_asr.py", line 81, in <module>
    transcription = whisper_asr.asr_one_video(video_path)
  File "/home/zhenyu/LlmTraining/llm_training/feature_extract/audio_feature/whisper_asr.py", line 73, in asr_one_video
    transcription = self.pipeline(video_path, generate_kwargs=self.generate_kwargs)['text']
  File "/home/zhenyu/miniconda3/envs/llm-training/lib/python3.10/site-packages/transformers/pipelines/automatic_speech_recognition.py", line 283, in __call__
    return super().__call__(inputs, **kwargs)
  File "/home/zhenyu/miniconda3/envs/llm-training/lib/python3.10/site-packages/transformers/pipelines/base.py", line 1360, in __call__
    return next(
  File "/home/zhenyu/miniconda3/envs/llm-training/lib/python3.10/site-packages/transformers/pipelines/pt_utils.py", line 124, in __next__
    item = next(self.iterator)
  File "/home/zhenyu/miniconda3/envs/llm-training/lib/python3.10/site-packages/transformers/pipelines/pt_utils.py", line 269, in __next__
    processed = self.infer(next(self.iterator), **self.params)
  File "/home/zhenyu/miniconda3/envs/llm-training/lib/python3.10/site-packages/transformers/pipelines/base.py", line 1275, in forward
    model_outputs = self._forward(model_inputs, **forward_params)
  File "/home/zhenyu/miniconda3/envs/llm-training/lib/python3.10/site-packages/transformers/pipelines/automatic_speech_recognition.py", line 521, in _forward
    tokens = self.model.generate(
  File "/home/zhenyu/miniconda3/envs/llm-training/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py", line 774, in generate
    ) = self.generate_with_fallback(
  File "/home/zhenyu/miniconda3/envs/llm-training/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py", line 950, in generate_with_fallback
    seek_outputs = super().generate(
  File "/home/zhenyu/miniconda3/envs/llm-training/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/zhenyu/miniconda3/envs/llm-training/lib/python3.10/site-packages/transformers/generation/utils.py", line 2223, in generate
    result = self._sample(
  File "/home/zhenyu/miniconda3/envs/llm-training/lib/python3.10/site-packages/transformers/generation/utils.py", line 3231, in _sample
    next_token_scores = logits_processor(input_ids, next_token_logits)
  File "/home/zhenyu/miniconda3/envs/llm-training/lib/python3.10/site-packages/transformers/generation/logits_process.py", line 88, in __call__
    scores = processor(input_ids, scores)
  File "/home/zhenyu/miniconda3/envs/llm-training/lib/python3.10/site-packages/transformers/generation/logits_process.py", line 2035, in __call__
    logits = self.model(**self.inputs).logits
  File "/home/zhenyu/miniconda3/envs/llm-training/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zhenyu/miniconda3/envs/llm-training/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: WhisperForConditionalGeneration.forward() got an unexpected keyword argument 'input_ids'

I think the function of _setup_no_speech_detection will add a "input_ids" kw argument, but the forward function of WhisperForConditionalGeneration does not support this argument. Adding a **kwargs into the forward function will fix this issue, but I'm not sure it is an elegant fixing

@ArthurZucker
Copy link
Collaborator

I think the function of _setup_no_speech_detection will add a "input_ids" kw argument, but the forward function of WhisperForConditionalGeneration does not support this argument. Adding a **kwargs into the forward function will fix this issue, but I'm not sure it is an elegant fixing

It's expecting decoder_input_ids ! fix should be a tad bit more involved

@gante
Copy link
Member

gante commented Mar 20, 2025

Adding a **kwargs into the forward function will fix this issue, but I'm not sure it is an elegant fixing

Yeah, adding **kwargs usually creates more issues than the ones it solves 😬 From the stack trace, it seems like it's passing one input too much, or the input has a bad name.

I would gladly accept a PR that corrects the input preparation, rather than changing the signature of model.forward.

(P.S.: I've edited the PR header, to avoid pinging everyone :) )

@zhaozhenyu-newsbreak
Copy link
Author

I think the function of _setup_no_speech_detection will add a "input_ids" kw argument, but the forward function of WhisperForConditionalGeneration does not support this argument. Adding a **kwargs into the forward function will fix this issue, but I'm not sure it is an elegant fixing

It's expecting decoder_input_ids ! fix should be a tad bit more involved

Actually, changing the "input_ids" to "decoder_input_ids" will also cause a problem. Because the GenerationMixin.prepare_inputs_for_generation() must take a positional argument: "input_ids". It may be complex to fix the issue elegantly, and I need more time to deepdive. Thanks!

@adelgiudice
Copy link

I'm also experiencing this issue as i attempt to upgrade transformers while still using whisper 3 turbo. Do we know why the original change was made (to line 1237 of generation_whisper)?

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jul 16, 2025

cc @ebezzam if you can have a look

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants