|
20 | 20 |
|
21 | 21 | from nemoguardrails import RailsConfig
|
22 | 22 | from nemoguardrails.llm.filters import conversation_to_events
|
23 |
| -from nemoguardrails.llm.prompts import get_prompt |
| 23 | +from nemoguardrails.llm.prompts import get_prompt, get_task_model |
24 | 24 | from nemoguardrails.llm.taskmanager import LLMTaskManager
|
25 | 25 | from nemoguardrails.llm.types import Task
|
26 | 26 |
|
@@ -457,3 +457,59 @@ def test_reasoning_traces_not_included_in_prompt_history():
|
457 | 457 | "Hi there!" in rendered_prompt
|
458 | 458 | or "I don't have access to real-time weather information." in rendered_prompt
|
459 | 459 | )
|
| 460 | + |
| 461 | + |
| 462 | +def test_get_task_model_with_empty_models(): |
| 463 | + """Test that get_task_model returns None when models list is empty. |
| 464 | +
|
| 465 | + This tests the fix for the IndexError that occurred when the models list was empty. |
| 466 | + """ |
| 467 | + config = RailsConfig.parse_object({"models": []}) |
| 468 | + |
| 469 | + result = get_task_model(config, "main") |
| 470 | + assert result is None |
| 471 | + |
| 472 | + result = get_task_model(config, Task.GENERAL) |
| 473 | + assert result is None |
| 474 | + |
| 475 | + |
| 476 | +def test_get_task_model_with_no_matching_models(): |
| 477 | + """Test that get_task_model returns None when no models match the requested type.""" |
| 478 | + config = RailsConfig.parse_object( |
| 479 | + { |
| 480 | + "models": [ |
| 481 | + { |
| 482 | + "type": "embeddings", |
| 483 | + "engine": "openai", |
| 484 | + "model": "text-embedding-ada-002", |
| 485 | + } |
| 486 | + ] |
| 487 | + } |
| 488 | + ) |
| 489 | + |
| 490 | + result = get_task_model(config, "main") |
| 491 | + assert result is None |
| 492 | + |
| 493 | + |
| 494 | +def test_get_task_model_with_main_model(): |
| 495 | + """Test that get_task_model returns the main model when present.""" |
| 496 | + config = RailsConfig.parse_object( |
| 497 | + {"models": [{"type": "main", "engine": "openai", "model": "gpt-3.5-turbo"}]} |
| 498 | + ) |
| 499 | + |
| 500 | + result = get_task_model(config, "main") |
| 501 | + assert result is not None |
| 502 | + assert result.type == "main" |
| 503 | + assert result.engine == "openai" |
| 504 | + assert result.model == "gpt-3.5-turbo" |
| 505 | + |
| 506 | + |
| 507 | +def test_get_task_model_fallback_to_main(): |
| 508 | + """Test that get_task_model falls back to main model when specific task model not found.""" |
| 509 | + config = RailsConfig.parse_object( |
| 510 | + {"models": [{"type": "main", "engine": "openai", "model": "gpt-3.5-turbo"}]} |
| 511 | + ) |
| 512 | + |
| 513 | + result = get_task_model(config, "some_other_task") |
| 514 | + assert result is not None |
| 515 | + assert result.type == "main" |
0 commit comments