From ef97fc7be9bd168cfbc4036f701e2169cc1950fd Mon Sep 17 00:00:00 2001 From: ayushtues Date: Sat, 19 Jul 2025 15:14:08 +0530 Subject: [PATCH 1/8] Add placeholder files --- src/diffusers/pipelines/f5tts/modeling_f5tts.py | 0 src/diffusers/pipelines/f5tts/pipeline_f5tts.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/diffusers/pipelines/f5tts/modeling_f5tts.py create mode 100644 src/diffusers/pipelines/f5tts/pipeline_f5tts.py diff --git a/src/diffusers/pipelines/f5tts/modeling_f5tts.py b/src/diffusers/pipelines/f5tts/modeling_f5tts.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/diffusers/pipelines/f5tts/pipeline_f5tts.py b/src/diffusers/pipelines/f5tts/pipeline_f5tts.py new file mode 100644 index 000000000000..e69de29bb2d1 From 0156dcbc586c69fb5cbbb028f879b3750f5678a1 Mon Sep 17 00:00:00 2001 From: ayushtues Date: Sat, 19 Jul 2025 18:33:31 +0530 Subject: [PATCH 2/8] Add first version of self-contained F5 DiT --- .../models/transformers/f5tts_transformer.py | 678 ++++++++++++++++++ 1 file changed, 678 insertions(+) create mode 100644 src/diffusers/models/transformers/f5tts_transformer.py diff --git a/src/diffusers/models/transformers/f5tts_transformer.py b/src/diffusers/models/transformers/f5tts_transformer.py new file mode 100644 index 000000000000..f792bf6ffbb7 --- /dev/null +++ b/src/diffusers/models/transformers/f5tts_transformer.py @@ -0,0 +1,678 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations + +import torch +import torch.nn.functional as F +from torch import nn +from ..normalization import GlobalResponseNorm, AdaLayerNorm +import math +from ..embeddings import get_1d_rotary_pos_embed, apply_rotary_emb +from typing import Optional, Union + +from einops import rearrange, repeat, reduce, pack, unpack +from torch import nn, einsum, tensor, Tensor, cat, stack, arange, is_tensor + + +def rotate_half(x): + x = rearrange(x, '... (d r) -> ... d r', r = 2) + x1, x2 = x.unbind(dim = -1) + x = stack((-x2, x1), dim = -1) + return rearrange(x, '... d r -> ... (d r)') + +def apply_rotary_pos_emb(t, freqs, scale = 1): + rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype + + freqs = freqs[:, -seq_len:, :] + scale = scale[:, -seq_len:, :] if isinstance(scale, torch.Tensor) else scale + + if t.ndim == 4 and freqs.ndim == 3: + freqs = rearrange(freqs, 'b n d -> b 1 n d') + + # partial rotary embeddings, Wang et al. GPT-J + t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + out = cat((t, t_unrotated), dim = -1) + + return out.type(orig_dtype) + +class AdaLayerNorm2(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 6) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb=None): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1) + + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + self.native_rms_norm = float(torch.__version__[:3]) >= 2.4 + + def forward(self, x): + if self.native_rms_norm: + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.to(self.weight.dtype) + x = F.rms_norm(x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps) + else: + variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.to(self.weight.dtype) + x = x * self.weight + + return x + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + activation = nn.GELU(approximate=approximate) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation) + self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.ff(x) + + +# Attention with possible joint part +# modified from diffusers/src/diffusers/models/attention_processor.py + + +class Attention(nn.Module): + def __init__( + self, + processor: AttnProcessor, + dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + context_dim: Optional[int] = None, # if not None -> joint attention + context_pre_only: bool = False, + qk_norm: Optional[str] = None, + ): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.processor = processor + + self.dim = dim + self.heads = heads + self.inner_dim = dim_head * heads + self.dropout = dropout + + self.context_dim = context_dim + self.context_pre_only = context_pre_only + + self.to_q = nn.Linear(dim, self.inner_dim) + self.to_k = nn.Linear(dim, self.inner_dim) + self.to_v = nn.Linear(dim, self.inner_dim) + + if qk_norm is None: + self.q_norm = None + self.k_norm = None + elif qk_norm == "rms_norm": + self.q_norm = RMSNorm(dim_head, eps=1e-6) + self.k_norm = RMSNorm(dim_head, eps=1e-6) + else: + raise ValueError(f"Unimplemented qk_norm: {qk_norm}") + + if self.context_dim is not None: + self.to_q_c = nn.Linear(context_dim, self.inner_dim) + self.to_k_c = nn.Linear(context_dim, self.inner_dim) + self.to_v_c = nn.Linear(context_dim, self.inner_dim) + if qk_norm is None: + self.c_q_norm = None + self.c_k_norm = None + elif qk_norm == "rms_norm": + self.c_q_norm = RMSNorm(dim_head, eps=1e-6) + self.c_k_norm = RMSNorm(dim_head, eps=1e-6) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, dim)) + self.to_out.append(nn.Dropout(dropout)) + + if self.context_dim is not None and not self.context_pre_only: + self.to_out_c = nn.Linear(self.inner_dim, context_dim) + + def forward( + self, + x: float["b n d"], # noised input x + c: float["b n d"] = None, # context c + mask: bool["b n"] | None = None, + rope=None, # rotary position embedding for x + c_rope=None, # rotary position embedding for c + ) -> torch.Tensor: + if c is not None: + return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope) + else: + return self.processor(self, x, mask=mask, rope=rope) + + + +def is_package_available(package_name: str) -> bool: + try: + import importlib + + package_exists = importlib.util.find_spec(package_name) is not None + return package_exists + except Exception: + return False + +# Attention processor + + + +class AttnProcessor: + def __init__( + self, + pe_attn_head: int | None = None, # number of attention head to apply rope, None for all + attn_backend: str = "torch", # "torch" or "flash_attn" + attn_mask_enabled: bool = True, + ): + if attn_backend == "flash_attn": + assert is_package_available("flash_attn"), "Please install flash-attn first." + + self.pe_attn_head = pe_attn_head + self.attn_backend = attn_backend + self.attn_mask_enabled = attn_mask_enabled + + def __call__( + self, + attn: Attention, + x: float["b n d"], # noised input x + mask: bool["b n"] | None = None, + rope=None, # rotary position embedding + ) -> torch.FloatTensor: + batch_size = x.shape[0] + + # `sample` projections + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # qk norm + if attn.q_norm is not None: + query = attn.q_norm(query) + if attn.k_norm is not None: + key = attn.k_norm(key) + + # apply rotary position embedding + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + + if self.pe_attn_head is not None: + pn = self.pe_attn_head + query[:, :pn, :, :] = apply_rotary_pos_emb(query[:, :pn, :, :], freqs, q_xpos_scale) + key[:, :pn, :, :] = apply_rotary_pos_emb(key[:, :pn, :, :], freqs, k_xpos_scale) + else: + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + + if self.attn_backend == "torch": + # mask. e.g. inference got a batch with different target durations, mask out the padding + if self.attn_mask_enabled and mask is not None: + attn_mask = mask + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + else: + attn_mask = None + x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + # elif self.attn_backend == "flash_attn": + # query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d] + # key = key.transpose(1, 2) + # value = value.transpose(1, 2) + # if self.attn_mask_enabled and mask is not None: + # query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask) + # key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask) + # value, _, _, _, _ = unpad_input(value, mask) + # x = flash_attn_varlen_func( + # query, + # key, + # value, + # q_cu_seqlens, + # k_cu_seqlens, + # q_max_seqlen_in_batch, + # k_max_seqlen_in_batch, + # ) + # x = pad_input(x, indices, batch_size, q_max_seqlen_in_batch) + # x = x.reshape(batch_size, -1, attn.heads * head_dim) + # else: + # x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False) + # x = x.reshape(batch_size, -1, attn.heads * head_dim) + + x = x.to(query.dtype) + + # linear proj + x = attn.to_out[0](x) + # dropout + x = attn.to_out[1](x) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + + return x + + +class DiTBlock(nn.Module): + def __init__( + self, + dim, + heads, + dim_head, + ff_mult=4, + dropout=0.1, + qk_norm=None, + pe_attn_head=None, + attn_backend="torch", # "torch" or "flash_attn" + attn_mask_enabled=True, + ): + super().__init__() + + self.attn_norm = AdaLayerNorm2(dim) + self.attn = Attention( + processor=AttnProcessor( + pe_attn_head=pe_attn_head, + attn_backend=attn_backend, + attn_mask_enabled=attn_mask_enabled, + ), + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + qk_norm=qk_norm, + ) + + self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + + def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding + # pre-norm & modulation for attention input + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) + + # attention + attn_output = self.attn(x=norm, mask=mask, rope=rope) + + # process attention output for input x + x = x + gate_msa.unsqueeze(1) * attn_output + + norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(norm) + x = x + gate_mlp.unsqueeze(1) * ff_output + + return x + + +class ConvPositionEmbedding(nn.Module): + def __init__(self, dim, kernel_size=31, groups=16): + super().__init__() + assert kernel_size % 2 != 0 + self.conv1d = nn.Sequential( + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), + nn.Mish(), + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), + nn.Mish(), + ) + + def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): + if mask is not None: + mask = mask[..., None] + x = x.masked_fill(~mask, 0.0) + + x = x.permute(0, 2, 1) + x = self.conv1d(x) + out = x.permute(0, 2, 1) + + if mask is not None: + out = out.masked_fill(~mask, 0.0) + + return out + + + +class SinusPositionEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x, scale=1000): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + + +class TimestepEmbedding(nn.Module): + def __init__(self, dim, freq_embed_dim=256): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + + def forward(self, timestep: float["b"]): + time_hidden = self.time_embed(timestep) + time_hidden = time_hidden.to(timestep.dtype) + time = self.time_mlp(time_hidden) # b d + return time + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0): + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py + theta *= theta_rescale_factor ** (dim / (dim - 2)) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cos = torch.cos(freqs) # real part + freqs_sin = torch.sin(freqs) # imaginary part + return torch.cat([freqs_cos, freqs_sin], dim=-1) + +def get_pos_embed_indices(start, length, max_pos, scale=1.0): + # length = length if isinstance(length, int) else length.max() + scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar + pos = ( + start.unsqueeze(1) + + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long() + ) + # avoid extra long error. + pos = torch.where(pos < max_pos, pos, max_pos - 1) + return pos + + + +class ConvNeXtV2Block(nn.Module): + def __init__( + self, + dim: int, + intermediate_dim: int, + dilation: int = 1, + ): + super().__init__() + padding = (dilation * (7 - 1)) // 2 + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation + ) # depthwise conv + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.grn = GlobalResponseNorm(intermediate_dim) + self.pwconv2 = nn.Linear(intermediate_dim, dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = x.transpose(1, 2) # b n d -> b d n + x = self.dwconv(x) + x = x.transpose(1, 2) # b d n -> b n d + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + return residual + x + + + + +# Text embedding + + +class TextEmbedding(nn.Module): + def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2): + super().__init__() + self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token + + self.mask_padding = mask_padding # mask filler and batch padding tokens or not + + if conv_layers > 0: + self.extra_modeling = True + self.precompute_max_pos = 4096 # ~44s of 24khz audio + self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) + self.text_blocks = nn.Sequential( + *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)] + ) + else: + self.extra_modeling = False + + def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 + text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() + text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens + batch, text_len = text.shape[0], text.shape[1] + text = F.pad(text, (0, seq_len - text_len), value=0) + if self.mask_padding: + text_mask = text == 0 + + if drop_text: # cfg for text + text = torch.zeros_like(text) + + text = self.text_embed(text) # b n -> b n d + + # possible extra modeling + if self.extra_modeling: + # sinus pos emb + batch_start = torch.zeros((batch,), dtype=torch.long) + pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos) + text_pos_embed = self.freqs_cis[pos_idx] + text = text + text_pos_embed + + # convnextv2 blocks + if self.mask_padding: + text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) + for block in self.text_blocks: + text = block(text) + text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) + else: + text = self.text_blocks(text) + + return text + + +# noised input audio and context mixing embedding + + +class InputEmbedding(nn.Module): + def __init__(self, mel_dim, text_dim, out_dim): + super().__init__() + self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) + self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) + + def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722 + if drop_audio_cond: # cfg for cond audio + cond = torch.zeros_like(cond) + + x = self.proj(torch.cat((x, cond, text_embed), dim=-1)) + x = self.conv_pos_embed(x) + x + return x + + +# Transformer backbone using DiT blocks + + +class DiT(nn.Module): + def __init__( + self, + *, + dim, + depth=8, + heads=8, + dim_head=64, + dropout=0.1, + ff_mult=4, + mel_dim=100, + text_num_embeds=256, + text_dim=None, + text_mask_padding=True, + qk_norm=None, + conv_layers=0, + pe_attn_head=None, + attn_backend="torch", # "torch" | "flash_attn" + attn_mask_enabled=False, + long_skip_connection=False, + checkpoint_activations=False, + ): + super().__init__() + + self.time_embed = TimestepEmbedding(dim) + if text_dim is None: + text_dim = mel_dim + self.text_embed = TextEmbedding( + text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers + ) + self.text_cond, self.text_uncond = None, None # text cache + self.input_embed = InputEmbedding(mel_dim, text_dim, dim) + self.dim = dim + self.depth = depth + + self.transformer_blocks = nn.ModuleList( + [ + DiTBlock( + dim=dim, + heads=heads, + dim_head=dim_head, + ff_mult=ff_mult, + dropout=dropout, + qk_norm=qk_norm, + pe_attn_head=pe_attn_head, + attn_backend=attn_backend, + attn_mask_enabled=attn_mask_enabled, + ) + for _ in range(depth) + ] + ) + self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None + + self.norm_out = AdaLayerNorm(dim) # final modulation + self.proj_out = nn.Linear(dim, mel_dim) + + self.checkpoint_activations = checkpoint_activations + + self.initialize_weights() + + def initialize_weights(self): + # Zero-out AdaLN layers in DiT blocks: + for block in self.transformer_blocks: + nn.init.constant_(block.attn_norm.linear.weight, 0) + nn.init.constant_(block.attn_norm.linear.bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.norm_out.linear.weight, 0) + nn.init.constant_(self.norm_out.linear.bias, 0) + nn.init.constant_(self.proj_out.weight, 0) + nn.init.constant_(self.proj_out.bias, 0) + + def ckpt_wrapper(self, module): + # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py + def ckpt_forward(*inputs): + outputs = module(*inputs) + return outputs + + return ckpt_forward + + def get_input_embed( + self, + x, # b n d + cond, # b n d + text, # b nt + drop_audio_cond: bool = False, + drop_text: bool = False, + cache: bool = True, + ): + seq_len = x.shape[1] + if cache: + if drop_text: + if self.text_uncond is None: + self.text_uncond = self.text_embed(text, seq_len, drop_text=True) + text_embed = self.text_uncond + else: + if self.text_cond is None: + self.text_cond = self.text_embed(text, seq_len, drop_text=False) + text_embed = self.text_cond + else: + text_embed = self.text_embed(text, seq_len, drop_text=drop_text) + + x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) + + return x + + def clear_cache(self): + self.text_cond, self.text_uncond = None, None + + def forward( + self, + x: float["b n d"], # nosied input audio # noqa: F722 + cond: float["b n d"], # masked cond audio # noqa: F722 + text: int["b nt"], # text # noqa: F722 + time: float["b"] | float[""], # time step # noqa: F821 F722 + mask: bool["b n"] | None = None, # noqa: F722 + drop_audio_cond: bool = False, # cfg for cond audio + drop_text: bool = False, # cfg for text + cfg_infer: bool = False, # cfg inference, pack cond & uncond forward + cache: bool = False, + ): + batch, seq_len = x.shape[0], x.shape[1] + if time.ndim == 0: + time = time.repeat(batch) + + # t: conditioning time, text: text, x: noised audio + cond audio + text + t = self.time_embed(time) + if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d + x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache) + x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache) + x = torch.cat((x_cond, x_uncond), dim=0) + t = torch.cat((t, t), dim=0) + mask = torch.cat((mask, mask), dim=0) if mask is not None else None + else: + x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache) + + rope = get_1d_rotary_pos_embed(seq_len, self.dim, device=x.device) + + for block in self.transformer_blocks: + if self.checkpoint_activations: + # https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint + x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False) + else: + x = block(x, t, mask=mask, rope=rope) + + x = self.norm_out(x, t) + output = self.proj_out(x) + + return output From 1450cf285d9cfa72f9196991faed1a694da5cec1 Mon Sep 17 00:00:00 2001 From: ayushtues Date: Sat, 19 Jul 2025 21:11:22 +0530 Subject: [PATCH 3/8] Add F5 pipeline code in a single file --- .../pipelines/f5tts/modeling_f5tts.py | 0 .../pipelines/f5tts/pipeline_f5tts.py | 863 ++++++++++++++++++ 2 files changed, 863 insertions(+) delete mode 100644 src/diffusers/pipelines/f5tts/modeling_f5tts.py diff --git a/src/diffusers/pipelines/f5tts/modeling_f5tts.py b/src/diffusers/pipelines/f5tts/modeling_f5tts.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/src/diffusers/pipelines/f5tts/pipeline_f5tts.py b/src/diffusers/pipelines/f5tts/pipeline_f5tts.py index e69de29bb2d1..33f1b3bb7f3e 100644 --- a/src/diffusers/pipelines/f5tts/pipeline_f5tts.py +++ b/src/diffusers/pipelines/f5tts/pipeline_f5tts.py @@ -0,0 +1,863 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations + +from random import random +from typing import Callable + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.utils.rnn import pad_sequence +from torchdiffeq import odeint + +import torchaudio +from librosa.filters import mel as librosa_mel_fn + +import os +import random +from collections import defaultdict +from importlib.resources import files + +import jieba +import torch +from pypinyin import Style, lazy_pinyin +from torch.nn.utils.rnn import pad_sequence + + +# helpers + + +def exists(v): + return v is not None + + +def default(v, d): + return v if exists(v) else d + + +def is_package_available(package_name: str) -> bool: + try: + import importlib + + package_exists = importlib.util.find_spec(package_name) is not None + return package_exists + except Exception: + return False + + + +def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821 + if not exists(length): + length = t.amax() + + seq = torch.arange(length, device=t.device) + return seq[None, :] < t[:, None] + + +def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821 + max_seq_len = seq_len.max().item() + seq = torch.arange(max_seq_len, device=start.device).long() + start_mask = seq[None, :] >= start[:, None] + end_mask = seq[None, :] < end[:, None] + return start_mask & end_mask + + +def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821 + lengths = (frac_lengths * seq_len).long() + max_start = seq_len - lengths + + rand = torch.rand_like(frac_lengths) + start = (max_start * rand).long().clamp(min=0) + end = start + lengths + + return mask_from_start_end_indices(seq_len, start, end) + + +def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722 + if not exists(mask): + return t.mean(dim=1) + + t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device)) + num = t.sum(dim=1) + den = mask.float().sum(dim=1) + + return num / den.clamp(min=1.0) + + +# simple utf-8 tokenizer, since paper went character based +def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722 + list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style + text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True) + return text + + +# char tokenizer, based on custom dataset's extracted .txt file +def list_str_to_idx( + text: list[str] | list[list[str]], + vocab_char_map: dict[str, int], # {char: idx} + padding_value=-1, +) -> int["b nt"]: # noqa: F722 + list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style + text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) + return text + + +# Get tokenizer + + +def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): + """ + tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file + - "char" for char-wise tokenizer, need .txt vocab_file + - "byte" for utf-8 tokenizer + - "custom" if you're directly passing in a path to the vocab.txt you want to use + vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols + - if use "char", derived from unfiltered character & symbol counts of custom dataset + - if use "byte", set to 256 (unicode byte range) + """ + if tokenizer in ["pinyin", "char"]: + tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt") + with open(tokenizer_path, "r", encoding="utf-8") as f: + vocab_char_map = {} + for i, char in enumerate(f): + vocab_char_map[char[:-1]] = i + vocab_size = len(vocab_char_map) + assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char" + + elif tokenizer == "byte": + vocab_char_map = None + vocab_size = 256 + + elif tokenizer == "custom": + with open(dataset_name, "r", encoding="utf-8") as f: + vocab_char_map = {} + for i, char in enumerate(f): + vocab_char_map[char[:-1]] = i + vocab_size = len(vocab_char_map) + + return vocab_char_map, vocab_size + + +# convert char to pinyin + + +def convert_char_to_pinyin(text_list, polyphone=True): + if jieba.dt.initialized is False: + jieba.default_logger.setLevel(50) # CRITICAL + jieba.initialize() + + final_text_list = [] + custom_trans = str.maketrans( + {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} + ) # add custom trans here, to address oov + + def is_chinese(c): + return ( + "\u3100" <= c <= "\u9fff" # common chinese characters + ) + + for text in text_list: + char_list = [] + text = text.translate(custom_trans) + for seg in jieba.cut(text): + seg_byte_len = len(bytes(seg, "UTF-8")) + if seg_byte_len == len(seg): # if pure alphabets and symbols + if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": + char_list.append(" ") + char_list.extend(seg) + elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters + seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) + for i, c in enumerate(seg): + if is_chinese(c): + char_list.append(" ") + char_list.append(seg_[i]) + else: # if mixed characters, alphabets and symbols + for c in seg: + if ord(c) < 256: + char_list.extend(c) + elif is_chinese(c): + char_list.append(" ") + char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)) + else: + char_list.append(c) + final_text_list.append(char_list) + + return final_text_list + + +# filter func for dirty data with many repetitions + + +def repetition_found(text, length=2, tolerance=10): + pattern_count = defaultdict(int) + for i in range(len(text) - length + 1): + pattern = text[i : i + length] + pattern_count[pattern] += 1 + for pattern, count in pattern_count.items(): + if count > tolerance: + return True + return False + + +# get the empirically pruned step for sampling + + +def get_epss_timesteps(n, device, dtype): + dt = 1 / 32 + predefined_timesteps = { + 5: [0, 2, 4, 8, 16, 32], + 6: [0, 2, 4, 6, 8, 16, 32], + 7: [0, 2, 4, 6, 8, 16, 24, 32], + 10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32], + 12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32], + 16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32], + } + t = predefined_timesteps.get(n, []) + if not t: + return torch.linspace(0, 1, n + 1, device=device, dtype=dtype) + return dt * torch.tensor(t, device=device, dtype=dtype) + + + + +mel_basis_cache = {} +hann_window_cache = {} + + +def get_bigvgan_mel_spectrogram( + waveform, + n_fft=1024, + n_mel_channels=100, + target_sample_rate=24000, + hop_length=256, + win_length=1024, + fmin=0, + fmax=None, + center=False, +): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main + device = waveform.device + key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}" + + if key not in mel_basis_cache: + mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax) + mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()? + hann_window_cache[key] = torch.hann_window(win_length).to(device) + + mel_basis = mel_basis_cache[key] + hann_window = hann_window_cache[key] + + padding = (n_fft - hop_length) // 2 + waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1) + + spec = torch.stft( + waveform, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) + + mel_spec = torch.matmul(mel_basis, spec) + mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5)) + + return mel_spec + + +def get_vocos_mel_spectrogram( + waveform, + n_fft=1024, + n_mel_channels=100, + target_sample_rate=24000, + hop_length=256, + win_length=1024, +): + mel_stft = torchaudio.transforms.MelSpectrogram( + sample_rate=target_sample_rate, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + n_mels=n_mel_channels, + power=1, + center=True, + normalized=False, + norm=None, + ).to(waveform.device) + if len(waveform.shape) == 3: + waveform = waveform.squeeze(1) # 'b 1 nw -> b nw' + + assert len(waveform.shape) == 2 + + mel = mel_stft(waveform) + mel = mel.clamp(min=1e-5).log() + return mel + +class MelSpec(nn.Module): + def __init__( + self, + n_fft=1024, + hop_length=256, + win_length=1024, + n_mel_channels=100, + target_sample_rate=24_000, + mel_spec_type="vocos", + ): + super().__init__() + assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan") + + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.n_mel_channels = n_mel_channels + self.target_sample_rate = target_sample_rate + + if mel_spec_type == "vocos": + self.extractor = get_vocos_mel_spectrogram + elif mel_spec_type == "bigvgan": + self.extractor = get_bigvgan_mel_spectrogram + + self.register_buffer("dummy", torch.tensor(0), persistent=False) + + def forward(self, wav): + if self.dummy.device != wav.device: + self.to(wav.device) + + mel = self.extractor( + waveform=wav, + n_fft=self.n_fft, + n_mel_channels=self.n_mel_channels, + target_sample_rate=self.target_sample_rate, + hop_length=self.hop_length, + win_length=self.win_length, + ) + + return mel + + +class CFM(nn.Module): + def __init__( + self, + transformer: nn.Module, + sigma=0.0, + odeint_kwargs: dict = dict( + # atol = 1e-5, + # rtol = 1e-5, + method="euler" # 'midpoint' + ), + audio_drop_prob=0.3, + cond_drop_prob=0.2, + num_channels=None, + mel_spec_module: nn.Module | None = None, + mel_spec_kwargs: dict = dict(), + frac_lengths_mask: tuple[float, float] = (0.7, 1.0), + vocab_char_map: dict[str:int] | None = None, + ): + super().__init__() + + self.frac_lengths_mask = frac_lengths_mask + + # mel spec + self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) + num_channels = default(num_channels, self.mel_spec.n_mel_channels) + self.num_channels = num_channels + + # classifier-free guidance + self.audio_drop_prob = audio_drop_prob + self.cond_drop_prob = cond_drop_prob + + # transformer + self.transformer = transformer + dim = transformer.dim + self.dim = dim + + # conditional flow related + self.sigma = sigma + + # sampling related + self.odeint_kwargs = odeint_kwargs + + # vocab map for tokenization + self.vocab_char_map = vocab_char_map + + @property + def device(self): + return next(self.parameters()).device + + @torch.no_grad() + def sample( + self, + cond: float["b n d"] | float["b nw"], # noqa: F722 + text: int["b nt"] | list[str], # noqa: F722 + duration: int | int["b"], # noqa: F821 + *, + lens: int["b"] | None = None, # noqa: F821 + steps=32, + cfg_strength=1.0, + sway_sampling_coef=None, + seed: int | None = None, + max_duration=4096, + vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 + use_epss=True, + no_ref_audio=False, + duplicate_test=False, + t_inter=0.1, + edit_mask=None, + ): + self.eval() + # raw wave + + if cond.ndim == 2: + cond = self.mel_spec(cond) + cond = cond.permute(0, 2, 1) + assert cond.shape[-1] == self.num_channels + + cond = cond.to(next(self.parameters()).dtype) + + batch, cond_seq_len, device = *cond.shape[:2], cond.device + if not exists(lens): + lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) + + # text + + if isinstance(text, list): + if exists(self.vocab_char_map): + text = list_str_to_idx(text, self.vocab_char_map).to(device) + else: + text = list_str_to_tensor(text).to(device) + assert text.shape[0] == batch + + # duration + + cond_mask = lens_to_mask(lens) + if edit_mask is not None: + cond_mask = cond_mask & edit_mask + + if isinstance(duration, int): + duration = torch.full((batch,), duration, device=device, dtype=torch.long) + + duration = torch.maximum( + torch.maximum((text != -1).sum(dim=-1), lens) + 1, duration + ) # duration at least text/audio prompt length plus one token, so something is generated + duration = duration.clamp(max=max_duration) + max_duration = duration.amax() + + # duplicate test corner for inner time step oberservation + if duplicate_test: + test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0) + + cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) + if no_ref_audio: + cond = torch.zeros_like(cond) + + cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) + cond_mask = cond_mask.unsqueeze(-1) + step_cond = torch.where( + cond_mask, cond, torch.zeros_like(cond) + ) # allow direct control (cut cond audio) with lens passed in + + if batch > 1: + mask = lens_to_mask(duration) + else: # save memory and speed up, as single inference need no mask currently + mask = None + + # neural ode + + def fn(t, x): + # at each step, conditioning is fixed + # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) + + # predict flow (cond) + if cfg_strength < 1e-5: + pred = self.transformer( + x=x, + cond=step_cond, + text=text, + time=t, + mask=mask, + drop_audio_cond=False, + drop_text=False, + cache=True, + ) + return pred + + # predict flow (cond and uncond), for classifier-free guidance + pred_cfg = self.transformer( + x=x, + cond=step_cond, + text=text, + time=t, + mask=mask, + cfg_infer=True, + cache=True, + ) + pred, null_pred = torch.chunk(pred_cfg, 2, dim=0) + return pred + (pred - null_pred) * cfg_strength + + # noise input + # to make sure batch inference result is same with different batch size, and for sure single inference + # still some difference maybe due to convolutional layers + y0 = [] + for dur in duration: + if exists(seed): + torch.manual_seed(seed) + y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype)) + y0 = pad_sequence(y0, padding_value=0, batch_first=True) + + t_start = 0 + + # duplicate test corner for inner time step oberservation + if duplicate_test: + t_start = t_inter + y0 = (1 - t_start) * y0 + t_start * test_cond + steps = int(steps * (1 - t_start)) + + if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE + t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype) + else: + t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype) + if sway_sampling_coef is not None: + t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) + + trajectory = odeint(fn, y0, t, **self.odeint_kwargs) + self.transformer.clear_cache() + + sampled = trajectory[-1] + out = sampled + out = torch.where(cond_mask, cond, out) + + if exists(vocoder): + out = out.permute(0, 2, 1) + out = vocoder(out) + + return out, trajectory + + def forward( + self, + inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722 + text: int["b nt"] | list[str], # noqa: F722 + *, + lens: int["b"] | None = None, # noqa: F821 + noise_scheduler: str | None = None, + ): + # handle raw wave + if inp.ndim == 2: + inp = self.mel_spec(inp) + inp = inp.permute(0, 2, 1) + assert inp.shape[-1] == self.num_channels + + batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma + + # handle text as string + if isinstance(text, list): + if exists(self.vocab_char_map): + text = list_str_to_idx(text, self.vocab_char_map).to(device) + else: + text = list_str_to_tensor(text).to(device) + assert text.shape[0] == batch + + # lens and mask + if not exists(lens): + lens = torch.full((batch,), seq_len, device=device) + + mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch + + # get a random span to mask out for training conditionally + frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask) + rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) + + if exists(mask): + rand_span_mask &= mask + + # mel is x1 + x1 = inp + + # x0 is gaussian noise + x0 = torch.randn_like(x1) + + # time step + time = torch.rand((batch,), dtype=dtype, device=self.device) + # TODO. noise_scheduler + + # sample xt (φ_t(x) in the paper) + t = time.unsqueeze(-1).unsqueeze(-1) + φ = (1 - t) * x0 + t * x1 + flow = x1 - x0 + + # only predict what is within the random mask span for infilling + cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) + + # transformer and cfg training with a drop rate + drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper + if random() < self.cond_drop_prob: # p_uncond in voicebox paper + drop_audio_cond = True + drop_text = True + else: + drop_text = False + + # apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold + pred = self.transformer( + x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, mask=mask + ) + + # flow matching loss + loss = F.mse_loss(pred, flow, reduction="none") + loss = loss[rand_span_mask] + + return loss.mean(), cond, pred + + +def infer_process( + ref_audio, + ref_text, + gen_text, + model_obj, + vocoder, + mel_spec_type=mel_spec_type, + show_info=print, + progress=tqdm, + target_rms=target_rms, + cross_fade_duration=cross_fade_duration, + nfe_step=nfe_step, + cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, + speed=speed, + fix_duration=fix_duration, + device=device, +): + # Split the input text into batches + audio, sr = torchaudio.load(ref_audio) + max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed) + gen_text_batches = chunk_text(gen_text, max_chars=max_chars) + for i, gen_text in enumerate(gen_text_batches): + print(f"gen_text {i}", gen_text) + print("\n") + + show_info(f"Generating audio in {len(gen_text_batches)} batches...") + return next( + infer_batch_process( + (audio, sr), + ref_text, + gen_text_batches, + model_obj, + vocoder, + mel_spec_type=mel_spec_type, + progress=progress, + target_rms=target_rms, + cross_fade_duration=cross_fade_duration, + nfe_step=nfe_step, + cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, + speed=speed, + fix_duration=fix_duration, + device=device, + ) + ) + + +# infer batches + + +def infer_batch_process( + ref_audio, + ref_text, + gen_text_batches, + model_obj, + vocoder, + mel_spec_type="vocos", + progress=tqdm, + target_rms=0.1, + cross_fade_duration=0.15, + nfe_step=32, + cfg_strength=2.0, + sway_sampling_coef=-1, + speed=1, + fix_duration=None, + device=None, + streaming=False, + chunk_size=2048, +): + audio, sr = ref_audio + if audio.shape[0] > 1: + audio = torch.mean(audio, dim=0, keepdim=True) + + rms = torch.sqrt(torch.mean(torch.square(audio))) + if rms < target_rms: + audio = audio * target_rms / rms + if sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(sr, target_sample_rate) + audio = resampler(audio) + audio = audio.to(device) + + generated_waves = [] + spectrograms = [] + + if len(ref_text[-1].encode("utf-8")) == 1: + ref_text = ref_text + " " + + def process_batch(gen_text): + local_speed = speed + if len(gen_text.encode("utf-8")) < 10: + local_speed = 0.3 + + # Prepare the text + text_list = [ref_text + gen_text] + final_text_list = convert_char_to_pinyin(text_list) + + ref_audio_len = audio.shape[-1] // hop_length + if fix_duration is not None: + duration = int(fix_duration * target_sample_rate / hop_length) + else: + # Calculate duration + ref_text_len = len(ref_text.encode("utf-8")) + gen_text_len = len(gen_text.encode("utf-8")) + duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed) + + # inference + with torch.inference_mode(): + generated, _ = model_obj.sample( + cond=audio, + text=final_text_list, + duration=duration, + steps=nfe_step, + cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, + ) + del _ + + generated = generated.to(torch.float32) # generated mel spectrogram + generated = generated[:, ref_audio_len:, :] + generated = generated.permute(0, 2, 1) + if mel_spec_type == "vocos": + generated_wave = vocoder.decode(generated) + elif mel_spec_type == "bigvgan": + generated_wave = vocoder(generated) + if rms < target_rms: + generated_wave = generated_wave * rms / target_rms + + # wav -> numpy + generated_wave = generated_wave.squeeze().cpu().numpy() + + if streaming: + for j in range(0, len(generated_wave), chunk_size): + yield generated_wave[j : j + chunk_size], target_sample_rate + else: + generated_cpu = generated[0].cpu().numpy() + del generated + yield generated_wave, generated_cpu + + if streaming: + for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches: + for chunk in process_batch(gen_text): + yield chunk + else: + with ThreadPoolExecutor() as executor: + futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches] + for future in progress.tqdm(futures) if progress is not None else futures: + result = future.result() + if result: + generated_wave, generated_mel_spec = next(result) + generated_waves.append(generated_wave) + spectrograms.append(generated_mel_spec) + + if generated_waves: + if cross_fade_duration <= 0: + # Simply concatenate + final_wave = np.concatenate(generated_waves) + else: + # Combine all generated waves with cross-fading + final_wave = generated_waves[0] + for i in range(1, len(generated_waves)): + prev_wave = final_wave + next_wave = generated_waves[i] + + # Calculate cross-fade samples, ensuring it does not exceed wave lengths + cross_fade_samples = int(cross_fade_duration * target_sample_rate) + cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave)) + + if cross_fade_samples <= 0: + # No overlap possible, concatenate + final_wave = np.concatenate([prev_wave, next_wave]) + continue + + # Overlapping parts + prev_overlap = prev_wave[-cross_fade_samples:] + next_overlap = next_wave[:cross_fade_samples] + + # Fade out and fade in + fade_out = np.linspace(1, 0, cross_fade_samples) + fade_in = np.linspace(0, 1, cross_fade_samples) + + # Cross-faded overlap + cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in + + # Combine + new_wave = np.concatenate( + [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]] + ) + + final_wave = new_wave + + # Create a combined spectrogram + combined_spectrogram = np.concatenate(spectrograms, axis=1) + + yield final_wave, target_sample_rate, combined_spectrogram + + else: + yield None, target_sample_rate, None + + + +# load vocoder +def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=device, hf_cache_dir=None): + if vocoder_name == "vocos": + # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device) + if is_local: + print(f"Load vocos from local path {local_path}") + config_path = f"{local_path}/config.yaml" + model_path = f"{local_path}/pytorch_model.bin" + else: + print("Download Vocos from huggingface charactr/vocos-mel-24khz") + repo_id = "charactr/vocos-mel-24khz" + config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml") + model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin") + vocoder = Vocos.from_hparams(config_path) + state_dict = torch.load(model_path, map_location="cpu", weights_only=True) + from vocos.feature_extractors import EncodecFeatures + + if isinstance(vocoder.feature_extractor, EncodecFeatures): + encodec_parameters = { + "feature_extractor.encodec." + key: value + for key, value in vocoder.feature_extractor.encodec.state_dict().items() + } + state_dict.update(encodec_parameters) + vocoder.load_state_dict(state_dict) + vocoder = vocoder.eval().to(device) + elif vocoder_name == "bigvgan": + try: + from third_party.BigVGAN import bigvgan + except ImportError: + print("You need to follow the README to init submodule and change the BigVGAN source code.") + if is_local: + # download generator from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main + vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False) + else: + vocoder = bigvgan.BigVGAN.from_pretrained( + "nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False, cache_dir=hf_cache_dir + ) + + vocoder.remove_weight_norm() + vocoder = vocoder.eval().to(device) + return vocoder From c7ee59404fae30886c30ff5d1d8c321e48375ca0 Mon Sep 17 00:00:00 2001 From: ayushtues Date: Mon, 21 Jul 2025 20:08:53 +0530 Subject: [PATCH 4/8] Use diffusers attention --- src/diffusers/models/attention_processor.py | 126 ++++++++++ .../models/transformers/f5tts_transformer.py | 231 +----------------- 2 files changed, 135 insertions(+), 222 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 990245de1742..f4746918cf17 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3117,6 +3117,132 @@ def __call__( hidden_states = hidden_states / attn.rescale_output_factor return hidden_states + + +class F5TTSAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the F5 TTS model. It applies rotary embedding to selected heads, and allows MHA, GQA or MQA. + """ + + def __init__(self, pe_attn_head: Optional[int] = None): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "F5TTSAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + self.pe_attn_head = pe_attn_head + + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + head_dim = query.shape[-1] // attn.heads + kv_heads = key.shape[-1] // head_dim + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + + if kv_heads != attn.heads: + # if GQA or MQA, repeat the key/value heads to reach the number of query heads. + heads_per_kv_head = attn.heads // kv_heads + key = torch.repeat_interleave(key, heads_per_kv_head, dim=1, output_size=key.shape[1] * heads_per_kv_head) + value = torch.repeat_interleave( + value, heads_per_kv_head, dim=1, output_size=value.shape[1] * heads_per_kv_head + ) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if rotary_emb is not None: + query_dtype = query.dtype + key_dtype = key.dtype + query = query.to(torch.float32) + key = key.to(torch.float32) + + + if self.pe_attn_head is not None: + query_rotated = apply_rotary_emb( + query[..., :self.pe_attn_head, :], rotary_emb, use_real=True, use_real_unbind_dim=-2 + ) + + query_unrotated = query[..., self.pe_attn_head:, :] + query = torch.cat((query_rotated, query_unrotated), dim=-2) + + if not attn.is_cross_attention: + key_to_rotate, key_unrotated = key[..., :self.pe_attn_head, :], key[..., self.pe_attn_head:, :] + key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) + key = torch.cat((key_rotated, key_unrotated), dim=-2) + + query = query.to(query_dtype) + key = key.to(key_dtype) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + + # these two need to be no-ops + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class HunyuanAttnProcessor2_0: diff --git a/src/diffusers/models/transformers/f5tts_transformer.py b/src/diffusers/models/transformers/f5tts_transformer.py index f792bf6ffbb7..2e197b9568bc 100644 --- a/src/diffusers/models/transformers/f5tts_transformer.py +++ b/src/diffusers/models/transformers/f5tts_transformer.py @@ -12,36 +12,17 @@ import torch import torch.nn.functional as F from torch import nn -from ..normalization import GlobalResponseNorm, AdaLayerNorm +from ..normalization import GlobalResponseNorm, AdaLayerNorm, RMSNorm import math from ..embeddings import get_1d_rotary_pos_embed, apply_rotary_emb from typing import Optional, Union - +from ..attention_processor import F5TTSAttnProcessor2_0 +from ..attention import Attention from einops import rearrange, repeat, reduce, pack, unpack from torch import nn, einsum, tensor, Tensor, cat, stack, arange, is_tensor -def rotate_half(x): - x = rearrange(x, '... (d r) -> ... d r', r = 2) - x1, x2 = x.unbind(dim = -1) - x = stack((-x2, x1), dim = -1) - return rearrange(x, '... d r -> ... (d r)') - -def apply_rotary_pos_emb(t, freqs, scale = 1): - rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype - - freqs = freqs[:, -seq_len:, :] - scale = scale[:, -seq_len:, :] if isinstance(scale, torch.Tensor) else scale - if t.ndim == 4 and freqs.ndim == 3: - freqs = rearrange(freqs, 'b n d -> b 1 n d') - - # partial rotary embeddings, Wang et al. GPT-J - t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] - t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) - out = cat((t, t_unrotated), dim = -1) - - return out.type(orig_dtype) class AdaLayerNorm2(nn.Module): def __init__(self, dim): @@ -62,34 +43,15 @@ def forward(self, x, emb=None): -class RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - self.native_rms_norm = float(torch.__version__[:3]) >= 2.4 - - def forward(self, x): - if self.native_rms_norm: - if self.weight.dtype in [torch.float16, torch.bfloat16]: - x = x.to(self.weight.dtype) - x = F.rms_norm(x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps) - else: - variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) - x = x * torch.rsqrt(variance + self.eps) - if self.weight.dtype in [torch.float16, torch.bfloat16]: - x = x.to(self.weight.dtype) - x = x * self.weight - - return x -class FeedForward(nn.Module): +class F5FeedForward(nn.Module): def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim - + + # a bit different ordering from the diffusers FeedForward class, here the inner projection weight comes first, in diffusers the activation comes first activation = nn.GELU(approximate=approximate) project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation) self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) @@ -102,76 +64,6 @@ def forward(self, x): # modified from diffusers/src/diffusers/models/attention_processor.py -class Attention(nn.Module): - def __init__( - self, - processor: AttnProcessor, - dim: int, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - context_dim: Optional[int] = None, # if not None -> joint attention - context_pre_only: bool = False, - qk_norm: Optional[str] = None, - ): - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - self.processor = processor - - self.dim = dim - self.heads = heads - self.inner_dim = dim_head * heads - self.dropout = dropout - - self.context_dim = context_dim - self.context_pre_only = context_pre_only - - self.to_q = nn.Linear(dim, self.inner_dim) - self.to_k = nn.Linear(dim, self.inner_dim) - self.to_v = nn.Linear(dim, self.inner_dim) - - if qk_norm is None: - self.q_norm = None - self.k_norm = None - elif qk_norm == "rms_norm": - self.q_norm = RMSNorm(dim_head, eps=1e-6) - self.k_norm = RMSNorm(dim_head, eps=1e-6) - else: - raise ValueError(f"Unimplemented qk_norm: {qk_norm}") - - if self.context_dim is not None: - self.to_q_c = nn.Linear(context_dim, self.inner_dim) - self.to_k_c = nn.Linear(context_dim, self.inner_dim) - self.to_v_c = nn.Linear(context_dim, self.inner_dim) - if qk_norm is None: - self.c_q_norm = None - self.c_k_norm = None - elif qk_norm == "rms_norm": - self.c_q_norm = RMSNorm(dim_head, eps=1e-6) - self.c_k_norm = RMSNorm(dim_head, eps=1e-6) - - self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(self.inner_dim, dim)) - self.to_out.append(nn.Dropout(dropout)) - - if self.context_dim is not None and not self.context_pre_only: - self.to_out_c = nn.Linear(self.inner_dim, context_dim) - - def forward( - self, - x: float["b n d"], # noised input x - c: float["b n d"] = None, # context c - mask: bool["b n"] | None = None, - rope=None, # rotary position embedding for x - c_rope=None, # rotary position embedding for c - ) -> torch.Tensor: - if c is not None: - return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope) - else: - return self.processor(self, x, mask=mask, rope=rope) @@ -184,110 +76,7 @@ def is_package_available(package_name: str) -> bool: except Exception: return False -# Attention processor - - - -class AttnProcessor: - def __init__( - self, - pe_attn_head: int | None = None, # number of attention head to apply rope, None for all - attn_backend: str = "torch", # "torch" or "flash_attn" - attn_mask_enabled: bool = True, - ): - if attn_backend == "flash_attn": - assert is_package_available("flash_attn"), "Please install flash-attn first." - - self.pe_attn_head = pe_attn_head - self.attn_backend = attn_backend - self.attn_mask_enabled = attn_mask_enabled - def __call__( - self, - attn: Attention, - x: float["b n d"], # noised input x - mask: bool["b n"] | None = None, - rope=None, # rotary position embedding - ) -> torch.FloatTensor: - batch_size = x.shape[0] - - # `sample` projections - query = attn.to_q(x) - key = attn.to_k(x) - value = attn.to_v(x) - - # attention - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # qk norm - if attn.q_norm is not None: - query = attn.q_norm(query) - if attn.k_norm is not None: - key = attn.k_norm(key) - - # apply rotary position embedding - if rope is not None: - freqs, xpos_scale = rope - q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) - - if self.pe_attn_head is not None: - pn = self.pe_attn_head - query[:, :pn, :, :] = apply_rotary_pos_emb(query[:, :pn, :, :], freqs, q_xpos_scale) - key[:, :pn, :, :] = apply_rotary_pos_emb(key[:, :pn, :, :], freqs, k_xpos_scale) - else: - query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) - key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) - - if self.attn_backend == "torch": - # mask. e.g. inference got a batch with different target durations, mask out the padding - if self.attn_mask_enabled and mask is not None: - attn_mask = mask - attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' - attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) - else: - attn_mask = None - x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) - x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - - # elif self.attn_backend == "flash_attn": - # query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d] - # key = key.transpose(1, 2) - # value = value.transpose(1, 2) - # if self.attn_mask_enabled and mask is not None: - # query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask) - # key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask) - # value, _, _, _, _ = unpad_input(value, mask) - # x = flash_attn_varlen_func( - # query, - # key, - # value, - # q_cu_seqlens, - # k_cu_seqlens, - # q_max_seqlen_in_batch, - # k_max_seqlen_in_batch, - # ) - # x = pad_input(x, indices, batch_size, q_max_seqlen_in_batch) - # x = x.reshape(batch_size, -1, attn.heads * head_dim) - # else: - # x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False) - # x = x.reshape(batch_size, -1, attn.heads * head_dim) - - x = x.to(query.dtype) - - # linear proj - x = attn.to_out[0](x) - # dropout - x = attn.to_out[1](x) - - if mask is not None: - mask = mask.unsqueeze(-1) - x = x.masked_fill(~mask, 0.0) - - return x class DiTBlock(nn.Module): @@ -307,12 +96,10 @@ def __init__( self.attn_norm = AdaLayerNorm2(dim) self.attn = Attention( - processor=AttnProcessor( + processor=F5TTSAttnProcessor2_0( pe_attn_head=pe_attn_head, - attn_backend=attn_backend, - attn_mask_enabled=attn_mask_enabled, ), - dim=dim, + query_dim=dim, heads=heads, dim_head=dim_head, dropout=dropout, @@ -320,7 +107,7 @@ def __init__( ) self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + self.ff = F5FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding # pre-norm & modulation for attention input From 380963d3a47c86288907007502ec74485a725f42 Mon Sep 17 00:00:00 2001 From: ayushtues Date: Sat, 26 Jul 2025 14:54:49 +0530 Subject: [PATCH 5/8] Remove conditioning encoding from DiT --- .../models/transformers/f5tts_transformer.py | 643 ++++++++++-- .../pipelines/f5tts/pipeline_f5tts.py | 966 +++--------------- 2 files changed, 725 insertions(+), 884 deletions(-) diff --git a/src/diffusers/models/transformers/f5tts_transformer.py b/src/diffusers/models/transformers/f5tts_transformer.py index 2e197b9568bc..9671cdc9e564 100644 --- a/src/diffusers/models/transformers/f5tts_transformer.py +++ b/src/diffusers/models/transformers/f5tts_transformer.py @@ -20,7 +20,8 @@ from ..attention import Attention from einops import rearrange, repeat, reduce, pack, unpack from torch import nn, einsum, tensor, Tensor, cat, stack, arange, is_tensor - +import jieba +from pypinyin import Style, lazy_pinyin @@ -313,6 +314,49 @@ def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b # Transformer backbone using DiT blocks +# Adding this to decouple the conditoning encoding from the DiT backbone +class ConditioningEncoder(nn.Module): + def __init__( + self, + dim + text_num_embeds, + text_dim, + text_mask_padding, + conv_layers, + mel_dim, + ): + super().__init__() + self.text_embed = TextEmbedding( + text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers + ) + self.input_embed = InputEmbedding(mel_dim, text_dim, dim) + + def forward( + self, + x, # b n d + cond, # b n d + text, # b nt + drop_audio_cond: bool = False, + drop_text: bool = False, + cache: bool = True, + ): + seq_len = x.shape[1] + if cache: + if drop_text: + if self.text_uncond is None: + self.text_uncond = self.text_embed(text, seq_len, drop_text=True) + text_embed = self.text_uncond + else: + if self.text_cond is None: + self.text_cond = self.text_embed(text, seq_len, drop_text=False) + text_embed = self.text_cond + else: + text_embed = self.text_embed(text, seq_len, drop_text=drop_text) + + x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) + + return x + class DiT(nn.Module): def __init__( @@ -341,11 +385,6 @@ def __init__( self.time_embed = TimestepEmbedding(dim) if text_dim is None: text_dim = mel_dim - self.text_embed = TextEmbedding( - text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers - ) - self.text_cond, self.text_uncond = None, None # text cache - self.input_embed = InputEmbedding(mel_dim, text_dim, dim) self.dim = dim self.depth = depth @@ -370,69 +409,12 @@ def __init__( self.norm_out = AdaLayerNorm(dim) # final modulation self.proj_out = nn.Linear(dim, mel_dim) - self.checkpoint_activations = checkpoint_activations - - self.initialize_weights() - - def initialize_weights(self): - # Zero-out AdaLN layers in DiT blocks: - for block in self.transformer_blocks: - nn.init.constant_(block.attn_norm.linear.weight, 0) - nn.init.constant_(block.attn_norm.linear.bias, 0) - - # Zero-out output layers: - nn.init.constant_(self.norm_out.linear.weight, 0) - nn.init.constant_(self.norm_out.linear.bias, 0) - nn.init.constant_(self.proj_out.weight, 0) - nn.init.constant_(self.proj_out.bias, 0) - - def ckpt_wrapper(self, module): - # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py - def ckpt_forward(*inputs): - outputs = module(*inputs) - return outputs - - return ckpt_forward - - def get_input_embed( - self, - x, # b n d - cond, # b n d - text, # b nt - drop_audio_cond: bool = False, - drop_text: bool = False, - cache: bool = True, - ): - seq_len = x.shape[1] - if cache: - if drop_text: - if self.text_uncond is None: - self.text_uncond = self.text_embed(text, seq_len, drop_text=True) - text_embed = self.text_uncond - else: - if self.text_cond is None: - self.text_cond = self.text_embed(text, seq_len, drop_text=False) - text_embed = self.text_cond - else: - text_embed = self.text_embed(text, seq_len, drop_text=drop_text) - - x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) - - return x - - def clear_cache(self): - self.text_cond, self.text_uncond = None, None def forward( self, x: float["b n d"], # nosied input audio # noqa: F722 - cond: float["b n d"], # masked cond audio # noqa: F722 - text: int["b nt"], # text # noqa: F722 time: float["b"] | float[""], # time step # noqa: F821 F722 mask: bool["b n"] | None = None, # noqa: F722 - drop_audio_cond: bool = False, # cfg for cond audio - drop_text: bool = False, # cfg for text - cfg_infer: bool = False, # cfg inference, pack cond & uncond forward cache: bool = False, ): batch, seq_len = x.shape[0], x.shape[1] @@ -441,25 +423,534 @@ def forward( # t: conditioning time, text: text, x: noised audio + cond audio + text t = self.time_embed(time) - if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d - x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache) - x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache) - x = torch.cat((x_cond, x_uncond), dim=0) - t = torch.cat((t, t), dim=0) - mask = torch.cat((mask, mask), dim=0) if mask is not None else None - else: - x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache) + + # if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d + # x_cond = self. + # (x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache) + # x_uncond = self. + # (x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache) + # x = torch.cat((x_cond, x_uncond), dim=0) + # t = torch.cat((t, t), dim=0) + # mask = torch.cat((mask, mask), dim=0) if mask is not None else None + # else: + # x = self. + # (x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache) rope = get_1d_rotary_pos_embed(seq_len, self.dim, device=x.device) - for block in self.transformer_blocks: - if self.checkpoint_activations: - # https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint - x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False) - else: x = block(x, t, mask=mask, rope=rope) x = self.norm_out(x, t) output = self.proj_out(x) return output + + + + +def exists(v): + return v is not None + + +def default(v, d): + return v if exists(v) else d + + + + + + + + + + + + + +# Get tokenizer + + +def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): + """ + tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file + - "char" for char-wise tokenizer, need .txt vocab_file + - "byte" for utf-8 tokenizer + - "custom" if you're directly passing in a path to the vocab.txt you want to use + vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols + - if use "char", derived from unfiltered character & symbol counts of custom dataset + - if use "byte", set to 256 (unicode byte range) + """ + if tokenizer in ["pinyin", "char"]: + tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt") + with open(tokenizer_path, "r", encoding="utf-8") as f: + vocab_char_map = {} + for i, char in enumerate(f): + vocab_char_map[char[:-1]] = i + vocab_size = len(vocab_char_map) + assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char" + + elif tokenizer == "byte": + vocab_char_map = None + vocab_size = 256 + + elif tokenizer == "custom": + with open(dataset_name, "r", encoding="utf-8") as f: + vocab_char_map = {} + for i, char in enumerate(f): + vocab_char_map[char[:-1]] = i + vocab_size = len(vocab_char_map) + + return vocab_char_map, vocab_size + + +# convert char to pinyin + + + + + + + + + +def get_vocos_mel_spectrogram( + waveform, + n_fft=1024, + n_mel_channels=100, + target_sample_rate=24000, + hop_length=256, + win_length=1024, +): + mel_stft = torchaudio.transforms.MelSpectrogram( + sample_rate=target_sample_rate, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + n_mels=n_mel_channels, + power=1, + center=True, + normalized=False, + norm=None, + ).to(waveform.device) + if len(waveform.shape) == 3: + waveform = waveform.squeeze(1) # 'b 1 nw -> b nw' + + assert len(waveform.shape) == 2 + + mel = mel_stft(waveform) + mel = mel.clamp(min=1e-5).log() + return mel + +class MelSpec(nn.Module): + def __init__( + self, + n_fft=1024, + hop_length=256, + win_length=1024, + n_mel_channels=100, + target_sample_rate=24_000, + mel_spec_type="vocos", + ): + super().__init__() + assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan") + + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.n_mel_channels = n_mel_channels + self.target_sample_rate = target_sample_rate + + #TODO - add BigVGAN support later + self.extractor = get_vocos_mel_spectrogram + + + self.register_buffer("dummy", torch.tensor(0), persistent=False) + + def forward(self, wav): + if self.dummy.device != wav.device: + self.to(wav.device) + + mel = self.extractor( + waveform=wav, + n_fft=self.n_fft, + n_mel_channels=self.n_mel_channels, + target_sample_rate=self.target_sample_rate, + hop_length=self.hop_length, + win_length=self.win_length, + ) + + return mel + + +class CFM(nn.Module): + def __init__( + self, + transformer: nn.Module, + sigma=0.0, + odeint_kwargs: dict = dict( + # atol = 1e-5, + # rtol = 1e-5, + method="euler" # 'midpoint' + ), + audio_drop_prob=0.3, + cond_drop_prob=0.2, + num_channels=None, + mel_spec_module: nn.Module | None = None, + mel_spec_kwargs: dict = dict(), + frac_lengths_mask: tuple[float, float] = (0.7, 1.0), + vocab_char_map: dict[str:int] | None = None, + ): + super().__init__() + + self.frac_lengths_mask = frac_lengths_mask + + # mel spec + self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) + num_channels = default(num_channels, self.mel_spec.n_mel_channels) + self.num_channels = num_channels + + # classifier-free guidance + self.audio_drop_prob = audio_drop_prob + self.cond_drop_prob = cond_drop_prob + + # transformer + self.transformer = transformer + dim = transformer.dim + self.dim = dim + + # conditional flow related + self.sigma = sigma + + # sampling related + self.odeint_kwargs = odeint_kwargs + + # vocab map for tokenization + self.vocab_char_map = vocab_char_map + + @property + def device(self): + return next(self.parameters()).device + + + # simple utf-8 tokenizer, since paper went character based + def list_str_to_tensor(self, text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722 + list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style + text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True) + return text + + + # char tokenizer, based on custom dataset's extracted .txt file + def list_str_to_idx( + self, + text: list[str] | list[list[str]], + vocab_char_map: dict[str, int], # {char: idx} + padding_value=-1, + ) -> int["b nt"]: # noqa: F722 + list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style + text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) + return text + + + def mask_from_frac_lengths(self, seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821 + lengths = (frac_lengths * seq_len).long() + max_start = seq_len - lengths + + rand = torch.rand_like(frac_lengths) + start = (max_start * rand).long().clamp(min=0) + end = start + lengths + + return mask_from_start_end_indices(seq_len, start, end) + + + + def lens_to_mask(self, t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821 + if not exists(length): + length = t.amax() + + seq = torch.arange(length, device=t.device) + return seq[None, :] < t[:, None] + + + def get_epss_timesteps(self, n, device, dtype): + dt = 1 / 32 + predefined_timesteps = { + 5: [0, 2, 4, 8, 16, 32], + 6: [0, 2, 4, 6, 8, 16, 32], + 7: [0, 2, 4, 6, 8, 16, 24, 32], + 10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32], + 12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32], + 16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32], + } + t = predefined_timesteps.get(n, []) + if not t: + return torch.linspace(0, 1, n + 1, device=device, dtype=dtype) + return dt * torch.tensor(t, device=device, dtype=dtype) + + + def convert_char_to_pinyin(self, text_list, polyphone=True): + if jieba.dt.initialized is False: + jieba.default_logger.setLevel(50) # CRITICAL + jieba.initialize() + + final_text_list = [] + custom_trans = str.maketrans( + {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} + ) # add custom trans here, to address oov + + def is_chinese(c): + return ( + "\u3100" <= c <= "\u9fff" # common chinese characters + ) + + for text in text_list: + char_list = [] + text = text.translate(custom_trans) + for seg in jieba.cut(text): + seg_byte_len = len(bytes(seg, "UTF-8")) + if seg_byte_len == len(seg): # if pure alphabets and symbols + if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": + char_list.append(" ") + char_list.extend(seg) + elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters + seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) + for i, c in enumerate(seg): + if is_chinese(c): + char_list.append(" ") + char_list.append(seg_[i]) + else: # if mixed characters, alphabets and symbols + for c in seg: + if ord(c) < 256: + char_list.extend(c) + elif is_chinese(c): + char_list.append(" ") + char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)) + else: + char_list.append(c) + final_text_list.append(char_list) + + return final_text_list + + + + + @torch.no_grad() + def sample( + self, + cond: float["b n d"] | float["b nw"], # noqa: F722 + text: int["b nt"] | list[str], # noqa: F722 + duration: int | int["b"], # noqa: F821 + *, + lens: int["b"] | None = None, # noqa: F821 + steps=32, + cfg_strength=1.0, + sway_sampling_coef=None, + seed: int | None = None, + max_duration=4096, + vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 + use_epss=True, + no_ref_audio=False, + duplicate_test=False, + t_inter=0.1, + edit_mask=None, + ): + self.eval() + # raw wave + + if cond.ndim == 2: + cond = self.mel_spec(cond) + cond = cond.permute(0, 2, 1) + assert cond.shape[-1] == self.num_channels + + cond = cond.to(next(self.parameters()).dtype) + + batch, cond_seq_len, device = *cond.shape[:2], cond.device + if not exists(lens): + lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) + + # text + + if isinstance(text, list): + if exists(self.vocab_char_map): + text = self.list_str_to_idx(text, self.vocab_char_map).to(device) + else: + text = self.list_str_to_tensor(text).to(device) + assert text.shape[0] == batch + + # duration + + cond_mask = lens_to_mask(lens) + if edit_mask is not None: + cond_mask = cond_mask & edit_mask + + if isinstance(duration, int): + duration = torch.full((batch,), duration, device=device, dtype=torch.long) + + duration = torch.maximum( + torch.maximum((text != -1).sum(dim=-1), lens) + 1, duration + ) # duration at least text/audio prompt length plus one token, so something is generated + duration = duration.clamp(max=max_duration) + max_duration = duration.amax() + + # duplicate test corner for inner time step oberservation + if duplicate_test: + test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0) + + cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) + if no_ref_audio: + cond = torch.zeros_like(cond) + + cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) + cond_mask = cond_mask.unsqueeze(-1) + step_cond = torch.where( + cond_mask, cond, torch.zeros_like(cond) + ) # allow direct control (cut cond audio) with lens passed in + + if batch > 1: + mask = lens_to_mask(duration) + else: # save memory and speed up, as single inference need no mask currently + mask = None + + # neural ode + + def fn(t, x): + # at each step, conditioning is fixed + # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) + + # predict flow (cond) + if cfg_strength < 1e-5: + pred = self.transformer( + x=x, + cond=step_cond, + text=text, + time=t, + mask=mask, + drop_audio_cond=False, + drop_text=False, + cache=True, + ) + return pred + + # predict flow (cond and uncond), for classifier-free guidance + pred_cfg = self.transformer( + x=x, + cond=step_cond, + text=text, + time=t, + mask=mask, + cfg_infer=True, + cache=True, + ) + pred, null_pred = torch.chunk(pred_cfg, 2, dim=0) + return pred + (pred - null_pred) * cfg_strength + + # noise input + # to make sure batch inference result is same with different batch size, and for sure single inference + # still some difference maybe due to convolutional layers + y0 = [] + for dur in duration: + if exists(seed): + torch.manual_seed(seed) + y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype)) + y0 = pad_sequence(y0, padding_value=0, batch_first=True) + + t_start = 0 + + # duplicate test corner for inner time step oberservation + if duplicate_test: + t_start = t_inter + y0 = (1 - t_start) * y0 + t_start * test_cond + steps = int(steps * (1 - t_start)) + + if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE + t = self.get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype) + else: + t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype) + if sway_sampling_coef is not None: + t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) + + trajectory = odeint(fn, y0, t, **self.odeint_kwargs) + self.transformer.clear_cache() + + sampled = trajectory[-1] + out = sampled + out = torch.where(cond_mask, cond, out) + + if exists(vocoder): + out = out.permute(0, 2, 1) + out = vocoder(out) + + return out, trajectory + + def forward( + self, + inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722 + text: int["b nt"] | list[str], # noqa: F722 + *, + lens: int["b"] | None = None, # noqa: F821 + noise_scheduler: str | None = None, + ): + # handle raw wave + if inp.ndim == 2: + inp = self.mel_spec(inp) + inp = inp.permute(0, 2, 1) + assert inp.shape[-1] == self.num_channels + + batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma + + # handle text as string + if isinstance(text, list): + if exists(self.vocab_char_map): + text = self.list_str_to_idx(text, self.vocab_char_map).to(device) + else: + text = self.list_str_to_tensor(text).to(device) + assert text.shape[0] == batch + + # lens and mask + if not exists(lens): + lens = torch.full((batch,), seq_len, device=device) + + mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch + + # get a random span to mask out for training conditionally + frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask) + rand_span_mask = self.mask_from_frac_lengths(lens, frac_lengths) + + if exists(mask): + rand_span_mask &= mask + + # mel is x1 + x1 = inp + + # x0 is gaussian noise + x0 = torch.randn_like(x1) + + # time step + time = torch.rand((batch,), dtype=dtype, device=self.device) + # TODO. noise_scheduler + + # sample xt (φ_t(x) in the paper) + t = time.unsqueeze(-1).unsqueeze(-1) + φ = (1 - t) * x0 + t * x1 + flow = x1 - x0 + + # only predict what is within the random mask span for infilling + cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) + + # transformer and cfg training with a drop rate + drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper + if random() < self.cond_drop_prob: # p_uncond in voicebox paper + drop_audio_cond = True + drop_text = True + else: + drop_text = False + + # apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold + pred = self.transformer( + x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, mask=mask + ) + + # flow matching loss + loss = F.mse_loss(pred, flow, reduction="none") + loss = loss[rand_span_mask] + + return loss.mean(), cond, pred diff --git a/src/diffusers/pipelines/f5tts/pipeline_f5tts.py b/src/diffusers/pipelines/f5tts/pipeline_f5tts.py index 33f1b3bb7f3e..7e7ef3c2146c 100644 --- a/src/diffusers/pipelines/f5tts/pipeline_f5tts.py +++ b/src/diffusers/pipelines/f5tts/pipeline_f5tts.py @@ -26,838 +26,188 @@ from collections import defaultdict from importlib.resources import files -import jieba + import torch -from pypinyin import Style, lazy_pinyin from torch.nn.utils.rnn import pad_sequence - +from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline +from vocos import Vocos # helpers -def exists(v): - return v is not None - - -def default(v, d): - return v if exists(v) else d - - -def is_package_available(package_name: str) -> bool: - try: - import importlib - - package_exists = importlib.util.find_spec(package_name) is not None - return package_exists - except Exception: - return False - - - -def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821 - if not exists(length): - length = t.amax() - - seq = torch.arange(length, device=t.device) - return seq[None, :] < t[:, None] - - -def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821 - max_seq_len = seq_len.max().item() - seq = torch.arange(max_seq_len, device=start.device).long() - start_mask = seq[None, :] >= start[:, None] - end_mask = seq[None, :] < end[:, None] - return start_mask & end_mask - - -def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821 - lengths = (frac_lengths * seq_len).long() - max_start = seq_len - lengths - - rand = torch.rand_like(frac_lengths) - start = (max_start * rand).long().clamp(min=0) - end = start + lengths - - return mask_from_start_end_indices(seq_len, start, end) - - -def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722 - if not exists(mask): - return t.mean(dim=1) - - t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device)) - num = t.sum(dim=1) - den = mask.float().sum(dim=1) - - return num / den.clamp(min=1.0) - - -# simple utf-8 tokenizer, since paper went character based -def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722 - list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style - text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True) - return text -# char tokenizer, based on custom dataset's extracted .txt file -def list_str_to_idx( - text: list[str] | list[list[str]], - vocab_char_map: dict[str, int], # {char: idx} - padding_value=-1, -) -> int["b nt"]: # noqa: F722 - list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style - text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) - return text -# Get tokenizer - -def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): - """ - tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file - - "char" for char-wise tokenizer, need .txt vocab_file - - "byte" for utf-8 tokenizer - - "custom" if you're directly passing in a path to the vocab.txt you want to use - vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols - - if use "char", derived from unfiltered character & symbol counts of custom dataset - - if use "byte", set to 256 (unicode byte range) - """ - if tokenizer in ["pinyin", "char"]: - tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt") - with open(tokenizer_path, "r", encoding="utf-8") as f: - vocab_char_map = {} - for i, char in enumerate(f): - vocab_char_map[char[:-1]] = i - vocab_size = len(vocab_char_map) - assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char" - - elif tokenizer == "byte": - vocab_char_map = None - vocab_size = 256 - - elif tokenizer == "custom": - with open(dataset_name, "r", encoding="utf-8") as f: - vocab_char_map = {} - for i, char in enumerate(f): - vocab_char_map[char[:-1]] = i - vocab_size = len(vocab_char_map) - - return vocab_char_map, vocab_size - - -# convert char to pinyin - - -def convert_char_to_pinyin(text_list, polyphone=True): - if jieba.dt.initialized is False: - jieba.default_logger.setLevel(50) # CRITICAL - jieba.initialize() - - final_text_list = [] - custom_trans = str.maketrans( - {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} - ) # add custom trans here, to address oov - - def is_chinese(c): - return ( - "\u3100" <= c <= "\u9fff" # common chinese characters - ) - - for text in text_list: - char_list = [] - text = text.translate(custom_trans) - for seg in jieba.cut(text): - seg_byte_len = len(bytes(seg, "UTF-8")) - if seg_byte_len == len(seg): # if pure alphabets and symbols - if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": - char_list.append(" ") - char_list.extend(seg) - elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters - seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) - for i, c in enumerate(seg): - if is_chinese(c): - char_list.append(" ") - char_list.append(seg_[i]) - else: # if mixed characters, alphabets and symbols - for c in seg: - if ord(c) < 256: - char_list.extend(c) - elif is_chinese(c): - char_list.append(" ") - char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)) - else: - char_list.append(c) - final_text_list.append(char_list) - - return final_text_list - - -# filter func for dirty data with many repetitions - - -def repetition_found(text, length=2, tolerance=10): - pattern_count = defaultdict(int) - for i in range(len(text) - length + 1): - pattern = text[i : i + length] - pattern_count[pattern] += 1 - for pattern, count in pattern_count.items(): - if count > tolerance: - return True - return False - - -# get the empirically pruned step for sampling - - -def get_epss_timesteps(n, device, dtype): - dt = 1 / 32 - predefined_timesteps = { - 5: [0, 2, 4, 8, 16, 32], - 6: [0, 2, 4, 6, 8, 16, 32], - 7: [0, 2, 4, 6, 8, 16, 24, 32], - 10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32], - 12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32], - 16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32], - } - t = predefined_timesteps.get(n, []) - if not t: - return torch.linspace(0, 1, n + 1, device=device, dtype=dtype) - return dt * torch.tensor(t, device=device, dtype=dtype) - - - - -mel_basis_cache = {} -hann_window_cache = {} - - -def get_bigvgan_mel_spectrogram( - waveform, - n_fft=1024, - n_mel_channels=100, - target_sample_rate=24000, - hop_length=256, - win_length=1024, - fmin=0, - fmax=None, - center=False, -): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main - device = waveform.device - key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}" - - if key not in mel_basis_cache: - mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax) - mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()? - hann_window_cache[key] = torch.hann_window(win_length).to(device) - - mel_basis = mel_basis_cache[key] - hann_window = hann_window_cache[key] - - padding = (n_fft - hop_length) // 2 - waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1) - - spec = torch.stft( - waveform, - n_fft, - hop_length=hop_length, - win_length=win_length, - window=hann_window, - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) - - mel_spec = torch.matmul(mel_basis, spec) - mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5)) - - return mel_spec - - -def get_vocos_mel_spectrogram( - waveform, - n_fft=1024, - n_mel_channels=100, - target_sample_rate=24000, - hop_length=256, - win_length=1024, -): - mel_stft = torchaudio.transforms.MelSpectrogram( - sample_rate=target_sample_rate, - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, - n_mels=n_mel_channels, - power=1, - center=True, - normalized=False, - norm=None, - ).to(waveform.device) - if len(waveform.shape) == 3: - waveform = waveform.squeeze(1) # 'b 1 nw -> b nw' - - assert len(waveform.shape) == 2 - - mel = mel_stft(waveform) - mel = mel.clamp(min=1e-5).log() - return mel - -class MelSpec(nn.Module): +class F5FlowPipeline(DiffusionPipeline): def __init__( self, - n_fft=1024, - hop_length=256, - win_length=1024, - n_mel_channels=100, - target_sample_rate=24_000, + model: CFM, + vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 mel_spec_type="vocos", - ): - super().__init__() - assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan") - - self.n_fft = n_fft - self.hop_length = hop_length - self.win_length = win_length - self.n_mel_channels = n_mel_channels - self.target_sample_rate = target_sample_rate - - if mel_spec_type == "vocos": - self.extractor = get_vocos_mel_spectrogram - elif mel_spec_type == "bigvgan": - self.extractor = get_bigvgan_mel_spectrogram - - self.register_buffer("dummy", torch.tensor(0), persistent=False) - - def forward(self, wav): - if self.dummy.device != wav.device: - self.to(wav.device) - - mel = self.extractor( - waveform=wav, - n_fft=self.n_fft, - n_mel_channels=self.n_mel_channels, - target_sample_rate=self.target_sample_rate, - hop_length=self.hop_length, - win_length=self.win_length, - ) - - return mel - - -class CFM(nn.Module): - def __init__( - self, - transformer: nn.Module, - sigma=0.0, - odeint_kwargs: dict = dict( - # atol = 1e-5, - # rtol = 1e-5, - method="euler" # 'midpoint' - ), - audio_drop_prob=0.3, - cond_drop_prob=0.2, - num_channels=None, - mel_spec_module: nn.Module | None = None, mel_spec_kwargs: dict = dict(), - frac_lengths_mask: tuple[float, float] = (0.7, 1.0), - vocab_char_map: dict[str:int] | None = None, + device=None, ): super().__init__() - - self.frac_lengths_mask = frac_lengths_mask - - # mel spec - self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) - num_channels = default(num_channels, self.mel_spec.n_mel_channels) - self.num_channels = num_channels - - # classifier-free guidance - self.audio_drop_prob = audio_drop_prob - self.cond_drop_prob = cond_drop_prob - - # transformer - self.transformer = transformer - dim = transformer.dim - self.dim = dim - - # conditional flow related - self.sigma = sigma - - # sampling related - self.odeint_kwargs = odeint_kwargs - - # vocab map for tokenization - self.vocab_char_map = vocab_char_map - - @property - def device(self): - return next(self.parameters()).device - - @torch.no_grad() - def sample( - self, - cond: float["b n d"] | float["b nw"], # noqa: F722 - text: int["b nt"] | list[str], # noqa: F722 - duration: int | int["b"], # noqa: F821 - *, - lens: int["b"] | None = None, # noqa: F821 - steps=32, - cfg_strength=1.0, - sway_sampling_coef=None, - seed: int | None = None, - max_duration=4096, - vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 - use_epss=True, - no_ref_audio=False, - duplicate_test=False, - t_inter=0.1, - edit_mask=None, + self.model = model + self.vocoder = vocoder + self.mel_spec_type = mel_spec_type + self.mel_spec_kwargs = mel_spec_kwargs + self.device = device or torch.device("cpu") + + def __call__(self, *args, **kwargs): + return infer_process(*args, **kwargs, model_obj=self.model, vocoder=self.vocoder, device=self.device) + + + def infer_batch_process( + ref_audio, + ref_text, + gen_text_batches, + model_obj, + vocoder, + mel_spec_type="vocos", + progress=tqdm, + target_rms=0.1, + cross_fade_duration=0.15, + nfe_step=32, + cfg_strength=2.0, + sway_sampling_coef=-1, + speed=1, + fix_duration=None, + device=None, + streaming=False, + chunk_size=2048, ): - self.eval() - # raw wave - - if cond.ndim == 2: - cond = self.mel_spec(cond) - cond = cond.permute(0, 2, 1) - assert cond.shape[-1] == self.num_channels - - cond = cond.to(next(self.parameters()).dtype) - - batch, cond_seq_len, device = *cond.shape[:2], cond.device - if not exists(lens): - lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) - - # text - - if isinstance(text, list): - if exists(self.vocab_char_map): - text = list_str_to_idx(text, self.vocab_char_map).to(device) + audio, sr = ref_audio + if audio.shape[0] > 1: + audio = torch.mean(audio, dim=0, keepdim=True) + + rms = torch.sqrt(torch.mean(torch.square(audio))) + if rms < target_rms: + audio = audio * target_rms / rms + if sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(sr, target_sample_rate) + audio = resampler(audio) + audio = audio.to(device) + + generated_waves = [] + spectrograms = [] + + if len(ref_text[-1].encode("utf-8")) == 1: + ref_text = ref_text + " " + + def process_batch(gen_text): + local_speed = speed + if len(gen_text.encode("utf-8")) < 10: + local_speed = 0.3 + + # Prepare the text + text_list = [ref_text + gen_text] + final_text_list = convert_char_to_pinyin(text_list) + + ref_audio_len = audio.shape[-1] // hop_length + if fix_duration is not None: + duration = int(fix_duration * target_sample_rate / hop_length) else: - text = list_str_to_tensor(text).to(device) - assert text.shape[0] == batch - - # duration - - cond_mask = lens_to_mask(lens) - if edit_mask is not None: - cond_mask = cond_mask & edit_mask - - if isinstance(duration, int): - duration = torch.full((batch,), duration, device=device, dtype=torch.long) - - duration = torch.maximum( - torch.maximum((text != -1).sum(dim=-1), lens) + 1, duration - ) # duration at least text/audio prompt length plus one token, so something is generated - duration = duration.clamp(max=max_duration) - max_duration = duration.amax() - - # duplicate test corner for inner time step oberservation - if duplicate_test: - test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0) - - cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) - if no_ref_audio: - cond = torch.zeros_like(cond) - - cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) - cond_mask = cond_mask.unsqueeze(-1) - step_cond = torch.where( - cond_mask, cond, torch.zeros_like(cond) - ) # allow direct control (cut cond audio) with lens passed in - - if batch > 1: - mask = lens_to_mask(duration) - else: # save memory and speed up, as single inference need no mask currently - mask = None - - # neural ode - - def fn(t, x): - # at each step, conditioning is fixed - # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) - - # predict flow (cond) - if cfg_strength < 1e-5: - pred = self.transformer( - x=x, - cond=step_cond, - text=text, - time=t, - mask=mask, - drop_audio_cond=False, - drop_text=False, - cache=True, + # Calculate duration + ref_text_len = len(ref_text.encode("utf-8")) + gen_text_len = len(gen_text.encode("utf-8")) + duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed) + + # inference + with torch.inference_mode(): + generated, _ = model_obj.sample( + cond=audio, + text=final_text_list, + duration=duration, + steps=nfe_step, + cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, ) - return pred - - # predict flow (cond and uncond), for classifier-free guidance - pred_cfg = self.transformer( - x=x, - cond=step_cond, - text=text, - time=t, - mask=mask, - cfg_infer=True, - cache=True, - ) - pred, null_pred = torch.chunk(pred_cfg, 2, dim=0) - return pred + (pred - null_pred) * cfg_strength - - # noise input - # to make sure batch inference result is same with different batch size, and for sure single inference - # still some difference maybe due to convolutional layers - y0 = [] - for dur in duration: - if exists(seed): - torch.manual_seed(seed) - y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype)) - y0 = pad_sequence(y0, padding_value=0, batch_first=True) - - t_start = 0 - - # duplicate test corner for inner time step oberservation - if duplicate_test: - t_start = t_inter - y0 = (1 - t_start) * y0 + t_start * test_cond - steps = int(steps * (1 - t_start)) - - if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE - t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype) + del _ + + generated = generated.to(torch.float32) # generated mel spectrogram + generated = generated[:, ref_audio_len:, :] + generated = generated.permute(0, 2, 1) + if mel_spec_type == "vocos": + generated_wave = vocoder.decode(generated) + elif mel_spec_type == "bigvgan": + generated_wave = vocoder(generated) + if rms < target_rms: + generated_wave = generated_wave * rms / target_rms + + # wav -> numpy + generated_wave = generated_wave.squeeze().cpu().numpy() + + if streaming: + for j in range(0, len(generated_wave), chunk_size): + yield generated_wave[j : j + chunk_size], target_sample_rate + else: + generated_cpu = generated[0].cpu().numpy() + del generated + yield generated_wave, generated_cpu + + if streaming: + for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches: + for chunk in process_batch(gen_text): + yield chunk else: - t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype) - if sway_sampling_coef is not None: - t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) - - trajectory = odeint(fn, y0, t, **self.odeint_kwargs) - self.transformer.clear_cache() - - sampled = trajectory[-1] - out = sampled - out = torch.where(cond_mask, cond, out) - - if exists(vocoder): - out = out.permute(0, 2, 1) - out = vocoder(out) - - return out, trajectory - - def forward( - self, - inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722 - text: int["b nt"] | list[str], # noqa: F722 - *, - lens: int["b"] | None = None, # noqa: F821 - noise_scheduler: str | None = None, - ): - # handle raw wave - if inp.ndim == 2: - inp = self.mel_spec(inp) - inp = inp.permute(0, 2, 1) - assert inp.shape[-1] == self.num_channels - - batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma - - # handle text as string - if isinstance(text, list): - if exists(self.vocab_char_map): - text = list_str_to_idx(text, self.vocab_char_map).to(device) - else: - text = list_str_to_tensor(text).to(device) - assert text.shape[0] == batch - - # lens and mask - if not exists(lens): - lens = torch.full((batch,), seq_len, device=device) - - mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch - - # get a random span to mask out for training conditionally - frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask) - rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) + with ThreadPoolExecutor() as executor: + futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches] + for future in progress.tqdm(futures) if progress is not None else futures: + result = future.result() + if result: + generated_wave, generated_mel_spec = next(result) + generated_waves.append(generated_wave) + spectrograms.append(generated_mel_spec) + + if generated_waves: + if cross_fade_duration <= 0: + # Simply concatenate + final_wave = np.concatenate(generated_waves) + else: + # Combine all generated waves with cross-fading + final_wave = generated_waves[0] + for i in range(1, len(generated_waves)): + prev_wave = final_wave + next_wave = generated_waves[i] + + # Calculate cross-fade samples, ensuring it does not exceed wave lengths + cross_fade_samples = int(cross_fade_duration * target_sample_rate) + cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave)) + + if cross_fade_samples <= 0: + # No overlap possible, concatenate + final_wave = np.concatenate([prev_wave, next_wave]) + continue + + # Overlapping parts + prev_overlap = prev_wave[-cross_fade_samples:] + next_overlap = next_wave[:cross_fade_samples] + + # Fade out and fade in + fade_out = np.linspace(1, 0, cross_fade_samples) + fade_in = np.linspace(0, 1, cross_fade_samples) + + # Cross-faded overlap + cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in + + # Combine + new_wave = np.concatenate( + [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]] + ) + + final_wave = new_wave + + # Create a combined spectrogram + combined_spectrogram = np.concatenate(spectrograms, axis=1) + + yield final_wave, target_sample_rate, combined_spectrogram - if exists(mask): - rand_span_mask &= mask - - # mel is x1 - x1 = inp - - # x0 is gaussian noise - x0 = torch.randn_like(x1) - - # time step - time = torch.rand((batch,), dtype=dtype, device=self.device) - # TODO. noise_scheduler - - # sample xt (φ_t(x) in the paper) - t = time.unsqueeze(-1).unsqueeze(-1) - φ = (1 - t) * x0 + t * x1 - flow = x1 - x0 - - # only predict what is within the random mask span for infilling - cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) - - # transformer and cfg training with a drop rate - drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper - if random() < self.cond_drop_prob: # p_uncond in voicebox paper - drop_audio_cond = True - drop_text = True - else: - drop_text = False - - # apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold - pred = self.transformer( - x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, mask=mask - ) - - # flow matching loss - loss = F.mse_loss(pred, flow, reduction="none") - loss = loss[rand_span_mask] - - return loss.mean(), cond, pred - - -def infer_process( - ref_audio, - ref_text, - gen_text, - model_obj, - vocoder, - mel_spec_type=mel_spec_type, - show_info=print, - progress=tqdm, - target_rms=target_rms, - cross_fade_duration=cross_fade_duration, - nfe_step=nfe_step, - cfg_strength=cfg_strength, - sway_sampling_coef=sway_sampling_coef, - speed=speed, - fix_duration=fix_duration, - device=device, -): - # Split the input text into batches - audio, sr = torchaudio.load(ref_audio) - max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed) - gen_text_batches = chunk_text(gen_text, max_chars=max_chars) - for i, gen_text in enumerate(gen_text_batches): - print(f"gen_text {i}", gen_text) - print("\n") - - show_info(f"Generating audio in {len(gen_text_batches)} batches...") - return next( - infer_batch_process( - (audio, sr), - ref_text, - gen_text_batches, - model_obj, - vocoder, - mel_spec_type=mel_spec_type, - progress=progress, - target_rms=target_rms, - cross_fade_duration=cross_fade_duration, - nfe_step=nfe_step, - cfg_strength=cfg_strength, - sway_sampling_coef=sway_sampling_coef, - speed=speed, - fix_duration=fix_duration, - device=device, - ) - ) - - -# infer batches - - -def infer_batch_process( - ref_audio, - ref_text, - gen_text_batches, - model_obj, - vocoder, - mel_spec_type="vocos", - progress=tqdm, - target_rms=0.1, - cross_fade_duration=0.15, - nfe_step=32, - cfg_strength=2.0, - sway_sampling_coef=-1, - speed=1, - fix_duration=None, - device=None, - streaming=False, - chunk_size=2048, -): - audio, sr = ref_audio - if audio.shape[0] > 1: - audio = torch.mean(audio, dim=0, keepdim=True) - - rms = torch.sqrt(torch.mean(torch.square(audio))) - if rms < target_rms: - audio = audio * target_rms / rms - if sr != target_sample_rate: - resampler = torchaudio.transforms.Resample(sr, target_sample_rate) - audio = resampler(audio) - audio = audio.to(device) - - generated_waves = [] - spectrograms = [] - - if len(ref_text[-1].encode("utf-8")) == 1: - ref_text = ref_text + " " - - def process_batch(gen_text): - local_speed = speed - if len(gen_text.encode("utf-8")) < 10: - local_speed = 0.3 - - # Prepare the text - text_list = [ref_text + gen_text] - final_text_list = convert_char_to_pinyin(text_list) - - ref_audio_len = audio.shape[-1] // hop_length - if fix_duration is not None: - duration = int(fix_duration * target_sample_rate / hop_length) - else: - # Calculate duration - ref_text_len = len(ref_text.encode("utf-8")) - gen_text_len = len(gen_text.encode("utf-8")) - duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed) - - # inference - with torch.inference_mode(): - generated, _ = model_obj.sample( - cond=audio, - text=final_text_list, - duration=duration, - steps=nfe_step, - cfg_strength=cfg_strength, - sway_sampling_coef=sway_sampling_coef, - ) - del _ - - generated = generated.to(torch.float32) # generated mel spectrogram - generated = generated[:, ref_audio_len:, :] - generated = generated.permute(0, 2, 1) - if mel_spec_type == "vocos": - generated_wave = vocoder.decode(generated) - elif mel_spec_type == "bigvgan": - generated_wave = vocoder(generated) - if rms < target_rms: - generated_wave = generated_wave * rms / target_rms - - # wav -> numpy - generated_wave = generated_wave.squeeze().cpu().numpy() - - if streaming: - for j in range(0, len(generated_wave), chunk_size): - yield generated_wave[j : j + chunk_size], target_sample_rate - else: - generated_cpu = generated[0].cpu().numpy() - del generated - yield generated_wave, generated_cpu - - if streaming: - for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches: - for chunk in process_batch(gen_text): - yield chunk - else: - with ThreadPoolExecutor() as executor: - futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches] - for future in progress.tqdm(futures) if progress is not None else futures: - result = future.result() - if result: - generated_wave, generated_mel_spec = next(result) - generated_waves.append(generated_wave) - spectrograms.append(generated_mel_spec) - - if generated_waves: - if cross_fade_duration <= 0: - # Simply concatenate - final_wave = np.concatenate(generated_waves) else: - # Combine all generated waves with cross-fading - final_wave = generated_waves[0] - for i in range(1, len(generated_waves)): - prev_wave = final_wave - next_wave = generated_waves[i] - - # Calculate cross-fade samples, ensuring it does not exceed wave lengths - cross_fade_samples = int(cross_fade_duration * target_sample_rate) - cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave)) - - if cross_fade_samples <= 0: - # No overlap possible, concatenate - final_wave = np.concatenate([prev_wave, next_wave]) - continue + yield None, target_sample_rate, None - # Overlapping parts - prev_overlap = prev_wave[-cross_fade_samples:] - next_overlap = next_wave[:cross_fade_samples] - # Fade out and fade in - fade_out = np.linspace(1, 0, cross_fade_samples) - fade_in = np.linspace(0, 1, cross_fade_samples) - # Cross-faded overlap - cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in - - # Combine - new_wave = np.concatenate( - [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]] - ) - - final_wave = new_wave - - # Create a combined spectrogram - combined_spectrogram = np.concatenate(spectrograms, axis=1) - - yield final_wave, target_sample_rate, combined_spectrogram - - else: - yield None, target_sample_rate, None - - - -# load vocoder -def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=device, hf_cache_dir=None): - if vocoder_name == "vocos": - # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device) - if is_local: - print(f"Load vocos from local path {local_path}") - config_path = f"{local_path}/config.yaml" - model_path = f"{local_path}/pytorch_model.bin" - else: - print("Download Vocos from huggingface charactr/vocos-mel-24khz") - repo_id = "charactr/vocos-mel-24khz" - config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml") - model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin") - vocoder = Vocos.from_hparams(config_path) - state_dict = torch.load(model_path, map_location="cpu", weights_only=True) - from vocos.feature_extractors import EncodecFeatures - - if isinstance(vocoder.feature_extractor, EncodecFeatures): - encodec_parameters = { - "feature_extractor.encodec." + key: value - for key, value in vocoder.feature_extractor.encodec.state_dict().items() - } - state_dict.update(encodec_parameters) - vocoder.load_state_dict(state_dict) - vocoder = vocoder.eval().to(device) - elif vocoder_name == "bigvgan": - try: - from third_party.BigVGAN import bigvgan - except ImportError: - print("You need to follow the README to init submodule and change the BigVGAN source code.") - if is_local: - # download generator from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main - vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False) - else: - vocoder = bigvgan.BigVGAN.from_pretrained( - "nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False, cache_dir=hf_cache_dir - ) - vocoder.remove_weight_norm() - vocoder = vocoder.eval().to(device) - return vocoder From ce4237fc33e100a60438028edc7f7599d5d684fb Mon Sep 17 00:00:00 2001 From: ayushtues Date: Tue, 29 Jul 2025 08:13:18 +0530 Subject: [PATCH 6/8] Integrate CFM into the pipeline class --- .../models/transformers/f5tts_transformer.py | 513 +----------------- .../pipelines/f5tts/pipeline_f5tts.py | 381 ++++++++----- 2 files changed, 246 insertions(+), 648 deletions(-) diff --git a/src/diffusers/models/transformers/f5tts_transformer.py b/src/diffusers/models/transformers/f5tts_transformer.py index 9671cdc9e564..1824571dcbe6 100644 --- a/src/diffusers/models/transformers/f5tts_transformer.py +++ b/src/diffusers/models/transformers/f5tts_transformer.py @@ -420,22 +420,6 @@ def forward( batch, seq_len = x.shape[0], x.shape[1] if time.ndim == 0: time = time.repeat(batch) - - # t: conditioning time, text: text, x: noised audio + cond audio + text - t = self.time_embed(time) - - # if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d - # x_cond = self. - # (x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache) - # x_uncond = self. - # (x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache) - # x = torch.cat((x_cond, x_uncond), dim=0) - # t = torch.cat((t, t), dim=0) - # mask = torch.cat((mask, mask), dim=0) if mask is not None else None - # else: - # x = self. - # (x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache) - rope = get_1d_rotary_pos_embed(seq_len, self.dim, device=x.device) for block in self.transformer_blocks: x = block(x, t, mask=mask, rope=rope) @@ -448,99 +432,7 @@ def forward( -def exists(v): - return v is not None - - -def default(v, d): - return v if exists(v) else d - - - - - - - - - - - - - -# Get tokenizer - - -def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): - """ - tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file - - "char" for char-wise tokenizer, need .txt vocab_file - - "byte" for utf-8 tokenizer - - "custom" if you're directly passing in a path to the vocab.txt you want to use - vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols - - if use "char", derived from unfiltered character & symbol counts of custom dataset - - if use "byte", set to 256 (unicode byte range) - """ - if tokenizer in ["pinyin", "char"]: - tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt") - with open(tokenizer_path, "r", encoding="utf-8") as f: - vocab_char_map = {} - for i, char in enumerate(f): - vocab_char_map[char[:-1]] = i - vocab_size = len(vocab_char_map) - assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char" - - elif tokenizer == "byte": - vocab_char_map = None - vocab_size = 256 - - elif tokenizer == "custom": - with open(dataset_name, "r", encoding="utf-8") as f: - vocab_char_map = {} - for i, char in enumerate(f): - vocab_char_map[char[:-1]] = i - vocab_size = len(vocab_char_map) - - return vocab_char_map, vocab_size - - -# convert char to pinyin - - - - - - - - - -def get_vocos_mel_spectrogram( - waveform, - n_fft=1024, - n_mel_channels=100, - target_sample_rate=24000, - hop_length=256, - win_length=1024, -): - mel_stft = torchaudio.transforms.MelSpectrogram( - sample_rate=target_sample_rate, - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, - n_mels=n_mel_channels, - power=1, - center=True, - normalized=False, - norm=None, - ).to(waveform.device) - if len(waveform.shape) == 3: - waveform = waveform.squeeze(1) # 'b 1 nw -> b nw' - - assert len(waveform.shape) == 2 - - mel = mel_stft(waveform) - mel = mel.clamp(min=1e-5).log() - return mel - +#TODO - add BigVGAN support later class MelSpec(nn.Module): def __init__( self, @@ -549,10 +441,8 @@ def __init__( win_length=1024, n_mel_channels=100, target_sample_rate=24_000, - mel_spec_type="vocos", ): super().__init__() - assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan") self.n_fft = n_fft self.hop_length = hop_length @@ -560,397 +450,22 @@ def __init__( self.n_mel_channels = n_mel_channels self.target_sample_rate = target_sample_rate - #TODO - add BigVGAN support later - self.extractor = get_vocos_mel_spectrogram - + self.mel_stft = torchaudio.transforms.MelSpectrogram( + sample_rate=target_sample_rate, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + n_mels=n_mel_channels, + power=1, + center=True, + normalized=False, + norm=None, + ) - self.register_buffer("dummy", torch.tensor(0), persistent=False) def forward(self, wav): - if self.dummy.device != wav.device: - self.to(wav.device) - - mel = self.extractor( - waveform=wav, - n_fft=self.n_fft, - n_mel_channels=self.n_mel_channels, - target_sample_rate=self.target_sample_rate, - hop_length=self.hop_length, - win_length=self.win_length, - ) - + mel = self.mel_stft(wav) + mel = mel.clamp(min=1e-5).log() return mel -class CFM(nn.Module): - def __init__( - self, - transformer: nn.Module, - sigma=0.0, - odeint_kwargs: dict = dict( - # atol = 1e-5, - # rtol = 1e-5, - method="euler" # 'midpoint' - ), - audio_drop_prob=0.3, - cond_drop_prob=0.2, - num_channels=None, - mel_spec_module: nn.Module | None = None, - mel_spec_kwargs: dict = dict(), - frac_lengths_mask: tuple[float, float] = (0.7, 1.0), - vocab_char_map: dict[str:int] | None = None, - ): - super().__init__() - - self.frac_lengths_mask = frac_lengths_mask - - # mel spec - self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) - num_channels = default(num_channels, self.mel_spec.n_mel_channels) - self.num_channels = num_channels - - # classifier-free guidance - self.audio_drop_prob = audio_drop_prob - self.cond_drop_prob = cond_drop_prob - - # transformer - self.transformer = transformer - dim = transformer.dim - self.dim = dim - - # conditional flow related - self.sigma = sigma - - # sampling related - self.odeint_kwargs = odeint_kwargs - - # vocab map for tokenization - self.vocab_char_map = vocab_char_map - - @property - def device(self): - return next(self.parameters()).device - - - # simple utf-8 tokenizer, since paper went character based - def list_str_to_tensor(self, text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722 - list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style - text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True) - return text - - - # char tokenizer, based on custom dataset's extracted .txt file - def list_str_to_idx( - self, - text: list[str] | list[list[str]], - vocab_char_map: dict[str, int], # {char: idx} - padding_value=-1, - ) -> int["b nt"]: # noqa: F722 - list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style - text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) - return text - - - def mask_from_frac_lengths(self, seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821 - lengths = (frac_lengths * seq_len).long() - max_start = seq_len - lengths - - rand = torch.rand_like(frac_lengths) - start = (max_start * rand).long().clamp(min=0) - end = start + lengths - - return mask_from_start_end_indices(seq_len, start, end) - - - - def lens_to_mask(self, t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821 - if not exists(length): - length = t.amax() - - seq = torch.arange(length, device=t.device) - return seq[None, :] < t[:, None] - - - def get_epss_timesteps(self, n, device, dtype): - dt = 1 / 32 - predefined_timesteps = { - 5: [0, 2, 4, 8, 16, 32], - 6: [0, 2, 4, 6, 8, 16, 32], - 7: [0, 2, 4, 6, 8, 16, 24, 32], - 10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32], - 12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32], - 16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32], - } - t = predefined_timesteps.get(n, []) - if not t: - return torch.linspace(0, 1, n + 1, device=device, dtype=dtype) - return dt * torch.tensor(t, device=device, dtype=dtype) - - - def convert_char_to_pinyin(self, text_list, polyphone=True): - if jieba.dt.initialized is False: - jieba.default_logger.setLevel(50) # CRITICAL - jieba.initialize() - - final_text_list = [] - custom_trans = str.maketrans( - {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} - ) # add custom trans here, to address oov - - def is_chinese(c): - return ( - "\u3100" <= c <= "\u9fff" # common chinese characters - ) - - for text in text_list: - char_list = [] - text = text.translate(custom_trans) - for seg in jieba.cut(text): - seg_byte_len = len(bytes(seg, "UTF-8")) - if seg_byte_len == len(seg): # if pure alphabets and symbols - if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": - char_list.append(" ") - char_list.extend(seg) - elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters - seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) - for i, c in enumerate(seg): - if is_chinese(c): - char_list.append(" ") - char_list.append(seg_[i]) - else: # if mixed characters, alphabets and symbols - for c in seg: - if ord(c) < 256: - char_list.extend(c) - elif is_chinese(c): - char_list.append(" ") - char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)) - else: - char_list.append(c) - final_text_list.append(char_list) - - return final_text_list - - - - - @torch.no_grad() - def sample( - self, - cond: float["b n d"] | float["b nw"], # noqa: F722 - text: int["b nt"] | list[str], # noqa: F722 - duration: int | int["b"], # noqa: F821 - *, - lens: int["b"] | None = None, # noqa: F821 - steps=32, - cfg_strength=1.0, - sway_sampling_coef=None, - seed: int | None = None, - max_duration=4096, - vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 - use_epss=True, - no_ref_audio=False, - duplicate_test=False, - t_inter=0.1, - edit_mask=None, - ): - self.eval() - # raw wave - - if cond.ndim == 2: - cond = self.mel_spec(cond) - cond = cond.permute(0, 2, 1) - assert cond.shape[-1] == self.num_channels - - cond = cond.to(next(self.parameters()).dtype) - - batch, cond_seq_len, device = *cond.shape[:2], cond.device - if not exists(lens): - lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) - - # text - - if isinstance(text, list): - if exists(self.vocab_char_map): - text = self.list_str_to_idx(text, self.vocab_char_map).to(device) - else: - text = self.list_str_to_tensor(text).to(device) - assert text.shape[0] == batch - - # duration - - cond_mask = lens_to_mask(lens) - if edit_mask is not None: - cond_mask = cond_mask & edit_mask - - if isinstance(duration, int): - duration = torch.full((batch,), duration, device=device, dtype=torch.long) - - duration = torch.maximum( - torch.maximum((text != -1).sum(dim=-1), lens) + 1, duration - ) # duration at least text/audio prompt length plus one token, so something is generated - duration = duration.clamp(max=max_duration) - max_duration = duration.amax() - - # duplicate test corner for inner time step oberservation - if duplicate_test: - test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0) - - cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) - if no_ref_audio: - cond = torch.zeros_like(cond) - - cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) - cond_mask = cond_mask.unsqueeze(-1) - step_cond = torch.where( - cond_mask, cond, torch.zeros_like(cond) - ) # allow direct control (cut cond audio) with lens passed in - - if batch > 1: - mask = lens_to_mask(duration) - else: # save memory and speed up, as single inference need no mask currently - mask = None - - # neural ode - - def fn(t, x): - # at each step, conditioning is fixed - # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) - - # predict flow (cond) - if cfg_strength < 1e-5: - pred = self.transformer( - x=x, - cond=step_cond, - text=text, - time=t, - mask=mask, - drop_audio_cond=False, - drop_text=False, - cache=True, - ) - return pred - - # predict flow (cond and uncond), for classifier-free guidance - pred_cfg = self.transformer( - x=x, - cond=step_cond, - text=text, - time=t, - mask=mask, - cfg_infer=True, - cache=True, - ) - pred, null_pred = torch.chunk(pred_cfg, 2, dim=0) - return pred + (pred - null_pred) * cfg_strength - - # noise input - # to make sure batch inference result is same with different batch size, and for sure single inference - # still some difference maybe due to convolutional layers - y0 = [] - for dur in duration: - if exists(seed): - torch.manual_seed(seed) - y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype)) - y0 = pad_sequence(y0, padding_value=0, batch_first=True) - - t_start = 0 - - # duplicate test corner for inner time step oberservation - if duplicate_test: - t_start = t_inter - y0 = (1 - t_start) * y0 + t_start * test_cond - steps = int(steps * (1 - t_start)) - - if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE - t = self.get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype) - else: - t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype) - if sway_sampling_coef is not None: - t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) - - trajectory = odeint(fn, y0, t, **self.odeint_kwargs) - self.transformer.clear_cache() - - sampled = trajectory[-1] - out = sampled - out = torch.where(cond_mask, cond, out) - - if exists(vocoder): - out = out.permute(0, 2, 1) - out = vocoder(out) - - return out, trajectory - - def forward( - self, - inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722 - text: int["b nt"] | list[str], # noqa: F722 - *, - lens: int["b"] | None = None, # noqa: F821 - noise_scheduler: str | None = None, - ): - # handle raw wave - if inp.ndim == 2: - inp = self.mel_spec(inp) - inp = inp.permute(0, 2, 1) - assert inp.shape[-1] == self.num_channels - - batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma - - # handle text as string - if isinstance(text, list): - if exists(self.vocab_char_map): - text = self.list_str_to_idx(text, self.vocab_char_map).to(device) - else: - text = self.list_str_to_tensor(text).to(device) - assert text.shape[0] == batch - - # lens and mask - if not exists(lens): - lens = torch.full((batch,), seq_len, device=device) - - mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch - - # get a random span to mask out for training conditionally - frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask) - rand_span_mask = self.mask_from_frac_lengths(lens, frac_lengths) - - if exists(mask): - rand_span_mask &= mask - - # mel is x1 - x1 = inp - - # x0 is gaussian noise - x0 = torch.randn_like(x1) - - # time step - time = torch.rand((batch,), dtype=dtype, device=self.device) - # TODO. noise_scheduler - - # sample xt (φ_t(x) in the paper) - t = time.unsqueeze(-1).unsqueeze(-1) - φ = (1 - t) * x0 + t * x1 - flow = x1 - x0 - - # only predict what is within the random mask span for infilling - cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) - - # transformer and cfg training with a drop rate - drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper - if random() < self.cond_drop_prob: # p_uncond in voicebox paper - drop_audio_cond = True - drop_text = True - else: - drop_text = False - - # apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold - pred = self.transformer( - x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, mask=mask - ) - - # flow matching loss - loss = F.mse_loss(pred, flow, reduction="none") - loss = loss[rand_span_mask] - - return loss.mean(), cond, pred diff --git a/src/diffusers/pipelines/f5tts/pipeline_f5tts.py b/src/diffusers/pipelines/f5tts/pipeline_f5tts.py index 7e7ef3c2146c..8708e2386af8 100644 --- a/src/diffusers/pipelines/f5tts/pipeline_f5tts.py +++ b/src/diffusers/pipelines/f5tts/pipeline_f5tts.py @@ -19,7 +19,6 @@ from torchdiffeq import odeint import torchaudio -from librosa.filters import mel as librosa_mel_fn import os import random @@ -31,7 +30,7 @@ from torch.nn.utils.rnn import pad_sequence from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline from vocos import Vocos - +from ...models.transformers.f5tts_transformer import DiT, MelSpec, ConditioningEncoder # helpers @@ -40,174 +39,258 @@ + class F5FlowPipeline(DiffusionPipeline): def __init__( self, - model: CFM, - vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 - mel_spec_type="vocos", + transformer: DiT, + conditioning_encoder: ConditioningEncoder, + odeint_kwargs: dict = dict( + # atol = 1e-5, + # rtol = 1e-5, + method="euler" # 'midpoint' + ), + num_channels=None, + mel_spec_module: nn.Module | None = None, mel_spec_kwargs: dict = dict(), - device=None, + vocab_char_map: dict[str:int] | None = None, ): super().__init__() - self.model = model - self.vocoder = vocoder - self.mel_spec_type = mel_spec_type - self.mel_spec_kwargs = mel_spec_kwargs + self.transformer = transformer + self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) + num_channels = default(num_channels, self.mel_spec.n_mel_channels) + self.num_channels = num_channels + # sampling related + self.odeint_kwargs = odeint_kwargs + # vocab map for tokenization + self.vocab_char_map = vocab_char_map self.device = device or torch.device("cpu") - def __call__(self, *args, **kwargs): - return infer_process(*args, **kwargs, model_obj=self.model, vocoder=self.vocoder, device=self.device) + # simple utf-8 tokenizer, since paper went character based + def list_str_to_tensor(self, text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722 + list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style + text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True) + return text - def infer_batch_process( + + # char tokenizer, based on custom dataset's extracted .txt file + def list_str_to_idx( + self, + text: list[str] | list[list[str]], + vocab_char_map: dict[str, int], # {char: idx} + padding_value=-1, + ) -> int["b nt"]: # noqa: F722 + list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style + text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) + return text + + + + + def lens_to_mask(self, t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821 + if not exists(length): + length = t.amax() + + seq = torch.arange(length, device=t.device) + return seq[None, :] < t[:, None] + + + def get_epss_timesteps(self, n, device, dtype): + dt = 1 / 32 + predefined_timesteps = { + 5: [0, 2, 4, 8, 16, 32], + 6: [0, 2, 4, 6, 8, 16, 32], + 7: [0, 2, 4, 6, 8, 16, 24, 32], + 10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32], + 12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32], + 16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32], + } + t = predefined_timesteps.get(n, []) + if not t: + return torch.linspace(0, 1, n + 1, device=device, dtype=dtype) + return dt * torch.tensor(t, device=device, dtype=dtype) + + + def convert_char_to_pinyin(self, text_list, polyphone=True): + if jieba.dt.initialized is False: + jieba.default_logger.setLevel(50) # CRITICAL + jieba.initialize() + + final_text_list = [] + custom_trans = str.maketrans( + {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} + ) # add custom trans here, to address oov + + def is_chinese(c): + return ( + "\u3100" <= c <= "\u9fff" # common chinese characters + ) + + for text in text_list: + char_list = [] + text = text.translate(custom_trans) + for seg in jieba.cut(text): + seg_byte_len = len(bytes(seg, "UTF-8")) + if seg_byte_len == len(seg): # if pure alphabets and symbols + if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": + char_list.append(" ") + char_list.extend(seg) + elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters + seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) + for i, c in enumerate(seg): + if is_chinese(c): + char_list.append(" ") + char_list.append(seg_[i]) + else: # if mixed characters, alphabets and symbols + for c in seg: + if ord(c) < 256: + char_list.extend(c) + elif is_chinese(c): + char_list.append(" ") + char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)) + else: + char_list.append(c) + final_text_list.append(char_list) + + return final_text_list + + + def __call__( + self, ref_audio, ref_text, - gen_text_batches, - model_obj, - vocoder, - mel_spec_type="vocos", - progress=tqdm, - target_rms=0.1, - cross_fade_duration=0.15, + gen_text, nfe_step=32, cfg_strength=2.0, sway_sampling_coef=-1, speed=1, fix_duration=None, - device=None, - streaming=False, - chunk_size=2048, ): - audio, sr = ref_audio - if audio.shape[0] > 1: - audio = torch.mean(audio, dim=0, keepdim=True) - - rms = torch.sqrt(torch.mean(torch.square(audio))) - if rms < target_rms: - audio = audio * target_rms / rms - if sr != target_sample_rate: - resampler = torchaudio.transforms.Resample(sr, target_sample_rate) - audio = resampler(audio) - audio = audio.to(device) - - generated_waves = [] - spectrograms = [] - - if len(ref_text[-1].encode("utf-8")) == 1: - ref_text = ref_text + " " - - def process_batch(gen_text): - local_speed = speed - if len(gen_text.encode("utf-8")) < 10: - local_speed = 0.3 - - # Prepare the text - text_list = [ref_text + gen_text] - final_text_list = convert_char_to_pinyin(text_list) - - ref_audio_len = audio.shape[-1] // hop_length - if fix_duration is not None: - duration = int(fix_duration * target_sample_rate / hop_length) - else: - # Calculate duration - ref_text_len = len(ref_text.encode("utf-8")) - gen_text_len = len(gen_text.encode("utf-8")) - duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed) - - # inference - with torch.inference_mode(): - generated, _ = model_obj.sample( - cond=audio, - text=final_text_list, - duration=duration, - steps=nfe_step, - cfg_strength=cfg_strength, - sway_sampling_coef=sway_sampling_coef, - ) - del _ - - generated = generated.to(torch.float32) # generated mel spectrogram - generated = generated[:, ref_audio_len:, :] - generated = generated.permute(0, 2, 1) - if mel_spec_type == "vocos": - generated_wave = vocoder.decode(generated) - elif mel_spec_type == "bigvgan": - generated_wave = vocoder(generated) - if rms < target_rms: - generated_wave = generated_wave * rms / target_rms - - # wav -> numpy - generated_wave = generated_wave.squeeze().cpu().numpy() - - if streaming: - for j in range(0, len(generated_wave), chunk_size): - yield generated_wave[j : j + chunk_size], target_sample_rate - else: - generated_cpu = generated[0].cpu().numpy() - del generated - yield generated_wave, generated_cpu - - if streaming: - for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches: - for chunk in process_batch(gen_text): - yield chunk + + + # Prepare the text + text_list = [ref_text + gen_text] + final_text_list = convert_char_to_pinyin(text_list) + ref_audio_len = audio.shape[-1] // hop_length + if fix_duration is not None: + duration = int(fix_duration * target_sample_rate / hop_length) else: - with ThreadPoolExecutor() as executor: - futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches] - for future in progress.tqdm(futures) if progress is not None else futures: - result = future.result() - if result: - generated_wave, generated_mel_spec = next(result) - generated_waves.append(generated_wave) - spectrograms.append(generated_mel_spec) - - if generated_waves: - if cross_fade_duration <= 0: - # Simply concatenate - final_wave = np.concatenate(generated_waves) - else: - # Combine all generated waves with cross-fading - final_wave = generated_waves[0] - for i in range(1, len(generated_waves)): - prev_wave = final_wave - next_wave = generated_waves[i] - - # Calculate cross-fade samples, ensuring it does not exceed wave lengths - cross_fade_samples = int(cross_fade_duration * target_sample_rate) - cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave)) - - if cross_fade_samples <= 0: - # No overlap possible, concatenate - final_wave = np.concatenate([prev_wave, next_wave]) - continue - - # Overlapping parts - prev_overlap = prev_wave[-cross_fade_samples:] - next_overlap = next_wave[:cross_fade_samples] - - # Fade out and fade in - fade_out = np.linspace(1, 0, cross_fade_samples) - fade_in = np.linspace(0, 1, cross_fade_samples) - - # Cross-faded overlap - cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in - - # Combine - new_wave = np.concatenate( - [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]] - ) - - final_wave = new_wave - - # Create a combined spectrogram - combined_spectrogram = np.concatenate(spectrograms, axis=1) - - yield final_wave, target_sample_rate, combined_spectrogram + # Calculate duration + ref_text_len = len(ref_text.encode("utf-8")) + gen_text_len = len(gen_text.encode("utf-8")) + duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed) - else: - yield None, target_sample_rate, None + cond = ref_audio + + if cond.ndim == 2: + cond = self.mel_spec(cond) + cond = cond.permute(0, 2, 1) + assert cond.shape[-1] == self.num_channels + + batch, cond_seq_len, device = *cond.shape[:2], cond.device + lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) + + if isinstance(final_text_list, list): + if exists(self.vocab_char_map): + text = self.list_str_to_idx(final_text_list, self.vocab_char_map).to(device) + else: + text = self.list_str_to_tensor(final_text_list).to(device) + assert text.shape[0] == batch + + # duration + cond_mask = lens_to_mask(lens) + if isinstance(duration, int): + duration = torch.full((batch,), duration, device=device, dtype=torch.long) + + duration = torch.maximum( + torch.maximum((text != -1).sum(dim=-1), lens) + 1, duration + ) # duration at least text/audio prompt length plus one token, so something is generated + duration = duration.clamp(max=max_duration) + max_duration = duration.amax() + + cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) + cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) + cond_mask = cond_mask.unsqueeze(-1) + step_cond = torch.where( + cond_mask, cond, torch.zeros_like(cond) + ) # allow direct control (cut cond audio) with lens passed in + + + + + + step_cond = self.conditioning_encoder(x, step_cond, text) + + if batch > 1: + mask = lens_to_mask(duration) + else: # save memory and speed up, as single inference need no mask currently + mask = None + + # neural ode + def fn(t, x): + # at each step, conditioning is fixed + # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) + step_cond = self.conditioning_encoder(x, step_cond, text, drop_audio_cond=False, drop_text=False) + # predict flow (cond) + if cfg_strength < 1e-5: + pred = self.transformer( + x=x, + cond=step_cond, + time=t, + mask=mask, + cache=True, + ) + return pred + + # predict flow (cond and uncond), for classifier-free guidance + + step_uncond = self.conditioning_encoder(x, step_uncond, text, drop_audio_cond=False, drop_text=False) + step_cond = torch.cat((step_cond, step_uncond), dim=0) + x = torch.cat((x, x), dim=0) + t = torch.cat((t, t), dim=0) + pred_cfg = self.transformer( + x=x, + cond=step_cond, + time=t, + mask=mask, + cache=True, + ) + pred, null_pred = torch.chunk(pred_cfg, 2, dim=0) + return pred + (pred - null_pred) * cfg_strength + + # noise input + # to make sure batch inference result is same with different batch size, and for sure single inference + # still some difference maybe due to convolutional layers + y0 = [] + for dur in duration: + if exists(seed): + torch.manual_seed(seed) + y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype)) + y0 = pad_sequence(y0, padding_value=0, batch_first=True) + t_start = 0 + + # TODO Add Empirically Pruned Step Sampling for low NFE + t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype) + if sway_sampling_coef is not None: + t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) + + trajectory = odeint(fn, y0, t, **self.odeint_kwargs) + self.transformer.clear_cache() + + sampled = trajectory[-1] + out = sampled + out = torch.where(cond_mask, cond, out) + + out = out.to(torch.float32) # generated mel spectrogram + out = out[:, ref_audio_len:, :] + out = out.permute(0, 2, 1) + generated_cpu = out[0].cpu().numpy() + + + # This need to be in HF Output format + return generated_cpu From f127e453dcdfd513858bc9d66c94e4c01a08c8c6 Mon Sep 17 00:00:00 2001 From: ayushtues Date: Sat, 2 Aug 2025 18:44:03 +0530 Subject: [PATCH 7/8] F5 pipeline definition working --- .../pipelines/f5tts/pipeline_f5tts.py | 76 ++++++++++++++----- 1 file changed, 58 insertions(+), 18 deletions(-) diff --git a/src/diffusers/pipelines/f5tts/pipeline_f5tts.py b/src/diffusers/pipelines/f5tts/pipeline_f5tts.py index 8708e2386af8..d351ee375f9f 100644 --- a/src/diffusers/pipelines/f5tts/pipeline_f5tts.py +++ b/src/diffusers/pipelines/f5tts/pipeline_f5tts.py @@ -8,7 +8,6 @@ """ from __future__ import annotations - from random import random from typing import Callable @@ -17,7 +16,6 @@ from torch import nn from torch.nn.utils.rnn import pad_sequence from torchdiffeq import odeint - import torchaudio import os @@ -25,12 +23,11 @@ from collections import defaultdict from importlib.resources import files - import torch from torch.nn.utils.rnn import pad_sequence -from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline +from diffusers.pipelines.pipeline_utils import AudioPipelineOutput, DiffusionPipeline from vocos import Vocos -from ...models.transformers.f5tts_transformer import DiT, MelSpec, ConditioningEncoder +from diffusers.models.transformers.f5tts_transformer import DiT, MelSpec, ConditioningEncoder # helpers @@ -46,25 +43,20 @@ def __init__( transformer: DiT, conditioning_encoder: ConditioningEncoder, odeint_kwargs: dict = dict( - # atol = 1e-5, - # rtol = 1e-5, - method="euler" # 'midpoint' + method="euler" ), - num_channels=None, - mel_spec_module: nn.Module | None = None, mel_spec_kwargs: dict = dict(), vocab_char_map: dict[str:int] | None = None, ): super().__init__() self.transformer = transformer - self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) - num_channels = default(num_channels, self.mel_spec.n_mel_channels) + self.mel_spec = MelSpec(**mel_spec_kwargs) + num_channels = self.mel_spec.n_mel_channels self.num_channels = num_channels # sampling related self.odeint_kwargs = odeint_kwargs # vocab map for tokenization self.vocab_char_map = vocab_char_map - self.device = device or torch.device("cpu") # simple utf-8 tokenizer, since paper went character based @@ -217,9 +209,6 @@ def __call__( cond_mask, cond, torch.zeros_like(cond) ) # allow direct control (cut cond audio) with lens passed in - - - step_cond = self.conditioning_encoder(x, step_cond, text) @@ -267,12 +256,12 @@ def fn(t, x): for dur in duration: if exists(seed): torch.manual_seed(seed) - y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype)) + y0.append(torch.randn(dur, self.num_channels, device=self.transformer.device, dtype=step_cond.dtype)) y0 = pad_sequence(y0, padding_value=0, batch_first=True) t_start = 0 # TODO Add Empirically Pruned Step Sampling for low NFE - t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype) + t = torch.linspace(t_start, 1, steps + 1, device=self.transformer.device, dtype=step_cond.dtype) if sway_sampling_coef is not None: t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) @@ -294,3 +283,54 @@ def fn(t, x): +if __name__ == "__main__": + print('entering main funcitn') + + dit_config = { + "dim": 1024, + "depth": 22, + "heads": 16, + "ff_mult": 2, + "text_dim": 512, + "text_num_embeds": 256, + "text_mask_padding": True, + "qk_norm": None, # null | rms_norm + "conv_layers": 4, + "pe_attn_head": None, + "attn_backend": "torch", # torch | flash_attn + "attn_mask_enabled": False, + "checkpoint_activations": False, # recompute activations and save memory for extra compute + } + + + mel_spec_config = { + "target_sample_rate": 24000, + "n_mel_channels": 100, + "hop_length": 256, + "win_length": 1024, + "n_fft": 1024, + } + + + dit = DiT(**dit_config) + print("DiT model initialized with config:", dit_config) + + conditioning_encoder_config = { + 'dim': 1024, + 'text_num_embeds': 256, + 'text_dim': 512, + 'text_mask_padding': True, + 'conv_layers': 4, + 'mel_dim': mel_spec_config['n_mel_channels'], + } + conditioning_encoder = ConditioningEncoder(**conditioning_encoder_config) + print("Conditioning Encoder initialized with config:", conditioning_encoder_config) + + f5_pipeline = F5FlowPipeline( + transformer=dit, + conditioning_encoder=conditioning_encoder, + odeint_kwargs={"method": "euler"}, + mel_spec_kwargs=mel_spec_config, + ) + print("F5FlowPipeline initialized with DiT and Conditioning Encoder.") + From 7fe21e677039a2d38ba18e333ca85909394ec447 Mon Sep 17 00:00:00 2001 From: ayushtues Date: Sun, 3 Aug 2025 12:05:34 +0530 Subject: [PATCH 8/8] Make forwrad pass work --- src/diffusers/models/attention_processor.py | 3 +- .../models/transformers/f5tts_transformer.py | 27 +++--- .../pipelines/f5tts/pipeline_f5tts.py | 83 ++++++++++++------- 3 files changed, 67 insertions(+), 46 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index f4746918cf17..823e2c052fc8 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3150,7 +3150,6 @@ def __call__( if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) @@ -3208,7 +3207,7 @@ def __call__( query_unrotated = query[..., self.pe_attn_head:, :] query = torch.cat((query_rotated, query_unrotated), dim=-2) - if not attn.is_cross_attention: + if not attn.is_cross_attention and self.pe_attn_head is not None: key_to_rotate, key_unrotated = key[..., :self.pe_attn_head, :], key[..., self.pe_attn_head:, :] key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) key = torch.cat((key_rotated, key_unrotated), dim=-2) diff --git a/src/diffusers/models/transformers/f5tts_transformer.py b/src/diffusers/models/transformers/f5tts_transformer.py index 1824571dcbe6..4c23d07612f0 100644 --- a/src/diffusers/models/transformers/f5tts_transformer.py +++ b/src/diffusers/models/transformers/f5tts_transformer.py @@ -12,6 +12,7 @@ import torch import torch.nn.functional as F from torch import nn +import torchaudio from ..normalization import GlobalResponseNorm, AdaLayerNorm, RMSNorm import math from ..embeddings import get_1d_rotary_pos_embed, apply_rotary_emb @@ -20,8 +21,8 @@ from ..attention import Attention from einops import rearrange, repeat, reduce, pack, unpack from torch import nn, einsum, tensor, Tensor, cat, stack, arange, is_tensor -import jieba -from pypinyin import Style, lazy_pinyin +# import jieba +# from pypinyin import Style, lazy_pinyin @@ -115,7 +116,7 @@ def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embed norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) # attention - attn_output = self.attn(x=norm, mask=mask, rope=rope) + attn_output = self.attn(hidden_states=norm, attention_mask=mask, rotary_emb=rope) # process attention output for input x x = x + gate_msa.unsqueeze(1) * attn_output @@ -287,6 +288,7 @@ def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) for block in self.text_blocks: text = block(text) + text = text.squeeze(0) # TODO for some reason an extra dimension is added text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) else: text = self.text_blocks(text) @@ -318,7 +320,7 @@ def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b class ConditioningEncoder(nn.Module): def __init__( self, - dim + dim, text_num_embeds, text_dim, text_mask_padding, @@ -343,13 +345,9 @@ def forward( seq_len = x.shape[1] if cache: if drop_text: - if self.text_uncond is None: - self.text_uncond = self.text_embed(text, seq_len, drop_text=True) - text_embed = self.text_uncond + text_embed = self.text_embed(text, seq_len, drop_text=True) else: - if self.text_cond is None: - self.text_cond = self.text_embed(text, seq_len, drop_text=False) - text_embed = self.text_cond + text_embed = self.text_embed(text, seq_len, drop_text=False) else: text_embed = self.text_embed(text, seq_len, drop_text=drop_text) @@ -406,7 +404,7 @@ def __init__( ) self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None - self.norm_out = AdaLayerNorm(dim) # final modulation + self.norm_out = AdaLayerNorm(dim, chunk_dim=1) # final modulation self.proj_out = nn.Linear(dim, mel_dim) @@ -420,11 +418,12 @@ def forward( batch, seq_len = x.shape[0], x.shape[1] if time.ndim == 0: time = time.repeat(batch) - rope = get_1d_rotary_pos_embed(seq_len, self.dim, device=x.device) + temb = self.time_embed(time) # b d + rope = get_1d_rotary_pos_embed(pos=seq_len, dim=self.dim, use_real=True) for block in self.transformer_blocks: - x = block(x, t, mask=mask, rope=rope) + x = block(x, temb, mask=mask, rope=rope) - x = self.norm_out(x, t) + x = self.norm_out(x, temb=temb) output = self.proj_out(x) return output diff --git a/src/diffusers/pipelines/f5tts/pipeline_f5tts.py b/src/diffusers/pipelines/f5tts/pipeline_f5tts.py index d351ee375f9f..7cf7f89c63a1 100644 --- a/src/diffusers/pipelines/f5tts/pipeline_f5tts.py +++ b/src/diffusers/pipelines/f5tts/pipeline_f5tts.py @@ -29,7 +29,8 @@ from vocos import Vocos from diffusers.models.transformers.f5tts_transformer import DiT, MelSpec, ConditioningEncoder # helpers - +import jieba +from pypinyin import lazy_pinyin, Style @@ -50,6 +51,7 @@ def __init__( ): super().__init__() self.transformer = transformer + self.conditioning_encoder = conditioning_encoder self.mel_spec = MelSpec(**mel_spec_kwargs) num_channels = self.mel_spec.n_mel_channels self.num_channels = num_channels @@ -81,7 +83,7 @@ def list_str_to_idx( def lens_to_mask(self, t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821 - if not exists(length): + if length is None: length = t.amax() seq = torch.arange(length, device=t.device) @@ -151,22 +153,25 @@ def is_chinese(c): def __call__( self, ref_audio, - ref_text, - gen_text, + text_list, nfe_step=32, cfg_strength=2.0, sway_sampling_coef=-1, speed=1, + max_duration=4096, fix_duration=None, + seed=None, + device="cuda", + steps=32, ): # Prepare the text - text_list = [ref_text + gen_text] - final_text_list = convert_char_to_pinyin(text_list) - ref_audio_len = audio.shape[-1] // hop_length + + final_text_list = self.convert_char_to_pinyin(text_list) + ref_audio_len = ref_audio.shape[-1] // self.mel_spec.hop_length if fix_duration is not None: - duration = int(fix_duration * target_sample_rate / hop_length) + duration = int(fix_duration * self.mel_spec.target_sample_rate / self.mel_spec.hop_length) else: # Calculate duration ref_text_len = len(ref_text.encode("utf-8")) @@ -185,14 +190,14 @@ def __call__( lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) if isinstance(final_text_list, list): - if exists(self.vocab_char_map): + if self.vocab_char_map is not None: text = self.list_str_to_idx(final_text_list, self.vocab_char_map).to(device) else: text = self.list_str_to_tensor(final_text_list).to(device) assert text.shape[0] == batch # duration - cond_mask = lens_to_mask(lens) + cond_mask = self.lens_to_mask(lens) if isinstance(duration, int): duration = torch.full((batch,), duration, device=device, dtype=torch.long) @@ -205,23 +210,21 @@ def __call__( cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) cond_mask = cond_mask.unsqueeze(-1) - step_cond = torch.where( + step_cond_input = torch.where( cond_mask, cond, torch.zeros_like(cond) ) # allow direct control (cut cond audio) with lens passed in - - - step_cond = self.conditioning_encoder(x, step_cond, text) - if batch > 1: - mask = lens_to_mask(duration) + mask = self.lens_to_mask(duration) else: # save memory and speed up, as single inference need no mask currently mask = None - + + if cfg_strength >= 1e-5 and mask is not None: + mask = torch.cat((mask, mask), dim=0) # for classifier-free guidance, we need to double the batch size # neural ode def fn(t, x): # at each step, conditioning is fixed # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) - step_cond = self.conditioning_encoder(x, step_cond, text, drop_audio_cond=False, drop_text=False) + step_cond = self.conditioning_encoder(x, step_cond_input, text, drop_audio_cond=False, drop_text=False) # predict flow (cond) if cfg_strength < 1e-5: pred = self.transformer( @@ -234,14 +237,10 @@ def fn(t, x): return pred # predict flow (cond and uncond), for classifier-free guidance - - step_uncond = self.conditioning_encoder(x, step_uncond, text, drop_audio_cond=False, drop_text=False) + step_uncond = self.conditioning_encoder(x, step_cond_input, text, drop_audio_cond=False, drop_text=False) step_cond = torch.cat((step_cond, step_uncond), dim=0) - x = torch.cat((x, x), dim=0) - t = torch.cat((t, t), dim=0) pred_cfg = self.transformer( - x=x, - cond=step_cond, + x=step_cond, time=t, mask=mask, cache=True, @@ -254,19 +253,19 @@ def fn(t, x): # still some difference maybe due to convolutional layers y0 = [] for dur in duration: - if exists(seed): + if seed is not None: torch.manual_seed(seed) - y0.append(torch.randn(dur, self.num_channels, device=self.transformer.device, dtype=step_cond.dtype)) + y0.append(torch.randn(dur, self.num_channels, device=device, dtype=step_cond_input.dtype)) y0 = pad_sequence(y0, padding_value=0, batch_first=True) t_start = 0 # TODO Add Empirically Pruned Step Sampling for low NFE - t = torch.linspace(t_start, 1, steps + 1, device=self.transformer.device, dtype=step_cond.dtype) + t = torch.linspace(t_start, 1, steps + 1, device=device, dtype=step_cond_input.dtype) if sway_sampling_coef is not None: t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) trajectory = odeint(fn, y0, t, **self.odeint_kwargs) - self.transformer.clear_cache() + # self.transformer.clear_cache() sampled = trajectory[-1] out = sampled @@ -275,7 +274,7 @@ def fn(t, x): out = out.to(torch.float32) # generated mel spectrogram out = out[:, ref_audio_len:, :] out = out.permute(0, 2, 1) - generated_cpu = out[0].cpu().numpy() + generated_cpu = out[0] # This need to be in HF Output format @@ -312,12 +311,19 @@ def fn(t, x): } + with open('vocab.txt', "r", encoding="utf-8") as f: + vocab_char_map = {} + for i, char in enumerate(f): + vocab_char_map[char[:-1]] = i + vocab_size = len(vocab_char_map) + + dit = DiT(**dit_config) print("DiT model initialized with config:", dit_config) conditioning_encoder_config = { 'dim': 1024, - 'text_num_embeds': 256, + 'text_num_embeds': vocab_size, 'text_dim': 512, 'text_mask_padding': True, 'conv_layers': 4, @@ -331,6 +337,23 @@ def fn(t, x): conditioning_encoder=conditioning_encoder, odeint_kwargs={"method": "euler"}, mel_spec_kwargs=mel_spec_config, + vocab_char_map=vocab_char_map, ) print("F5FlowPipeline initialized with DiT and Conditioning Encoder.") + import torch + ref_audio = torch.randn(2, 16000) # Dummy reference audio + duration = 250 + + ref_text = "This is a test sentence." # Dummy reference text + gen_text = "This is a generated sentence." # Dummy generated text + + text = [ref_text+gen_text] # Combine reference and generated text + text_list = text * 2 + + x = f5_pipeline(ref_audio=ref_audio, + text_list=text_list, + fix_duration=4, + max_duration=4096, device='cpu', steps=2) + print("Generated output shape:", x.shape) + \ No newline at end of file