diff --git a/livekit-plugins/livekit-plugins-oracle/README.md b/livekit-plugins/livekit-plugins-oracle/README.md new file mode 100644 index 0000000000..9fcd73df70 --- /dev/null +++ b/livekit-plugins/livekit-plugins-oracle/README.md @@ -0,0 +1,17 @@ +# Oracle plugins for LiveKit Agents + +Support for Oracle's RTS, GenAI, and TTS services. + +See [https://docs.livekit.io/agents/integrations/oracle/](https://docs.livekit.io/agents/integrations/oracle/) for more information. + +## Installation + +```bash +pip install livekit-plugins-oracle ~= 1.2" +pip install "oci-ai-speech-realtime ~= 2.2" +``` + +## Pre-requisites + +For credentials, you will need an Oracle Cloud Infrastructure (OCI) account and pass the credential information into whichever plug-ins +you use (STT, LLM, and / or TTS). diff --git a/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/__init__.py b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/__init__.py new file mode 100644 index 0000000000..85738e27fd --- /dev/null +++ b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/__init__.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Oracle Corporation and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Oracle plug-ins for LiveKit Agents + +Support for Oracle RTS, GenAI, and TTS services. +""" + +from livekit.agents import Plugin + +from .llm import LLM +from .log import logger +from .oracle_llm import BackEnd, Role +from .stt import STT +from .tts import TTS +from .utils import AuthenticationType +from .version import __version__ + +__all__ = ["STT", "LLM", "TTS", "AuthenticationType", "BackEnd", "Role", "__version__"] + + +class OraclePlugin(Plugin): + def __init__(self) -> None: + super().__init__(__name__, __version__, __package__, logger) + + +Plugin.register_plugin(OraclePlugin()) + + +# Cleanup docs of unexported modules +_module = dir() +NOT_IN_ALL = [m for m in _module if m not in __all__] + + +__pdoc__ = {} +for n in NOT_IN_ALL: + __pdoc__[n] = False diff --git a/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/audio_cache.py b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/audio_cache.py new file mode 100644 index 0000000000..e57e4b66cb --- /dev/null +++ b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/audio_cache.py @@ -0,0 +1,220 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Oracle Corporation and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module implements simple audio caching used by the Oracle LiveKit TTS plug-in. + +Author: Keith Schnable (at Oracle Corporation) +Date: 2025-08-12 +""" + +import json +import os +import time + +from livekit.agents import utils + +INDEX_FILE_NAME = "index.json" + + +class AudioCache: + """ + The audio cache class. + """ + + def __init__( + self, *, audio_cache_file_path: str, audio_cache_maximum_number_of_utterances: int + ): + self._audio_cache_file_path = audio_cache_file_path + self._audio_cache_maximum_number_of_utterances = audio_cache_maximum_number_of_utterances + + if not os.path.exists(self._audio_cache_file_path): + os.makedirs(self._audio_cache_file_path) + + self._index_file_spec = os.path.join(self._audio_cache_file_path, INDEX_FILE_NAME) + + if os.path.exists(self._index_file_spec): + with open(self._index_file_spec, encoding="utf-8") as file: + index_json_text = file.read() + self._index_dictionary = json.loads(index_json_text) + else: + self._index_dictionary = {} + + def get_audio_bytes( + self, *, text: str, voice: str, audio_rate: int, audio_channels: int, audio_bits: int + ): + """ + Get the audio bytes for the specified text, voice, audio rate, audio channels, and audio bits. + + Parameters: + text (str): The text. + voice (str): The voice. + audio_rate (int): The audio rate (16000 for example). + audio_channels (int): The audio channels (1 for example). + audio_bits (int): The audio bits (16 for example). + + Returns: + bytes: The audio bytes. + """ + + key = AudioCache.form_key( + text=text, + voice=voice, + audio_rate=audio_rate, + audio_channels=audio_channels, + audio_bits=audio_bits, + ) + + if key in self._index_dictionary: + dictionary = self._index_dictionary[key] + audio_bytes_file_name = dictionary["audio_bytes_file_name"] + audio_bytes_file_spec = os.path.join(self._audio_cache_file_path, audio_bytes_file_name) + if os.path.exists(audio_bytes_file_spec): + write_index_dictionary = True + dictionary["last_accessed_milliseconds"] = int(time.time() * 1000) + with open(audio_bytes_file_spec, "rb") as file: + audio_bytes = file.read() + else: + del self._index_dictionary[key] + write_index_dictionary = True + audio_bytes = None + else: + write_index_dictionary = False + audio_bytes = None + + if write_index_dictionary: + with open(self._index_file_spec, "w", encoding="utf-8") as file: + json.dump(self._index_dictionary, file, indent=4) + + return audio_bytes + + def set_audio_bytes( + self, + *, + text: str, + voice: str, + audio_rate: int, + audio_channels: int, + audio_bits: int, + audio_bytes: bytes, + ): + """ + Set the audio bytes for the specified text, voice, audio rate, audio channels, audio bits, and audio bytes. + + Parameters: + text (str): The text. + voice (str): The voice. + audio_rate (int): The audio rate (16000 for example). + audio_channels (int): The audio channels (1 for example). + audio_bits (int): The audio bits (16 for example). + audio_bytes (bytes) : The audio bytes. + + Returns: + (nothing) + """ + + key = AudioCache.form_key( + text=text, + voice=voice, + audio_rate=audio_rate, + audio_channels=audio_channels, + audio_bits=audio_bits, + ) + + if key in self._index_dictionary: + dictionary = self._index_dictionary[key] + audio_bytes_file_name = dictionary["audio_bytes_file_name"] + write_index_dictionary = False + else: + audio_bytes_file_name = str(utils.shortuuid()) + dictionary = {} + dictionary["audio_bytes_file_name"] = audio_bytes_file_name + dictionary["created_milliseconds"] = int(time.time() * 1000) + dictionary["last_accessed_milliseconds"] = dictionary["created_milliseconds"] + self._index_dictionary[key] = dictionary + write_index_dictionary = True + + audio_bytes_file_spec = os.path.join(self._audio_cache_file_path, audio_bytes_file_name) + + with open(audio_bytes_file_spec, "wb") as file: + file.write(audio_bytes) + + if write_index_dictionary: + with open(self._index_file_spec, "w", encoding="utf-8") as file: + json.dump(self._index_dictionary, file, indent=4) + + self.clean_up_old_utterances() + + def clean_up_old_utterances(self): + """ + Clean up old utterance files based on the audio_cache_maximum_number_of_utterances parameter. + The oldest utterances get deleted first. + + Parameters: + (none) + + Returns: + (nothing) + """ + + while len(self._index_dictionary) > self._audio_cache_maximum_number_of_utterances: + oldest_key = None + oldest_dictionary = None + for key, dictionary in self._index_dictionary: + if ( + oldest_dictionary is None + or dictionary["last_accessed_milliseconds"] + < oldest_dictionary["last_accessed_milliseconds"] + ): + oldest_key = key + oldest_dictionary = dictionary + + audio_bytes_file_name = oldest_dictionary["audio_bytes_file_name"] + audio_bytes_file_spec = os.path.join(self._audio_cache_file_path, audio_bytes_file_name) + if os.path.exists(audio_bytes_file_spec): + os.remove(audio_bytes_file_spec) + del self._index_dictionary[oldest_key] + + with open(self._index_file_spec, "w", encoding="utf-8") as file: + json.dump(self._index_dictionary, file, indent=4) + + @staticmethod + def form_key(*, text: str, voice: str, audio_rate: int, audio_channels: int, audio_bits: int): + """ + Form the key for the specified text, voice, audio rate, audio channels, and audio bits. + + Parameters: + text (str): The text. + voice (str): The voice. + audio_rate (int): The audio rate (16000 for example). + audio_channels (int): The audio channels (1 for example). + audio_bits (int): The audio bits (16 for example). + + Returns: + (nothing) + """ + + key = ( + voice + + "\t" + + str(audio_rate) + + "\t" + + str(audio_channels) + + "\t" + + str(audio_bits) + + "\t" + + text + ) + return key diff --git a/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/llm.py b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/llm.py new file mode 100644 index 0000000000..4a614b78e1 --- /dev/null +++ b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/llm.py @@ -0,0 +1,362 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Oracle Corporation and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module is the Oracle LiveKit LLM plug-in. + +Author: Keith Schnable (at Oracle Corporation) +Date: 2025-08-12 +""" + +from __future__ import annotations + +import ast +import json + +from livekit.agents import llm, utils +from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions +from livekit.plugins.openai.utils import to_fnc_ctx + +from .log import logger +from .oracle_llm import ( + CONTENT_TYPE_STRING, + TOOL_CALL_DESCRIPTION, + TOOL_CALL_PREFIX, + BackEnd, + OracleLLM, + OracleLLMContent, + OracleTool, + OracleValue, + Role, +) +from .utils import AuthenticationType + + +class LLM(llm.LLM): + """ + The Oracle LiveKit LLM plug-in class. This derives from livekit.agents.llm.LLM. + """ + + def __init__( + self, + *, + base_url: str, # must be specified + authentication_type: AuthenticationType = AuthenticationType.SECURITY_TOKEN, + authentication_configuration_file_spec: str = "~/.oci/config", + authentication_profile_name: str = "DEFAULT", + back_end: BackEnd = BackEnd.GEN_AI_LLM, + # these apply only if back_end == BackEnd.GEN_AI_LLM + compartment_id: str | None = None, # must be specified + model_type: str = "GENERIC", # must be "GENERIC" or "COHERE" + model_id: str | None = None, # must be specified or model_name must be specified + model_name: str | None = None, # must be specified or model_id must be specified + maximum_number_of_tokens: int | None = None, + temperature: float | None = None, + top_p: float | None = None, + top_k: int | None = None, + frequency_penalty: float | None = None, + presence_penalty: float | None = None, + seed: int | None = None, + # these apply only if back_end == BackEnd.GEN_AI_AGENT + agent_endpoint_id: str | None = None, # must be specified + ) -> None: + """ + Create a new instance of the OracleLLM class to access Oracle's GenAI service. This has LiveKit dependencies. + + Args: + base_url: Base URL. Type is str. No default (must be specified). + authentication_type: Authentication type. Type is AuthenticationType enum (API_KEY, SECURITY_TOKEN, INSTANCE_PRINCIPAL, or RESOURCE_PRINCIPAL). Default is SECURITY_TOKEN. + authentication_configuration_file_spec: Authentication configuration file spec. Type is str. Default is "~/.oci/config". Applies only for API_KEY or SECURITY_TOKEN. + authentication_profile_name: Authentication profile name. Type is str. Default is "DEFAULT". Applies only for API_KEY or SECURITY_TOKEN. + back_end: Back-end. Type is BackEnd enum (GEN_AI_LLM or GEN_AI_AGENT). Default is GEN_AI_LLM. + compartment_id: Compartment ID. Type is str. Default is None (must be specified). Applies only for GEN_AI_LLM. + model_type: Model type. Type is str. Default is "GENERIC". Must be one of "GENERIC" or "COHERE". Applies only for GEN_AI_LLM. + model_id: Model ID. Type is str. Default is None (must be specified or model_name must be specified). Applies only for GEN_AI_LLM. + model_name: Model name. Type is name. Default is None (must be specified or model_id must be specified). Applies only for GEN_AI_LLM. + maximum_number_of_tokens: Maximum number of tokens. Type is int. Default is None. Applies only for GEN_AI_LLM. + temperature: Temperature. Type is float. Default is None. Applies only for GEN_AI_LLM. + top_p: Top-P. Type is float. Default is None. Applies only for GEN_AI_LLM. + top_k: Top-K. Type is int. Default is None. Applies only for GEN_AI_LLM. + frequency_penalty: Frequency penalty. Type is float. Default is None. Applies only for GEN_AI_LLM. + presence_penalty: Presence penalty. Type is float. Default is None. Applies only for GEN_AI_LLM. + seed: Seed. Type is int. Default is None. Applies only for GEN_AI_LLM. + agent_endpoint_id: Agent endpoint ID. Type is str. Default is None (must be specified). Applies only for GEN_AI_AGENT. + """ + + super().__init__() + + self._oracle_llm = OracleLLM( + base_url=base_url, + authentication_type=authentication_type, + authentication_configuration_file_spec=authentication_configuration_file_spec, + authentication_profile_name=authentication_profile_name, + back_end=back_end, + compartment_id=compartment_id, + model_type=model_type, + model_id=model_id, + model_name=model_name, + maximum_number_of_tokens=maximum_number_of_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + seed=seed, + agent_endpoint_id=agent_endpoint_id, + ) + + # + # currently this is never cleaned up because it appears that the past tool calls may + # always be needed to construct the entire conversation history. if this is not actually + # the case, theoretically old keys that are no longer referenced should be removed. + # + self._call_id_to_tool_call_dictionary = {} + + logger.debug("Initialized LLM.") + + def chat( + self, + *, + chat_ctx: llm.ChatContext, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + tools=None, + tool_choice=None, + extra_kwargs=None, + ) -> LLMStream: + return LLMStream( + oracle_llm_livekit_plugin=self, + chat_ctx=chat_ctx, + conn_options=conn_options, + tools=tools, + ) + + +class LLMStream(llm.LLMStream): + """ + The LLM stream class. This derives from livekit.agents.llm.LLMStream. + """ + + def __init__( + self, + *, + oracle_llm_livekit_plugin: LLM, + chat_ctx: llm.ChatContext, + conn_options: None, + tools: None, + ) -> None: + super().__init__( + oracle_llm_livekit_plugin, chat_ctx=chat_ctx, tools=None, conn_options=conn_options + ) + + self._oracle_llm_livekit_plugin = oracle_llm_livekit_plugin + + self._tools = LLMStream.convert_tools(tools) + + logger.debug("Converted tools.") + + async def _run(self) -> None: + oracle_llm_content_list = [] + + for chat_message in self._chat_ctx._items: + if chat_message.type == "message": + role = Role(chat_message.role.upper()) + for message in chat_message.content: + oracle_llm_content = OracleLLMContent(message, CONTENT_TYPE_STRING, role) + oracle_llm_content_list.append(oracle_llm_content) + + elif chat_message.type == "function_call_output": + call_id = chat_message.call_id + + tool_call = self._oracle_llm_livekit_plugin._call_id_to_tool_call_dictionary.get( + call_id + ) + + if tool_call is not None: + try: + output_json = json.loads(chat_message.output) + message = output_json["text"] + except Exception: + message = chat_message.output + + oracle_llm_content = OracleLLMContent( + tool_call, CONTENT_TYPE_STRING, Role.ASSISTANT + ) + oracle_llm_content_list.append(oracle_llm_content) + + oracle_llm_content = OracleLLMContent( + "The function result of " + tool_call + " is: " + message, + CONTENT_TYPE_STRING, + Role.SYSTEM, + ) + oracle_llm_content_list.append(oracle_llm_content) + + logger.debug( + "Before running content thru LLM. Content list count: " + + str(len(oracle_llm_content_list)) + ) + + response_messages = self._oracle_llm_livekit_plugin._oracle_llm.run( + oracle_llm_content_list=oracle_llm_content_list, tools=self._tools + ) + + logger.debug( + "After running content thru LLM. Response message list count: " + + str(len(response_messages)) + + "." + ) + + for response_message in response_messages: + if response_message.startswith(TOOL_CALL_PREFIX): + tool_call = response_message + + logger.debug("External tool call needs to be made: " + tool_call) + + function_name, function_parameters = ( + LLMStream.get_name_and_arguments_from_tool_call(tool_call) + ) + + tool = None + for temp_tool in self._tools: + if temp_tool.name == function_name and len(temp_tool.parameters) == len( + function_parameters + ): + tool = temp_tool + + if tool is None: + raise Exception( + "Unknown function name: " + + function_name + + " in " + + TOOL_CALL_DESCRIPTION + + " response message: " + + tool_call + + "." + ) + + function_parameters_text = "{" + for i in range(len(function_parameters)): + parameter = tool.parameters[i] + if i > 0: + function_parameters_text += "," + function_parameters_text += '"' + parameter.name + '":' + is_string_parameter = parameter.type in {"string", "str"} + if is_string_parameter: + function_parameters_text += '"' + function_parameters_text += str(function_parameters[i]) + if is_string_parameter: + function_parameters_text += '"' + function_parameters_text += "}" + + call_id = utils.shortuuid() + + self._oracle_llm_livekit_plugin._call_id_to_tool_call_dictionary[call_id] = ( + tool_call + ) + + function_tool_call = llm.FunctionToolCall( + name=function_name, arguments=function_parameters_text, call_id=call_id + ) + + choice_delta = llm.ChoiceDelta( + role=Role.ASSISTANT.name.lower(), content=None, tool_calls=[function_tool_call] + ) + + chat_chunk = llm.ChatChunk(id=utils.shortuuid(), delta=choice_delta, usage=None) + + self._event_ch.send_nowait(chat_chunk) + + logger.debug("Added tool call to event channel: " + tool_call) + + else: + logger.debug("LLM response message: " + response_message) + + chat_chunk = llm.ChatChunk( + id=utils.shortuuid(), + delta=llm.ChoiceDelta( + content=response_message, role=Role.ASSISTANT.name.lower() + ), + ) + + self._event_ch.send_nowait(chat_chunk) + + logger.debug("Added response message to event channel: " + response_message) + + @staticmethod + def convert_tools(livekit_tools): + tools = [] + + if livekit_tools is not None: + function_contexts = to_fnc_ctx(livekit_tools) + + for function_context in function_contexts: + type = function_context["type"] + if type == "function": + function = function_context["function"] + + function_name = function["name"] + function_description = function["description"] + if function_description is None or len(function_description) == 0: + function_description = function_name + + function_parameters = function["parameters"] + + parameters = [] + for property_key, property_value in function_parameters["properties"].items(): + parameter_name = property_key + if "description" in property_value: + parameter_description = property_value["description"] + elif "title" in property_value: + parameter_description = property_value["title"] + else: + parameter_description = parameter_name + parameter_type = property_value["type"] + + parameter = OracleValue( + parameter_name, parameter_description, parameter_type + ) + parameters.append(parameter) + + tool = OracleTool(function_name, function_description, parameters) + tools.append(tool) + + if len(tools) == 0: + return None + + return tools + + @staticmethod + def get_name_and_arguments_from_tool_call(tool_call): + tool_call = tool_call[len(TOOL_CALL_PREFIX) :].strip() + + function_name, function_parameters = LLMStream.parse_function_call( + tool_call, TOOL_CALL_DESCRIPTION + ) + + return function_name, function_parameters + + @staticmethod + def parse_function_call(code_string, description): + expression = ast.parse(code_string, mode="eval").body + + if not isinstance(expression, ast.Call): + raise Exception("Invalid " + description + ": " + code_string + ".") + + function_name = expression.func.id if isinstance(expression.func, ast.Name) else None + if not function_name: + raise Exception("Invalid " + description + ": " + code_string + ".") + + function_parameters = [ast.literal_eval(parameter) for parameter in expression.args] + + return function_name, function_parameters diff --git a/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/log.py b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/log.py new file mode 100644 index 0000000000..841b55c64d --- /dev/null +++ b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/log.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Oracle Corporation and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +TRACE_LEVEL = 5 + +logging.addLevelName(TRACE_LEVEL, "TRACE") +logging.TRACE = TRACE_LEVEL + + +def trace(self, message, *args, **kwargs): + if self.isEnabledFor(TRACE_LEVEL): + self._log(TRACE_LEVEL, message, args, **kwargs) + + +logging.Logger.trace = trace + +logging.basicConfig(level=TRACE_LEVEL, format="%(asctime)s - %(levelname)s - %(message)s") + +logger = logging.getLogger("livekit.plugins.oracle") diff --git a/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/oracle_llm.py b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/oracle_llm.py new file mode 100644 index 0000000000..514db363fd --- /dev/null +++ b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/oracle_llm.py @@ -0,0 +1,632 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Oracle Corporation and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module wraps Oracle's LLM cloud service. While it is used by the Oracle LiveKit LLM plug-in, +it it completely indpendent of LiveKit and could be used in other environments besides LiveKit. + +Author: Keith Schnable (at Oracle Corporation) +Date: 2025-08-12 +""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass +from enum import Enum +from typing import Any +from urllib.parse import urlparse + +import oci + +from .log import logger +from .utils import AuthenticationType, get_config_and_signer + + +class BackEnd(Enum): + """Back-ends as enumerator.""" + + GEN_AI_LLM = "GEN_AI_LLM" + GEN_AI_AGENT = "GEN_AI_AGENT" + + +CONTENT_TYPE_STRING = "string" + + +class Role(Enum): + """Roles as enumerator.""" + + USER = "USER" + SYSTEM = "SYSTEM" + ASSISTANT = "ASSISTANT" + DEVELOPER = "DEVELOPER" + + +TOOL_CALL_PREFIX = "TOOL-CALL:" +TOOL_CALL_DESCRIPTION = "tool-call" + + +class OracleLLM: + """ + The Oracle LLM class. This class wraps the Oracle LLM service. + """ + + def __init__( + self, + *, + base_url: str, # must be specified + authentication_type: AuthenticationType = AuthenticationType.SECURITY_TOKEN, + authentication_configuration_file_spec: str = "~/.oci/config", + authentication_profile_name: str = "DEFAULT", + back_end: BackEnd = BackEnd.GEN_AI_LLM, + # these apply only if back_end == BackEnd.GEN_AI_LLM + compartment_id: str | None = None, # must be specified + model_type: str = "GENERIC", # must be "GENERIC" or "COHERE" + model_id: str | None = None, # must be specified or model_name must be specified + model_name: str | None = None, # must be specified or model_id must be specified + maximum_number_of_tokens: int | None = None, + temperature: float | None = None, + top_p: float | None = None, + top_k: int | None = None, + frequency_penalty: float | None = None, + presence_penalty: float | None = None, + seed: int | None = None, + # these apply only if back_end == BackEnd.GEN_AI_AGENT + agent_endpoint_id: str | None = None, # must be specified + ) -> None: + """ + Create a new instance of the OracleLLM class to access Oracle's GenAI service. This has no LiveKit dependencies. + + Args: + base_url: Base URL. Type is str. No default (must be specified). + authentication_type: Authentication type. Type is AuthenticationType enum (API_KEY, SECURITY_TOKEN, INSTANCE_PRINCIPAL, or RESOURCE_PRINCIPAL). Default is SECURITY_TOKEN. + authentication_configuration_file_spec: Authentication configuration file spec. Type is str. Default is "~/.oci/config". Applies only for API_KEY or SECURITY_TOKEN. + authentication_profile_name: Authentication profile name. Type is str. Default is "DEFAULT". Applies only for API_KEY or SECURITY_TOKEN. + back_end: Back-end. Type is BackEnd enum (GEN_AI_LLM or GEN_AI_AGENT). Default is GEN_AI_LLM. + compartment_id: Compartment ID. Type is str. Default is None (must be specified). Applies only for GEN_AI_LLM. + model_type: Model type. Type is str. Default is "GENERIC". Must be one of "GENERIC" or "COHERE". Applies only for GEN_AI_LLM. + model_id: Model ID. Type is str. Default is None (must be specified or model_name must be specified). Applies only for GEN_AI_LLM. + model_name: Model name. Type is name. Default is None (must be specified or model_id must be specified). Applies only for GEN_AI_LLM. + maximum_number_of_tokens: Maximum number of tokens. Type is int. Default is None. Applies only for GEN_AI_LLM. + temperature: Temperature. Type is float. Default is None. Applies only for GEN_AI_LLM. + top_p: Top-P. Type is float. Default is None. Applies only for GEN_AI_LLM. + top_k: Top-K. Type is int. Default is None. Applies only for GEN_AI_LLM. + frequency_penalty: Frequency penalty. Type is float. Default is None. Applies only for GEN_AI_LLM. + presence_penalty: Presence penalty. Type is float. Default is None. Applies only for GEN_AI_LLM. + seed: Seed. Type is int. Default is None. Applies only for GEN_AI_LLM. + agent_endpoint_id: Agent endpoint ID. Type is str. Default is None (must be specified). Applies only for GEN_AI_AGENT. + """ + + self._parameters = Parameters() + + self._parameters.base_url = base_url + self._parameters.authentication_type = authentication_type + self._parameters.authentication_configuration_file_spec = ( + authentication_configuration_file_spec + ) + self._parameters.authentication_profile_name = authentication_profile_name + self._parameters.back_end = back_end + + self._parameters.compartment_id = compartment_id + self._parameters.model_type = model_type + self._parameters.model_id = model_id + self._parameters.model_name = model_name + self._parameters.maximum_number_of_tokens = maximum_number_of_tokens + self._parameters.temperature = temperature + self._parameters.top_p = top_p + self._parameters.top_k = top_k + self._parameters.frequency_penalty = frequency_penalty + self._parameters.presence_penalty = presence_penalty + self._parameters.seed = seed + + self._parameters.agent_endpoint_id = agent_endpoint_id + + self.validate_parameters() + + self._number_of_runs = 0 + + self._output_tool_descriptions = True + + if self._parameters.back_end == BackEnd.GEN_AI_LLM: + self.initialize_for_llm() + else: # if self._parameters.back_end == BackEnd.GEN_AI_AGENT: + self.initialize_for_agent() + + logger.debug("Initialized OracleLLM.") + + def validate_parameters(self): + if not isinstance(self._parameters.base_url, str): + raise TypeError("The base_url parameter must be a string.") + self._parameters.base_url = self._parameters.base_url.strip() + parsed = urlparse(self._parameters.base_url) + if not all([parsed.scheme, parsed.netloc]): + raise ValueError("The base_url parameter must be a valid URL.") + + if not isinstance(self._parameters.authentication_type, AuthenticationType): + raise TypeError( + "The authentication_type parameter must be one of the AuthenticationType enum members." + ) + + if self._parameters.authentication_type in { + AuthenticationType.API_KEY, + AuthenticationType.SECURITY_TOKEN, + }: + if not isinstance(self._parameters.authentication_configuration_file_spec, str): + raise TypeError( + "The authentication_configuration_file_spec parameter must be a string." + ) + self._parameters.authentication_configuration_file_spec = ( + self._parameters.authentication_configuration_file_spec.strip() + ) + if len(self._parameters.authentication_configuration_file_spec) == 0: + raise ValueError( + "The authentication_configuration_file_spec parameter must not be an empty string." + ) + + if not isinstance(self._parameters.authentication_profile_name, str): + raise TypeError("The authentication_profile_name parameter must be a string.") + self._parameters.authentication_profile_name = ( + self._parameters.authentication_profile_name.strip() + ) + if len(self._parameters.authentication_profile_name) == 0: + raise ValueError( + "The authentication_profile_name parameter must not be an empty string." + ) + + if not isinstance(self._parameters.back_end, BackEnd): + raise TypeError("The back_end parameter must be one of the BackEnd enum members.") + + if self._parameters.back_end == BackEnd.GEN_AI_LLM: + if not isinstance(self._parameters.compartment_id, str): + raise TypeError("The compartment_id parameter must be a string.") + self._parameters.compartment_id = self._parameters.compartment_id.strip() + if len(self._parameters.compartment_id) == 0: + raise ValueError("The compartment_id parameter must not be an empty string.") + + if not isinstance(self._parameters.model_type, str): + raise TypeError("The model_type parameter must be a string.") + self._parameters.model_type = self._parameters.model_type.strip().upper() + if self._parameters.model_type not in {"GENERIC", "COHERE"}: + raise ValueError("The model_type parameter must be 'GENERIC' or 'COHERE'.") + + if self._parameters.model_id is not None: + if not isinstance(self._parameters.model_id, str): + raise TypeError("The model_id parameter must be a string.") + self._parameters.model_id = self._parameters.model_id.strip() + if len(self._parameters.model_id) == 0: + raise ValueError("The model_id parameter must not be an empty string.") + + if self._parameters.model_name is not None: + if not isinstance(self._parameters.model_name, str): + raise TypeError("The model_name parameter must be a string.") + self._parameters.model_name = self._parameters.model_name.strip() + if len(self._parameters.model_name) == 0: + raise ValueError("The model_name parameter must not be an empty string.") + + if self._parameters.model_id is None: + if self._parameters.model_name is None: + raise TypeError( + "Either the model_id or the model_name parameter must not be None." + ) + elif self._parameters.model_name is not None: + raise TypeError("Either the model_id or the model_name parameter must be None.") + + if self._parameters.maximum_number_of_tokens is not None: + if not isinstance(self._parameters.maximum_number_of_tokens, int): + raise TypeError("The maximum_number_of_tokens parameter must be an integer.") + if self._parameters.maximum_number_of_tokens <= 0: + raise ValueError( + "The maximum_number_of_tokens parameter must be greater than 0." + ) + + if self._parameters.temperature is not None: + if not isinstance(self._parameters.temperature, float): + raise TypeError("The temperature parameter must be a float.") + if self._parameters.temperature < 0: + raise ValueError( + "The maximum_number_of_tokens parameter must be greater than or equal to 0." + ) + + if self._parameters.top_p is not None: + if not isinstance(self._parameters.top_p, float): + raise TypeError("The top_p parameter must be a float.") + if self._parameters.top_p < 0 or self._parameters.top_p > 1: + raise ValueError("The top_p parameter must be between 0 and 1.") + + if self._parameters.top_k is not None: + if not isinstance(self._parameters.top_k, int): + raise TypeError("The top_k parameter must be an integer.") + if self._parameters.top_k <= 0: + raise ValueError("The top_k parameter must be greater than 0.") + + if self._parameters.frequency_penalty is not None: + if not isinstance(self._parameters.frequency_penalty, float): + raise TypeError("The frequency_penalty parameter must be a float.") + if self._parameters.frequency_penalty < 0: + raise ValueError( + "The frequency_penalty parameter must be greater than or equal to 0." + ) + + if self._parameters.presence_penalty is not None: + if not isinstance(self._parameters.presence_penalty, float): + raise TypeError("The presence_penalty parameter must be a float.") + if self._parameters.presence_penalty < 0: + raise ValueError( + "The presence_penalty parameter must be greater than or equal to 0." + ) + + if self._parameters.seed is not None: # noqa: SIM102 + if not isinstance(self._parameters.seed, int): + raise TypeError("The seed parameter must be an integer.") + + elif self._parameters.back_end == BackEnd.GEN_AI_AGENT: + if not isinstance(self._parameters.agent_endpoint_id, str): + raise TypeError("The agent_endpoint_id parameter must be a string.") + self._parameters.agent_endpoint_id = self._parameters.agent_endpoint_id.strip() + if len(self._parameters.agent_endpoint_id) == 0: + raise ValueError("The agent_endpoint_id parameter must not be an empty string.") + + def initialize_for_llm(self): + configAndSigner = get_config_and_signer( + authentication_type=self._parameters.authentication_type, + authentication_configuration_file_spec=self._parameters.authentication_configuration_file_spec, + authentication_profile_name=self._parameters.authentication_profile_name, + ) + config = configAndSigner["config"] + signer = configAndSigner["signer"] + + if signer is None: + self._generative_ai_inference_client = ( + oci.generative_ai_inference.GenerativeAiInferenceClient( + config=config, + service_endpoint=self._parameters.base_url, + retry_strategy=oci.retry.NoneRetryStrategy(), + ) + ) + else: + self._generative_ai_inference_client = ( + oci.generative_ai_inference.GenerativeAiInferenceClient( + config=config, + service_endpoint=self._parameters.base_url, + retry_strategy=oci.retry.NoneRetryStrategy(), + signer=signer, + ) + ) + + logger.debug("Initialized for GenAI LLM.") + + def initialize_for_agent(self): + configAndSigner = get_config_and_signer( + authentication_type=self._parameters.authentication_type, + authentication_configuration_file_spec=self._parameters.authentication_configuration_file_spec, + authentication_profile_name=self._parameters.authentication_profile_name, + ) + config = configAndSigner["config"] + signer = configAndSigner["signer"] + + if signer is None: + self._generative_ai_agent_runtime_client = ( + oci.generative_ai_agent_runtime.GenerativeAiAgentRuntimeClient( + config=config, + service_endpoint=self._parameters.base_url, + retry_strategy=oci.retry.NoneRetryStrategy(), + ) + ) + else: + self._generative_ai_agent_runtime_client = ( + oci.generative_ai_agent_runtime.GenerativeAiAgentRuntimeClient( + config=config, + service_endpoint=self._parameters.base_url, + retry_strategy=oci.retry.NoneRetryStrategy(), + signer=signer, + ) + ) + + id = str(uuid.uuid4()) + + session_details = oci.generative_ai_agent_runtime.models.CreateSessionDetails( + display_name="display_name_for_" + id, description="description_for_" + id + ) + + response = self._generative_ai_agent_runtime_client.create_session( + agent_endpoint_id=self._parameters.agent_endpoint_id, + create_session_details=session_details, + ) + self._session_id = response.data.id + + logger.debug("Initialized for GenAI Agent.") + + def run( + self, + *, + oracle_llm_content_list: list[OracleLLMContent] = None, + tools: list[OracleTool] = None, + ) -> list[str]: + if self._parameters.back_end == BackEnd.GEN_AI_LLM: + response_messages = self.run_for_llm( + oracle_llm_content_list=oracle_llm_content_list, tools=tools + ) + else: # if self._parameters.back_end == BackEnd.GEN_AI_AGENT: + response_messages = self.run_for_agent( + oracle_llm_content_list=oracle_llm_content_list, tools=tools + ) + + self._number_of_runs += 1 + + return response_messages + + def run_for_llm( + self, + *, + oracle_llm_content_list: list[OracleLLMContent] = None, + tools: list[OracleTool] = None, + ) -> list[str]: + if oracle_llm_content_list is None: + oracle_llm_content_list = [] + + temp_message_list = [] + temp_messages = "" + + tool_descriptions = self.get_tool_descriptions(tools) + if tool_descriptions is not None: + text_content = oci.generative_ai_inference.models.TextContent() + text_content.text = tool_descriptions + + message = oci.generative_ai_inference.models.Message() + message.role = Role.SYSTEM.name + message.content = [text_content] + + temp_message_list.append(message) + + if len(temp_messages) > 0: + temp_messages += "\n" + + temp_messages += tool_descriptions + + for oracle_llm_content in oracle_llm_content_list: + if oracle_llm_content.content_type == CONTENT_TYPE_STRING: + text_content = oci.generative_ai_inference.models.TextContent() + text_content.text = oracle_llm_content.content_data + + message = oci.generative_ai_inference.models.Message() + message.role = oracle_llm_content.role.name + message.content = [text_content] + + temp_message_list.append(message) + + if len(temp_messages) > 0: + temp_messages += "\n" + + temp_messages += oracle_llm_content.content_data + + if self._parameters.model_type == "GENERIC": + chat_request = oci.generative_ai_inference.models.GenericChatRequest() + chat_request.messages = temp_message_list + + elif self._parameters.model_type == "COHERE": + chat_request = oci.generative_ai_inference.models.CohereChatRequest() + chat_request.message = temp_messages + + if self._parameters.maximum_number_of_tokens is not None: + chat_request.max_tokens = self._parameters.maximum_number_of_tokens + if self._parameters.temperature is not None: + chat_request.temperature = self._parameters.temperature + if self._parameters.frequency_penalty is not None: + chat_request.frequency_penalty = self._parameters.frequency_penalty + if self._parameters.presence_penalty is not None: + chat_request.presence_penalty = self._parameters.presence_penalty + if self._parameters.top_p is not None: + chat_request.top_p = self._parameters.top_p + if self._parameters.top_k is not None: + chat_request.top_k = self._parameters.top_k + if self._parameters.seed is not None: + chat_request.seed = self._parameters.seed + + serving_mode = oci.generative_ai_inference.models.OnDemandServingMode( + model_id=self._parameters.model_name + if self._parameters.model_id is None + else self._parameters.model_id + ) + + chat_details = oci.generative_ai_inference.models.ChatDetails() + chat_details.serving_mode = serving_mode + chat_details.chat_request = chat_request + chat_details.compartment_id = self._parameters.compartment_id + + logger.debug("Before calling GenAI LLM.") + + chat_response = self._generative_ai_inference_client.chat(chat_details) + + logger.debug("After calling GenAI LLM.") + + if self._parameters.model_type == "GENERIC": + response_messages = [] + for temp_content in chat_response.data.chat_response.choices[0].message.content: + response_messages.append(temp_content.text) + elif self._parameters.model_type == "COHERE": + response_messages = [chat_response.data.chat_response.text] + + new_response_messages = [] + + for response_message in response_messages: + logger.debug("Raw response message: " + response_message) + tool_call_index = response_message.find(TOOL_CALL_PREFIX) + if tool_call_index == -1: + new_response_messages.append(response_message) + logger.debug("Response message: " + response_message) + else: + tool_call = response_message[tool_call_index:] + if tool_call_index != 0: + response_message = response_message[:tool_call_index] + new_response_messages.append(response_message) + logger.debug("Response message: " + response_message) + new_response_messages.append(tool_call) + logger.debug("Tool call: " + tool_call) + + response_messages = new_response_messages + + return response_messages + + def run_for_agent( + self, + *, + oracle_llm_content_list: list[OracleLLMContent] = None, + tools: list[OracleTool] = None, + ) -> list[str]: + if oracle_llm_content_list is None: + oracle_llm_content_list = [] + + user_message = "" + + if self._number_of_runs == 0: + tool_descriptions = self.get_tool_descriptions(tools) + if tool_descriptions is not None: + if len(user_message) > 0: + user_message += "\n" + user_message += tool_descriptions + + for oracle_llm_content in reversed(oracle_llm_content_list): + if oracle_llm_content.content_type == CONTENT_TYPE_STRING: + if len(user_message) > 0: + user_message += "\n" + user_message += oracle_llm_content.content_data + break + + logger.debug(user_message) + + chat_details = oci.generative_ai_agent_runtime.models.ChatDetails( + session_id=self._session_id, user_message=user_message, should_stream=False + ) + + logger.debug("Before calling GenAI agent.") + + response = self._generative_ai_agent_runtime_client.chat( + agent_endpoint_id=self._parameters.agent_endpoint_id, chat_details=chat_details + ) + + logger.debug("After calling GenAI agent.") + + response_message = response.data.message.content.text + + logger.debug(response_message) + + response_messages = [response_message] + + if ( + TOOL_CALL_PREFIX in response_message + and response_message.find(TOOL_CALL_PREFIX, 1) != -1 + ): + raise Exception( + "Unexpectedly received a response message with an embedded " + + TOOL_CALL_DESCRIPTION + + "." + ) + + return response_messages + + def get_tool_descriptions(self, tools): + if tools is None or len(tools) == 0: + return None + + tool_descriptions = "You are an assistant with access to the following functions:\n\n" + + for i in range(len(tools)): + tool = tools[i] + + tool_descriptions += str(i + 1) + ". The function prototype is: " + tool.name + "(" + + for j in range(len(tool.parameters)): + parameter = tool.parameters[j] + if j > 0: + tool_descriptions += "," + tool_descriptions += parameter.name + + tool_descriptions += ") and the function description is: " + tool.description + "\n" + + tool_descriptions += ( + '\nAlways indicate when you want to call a function by writing: "' + + TOOL_CALL_PREFIX + + ' function_name(parameters)"\n' + ) + tool_descriptions += "Do not combine function calls and text responses in the same output: either only function calls or only text responses.\n" + tool_descriptions += ( + "For any string parameters, be sure to enclose each of them in double quotes." + ) + + if self._output_tool_descriptions: + self._output_tool_descriptions = False + logger.debug(tool_descriptions) + + return tool_descriptions + + +class Parameters: + """ + The parameters class. This class contains all parameter information for the Oracle LLM class. + """ + + base_url: str + + back_end: str + + compartment_id: str + authentication_type: AuthenticationType + authentication_configuration_file_spec: str + authentication_profile_name: str + model_type: str + model_id: str + model_name: str + maximum_number_of_tokens: int + temperature: float + top_p: float + top_k: int + frequency_penalty: float + presence_penalty: float + seed: int + + agent_endpoint_id: str + + +@dataclass +class OracleLLMContent: + """ + The Oracle LLM content class. This class contains all information related to one LLM content item. + """ + + content_data: Any + content_type: str + role: Role + + +@dataclass +class OracleValue: + """ + The Oracle value class. This class contains all information related to one value. + """ + + name: str + description: str + type: str + + +@dataclass +class OracleTool: + """ + The Oracle tool class. This class contains all information related to one tool. + """ + + name: str + description: str + parameters: list[OracleValue] diff --git a/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/oracle_stt.py b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/oracle_stt.py new file mode 100644 index 0000000000..ba6dc90e4e --- /dev/null +++ b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/oracle_stt.py @@ -0,0 +1,426 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Oracle Corporation and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module wraps Oracle's STT cloud service. While it is used by the Oracle LiveKit STT plug-in, +it it completely indpendent of LiveKit and could be used in other environments besides LiveKit. + +Author: Keith Schnable (at Oracle Corporation) +Date: 2025-08-12 +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from urllib.parse import urlparse + +from oci.ai_speech.models import ( + RealtimeMessageAckAudio, + RealtimeMessageConnect, + RealtimeMessageError, + RealtimeMessageResult, + RealtimeParameters, +) +from oci_ai_speech_realtime import RealtimeSpeechClient, RealtimeSpeechClientListener + +from .log import logger +from .utils import AuthenticationType, get_config_and_signer + + +class OracleSTT(RealtimeSpeechClientListener): + """ + The Oracle STT class. This class wraps the Oracle STT service. + """ + + def __init__( + self, + *, + base_url: str, # must be specified + compartment_id: str, # must be specified + authentication_type: AuthenticationType = AuthenticationType.SECURITY_TOKEN, + authentication_configuration_file_spec: str = "~/.oci/config", + authentication_profile_name: str = "DEFAULT", + request_id_prefix: str = "", + sample_rate: int = 16000, + language_code: str = "en-US", + model_domain: str = "GENERIC", + is_ack_enabled: bool = False, + partial_silence_threshold_milliseconds: int = 0, + final_silence_threshold_milliseconds: int = 2000, + stabilize_partial_results: str = "NONE", + punctuation: str = "NONE", + customizations: list[dict] | None = None, + should_ignore_invalid_customizations: bool = False, + return_partial_results: bool = False, + ) -> None: + """ + Create a new instance of the OracleSTT class to access Oracle's RTS service. This has no LiveKit dependencies. + + Args: + base_url: Base URL. Type is str. No Default (must be specified). + compartment_id: Compartment ID. Type is str. No default (must be specified). + authentication_type: Authentication type. Type is AuthenticationType enum (API_KEY, SECURITY_TOKEN, INSTANCE_PRINCIPAL, or RESOURCE_PRINCIPAL). Default is SECURITY_TOKEN. + authentication_configuration_file_spec: Authentication configuration file spec. Type is str. Default is "~/.oci/config". Applies only for API_KEY or SECURITY_TOKEN. + authentication_profile_name: Authentication profile name. Type is str. Default is "DEFAULT". Applies only for API_KEY or SECURITY_TOKEN. + request_id_prefix: Request ID prefix. Type is str. Default is "". + sample_rate: Sample rate. Type is int. Default is 16000. + language_code: Language code. Type is str. Default is "en-US". + model_domain: Model domain. Type is str. Default is "GENERIC". + is_ack_enabled: Is-ack-enabled flag. Type is bool. Default is False. + partial_silence_threshold_milliseconds: Partial silence threshold milliseconds. Type is int. Default is 0. + final_silence_threshold_milliseconds: Final silence threshold milliseconds. Type is int. Default is 2000. + stabilize_partial_results: Stabilize partial results. Type is str. Default is "NONE". Must be one of "NONE", "LOW", "MEDIUM", or "HIGH". + punctuation: Punctuation. Type is str. Default is "NONE". Must be one of "NONE", "SPOKEN", or "AUTO". + customizations: Customizations. Type is list[dict]. Default is None. + should_ignore_invalid_customizations. Should-ignore-invalid-customizations flag. Type is bool. Default is False. + return_partial_results. Return-partial-results flag. Type is bool. Default is False. + """ + + self._parameters = Parameters() + self._parameters.base_url = base_url + self._parameters.compartment_id = compartment_id + self._parameters.authentication_type = authentication_type + self._parameters.authentication_configuration_file_spec = ( + authentication_configuration_file_spec + ) + self._parameters.authentication_profile_name = authentication_profile_name + self._parameters.request_id_prefix = request_id_prefix + self._parameters.sample_rate = sample_rate + self._parameters.language_code = language_code + self._parameters.model_domain = model_domain + self._parameters.is_ack_enabled = is_ack_enabled + self._parameters.partial_silence_threshold_milliseconds = ( + partial_silence_threshold_milliseconds + ) + self._parameters.final_silence_threshold_milliseconds = final_silence_threshold_milliseconds + self._parameters.stabilize_partial_results = stabilize_partial_results + self._parameters.punctuation = punctuation + self._parameters.customizations = customizations + self._parameters.should_ignore_invalid_customizations = should_ignore_invalid_customizations + self._parameters.return_partial_results = return_partial_results + + self.validate_parameters() + + self._audio_bytes_queue = asyncio.Queue() + self._speech_result_queue = asyncio.Queue() + + self._real_time_speech_client = None + self._connected = False + + asyncio.create_task(self.add_audio_bytes_background_task()) + + self.real_time_speech_client_open() + + logger.debug("Initialized OracleSTT.") + + def validate_parameters(self): + if not isinstance(self._parameters.base_url, str): + raise TypeError("The base_url parameter must be a string.") + self._parameters.base_url = self._parameters.base_url.strip() + parsed = urlparse(self._parameters.base_url) + if not all([parsed.scheme, parsed.netloc]): + raise ValueError("The base_url parameter must be a valid URL.") + + if not isinstance(self._parameters.compartment_id, str): + raise TypeError("The compartment_id parameter must be a string.") + self._parameters.compartment_id = self._parameters.compartment_id.strip() + if len(self._parameters.compartment_id) == 0: + raise ValueError("The compartment_id parameter must not be an empty string.") + + if not isinstance(self._parameters.authentication_type, AuthenticationType): + raise TypeError( + "The authentication_type parameter must be one of the AuthenticationType enum members." + ) + + if self._parameters.authentication_type in { + AuthenticationType.API_KEY, + AuthenticationType.SECURITY_TOKEN, + }: + if not isinstance(self._parameters.authentication_configuration_file_spec, str): + raise TypeError( + "The authentication_configuration_file_spec parameter must be a string." + ) + self._parameters.authentication_configuration_file_spec = ( + self._parameters.authentication_configuration_file_spec.strip() + ) + if len(self._parameters.authentication_configuration_file_spec) == 0: + raise ValueError( + "The authentication_configuration_file_spec parameter must not be an empty string." + ) + + if not isinstance(self._parameters.authentication_profile_name, str): + raise TypeError("The authentication_profile_name parameter must be a string.") + self._parameters.authentication_profile_name = ( + self._parameters.authentication_profile_name.strip() + ) + if len(self._parameters.authentication_profile_name) == 0: + raise ValueError( + "The authentication_profile_name parameter must not be an empty string." + ) + + if not isinstance(self._parameters.request_id_prefix, str): + raise TypeError("The request_id_prefix parameter must be a string.") + + if not isinstance(self._parameters.sample_rate, int): + raise TypeError("The sample_rate parameter must be an integer.") + if self._parameters.sample_rate <= 0: + raise ValueError("The sample_rate parameter must be greater than 0.") + + if not isinstance(self._parameters.language_code, str): + raise TypeError("The language_code parameter must be a string.") + self._parameters.language_code = self._parameters.language_code.strip() + if len(self._parameters.language_code) == 0: + raise ValueError("The language_code parameter must not be an empty string.") + + if not isinstance(self._parameters.model_domain, str): + raise TypeError("The model_domain parameter must be a string.") + self._parameters.model_domain = self._parameters.model_domain.strip() + if len(self._parameters.model_domain) == 0: + raise ValueError("The model_domain parameter must not be an empty string.") + + if not isinstance(self._parameters.is_ack_enabled, bool): + raise TypeError("The is_ack_enabled parameter must be a boolean.") + + if not isinstance(self._parameters.final_silence_threshold_milliseconds, int): + raise TypeError( + "The final_silence_threshold_milliseconds parameter must be an integer." + ) + if self._parameters.final_silence_threshold_milliseconds <= 0: + raise ValueError( + "The final_silence_threshold_milliseconds parameter must be greater than 0." + ) + + if not isinstance(self._parameters.return_partial_results, bool): + raise TypeError("The return_partial_results parameter must be a boolean.") + + if self._parameters.return_partial_results: + if not isinstance(self._parameters.partial_silence_threshold_milliseconds, int): + raise TypeError( + "The partial_silence_threshold_milliseconds parameter must be an integer." + ) + if self._parameters.partial_silence_threshold_milliseconds <= 0: + raise ValueError( + "The partial_silence_threshold_milliseconds parameter must be greater than 0." + ) + else: + self._parameters.partial_silence_threshold_milliseconds = ( + self._parameters.final_silence_threshold_milliseconds + ) + + if not isinstance(self._parameters.stabilize_partial_results, str): + raise TypeError("The stabilize_partial_results parameter must be a string.") + self._parameters.stabilize_partial_results = ( + self._parameters.stabilize_partial_results.strip().upper() + ) + if self._parameters.stabilize_partial_results not in {"NONE", "LOW", "MEDIUM", "HIGH"}: + raise ValueError( + "The stabilize_partial_results parameter must be 'NONE', 'LOW', 'MEDIUM', or 'HIGH'." + ) + + if not isinstance(self._parameters.punctuation, str): + raise TypeError("The punctuation parameter must be a string.") + self._parameters.punctuation = self._parameters.punctuation.strip().upper() + if self._parameters.punctuation not in {"NONE", "SPOKEN", "AUTO"}: + raise ValueError("The punctuation parameter must be 'NONE', 'SPOKEN', or 'AUTO'.") + + if self._parameters.customizations is not None and ( + not isinstance(self._parameters.customizations, list) + or not all(isinstance(item, dict) for item in self._parameters.customizations) + ): + raise TypeError("The customizations parameter must be None or a list of dictionaries.") + + if not isinstance(self._parameters.should_ignore_invalid_customizations, bool): + raise TypeError("The should_ignore_invalid_customizations parameter must be a boolean.") + + def add_audio_bytes(self, audio_bytes: bytes) -> None: + self._audio_bytes_queue.put_nowait(audio_bytes) + + def get_speech_result_queue(self) -> asyncio.Queue: + return self._speech_result_queue + + def real_time_speech_client_open(self) -> None: + self.real_time_speech_client_close() + + configAndSigner = get_config_and_signer( + authentication_type=self._parameters.authentication_type, + authentication_configuration_file_spec=self._parameters.authentication_configuration_file_spec, + authentication_profile_name=self._parameters.authentication_profile_name, + ) + config = configAndSigner["config"] + signer = configAndSigner["signer"] + + real_time_parameters = RealtimeParameters() + + real_time_parameters.encoding = "audio/raw;rate=" + str(self._parameters.sample_rate) + real_time_parameters.language_code = self._parameters.language_code + real_time_parameters.model_domain = self._parameters.model_domain + real_time_parameters.is_ack_enabled = self._parameters.is_ack_enabled + real_time_parameters.partial_silence_threshold_in_ms = ( + self._parameters.partial_silence_threshold_milliseconds + ) + real_time_parameters.final_silence_threshold_in_ms = ( + self._parameters.final_silence_threshold_milliseconds + ) + real_time_parameters.stabilize_partial_results = self._parameters.stabilize_partial_results + real_time_parameters.punctuation = self._parameters.punctuation + if self._parameters.customizations is not None: + real_time_parameters.customizations = self._parameters._customizations + real_time_parameters.should_ignore_invalid_customizations = ( + self._parameters.should_ignore_invalid_customizations + ) + + real_time_speech_client_listener = self + + compartment_id = self._parameters.compartment_id + + # + # TODO: The self._parameters.request_id_prefix parameter is never used because there is no clear + # way to set the opc_request_id using the RealtimeParameters and the RealtimeSpeechClient + # classes. The commented-out Python code just below accomplishes part of this but it doesn't + # seem to support the four different ways of authenticating which require both the "config" + # and "signer" parameters. + # + # import oci + # + # config = oci.config.from_file() + # + # ai_speech_client = oci.ai_speech.AIServiceSpeechClient(config) + # + # create_realtime_session_token_response = ai_speech_client.create_realtime_session_token( + # create_realtime_session_token_details=oci.ai_speech.models.CreateRealtimeSessionTokenDetails( + # compartment_id="ocid1.test.oc1..EXAMPLE-compartmentId-Value", + # freeform_tags={ + # 'EXAMPLE_KEY_XhhWK': 'EXAMPLE_VALUE_boU7XVY49wxJc7QtHycR'}, + # defined_tags={ + # 'EXAMPLE_KEY_C52Iu': { + # 'EXAMPLE_KEY_0C8XM': 'EXAMPLE--Value'}}), + # opc_retry_token="EXAMPLE-opcRetryToken-Value", + # opc_request_id="BRMK4LU7R2XWHOJSLH3S") + # + + self._real_time_speech_client = RealtimeSpeechClient( + config, + real_time_parameters, + real_time_speech_client_listener, + self._parameters.base_url, + signer, + compartment_id, + ) + + asyncio.create_task(self.connect_background_task()) + + def real_time_speech_client_close(self) -> None: + if self._real_time_speech_client is not None: + self._real_time_speech_client.close() + self._real_time_speech_client = None + self._connected = False + + async def connect_background_task(self) -> None: + await self._real_time_speech_client.connect() + + async def add_audio_bytes_background_task(self) -> None: + while True: + if ( + self._real_time_speech_client is not None + and not self._real_time_speech_client.close_flag + and self._connected + ): + logger.trace("Adding audio frame data to RTS SDK.") + audio_bytes = await self._audio_bytes_queue.get() + await self._real_time_speech_client.send_data(audio_bytes) + else: + await asyncio.sleep(0.010) + + # RealtimeSpeechClient method. + def on_network_event(self, message): + super_result = super().on_network_event(message) + self.real_time_speech_client_open() + return super_result + + # RealtimeSpeechClient method. + def on_error(self, error: RealtimeMessageError): + super_result = super().on_error(error) + self.real_time_speech_client_open() + return super_result + + # RealtimeSpeechClient method. + def on_connect(self): + return super().on_connect() + + # RealtimeSpeechClient method. + def on_connect_message(self, connectmessage: RealtimeMessageConnect): + self._connected = True + return super().on_connect_message(connectmessage) + + # RealtimeSpeechClient method. + def on_ack_message(self, ackmessage: RealtimeMessageAckAudio): + return super().on_ack_message(ackmessage) + + # RealtimeSpeechClient method. + def on_result(self, result: RealtimeMessageResult): + super_result = super().on_result(result) + + transcription = result["transcriptions"][0] + + is_final = transcription["isFinal"] + text = transcription["transcription"] + + log_message = "FINAL" if is_final else "PARTIAL" + log_message += " utterance: " + text + logger.debug(log_message) + + if is_final or self._parameters.return_partial_results: + speech_result = SpeechResult(is_final, text) + self._speech_result_queue.put_nowait(speech_result) + + return super_result + + # RealtimeSpeechClient method. + def on_close(self, error_code: int, error_message: str): + return super().on_close(error_code, error_message) + + +class Parameters: + """ + The parameters class. This class contains all parameter information for the Oracle STT class. + """ + + base_url: str + compartment_id: str + authentication_type: AuthenticationType + authentication_configuration_file_spec: str + authentication_profile_name: str + sample_rate: int + language_code: str + model_domain: str + is_ack_enabled: bool + partial_silence_threshold_milliseconds: int + final_silence_threshold_milliseconds: int + stabilize_partial_results: str + punctuation: str + customizations: list[dict] + should_ignore_invalid_customizations: bool + return_partial_results: bool + + +@dataclass +class SpeechResult: + """ + The speech result class. This class contains all information related to one speech result. + """ + + is_final: bool + text: str diff --git a/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/oracle_tts.py b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/oracle_tts.py new file mode 100644 index 0000000000..cfc15fdf74 --- /dev/null +++ b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/oracle_tts.py @@ -0,0 +1,232 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Oracle Corporation and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module wraps Oracle's TTS cloud service. While it is used by the Oracle LiveKit TTS plug-in, +it it completely indpendent of LiveKit and could be used in other environments besides LiveKit. + +Author: Keith Schnable (at Oracle Corporation) +Date: 2025-08-12 +""" + +from __future__ import annotations + +import asyncio +import base64 +import uuid +from urllib.parse import urlparse + +import oci +from oci.ai_speech import AIServiceSpeechClient +from oci.ai_speech.models import ( + TtsOracleConfiguration, + TtsOracleSpeechSettings, + TtsOracleTts2NaturalModelDetails, +) + +from .log import logger +from .utils import AuthenticationType, get_config_and_signer + + +class OracleTTS: + """ + The Oracle TTS class. This class wraps the Oracle TTS service. + """ + + def __init__( + self, + *, + base_url: str, # must be specified + compartment_id: str, # must be specified + authentication_type: AuthenticationType = AuthenticationType.SECURITY_TOKEN, + authentication_configuration_file_spec: str = "~/.oci/config", + authentication_profile_name: str = "DEFAULT", + request_id_prefix: str = "", + voice: str = "Victoria", + sample_rate: int = 16000, + ) -> None: + """ + Create a new instance of the OracleTTS class to access Oracle's TTS service. This has no LiveKit dependencies. + + Args: + base_url: Base URL. Type is str. No default (must be specified). + compartment_id: Compartment ID. Type is str. No default (must be specified). + authentication_type: Authentication type. Type is AuthenticationType enum (API_KEY, SECURITY_TOKEN, INSTANCE_PRINCIPAL, or RESOURCE_PRINCIPAL). Default is SECURITY_TOKEN. + authentication_configuration_file_spec: Authentication configuration file spec. Type is str. Default is "~/.oci/config". Applies only for API_KEY or SECURITY_TOKEN. + authentication_profile_name: Authentication profile name. Type is str. Default is "DEFAULT". Applies only for API_KEY or SECURITY_TOKEN. + request_id_prefix: Request ID prefix. Type is str. Default is "". + voice: Voice. Type is str. Default is "Victoria". + sample_rate: Sample rate. Type is int. Default is 16000. + """ + + self._parameters = Parameters() + self._parameters.base_url = base_url + self._parameters.compartment_id = compartment_id + self._parameters.authentication_type = authentication_type + self._parameters.authentication_configuration_file_spec = ( + authentication_configuration_file_spec + ) + self._parameters.authentication_profile_name = authentication_profile_name + self._parameters.request_id_prefix = request_id_prefix + self._parameters.voice = voice + self._parameters.sample_rate = sample_rate + + self.validate_parameters() + + configAndSigner = get_config_and_signer( + authentication_type=self._parameters.authentication_type, + authentication_configuration_file_spec=self._parameters.authentication_configuration_file_spec, + authentication_profile_name=self._parameters.authentication_profile_name, + ) + config = configAndSigner["config"] + signer = configAndSigner["signer"] + + if signer is None: + self._ai_service_speech_client = AIServiceSpeechClient( + config=config, service_endpoint=self._parameters.base_url + ) + else: + self._ai_service_speech_client = AIServiceSpeechClient( + config=config, service_endpoint=self._parameters.base_url, signer=signer + ) + + logger.debug("Initialized OracleTTS.") + + def validate_parameters(self): + if not isinstance(self._parameters.base_url, str): + raise TypeError("The base_url parameter must be a string.") + self._parameters.base_url = self._parameters.base_url.strip() + parsed = urlparse(self._parameters.base_url) + if not all([parsed.scheme, parsed.netloc]): + raise ValueError("The base_url parameter must be a valid URL.") + + if not isinstance(self._parameters.compartment_id, str): + raise TypeError("The compartment_id parameter must be a string.") + self._parameters.compartment_id = self._parameters.compartment_id.strip() + if len(self._parameters.compartment_id) == 0: + raise ValueError("The compartment_id parameter must not be an empty string.") + + if not isinstance(self._parameters.authentication_type, AuthenticationType): + raise TypeError( + "The authentication_type parameter must be one of the AuthenticationType enum members." + ) + + if self._parameters.authentication_type in { + AuthenticationType.API_KEY, + AuthenticationType.SECURITY_TOKEN, + }: + if not isinstance(self._parameters.authentication_configuration_file_spec, str): + raise TypeError( + "The authentication_configuration_file_spec parameter must be a string." + ) + self._parameters.authentication_configuration_file_spec = ( + self._parameters.authentication_configuration_file_spec.strip() + ) + if len(self._parameters.authentication_configuration_file_spec) == 0: + raise ValueError( + "The authentication_configuration_file_spec parameter must not be an empty string." + ) + + if not isinstance(self._parameters.authentication_profile_name, str): + raise TypeError("The authentication_profile_name parameter must be a string.") + self._parameters.authentication_profile_name = ( + self._parameters.authentication_profile_name.strip() + ) + if len(self._parameters.authentication_profile_name) == 0: + raise ValueError( + "The authentication_profile_name parameter must not be an empty string." + ) + + if not isinstance(self._parameters.request_id_prefix, str): + raise TypeError("The request_id_prefix parameter must be a string.") + + if not isinstance(self._parameters.voice, str): + raise TypeError("The voice parameter must be a string.") + self._parameters.voice = self._parameters.voice.strip() + if len(self._parameters.voice) == 0: + raise ValueError("The voice parameter must not be an empty string.") + + if not isinstance(self._parameters.sample_rate, int): + raise TypeError("The sample_rate parameter must be an integer.") + if self._parameters.sample_rate <= 0: + raise ValueError("The sample_rate parameter must be greater than 0.") + + async def synthesize_speech(self, *, text: str) -> bytes: + def sync_call(): + request_id = self._parameters.request_id_prefix + short_uuid() + + logger.debug("Before call to TTS service for: " + text) + + # + # this link may help if ever setting is_stream_enabled = True. this will only noticeably reduce latency + # if multiple sentences are passed into synthesize_speech() at a time. + # + # https://confluence.oraclecorp.com/confluence/pages/viewpage.action?pageId=11517257226 + # + response = self._ai_service_speech_client.synthesize_speech( + synthesize_speech_details=oci.ai_speech.models.SynthesizeSpeechDetails( + text=text, + is_stream_enabled=False, + compartment_id=self._parameters.compartment_id, + configuration=TtsOracleConfiguration( + model_family="ORACLE", + model_details=TtsOracleTts2NaturalModelDetails( + model_name="TTS_2_NATURAL", voice_id=self._parameters.voice + ), + speech_settings=TtsOracleSpeechSettings( + text_type="TEXT", + sample_rate_in_hz=self._parameters.sample_rate, + output_format="PCM", + ), + ), + ), + opc_request_id=request_id, + ) + + logger.debug("After call to TTS service for: " + text) + + if response is None or response.status != 200: + logger.error("Error calling TTS service for: " + text) + return None + + # + # the data is in .wav file format so remove the 44-byte .wav header. + # + audio_bytes = response.data.content[44:] + + return audio_bytes + + return await asyncio.to_thread(sync_call) + + +@staticmethod +def short_uuid() -> str: + uuid4 = uuid.uuid4() + base64EncodedUUID = base64.urlsafe_b64encode(uuid4.bytes) + return base64EncodedUUID.rstrip(b"=").decode("ascii") + + +class Parameters: + """ + The parameters class. This class contains all parameter information for the Oracle TTS class. + """ + + base_url: str + compartment_id: str + authentication_type: AuthenticationType + authentication_configuration_file_spec: str + authentication_profile_name: str + voice: str + sample_rate: int diff --git a/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/py.typed b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/stt.py b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/stt.py new file mode 100644 index 0000000000..12d013bd00 --- /dev/null +++ b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/stt.py @@ -0,0 +1,213 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Oracle Corporation and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module is the Oracle LiveKit STT plug-in. + +Author: Keith Schnable (at Oracle Corporation) +Date: 2025-08-12 +""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator + +from livekit import rtc +from livekit.agents import stt + +from .log import logger +from .oracle_stt import OracleSTT +from .utils import AuthenticationType + +REQUIRED_REAL_TIME_SPEECH_SERVICE_AUDIO_RATE = 16000 +REQUIRED_REAL_TIME_SPEECH_SERVICE_IS_ACK_ENABLED = False + + +class STT(stt.STT): + """ + The Oracle LiveKit STT plug-in class. This derives from livekit.agents.stt.STT. + """ + + def __init__( + self, + *, + base_url: str, # must be specified + compartment_id: str, # must be specified + authentication_type: AuthenticationType = AuthenticationType.SECURITY_TOKEN, + authentication_configuration_file_spec: str = "~/.oci/config", + authentication_profile_name: str = "DEFAULT", + language_code: str = "en-US", + model_domain: str = "GENERIC", + partial_silence_threshold_milliseconds: int = 0, + final_silence_threshold_milliseconds: int = 2000, + stabilize_partial_results: str = "NONE", + punctuation: str = "NONE", + customizations: list[dict] | None = None, + should_ignore_invalid_customizations: bool = False, + return_partial_results: bool = False, + ) -> None: + """ + Create a new instance of the STT class to access Oracle's RTS service. This has LiveKit dependencies. + + Args: + base_url: Base URL. Type is str. No default (must be specified). + compartment_id: Compartment ID. Type is str. No default (must be specified). + authentication_type: Authentication type. Type is AuthenticationType enum (API_KEY, SECURITY_TOKEN, INSTANCE_PRINCIPAL, or RESOURCE_PRINCIPAL). Default is SECURITY_TOKEN. + authentication_configuration_file_spec: Authentication configuration file spec. Type is str. Default is "~/.oci/config". Applies only for API_KEY or SECURITY_TOKEN. + authentication_profile_name: Authentication profile name. Type is str. Default is "DEFAULT". Applies only for API_KEY or SECURITY_TOKEN. + language_code: Language code. Type is str. Default is "en-US". + model_domain: Model domain. Type is str. Default is "GENERIC". + partial_silence_threshold_milliseconds: Partial silence threshold milliseconds. Type is int. Default is 0. + final_silence_threshold_milliseconds: Final silence threshold milliseconds. Type is int. Default is 2000. + stabilize_partial_results: Stabilize partial results. Type is str. Default is "NONE". Must be one of "NONE", "LOW", "MEDIUM", or "HIGH". + punctuation: Punctuation. Type is str. Default is "NONE". Must be one of "NONE", "SPOKEN", or "AUTO". + customizations: Customizations. Type is list[dict]. Default is None. + should_ignore_invalid_customizations. Should-ignore-invalid-customizations flag. Type is bool. Default is False. + return_partial_results. Return-partial-results flag. Type is bool. Default is False. + """ + + capabilities = stt.STTCapabilities(streaming=True, interim_results=return_partial_results) + super().__init__(capabilities=capabilities) + + self._oracle_stt = OracleSTT( + base_url=base_url, + compartment_id=compartment_id, + authentication_type=authentication_type, + authentication_configuration_file_spec=authentication_configuration_file_spec, + authentication_profile_name=authentication_profile_name, + request_id_prefix="live-kit-stt-plug-in-", + sample_rate=REQUIRED_REAL_TIME_SPEECH_SERVICE_AUDIO_RATE, + language_code=language_code, + model_domain=model_domain, + is_ack_enabled=REQUIRED_REAL_TIME_SPEECH_SERVICE_IS_ACK_ENABLED, + partial_silence_threshold_milliseconds=partial_silence_threshold_milliseconds, + final_silence_threshold_milliseconds=final_silence_threshold_milliseconds, + stabilize_partial_results=stabilize_partial_results, + punctuation=punctuation, + customizations=customizations, + should_ignore_invalid_customizations=should_ignore_invalid_customizations, + return_partial_results=return_partial_results, + ) + + logger.debug("Initialized STT.") + + async def get_speech_event(self) -> stt.SpeechEvent: + speech_result_queue = self._oracle_stt.get_speech_result_queue() + + if speech_result_queue.empty(): + return None + + speech_result = await speech_result_queue.get() + + speech_data = stt.SpeechData( + language="multi", # this must be "multi" or 4-second delays seem to occur before any tts occurs. + text=speech_result.text, + ) + + speech_event = stt.SpeechEvent( + type=stt.SpeechEventType.FINAL_TRANSCRIPT + if speech_result.is_final + else stt.SpeechEventType.INTERIM_TRANSCRIPT, + alternatives=[speech_data], + ) + + logger.debug( + "Returning " + + ("final" if speech_result.is_final else "partial") + + " speech result to LiveKit: " + + speech_result.text + ) + + return speech_event + + # STT method. + def stream(self, *, language=None, conn_options=None) -> SpeechStream: + return SpeechStream(self) + + # STT method. + async def _recognize_impl( + self, + audio_buffer, + *, + language=None, + conn_options=None, + ) -> stt.SpeechEvent: + speech_data = stt.SpeechData(language="multi", text="zz") + + return stt.SpeechEvent( + type=stt.SpeechEventType.FINAL_TRANSCRIPT, alternatives=[speech_data] + ) + + # STT method. + def on_start(self, participant_id: str, room_id: str): + pass + + # STT method. + def on_stop(self): + pass + + +class SpeechStream: + """ + The STT stream class. + """ + + def __init__(self, oracle_stt_livekit_plugin: STT): + self._running = True + self._queue = asyncio.Queue() + + self._oracle_stt_livekit_plugin = oracle_stt_livekit_plugin + + self._audio_resampler = None + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self._running = False + + def __aiter__(self) -> AsyncIterator[stt.SpeechEvent]: + return self._event_stream() + + def push_frame(self, frame: rtc.AudioFrame): + self._queue.put_nowait(frame) + + async def _event_stream(self) -> AsyncIterator[stt.SpeechEvent]: + while self._running: + frame = await self._queue.get() + + logger.trace("Received audio frame data from LiveKit.") + + if frame.sample_rate != REQUIRED_REAL_TIME_SPEECH_SERVICE_AUDIO_RATE: + if self._audio_resampler is None: + self._audio_resampler = rtc.AudioResampler( + input_rate=frame.sample_rate, + output_rate=REQUIRED_REAL_TIME_SPEECH_SERVICE_AUDIO_RATE, + quality=rtc.AudioResamplerQuality.HIGH, + ) + frame = self._audio_resampler.push(frame) + + frames = frame if isinstance(frame, list) else [frame] + + for frame in frames: + audio_bytes = frame.data + self._oracle_stt_livekit_plugin._oracle_stt.add_audio_bytes(audio_bytes) + + while True: + speech_event = await self._oracle_stt_livekit_plugin.get_speech_event() + if speech_event is None: + break + yield speech_event diff --git a/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/tts.py b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/tts.py new file mode 100644 index 0000000000..82e9e8073f --- /dev/null +++ b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/tts.py @@ -0,0 +1,196 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Oracle Corporation and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module is the Oracle LiveKit TTS plug-in. + +Author: Keith Schnable (at Oracle Corporation) +Date: 2025-08-12 +""" + +from __future__ import annotations + +from livekit.agents import tts, utils +from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS + +from .audio_cache import AudioCache +from .log import logger +from .oracle_tts import OracleTTS +from .utils import AuthenticationType + +REQUIRED_LIVE_KIT_AUDIO_RATE = 24000 +REQUIRED_LIVE_KIT_AUDIO_CHANNELS = 1 +REQUIRED_LIVE_KIT_AUDIO_BITS = 16 + + +class TTS(tts.TTS): + """ + The Oracle LiveKit TTS plug-in class. This derives from livekit.agents.tts.TTS. + """ + + def __init__( + self, + *, + base_url: str, # must be specified + compartment_id: str, # must be specified + authentication_type: AuthenticationType = AuthenticationType.SECURITY_TOKEN, + authentication_configuration_file_spec: str = "~/.oci/config", + authentication_profile_name: str = "DEFAULT", + voice: str = "Victoria", + audio_cache_file_path: str | None = None, + audio_cache_maximum_utterance_length: int = 100, + audio_cache_maximum_number_of_utterances: int = 100, + ) -> None: + """ + Create a new instance of the TTS class to access Oracle's TTS service. This has LiveKit dependencies. + + Args: + base_url: Base URL. Type is str. No default (must be specified). + compartment_id: Compartment ID. Type is str. No default (must be specified). + authentication_type: Authentication type. Type is AuthenticationType enum (API_KEY, SECURITY_TOKEN, INSTANCE_PRINCIPAL, or RESOURCE_PRINCIPAL). Default is SECURITY_TOKEN. + authentication_configuration_file_spec: Authentication configuration file spec. Type is str. Default is "~/.oci/config". Applies only for API_KEY or SECURITY_TOKEN. + authentication_profile_name: Authentication profile name. Type is str. Default is "DEFAULT". Applies only for API_KEY or SECURITY_TOKEN. + voice: Voice. Type is str. Default is "Victoria". + audio_cache_file_path: Audio cache file path. Type is str. Default is None. + audio_cache_maximum_utterance_length: Audio cache maximum utterance length. Type is int. Default is 100. + audio_cache_maximum_number_of_utterances: Audio cache maximum number of utterances. Type is int. Default is 100. + """ + + capabilities = tts.TTSCapabilities(streaming=False) + + super().__init__( + capabilities=capabilities, + sample_rate=REQUIRED_LIVE_KIT_AUDIO_RATE, + num_channels=REQUIRED_LIVE_KIT_AUDIO_CHANNELS, + ) + + self._oracle_tts = OracleTTS( + base_url=base_url, + compartment_id=compartment_id, + authentication_type=authentication_type, + authentication_configuration_file_spec=authentication_configuration_file_spec, + authentication_profile_name=authentication_profile_name, + request_id_prefix="live-kit-tts-plug-in-", + voice=voice, + sample_rate=REQUIRED_LIVE_KIT_AUDIO_RATE, + ) + + if audio_cache_file_path is not None: + if not isinstance(audio_cache_file_path, str): + raise TypeError("The audio_cache_file_path parameter must be a string.") + audio_cache_file_path = audio_cache_file_path.strip() + if len(audio_cache_file_path) == 0: + raise ValueError("The audio_cache_file_path parameter must not be an empty string.") + + if not isinstance(audio_cache_maximum_utterance_length, int): + raise TypeError( + "The audio_cache_maximum_utterance_length parameter must be an integer." + ) + if audio_cache_maximum_utterance_length <= 0: + raise ValueError( + "The audio_cache_maximum_utterance_length parameter must be greater than 0." + ) + + if not isinstance(audio_cache_maximum_number_of_utterances, int): + raise TypeError( + "The audio_cache_maximum_number_of_utterances parameter must be an integer." + ) + if audio_cache_maximum_number_of_utterances <= 0: + raise ValueError( + "The audio_cache_maximum_number_of_utterances parameter must be greater than 0." + ) + + if audio_cache_file_path is None: + self._audio_cache = None + else: + self._audio_cache = AudioCache( + audio_cache_file_path=audio_cache_file_path, + audio_cache_maximum_number_of_utterances=audio_cache_maximum_number_of_utterances, + ) + self._voice = voice + self._audio_cache_maximum_utterance_length = audio_cache_maximum_utterance_length + + logger.debug("Initialized TTS.") + + def synthesize(self, text: str, *, conn_options: DEFAULT_API_CONNECT_OPTIONS) -> ChunkedStream: + return ChunkedStream(tts=self, text=text, conn_options=conn_options) + + +class ChunkedStream(tts.ChunkedStream): + """ + The TTS chunked stream class. This derives from livekit.agents.tts.ChunkedStream. + """ + + def __init__( + self, *, tts: tts.TTS, text: str, conn_options: DEFAULT_API_CONNECT_OPTIONS + ) -> None: + super().__init__(tts=tts, input_text=text, conn_options=conn_options) + + self._oracle_tts_livekit_plugin = tts + + async def _run(self, audio_emitter: tts.AudioEmitter) -> None: + logger.debug("Received text from LiveKit for TTS: " + self._input_text) + + if self._oracle_tts_livekit_plugin._audio_cache is None: + audio_bytes = None + else: + audio_bytes = self._oracle_tts_livekit_plugin._audio_cache.get_audio_bytes( + text=self._input_text, + voice=self._oracle_tts_livekit_plugin._voice, + audio_rate=REQUIRED_LIVE_KIT_AUDIO_RATE, + audio_channels=REQUIRED_LIVE_KIT_AUDIO_CHANNELS, + audio_bits=REQUIRED_LIVE_KIT_AUDIO_BITS, + ) + + logger.debug("TTS is" + (" not" if audio_bytes is None else "") + " cached.") + + if audio_bytes is None: + logger.debug("Before getting TTS audio bytes.") + + audio_bytes = await self._oracle_tts_livekit_plugin._oracle_tts.synthesize_speech( + text=self._input_text + ) + + logger.debug("After getting TTS audio bytes.") + + audio_bytes_from_cache = False + else: + audio_bytes_from_cache = True + + if audio_bytes is not None: + audio_emitter.initialize( + request_id=utils.shortuuid(), + sample_rate=REQUIRED_LIVE_KIT_AUDIO_RATE, + num_channels=REQUIRED_LIVE_KIT_AUDIO_CHANNELS, + mime_type="audio/pcm", + ) + + audio_emitter.push(audio_bytes) + audio_emitter.flush() + + if ( + not audio_bytes_from_cache + and self._oracle_tts_livekit_plugin._audio_cache is not None + and len(self._input_text) + <= self._oracle_tts_livekit_plugin._audio_cache_maximum_utterance_length + ): + self._oracle_tts_livekit_plugin._audio_cache.set_audio_bytes( + text=self._input_text, + voice=self._oracle_tts_livekit_plugin._voice, + audio_rate=REQUIRED_LIVE_KIT_AUDIO_RATE, + audio_channels=REQUIRED_LIVE_KIT_AUDIO_CHANNELS, + audio_bits=REQUIRED_LIVE_KIT_AUDIO_BITS, + audio_bytes=audio_bytes, + ) diff --git a/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/utils.py b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/utils.py new file mode 100644 index 0000000000..df0ec9c982 --- /dev/null +++ b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/utils.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Oracle Corporation and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module provides utilities used throughout the Oracle LiveKit plug-in code. + +Author: Keith Schnable (at Oracle Corporation) +Date: 2025-08-12 +""" + +from __future__ import annotations + +from enum import Enum + +import oci + + +class AuthenticationType(Enum): + """Authentication types as enumerator.""" + + API_KEY = "API_KEY" + SECURITY_TOKEN = "SECURITY_TOKEN" + INSTANCE_PRINCIPAL = "INSTANCE_PRINCIPAL" + RESOURCE_PRINCIPAL = "RESOURCE_PRINCIPAL" + + +def get_config_and_signer( + *, + authentication_type: AuthenticationType = None, + authentication_configuration_file_spec: str = None, + authentication_profile_name: str = None, +): + config = {} + signer = None + + # API_KEY + if authentication_type == AuthenticationType.API_KEY: + config = oci.config.from_file( + authentication_configuration_file_spec, authentication_profile_name + ) + + # SECURITY_TOKEN + elif authentication_type == AuthenticationType.SECURITY_TOKEN: + config = oci.config.from_file( + authentication_configuration_file_spec, authentication_profile_name + ) + with open(config["security_token_file"]) as f: + token = f.readline() + private_key = oci.signer.load_private_key_from_file(config["key_file"]) + signer = oci.auth.signers.SecurityTokenSigner(token=token, private_key=private_key) + + # INSTANCE_PRINCIPAL + elif authentication_type == AuthenticationType.INSTANCE_PRINCIPAL: + signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + + # RESOURCE_PRINCIPAL + elif authentication_type == AuthenticationType.RESOURCE_PRINCIPAL: + signer = oci.auth.signers.get_resource_principals_signer() + + return {"config": config, "signer": signer} diff --git a/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/version.py b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/version.py new file mode 100644 index 0000000000..b178d43f35 --- /dev/null +++ b/livekit-plugins/livekit-plugins-oracle/livekit/plugins/oracle/version.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Oracle Corporation and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "1.2.11" diff --git a/livekit-plugins/livekit-plugins-oracle/pyproject.toml b/livekit-plugins/livekit-plugins-oracle/pyproject.toml new file mode 100644 index 0000000000..2bebf44edf --- /dev/null +++ b/livekit-plugins/livekit-plugins-oracle/pyproject.toml @@ -0,0 +1,42 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "livekit-plugins-oracle" +dynamic = ["version"] +description = "LiveKit Agents Plugin for Oracle" +readme = "README.md" +license = "Apache-2.0" +requires-python = ">=3.9.0" +authors = [{ name = "LiveKit" }] +keywords = ["webrtc", "realtime", "audio", "video", "livekit"] +classifiers = [ + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Topic :: Multimedia :: Sound/Audio", + "Topic :: Multimedia :: Video", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3 :: Only", +] +dependencies = [ + "livekit-agents ~= 1.2", + "oci-ai-speech-realtime ~= 2.1", +] + +[project.urls] +Documentation = "https://docs.livekit.io" +Website = "https://livekit.io/" +Source = "https://github.com/livekit/agents" + +[tool.hatch.version] +path = "livekit/plugins/oracle/version.py" + +[tool.hatch.build.targets.wheel] +packages = ["livekit"] + +[tool.hatch.build.targets.sdist] +include = ["/livekit"]