Skip to content

Commit 08f6994

Browse files
committed
Comments
1 parent 1027c9a commit 08f6994

File tree

7 files changed

+25
-11
lines changed

7 files changed

+25
-11
lines changed

src/neo4j_graphrag/llm/anthropic_llm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ def get_messages(
8585
if i["role"] == "system":
8686
system_instruction = i["content"]
8787
else:
88+
if i["role"] not in ("user", "assistant"):
89+
raise ValueError(f"Unknown role: {i['role']}")
8890
messages.append(
8991
self.anthropic.types.MessageParam(
9092
role=i["role"],

src/neo4j_graphrag/llm/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def _invoke_with_tools(
159159
) -> ToolCallResponse:
160160
raise NotImplementedError("This LLM provider does not support tool calling.")
161161

162+
@async_rate_limit_handler
162163
async def ainvoke_with_tools(
163164
self,
164165
input: str,

src/neo4j_graphrag/llm/cohere_llm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,14 @@ def get_messages(
8181
for i in input:
8282
if i["role"] == "system":
8383
messages.append(self.cohere.SystemChatMessageV2(content=i["content"]))
84-
if i["role"] == "user":
84+
elif i["role"] == "user":
8585
messages.append(self.cohere.UserChatMessageV2(content=i["content"]))
86-
if i["role"] == "assistant":
86+
elif i["role"] == "assistant":
8787
messages.append(
8888
self.cohere.AssistantChatMessageV2(content=i["content"])
8989
)
90+
else:
91+
raise ValueError(f"Unknown role: {i['role']}")
9092
return messages
9193

9294
def _invoke(

src/neo4j_graphrag/llm/mistralai_llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def get_messages(
8686
if m["role"] == "assistant":
8787
messages.append(AssistantMessage(content=m["content"]))
8888
continue
89+
raise ValueError(f"Unknown role: {m['role']}")
8990
return messages
9091

9192
def _invoke(

src/neo4j_graphrag/llm/vertexai_llm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,25 +97,27 @@ def get_messages(
9797
messages = []
9898
system_instruction = self.system_instruction
9999
for message in input:
100-
if message.get("role") == "system":
100+
role = message.get("role")
101+
if role == "system":
101102
system_instruction = message.get("content")
102103
continue
103-
if message.get("role") == "user":
104+
if role == "user":
104105
messages.append(
105106
Content(
106107
role="user",
107108
parts=[Part.from_text(message.get("content", ""))],
108109
)
109110
)
110111
continue
111-
if message.get("role") == "assistant":
112+
if role == "assistant":
112113
messages.append(
113114
Content(
114115
role="model",
115116
parts=[Part.from_text(message.get("content", ""))],
116117
)
117118
)
118119
continue
120+
raise ValueError(f"Unknown role: {role}")
119121
return system_instruction, messages
120122

121123
def _invoke(

tests/unit/llm/test_anthropic_llm.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,15 @@ def test_anthropic_llm_get_messages_without_system_instructions() -> None:
8383
assert actual["content"] == expected["content"]
8484

8585

86+
def test_anthropic_llm_get_messages_unknown_role() -> None:
87+
llm = AnthropicLLM(api_key="my key", model_name="claude")
88+
message_history = [
89+
LLMMessage(**{"role": "unknown role", "content": "Usually around 6am."}), # type: ignore[typeddict-item]
90+
]
91+
with pytest.raises(ValueError, match="Unknown role"):
92+
llm.get_messages(message_history)
93+
94+
8695
def test_anthropic_invoke_happy_path(mock_anthropic: Mock) -> None:
8796
mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock(
8897
content=[MagicMock(text="generated text")]

tests/unit/llm/test_vertexai_llm.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from unittest.mock import AsyncMock, MagicMock, Mock, patch
1717

1818
import pytest
19-
from neo4j_graphrag.exceptions import LLMGenerationError
2019
from neo4j_graphrag.llm.types import ToolCallResponse
2120
from neo4j_graphrag.llm.vertexai_llm import VertexAILLM
2221
from neo4j_graphrag.tool import Tool
@@ -89,16 +88,14 @@ def test_vertexai_get_messages(GenerativeModelMock: MagicMock) -> None:
8988

9089
@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel")
9190
def test_vertexai_get_messages_validation_error(GenerativeModelMock: MagicMock) -> None:
92-
system_instruction = "You are a helpful assistant."
9391
model_name = "gemini-1.5-flash-001"
94-
question = "hi!"
9592
message_history = [
9693
LLMMessage(**{"role": "model", "content": "hello!"}), # type: ignore[typeddict-item]
9794
]
9895

99-
llm = VertexAILLM(model_name=model_name, system_instruction=system_instruction)
100-
with pytest.raises(LLMGenerationError, match="Input validation failed"):
101-
llm.invoke(question, message_history)
96+
llm = VertexAILLM(model_name=model_name)
97+
with pytest.raises(ValueError, match="Unknown role"):
98+
llm.get_messages(message_history)
10299

103100

104101
@pytest.mark.asyncio

0 commit comments

Comments
 (0)