Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions mlx_audio/tts/models/qwen3_tts/qwen3_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,8 @@ def generate(
top_p=top_p,
repetition_penalty=icl_rep_penalty,
verbose=verbose,
stream=stream,
streaming_interval=streaming_interval,
)
return

Expand Down Expand Up @@ -1257,6 +1259,8 @@ def _generate_icl(
top_p: float = 1.0,
repetition_penalty: float = 1.5,
verbose: bool = False,
stream: bool = False,
streaming_interval: float = 2.0,
) -> Generator[GenerationResult, None, None]:
"""Generate speech using ICL (In-Context Learning) voice cloning.

Expand Down Expand Up @@ -1307,6 +1311,12 @@ def _generate_icl(
leave=False,
)

# Streaming state
# At 12.5 Hz, 25 tokens ≈ 2 seconds of audio
streaming_chunk_size = max(1, int(streaming_interval * 12.5))
decoded_tokens = 0 # Track how many tokens we've decoded and yielded
context_size = 25 # Overlap tokens for smooth audio transitions (25 gives ~0.04% error vs full decode)

for step in range(effective_max_tokens):
# Forward pass through talker
logits, hidden = self.talker(input_embeds, cache=cache)
Expand Down Expand Up @@ -1362,6 +1372,9 @@ def _generate_icl(
all_codes = mx.concatenate(code_tokens, axis=1)
generated_codes.append(all_codes)

del code_cache
mx.clear_cache()

# Prepare next input
if trailing_idx < trailing_text_hidden.shape[1]:
text_embed = trailing_text_hidden[:, trailing_idx : trailing_idx + 1, :]
Expand All @@ -1378,10 +1391,99 @@ def _generate_icl(
input_embeds = text_embed + codec_embed
mx.eval(input_embeds)

# Periodically clear cache to prevent memory buildup during long generation
if step > 0 and step % 50 == 0:
mx.clear_cache()

pbar.update(1)

# Streaming: decode and yield audio chunks during generation
new_tokens = len(generated_codes) - decoded_tokens
if stream and new_tokens >= streaming_chunk_size:
# Include context from previous tokens for smooth transitions
start_idx = max(0, decoded_tokens - context_size)
codes_chunk = mx.stack(generated_codes[start_idx:], axis=1)
mx.eval(codes_chunk)

audio_chunk = self._decode_chunk(
codes_chunk, chunk_tokens=streaming_chunk_size
)

# Trim the context overlap from audio (only yield new audio)
if decoded_tokens > 0 and start_idx < decoded_tokens:
context_tokens = decoded_tokens - start_idx
samples_per_token = self.speech_tokenizer.decode_upsample_rate
trim_samples = context_tokens * samples_per_token
if trim_samples < audio_chunk.shape[0]:
audio_chunk = audio_chunk[trim_samples:]

decoded_tokens = len(generated_codes)

yield GenerationResult(
audio=audio_chunk,
samples=audio_chunk.shape[0],
sample_rate=self.sample_rate,
segment_idx=0,
token_count=new_tokens,
audio_duration=format_duration(
audio_chunk.shape[0] / self.sample_rate
),
real_time_factor=0,
prompt={"tokens": new_tokens, "tokens-per-sec": 0},
audio_samples={
"samples": audio_chunk.shape[0],
"samples-per-sec": 0,
},
processing_time_seconds=0,
peak_memory_usage=mx.get_peak_memory() / 1e9,
is_streaming_chunk=True,
)

mx.clear_cache()

pbar.close()

# Yield any remaining tokens
if stream and len(generated_codes) > decoded_tokens:
# Include context from previous tokens for smooth transitions
start_idx = max(0, decoded_tokens - context_size)
codes_chunk = mx.stack(generated_codes[start_idx:], axis=1)
mx.eval(codes_chunk)

audio_chunk = self._decode_chunk(
codes_chunk, chunk_tokens=streaming_chunk_size
)

# Trim the context overlap from audio (only yield new audio)
if decoded_tokens > 0 and start_idx < decoded_tokens:
context_tokens = decoded_tokens - start_idx
samples_per_token = self.speech_tokenizer.decode_upsample_rate
trim_samples = context_tokens * samples_per_token
if trim_samples < audio_chunk.shape[0]:
audio_chunk = audio_chunk[trim_samples:]

new_tokens = len(generated_codes) - decoded_tokens

yield GenerationResult(
audio=audio_chunk,
samples=audio_chunk.shape[0],
sample_rate=self.sample_rate,
segment_idx=0,
token_count=new_tokens,
audio_duration=format_duration(audio_chunk.shape[0] / self.sample_rate),
real_time_factor=0,
prompt={"tokens": new_tokens, "tokens-per-sec": 0},
audio_samples={
"samples": audio_chunk.shape[0],
"samples-per-sec": 0,
},
processing_time_seconds=0,
peak_memory_usage=mx.get_peak_memory() / 1e9,
is_streaming_chunk=True,
is_final_chunk=True,
)
return # Skip non-streaming yield

if not generated_codes:
return

Expand Down