Skip to content

Conversation

eginhard
Copy link
Contributor

@eginhard eginhard commented Jun 23, 2025

What does this PR do?

Extend BatchEncoding.to() to also work for nested elements.

When using voice presets in Bark, the processor returns a BatchEncoding of

  • { "input_ids": torch.Tensor, "attention_mask": torch.Tensor, "history_prompt": BatchFeature}

Currently, only tensor elements are moved, so running on cuda the following code fails with RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select):

import scipy
import torch

from transformers import AutoProcessor
from transformers import BarkModel

model = BarkModel.from_pretrained("suno/bark-small")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = model.to(device)

sampling_rate = model.generation_config.sample_rate
processor = AutoProcessor.from_pretrained("suno/bark-small")
voice_preset = "v2/en_speaker_6"

# prepare the inputs
text_prompt = "Let's try generating speech, with Bark, a text-to-speech model"
inputs = processor(text_prompt, voice_preset=voice_preset)

# generate speech
speech_output = model.generate(**inputs.to(device))
scipy.io.wavfile.write("bark_out.wav", rate=sampling_rate, data=speech_output[0].cpu().numpy())

A workaround was to manually do inputs["history_prompt"].to(device). This PR fixes this by moving all nested elements with a callable to().

Fixes #34634

Before submitting

Who can review?

@Rocketknight1

@Rocketknight1 Rocketknight1 force-pushed the fix-batchencoding-to branch from bee69bb to 32137fd Compare June 23, 2025 14:29
Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, LGTM! Thank you for the PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat

@Rocketknight1
Copy link
Member

cc @itazap for tokenizers, feel free to merge it if you're happy

@ebezzam ebezzam mentioned this pull request Jul 17, 2025
@ebezzam
Copy link
Contributor

ebezzam commented Jul 17, 2025

@Rocketknight1, @itazap any update on this PR?

I can confirm it would address #34634 and a Bark test mentioned in #39478 🙂

@Rocketknight1
Copy link
Member

Good point - I took another look and I think this is safe, so I'm going to merge! If it breaks anything, anyone finding this PR can yell at me 😅

@Rocketknight1 Rocketknight1 merged commit 561a79a into huggingface:main Jul 18, 2025
20 checks passed
@eginhard eginhard deleted the fix-batchencoding-to branch July 18, 2025 13:17
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Jul 22, 2025
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BarkProcessor voice_preset doesn't work
5 participants