diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index acd53a20b79c..c1b0e6a7b746 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4005,6 +4005,8 @@ def _beam_search( vocab_size = self.config.audio_vocab_size elif self.__class__.__name__ == "ImageGPTForCausalImageModeling": vocab_size = self.get_output_embeddings().out_features + elif self.__class__.__name__ == "BarkSemanticModel": + vocab_size = self.config.output_vocab_size else: vocab_size = self.config.get_text_config().vocab_size decoder_prompt_len = cur_len diff --git a/src/transformers/models/bark/processing_bark.py b/src/transformers/models/bark/processing_bark.py index 9825a34a3ba0..155f15cced20 100644 --- a/src/transformers/models/bark/processing_bark.py +++ b/src/transformers/models/bark/processing_bark.py @@ -53,9 +53,9 @@ class BarkProcessor(ProcessorMixin): attributes = ["tokenizer"] preset_shape = { - "semantic_prompt": 1, - "coarse_prompt": 2, - "fine_prompt": 2, + "semantic_prompt": 1, # 1D array of shape (X,) + "coarse_prompt": 2, # 2D array of shape (2,X) + "fine_prompt": 2, # 2D array of shape (8,X) } def __init__(self, tokenizer, speaker_embeddings=None): @@ -115,6 +115,9 @@ def from_pretrained( else: speaker_embeddings = None + if speaker_embeddings is not None: + if "repo_or_path" in speaker_embeddings: + speaker_embeddings["repo_or_path"] = pretrained_processor_name_or_path tokenizer = AutoTokenizer.from_pretrained(pretrained_processor_name_or_path, **kwargs) return cls(tokenizer=tokenizer, speaker_embeddings=speaker_embeddings) @@ -154,22 +157,21 @@ def save_pretrained( embeddings_dict["repo_or_path"] = save_directory - for prompt_key in self.speaker_embeddings: - if prompt_key != "repo_or_path": - voice_preset = self._load_voice_preset(prompt_key) + for prompt_key in self.available_voice_presets: + voice_preset = self._load_voice_preset(prompt_key) - tmp_dict = {} - for key in self.speaker_embeddings[prompt_key]: - np.save( - os.path.join( - embeddings_dict["repo_or_path"], speaker_embeddings_directory, f"{prompt_key}_{key}" - ), - voice_preset[key], - allow_pickle=False, - ) - tmp_dict[key] = os.path.join(speaker_embeddings_directory, f"{prompt_key}_{key}.npy") + tmp_dict = {} + for key in self.speaker_embeddings[prompt_key]: + np.save( + os.path.join( + embeddings_dict["repo_or_path"], speaker_embeddings_directory, f"{prompt_key}_{key}" + ), + voice_preset[key], + allow_pickle=False, + ) + tmp_dict[key] = os.path.join(speaker_embeddings_directory, f"{prompt_key}_{key}.npy") - embeddings_dict[prompt_key] = tmp_dict + embeddings_dict[prompt_key] = tmp_dict with open(os.path.join(save_directory, speaker_embeddings_dict_path), "w") as fp: json.dump(embeddings_dict, fp) @@ -223,6 +225,45 @@ def _validate_voice_preset_dict(self, voice_preset: Optional[dict] = None): if len(voice_preset[key].shape) != self.preset_shape[key]: raise ValueError(f"{key} voice preset must be a {str(self.preset_shape[key])}D ndarray.") + @property + def available_voice_presets(self) -> list: + """ + Returns a list of available voice presets. + + Returns: + `list[str]`: A list of voice preset names. + """ + if self.speaker_embeddings is None: + return [] + + voice_presets = list(self.speaker_embeddings.keys()) + if "repo_or_path" in voice_presets: + voice_presets.remove("repo_or_path") + return voice_presets + + def _verify_speaker_embeddings(self, remove_unavailable: bool = True): + # check which actually downloaded properly / are available + unavailable_keys = [] + if self.speaker_embeddings is not None: + for voice_preset in self.available_voice_presets: + try: + voice_preset_dict = self._load_voice_preset(voice_preset) + except ValueError: + # error from `_load_voice_preset` of path not existing + unavailable_keys.append(voice_preset) + continue + self._validate_voice_preset_dict(voice_preset_dict) + + if unavailable_keys: + logger.warning( + f"The following {len(unavailable_keys)} speaker embeddings are not available: {unavailable_keys} " + "If you would like to use them, please check the paths or try downloading them again." + ) + + if remove_unavailable: + for voice_preset in unavailable_keys: + del self.speaker_embeddings[voice_preset] + def __call__( self, text=None, @@ -248,7 +289,8 @@ def __call__( voice_preset (`str`, `dict[np.ndarray]`): The voice preset, i.e the speaker embeddings. It can either be a valid voice_preset name, e.g `"en_speaker_1"`, or directly a dictionary of `np.ndarray` embeddings for each submodel of `Bark`. Or - it can be a valid file name of a local `.npz` single voice preset. + it can be a valid file name of a local `.npz` single voice preset containing the keys + `"semantic_prompt"`, `"coarse_prompt"` and `"fine_prompt"`. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: diff --git a/tests/models/bark/test_processor_bark.py b/tests/models/bark/test_processor_bark.py index 447d38b95654..e20c3b302f33 100644 --- a/tests/models/bark/test_processor_bark.py +++ b/tests/models/bark/test_processor_bark.py @@ -55,6 +55,13 @@ def test_save_load_pretrained_additional_features(self): pretrained_processor_name_or_path=self.checkpoint, speaker_embeddings_dict_path=self.speaker_embeddings_dict_path, ) + + # TODO (ebezzam) not all speaker embedding are properly downloaded. + # My hypothesis: there are many files (~700 speaker embeddings) and some fail to download (not the same at different first runs) + # https://github.com/huggingface/transformers/blob/967045082faaaaf3d653bfe665080fd746b2bb60/src/transformers/models/bark/processing_bark.py#L89 + # https://github.com/huggingface/transformers/blob/967045082faaaaf3d653bfe665080fd746b2bb60/src/transformers/models/bark/processing_bark.py#L188 + # So for testing purposes, we will remove the unavailable speaker embeddings before saving. + processor._verify_speaker_embeddings(remove_unavailable=True) processor.save_pretrained( self.tmpdirname, speaker_embeddings_dict_path=self.speaker_embeddings_dict_path,