Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ def generate(
return_segments: bool = False,
return_dict_in_generate: Optional[bool] = None,
force_unique_generate_call: Optional[bool] = None,
monitor_progress: Optional[Callable[[torch.Tensor], None]] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -461,6 +462,7 @@ def generate(
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
return_timestamps (`bool`, *optional*):
Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
For audios longer than 30 seconds, it is necessary to set `return_timestamps=True`.
task (`str`, *optional*):
Task to use for generation, either "translate" or "transcribe".
language (`str` or list of `str`, *optional*):
Expand Down Expand Up @@ -533,14 +535,19 @@ def generate(
force_unique_generate_call (`bool`, *optional*):
Whether to force a unique call to the underlying GenerationMixin's [`~generation.GenerationMixin.generate`] method. This is useful for assisted decoding and testing purposes to ensure
that only one call to [`~generation.GenerationMixin.generate`] is made and therefore decoder input token ids and eos token ids are returned.
monitor_progress (`Callable[[torch.Tensor], None]`, *optional*):
If provided, this function can be called to report the progress of the audio transcription. The function
takes a tensor argument `p` of shape `(n, 2)`, where `n` is the batch size. `p[i, 0]` contains the
index of the audio frame that is currently being transcribed for batch item `i`. `p[i, 1]` contains
the total number of frames for batch item `i`. No return value is expected.
kwargs (`dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
Return:
[`~utils.ModelOutput`] or `dict[str, Any]` or `torch.LongTensor`:

A:
One of the following:
- [`~utils.ModelOutput`] when `return_dict_in_generate=True` and (`return_timestamps=False` or `force_unique_generate_call=True`), including the decoder input ids and end of sequence id.
- `dict[str, Any]` when (`return_dict_in_generate=True` and `return_timestamps=True`) or `return_segments=True` or `return_token_timestamps=True`.
- `torch.LongTensor` in all other cases, excluding the decoder input ids and end of sequence id.
Expand Down Expand Up @@ -586,13 +593,37 @@ def generate(
>>> inputs = inputs.to("cuda", torch.float32)

>>> # transcribe audio to ids
>>> generated_ids = model.generate(**inputs)
>>> generated_ids = model.generate(**inputs, return_timestamps=True)

>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
>>> transcription[0]
" Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile."
```

The `monitor_progress` callback can be used to monitor the progress of the transcription:
```python
>>> from tqdm import tqdm

>>> # prepare inputs like above

>>> # define a callback to monitor the progress of the transcription.
>>> with tqdm(desc="Progress") as pbar:
>>> def monitor_progress(p_batch):
>>> i = torch.argmax(p_batch[:, 1])
>>> p = p_batch[i].detach().cpu()
>>> pbar.total = int(p[1])
>>> pbar.n = int(p[0])
>>> pbar.update()

>>> # transcribe audio to ids
>>> generated_ids = model.generate(**inputs, return_timestamps=True, monitor_progress=monitor_progress)

>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
>>> transcription[0]
Progress: 95%|█████████████████████████████████████████████████████████████████████████████████████████████████▎ | 8497/8901 [00:04<00:00, 2052.79it/s]
" Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile."
```

- *Shortform transcription*: If passed mel input features are <= 30 seconds, there are two possibilities:
- `return_timestamps=False`: the whole audio will be transcribed with a single call to GenerationMixin's [`~generation.GenerationMixin.generate`].
- `return_timestamps=True`: the audio will be transcribed using the same logic as long-form transcription.
Expand Down Expand Up @@ -763,6 +794,9 @@ def generate(

# 6 Transcribe audio until we reach the end of all input audios
while (seek < max_frames).any():
if monitor_progress is not None:
monitor_progress(torch.stack((seek, max_frames), dim=1))

# 6.1 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop
# in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order
# to know which original audio is being decoded
Expand Down