diff --git a/.gitignore b/.gitignore
index 18154f54..79c697e2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -45,4 +45,4 @@ checkpoints/
.gradio
# Ignore generated sample .wav files
-**/*.wav
+**/*.wav
\ No newline at end of file
diff --git a/Chatterbox-Multilingual.png b/Chatterbox-Multilingual.png
new file mode 100644
index 00000000..01c0efaf
Binary files /dev/null and b/Chatterbox-Multilingual.png differ
diff --git a/README.md b/README.md
index 162ac079..d6651ef7 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
-
+
# Chatterbox TTS
@@ -10,14 +10,15 @@
_Made with ♥️ by
-We're excited to introduce Chatterbox, [Resemble AI's](https://resemble.ai) first production-grade open source TTS model. Licensed under MIT, Chatterbox has been benchmarked against leading closed-source systems like ElevenLabs, and is consistently preferred in side-by-side evaluations.
+We're excited to introduce **Chatterbox Multilingual**, [Resemble AI's](https://resemble.ai) first production-grade open source TTS model supporting **23 languages** out of the box. Licensed under MIT, Chatterbox has been benchmarked against leading closed-source systems like ElevenLabs, and is consistently preferred in side-by-side evaluations.
-Whether you're working on memes, videos, games, or AI agents, Chatterbox brings your content to life. It's also the first open source TTS model to support **emotion exaggeration control**, a powerful feature that makes your voices stand out. Try it now on our [Hugging Face Gradio app.](https://huggingface.co/spaces/ResembleAI/Chatterbox)
+Whether you're working on memes, videos, games, or AI agents, Chatterbox brings your content to life across languages. It's also the first open source TTS model to support **emotion exaggeration control** with robust **multilingual zero-shot voice cloning**. Try the english only version now on our [English Hugging Face Gradio app.](https://huggingface.co/spaces/ResembleAI/Chatterbox). Or try the multilingual version on our [Multilingual Hugging Face Gradio app.](https://huggingface.co/spaces/ResembleAI/Chatterbox-Multilingual-TTS).
If you like the model but need to scale or tune it for higher accuracy, check out our competitively priced TTS service (link). It delivers reliable performance with ultra-low latency of sub 200ms—ideal for production use in agents, applications, or interactive media.
# Key Details
-- SoTA zeroshot TTS
+- Multilingual, zero-shot TTS supporting 23 languages
+- SoTA zeroshot English TTS
- 0.5B Llama backbone
- Unique exaggeration/intensity control
- Ultra-stable with alignment-informed inference
@@ -26,9 +27,12 @@ If you like the model but need to scale or tune it for higher accuracy, check ou
- Easy voice conversion script
- [Outperforms ElevenLabs](https://podonos.com/resembleai/chatterbox)
+# Supported Languages
+Arabic (ar) • Danish (da) • German (de) • Greek (el) • English (en) • Spanish (es) • Finnish (fi) • French (fr) • Hebrew (he) • Hindi (hi) • Italian (it) • Japanese (ja) • Korean (ko) • Malay (ms) • Dutch (nl) • Norwegian (no) • Polish (pl) • Portuguese (pt) • Russian (ru) • Swedish (sv) • Swahili (sw) • Turkish (tr) • Chinese (zh)
# Tips
- **General Use (TTS and Voice Agents):**
- - The default settings (`exaggeration=0.5`, `cfg_weight=0.5`) work well for most prompts.
+ - Ensure that the reference clip matches the specified language tag. Otherwise, language transfer outputs may inherit the accent of the reference clip’s language. To mitigate this, set `cfg_weight` to `0`.
+ - The default settings (`exaggeration=0.5`, `cfg_weight=0.5`) work well for most prompts across all languages.
- If the reference speaker has a fast speaking style, lowering `cfg_weight` to around `0.3` can improve pacing.
- **Expressive or Dramatic Speech:**
@@ -50,19 +54,31 @@ git clone https://github.com/resemble-ai/chatterbox.git
cd chatterbox
pip install -e .
```
-We developed and tested Chatterbox on Python 3.11 on Debain 11 OS; the versions of the dependencies are pinned in `pyproject.toml` to ensure consistency. You can modify the code or dependencies in this installation mode.
-
+We developed and tested Chatterbox on Python 3.11 on Debian 11 OS; the versions of the dependencies are pinned in `pyproject.toml` to ensure consistency. You can modify the code or dependencies in this installation mode.
# Usage
```python
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
+from chatterbox.mtl_tts import ChatterboxMultilingualTTS
+# English example
model = ChatterboxTTS.from_pretrained(device="cuda")
text = "Ezreal and Jinx teamed up with Ahri, Yasuo, and Teemo to take down the enemy's Nexus in an epic late-game pentakill."
wav = model.generate(text)
-ta.save("test-1.wav", wav, model.sr)
+ta.save("test-english.wav", wav, model.sr)
+
+# Multilingual examples
+multilingual_model = ChatterboxMultilingualTTS.from_pretrained(device=device)
+
+french_text = "Bonjour, comment ça va? Ceci est le modèle de synthèse vocale multilingue Chatterbox, il prend en charge 23 langues."
+wav_french = multilingual_model.generate(spanish_text, language_id="fr")
+ta.save("test-french.wav", wav_french, model.sr)
+
+chinese_text = "你好,今天天气真不错,希望你有一个愉快的周末。"
+wav_chinese = multilingual_model.generate(chinese_text, language_id="zh")
+ta.save("test-chinese.wav", wav_chinese, model.sr)
# If you want to synthesize with a different voice, specify the audio prompt
AUDIO_PROMPT_PATH = "YOUR_FILE.wav"
@@ -71,9 +87,6 @@ ta.save("test-2.wav", wav, model.sr)
```
See `example_tts.py` and `example_vc.py` for more examples.
-# Supported Lanugage
-Currenlty only English.
-
# Acknowledgements
- [Cosyvoice](https://github.com/FunAudioLLM/CosyVoice)
- [Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning)
@@ -113,5 +126,16 @@ print(f"Extracted watermark: {watermark}")
👋 Join us on [Discord](https://discord.gg/rJq9cRJBJ6) and let's build something awesome together!
+# Citation
+If you find this model useful, please consider citing.
+```
+@misc{chatterboxtts2025,
+ author = {{Resemble AI}},
+ title = {{Chatterbox-TTS}},
+ year = {2025},
+ howpublished = {\url{https://github.com/resemble-ai/chatterbox}},
+ note = {GitHub repository}
+}
+```
# Disclaimer
Don't use this model to do bad things. Prompts are sourced from freely available data on the internet.
diff --git a/chatterbox.pyproj b/chatterbox.pyproj
new file mode 100644
index 00000000..a450db6d
--- /dev/null
+++ b/chatterbox.pyproj
@@ -0,0 +1,180 @@
+
+
+ Debug
+ 2.0
+ 958ac2fa-ebae-40f7-bf36-e239ba9b0efd
+ .
+ run_tts_test.py
+
+
+ .
+ .
+ chatterbox
+ chatterbox
+ true
+ MSBuild|venv|C:\_myDrive\repos\auto-vlog\AutoVlogProj\AutoVlogProj.pyproj
+
+
+ true
+ false
+
+
+ true
+ false
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/example_tts.py b/example_tts.py
index 0efcf820..7083f841 100644
--- a/example_tts.py
+++ b/example_tts.py
@@ -1,6 +1,9 @@
+# example_tts.py
+
import torchaudio as ta
import torch
from chatterbox.tts import ChatterboxTTS
+from chatterbox.mtl_tts import ChatterboxMultilingualTTS
# Automatically detect the best available device
if torch.cuda.is_available():
@@ -14,11 +17,17 @@
model = ChatterboxTTS.from_pretrained(device=device)
-text = "Ezreal and Jinx teamed up with Ahri, Yasuo, and Teemo to take down the enemy's Nexus in an epic late-game pentakill."
+text = "Ezreal and Jinx teamed up with Ahri, Yasuo, and Teemo to take down the enemy's Nexus."
wav = model.generate(text)
-ta.save("test-1.wav", wav, model.sr)
+ta.save("test-default-voice.wav", wav, model.sr)
+
+multilingual_model = ChatterboxMultilingualTTS.from_pretrained(device=device)
+text = "Bonjour, comment ça va? Ceci est le modèle de synthèse vocale multilingue Chatterbox, il prend en charge 23 langues."
+wav = multilingual_model.generate(text, language_id="fr")
+ta.save("test-2.wav", wav, multilingual_model.sr)
+
# If you want to synthesize with a different voice, specify the audio prompt
AUDIO_PROMPT_PATH = "YOUR_FILE.wav"
wav = model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH)
-ta.save("test-2.wav", wav, model.sr)
+ta.save("test-3.wav", wav, model.sr)
diff --git a/example_vc.py b/example_vc.py
index 6255a0a3..90d66480 100644
--- a/example_vc.py
+++ b/example_vc.py
@@ -1,3 +1,5 @@
+# example_vc.py
+
import torch
import torchaudio as ta
diff --git a/example_vc_batching.py b/example_vc_batching.py
new file mode 100644
index 00000000..1256f509
--- /dev/null
+++ b/example_vc_batching.py
@@ -0,0 +1,43 @@
+# example_vc_batching.py
+
+import torchaudio as ta
+import torch
+from chatterbox.tts import ChatterboxTTS
+
+# Automatically detect the best available device
+if torch.cuda.is_available():
+ device = "cuda"
+elif torch.backends.mps.is_available():
+ device = "mps"
+else:
+ device = "cpu"
+
+print(f"Using device: {device}")
+
+
+model = ChatterboxTTS.from_pretrained(device=device)
+
+texts_batch = [ "This is the first sentence to be synthesized in a batch.",
+ "This is the second one." ]
+
+
+# If you want to synthesize with a different voice, specify the audio prompt:
+AUDIO_PROMPT_PATH = "YOUR_AUDIO_PROMPT.wav"
+
+
+# Batching - list of strings to synthesize multiple different texts in a single batch.
+# This is the most efficient way to process multiple, different prompts at once.
+# Careful: 1 text = 1 additional KV Cache (Vram)
+wavs_batch = model.generate(texts_batch, audio_prompt_path=AUDIO_PROMPT_PATH)
+for i, wav in enumerate(wavs_batch):
+ ta.save(f"test-batch-{i+1}.wav", wav, model.sr)
+
+# Batching - Use num_return_sequences to generate multiple variations for each text.
+# This is highly efficient for creating diverse samples, as the prompt is only processed once.
+# Without making extra KV Caches.
+num_variations = 3
+
+wavs_batch_multi = model.generate(texts_batch, audio_prompt_path=AUDIO_PROMPT_PATH, num_return_sequences=num_variations)
+for i, group in enumerate(wavs_batch_multi):
+ for j, wav in enumerate(group):
+ ta.save(f"test-batch-{i+1}-variant-{j+1}.wav", wav, model.sr)
\ No newline at end of file
diff --git a/gradio_tts_app.py b/gradio_tts_app.py
index cda7912b..a313437f 100644
--- a/gradio_tts_app.py
+++ b/gradio_tts_app.py
@@ -1,3 +1,5 @@
+# gradio_tts_app.py
+
import random
import numpy as np
import torch
diff --git a/gradio_vc_app.py b/gradio_vc_app.py
index 2202e6ba..13c2f219 100644
--- a/gradio_vc_app.py
+++ b/gradio_vc_app.py
@@ -1,3 +1,5 @@
+# gradio_vc_app.py
+
import torch
import gradio as gr
from chatterbox.vc import ChatterboxVC
diff --git a/multilingual_app.py b/multilingual_app.py
new file mode 100644
index 00000000..51e9c693
--- /dev/null
+++ b/multilingual_app.py
@@ -0,0 +1,317 @@
+import random
+import numpy as np
+import torch
+from chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
+import gradio as gr
+
+DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
+print(f"🚀 Running on device: {DEVICE}")
+
+# --- Global Model Initialization ---
+MODEL = None
+
+LANGUAGE_CONFIG = {
+ "ar": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ar_f/ar_prompts2.flac",
+ "text": "في الشهر الماضي، وصلنا إلى معلم جديد بمليارين من المشاهدات على قناتنا على يوتيوب."
+ },
+ "da": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/da_m1.flac",
+ "text": "Sidste måned nåede vi en ny milepæl med to milliarder visninger på vores YouTube-kanal."
+ },
+ "de": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/de_f1.flac",
+ "text": "Letzten Monat haben wir einen neuen Meilenstein erreicht: zwei Milliarden Aufrufe auf unserem YouTube-Kanal."
+ },
+ "el": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/el_m.flac",
+ "text": "Τον περασμένο μήνα, φτάσαμε σε ένα νέο ορόσημο με δύο δισεκατομμύρια προβολές στο κανάλι μας στο YouTube."
+ },
+ "en": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/en_f1.flac",
+ "text": "Last month, we reached a new milestone with two billion views on our YouTube channel."
+ },
+ "es": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/es_f1.flac",
+ "text": "El mes pasado alcanzamos un nuevo hito: dos mil millones de visualizaciones en nuestro canal de YouTube."
+ },
+ "fi": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/fi_m.flac",
+ "text": "Viime kuussa saavutimme uuden virstanpylvään kahden miljardin katselukerran kanssa YouTube-kanavallamme."
+ },
+ "fr": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/fr_f1.flac",
+ "text": "Le mois dernier, nous avons atteint un nouveau jalon avec deux milliards de vues sur notre chaîne YouTube."
+ },
+ "he": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/he_m1.flac",
+ "text": "בחודש שעבר הגענו לאבן דרך חדשה עם שני מיליארד צפיות בערוץ היוטיוב שלנו."
+ },
+ "hi": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/hi_f1.flac",
+ "text": "पिछले महीने हमने एक नया मील का पत्थर छुआ: हमारे YouTube चैनल पर दो अरब व्यूज़।"
+ },
+ "it": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/it_m1.flac",
+ "text": "Il mese scorso abbiamo raggiunto un nuovo traguardo: due miliardi di visualizzazioni sul nostro canale YouTube."
+ },
+ "ja": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ja/ja_prompts1.flac",
+ "text": "先月、私たちのYouTubeチャンネルで二十億回の再生回数という新たなマイルストーンに到達しました。"
+ },
+ "ko": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ko_f.flac",
+ "text": "지난달 우리는 유튜브 채널에서 이십억 조회수라는 새로운 이정표에 도달했습니다."
+ },
+ "ms": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ms_f.flac",
+ "text": "Bulan lepas, kami mencapai pencapaian baru dengan dua bilion tontonan di saluran YouTube kami."
+ },
+ "nl": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/nl_m.flac",
+ "text": "Vorige maand bereikten we een nieuwe mijlpaal met twee miljard weergaven op ons YouTube-kanaal."
+ },
+ "no": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/no_f1.flac",
+ "text": "Forrige måned nådde vi en ny milepæl med to milliarder visninger på YouTube-kanalen vår."
+ },
+ "pl": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/pl_m.flac",
+ "text": "W zeszłym miesiącu osiągnęliśmy nowy kamień milowy z dwoma miliardami wyświetleń na naszym kanale YouTube."
+ },
+ "pt": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/pt_m1.flac",
+ "text": "No mês passado, alcançámos um novo marco: dois mil milhões de visualizações no nosso canal do YouTube."
+ },
+ "ru": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ru_m.flac",
+ "text": "В прошлом месяце мы достигли нового рубежа: два миллиарда просмотров на нашем YouTube-канале."
+ },
+ "sv": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/sv_f.flac",
+ "text": "Förra månaden nådde vi en ny milstolpe med två miljarder visningar på vår YouTube-kanal."
+ },
+ "sw": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/sw_m.flac",
+ "text": "Mwezi uliopita, tulifika hatua mpya ya maoni ya bilioni mbili kweny kituo chetu cha YouTube."
+ },
+ "tr": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/tr_m.flac",
+ "text": "Geçen ay YouTube kanalımızda iki milyar görüntüleme ile yeni bir dönüm noktasına ulaştık."
+ },
+ "zh": {
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/zh_f2.flac",
+ "text": "上个月,我们达到了一个新的里程碑. 我们的YouTube频道观看次数达到了二十亿次,这绝对令人难以置信。"
+ },
+}
+
+# --- UI Helpers ---
+def default_audio_for_ui(lang: str) -> str | None:
+ return LANGUAGE_CONFIG.get(lang, {}).get("audio")
+
+
+def default_text_for_ui(lang: str) -> str:
+ return LANGUAGE_CONFIG.get(lang, {}).get("text", "")
+
+
+def get_supported_languages_display() -> str:
+ """Generate a formatted display of all supported languages."""
+ language_items = []
+ for code, name in sorted(SUPPORTED_LANGUAGES.items()):
+ language_items.append(f"**{name}** (`{code}`)")
+
+ # Split into 2 lines
+ mid = len(language_items) // 2
+ line1 = " • ".join(language_items[:mid])
+ line2 = " • ".join(language_items[mid:])
+
+ return f"""
+### 🌍 Supported Languages ({len(SUPPORTED_LANGUAGES)} total)
+{line1}
+
+{line2}
+"""
+
+
+def get_or_load_model():
+ """Loads the ChatterboxMultilingualTTS model if it hasn't been loaded already,
+ and ensures it's on the correct device."""
+ global MODEL
+ if MODEL is None:
+ print("Model not loaded, initializing...")
+ try:
+ MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE)
+ if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
+ MODEL.to(DEVICE)
+ print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
+ except Exception as e:
+ print(f"Error loading model: {e}")
+ raise
+ return MODEL
+
+# Attempt to load the model at startup.
+try:
+ get_or_load_model()
+except Exception as e:
+ print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")
+
+def set_seed(seed: int):
+ """Sets the random seed for reproducibility across torch, numpy, and random."""
+ torch.manual_seed(seed)
+ if DEVICE == "cuda":
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ random.seed(seed)
+ np.random.seed(seed)
+
+def resolve_audio_prompt(language_id: str, provided_path: str | None) -> str | None:
+ """
+ Decide which audio prompt to use:
+ - If user provided a path (upload/mic/url), use it.
+ - Else, fall back to language-specific default (if any).
+ """
+ if provided_path and str(provided_path).strip():
+ return provided_path
+ return LANGUAGE_CONFIG.get(language_id, {}).get("audio")
+
+
+def generate_tts_audio(
+ text_input: str,
+ language_id: str,
+ audio_prompt_path_input: str = None,
+ exaggeration_input: float = 0.5,
+ temperature_input: float = 0.8,
+ seed_num_input: int = 0,
+ cfgw_input: float = 0.5
+) -> tuple[int, np.ndarray]:
+ """
+ Generate high-quality speech audio from text using Chatterbox Multilingual model with optional reference audio styling.
+ Supported languages: English, French, German, Spanish, Italian, Portuguese, and Hindi.
+
+ This tool synthesizes natural-sounding speech from input text. When a reference audio file
+ is provided, it captures the speaker's voice characteristics and speaking style. The generated audio
+ maintains the prosody, tone, and vocal qualities of the reference speaker, or uses default voice if no reference is provided.
+
+ Args:
+ text_input (str): The text to synthesize into speech (maximum 300 characters)
+ language_id (str): The language code for synthesis (eg. en, fr, de, es, it, pt, hi)
+ audio_prompt_path_input (str, optional): File path or URL to the reference audio file that defines the target voice style. Defaults to None.
+ exaggeration_input (float, optional): Controls speech expressiveness (0.25-2.0, neutral=0.5, extreme values may be unstable). Defaults to 0.5.
+ temperature_input (float, optional): Controls randomness in generation (0.05-5.0, higher=more varied). Defaults to 0.8.
+ seed_num_input (int, optional): Random seed for reproducible results (0 for random generation). Defaults to 0.
+ cfgw_input (float, optional): CFG/Pace weight controlling generation guidance (0.2-1.0). Defaults to 0.5, 0 for language transfer.
+
+ Returns:
+ tuple[int, np.ndarray]: A tuple containing the sample rate (int) and the generated audio waveform (numpy.ndarray)
+ """
+ current_model = get_or_load_model()
+
+ if current_model is None:
+ raise RuntimeError("TTS model is not loaded.")
+
+ if seed_num_input != 0:
+ set_seed(int(seed_num_input))
+
+ print(f"Generating audio for text: '{text_input[:50]}...'")
+
+ # Handle optional audio prompt
+ chosen_prompt = audio_prompt_path_input or default_audio_for_ui(language_id)
+
+ generate_kwargs = {
+ "exaggeration": exaggeration_input,
+ "temperature": temperature_input,
+ "cfg_weight": cfgw_input,
+ }
+ if chosen_prompt:
+ generate_kwargs["audio_prompt_path"] = chosen_prompt
+ print(f"Using audio prompt: {chosen_prompt}")
+ else:
+ print("No audio prompt provided; using default voice.")
+
+ wav = current_model.generate(
+ text_input[:300], # Truncate text to max chars
+ language_id=language_id,
+ **generate_kwargs
+ )
+ print("Audio generation complete.")
+ return (current_model.sr, wav.squeeze(0).numpy())
+
+with gr.Blocks() as demo:
+ gr.Markdown(
+ """
+ # Chatterbox Multilingual Demo
+ Generate high-quality multilingual speech from text with reference audio styling, supporting 23 languages.
+ """
+ )
+
+ # Display supported languages
+ gr.Markdown(get_supported_languages_display())
+ with gr.Row():
+ with gr.Column():
+ initial_lang = "fr"
+ text = gr.Textbox(
+ value=default_text_for_ui(initial_lang),
+ label="Text to synthesize (max chars 300)",
+ max_lines=5
+ )
+
+ language_id = gr.Dropdown(
+ choices=list(ChatterboxMultilingualTTS.get_supported_languages().keys()),
+ value=initial_lang,
+ label="Language",
+ info="Select the language for text-to-speech synthesis"
+ )
+
+ ref_wav = gr.Audio(
+ sources=["upload", "microphone"],
+ type="filepath",
+ label="Reference Audio File (Optional)",
+ value=default_audio_for_ui(initial_lang)
+ )
+
+ gr.Markdown(
+ "💡 **Note**: Ensure that the reference clip matches the specified language tag. Otherwise, language transfer outputs may inherit the accent of the reference clip's language. To mitigate this, set the CFG weight to 0.",
+ elem_classes=["audio-note"]
+ )
+
+ exaggeration = gr.Slider(
+ 0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5
+ )
+ cfg_weight = gr.Slider(
+ 0.2, 1, step=.05, label="CFG/Pace", value=0.5
+ )
+
+ with gr.Accordion("More options", open=False):
+ seed_num = gr.Number(value=0, label="Random seed (0 for random)")
+ temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
+
+ run_btn = gr.Button("Generate", variant="primary")
+
+ with gr.Column():
+ audio_output = gr.Audio(label="Output Audio")
+
+ def on_language_change(lang, current_ref, current_text):
+ return default_audio_for_ui(lang), default_text_for_ui(lang)
+
+ language_id.change(
+ fn=on_language_change,
+ inputs=[language_id, ref_wav, text],
+ outputs=[ref_wav, text],
+ show_progress=False
+ )
+
+ run_btn.click(
+ fn=generate_tts_audio,
+ inputs=[
+ text,
+ language_id,
+ ref_wav,
+ exaggeration,
+ temp,
+ seed_num,
+ cfg_weight,
+ ],
+ outputs=[audio_output],
+ )
+
+demo.launch(mcp_server=True)
diff --git a/pyproject.toml b/pyproject.toml
index 9e1bbba4..780ec361 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,15 +1,15 @@
[project]
name = "chatterbox-tts"
-version = "0.1.2"
+version = "0.1.4"
description = "Chatterbox: Open Source TTS and Voice Conversion by Resemble AI"
readme = "README.md"
-requires-python = ">=3.9"
+requires-python = ">=3.10"
license = {file = "LICENSE"}
authors = [
{name = "resemble-ai", email = "engineering@resemble.ai"}
]
dependencies = [
- "numpy>=1.26.0",
+ "numpy>=1.24.0,<1.26.0",
"librosa==0.11.0",
"s3tokenizer",
"torch==2.6.0",
@@ -18,7 +18,12 @@ dependencies = [
"diffusers==0.29.0",
"resemble-perth==1.0.1",
"conformer==0.3.2",
- "safetensors==0.5.3"
+ "safetensors==0.5.3",
+ "spacy-pkuseg",
+ "pykakasi==2.3.0",
+ "gradio==5.44.1",
+ "russian-text-stresser @ git+https://github.com/Vuizur/add-stress-to-epub",
+
]
[project.urls]
diff --git a/run_tts_test.py b/run_tts_test.py
new file mode 100644
index 00000000..fbf4dd19
--- /dev/null
+++ b/run_tts_test.py
@@ -0,0 +1,108 @@
+# chatterbox/run_tts_test.py
+
+import torchaudio
+import torch
+from pathlib import Path
+import time
+import sys
+
+# Robust Path Setup
+# Add the project root to the Python path to allow importing the chatterbox module
+try:
+ # Assuming this script is in the 'chatterbox' directory, the project root is one level up.
+ project_root = Path(__file__).resolve().parent.parent
+ if str(project_root) not in sys.path:
+ sys.path.insert(0, str(project_root))
+
+ from chatterbox.tts import ChatterboxTTS
+except (ImportError, NameError):
+ print("Error: Could not import ChatterboxTTS.")
+ print("Please ensure this script is located in the 'chatterbox' directory and that the main project structure is intact.")
+ sys.exit(1)
+
+
+# 1. Define the texts you want to generate in a batch.
+texts_to_generate = [
+ "Artists create a sequence of drawings to visualize the animation, shot by shot, much like a comic book. This process helps the entire team understand the director's vision, plan the pacing of the story, and identify potential issues before any resource-intensive digital work commences. It serves as the essential blueprint that guides all subsequent stages of production, ensuring a cohesive final product.",
+ "Following this, the rigging process gives the models a digital skeleton, or armature, which allows animators to pose and move them realistically. This technical yet artistic step is fundamental for creating believable movement, as a well-constructed rig provides the intuitive controls necessary for bringing static models to life.",
+ "Simultaneously, the texturing process involves painting and applying surface details, such as skin, fabric, or metal, to the models. These detailed maps, known as textures, determine how light interacts with the surfaces, adding a layer of realism and visual richness. This combination of movement and surface artistry transforms simple geometric shapes into compelling characters.",
+ "Rendering is the computationally intensive process of calculating the final image from all the data, turning the 3D scene into a sequence of 2D frames. Finally, compositing combines these rendered layers with visual effects and color grading in post-production, seamlessly integrating every component to achieve the final, stunning look of the completed animation.",
+]
+
+
+# 2. Define the path to the voice prompt audio file.
+# This uses the project root to create a reliable path.
+AUDIO_PROMPT_PATH = project_root / "assets" / "audio_sample1.wav"
+
+# 3. Define the output directory for the generated audio files.
+OUTPUT_DIR = Path(__file__).resolve().parent / "tts_test_outputs"
+
+
+def main():
+ """Main function to run the TTS test."""
+
+ # Device Selection
+ if torch.cuda.is_available():
+ device = "cuda"
+ elif torch.backends.mps.is_available():
+ device = "mps"
+ else:
+ device = "cpu"
+
+ print(f"Using device: {device}")
+
+ # Validate Inputs
+ if not AUDIO_PROMPT_PATH.is_file():
+ print(f"Error: Audio prompt file not found at '{AUDIO_PROMPT_PATH}'")
+ return
+
+ # Create the output directory if it doesn't exist
+ OUTPUT_DIR.mkdir(exist_ok=True)
+ print(f"Audio files will be saved in: '{OUTPUT_DIR}'")
+
+ # Model Loading
+ print("\nLoading Chatterbox TTS model... (This may take a moment)")
+ start_time = time.time()
+ try:
+ # Enable BF16 and Compilation if running on CUDA
+ use_optimizations = (device == "cuda")
+
+ # Use the updated from_pretrained method
+ model = ChatterboxTTS.from_pretrained(
+ device=device,
+ use_bf16=use_optimizations,
+ compile_model=use_optimizations
+ )
+ except Exception as e:
+ print(f"Failed to load model: {e}")
+ import traceback
+ traceback.print_exc()
+ return
+ load_time = time.time() - start_time
+ print(f"Model loaded successfully in {load_time:.2f} seconds.")
+
+ # Audio Generation
+ print(f"\nGenerating audio for a batch of {len(texts_to_generate)} texts...")
+ start_time = time.time()
+
+ # The `generate` method handles batching when given a list of strings.
+ wavs_batch = model.generate(
+ texts_to_generate,
+ audio_prompt_path=str(AUDIO_PROMPT_PATH)
+ )
+
+ generation_time = time.time() - start_time
+ print(f"Batch generation completed in {generation_time:.2f} seconds.")
+
+ # Saving Outputs
+ print("\nSaving generated audio files...")
+ for i, wav_tensor in enumerate(wavs_batch):
+ output_filename = OUTPUT_DIR / f"output_batch_{i+1}.wav"
+ torchaudio.save(str(output_filename), wav_tensor, model.sr)
+ print(f" - Saved: {output_filename.name}")
+
+ print("\n--- Test Complete")
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/src/chatterbox/__init__.py b/src/chatterbox/__init__.py
index 4f5381fb..190cfbf2 100644
--- a/src/chatterbox/__init__.py
+++ b/src/chatterbox/__init__.py
@@ -8,3 +8,4 @@
from .tts import ChatterboxTTS
from .vc import ChatterboxVC
+from .mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
\ No newline at end of file
diff --git a/src/chatterbox/models/s3gen/configs.py b/src/chatterbox/models/s3gen/configs.py
index b09b2e52..c243fc34 100644
--- a/src/chatterbox/models/s3gen/configs.py
+++ b/src/chatterbox/models/s3gen/configs.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3gen/configs.py
+
from ..utils import AttrDict
CFM_PARAMS = AttrDict({
diff --git a/src/chatterbox/models/s3gen/const.py b/src/chatterbox/models/s3gen/const.py
index 72de6a23..9f576bd5 100644
--- a/src/chatterbox/models/s3gen/const.py
+++ b/src/chatterbox/models/s3gen/const.py
@@ -1 +1,9 @@
+#chatterbox/src/chatterbox/models/s3gen/const.py
+
S3GEN_SR = 24000
+
+# The ratio of audio samples to speech tokens.
+# 1 token -> 2 mel frames (in CausalMaskedDiffWithXvec)
+# 1 mel frame -> 480 audio samples (hop_size in hifigan)
+# Total: 2 * 480 = 960
+TOKEN_TO_WAV_RATIO = 960
\ No newline at end of file
diff --git a/src/chatterbox/models/s3gen/decoder.py b/src/chatterbox/models/s3gen/decoder.py
index c568c2df..dc2b3049 100644
--- a/src/chatterbox/models/s3gen/decoder.py
+++ b/src/chatterbox/models/s3gen/decoder.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3gen/decoder.py
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/src/chatterbox/models/s3gen/f0_predictor.py b/src/chatterbox/models/s3gen/f0_predictor.py
index 172c5f50..51b48868 100644
--- a/src/chatterbox/models/s3gen/f0_predictor.py
+++ b/src/chatterbox/models/s3gen/f0_predictor.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3gen/f0_predictor.py
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -49,7 +51,18 @@ def __init__(self,
)
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
+ @property
+ def dtype(self):
+ try:# Dynamically determine the dtype from parameters
+ return self.classifier.weight.dtype
+ except StopIteration:
+ return torch.float32
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = self.condnet(x)
+ # Ensure input matches the model's dtype (e.g., BF16)
+ if x.dtype != self.dtype:
+ x = x.to(self.dtype)
+
+ x = self.condnet(x) # This now works: BF16 input, BF16 weights
x = x.transpose(1, 2)
- return torch.abs(self.classifier(x).squeeze(-1))
+ return torch.abs(self.classifier(x).squeeze(-1))
\ No newline at end of file
diff --git a/src/chatterbox/models/s3gen/flow.py b/src/chatterbox/models/s3gen/flow.py
index 8069f4e0..75d1fd24 100644
--- a/src/chatterbox/models/s3gen/flow.py
+++ b/src/chatterbox/models/s3gen/flow.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3gen/flow.py
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,6 +16,8 @@
import logging
import random
from typing import Dict, Optional
+
+logger = logging.getLogger(__name__)
import torch
import torch.nn as nn
from torch.nn import functional as F
@@ -76,6 +80,7 @@ def __init__(
self.decoder = decoder
self.length_regulator = length_regulator
self.only_mask_loss = only_mask_loss
+ self.fp16 = False # was missing
def forward(
self,
@@ -94,7 +99,7 @@ def forward(
# concat text and prompt_text
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
- token = self.input_embedding(torch.clamp(token, min=0)) * mask
+ token = self.input_embedding(torch.clamp(token, min=0, max=self.input_embedding.num_embeddings-1)) * mask
# text encode
h, h_lengths = self.encoder(token, token_len)
@@ -130,46 +135,57 @@ def inference(self,
prompt_feat,
prompt_feat_len,
embedding,
- flow_cache):
- if self.fp16 is True:
- prompt_feat = prompt_feat.half()
- embedding = embedding.half()
+ finalize):
+ # Use the actual model dtype for inputs
+ expected_dtype = self.dtype
- assert token.shape[0] == 1
+ # Ensure inputs match the expected dtype (replaces the old self.fp16 check)
+ if prompt_feat.dtype != expected_dtype:
+ prompt_feat = prompt_feat.to(expected_dtype)
+ if embedding.dtype != expected_dtype:
+ embedding = embedding.to(expected_dtype)
+
+ B = token.shape[0]
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
- token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
- token = self.input_embedding(torch.clamp(token, min=0)) * mask
+ token = self.input_embedding(torch.clamp(token, min=0, max=self.input_embedding.num_embeddings-1)) * mask
- # text encode
- h, h_lengths = self.encoder(token, token_len)
+ # text encode (Using your corrected version)
+ h, _ = self.encoder(token, token_len) # Ignore h_lengths from encoder
+ h_lengths = token_len * self.token_mel_ratio
+
+ if finalize is False:
+ h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
+ h_lengths = h_lengths - self.pre_lookahead_len * self.token_mel_ratio
+ h_lengths = torch.clamp(h_lengths, min=0)
+
+ mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
h = self.encoder_proj(h)
- mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
- h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
# get conditions
- conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
+ # Ensure conds tensor is created with the correct dtype (h.dtype)
+ conds = torch.zeros([B, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
conds[:, :mel_len1] = prompt_feat
conds = conds.transpose(1, 2)
- mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
- feat, flow_cache = self.decoder(
+ mask = (~make_pad_mask(h_lengths)).to(h)
+ feat, _ = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
- n_timesteps=10,
- prompt_len=mel_len1,
- flow_cache=flow_cache
+ n_timesteps=10
)
feat = feat[:, :, mel_len1:]
- assert feat.shape[2] == mel_len2
- return feat.float(), flow_cache
+
+ # CRITICAL FIX: Do NOT cast to .float(). Keep the precision (e.g., BF16) for the Vocoder.
+ # OLD LINE was: return feat.float(), None
+ return feat, None
class CausalMaskedDiffWithXvec(torch.nn.Module):
@@ -229,9 +245,17 @@ def __init__(
self.only_mask_loss = only_mask_loss
self.token_mel_ratio = token_mel_ratio
self.pre_lookahead_len = pre_lookahead_len
+ # Remove the reliance on the manual fp16 flag if it exists.
+ if hasattr(self, 'fp16'):
+ del self.fp16
- # FIXME: this was missing - just putting it in as false
- self.fp16 = False
+ # Add helper property for dtype
+ @property
+ def dtype(self):
+ try:
+ return next(self.parameters()).dtype
+ except StopIteration:
+ return torch.float32
@torch.inference_mode()
def inference(self,
@@ -243,11 +267,17 @@ def inference(self,
prompt_feat_len,
embedding,
finalize):
- if self.fp16 is True:
- prompt_feat = prompt_feat.half()
- embedding = embedding.half()
+
+ # Use the actual model dtype for inputs
+ expected_dtype = self.dtype
- assert token.shape[0] == 1
+ # Ensure inputs match the expected dtype
+ if prompt_feat.dtype != expected_dtype:
+ prompt_feat = prompt_feat.to(expected_dtype)
+ if embedding.dtype != expected_dtype:
+ embedding = embedding.to(expected_dtype)
+
+ B = token.shape[0]
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
@@ -255,21 +285,26 @@ def inference(self,
# concat text and prompt_text
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
- token = self.input_embedding(torch.clamp(token, min=0)) * mask
+ token = self.input_embedding(torch.clamp(token, min=0, max=self.input_embedding.num_embeddings-1)) * mask
# text encode
- h, h_lengths = self.encoder(token, token_len)
+ h, _ = self.encoder(token, token_len) # Ignore h_lengths from encoder
+ h_lengths = token_len * self.token_mel_ratio
+
if finalize is False:
h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
+ h_lengths = h_lengths - self.pre_lookahead_len * self.token_mel_ratio
+ h_lengths = torch.clamp(h_lengths, min=0)
+
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
h = self.encoder_proj(h)
# get conditions
- conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
+ conds = torch.zeros([B, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
conds[:, :mel_len1] = prompt_feat
conds = conds.transpose(1, 2)
- mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
+ mask = (~make_pad_mask(h_lengths)).to(h)
feat, _ = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
@@ -278,5 +313,7 @@ def inference(self,
n_timesteps=10
)
feat = feat[:, :, mel_len1:]
- assert feat.shape[2] == mel_len2
- return feat.float(), None # NOTE jrm: why are they returning None here?
+ # assert feat.shape[2] == mel_len2 # This assertion is not batch-aware
+
+ # CRITICAL: Do NOT cast to .float() here. Keep the precision (e.g., BF16) for the Vocoder.
+ return feat, None
\ No newline at end of file
diff --git a/src/chatterbox/models/s3gen/flow_matching.py b/src/chatterbox/models/s3gen/flow_matching.py
index ecd69fa4..146acfaf 100644
--- a/src/chatterbox/models/s3gen/flow_matching.py
+++ b/src/chatterbox/models/s3gen/flow_matching.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3gen/flow_matching.py
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -91,21 +93,28 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond):
# Or in future might add like a return_all_steps flag
sol = []
+ B = x.size(0)
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
- x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
- mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
- mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
- t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
- spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
- cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
+ x_in = torch.zeros([B * 2, 80, x.size(2)], device=x.device, dtype=x.dtype)
+ mask_in = torch.zeros([B * 2, 1, x.size(2)], device=x.device, dtype=x.dtype)
+ mu_in = torch.zeros([B * 2, 80, x.size(2)], device=x.device, dtype=x.dtype)
+ t_in = torch.zeros([B * 2], device=x.device, dtype=x.dtype)
+ spks_in = torch.zeros([B * 2, 80], device=x.device, dtype=x.dtype)
+ cond_in = torch.zeros([B * 2, 80, x.size(2)], device=x.device, dtype=x.dtype)
for step in range(1, len(t_span)):
# Classifier-Free Guidance inference introduced in VoiceBox
- x_in[:] = x
- mask_in[:] = mask
- mu_in[0] = mu
- t_in[:] = t.unsqueeze(0)
- spks_in[0] = spks
- cond_in[0] = cond
+ x_in[:] = torch.cat([x, x], dim=0)
+ mask_in[:] = torch.cat([mask, mask], dim=0)
+ mu_in[0:B] = mu
+ # mu_in[B:] is zero for uncond
+ t_in[:] = t.expand(B * 2)
+ if spks is not None:
+ spks_in[0:B] = spks
+ # spks_in[B:] is zero for uncond
+ if cond is not None:
+ cond_in[0:B] = cond
+ # cond_in[B:] is zero for uncond
+
dphi_dt = self.forward_estimator(
x_in, mask_in,
mu_in, t_in,
@@ -127,12 +136,13 @@ def forward_estimator(self, x, mask, mu, t, spks, cond):
return self.estimator.forward(x, mask, mu, t, spks, cond)
else:
with self.lock:
- self.estimator.set_input_shape('x', (2, 80, x.size(2)))
- self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
- self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
- self.estimator.set_input_shape('t', (2,))
- self.estimator.set_input_shape('spks', (2, 80))
- self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
+ B = x.size(0)
+ self.estimator.set_input_shape('x', (B, 80, x.size(2)))
+ self.estimator.set_input_shape('mask', (B, 1, x.size(2)))
+ self.estimator.set_input_shape('mu', (B, 80, x.size(2)))
+ self.estimator.set_input_shape('t', (B,))
+ self.estimator.set_input_shape('spks', (B, 80))
+ self.estimator.set_input_shape('cond', (B, 80, x.size(2)))
# run trt engine
self.estimator.execute_v2([x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
@@ -211,8 +221,9 @@ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
"""
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
+ z = z.expand(mu.shape[0], -1, -1)
# fix prompt and overlap part mu and z
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
- return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
\ No newline at end of file
diff --git a/src/chatterbox/models/s3gen/hifigan.py b/src/chatterbox/models/s3gen/hifigan.py
index 33f9387e..9914337d 100644
--- a/src/chatterbox/models/s3gen/hifigan.py
+++ b/src/chatterbox/models/s3gen/hifigan.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3gen/hifigan.py
+
# jrm: adapted from CosyVoice/cosyvoice/hifigan/generator.py
# most modules should be reusable, but I found their SineGen changed a git.
@@ -194,32 +196,48 @@ def __init__(self, samp_rate, harmonic_num=0,
def _f02uv(self, f0):
# generate uv signal
- uv = (f0 > self.voiced_threshold).type(torch.float32)
+ # FIX: Ensure output matches the dtype of the input f0
+ uv = (f0 > self.voiced_threshold).to(f0.dtype)
return uv
@torch.no_grad()
def forward(self, f0):
"""
- :param f0: [B, 1, sample_len], Hz
- :return: [B, 1, sample_len]
+ :param f0: [B, 1, sample_len], Hz. Expected input dtype (e.g., BF16).
+ :return: [B, 1, sample_len]. Output matches input dtype.
"""
-
- F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
+ # Stable FP32 Island
+ # Operations like cumsum and sin can be unstable in BF16/FP16.
+ # We perform the core math in FP32 for stability, then cast back.
+
+ # Store the target dtype (e.g., BF16)
+ target_dtype = f0.dtype
+
+ # Cast input F0 to FP32 for calculation stability
+ f0_fp32 = f0.float()
+
+ # Initialize F_mat in FP32
+ B, _, T = f0.shape
+ F_mat = torch.zeros((B, self.harmonic_num + 1, T), device=f0.device, dtype=torch.float32)
+
for i in range(self.harmonic_num + 1):
- F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
+ F_mat[:, i: i + 1, :] = f0_fp32 * (i + 1) / self.sampling_rate
+ # Calculations remain in FP32
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
u_dist = Uniform(low=-np.pi, high=np.pi)
- phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
+
+ # Ensure phase_vec is FP32 (Uniform defaults to FP32)
+ phase_vec = u_dist.sample(sample_shape=(B, self.harmonic_num + 1, 1)).to(F_mat.device)
phase_vec[:, 0, :] = 0
- # generate sine waveforms
+ # generate sine waveforms (FP32)
sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
- # generate uv signal
- uv = self._f02uv(f0)
+ # generate uv signal (using FP32 version)
+ uv = self._f02uv(f0_fp32)
- # noise: for unvoiced should be similar to sine_amp
+ # noise (FP32): for unvoiced should be similar to sine_amp
# std = self.sine_amp/3 -> max value ~ self.sine_amp
# . for voiced regions is self.noise_std
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
@@ -228,7 +246,9 @@ def forward(self, f0):
# first: set the unvoiced part to 0 by uv
# then: additive noise
sine_waves = sine_waves * uv + noise
- return sine_waves, uv, noise
+
+ # Cast final outputs back to the target dtype (e.g., BF16)
+ return sine_waves.to(target_dtype), uv.to(target_dtype), noise.to(target_dtype)
class SourceModuleHnNSF(torch.nn.Module):
@@ -264,6 +284,15 @@ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh()
+ # Add dtype property
+ @property
+ def dtype(self):
+ try:
+ # Dynamically determine dtype from the linear layer weights
+ return self.l_linear.weight.dtype
+ except StopIteration:
+ return torch.float32
+
def forward(self, x):
"""
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
@@ -271,14 +300,22 @@ def forward(self, x):
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
"""
+ # Ensure input x (F0 from F0Predictor) matches the module dtype
+ if x.dtype != self.dtype:
+ x = x.to(self.dtype)
+
# source for harmonic branch
with torch.no_grad():
+ # SineGen now handles stability internally and returns the correct dtype (e.g., BF16)
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
sine_wavs = sine_wavs.transpose(1, 2)
uv = uv.transpose(1, 2)
+
+ # This operation (the location of the matmul error) works because
+ # sine_wavs (BF16) matches l_linear weights (BF16)
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
- # source for noise branch, in the same shape as uv
+ # source for noise branch, torch.randn_like respects the dtype of uv (BF16)
noise = torch.randn_like(uv) * self.sine_amp / 3
return sine_merge, noise, uv
@@ -379,6 +416,13 @@ def __init__(
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
self.f0_predictor = f0_predictor
+ @property
+ def dtype(self):
+ try:
+ return self.conv_pre.weight.dtype
+ except StopIteration:
+ return torch.float32
+
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
@@ -387,15 +431,16 @@ def remove_weight_norm(self):
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
- self.m_source.remove_weight_norm()
+ # self.m_source.remove_weight_norm() # SourceModuleHnNSF does not have this method
for l in self.source_downs:
remove_weight_norm(l)
for l in self.source_resblocks:
l.remove_weight_norm()
def _stft(self, x):
+ # This function requires FP32 input for stft
spec = torch.stft(
- x,
+ x.float(),
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
return_complex=True)
spec = torch.view_as_real(spec) # [B, F, TT, 2]
@@ -405,13 +450,25 @@ def _istft(self, magnitude, phase):
magnitude = torch.clip(magnitude, max=1e2)
real = magnitude * torch.cos(phase)
img = magnitude * torch.sin(phase)
- inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
+ # This function requires FP32 complex input for istft
+ inverse_transform = torch.istft(torch.complex(real.float(), img.float()), self.istft_params["n_fft"], self.istft_params["hop_len"],
self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
return inverse_transform
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
+ expected_dtype = self.dtype
+ # Defensively cast inputs to match the model's expected dtype
+ if x.dtype != expected_dtype:
+ x = x.to(expected_dtype)
+ if s.dtype != expected_dtype:
+ s = s.to(expected_dtype)
+
+ # _stft handles casting its input to float32 internally
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
+
+ # s_stft is now FP32, cast it to the model's dtype for downstream layers
+ s_stft = s_stft.to(expected_dtype)
x = self.conv_pre(x)
for i in range(self.num_upsamples):
@@ -439,6 +496,7 @@ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> tor
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
+ # _istft handles casting its inputs to float32 internally
x = self._istft(magnitude, phase)
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
return x
@@ -469,6 +527,6 @@ def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torc
s = s.transpose(1, 2)
# use cache_source to avoid glitch
if cache_source.shape[2] != 0:
- s[:, :, :cache_source.shape[2]] = cache_source
+ s[:, :, :cache_source.shape[2]] = cache_source.to(s.dtype)
generated_speech = self.decode(x=speech_feat, s=s)
- return generated_speech, s
+ return generated_speech, s
\ No newline at end of file
diff --git a/src/chatterbox/models/s3gen/matcha/decoder.py b/src/chatterbox/models/s3gen/matcha/decoder.py
index 6919f32d..b18b3c18 100644
--- a/src/chatterbox/models/s3gen/matcha/decoder.py
+++ b/src/chatterbox/models/s3gen/matcha/decoder.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3gen/decoder.py
+
import math
from typing import Optional
diff --git a/src/chatterbox/models/s3gen/matcha/flow_matching.py b/src/chatterbox/models/s3gen/matcha/flow_matching.py
index add7b08c..eefe78cb 100644
--- a/src/chatterbox/models/s3gen/matcha/flow_matching.py
+++ b/src/chatterbox/models/s3gen/matcha/flow_matching.py
@@ -1,3 +1,4 @@
+#chatterbox/src/chatterbox/models/s3gen/flow_matching.py
from abc import ABC
import torch
diff --git a/src/chatterbox/models/s3gen/matcha/text_encoder.py b/src/chatterbox/models/s3gen/matcha/text_encoder.py
index 276eee73..9d8fbb1c 100644
--- a/src/chatterbox/models/s3gen/matcha/text_encoder.py
+++ b/src/chatterbox/models/s3gen/matcha/text_encoder.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3gen/text_encoder.py
+
""" from https://github.com/jaywalnut310/glow-tts """
import math
diff --git a/src/chatterbox/models/s3gen/matcha/transformer.py b/src/chatterbox/models/s3gen/matcha/transformer.py
index dd1afa3a..5bc77767 100644
--- a/src/chatterbox/models/s3gen/matcha/transformer.py
+++ b/src/chatterbox/models/s3gen/matcha/transformer.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3gen/matcha/transformer.py
+
from typing import Any, Dict, Optional
import torch
diff --git a/src/chatterbox/models/s3gen/s3gen.py b/src/chatterbox/models/s3gen/s3gen.py
index b1cf05e6..5fd27dee 100644
--- a/src/chatterbox/models/s3gen/s3gen.py
+++ b/src/chatterbox/models/s3gen/s3gen.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3gen/s3gen.py
+
# Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -33,11 +35,6 @@
from .configs import CFM_PARAMS
-def drop_invalid_tokens(x):
- assert len(x.shape) <= 2 and x.shape[0] == 1, "only batch size of one allowed for now"
- return x[x < SPEECH_VOCAB_SIZE]
-
-
# TODO: global resampler cache
@lru_cache(100)
def get_resampler(src_sr, dst_sr, device):
@@ -146,7 +143,7 @@ def embed_ref(
"Reference mel length is not equal to 2 * reference token length.\n"
)
ref_speech_tokens = ref_speech_tokens[:, :ref_mels_24.shape[1] // 2]
- ref_speech_token_lens[0] = ref_speech_tokens.shape[1]
+ ref_speech_token_lens.fill_(ref_speech_tokens.shape[1])
return dict(
prompt_token=ref_speech_tokens.to(device),
@@ -159,6 +156,7 @@ def embed_ref(
def forward(
self,
speech_tokens: torch.LongTensor,
+ speech_token_lens: torch.LongTensor,
# locally-computed ref embedding (mutex with ref_dict)
ref_wav: Optional[torch.Tensor],
ref_sr: Optional[int],
@@ -197,9 +195,6 @@ def forward(
if len(speech_tokens.shape) == 1:
speech_tokens = speech_tokens.unsqueeze(0)
- # assert speech_tokens.shape[0] == 1, "only batch size of one allowed for now"
- speech_token_lens = torch.LongTensor([speech_tokens.size(1)]).to(self.device)
-
output_mels, _ = self.flow.inference(
token=speech_tokens,
token_len=speech_token_lens,
@@ -235,9 +230,17 @@ def __init__(self):
trim_fade[n_trim:] = (torch.cos(torch.linspace(torch.pi, 0, n_trim)) + 1) / 2
self.register_buffer("trim_fade", trim_fade, persistent=False) # (buffers get automatic device casting)
+ @property
+ def dtype(self):
+ try:
+ return next(self.parameters()).dtype
+ except StopIteration:
+ return torch.float32
+
def forward(
self,
speech_tokens,
+ speech_token_lens: torch.LongTensor,
# locally-computed ref embedding (mutex with ref_dict)
ref_wav: Optional[torch.Tensor],
ref_sr: Optional[int],
@@ -245,16 +248,22 @@ def forward(
ref_dict: Optional[dict] = None,
finalize: bool = False
):
- output_mels = super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
+ # output_mels will be in the precision of the flow model (e.g., BF16)
+ output_mels = super().forward(speech_tokens, speech_token_lens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
- # TODO jrm: ignoring the speed control (mel interpolation) and the HiFTGAN caching mechanisms for now.
- hift_cache_source = torch.zeros(1, 1, 0).to(self.device)
+ # Ensure cache_source matches batch size, dtype, and device
+ B = output_mels.shape[0]
+ # Use self.dtype and correct batch size B instead of hardcoding
+ hift_cache_source = torch.zeros(B, 1, 0, device=self.device, dtype=self.dtype)
+ # The input (output_mels) already matches the expected dtype (self.dtype).
output_wavs, *_ = self.mel2wav.inference(speech_feat=output_mels, cache_source=hift_cache_source)
if not self.training:
# NOTE: ad-hoc method to reduce "spillover" from the reference clip.
- output_wavs[:, :len(self.trim_fade)] *= self.trim_fade
+ # Robustly apply trim_fade: Ensure slice length doesn't exceed wav length
+ trim_len = min(len(self.trim_fade), output_wavs.shape[1])
+ output_wavs[:, :trim_len] *= self.trim_fade[:trim_len]
return output_wavs
@@ -262,6 +271,7 @@ def forward(
def flow_inference(
self,
speech_tokens,
+ speech_token_lens: torch.LongTensor,
# locally-computed ref embedding (mutex with ref_dict)
ref_wav: Optional[torch.Tensor] = None,
ref_sr: Optional[int] = None,
@@ -269,18 +279,27 @@ def flow_inference(
ref_dict: Optional[dict] = None,
finalize: bool = False,
):
- return super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
+ return super().forward(speech_tokens, speech_token_lens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
@torch.inference_mode()
def hift_inference(self, speech_feat, cache_source: torch.Tensor = None):
if cache_source is None:
- cache_source = torch.zeros(1, 1, 0).to(self.device)
+ B = speech_feat.shape[0]
+ # Ensure cache_source matches dtype and device
+ cache_source = torch.zeros(B, 1, 0, device=self.device, dtype=self.dtype)
+
+ # Ensure input feat matches expected dtype (in case this is called standalone)
+ if speech_feat.dtype != self.dtype:
+ speech_feat = speech_feat.to(self.dtype)
+
return self.mel2wav.inference(speech_feat=speech_feat, cache_source=cache_source)
+
@torch.inference_mode()
def inference(
self,
speech_tokens,
+ speech_token_lens: Optional[torch.LongTensor] = None,
# locally-computed ref embedding (mutex with ref_dict)
ref_wav: Optional[torch.Tensor] = None,
ref_sr: Optional[int] = None,
@@ -289,10 +308,18 @@ def inference(
cache_source: torch.Tensor = None, # NOTE: this arg is for streaming, it can probably be removed here
finalize: bool = True,
):
- output_mels = self.flow_inference(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
+ if speech_token_lens is None:
+ speech_token_lens = torch.tensor([speech_tokens.size(1)] * speech_tokens.size(0), device=speech_tokens.device)
+
+ # output_mels will be in the model precision (e.g., BF16)
+ output_mels = self.flow_inference(speech_tokens, speech_token_lens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
+
+ # hift_inference handles dtype matching internally now
output_wavs, output_sources = self.hift_inference(output_mels, cache_source)
# NOTE: ad-hoc method to reduce "spillover" from the reference clip.
- output_wavs[:, :len(self.trim_fade)] *= self.trim_fade
+ # Robustly apply trim_fade
+ trim_len = min(len(self.trim_fade), output_wavs.shape[1])
+ output_wavs[:, :trim_len] *= self.trim_fade[:trim_len]
- return output_wavs, output_sources
+ return output_wavs, output_sources
\ No newline at end of file
diff --git a/src/chatterbox/models/s3gen/transformer/activation.py b/src/chatterbox/models/s3gen/transformer/activation.py
index 8cea5481..00180d0d 100644
--- a/src/chatterbox/models/s3gen/transformer/activation.py
+++ b/src/chatterbox/models/s3gen/transformer/activation.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3gen/transformer/activation.py
+
# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
# 2020 Northwestern Polytechnical University (Pengcheng Guo)
# 2020 Mobvoi Inc (Binbin Zhang)
diff --git a/src/chatterbox/models/s3gen/transformer/attention.py b/src/chatterbox/models/s3gen/transformer/attention.py
index 95e1d840..046b7c7c 100644
--- a/src/chatterbox/models/s3gen/transformer/attention.py
+++ b/src/chatterbox/models/s3gen/transformer/attention.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3gen/transformer/attention.py
+
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
diff --git a/src/chatterbox/models/s3gen/transformer/convolution.py b/src/chatterbox/models/s3gen/transformer/convolution.py
index 4d5d9614..7f6d963b 100644
--- a/src/chatterbox/models/s3gen/transformer/convolution.py
+++ b/src/chatterbox/models/s3gen/transformer/convolution.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3gen/transformer/convolution.py
+
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
# 2024 Alibaba Inc (Xiang Lyu)
#
diff --git a/src/chatterbox/models/s3gen/transformer/embedding.py b/src/chatterbox/models/s3gen/transformer/embedding.py
index eae8c8ec..bb662e4a 100644
--- a/src/chatterbox/models/s3gen/transformer/embedding.py
+++ b/src/chatterbox/models/s3gen/transformer/embedding.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3gen/transformer/embedding.py
+
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
# 2024 Alibaba Inc (Xiang Lyu)
#
diff --git a/src/chatterbox/models/s3gen/transformer/encoder_layer.py b/src/chatterbox/models/s3gen/transformer/encoder_layer.py
index efbb12dd..636b884f 100644
--- a/src/chatterbox/models/s3gen/transformer/encoder_layer.py
+++ b/src/chatterbox/models/s3gen/transformer/encoder_layer.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3gen/transformer/encoder_layer.py
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
#
diff --git a/src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py b/src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py
index b7a2cf6e..9a05a31f 100644
--- a/src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py
+++ b/src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py
+
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
#
diff --git a/src/chatterbox/models/s3gen/transformer/subsampling.py b/src/chatterbox/models/s3gen/transformer/subsampling.py
index e17c2e32..df570b14 100644
--- a/src/chatterbox/models/s3gen/transformer/subsampling.py
+++ b/src/chatterbox/models/s3gen/transformer/subsampling.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3gen/transformer/subsampling.py
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2024 Alibaba Inc (Xiang Lyu)
#
diff --git a/src/chatterbox/models/s3gen/transformer/upsample_encoder.py b/src/chatterbox/models/s3gen/transformer/upsample_encoder.py
index 766a5e4e..2af7ef39 100644
--- a/src/chatterbox/models/s3gen/transformer/upsample_encoder.py
+++ b/src/chatterbox/models/s3gen/transformer/upsample_encoder.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3gen/transformer/upsample_encoder.py
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
# 2024 Alibaba Inc (Xiang Lyu)
diff --git a/src/chatterbox/models/s3gen/utils/class_utils.py b/src/chatterbox/models/s3gen/utils/class_utils.py
index cd31e480..2e176564 100644
--- a/src/chatterbox/models/s3gen/utils/class_utils.py
+++ b/src/chatterbox/models/s3gen/utils/class_utils.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3gen/utils/class_utils.py
+
# Copyright [2023-11-28]
# 2024 Alibaba Inc (authors: Xiang Lyu)
#
diff --git a/src/chatterbox/models/s3gen/utils/mask.py b/src/chatterbox/models/s3gen/utils/mask.py
index 08c97a3e..faa63f84 100644
--- a/src/chatterbox/models/s3gen/utils/mask.py
+++ b/src/chatterbox/models/s3gen/utils/mask.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3gen/utils/mask.py
+
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
# 2024 Alibaba Inc (authors: Xiang Lyu)
diff --git a/src/chatterbox/models/s3gen/utils/mel.py b/src/chatterbox/models/s3gen/utils/mel.py
index 5a9ff9d1..d020cc8d 100644
--- a/src/chatterbox/models/s3gen/utils/mel.py
+++ b/src/chatterbox/models/s3gen/utils/mel.py
@@ -1,8 +1,13 @@
+#chatterbox/src/chatterbox/models/s3gen/utils/mel.py
+
"""mel-spectrogram extraction in Matcha-TTS"""
+import logging
from librosa.filters import mel as librosa_mel_fn
import torch
import numpy as np
+logger = logging.getLogger(__name__)
+
# NOTE: they decalred these global vars
mel_basis = {}
@@ -42,10 +47,11 @@ def mel_spectrogram(y, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=48
if len(y.shape) == 1:
y = y[None, ]
- if torch.min(y) < -1.0:
- print("min value is ", torch.min(y))
- if torch.max(y) > 1.0:
- print("max value is ", torch.max(y))
+ # Debug: Check for audio clipping (values outside [-1.0, 1.0] range)
+ min_val = torch.min(y)
+ max_val = torch.max(y)
+ if min_val < -1.0 or max_val > 1.0:
+ logger.warning(f"Audio values outside normalized range: min={min_val.item():.4f}, max={max_val.item():.4f}")
global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
diff --git a/src/chatterbox/models/s3gen/xvector.py b/src/chatterbox/models/s3gen/xvector.py
index 6eb99af4..4c2a758a 100644
--- a/src/chatterbox/models/s3gen/xvector.py
+++ b/src/chatterbox/models/s3gen/xvector.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3gen/xvector.py
+
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
@@ -414,9 +416,20 @@ def __init__(
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
+ @property
+ def dtype(self):
+ # We can infer the dtype from any parameter (e.g., the first one)
+ try:
+ return next(self.parameters()).dtype
+ except StopIteration:
+ return torch.float32 # Default if no parameters found
+
def forward(self, x):
+ # Ensure input matches the model dtype (e.g., BF16)
+ if x.dtype != self.dtype:
+ x = x.to(self.dtype)
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
- x = self.head(x)
+ x = self.head(x) # The crash occurred inside here (FCM)
x = self.xvector(x)
if self.output_level == "frame":
x = x.transpose(1, 2)
@@ -424,5 +437,5 @@ def forward(self, x):
def inference(self, audio_list):
speech, speech_lengths, speech_times = extract_feature(audio_list)
- results = self.forward(speech.to(torch.float32))
+ results = self.forward(speech)
return results
diff --git a/src/chatterbox/models/s3tokenizer/__init__.py b/src/chatterbox/models/s3tokenizer/__init__.py
index cb2973ab..b4799167 100644
--- a/src/chatterbox/models/s3tokenizer/__init__.py
+++ b/src/chatterbox/models/s3tokenizer/__init__.py
@@ -5,26 +5,9 @@
S3_TOKEN_RATE,
SPEECH_VOCAB_SIZE,
S3Tokenizer,
+ drop_invalid_tokens,
)
SOS = SPEECH_VOCAB_SIZE
-EOS = SPEECH_VOCAB_SIZE + 1
-
-
-
-def drop_invalid_tokens(x):
- """Drop SoS and EoS"""
- assert len(x.shape) == 1 or (len(x.shape) == 2 and x.shape[0] == 1), "only batch size of one allowed for now"
- if SOS in x:
- s = (x == SOS).nonzero(as_tuple=True)[0].squeeze(0) + 1
- else:
- s = 0
-
- if EOS in x:
- e = (x == EOS).nonzero(as_tuple=True)[0].squeeze(0)
- else:
- e = None
-
- x = x[s: e]
- return x
+EOS = SPEECH_VOCAB_SIZE + 1
\ No newline at end of file
diff --git a/src/chatterbox/models/s3tokenizer/s3tokenizer.py b/src/chatterbox/models/s3tokenizer/s3tokenizer.py
index 8648608a..e111fe69 100644
--- a/src/chatterbox/models/s3tokenizer/s3tokenizer.py
+++ b/src/chatterbox/models/s3tokenizer/s3tokenizer.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/s3tokenizer/s3tokenizer.py
+
from typing import List, Tuple
import numpy as np
@@ -19,11 +21,24 @@
SPEECH_VOCAB_SIZE = 6561
+def drop_invalid_tokens(x: torch.Tensor) -> List[torch.Tensor]:
+ """
+ Filters out invalid tokens from a batch of token sequences.
+ Invalid tokens are those with an ID >= SPEECH_VOCAB_SIZE.
+ Returns a list of tensors, as each sequence may have a different length after filtering.
+ """
+ if x.dim() == 1:
+ x = x.unsqueeze(0)
+ return [row[row < SPEECH_VOCAB_SIZE] for row in x]
+
+
class S3Tokenizer(S3TokenizerV2):
"""
s3tokenizer.S3TokenizerV2 with the following changes:
- a more integrated `forward`
- compute `log_mel_spectrogram` using `_mel_filters` and `window` in `register_buffers`
+ - robust dtype handling for mixed precision support
+ - corrected batch normalization in log_mel_spectrogram
"""
ignore_state_dict_missing = ("_mel_filters", "window")
@@ -41,9 +56,12 @@ def __init__(
n_fft=self.n_fft,
n_mels=config.n_mels
)
+
+ # Register buffers without explicit FP32 cast to allow them to adopt
+ # the module's dtype when .to(dtype) is called.
self.register_buffer(
"_mel_filters",
- torch.FloatTensor(_mel_filters),
+ torch.from_numpy(_mel_filters).float(),
)
self.register_buffer(
@@ -51,6 +69,14 @@ def __init__(
torch.hann_window(self.n_fft),
)
+ @property
+ def dtype(self):
+ # Dynamically determine the dtype from parameters/buffers
+ try:
+ return self.window.dtype
+ except StopIteration:
+ return torch.float32
+
def pad(self, wavs, sr) -> List[torch.Tensor]:
"""
Given a list of wavs with the same `sample_rate`, pad them so that the length is multiple of 40ms (S3 runs at 25 token/sec).
@@ -119,7 +145,9 @@ def forward(
else:
tokenizer = accelerator.unwrap_model(self)
- speech_tokens, speech_token_lens = tokenizer.quantize(mels, mel_lens.to(self.device))
+ # The parent S3TokenizerV2's quantize method is sensitive to dtype.
+ # For stability, we ensure the input is float32, creating an "FP32 island".
+ speech_tokens, speech_token_lens = tokenizer.float().quantize(mels.float(), mel_lens.to(self.device))
return (
speech_tokens.long().detach(),
speech_token_lens.long().detach(),
@@ -150,19 +178,44 @@ def log_mel_spectrogram(
if not torch.is_tensor(audio):
audio = torch.from_numpy(audio)
- audio = audio.to(self.device)
+ # `torch.stft` requires FP32 input. Cast audio to float32 here.
+ audio = audio.to(self.device).float()
+
if padding > 0:
audio = F.pad(audio, (0, padding))
stft = torch.stft(
audio, self.n_fft, S3_HOP,
- window=self.window.to(self.device),
+ window=self.window,
return_complex=True
)
magnitudes = stft[..., :-1].abs()**2
- mel_spec = self._mel_filters.to(self.device) @ magnitudes
+ # Cast magnitudes to the target dtype BEFORE matrix multiplication
+ # This ensures the matrix multiplication is done in the desired precision (e.g., BF16).
+ magnitudes = magnitudes.to(self.dtype)
+
+ # This matrix multiplication now works because both operands are the same dtype.
+ mel_spec = self._mel_filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
- log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
+
+ # Robustness Enhancement: Handle potential instability in mixed precision
+ if not torch.isfinite(log_spec).all():
+ print("Warning: Numerical instability detected in log_mel_spectrogram (BF16/FP16). Falling back to FP32 calculation for this step.")
+ mel_spec_fp32 = mel_spec.float()
+ log_spec = torch.clamp(mel_spec_fp32, min=1e-10).log10()
+ log_spec = log_spec.to(self.dtype)
+
+ # Functional Fix: Correct normalization for batched input
+ B = log_spec.shape[0]
+ if B > 1:
+ # Flatten feature/time dimensions to find max per batch item
+ max_vals = log_spec.view(B, -1).max(dim=1, keepdim=True)[0].view(B, 1, 1)
+ else:
+ # Optimization for single input
+ max_vals = log_spec.max()
+
+ log_spec = torch.maximum(log_spec, max_vals - 8.0)
+
log_spec = (log_spec + 4.0) / 4.0
- return log_spec
+ return log_spec
\ No newline at end of file
diff --git a/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py b/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py
index d3a144f0..03115eb5 100644
--- a/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py
+++ b/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py
+
# Copyright (c) 2025 Resemble AI
# Author: John Meade, Jeremy Hsu
# MIT License
@@ -10,6 +12,9 @@
logger = logging.getLogger(__name__)
+LLAMA_ALIGNED_HEADS = [(12, 15), (13, 11), (9, 2)]
+
+
@dataclass
class AlignmentAnalysisResult:
# was this frame detected as being part of a noisy beginning chunk with potential hallucinations?
@@ -49,21 +54,22 @@ def __init__(self, tfmr, queue, text_tokens_slice, alignment_layer_idx=9, eos_id
self.complete = False
self.completed_at = None
+
+ # Track generated tokens for repetition detection
+ self.generated_tokens = []
# Using `output_attentions=True` is incompatible with optimized attention kernels, so
# using it for all layers slows things down too much. We can apply it to just one layer
# by intercepting the kwargs and adding a forward hook (credit: jrm)
- self.last_aligned_attn = None
- self._add_attention_spy(tfmr, alignment_layer_idx)
+ self.last_aligned_attns = []
+ for i, (layer_idx, head_idx) in enumerate(LLAMA_ALIGNED_HEADS):
+ self.last_aligned_attns += [None]
+ self._add_attention_spy(tfmr, i, layer_idx, head_idx)
- def _add_attention_spy(self, tfmr, alignment_layer_idx):
+ def _add_attention_spy(self, tfmr, buffer_idx, layer_idx, head_idx):
"""
Adds a forward hook to a specific attention layer to collect outputs.
- Using `output_attentions=True` is incompatible with optimized attention kernels, so
- using it for all layers slows things down too much.
- (credit: jrm)
"""
-
def attention_forward_hook(module, input, output):
"""
See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`.
@@ -71,27 +77,23 @@ def attention_forward_hook(module, input, output):
- When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`.
- `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th.
"""
- step_attention = output[1].cpu() # (B, 16, N, N)
- self.last_aligned_attn = step_attention[0].mean(0) # (N, N)
-
- target_layer = tfmr.layers[alignment_layer_idx].self_attn
- hook_handle = target_layer.register_forward_hook(attention_forward_hook)
-
- # Backup original forward
- original_forward = target_layer.forward
- def patched_forward(self, *args, **kwargs):
- kwargs['output_attentions'] = True
- return original_forward(*args, **kwargs)
-
- # TODO: how to unpatch it?
- target_layer.forward = MethodType(patched_forward, target_layer)
-
- def step(self, logits):
+ if isinstance(output, tuple) and len(output) > 1 and output[1] is not None:
+ step_attention = output[1].cpu() # (B, n_heads, T0, Ti)
+ self.last_aligned_attns[buffer_idx] = step_attention[0, head_idx] # (T0, Ti)
+
+ target_layer = tfmr.layers[layer_idx].self_attn
+ # Register hook and store the handle
+ target_layer.register_forward_hook(attention_forward_hook)
+ if hasattr(tfmr, 'config') and hasattr(tfmr.config, 'output_attentions'):
+ self.original_output_attentions = tfmr.config.output_attentions
+ tfmr.config.output_attentions = True
+
+ def step(self, logits, next_token=None):
"""
Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS.
"""
# extract approximate alignment matrix chunk (1 frame at a time after the first chunk)
- aligned_attn = self.last_aligned_attn # (N, N)
+ aligned_attn = torch.stack(self.last_aligned_attns).mean(dim=0) # (N, N)
i, j = self.text_tokens_slice
if self.curr_frame_pos == 0:
# first chunk has conditioning info, text tokens, and BOS token
@@ -133,22 +135,46 @@ def step(self, logits):
last_text_token_duration = A[15:, -3:].sum()
# Activations for the final token that last too long are likely hallucinations.
- long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 10) # 400ms
+ long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 5) # 200ms
# If there are activations in previous tokens after generation has completed, assume this is a repetition error.
- repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5)
+ alignment_repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5)
+
+ # Track generated tokens for repetition detection
+ if next_token is not None:
+ # Convert tensor to scalar if needed
+ if isinstance(next_token, torch.Tensor):
+ token_id = next_token.item() if next_token.numel() == 1 else next_token.view(-1)[0].item()
+ else:
+ token_id = next_token
+ self.generated_tokens.append(token_id)
+
+ # Keep only last 8 tokens to prevent memory issues
+ if len(self.generated_tokens) > 8:
+ self.generated_tokens = self.generated_tokens[-8:]
+
+ # Check for excessive token repetition (3x same token in a row)
+ token_repetition = (
+ # self.complete and
+ len(self.generated_tokens) >= 3 and
+ len(set(self.generated_tokens[-2:])) == 1
+ )
+
+ if token_repetition:
+ repeated_token = self.generated_tokens[-1]
+ logger.warning(f"🚨 Detected 2x repetition of token {repeated_token}")
+
+ # Suppress EoS to prevent early termination
+ if cur_text_posn < S - 3 and S > 5: # Only suppress if text is longer than 5 tokens
+ logits[..., self.eos_idx] = -2**15
# If a bad ending is detected, force emit EOS by modifying logits
# NOTE: this means logits may be inconsistent with latents!
- if long_tail or repetition:
- logger.warn(f"forcing EOS token, {long_tail=}, {repetition=}")
+ if long_tail or alignment_repetition or token_repetition:
+ logger.warning(f"forcing EOS token, {long_tail=}, {alignment_repetition=}, {token_repetition=}")
# (±2**15 is safe for all dtypes >= 16bit)
logits = -(2**15) * torch.ones_like(logits)
logits[..., self.eos_idx] = 2**15
- # Suppress EoS to prevent early termination
- if cur_text_posn < S - 3: # FIXME: arbitrary
- logits[..., self.eos_idx] = -2**15
-
self.curr_frame_pos += 1
return logits
diff --git a/src/chatterbox/models/t3/inference/t3_hf_backend.py b/src/chatterbox/models/t3/inference/t3_hf_backend.py
index 69a6bf20..031b1a12 100644
--- a/src/chatterbox/models/t3/inference/t3_hf_backend.py
+++ b/src/chatterbox/models/t3/inference/t3_hf_backend.py
@@ -1,3 +1,4 @@
+#chatterbox/src/chatterbox/models/t3/inference/t3_hf_backend.py
from typing import Optional
import torch
diff --git a/src/chatterbox/models/t3/llama_configs.py b/src/chatterbox/models/t3/llama_configs.py
index 14d06816..fa3cfbba 100644
--- a/src/chatterbox/models/t3/llama_configs.py
+++ b/src/chatterbox/models/t3/llama_configs.py
@@ -1,3 +1,15 @@
+#chatterbox/src/chatterbox/models/t3/llama_configs.py
+
+# Detect Flash Attention 2
+try:
+ import flash_attn
+ ATTN_IMPLEMENTATION = "flash_attention_2"
+ print("Successfully detected Flash Attention 2. It will be used.")
+except ImportError:
+ # Fallback to PyTorch's built-in optimized attention
+ ATTN_IMPLEMENTATION = "sdpa"
+ print("Flash Attention 2 not found. Falling back to SDPA. Install 'flash-attn' for potentially better performance.")
+
LLAMA_520M_CONFIG_DICT = dict(
# Arbitrary small number that won't cause problems when loading.
# These param are unused due to custom input layers.
@@ -8,7 +20,7 @@
intermediate_size=4096,
num_hidden_layers=30,
num_attention_heads=16,
- attn_implementation="sdpa",
+ attn_implementation=ATTN_IMPLEMENTATION, # Use the detected implementation
head_dim=64,
tie_word_embeddings=False,
hidden_act="silu",
@@ -28,10 +40,10 @@
rope_type="llama3"
),
rope_theta=500000.0,
- torch_dtype="bfloat16",
+ torch_dtype="bfloat16", # This is informational; runtime dtype is set during loading.
use_cache=True,
)
LLAMA_CONFIGS = {
"Llama_520M": LLAMA_520M_CONFIG_DICT,
-}
+}
\ No newline at end of file
diff --git a/src/chatterbox/models/t3/modules/cond_enc.py b/src/chatterbox/models/t3/modules/cond_enc.py
index b5f15c68..2f2a17b7 100644
--- a/src/chatterbox/models/t3/modules/cond_enc.py
+++ b/src/chatterbox/models/t3/modules/cond_enc.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/t3/modules/cond_enc.py
+
from dataclasses import dataclass
from typing import Optional
@@ -25,8 +27,12 @@ def to(self, *, device=None, dtype=None):
"Cast to a device and dtype. Dtype casting is ignored for long/int tensors."
for k, v in self.__dict__.items():
if torch.is_tensor(v):
- is_fp = type(v.view(-1)[0].item()) is not int
- setattr(self, k, v.to(device=device, dtype=dtype if is_fp else None))
+ # Use torch.is_floating_point() for a robust check (safe for empty tensors)
+ if v.is_floating_point() and dtype is not None:
+ setattr(self, k, v.to(device=device, dtype=dtype))
+ elif device is not None:
+ # Only move to device if dtype casting is not applicable or not requested
+ setattr(self, k, v.to(device=device))
return self
def save(self, fpath):
diff --git a/src/chatterbox/models/t3/modules/learned_pos_emb.py b/src/chatterbox/models/t3/modules/learned_pos_emb.py
index 9b197f21..536b9d6f 100644
--- a/src/chatterbox/models/t3/modules/learned_pos_emb.py
+++ b/src/chatterbox/models/t3/modules/learned_pos_emb.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/t3/modules/learned_pos_emb.py
+
from typing import Union
import torch
diff --git a/src/chatterbox/models/t3/modules/perceiver.py b/src/chatterbox/models/t3/modules/perceiver.py
index be9c5b86..99e562d1 100644
--- a/src/chatterbox/models/t3/modules/perceiver.py
+++ b/src/chatterbox/models/t3/modules/perceiver.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/t3/modules/perceiver.py
+
# Copyright (c) 2025 Resemble AI
# Author: Manmay Nakhashi
# MIT License
diff --git a/src/chatterbox/models/t3/modules/t3_config.py b/src/chatterbox/models/t3/modules/t3_config.py
index 2769d835..81a46583 100644
--- a/src/chatterbox/models/t3/modules/t3_config.py
+++ b/src/chatterbox/models/t3/modules/t3_config.py
@@ -1,27 +1,43 @@
+#chatterbox/src/chatterbox/models/t3/modules/t3_config.py
+
from ..llama_configs import LLAMA_CONFIGS
class T3Config:
- start_text_token = 255
- stop_text_token = 0
- text_tokens_dict_size = 704
- max_text_tokens = 2048
-
- start_speech_token = 6561
- stop_speech_token = 6562
- speech_tokens_dict_size = 8194
- max_speech_tokens = 4096
-
- llama_config_name = "Llama_520M"
- input_pos_emb = "learned"
- speech_cond_prompt_len = 150
-
- # For T3CondEnc
- encoder_type = "voice_encoder"
- speaker_embed_size = 256
- use_perceiver_resampler = True
- emotion_adv = True
+ def __init__(self, text_tokens_dict_size=704):
+ self.start_text_token = 255
+ self.stop_text_token = 0
+ self.text_tokens_dict_size = text_tokens_dict_size
+ self.max_text_tokens = 2048
+
+ self.start_speech_token = 6561
+ self.stop_speech_token = 6562
+ self.speech_tokens_dict_size = 8194
+ self.max_speech_tokens = 4096
+
+ self.llama_config_name = "Llama_520M"
+ self.input_pos_emb = "learned"
+ self.speech_cond_prompt_len = 150
+
+ self.encoder_type = "voice_encoder"
+ self.speaker_embed_size = 256
+ self.use_perceiver_resampler = True
+ self.emotion_adv = True
@property
def n_channels(self):
return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"]
+
+ @property
+ def is_multilingual(self):
+ return self.text_tokens_dict_size == 2454
+
+ @classmethod
+ def english_only(cls):
+ """Create configuration for English-only TTS model."""
+ return cls(text_tokens_dict_size=704)
+
+ @classmethod
+ def multilingual(cls):
+ """Create configuration for multilingual TTS model."""
+ return cls(text_tokens_dict_size=2454)
\ No newline at end of file
diff --git a/src/chatterbox/models/t3/t3.py b/src/chatterbox/models/t3/t3.py
index 4984470c..93d2babd 100644
--- a/src/chatterbox/models/t3/t3.py
+++ b/src/chatterbox/models/t3/t3.py
@@ -1,14 +1,18 @@
+#chatterbox/src/chatterbox/models/t3/t3.py
+
# Copyright (c) 2025 Resemble AI
# MIT License
import logging
from typing import Union, Optional, List
+logger = logging.getLogger(__name__)
+
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from transformers import LlamaModel, LlamaConfig
-from transformers.generation.logits_process import MinPLogitsWarper, RepetitionPenaltyLogitsProcessor, TopPLogitsWarper
+from transformers.generation.logits_process import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor, MinPLogitsWarper
from .modules.learned_pos_emb import LearnedPositionEmbeddings
@@ -16,6 +20,7 @@
from .modules.t3_config import T3Config
from .llama_configs import LLAMA_CONFIGS
from .inference.t3_hf_backend import T3HuggingfaceBackend
+from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer
from ..utils import AttrDict
@@ -38,11 +43,33 @@ class T3(nn.Module):
different PE embedding space for speech.
"""
- def __init__(self, hp=T3Config()):
+ def __init__(self, hp=None, dtype=torch.float32):
+ if hp is None:
+ hp = T3Config.english_only() # Default to English-only config for backward compatibility
super().__init__()
self.hp = hp
- self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name])
- self.tfmr = LlamaModel(self.cfg)
+
+ config_dict = LLAMA_CONFIGS[hp.llama_config_name].copy()
+ # Note: We no longer need to set torch_dtype in the config, as it's not used for scratch initialization.
+ # config_dict['torch_dtype'] = dtype
+ self.cfg = LlamaConfig(**config_dict)
+
+ # --- DEFINITIVE FIX for Flash Attention Initialization ---
+ # The LlamaModel constructor does not respect config.torch_dtype for scratch initialization;
+ # it uses the global default. We temporarily set the global default dtype to our target dtype
+ # to ensure the model's layers are created in BF16 from the start.
+ original_dtype = torch.get_default_dtype()
+ try:
+ if dtype in [torch.float16, torch.bfloat16]:
+ torch.set_default_dtype(dtype)
+
+ self.tfmr = LlamaModel(self.cfg)
+
+ finally:
+ # Always restore the original default dtype to avoid side-effects elsewhere.
+ torch.set_default_dtype(original_dtype)
+ # --- End of Fix ---
+
self.dim = self.cfg.hidden_size
self.deepspeed_patch_applied = False
@@ -89,7 +116,8 @@ def prepare_input_embeds(
cond_emb = self.prepare_conditioning(t3_cond) # (B, len_cond, dim)
text_emb = self.text_emb(text_tokens) # (B, len_text, dim)
if cfg_weight > 0.0:
- text_emb[1].zero_() # CFG uncond
+ B = text_tokens.size(0) // 2
+ text_emb[B:].zero_() # CFG uncond
speech_emb = self.speech_emb(speech_tokens) # (B, len_speech, dim)
if self.hp.input_pos_emb == "learned":
@@ -101,10 +129,7 @@ def prepare_input_embeds(
cond_emb = cond_emb.expand(text_emb.size(0), -1, -1)
# concat
- embeds = torch.stack([
- torch.cat((ce, te, se))
- for ce, te, se in zip(cond_emb, text_emb, speech_emb)
- ]) # (B, length, dim)
+ embeds = torch.cat([cond_emb, text_emb, speech_emb], dim=1) # (B, length, dim)
return embeds, len_cond
def forward(
@@ -200,6 +225,7 @@ def loss(
return loss_text, loss_speech
+
@torch.inference_mode()
def inference(
self,
@@ -207,36 +233,49 @@ def inference(
t3_cond: T3Cond,
text_tokens: Tensor,
initial_speech_tokens: Optional[Tensor]=None,
-
# misc conditioning
prepend_prompt_speech_tokens: Optional[Tensor]=None,
-
# HF generate args
num_return_sequences=1,
max_new_tokens=None,
stop_on_eos=True,
do_sample=True,
temperature=0.8,
+ top_p=0.95,
min_p=0.05,
- top_p=1.00,
length_penalty=1.0,
repetition_penalty=1.2,
- cfg_weight=0,
+ cfg_weight=0.5,
):
"""
- Args:
- text_tokens: a 1D (unbatched) or 2D (batched) tensor.
+ Optimized batched inference with CFG support and early stopping.
"""
# Validate / sanitize inputs
assert prepend_prompt_speech_tokens is None, "not implemented"
_ensure_BOT_EOT(text_tokens, self.hp)
text_tokens = torch.atleast_2d(text_tokens).to(dtype=torch.long, device=self.device)
+
+ # Determine original batch size (B_orig) before CFG duplication
+ B_total = text_tokens.size(0)
+ if cfg_weight > 0.0:
+ # Inputs are expected to be already duplicated if CFG is used (handled in ChatterboxTTS.generate)
+ if B_total % 2 != 0:
+ raise ValueError("Batch size must be even when using CFG due to input duplication.")
+ B_orig = B_total // 2
+ else:
+ B_orig = B_total
+
+ max_new_tokens = max_new_tokens or self.hp.max_speech_tokens
# Default initial speech to a single start-of-speech token
if initial_speech_tokens is None:
- initial_speech_tokens = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1])
+ # Use B_total here as initial_speech_tokens must cover the full batch (cond + uncond)
+ initial_speech_tokens = self.hp.start_speech_token * torch.ones((B_total, 1), dtype=torch.long, device=self.device)
+
+ len_initial_speech = initial_speech_tokens.size(1)
- # Prepare custom input embeds
+ # Prepare custom input embeds [Cond, Text, InitialSpeech]
+ # This correctly prepares the embeddings including positional information for the prefill phase.
embeds, len_cond = self.prepare_input_embeds(
t3_cond=t3_cond,
text_tokens=text_tokens,
@@ -244,66 +283,41 @@ def inference(
cfg_weight=cfg_weight,
)
- # In order to use the standard HF generate method, we need to extend some methods to inject our custom logic
- # Note the llama-specific logic. Other tfmr types can be added later.
+ # Initialize the Huggingface backend if not already done
+ if not self.compiled:
+ # Default to None for English models, only create for multilingual
+ alignment_stream_analyzer = None
+ if self.hp.is_multilingual:
+ # NOTE: Ensure AlignmentStreamAnalyzer supports batching if required.
+ logger.info("Initializing AlignmentStreamAnalyzer for multilingual support.")
+
+ # Assuming AlignmentStreamAnalyzer is imported correctly in the file scope
+ try:
+ alignment_stream_analyzer = AlignmentStreamAnalyzer(
+ self.tfmr,
+ None,
+ text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
+ alignment_layer_idx=9,
+ eos_idx=self.hp.stop_speech_token,
+ )
+ except Exception as e:
+ logger.warning(f"AlignmentStreamAnalyzer initialization failed: {e}. Multilingual stability might be affected.")
- self.compiled = False
- # TODO? synchronize the expensive compile function
- # with self.compile_lock:
- if not self.compiled:
patched_model = T3HuggingfaceBackend(
config=self.cfg,
llama=self.tfmr,
speech_enc=self.speech_emb,
speech_head=self.speech_head,
- alignment_stream_analyzer=None,
+ alignment_stream_analyzer=alignment_stream_analyzer,
)
self.patched_model = patched_model
self.compiled = True
- # # Run normal generate method, which calls our custom extended methods
- # return self.patched_model.generate(
- # inputs=initial_speech_tokens,
- # decoder_cond=embeds,
- # bos_token_id=self.hp.start_speech_token,
- # eos_token_id=(self.hp.stop_speech_token if stop_on_eos else -1),
- # pad_token_id=self.hp.stop_speech_token,
- # max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
- # num_return_sequences=num_return_sequences,
- # temperature=temperature,
- # min_p=min_p,
- # length_penalty=length_penalty,
- # repetition_penalty=repetition_penalty,
- # do_sample=do_sample,
- # # cache_implementation=None if not self.compiled else "static",
- # )
-
+ inputs_embeds = embeds
device = embeds.device
- bos_token = torch.tensor([[self.hp.start_speech_token]], dtype=torch.long, device=device)
- bos_embed = self.speech_emb(bos_token) # shape: (B, 1, embed_dim)
- bos_embed = bos_embed + self.speech_pos_emb.get_fixed_embedding(0)
-
- # batch_size=2 for CFG
- bos_embed = torch.cat([bos_embed, bos_embed])
-
- # Combine condition and BOS token for the initial input if cfg_weight > 0
- if cfg_weight > 0:
- inputs_embeds = torch.cat([embeds, bos_embed], dim=1)
- else:
- inputs_embeds = embeds
-
- # Track generated token ids; start with the BOS token.
- generated_ids = bos_token.clone()
- predicted = [] # To store the predicted tokens
-
- # Instantiate the logits processors.
- min_p_warper = MinPLogitsWarper(min_p=min_p)
- top_p_warper = TopPLogitsWarper(top_p=top_p)
- repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty))
-
- # ---- Initial Forward Pass (no kv_cache yet) ----
+ # ---- Initial Forward Pass (Prefill/Prompt Processing) ----
output = self.patched_model(
inputs_embeds=inputs_embeds,
past_key_values=None,
@@ -315,45 +329,101 @@ def inference(
# Initialize kv_cache with the full context.
past = output.past_key_values
- # ---- Generation Loop using kv_cache ----
- for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
- logits = output.logits[:, -1, :]
+ # ---- Generation Loop (Decoding) using kv_cache ----
- # CFG
- if cfg_weight > 0.0:
- logits_cond = logits[0:1]
- logits_uncond = logits[1:2]
- logits = logits_cond + cfg_weight * (logits_cond - logits_uncond)
+ # Instantiate the logits processors.
+ top_p_warper = TopPLogitsWarper(top_p=top_p)
+ min_p_warper = MinPLogitsWarper(min_p=min_p)
+ repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty))
- logits = logits.squeeze(1)
+ # Pre-allocate tensor for generated IDs (only the conditional part, B_orig)
+ # We use the stop token as the padding value.
+ max_len = len_initial_speech + max_new_tokens
+ generated_ids_cond = torch.full((B_orig, max_len), self.hp.stop_speech_token, dtype=torch.long, device=device)
+ # Copy the initial tokens (e.g., BOS)
+ generated_ids_cond[:, :len_initial_speech] = initial_speech_tokens[:B_orig, :]
+ # Track which sequences are finished
+ is_finished = torch.zeros(B_orig, dtype=torch.bool, device=device)
+ current_token_idx = len_initial_speech
+
+ for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
+ # Get the logits for the last time step
+ logits_step = output.logits[:, -1, :] # (B_total, V)
+
+ # --- CFG combination ---
+ if cfg_weight > 0.0:
+ # Split the logits (B_total, V) into conditional and unconditional parts
+ cond, uncond = torch.split(logits_step, B_orig, dim=0)
+ cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype)
+ # Combine using CFG formula
+ logits = cond + cfg * (cond - uncond) # (B_orig, V)
+ else:
+ logits = logits_step # (B_orig, V)
+
+
+ # --- Apply Logit Processors and Sampling ---
+
+ # Apply alignment stream analyzer (if present, assumes it handles batching)
+ if self.patched_model.alignment_stream_analyzer is not None:
+ # We omit passing the last token here; the analyzer must rely on internal state if needed.
+ # This might need refinement for robust multilingual batched generation.
+ logits = self.patched_model.alignment_stream_analyzer.step(logits, next_token=None)
+
+ # Prepare inputs for processors: generated IDs up to the current step
+ ids_for_proc = generated_ids_cond[:, :current_token_idx]
+
+ # Apply repetition penalty
+ logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B_orig, V)
+
# Apply temperature scaling.
if temperature != 1.0:
logits = logits / temperature
-
- # Apply repetition penalty and top‑p filtering.
- logits = repetition_penalty_processor(generated_ids, logits)
- logits = min_p_warper(None, logits)
- logits = top_p_warper(None, logits)
+
+ # Apply min_p and top_p filtering
+ logits = min_p_warper(ids_for_proc, logits)
+ logits = top_p_warper(ids_for_proc, logits)
# Convert logits to probabilities and sample the next token.
- probs = torch.softmax(logits, dim=-1)
- next_token = torch.multinomial(probs, num_samples=1) # shape: (B, 1)
+ # Use FP32 for softmax for better numerical stability when using half-precision.
+ probs = torch.softmax(logits.float(), dim=-1)
+ next_token = torch.multinomial(probs, num_samples=1) # shape: (B_orig, 1)
- predicted.append(next_token)
- generated_ids = torch.cat([generated_ids, next_token], dim=1)
+ # --- Update Generation State and Check Stopping Condition ---
- # Check for EOS token.
- if next_token.view(-1) == self.hp.stop_speech_token:
+ # If a sequence is finished, force the generation of the EOS token (padding)
+ next_token = torch.where(
+ is_finished.unsqueeze(1),
+ torch.tensor(self.hp.stop_speech_token, device=device, dtype=torch.long),
+ next_token
+ )
+
+ # Update the pre-allocated tensor in-place
+ generated_ids_cond[:, current_token_idx] = next_token.squeeze(-1)
+ current_token_idx += 1
+
+ # Update finished status (check if the newly generated token is EOS)
+ is_finished = is_finished | (next_token.squeeze(-1) == self.hp.stop_speech_token)
+
+ # Check if all sequences are finished (Early Stopping)
+ if is_finished.all():
+ logger.info(f"✅ All sequences finished. Stopping generation at step {i+1}")
break
+ # --- Prepare for Next Step ---
+
# Get embedding for the new token.
next_token_embed = self.speech_emb(next_token)
- next_token_embed = next_token_embed + self.speech_pos_emb.get_fixed_embedding(i + 1)
+
+ # Apply positional embedding for the next step
+ if self.hp.input_pos_emb == "learned":
+ # The position index is len_initial_speech + i
+ next_token_embed = next_token_embed + self.speech_pos_emb.get_fixed_embedding(len_initial_speech + i)
- # For CFG
+ # For CFG, duplicate the embeddings for the full batch (B_total)
if cfg_weight > 0.0:
- next_token_embed = torch.cat([next_token_embed, next_token_embed])
+ # Unconditional part uses the same embeddings as the conditional part for the next step input
+ next_token_embed = torch.cat([next_token_embed, next_token_embed], dim=0)
# Forward pass with only the new token and the cached past.
output = self.patched_model(
@@ -366,6 +436,15 @@ def inference(
# Update the kv_cache.
past = output.past_key_values
- # Concatenate all predicted tokens along the sequence dimension.
- predicted_tokens = torch.cat(predicted, dim=1) # shape: (B, num_tokens)
- return predicted_tokens
+ # Un-batch and trim EOS token
+ output_tokens = []
+ for gen_ids in generated_ids_cond:
+ # Remove initial speech tokens that were passed in
+ gen_ids = gen_ids[len_initial_speech:]
+ # Find the first EOS token and trim
+ eos_idx = (gen_ids == self.hp.stop_speech_token).nonzero(as_tuple=True)[0]
+ if len(eos_idx) > 0:
+ gen_ids = gen_ids[:eos_idx[0]]
+ output_tokens.append(gen_ids)
+
+ return output_tokens
\ No newline at end of file
diff --git a/src/chatterbox/models/tokenizers/__init__.py b/src/chatterbox/models/tokenizers/__init__.py
index 97457e6f..fdf6d727 100644
--- a/src/chatterbox/models/tokenizers/__init__.py
+++ b/src/chatterbox/models/tokenizers/__init__.py
@@ -1 +1 @@
-from .tokenizer import EnTokenizer
+from .tokenizer import EnTokenizer, MTLTokenizer
\ No newline at end of file
diff --git a/src/chatterbox/models/tokenizers/tokenizer.py b/src/chatterbox/models/tokenizers/tokenizer.py
index f3536bc2..15641a03 100644
--- a/src/chatterbox/models/tokenizers/tokenizer.py
+++ b/src/chatterbox/models/tokenizers/tokenizer.py
@@ -1,7 +1,13 @@
+#chatterbox/src/chatterbox/models/tokenizers/tokenizer.py
+
import logging
+import json
import torch
+from pathlib import Path
+from unicodedata import category, normalize
from tokenizers import Tokenizer
+from huggingface_hub import hf_hub_download
# Special tokens
@@ -28,7 +34,7 @@ def text_to_tokens(self, text: str):
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
return text_tokens
- def encode( self, txt: str, verbose=False):
+ def encode(self, txt: str):
"""
clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer
"""
@@ -41,10 +47,269 @@ def decode(self, seq):
if isinstance(seq, torch.Tensor):
seq = seq.cpu().numpy()
- txt: str = self.tokenizer.decode(seq,
- skip_special_tokens=False)
+ txt: str = self.tokenizer.decode(seq, skip_special_tokens=False)
txt = txt.replace(' ', '')
txt = txt.replace(SPACE, ' ')
txt = txt.replace(EOT, '')
txt = txt.replace(UNK, '')
return txt
+
+
+# Model repository
+REPO_ID = "ResembleAI/chatterbox"
+
+# Global instances for optional dependencies
+_kakasi = None
+_dicta = None
+_russian_stresser = None
+
+
+def is_kanji(c: str) -> bool:
+ """Check if character is kanji."""
+ return 19968 <= ord(c) <= 40959
+
+
+def is_katakana(c: str) -> bool:
+ """Check if character is katakana."""
+ return 12449 <= ord(c) <= 12538
+
+
+def hiragana_normalize(text: str) -> str:
+ """Japanese text normalization: converts kanji to hiragana; katakana remains the same."""
+ global _kakasi
+
+ try:
+ if _kakasi is None:
+ import pykakasi
+ _kakasi = pykakasi.kakasi()
+
+ result = _kakasi.convert(text)
+ out = []
+
+ for r in result:
+ inp = r['orig']
+ hira = r["hira"]
+
+ # Any kanji in the phrase
+ if any([is_kanji(c) for c in inp]):
+ if hira and hira[0] in ["は", "へ"]: # Safety check for empty hira
+ hira = " " + hira
+ out.append(hira)
+
+ # All katakana
+ elif all([is_katakana(c) for c in inp]) if inp else False: # Safety check for empty inp
+ out.append(r['orig'])
+
+ else:
+ out.append(inp)
+
+ normalized_text = "".join(out)
+
+ # Decompose Japanese characters for tokenizer compatibility
+ import unicodedata
+ normalized_text = unicodedata.normalize('NFKD', normalized_text)
+
+ return normalized_text
+
+ except ImportError:
+ logger.warning("pykakasi not available - Japanese text processing skipped")
+ return text
+
+
+def add_hebrew_diacritics(text: str) -> str:
+ """Hebrew text normalization: adds diacritics to Hebrew text."""
+ global _dicta
+
+ try:
+ if _dicta is None:
+ from dicta_onnx import Dicta
+ _dicta = Dicta()
+
+ return _dicta.add_diacritics(text)
+
+ except ImportError:
+ logger.warning("dicta_onnx not available - Hebrew text processing skipped")
+ return text
+ except Exception as e:
+ logger.warning(f"Hebrew diacritization failed: {e}")
+ return text
+
+
+def korean_normalize(text: str) -> str:
+ """Korean text normalization: decompose syllables into Jamo for tokenization."""
+
+ def decompose_hangul(char):
+ """Decompose Korean syllable into Jamo components."""
+ if not ('\uac00' <= char <= '\ud7af'):
+ return char
+
+ # Hangul decomposition formula
+ base = ord(char) - 0xAC00
+ initial = chr(0x1100 + base // (21 * 28))
+ medial = chr(0x1161 + (base % (21 * 28)) // 28)
+ final = chr(0x11A7 + base % 28) if base % 28 > 0 else ''
+
+ return initial + medial + final
+
+ # Decompose syllables and normalize punctuation
+ result = ''.join(decompose_hangul(char) for char in text)
+ return result.strip()
+
+
+class ChineseCangjieConverter:
+ """Converts Chinese characters to Cangjie codes for tokenization."""
+
+ def __init__(self, model_dir=None):
+ self.word2cj = {}
+ self.cj2word = {}
+ self.segmenter = None
+ self._load_cangjie_mapping(model_dir)
+ self._init_segmenter()
+
+ def _load_cangjie_mapping(self, model_dir=None):
+ """Load Cangjie mapping from HuggingFace model repository."""
+ try:
+ cangjie_file = hf_hub_download(
+ repo_id=REPO_ID,
+ filename="Cangjie5_TC.json",
+ cache_dir=model_dir
+ )
+
+ with open(cangjie_file, "r", encoding="utf-8") as fp:
+ data = json.load(fp)
+
+ for entry in data:
+ word, code = entry.split("\t")[:2]
+ self.word2cj[word] = code
+ if code not in self.cj2word:
+ self.cj2word[code] = [word]
+ else:
+ self.cj2word[code].append(word)
+
+ except Exception as e:
+ logger.warning(f"Could not load Cangjie mapping: {e}")
+
+ def _init_segmenter(self):
+ """Initialize pkuseg segmenter."""
+ try:
+ from spacy_pkuseg import pkuseg
+ self.segmenter = pkuseg()
+ except ImportError:
+ logger.warning("pkuseg not available - Chinese segmentation will be skipped")
+ self.segmenter = None
+
+ def _cangjie_encode(self, glyph: str):
+ """Encode a single Chinese glyph to Cangjie code."""
+ normed_glyph = glyph
+ code = self.word2cj.get(normed_glyph, None)
+ if code is None: # e.g. Japanese hiragana
+ return None
+ index = self.cj2word[code].index(normed_glyph)
+ index = str(index) if index > 0 else ""
+ return code + str(index)
+
+
+
+ def __call__(self, text):
+ """Convert Chinese characters in text to Cangjie tokens."""
+ output = []
+ if self.segmenter is not None:
+ segmented_words = self.segmenter.cut(text)
+ full_text = " ".join(segmented_words)
+ else:
+ full_text = text
+
+ for t in full_text:
+ if category(t) == "Lo":
+ cangjie = self._cangjie_encode(t)
+ if cangjie is None:
+ output.append(t)
+ continue
+ code = []
+ for c in cangjie:
+ code.append(f"[cj_{c}]")
+ code.append("[cj_.]")
+ code = "".join(code)
+ output.append(code)
+ else:
+ output.append(t)
+ return "".join(output)
+
+
+def add_russian_stress(text: str) -> str:
+ """Russian text normalization: adds stress marks to Russian text."""
+ global _russian_stresser
+
+ try:
+ if _russian_stresser is None:
+ from russian_text_stresser.text_stresser import RussianTextStresser
+ _russian_stresser = RussianTextStresser()
+
+ return _russian_stresser.stress_text(text)
+
+ except ImportError:
+ logger.warning("russian_text_stresser not available - Russian stress labeling skipped")
+ return text
+ except Exception as e:
+ logger.warning(f"Russian stress labeling failed: {e}")
+ return text
+
+
+class MTLTokenizer:
+ def __init__(self, vocab_file_path):
+ self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
+ model_dir = Path(vocab_file_path).parent
+ self.cangjie_converter = ChineseCangjieConverter(model_dir)
+ self.check_vocabset_sot_eot()
+
+ def check_vocabset_sot_eot(self):
+ voc = self.tokenizer.get_vocab()
+ assert SOT in voc
+ assert EOT in voc
+
+ def preprocess_text(self, raw_text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
+ """
+ Text preprocessor that handles lowercase conversion and NFKD normalization.
+ """
+ preprocessed_text = raw_text
+ if lowercase:
+ preprocessed_text = preprocessed_text.lower()
+ if nfkd_normalize:
+ preprocessed_text = normalize("NFKD", preprocessed_text)
+
+ return preprocessed_text
+
+ def text_to_tokens(self, text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
+ text_tokens = self.encode(text, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize)
+ text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
+ return text_tokens
+
+ def encode(self, txt: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
+ txt = self.preprocess_text(txt, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize)
+
+ # Language-specific text processing
+ if language_id == 'zh':
+ txt = self.cangjie_converter(txt)
+ elif language_id == 'ja':
+ txt = hiragana_normalize(txt)
+ elif language_id == 'he':
+ txt = add_hebrew_diacritics(txt)
+ elif language_id == 'ko':
+ txt = korean_normalize(txt)
+ elif language_id == 'ru':
+ txt = add_russian_stress(txt)
+
+ # Prepend language token
+ if language_id:
+ txt = f"[{language_id.lower()}]{txt}"
+
+ txt = txt.replace(' ', SPACE)
+ return self.tokenizer.encode(txt).ids
+
+ def decode(self, seq):
+ if isinstance(seq, torch.Tensor):
+ seq = seq.cpu().numpy()
+
+ txt = self.tokenizer.decode(seq, skip_special_tokens=False)
+ txt = txt.replace(' ', '').replace(SPACE, ' ').replace(EOT, '').replace(UNK, '')
+ return txt
diff --git a/src/chatterbox/models/utils.py b/src/chatterbox/models/utils.py
index a4abce5d..e0e1d2ee 100644
--- a/src/chatterbox/models/utils.py
+++ b/src/chatterbox/models/utils.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/utils.py
+
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
diff --git a/src/chatterbox/models/voice_encoder/config.py b/src/chatterbox/models/voice_encoder/config.py
index 8e9782a2..b3bec4b7 100644
--- a/src/chatterbox/models/voice_encoder/config.py
+++ b/src/chatterbox/models/voice_encoder/config.py
@@ -1,3 +1,5 @@
+#chatterbox/src/chatterbox/models/voice_encoder/config.py
+
class VoiceEncConfig:
num_mels = 40
sample_rate = 16000
diff --git a/src/chatterbox/models/voice_encoder/melspec.py b/src/chatterbox/models/voice_encoder/melspec.py
index 69147fc8..ac18b474 100644
--- a/src/chatterbox/models/voice_encoder/melspec.py
+++ b/src/chatterbox/models/voice_encoder/melspec.py
@@ -1,3 +1,4 @@
+#chatterbox/src/chatterbox/models/voice_encoder/melspec.py
from functools import lru_cache
from scipy import signal
diff --git a/src/chatterbox/models/voice_encoder/voice_encoder.py b/src/chatterbox/models/voice_encoder/voice_encoder.py
index d986f17f..9bcf337d 100644
--- a/src/chatterbox/models/voice_encoder/voice_encoder.py
+++ b/src/chatterbox/models/voice_encoder/voice_encoder.py
@@ -1,3 +1,4 @@
+#chatterbox/src/chatterbox/models/voice_encoder/voice_encoder.py
# Adapted from https://github.com/CorentinJ/Real-Time-Voice-Cloning
# MIT License
from typing import List, Union, Optional
diff --git a/src/chatterbox/mtl_tts.py b/src/chatterbox/mtl_tts.py
new file mode 100644
index 00000000..2c9cf052
--- /dev/null
+++ b/src/chatterbox/mtl_tts.py
@@ -0,0 +1,301 @@
+from dataclasses import dataclass
+from pathlib import Path
+import os
+
+import librosa
+import torch
+import perth
+import torch.nn.functional as F
+from safetensors.torch import load_file as load_safetensors
+from huggingface_hub import snapshot_download
+
+from .models.t3 import T3
+from .models.t3.modules.t3_config import T3Config
+from .models.s3tokenizer import S3_SR, drop_invalid_tokens
+from .models.s3gen import S3GEN_SR, S3Gen
+from .models.tokenizers import MTLTokenizer
+from .models.voice_encoder import VoiceEncoder
+from .models.t3.modules.cond_enc import T3Cond
+
+
+REPO_ID = "ResembleAI/chatterbox"
+
+# Supported languages for the multilingual model
+SUPPORTED_LANGUAGES = {
+ "ar": "Arabic",
+ "da": "Danish",
+ "de": "German",
+ "el": "Greek",
+ "en": "English",
+ "es": "Spanish",
+ "fi": "Finnish",
+ "fr": "French",
+ "he": "Hebrew",
+ "hi": "Hindi",
+ "it": "Italian",
+ "ja": "Japanese",
+ "ko": "Korean",
+ "ms": "Malay",
+ "nl": "Dutch",
+ "no": "Norwegian",
+ "pl": "Polish",
+ "pt": "Portuguese",
+ "ru": "Russian",
+ "sv": "Swedish",
+ "sw": "Swahili",
+ "tr": "Turkish",
+ "zh": "Chinese",
+}
+
+
+def punc_norm(text: str) -> str:
+ """
+ Quick cleanup func for punctuation from LLMs or
+ containing chars not seen often in the dataset
+ """
+ if len(text) == 0:
+ return "You need to add some text for me to talk."
+
+ # Capitalise first letter
+ if text[0].islower():
+ text = text[0].upper() + text[1:]
+
+ # Remove multiple space chars
+ text = " ".join(text.split())
+
+ # Replace uncommon/llm punc
+ punc_to_replace = [
+ ("...", ", "),
+ ("…", ", "),
+ (":", ","),
+ (" - ", ", "),
+ (";", ", "),
+ ("—", "-"),
+ ("–", "-"),
+ (" ,", ","),
+ ("“", "\""),
+ ("”", "\""),
+ ("‘", "'"),
+ ("’", "'"),
+ ]
+ for old_char_sequence, new_char in punc_to_replace:
+ text = text.replace(old_char_sequence, new_char)
+
+ # Add full stop if no ending punc
+ text = text.rstrip(" ")
+ sentence_enders = {".", "!", "?", "-", ",","、",",","。","?","!"}
+ if not any(text.endswith(p) for p in sentence_enders):
+ text += "."
+
+ return text
+
+
+@dataclass
+class Conditionals:
+ """
+ Conditionals for T3 and S3Gen
+ - T3 conditionals:
+ - speaker_emb
+ - clap_emb
+ - cond_prompt_speech_tokens
+ - cond_prompt_speech_emb
+ - emotion_adv
+ - S3Gen conditionals:
+ - prompt_token
+ - prompt_token_len
+ - prompt_feat
+ - prompt_feat_len
+ - embedding
+ """
+ t3: T3Cond
+ gen: dict
+
+ def to(self, device):
+ self.t3 = self.t3.to(device=device)
+ for k, v in self.gen.items():
+ if torch.is_tensor(v):
+ self.gen[k] = v.to(device=device)
+ return self
+
+ def save(self, fpath: Path):
+ arg_dict = dict(
+ t3=self.t3.__dict__,
+ gen=self.gen
+ )
+ torch.save(arg_dict, fpath)
+
+ @classmethod
+ def load(cls, fpath, map_location="cpu"):
+ kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
+ return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
+
+
+class ChatterboxMultilingualTTS:
+ ENC_COND_LEN = 6 * S3_SR
+ DEC_COND_LEN = 10 * S3GEN_SR
+
+ def __init__(
+ self,
+ t3: T3,
+ s3gen: S3Gen,
+ ve: VoiceEncoder,
+ tokenizer: MTLTokenizer,
+ device: str,
+ conds: Conditionals = None,
+ ):
+ self.sr = S3GEN_SR # sample rate of synthesized audio
+ self.t3 = t3
+ self.s3gen = s3gen
+ self.ve = ve
+ self.tokenizer = tokenizer
+ self.device = device
+ self.conds = conds
+ self.watermarker = perth.PerthImplicitWatermarker()
+
+ @classmethod
+ def get_supported_languages(cls):
+ """Return dictionary of supported language codes and names."""
+ return SUPPORTED_LANGUAGES.copy()
+
+ @classmethod
+ def from_local(cls, ckpt_dir, device) -> 'ChatterboxMultilingualTTS':
+ ckpt_dir = Path(ckpt_dir)
+
+ ve = VoiceEncoder()
+ ve.load_state_dict(
+ torch.load(ckpt_dir / "ve.pt", weights_only=True)
+ )
+ ve.to(device).eval()
+
+ t3 = T3(T3Config.multilingual())
+ t3_state = load_safetensors(ckpt_dir / "t3_mtl23ls_v2.safetensors")
+ if "model" in t3_state.keys():
+ t3_state = t3_state["model"][0]
+ t3.load_state_dict(t3_state)
+ t3.to(device).eval()
+
+ s3gen = S3Gen()
+ s3gen.load_state_dict(
+ torch.load(ckpt_dir / "s3gen.pt", weights_only=True)
+ )
+ s3gen.to(device).eval()
+
+ tokenizer = MTLTokenizer(
+ str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json")
+ )
+
+ conds = None
+ if (builtin_voice := ckpt_dir / "conds.pt").exists():
+ conds = Conditionals.load(builtin_voice).to(device)
+
+ return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
+
+ @classmethod
+ def from_pretrained(cls, device: torch.device) -> 'ChatterboxMultilingualTTS':
+ ckpt_dir = Path(
+ snapshot_download(
+ repo_id=REPO_ID,
+ repo_type="model",
+ revision="main",
+ allow_patterns=["ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt", "grapheme_mtl_merged_expanded_v1.json", "conds.pt", "Cangjie5_TC.json"],
+ token=os.getenv("HF_TOKEN"),
+ )
+ )
+ return cls.from_local(ckpt_dir, device)
+
+ def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
+ ## Load reference wav
+ s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
+
+ ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
+
+ s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
+ s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
+
+ # Speech cond prompt tokens
+ t3_cond_prompt_tokens = None
+ if plen := self.t3.hp.speech_cond_prompt_len:
+ s3_tokzr = self.s3gen.tokenizer
+ t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen)
+ t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device)
+
+ # Voice-encoder speaker embedding
+ ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR))
+ ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
+
+ t3_cond = T3Cond(
+ speaker_emb=ve_embed,
+ cond_prompt_speech_tokens=t3_cond_prompt_tokens,
+ emotion_adv=exaggeration * torch.ones(1, 1, 1),
+ ).to(device=self.device)
+ self.conds = Conditionals(t3_cond, s3gen_ref_dict)
+
+ def generate(
+ self,
+ text,
+ language_id,
+ audio_prompt_path=None,
+ exaggeration=0.5,
+ cfg_weight=0.5,
+ temperature=0.8,
+ repetition_penalty=2.0,
+ min_p=0.05,
+ top_p=1.0,
+ ):
+ # Validate language_id
+ if language_id and language_id.lower() not in SUPPORTED_LANGUAGES:
+ supported_langs = ", ".join(SUPPORTED_LANGUAGES.keys())
+ raise ValueError(
+ f"Unsupported language_id '{language_id}'. "
+ f"Supported languages: {supported_langs}"
+ )
+
+ if audio_prompt_path:
+ self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
+ else:
+ assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
+
+ # Update exaggeration if needed
+ if float(exaggeration) != float(self.conds.t3.emotion_adv[0, 0, 0].item()):
+ _cond: T3Cond = self.conds.t3
+ self.conds.t3 = T3Cond(
+ speaker_emb=_cond.speaker_emb,
+ cond_prompt_speech_tokens=_cond.cond_prompt_speech_tokens,
+ emotion_adv=exaggeration * torch.ones(1, 1, 1),
+ ).to(device=self.device)
+
+ # Norm and tokenize text
+ text = punc_norm(text)
+ text_tokens = self.tokenizer.text_to_tokens(text, language_id=language_id.lower() if language_id else None).to(self.device)
+ text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG
+
+ sot = self.t3.hp.start_text_token
+ eot = self.t3.hp.stop_text_token
+ text_tokens = F.pad(text_tokens, (1, 0), value=sot)
+ text_tokens = F.pad(text_tokens, (0, 1), value=eot)
+
+ with torch.inference_mode():
+ speech_tokens = self.t3.inference(
+ t3_cond=self.conds.t3,
+ text_tokens=text_tokens,
+ max_new_tokens=1000, # TODO: use the value in config
+ temperature=temperature,
+ cfg_weight=cfg_weight,
+ repetition_penalty=repetition_penalty,
+ min_p=min_p,
+ top_p=top_p,
+ )
+ # Extract only the conditional batch.
+ speech_tokens = speech_tokens[0]
+
+ # TODO: output becomes 1D
+ speech_tokens = drop_invalid_tokens(speech_tokens)
+ speech_tokens = speech_tokens.to(self.device)
+
+ wav, _ = self.s3gen.inference(
+ speech_tokens=speech_tokens,
+ ref_dict=self.conds.gen,
+ )
+ wav = wav.squeeze(0).detach().cpu().numpy()
+ watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
+ return torch.from_numpy(watermarked_wav).unsqueeze(0)
diff --git a/src/chatterbox/tts.py b/src/chatterbox/tts.py
index 6d9b5ad5..b828d27c 100644
--- a/src/chatterbox/tts.py
+++ b/src/chatterbox/tts.py
@@ -1,3 +1,5 @@
+# chatterbox/src/chatterbox/tts.py
+
from dataclasses import dataclass
from pathlib import Path
@@ -7,10 +9,12 @@
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
+from typing import Union, List
from .models.t3 import T3
from .models.s3tokenizer import S3_SR, drop_invalid_tokens
from .models.s3gen import S3GEN_SR, S3Gen
+from .models.s3gen.const import TOKEN_TO_WAV_RATIO
from .models.tokenizers import EnTokenizer
from .models.voice_encoder import VoiceEncoder
from .models.t3.modules.cond_enc import T3Cond
@@ -81,11 +85,18 @@ class Conditionals:
t3: T3Cond
gen: dict
- def to(self, device):
- self.t3 = self.t3.to(device=device)
+ def to(self, device, dtype=None):
+ # Use the updated T3Cond.to method
+ self.t3 = self.t3.to(device=device, dtype=dtype)
+
+ # Update S3Gen conditionals (dictionary)
for k, v in self.gen.items():
if torch.is_tensor(v):
- self.gen[k] = v.to(device=device)
+ # Cast float tensors if dtype is provided
+ if dtype is not None and v.is_floating_point():
+ self.gen[k] = v.to(device=device, dtype=dtype)
+ else:
+ self.gen[k] = v.to(device=device)
return self
def save(self, fpath: Path):
@@ -115,23 +126,28 @@ def __init__(
tokenizer: EnTokenizer,
device: str,
conds: Conditionals = None,
+ dtype: torch.dtype = torch.float32, # Add dtype tracking
):
- self.sr = S3GEN_SR # sample rate of synthesized audio
+ self.sr = S3GEN_SR
self.t3 = t3
self.s3gen = s3gen
self.ve = ve
self.tokenizer = tokenizer
self.device = device
self.conds = conds
+ self.dtype = dtype # Store the dtype
self.watermarker = perth.PerthImplicitWatermarker()
@classmethod
- def from_local(cls, ckpt_dir, device) -> 'ChatterboxTTS':
+ def from_local(cls, ckpt_dir, device, dtype=torch.float32, compile_model=False) -> 'ChatterboxTTS':
ckpt_dir = Path(ckpt_dir)
- # Always load to CPU first for non-CUDA devices to handle CUDA-saved models
+ # Handle device mapping and enforce FP32 on non-CUDA devices
if device in ["cpu", "mps"]:
map_location = torch.device('cpu')
+ if dtype != torch.float32:
+ print(f"Note: Forcing FP32 on {device}.")
+ dtype = torch.float32
else:
map_location = None
@@ -139,20 +155,42 @@ def from_local(cls, ckpt_dir, device) -> 'ChatterboxTTS':
ve.load_state_dict(
load_file(ckpt_dir / "ve.safetensors")
)
+ # Keep VoiceEncoder in FP32 for stability.
ve.to(device).eval()
- t3 = T3()
+ t3 = T3(dtype=dtype)
t3_state = load_file(ckpt_dir / "t3_cfg.safetensors")
if "model" in t3_state.keys():
t3_state = t3_state["model"][0]
t3.load_state_dict(t3_state)
- t3.to(device).eval()
+ t3.to(device=device, dtype=dtype).eval() # Cast T3 to target dtype
s3gen = S3Gen()
s3gen.load_state_dict(
load_file(ckpt_dir / "s3gen.safetensors"), strict=False
)
- s3gen.to(device).eval()
+ s3gen.to(device=device, dtype=dtype).eval() # Cast S3Gen to target dtype
+
+ # ------------------------------------
+ # Compilation Logic (torch.compile)
+ if compile_model and hasattr(torch, 'compile') and device == 'cuda':
+ # Disable torch.compile on Windows due to Triton dependency:
+ import sys
+ if sys.platform == "win32":
+ print("Skipping torch.compile on Windows due to lack of official Triton support.")
+ print("Flash Attention 2 will be used for acceleration, which is often sufficient.")
+ else:
+ print("Compiling models... (This might take a few minutes on the first run)")
+ try:
+ # Compile the Llama backbone (the most critical part of T3)
+ # mode="reduce-overhead" is often better for smaller batch sizes during decoding.
+ t3.tfmr = torch.compile(t3.tfmr, mode="reduce-overhead")
+ # Compile S3Gen
+ s3gen = torch.compile(s3gen, mode="reduce-overhead")
+ print("Compilation successful.")
+ except Exception as e:
+ print(f"Warning: Model compilation failed: {e}. Proceeding without compilation.")
+ # ------------------------------------
tokenizer = EnTokenizer(
str(ckpt_dir / "tokenizer.json")
@@ -160,12 +198,13 @@ def from_local(cls, ckpt_dir, device) -> 'ChatterboxTTS':
conds = None
if (builtin_voice := ckpt_dir / "conds.pt").exists():
- conds = Conditionals.load(builtin_voice, map_location=map_location).to(device)
+ # Load and cast conditionals using the updated .to() method
+ conds = Conditionals.load(builtin_voice, map_location=map_location).to(device=device, dtype=dtype)
- return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
+ return cls(t3, s3gen, ve, tokenizer, device, conds=conds, dtype=dtype)
@classmethod
- def from_pretrained(cls, device) -> 'ChatterboxTTS':
+ def from_pretrained(cls, device, use_bf16=True, compile_model=False) -> 'ChatterboxTTS':
# Check if MPS is available on macOS
if device == "mps" and not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
@@ -174,68 +213,139 @@ def from_pretrained(cls, device) -> 'ChatterboxTTS':
print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.")
device = "cpu"
+ # Determine the dtype
+ dtype = torch.float32
+ if device == "cuda" and use_bf16:
+ if torch.cuda.is_bf16_supported():
+ dtype = torch.bfloat16
+ print("Using BFloat16 precision.")
+ else:
+ # Fallback to FP16
+ dtype = torch.float16
+ print("BFloat16 not supported. Using Float16 precision.")
+
for fpath in ["ve.safetensors", "t3_cfg.safetensors", "s3gen.safetensors", "tokenizer.json", "conds.pt"]:
local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath)
- return cls.from_local(Path(local_path).parent, device)
+ return cls.from_local(Path(local_path).parent, device, dtype=dtype, compile_model=compile_model)
- def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
+ def prepare_conditionals(self, wav_fpaths: Union[str, List[str]], exaggeration=0.5):
+ if isinstance(wav_fpaths, str):
+ wav_fpaths = [wav_fpaths]
+
## Load reference wav
- s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
+ s3gen_ref_wavs, ref_16k_wavs_np = [], []
+ for fpath in wav_fpaths:
+ s3gen_wav, _ = librosa.load(fpath, sr=S3GEN_SR)
+ s3gen_wav_tensor = torch.from_numpy(s3gen_wav[:self.DEC_COND_LEN])
+ s3gen_ref_wavs.append(s3gen_wav_tensor)
- ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
+ ref_16k_wav, _ = librosa.load(fpath, sr=S3_SR)
+ ref_16k_wavs_np.append(ref_16k_wav)
- s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
- s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
+ s3gen_ref_batch = torch.nn.utils.rnn.pad_sequence(s3gen_ref_wavs, batch_first=True)
+ s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_batch, S3GEN_SR, device=self.device)
# Speech cond prompt tokens
if plen := self.t3.hp.speech_cond_prompt_len:
s3_tokzr = self.s3gen.tokenizer
- t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen)
- t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device)
+ ref_16k_prompts = [wav[:self.ENC_COND_LEN] for wav in ref_16k_wavs_np]
+ t3_cond_prompt_tokens, _ = s3_tokzr.forward(ref_16k_prompts, max_len=plen)
+ t3_cond_prompt_tokens = t3_cond_prompt_tokens.to(self.device)
+ else:
+ t3_cond_prompt_tokens = None
# Voice-encoder speaker embedding
- ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR))
- ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
+ ve_embeds = self.ve.embeds_from_wavs(ref_16k_wavs_np, sample_rate=S3_SR)
+ ve_embed = torch.from_numpy(ve_embeds).unsqueeze(1).to(self.device)
+ batch_size = len(wav_fpaths)
t3_cond = T3Cond(
speaker_emb=ve_embed,
cond_prompt_speech_tokens=t3_cond_prompt_tokens,
- emotion_adv=exaggeration * torch.ones(1, 1, 1),
+ emotion_adv=exaggeration * torch.ones(batch_size, 1, 1),
).to(device=self.device)
self.conds = Conditionals(t3_cond, s3gen_ref_dict)
def generate(
self,
- text,
+ text: Union[str, List[str]],
repetition_penalty=1.2,
min_p=0.05,
top_p=1.0,
- audio_prompt_path=None,
+ audio_prompt_path: Union[str, List[str]] = None,
exaggeration=0.5,
cfg_weight=0.5,
temperature=0.8,
- ):
+ num_return_sequences=1,
+ ) -> Union[torch.Tensor, List[torch.Tensor], List[List[torch.Tensor]]]:
+ is_single_input = isinstance(text, str)
+ if is_single_input:
+ text = [text]
+ batch_size = len(text)
+
if audio_prompt_path:
self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
else:
assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
- # Update exaggeration if needed
- if exaggeration != self.conds.t3.emotion_adv[0, 0, 0]:
- _cond: T3Cond = self.conds.t3
- self.conds.t3 = T3Cond(
- speaker_emb=_cond.speaker_emb,
- cond_prompt_speech_tokens=_cond.cond_prompt_speech_tokens,
- emotion_adv=exaggeration * torch.ones(1, 1, 1),
- ).to(device=self.device)
+ # --- CRITICAL: Ensure Conditionals match Model dtype ---
+ # The prepare_conditionals uses VE (FP32) to create embeddings.
+ # We must ensure they match T3's dtype (e.g., BF16) before inference.
+ if self.conds.t3.speaker_emb.dtype != self.dtype:
+ # Cast all floating point conditionals to match the model dtype
+ self.conds = self.conds.to(device=self.device, dtype=self.dtype)
+
+ # Broadcast conditioning if a single prompt is used for a batch of texts
+ current_cond_bs = self.conds.t3.speaker_emb.size(0)
+ if current_cond_bs == 1 and batch_size > 1:
+ t3c = self.conds.t3
+ t3c.speaker_emb = t3c.speaker_emb.expand(batch_size, -1, -1)
+ if t3c.cond_prompt_speech_tokens is not None:
+ t3c.cond_prompt_speech_tokens = t3c.cond_prompt_speech_tokens.expand(batch_size, -1)
+ if t3c.emotion_adv is not None:
+ t3c.emotion_adv = t3c.emotion_adv.expand(batch_size, -1, -1)
+
+ gend = self.conds.gen
+ for k, v in gend.items():
+ if torch.is_tensor(v):
+ if k.endswith("_len"):
+ gend[k] = v.expand(batch_size)
+ else:
+ gend[k] = v.expand(batch_size, *v.shape[1:])
+ elif current_cond_bs != batch_size and not (current_cond_bs == 1 and batch_size == 1):
+ raise ValueError(f"Mismatch between number of texts ({batch_size}) and audio prompts ({current_cond_bs})")
+
+ # Update exaggeration if needed (Ensure dtype consistency and robust checks)
+ if self.conds.t3.emotion_adv is not None and self.conds.t3.emotion_adv.numel() > 0:
+ # Use .item() for safe comparison and ensure the new tensor matches dtype
+ if exaggeration != self.conds.t3.emotion_adv[0, 0, 0].item():
+ self.conds.t3.emotion_adv = exaggeration * torch.ones(batch_size, 1, 1, device=self.device, dtype=self.dtype)
+ elif exaggeration != 0.5: # Default value check
+ # Handle case where it might be None initially but a new value is provided
+ self.conds.t3.emotion_adv = exaggeration * torch.ones(batch_size, 1, 1, device=self.device, dtype=self.dtype)
# Norm and tokenize text
- text = punc_norm(text)
- text_tokens = self.tokenizer.text_to_tokens(text).to(self.device)
+ texts = [punc_norm(t) for t in text]
+ tokenized_texts = [self.tokenizer.text_to_tokens(t).squeeze(0) for t in texts]
+ text_tokens = torch.nn.utils.rnn.pad_sequence(tokenized_texts, batch_first=True, padding_value=0).to(self.device)
+
+ # --- Start: Logic for num_return_sequences and CFG ---
+ t3_cond = self.conds.t3
+ gen_cond = self.conds.gen
+
+ # Expand inputs for num_return_sequences
+ if num_return_sequences > 1:
+ text_tokens = text_tokens.repeat_interleave(num_return_sequences, dim=0)
+ t3_cond = T3Cond(**{k: v.repeat_interleave(num_return_sequences, dim=0) if torch.is_tensor(v) else v for k, v in t3_cond.__dict__.items()})
+ gen_cond = {k: v.repeat_interleave(num_return_sequences, dim=0) if torch.is_tensor(v) else v for k, v in gen_cond.items()}
if cfg_weight > 0.0:
- text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG
+ # Duplicate text tokens and conditioning tensors for CFG
+ text_tokens = torch.cat([text_tokens, text_tokens], dim=0)
+ t3_cond = T3Cond(**{k: torch.cat([v, v], dim=0) if torch.is_tensor(v) else v for k, v in t3_cond.__dict__.items()})
+ # --- End: Logic for num_return_sequences and CFG ---
+
sot = self.t3.hp.start_text_token
eot = self.t3.hp.stop_text_token
@@ -243,30 +353,46 @@ def generate(
text_tokens = F.pad(text_tokens, (0, 1), value=eot)
with torch.inference_mode():
- speech_tokens = self.t3.inference(
- t3_cond=self.conds.t3,
+ # T3 generates a list of variable-length token sequences
+ speech_tokens_list = self.t3.inference(
+ t3_cond=t3_cond,
text_tokens=text_tokens,
- max_new_tokens=1000, # TODO: use the value in config
+ max_new_tokens=1000,
temperature=temperature,
cfg_weight=cfg_weight,
repetition_penalty=repetition_penalty,
min_p=min_p,
top_p=top_p,
)
- # Extract only the conditional batch.
- speech_tokens = speech_tokens[0]
-
- # TODO: output becomes 1D
- speech_tokens = drop_invalid_tokens(speech_tokens)
-
- speech_tokens = speech_tokens[speech_tokens < 6561]
- speech_tokens = speech_tokens.to(self.device)
+ # Pad for filtering, filter, and pad again for S3Gen
+ speech_tokens_padded = torch.nn.utils.rnn.pad_sequence(speech_tokens_list, batch_first=True, padding_value=self.t3.hp.stop_speech_token)
+ clean_tokens_list = drop_invalid_tokens(speech_tokens_padded)
+ s3gen_tokens_padded = torch.nn.utils.rnn.pad_sequence(clean_tokens_list, batch_first=True, padding_value=0)
+ s3gen_token_lens = torch.tensor([len(t) for t in clean_tokens_list], device=self.device)
- wav, _ = self.s3gen.inference(
- speech_tokens=speech_tokens,
- ref_dict=self.conds.gen,
+ wavs, _ = self.s3gen.inference(
+ speech_tokens=s3gen_tokens_padded.to(self.device),
+ speech_token_lens=s3gen_token_lens,
+ ref_dict=gen_cond,
)
- wav = wav.squeeze(0).detach().cpu().numpy()
- watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
- return torch.from_numpy(watermarked_wav).unsqueeze(0)
\ No newline at end of file
+
+ # Trim padding noise
+ audio_lengths = s3gen_token_lens * TOKEN_TO_WAV_RATIO
+ output_tensors = []
+ for i, wav in enumerate(wavs):
+ # Ensure output is FP32 for watermarking/saving compatibility
+ trimmed_wav = wav[:audio_lengths[i]].cpu().float().numpy()
+ watermarked_wav = self.watermarker.apply_watermark(trimmed_wav, sample_rate=self.sr)
+ output_tensors.append(torch.from_numpy(watermarked_wav).unsqueeze(0))
+
+ if num_return_sequences > 1:
+ # Group the flat list of outputs into a list of lists
+ grouped_outputs = [output_tensors[i:i + num_return_sequences] for i in range(0, len(output_tensors), num_return_sequences)]
+ if is_single_input:
+ return grouped_outputs[0]
+ return grouped_outputs
+
+ if is_single_input:
+ return output_tensors[0]
+ return output_tensors
\ No newline at end of file
diff --git a/src/chatterbox/vc.py b/src/chatterbox/vc.py
index a9c32ed3..c3de4f11 100644
--- a/src/chatterbox/vc.py
+++ b/src/chatterbox/vc.py
@@ -1,3 +1,5 @@
+# chatterbox/src/chatterbox/vc.py
+
from pathlib import Path
import librosa
@@ -5,9 +7,11 @@
import perth
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
+from typing import List, Union
from .models.s3tokenizer import S3_SR
from .models.s3gen import S3GEN_SR, S3Gen
+from .models.s3gen.const import TOKEN_TO_WAV_RATIO
REPO_ID = "ResembleAI/chatterbox"
@@ -16,12 +20,13 @@
class ChatterboxVC:
ENC_COND_LEN = 6 * S3_SR
DEC_COND_LEN = 10 * S3GEN_SR
+ TOKEN_TO_WAV_RATIO = 960
def __init__(
self,
s3gen: S3Gen,
device: str,
- ref_dict: dict=None,
+ ref_dict: dict = None,
):
self.sr = S3GEN_SR
self.s3gen = s3gen
@@ -38,13 +43,13 @@ def __init__(
@classmethod
def from_local(cls, ckpt_dir, device) -> 'ChatterboxVC':
ckpt_dir = Path(ckpt_dir)
-
+
# Always load to CPU first for non-CUDA devices to handle CUDA-saved models
if device in ["cpu", "mps"]:
map_location = torch.device('cpu')
else:
map_location = None
-
+
ref_dict = None
if (builtin_voice := ckpt_dir / "conds.pt").exists():
states = torch.load(builtin_voice, map_location=map_location)
@@ -67,38 +72,74 @@ def from_pretrained(cls, device) -> 'ChatterboxVC':
else:
print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.")
device = "cpu"
-
+
for fpath in ["s3gen.safetensors", "conds.pt"]:
local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath)
return cls.from_local(Path(local_path).parent, device)
- def set_target_voice(self, wav_fpath):
- ## Load reference wav
- s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
-
- s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
- self.ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
+ def set_target_voice(self, wav_fpaths: Union[str, List[str]]):
+ if isinstance(wav_fpaths, str):
+ wav_fpaths = [wav_fpaths]
+
+ s3gen_ref_wavs = []
+ for fpath in wav_fpaths:
+ s3gen_ref_wav, _ = librosa.load(fpath, sr=S3GEN_SR)
+ s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
+ s3gen_ref_wavs.append(torch.from_numpy(s3gen_ref_wav))
+
+ s3gen_ref_batch = torch.nn.utils.rnn.pad_sequence(s3gen_ref_wavs, batch_first=True)
+ self.ref_dict = self.s3gen.embed_ref(s3gen_ref_batch, S3GEN_SR, device=self.device)
def generate(
self,
- audio,
- target_voice_path=None,
- ):
+ audio: Union[str, List[str]],
+ target_voice_path: Union[str, List[str]] = None,
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
+ is_single_input = isinstance(audio, str)
+ if is_single_input:
+ audio = [audio]
+ batch_size = len(audio)
+
if target_voice_path:
self.set_target_voice(target_voice_path)
else:
- assert self.ref_dict is not None, "Please `prepare_conditionals` first or specify `target_voice_path`"
+ assert self.ref_dict is not None, "Please call `set_target_voice` first or specify `target_voice_path`"
+
+ # Broadcast conditioning if a single prompt is used for a batch of inputs
+ current_cond_bs = self.ref_dict['embedding'].size(0)
+ if current_cond_bs == 1 and batch_size > 1:
+ for k, v in self.ref_dict.items():
+ if torch.is_tensor(v):
+ if k.endswith("_len"):
+ self.ref_dict[k] = v.expand(batch_size)
+ else:
+ self.ref_dict[k] = v.expand(batch_size, *v.shape[1:])
+ elif current_cond_bs != batch_size and not (current_cond_bs == 1 and batch_size == 1):
+ raise ValueError(f"Mismatch between number of source audios ({batch_size}) and target voice paths ({current_cond_bs})")
with torch.inference_mode():
- audio_16, _ = librosa.load(audio, sr=S3_SR)
- audio_16 = torch.from_numpy(audio_16).float().to(self.device)[None, ]
+ audios_16k = []
+ for a in audio:
+ audio_16, _ = librosa.load(a, sr=S3_SR)
+ audios_16k.append(torch.from_numpy(audio_16).float())
+
+ audio_16_padded = torch.nn.utils.rnn.pad_sequence(audios_16k, batch_first=True).to(self.device)
- s3_tokens, _ = self.s3gen.tokenizer(audio_16)
- wav, _ = self.s3gen.inference(
+ s3_tokens, s3_token_lens = self.s3gen.tokenizer(audio_16_padded)
+ wavs, _ = self.s3gen.inference(
speech_tokens=s3_tokens,
+ speech_token_lens=s3_token_lens,
ref_dict=self.ref_dict,
)
- wav = wav.squeeze(0).detach().cpu().numpy()
- watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
- return torch.from_numpy(watermarked_wav).unsqueeze(0)
\ No newline at end of file
+ # Trim padding noise
+ audio_lengths = s3_token_lens * TOKEN_TO_WAV_RATIO
+ output_tensors = []
+ for i, wav in enumerate(wavs):
+ trimmed_wav = wav[:audio_lengths[i]].cpu().numpy()
+ watermarked_wav = self.watermarker.apply_watermark(trimmed_wav, sample_rate=self.sr)
+ output_tensors.append(torch.from_numpy(watermarked_wav).unsqueeze(0))
+
+ if is_single_input:
+ return output_tensors[0]
+ return output_tensors
\ No newline at end of file