Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion livekit-agents/livekit/agents/inference/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class DeepgramOptions(TypedDict, total=False):
numerals: bool
mip_opt_out: bool # default: False
vad_events: bool # default: False
diarize: bool
diarize: bool # when True, enables speaker diarization (default off)
dictation: bool
detect_language: bool
no_delay: bool # default: True
Expand Down Expand Up @@ -105,6 +105,7 @@ class AssemblyaiOptions(TypedDict, total=False):
language_detection: bool
inactivity_timeout: float # seconds
prompt: str # default: not specified (u3-rt-pro only, mutually exclusive with keyterms_prompt)
speaker_labels: bool # when True, enables speaker diarization (default off)


class ElevenlabsOptions(TypedDict, total=False):
Expand Down Expand Up @@ -332,10 +333,19 @@ def __init__(
a list of FallbackModel instances.
conn_options (APIConnectOptions, optional): Connection options for request attempts.
"""
# Check extra_kwargs for diarization parameters across different providers
# Deepgram uses "diarize", AssemblyAI uses "speaker_labels"
diarization_enabled = False
if is_given(extra_kwargs):
diarization_enabled = bool(
extra_kwargs.get("diarize") or extra_kwargs.get("speaker_labels")
)

super().__init__(
capabilities=stt.STTCapabilities(
streaming=True,
interim_results=True,
diarization=diarization_enabled,
aligned_transcript="word",
offline_recognize=False,
),
Expand Down Expand Up @@ -452,6 +462,12 @@ def update_options(
self._opts.language = LanguageCode(language)
if is_given(extra):
self._opts.extra_kwargs.update(extra)
# Update diarization capability based on extra_kwargs
diarization_enabled = bool(
self._opts.extra_kwargs.get("diarize")
or self._opts.extra_kwargs.get("speaker_labels")
)
self._capabilities = replace(self._capabilities, diarization=diarization_enabled)

for stream in self._streams:
stream.update_options(model=model, language=language, extra=extra)
Expand Down Expand Up @@ -689,13 +705,15 @@ def _build_speech_data(self, data: dict) -> stt.SpeechData:
end_time=self.start_time_offset + data.get("start", 0) + data.get("duration", 0),
confidence=data.get("confidence", 1.0),
text=data.get("transcript", ""),
speaker_id=data.get("speaker_id"),
words=[
TimedString(
text=word.get("word", ""),
start_time=word.get("start", 0) + self.start_time_offset,
end_time=word.get("end", 0) + self.start_time_offset,
start_time_offset=self.start_time_offset,
confidence=word.get("confidence", 0.0),
speaker_id=word.get("speaker_id"),
)
for word in words
],
Expand Down
3 changes: 3 additions & 0 deletions livekit-agents/livekit/agents/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class TimedString(str):
confidence: NotGivenOr[float]
start_time_offset: NotGivenOr[float]
# offset relative to the start of the audio input stream or session in seconds, used in STT plugins
speaker_id: str | None

def __new__(
cls,
Expand All @@ -128,10 +129,12 @@ def __new__(
end_time: NotGivenOr[float] = NOT_GIVEN,
confidence: NotGivenOr[float] = NOT_GIVEN,
start_time_offset: NotGivenOr[float] = NOT_GIVEN,
speaker_id: str | None = None,
) -> "TimedString":
obj = super().__new__(cls, text)
obj.start_time = start_time
obj.end_time = end_time
obj.confidence = confidence
obj.start_time_offset = start_time_offset
obj.speaker_id = speaker_id
return obj
33 changes: 33 additions & 0 deletions tests/test_inference_stt_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,36 @@ def test_connect_options_full_custom(self):
assert stt._opts.conn_options.timeout == 60.0
assert stt._opts.conn_options.max_retry == 10
assert stt._opts.conn_options.retry_interval == 2.0


class TestSTTDiarizationCapabilities:
"""Tests for STT diarization capability detection from extra_kwargs."""

def test_no_diarization_by_default(self):
"""Without diarization params, capabilities.diarization is False."""
stt = _make_stt()
assert stt.capabilities.diarization is False

def test_diarization_enabled_with_deepgram_diarize(self):
"""Deepgram's 'diarize' param enables diarization capability."""
stt = _make_stt(extra_kwargs={"diarize": True})
assert stt.capabilities.diarization is True

def test_diarization_disabled_with_diarize_false(self):
"""Deepgram's 'diarize: False' keeps diarization capability False."""
stt = _make_stt(extra_kwargs={"diarize": False})
assert stt.capabilities.diarization is False

def test_diarization_enabled_with_assemblyai_speaker_labels(self):
"""AssemblyAI's 'speaker_labels' param enables diarization capability."""
stt = _make_stt(model="assemblyai/universal-streaming", extra_kwargs={"speaker_labels": True})
assert stt.capabilities.diarization is True

def test_update_options_toggles_diarization(self):
"""update_options can enable and disable diarization capability."""
stt = _make_stt()
assert stt.capabilities.diarization is False
stt.update_options(extra={"diarize": True})
assert stt.capabilities.diarization is True
stt.update_options(extra={"diarize": False})
assert stt.capabilities.diarization is False
Loading