Skip to content

Conversation

poke1024
Copy link
Contributor

This PR adds a callback in the generate function of WhisperGenerationMixin to give callers the ability to monitor progress for whisper transcriptions.

This is useful in settings where transcription happens in a notebook or UI settings, and callers want to provide users with a progress bar or similar feedback on the progress of long running calls (e.g. >1 minute).

Reviewer suggestion: @eustlb

@github-actions github-actions bot marked this pull request as draft April 14, 2025 08:05
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@poke1024
Copy link
Contributor Author

Example code:

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset


device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch_dtype,
    device=device,
)

dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]

def monitor_progress(p):
    print(p)

generate_kwargs = {
    "monitor_progress": monitor_progress,
    "return_timestamps": True
}

result = pipe(sample,generate_kwargs=generate_kwargs)
print(result["text"])

Will output progress as:

tensor([[   0, 6245]])
tensor([[2868, 6245]])
tensor([[5724, 6245]])

@poke1024 poke1024 marked this pull request as ready for review April 14, 2025 09:01
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not!
happy to add this if you can also provide an example of a useful monitor progress a a tad bit of doc! 🤗

@poke1024
Copy link
Contributor Author

poke1024 commented May 2, 2025

https://colab.research.google.com/drive/1wEIr1m7-D-EN9M_ygo388bjF8dMK1Oan?usp=sharing demonstrates the callback through a tqdm progress bar in a notebook.

As Colab can be sluggish, here is a real-time screen recording showing the same notebook doing a transcription of a 10 minute audio file on a Macbook M4 Max using mps: https://www.dropbox.com/scl/fi/6dzdh1konw5aj7iufr6b6/whisper_progress_monitor.mov?rlkey=kycpt3o4h84e6pzhbp7as8eft&st=ulwz502l&dl=0

This new monitor callback would also be very useful for a web app that allows running various ML tasks via UI that I am currently developing for my PhD thesis.

@ArthurZucker
Copy link
Collaborator

cc @ebezzam !

@ebezzam
Copy link
Contributor

ebezzam commented Jul 17, 2025

@poke1024 thanks for the contribution! Indeed it's a nice feature to keep track of progress for long transcription 👏

Could you resync with main? Your current snippet didn't work for me. I had to directly pass return_timestamps=True to pipe (maybe something changed since you last tried):

generate_kwargs = {
    "monitor_progress": monitor_progress,
}
result = pipe(sample, return_timestamps=True, generate_kwargs=generate_kwargs)

@poke1024
Copy link
Contributor Author

poke1024 commented Jul 28, 2025

@ebezzam Synced with main, updated the Collab example code to your version (the return_timestamps indeed seems to have gone from generate_kwargs into pipe since the last version). The current tests_torch fail looks unrelated to me.

@ArthurZucker ArthurZucker requested a review from ebezzam July 28, 2025 11:32
Copy link
Contributor

@ebezzam ebezzam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @poke1024! Yes the failing tests are unrelated.

I've also updated the docstrings to show an example of your new feature.

@ArthurZucker LGTM for merging 👍

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks both!

@ebezzam ebezzam enabled auto-merge (squash) July 30, 2025 14:59
@ebezzam ebezzam disabled auto-merge July 30, 2025 15:00
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: whisper

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ebezzam ebezzam merged commit 9b3203f into huggingface:main Jul 30, 2025
25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants