Skip to content

fix(core): properly handle empty or missing arguments for ToolCallChunk #32017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions libs/core/langchain_core/messages/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,12 @@ def add_chunk_to_invalid_tool_calls(chunk: ToolCallChunk) -> None:

for chunk in self.tool_call_chunks:
try:
args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {} # type: ignore[arg-type]
if isinstance(args_, dict):
args_ = (
parse_partial_json(chunk["args"])
if chunk["args"] is not None
else None
)
if isinstance(args_, dict) or args_ is None:
tool_calls.append(
create_tool_call(
name=chunk["name"] or "",
Expand Down
4 changes: 2 additions & 2 deletions libs/core/langchain_core/messages/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class ToolCall(TypedDict):

name: str
"""The name of the tool to be called."""
args: dict[str, Any]
args: Optional[dict[str, Any]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this change necessary?

An empty dict represents no args currently. I'd rather not add a second way to represent the same information

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, an empty dict should represent a completed tool call that has no arguments.

None is used exclusively to represent the in-progress state where arguments have not yet been received during streaming.

Using an empty dictionary for both the "no arguments" state and the "arguments are coming" state makes them indistinguishable. We wouldn't know how to interpret the message, e.g. two scenarios to illustrate this:

  1. Final State (No Arguments): The model finishes a tool call and intends for it to have zero args. It correctly sends args={}. We'd see this and execute the tool immediately.
  2. In-Transit State (Streaming): The model starts a tool call and is about to stream the arguments. If it used {} as a temporary placeholder, we'd see args={}

If we went back to args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {}, it doesn't handle the case where chunk["args"] is literally None. When a streaming chunk arrives with args=None, chunk["args"] != "" evaluates to True (since None != ""). The code then tries to execute parse_partial_json(None), which results in a TypeError

Even if it mapped None to {} we still lose the ability to distinguish between the two different scenarios.

"""The arguments to the tool call."""
id: Optional[str]
"""An identifier associated with the tool call.
Expand All @@ -209,7 +209,7 @@ class ToolCall(TypedDict):
def tool_call(
*,
name: str,
args: dict[str, Any],
args: Optional[dict[str, Any]],
id: Optional[str], # noqa: A002
) -> ToolCall:
"""Create a tool call.
Expand Down
7 changes: 6 additions & 1 deletion libs/core/langchain_core/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,7 +1109,12 @@ def _prep_run_args(
config = ensure_config(config)
if _is_tool_call(value):
tool_call_id: Optional[str] = cast("ToolCall", value)["id"]
tool_input: Union[str, dict] = cast("ToolCall", value)["args"].copy()
args = cast("ToolCall", value)["args"]
tool_input: Union[str, dict] = (
args.copy()
if isinstance(args, dict)
else (args if args is not None else {})
)
else:
tool_call_id = None
tool_input = cast("Union[str, dict]", value)
Expand Down
18 changes: 16 additions & 2 deletions libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -1001,8 +1001,15 @@
''',
'properties': dict({
'args': dict({
'anyOf': list([
dict({
'type': 'object',
}),
dict({
'type': 'null',
}),
]),
'title': 'Args',
'type': 'object',
}),
'id': dict({
'anyOf': list([
Expand Down Expand Up @@ -2433,8 +2440,15 @@
''',
'properties': dict({
'args': dict({
'anyOf': list([
dict({
'type': 'object',
}),
dict({
'type': 'null',
}),
]),
'title': 'Args',
'type': 'object',
}),
'id': dict({
'anyOf': list([
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1404,8 +1404,15 @@
''',
'properties': dict({
'args': dict({
'anyOf': list([
dict({
'type': 'object',
}),
dict({
'type': 'null',
}),
]),
'title': 'Args',
'type': 'object',
}),
'id': dict({
'anyOf': list([
Expand Down
6 changes: 4 additions & 2 deletions libs/core/tests/unit_tests/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def test_message_chunks() -> None:
create_tool_call_chunk(name="tool1", args="", id="1", index=0)
],
)
assert ai_msg_chunk.tool_calls == [create_tool_call(name="tool1", args={}, id="1")]
assert ai_msg_chunk.tool_call_chunks == [
create_tool_call_chunk(name="tool1", args="", id="1", index=0)
]

# Test token usage
left = AIMessageChunk(
Expand Down Expand Up @@ -455,9 +457,9 @@ def test_message_chunk_to_message() -> None:
tool_calls=[
create_tool_call(name="tool1", args={"a": 1}, id="1"),
create_tool_call(name="tool2", args={}, id="2"),
create_tool_call(name="tool3", args=None, id="3"),
],
invalid_tool_calls=[
create_invalid_tool_call(name="tool3", args=None, id="3", error=None),
create_invalid_tool_call(name="tool4", args="abc", id="4", error=None),
],
)
Expand Down
6 changes: 3 additions & 3 deletions libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2609,15 +2609,15 @@ async def async_no_op(foo: int) -> str:
"type": "tool_call",
}

assert tool.invoke(tool_call["args"]) == "good"
assert tool.invoke(tool_call.get("args", {})) == "good"
assert tool_call == {
"name": "sample_tool",
"args": {"foo": 2},
"id": "call_0_82c17db8-95df-452f-a4c2-03f809022134",
"type": "tool_call",
}

assert await tool.ainvoke(tool_call["args"]) == "good"
assert tool.invoke(tool_call.get("args", {})) == "good"

assert tool_call == {
"name": "sample_tool",
Expand Down Expand Up @@ -2657,7 +2657,7 @@ async def async_no_op(foo: int) -> str:
"type": "tool_call",
}

assert tool.invoke(tool_call["args"]) == "good"
assert tool.invoke(tool_call.get("args", {})) == "good"
assert tool_call == {
"name": "sample_tool",
"args": {"foo": 2},
Expand Down
Loading
Loading