Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ private void sendFinalAnswer(
String finalAnswer
) {
// Send completion chunk for streaming
streamingWrapper.sendCompletionChunk(sessionId, parentInteractionId);
streamingWrapper.sendCompletionChunk(sessionId, parentInteractionId, null, null);

if (conversationIndexMemory != null) {
String copyOfFinalAnswer = finalAnswer;
Expand Down Expand Up @@ -794,6 +794,10 @@ private static String constructLLMPrompt(Map<String, Tool> tools, Map<String, St
@VisibleForTesting
static Map<String, String> constructLLMParams(LLMSpec llm, Map<String, String> parameters) {
Map<String, String> tmpParameters = new HashMap<>();

// Set agent type for Chat agent for streaming
tmpParameters.put("agent_type", "chat");

if (llm.getParameters() != null) {
tmpParameters.putAll(llm.getParameters());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import java.util.function.Consumer;

import org.apache.commons.text.StringSubstitutor;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.StepListener;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
Expand All @@ -57,7 +58,6 @@
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
Expand All @@ -66,8 +66,6 @@
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
Expand All @@ -92,6 +90,7 @@ public class MLPlanExecuteAndReflectAgentRunner implements MLAgentRunner {
private final Map<String, Memory.Factory> memoryFactoryMap;
private SdkClient sdkClient;
private Encryptor encryptor;
private StreamingWrapper streamingWrapper;
// flag to track if task has been updated with executor memory ids or not
private boolean taskUpdated = false;
private final Map<String, Object> taskUpdates = new HashMap<>();
Expand Down Expand Up @@ -182,6 +181,9 @@ public MLPlanExecuteAndReflectAgentRunner(

@VisibleForTesting
void setupPromptParameters(Map<String, String> params) {
// Set agent type for PER agent for streaming
params.put("agent_type", "per");

// populated depending on whether LLM is asked to plan or re-evaluate
// removed here, so that error is thrown in case this field is not populated
params.remove(PROMPT_FIELD);
Expand Down Expand Up @@ -273,6 +275,7 @@ void populatePrompt(Map<String, String> allParams) {

@Override
public void run(MLAgent mlAgent, Map<String, String> apiParams, ActionListener<Object> listener, TransportChannel channel) {
this.streamingWrapper = new StreamingWrapper(channel, client);
Map<String, String> allParams = new HashMap<>();
allParams.putAll(apiParams);
allParams.putAll(mlAgent.getParameters());
Expand Down Expand Up @@ -387,16 +390,7 @@ private void executePlanningLoop(
return;
}

MLPredictionTaskRequest request = new MLPredictionTaskRequest(
llm.getModelId(),
RemoteInferenceMLInput
.builder()
.algorithm(FunctionName.REMOTE)
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(allParams).build())
.build(),
null,
allParams.get(TENANT_ID_FIELD)
);
ActionRequest request = streamingWrapper.createPredictionRequest(llm, allParams, allParams.get(TENANT_ID_FIELD));

StepListener<MLTaskResponse> planListener = new StepListener<>();

Expand Down Expand Up @@ -550,8 +544,7 @@ private void executePlanningLoop(
log.error("Failed to run deep research agent", e);
finalListener.onFailure(e);
});

client.execute(MLPredictionTaskAction.INSTANCE, request, planListener);
streamingWrapper.executeRequest(request, planListener);
}

@VisibleForTesting
Expand Down Expand Up @@ -689,6 +682,9 @@ void saveAndReturnFinalResult(
}

memory.getMemoryManager().updateInteraction(parentInteractionId, updateContent, ActionListener.wrap(res -> {
// Send completion chunk to close streaming connection
streamingWrapper
.sendCompletionChunk(memory.getConversationId(), parentInteractionId, reactAgentMemoryId, reactParentInteractionId);
List<ModelTensors> finalModelTensors = createModelTensors(
memory.getConversationId(),
parentInteractionId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,23 @@ public void executeRequest(ActionRequest request, ActionListener<MLTaskResponse>
client.execute(MLPredictionTaskAction.INSTANCE, request, listener);
}

public void sendCompletionChunk(String sessionId, String parentInteractionId) {
public void sendCompletionChunk(
String sessionId,
String parentInteractionId,
String executorMemoryId,
String executorParentInteractionId
) {
if (!isStreaming) {
return;
}
MLTaskResponse completionChunk = createStreamChunk("", sessionId, parentInteractionId, true);
MLTaskResponse completionChunk = createStreamChunk(
"",
sessionId,
parentInteractionId,
executorMemoryId,
executorParentInteractionId,
true
);
try {
channel.sendResponseBatch(completionChunk);
} catch (Exception e) {
Expand All @@ -114,20 +126,29 @@ public void sendFinalResponse(
public void sendToolResponse(String toolOutput, String sessionId, String parentInteractionId) {
if (isStreaming) {
try {
MLTaskResponse toolChunk = createStreamChunk(toolOutput, sessionId, parentInteractionId, false);
MLTaskResponse toolChunk = createStreamChunk(toolOutput, sessionId, parentInteractionId, null, null, false);
channel.sendResponseBatch(toolChunk);
} catch (Exception e) {
log.error("Failed to send tool response chunk", e);
}
}
}

private MLTaskResponse createStreamChunk(String toolOutput, String sessionId, String parentInteractionId, boolean isLast) {
private MLTaskResponse createStreamChunk(
String toolOutput,
String sessionId,
String parentInteractionId,
String executorMemoryId,
String executorParentInteractionId,
boolean isLast
) {
List<ModelTensor> tensors = Arrays
.asList(
ModelTensor.builder().name("response").dataAsMap(Map.of("content", toolOutput, "is_last", isLast)).build(),
ModelTensor.builder().name("memory_id").result(sessionId).build(),
ModelTensor.builder().name("parent_interaction_id").result(parentInteractionId).build()
ModelTensor.builder().name("parent_interaction_id").result(parentInteractionId).build(),
ModelTensor.builder().name("executor_agent_memory_id").result(executorMemoryId).build(),
ModelTensor.builder().name("executor_agent_parent_interaction_id").result(executorParentInteractionId).build()
);

ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(tensors).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ public void startStream(
AtomicReference<String> toolUseId = new AtomicReference<>();
StringBuilder toolInputAccumulator = new StringBuilder();
AtomicReference<StreamState> currentState = new AtomicReference<>(StreamState.STREAMING_CONTENT);
String agentType = parameters.get("agent_type");
StringBuilder accumulatedContent = new StringBuilder();

// Build Bedrock client
BedrockRuntimeAsyncClient bedrockClient = buildBedrockRuntimeAsyncClient();
Expand Down Expand Up @@ -128,7 +130,18 @@ public void startStream(
log.debug("Tool execution in progress - keeping stream open");
}
}).subscriber(event -> {
handleStreamEvent(event, listener, isStreamClosed, toolName, toolInput, toolUseId, toolInputAccumulator, currentState);
handleStreamEvent(
event,
listener,
isStreamClosed,
toolName,
toolInput,
toolUseId,
toolInputAccumulator,
currentState,
agentType,
accumulatedContent
);
}).build();

// Start streaming
Expand Down Expand Up @@ -183,18 +196,28 @@ private void handleStreamEvent(
AtomicReference<Map<String, Object>> toolInput,
AtomicReference<String> toolUseId,
StringBuilder toolInputAccumulator,
AtomicReference<StreamState> currentState
AtomicReference<StreamState> currentState,
String agentType,
StringBuilder accumulatedContent
) {
switch (currentState.get()) {
case STREAMING_CONTENT:
if (isToolUseDetected(event)) {
currentState.set(StreamState.TOOL_CALL_DETECTED);
extractToolInfo(event, toolName, toolUseId);
} else if (isContentDelta(event)) {
sendContentResponse(getTextContent(event), false, listener);
String content = getTextContent(event);
accumulatedContent.append(content);
sendContentResponse(content, false, listener);
} else if (isStreamComplete(event)) {
currentState.set(StreamState.COMPLETED);
sendCompletionResponse(isStreamClosed, listener);
// For PER agent, we should keep the connection open after the planner LLM finish
if ("per".equals(agentType)) {
currentState.set(StreamState.WAITING_FOR_TOOL_RESULT);
sendPlannerResponse(false, listener, String.valueOf(accumulatedContent));
} else {
currentState.set(StreamState.COMPLETED);
sendCompletionResponse(isStreamClosed, listener);
}
}
break;

Expand Down Expand Up @@ -225,6 +248,26 @@ private void handleStreamEvent(
}
}

private void sendPlannerResponse(
boolean isStreamClosed,
StreamPredictActionListener<MLTaskResponse, ?> listener,
String plannerContent
) {
if (!isStreamClosed) {
Map<String, Object> responseMap = new HashMap<>();
responseMap.put("output", Map.of("message", Map.of("content", List.of(Map.of("text", plannerContent)))));

ModelTensor tensor = ModelTensor.builder().name("response").dataAsMap(responseMap).build();

ModelTensors tensors = ModelTensors.builder().mlModelTensors(List.of(tensor)).build();

ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build();

listener.onResponse(MLTaskResponse.builder().output(output).build());
log.debug("Sent planner response for PER agent");
}
}

// TODO: refactor the event type checker methods
private void extractToolInfo(ConverseStreamOutput event, AtomicReference<String> toolName, AtomicReference<String> toolUseId) {
ContentBlockStartEvent startEvent = (ContentBlockStartEvent) event;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ public void startStream(
) {
try {
log.info("Creating SSE connection for streaming request");
EventSourceListener listener = new HTTPEventSourceListener(actionListener, llmInterface);
String agentType = parameters.get("agent_type");
EventSourceListener listener = new HTTPEventSourceListener(actionListener, llmInterface, agentType);
Request request = ConnectorUtils.buildOKHttpStreamingRequest(action, connector, parameters, payload);

AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
Expand All @@ -100,16 +101,23 @@ public final class HTTPEventSourceListener extends EventSourceListener {
private StreamPredictActionListener<MLTaskResponse, ?> streamActionListener;
private final String llmInterface;
private AtomicBoolean isStreamClosed;
private final String agentType;
private boolean functionCallInProgress = false;
private boolean agentExecutionInProgress = false;
private String accumulatedToolCallId = null;
private String accumulatedToolName = null;
private String accumulatedArguments = "";
private StringBuilder accumulatedContent = new StringBuilder();

public HTTPEventSourceListener(StreamPredictActionListener<MLTaskResponse, ?> streamActionListener, String llmInterface) {
public HTTPEventSourceListener(
StreamPredictActionListener<MLTaskResponse, ?> streamActionListener,
String llmInterface,
String agentType
) {
this.streamActionListener = streamActionListener;
this.llmInterface = llmInterface;
this.isStreamClosed = new AtomicBoolean(false);
this.agentType = agentType;
}

/***
Expand Down Expand Up @@ -206,14 +214,20 @@ private void processStreamChunk(Map<String, Object> dataMap) {
// Handle stop finish reason
String finishReason = extractPath(dataMap, "$.choices[0].finish_reason");
if ("stop".equals(finishReason)) {
agentExecutionInProgress = false;
sendCompletionResponse(isStreamClosed, streamActionListener);
// For PER agent, we should keep the connection open after the planner LLM finish
if ("per".equals(agentType)) {
completePlannerResponse();
} else {
agentExecutionInProgress = false;
sendCompletionResponse(isStreamClosed, streamActionListener);
}
return;
}

// Process content
String content = extractPath(dataMap, "$.choices[0].delta.content");
if (content != null && !content.isEmpty()) {
accumulatedContent.append(content);
sendContentResponse(content, false, streamActionListener);
}

Expand Down Expand Up @@ -268,6 +282,19 @@ private ModelTensorOutput createModelTensorOutput(Map<String, Object> responseDa
return ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build();
}

private void completePlannerResponse() {
String fullContent = accumulatedContent.toString().trim();

// Create compatible response format
Map<String, Object> message = Map.of("content", fullContent);
Map<String, Object> choice = Map.of("message", message);
Map<String, Object> response = Map.of("choices", List.of(choice));

ModelTensorOutput output = createModelTensorOutput(response);
streamActionListener.onResponse(new MLTaskResponse(output));
agentExecutionInProgress = true;
}

private void accumulateFunctionCall(List<?> toolCalls) {
functionCallInProgress = true;
for (Object toolCall : toolCalls) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ public void testToolParameters() {
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(15, ((Map) argumentCaptor.getValue()).size());
assertEquals(16, ((Map) argumentCaptor.getValue()).size());

Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
Expand Down Expand Up @@ -738,7 +738,7 @@ public void testToolUseOriginalInput() {
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(16, ((Map) argumentCaptor.getValue()).size());
assertEquals(17, ((Map) argumentCaptor.getValue()).size());
assertEquals("raw input", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));

Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
Expand Down Expand Up @@ -804,7 +804,7 @@ public void testToolConfig() {
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(16, ((Map) argumentCaptor.getValue()).size());
assertEquals(17, ((Map) argumentCaptor.getValue()).size());
// The value of input should be "config_value".
assertEquals("config_value", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));

Expand Down Expand Up @@ -834,7 +834,7 @@ public void testToolConfigWithInputPlaceholder() {
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(16, ((Map) argumentCaptor.getValue()).size());
assertEquals(17, ((Map) argumentCaptor.getValue()).size());
// The value of input should be replaced with the value associated with the key "key2" of the first tool.
assertEquals("value2", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));

Expand Down
Loading
Loading