-
Notifications
You must be signed in to change notification settings - Fork 30.4k
Fix BatchEncoding.to() for nested elements #38985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
bee69bb
to
32137fd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, LGTM! Thank you for the PR.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neat
cc @itazap for tokenizers, feel free to merge it if you're happy |
@Rocketknight1, @itazap any update on this PR? I can confirm it would address #34634 and a Bark test mentioned in #39478 🙂 |
Good point - I took another look and I think this is safe, so I'm going to merge! If it breaks anything, anyone finding this PR can yell at me 😅 |
What does this PR do?
Extend
BatchEncoding.to()
to also work for nested elements.When using voice presets in Bark, the processor returns a
BatchEncoding
of{ "input_ids": torch.Tensor, "attention_mask": torch.Tensor, "history_prompt": BatchFeature}
Currently, only tensor elements are moved, so running on
cuda
the following code fails withRuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
:A workaround was to manually do
inputs["history_prompt"].to(device)
. This PR fixes this by moving all nested elements with a callableto()
.Fixes #34634
Before submitting
Pull Request section?
to it if that's the case: BarkProcessor voice_preset doesn't work #34634 (comment)
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@Rocketknight1