Skip to content

Commit 4254954

Browse files
committed
add transformers to initialization args
1 parent 6b2d9b3 commit 4254954

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

src/pipecat/processors/frameworks/rtvi.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,10 @@ class RTVIObserverParams:
937937
skip_aggregator_types: List of aggregation types to skip sending as tts/output messages.
938938
Note: if using this to avoid sending secure information, be sure to also disable
939939
bot_llm_enabled to avoid leaking through LLM messages.
940+
bot_output_transforms: A list of callables to transform text before just before sending it
941+
to TTS. Each callable takes the aggregated text and its type, and returns the
942+
transformed text. To register, provide a list of tuples of
943+
(aggregation_type | '*', transform_function).
940944
audio_level_period_secs: How often audio levels should be sent if enabled.
941945
"""
942946

@@ -953,6 +957,14 @@ class RTVIObserverParams:
953957
system_logs_enabled: bool = False
954958
errors_enabled: Optional[bool] = None
955959
skip_aggregator_types: Optional[List[AggregationType | str]] = None
960+
bot_output_transforms: Optional[
961+
List[
962+
Tuple[
963+
AggregationType | str,
964+
Callable[[str, AggregationType | str], Awaitable[str]],
965+
]
966+
]
967+
] = None
956968
audio_level_period_secs: float = 0.15
957969

958970

@@ -1005,15 +1017,17 @@ def __init__(
10051017
DeprecationWarning,
10061018
)
10071019

1008-
self._aggregation_transforms: List[Tuple[str, Callable[[str, str], Awaitable[str]]]] = []
1020+
self._aggregation_transforms: List[
1021+
Tuple[AggregationType | str, Callable[[str, AggregationType | str], Awaitable[str]]]
1022+
] = self._params.bot_output_transforms or []
10091023

10101024
def add_bot_output_transformer(
1011-
self, transform_function: Callable[[str, str], Awaitable[str]], aggregation_type: str = "*"
1025+
self,
1026+
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
1027+
aggregation_type: AggregationType | str = "*",
10121028
):
10131029
"""Transform text for a specific aggregation type before sending as Bot Output or TTS.
10141030
1015-
# TODO: What if someone wanted to remove a registered transform?
1016-
10171031
Args:
10181032
transform_function: The function to apply for transformation. This function should take
10191033
the text and aggregation type as input and return the transformed text.
@@ -1024,7 +1038,9 @@ def add_bot_output_transformer(
10241038
self._aggregation_transforms.append((aggregation_type, transform_function))
10251039

10261040
def remove_bot_output_transformer(
1027-
self, transform_function: Callable[[str, str], Awaitable[str]], aggregation_type: str = "*"
1041+
self,
1042+
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
1043+
aggregation_type: AggregationType | str = "*",
10281044
):
10291045
"""Remove a text transformer for a specific aggregation type.
10301046

src/pipecat/services/tts_service.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,14 @@ def __init__(
107107
text_aggregator: Optional[BaseTextAggregator] = None,
108108
# Types of text aggregations that should not be spoken.
109109
skip_aggregator_types: Optional[List[str]] = [],
110+
# A list of callables to transform text before just before sending it to TTS.
111+
# Each callable takes the aggregated text and its type, and returns the transformed text.
112+
# To register, provide a list of tuples of (aggregation_type | '*', transform_function).
113+
text_transforms: Optional[
114+
List[
115+
Tuple[AggregationType | str, Callable[[str, str | AggregationType], Awaitable[str]]]
116+
]
117+
] = None,
110118
# Text filter executed after text has been aggregated.
111119
text_filters: Optional[Sequence[BaseTextFilter]] = None,
112120
text_filter: Optional[BaseTextFilter] = None,
@@ -131,6 +139,11 @@ def __init__(
131139
Use an LLMTextProcessor before the TTSService for custom text aggregation.
132140
133141
skip_aggregator_types: List of aggregation types that should not be spoken.
142+
text_transforms: A list of callables to transform text before just before sending it
143+
to TTS. Each callable takes the aggregated text and its type, and returns the
144+
transformed text. To register, provide a list of tuples of
145+
(aggregation_type | '*', transform_function).
146+
134147
text_filters: Sequence of text filters to apply after aggregation.
135148
text_filter: Single text filter (deprecated, use text_filters).
136149
@@ -164,7 +177,9 @@ def __init__(
164177
)
165178

166179
self._skip_aggregator_types: List[str] = skip_aggregator_types or []
167-
self._text_transforms: List[Tuple[str, Callable[[str, str], Awaitable[str]]]] = []
180+
self._text_transforms: List[
181+
Tuple[AggregationType | str, Callable[[str, AggregationType | str], Awaitable[str]]]
182+
] = text_transforms or []
168183
# TODO: Deprecate _text_filters when added to LLMTextProcessor
169184
self._text_filters: Sequence[BaseTextFilter] = text_filters or []
170185
self._transport_destination: Optional[str] = transport_destination
@@ -323,7 +338,9 @@ async def cancel(self, frame: CancelFrame):
323338
self._stop_frame_task = None
324339

325340
def add_text_transformer(
326-
self, transform_function: Callable[[str, str], Awaitable[str]], aggregation_type: str = "*"
341+
self,
342+
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
343+
aggregation_type: AggregationType | str = "*",
327344
):
328345
"""Transform text for a specific aggregation type.
329346
@@ -337,7 +354,9 @@ def add_text_transformer(
337354
self._text_transforms.append((aggregation_type, transform_function))
338355

339356
def remove_text_transformer(
340-
self, transform_function: Callable[[str, str], Awaitable[str]], aggregation_type: str = "*"
357+
self,
358+
transform_function: Callable[[str, AggregationType | str], Awaitable[str]],
359+
aggregation_type: AggregationType | str = "*",
341360
):
342361
"""Remove a text transformer for a specific aggregation type.
343362

0 commit comments

Comments
 (0)