diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 3083505d3..dc415263c 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -593,8 +593,6 @@ def _create_action_llm_copy( and isolated_llm.model_kwargs is not None ): isolated_llm.model_kwargs = isolated_llm.model_kwargs.copy() - else: - isolated_llm.model_kwargs = {} log.debug( "Successfully created isolated LLM copy for action: %s", action_name diff --git a/tests/test_llm_isolation.py b/tests/test_llm_isolation.py index 8d07ea48f..c66e0d8e7 100644 --- a/tests/test_llm_isolation.py +++ b/tests/test_llm_isolation.py @@ -225,8 +225,7 @@ def test_create_action_llm_copy_with_none_model_kwargs(self, rails_with_mock_llm isolated_llm = rails._create_action_llm_copy(original_llm, "test_action") - assert isolated_llm.model_kwargs == {} - assert isinstance(isolated_llm.model_kwargs, dict) + assert isolated_llm.model_kwargs is None def test_create_action_llm_copy_handles_copy_failure(self, rails_with_mock_llm): """Test that copy failures raise detailed error message.""" diff --git a/tests/test_llm_isolation_model_kwargs_fix.py b/tests/test_llm_isolation_model_kwargs_fix.py new file mode 100644 index 000000000..a4ace2b29 --- /dev/null +++ b/tests/test_llm_isolation_model_kwargs_fix.py @@ -0,0 +1,192 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for LLM isolation with models that don't have model_kwargs field.""" + +from typing import Any, Dict, List, Optional +from unittest.mock import Mock + +import pytest +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage +from langchain_core.outputs import ChatGeneration, ChatResult +from pydantic import BaseModel, Field + +from nemoguardrails.rails.llm.config import RailsConfig +from nemoguardrails.rails.llm.llmrails import LLMRails + + +class StrictPydanticLLM(BaseModel): + """Mock Pydantic LLM that doesn't allow arbitrary attributes (like ChatNVIDIA).""" + + class Config: + extra = "forbid" + + temperature: float = Field(default=0.7) + max_tokens: Optional[int] = Field(default=None) + + +class MockChatNVIDIA(BaseChatModel): + """Mock ChatNVIDIA-like model that doesn't have model_kwargs.""" + + model: str = "nvidia-model" + temperature: float = 0.7 + + class Config: + extra = "forbid" + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[Any] = None, + **kwargs: Any, + ) -> ChatResult: + """Mock generation method.""" + return ChatResult(generations=[ChatGeneration(message=Mock())]) + + @property + def _llm_type(self) -> str: + """Return the type of language model.""" + return "nvidia" + + +class FlexibleLLMWithModelKwargs(BaseModel): + """Mock LLM that has model_kwargs and allows modifications.""" + + class Config: + extra = "allow" + + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + temperature: float = 0.7 + + +class FlexibleLLMWithoutModelKwargs(BaseModel): + """Mock LLM that doesn't have model_kwargs but allows adding attributes.""" + + class Config: + extra = "allow" + + temperature: float = 0.7 + # no model_kwargs field + + +@pytest.fixture +def test_config(): + """Create test configuration.""" + return RailsConfig.from_content( + """ + models: + - type: main + engine: openai + model: gpt-3.5-turbo + """ + ) + + +class TestLLMIsolationModelKwargsFix: + """Test LLM isolation with different model types.""" + + def test_strict_pydantic_model_without_model_kwargs(self, test_config): + """Test isolation with strict Pydantic model that doesn't have model_kwargs.""" + rails = LLMRails(config=test_config, verbose=False) + + strict_llm = StrictPydanticLLM(temperature=0.5) + + isolated_llm = rails._create_action_llm_copy(strict_llm, "test_action") + + assert isolated_llm is not None + assert isolated_llm is not strict_llm + assert isolated_llm.temperature == 0.5 + assert not hasattr(isolated_llm, "model_kwargs") + + def test_mock_chat_nvidia_without_model_kwargs(self, test_config): + """Test with a ChatNVIDIA-like model that doesn't allow arbitrary attributes.""" + rails = LLMRails(config=test_config, verbose=False) + + nvidia_llm = MockChatNVIDIA() + + isolated_llm = rails._create_action_llm_copy(nvidia_llm, "self_check_output") + + assert isolated_llm is not None + assert isolated_llm is not nvidia_llm + assert isolated_llm.model == "nvidia-model" + assert isolated_llm.temperature == 0.7 + assert not hasattr(isolated_llm, "model_kwargs") + + def test_flexible_llm_with_model_kwargs(self, test_config): + """Test with LLM that has model_kwargs field.""" + rails = LLMRails(config=test_config, verbose=False) + + llm_with_kwargs = FlexibleLLMWithModelKwargs( + model_kwargs={"custom_param": "value"}, temperature=0.3 + ) + + isolated_llm = rails._create_action_llm_copy(llm_with_kwargs, "test_action") + + assert isolated_llm is not None + assert isolated_llm is not llm_with_kwargs + assert hasattr(isolated_llm, "model_kwargs") + assert isolated_llm.model_kwargs == {"custom_param": "value"} + assert isolated_llm.model_kwargs is not llm_with_kwargs.model_kwargs + + isolated_llm.model_kwargs["new_param"] = "new_value" + assert "new_param" not in llm_with_kwargs.model_kwargs + + def test_flexible_llm_without_model_kwargs_but_allows_adding(self, test_config): + """Test with LLM that doesn't have model_kwargs but allows adding attributes.""" + rails = LLMRails(config=test_config, verbose=False) + + flexible_llm = FlexibleLLMWithoutModelKwargs(temperature=0.8) + + isolated_llm = rails._create_action_llm_copy(flexible_llm, "test_action") + + assert isolated_llm is not None + assert isolated_llm is not flexible_llm + assert isolated_llm.temperature == 0.8 + # since it allows extra attributes, model_kwargs might have been added + # but it shouldn't cause an error either way + + def test_llm_with_none_model_kwargs(self, test_config): + """Test with LLM that has model_kwargs set to None.""" + rails = LLMRails(config=test_config, verbose=False) + + llm_with_none = FlexibleLLMWithModelKwargs(temperature=0.6) + llm_with_none.model_kwargs = None + + isolated_llm = rails._create_action_llm_copy(llm_with_none, "test_action") + + assert isolated_llm is not None + assert isolated_llm is not llm_with_none + if hasattr(isolated_llm, "model_kwargs"): + assert isolated_llm.model_kwargs in (None, {}) + + def test_copy_preserves_other_attributes(self, test_config): + """Test that copy preserves other attributes correctly.""" + rails = LLMRails(config=test_config, verbose=False) + + strict_llm = StrictPydanticLLM(temperature=0.2, max_tokens=100) + isolated_strict = rails._create_action_llm_copy(strict_llm, "action1") + + assert isolated_strict.temperature == 0.2 + assert isolated_strict.max_tokens == 100 + + flexible_llm = FlexibleLLMWithModelKwargs( + model_kwargs={"key": "value"}, temperature=0.9 + ) + isolated_flexible = rails._create_action_llm_copy(flexible_llm, "action2") + + assert isolated_flexible.temperature == 0.9 + assert isolated_flexible.model_kwargs == {"key": "value"} diff --git a/tests/utils.py b/tests/utils.py index 7b73d4767..2c71c7551 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -45,11 +45,41 @@ class FakeLLM(LLM): """Fake LLM wrapper for testing purposes.""" responses: List - i: int = 0 streaming: bool = False exception: Optional[Exception] = None token_usage: Optional[List[Dict[str, int]]] = None # Token usage per response should_enable_stream_usage: bool = False + _shared_state: Optional[Dict] = None # Shared state for isolated copies + + def __init__(self, **kwargs): + """Initialize FakeLLM.""" + # Extract initial counter value before parent init + initial_i = kwargs.pop("i", 0) + super().__init__(**kwargs) + # If no shared state, create one with initial counter + if self._shared_state is None: + self._shared_state = {"counter": initial_i} + + def __copy__(self): + """Create a shallow copy that shares state with the original.""" + new_instance = self.__class__.__new__(self.__class__) + new_instance.__dict__.update(self.__dict__) + # Share the same state dict so counter is synchronized + new_instance._shared_state = self._shared_state + return new_instance + + @property + def i(self) -> int: + """Get current counter value from shared state.""" + if self._shared_state: + return self._shared_state["counter"] + return 0 + + @i.setter + def i(self, value: int): + """Set counter value in shared state.""" + if self._shared_state: + self._shared_state["counter"] = value @property def _llm_type(self) -> str: @@ -67,14 +97,15 @@ def _call( if self.exception: raise self.exception - if self.i >= len(self.responses): + current_i = self.i + if current_i >= len(self.responses): raise RuntimeError( - f"No responses available for query number {self.i + 1} in FakeLLM. " + f"No responses available for query number {current_i + 1} in FakeLLM. " "Most likely, too many LLM calls are made or additional responses need to be provided." ) - response = self.responses[self.i] - self.i += 1 + response = self.responses[current_i] + self.i = current_i + 1 return response async def _acall( @@ -88,15 +119,15 @@ async def _acall( if self.exception: raise self.exception - if self.i >= len(self.responses): + current_i = self.i + if current_i >= len(self.responses): raise RuntimeError( - f"No responses available for query number {self.i + 1} in FakeLLM. " + f"No responses available for query number {current_i + 1} in FakeLLM. " "Most likely, too many LLM calls are made or additional responses need to be provided." ) - response = self.responses[self.i] - - self.i += 1 + response = self.responses[current_i] + self.i = current_i + 1 if self.streaming and run_manager: # To mock streaming, we just split in chunk by spaces