Skip to content

Commit 8461208

Browse files
committed
Refactor DeepgramTTSService to use HTTP directly
1 parent 5db0871 commit 8461208

File tree

3 files changed

+108
-77
lines changed

3 files changed

+108
-77
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
151151

152152
### Changed
153153

154+
- Refactor `DeepgramTTSService` to use a direct HTTP connection. This results
155+
in a significant TTFB reduction when compared to using the Deepgram python
156+
SDK.
157+
158+
Note: an `aiohttp_session` is now required when initializing
159+
`DeepgramTTSService`.
160+
154161
- `DailyTransport` triggers `on_error` event if transcription can't be started
155162
or stopped.
156163

examples/foundational/07c-interruptible-deepgram.py

Lines changed: 54 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import os
99

10+
import aiohttp
1011
from dotenv import load_dotenv
1112
from loguru import logger
1213

@@ -60,58 +61,63 @@
6061
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
6162
logger.info(f"Starting bot")
6263

63-
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
64+
async with aiohttp.ClientSession() as session:
65+
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
6466

65-
tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-2-andromeda-en")
67+
tts = DeepgramTTSService(
68+
api_key=os.getenv("DEEPGRAM_API_KEY"),
69+
voice="aura-2-andromeda-en",
70+
aiohttp_session=session,
71+
)
6672

67-
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
73+
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
6874

69-
messages = [
70-
{
71-
"role": "system",
72-
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
73-
},
74-
]
75-
76-
context = LLMContext(messages)
77-
context_aggregator = LLMContextAggregatorPair(context)
78-
79-
pipeline = Pipeline(
80-
[
81-
transport.input(), # Transport user input
82-
stt, # STT
83-
context_aggregator.user(), # User responses
84-
llm, # LLM
85-
tts, # TTS
86-
transport.output(), # Transport bot output
87-
context_aggregator.assistant(), # Assistant spoken responses
75+
messages = [
76+
{
77+
"role": "system",
78+
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
79+
},
8880
]
89-
)
90-
91-
task = PipelineTask(
92-
pipeline,
93-
params=PipelineParams(
94-
enable_metrics=True,
95-
enable_usage_metrics=True,
96-
),
97-
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
98-
)
99-
100-
@transport.event_handler("on_client_connected")
101-
async def on_client_connected(transport, client):
102-
logger.info(f"Client connected")
103-
# Kick off the conversation.
104-
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
105-
await task.queue_frames([LLMRunFrame()])
106-
107-
@transport.event_handler("on_client_disconnected")
108-
async def on_client_disconnected(transport, client):
109-
logger.info(f"Client disconnected")
110-
await task.cancel()
111-
112-
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
113-
114-
await runner.run(task)
81+
82+
context = LLMContext(messages)
83+
context_aggregator = LLMContextAggregatorPair(context)
84+
85+
pipeline = Pipeline(
86+
[
87+
transport.input(), # Transport user input
88+
stt, # STT
89+
context_aggregator.user(), # User responses
90+
llm, # LLM
91+
tts, # TTS
92+
transport.output(), # Transport bot output
93+
context_aggregator.assistant(), # Assistant spoken responses
94+
]
95+
)
96+
97+
task = PipelineTask(
98+
pipeline,
99+
params=PipelineParams(
100+
enable_metrics=True,
101+
enable_usage_metrics=True,
102+
),
103+
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
104+
)
105+
106+
@transport.event_handler("on_client_connected")
107+
async def on_client_connected(transport, client):
108+
logger.info(f"Client connected")
109+
# Kick off the conversation.
110+
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
111+
await task.queue_frames([LLMRunFrame()])
112+
113+
@transport.event_handler("on_client_disconnected")
114+
async def on_client_disconnected(transport, client):
115+
logger.info(f"Client disconnected")
116+
await task.cancel()
117+
118+
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
119+
120+
await runner.run(task)
115121

116122

117123
async def bot(runner_args: RunnerArguments):

src/pipecat/services/deepgram/tts.py

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from typing import AsyncGenerator, Optional
1414

15+
import aiohttp
1516
from loguru import logger
1617

1718
from pipecat.frames.frames import (
@@ -24,13 +25,6 @@
2425
from pipecat.services.tts_service import TTSService
2526
from pipecat.utils.tracing.service_decorators import traced_tts
2627

27-
try:
28-
from deepgram import DeepgramClient, DeepgramClientOptions, SpeakOptions
29-
except ModuleNotFoundError as e:
30-
logger.error(f"Exception: {e}")
31-
logger.error("In order to use Deepgram, you need to `pip install pipecat-ai[deepgram]`.")
32-
raise Exception(f"Missing module: {e}")
33-
3428

3529
class DeepgramTTSService(TTSService):
3630
"""Deepgram text-to-speech service.
@@ -45,7 +39,8 @@ def __init__(
4539
*,
4640
api_key: str,
4741
voice: str = "aura-2-helena-en",
48-
base_url: str = "",
42+
aiohttp_session: aiohttp.ClientSession,
43+
base_url: str = "https://api.deepgram.com",
4944
sample_rate: Optional[int] = None,
5045
encoding: str = "linear16",
5146
**kwargs,
@@ -55,21 +50,22 @@ def __init__(
5550
Args:
5651
api_key: Deepgram API key for authentication.
5752
voice: Voice model to use for synthesis. Defaults to "aura-2-helena-en".
58-
base_url: Custom base URL for Deepgram API. Uses default if empty.
53+
aiohttp_session: Shared aiohttp session for HTTP requests with connection pooling.
54+
base_url: Custom base URL for Deepgram API. Defaults to "https://api.deepgram.com".
5955
sample_rate: Audio sample rate in Hz. If None, uses service default.
6056
encoding: Audio encoding format. Defaults to "linear16".
6157
**kwargs: Additional arguments passed to parent TTSService class.
6258
"""
6359
super().__init__(sample_rate=sample_rate, **kwargs)
6460

61+
self._api_key = api_key
62+
self._session = aiohttp_session
63+
self._base_url = base_url
6564
self._settings = {
6665
"encoding": encoding,
6766
}
6867
self.set_voice(voice)
6968

70-
client_options = DeepgramClientOptions(url=base_url)
71-
self._deepgram_client = DeepgramClient(api_key, config=client_options)
72-
7369
def can_generate_metrics(self) -> bool:
7470
"""Check if the service can generate metrics.
7571
@@ -90,27 +86,49 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
9086
"""
9187
logger.debug(f"{self}: Generating TTS [{text}]")
9288

93-
options = SpeakOptions(
94-
model=self._voice_id,
95-
encoding=self._settings["encoding"],
96-
sample_rate=self.sample_rate,
97-
container="none",
98-
)
89+
# Build URL with parameters
90+
url = f"{self._base_url}/v1/speak"
9991

100-
try:
101-
await self.start_ttfb_metrics()
92+
headers = {"Authorization": f"Token {self._api_key}", "Content-Type": "application/json"}
10293

103-
response = await self._deepgram_client.speak.asyncrest.v("1").stream_raw(
104-
{"text": text}, options
105-
)
94+
params = {
95+
"model": self._voice_id,
96+
"encoding": self._settings["encoding"],
97+
"sample_rate": self.sample_rate,
98+
"container": "none",
99+
}
106100

107-
await self.start_tts_usage_metrics(text)
108-
yield TTSStartedFrame()
101+
payload = {
102+
"text": text,
103+
}
104+
105+
try:
106+
await self.start_ttfb_metrics()
109107

110-
async for data in response.aiter_bytes():
111-
await self.stop_ttfb_metrics()
112-
if data:
113-
yield TTSAudioRawFrame(audio=data, sample_rate=self.sample_rate, num_channels=1)
108+
async with self._session.post(
109+
url, headers=headers, json=payload, params=params
110+
) as response:
111+
if response.status != 200:
112+
error_text = await response.text()
113+
raise Exception(f"HTTP {response.status}: {error_text}")
114+
115+
await self.start_tts_usage_metrics(text)
116+
yield TTSStartedFrame()
117+
118+
CHUNK_SIZE = self.chunk_size
119+
120+
first_chunk = True
121+
async for chunk in response.content.iter_chunked(CHUNK_SIZE):
122+
if first_chunk:
123+
await self.stop_ttfb_metrics()
124+
first_chunk = False
125+
126+
if chunk:
127+
yield TTSAudioRawFrame(
128+
audio=chunk,
129+
sample_rate=self.sample_rate,
130+
num_channels=1,
131+
)
114132

115133
yield TTSStoppedFrame()
116134

0 commit comments

Comments
 (0)