-
Notifications
You must be signed in to change notification settings - Fork 30.2k
Add StyleTTS 2 #35790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add StyleTTS 2 #35790
Conversation
cc @Cyrilvallez, modeling code is ready for review 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright! Super sorry about the (very) long delay! Very very nice implementation, not much to complain about here! Great work! 🤗
Mostly just a few recurrent but easy to fix stuff:
- Use only the config in
__init__
if possible. I did not check each one, and some may have additional params that do not really make sense in the config as they are hard-coded, such as the boolean args inStyleTextToSpeech2AdainResBlock1d
, but let's still try to use config as much as possible (and add additional parameters if still needed for internal-only purposes) - If possible, all the
transpose
ops should have explicit dim number, let's try not to use-1
if number of dims is always the same (not sure here) - The weight_norm function should probably always be one or the other, as layer names depend on it (if needed, let's use a hard pytorch version check for this one)
src/transformers/models/style_text_to_speech_2/configuration_style_text_to_speech_2.py
Outdated
Show resolved
Hide resolved
src/transformers/models/style_text_to_speech_2/convert_kokoro_weights_to_hf.py
Show resolved
Hide resolved
if hasattr(nn.utils.parametrizations, "weight_norm"): | ||
weight_norm = nn.utils.parametrizations.weight_norm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can only use this one anyway due to the layer naming in the state dicts no? Did you add the check for older torch versions? Also, IMO it would help clarity to wrap the layer declaration directly with the function, as opposed to applying it afterwards
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True! We do not need to check for older torch versions right? since parametrizations.weight_norm
is supported anyway for torch>=2.1
src/transformers/models/style_text_to_speech_2/modeling_style_text_to_speech_2.py
Outdated
Show resolved
Hide resolved
batch_first=True, | ||
enforce_sorted=False | ||
) | ||
self.lstm.flatten_parameters() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Humm, do we need this call every forward?
src/transformers/models/style_text_to_speech_2/modeling_style_text_to_speech_2.py
Outdated
Show resolved
Hide resolved
src/transformers/models/style_text_to_speech_2/modeling_style_text_to_speech_2.py
Outdated
Show resolved
Hide resolved
class StyleTextToSpeech2Decoder(StyleTextToSpeech2PretrainedModel): | ||
base_model_prefix = "decoder" | ||
config_class = StyleTextToSpeech2DecoderConfig | ||
main_input_name = "hidden_states" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here it's actually different from the default, but with a quick glance at the library i'm not sure it's useful to set it, but not sure
src/transformers/models/style_text_to_speech_2/processing_style_text_to_speech_2.py
Outdated
Show resolved
Hide resolved
src/transformers/models/style_text_to_speech_2/processing_style_text_to_speech_2.py
Outdated
Show resolved
Hide resolved
Also, I assume modular could not be used at all here? |
any update on this ? |
Hey @monuminu, this PR went stale due to other priorities. I’ll try to get it moving ASAP, though I’m not sure I’ll manage before the summer break. Thanks for checking in, I’ll bump its priority. |
What does this PR do?
Adds StyleTTS 2 to support the original model but also other checkpoints like Kokoro.
🆕 🔥 This implementation also add batch support (early benchmarks show ~50% inference speed improvement for BS 128) and mask support for padded inputs.
Note
This implementation differs slightly from the original codebase. The aim here is to have the clearest possible correspondence of naming and structure compared to the original papers (StyleTTS 2 that builds on top of StyleTTS).
TODO
Benchmarks
BS 1