Skip to content

Commit 7b29fc9

Browse files
oskarhanestellasia
authored andcommitted
Address PR comments: improve validation and remove redundant code
- Remove redundant convert_retriever_to_tool function from utils.py - Add validation for tool name uniqueness in ToolsRetriever - Add parameter type validation in Tool constructor - Update convert_to_tool() to use search() method instead of get_search_results() - This ensures retriever result_formatter is applied for consistent formatting - Update ToolsRetriever to handle RetrieverResult objects from formatted tools - Create consistent record structure with tool attribution and metadata Fixes the result formatting inconsistency issue identified in PR review. Each tool now returns consistently formatted results while preserving the original retriever's formatting logic.
1 parent 58a6309 commit 7b29fc9

File tree

7 files changed

+230
-305
lines changed

7 files changed

+230
-305
lines changed

examples/retrieve/tools/tools_retriever_example.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
StringParameter,
3939
Tool,
4040
)
41-
from neo4j_graphrag.tools.utils import convert_retriever_to_tool
4241
from neo4j_graphrag.llm.openai_llm import OpenAILLM
4342

4443
# Load environment variables from .env file (OPENAI_API_KEY required for this example)
@@ -241,23 +240,13 @@ def main() -> None:
241240
# Create retrievers
242241
neo4j_retriever = Neo4jInfoRetriever(driver=driver)
243242

244-
# Define parameters for the tools
245-
neo4j_parameters = ObjectParameter(
246-
description="Parameters for Neo4j information retrieval",
247-
properties={
248-
"query": StringParameter(
249-
description="The query about Neo4j",
250-
),
251-
},
252-
required_properties=["query"],
253-
)
254-
255243
# Convert retrievers to tools
256-
neo4j_tool = convert_retriever_to_tool(
257-
retriever=neo4j_retriever,
244+
neo4j_tool = neo4j_retriever.convert_to_tool(
258245
name="neo4j_info_tool",
259246
description="Get information about Neo4j graph database",
260-
parameters=neo4j_parameters,
247+
parameter_descriptions={
248+
"query_text": "The query about Neo4j",
249+
},
261250
)
262251

263252
# Create a calendar tool
@@ -325,7 +314,10 @@ def main() -> None:
325314
print("\nRESULTS:")
326315
for i, record in enumerate(result.records):
327316
print(f"\n--- Result {i + 1} ---")
328-
print(record)
317+
print(f"Content: {record.get('content', 'N/A')}")
318+
print(f"Tool: {record.get('tool_name', 'Unknown')}")
319+
if record.get("metadata"):
320+
print(f"Metadata: {record.get('metadata')}")
329321
except Exception as e:
330322
print(f"Error: {str(e)}")
331323
print(f"{'=' * 80}")

src/neo4j_graphrag/retrievers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def convert_to_tool(
430430

431431
# Define a function that matches the Callable[[str, ...], Any] signature
432432
def execute_func(**kwargs: Any) -> Any:
433-
return self.get_search_results(**kwargs)
433+
return self.search(**kwargs)
434434

435435
# Create a Tool object from the retriever
436436
return Tool(

src/neo4j_graphrag/retrievers/tools_retriever.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,24 @@ def __init__(
5555
super().__init__(driver, neo4j_database)
5656
self.llm = llm
5757
self._tools = list(tools) # Make a copy to allow modification
58+
self._validate_tool_names()
5859
self.system_instruction = (
5960
system_instruction or self._get_default_system_instruction()
6061
)
6162

63+
def _validate_tool_names(self) -> None:
64+
"""Validate that all tool names are unique."""
65+
tool_names = [tool.get_name() for tool in self._tools]
66+
duplicate_names = [
67+
name for name in set(tool_names) if tool_names.count(name) > 1
68+
]
69+
70+
if duplicate_names:
71+
raise ValueError(
72+
f"Duplicate tool names found: {duplicate_names}. "
73+
"All tools must have unique names for proper LLM tool selection."
74+
)
75+
6276
def _get_default_system_instruction(self) -> str:
6377
"""Get the default system instruction for the LLM."""
6478
return (
@@ -129,12 +143,48 @@ def get_search_results(
129143

130144
# Execute the tool with the provided arguments
131145
tool_result = selected_tool.execute(**tool_args)
132-
# If the tool result is a RawSearchResult, extract its records
133-
if hasattr(tool_result, "records"):
134-
all_records.extend(tool_result.records)
146+
147+
# Handle different tool result types
148+
if hasattr(tool_result, "items") and not callable(
149+
getattr(tool_result, "items")
150+
):
151+
# RetrieverResult from formatted retriever tools
152+
for item in tool_result.items:
153+
record = neo4j.Record(
154+
{
155+
"content": item.content,
156+
"tool_name": tool_name,
157+
"metadata": {
158+
**(item.metadata or {}),
159+
"tool": tool_name,
160+
},
161+
}
162+
)
163+
all_records.append(record)
164+
elif hasattr(tool_result, "records"):
165+
# RawSearchResult from raw retriever tools (legacy)
166+
for record in tool_result.records:
167+
# Wrap raw records with tool attribution
168+
attributed_record = neo4j.Record(
169+
{
170+
"content": str(record),
171+
"tool_name": tool_name,
172+
"metadata": {
173+
"original_record": dict(record),
174+
"tool": tool_name,
175+
},
176+
}
177+
)
178+
all_records.append(attributed_record)
135179
else:
136-
# Create a record from the tool result
137-
record = neo4j.Record({"result": tool_result})
180+
# Handle non-retriever tools or simple return values
181+
record = neo4j.Record(
182+
{
183+
"content": str(tool_result),
184+
"tool_name": tool_name,
185+
"metadata": {"tool": tool_name},
186+
}
187+
)
138188
all_records.append(record)
139189

140190
# Combine metadata from all tool calls

src/neo4j_graphrag/tools/tool.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,13 @@ def __init__(
231231
self._parameters = ObjectParameter.model_validate(parameters)
232232
elif isinstance(parameters, ObjectParameter):
233233
self._parameters = parameters
234-
else:
234+
elif parameters is None:
235235
self._parameters = None
236+
else:
237+
raise TypeError(
238+
f"Parameters must be None, dict, or ObjectParameter, "
239+
f"got {type(parameters).__name__}: {parameters}"
240+
)
236241

237242
def get_name(self) -> str:
238243
"""Get the name of the tool.

src/neo4j_graphrag/tools/utils.py

Lines changed: 0 additions & 76 deletions
This file was deleted.

tests/unit/retrievers/test_retriever_parameter_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def test_tool_execution(self):
402402

403403
# Check that we get a result (even if empty due to mocking)
404404
assert result is not None
405-
assert hasattr(result, "records")
405+
assert hasattr(result, "items") # Should return RetrieverResult now
406406
assert hasattr(result, "metadata")
407407

408408
def test_tool_execution_with_validation(self):

0 commit comments

Comments
 (0)