-
Notifications
You must be signed in to change notification settings - Fork 29.8k
Description
System Info
transformers-cli env
- `transformers` version: 4.48.3
- Platform: Linux-6.11.0-1003-nvidia-x86_64-with-glibc2.39
- Python version: 3.11.11
- Huggingface_hub version: 0.28.1
- Safetensors version: 0.5.2
- Accelerate version: not installed
- Accelerate config: not found
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: no
- Using GPU in script?: yes
- GPU type: NVIDIA GeForce RTX 3090
Who can help?
Tagging relevant speech: @ylacombe, @eustlb
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
I'm running Whisper for clearing out potentially incorrect transcripts from the yodas2
dataset.
When using the .generate()
function with max_new_tokens=64
, warnings indicate:
Token indices sequence length is longer than the specified maximum sequence length for this model (3572 > 1024). Running this sequence through the model will result in indexing errors
I manually checked the transcripts tracked to that point and found that the max length transcript didn't appear to be this large. I assume that max_new_tokens
stops the model from generating after that point, however the logging message makes it sound as though processing continued and the returned tokens are truncated? I tried using max_length=70
as well just to see if it prevented this message, but it did not.
Relevant portion of script used:
model_name = "distil-whisper/distil-small.en"
processor = WhisperProcessor.from_pretrained(model_name)
tokenizer = WhisperTokenizer.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name, use_flash_attention_2=True).cuda()
model.compile()
partial_collate_fn = partial(collate_fn, processor=processor, tokenizer=tokenizer)
dataset = Yodas2Dataset(dir, data, processor, tokenizer)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=8, shuffle=False, collate_fn=partial_collate_fn)
transcripts = []
with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
for batch in tqdm(dataloader):
audio_features = batch["input_features"].cuda()
outputs = model.generate(audio_features, max_length=70, max_new_tokens=64, use_cache=True)
batch_transcripts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
transcripts.extend(batch_transcripts)
Expected behavior
I would expect that the indexing warning should not trigger, given that specified max_new_tokens is under the max supported seq len.