Skip to content

Commit b0679a1

Browse files
committed
fix tool result in chat history
1 parent 94811bc commit b0679a1

File tree

2 files changed

+223
-60
lines changed

2 files changed

+223
-60
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAGUIAgentRunner.java

Lines changed: 122 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -395,20 +395,61 @@ private void processAGUIMessages(MLAgent mlAgent, Map<String, String> params) {
395395
}
396396

397397
// Check for tool result messages and extract them
398-
// Also track which messages have tool calls so we can include them in chat history
398+
// Also track assistant messages with tool calls
399399
List<Map<String, String>> toolResults = new ArrayList<>();
400400
List<Integer> toolCallMessageIndices = new ArrayList<>();
401+
List<Integer> toolResultMessageIndices = new ArrayList<>();
402+
List<String> assistantToolCallMessages = new ArrayList<>();
403+
int lastToolResultIndex = -1;
401404

402405
for (int i = 0; i < messageArray.size(); i++) {
403406
JsonElement messageElement = messageArray.get(i);
404407
if (messageElement.isJsonObject()) {
405408
JsonObject message = messageElement.getAsJsonObject();
406409
String role = getStringField(message, "role");
407410

408-
// Track assistant messages with tool calls
411+
// Track and extract assistant messages with tool calls
409412
if ("assistant".equals(role) && message.has("toolCalls")) {
410413
toolCallMessageIndices.add(i);
411-
log.debug("AG-UI: Found assistant message with tool calls at index {}", i);
414+
415+
// Convert to OpenAI format for interactions
416+
JsonElement toolCallsElement = message.get("toolCalls");
417+
if (toolCallsElement != null && toolCallsElement.isJsonArray()) {
418+
List<Map<String, Object>> toolCalls = new ArrayList<>();
419+
for (JsonElement tcElement : toolCallsElement.getAsJsonArray()) {
420+
if (tcElement.isJsonObject()) {
421+
JsonObject tc = tcElement.getAsJsonObject();
422+
Map<String, Object> toolCall = new HashMap<>();
423+
424+
// OpenAI format: id, type, and function at the same level
425+
String toolCallId = getStringField(tc, "id");
426+
String toolCallType = getStringField(tc, "type");
427+
428+
toolCall.put("id", toolCallId);
429+
toolCall.put("type", toolCallType != null ? toolCallType : "function");
430+
431+
JsonElement functionElement = tc.get("function");
432+
if (functionElement != null && functionElement.isJsonObject()) {
433+
JsonObject func = functionElement.getAsJsonObject();
434+
Map<String, String> function = new HashMap<>();
435+
function.put("name", getStringField(func, "name"));
436+
function.put("arguments", getStringField(func, "arguments"));
437+
toolCall.put("function", function);
438+
}
439+
toolCalls.add(toolCall);
440+
log.debug("AG-UI: Extracted tool call - id: {}, type: {}", toolCallId, toolCallType);
441+
}
442+
}
443+
444+
// Create assistant message with tool_calls in OpenAI format
445+
Map<String, Object> assistantMsg = new HashMap<>();
446+
assistantMsg.put("role", "assistant");
447+
assistantMsg.put("tool_calls", toolCalls);
448+
String assistantMessage = gson.toJson(assistantMsg);
449+
assistantToolCallMessages.add(assistantMessage);
450+
log.debug("AG-UI: Extracted assistant message with {} tool calls at index {}", toolCalls.size(), i);
451+
log.debug("AG-UI: Assistant message JSON: {}", assistantMessage);
452+
}
412453
}
413454

414455
if ("tool".equals(role)) {
@@ -420,17 +461,69 @@ private void processAGUIMessages(MLAgent mlAgent, Map<String, String> params) {
420461
toolResult.put("tool_call_id", toolCallId);
421462
toolResult.put("content", content);
422463
toolResults.add(toolResult);
423-
log.info("AG-UI: Extracted tool result for toolCallId: {}", toolCallId);
464+
toolResultMessageIndices.add(i);
465+
lastToolResultIndex = i;
466+
log.info("AG-UI: Extracted tool result for toolCallId: {} at index {}", toolCallId, i);
424467
}
425468
}
426469
}
427470
}
428471

429-
// If we found tool results, add them to params so they can be processed
430-
if (!toolResults.isEmpty()) {
431-
params.put("agui_tool_call_results", gson.toJson(toolResults));
432-
params.put("agui_tool_call_message_indices", gson.toJson(toolCallMessageIndices));
433-
log.info("AG-UI: Found {} tool results in messages, added to params", toolResults.size());
472+
// Only process the MOST RECENT tool execution
473+
// Check if there are any assistant messages after the last tool result
474+
boolean hasAssistantAfterToolResult = false;
475+
if (lastToolResultIndex >= 0) {
476+
for (int i = lastToolResultIndex + 1; i < messageArray.size(); i++) {
477+
JsonElement messageElement = messageArray.get(i);
478+
if (messageElement.isJsonObject()) {
479+
JsonObject message = messageElement.getAsJsonObject();
480+
String role = getStringField(message, "role");
481+
if ("assistant".equals(role)) {
482+
hasAssistantAfterToolResult = true;
483+
log.debug("AG-UI: Found assistant message at index {} after tool result at index {}", i, lastToolResultIndex);
484+
break;
485+
}
486+
}
487+
}
488+
}
489+
490+
boolean toolResultsAreRecent = !toolResults.isEmpty() && !hasAssistantAfterToolResult;
491+
492+
if (!toolResults.isEmpty() && toolResultsAreRecent) {
493+
// Only include the MOST RECENT tool execution (last tool call + result pair)
494+
// Find the assistant message that corresponds to the last tool result
495+
int lastToolCallIndex = -1;
496+
String lastToolCallMessage = null;
497+
498+
// The last tool result should correspond to the last assistant message with tool_calls
499+
if (!assistantToolCallMessages.isEmpty() && !toolCallMessageIndices.isEmpty()) {
500+
lastToolCallIndex = toolCallMessageIndices.get(toolCallMessageIndices.size() - 1);
501+
lastToolCallMessage = assistantToolCallMessages.get(assistantToolCallMessages.size() - 1);
502+
}
503+
504+
// Only include the last tool result
505+
Map<String, String> lastToolResult = toolResults.get(toolResults.size() - 1);
506+
List<Map<String, String>> recentToolResults = List.of(lastToolResult);
507+
508+
String toolResultsJson = gson.toJson(recentToolResults);
509+
params.put("agui_tool_call_results", toolResultsJson);
510+
511+
// Only pass the most recent assistant message with tool_calls
512+
if (lastToolCallMessage != null) {
513+
params.put("agui_assistant_tool_call_messages", gson.toJson(List.of(lastToolCallMessage)));
514+
log.debug("AG-UI: Added most recent assistant tool call message (index {}) to params", lastToolCallIndex);
515+
}
516+
517+
log.info("AG-UI: Found most recent tool result at index {} (out of {} total tool results), added to params",
518+
lastToolResultIndex, toolResults.size());
519+
log.debug("AG-UI: Recent tool result JSON: {}", toolResultsJson);
520+
log.debug("AG-UI: Recent tool result - toolCallId: {}, content length: {}",
521+
lastToolResult.get("tool_call_id"),
522+
lastToolResult.get("content") != null ? lastToolResult.get("content").length() : 0);
523+
} else if (!toolResults.isEmpty()) {
524+
log.info("AG-UI: Found {} tool results but they are not recent (last at index {}, total messages: {}), " +
525+
"skipping from interactions",
526+
toolResults.size(), lastToolResultIndex, messageArray.size());
434527
}
435528

436529
String chatHistoryQuestionTemplate = params.get(CHAT_HISTORY_QUESTION_TEMPLATE);
@@ -440,28 +533,25 @@ private void processAGUIMessages(MLAgent mlAgent, Map<String, String> params) {
440533

441534
StringBuilder chatHistoryBuilder = new StringBuilder();
442535

443-
// Find the first tool call message index to know where to stop building chat history
444-
int firstToolCallIndex = toolCallMessageIndices.isEmpty() ? messageArray.size() : toolCallMessageIndices.get(0);
445-
446536
for (int i = 0; i < messageArray.size() - 1; i++) {
447537
JsonElement messageElement = messageArray.get(i);
448538
if (messageElement.isJsonObject()) {
449539
JsonObject message = messageElement.getAsJsonObject();
450540
String role = getStringField(message, "role");
451541
String content = getStringField(message, "content");
452542

453-
// Skip tool messages - they're handled separately via interactions
543+
// Skip tool messages - they're not part of chat history
454544
if ("tool".equals(role)) {
455545
continue;
456546
}
457547

458-
// Skip assistant messages with tool calls and any messages after them
459-
// They'll be handled by the function calling interface
460-
if (i >= firstToolCallIndex) {
461-
log.debug("AG-UI: Skipping message at index {} (after tool call sequence)", i);
548+
// Skip assistant messages with tool_calls - they're not part of chat history
549+
if ("assistant".equals(role) && message.has("toolCalls")) {
550+
log.debug("AG-UI: Skipping assistant message with tool_calls at index {} (not included in chat history)", i);
462551
continue;
463552
}
464553

554+
// Include user messages and assistant messages with content (final answers)
465555
if (("user".equals(role) || "assistant".equals(role)) && content != null && !content.isEmpty()) {
466556
if (chatHistoryBuilder.length() > 0) {
467557
chatHistoryBuilder.append("\n");
@@ -478,37 +568,33 @@ private void processAGUIMessages(MLAgent mlAgent, Map<String, String> params) {
478568
} else {
479569
List<String> chatHistory = new ArrayList<>();
480570

481-
// Find the first tool call message index to know where to stop building chat history
482-
int firstToolCallIndex = toolCallMessageIndices.isEmpty() ? messageArray.size() : toolCallMessageIndices.get(0);
483-
484571
for (int i = 0; i < messageArray.size() - 1; i++) {
485572
JsonElement messageElement = messageArray.get(i);
486573
if (messageElement.isJsonObject()) {
487574
JsonObject message = messageElement.getAsJsonObject();
488575
String role = getStringField(message, "role");
489576
String content = getStringField(message, "content");
490577

491-
// Skip tool messages - they're handled separately via interactions
578+
// Skip tool messages - they're never part of chat history
492579
if ("tool".equals(role)) {
580+
log.debug("AG-UI: Skipping tool message at index {} (not included in chat history)", i);
493581
continue;
494582
}
495583

496-
// Skip assistant messages with tool calls and any messages after them
497-
// They'll be handled by the function calling interface
498-
if (i >= firstToolCallIndex) {
499-
log.debug("AG-UI: Skipping message at index {} (after tool call sequence)", i);
500-
continue;
501-
}
502-
503-
if (content != null && !content.isEmpty()) {
584+
if ("user".equals(role) && content != null && !content.isEmpty()) {
504585
Map<String, String> messageParams = new HashMap<>();
505-
506-
if ("user".equals(role)) {
507-
messageParams.put("question", processTextDoc(content));
508-
StringSubstitutor substitutor = new StringSubstitutor(messageParams, CHAT_HISTORY_MESSAGE_PREFIX, "}");
509-
String chatMessage = substitutor.replace(chatHistoryQuestionTemplate);
510-
chatHistory.add(chatMessage);
511-
} else if ("assistant".equals(role)) {
586+
messageParams.put("question", processTextDoc(content));
587+
StringSubstitutor substitutor = new StringSubstitutor(messageParams, CHAT_HISTORY_MESSAGE_PREFIX, "}");
588+
String chatMessage = substitutor.replace(chatHistoryQuestionTemplate);
589+
chatHistory.add(chatMessage);
590+
} else if ("assistant".equals(role)) {
591+
// Skip ALL assistant messages with tool_calls - they're never part of chat history
592+
// (matching backend behavior where only final answers are in chat history)
593+
if (message.has("toolCalls")) {
594+
log.debug("AG-UI: Skipping assistant message with tool_calls at index {} (not included in chat history)", i);
595+
} else if (content != null && !content.isEmpty()) {
596+
// Regular assistant message with content (final answer)
597+
Map<String, String> messageParams = new HashMap<>();
512598
messageParams.put("response", processTextDoc(content));
513599
StringSubstitutor substitutor = new StringSubstitutor(messageParams, CHAT_HISTORY_MESSAGE_PREFIX, "}");
514600
String chatMessage = substitutor.replace(chatHistoryResponseTemplate);

0 commit comments

Comments
 (0)