diff --git a/nemoguardrails/llm/prompts.py b/nemoguardrails/llm/prompts.py index 390bbeeb1..03d09be2d 100644 --- a/nemoguardrails/llm/prompts.py +++ b/nemoguardrails/llm/prompts.py @@ -139,7 +139,8 @@ def get_task_model(config: RailsConfig, task: Union[str, Task]) -> Model: if not _models: _models = [model for model in config.models if model.type == "main"] - return _models[0] + if _models: + return _models[0] return None diff --git a/tests/test_llm_task_manager.py b/tests/test_llm_task_manager.py index 9afd48914..c833fb709 100644 --- a/tests/test_llm_task_manager.py +++ b/tests/test_llm_task_manager.py @@ -20,7 +20,7 @@ from nemoguardrails import RailsConfig from nemoguardrails.llm.filters import conversation_to_events -from nemoguardrails.llm.prompts import get_prompt +from nemoguardrails.llm.prompts import get_prompt, get_task_model from nemoguardrails.llm.taskmanager import LLMTaskManager from nemoguardrails.llm.types import Task @@ -457,3 +457,59 @@ def test_reasoning_traces_not_included_in_prompt_history(): "Hi there!" in rendered_prompt or "I don't have access to real-time weather information." in rendered_prompt ) + + +def test_get_task_model_with_empty_models(): + """Test that get_task_model returns None when models list is empty. + + This tests the fix for the IndexError that occurred when the models list was empty. + """ + config = RailsConfig.parse_object({"models": []}) + + result = get_task_model(config, "main") + assert result is None + + result = get_task_model(config, Task.GENERAL) + assert result is None + + +def test_get_task_model_with_no_matching_models(): + """Test that get_task_model returns None when no models match the requested type.""" + config = RailsConfig.parse_object( + { + "models": [ + { + "type": "embeddings", + "engine": "openai", + "model": "text-embedding-ada-002", + } + ] + } + ) + + result = get_task_model(config, "main") + assert result is None + + +def test_get_task_model_with_main_model(): + """Test that get_task_model returns the main model when present.""" + config = RailsConfig.parse_object( + {"models": [{"type": "main", "engine": "openai", "model": "gpt-3.5-turbo"}]} + ) + + result = get_task_model(config, "main") + assert result is not None + assert result.type == "main" + assert result.engine == "openai" + assert result.model == "gpt-3.5-turbo" + + +def test_get_task_model_fallback_to_main(): + """Test that get_task_model falls back to main model when specific task model not found.""" + config = RailsConfig.parse_object( + {"models": [{"type": "main", "engine": "openai", "model": "gpt-3.5-turbo"}]} + ) + + result = get_task_model(config, "some_other_task") + assert result is not None + assert result.type == "main" diff --git a/tests/test_llmrails.py b/tests/test_llmrails.py index 96ff01b67..2451dd7b2 100644 --- a/tests/test_llmrails.py +++ b/tests/test_llmrails.py @@ -856,6 +856,42 @@ async def test_other_models_honored(mock_init, llm_config_with_multiple_models): assert any(event.get("intent") == "express greeting" for event in new_events) +@pytest.mark.asyncio +async def test_llm_constructor_with_empty_models_config(): + """Test that LLMRails can be initialized with constructor LLM when config has empty models list. + + This tests the fix for the IndexError that occurred when providing an LLM via constructor + but having an empty models list in the config. + """ + config = RailsConfig.parse_object( + { + "models": [], + "user_messages": { + "express greeting": ["Hello!"], + }, + "flows": [ + { + "elements": [ + {"user": "express greeting"}, + {"bot": "express greeting"}, + ] + }, + ], + "bot_messages": { + "express greeting": ["Hello! How are you?"], + }, + } + ) + + injected_llm = FakeLLM(responses=["express greeting"]) + llm_rails = LLMRails(config=config, llm=injected_llm) + assert llm_rails.llm == injected_llm + + events = [{"type": "UtteranceUserActionFinished", "final_transcript": "Hello!"}] + new_events = await llm_rails.runtime.generate_events(events) + assert any(event.get("intent") == "express greeting" for event in new_events) + + @pytest.mark.asyncio @patch( "nemoguardrails.rails.llm.llmrails.init_llm_model",