Skip to content

fix(prompts): prevent IndexError when LLM provided via constructor with empty models config #1334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion nemoguardrails/llm/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def get_task_model(config: RailsConfig, task: Union[str, Task]) -> Model:
if not _models:
_models = [model for model in config.models if model.type == "main"]

return _models[0]
if _models:
return _models[0]

return None

Expand Down
58 changes: 57 additions & 1 deletion tests/test_llm_task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from nemoguardrails import RailsConfig
from nemoguardrails.llm.filters import conversation_to_events
from nemoguardrails.llm.prompts import get_prompt
from nemoguardrails.llm.prompts import get_prompt, get_task_model
from nemoguardrails.llm.taskmanager import LLMTaskManager
from nemoguardrails.llm.types import Task

Expand Down Expand Up @@ -457,3 +457,59 @@ def test_reasoning_traces_not_included_in_prompt_history():
"Hi there!" in rendered_prompt
or "I don't have access to real-time weather information." in rendered_prompt
)


def test_get_task_model_with_empty_models():
"""Test that get_task_model returns None when models list is empty.

This tests the fix for the IndexError that occurred when the models list was empty.
"""
config = RailsConfig.parse_object({"models": []})

result = get_task_model(config, "main")
assert result is None

result = get_task_model(config, Task.GENERAL)
assert result is None


def test_get_task_model_with_no_matching_models():
"""Test that get_task_model returns None when no models match the requested type."""
config = RailsConfig.parse_object(
{
"models": [
{
"type": "embeddings",
"engine": "openai",
"model": "text-embedding-ada-002",
}
]
}
)

result = get_task_model(config, "main")
assert result is None


def test_get_task_model_with_main_model():
"""Test that get_task_model returns the main model when present."""
config = RailsConfig.parse_object(
{"models": [{"type": "main", "engine": "openai", "model": "gpt-3.5-turbo"}]}
)

result = get_task_model(config, "main")
assert result is not None
assert result.type == "main"
assert result.engine == "openai"
assert result.model == "gpt-3.5-turbo"


def test_get_task_model_fallback_to_main():
"""Test that get_task_model falls back to main model when specific task model not found."""
config = RailsConfig.parse_object(
{"models": [{"type": "main", "engine": "openai", "model": "gpt-3.5-turbo"}]}
)

result = get_task_model(config, "some_other_task")
assert result is not None
assert result.type == "main"
36 changes: 36 additions & 0 deletions tests/test_llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,42 @@ async def test_other_models_honored(mock_init, llm_config_with_multiple_models):
assert any(event.get("intent") == "express greeting" for event in new_events)


@pytest.mark.asyncio
async def test_llm_constructor_with_empty_models_config():
"""Test that LLMRails can be initialized with constructor LLM when config has empty models list.

This tests the fix for the IndexError that occurred when providing an LLM via constructor
but having an empty models list in the config.
"""
config = RailsConfig.parse_object(
{
"models": [],
"user_messages": {
"express greeting": ["Hello!"],
},
"flows": [
{
"elements": [
{"user": "express greeting"},
{"bot": "express greeting"},
]
},
],
"bot_messages": {
"express greeting": ["Hello! How are you?"],
},
}
)

injected_llm = FakeLLM(responses=["express greeting"])
llm_rails = LLMRails(config=config, llm=injected_llm)
assert llm_rails.llm == injected_llm

events = [{"type": "UtteranceUserActionFinished", "final_transcript": "Hello!"}]
new_events = await llm_rails.runtime.generate_events(events)
assert any(event.get("intent") == "express greeting" for event in new_events)


@pytest.mark.asyncio
@patch(
"nemoguardrails.rails.llm.llmrails.init_llm_model",
Expand Down