Skip to content

Commit 89c7277

Browse files
committed
Various fixes:
1. Fixed pattern_pair_aggregator to support various ways of handling pattern matches (remove, keep and just trigger a callback, or aggregate 2. Fixed ivr_navigator use of pattern_pair_aggregator 3. Test fixes -- Tests now pass
1 parent fa68db9 commit 89c7277

File tree

6 files changed

+95
-48
lines changed

6 files changed

+95
-48
lines changed

examples/foundational/35-pattern-pair-voice-switching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
111111
start_pattern="<voice>",
112112
end_pattern="</voice>",
113113
type="voice",
114-
remove_match=True,
114+
action="remove", # Remove tags from final text
115115
)
116116

117117
# Register handler for voice switching

src/pipecat/extensions/ivr/ivr_navigator.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -114,19 +114,15 @@ def _get_conversation_history(self) -> List[dict]:
114114
def _setup_xml_patterns(self):
115115
"""Set up XML pattern detection and handlers."""
116116
# Register DTMF pattern
117-
self._aggregator.add_pattern_pair(
118-
"dtmf", "<dtmf>", "</dtmf>", type="dtmf", remove_match=True
119-
)
117+
self._aggregator.add_pattern_pair("dtmf", "<dtmf>", "</dtmf>", type="dtmf", action="remove")
120118
self._aggregator.on_pattern_match("dtmf", self._handle_dtmf_action)
121119

122120
# Register mode pattern
123-
self._aggregator.add_pattern_pair(
124-
"mode", "<mode>", "</mode>", type="mode", remove_match=True
125-
)
121+
self._aggregator.add_pattern_pair("mode", "<mode>", "</mode>", type="mode", action="remove")
126122
self._aggregator.on_pattern_match("mode", self._handle_mode_action)
127123

128124
# Register IVR pattern
129-
self._aggregator.add_pattern_pair("ivr", "<ivr>", "</ivr>", type="ivr", remove_match=True)
125+
self._aggregator.add_pattern_pair("ivr", "<ivr>", "</ivr>", type="ivr", action="remove")
130126
self._aggregator.on_pattern_match("ivr", self._handle_ivr_action)
131127

132128
async def process_frame(self, frame: Frame, direction: FrameDirection):
@@ -163,7 +159,7 @@ async def _handle_dtmf_action(self, match: PatternMatch):
163159
Args:
164160
match: The pattern match containing DTMF content.
165161
"""
166-
value = match.content
162+
value = match.text
167163
logger.debug(f"DTMF detected: {value}")
168164

169165
try:
@@ -184,7 +180,7 @@ async def _handle_ivr_action(self, match: PatternMatch):
184180
Args:
185181
match: The pattern match containing IVR status content.
186182
"""
187-
status = match.content
183+
status = match.text
188184
logger.trace(f"IVR status detected: {status}")
189185

190186
# Convert string to enum, with validation
@@ -215,7 +211,7 @@ async def _handle_mode_action(self, match: PatternMatch):
215211
Args:
216212
match: The pattern match containing mode content.
217213
"""
218-
mode = match.content
214+
mode = match.text
219215
logger.debug(f"Mode detected: {mode}")
220216
if mode == "conversation":
221217
await self._handle_conversation()

src/pipecat/utils/text/pattern_pair_aggregator.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"""
1313

1414
import re
15-
from typing import Awaitable, Callable, List, Optional, Tuple
15+
from typing import Awaitable, Callable, List, Literal, Optional, Tuple
1616

1717
from loguru import logger
1818

@@ -83,9 +83,9 @@ def text(self) -> Aggregation:
8383
Returns:
8484
The text that has been accumulated in the buffer.
8585
"""
86-
start, curtype = self._match_start_of_pattern(self._text)
87-
if curtype:
88-
return Aggregation(self._text, curtype)
86+
pattern_start = self._match_start_of_pattern(self._text)
87+
if pattern_start:
88+
return Aggregation(self._text, pattern_start[1].get("type", "sentence"))
8989
return Aggregation(self._text, "sentence")
9090

9191
def add_pattern_pair(
@@ -94,7 +94,7 @@ def add_pattern_pair(
9494
start_pattern: str,
9595
end_pattern: str,
9696
type: str,
97-
remove_match: bool = True,
97+
action: Literal["remove", "keep", "aggregate"] = "remove",
9898
) -> "PatternPairAggregator":
9999
"""Add a pattern pair to detect in the text.
100100
@@ -108,7 +108,12 @@ def add_pattern_pair(
108108
end_pattern: Pattern that marks the end of content.
109109
type: The type of aggregation the matched content represents
110110
(e.g., 'code', 'speaker', 'custom').
111-
remove_match: Whether to remove the matched content from the text returned.
111+
action: What to do when a complete pattern is matched:
112+
- "remove": Remove the matched pattern from the text.
113+
- "keep": Keep the matched pattern in the text and treat it as
114+
normal text.
115+
- "aggregate": Return the matched pattern as a separate
116+
aggregation object.
112117
113118
Returns:
114119
Self for method chaining.
@@ -117,7 +122,7 @@ def add_pattern_pair(
117122
"start": start_pattern,
118123
"end": end_pattern,
119124
"type": type,
120-
"remove_match": remove_match,
125+
"action": action,
121126
}
122127
return self
123128

@@ -162,7 +167,7 @@ async def _process_complete_patterns(self, text: str) -> Tuple[List[PatternMatch
162167
# Escape special regex characters in the patterns
163168
start = re.escape(pattern_info["start"])
164169
end = re.escape(pattern_info["end"])
165-
remove_match = pattern_info["remove_match"]
170+
action = pattern_info["action"]
166171
match_type = pattern_info["type"]
167172

168173
# Create regex to match from start pattern to end pattern
@@ -190,15 +195,15 @@ async def _process_complete_patterns(self, text: str) -> Tuple[List[PatternMatch
190195
logger.error(f"Error in pattern handler for {pattern_id}: {e}")
191196

192197
# Remove the pattern from the text if configured
193-
if remove_match:
198+
if action == "remove":
194199
processed_text = processed_text.replace(full_match, "", 1)
195200
# modified = True
196201
else:
197202
all_matches.append(pattern_match)
198203

199204
return all_matches, processed_text
200205

201-
def _match_start_of_pattern(self, text: str) -> Optional[Tuple[int, str]]:
206+
def _match_start_of_pattern(self, text: str) -> Optional[Tuple[int, dict]]:
202207
"""Check if text contains incomplete pattern pairs.
203208
204209
Determines whether the text contains any start patterns without
@@ -225,9 +230,9 @@ def _match_start_of_pattern(self, text: str) -> Optional[Tuple[int, str]]:
225230
# Which is why we base the return on the first found.
226231
if start_count > end_count:
227232
start_index = text.find(start)
228-
return [start_index, pattern_info["type"]]
233+
return [start_index, pattern_info]
229234

230-
return None, None
235+
return None
231236

232237
async def aggregate(self, text: str) -> Optional[PatternMatch]:
233238
"""Aggregate text and process pattern pairs.
@@ -258,17 +263,22 @@ async def aggregate(self, text: str) -> Optional[PatternMatch]:
258263
logger.warning(
259264
f"Multiple patterns matched: {[p.pattern_id for p in patterns]}. Only the first pattern will be returned."
260265
)
261-
self._text = ""
262-
return patterns[0]
266+
# If the pattern found is set to be aggregated, return it
267+
action = self._patterns[patterns[0].pattern_id].get("action", "remove")
268+
if action == "aggregate":
269+
self._text = ""
270+
print(f"Returning pattern: {patterns[0]}")
271+
return patterns[0]
263272

264273
# Check if we have incomplete patterns
265-
start, curtype = self._match_start_of_pattern(self._text)
266-
if start is not None:
267-
# Still waiting for complete patterns
268-
if start == 0:
274+
pattern_start = self._match_start_of_pattern(self._text)
275+
if pattern_start is not None:
276+
# If the start pattern is at the beginning or should not be separately aggregated, return None
277+
if pattern_start[0] == 0 or pattern_start[1].get("action", "remove") != "aggregate":
269278
return None
270-
result = self._text[:start]
271-
self._text = self._text[start:]
279+
# Otherwise, strip the text up to the start pattern and return it
280+
result = self._text[: pattern_start[0]]
281+
self._text = self._text[pattern_start[0] :]
272282
return PatternMatch(f"_sentence", result, result, "sentence")
273283

274284
# Find sentence boundary if no incomplete patterns

tests/test_pattern_pair_aggregator.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,32 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
1414
def setUp(self):
1515
self.aggregator = PatternPairAggregator()
1616
self.test_handler = AsyncMock()
17+
self.code_handler = AsyncMock()
1718

1819
# Add a test pattern
1920
self.aggregator.add_pattern_pair(
2021
pattern_id="test_pattern",
2122
start_pattern="<test>",
2223
end_pattern="</test>",
2324
type="test",
24-
remove_match=True,
25+
action="remove",
2526
)
2627
self.aggregator.add_pattern_pair(
2728
pattern_id="code_pattern",
2829
start_pattern="<code>",
2930
end_pattern="</code>",
3031
type="code",
31-
remove_match=False,
32+
action="aggregate",
3233
)
3334

3435
# Register the mock handler
3536
self.aggregator.on_pattern_match("test_pattern", self.test_handler)
37+
self.aggregator.on_pattern_match("code_pattern", self.code_handler)
3638

3739
async def test_pattern_match_and_removal(self):
3840
# First part doesn't complete the pattern
3941
result = await self.aggregator.aggregate("Hello <test>pattern")
42+
print(f"result: {result}")
4043
self.assertIsNone(result)
4144
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern")
4245
self.assertEqual(self.aggregator.text.type, "test")
@@ -50,7 +53,7 @@ async def test_pattern_match_and_removal(self):
5053
self.assertIsInstance(call_args, PatternMatch)
5154
self.assertEqual(call_args.pattern_id, "test_pattern")
5255
self.assertEqual(call_args.full_match, "<test>pattern content</test>")
53-
self.assertEqual(call_args.content, "pattern content")
56+
self.assertEqual(call_args.text, "pattern content")
5457

5558
# The exclamation point should be treated as a sentence boundary,
5659
# so the result should include just text up to and including "!"
@@ -64,6 +67,33 @@ async def test_pattern_match_and_removal(self):
6467
# Buffer should be empty after returning a complete sentence
6568
self.assertEqual(self.aggregator.text.text, "")
6669

70+
async def test_pattern_match_and_aggregate(self):
71+
# First part doesn't complete the pattern
72+
result = await self.aggregator.aggregate("Here is code <code>pattern")
73+
print(f"result: {result}")
74+
self.assertEqual(result.text, "Here is code ")
75+
self.assertEqual(self.aggregator.text.text, "<code>pattern")
76+
self.assertEqual(self.aggregator.text.type, "code")
77+
78+
# Second part completes the pattern and includes an exclamation point
79+
result = await self.aggregator.aggregate(" content</code>")
80+
81+
# Verify the handler was called with correct PatternMatch object
82+
self.code_handler.assert_called_once()
83+
call_args = self.code_handler.call_args[0][0]
84+
self.assertIsInstance(call_args, PatternMatch)
85+
self.assertEqual(call_args.pattern_id, "code_pattern")
86+
self.assertEqual(call_args.full_match, "<code>pattern content</code>")
87+
self.assertEqual(call_args.text, "pattern content")
88+
89+
# Next sentence should be processed separately
90+
result = await self.aggregator.aggregate(" This is another sentence.")
91+
self.assertEqual(result.text, " This is another sentence.")
92+
self.assertEqual(result.type, "sentence")
93+
94+
# Buffer should be empty after returning a complete sentence
95+
self.assertEqual(self.aggregator.text.text, "")
96+
6797
async def test_incomplete_pattern(self):
6898
# Add text with incomplete pattern
6999
result = await self.aggregator.aggregate("Hello <test>pattern content")
@@ -88,14 +118,19 @@ async def test_multiple_patterns(self):
88118
emphasis_handler = AsyncMock()
89119

90120
self.aggregator.add_pattern_pair(
91-
pattern_id="voice", start_pattern="<voice>", end_pattern="</voice>", remove_match=True
121+
pattern_id="voice",
122+
start_pattern="<voice>",
123+
end_pattern="</voice>",
124+
type="voice",
125+
action="remove",
92126
)
93127

94128
self.aggregator.add_pattern_pair(
95129
pattern_id="emphasis",
96130
start_pattern="<em>",
97131
end_pattern="</em>",
98-
remove_match=False, # Keep emphasis tags
132+
type="emphasis",
133+
action="keep", # Keep emphasis tags
99134
)
100135

101136
self.aggregator.on_pattern_match("voice", voice_handler)
@@ -109,15 +144,15 @@ async def test_multiple_patterns(self):
109144
voice_handler.assert_called_once()
110145
voice_match = voice_handler.call_args[0][0]
111146
self.assertEqual(voice_match.pattern_id, "voice")
112-
self.assertEqual(voice_match.content, "female")
147+
self.assertEqual(voice_match.text, "female")
113148

114149
emphasis_handler.assert_called_once()
115150
emphasis_match = emphasis_handler.call_args[0][0]
116151
self.assertEqual(emphasis_match.pattern_id, "emphasis")
117-
self.assertEqual(emphasis_match.content, "very")
152+
self.assertEqual(emphasis_match.text, "very")
118153

119154
# Voice pattern should be removed, emphasis pattern should remain
120-
self.assertEqual(result, "Hello I am <em>very</em> excited to meet you!")
155+
self.assertEqual(result.text, "Hello I am <em>very</em> excited to meet you!")
121156

122157
# Buffer should be empty
123158
self.assertEqual(self.aggregator.text.text, "")
@@ -149,10 +184,10 @@ async def test_pattern_across_sentences(self):
149184
# Handler should be called with entire content
150185
self.test_handler.assert_called_once()
151186
call_args = self.test_handler.call_args[0][0]
152-
self.assertEqual(call_args.content, "This is sentence one. This is sentence two.")
187+
self.assertEqual(call_args.text, "This is sentence one. This is sentence two.")
153188

154189
# Pattern should be removed, resulting in text with sentences merged
155-
self.assertEqual(result, "Hello Final sentence.")
190+
self.assertEqual(result.text, "Hello Final sentence.")
156191

157192
# Buffer should be empty
158193
self.assertEqual(self.aggregator.text.text, "")

tests/test_piper_tts.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ async def handler(request):
7474
]
7575

7676
expected_returned_frames = [
77+
TTSTextFrame,
7778
TTSStartedFrame,
7879
TTSAudioRawFrame,
7980
TTSAudioRawFrame,
@@ -121,7 +122,7 @@ async def handler(_request):
121122
TTSSpeakFrame(text="Error case."),
122123
]
123124

124-
expected_down_frames = [TTSStoppedFrame, TTSTextFrame]
125+
expected_down_frames = [TTSTextFrame, TTSStoppedFrame, TTSTextFrame]
125126

126127
expected_up_frames = [ErrorFrame]
127128

tests/test_simple_text_aggregator.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,20 @@ def setUp(self):
1515

1616
async def test_reset_aggregations(self):
1717
assert await self.aggregator.aggregate("Hello ") == None
18-
assert self.aggregator.text == "Hello "
18+
assert self.aggregator.text.text == "Hello "
1919
await self.aggregator.reset()
20-
assert self.aggregator.text == ""
20+
assert self.aggregator.text.text == ""
2121

2222
async def test_simple_sentence(self):
2323
assert await self.aggregator.aggregate("Hello ") == None
24-
assert await self.aggregator.aggregate("Pipecat!") == "Hello Pipecat!"
25-
assert self.aggregator.text == ""
24+
aggregate = await self.aggregator.aggregate("Pipecat!")
25+
assert aggregate.text == "Hello Pipecat!"
26+
assert aggregate.type == "sentence"
27+
assert self.aggregator.text.text == ""
2628

2729
async def test_multiple_sentences(self):
28-
assert await self.aggregator.aggregate("Hello Pipecat! How are ") == "Hello Pipecat!"
29-
assert await self.aggregator.aggregate("you?") == " How are you?"
30+
aggregate = await self.aggregator.aggregate("Hello Pipecat! How are ")
31+
assert aggregate.text == "Hello Pipecat!"
32+
assert self.aggregator.text.text == " How are "
33+
aggregate = await self.aggregator.aggregate("you?")
34+
assert aggregate.text == " How are you?"

0 commit comments

Comments
 (0)