Skip to content

Commit 4299faa

Browse files
authored
Add ToolsRetriever class and Retriever.convert_to_tool() method (#332)
* Add ToolsRetriever class and convert_retriever_to_tool() fn
1 parent 09a65fc commit 4299faa

File tree

13 files changed

+2327
-10
lines changed

13 files changed

+2327
-10
lines changed

CHANGELOG.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22

33
## Next
44

5+
### Added
6+
7+
- Added a `ToolsRetriever` retriever that uses an LLM to decide on what tools to use to find the relevant data.
8+
- Added `convert_to_tool` method to the `Retriever` interface to convert a Retriever to a Tool so it can be used within the ToolsRetriever. This is useful when you might want to have both a VectorRetriever and a Text2CypherRetreiver as a fallback.
9+
510
### Fixed
611

712
- Fixed an edge case where the LLM can output a property with type 'map', which was causing errors during import as it is not a valid property type in Neo4j.
813

9-
1014
## 1.9.1
1115

1216
### Fixed
@@ -26,6 +30,7 @@
2630

2731
- Added automatic rate limiting with retry logic and exponential backoff for all LLM providers using tenacity. The `RateLimitHandler` interface allows for custom rate limiting strategies, including the ability to disable rate limiting entirely.
2832

33+
2934
## 1.8.0
3035

3136
### Added

examples/customize/llms/openai_tool_calls.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717

1818
from neo4j_graphrag.llm import OpenAILLM
1919
from neo4j_graphrag.llm.types import ToolCallResponse
20-
from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter
20+
from neo4j_graphrag.tool import (
21+
Tool,
22+
ObjectParameter,
23+
StringParameter,
24+
IntegerParameter,
25+
)
2126

2227
# Load environment variables from .env file (OPENAI_API_KEY required for this example)
2328
load_dotenv()

examples/customize/llms/vertexai_tool_calls.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111

1212
from neo4j_graphrag.llm import VertexAILLM
1313
from neo4j_graphrag.llm.types import ToolCallResponse
14-
from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter
14+
from neo4j_graphrag.tool import (
15+
Tool,
16+
ObjectParameter,
17+
StringParameter,
18+
IntegerParameter,
19+
)
1520

1621
# Load environment variables from .env file
1722
load_dotenv()
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
"""
17+
Example demonstrating how to create multiple domain-specific tools from retrievers.
18+
19+
This example shows:
20+
1. How to create multiple tools from the same retriever type for different use cases
21+
2. How to provide custom parameter descriptions for each tool
22+
3. How type inference works automatically while descriptions are explicit
23+
"""
24+
25+
import neo4j
26+
from typing import cast, Any, Optional
27+
from unittest.mock import MagicMock
28+
29+
from neo4j_graphrag.retrievers.base import Retriever
30+
from neo4j_graphrag.types import RawSearchResult
31+
32+
33+
class MockVectorRetriever(Retriever):
34+
"""A mock vector retriever for demonstration purposes."""
35+
36+
VERIFY_NEO4J_VERSION = False
37+
38+
def __init__(self, driver: neo4j.Driver, index_name: str):
39+
super().__init__(driver)
40+
self.index_name = index_name
41+
42+
def get_search_results(
43+
self,
44+
query_vector: Optional[list[float]] = None,
45+
query_text: Optional[str] = None,
46+
top_k: int = 5,
47+
effective_search_ratio: int = 1,
48+
filters: Optional[dict[str, Any]] = None,
49+
) -> RawSearchResult:
50+
"""Get vector search results (mocked for demonstration)."""
51+
# Return empty results for demo
52+
return RawSearchResult(records=[], metadata={"index": self.index_name})
53+
54+
55+
def main() -> None:
56+
"""Demonstrate creating multiple domain-specific tools from retrievers."""
57+
58+
# Create mock driver (in real usage, this would be actual Neo4j driver)
59+
driver = cast(Any, MagicMock())
60+
61+
# Create retrievers for different domains using the same retriever type
62+
# In practice, these would point to different vector indexes
63+
64+
# Movie recommendations retriever
65+
movie_retriever = MockVectorRetriever(driver=driver, index_name="movie_embeddings")
66+
67+
# Product search retriever
68+
product_retriever = MockVectorRetriever(
69+
driver=driver, index_name="product_embeddings"
70+
)
71+
72+
# Document search retriever
73+
document_retriever = MockVectorRetriever(
74+
driver=driver, index_name="document_embeddings"
75+
)
76+
77+
# Convert each retriever to a domain-specific tool with custom descriptions
78+
79+
# 1. Movie recommendation tool
80+
movie_tool = movie_retriever.convert_to_tool(
81+
name="movie_search",
82+
description="Find movie recommendations based on plot, genre, or actor preferences",
83+
parameter_descriptions={
84+
"query_text": "Movie title, plot description, genre, or actor name",
85+
"query_vector": "Pre-computed embedding vector for movie search",
86+
"top_k": "Number of movie recommendations to return (1-20)",
87+
"filters": "Optional filters for genre, year, rating, etc.",
88+
"effective_search_ratio": "Search pool multiplier for better accuracy",
89+
},
90+
)
91+
92+
# 2. Product search tool
93+
product_tool = product_retriever.convert_to_tool(
94+
name="product_search",
95+
description="Search for products matching customer needs and preferences",
96+
parameter_descriptions={
97+
"query_text": "Product name, description, or customer need",
98+
"query_vector": "Pre-computed embedding for product matching",
99+
"top_k": "Maximum number of product results (1-50)",
100+
"filters": "Filters for price range, brand, category, availability",
101+
"effective_search_ratio": "Breadth vs precision trade-off for search",
102+
},
103+
)
104+
105+
# 3. Document search tool
106+
document_tool = document_retriever.convert_to_tool(
107+
name="document_search",
108+
description="Find relevant documents and knowledge articles",
109+
parameter_descriptions={
110+
"query_text": "Question, keywords, or topic to search for",
111+
"query_vector": "Semantic embedding for document retrieval",
112+
"top_k": "Number of relevant documents to retrieve (1-10)",
113+
"filters": "Document type, date range, or department filters",
114+
},
115+
)
116+
117+
# Demonstrate that each tool has distinct, meaningful descriptions
118+
tools = [movie_tool, product_tool, document_tool]
119+
120+
for tool in tools:
121+
print(f"\n=== {tool.get_name().upper()} ===")
122+
print(f"Description: {tool.get_description()}")
123+
print("Parameters:")
124+
125+
params = tool.get_parameters()
126+
for param_name, param_def in params["properties"].items():
127+
required = (
128+
"required" if param_name in params.get("required", []) else "optional"
129+
)
130+
print(
131+
f" - {param_name} ({param_def['type']}, {required}): {param_def['description']}"
132+
)
133+
134+
# Show how the same parameter type gets different contextual descriptions
135+
print("\n=== PARAMETER COMPARISON ===")
136+
print("Same parameter 'query_text' with different contextual descriptions:")
137+
138+
for tool in tools:
139+
params = tool.get_parameters()
140+
query_text_desc = params["properties"]["query_text"]["description"]
141+
print(f" {tool.get_name()}: {query_text_desc}")
142+
143+
print("\nSame parameter 'top_k' with different contextual descriptions:")
144+
for tool in tools:
145+
params = tool.get_parameters()
146+
top_k_desc = params["properties"]["top_k"]["description"]
147+
print(f" {tool.get_name()}: {top_k_desc}")
148+
149+
150+
if __name__ == "__main__":
151+
main()
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
"""
17+
Example demonstrating how to convert a retriever to a tool.
18+
19+
This example shows:
20+
1. How to convert a custom StaticRetriever to a Tool using the convert_to_tool method
21+
2. How to define parameters for the tool in the retriever class
22+
3. How to execute the tool
23+
"""
24+
25+
import neo4j
26+
from typing import Optional, Any, cast
27+
from unittest.mock import MagicMock
28+
29+
from neo4j_graphrag.retrievers.base import Retriever
30+
from neo4j_graphrag.types import RawSearchResult
31+
32+
33+
# Create a Retriever that returns static results about Neo4j
34+
# This would illustrate the conversion process of any Retriever (Vector, Hybrid, etc.)
35+
class StaticRetriever(Retriever):
36+
"""A retriever that returns static results about Neo4j."""
37+
38+
# Disable Neo4j version verification
39+
VERIFY_NEO4J_VERSION = False
40+
41+
def __init__(self, driver: neo4j.Driver):
42+
# Call the parent class constructor with the driver
43+
super().__init__(driver)
44+
45+
def get_search_results(
46+
self, query_text: Optional[str] = None, **kwargs: Any
47+
) -> RawSearchResult:
48+
"""Return static information about Neo4j regardless of the query.
49+
50+
Args:
51+
query_text (Optional[str]): The query about Neo4j (any query will return general Neo4j information)
52+
**kwargs (Any): Additional keyword arguments (not used)
53+
54+
Returns:
55+
RawSearchResult: Static Neo4j information with metadata
56+
"""
57+
# Create formatted Neo4j information
58+
neo4j_info = (
59+
"# Neo4j Graph Database\n\n"
60+
"Neo4j is a graph database management system developed by Neo4j, Inc. "
61+
"It is an ACID-compliant transactional database with native graph storage and processing.\n\n"
62+
"## Key Features:\n\n"
63+
"- **Cypher Query Language**: Neo4j's intuitive query language designed specifically for working with graph data\n"
64+
"- **Property Graphs**: Both nodes and relationships can have properties (key-value pairs)\n"
65+
"- **ACID Compliance**: Ensures data integrity with full transaction support\n"
66+
"- **Native Graph Storage**: Optimized storage for graph data structures\n"
67+
"- **High Availability**: Clustering for enterprise deployments\n"
68+
"- **Scalability**: Handles billions of nodes and relationships"
69+
)
70+
71+
# Create a Neo4j record with the information
72+
records = [neo4j.Record({"result": neo4j_info})]
73+
74+
# Return a RawSearchResult with the records and metadata
75+
return RawSearchResult(records=records, metadata={"query": query_text})
76+
77+
78+
def main() -> None:
79+
# Convert a StaticRetriever to a tool using the new convert_to_tool method
80+
static_retriever = StaticRetriever(driver=cast(Any, MagicMock()))
81+
82+
# Convert the retriever to a tool with custom parameter descriptions
83+
static_tool = static_retriever.convert_to_tool(
84+
name="Neo4jInfoTool",
85+
description="Get general information about Neo4j graph database",
86+
parameter_descriptions={
87+
"query_text": "Any query about Neo4j (the tool returns general information regardless)"
88+
},
89+
)
90+
91+
# Print tool information
92+
print("Example: StaticRetriever with specific parameters")
93+
print(f"Tool Name: {static_tool.get_name()}")
94+
print(f"Tool Description: {static_tool.get_description()}")
95+
print(f"Tool Parameters: {static_tool.get_parameters()}")
96+
print()
97+
98+
# Execute the tools (in a real application, this would be done by instructions from an LLM)
99+
try:
100+
# Execute the static retriever tool
101+
print("\nExecuting the static retriever tool...")
102+
static_result = static_tool.execute(
103+
query_text="What is Neo4j?",
104+
)
105+
print("Static Search Results:")
106+
for i, item in enumerate(static_result):
107+
print(f"{i + 1}. {str(item)[:100]}...")
108+
109+
except Exception as e:
110+
print(f"Error executing tool: {e}")
111+
112+
113+
if __name__ == "__main__":
114+
main()

0 commit comments

Comments
 (0)