diff --git a/CHANGELOG.md b/CHANGELOG.md index fc13f8fab0..75dcc1b314 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added supprt for Sarvam Speech-to-Text service (`SarvamSTTService`) with streaming WebSocket + support for `saarika` (STT) and `saaras` (STT-translate) models. + - Added a new `DeepgramHttpTTSService`, which delivers a meaningful reduction in latency when compared to the `DeepgramTTSService`. diff --git a/examples/foundational/07z-interruptible-sarvam-http.py b/examples/foundational/07z-interruptible-sarvam-http.py index 29851d2547..0821167ef8 100644 --- a/examples/foundational/07z-interruptible-sarvam-http.py +++ b/examples/foundational/07z-interruptible-sarvam-http.py @@ -22,8 +22,8 @@ from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair from pipecat.runner.types import RunnerArguments from pipecat.runner.utils import create_transport -from pipecat.services.deepgram.stt import DeepgramSTTService from pipecat.services.openai.llm import OpenAILLMService +from pipecat.services.sarvam.stt import SarvamSTTService from pipecat.services.sarvam.tts import SarvamHttpTTSService from pipecat.transcriptions.language import Language from pipecat.transports.base_transport import BaseTransport, TransportParams @@ -63,7 +63,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): # Create an HTTP session async with aiohttp.ClientSession() as session: - stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + stt = SarvamSTTService( + api_key=os.getenv("SARVAM_API_KEY"), + model="saarika:v2.5", + ) tts = SarvamHttpTTSService( api_key=os.getenv("SARVAM_API_KEY"), diff --git a/examples/foundational/07z-interruptible-sarvam.py b/examples/foundational/07z-interruptible-sarvam.py index 44e0b7844b..3123df31d9 100644 --- a/examples/foundational/07z-interruptible-sarvam.py +++ b/examples/foundational/07z-interruptible-sarvam.py @@ -24,8 +24,8 @@ from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair from pipecat.runner.types import RunnerArguments from pipecat.runner.utils import create_transport -from pipecat.services.deepgram.stt import DeepgramSTTService from pipecat.services.openai.llm import OpenAILLMService +from pipecat.services.sarvam.stt import SarvamSTTService from pipecat.services.sarvam.tts import SarvamTTSService from pipecat.transports.base_transport import BaseTransport, TransportParams from pipecat.transports.daily.transport import DailyParams @@ -62,7 +62,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): logger.info(f"Starting bot") - stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + stt = SarvamSTTService( + api_key=os.getenv("SARVAM_API_KEY"), + model="saarika:v2.5", + ) tts = SarvamTTSService( api_key=os.getenv("SARVAM_API_KEY"), diff --git a/pyproject.toml b/pyproject.toml index 5760361b93..00cdb8c570 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,7 +93,7 @@ rime = [ "pipecat-ai[websockets-base]" ] riva = [ "nvidia-riva-client~=2.21.1" ] runner = [ "python-dotenv>=1.0.0,<2.0.0", "uvicorn>=0.32.0,<1.0.0", "fastapi>=0.115.6,<0.117.0", "pipecat-ai-small-webrtc-prebuilt>=1.0.0"] sambanova = [] -sarvam = [ "pipecat-ai[websockets-base]" ] +sarvam = [ "sarvamai==0.1.21", "pipecat-ai[websockets-base]" ] sentry = [ "sentry-sdk>=2.28.0,<3" ] local-smart-turn = [ "coremltools>=8.0", "transformers", "torch>=2.5.0,<3", "torchaudio>=2.5.0,<3" ] local-smart-turn-v3 = [ "transformers", "onnxruntime>=1.20.1,<2" ] diff --git a/src/pipecat/services/sarvam/stt.py b/src/pipecat/services/sarvam/stt.py new file mode 100644 index 0000000000..27816163d5 --- /dev/null +++ b/src/pipecat/services/sarvam/stt.py @@ -0,0 +1,468 @@ +"""Sarvam AI Speech-to-Text service implementation. + +This module provides a streaming Speech-to-Text service using Sarvam AI's WebSocket-based +API. It supports real-time transcription with Voice Activity Detection (VAD) and +can handle multiple audio formats for Indian language speech recognition. +""" + +import base64 +from typing import Optional + +from loguru import logger +from pydantic import BaseModel + +from pipecat.frames.frames import ( + CancelFrame, + EndFrame, + ErrorFrame, + StartFrame, + TranscriptionFrame, +) +from pipecat.services.stt_service import STTService +from pipecat.transcriptions.language import Language +from pipecat.utils.time import time_now_iso8601 +from pipecat.utils.tracing.service_decorators import traced_stt + +try: + from sarvamai import AsyncSarvamAI + from sarvamai.core.api_error import ApiError + from sarvamai.core.events import EventType +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use Sarvam, you need to `pip install pipecat-ai[sarvam]`.") + raise Exception(f"Missing module: {e}") + + +def language_to_sarvam_language(language: Language) -> str: + """Convert a Language enum to Sarvam's language code format. + + Args: + language: The Language enum value to convert. + + Returns: + The Sarvam language code string. + """ + # Mapping of pipecat Language enum to Sarvam language codes + SARVAM_LANGUAGES = { + Language.BN_IN: "bn-IN", + Language.GU_IN: "gu-IN", + Language.HI_IN: "hi-IN", + Language.KN_IN: "kn-IN", + Language.ML_IN: "ml-IN", + Language.MR_IN: "mr-IN", + Language.TA_IN: "ta-IN", + Language.TE_IN: "te-IN", + Language.PA_IN: "pa-IN", + Language.OR_IN: "od-IN", + Language.EN_IN: "en-IN", + Language.AS_IN: "as-IN", + } + + return SARVAM_LANGUAGES.get( + language, "unknown" + ) # Default to unknown (Sarvam models auto-detect the language) + + +class SarvamSTTService(STTService): + """Sarvam speech-to-text service. + + Provides real-time speech recognition using Sarvam's WebSocket API. + """ + + class InputParams(BaseModel): + """Configuration parameters for Sarvam STT service. + + Parameters: + language: Target language for transcription. Defaults to None (required for saarika models). + prompt: Optional prompt to guide translation style/context for STT-Translate models. + Only applicable to saaras (STT-Translate) models. Defaults to None. + vad_signals: Enable VAD signals in response. Defaults to True. + high_vad_sensitivity: Enable high VAD (Voice Activity Detection) sensitivity. Defaults to False. + """ + + language: Optional[Language] = None + prompt: Optional[str] = None + vad_signals: bool = True + high_vad_sensitivity: bool = False + + def __init__( + self, + *, + api_key: str, + model: str = "saarika:v2.5", + sample_rate: Optional[int] = None, + input_audio_codec: str = "wav", + params: Optional[InputParams] = None, + **kwargs, + ): + """Initialize the Sarvam STT service. + + Args: + api_key: Sarvam API key for authentication. + model: Sarvam model to use for transcription. + sample_rate: Audio sample rate. Defaults to 16000 if not specified. + input_audio_codec: Audio codec/format of the input file. Defaults to "wav". + params: Configuration parameters for Sarvam STT service. + **kwargs: Additional arguments passed to the parent STTService. + """ + params = params or SarvamSTTService.InputParams() + + # Validate that saaras models don't accept language parameter + if "saaras" in model.lower(): + if params.language is not None: + raise ValueError( + f"Model '{model}' does not accept language parameter. " + "STT-Translate models auto-detect language." + ) + + # Validate that saarika models don't accept prompt parameter + if "saarika" in model.lower(): + if params.prompt is not None: + raise ValueError( + f"Model '{model}' does not accept prompt parameter. " + "Prompts are only supported for STT-Translate models" + ) + + super().__init__(sample_rate=sample_rate, **kwargs) + + self.set_model_name(model) + self._api_key = api_key + self._language_code = params.language + # For saarika models, default to "unknown" if language is not provided + if params.language: + self._language_string = language_to_sarvam_language(params.language) + elif "saarika" in model.lower(): + self._language_string = "unknown" + else: + self._language_string = None + self._prompt = params.prompt + + # Store connection parameters + self._vad_signals = params.vad_signals + self._high_vad_sensitivity = params.high_vad_sensitivity + self._input_audio_codec = input_audio_codec + + # Initialize Sarvam SDK client + self._sarvam_client = AsyncSarvamAI(api_subscription_key=api_key) + self._websocket_context = None + self._socket_client = None + self._receive_task = None + + def language_to_service_language(self, language: Language) -> str: + """Convert pipecat Language enum to Sarvam's language code. + + Args: + language: The Language enum value to convert. + + Returns: + The Sarvam language code string. + """ + return language_to_sarvam_language(language) + + def can_generate_metrics(self) -> bool: + """Check if this service can generate processing metrics. + + Returns: + True, as Sarvam service supports metrics generation. + """ + return True + + async def set_language(self, language: Language): + """Set the recognition language and reconnect. + + Args: + language: The language to use for speech recognition. + """ + # saaras models do not accept a language parameter + if "saaras" in self.model_name.lower(): + raise ValueError( + f"Model '{self.model_name}' (saaras) does not accept language parameter. " + "saaras models auto-detect language." + ) + + logger.info(f"Switching STT language to: [{language}]") + self._language_code = language + self._language_string = language_to_sarvam_language(language) + await self._disconnect() + await self._connect() + + async def set_prompt(self, prompt: Optional[str]): + """Set the translation prompt and reconnect. + + Args: + prompt: Prompt text to guide translation style/context. + Pass None to clear/disable prompt. + Only applicable to STT-Translate models, not STT models. + """ + # saarika models do not accept prompt parameter + if "saarika" in self.model_name.lower(): + if prompt is not None: + raise ValueError( + f"Model '{self.model_name}' does not accept prompt parameter. " + "Prompts are only supported for STT-Translate models." + ) + # If prompt is None and it's saarika, just silently return (no-op) + return + + logger.info("Updating STT-Translate prompt.") + self._prompt = prompt + await self._disconnect() + await self._connect() + + async def start(self, frame: StartFrame): + """Start the Sarvam STT service. + + Args: + frame: The start frame containing initialization parameters. + """ + await super().start(frame) + await self._connect() + + async def stop(self, frame: EndFrame): + """Stop the Sarvam STT service. + + Args: + frame: The end frame. + """ + await super().stop(frame) + await self._disconnect() + + async def cancel(self, frame: CancelFrame): + """Cancel the Sarvam STT service. + + Args: + frame: The cancel frame. + """ + await super().cancel(frame) + await self._disconnect() + + async def run_stt(self, audio: bytes): + """Send audio data to Sarvam for transcription. + + Args: + audio: Raw audio bytes to transcribe. + + Yields: + Frame: None (transcription results come via WebSocket callbacks). + """ + if not self._socket_client: + logger.warning("WebSocket not connected, cannot process audio") + yield None + return + + try: + # Convert audio bytes to base64 for Sarvam API + audio_base64 = base64.b64encode(audio).decode("utf-8") + + # Convert input_audio_codec to encoding format (prepend "audio/" if needed) + encoding = ( + self._input_audio_codec + if self._input_audio_codec.startswith("audio/") + else f"audio/{self._input_audio_codec}" + ) + + # Build method arguments + method_kwargs = { + "audio": audio_base64, + "encoding": encoding, + "sample_rate": self.sample_rate, + } + + # Use appropriate method based on service type + if "saarika" in self.model_name.lower(): + # STT service + await self._socket_client.transcribe(**method_kwargs) + else: + # STT-Translate service - auto-detects input language and returns translated text + await self._socket_client.translate(**method_kwargs) + + except Exception as e: + logger.error(f"Error sending audio to Sarvam: {e}") + await self.push_error(ErrorFrame(f"Failed to send audio: {e}")) + + yield None + + async def _connect(self): + """Connect to Sarvam WebSocket API using the SDK.""" + logger.debug("Connecting to Sarvam") + + try: + # Convert boolean parameters to string for SDK + vad_signals_str = "true" if self._vad_signals else "false" + high_vad_sensitivity_str = "true" if self._high_vad_sensitivity else "false" + + # Build common connection parameters + connect_kwargs = { + "model": self.model_name, + "vad_signals": vad_signals_str, + "high_vad_sensitivity": high_vad_sensitivity_str, + "input_audio_codec": self._input_audio_codec, + "sample_rate": str(self.sample_rate), + } + + # Choose the appropriate service based on model + if "saarika" in self.model_name.lower(): + # STT service - requires language_code + connect_kwargs["language_code"] = self._language_string + self._websocket_context = self._sarvam_client.speech_to_text_streaming.connect( + **connect_kwargs + ) + else: + # STT-Translate service - auto-detects input language and returns translated text + self._websocket_context = ( + self._sarvam_client.speech_to_text_translate_streaming.connect(**connect_kwargs) + ) + + # Enter the async context manager + self._socket_client = await self._websocket_context.__aenter__() + + # Set prompt if provided (only for STT-Translate models, after connection) + if self._prompt is not None and "saaras" in self.model_name.lower(): + await self._socket_client.set_prompt(self._prompt) + + # Register event handler for incoming messages + def _message_handler(message): + """Wrapper to handle async response handler.""" + # Use Pipecat's built-in task management + self.create_task(self._handle_message(message)) + + self._socket_client.on(EventType.MESSAGE, _message_handler) + + # Start receive task using Pipecat's task management + self._receive_task = self.create_task(self._receive_task_handler()) + + logger.info("Connected to Sarvam successfully") + + except ApiError as e: + logger.error(f"Sarvam API error: {e}") + await self.push_error(ErrorFrame(f"Sarvam API error: {e}")) + except Exception as e: + logger.error(f"Failed to connect to Sarvam: {e}") + self._socket_client = None + self._websocket_context = None + await self.push_error(ErrorFrame(f"Failed to connect to Sarvam: {e}")) + + async def _disconnect(self): + """Disconnect from Sarvam WebSocket API using SDK.""" + if self._receive_task: + await self.cancel_task(self._receive_task) + self._receive_task = None + + if self._websocket_context and self._socket_client: + try: + # Exit the async context manager + await self._websocket_context.__aexit__(None, None, None) + except Exception as e: + logger.error(f"Error closing WebSocket connection: {e}") + finally: + logger.debug("Disconnected from Sarvam WebSocket") + self._socket_client = None + self._websocket_context = None + + async def _receive_task_handler(self): + """Handle incoming messages from Sarvam WebSocket. + + This task wraps the SDK's start_listening() method which processes + messages via the registered event handler callback. + """ + if not self._socket_client: + return + + try: + # Start listening for messages from the Sarvam SDK + # Messages will be handled via the _message_handler callback + await self._socket_client.start_listening() + except Exception as e: + logger.error(f"Error in Sarvam receive task: {e}") + await self.push_error(ErrorFrame(f"Sarvam receive task error: {e}")) + + async def _handle_message(self, message): + """Handle incoming WebSocket message from Sarvam SDK. + + Processes transcription data and VAD events from the Sarvam service. + + Args: + message: The parsed response object from Sarvam WebSocket. + """ + logger.debug(f"Received response: {message}") + + try: + if message.type == "events": + # VAD event + signal = message.data.signal_type + timestamp = message.data.occured_at + logger.debug(f"VAD Signal: {signal}, Occurred at: {timestamp}") + + if signal == "START_SPEECH": + await self.start_metrics() + logger.debug("User started speaking") + await self._call_event_handler("on_speech_started") + + elif message.type == "data": + await self.stop_ttfb_metrics() + transcript = message.data.transcript + language_code = message.data.language_code + # Prefer language from message (auto-detected for translate models). Fallback to configured. + if language_code: + language = self._map_language_code_to_enum(language_code) + elif self._language_string: + language = self._map_language_code_to_enum(self._language_string) + else: + language = Language.HI_IN + + # Emit utterance end event + await self._call_event_handler("on_utterance_end") + + if transcript and transcript.strip(): + # Record tracing for this transcription event + await self._handle_transcription(transcript, True, language) + await self.push_frame( + TranscriptionFrame( + transcript, + self._user_id, + time_now_iso8601(), + language, + result=(message.dict() if hasattr(message, "dict") else str(message)), + ) + ) + + await self.stop_processing_metrics() + + except Exception as e: + logger.error(f"Error handling Sarvam message: {e}") + await self.push_error(ErrorFrame(f"Failed to handle message: {e}")) + await self.stop_all_metrics() + + @traced_stt + async def _handle_transcription( + self, transcript: str, is_final: bool, language: Optional[Language] = None + ): + """Handle a transcription result with tracing. + + This method is decorated with @traced_stt for observability. + """ + pass + + def _map_language_code_to_enum(self, language_code: str) -> Language: + """Map Sarvam language code to pipecat Language enum.""" + mapping = { + "bn-IN": Language.BN_IN, + "gu-IN": Language.GU_IN, + "hi-IN": Language.HI_IN, + "kn-IN": Language.KN_IN, + "ml-IN": Language.ML_IN, + "mr-IN": Language.MR_IN, + "ta-IN": Language.TA_IN, + "te-IN": Language.TE_IN, + "pa-IN": Language.PA_IN, + "od-IN": Language.OR_IN, + "en-US": Language.EN_US, + "en-IN": Language.EN_IN, + "as-IN": Language.AS_IN, + } + return mapping.get(language_code, Language.HI_IN) + + async def start_metrics(self): + """Start TTFB and processing metrics collection.""" + await self.start_ttfb_metrics() + await self.start_processing_metrics() diff --git a/uv.lock b/uv.lock index 129491a615..f4c9f467af 100644 --- a/uv.lock +++ b/uv.lock @@ -4550,6 +4550,7 @@ runner = [ { name = "uvicorn" }, ] sarvam = [ + { name = "sarvamai" }, { name = "websockets" }, ] sentry = [ @@ -4704,6 +4705,7 @@ requires-dist = [ { name = "python-dotenv", marker = "extra == 'runner'", specifier = ">=1.0.0,<2.0.0" }, { name = "pyvips", extras = ["binary"], marker = "extra == 'moondream'", specifier = "~=3.0.0" }, { name = "resampy", specifier = "~=0.4.3" }, + { name = "sarvamai", marker = "extra == 'sarvam'", specifier = "==0.1.21" }, { name = "sentry-sdk", marker = "extra == 'sentry'", specifier = ">=2.28.0,<3" }, { name = "simli-ai", marker = "extra == 'simli'", specifier = "~=0.1.10" }, { name = "soundfile", marker = "extra == 'soundfile'", specifier = "~=0.13.0" }, @@ -6212,6 +6214,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2c/c3/c0be1135726618dc1e28d181b8c442403d8dbb9e273fd791de2d4384bcdd/safetensors-0.6.2-cp38-abi3-win_amd64.whl", hash = "sha256:c7b214870df923cbc1593c3faee16bec59ea462758699bd3fee399d00aac072c", size = 320192, upload-time = "2025-08-08T13:13:59.467Z" }, ] +[[package]] +name = "sarvamai" +version = "0.1.21" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx" }, + { name = "pydantic" }, + { name = "pydantic-core" }, + { name = "typing-extensions" }, + { name = "websockets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/08/e5efcb30818ed220b818319255c22fd91e379489ebaa93efd6f444fb4987/sarvamai-0.1.21.tar.gz", hash = "sha256:865065635b2b99d40f5519308832954015627938e06a6333b5f62ae9c36278bb", size = 87386, upload-time = "2025-10-07T07:37:47.085Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2e/4e/b9933f72681b7aed91b86913337dd3981fad97027881fbc66c3c5eb03568/sarvamai-0.1.21-py3-none-any.whl", hash = "sha256:daa4e5d16635fe434f5f270cee416849249285369141d77132a17f0bf670f120", size = 175204, upload-time = "2025-10-07T07:37:46.024Z" }, +] + [[package]] name = "scipy" version = "1.15.3"