Skip to content

Commit 0358cd7

Browse files
committed
feat: support ToolMessage in message dicts
1 parent 30856b7 commit 0358cd7

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

nemoguardrails/integrations/langchain/runnable_rails.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
BaseMessage,
2727
HumanMessage,
2828
SystemMessage,
29+
ToolMessage,
2930
)
3031
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
3132
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
@@ -234,11 +235,23 @@ def _create_passthrough_messages(self, _input) -> List[Dict[str, Any]]:
234235
def _message_to_dict(self, msg: BaseMessage) -> Dict[str, Any]:
235236
"""Convert a BaseMessage to dictionary format."""
236237
if isinstance(msg, AIMessage):
237-
return {"role": "assistant", "content": msg.content}
238+
result = {"role": "assistant", "content": msg.content}
239+
if hasattr(msg, "tool_calls") and msg.tool_calls:
240+
result["tool_calls"] = msg.tool_calls
241+
return result
238242
elif isinstance(msg, HumanMessage):
239243
return {"role": "user", "content": msg.content}
240244
elif isinstance(msg, SystemMessage):
241245
return {"role": "system", "content": msg.content}
246+
elif isinstance(msg, ToolMessage):
247+
result = {
248+
"role": "tool",
249+
"content": msg.content,
250+
"tool_call_id": msg.tool_call_id,
251+
}
252+
if hasattr(msg, "name") and msg.name:
253+
result["name"] = msg.name
254+
return result
242255
else: # Handle other message types
243256
role = getattr(msg, "type", "user")
244257
return {"role": role, "content": msg.content}

0 commit comments

Comments
 (0)