diff --git a/docassemble/ALToolbox/llms.py b/docassemble/ALToolbox/llms.py index f1bc938..c2a7065 100644 --- a/docassemble/ALToolbox/llms.py +++ b/docassemble/ALToolbox/llms.py @@ -27,14 +27,44 @@ "IntakeQuestionList", ] -if os.getenv("OPENAI_API_KEY"): - client: Optional[OpenAI] = OpenAI() -elif get_config("open ai"): - api_key = get_config("open ai", {}).get("key") - client = OpenAI(api_key=api_key) + +def _get_openai_api_key() -> Optional[str]: + """ + Get the OpenAI API key from configuration or environment in priority order: + 1. openai api key (direct key configuration) + 2. openai: key (standardized nested configuration) + 3. open ai: key (legacy nested configuration with space) + 4. OPENAI_API_KEY environment variable + + Returns: + The API key if found, None otherwise + """ + # Priority 1: Direct key configuration "openai api key: sk-..." + direct_key = get_config("openai api key") + if direct_key: + return direct_key + + # Priority 2: Standardized nested configuration "openai: key: sk-..." + openai_config = get_config("openai", {}) + if isinstance(openai_config, dict) and openai_config.get("key"): + return openai_config.get("key") + + # Priority 3: Legacy nested configuration "open ai: key: sk-..." + legacy_config = get_config("open ai", {}) + if isinstance(legacy_config, dict) and legacy_config.get("key"): + return legacy_config.get("key") + + # Priority 4: Environment variable + return os.getenv("OPENAI_API_KEY") + + +api_key = _get_openai_api_key() +if api_key: + client: Optional[OpenAI] = OpenAI(api_key=api_key) else: client = None + always_reserved_names = set( docassemble.base.util.__all__ + keyword.kwlist @@ -115,7 +145,8 @@ def chat_completion( system_message: Optional[str] = None, user_message: Optional[str] = None, openai_client: Optional[OpenAI] = None, - openai_api: Optional[str] = None, + openai_api_key: Optional[str] = None, + openai_api: Optional[str] = None, # Kept for backwards compatibility temperature: float = 0.5, json_mode=False, model: str = "gpt-4o", @@ -133,7 +164,8 @@ def chat_completion( system_message (str): The role the chat engine should play user_message (str): The message (data) from the user openai_client (Optional[OpenAI]): An OpenAI client object, optional. If omitted, will fall back to creating a new OpenAI client with the API key provided as an environment variable - openai_api (Optional[str]): the API key for an OpenAI client, optional. If provided, a new OpenAI client will be created. + openai_api_key (Optional[str]): the API key for an OpenAI client, optional. If provided, a new OpenAI client will be created. + openai_api (Optional[str]): the API key for an OpenAI client, optional. Deprecated, use openai_api_key instead. Kept for backwards compatibility. temperature (float): The temperature to use for the GPT API json_mode (bool): Whether to use JSON mode for the GPT API. Requires the word `json` in the system message, but will add if you omit it. model (str): The model to use for the GPT API @@ -147,12 +179,20 @@ def chat_completion( A string with the response from the API endpoint or JSON data if json_mode is True """ if not openai_base_url: - openai_base_url = ( - get_config("open ai", {}).get("base url") or "https://api.openai.com/v1/" - ) + # Check both new and legacy config formats for base URL + openai_config = get_config("openai", {}) + if isinstance(openai_config, dict) and openai_config.get("base url"): + openai_base_url = openai_config.get("base url") + else: + openai_base_url = ( + get_config("open ai", {}).get("base url") + or "https://api.openai.com/v1/" + ) - if not openai_api: - openai_api = get_config("open ai", {}).get("key") or os.getenv("OPENAI_API_KEY") + # Handle backwards compatibility: prefer openai_api_key, fall back to openai_api, then config + api_key = openai_api_key or openai_api + if not api_key: + api_key = _get_openai_api_key() if not messages and not system_message: raise Exception( @@ -192,12 +232,12 @@ def chat_completion( if openai_base_url: openai_client = None # Always override client in this circumstance openai_client = ( - openai_client or OpenAI(base_url=openai_base_url, api_key=openai_api) or client + openai_client or OpenAI(base_url=openai_base_url, api_key=api_key) or client ) if not openai_client: raise Exception( - "You need to pass an OpenAI client or API key to use this function, or the API key needs to be set in the environment or Docassemble configuration. Try adding a new section in your global config that looks like this:\n\nopen ai:\n key: sk-..." + "You need to pass an OpenAI client or API key to use this function, or the API key needs to be set in the environment or Docassemble configuration. Try adding a new section in your global config that looks like this:\n\nopenai api key: sk-..." ) try: @@ -260,7 +300,8 @@ def extract_fields_from_text( text: str, field_list: Dict[str, str], openai_client: Optional[OpenAI] = None, - openai_api: Optional[str] = None, + openai_api_key: Optional[str] = None, + openai_api: Optional[str] = None, # Kept for backwards compatibility temperature: float = 0, model="gpt-4o-mini", ) -> Dict[str, Any]: @@ -270,7 +311,8 @@ def extract_fields_from_text( text (str): The text to extract fields from field_list (Dict[str,str]): A list of fields to extract, with the key being the field name and the value being a description of the field openai_client (Optional[OpenAI]): An OpenAI client object. Defaults to None. - openai_api (Optional[str]): An OpenAI API key. Defaults to None. + openai_api_key (Optional[str]): An OpenAI API key. Defaults to None. + openai_api (Optional[str]): An OpenAI API key. Deprecated, use openai_api_key instead. Kept for backwards compatibility. temperature (float): The temperature to use for the OpenAI API. Defaults to 0. model (str): The model to use for the OpenAI API. Defaults to "gpt-4o-mini". @@ -292,7 +334,7 @@ def extract_fields_from_text( user_message=text, model=model, openai_client=openai_client, - openai_api=openai_api, + openai_api_key=openai_api_key or openai_api, temperature=temperature, json_mode=True, ) @@ -305,7 +347,8 @@ def match_goals_from_text( user_response: str, goals: Dict[str, str], openai_client: Optional[OpenAI] = None, - openai_api: Optional[str] = None, + openai_api_key: Optional[str] = None, + openai_api: Optional[str] = None, # Kept for backwards compatibility temperature: float = 0, model="gpt-4o-mini", ) -> Dict[str, Any]: @@ -316,7 +359,8 @@ def match_goals_from_text( user_response (str): The user's response to the question goals (Dict[str,str]): A list of goals to extract, with the key being the goal name and the value being a description of the goal openai_client (Optional[OpenAI]): An OpenAI client object. Defaults to None. - openai_api (Optional[str]): An OpenAI API key. Defaults to None. + openai_api_key (Optional[str]): An OpenAI API key. Defaults to None. + openai_api (Optional[str]): An OpenAI API key. Deprecated, use openai_api_key instead. Kept for backwards compatibility. temperature (float): The temperature to use for the OpenAI API. Defaults to 0. model (str): The model to use for the OpenAI API. Defaults to "gpt-4o-mini". @@ -352,7 +396,7 @@ def match_goals_from_text( user_message=user_response, model=model, openai_client=openai_client, - openai_api=openai_api, + openai_api_key=openai_api_key or openai_api, temperature=temperature, json_mode=True, ) @@ -365,7 +409,8 @@ def classify_text( choices: Dict[str, str], default_response: str = "null", openai_client: Optional[OpenAI] = None, - openai_api: Optional[str] = None, + openai_api_key: Optional[str] = None, + openai_api: Optional[str] = None, # Kept for backwards compatibility temperature: float = 0, model="gpt-4o-mini", ) -> str: @@ -376,7 +421,8 @@ def classify_text( choices (Dict[str,str]): A list of choices to classify the text into, with the key being the choice name and the value being a description of the choice default_response (str): The default response to return if the text cannot be classified. Defaults to "null". openai_client (Optional[OpenAI]): An OpenAI client object, optional. If omitted, will fall back to creating a new OpenAI client with the API key provided as an environment variable - openai_api (Optional[str]): the API key for an OpenAI client, optional. If provided, a new OpenAI client will be created. + openai_api_key (Optional[str]): the API key for an OpenAI client, optional. If provided, a new OpenAI client will be created. + openai_api (Optional[str]): the API key for an OpenAI client, optional. Deprecated, use openai_api_key instead. Kept for backwards compatibility. temperature (float): The temperature to use for GPT. Defaults to 0. model (str): The model to use for the GPT API @@ -395,7 +441,7 @@ def classify_text( user_message=text, model=model, openai_client=openai_client, - openai_api=openai_api, + openai_api_key=openai_api_key or openai_api, temperature=temperature, json_mode=False, ) @@ -407,7 +453,8 @@ def synthesize_user_responses( messages: List[Dict[str, str]], custom_instructions: Optional[str] = "", openai_client: Optional[OpenAI] = None, - openai_api: Optional[str] = None, + openai_api_key: Optional[str] = None, + openai_api: Optional[str] = None, # Kept for backwards compatibility temperature: float = 0, model: str = "gpt-4o-mini", ) -> str: @@ -418,7 +465,8 @@ def synthesize_user_responses( messages (List[Dict[str, str]]): A list of questions from the LLM and responses from the user custom_instructions (str): Custom instructions for the LLM to follow in constructing the synthesized response openai_client (Optional[OpenAI]): An OpenAI client object, optional. If omitted, will fall back to creating a new OpenAI client with the API key provided as an environment variable - openai_api (Optional[str]): the API key for an OpenAI client, optional. If provided, a new OpenAI client will be created. + openai_api_key (Optional[str]): the API key for an OpenAI client, optional. If provided, a new OpenAI client will be created. + openai_api (Optional[str]): the API key for an OpenAI client, optional. Deprecated, use openai_api_key instead. Kept for backwards compatibility. temperature (float): The temperature to use for GPT. Defaults to 0. model (str): The model to use for the GPT API @@ -449,7 +497,7 @@ def synthesize_user_responses( ], model=model, openai_client=openai_client, - openai_api=openai_api, + openai_api_key=openai_api_key or openai_api, temperature=temperature, json_mode=False, ) @@ -629,7 +677,7 @@ class GoalSatisfactionList(DAList): By default, this will use the OpenAI API key defined in the global configuration under this path: ``` - open ai: + openai: key: sk-... ``` diff --git a/docassemble/ALToolbox/test_llms.py b/docassemble/ALToolbox/test_llms.py new file mode 100644 index 0000000..4735e60 --- /dev/null +++ b/docassemble/ALToolbox/test_llms.py @@ -0,0 +1,148 @@ +# do not pre-load + +import unittest +from unittest.mock import patch, MagicMock +import os +from .llms import _get_openai_api_key, chat_completion + + +class TestLLMSConfig(unittest.TestCase): + """Test configuration lookup and parameter handling for LLM functions.""" + + @patch("docassemble.ALToolbox.llms.get_config") + @patch.dict(os.environ, {}, clear=True) + def test_get_openai_api_key_priority_order(self, mock_get_config): + """Test that _get_openai_api_key follows the correct priority order.""" + + # Test priority 1: direct key configuration "openai api key" + mock_get_config.side_effect = lambda key, default=None: { + "openai api key": "direct-key" + }.get(key, default) + + result = _get_openai_api_key() + self.assertEqual(result, "direct-key") + + # Test priority 2: standardized nested configuration "openai: key" + mock_get_config.side_effect = lambda key, default=None: { + "openai api key": None, + "openai": {"key": "standardized-key"}, + }.get(key, default) + + result = _get_openai_api_key() + self.assertEqual(result, "standardized-key") + + # Test priority 3: legacy nested configuration "open ai: key" + mock_get_config.side_effect = lambda key, default=None: { + "openai api key": None, + "openai": {}, + "open ai": {"key": "legacy-key"}, + }.get(key, default) + + result = _get_openai_api_key() + self.assertEqual(result, "legacy-key") + + @patch("docassemble.ALToolbox.llms.get_config") + @patch.dict(os.environ, {"OPENAI_API_KEY": "env-key"}, clear=True) + def test_get_openai_api_key_environment_fallback(self, mock_get_config): + """Test that environment variable is used as last resort.""" + + # No config keys found, should fall back to environment + mock_get_config.side_effect = lambda key, default=None: default + + result = _get_openai_api_key() + self.assertEqual(result, "env-key") + + @patch("docassemble.ALToolbox.llms.get_config") + @patch.dict(os.environ, {}, clear=True) + def test_get_openai_api_key_no_key_found(self, mock_get_config): + """Test that None is returned when no key is found.""" + + mock_get_config.side_effect = lambda key, default=None: default + + result = _get_openai_api_key() + self.assertIsNone(result) + + @patch("docassemble.ALToolbox.llms.get_config") + def test_get_openai_api_key_handles_non_dict_config(self, mock_get_config): + """Test that non-dict config values are handled gracefully.""" + + # Test when config returns a string instead of dict + mock_get_config.side_effect = lambda key, default=None: { + "openai api key": None, + "openai": "not-a-dict", # This should be handled gracefully + "open ai": {"key": "fallback-key"}, + }.get(key, default) + + result = _get_openai_api_key() + self.assertEqual(result, "fallback-key") + + @patch("docassemble.ALToolbox.llms.OpenAI") + @patch("docassemble.ALToolbox.llms._get_openai_api_key") + def test_chat_completion_parameter_backwards_compatibility( + self, mock_get_key, mock_openai + ): + """Test that both openai_api and openai_api_key parameters work.""" + + # Mock the OpenAI client + mock_client = MagicMock() + mock_openai.return_value = mock_client + + # Mock successful response + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].finish_reason = "stop" + mock_response.choices[0].message.content = "test response" + mock_client.chat.completions.create.return_value = mock_response + + # Mock moderation + mock_moderation = MagicMock() + mock_moderation.results = [MagicMock()] + mock_moderation.results[0].flagged = False + mock_client.moderations.create.return_value = mock_moderation + + mock_get_key.return_value = "config-key" + + # Test new parameter name + result = chat_completion( + system_message="test", user_message="test", openai_api_key="new-param-key" + ) + + # Should use new parameter value + mock_openai.assert_called_with( + base_url="https://api.openai.com/v1/", api_key="new-param-key" + ) + + # Test old parameter name + result = chat_completion( + system_message="test", user_message="test", openai_api="old-param-key" + ) + + # Should use old parameter value + mock_openai.assert_called_with( + base_url="https://api.openai.com/v1/", api_key="old-param-key" + ) + + # Test both parameters (new should take priority) + result = chat_completion( + system_message="test", + user_message="test", + openai_api_key="new-param-key", + openai_api="old-param-key", + ) + + # Should prioritize new parameter + mock_openai.assert_called_with( + base_url="https://api.openai.com/v1/", api_key="new-param-key" + ) + + # Test neither parameter (should fall back to config) + result = chat_completion(system_message="test", user_message="test") + + # Should use config key + mock_openai.assert_called_with( + base_url="https://api.openai.com/v1/", api_key="config-key" + ) + + +if __name__ == "__main__": + unittest.main()