Skip to content

Whisper .generate() function not respecting max_new_tokens or max_length #36183

@mitchelldehaven

Description

@mitchelldehaven

System Info

transformers-cli env

- `transformers` version: 4.48.3
- Platform: Linux-6.11.0-1003-nvidia-x86_64-with-glibc2.39
- Python version: 3.11.11
- Huggingface_hub version: 0.28.1
- Safetensors version: 0.5.2
- Accelerate version: not installed
- Accelerate config: not found
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: no
- Using GPU in script?: yes
- GPU type: NVIDIA GeForce RTX 3090

Who can help?

Tagging relevant speech: @ylacombe, @eustlb

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm running Whisper for clearing out potentially incorrect transcripts from the yodas2 dataset.

When using the .generate() function with max_new_tokens=64, warnings indicate:

Token indices sequence length is longer than the specified maximum sequence length for this model (3572 > 1024). Running this sequence through the model will result in indexing errors

I manually checked the transcripts tracked to that point and found that the max length transcript didn't appear to be this large. I assume that max_new_tokens stops the model from generating after that point, however the logging message makes it sound as though processing continued and the returned tokens are truncated? I tried using max_length=70 as well just to see if it prevented this message, but it did not.

Relevant portion of script used:

    model_name = "distil-whisper/distil-small.en"
    processor = WhisperProcessor.from_pretrained(model_name)
    tokenizer = WhisperTokenizer.from_pretrained(model_name)
    model = WhisperForConditionalGeneration.from_pretrained(model_name, use_flash_attention_2=True).cuda()
    model.compile()
    partial_collate_fn = partial(collate_fn, processor=processor, tokenizer=tokenizer)
    dataset = Yodas2Dataset(dir, data, processor, tokenizer)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=8, shuffle=False, collate_fn=partial_collate_fn)
    transcripts = []
    with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
        for batch in tqdm(dataloader):
            audio_features = batch["input_features"].cuda()
            outputs = model.generate(audio_features, max_length=70, max_new_tokens=64, use_cache=True)
            batch_transcripts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            transcripts.extend(batch_transcripts)

Expected behavior

I would expect that the indexing warning should not trigger, given that specified max_new_tokens is under the max supported seq len.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions