Skip to content

Commit 1306f49

Browse files
oskarhanestellasia
authored andcommitted
Add ToolsRetriever class and convert_retriever_to_tool() fn
1 parent 65a85c7 commit 1306f49

File tree

17 files changed

+1533
-15
lines changed

17 files changed

+1533
-15
lines changed

examples/customize/llms/openai_tool_calls.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717

1818
from neo4j_graphrag.llm import OpenAILLM
1919
from neo4j_graphrag.llm.types import ToolCallResponse
20-
from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter
20+
from neo4j_graphrag.tools.tool import (
21+
Tool,
22+
ObjectParameter,
23+
StringParameter,
24+
IntegerParameter,
25+
)
2126

2227
# Load environment variables from .env file (OPENAI_API_KEY required for this example)
2328
load_dotenv()

examples/customize/llms/vertexai_tool_calls.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111

1212
from neo4j_graphrag.llm import VertexAILLM
1313
from neo4j_graphrag.llm.types import ToolCallResponse
14-
from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter
14+
from neo4j_graphrag.tools.tool import (
15+
Tool,
16+
ObjectParameter,
17+
StringParameter,
18+
IntegerParameter,
19+
)
1520

1621
# Load environment variables from .env file
1722
load_dotenv()
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
"""
17+
Example demonstrating how to convert a retriever to a tool.
18+
19+
This example shows:
20+
1. How to convert a custom StaticRetriever to a Tool
21+
2. How to define parameters for the tool
22+
3. How to execute the tool
23+
"""
24+
25+
import neo4j
26+
from typing import Optional, Any, cast
27+
from unittest.mock import MagicMock
28+
29+
from neo4j_graphrag.retrievers.base import Retriever
30+
from neo4j_graphrag.types import RawSearchResult
31+
from neo4j_graphrag.tools.tool import (
32+
StringParameter,
33+
ObjectParameter,
34+
)
35+
from neo4j_graphrag.tools.utils import convert_retriever_to_tool
36+
37+
38+
# Create a Retriever that returns static results about Neo4j
39+
# This would illustrate the conversion process of any Retriever (Vector, Hybrid, etc.)
40+
class StaticRetriever(Retriever):
41+
"""A retriever that returns static results about Neo4j."""
42+
43+
# Disable Neo4j version verification
44+
VERIFY_NEO4J_VERSION = False
45+
46+
def __init__(self, driver: neo4j.Driver):
47+
# Call the parent class constructor with the driver
48+
super().__init__(driver)
49+
50+
def get_search_results(
51+
self, query_text: Optional[str] = None, **kwargs: Any
52+
) -> RawSearchResult:
53+
"""Return static information about Neo4j regardless of the query."""
54+
# Create formatted Neo4j information
55+
neo4j_info = (
56+
"# Neo4j Graph Database\n\n"
57+
"Neo4j is a graph database management system developed by Neo4j, Inc. "
58+
"It is an ACID-compliant transactional database with native graph storage and processing.\n\n"
59+
"## Key Features:\n\n"
60+
"- **Cypher Query Language**: Neo4j's intuitive query language designed specifically for working with graph data\n"
61+
"- **Property Graphs**: Both nodes and relationships can have properties (key-value pairs)\n"
62+
"- **ACID Compliance**: Ensures data integrity with full transaction support\n"
63+
"- **Native Graph Storage**: Optimized storage for graph data structures\n"
64+
"- **High Availability**: Clustering for enterprise deployments\n"
65+
"- **Scalability**: Handles billions of nodes and relationships"
66+
)
67+
68+
# Create a Neo4j record with the information
69+
records = [neo4j.Record({"result": neo4j_info})]
70+
71+
# Return a RawSearchResult with the records and metadata
72+
return RawSearchResult(records=records, metadata={"query": query_text})
73+
74+
75+
def main() -> None:
76+
# Convert a StaticRetriever to a tool with specific parameters
77+
static_retriever = StaticRetriever(driver=cast(Any, MagicMock()))
78+
79+
# Define parameters for the static retriever tool
80+
static_parameters = ObjectParameter(
81+
description="Parameters for the Neo4j information retriever",
82+
properties={
83+
"query_text": StringParameter(
84+
description="The query about Neo4j (any query will return general Neo4j information)",
85+
required=True,
86+
),
87+
},
88+
)
89+
90+
# Convert the retriever to a tool with specific parameters
91+
static_tool = convert_retriever_to_tool(
92+
retriever=static_retriever,
93+
description="Get general information about Neo4j graph database",
94+
parameters=static_parameters,
95+
name="Neo4jInfoTool",
96+
)
97+
98+
# Print tool information
99+
print("Example: StaticRetriever with specific parameters")
100+
print(f"Tool Name: {static_tool.get_name()}")
101+
print(f"Tool Description: {static_tool.get_description()}")
102+
print(f"Tool Parameters: {static_tool.get_parameters()}")
103+
print()
104+
105+
# Execute the tools (in a real application, this would be done by instructions from an LLM)
106+
try:
107+
# Execute the static retriever tool
108+
print("\nExecuting the static retriever tool...")
109+
static_result = static_tool.execute(
110+
query="What is Neo4j?",
111+
)
112+
print("Static Search Results:")
113+
for i, item in enumerate(static_result):
114+
print(f"{i + 1}. {str(item)[:100]}...")
115+
116+
except Exception as e:
117+
print(f"Error executing tool: {e}")
118+
119+
120+
if __name__ == "__main__":
121+
main()

0 commit comments

Comments
 (0)