diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 990245de1742..823e2c052fc8 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3117,6 +3117,131 @@ def __call__( hidden_states = hidden_states / attn.rescale_output_factor return hidden_states + + +class F5TTSAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the F5 TTS model. It applies rotary embedding to selected heads, and allows MHA, GQA or MQA. + """ + + def __init__(self, pe_attn_head: Optional[int] = None): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "F5TTSAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + self.pe_attn_head = pe_attn_head + + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + head_dim = query.shape[-1] // attn.heads + kv_heads = key.shape[-1] // head_dim + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + + if kv_heads != attn.heads: + # if GQA or MQA, repeat the key/value heads to reach the number of query heads. + heads_per_kv_head = attn.heads // kv_heads + key = torch.repeat_interleave(key, heads_per_kv_head, dim=1, output_size=key.shape[1] * heads_per_kv_head) + value = torch.repeat_interleave( + value, heads_per_kv_head, dim=1, output_size=value.shape[1] * heads_per_kv_head + ) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if rotary_emb is not None: + query_dtype = query.dtype + key_dtype = key.dtype + query = query.to(torch.float32) + key = key.to(torch.float32) + + + if self.pe_attn_head is not None: + query_rotated = apply_rotary_emb( + query[..., :self.pe_attn_head, :], rotary_emb, use_real=True, use_real_unbind_dim=-2 + ) + + query_unrotated = query[..., self.pe_attn_head:, :] + query = torch.cat((query_rotated, query_unrotated), dim=-2) + + if not attn.is_cross_attention and self.pe_attn_head is not None: + key_to_rotate, key_unrotated = key[..., :self.pe_attn_head, :], key[..., self.pe_attn_head:, :] + key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) + key = torch.cat((key_rotated, key_unrotated), dim=-2) + + query = query.to(query_dtype) + key = key.to(key_dtype) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + + # these two need to be no-ops + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class HunyuanAttnProcessor2_0: diff --git a/src/diffusers/models/transformers/f5tts_transformer.py b/src/diffusers/models/transformers/f5tts_transformer.py new file mode 100644 index 000000000000..ff355fe33146 --- /dev/null +++ b/src/diffusers/models/transformers/f5tts_transformer.py @@ -0,0 +1,474 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations + +import torch +import torch.nn.functional as F +from torch import nn +import torchaudio +from ..normalization import GlobalResponseNorm, AdaLayerNorm, RMSNorm +import math +from ..embeddings import get_1d_rotary_pos_embed, apply_rotary_emb +from typing import Optional, Union +from ..attention_processor import F5TTSAttnProcessor2_0 +from ..attention import Attention +from einops import rearrange, repeat, reduce, pack, unpack +from torch import nn, einsum, tensor, Tensor, cat, stack, arange, is_tensor +# import jieba +# from pypinyin import Style, lazy_pinyin + + + +class AdaLayerNorm2(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 6) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb=None): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1) + + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + + + + + +class F5FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + # a bit different ordering from the diffusers FeedForward class, here the inner projection weight comes first, in diffusers the activation comes first + activation = nn.GELU(approximate=approximate) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation) + self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.ff(x) + + +# Attention with possible joint part +# modified from diffusers/src/diffusers/models/attention_processor.py + + + + + +def is_package_available(package_name: str) -> bool: + try: + import importlib + + package_exists = importlib.util.find_spec(package_name) is not None + return package_exists + except Exception: + return False + + + + +class DiTBlock(nn.Module): + def __init__( + self, + dim, + heads, + dim_head, + ff_mult=4, + dropout=0.1, + qk_norm=None, + pe_attn_head=None, + attn_backend="torch", # "torch" or "flash_attn" + attn_mask_enabled=True, + ): + super().__init__() + + self.attn_norm = AdaLayerNorm2(dim) + self.attn = Attention( + processor=F5TTSAttnProcessor2_0( + pe_attn_head=pe_attn_head, + ), + query_dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + qk_norm=qk_norm, + bias=True, + ) + + self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = F5FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + + def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding + # pre-norm & modulation for attention input + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) + + # attention + attn_output = self.attn(hidden_states=norm, attention_mask=mask, rotary_emb=rope) + + # process attention output for input x + x = x + gate_msa.unsqueeze(1) * attn_output + + norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(norm) + x = x + gate_mlp.unsqueeze(1) * ff_output + + return x + + +class ConvPositionEmbedding(nn.Module): + def __init__(self, dim, kernel_size=31, groups=16): + super().__init__() + assert kernel_size % 2 != 0 + self.conv1d = nn.Sequential( + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), + nn.Mish(), + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), + nn.Mish(), + ) + + def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): + if mask is not None: + mask = mask[..., None] + x = x.masked_fill(~mask, 0.0) + + x = x.permute(0, 2, 1) + x = self.conv1d(x) + out = x.permute(0, 2, 1) + + if mask is not None: + out = out.masked_fill(~mask, 0.0) + + return out + + + +class SinusPositionEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x, scale=1000): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + + +class TimestepEmbedding(nn.Module): + def __init__(self, dim, freq_embed_dim=256): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + + def forward(self, timestep: float["b"]): + time_hidden = self.time_embed(timestep) + time_hidden = time_hidden.to(timestep.dtype) + time = self.time_mlp(time_hidden) # b d + return time + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0): + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py + theta *= theta_rescale_factor ** (dim / (dim - 2)) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cos = torch.cos(freqs) # real part + freqs_sin = torch.sin(freqs) # imaginary part + return torch.cat([freqs_cos, freqs_sin], dim=-1) + +def get_pos_embed_indices(start, length, max_pos, scale=1.0): + # length = length if isinstance(length, int) else length.max() + scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar + pos = ( + start.unsqueeze(1) + + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long() + ) + # avoid extra long error. + pos = torch.where(pos < max_pos, pos, max_pos - 1) + return pos + + + +class ConvNeXtV2Block(nn.Module): + def __init__( + self, + dim: int, + intermediate_dim: int, + dilation: int = 1, + ): + super().__init__() + padding = (dilation * (7 - 1)) // 2 + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation + ) # depthwise conv + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.grn = GlobalResponseNorm(intermediate_dim) + self.pwconv2 = nn.Linear(intermediate_dim, dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = x.transpose(1, 2) # b n d -> b d n + x = self.dwconv(x) + x = x.transpose(1, 2) # b d n -> b n d + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + return residual + x + + + + +# Text embedding + + +class TextEmbedding(nn.Module): + def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2): + super().__init__() + self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token + + self.mask_padding = mask_padding # mask filler and batch padding tokens or not + + if conv_layers > 0: + self.extra_modeling = True + self.precompute_max_pos = 4096 # ~44s of 24khz audio + self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) + self.text_blocks = nn.Sequential( + *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)] + ) + else: + self.extra_modeling = False + + def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 + text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() + text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens + batch, text_len = text.shape[0], text.shape[1] + text = F.pad(text, (0, seq_len - text_len), value=0) + if self.mask_padding: + text_mask = text == 0 + + if drop_text: # cfg for text + text = torch.zeros_like(text) + + text = self.text_embed(text) # b n -> b n d + + # possible extra modeling + if self.extra_modeling: + # sinus pos emb + batch_start = torch.zeros((batch,), dtype=torch.long) + pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos) + text_pos_embed = self.freqs_cis[pos_idx] + text = text + text_pos_embed + + # convnextv2 blocks + if self.mask_padding: + text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) + for block in self.text_blocks: + text = block(text) + text = text.squeeze(0) # TODO for some reason an extra dimension is added + text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) + else: + text = self.text_blocks(text) + + return text + + +# noised input audio and context mixing embedding + + +class InputEmbedding(nn.Module): + def __init__(self, mel_dim, text_dim, out_dim): + super().__init__() + self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) + self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) + + def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722 + if drop_audio_cond: # cfg for cond audio + cond = torch.zeros_like(cond) + + x = self.proj(torch.cat((x, cond, text_embed), dim=-1)) + x = self.conv_pos_embed(x) + x + return x + + +# Transformer backbone using DiT blocks + +# Adding this to decouple the conditoning encoding from the DiT backbone +class ConditioningEncoder(nn.Module): + def __init__( + self, + dim, + text_num_embeds, + text_dim, + text_mask_padding, + conv_layers, + mel_dim, + ): + super().__init__() + self.text_embed = TextEmbedding( + text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers + ) + self.input_embed = InputEmbedding(mel_dim, text_dim, dim) + + def forward( + self, + x, # b n d + cond, # b n d + text, # b nt + drop_audio_cond: bool = False, + drop_text: bool = False, + cache: bool = True, + ): + seq_len = x.shape[1] + if cache: + if drop_text: + text_embed = self.text_embed(text, seq_len, drop_text=True) + else: + text_embed = self.text_embed(text, seq_len, drop_text=False) + else: + text_embed = self.text_embed(text, seq_len, drop_text=drop_text) + + x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) + + return x + + +class DiT(nn.Module): + def __init__( + self, + *, + dim, + depth=8, + heads=8, + dim_head=64, + dropout=0.1, + ff_mult=4, + mel_dim=100, + text_num_embeds=256, + text_dim=None, + text_mask_padding=True, + qk_norm=None, + conv_layers=0, + pe_attn_head=None, + attn_backend="torch", # "torch" | "flash_attn" + attn_mask_enabled=False, + long_skip_connection=False, + checkpoint_activations=False, + ): + super().__init__() + + self.time_embed = TimestepEmbedding(dim) + if text_dim is None: + text_dim = mel_dim + self.dim = dim + self.depth = depth + + self.transformer_blocks = nn.ModuleList( + [ + DiTBlock( + dim=dim, + heads=heads, + dim_head=dim_head, + ff_mult=ff_mult, + dropout=dropout, + qk_norm=qk_norm, + pe_attn_head=pe_attn_head, + attn_backend=attn_backend, + attn_mask_enabled=attn_mask_enabled, + ) + for _ in range(depth) + ] + ) + self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None + + self.norm_out = AdaLayerNorm(dim, chunk_dim=1) # final modulation + self.proj_out = nn.Linear(dim, mel_dim) + + + def forward( + self, + x: float["b n d"], # nosied input audio # noqa: F722 + time: float["b"] | float[""], # time step # noqa: F821 F722 + mask: bool["b n"] | None = None, # noqa: F722 + cache: bool = False, + ): + batch, seq_len = x.shape[0], x.shape[1] + if time.ndim == 0: + time = time.repeat(batch) + print('time:', time) + temb = self.time_embed(time) # b d + print('temb:', temb) + rope = get_1d_rotary_pos_embed(pos=seq_len, dim=self.dim, use_real=True) + print('x:', x) + for block in self.transformer_blocks: + x = block(x, temb, mask=mask, rope=rope) + + x = self.norm_out(x, temb=temb) + output = self.proj_out(x) + + return output + + + + +#TODO - add BigVGAN support later +class MelSpec(nn.Module): + def __init__( + self, + n_fft=1024, + hop_length=256, + win_length=1024, + n_mel_channels=100, + target_sample_rate=24_000, + ): + super().__init__() + + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.n_mel_channels = n_mel_channels + self.target_sample_rate = target_sample_rate + + self.mel_stft = torchaudio.transforms.MelSpectrogram( + sample_rate=target_sample_rate, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + n_mels=n_mel_channels, + power=1, + center=True, + normalized=False, + norm=None, + ) + + + def forward(self, wav): + mel = self.mel_stft(wav) + mel = mel.clamp(min=1e-5).log() + return mel + + diff --git a/src/diffusers/pipelines/f5tts/convert_ckpt_to_diffusers.py b/src/diffusers/pipelines/f5tts/convert_ckpt_to_diffusers.py new file mode 100644 index 000000000000..a29ce449250b --- /dev/null +++ b/src/diffusers/pipelines/f5tts/convert_ckpt_to_diffusers.py @@ -0,0 +1,170 @@ +import sys +sys.path.append('/Users/ayushmangal/f5_contri/F5-TTS/src') + + +# training script. +import sys +sys.path.append('/Users/ayushmangal/f5_contri/F5-TTS/src') +import os +from importlib.resources import files +import torch +import hydra +from omegaconf import OmegaConf + +from f5_tts.model import CFM, Trainer +from f5_tts.model.dataset import load_dataset +from f5_tts.model.utils import get_tokenizer + + +os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable) +cfg_path = 'F5-TTS/src/f5_tts/configs/F5TTS_v1_Base.yaml' + +model_cfg = OmegaConf.load(cfg_path) +model_cls = hydra.utils.get_class(f"f5_tts.model.{model_cfg.model.backbone}") +model_arc = model_cfg.model.arch +tokenizer = model_cfg.model.tokenizer +mel_spec_type = model_cfg.model.mel_spec.mel_spec_type + + +# set text tokenizer +if tokenizer != "custom": + tokenizer_path = model_cfg.datasets.name +else: + tokenizer_path = model_cfg.model.tokenizer_path +vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) + +# set model +model = CFM( + transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=model_cfg.model.mel_spec.n_mel_channels), + mel_spec_kwargs=model_cfg.model.mel_spec, + vocab_char_map=vocab_char_map, +) + +# save model +ckpt_path = 'model_1250000.safetensors' +from safetensors.torch import load_file + +checkpoint = load_file(ckpt_path) +checkpoint = {"ema_model_state_dict": checkpoint} + +checkpoint["model_state_dict"] = { + k.replace("ema_model.", ""): v + for k, v in checkpoint["ema_model_state_dict"].items() + if k not in ["initted", "update", "step"] +} +model.load_state_dict(checkpoint["model_state_dict"], strict=True) + + +with open('vocab.txt', "r", encoding="utf-8") as f: + vocab_char_map = {} + for i, char in enumerate(f): + vocab_char_map[char[:-1]] = i +vocab_size = len(vocab_char_map) + +from diffusers.pipelines.f5tts.pipeline_f5tts import DiT, ConditioningEncoder, F5FlowPipeline + +dit_config = { +"dim": 1024, +"depth": 22, +"heads": 16, +"ff_mult": 2, +"text_dim": 512, +"text_num_embeds": 256, +"text_mask_padding": True, +"qk_norm": None, # null | rms_norm +"conv_layers": 4, +"pe_attn_head": None, +"attn_backend": "torch", # torch | flash_attn +"attn_mask_enabled": False, +"checkpoint_activations": False, # recompute activations and save memory for extra compute +} + + +mel_spec_config = { + "target_sample_rate": 24000, + "n_mel_channels": 100, + "hop_length": 256, + "win_length": 1024, + "n_fft": 1024, +} + + +dit = DiT(**dit_config) +print("DiT model initialized with config:", dit_config) + +conditioning_encoder_config = { + 'dim': 1024, + 'text_num_embeds': vocab_size, + 'text_dim': 512, + 'text_mask_padding': True, + 'conv_layers': 4, + 'mel_dim': mel_spec_config['n_mel_channels'], +} +conditioning_encoder = ConditioningEncoder(**conditioning_encoder_config) +print("Conditioning Encoder initialized with config:", conditioning_encoder_config) + +f5_pipeline = F5FlowPipeline( + transformer=dit, + conditioning_encoder=conditioning_encoder, + odeint_kwargs={"method": "euler"}, + mel_spec_kwargs=mel_spec_config, + vocab_char_map=vocab_char_map, +) +print("F5FlowPipeline initialized with DiT and Conditioning Encoder.") + + + +def load_pipeline_components_from_state_dict(state_dict, f5_pipeline): + """ + Load the components of the F5FlowPipeline from a state_dict. + """ + # print('state_dict, ', state_dict) + conditioning_encoder_state_dict = {} + for key in f5_pipeline.conditioning_encoder.state_dict().keys(): + if 'transformer.' + key in state_dict: + if 'grn' not in key: + conditioning_encoder_state_dict[key] = state_dict['transformer.' + key] + else: + grn_param = state_dict['transformer.' + key] + grn_param = grn_param.unsqueeze(0) + conditioning_encoder_state_dict[key] = grn_param + f5_pipeline.conditioning_encoder.load_state_dict(conditioning_encoder_state_dict) + + + transformer_state_dict = {} + # Load transformer + for key in f5_pipeline.transformer.state_dict().keys(): + if key in state_dict: + transformer_state_dict[key] = state_dict[key] + if 'transformer.' + key in state_dict: + transformer_state_dict[key] = state_dict['transformer.' + key] + f5_pipeline.transformer.load_state_dict(transformer_state_dict) + + return f5_pipeline + +f5_pipeline = load_pipeline_components_from_state_dict(model.state_dict(), f5_pipeline) + + +# check what keys have not changed +for key in f5_pipeline.conditioning_encoder.state_dict().keys(): + if key in model.state_dict(): + if not torch.allclose(f5_pipeline.conditioning_encoder.state_dict()[key], model.state_dict()[key], atol=1e-3): + print(f"Key {key} has changed in the conditioning encoder state dict.") + # Check if the key exists in the model state dict with a 'transformer.' prefix + elif 'transformer.' + key in model.state_dict(): + if not torch.allclose(f5_pipeline.conditioning_encoder.state_dict()[key], model.state_dict()['transformer.' + key], atol=1e-3): + print(f"Key {key} has changed in the conditioning encoder state dict.") + +for key in f5_pipeline.transformer.state_dict().keys(): + if key in model.state_dict(): + if not torch.allclose(f5_pipeline.transformer.state_dict()[key], model.state_dict()[key], atol=1e-3): + print(f"Key {key} has changed in the transformer state dict.") + print(f"Key {key} in model state dict: {model.state_dict()[key]}") + print(f"Key {key} in f5_pipeline state dict: {f5_pipeline.transformer.state_dict()[key]}") + break + elif 'transformer.' + key in model.state_dict(): + if not torch.allclose(f5_pipeline.transformer.state_dict()[key], model.state_dict()['transformer.' + key], atol=1e-3): + print(f"Key {key} has changed in the transformer state dict.") + print(f"Key {key} in model state dict: {model.state_dict()['transformer.' + key]}") + print(f"Key {key} in f5_pipeline state dict: {f5_pipeline.transformer.state_dict()[key]}") + break \ No newline at end of file diff --git a/src/diffusers/pipelines/f5tts/pipeline_f5tts.py b/src/diffusers/pipelines/f5tts/pipeline_f5tts.py new file mode 100644 index 000000000000..7cf7f89c63a1 --- /dev/null +++ b/src/diffusers/pipelines/f5tts/pipeline_f5tts.py @@ -0,0 +1,359 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations +from random import random +from typing import Callable + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.utils.rnn import pad_sequence +from torchdiffeq import odeint +import torchaudio + +import os +import random +from collections import defaultdict +from importlib.resources import files + +import torch +from torch.nn.utils.rnn import pad_sequence +from diffusers.pipelines.pipeline_utils import AudioPipelineOutput, DiffusionPipeline +from vocos import Vocos +from diffusers.models.transformers.f5tts_transformer import DiT, MelSpec, ConditioningEncoder +# helpers +import jieba +from pypinyin import lazy_pinyin, Style + + + + + + + +class F5FlowPipeline(DiffusionPipeline): + def __init__( + self, + transformer: DiT, + conditioning_encoder: ConditioningEncoder, + odeint_kwargs: dict = dict( + method="euler" + ), + mel_spec_kwargs: dict = dict(), + vocab_char_map: dict[str:int] | None = None, + ): + super().__init__() + self.transformer = transformer + self.conditioning_encoder = conditioning_encoder + self.mel_spec = MelSpec(**mel_spec_kwargs) + num_channels = self.mel_spec.n_mel_channels + self.num_channels = num_channels + # sampling related + self.odeint_kwargs = odeint_kwargs + # vocab map for tokenization + self.vocab_char_map = vocab_char_map + + + # simple utf-8 tokenizer, since paper went character based + def list_str_to_tensor(self, text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722 + list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style + text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True) + return text + + + # char tokenizer, based on custom dataset's extracted .txt file + def list_str_to_idx( + self, + text: list[str] | list[list[str]], + vocab_char_map: dict[str, int], # {char: idx} + padding_value=-1, + ) -> int["b nt"]: # noqa: F722 + list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style + text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) + return text + + + + + def lens_to_mask(self, t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821 + if length is None: + length = t.amax() + + seq = torch.arange(length, device=t.device) + return seq[None, :] < t[:, None] + + + def get_epss_timesteps(self, n, device, dtype): + dt = 1 / 32 + predefined_timesteps = { + 5: [0, 2, 4, 8, 16, 32], + 6: [0, 2, 4, 6, 8, 16, 32], + 7: [0, 2, 4, 6, 8, 16, 24, 32], + 10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32], + 12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32], + 16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32], + } + t = predefined_timesteps.get(n, []) + if not t: + return torch.linspace(0, 1, n + 1, device=device, dtype=dtype) + return dt * torch.tensor(t, device=device, dtype=dtype) + + + def convert_char_to_pinyin(self, text_list, polyphone=True): + if jieba.dt.initialized is False: + jieba.default_logger.setLevel(50) # CRITICAL + jieba.initialize() + + final_text_list = [] + custom_trans = str.maketrans( + {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} + ) # add custom trans here, to address oov + + def is_chinese(c): + return ( + "\u3100" <= c <= "\u9fff" # common chinese characters + ) + + for text in text_list: + char_list = [] + text = text.translate(custom_trans) + for seg in jieba.cut(text): + seg_byte_len = len(bytes(seg, "UTF-8")) + if seg_byte_len == len(seg): # if pure alphabets and symbols + if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": + char_list.append(" ") + char_list.extend(seg) + elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters + seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) + for i, c in enumerate(seg): + if is_chinese(c): + char_list.append(" ") + char_list.append(seg_[i]) + else: # if mixed characters, alphabets and symbols + for c in seg: + if ord(c) < 256: + char_list.extend(c) + elif is_chinese(c): + char_list.append(" ") + char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)) + else: + char_list.append(c) + final_text_list.append(char_list) + + return final_text_list + + + def __call__( + self, + ref_audio, + text_list, + nfe_step=32, + cfg_strength=2.0, + sway_sampling_coef=-1, + speed=1, + max_duration=4096, + fix_duration=None, + seed=None, + device="cuda", + steps=32, + ): + + + # Prepare the text + + final_text_list = self.convert_char_to_pinyin(text_list) + ref_audio_len = ref_audio.shape[-1] // self.mel_spec.hop_length + if fix_duration is not None: + duration = int(fix_duration * self.mel_spec.target_sample_rate / self.mel_spec.hop_length) + else: + # Calculate duration + ref_text_len = len(ref_text.encode("utf-8")) + gen_text_len = len(gen_text.encode("utf-8")) + duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed) + + + cond = ref_audio + + if cond.ndim == 2: + cond = self.mel_spec(cond) + cond = cond.permute(0, 2, 1) + assert cond.shape[-1] == self.num_channels + + batch, cond_seq_len, device = *cond.shape[:2], cond.device + lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) + + if isinstance(final_text_list, list): + if self.vocab_char_map is not None: + text = self.list_str_to_idx(final_text_list, self.vocab_char_map).to(device) + else: + text = self.list_str_to_tensor(final_text_list).to(device) + assert text.shape[0] == batch + + # duration + cond_mask = self.lens_to_mask(lens) + if isinstance(duration, int): + duration = torch.full((batch,), duration, device=device, dtype=torch.long) + + duration = torch.maximum( + torch.maximum((text != -1).sum(dim=-1), lens) + 1, duration + ) # duration at least text/audio prompt length plus one token, so something is generated + duration = duration.clamp(max=max_duration) + max_duration = duration.amax() + + cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) + cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) + cond_mask = cond_mask.unsqueeze(-1) + step_cond_input = torch.where( + cond_mask, cond, torch.zeros_like(cond) + ) # allow direct control (cut cond audio) with lens passed in + if batch > 1: + mask = self.lens_to_mask(duration) + else: # save memory and speed up, as single inference need no mask currently + mask = None + + if cfg_strength >= 1e-5 and mask is not None: + mask = torch.cat((mask, mask), dim=0) # for classifier-free guidance, we need to double the batch size + # neural ode + def fn(t, x): + # at each step, conditioning is fixed + # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) + step_cond = self.conditioning_encoder(x, step_cond_input, text, drop_audio_cond=False, drop_text=False) + # predict flow (cond) + if cfg_strength < 1e-5: + pred = self.transformer( + x=x, + cond=step_cond, + time=t, + mask=mask, + cache=True, + ) + return pred + + # predict flow (cond and uncond), for classifier-free guidance + step_uncond = self.conditioning_encoder(x, step_cond_input, text, drop_audio_cond=False, drop_text=False) + step_cond = torch.cat((step_cond, step_uncond), dim=0) + pred_cfg = self.transformer( + x=step_cond, + time=t, + mask=mask, + cache=True, + ) + pred, null_pred = torch.chunk(pred_cfg, 2, dim=0) + return pred + (pred - null_pred) * cfg_strength + + # noise input + # to make sure batch inference result is same with different batch size, and for sure single inference + # still some difference maybe due to convolutional layers + y0 = [] + for dur in duration: + if seed is not None: + torch.manual_seed(seed) + y0.append(torch.randn(dur, self.num_channels, device=device, dtype=step_cond_input.dtype)) + y0 = pad_sequence(y0, padding_value=0, batch_first=True) + t_start = 0 + + # TODO Add Empirically Pruned Step Sampling for low NFE + t = torch.linspace(t_start, 1, steps + 1, device=device, dtype=step_cond_input.dtype) + if sway_sampling_coef is not None: + t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) + + trajectory = odeint(fn, y0, t, **self.odeint_kwargs) + # self.transformer.clear_cache() + + sampled = trajectory[-1] + out = sampled + out = torch.where(cond_mask, cond, out) + + out = out.to(torch.float32) # generated mel spectrogram + out = out[:, ref_audio_len:, :] + out = out.permute(0, 2, 1) + generated_cpu = out[0] + + + # This need to be in HF Output format + return generated_cpu + + + +if __name__ == "__main__": + print('entering main funcitn') + + dit_config = { + "dim": 1024, + "depth": 22, + "heads": 16, + "ff_mult": 2, + "text_dim": 512, + "text_num_embeds": 256, + "text_mask_padding": True, + "qk_norm": None, # null | rms_norm + "conv_layers": 4, + "pe_attn_head": None, + "attn_backend": "torch", # torch | flash_attn + "attn_mask_enabled": False, + "checkpoint_activations": False, # recompute activations and save memory for extra compute + } + + + mel_spec_config = { + "target_sample_rate": 24000, + "n_mel_channels": 100, + "hop_length": 256, + "win_length": 1024, + "n_fft": 1024, + } + + + with open('vocab.txt', "r", encoding="utf-8") as f: + vocab_char_map = {} + for i, char in enumerate(f): + vocab_char_map[char[:-1]] = i + vocab_size = len(vocab_char_map) + + + dit = DiT(**dit_config) + print("DiT model initialized with config:", dit_config) + + conditioning_encoder_config = { + 'dim': 1024, + 'text_num_embeds': vocab_size, + 'text_dim': 512, + 'text_mask_padding': True, + 'conv_layers': 4, + 'mel_dim': mel_spec_config['n_mel_channels'], + } + conditioning_encoder = ConditioningEncoder(**conditioning_encoder_config) + print("Conditioning Encoder initialized with config:", conditioning_encoder_config) + + f5_pipeline = F5FlowPipeline( + transformer=dit, + conditioning_encoder=conditioning_encoder, + odeint_kwargs={"method": "euler"}, + mel_spec_kwargs=mel_spec_config, + vocab_char_map=vocab_char_map, + ) + print("F5FlowPipeline initialized with DiT and Conditioning Encoder.") + + import torch + ref_audio = torch.randn(2, 16000) # Dummy reference audio + duration = 250 + + ref_text = "This is a test sentence." # Dummy reference text + gen_text = "This is a generated sentence." # Dummy generated text + + text = [ref_text+gen_text] # Combine reference and generated text + text_list = text * 2 + + x = f5_pipeline(ref_audio=ref_audio, + text_list=text_list, + fix_duration=4, + max_duration=4096, device='cpu', steps=2) + print("Generated output shape:", x.shape) + \ No newline at end of file