Skip to content

Commit eadb06b

Browse files
committed
Adding support for new bot-output RTVI Message:
1. TTSTextFrames now include metadata about whether the text was spoken or not along with a type string to describe what the text represents: ex. "sentence", "word", "custom aggregation" 2. Expanded how aggregators work so that the aggregate method returns aggregated text along with the type of aggregation used to create it 3. Deprecated the RTVI bot-transcription event in lieu of... 4. Introduced support for a new bot-output event. This event is meant to be the one stop shop for communicating what the bot actually "says". It is based off TTSTextFrames to communicate both sentence by sentence (or whatever aggregation is used) as well as word by word. In addition, it will include LLMTextFrames, aggregated by sentence when tts is turned off (i.e. skip_tts is true). Resolves pipecat-ai/pipecat-client-web#158
1 parent 84ed246 commit eadb06b

File tree

12 files changed

+268
-108
lines changed

12 files changed

+268
-108
lines changed

src/pipecat/frames/frames.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,9 @@ class LLMTextFrame(TextFrame):
351351
class TTSTextFrame(TextFrame):
352352
"""Text frame generated by Text-to-Speech services."""
353353

354+
aggregated_by: Literal["sentence", "word"] | str
355+
spoken: Optional[bool] = True # Whether this text has been spoken by TTS
356+
354357
pass
355358

356359

src/pipecat/processors/frameworks/rtvi.py

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,29 @@ class RTVITextMessageData(BaseModel):
704704
text: str
705705

706706

707+
class RTVIBotOutputMessageData(RTVITextMessageData):
708+
"""Data for bot output RTVI messages.
709+
710+
Extends RTVITextMessageData to include metadata about the output.
711+
"""
712+
713+
spoken: bool = True # Indicates if the text has been spoken by TTS
714+
aggregated_by: Optional[Literal["word", "sentence"] | str] = None
715+
# Indicates what form the text is in (e.g., by word, sentence, etc.)
716+
717+
718+
class RTVIBotOutputMessage(BaseModel):
719+
"""Message containing bot output text.
720+
721+
An event meant to wholistically represent what the bot is outputting,
722+
along with metadata about the output and if it has been spoken.
723+
"""
724+
725+
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
726+
type: Literal["bot-output"] = "bot-output"
727+
data: RTVIBotOutputMessageData
728+
729+
707730
class RTVIBotTranscriptionMessage(BaseModel):
708731
"""Message containing bot transcription text.
709732
@@ -960,6 +983,8 @@ def __init__(
960983
self._last_user_audio_level = 0
961984
self._last_bot_audio_level = 0
962985

986+
self._skip_tts = None
987+
963988
if self._params.system_logs_enabled:
964989
self._system_logger_id = logger.add(self._logger_sink)
965990

@@ -1050,8 +1075,7 @@ async def on_push_frame(self, data: FramePushed):
10501075
await self.send_rtvi_message(RTVIBotTTSStoppedMessage())
10511076
elif isinstance(frame, TTSTextFrame) and self._params.bot_tts_enabled:
10521077
if isinstance(src, BaseOutputTransport):
1053-
message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text))
1054-
await self.send_rtvi_message(message)
1078+
await self._handle_tts_text_frame(frame)
10551079
else:
10561080
mark_as_seen = False
10571081
elif isinstance(frame, MetricsFrame) and self._params.metrics_enabled:
@@ -1115,14 +1139,63 @@ async def _handle_bot_speaking(self, frame: Frame):
11151139
if message:
11161140
await self.send_rtvi_message(message)
11171141

1142+
async def _handle_tts_text_frame(self, frame: TTSTextFrame):
1143+
"""Handle TTS text output frames."""
1144+
# send the tts-text message
1145+
message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text))
1146+
await self.send_rtvi_message(message)
1147+
# send the bot-output message
1148+
message = RTVIBotOutputMessage(
1149+
data=RTVIBotOutputMessageData(
1150+
text=frame.text, spoken=frame.spoken, aggregated_by=frame.aggregated_by
1151+
)
1152+
)
1153+
await self.send_rtvi_message(message)
1154+
11181155
async def _handle_llm_text_frame(self, frame: LLMTextFrame):
11191156
"""Handle LLM text output frames."""
11201157
message = RTVIBotLLMTextMessage(data=RTVITextMessageData(text=frame.text))
11211158
await self.send_rtvi_message(message)
11221159

1160+
# initialize skip_tts on first LLMTextFrame
1161+
if self._skip_tts is None:
1162+
self._skip_tts = frame.skip_tts
1163+
1164+
messages = []
1165+
should_reset_transcription = False
11231166
self._bot_transcription += frame.text
1124-
if match_endofsentence(self._bot_transcription):
1125-
await self._push_bot_transcription()
1167+
1168+
if not frame.skip_tts and self._skip_tts:
1169+
# We just switched from skipping TTS to not skipping TTS.
1170+
# Send and reset any existing transcription.
1171+
if len(self._bot_transcription) > 0:
1172+
message.append(
1173+
RTVIBotOutputMessage(
1174+
data=RTVIBotOutputMessageData(
1175+
text=self._bot_transcription, spoken=False, aggregated_by="sentence"
1176+
)
1177+
)
1178+
)
1179+
should_reset_transcription = True
1180+
1181+
if match_endofsentence(self._bot_transcription) and len(self._bot_transcription) > 0:
1182+
messages.append(
1183+
RTVIBotTranscriptionMessage(data=RTVITextMessageData(text=self._bot_transcription))
1184+
)
1185+
if frame.skip_tts:
1186+
messages.append(
1187+
RTVIBotOutputMessage(
1188+
data=RTVIBotOutputMessageData(
1189+
text=self._bot_transcription, spoken=False, aggregated_by="sentence"
1190+
)
1191+
)
1192+
)
1193+
should_reset_transcription = True
1194+
1195+
for msg in messages:
1196+
await self.send_rtvi_message(msg)
1197+
if should_reset_transcription:
1198+
self._bot_transcription = ""
11261199

11271200
async def _handle_user_transcriptions(self, frame: Frame):
11281201
"""Handle user transcription frames."""

src/pipecat/services/aws/nova_sonic/llm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,7 +1027,7 @@ async def _report_assistant_response_text_added(self, text):
10271027
logger.debug(f"Assistant response text added: {text}")
10281028

10291029
# Report the text of the assistant response.
1030-
await self.push_frame(TTSTextFrame(text))
1030+
await self.push_frame(TTSTextFrame(text, aggregated_by="sentence", spoken=True))
10311031

10321032
# HACK: here we're also buffering the assistant text ourselves as a
10331033
# backup rather than relying solely on the assistant context aggregator
@@ -1060,7 +1060,9 @@ async def _report_assistant_response_ended(self):
10601060
# TTSTextFrame would be ignored otherwise (the interruption frame
10611061
# would have cleared the assistant aggregator state).
10621062
await self.push_frame(LLMFullResponseStartFrame())
1063-
await self.push_frame(TTSTextFrame(self._assistant_text_buffer))
1063+
await self.push_frame(
1064+
TTSTextFrame(self._assistant_text_buffer, aggregated_by="sentence", spoken=True)
1065+
)
10641066
self._may_need_repush_assistant_text = False
10651067

10661068
# Report the end of the assistant response.

src/pipecat/services/google/gemini_live/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1459,7 +1459,7 @@ async def _handle_msg_output_transcription(self, message: LiveServerMessage):
14591459
self._llm_output_buffer += text
14601460

14611461
await self.push_frame(LLMTextFrame(text=text))
1462-
await self.push_frame(TTSTextFrame(text=text))
1462+
await self.push_frame(TTSTextFrame(text=text, aggregated_by="sentence", spoken=True))
14631463

14641464
async def _handle_msg_grounding_metadata(self, message: LiveServerMessage):
14651465
"""Handle dedicated grounding metadata messages."""

src/pipecat/services/openai/realtime/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ async def _handle_evt_text_delta(self, evt):
673673
async def _handle_evt_audio_transcript_delta(self, evt):
674674
if evt.delta:
675675
await self.push_frame(LLMTextFrame(evt.delta))
676-
await self.push_frame(TTSTextFrame(evt.delta))
676+
await self.push_frame(TTSTextFrame(evt.delta, aggregated_by="sentence", spoken=True))
677677

678678
async def _handle_evt_function_call_arguments_done(self, evt):
679679
"""Handle completion of function call arguments.

src/pipecat/services/openai_realtime_beta/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ async def _handle_evt_text_delta(self, evt):
654654
async def _handle_evt_audio_transcript_delta(self, evt):
655655
if evt.delta:
656656
await self.push_frame(LLMTextFrame(evt.delta))
657-
await self.push_frame(TTSTextFrame(evt.delta))
657+
await self.push_frame(TTSTextFrame(evt.delta, aggregated_by="sentence", spoken=True))
658658

659659
async def _handle_evt_speech_started(self, evt):
660660
await self._truncate_current_audio_response()

src/pipecat/services/tts_service.py

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ def __init__(
101101
sample_rate: Optional[int] = None,
102102
# Text aggregator to aggregate incoming tokens and decide when to push to the TTS.
103103
text_aggregator: Optional[BaseTextAggregator] = None,
104+
# Types of text aggregations that should not be spoken.
105+
skip_aggregator_types: Optional[List[str]] = [],
104106
# Text filter executed after text has been aggregated.
105107
text_filters: Optional[Sequence[BaseTextFilter]] = None,
106108
text_filter: Optional[BaseTextFilter] = None,
@@ -120,6 +122,7 @@ def __init__(
120122
pause_frame_processing: Whether to pause frame processing during audio generation.
121123
sample_rate: Output sample rate for generated audio.
122124
text_aggregator: Custom text aggregator for processing incoming text.
125+
skip_aggregator_types: List of aggregation types that should not be spoken.
123126
text_filters: Sequence of text filters to apply after aggregation.
124127
text_filter: Single text filter (deprecated, use text_filters).
125128
@@ -142,6 +145,7 @@ def __init__(
142145
self._voice_id: str = ""
143146
self._settings: Dict[str, Any] = {}
144147
self._text_aggregator: BaseTextAggregator = text_aggregator or SimpleTextAggregator()
148+
self._skip_aggregator_types: List[str] = skip_aggregator_types or []
145149
self._text_filters: Sequence[BaseTextFilter] = text_filters or []
146150
self._transport_destination: Optional[str] = transport_destination
147151
self._tracing_enabled: bool = False
@@ -351,10 +355,14 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
351355
# pause to avoid audio overlapping.
352356
await self._maybe_pause_frame_processing()
353357

354-
sentence = self._text_aggregator.text
358+
aggregate = self._text_aggregator.text
355359
await self._text_aggregator.reset()
356360
self._processing_text = False
357-
await self._push_tts_frames(sentence)
361+
await self._push_tts_frames(
362+
text=aggregate.text,
363+
should_speak=aggregate.type not in self._skip_aggregator_types,
364+
aggregated_by=aggregate.type,
365+
)
358366
if isinstance(frame, LLMFullResponseEndFrame):
359367
if self._push_text_frames:
360368
await self.push_frame(frame, direction)
@@ -363,7 +371,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
363371
elif isinstance(frame, TTSSpeakFrame):
364372
# Store if we were processing text or not so we can set it back.
365373
processing_text = self._processing_text
366-
await self._push_tts_frames(frame.text)
374+
await self._push_tts_frames(frame.text, should_speak=True, aggregated_by="word")
367375
# We pause processing incoming frames because we are sending data to
368376
# the TTS. We pause to avoid audio overlapping.
369377
await self._maybe_pause_frame_processing()
@@ -455,42 +463,53 @@ async def _process_text_frame(self, frame: TextFrame):
455463
text: Optional[str] = None
456464
if not self._aggregate_sentences:
457465
text = frame.text
466+
should_speak = True
467+
aggregated_by = "token"
458468
else:
459-
text = await self._text_aggregator.aggregate(frame.text)
469+
aggregate = await self._text_aggregator.aggregate(frame.text)
470+
if aggregate:
471+
text = aggregate.text
472+
should_speak = aggregate.type not in self._skip_aggregator_types
473+
aggregated_by = aggregate.type
460474

461475
if text:
462-
await self._push_tts_frames(text)
476+
logger.trace(f"Pushing TTS frames for text: {text}, {should_speak}, {aggregated_by}")
477+
await self._push_tts_frames(text, should_speak, aggregated_by)
463478

464-
async def _push_tts_frames(self, text: str):
465-
# Remove leading newlines only
466-
text = text.lstrip("\n")
479+
async def _push_tts_frames(self, text: str, should_speak: bool, aggregated_by: str):
480+
if should_speak:
481+
# Remove leading newlines only
482+
text = text.lstrip("\n")
467483

468-
# Don't send only whitespace. This causes problems for some TTS models. But also don't
469-
# strip all whitespace, as whitespace can influence prosody.
470-
if not text.strip():
471-
return
484+
# Don't send only whitespace. This causes problems for some TTS models. But also don't
485+
# strip all whitespace, as whitespace can influence prosody.
486+
if not text.strip():
487+
return
472488

473-
# This is just a flag that indicates if we sent something to the TTS
474-
# service. It will be cleared if we sent text because of a TTSSpeakFrame
475-
# or when we received an LLMFullResponseEndFrame
476-
self._processing_text = True
489+
# This is just a flag that indicates if we sent something to the TTS
490+
# service. It will be cleared if we sent text because of a TTSSpeakFrame
491+
# or when we received an LLMFullResponseEndFrame
492+
self._processing_text = True
477493

478-
await self.start_processing_metrics()
494+
await self.start_processing_metrics()
479495

480-
# Process all filter.
481-
for filter in self._text_filters:
482-
await filter.reset_interruption()
483-
text = await filter.filter(text)
496+
# Process all filter.
497+
for filter in self._text_filters:
498+
await filter.reset_interruption()
499+
text = await filter.filter(text)
484500

485-
if text:
486-
await self.process_generator(self.run_tts(text))
501+
if text:
502+
await self.push_frame(TTSTextFrame(text, spoken=True, aggregated_by=aggregated_by))
503+
await self.process_generator(self.run_tts(text))
487504

488-
await self.stop_processing_metrics()
505+
await self.stop_processing_metrics()
489506

490-
if self._push_text_frames:
507+
if self._push_text_frames or not should_speak:
491508
# We send the original text after the audio. This way, if we are
492509
# interrupted, the text is not added to the assistant context.
493-
await self.push_frame(TTSTextFrame(text))
510+
await self.push_frame(
511+
TTSTextFrame(text, spoken=should_speak, aggregated_by=aggregated_by)
512+
)
494513

495514
async def _stop_frame_handler(self):
496515
has_started = False
@@ -616,7 +635,7 @@ async def _words_task_handler(self):
616635
frame = TTSStoppedFrame()
617636
frame.pts = last_pts
618637
else:
619-
frame = TTSTextFrame(word)
638+
frame = TTSTextFrame(word, spoken=True, aggregated_by="word")
620639
frame.pts = self._initial_word_timestamp + timestamp
621640
if frame:
622641
last_pts = frame.pts

src/pipecat/utils/text/base_text_aggregator.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,38 @@
1212
"""
1313

1414
from abc import ABC, abstractmethod
15+
from dataclasses import dataclass
1516
from typing import Optional
1617

1718

19+
@dataclass
20+
class Aggregation:
21+
"""Data class representing aggregated text and its type.
22+
23+
An Aggregation object is created whenever a stream of text is aggregated by
24+
a text aggregator. It contains the aggregated text and a type indicating
25+
the nature of the aggregation.
26+
"""
27+
28+
def __init__(self, text: str, type: str):
29+
"""Initialize an aggregation instance.
30+
31+
Args:
32+
text: The aggregated text content.
33+
type: The type of aggregation the text represents (e.g., 'sentence', 'word', 'token', 'my_custom_aggregation').
34+
"""
35+
self.text = text
36+
self.type = type
37+
38+
def __str__(self) -> str:
39+
"""Return a string representation of the aggregation.
40+
41+
Returns:
42+
A descriptive string showing the type and text of the aggregation.
43+
"""
44+
return f"Aggregation by {self.type}: {self.text}"
45+
46+
1847
class BaseTextAggregator(ABC):
1948
"""Base class for text aggregators in the Pipecat framework.
2049
@@ -30,7 +59,7 @@ class BaseTextAggregator(ABC):
3059

3160
@property
3261
@abstractmethod
33-
def text(self) -> str:
62+
def text(self) -> Aggregation:
3463
"""Get the currently aggregated text.
3564
3665
Subclasses must implement this property to return the text that has
@@ -42,12 +71,13 @@ def text(self) -> str:
4271
pass
4372

4473
@abstractmethod
45-
async def aggregate(self, text: str) -> Optional[str]:
74+
async def aggregate(self, text: str) -> Optional[Aggregation]:
4675
"""Aggregate the specified text with the currently accumulated text.
4776
4877
This method should be implemented to define how the new text contributes
49-
to the aggregation process. It returns the updated aggregated text if
50-
it's ready to be processed, or None otherwise.
78+
to the aggregation process. It returns the aggregated text and a string
79+
describing how it was aggregated if it's ready to be processed,
80+
or None otherwise.
5181
5282
Subclasses should implement their specific logic for:
5383

0 commit comments

Comments
 (0)