Skip to content

Commit 23c9ae3

Browse files
authored
Fix tool dumping issues (#418)
The issue was caused by two JSON Schema violations: - Missing additionalProperties field in nested objects - Invalid "required": True on individual parameters (should be array at object level) Both issues are now resolved with proper JSON Schema generation that works with OpenAI's API requirements.
1 parent 4299faa commit 23c9ae3

File tree

3 files changed

+243
-13
lines changed

3 files changed

+243
-13
lines changed

src/neo4j_graphrag/tool.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ class ToolParameter(BaseModel):
2525
def model_dump_tool(self) -> Dict[str, Any]:
2626
"""Convert the parameter to a dictionary format for tool usage."""
2727
result: Dict[str, Any] = {"type": self.type, "description": self.description}
28-
if self.required:
29-
result["required"] = True
3028
return result
3129

3230
@classmethod
@@ -183,8 +181,8 @@ def model_dump_tool(self, exclude: Optional[list[str]] = None) -> Dict[str, Any]
183181
if self.required_properties and "required" not in exclude:
184182
result["required"] = self.required_properties
185183

186-
if not self.additional_properties and "additional_properties" not in exclude:
187-
result["additionalProperties"] = False
184+
if "additional_properties" not in exclude:
185+
result["additionalProperties"] = self.additional_properties
188186

189187
return result
190188

tests/unit/retrievers/test_retriever_parameter_inference.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Tests for retriever parameter inference and convert_to_tool functionality.
1919
"""
2020

21+
import pytest
2122
from unittest.mock import MagicMock, patch
2223
from typing import Optional, Any, Dict
2324

@@ -468,3 +469,105 @@ def get_search_results(self, test_param: str) -> RawSearchResult:
468469

469470
# Should use fallback description
470471
assert properties["test_param"].description == "Parameter test_param"
472+
473+
474+
class TestOpenAICompatibilityFix:
475+
"""Test the specific fixes for OpenAI API compatibility."""
476+
477+
@patch("neo4j_graphrag.retrievers.base.get_version")
478+
def test_text2cypher_retriever_openai_schema_compatibility(self, mock_get_version):
479+
"""Test that Text2CypherRetriever generates OpenAI-compatible schema.
480+
481+
This test specifically covers the bug that was causing:
482+
'Invalid schema for function 't2c_retriever': True is not of type 'array''
483+
"""
484+
mock_get_version.return_value = ((5, 20, 0), False, False)
485+
486+
driver = create_mock_driver()
487+
llm = create_mock_llm()
488+
retriever = Text2CypherRetriever(
489+
driver=driver, llm=llm, neo4j_schema="(Person)-[:KNOWS]->(Person)"
490+
)
491+
492+
# Convert to tool (this is where the original bug occurred)
493+
tool = retriever.convert_to_tool(
494+
name="t2c_retriever",
495+
description="Use this tool when no other tool can help. It will directly try to build a Cypher query to query the graph.",
496+
)
497+
498+
# Get the tool parameters schema
499+
schema = tool.get_parameters()
500+
501+
# Verify JSON Schema structure is correct for OpenAI
502+
assert schema["type"] == "object"
503+
assert "properties" in schema
504+
assert "required" in schema
505+
assert "additionalProperties" in schema
506+
507+
# Check that required is an array, not a boolean
508+
assert isinstance(schema["required"], list)
509+
assert "query_text" in schema["required"]
510+
511+
# Check individual properties don't have 'required' field
512+
for prop_name, prop_schema in schema["properties"].items():
513+
assert (
514+
"required" not in prop_schema
515+
), f"Property {prop_name} should not have 'required' field"
516+
517+
# Check the specific property that was causing issues
518+
prompt_params_schema = schema["properties"]["prompt_params"]
519+
assert prompt_params_schema["type"] == "object"
520+
assert "additionalProperties" in prompt_params_schema
521+
assert prompt_params_schema["additionalProperties"] is True
522+
523+
# Ensure the schema is valid JSON Schema format
524+
import json
525+
526+
try:
527+
# This should not raise any exceptions
528+
json_str = json.dumps(schema)
529+
parsed = json.loads(json_str)
530+
assert parsed == schema
531+
except (TypeError, ValueError) as e:
532+
pytest.fail(f"Schema is not JSON serializable: {e}")
533+
534+
@patch("neo4j_graphrag.retrievers.base.get_version")
535+
def test_tools_retriever_with_t2c_tool_integration(self, mock_get_version):
536+
"""Integration test showing the full ToolsRetriever + Text2CypherRetriever workflow."""
537+
mock_get_version.return_value = ((5, 20, 0), False, False)
538+
539+
driver = create_mock_driver()
540+
llm = create_mock_llm()
541+
542+
# Create a Text2CypherRetriever
543+
t2c_retriever = Text2CypherRetriever(
544+
driver=driver, llm=llm, neo4j_schema="(Movie)-[:ACTED_IN]-(Person)"
545+
)
546+
547+
# Convert it to a tool (this was failing before the fix)
548+
t2c_tool = t2c_retriever.convert_to_tool(
549+
name="t2c_retriever",
550+
description="Generate Cypher queries from natural language",
551+
)
552+
553+
# Create ToolsRetriever with the t2c_tool
554+
tools_retriever = ToolsRetriever(driver=driver, llm=llm, tools=[t2c_tool])
555+
556+
# Verify that the tools_retriever was created successfully
557+
assert len(tools_retriever._tools) == 1
558+
assert tools_retriever._tools[0].get_name() == "t2c_retriever"
559+
560+
# Get the tool's parameters to verify schema structure
561+
tool_params = t2c_tool.get_parameters()
562+
563+
# This should have the correct structure that OpenAI expects
564+
assert tool_params["type"] == "object"
565+
assert isinstance(tool_params["required"], list)
566+
assert "additionalProperties" in tool_params
567+
568+
# All nested objects should also have additionalProperties
569+
for prop_name, prop_schema in tool_params["properties"].items():
570+
if prop_schema.get("type") == "object":
571+
assert (
572+
"additionalProperties" in prop_schema
573+
), f"Nested object {prop_name} missing additionalProperties"

tests/unit/tool/test_tool.py

Lines changed: 138 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ def test_string_parameter() -> None:
2121
d = param.model_dump_tool()
2222
assert d["type"] == ParameterType.STRING
2323
assert d["enum"] == ["a", "b"]
24-
assert d["required"] is True
24+
# Note: 'required' is handled at the object level, not individual parameter level
25+
assert "required" not in d
2526

2627

2728
def test_integer_parameter() -> None:
@@ -141,38 +142,166 @@ def test_from_dict() -> None:
141142

142143

143144
def test_required_parameter() -> None:
144-
# Test that required=True is included in model_dump_tool output for different parameter types
145+
# Test that individual parameters don't include 'required' field (it's handled at object level)
145146
string_param = StringParameter(description="Required string", required=True)
146-
assert string_param.model_dump_tool()["required"] is True
147+
assert "required" not in string_param.model_dump_tool()
147148

148149
integer_param = IntegerParameter(description="Required integer", required=True)
149-
assert integer_param.model_dump_tool()["required"] is True
150+
assert "required" not in integer_param.model_dump_tool()
150151

151152
number_param = NumberParameter(description="Required number", required=True)
152-
assert number_param.model_dump_tool()["required"] is True
153+
assert "required" not in number_param.model_dump_tool()
153154

154155
boolean_param = BooleanParameter(description="Required boolean", required=True)
155-
assert boolean_param.model_dump_tool()["required"] is True
156+
assert "required" not in boolean_param.model_dump_tool()
156157

157158
array_param = ArrayParameter(
158159
description="Required array",
159160
items=StringParameter(description="item"),
160161
required=True,
161162
)
162-
assert array_param.model_dump_tool()["required"] is True
163+
assert "required" not in array_param.model_dump_tool()
163164

164165
object_param = ObjectParameter(
165166
description="Required object",
166167
properties={"prop": StringParameter(description="property")},
167168
required=True,
168169
)
169-
assert object_param.model_dump_tool()["required"] is True
170+
assert "required" not in object_param.model_dump_tool()
170171

171-
# Test that required=False doesn't include the required field
172+
# Test that optional parameters also don't include the required field
172173
optional_param = StringParameter(description="Optional string", required=False)
173174
assert "required" not in optional_param.model_dump_tool()
174175

175176

177+
def test_object_parameter_additional_properties_always_present() -> None:
178+
"""Test that additionalProperties is always present in ObjectParameter schema, fixing OpenAI compatibility."""
179+
180+
# Test additionalProperties=True (default)
181+
obj_param_true = ObjectParameter(
182+
description="Object with additional properties",
183+
properties={"prop": StringParameter(description="A property")},
184+
additional_properties=True,
185+
)
186+
schema_true = obj_param_true.model_dump_tool()
187+
assert "additionalProperties" in schema_true
188+
assert schema_true["additionalProperties"] is True
189+
190+
# Test additionalProperties=False
191+
obj_param_false = ObjectParameter(
192+
description="Object without additional properties",
193+
properties={"prop": StringParameter(description="A property")},
194+
additional_properties=False,
195+
)
196+
schema_false = obj_param_false.model_dump_tool()
197+
assert "additionalProperties" in schema_false
198+
assert schema_false["additionalProperties"] is False
199+
200+
201+
def test_json_schema_compatibility() -> None:
202+
"""Test that the generated schema is compatible with JSON Schema specification."""
203+
204+
# Create a complex object with nested properties and required fields
205+
nested_obj = ObjectParameter(
206+
description="Nested object",
207+
properties={
208+
"nested_prop": StringParameter(description="Nested string"),
209+
},
210+
additional_properties=True,
211+
)
212+
213+
main_obj = ObjectParameter(
214+
description="Main object",
215+
properties={
216+
"required_string": StringParameter(description="Required string"),
217+
"optional_number": NumberParameter(description="Optional number"),
218+
"nested_object": nested_obj,
219+
},
220+
required_properties=["required_string"],
221+
additional_properties=False,
222+
)
223+
224+
schema = main_obj.model_dump_tool()
225+
226+
# Verify JSON Schema structure
227+
assert schema["type"] == "object"
228+
assert "properties" in schema
229+
assert "required" in schema
230+
assert "additionalProperties" in schema
231+
232+
# Check required is an array (not boolean on individual properties)
233+
assert isinstance(schema["required"], list)
234+
assert "required_string" in schema["required"]
235+
assert len(schema["required"]) == 1
236+
237+
# Check individual properties don't have 'required' field
238+
for prop_name, prop_schema in schema["properties"].items():
239+
assert "required" not in prop_schema
240+
241+
# Check additionalProperties is properly set at all levels
242+
assert schema["additionalProperties"] is False
243+
assert schema["properties"]["nested_object"]["additionalProperties"] is True
244+
245+
246+
def test_text2cypher_retriever_schema_compatibility() -> None:
247+
"""Test the specific schema structure that caused the OpenAI API error."""
248+
249+
# Simulate the Text2CypherRetriever parameter structure
250+
prompt_params = ObjectParameter(
251+
description="Parameter prompt_params",
252+
properties={},
253+
additional_properties=True, # This was missing in the original bug
254+
)
255+
256+
t2c_params = ObjectParameter(
257+
description="Parameters for Text2CypherRetriever",
258+
properties={
259+
"query_text": StringParameter(description="Parameter query_text"),
260+
"prompt_params": prompt_params,
261+
},
262+
required_properties=["query_text"],
263+
additional_properties=False,
264+
)
265+
266+
schema = t2c_params.model_dump_tool()
267+
268+
# Verify the fix: prompt_params should have additionalProperties
269+
prompt_params_schema = schema["properties"]["prompt_params"]
270+
assert "additionalProperties" in prompt_params_schema
271+
assert prompt_params_schema["additionalProperties"] is True
272+
273+
# Verify query_text doesn't have individual 'required' field
274+
query_text_schema = schema["properties"]["query_text"]
275+
assert "required" not in query_text_schema
276+
277+
# Verify required array at object level
278+
assert schema["required"] == ["query_text"]
279+
280+
281+
def test_exclude_parameter_in_object_schema() -> None:
282+
"""Test that exclude parameter works correctly in ObjectParameter.model_dump_tool()."""
283+
284+
obj_param = ObjectParameter(
285+
description="Test object",
286+
properties={
287+
"prop1": StringParameter(description="Property 1"),
288+
"prop2": IntegerParameter(description="Property 2"),
289+
},
290+
required_properties=["prop1"],
291+
additional_properties=True,
292+
)
293+
294+
# Test excluding required field
295+
schema_no_required = obj_param.model_dump_tool(exclude=["required"])
296+
assert "required" not in schema_no_required
297+
assert "additionalProperties" in schema_no_required # Should still be present
298+
299+
# Test excluding additionalProperties field
300+
schema_no_additional = obj_param.model_dump_tool(exclude=["additional_properties"])
301+
assert "additionalProperties" not in schema_no_additional
302+
assert "required" in schema_no_additional # Should still be present
303+
304+
176305
def test_tool_class() -> None:
177306
def dummy_func(**kwargs: Any) -> dict[str, Any]:
178307
return kwargs

0 commit comments

Comments
 (0)