Skip to content

Conversation

eustlb
Copy link
Contributor

@eustlb eustlb commented Jan 20, 2025

What does this PR do?

Adds StyleTTS 2 to support the original model but also other checkpoints like Kokoro.

🆕 🔥 This implementation also add batch support (early benchmarks show ~50% inference speed improvement for BS 128) and mask support for padded inputs.

Note

This implementation differs slightly from the original codebase. The aim here is to have the clearest possible correspondence of naming and structure compared to the original papers (StyleTTS 2 that builds on top of StyleTTS).

TODO

  • modeling code → ready for review
  • tokenization → add other languages + docstring
  • processor → add docstring
  • tests
  • model_doc

Benchmarks

BS 1

@eustlb eustlb marked this pull request as ready for review February 25, 2025 11:33
@eustlb eustlb requested a review from Cyrilvallez February 25, 2025 11:37
@eustlb
Copy link
Contributor Author

eustlb commented Feb 25, 2025

cc @Cyrilvallez, modeling code is ready for review 🤗

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright! Super sorry about the (very) long delay! Very very nice implementation, not much to complain about here! Great work! 🤗
Mostly just a few recurrent but easy to fix stuff:

  • Use only the config in __init__ if possible. I did not check each one, and some may have additional params that do not really make sense in the config as they are hard-coded, such as the boolean args in StyleTextToSpeech2AdainResBlock1d, but let's still try to use config as much as possible (and add additional parameters if still needed for internal-only purposes)
  • If possible, all the transpose ops should have explicit dim number, let's try not to use -1 if number of dims is always the same (not sure here)
  • The weight_norm function should probably always be one or the other, as layer names depend on it (if needed, let's use a hard pytorch version check for this one)

Comment on lines 149 to 150
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can only use this one anyway due to the layer naming in the state dicts no? Did you add the check for older torch versions? Also, IMO it would help clarity to wrap the layer declaration directly with the function, as opposed to applying it afterwards

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True! We do not need to check for older torch versions right? since parametrizations.weight_norm is supported anyway for torch>=2.1

batch_first=True,
enforce_sorted=False
)
self.lstm.flatten_parameters()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Humm, do we need this call every forward?

class StyleTextToSpeech2Decoder(StyleTextToSpeech2PretrainedModel):
base_model_prefix = "decoder"
config_class = StyleTextToSpeech2DecoderConfig
main_input_name = "hidden_states"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it's actually different from the default, but with a quick glance at the library i'm not sure it's useful to set it, but not sure

@Cyrilvallez
Copy link
Member

Also, I assume modular could not be used at all here?

@monuminu
Copy link

any update on this ?

@eustlb
Copy link
Contributor Author

eustlb commented Jul 28, 2025

Hey @monuminu, this PR went stale due to other priorities. I’ll try to get it moving ASAP, though I’m not sure I’ll manage before the summer break. Thanks for checking in, I’ll bump its priority.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants