Skip to content

fix(llmrails): handle LLM models without model_kwargs field in isolation #1336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,8 +593,6 @@ def _create_action_llm_copy(
and isolated_llm.model_kwargs is not None
):
isolated_llm.model_kwargs = isolated_llm.model_kwargs.copy()
else:
isolated_llm.model_kwargs = {}

log.debug(
"Successfully created isolated LLM copy for action: %s", action_name
Expand Down
3 changes: 1 addition & 2 deletions tests/test_llm_isolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,7 @@ def test_create_action_llm_copy_with_none_model_kwargs(self, rails_with_mock_llm

isolated_llm = rails._create_action_llm_copy(original_llm, "test_action")

assert isolated_llm.model_kwargs == {}
assert isinstance(isolated_llm.model_kwargs, dict)
assert isolated_llm.model_kwargs is None

def test_create_action_llm_copy_handles_copy_failure(self, rails_with_mock_llm):
"""Test that copy failures raise detailed error message."""
Expand Down
192 changes: 192 additions & 0 deletions tests/test_llm_isolation_model_kwargs_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for LLM isolation with models that don't have model_kwargs field."""

from typing import Any, Dict, List, Optional
from unittest.mock import Mock

import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from pydantic import BaseModel, Field

from nemoguardrails.rails.llm.config import RailsConfig
from nemoguardrails.rails.llm.llmrails import LLMRails


class StrictPydanticLLM(BaseModel):
"""Mock Pydantic LLM that doesn't allow arbitrary attributes (like ChatNVIDIA)."""

class Config:
extra = "forbid"

temperature: float = Field(default=0.7)
max_tokens: Optional[int] = Field(default=None)


class MockChatNVIDIA(BaseChatModel):
"""Mock ChatNVIDIA-like model that doesn't have model_kwargs."""

model: str = "nvidia-model"
temperature: float = 0.7

class Config:
extra = "forbid"

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[Any] = None,
**kwargs: Any,
) -> ChatResult:
"""Mock generation method."""
return ChatResult(generations=[ChatGeneration(message=Mock())])

@property
def _llm_type(self) -> str:
"""Return the type of language model."""
return "nvidia"


class FlexibleLLMWithModelKwargs(BaseModel):
"""Mock LLM that has model_kwargs and allows modifications."""

class Config:
extra = "allow"

model_kwargs: Dict[str, Any] = Field(default_factory=dict)
temperature: float = 0.7


class FlexibleLLMWithoutModelKwargs(BaseModel):
"""Mock LLM that doesn't have model_kwargs but allows adding attributes."""

class Config:
extra = "allow"

temperature: float = 0.7
# no model_kwargs field


@pytest.fixture
def test_config():
"""Create test configuration."""
return RailsConfig.from_content(
"""
models:
- type: main
engine: openai
model: gpt-3.5-turbo
"""
)


class TestLLMIsolationModelKwargsFix:
"""Test LLM isolation with different model types."""

def test_strict_pydantic_model_without_model_kwargs(self, test_config):
"""Test isolation with strict Pydantic model that doesn't have model_kwargs."""
rails = LLMRails(config=test_config, verbose=False)

strict_llm = StrictPydanticLLM(temperature=0.5)

isolated_llm = rails._create_action_llm_copy(strict_llm, "test_action")

assert isolated_llm is not None
assert isolated_llm is not strict_llm
assert isolated_llm.temperature == 0.5
assert not hasattr(isolated_llm, "model_kwargs")

def test_mock_chat_nvidia_without_model_kwargs(self, test_config):
"""Test with a ChatNVIDIA-like model that doesn't allow arbitrary attributes."""
rails = LLMRails(config=test_config, verbose=False)

nvidia_llm = MockChatNVIDIA()

isolated_llm = rails._create_action_llm_copy(nvidia_llm, "self_check_output")

assert isolated_llm is not None
assert isolated_llm is not nvidia_llm
assert isolated_llm.model == "nvidia-model"
assert isolated_llm.temperature == 0.7
assert not hasattr(isolated_llm, "model_kwargs")

def test_flexible_llm_with_model_kwargs(self, test_config):
"""Test with LLM that has model_kwargs field."""
rails = LLMRails(config=test_config, verbose=False)

llm_with_kwargs = FlexibleLLMWithModelKwargs(
model_kwargs={"custom_param": "value"}, temperature=0.3
)

isolated_llm = rails._create_action_llm_copy(llm_with_kwargs, "test_action")

assert isolated_llm is not None
assert isolated_llm is not llm_with_kwargs
assert hasattr(isolated_llm, "model_kwargs")
assert isolated_llm.model_kwargs == {"custom_param": "value"}
assert isolated_llm.model_kwargs is not llm_with_kwargs.model_kwargs

isolated_llm.model_kwargs["new_param"] = "new_value"
assert "new_param" not in llm_with_kwargs.model_kwargs

def test_flexible_llm_without_model_kwargs_but_allows_adding(self, test_config):
"""Test with LLM that doesn't have model_kwargs but allows adding attributes."""
rails = LLMRails(config=test_config, verbose=False)

flexible_llm = FlexibleLLMWithoutModelKwargs(temperature=0.8)

isolated_llm = rails._create_action_llm_copy(flexible_llm, "test_action")

assert isolated_llm is not None
assert isolated_llm is not flexible_llm
assert isolated_llm.temperature == 0.8
# since it allows extra attributes, model_kwargs might have been added
# but it shouldn't cause an error either way

def test_llm_with_none_model_kwargs(self, test_config):
"""Test with LLM that has model_kwargs set to None."""
rails = LLMRails(config=test_config, verbose=False)

llm_with_none = FlexibleLLMWithModelKwargs(temperature=0.6)
llm_with_none.model_kwargs = None

isolated_llm = rails._create_action_llm_copy(llm_with_none, "test_action")

assert isolated_llm is not None
assert isolated_llm is not llm_with_none
if hasattr(isolated_llm, "model_kwargs"):
assert isolated_llm.model_kwargs in (None, {})

def test_copy_preserves_other_attributes(self, test_config):
"""Test that copy preserves other attributes correctly."""
rails = LLMRails(config=test_config, verbose=False)

strict_llm = StrictPydanticLLM(temperature=0.2, max_tokens=100)
isolated_strict = rails._create_action_llm_copy(strict_llm, "action1")

assert isolated_strict.temperature == 0.2
assert isolated_strict.max_tokens == 100

flexible_llm = FlexibleLLMWithModelKwargs(
model_kwargs={"key": "value"}, temperature=0.9
)
isolated_flexible = rails._create_action_llm_copy(flexible_llm, "action2")

assert isolated_flexible.temperature == 0.9
assert isolated_flexible.model_kwargs == {"key": "value"}
51 changes: 41 additions & 10 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,41 @@ class FakeLLM(LLM):
"""Fake LLM wrapper for testing purposes."""

responses: List
i: int = 0
streaming: bool = False
exception: Optional[Exception] = None
token_usage: Optional[List[Dict[str, int]]] = None # Token usage per response
should_enable_stream_usage: bool = False
_shared_state: Optional[Dict] = None # Shared state for isolated copies

def __init__(self, **kwargs):
"""Initialize FakeLLM."""
# Extract initial counter value before parent init
initial_i = kwargs.pop("i", 0)
super().__init__(**kwargs)
# If no shared state, create one with initial counter
if self._shared_state is None:
self._shared_state = {"counter": initial_i}

def __copy__(self):
"""Create a shallow copy that shares state with the original."""
new_instance = self.__class__.__new__(self.__class__)
new_instance.__dict__.update(self.__dict__)
# Share the same state dict so counter is synchronized
new_instance._shared_state = self._shared_state
return new_instance

@property
def i(self) -> int:
"""Get current counter value from shared state."""
if self._shared_state:
return self._shared_state["counter"]
return 0

@i.setter
def i(self, value: int):
"""Set counter value in shared state."""
if self._shared_state:
self._shared_state["counter"] = value

@property
def _llm_type(self) -> str:
Expand All @@ -67,14 +97,15 @@ def _call(
if self.exception:
raise self.exception

if self.i >= len(self.responses):
current_i = self.i
if current_i >= len(self.responses):
raise RuntimeError(
f"No responses available for query number {self.i + 1} in FakeLLM. "
f"No responses available for query number {current_i + 1} in FakeLLM. "
"Most likely, too many LLM calls are made or additional responses need to be provided."
)

response = self.responses[self.i]
self.i += 1
response = self.responses[current_i]
self.i = current_i + 1
return response

async def _acall(
Expand All @@ -88,15 +119,15 @@ async def _acall(
if self.exception:
raise self.exception

if self.i >= len(self.responses):
current_i = self.i
if current_i >= len(self.responses):
raise RuntimeError(
f"No responses available for query number {self.i + 1} in FakeLLM. "
f"No responses available for query number {current_i + 1} in FakeLLM. "
"Most likely, too many LLM calls are made or additional responses need to be provided."
)

response = self.responses[self.i]

self.i += 1
response = self.responses[current_i]
self.i = current_i + 1

if self.streaming and run_manager:
# To mock streaming, we just split in chunk by spaces
Expand Down