Skip to content

Commit 029c5a6

Browse files
Merge pull request #32 from UiPath/fix/infer-optional-inputs
fix: process infer optional types
2 parents a603a7e + 9164dc8 commit 029c5a6

File tree

1 file changed

+30
-6
lines changed

1 file changed

+30
-6
lines changed

src/uipath_langchain/_cli/cli_init.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,25 @@ def resolve_refs(schema, root=None):
3232
return schema
3333

3434

35+
def process_nullable_types(
36+
schema: Dict[str, Any] | list[Any] | Any,
37+
) -> Dict[str, Any] | list[Any]:
38+
"""Process the schema to handle nullable types by removing anyOf with null and keeping the base type."""
39+
if isinstance(schema, dict):
40+
if "anyOf" in schema and len(schema["anyOf"]) == 2:
41+
types = [t.get("type") for t in schema["anyOf"]]
42+
if "null" in types:
43+
non_null_type = next(
44+
t for t in schema["anyOf"] if t.get("type") != "null"
45+
)
46+
return non_null_type
47+
48+
return {k: process_nullable_types(v) for k, v in schema.items()}
49+
elif isinstance(schema, list):
50+
return [process_nullable_types(item) for item in schema]
51+
return schema
52+
53+
3554
def generate_schema_from_graph(graph: CompiledStateGraph) -> Dict[str, Any]:
3655
"""Extract input/output schema from a LangGraph graph"""
3756
schema = {
@@ -42,24 +61,29 @@ def generate_schema_from_graph(graph: CompiledStateGraph) -> Dict[str, Any]:
4261
if hasattr(graph, "input_schema"):
4362
if hasattr(graph.input_schema, "model_json_schema"):
4463
input_schema = graph.input_schema.model_json_schema()
45-
4664
unpacked_ref_def_properties = resolve_refs(input_schema)
4765

48-
schema["input"]["properties"] = unpacked_ref_def_properties.get(
49-
"properties", {}
66+
# Process the schema to handle nullable types
67+
processed_properties = process_nullable_types(
68+
unpacked_ref_def_properties.get("properties", {})
5069
)
70+
71+
schema["input"]["properties"] = processed_properties
5172
schema["input"]["required"] = unpacked_ref_def_properties.get(
5273
"required", []
5374
)
5475

5576
if hasattr(graph, "output_schema"):
5677
if hasattr(graph.output_schema, "model_json_schema"):
5778
output_schema = graph.output_schema.model_json_schema()
58-
5979
unpacked_ref_def_properties = resolve_refs(output_schema)
60-
schema["output"]["properties"] = unpacked_ref_def_properties.get(
61-
"properties", {}
80+
81+
# Process the schema to handle nullable types
82+
processed_properties = process_nullable_types(
83+
unpacked_ref_def_properties.get("properties", {})
6284
)
85+
86+
schema["output"]["properties"] = processed_properties
6387
schema["output"]["required"] = unpacked_ref_def_properties.get(
6488
"required", []
6589
)

0 commit comments

Comments
 (0)