1919import json
2020import os
2121import weakref
22+ from collections import deque
2223from dataclasses import dataclass , replace
2324from typing import Any , Union , cast
2425
@@ -171,6 +172,19 @@ def __init__(
171172 elif isinstance (text_pacing , tts .SentenceStreamPacer ):
172173 self ._stream_pacer = text_pacing
173174
175+ if word_timestamps :
176+ if "preview" not in self ._opts .model and self ._opts .language not in {
177+ "en" ,
178+ "de" ,
179+ "es" ,
180+ "fr" ,
181+ }:
182+ # https://docs.cartesia.ai/api-reference/tts/compare-tts-endpoints
183+ logger .warning (
184+ "word_timestamps is only supported for languages en, de, es, and fr with `sonic` models"
185+ " or all languages with `preview` models"
186+ )
187+
174188 @property
175189 def model (self ) -> str :
176190 return self ._opts .model
@@ -348,6 +362,7 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
348362 stream = True ,
349363 )
350364 input_sent_event = asyncio .Event ()
365+ sent_tokens = deque [str ]()
351366
352367 sent_tokenizer_stream = self ._tts ._sentence_tokenizer .stream ()
353368 if self ._tts ._stream_pacer :
@@ -363,6 +378,7 @@ async def _sentence_stream_task(ws: aiohttp.ClientWebSocketResponse) -> None:
363378 token_pkt = base_pkt .copy ()
364379 token_pkt ["context_id" ] = context_id
365380 token_pkt ["transcript" ] = ev .token + " "
381+ sent_tokens .append (ev .token + " " )
366382 token_pkt ["continue" ] = True
367383 self ._mark_started ()
368384 await ws .send_str (json .dumps (token_pkt ))
@@ -371,6 +387,7 @@ async def _sentence_stream_task(ws: aiohttp.ClientWebSocketResponse) -> None:
371387 end_pkt = base_pkt .copy ()
372388 end_pkt ["context_id" ] = context_id
373389 end_pkt ["transcript" ] = " "
390+ sent_tokens .append (" " )
374391 end_pkt ["continue" ] = False
375392 await ws .send_str (json .dumps (end_pkt ))
376393 input_sent_event .set ()
@@ -387,6 +404,7 @@ async def _input_task() -> None:
387404 async def _recv_task (ws : aiohttp .ClientWebSocketResponse ) -> None :
388405 current_segment_id : str | None = None
389406 await input_sent_event .wait ()
407+ skip_aligning = False
390408 while True :
391409 msg = await ws .receive (timeout = self ._conn_options .timeout )
392410 if msg .type in (
@@ -416,10 +434,26 @@ async def _recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
416434 output_emitter .end_input ()
417435 break
418436 elif word_timestamps := data .get ("word_timestamps" ):
437+ # assuming Cartesia echos the sent text in the original format and order.
419438 for word , start , end in zip (
420439 word_timestamps ["words" ], word_timestamps ["start" ], word_timestamps ["end" ]
421440 ):
422- word = f"{ word } " # TODO(long): any better way to format the words?
441+ if not sent_tokens or skip_aligning :
442+ word = f"{ word } "
443+ skip_aligning = True
444+ else :
445+ sent = sent_tokens .popleft ()
446+ if (idx := sent .find (word )) != - 1 :
447+ word , sent = sent [: idx + len (word )], sent [idx + len (word ) :]
448+ if sent .strip ():
449+ sent_tokens .appendleft (sent )
450+ elif sent and sent_tokens :
451+ # merge the remaining whitespace to the next sentence
452+ sent_tokens [0 ] = sent + sent_tokens [0 ]
453+ else :
454+ word = f"{ word } "
455+ skip_aligning = True
456+
423457 output_emitter .push_timed_transcript (
424458 TimedString (text = word , start_time = start , end_time = end )
425459 )
0 commit comments