Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 75 additions & 27 deletions docassemble/ALToolbox/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,44 @@
"IntakeQuestionList",
]

if os.getenv("OPENAI_API_KEY"):
client: Optional[OpenAI] = OpenAI()
elif get_config("open ai"):
api_key = get_config("open ai", {}).get("key")
client = OpenAI(api_key=api_key)

def _get_openai_api_key() -> Optional[str]:
"""
Get the OpenAI API key from configuration or environment in priority order:
1. openai api key (direct key configuration)
2. openai: key (standardized nested configuration)
3. open ai: key (legacy nested configuration with space)
4. OPENAI_API_KEY environment variable

Returns:
The API key if found, None otherwise
"""
# Priority 1: Direct key configuration "openai api key: sk-..."
direct_key = get_config("openai api key")
if direct_key:
return direct_key

# Priority 2: Standardized nested configuration "openai: key: sk-..."
openai_config = get_config("openai", {})
if isinstance(openai_config, dict) and openai_config.get("key"):
return openai_config.get("key")

# Priority 3: Legacy nested configuration "open ai: key: sk-..."
legacy_config = get_config("open ai", {})
if isinstance(legacy_config, dict) and legacy_config.get("key"):
return legacy_config.get("key")

# Priority 4: Environment variable
return os.getenv("OPENAI_API_KEY")


api_key = _get_openai_api_key()
if api_key:
client: Optional[OpenAI] = OpenAI(api_key=api_key)
else:
client = None


always_reserved_names = set(
docassemble.base.util.__all__
+ keyword.kwlist
Expand Down Expand Up @@ -115,7 +145,8 @@ def chat_completion(
system_message: Optional[str] = None,
user_message: Optional[str] = None,
openai_client: Optional[OpenAI] = None,
openai_api: Optional[str] = None,
openai_api_key: Optional[str] = None,
openai_api: Optional[str] = None, # Kept for backwards compatibility
temperature: float = 0.5,
json_mode=False,
model: str = "gpt-4o",
Expand All @@ -133,7 +164,8 @@ def chat_completion(
system_message (str): The role the chat engine should play
user_message (str): The message (data) from the user
openai_client (Optional[OpenAI]): An OpenAI client object, optional. If omitted, will fall back to creating a new OpenAI client with the API key provided as an environment variable
openai_api (Optional[str]): the API key for an OpenAI client, optional. If provided, a new OpenAI client will be created.
openai_api_key (Optional[str]): the API key for an OpenAI client, optional. If provided, a new OpenAI client will be created.
openai_api (Optional[str]): the API key for an OpenAI client, optional. Deprecated, use openai_api_key instead. Kept for backwards compatibility.
temperature (float): The temperature to use for the GPT API
json_mode (bool): Whether to use JSON mode for the GPT API. Requires the word `json` in the system message, but will add if you omit it.
model (str): The model to use for the GPT API
Expand All @@ -147,12 +179,20 @@ def chat_completion(
A string with the response from the API endpoint or JSON data if json_mode is True
"""
if not openai_base_url:
openai_base_url = (
get_config("open ai", {}).get("base url") or "https://api.openai.com/v1/"
)
# Check both new and legacy config formats for base URL
openai_config = get_config("openai", {})
if isinstance(openai_config, dict) and openai_config.get("base url"):
openai_base_url = openai_config.get("base url")
else:
openai_base_url = (
get_config("open ai", {}).get("base url")
or "https://api.openai.com/v1/"
)

if not openai_api:
openai_api = get_config("open ai", {}).get("key") or os.getenv("OPENAI_API_KEY")
# Handle backwards compatibility: prefer openai_api_key, fall back to openai_api, then config
api_key = openai_api_key or openai_api
if not api_key:
api_key = _get_openai_api_key()

if not messages and not system_message:
raise Exception(
Expand Down Expand Up @@ -192,12 +232,12 @@ def chat_completion(
if openai_base_url:
openai_client = None # Always override client in this circumstance
openai_client = (
openai_client or OpenAI(base_url=openai_base_url, api_key=openai_api) or client
openai_client or OpenAI(base_url=openai_base_url, api_key=api_key) or client
)

if not openai_client:
raise Exception(
"You need to pass an OpenAI client or API key to use this function, or the API key needs to be set in the environment or Docassemble configuration. Try adding a new section in your global config that looks like this:\n\nopen ai:\n key: sk-..."
"You need to pass an OpenAI client or API key to use this function, or the API key needs to be set in the environment or Docassemble configuration. Try adding a new section in your global config that looks like this:\n\nopenai api key: sk-..."
)

try:
Expand Down Expand Up @@ -260,7 +300,8 @@ def extract_fields_from_text(
text: str,
field_list: Dict[str, str],
openai_client: Optional[OpenAI] = None,
openai_api: Optional[str] = None,
openai_api_key: Optional[str] = None,
openai_api: Optional[str] = None, # Kept for backwards compatibility
temperature: float = 0,
model="gpt-4o-mini",
) -> Dict[str, Any]:
Expand All @@ -270,7 +311,8 @@ def extract_fields_from_text(
text (str): The text to extract fields from
field_list (Dict[str,str]): A list of fields to extract, with the key being the field name and the value being a description of the field
openai_client (Optional[OpenAI]): An OpenAI client object. Defaults to None.
openai_api (Optional[str]): An OpenAI API key. Defaults to None.
openai_api_key (Optional[str]): An OpenAI API key. Defaults to None.
openai_api (Optional[str]): An OpenAI API key. Deprecated, use openai_api_key instead. Kept for backwards compatibility.
temperature (float): The temperature to use for the OpenAI API. Defaults to 0.
model (str): The model to use for the OpenAI API. Defaults to "gpt-4o-mini".

Expand All @@ -292,7 +334,7 @@ def extract_fields_from_text(
user_message=text,
model=model,
openai_client=openai_client,
openai_api=openai_api,
openai_api_key=openai_api_key or openai_api,
temperature=temperature,
json_mode=True,
)
Expand All @@ -305,7 +347,8 @@ def match_goals_from_text(
user_response: str,
goals: Dict[str, str],
openai_client: Optional[OpenAI] = None,
openai_api: Optional[str] = None,
openai_api_key: Optional[str] = None,
openai_api: Optional[str] = None, # Kept for backwards compatibility
temperature: float = 0,
model="gpt-4o-mini",
) -> Dict[str, Any]:
Expand All @@ -316,7 +359,8 @@ def match_goals_from_text(
user_response (str): The user's response to the question
goals (Dict[str,str]): A list of goals to extract, with the key being the goal name and the value being a description of the goal
openai_client (Optional[OpenAI]): An OpenAI client object. Defaults to None.
openai_api (Optional[str]): An OpenAI API key. Defaults to None.
openai_api_key (Optional[str]): An OpenAI API key. Defaults to None.
openai_api (Optional[str]): An OpenAI API key. Deprecated, use openai_api_key instead. Kept for backwards compatibility.
temperature (float): The temperature to use for the OpenAI API. Defaults to 0.
model (str): The model to use for the OpenAI API. Defaults to "gpt-4o-mini".

Expand Down Expand Up @@ -352,7 +396,7 @@ def match_goals_from_text(
user_message=user_response,
model=model,
openai_client=openai_client,
openai_api=openai_api,
openai_api_key=openai_api_key or openai_api,
temperature=temperature,
json_mode=True,
)
Expand All @@ -365,7 +409,8 @@ def classify_text(
choices: Dict[str, str],
default_response: str = "null",
openai_client: Optional[OpenAI] = None,
openai_api: Optional[str] = None,
openai_api_key: Optional[str] = None,
openai_api: Optional[str] = None, # Kept for backwards compatibility
temperature: float = 0,
model="gpt-4o-mini",
) -> str:
Expand All @@ -376,7 +421,8 @@ def classify_text(
choices (Dict[str,str]): A list of choices to classify the text into, with the key being the choice name and the value being a description of the choice
default_response (str): The default response to return if the text cannot be classified. Defaults to "null".
openai_client (Optional[OpenAI]): An OpenAI client object, optional. If omitted, will fall back to creating a new OpenAI client with the API key provided as an environment variable
openai_api (Optional[str]): the API key for an OpenAI client, optional. If provided, a new OpenAI client will be created.
openai_api_key (Optional[str]): the API key for an OpenAI client, optional. If provided, a new OpenAI client will be created.
openai_api (Optional[str]): the API key for an OpenAI client, optional. Deprecated, use openai_api_key instead. Kept for backwards compatibility.
temperature (float): The temperature to use for GPT. Defaults to 0.
model (str): The model to use for the GPT API

Expand All @@ -395,7 +441,7 @@ def classify_text(
user_message=text,
model=model,
openai_client=openai_client,
openai_api=openai_api,
openai_api_key=openai_api_key or openai_api,
temperature=temperature,
json_mode=False,
)
Expand All @@ -407,7 +453,8 @@ def synthesize_user_responses(
messages: List[Dict[str, str]],
custom_instructions: Optional[str] = "",
openai_client: Optional[OpenAI] = None,
openai_api: Optional[str] = None,
openai_api_key: Optional[str] = None,
openai_api: Optional[str] = None, # Kept for backwards compatibility
temperature: float = 0,
model: str = "gpt-4o-mini",
) -> str:
Expand All @@ -418,7 +465,8 @@ def synthesize_user_responses(
messages (List[Dict[str, str]]): A list of questions from the LLM and responses from the user
custom_instructions (str): Custom instructions for the LLM to follow in constructing the synthesized response
openai_client (Optional[OpenAI]): An OpenAI client object, optional. If omitted, will fall back to creating a new OpenAI client with the API key provided as an environment variable
openai_api (Optional[str]): the API key for an OpenAI client, optional. If provided, a new OpenAI client will be created.
openai_api_key (Optional[str]): the API key for an OpenAI client, optional. If provided, a new OpenAI client will be created.
openai_api (Optional[str]): the API key for an OpenAI client, optional. Deprecated, use openai_api_key instead. Kept for backwards compatibility.
temperature (float): The temperature to use for GPT. Defaults to 0.
model (str): The model to use for the GPT API

Expand Down Expand Up @@ -449,7 +497,7 @@ def synthesize_user_responses(
],
model=model,
openai_client=openai_client,
openai_api=openai_api,
openai_api_key=openai_api_key or openai_api,
temperature=temperature,
json_mode=False,
)
Expand Down Expand Up @@ -629,7 +677,7 @@ class GoalSatisfactionList(DAList):
By default, this will use the OpenAI API key defined in the global configuration under this path:

```
open ai:
openai:
key: sk-...
```

Expand Down
148 changes: 148 additions & 0 deletions docassemble/ALToolbox/test_llms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# do not pre-load

import unittest
from unittest.mock import patch, MagicMock
import os
from .llms import _get_openai_api_key, chat_completion


class TestLLMSConfig(unittest.TestCase):
"""Test configuration lookup and parameter handling for LLM functions."""

@patch("docassemble.ALToolbox.llms.get_config")
@patch.dict(os.environ, {}, clear=True)
def test_get_openai_api_key_priority_order(self, mock_get_config):
"""Test that _get_openai_api_key follows the correct priority order."""

# Test priority 1: direct key configuration "openai api key"
mock_get_config.side_effect = lambda key, default=None: {
"openai api key": "direct-key"
}.get(key, default)

result = _get_openai_api_key()
self.assertEqual(result, "direct-key")

# Test priority 2: standardized nested configuration "openai: key"
mock_get_config.side_effect = lambda key, default=None: {
"openai api key": None,
"openai": {"key": "standardized-key"},
}.get(key, default)

result = _get_openai_api_key()
self.assertEqual(result, "standardized-key")

# Test priority 3: legacy nested configuration "open ai: key"
mock_get_config.side_effect = lambda key, default=None: {
"openai api key": None,
"openai": {},
"open ai": {"key": "legacy-key"},
}.get(key, default)

result = _get_openai_api_key()
self.assertEqual(result, "legacy-key")

@patch("docassemble.ALToolbox.llms.get_config")
@patch.dict(os.environ, {"OPENAI_API_KEY": "env-key"}, clear=True)
def test_get_openai_api_key_environment_fallback(self, mock_get_config):
"""Test that environment variable is used as last resort."""

# No config keys found, should fall back to environment
mock_get_config.side_effect = lambda key, default=None: default

result = _get_openai_api_key()
self.assertEqual(result, "env-key")

@patch("docassemble.ALToolbox.llms.get_config")
@patch.dict(os.environ, {}, clear=True)
def test_get_openai_api_key_no_key_found(self, mock_get_config):
"""Test that None is returned when no key is found."""

mock_get_config.side_effect = lambda key, default=None: default

result = _get_openai_api_key()
self.assertIsNone(result)

@patch("docassemble.ALToolbox.llms.get_config")
def test_get_openai_api_key_handles_non_dict_config(self, mock_get_config):
"""Test that non-dict config values are handled gracefully."""

# Test when config returns a string instead of dict
mock_get_config.side_effect = lambda key, default=None: {
"openai api key": None,
"openai": "not-a-dict", # This should be handled gracefully
"open ai": {"key": "fallback-key"},
}.get(key, default)

result = _get_openai_api_key()
self.assertEqual(result, "fallback-key")

@patch("docassemble.ALToolbox.llms.OpenAI")
@patch("docassemble.ALToolbox.llms._get_openai_api_key")
def test_chat_completion_parameter_backwards_compatibility(
self, mock_get_key, mock_openai
):
"""Test that both openai_api and openai_api_key parameters work."""

# Mock the OpenAI client
mock_client = MagicMock()
mock_openai.return_value = mock_client

# Mock successful response
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].finish_reason = "stop"
mock_response.choices[0].message.content = "test response"
mock_client.chat.completions.create.return_value = mock_response

# Mock moderation
mock_moderation = MagicMock()
mock_moderation.results = [MagicMock()]
mock_moderation.results[0].flagged = False
mock_client.moderations.create.return_value = mock_moderation

mock_get_key.return_value = "config-key"

# Test new parameter name
result = chat_completion(
system_message="test", user_message="test", openai_api_key="new-param-key"
)

# Should use new parameter value
mock_openai.assert_called_with(
base_url="https://api.openai.com/v1/", api_key="new-param-key"
)

# Test old parameter name
result = chat_completion(
system_message="test", user_message="test", openai_api="old-param-key"
)

# Should use old parameter value
mock_openai.assert_called_with(
base_url="https://api.openai.com/v1/", api_key="old-param-key"
)

# Test both parameters (new should take priority)
result = chat_completion(
system_message="test",
user_message="test",
openai_api_key="new-param-key",
openai_api="old-param-key",
)

# Should prioritize new parameter
mock_openai.assert_called_with(
base_url="https://api.openai.com/v1/", api_key="new-param-key"
)

# Test neither parameter (should fall back to config)
result = chat_completion(system_message="test", user_message="test")

# Should use config key
mock_openai.assert_called_with(
base_url="https://api.openai.com/v1/", api_key="config-key"
)


if __name__ == "__main__":
unittest.main()
Loading