Skip to content

Fix Bark failing tests #39478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3988,6 +3988,8 @@ def _beam_search(
vocab_size = self.config.audio_vocab_size
elif self.__class__.__name__ == "ImageGPTForCausalImageModeling":
vocab_size = self.get_output_embeddings().out_features
elif self.__class__.__name__ == "BarkSemanticModel":
vocab_size = self.config.output_vocab_size
else:
vocab_size = self.config.get_text_config().vocab_size
decoder_prompt_len = cur_len
Expand Down
76 changes: 58 additions & 18 deletions src/transformers/models/bark/processing_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ class BarkProcessor(ProcessorMixin):
attributes = ["tokenizer"]

preset_shape = {
"semantic_prompt": 1,
"coarse_prompt": 2,
"fine_prompt": 2,
"semantic_prompt": 1, # 1D array of shape (X,)
"coarse_prompt": 2, # 2D array of shape (2,X)
"fine_prompt": 2, # 2D array of shape (8,X)
Comment on lines +56 to +58
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments for clarity

}

def __init__(self, tokenizer, speaker_embeddings=None):
Expand Down Expand Up @@ -114,6 +114,9 @@ def from_pretrained(
else:
speaker_embeddings = None

if speaker_embeddings is not None:
if "repo_or_path" in speaker_embeddings:
speaker_embeddings["repo_or_path"] = pretrained_processor_name_or_path
Comment on lines +118 to +120
Copy link
Contributor Author

@ebezzam ebezzam Jul 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because Suno models are badly configured to get speaker embedding from Yoach's checkpoints (see repo_or_path):

So when used from_pretrained, models get speaker embedding from Yoach's checkpoint.

Best is to probably open PRs on the Hub to fix the repo_or_path entry and remove these lines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've asked Suno team to merge these two PRs but still waiting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if Suno doesn't merge, do we still keep pulling from Yoach and remove these lines?

cc @eustlb

tokenizer = AutoTokenizer.from_pretrained(pretrained_processor_name_or_path, **kwargs)

return cls(tokenizer=tokenizer, speaker_embeddings=speaker_embeddings)
Expand Down Expand Up @@ -153,22 +156,21 @@ def save_pretrained(

embeddings_dict["repo_or_path"] = save_directory

for prompt_key in self.speaker_embeddings:
if prompt_key != "repo_or_path":
Comment on lines -157 to -158
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main change is I added a property to easily get available voice presets. Rest is indenting inwards since we don't need prompt_key != "repo_or_path" anymore

voice_preset = self._load_voice_preset(prompt_key)
for prompt_key in self.available_voice_presets:
voice_preset = self._load_voice_preset(prompt_key)

tmp_dict = {}
for key in self.speaker_embeddings[prompt_key]:
np.save(
os.path.join(
embeddings_dict["repo_or_path"], speaker_embeddings_directory, f"{prompt_key}_{key}"
),
voice_preset[key],
allow_pickle=False,
)
tmp_dict[key] = os.path.join(speaker_embeddings_directory, f"{prompt_key}_{key}.npy")
tmp_dict = {}
for key in self.speaker_embeddings[prompt_key]:
np.save(
os.path.join(
embeddings_dict["repo_or_path"], speaker_embeddings_directory, f"{prompt_key}_{key}"
),
voice_preset[key],
allow_pickle=False,
)
tmp_dict[key] = os.path.join(speaker_embeddings_directory, f"{prompt_key}_{key}.npy")

embeddings_dict[prompt_key] = tmp_dict
embeddings_dict[prompt_key] = tmp_dict

with open(os.path.join(save_directory, speaker_embeddings_dict_path), "w") as fp:
json.dump(embeddings_dict, fp)
Expand Down Expand Up @@ -222,6 +224,43 @@ def _validate_voice_preset_dict(self, voice_preset: Optional[dict] = None):
if len(voice_preset[key].shape) != self.preset_shape[key]:
raise ValueError(f"{key} voice preset must be a {str(self.preset_shape[key])}D ndarray.")

@property
def available_voice_presets(self) -> list:
"""
Returns a list of available voice presets.

Returns:
`list[str]`: A list of voice preset names.
"""
if self.speaker_embeddings is None:
return []

voice_presets = list(self.speaker_embeddings.keys())
if "repo_or_path" in voice_presets:
voice_presets.remove("repo_or_path")
return voice_presets

def _verify_speaker_embeddings(self, remove_unavailable: bool = True):
# check which actually downloaded properly / are available
unavailable_keys = []
if self.speaker_embeddings is not None:
for voice_preset in self.available_voice_presets:
try:
voice_preset_dict = self._load_voice_preset(voice_preset)
self._validate_voice_preset_dict(voice_preset_dict)
except Exception:
unavailable_keys.append(voice_preset)

if unavailable_keys:
logger.warning(
f"The following {len(unavailable_keys)} speaker embeddings are not available: {unavailable_keys} "
"If you would like to use them, please check the paths or try downloading them again."
)

if remove_unavailable:
for voice_preset in unavailable_keys:
del self.speaker_embeddings[voice_preset]

def __call__(
self,
text=None,
Expand All @@ -247,7 +286,8 @@ def __call__(
voice_preset (`str`, `dict[np.ndarray]`):
The voice preset, i.e the speaker embeddings. It can either be a valid voice_preset name, e.g
`"en_speaker_1"`, or directly a dictionary of `np.ndarray` embeddings for each submodel of `Bark`. Or
it can be a valid file name of a local `.npz` single voice preset.
it can be a valid file name of a local `.npz` single voice preset containing the keys
`"semantic_prompt"`, `"coarse_prompt"` and `"fine_prompt"`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:

Expand Down
8 changes: 8 additions & 0 deletions tests/models/bark/test_processor_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ def test_save_load_pretrained_additional_features(self):
pretrained_processor_name_or_path=self.checkpoint,
speaker_embeddings_dict_path=self.speaker_embeddings_dict_path,
)
"""
TODO (ebezzam) not all speaker embedding are properly downloaded.
My hypothesis: there are many files (~700 speaker embeddings) and some fail to download (not the same at different first runs)
https://github.com/huggingface/transformers/blob/967045082faaaaf3d653bfe665080fd746b2bb60/src/transformers/models/bark/processing_bark.py#L89
https://github.com/huggingface/transformers/blob/967045082faaaaf3d653bfe665080fd746b2bb60/src/transformers/models/bark/processing_bark.py#L188
So for testing purposes, we will remove the unavailable speaker embeddings before saving.
"""
processor._verify_speaker_embeddings(remove_unavailable=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TLDR need to remove speaker embedding which couldn't download properly. Maybe because too many? (~700)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Format nit: let's use single-line comments (#...), and add an HF team member (@eustlb ?) in the TODO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure! and I'm a new HF team member working with @eustlb 😉

processor.save_pretrained(
self.tmpdirname,
speaker_embeddings_dict_path=self.speaker_embeddings_dict_path,
Expand Down