Skip to content

Commit 713b488

Browse files
committed
Final PR Feedback changes
1 parent 71b87fd commit 713b488

File tree

5 files changed

+56
-43
lines changed

5 files changed

+56
-43
lines changed

src/pipecat/services/tts_service.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -584,36 +584,38 @@ async def _push_tts_frames(self, src_frame: AggregatedTextFrame):
584584
await filter.reset_interruption()
585585
text = await filter.filter(text)
586586

587-
if text:
588-
if not self._push_text_frames:
589-
# In a typical pipeline, there is an assistant context aggregator
590-
# that listens for TTSTextFrames to add spoken text to the context.
591-
# If the TTS service supports word timestamps, then _push_text_frames
592-
# is set to False and these are sent word by word as part of the
593-
# _words_task_handler in the WordTTSService subclass. However, to
594-
# support use cases where an observer may want the full text before
595-
# the audio is generated, we send along the AggregatedTextFrame here,
596-
# but we set append_to_context to False so it does not cause duplication
597-
# in the context. This is primarily used by the RTVIObserver to
598-
# generate a complete bot-output.
599-
src_frame.append_to_context = False
600-
await self.push_frame(src_frame)
601-
# Note: Text transformations only affect the text sent to the TTS. This allows
602-
# for explicit TTS-specific modifications (e.g., inserting TTS supported tags
603-
# for spelling or emotion or replacing an @ with "at"). For TTS services that
604-
# support word-level timestamps, this DOES affect the resulting context as the
605-
# the context is built from the TTSTextFrames generated during word timestamping.
606-
for aggregation_type, transform in self._text_transforms:
607-
if aggregation_type == type or aggregation_type == "*":
608-
text = await transform(text, type)
609-
await self.process_generator(self.run_tts(text))
587+
if not text.strip():
588+
await self.stop_processing_metrics()
589+
return
590+
591+
# To support use cases that may want to know the text before it's spoken, we
592+
# push the AggregatedTextFrame version before transforming and sending to TTS.
593+
# However, we do not want to add this text to the assistant context until it
594+
# is spoken, so we set append_to_context to False.
595+
src_frame.append_to_context = False
596+
await self.push_frame(src_frame)
597+
598+
# Note: Text transformations are meant to only affect the text sent to the TTS for
599+
# TTS-specific purposes. This allows for explicit TTS modifications (e.g., inserting
600+
# TTS supported tags for spelling or emotion or replacing an @ with "at"). For TTS
601+
# services that support word-level timestamps, this CAN affect the resulting context
602+
# since the TTSTextFrames are generated from the TTS output stream
603+
transformed_text = text
604+
for aggregation_type, transform in self._text_transforms:
605+
if aggregation_type == type or aggregation_type == "*":
606+
transformed_text = await transform(transformed_text, type)
607+
await self.process_generator(self.run_tts(transformed_text))
610608

611609
await self.stop_processing_metrics()
612610

613611
if self._push_text_frames:
614-
# In the case where the TTS service does not support word timestamps,
615-
# we send the full aggregated text after the audio. This way, if we are
616-
# interrupted, the text is not added to the assistant context.
612+
# In TTS services that support word timestamps, the TTSTextFrames
613+
# are pushed as words are spoken. However, in the case where the TTS service
614+
# does not support word timestamps (i.e. _push_text_frames is True), we send
615+
# the original (non-transformed) text after the TTS generation has completed.
616+
# This way, if we are interrupted, the text is not added to the assistant
617+
# context and the context that IS added does not include TTS-specific tags
618+
# or transformations.
617619
frame = TTSTextFrame(text, aggregated_by=type)
618620
frame.includes_inter_frame_spaces = self.includes_inter_frame_spaces
619621
await self.push_frame(frame)

src/pipecat/utils/text/base_text_aggregator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,20 @@
1313

1414
from abc import ABC, abstractmethod
1515
from dataclasses import dataclass
16+
from enum import Enum
1617
from typing import Optional
1718

1819

20+
class AggregationType(str, Enum):
21+
"""Built-in aggregation strings."""
22+
23+
SENTENCE = "sentence"
24+
WORD = "word"
25+
26+
def __str__(self):
27+
return self.value
28+
29+
1930
@dataclass
2031
class Aggregation:
2132
"""Data class representing aggregated text and its type.

src/pipecat/utils/text/pattern_pair_aggregator.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from loguru import logger
1919

2020
from pipecat.utils.string import match_endofsentence
21-
from pipecat.utils.text.base_text_aggregator import Aggregation, BaseTextAggregator
21+
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
2222

2323

2424
class MatchAction(Enum):
@@ -110,8 +110,8 @@ def text(self) -> Aggregation:
110110
"""
111111
pattern_start = self._match_start_of_pattern(self._text)
112112
if pattern_start:
113-
return Aggregation(self._text, pattern_start[1].get("type", "sentence"))
114-
return Aggregation(self._text, "sentence")
113+
return Aggregation(self._text, pattern_start[1].get("type", AggregationType.SENTENCE))
114+
return Aggregation(self._text, AggregationType.SENTENCE)
115115

116116
def add_pattern(
117117
self,
@@ -128,8 +128,8 @@ def add_pattern(
128128
129129
Args:
130130
type: Identifier for this pattern pair. Should be unique and ideally descriptive.
131-
(e.g., 'code', 'speaker', 'custom'). type can not be 'sentence' as that is
132-
reserved for the default behavior.
131+
(e.g., 'code', 'speaker', 'custom'). type can not be 'sentence' or 'word' as
132+
those are reserved for the default behavior.
133133
start_pattern: Pattern that marks the beginning of content.
134134
end_pattern: Pattern that marks the end of content.
135135
action: What to do when a complete pattern is matched:
@@ -143,9 +143,9 @@ def add_pattern(
143143
Returns:
144144
Self for method chaining.
145145
"""
146-
if type == "sentence":
146+
if type in [AggregationType.SENTENCE, AggregationType.WORD]:
147147
raise ValueError(
148-
"The aggregation type 'sentence' is reserved for default behavior and can not be used for custom patterns."
148+
f"The aggregation type '{type}' is reserved for default behavior and can not be used for custom patterns."
149149
)
150150
self._patterns[type] = {
151151
"start": start_pattern,
@@ -169,8 +169,8 @@ def add_pattern_pair(
169169
170170
Args:
171171
pattern_id: Identifier for this pattern pair. Should be unique and ideally descriptive.
172-
(e.g., 'code', 'speaker', 'custom'). pattern_id can not be 'sentence' as that is
173-
reserved for the default behavior.
172+
(e.g., 'code', 'speaker', 'custom'). pattern_id can not be 'sentence' or 'word'
173+
as those arereserved for the default behavior.
174174
start_pattern: Pattern that marks the beginning of content.
175175
end_pattern: Pattern that marks the end of content.
176176
remove_match: If True, the matched pattern will be removed from the text. (Same as MatchAction.REMOVE)
@@ -345,15 +345,15 @@ async def aggregate(self, text: str) -> Optional[PatternMatch]:
345345
# Otherwise, strip the text up to the start pattern and return it
346346
result = self._text[: pattern_start[0]]
347347
self._text = self._text[pattern_start[0] :]
348-
return PatternMatch(content=result, type="sentence", full_match=result)
348+
return PatternMatch(content=result, type=AggregationType.SENTENCE, full_match=result)
349349

350350
# Find sentence boundary if no incomplete patterns
351351
eos_marker = match_endofsentence(self._text)
352352
if eos_marker:
353353
# Extract text up to the sentence boundary
354354
result = self._text[:eos_marker]
355355
self._text = self._text[eos_marker:]
356-
return PatternMatch(content=result, type="sentence", full_match=result)
356+
return PatternMatch(content=result, type=AggregationType.SENTENCE, full_match=result)
357357

358358
# No complete sentence found yet
359359
return None

src/pipecat/utils/text/simple_text_aggregator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import Optional
1515

1616
from pipecat.utils.string import match_endofsentence
17-
from pipecat.utils.text.base_text_aggregator import Aggregation, BaseTextAggregator
17+
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
1818

1919

2020
class SimpleTextAggregator(BaseTextAggregator):
@@ -39,7 +39,7 @@ def text(self) -> Aggregation:
3939
Returns:
4040
The text that has been accumulated in the buffer.
4141
"""
42-
return Aggregation(self._text, "sentence")
42+
return Aggregation(self._text, AggregationType.SENTENCE)
4343

4444
async def aggregate(self, text: str) -> Optional[Aggregation]:
4545
"""Aggregate text and return completed sentences.
@@ -64,7 +64,7 @@ async def aggregate(self, text: str) -> Optional[Aggregation]:
6464
result = self._text[:eos_end_marker]
6565
self._text = self._text[eos_end_marker:]
6666

67-
return Aggregation(result, "sentence") if result else None
67+
return Aggregation(result, AggregationType.SENTENCE) if result else None
6868

6969
async def handle_interruption(self):
7070
"""Handle interruptions by clearing the text buffer.

src/pipecat/utils/text/skip_tags_aggregator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import Optional, Sequence
1515

1616
from pipecat.utils.string import StartEndTags, match_endofsentence, parse_start_end_tags
17-
from pipecat.utils.text.base_text_aggregator import Aggregation, BaseTextAggregator
17+
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
1818

1919

2020
class SkipTagsAggregator(BaseTextAggregator):
@@ -49,7 +49,7 @@ def text(self) -> str:
4949
Returns:
5050
The current text buffer content that hasn't been processed yet.
5151
"""
52-
return Aggregation(self._text, "sentence")
52+
return Aggregation(self._text, AggregationType.SENTENCE)
5353

5454
async def aggregate(self, text: str) -> Optional[Aggregation]:
5555
"""Aggregate text while respecting tag boundaries.
@@ -80,7 +80,7 @@ async def aggregate(self, text: str) -> Optional[Aggregation]:
8080
# Extract text up to the sentence boundary
8181
result = self._text[:eos_marker]
8282
self._text = self._text[eos_marker:]
83-
return Aggregation(result, "sentence")
83+
return Aggregation(result, AggregationType.SENTENCE)
8484

8585
# No complete sentence found yet
8686
return None

0 commit comments

Comments
 (0)