diff --git a/langgraph_supervisor/supervisor.py b/langgraph_supervisor/supervisor.py index 06fe325..a5cbdd4 100644 --- a/langgraph_supervisor/supervisor.py +++ b/langgraph_supervisor/supervisor.py @@ -3,7 +3,7 @@ from uuid import UUID, uuid5 from langchain_core.language_models import BaseChatModel, LanguageModelLike -from langchain_core.messages import AnyMessage +from langchain_core.messages import AnyMessage, ToolMessage from langchain_core.runnables import RunnableConfig from langchain_core.tools import BaseTool from langgraph.graph import END, START, StateGraph @@ -75,7 +75,10 @@ def _process_output(output: dict) -> dict: if output_mode == "full_history": pass elif output_mode == "last_message": - messages = messages[-1:] + if isinstance(messages[-1], ToolMessage): + messages = messages[-2:] + else: + messages = messages[-1:] else: raise ValueError(