Skip to content

Commit 9b3203f

Browse files
poke1024ebezzam
andauthored
Add callback to monitor progress in whisper transcription (#37483)
* Add callback to monitor progress in whisper transcription * Added `` around variables, rewording * Add example of `monitor_progress`. --------- Co-authored-by: Eric B <[email protected]>
1 parent 7abb5d3 commit 9b3203f

File tree

1 file changed

+36
-2
lines changed

1 file changed

+36
-2
lines changed

src/transformers/models/whisper/generation_whisper.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ def generate(
410410
return_segments: bool = False,
411411
return_dict_in_generate: Optional[bool] = None,
412412
force_unique_generate_call: Optional[bool] = None,
413+
monitor_progress: Optional[Callable[[torch.Tensor], None]] = None,
413414
**kwargs,
414415
):
415416
"""
@@ -461,6 +462,7 @@ def generate(
461462
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
462463
return_timestamps (`bool`, *optional*):
463464
Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
465+
For audios longer than 30 seconds, it is necessary to set `return_timestamps=True`.
464466
task (`str`, *optional*):
465467
Task to use for generation, either "translate" or "transcribe".
466468
language (`str` or list of `str`, *optional*):
@@ -533,14 +535,19 @@ def generate(
533535
force_unique_generate_call (`bool`, *optional*):
534536
Whether to force a unique call to the underlying GenerationMixin's [`~generation.GenerationMixin.generate`] method. This is useful for assisted decoding and testing purposes to ensure
535537
that only one call to [`~generation.GenerationMixin.generate`] is made and therefore decoder input token ids and eos token ids are returned.
538+
monitor_progress (`Callable[[torch.Tensor], None]`, *optional*):
539+
If provided, this function can be called to report the progress of the audio transcription. The function
540+
takes a tensor argument `p` of shape `(n, 2)`, where `n` is the batch size. `p[i, 0]` contains the
541+
index of the audio frame that is currently being transcribed for batch item `i`. `p[i, 1]` contains
542+
the total number of frames for batch item `i`. No return value is expected.
536543
kwargs (`dict[str, Any]`, *optional*):
537544
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
538545
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
539546
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
540547
Return:
541548
[`~utils.ModelOutput`] or `dict[str, Any]` or `torch.LongTensor`:
542549
543-
A:
550+
One of the following:
544551
- [`~utils.ModelOutput`] when `return_dict_in_generate=True` and (`return_timestamps=False` or `force_unique_generate_call=True`), including the decoder input ids and end of sequence id.
545552
- `dict[str, Any]` when (`return_dict_in_generate=True` and `return_timestamps=True`) or `return_segments=True` or `return_token_timestamps=True`.
546553
- `torch.LongTensor` in all other cases, excluding the decoder input ids and end of sequence id.
@@ -586,13 +593,37 @@ def generate(
586593
>>> inputs = inputs.to("cuda", torch.float32)
587594
588595
>>> # transcribe audio to ids
589-
>>> generated_ids = model.generate(**inputs)
596+
>>> generated_ids = model.generate(**inputs, return_timestamps=True)
590597
591598
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
592599
>>> transcription[0]
593600
" Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile."
594601
```
595602
603+
The `monitor_progress` callback can be used to monitor the progress of the transcription:
604+
```python
605+
>>> from tqdm import tqdm
606+
607+
>>> # prepare inputs like above
608+
609+
>>> # define a callback to monitor the progress of the transcription.
610+
>>> with tqdm(desc="Progress") as pbar:
611+
>>> def monitor_progress(p_batch):
612+
>>> i = torch.argmax(p_batch[:, 1])
613+
>>> p = p_batch[i].detach().cpu()
614+
>>> pbar.total = int(p[1])
615+
>>> pbar.n = int(p[0])
616+
>>> pbar.update()
617+
618+
>>> # transcribe audio to ids
619+
>>> generated_ids = model.generate(**inputs, return_timestamps=True, monitor_progress=monitor_progress)
620+
621+
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
622+
>>> transcription[0]
623+
Progress: 95%|█████████████████████████████████████████████████████████████████████████████████████████████████▎ | 8497/8901 [00:04<00:00, 2052.79it/s]
624+
" Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile."
625+
```
626+
596627
- *Shortform transcription*: If passed mel input features are <= 30 seconds, there are two possibilities:
597628
- `return_timestamps=False`: the whole audio will be transcribed with a single call to GenerationMixin's [`~generation.GenerationMixin.generate`].
598629
- `return_timestamps=True`: the audio will be transcribed using the same logic as long-form transcription.
@@ -763,6 +794,9 @@ def generate(
763794

764795
# 6 Transcribe audio until we reach the end of all input audios
765796
while (seek < max_frames).any():
797+
if monitor_progress is not None:
798+
monitor_progress(torch.stack((seek, max_frames), dim=1))
799+
766800
# 6.1 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop
767801
# in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order
768802
# to know which original audio is being decoded

0 commit comments

Comments
 (0)