Skip to content

Commit fbf1874

Browse files
authored
Merge pull request #2 from liberate-org/exp/interrupt
Exp/interrupt
2 parents 96465e6 + 5a6fce9 commit fbf1874

File tree

3 files changed

+95
-2
lines changed

3 files changed

+95
-2
lines changed

vocode/streaming/agent/chat_gpt_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
else:
5656
self.openaiAsyncClient = AsyncOpenAI(
5757
base_url = "https://api.openai.com/v1",
58-
api_key = openai_api_key or getenv("OPENAI_API_KEY")
58+
api_key = openai_api_key or getenv("OPENAI_API_KEY"),
5959
)
6060
self.openaiSyncClient = OpenAI(
6161
base_url = "https://api.openai.com/v1",
@@ -141,6 +141,7 @@ async def respond(
141141
text = self.first_response
142142
else:
143143
chat_parameters = self.get_chat_parameters()
144+
chat_parameters["stream"] = True
144145
# chat_completion = await openai.ChatCompletion.acreate(**chat_parameters)
145146
chat_completion = await self.openaiAsyncClient.chat.completions.create(**chat_parameters)
146147
text = chat_completion.choices[0].message.content

vocode/streaming/models/synthesizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class SynthesizerConfig(TypedModel, type=SynthesizerType.BASE.value):
4343
audio_encoding: AudioEncoding
4444
should_encode_as_wav: bool = False
4545
sentiment_config: Optional[SentimentConfig] = None
46+
reengage_timeout: Optional[float] = None
47+
reengage_options: Optional[List[str]] = None
4648

4749
class Config:
4850
arbitrary_types_allowed = True

vocode/streaming/streaming_conversation.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@
7373

7474
OutputDeviceType = TypeVar("OutputDeviceType", bound=BaseOutputDevice)
7575

76-
7776
class StreamingConversation(Generic[OutputDeviceType]):
7877
class QueueingInterruptibleEventFactory(InterruptibleEventFactory):
7978
def __init__(self, conversation: "StreamingConversation"):
@@ -119,11 +118,27 @@ def __init__(
119118
self.conversation = conversation
120119
self.interruptible_event_factory = interruptible_event_factory
121120

121+
def kill_tasks_when_human_is_talking(self):
122+
has_task = self.conversation.synthesis_results_worker.current_task is not None
123+
if has_task and not self.conversation.synthesis_results_worker.current_task.done():
124+
self.conversation.logger.info("###### Synthesis task is running, attempting to cancel it ######")
125+
self.conversation.synthesis_results_worker.current_task.cancel()
126+
self.conversation.logger.info("###### Synthesis task is running, has been canceled ######")
127+
has_agent_task = self.conversation.agent_responses_worker.current_task
128+
if has_agent_task and not self.conversation.agent_responses_worker.current_task.done():
129+
self.conversation.logger.info("&&&&&&& Agent Response task is running, attempting to cancel it &&&&&&&")
130+
self.conversation.agent_responses_worker.current_task.cancel()
131+
self.conversation.logger.info("&&&&&&& Agent Response task is running, has been canceled &&&&&&&")
132+
122133
async def process(self, transcription: Transcription):
123134
self.conversation.mark_last_action_timestamp()
124135
if transcription.message.strip() == "":
125136
self.conversation.logger.info("Ignoring empty transcription")
126137
return
138+
elif transcription.message.strip() == "<INTERRUPT>" and transcription.confidence == 1.0:
139+
# self.kill_tasks_when_human_is_talking()
140+
self.conversation.broadcast_interrupt()
141+
127142
if transcription.is_final:
128143
self.conversation.logger.debug(
129144
"Got transcription: {}, confidence: {}".format(
@@ -156,6 +171,10 @@ async def process(self, transcription: Transcription):
156171
)
157172
)
158173
self.output_queue.put_nowait(event)
174+
self.conversation.mark_last_final_transcript_from_human()
175+
# else:
176+
# self.kill_tasks_when_human_is_talking()
177+
# self.conversation.broadcast_interrupt()
159178

160179
class FillerAudioWorker(InterruptibleAgentResponseWorker):
161180
"""
@@ -365,6 +384,7 @@ async def process(
365384
await self.conversation.terminate()
366385
except asyncio.TimeoutError:
367386
pass
387+
self.conversation.mark_last_agent_response()
368388
except asyncio.CancelledError:
369389
pass
370390

@@ -508,6 +528,12 @@ async def start(self, mark_ready: Optional[Callable[[], Awaitable[None]]] = None
508528
self.check_for_idle_task = asyncio.create_task(self.check_for_idle())
509529
if len(self.events_manager.subscriptions) > 0:
510530
self.events_task = asyncio.create_task(self.events_manager.start())
531+
if (
532+
self.synthesizer.get_synthesizer_config().reengage_timeout and
533+
(self.synthesizer.get_synthesizer_config().reengage_options and
534+
len(self.synthesizer.get_synthesizer_config().reengage_options) > 0)
535+
):
536+
self.human_prompt_checker = asyncio.create_task(self.check_if_human_should_be_prompted())
511537

512538
async def send_initial_message(self, initial_message: BaseMessage):
513539
# TODO: configure if initial message is interruptible
@@ -571,6 +597,13 @@ def warmup_synthesizer(self):
571597
def mark_last_action_timestamp(self):
572598
self.last_action_timestamp = time.time()
573599

600+
def mark_last_final_transcript_from_human(self):
601+
self.last_final_transcript_from_human = time.time()
602+
603+
def mark_last_agent_response(self):
604+
self.last_agent_response = time.time()
605+
606+
574607
def broadcast_interrupt(self):
575608
"""Stops all inflight events and cancels all workers that are sending output
576609
@@ -588,13 +621,32 @@ def broadcast_interrupt(self):
588621
break
589622
self.agent.cancel_current_task()
590623
self.agent_responses_worker.cancel_current_task()
624+
625+
# Clearing these queues cuts time from finishing interruption talking to bot talking cut by 1 second from ~4.5 to ~3.5 seconds.
626+
self.clear_queue(self.agent.output_queue, 'agent.output_queue')
627+
self.clear_queue(self.agent_responses_worker.output_queue, 'agent_responses_worker.output_queue')
628+
self.clear_queue(self.agent_responses_worker.input_queue, 'agent_responses_worker.input_queue')
629+
self.clear_queue(self.synthesis_results_worker.output_queue, 'synthesis_results_worker.output_queue')
630+
self.clear_queue(self.synthesis_results_worker.input_queue, 'synthesis_results_worker.input_queue')
631+
if hasattr(self.output_device, 'queue'):
632+
self.clear_queue(self.output_device.queue, 'output_device.queue')
633+
591634
return num_interrupts > 0
592635

593636
def is_interrupt(self, transcription: Transcription):
594637
return transcription.confidence >= (
595638
self.transcriber.get_transcriber_config().min_interrupt_confidence or 0
596639
)
597640

641+
@staticmethod
642+
def clear_queue(q: asyncio.Queue, queue_name: str):
643+
while not q.empty():
644+
logging.info(f'Clearing queue {queue_name} with size {q.qsize()}')
645+
try:
646+
q.get_nowait()
647+
except asyncio.QueueEmpty:
648+
continue
649+
598650
async def send_speech_to_output(
599651
self,
600652
message: str,
@@ -726,3 +778,41 @@ async def terminate(self):
726778

727779
def is_active(self):
728780
return self.active
781+
782+
async def check_if_human_should_be_prompted(self):
783+
self.logger.debug("starting should prompt user task")
784+
self.last_agent_response = None
785+
self.last_final_transcript_from_human = None
786+
reengage_timeout = self.synthesizer.get_synthesizer_config().reengage_timeout
787+
reengage_options = self.synthesizer.get_synthesizer_config().reengage_options
788+
while self.active:
789+
if self.last_agent_response and self.last_final_transcript_from_human:
790+
last_human_touchpoint = time.time() - self.last_final_transcript_from_human
791+
last_agent_touchpoint = time.time() - self.last_agent_response
792+
if last_human_touchpoint >= reengage_timeout and last_agent_touchpoint >= reengage_timeout:
793+
reengage_statement = random.choice(reengage_options)
794+
self.logger.debug(f"Prompting user with {reengage_statement}: no interaction has happened in {reengage_timeout} seconds")
795+
self.chunk_size = (
796+
get_chunk_size_per_second(
797+
self.synthesizer.get_synthesizer_config().audio_encoding,
798+
self.synthesizer.get_synthesizer_config().sampling_rate,
799+
)
800+
* TEXT_TO_SPEECH_CHUNK_SIZE_SECONDS
801+
)
802+
message = BaseMessage(text=reengage_statement)
803+
synthesis_result = await self.synthesizer.create_speech(
804+
message,
805+
self.chunk_size,
806+
bot_sentiment=self.bot_sentiment,
807+
)
808+
self.agent_responses_worker.produce_interruptible_agent_response_event_nonblocking(
809+
(message, synthesis_result),
810+
is_interruptible=True,
811+
agent_response_tracker=asyncio.Event(),
812+
)
813+
self.mark_last_agent_response()
814+
await asyncio.sleep(1)
815+
else:
816+
await asyncio.sleep(1)
817+
self.logger.debug("stopped check if human should be prompted")
818+

0 commit comments

Comments
 (0)