Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 196 additions & 2 deletions libs/core/langchain_core/language_models/fake_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import asyncio
import re
import time
from collections.abc import AsyncIterator, Iterator
from typing import Any, Optional, Union, cast
from collections.abc import AsyncIterator, Iterator, Sequence
from typing import Any, Callable, Dict, Optional, Union, cast

from typing_extensions import override

Expand All @@ -17,6 +17,12 @@
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import RunnableConfig

from langchain_core.language_models.base import LanguageModelInput

from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables.config import run_in_executor
from langchain_core.tools import BaseTool


class FakeMessagesListChatModel(BaseChatModel):
"""Fake ChatModel for testing purposes."""
Expand Down Expand Up @@ -363,3 +369,191 @@ def _generate(
@property
def _llm_type(self) -> str:
return "parrot-fake-chat-model"


class FakeToolCallingListChatModel(BaseChatModel):
"""Fake Tool calling ChatModel for testing purposes."""

responses: list[Union[dict,str,BaseMessage]]
"""List of responses to **cycle** through in order."""
sleep: Optional[float] = None
i: int = 0
"""List of responses to **cycle** through in order."""
error_on_chunk_number: Optional[int] = None
"""Internally incremented after every model invocation."""

@property
@override
def _llm_type(self) -> str:
return "fake-list-chat-model"

@override
def _generate(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0

if isinstance(response,BaseMessage):
message = response
elif isinstance(response,str):
message = AIMessage(content=response)
elif isinstance(response,dict):
message = AIMessage(content=response["content"], tool_calls = response["tool_calls"])
else:
raise ValueError("Incorrect response type")

generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])

@override
async def _agenerate(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
return await run_in_executor(
None,
self._generate,
messages,
stop=stop,
run_manager=run_manager.get_sync() if run_manager else None,
**kwargs,
)

@override
def _stream(
self,
messages: list[BaseMessage],
stop: Union[list[str], None] = None,
run_manager: Union[CallbackManagerForLLMRun, None] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0

tool_calls = []
if isinstance(response,BaseMessage):
output = response.content
elif isinstance(response,str):
output = response
elif isinstance(response,dict):
output = response["content"]
tool_calls = response["tool_calls"]
else:
raise ValueError("Incorrect response type")

for i_c, c in enumerate(output):
if self.sleep is not None:
time.sleep(self.sleep)
if (
self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number
):
raise FakeListChatModelError

yield ChatGenerationChunk(message=AIMessageChunk(content=c, tool_calls=tool_calls))

@override
async def _astream(
self,
messages: list[BaseMessage],
stop: Union[list[str], None] = None,
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0

tool_calls = []
if isinstance(response,BaseMessage):
output = response.content
elif isinstance(response,str):
output = response
elif isinstance(response,dict):
output = response["content"]
tool_calls = response["tool_calls"]
else:
raise ValueError("Incorrect response type")

for i_c, c in enumerate(output):
if self.sleep is not None:
await asyncio.sleep(self.sleep)
if (
self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number
):
raise FakeListChatModelError
yield ChatGenerationChunk(message=AIMessageChunk(content=c, tool_calls=tool_calls))

@property
@override
def _identifying_params(self) -> dict[str, Any]:
return {"responses": self.responses}

@override
# manually override batch to preserve batch ordering with no concurrency
def batch(
self,
inputs: list[Any],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> list[BaseMessage]:
if isinstance(config, list):
return [self.invoke(m, c, **kwargs) for m, c in zip(inputs, config)]
return [self.invoke(m, config, **kwargs) for m in inputs]

@override
async def abatch(
self,
inputs: list[Any],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> list[BaseMessage]:
if isinstance(config, list):
# do Not use an async iterator here because need explicit ordering
return [await self.ainvoke(m, c, **kwargs) for m, c in zip(inputs, config)]
# do Not use an async iterator here because need explicit ordering
return [await self.ainvoke(m, config, **kwargs) for m in inputs]

@override
def bind_tools(
self,
tools: Sequence[
Union[Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
],
*,
tool_choice: Optional[Union[str]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tools to the model.

Args:
tools: Sequence of tools to bind to the model.
tool_choice: The tool to use. If "any" then any tool can be used.

Returns:
A Runnable that returns a message.
"""
kwargs["tool_choice"] = tool_choice
raise super().bind(tools=tools, **kwargs)