Skip to content

Commit a18d5e9

Browse files
authored
add better word alignment for Cartesia (#3876)
Two irrelevant checks failed, skipping them for now.
1 parent 758fd13 commit a18d5e9

File tree

1 file changed

+35
-1
lines changed
  • livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia

1 file changed

+35
-1
lines changed

livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/tts.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import json
2020
import os
2121
import weakref
22+
from collections import deque
2223
from dataclasses import dataclass, replace
2324
from typing import Any, Union, cast
2425

@@ -171,6 +172,19 @@ def __init__(
171172
elif isinstance(text_pacing, tts.SentenceStreamPacer):
172173
self._stream_pacer = text_pacing
173174

175+
if word_timestamps:
176+
if "preview" not in self._opts.model and self._opts.language not in {
177+
"en",
178+
"de",
179+
"es",
180+
"fr",
181+
}:
182+
# https://docs.cartesia.ai/api-reference/tts/compare-tts-endpoints
183+
logger.warning(
184+
"word_timestamps is only supported for languages en, de, es, and fr with `sonic` models"
185+
" or all languages with `preview` models"
186+
)
187+
174188
@property
175189
def model(self) -> str:
176190
return self._opts.model
@@ -348,6 +362,7 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
348362
stream=True,
349363
)
350364
input_sent_event = asyncio.Event()
365+
sent_tokens = deque[str]()
351366

352367
sent_tokenizer_stream = self._tts._sentence_tokenizer.stream()
353368
if self._tts._stream_pacer:
@@ -363,6 +378,7 @@ async def _sentence_stream_task(ws: aiohttp.ClientWebSocketResponse) -> None:
363378
token_pkt = base_pkt.copy()
364379
token_pkt["context_id"] = context_id
365380
token_pkt["transcript"] = ev.token + " "
381+
sent_tokens.append(ev.token + " ")
366382
token_pkt["continue"] = True
367383
self._mark_started()
368384
await ws.send_str(json.dumps(token_pkt))
@@ -371,6 +387,7 @@ async def _sentence_stream_task(ws: aiohttp.ClientWebSocketResponse) -> None:
371387
end_pkt = base_pkt.copy()
372388
end_pkt["context_id"] = context_id
373389
end_pkt["transcript"] = " "
390+
sent_tokens.append(" ")
374391
end_pkt["continue"] = False
375392
await ws.send_str(json.dumps(end_pkt))
376393
input_sent_event.set()
@@ -387,6 +404,7 @@ async def _input_task() -> None:
387404
async def _recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
388405
current_segment_id: str | None = None
389406
await input_sent_event.wait()
407+
skip_aligning = False
390408
while True:
391409
msg = await ws.receive(timeout=self._conn_options.timeout)
392410
if msg.type in (
@@ -416,10 +434,26 @@ async def _recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
416434
output_emitter.end_input()
417435
break
418436
elif word_timestamps := data.get("word_timestamps"):
437+
# assuming Cartesia echos the sent text in the original format and order.
419438
for word, start, end in zip(
420439
word_timestamps["words"], word_timestamps["start"], word_timestamps["end"]
421440
):
422-
word = f"{word} " # TODO(long): any better way to format the words?
441+
if not sent_tokens or skip_aligning:
442+
word = f"{word} "
443+
skip_aligning = True
444+
else:
445+
sent = sent_tokens.popleft()
446+
if (idx := sent.find(word)) != -1:
447+
word, sent = sent[: idx + len(word)], sent[idx + len(word) :]
448+
if sent.strip():
449+
sent_tokens.appendleft(sent)
450+
elif sent and sent_tokens:
451+
# merge the remaining whitespace to the next sentence
452+
sent_tokens[0] = sent + sent_tokens[0]
453+
else:
454+
word = f"{word} "
455+
skip_aligning = True
456+
423457
output_emitter.push_timed_transcript(
424458
TimedString(text=word, start_time=start, end_time=end)
425459
)

0 commit comments

Comments
 (0)