diff --git a/src/pipecat/metrics/metrics.py b/src/pipecat/metrics/metrics.py index 4884d14f68..f4810cd0ca 100644 --- a/src/pipecat/metrics/metrics.py +++ b/src/pipecat/metrics/metrics.py @@ -87,6 +87,96 @@ class TTSUsageMetricsData(MetricsData): value: int +class STTUsage(BaseModel): + """Audio usage statistics for STT operations. + + Parameters: + audio_duration_seconds: Duration of audio processed in seconds. + requests: Number of STT requests made. + + # Content metrics (similar to TTS character counting) + word_count: Number of words transcribed. + character_count: Number of characters transcribed. + + # Performance metrics + processing_time_seconds: Total processing time in seconds. + real_time_factor: Processing time / audio duration (< 1.0 is faster than real-time). + words_per_second: Words transcribed per second (throughput). + time_to_first_transcript: Time from audio start to first transcription (like TTFT in LLMs). + time_to_final_transcript: Time from audio start to final transcription. + + # Quality metrics + average_confidence: Average confidence score (0.0 to 1.0). + word_error_rate: Word Error Rate percentage (if ground truth available). + proper_noun_accuracy: Proper noun transcription accuracy percentage. + + # Audio metadata + sample_rate: Audio sample rate in Hz (e.g., 16000). + channels: Number of audio channels (1 for mono, 2 for stereo). + encoding: Audio encoding format (e.g., "LINEAR16", "OPUS"). + + # Cost tracking + cost_per_word: Cost per word transcribed. + estimated_cost: Estimated total cost for this transcription. + + Calculation Examples: + # Words Per Second (WPS) + words_per_second = word_count / processing_time_seconds + + # Real-Time Factor (RTF) + real_time_factor = processing_time_seconds / audio_duration_seconds + # RTF < 1.0 means faster than real-time (good!) + # RTF = 0.5 means processing took half the audio duration + + # Word Error Rate (WER) - requires ground truth + wer = (substitutions + insertions + deletions) / total_reference_words * 100 + + # Cost Per Word + cost_per_word = estimated_cost / word_count + + # Time to First Transcript (TTFT) + ttft = timestamp_first_transcript - audio_start_time + """ + + audio_duration_seconds: float + requests: int = 1 + + # Content metrics + word_count: Optional[int] = None + character_count: Optional[int] = None + + # Performance metrics + processing_time_seconds: Optional[float] = None + real_time_factor: Optional[float] = None # processing_time / audio_duration + words_per_second: Optional[float] = None # word_count / processing_time + time_to_first_transcript: Optional[float] = None # TTFT in seconds + time_to_final_transcript: Optional[float] = None # Total latency + + # Quality metrics + average_confidence: Optional[float] = None # 0.0 to 1.0 + word_error_rate: Optional[float] = None # WER percentage + proper_noun_accuracy: Optional[float] = None # Proper noun accuracy percentage + + # Audio metadata + sample_rate: Optional[int] = None + channels: Optional[int] = None + encoding: Optional[str] = None + + # Cost tracking + cost_per_word: Optional[float] = None + estimated_cost: Optional[float] = None + + +class STTUsageMetricsData(MetricsData): + """Speech-to-Text usage metrics data. + + Parameters: + value: Audio duration and request statistics for the STT operation. + """ + + value: STTUsage + + class SmartTurnMetricsData(MetricsData): """Metrics data for smart turn predictions. diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py index 1ca3333b58..edc67808ed 100644 --- a/src/pipecat/processors/frame_processor.py +++ b/src/pipecat/processors/frame_processor.py @@ -434,6 +434,49 @@ async def start_tts_usage_metrics(self, text: str): if frame: await self.push_frame(frame) + async def start_stt_usage_metrics( + self, + audio_duration: float, + transcript: Optional[str] = None, + processing_time: Optional[float] = None, + confidence: Optional[float] = None, + sample_rate: Optional[int] = None, + channels: Optional[int] = None, + encoding: Optional[str] = None, + cost_per_minute: Optional[float] = None, + ttft: Optional[float] = None, + ground_truth: Optional[str] = None, + ): + """Start enhanced STT usage metrics collection with automatic calculations. + + Args: + audio_duration: Duration of audio processed in seconds (required). + transcript: The transcribed text (used to calculate word/character counts). + processing_time: Time taken to process the audio in seconds. + confidence: Average confidence score from 0.0 to 1.0. + sample_rate: Audio sample rate in Hz (e.g., 16000). + channels: Number of audio channels (1 for mono, 2 for stereo). + encoding: Audio encoding format (e.g., "LINEAR16", "OPUS"). + cost_per_minute: Cost per minute of audio (for cost estimation). + ttft: Time to first transcript in seconds. + ground_truth: Reference transcript for WER calculation (optional, for testing). + """ + if self.can_generate_metrics() and self.usage_metrics_enabled: + frame = await self._metrics.start_stt_usage_metrics( + audio_duration=audio_duration, + transcript=transcript, + processing_time=processing_time, + confidence=confidence, + sample_rate=sample_rate, + channels=channels, + encoding=encoding, + cost_per_minute=cost_per_minute, + ttft=ttft, + ground_truth=ground_truth, + ) + if frame: + await self.push_frame(frame) + async def stop_all_metrics(self): """Stop all active metrics collection.""" await self.stop_ttfb_metrics() diff --git a/src/pipecat/processors/frameworks/rtvi.py b/src/pipecat/processors/frameworks/rtvi.py index 08d127ef6b..8fedb8555b 100644 --- a/src/pipecat/processors/frameworks/rtvi.py +++ b/src/pipecat/processors/frameworks/rtvi.py @@ -65,6 +65,7 @@ from pipecat.metrics.metrics import ( LLMUsageMetricsData, ProcessingMetricsData, + STTUsageMetricsData, TTFBMetricsData, TTSUsageMetricsData, ) @@ -1197,6 +1198,10 @@ async def _handle_metrics(self, frame: MetricsFrame): if "characters" not in metrics: metrics["characters"] = [] metrics["characters"].append(d.model_dump(exclude_none=True)) + elif isinstance(d, STTUsageMetricsData): + if "stt" not in metrics: + metrics["stt"] = [] + metrics["stt"].append(d.model_dump(exclude_none=True)) message = RTVIMetricsMessage(data=metrics) await self.send_rtvi_message(message) diff --git a/src/pipecat/processors/metrics/frame_processor_metrics.py b/src/pipecat/processors/metrics/frame_processor_metrics.py index fd93241ed5..ca5f777b04 100644 --- a/src/pipecat/processors/metrics/frame_processor_metrics.py +++ b/src/pipecat/processors/metrics/frame_processor_metrics.py @@ -17,6 +17,8 @@ LLMUsageMetricsData, MetricsData, ProcessingMetricsData, + STTUsage, + STTUsageMetricsData, TTFBMetricsData, TTSUsageMetricsData, ) @@ -44,6 +46,7 @@ def __init__(self): self._start_ttfb_time = 0 self._start_processing_time = 0 self._last_ttfb_time = 0 + self._last_processing_time = 0 self._should_report_ttfb = True async def setup(self, task_manager: BaseTaskManager): @@ -83,6 +86,22 @@ def ttfb(self) -> Optional[float]: return None + @property + def processing_time(self) -> Optional[float]: + """Get the current processing time value in seconds. + + Returns: + The processing time value in seconds, or None if not measured. + """ + if self._last_processing_time > 0: + return self._last_processing_time + + # If processing is in progress, calculate current value + if self._start_processing_time > 0: + return time.time() - self._start_processing_time + + return None + def _processor_name(self): """Get the processor name from core metrics data.""" return self._core_metrics_data.processor @@ -149,6 +168,7 @@ async def stop_processing_metrics(self): return None value = time.time() - self._start_processing_time + self._last_processing_time = value logger.debug(f"{self._processor_name()} processing time: {value}") processing = ProcessingMetricsData( processor=self._processor_name(), value=value, model=self._model_name() @@ -190,3 +210,167 @@ async def start_tts_usage_metrics(self, text: str): ) logger.debug(f"{self._processor_name()} usage characters: {characters.value}") return MetricsFrame(data=[characters]) + + async def start_stt_usage_metrics( + self, + audio_duration: float, + transcript: Optional[str] = None, + processing_time: Optional[float] = None, + confidence: Optional[float] = None, + sample_rate: Optional[int] = None, + channels: Optional[int] = None, + encoding: Optional[str] = None, + cost_per_minute: Optional[float] = None, + ttft: Optional[float] = None, + ground_truth: Optional[str] = None, + ): + """Record enhanced STT usage metrics with automatic calculations. + + Args: + audio_duration: Duration of audio processed in seconds (required). + transcript: The transcribed text (used to calculate word/character counts). + processing_time: Time taken to process the audio in seconds. + confidence: Average confidence score from 0.0 to 1.0. + sample_rate: Audio sample rate in Hz (e.g., 16000). + channels: Number of audio channels (1 for mono, 2 for stereo). + encoding: Audio encoding format (e.g., "LINEAR16", "OPUS"). + cost_per_minute: Cost per minute of audio (for cost estimation). + ttft: Time to first transcript in seconds. + ground_truth: Reference transcript for WER calculation (optional, for testing). + + Returns: + MetricsFrame containing comprehensive STT usage data. + + Example: + # Basic usage (backward compatible) + await self.start_stt_usage_metrics(audio_duration=5.5) + + # Enhanced usage with all metrics + await self.start_stt_usage_metrics( + audio_duration=5.5, + transcript="Hello world this is a test", + processing_time=2.3, + confidence=0.95, + sample_rate=16000, + cost_per_minute=0.006, + ttft=0.5 + ) + """ + # Calculate content metrics from transcript + word_count = None + character_count = None + if transcript: + word_count = len(transcript.split()) + character_count = len(transcript) + + # Calculate performance metrics + real_time_factor = None + words_per_second = None + if processing_time and audio_duration > 0: + # RTF = processing_time / audio_duration + # RTF < 1.0 means faster than real-time (good!) + real_time_factor = processing_time / audio_duration + + if word_count and processing_time and processing_time > 0: + # WPS = total_words / processing_time + words_per_second = word_count / processing_time + + # Calculate cost metrics + estimated_cost = None + cost_per_word = None + if cost_per_minute and audio_duration > 0: + # Convert audio duration to minutes and calculate cost + audio_minutes = audio_duration / 60.0 + estimated_cost = audio_minutes * cost_per_minute + + if word_count and word_count > 0: + cost_per_word = estimated_cost / word_count + + # Calculate WER if ground truth is provided + word_error_rate = None + if ground_truth and transcript: + word_error_rate = self._calculate_wer(transcript, ground_truth) + + # Build usage metrics + usage = STTUsage( + audio_duration_seconds=audio_duration, + requests=1, + word_count=word_count, + character_count=character_count, + processing_time_seconds=processing_time, + real_time_factor=real_time_factor, + words_per_second=words_per_second, + time_to_first_transcript=ttft, + time_to_final_transcript=processing_time, + average_confidence=confidence, + word_error_rate=word_error_rate, + sample_rate=sample_rate, + channels=channels, + encoding=encoding, + cost_per_word=cost_per_word, + estimated_cost=estimated_cost, + ) + + value = STTUsageMetricsData( + processor=self._processor_name(), model=self._model_name(), value=usage + ) + + # Build comprehensive log message + log_parts = [f"{self._processor_name()} STT usage:"] + log_parts.append(f"{audio_duration:.3f}s audio") + if word_count: + log_parts.append(f"{word_count} words") + if words_per_second: + log_parts.append(f"{words_per_second:.1f} WPS") + if real_time_factor: + log_parts.append(f"RTF={real_time_factor:.2f}") + if estimated_cost: + log_parts.append(f"${estimated_cost:.4f}") + + logger.debug(", ".join(log_parts)) + + return MetricsFrame(data=[value]) + + def _calculate_wer(self, hypothesis: str, reference: str) -> float: + """Calculate Word Error Rate (WER) between hypothesis and reference. + + Args: + hypothesis: The transcribed text. + reference: The ground truth text. + + Returns: + WER as a percentage (0-100). + + Formula: + WER = (Substitutions + Insertions + Deletions) / Total_Reference_Words * 100 + """ + # Split into words + hyp_words = hypothesis.lower().split() + ref_words = reference.lower().split() + + # Create matrix for dynamic programming + d = [[0] * (len(ref_words) + 1) for _ in range(len(hyp_words) + 1)] + + # Initialize first row and column + for i in range(len(hyp_words) + 1): + d[i][0] = i + for j in range(len(ref_words) + 1): + d[0][j] = j + + # Calculate edit distance + for i in range(1, len(hyp_words) + 1): + for j in range(1, len(ref_words) + 1): + if hyp_words[i - 1] == ref_words[j - 1]: + d[i][j] = d[i - 1][j - 1] + else: + substitution = d[i - 1][j - 1] + 1 + insertion = d[i][j - 1] + 1 + deletion = d[i - 1][j] + 1 + d[i][j] = min(substitution, insertion, deletion) + + # Calculate WER percentage + if len(ref_words) == 0: + return 0.0 if len(hyp_words) == 0 else 100.0 + + wer = (d[len(hyp_words)][len(ref_words)] / len(ref_words)) * 100 + return round(wer, 2) diff --git a/src/pipecat/services/assemblyai/stt.py b/src/pipecat/services/assemblyai/stt.py index b3f20800c1..361d9dc69d 100644 --- a/src/pipecat/services/assemblyai/stt.py +++ b/src/pipecat/services/assemblyai/stt.py @@ -312,6 +312,35 @@ async def _handle_transcription(self, message: TurnMessage): if message.end_of_turn and ( not self._connection_params.formatted_finals or message.turn_is_formatted ): + # Calculate audio duration from word timings + audio_duration = None + confidence = None + if message.words and len(message.words) > 0: + # Audio duration is from first word start to last word end (in milliseconds) + audio_duration = (message.words[-1].end - message.words[0].start) / 1000.0 + # Calculate average confidence from all words + if all(word.confidence for word in message.words): + confidence = sum(word.confidence for word in message.words) / len(message.words) + + # Get performance metrics + ttft = self._metrics.ttfb if hasattr(self, "_metrics") else None + processing_time = self._metrics.processing_time if hasattr(self, "_metrics") else None + + # Calculate STT usage metrics + if audio_duration: + await self.start_stt_usage_metrics( + audio_duration=audio_duration, + transcript=message.transcript, + processing_time=processing_time, + ttft=ttft, + confidence=confidence, + sample_rate=self.sample_rate, + channels=1, + encoding=self._connection_params.encoding.upper(), + cost_per_minute=self._cost_per_minute, + ground_truth=self._ground_truth, + ) + await self.push_frame( TranscriptionFrame( message.transcript, diff --git a/src/pipecat/services/aws/stt.py b/src/pipecat/services/aws/stt.py index b019fc0585..89133108a5 100644 --- a/src/pipecat/services/aws/stt.py +++ b/src/pipecat/services/aws/stt.py @@ -499,6 +499,34 @@ async def _receive_loop(self): if transcript: await self.stop_ttfb_metrics() if is_final: + # Calculate audio duration from start and end time if available + audio_duration = None + if "StartTime" in result and "EndTime" in result: + audio_duration = result["EndTime"] - result["StartTime"] + + # Get performance metrics + ttft = self._metrics.ttfb if hasattr(self, "_metrics") else None + processing_time = ( + self._metrics.processing_time + if hasattr(self, "_metrics") + else None + ) + + # AWS doesn't provide confidence in streaming results + # Calculate STT usage metrics + if audio_duration: + await self.start_stt_usage_metrics( + audio_duration=audio_duration, + transcript=transcript, + processing_time=processing_time, + ttft=ttft, + sample_rate=self._settings["sample_rate"], + channels=self._settings["number_of_channels"], + encoding=self._settings["media_encoding"].upper(), + cost_per_minute=self._cost_per_minute, + ground_truth=self._ground_truth, + ) + await self.push_frame( TranscriptionFrame( transcript, diff --git a/src/pipecat/services/azure/stt.py b/src/pipecat/services/azure/stt.py index 586a94e442..79df526e43 100644 --- a/src/pipecat/services/azure/stt.py +++ b/src/pipecat/services/azure/stt.py @@ -188,6 +188,36 @@ async def _handle_transcription( def _on_handle_recognized(self, event): if event.result.reason == ResultReason.RecognizedSpeech and len(event.result.text) > 0: language = getattr(event.result, "language", None) or self._settings.get("language") + + # Calculate audio duration from event if available + audio_duration = getattr(event.result, "duration", None) + if audio_duration: + audio_duration = audio_duration.total_seconds() + + # Get performance metrics + ttft = self._metrics.ttfb if hasattr(self, "_metrics") else None + processing_time = self._metrics.processing_time if hasattr(self, "_metrics") else None + + # Get confidence if available (Azure provides this in some result types) + confidence = getattr(event.result, "confidence", None) + + # Calculate STT usage metrics if we have audio duration + if audio_duration: + asyncio.run_coroutine_threadsafe( + self.start_stt_usage_metrics( + audio_duration=audio_duration, + transcript=event.result.text, + processing_time=processing_time, + ttft=ttft, + confidence=confidence, + sample_rate=self.sample_rate, + channels=1, + encoding="LINEAR16", + ground_truth=self._ground_truth, + ), + self.get_event_loop(), + ) + frame = TranscriptionFrame( event.result.text, self._user_id, diff --git a/src/pipecat/services/cartesia/stt.py b/src/pipecat/services/cartesia/stt.py index b4e232c4ac..f87e65b403 100644 --- a/src/pipecat/services/cartesia/stt.py +++ b/src/pipecat/services/cartesia/stt.py @@ -341,6 +341,30 @@ async def _on_transcript(self, data): if len(transcript) > 0: await self.stop_ttfb_metrics() if is_final: + # Get performance metrics + ttft = self._metrics.ttfb if hasattr(self, "_metrics") else None + processing_time = ( + self._metrics.processing_time if hasattr(self, "_metrics") else None + ) + + # Cartesia may provide audio duration in the response + audio_duration = data.get("audio_duration") + + # Cartesia doesn't provide confidence scores + # Calculate STT usage metrics if we have duration + if audio_duration: + await self.start_stt_usage_metrics( + audio_duration=audio_duration, + transcript=transcript, + processing_time=processing_time, + ttft=ttft, + sample_rate=self.sample_rate, + channels=1, + encoding=self._settings.encoding.upper(), + cost_per_minute=self._cost_per_minute, + ground_truth=self._ground_truth, + ) + await self.push_frame( TranscriptionFrame( transcript, diff --git a/src/pipecat/services/deepgram/flux/stt.py b/src/pipecat/services/deepgram/flux/stt.py index f0b1a5baa9..c38c1f7570 100644 --- a/src/pipecat/services/deepgram/flux/stt.py +++ b/src/pipecat/services/deepgram/flux/stt.py @@ -558,6 +558,28 @@ async def _handle_end_of_turn(self, transcript: str, data: Dict[str, Any]): """ logger.debug("User stopped speaking") + # Get audio duration from data if available + audio_duration = data.get("duration") + + # Get performance metrics + ttft = self._metrics.ttfb if hasattr(self, "_metrics") else None + processing_time = self._metrics.processing_time if hasattr(self, "_metrics") else None + + # Deepgram Flux doesn't provide confidence in EndOfTurn events + # Calculate STT usage metrics + if audio_duration: + await self.start_stt_usage_metrics( + audio_duration=audio_duration, + transcript=transcript, + processing_time=processing_time, + ttft=ttft, + sample_rate=self.sample_rate, + channels=1, + encoding="LINEAR16", + cost_per_minute=self._cost_per_minute, + ground_truth=self._ground_truth, + ) + await self.push_frame( TranscriptionFrame( transcript, diff --git a/src/pipecat/services/deepgram/stt.py b/src/pipecat/services/deepgram/stt.py index fb5f670298..aa67a55aa5 100644 --- a/src/pipecat/services/deepgram/stt.py +++ b/src/pipecat/services/deepgram/stt.py @@ -290,6 +290,31 @@ async def _on_message(self, *args, **kwargs): if len(transcript) > 0: await self.stop_ttfb_metrics() if is_final: + # Report comprehensive STT usage metrics + # Deepgram includes duration and confidence in the result metadata + if hasattr(result, "duration") and result.duration: + audio_duration = result.duration + + # Get performance metrics + ttft = self._metrics.ttfb + processing_time = self._metrics.processing_time + + # Get confidence score if available + confidence = None + if result.channel.alternatives[0].confidence: + confidence = result.channel.alternatives[0].confidence + + # Use configured cost_per_minute (if provided) + await self.start_stt_usage_metrics( + audio_duration=audio_duration, + transcript=transcript, + processing_time=processing_time, + ttft=ttft, + confidence=confidence, + cost_per_minute=self._cost_per_minute, + ground_truth=self._ground_truth, + ) + await self.push_frame( TranscriptionFrame( transcript, diff --git a/src/pipecat/services/elevenlabs/stt.py b/src/pipecat/services/elevenlabs/stt.py index 291bad4142..88e8c70634 100644 --- a/src/pipecat/services/elevenlabs/stt.py +++ b/src/pipecat/services/elevenlabs/stt.py @@ -323,6 +323,34 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: # Use the language_code returned by the API detected_language = result.get("language_code", "eng") + # Calculate audio duration (audio is WAV format, calculate from bytes) + # WAV header is 44 bytes, then raw PCM data + # For 16-bit PCM: duration = (bytes - 44) / (sample_rate * 2) + audio_duration = None + if len(audio) > 44: + audio_duration = (len(audio) - 44) / (self.sample_rate * 2) + + # Get performance metrics + ttft = self._metrics.ttfb if hasattr(self, "_metrics") else None + processing_time = ( + self._metrics.processing_time if hasattr(self, "_metrics") else None + ) + + # ElevenLabs doesn't provide confidence scores + # Calculate STT usage metrics + if audio_duration: + await self.start_stt_usage_metrics( + audio_duration=audio_duration, + transcript=text, + processing_time=processing_time, + ttft=ttft, + sample_rate=self.sample_rate, + channels=1, + encoding="LINEAR16", + cost_per_minute=self._cost_per_minute, + ground_truth=self._ground_truth, + ) + await self._handle_transcription(text, True, detected_language) logger.debug(f"Transcription: [{text}]") diff --git a/src/pipecat/services/fal/stt.py b/src/pipecat/services/fal/stt.py index 202c03c1bb..4ce8cf9bd5 100644 --- a/src/pipecat/services/fal/stt.py +++ b/src/pipecat/services/fal/stt.py @@ -287,6 +287,34 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: if response and "text" in response: text = response["text"].strip() if text: # Only yield non-empty text + # Calculate audio duration (audio is WAV format, calculate from bytes) + # WAV header is 44 bytes, then raw PCM data + # For 16-bit PCM: duration = (bytes - 44) / (sample_rate * 2) + audio_duration = None + if len(audio) > 44: + audio_duration = (len(audio) - 44) / (self.sample_rate * 2) + + # Get performance metrics + ttft = self._metrics.ttfb if hasattr(self, "_metrics") else None + processing_time = ( + self._metrics.processing_time if hasattr(self, "_metrics") else None + ) + + # Fal doesn't provide confidence scores + # Calculate STT usage metrics + if audio_duration: + await self.start_stt_usage_metrics( + audio_duration=audio_duration, + transcript=text, + processing_time=processing_time, + ttft=ttft, + sample_rate=self.sample_rate, + channels=1, + encoding="LINEAR16", + cost_per_minute=self._cost_per_minute, + ground_truth=self._ground_truth, + ) + await self._handle_transcription(text, True, self._settings["language"]) logger.debug(f"Transcription: [{text}]") yield TranscriptionFrame( diff --git a/src/pipecat/services/gladia/stt.py b/src/pipecat/services/gladia/stt.py index f9ff91b4a6..1951f3f378 100644 --- a/src/pipecat/services/gladia/stt.py +++ b/src/pipecat/services/gladia/stt.py @@ -591,6 +591,37 @@ async def _receive_task_handler(self): transcript = utterance["text"] is_final = content["data"]["is_final"] if is_final: + # Get audio duration from utterance if available + audio_duration = None + if "time_begin" in utterance and "time_end" in utterance: + audio_duration = ( + utterance["time_end"] - utterance["time_begin"] + ) / 1000.0 + + # Get performance metrics + ttft = self._metrics.ttfb if hasattr(self, "_metrics") else None + processing_time = ( + self._metrics.processing_time if hasattr(self, "_metrics") else None + ) + + # Get confidence if available from utterance + confidence = utterance.get("confidence") + + # Calculate STT usage metrics + if audio_duration: + await self.start_stt_usage_metrics( + audio_duration=audio_duration, + transcript=transcript, + processing_time=processing_time, + ttft=ttft, + confidence=confidence, + sample_rate=self.sample_rate, + channels=1, + encoding="LINEAR16", + cost_per_minute=self._cost_per_minute, + ground_truth=self._ground_truth, + ) + await self.push_frame( TranscriptionFrame( transcript, diff --git a/src/pipecat/services/google/stt.py b/src/pipecat/services/google/stt.py index b9e56f55bf..ad0da349b7 100644 --- a/src/pipecat/services/google/stt.py +++ b/src/pipecat/services/google/stt.py @@ -859,6 +859,38 @@ async def _process_responses(self, streaming_recognize): if result.is_final: self._last_transcript_was_final = True + + # Get performance metrics + ttft = self._metrics.ttfb if hasattr(self, "_metrics") else None + processing_time = ( + self._metrics.processing_time if hasattr(self, "_metrics") else None + ) + + # Get confidence if available (Google provides word-level confidence) + confidence = None + if result.alternatives[0].confidence: + confidence = result.alternatives[0].confidence + + # Calculate audio duration from result end time if available + audio_duration = None + if hasattr(result, "result_end_offset") and result.result_end_offset: + audio_duration = result.result_end_offset.total_seconds() + + # Calculate STT usage metrics + if audio_duration: + await self.start_stt_usage_metrics( + audio_duration=audio_duration, + transcript=transcript, + processing_time=processing_time, + ttft=ttft, + confidence=confidence, + sample_rate=self.sample_rate, + channels=1, + encoding="LINEAR16", + cost_per_minute=self._cost_per_minute, + ground_truth=self._ground_truth, + ) + await self.push_frame( TranscriptionFrame( transcript, diff --git a/src/pipecat/services/riva/stt.py b/src/pipecat/services/riva/stt.py index eddd3da9e6..45dec157a4 100644 --- a/src/pipecat/services/riva/stt.py +++ b/src/pipecat/services/riva/stt.py @@ -310,6 +310,41 @@ async def _handle_response(self, response): if transcript and len(transcript) > 0: await self.stop_ttfb_metrics() if result.is_final: + # Calculate audio duration if available + audio_duration = None + if hasattr(result, "audio_processed") and result.audio_processed: + # audio_processed is in seconds + audio_duration = result.audio_processed + + # Get performance metrics + ttft = self._metrics.ttfb if hasattr(self, "_metrics") else None + processing_time = ( + self._metrics.processing_time if hasattr(self, "_metrics") else None + ) + + # Get confidence from alternative if available + confidence = None + if ( + hasattr(result.alternatives[0], "confidence") + and result.alternatives[0].confidence + ): + confidence = result.alternatives[0].confidence + + # Calculate STT usage metrics + if audio_duration: + await self.start_stt_usage_metrics( + audio_duration=audio_duration, + transcript=transcript, + processing_time=processing_time, + ttft=ttft, + confidence=confidence, + sample_rate=self.sample_rate, + channels=1, + encoding="LINEAR16", + cost_per_minute=self._cost_per_minute, + ground_truth=self._ground_truth, + ) + await self.stop_processing_metrics() await self.push_frame( TranscriptionFrame( diff --git a/src/pipecat/services/soniox/stt.py b/src/pipecat/services/soniox/stt.py index 1cf2d51948..8fd448b959 100644 --- a/src/pipecat/services/soniox/stt.py +++ b/src/pipecat/services/soniox/stt.py @@ -309,6 +309,38 @@ async def _receive_task_handler(self): async def send_endpoint_transcript(): if self._final_transcription_buffer: text = "".join(map(lambda token: token["text"], self._final_transcription_buffer)) + + # Calculate audio duration from final token timings if available + audio_duration = None + if self._final_transcription_buffer and len(self._final_transcription_buffer) > 0: + first_token = self._final_transcription_buffer[0] + last_token = self._final_transcription_buffer[-1] + if "start_ms" in first_token and "duration_ms" in last_token: + start_ms = first_token["start_ms"] + end_ms = last_token["start_ms"] + last_token["duration_ms"] + audio_duration = (end_ms - start_ms) / 1000.0 + + # Get performance metrics + ttft = self._metrics.ttfb if hasattr(self, "_metrics") else None + processing_time = ( + self._metrics.processing_time if hasattr(self, "_metrics") else None + ) + + # Soniox doesn't provide confidence in the basic token format + # Calculate STT usage metrics + if audio_duration: + await self.start_stt_usage_metrics( + audio_duration=audio_duration, + transcript=text, + processing_time=processing_time, + ttft=ttft, + sample_rate=self.sample_rate, + channels=1, + encoding="LINEAR16", + cost_per_minute=self._cost_per_minute, + ground_truth=self._ground_truth, + ) + await self.push_frame( TranscriptionFrame( text=text, diff --git a/src/pipecat/services/speechmatics/stt.py b/src/pipecat/services/speechmatics/stt.py index 901edb0e8a..ea306f8c3c 100644 --- a/src/pipecat/services/speechmatics/stt.py +++ b/src/pipecat/services/speechmatics/stt.py @@ -761,6 +761,45 @@ async def _send_frames(self, finalized: bool = False) -> None: # If final, then re-parse into TranscriptionFrame if finalized: + # Calculate audio duration from speech fragments + audio_duration = None + if self._speech_fragments and len(self._speech_fragments) > 0: + # Get first and last fragment timing + first_fragment = self._speech_fragments[0] + last_fragment = self._speech_fragments[-1] + audio_duration = last_fragment.end_time - first_fragment.start_time + + # Get performance metrics + ttft = self._metrics.ttfb if hasattr(self, "_metrics") else None + processing_time = self._metrics.processing_time if hasattr(self, "_metrics") else None + + # Calculate average confidence from fragments + confidence = None + if self._speech_fragments and all( + hasattr(f, "confidence") and f.confidence for f in self._speech_fragments + ): + confidence = sum(f.confidence for f in self._speech_fragments) / len( + self._speech_fragments + ) + + # Get transcript text for metrics + transcript_text = " ".join(frame.text for frame in speech_frames if frame.text) + + # Calculate STT usage metrics + if audio_duration and transcript_text: + await self.start_stt_usage_metrics( + audio_duration=audio_duration, + transcript=transcript_text, + processing_time=processing_time, + ttft=ttft, + confidence=confidence, + sample_rate=self.sample_rate, + channels=1, + encoding="LINEAR16", + cost_per_minute=self._cost_per_minute, + ground_truth=self._ground_truth, + ) + # Reset the speech fragments self._speech_fragments.clear() diff --git a/src/pipecat/services/stt_service.py b/src/pipecat/services/stt_service.py index 6fb96c571a..05107d57d3 100644 --- a/src/pipecat/services/stt_service.py +++ b/src/pipecat/services/stt_service.py @@ -61,6 +61,7 @@ def __init__( audio_passthrough=True, # STT input sample rate sample_rate: Optional[int] = None, + cost_per_minute: Optional[float] = None, **kwargs, ): """Initialize the STT service. @@ -70,6 +71,9 @@ def __init__( Defaults to True. sample_rate: The sample rate for audio input. If None, will be determined from the start frame. + cost_per_minute: Cost per minute of audio for usage metrics. If None, cost metrics + will not be calculated. Refer to your STT provider's pricing docs for the + specific cost of your model. **kwargs: Additional arguments passed to the parent AIService. """ super().__init__(**kwargs) @@ -80,6 +84,8 @@ def __init__( self._tracing_enabled: bool = False self._muted: bool = False self._user_id: str = "" + self._ground_truth: Optional[str] = None + self._cost_per_minute = cost_per_minute self._register_event_handler("on_connected") self._register_event_handler("on_disconnected") @@ -119,6 +125,16 @@ async def set_language(self, language: Language): """ pass + def set_ground_truth(self, ground_truth: Optional[str]): + """Set ground truth text for testing/evaluation purposes. + + This allows WER and other accuracy metrics to be calculated automatically. + + Args: + ground_truth: The expected correct transcription for accuracy comparison. + """ + self._ground_truth = ground_truth + @abstractmethod async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: """Run speech-to-text on the provided audio data. diff --git a/src/pipecat/services/whisper/base_stt.py b/src/pipecat/services/whisper/base_stt.py index 3d9151e379..77b19abf37 100644 --- a/src/pipecat/services/whisper/base_stt.py +++ b/src/pipecat/services/whisper/base_stt.py @@ -216,6 +216,24 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: text = response.text.strip() + # Calculate audio duration for usage metrics + # For 16-bit PCM audio: duration = bytes / (sample_rate * 2) + audio_duration = len(audio) / (self.sample_rate * 2) + + # Get performance metrics from metrics object + ttft = self._metrics.ttfb + processing_time = self._metrics.processing_time + + # Pass comprehensive metrics including processing time, TTFT, and ground truth (if available) + await self.start_stt_usage_metrics( + audio_duration=audio_duration, + transcript=text, + processing_time=processing_time, + ttft=ttft, + sample_rate=self.sample_rate, + ground_truth=self._ground_truth, + ) + if text: await self._handle_transcription(text, True, self._language) logger.debug(f"Transcription: [{text}]") diff --git a/tests/test_audio/sample.wav b/tests/test_audio/sample.wav new file mode 100644 index 0000000000..90104f1786 Binary files /dev/null and b/tests/test_audio/sample.wav differ diff --git a/tests/test_stt_metrics_providers.py b/tests/test_stt_metrics_providers.py new file mode 100644 index 0000000000..8db6669c1a --- /dev/null +++ b/tests/test_stt_metrics_providers.py @@ -0,0 +1,592 @@ +# +# Copyright (c) 2024-2025 Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +""" +Pytest-based STT Metrics Comparison Tests + +This test suite compares multiple STT providers with the same audio samples +and verifies their metrics (accuracy, speed, cost, quality). + +Usage: + pytest test_stt_metrics_providers.py + pytest test_stt_metrics_providers.py -v + pytest test_stt_metrics_providers.py -k "test_deepgram" + pytest test_stt_metrics_providers.py --html=report.html +""" + +import asyncio +import json +import os +import wave +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Dict, List, Optional + +import pytest +from loguru import logger + +from pipecat.frames.frames import EndFrame, InputAudioRawFrame, MetricsFrame +from pipecat.observers.base_observer import BaseObserver, FramePushed +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask + + +@dataclass +class STTTestResult: + """Results from testing a single STT provider.""" + + provider: str + model: str + + # Accuracy + transcript: str + word_error_rate: Optional[float] = None + + # Performance + audio_duration: float = 0.0 + processing_time: Optional[float] = None + real_time_factor: Optional[float] = None + words_per_second: Optional[float] = None + ttft: Optional[float] = None + + # Quality + average_confidence: Optional[float] = None + + # Cost + estimated_cost: Optional[float] = None + cost_per_word: Optional[float] = None + + # Metadata + requests: Optional[int] = None + word_count: Optional[int] = None + character_count: Optional[int] = None + + +@dataclass +class AudioTestCase: + """A test case with audio and expected transcript.""" + + name: str + audio_file: str + ground_truth: str + duration: float + + +class STTMetricsCollector(BaseObserver): + """Collects metrics from STT services.""" + + def __init__(self): + super().__init__() + self.results: List[STTTestResult] = [] + self.transcriptions: List[str] = [] # Collect all final transcriptions + self.metrics_data = None + self.latest_metrics = None # Track the most recent/complete metrics + + async def on_push_frame(self, data: FramePushed): + from pipecat.frames.frames import InterimTranscriptionFrame, TranscriptionFrame + + # Capture transcript (only final transcriptions, not interim) + if isinstance(data.frame, TranscriptionFrame) and not isinstance( + data.frame, InterimTranscriptionFrame + ): + # Store each final transcription separately instead of concatenating + self.transcriptions.append(data.frame.text) + + # Capture metrics (always update to the latest) + if isinstance(data.frame, MetricsFrame): + for metric_data in data.frame.data: + if hasattr(metric_data, "value") and hasattr( + metric_data.value, "audio_duration_seconds" + ): + # Keep updating with the latest metrics + self.latest_metrics = metric_data + + # When EndFrame arrives, finalize result with the latest metrics + if isinstance(data.frame, EndFrame): + if self.latest_metrics and not self.results: + self.metrics_data = self.latest_metrics + self._create_result() + + def _create_result(self): + """Create result from collected metrics and transcript.""" + if not self.metrics_data: + return + + # Use the last final transcription (for single-segment audio testing) + # or join multiple transcriptions if there are multiple segments + if len(self.transcriptions) == 1: + final_transcript = self.transcriptions[0].strip() + elif len(self.transcriptions) > 1: + # Log warning if multiple transcriptions detected (possible duplicate issue) + logger.warning( + f"Multiple final transcriptions detected ({len(self.transcriptions)}): {self.transcriptions}" + ) + # Use the last one as it's typically the most complete + final_transcript = self.transcriptions[-1].strip() + else: + final_transcript = "" + + usage = self.metrics_data.value + result = STTTestResult( + provider=self.metrics_data.processor, + model=self.metrics_data.model or "unknown", + transcript=final_transcript, + word_error_rate=usage.word_error_rate, + audio_duration=usage.audio_duration_seconds, + processing_time=usage.processing_time_seconds, + real_time_factor=usage.real_time_factor, + words_per_second=usage.words_per_second, + ttft=usage.time_to_first_transcript, + average_confidence=usage.average_confidence, + estimated_cost=usage.estimated_cost, + cost_per_word=usage.cost_per_word, + requests=usage.requests, + word_count=usage.word_count, + character_count=usage.character_count, + ) + self.results.append(result) + # Clear metrics_data so we don't create duplicate results + self.metrics_data = None + + +async def run_stt_service_test( + service_name: str, stt_service, test_case: AudioTestCase, ground_truth: Optional[str] = None +) -> Optional[STTTestResult]: + """Run a single STT service with a test case and return results.""" + + logger.info(f"Testing {service_name} with '{test_case.name}'...") + + # Set ground truth on the service for automatic WER calculation + if ground_truth: + stt_service.set_ground_truth(ground_truth) + + # Create observer + collector = STTMetricsCollector() + + # Create pipeline + pipeline = Pipeline([stt_service]) + + # Create task with metrics enabled + task = PipelineTask( + pipeline, + params=PipelineParams( + enable_metrics=True, + enable_usage_metrics=True, + ), + ) + task.add_observer(collector) + + try: + # Load audio file + with wave.open(test_case.audio_file, "rb") as wf: + sample_rate = wf.getframerate() + num_channels = wf.getnchannels() + audio_data = wf.readframes(wf.getnframes()) + + # Send VAD frame to indicate user started speaking (for SegmentedSTTService) + from pipecat.frames.frames import UserStartedSpeakingFrame, UserStoppedSpeakingFrame + + await task.queue_frame(UserStartedSpeakingFrame()) + + # Send audio frames + audio_frame = InputAudioRawFrame( + audio=audio_data, sample_rate=sample_rate, num_channels=num_channels + ) + await task.queue_frame(audio_frame) + + # Send VAD frame to indicate user stopped speaking (triggers transcription) + await task.queue_frame(UserStoppedSpeakingFrame()) + + await task.queue_frame(EndFrame()) + + # Run pipeline + runner = PipelineRunner() + await runner.run(task) + + # Return results + return collector.results[0] if collector.results else None + + except Exception as e: + logger.error(f"Error testing {service_name}: {e}") + return None + + +async def run_stt_comparison(test_cases: List[AudioTestCase], providers: Dict) -> Dict: + """Run comparison across all providers and test cases.""" + + all_results = {} + + for test_case in test_cases: + logger.info(f"\n{'=' * 70}") + logger.info(f"Test Case: {test_case.name}") + logger.info(f"Ground Truth: '{test_case.ground_truth}'") + logger.info(f"{'=' * 70}\n") + + test_results = [] + + for provider_name, stt_service in providers.items(): + try: + result = await run_stt_service_test( + provider_name, stt_service, test_case, test_case.ground_truth + ) + if result: + test_results.append(result) + wer_str = ( + f"WER: {result.word_error_rate:.2f}%" if result.word_error_rate else "" + ) + logger.info(f"โœ… {provider_name}: '{result.transcript}' {wer_str}") + except Exception as e: + logger.error(f"โŒ {provider_name} failed: {e}") + + all_results[test_case.name] = test_results + + return all_results + + +def generate_comparison_report(results: Dict, output_file: str = "stt_comparison_report.json"): + """Generate a detailed comparison report.""" + + logger.info(f"\n{'=' * 70}") + logger.info("STT METRICS COMPARISON REPORT") + logger.info(f"{'=' * 70}\n") + + report = {} + + for test_name, test_results in results.items(): + logger.info(f"\n๐Ÿ“Š Test: {test_name}") + logger.info("-" * 70) + + if not test_results: + logger.warning("No results for this test case") + continue + + # Sort by different metrics + by_accuracy = sorted( + test_results, key=lambda x: x.word_error_rate if x.word_error_rate is not None else 100 + ) + by_speed = sorted( + test_results, + key=lambda x: x.real_time_factor if x.real_time_factor is not None else 999, + ) + by_cost = sorted( + test_results, key=lambda x: x.estimated_cost if x.estimated_cost is not None else 999 + ) + + logger.info("\n๐ŸŽฏ ACCURACY (by WER - lower is better):") + for result in by_accuracy: + wer_str = ( + f"{result.word_error_rate:.2f}%" if result.word_error_rate is not None else "N/A" + ) + conf_str = ( + f"{result.average_confidence:.2%}" + if result.average_confidence is not None + else "N/A" + ) + logger.info( + f" {result.provider:20} WER: {wer_str:8} | Confidence: {conf_str} | '{result.transcript}'" + ) + + logger.info("\nโšก SPEED (by RTF - lower is better):") + for result in by_speed: + rtf_str = ( + f"{result.real_time_factor:.3f}" if result.real_time_factor is not None else "N/A" + ) + wps_str = ( + f"{result.words_per_second:.1f}" if result.words_per_second is not None else "N/A" + ) + ttft_str = f"{result.ttft:.3f}s" if result.ttft is not None else "N/A" + logger.info( + f" {result.provider:20} RTF: {rtf_str:8} | WPS: {wps_str:8} | TTFT: {ttft_str}" + ) + + logger.info("\n๐Ÿ’ฐ COST (lower is better):") + for result in by_cost: + cost_str = ( + f"${result.estimated_cost:.6f}" if result.estimated_cost is not None else "N/A" + ) + cpw_str = f"${result.cost_per_word:.6f}" if result.cost_per_word is not None else "N/A" + logger.info(f" {result.provider:20} Total: {cost_str:12} | Per word: {cpw_str}") + + # Convert to dict for JSON + report[test_name] = [asdict(r) for r in test_results] + + # Save to JSON + output_path = Path(__file__).parent / output_file + with open(output_path, "w") as f: + json.dump(report, f, indent=2) + + logger.info(f"\n๐Ÿ“ Full report saved to: {output_path}") + + # Recommendations + logger.info(f"\n{'=' * 70}") + logger.info("๐ŸŽฏ RECOMMENDATIONS") + logger.info(f"{'=' * 70}") + + # Find best overall + for test_name, test_results in results.items(): + if not test_results: + continue + + best_accuracy = min( + test_results, key=lambda x: x.word_error_rate if x.word_error_rate is not None else 100 + ) + best_speed = min( + test_results, + key=lambda x: x.real_time_factor if x.real_time_factor is not None else 999, + ) + best_cost = min( + test_results, key=lambda x: x.estimated_cost if x.estimated_cost is not None else 999 + ) + + logger.info(f"\nFor '{test_name}':") + if best_accuracy.word_error_rate is not None: + logger.info( + f" ๐ŸŽฏ Best Accuracy: {best_accuracy.provider} (WER: {best_accuracy.word_error_rate:.2f}%)" + ) + if best_speed.real_time_factor is not None: + logger.info( + f" โšก Fastest: {best_speed.provider} (RTF: {best_speed.real_time_factor:.3f})" + ) + if best_cost.estimated_cost is not None: + logger.info(f" ๐Ÿ’ฐ Cheapest: {best_cost.provider} (${best_cost.estimated_cost:.6f})") + + +# ============================================================================ +# Pytest Fixtures +# ============================================================================ + + +@pytest.fixture(scope="session") +def test_audio_dir(): + """Get the test_audio directory path.""" + return Path(__file__).parent / "test_audio" + + +@pytest.fixture(scope="session") +def test_cases(test_audio_dir): + """Get list of test cases from test_audio directory.""" + test_cases = [] + + test_definitions = [ + ("greeting.wav", "Hello, how are you today?", 2.5), + ("technical.wav", "The API endpoint returns JSON with authentication headers.", 4.0), + ("numbers.wav", "The meeting is scheduled for March 15th at 3:30 PM.", 3.5), + ("sample.wav", "You might also want to consider setting up a page on Facebook", 3.0), + ] + + for audio_file, ground_truth, duration in test_definitions: + audio_path = test_audio_dir / audio_file + if audio_path.exists(): + test_cases.append( + AudioTestCase( + name=audio_file.replace(".wav", ""), + audio_file=str(audio_path), + ground_truth=ground_truth, + duration=duration, + ) + ) + + if not test_cases: + pytest.skip( + f"No test audio files found in {test_audio_dir}. Please create test audio files first." + ) + + return test_cases + + +@pytest.fixture(scope="session") +def available_providers(): + """Get dictionary of available STT providers based on environment variables.""" + providers = {} + + # Deepgram + if os.getenv("DEEPGRAM_API_KEY"): + try: + from pipecat.services.deepgram.stt import DeepgramSTTService + + providers["deepgram"] = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + logger.info("โœ… Deepgram configured") + except ImportError: + logger.warning("โš ๏ธ Deepgram not available (install pipecat-ai[deepgram])") + + # OpenAI Whisper + if os.getenv("OPENAI_API_KEY"): + try: + from pipecat.services.openai.stt import OpenAISTTService + + providers["openai"] = OpenAISTTService(api_key=os.getenv("OPENAI_API_KEY")) + logger.info("โœ… OpenAI configured") + except ImportError: + logger.warning("โš ๏ธ OpenAI not available (install pipecat-ai[openai])") + + # Azure + if os.getenv("AZURE_SPEECH_API_KEY") and os.getenv("AZURE_SPEECH_REGION"): + try: + from pipecat.services.azure.stt import AzureSTTService + + providers["azure"] = AzureSTTService( + api_key=os.getenv("AZURE_SPEECH_API_KEY"), + region=os.getenv("AZURE_SPEECH_REGION"), + ) + logger.info("โœ… Azure configured") + except ImportError: + logger.warning("โš ๏ธ Azure not available (install pipecat-ai[azure])") + + # Google + if os.getenv("GOOGLE_APPLICATION_CREDENTIALS"): + try: + from pipecat.services.google.stt import GoogleSTTService + + providers["google"] = GoogleSTTService() + logger.info("โœ… Google configured") + except ImportError: + logger.warning("โš ๏ธ Google not available (install pipecat-ai[google])") + + # Groq Whisper + if os.getenv("GROQ_API_KEY"): + try: + from pipecat.services.groq.stt import GroqSTTService + + providers["groq"] = GroqSTTService(api_key=os.getenv("GROQ_API_KEY")) + logger.info("โœ… Groq configured") + except ImportError: + logger.warning("โš ๏ธ Groq not available (install pipecat-ai[groq])") + + # AssemblyAI + if os.getenv("ASSEMBLYAI_API_KEY"): + try: + from pipecat.services.assemblyai.stt import AssemblyAISTTService + + providers["assemblyai"] = AssemblyAISTTService(api_key=os.getenv("ASSEMBLYAI_API_KEY")) + logger.info("โœ… AssemblyAI configured") + except ImportError: + logger.warning("โš ๏ธ AssemblyAI not available (install pipecat-ai[assemblyai])") + + # ElevenLabs + if os.getenv("ELEVENLABS_API_KEY"): + try: + from pipecat.services.elevenlabs.stt import ElevenLabsSTTService + + providers["elevenlabs"] = ElevenLabsSTTService(api_key=os.getenv("ELEVENLABS_API_KEY")) + logger.info("โœ… ElevenLabs configured") + except ImportError: + logger.warning("โš ๏ธ ElevenLabs not available (install pipecat-ai[elevenlabs])") + + if not providers: + pytest.skip("No STT providers configured. Please set API keys as environment variables.") + + return providers + + +# ============================================================================ +# Pytest Test Functions +# ============================================================================ + + +@pytest.mark.asyncio +async def test_stt_service_basic(available_providers, test_cases): + """Test that at least one STT service works with at least one test case.""" + # This is a basic smoke test + provider_name = list(available_providers.keys())[0] + stt_service = available_providers[provider_name] + test_case = test_cases[0] + + result = await run_stt_service_test( + provider_name, stt_service, test_case, test_case.ground_truth + ) + + assert result is not None, f"Failed to get result from {provider_name}" + assert result.transcript, "Transcript should not be empty" + # The provider name from metrics includes the service class name, just verify it's not empty + assert result.provider, "Provider name should not be empty" + logger.info(f"โœ… Test passed for {provider_name}: '{result.transcript}'") + + +def get_comprehensive_test_params(available_providers, test_cases): + """Generate parameters for comprehensive testing (provider x test_case combinations).""" + params = [] + ids = [] + for provider_name, provider_service in available_providers.items(): + for test_case in test_cases: + params.append((provider_name, provider_service, test_case)) + ids.append(f"{provider_name}-{test_case.name}") + return {"argnames": "provider_name,provider_service,test_case", "argvalues": params, "ids": ids} + + +@pytest.fixture(scope="module") +def comprehensive_params(available_providers, test_cases): + """Fixture to provide comprehensive test parameters.""" + return get_comprehensive_test_params(available_providers, test_cases) + + +@pytest.mark.asyncio +async def test_stt_providers_comprehensive(available_providers, test_cases): + """Test each STT provider with each test case comprehensively (all combinations).""" + # Test each provider with each test case + for provider_name, provider_service in available_providers.items(): + for test_case in test_cases: + logger.info(f"\n๐Ÿงช Testing {provider_name} with '{test_case.name}'...") + + result = await run_stt_service_test( + provider_name, provider_service, test_case, test_case.ground_truth + ) + + # Assertions + assert result is not None, f"Failed to get result from {provider_name}" + assert result.transcript, f"Transcript should not be empty for {provider_name}" + assert result.provider, "Provider name should not be empty" + + # Log results + logger.info(f"โœ… {provider_name} - {test_case.name}: '{result.transcript}'") + if result.word_error_rate is not None: + logger.info(f" WER: {result.word_error_rate:.2f}%") + assert result.word_error_rate >= 0, "WER should be non-negative" + + if result.real_time_factor is not None: + logger.info(f" RTF: {result.real_time_factor:.3f}") + assert result.real_time_factor > 0, "RTF should be positive" + + +@pytest.mark.asyncio +async def test_stt_comparison_all(available_providers, test_cases, tmp_path): + """Run full comparison across all providers and generate report.""" + logger.info("๐ŸŽค Running comprehensive STT comparison") + logger.info(f"Providers: {list(available_providers.keys())}") + logger.info(f"Test cases: {[tc.name for tc in test_cases]}") + + # Run comparison + results = await run_stt_comparison(test_cases, available_providers) + + # Verify results + assert results, "Should have results" + for test_name, test_results in results.items(): + assert test_results, f"Should have results for {test_name}" + for result in test_results: + assert result.transcript, f"Transcript should not be empty for {result.provider}" + + # Generate report + output_file = tmp_path / "stt_comparison_report.json" + + # Save report + report = {} + for test_name, test_results in results.items(): + report[test_name] = [asdict(r) for r in test_results] + + with open(output_file, "w") as f: + json.dump(report, f, indent=2) + + logger.info(f"๐Ÿ“ Report saved to: {output_file}") + + # Verify report was created + assert output_file.exists(), "Report file should be created" + + # Display summary + for test_name, test_results in results.items(): + logger.info(f"\n๐Ÿ“Š Test: {test_name}") + for result in test_results: + wer_str = f"WER: {result.word_error_rate:.2f}%" if result.word_error_rate else "N/A" + logger.info(f" {result.provider}: '{result.transcript}' ({wer_str})")