Skip to content

Commit 5d0a355

Browse files
committed
Move aggregation logic when skip_tts is on to the assistant aggregator
1 parent 840b2d0 commit 5d0a355

File tree

3 files changed

+49
-38
lines changed

3 files changed

+49
-38
lines changed

src/pipecat/processors/aggregators/llm_response.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
2323
from pipecat.audio.vad.vad_analyzer import VADParams
2424
from pipecat.frames.frames import (
25-
AggregatedLLMTextFrame,
2625
BotStartedSpeakingFrame,
2726
BotStoppedSpeakingFrame,
2827
CancelFrame,

src/pipecat/processors/aggregators/llm_response_universal.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
2424
from pipecat.audio.vad.vad_analyzer import VADParams
2525
from pipecat.frames.frames import (
26+
AggregatedLLMTextFrame,
2627
BotStartedSpeakingFrame,
2728
BotStoppedSpeakingFrame,
2829
CancelFrame,
@@ -46,6 +47,7 @@
4647
LLMRunFrame,
4748
LLMSetToolChoiceFrame,
4849
LLMSetToolsFrame,
50+
LLMTextFrame,
4951
SpeechControlParamsFrame,
5052
StartFrame,
5153
TextFrame,
@@ -65,6 +67,7 @@
6567
LLMUserAggregatorParams,
6668
)
6769
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
70+
from pipecat.utils.string import match_endofsentence
6871
from pipecat.utils.time import time_now_iso8601
6972

7073

@@ -565,6 +568,9 @@ def __init__(
565568
self._function_calls_in_progress: Dict[str, Optional[FunctionCallInProgressFrame]] = {}
566569
self._context_updated_tasks: Set[asyncio.Task] = set()
567570

571+
self._llm_aggregation: str = ""
572+
self._skip_tts: Optional[bool] = None
573+
568574
@property
569575
def has_function_calls_in_progress(self) -> bool:
570576
"""Check if there are any function calls currently in progress.
@@ -588,6 +594,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
588594
await self.push_frame(frame, direction)
589595
elif isinstance(frame, LLMFullResponseStartFrame):
590596
await self._handle_llm_start(frame)
597+
elif isinstance(frame, LLMTextFrame):
598+
await self._handle_llm_text(frame)
591599
elif isinstance(frame, LLMFullResponseEndFrame):
592600
await self._handle_llm_end(frame)
593601
elif isinstance(frame, TextFrame):
@@ -787,12 +795,50 @@ async def _handle_user_image_frame(self, frame: UserImageRawFrame):
787795
await self.push_aggregation()
788796
await self.push_context_frame(FrameDirection.UPSTREAM)
789797

790-
async def _handle_llm_start(self, _: LLMFullResponseStartFrame):
798+
async def _handle_llm_start(self, frame: LLMFullResponseStartFrame):
791799
self._started += 1
800+
if self._skip_tts is None:
801+
self._skip_tts = frame.skip_tts
802+
await self._maybe_push_llm_aggregation(frame)
803+
804+
async def _handle_llm_text(self, frame: LLMTextFrame):
805+
await self._handle_text(frame)
806+
if self._skip_tts or frame.skip_tts:
807+
self._llm_aggregation += frame.text
808+
await self._maybe_push_llm_aggregation(frame)
792809

793-
async def _handle_llm_end(self, _: LLMFullResponseEndFrame):
810+
async def _handle_llm_end(self, frame: LLMFullResponseEndFrame):
794811
self._started -= 1
795812
await self.push_aggregation()
813+
await self._maybe_push_llm_aggregation(frame)
814+
815+
async def _maybe_push_llm_aggregation(
816+
self, frame: LLMFullResponseStartFrame | LLMTextFrame | LLMFullResponseEndFrame
817+
):
818+
should_push = False
819+
if self._skip_tts and not frame.skip_tts:
820+
# if the skip_tts flag switches, to false, push the current aggregation
821+
should_push = True
822+
self._skip_tts = frame.skip_tts
823+
if self._skip_tts:
824+
if self._skip_tts and isinstance(frame, LLMFullResponseEndFrame):
825+
# on end frame, always push the aggregation
826+
should_push = True
827+
elif len(self._llm_aggregation) > 0 and match_endofsentence(self._llm_aggregation):
828+
# push aggregation on end of sentence
829+
should_push = True
830+
831+
if not should_push:
832+
return
833+
834+
text = self._llm_aggregation.lstrip("\n")
835+
if not text.strip():
836+
# don't push empty text
837+
return
838+
839+
llm_frame = AggregatedLLMTextFrame(text=text, aggregated_by="sentence")
840+
await self.push_frame(llm_frame)
841+
self._llm_aggregation = ""
796842

797843
async def _handle_text(self, frame: TextFrame):
798844
if not self._started or not frame.append_to_context:

src/pipecat/processors/frameworks/rtvi.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -984,8 +984,6 @@ def __init__(
984984
self._last_user_audio_level = 0
985985
self._last_bot_audio_level = 0
986986

987-
self._skip_tts = None
988-
989987
if self._params.system_logs_enabled:
990988
self._system_logger_id = logger.add(self._logger_sink)
991989

@@ -1024,16 +1022,6 @@ async def send_rtvi_message(self, model: BaseModel, exclude_none: bool = True):
10241022
if self._rtvi:
10251023
await self._rtvi.push_transport_message(model, exclude_none)
10261024

1027-
async def send_aggregated_llm_text(self, text: str, aggregated_by: Optional[str] = None):
1028-
"""Send aggregated LLM text as a bot output message.
1029-
1030-
Args:
1031-
text: The aggregated text to send.
1032-
aggregated_by: The method of aggregation (e.g., "word", "sentence").
1033-
"""
1034-
if self._rtvi:
1035-
await self._rtvi.push_aggregated_llm_text(text, aggregated_by)
1036-
10371025
async def on_push_frame(self, data: FramePushed):
10381026
"""Process a frame being pushed through the pipeline.
10391027
@@ -1171,30 +1159,14 @@ async def _handle_llm_text_frame(self, frame: LLMTextFrame):
11711159
message = RTVIBotLLMTextMessage(data=RTVITextMessageData(text=frame.text))
11721160
await self.send_rtvi_message(message)
11731161

1174-
# initialize skip_tts on first LLMTextFrame
1175-
if self._skip_tts is None:
1176-
self._skip_tts = frame.skip_tts
1177-
1178-
orig_text = self._bot_transcription
1162+
# TODO: Remove all this logic when we fully deprecate bot-transcription messages.
11791163
self._bot_transcription += frame.text
11801164

11811165
if match_endofsentence(self._bot_transcription) and len(self._bot_transcription) > 0:
1182-
# TODO: Remove this message when we fully deprecate bot-transcription messages.
11831166
await self.send_rtvi_message(
11841167
RTVIBotTranscriptionMessage(data=RTVITextMessageData(text=self._bot_transcription))
11851168
)
1186-
if frame.skip_tts:
1187-
await self.send_aggregated_llm_text(
1188-
text=self._bot_transcription, aggregated_by="sentence"
1189-
)
11901169
self._bot_transcription = ""
1191-
elif not frame.skip_tts and self._skip_tts:
1192-
# We just switched from skipping TTS to not skipping TTS.
1193-
# Send any dangling transcription.
1194-
if len(orig_text) > 0:
1195-
await self.send_aggregated_llm_text(text=orig_text, aggregated_by="sentence")
1196-
self._bot_transcription = frame.text
1197-
self._skip_tts = frame.skip_tts
11981170

11991171
async def _handle_user_transcriptions(self, frame: Frame):
12001172
"""Handle user transcription frames."""
@@ -1424,12 +1396,6 @@ async def push_transport_message(self, model: BaseModel, exclude_none: bool = Tr
14241396
)
14251397
await self.push_frame(frame)
14261398

1427-
async def push_aggregated_llm_text(self, text: str, aggregated_by: Optional[str] = None):
1428-
"""Push an aggregated LLM text frame."""
1429-
frame = AggregatedLLMTextFrame(text=text, aggregated_by=aggregated_by)
1430-
frame.skip_tts = True
1431-
await self.push_frame(frame)
1432-
14331399
async def handle_message(self, message: RTVIMessage):
14341400
"""Handle an incoming RTVI message.
14351401

0 commit comments

Comments
 (0)