@@ -32,6 +32,25 @@ def resolve_refs(schema, root=None):
32
32
return schema
33
33
34
34
35
+ def process_nullable_types (
36
+ schema : Dict [str , Any ] | list [Any ] | Any ,
37
+ ) -> Dict [str , Any ] | list [Any ]:
38
+ """Process the schema to handle nullable types by removing anyOf with null and keeping the base type."""
39
+ if isinstance (schema , dict ):
40
+ if "anyOf" in schema and len (schema ["anyOf" ]) == 2 :
41
+ types = [t .get ("type" ) for t in schema ["anyOf" ]]
42
+ if "null" in types :
43
+ non_null_type = next (
44
+ t for t in schema ["anyOf" ] if t .get ("type" ) != "null"
45
+ )
46
+ return non_null_type
47
+
48
+ return {k : process_nullable_types (v ) for k , v in schema .items ()}
49
+ elif isinstance (schema , list ):
50
+ return [process_nullable_types (item ) for item in schema ]
51
+ return schema
52
+
53
+
35
54
def generate_schema_from_graph (graph : CompiledStateGraph ) -> Dict [str , Any ]:
36
55
"""Extract input/output schema from a LangGraph graph"""
37
56
schema = {
@@ -42,24 +61,29 @@ def generate_schema_from_graph(graph: CompiledStateGraph) -> Dict[str, Any]:
42
61
if hasattr (graph , "input_schema" ):
43
62
if hasattr (graph .input_schema , "model_json_schema" ):
44
63
input_schema = graph .input_schema .model_json_schema ()
45
-
46
64
unpacked_ref_def_properties = resolve_refs (input_schema )
47
65
48
- schema ["input" ]["properties" ] = unpacked_ref_def_properties .get (
49
- "properties" , {}
66
+ # Process the schema to handle nullable types
67
+ processed_properties = process_nullable_types (
68
+ unpacked_ref_def_properties .get ("properties" , {})
50
69
)
70
+
71
+ schema ["input" ]["properties" ] = processed_properties
51
72
schema ["input" ]["required" ] = unpacked_ref_def_properties .get (
52
73
"required" , []
53
74
)
54
75
55
76
if hasattr (graph , "output_schema" ):
56
77
if hasattr (graph .output_schema , "model_json_schema" ):
57
78
output_schema = graph .output_schema .model_json_schema ()
58
-
59
79
unpacked_ref_def_properties = resolve_refs (output_schema )
60
- schema ["output" ]["properties" ] = unpacked_ref_def_properties .get (
61
- "properties" , {}
80
+
81
+ # Process the schema to handle nullable types
82
+ processed_properties = process_nullable_types (
83
+ unpacked_ref_def_properties .get ("properties" , {})
62
84
)
85
+
86
+ schema ["output" ]["properties" ] = processed_properties
63
87
schema ["output" ]["required" ] = unpacked_ref_def_properties .get (
64
88
"required" , []
65
89
)
0 commit comments