diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 103f3f89b3..815abae41e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -719,7 +719,7 @@ private void sendFinalAnswer( String finalAnswer ) { // Send completion chunk for streaming - streamingWrapper.sendCompletionChunk(sessionId, parentInteractionId); + streamingWrapper.sendCompletionChunk(sessionId, parentInteractionId, null, null); if (conversationIndexMemory != null) { String copyOfFinalAnswer = finalAnswer; @@ -794,6 +794,10 @@ private static String constructLLMPrompt(Map tools, Map constructLLMParams(LLMSpec llm, Map parameters) { Map tmpParameters = new HashMap<>(); + + // Set agent type for Chat agent for streaming + tmpParameters.put("agent_type", "chat"); + if (llm.getParameters() != null) { tmpParameters.putAll(llm.getParameters()); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index 762b60ca5c..b2eeada1b8 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -43,6 +43,7 @@ import java.util.function.Consumer; import org.apache.commons.text.StringSubstitutor; +import org.opensearch.action.ActionRequest; import org.opensearch.action.StepListener; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -57,7 +58,6 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; -import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -66,8 +66,6 @@ import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; -import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; -import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.memory.ConversationIndexMemory; @@ -92,6 +90,7 @@ public class MLPlanExecuteAndReflectAgentRunner implements MLAgentRunner { private final Map memoryFactoryMap; private SdkClient sdkClient; private Encryptor encryptor; + private StreamingWrapper streamingWrapper; // flag to track if task has been updated with executor memory ids or not private boolean taskUpdated = false; private final Map taskUpdates = new HashMap<>(); @@ -182,6 +181,9 @@ public MLPlanExecuteAndReflectAgentRunner( @VisibleForTesting void setupPromptParameters(Map params) { + // Set agent type for PER agent for streaming + params.put("agent_type", "per"); + // populated depending on whether LLM is asked to plan or re-evaluate // removed here, so that error is thrown in case this field is not populated params.remove(PROMPT_FIELD); @@ -273,6 +275,7 @@ void populatePrompt(Map allParams) { @Override public void run(MLAgent mlAgent, Map apiParams, ActionListener listener, TransportChannel channel) { + this.streamingWrapper = new StreamingWrapper(channel, client); Map allParams = new HashMap<>(); allParams.putAll(apiParams); allParams.putAll(mlAgent.getParameters()); @@ -387,16 +390,7 @@ private void executePlanningLoop( return; } - MLPredictionTaskRequest request = new MLPredictionTaskRequest( - llm.getModelId(), - RemoteInferenceMLInput - .builder() - .algorithm(FunctionName.REMOTE) - .inputDataset(RemoteInferenceInputDataSet.builder().parameters(allParams).build()) - .build(), - null, - allParams.get(TENANT_ID_FIELD) - ); + ActionRequest request = streamingWrapper.createPredictionRequest(llm, allParams, allParams.get(TENANT_ID_FIELD)); StepListener planListener = new StepListener<>(); @@ -550,8 +544,7 @@ private void executePlanningLoop( log.error("Failed to run deep research agent", e); finalListener.onFailure(e); }); - - client.execute(MLPredictionTaskAction.INSTANCE, request, planListener); + streamingWrapper.executeRequest(request, planListener); } @VisibleForTesting @@ -689,6 +682,9 @@ void saveAndReturnFinalResult( } memory.getMemoryManager().updateInteraction(parentInteractionId, updateContent, ActionListener.wrap(res -> { + // Send completion chunk to close streaming connection + streamingWrapper + .sendCompletionChunk(memory.getConversationId(), parentInteractionId, reactAgentMemoryId, reactParentInteractionId); List finalModelTensors = createModelTensors( memory.getConversationId(), parentInteractionId, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapper.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapper.java index beb5e60f53..9f28e3a246 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapper.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapper.java @@ -83,11 +83,23 @@ public void executeRequest(ActionRequest request, ActionListener client.execute(MLPredictionTaskAction.INSTANCE, request, listener); } - public void sendCompletionChunk(String sessionId, String parentInteractionId) { + public void sendCompletionChunk( + String sessionId, + String parentInteractionId, + String executorMemoryId, + String executorParentInteractionId + ) { if (!isStreaming) { return; } - MLTaskResponse completionChunk = createStreamChunk("", sessionId, parentInteractionId, true); + MLTaskResponse completionChunk = createStreamChunk( + "", + sessionId, + parentInteractionId, + executorMemoryId, + executorParentInteractionId, + true + ); try { channel.sendResponseBatch(completionChunk); } catch (Exception e) { @@ -114,7 +126,7 @@ public void sendFinalResponse( public void sendToolResponse(String toolOutput, String sessionId, String parentInteractionId) { if (isStreaming) { try { - MLTaskResponse toolChunk = createStreamChunk(toolOutput, sessionId, parentInteractionId, false); + MLTaskResponse toolChunk = createStreamChunk(toolOutput, sessionId, parentInteractionId, null, null, false); channel.sendResponseBatch(toolChunk); } catch (Exception e) { log.error("Failed to send tool response chunk", e); @@ -122,12 +134,21 @@ public void sendToolResponse(String toolOutput, String sessionId, String parentI } } - private MLTaskResponse createStreamChunk(String toolOutput, String sessionId, String parentInteractionId, boolean isLast) { + private MLTaskResponse createStreamChunk( + String toolOutput, + String sessionId, + String parentInteractionId, + String executorMemoryId, + String executorParentInteractionId, + boolean isLast + ) { List tensors = Arrays .asList( ModelTensor.builder().name("response").dataAsMap(Map.of("content", toolOutput, "is_last", isLast)).build(), ModelTensor.builder().name("memory_id").result(sessionId).build(), - ModelTensor.builder().name("parent_interaction_id").result(parentInteractionId).build() + ModelTensor.builder().name("parent_interaction_id").result(parentInteractionId).build(), + ModelTensor.builder().name("executor_agent_memory_id").result(executorMemoryId).build(), + ModelTensor.builder().name("executor_agent_parent_interaction_id").result(executorParentInteractionId).build() ); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(tensors).build(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/BedrockStreamingHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/BedrockStreamingHandler.java index 0ec9cce537..583c10bf0b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/BedrockStreamingHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/BedrockStreamingHandler.java @@ -95,6 +95,8 @@ public void startStream( AtomicReference toolUseId = new AtomicReference<>(); StringBuilder toolInputAccumulator = new StringBuilder(); AtomicReference currentState = new AtomicReference<>(StreamState.STREAMING_CONTENT); + String agentType = parameters.get("agent_type"); + StringBuilder accumulatedContent = new StringBuilder(); // Build Bedrock client BedrockRuntimeAsyncClient bedrockClient = buildBedrockRuntimeAsyncClient(); @@ -128,7 +130,18 @@ public void startStream( log.debug("Tool execution in progress - keeping stream open"); } }).subscriber(event -> { - handleStreamEvent(event, listener, isStreamClosed, toolName, toolInput, toolUseId, toolInputAccumulator, currentState); + handleStreamEvent( + event, + listener, + isStreamClosed, + toolName, + toolInput, + toolUseId, + toolInputAccumulator, + currentState, + agentType, + accumulatedContent + ); }).build(); // Start streaming @@ -183,7 +196,9 @@ private void handleStreamEvent( AtomicReference> toolInput, AtomicReference toolUseId, StringBuilder toolInputAccumulator, - AtomicReference currentState + AtomicReference currentState, + String agentType, + StringBuilder accumulatedContent ) { switch (currentState.get()) { case STREAMING_CONTENT: @@ -191,10 +206,18 @@ private void handleStreamEvent( currentState.set(StreamState.TOOL_CALL_DETECTED); extractToolInfo(event, toolName, toolUseId); } else if (isContentDelta(event)) { - sendContentResponse(getTextContent(event), false, listener); + String content = getTextContent(event); + accumulatedContent.append(content); + sendContentResponse(content, false, listener); } else if (isStreamComplete(event)) { - currentState.set(StreamState.COMPLETED); - sendCompletionResponse(isStreamClosed, listener); + // For PER agent, we should keep the connection open after the planner LLM finish + if ("per".equals(agentType)) { + currentState.set(StreamState.WAITING_FOR_TOOL_RESULT); + sendPlannerResponse(false, listener, String.valueOf(accumulatedContent)); + } else { + currentState.set(StreamState.COMPLETED); + sendCompletionResponse(isStreamClosed, listener); + } } break; @@ -225,6 +248,26 @@ private void handleStreamEvent( } } + private void sendPlannerResponse( + boolean isStreamClosed, + StreamPredictActionListener listener, + String plannerContent + ) { + if (!isStreamClosed) { + Map responseMap = new HashMap<>(); + responseMap.put("output", Map.of("message", Map.of("content", List.of(Map.of("text", plannerContent))))); + + ModelTensor tensor = ModelTensor.builder().name("response").dataAsMap(responseMap).build(); + + ModelTensors tensors = ModelTensors.builder().mlModelTensors(List.of(tensor)).build(); + + ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build(); + + listener.onResponse(MLTaskResponse.builder().output(output).build()); + log.debug("Sent planner response for PER agent"); + } + } + // TODO: refactor the event type checker methods private void extractToolInfo(ConverseStreamOutput event, AtomicReference toolName, AtomicReference toolUseId) { ContentBlockStartEvent startEvent = (ContentBlockStartEvent) event; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandler.java index 15078dfccb..2c91e3e353 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandler.java @@ -76,7 +76,8 @@ public void startStream( ) { try { log.info("Creating SSE connection for streaming request"); - EventSourceListener listener = new HTTPEventSourceListener(actionListener, llmInterface); + String agentType = parameters.get("agent_type"); + EventSourceListener listener = new HTTPEventSourceListener(actionListener, llmInterface, agentType); Request request = ConnectorUtils.buildOKHttpStreamingRequest(action, connector, parameters, payload); AccessController.doPrivileged((PrivilegedExceptionAction) () -> { @@ -100,16 +101,23 @@ public final class HTTPEventSourceListener extends EventSourceListener { private StreamPredictActionListener streamActionListener; private final String llmInterface; private AtomicBoolean isStreamClosed; + private final String agentType; private boolean functionCallInProgress = false; private boolean agentExecutionInProgress = false; private String accumulatedToolCallId = null; private String accumulatedToolName = null; private String accumulatedArguments = ""; + private StringBuilder accumulatedContent = new StringBuilder(); - public HTTPEventSourceListener(StreamPredictActionListener streamActionListener, String llmInterface) { + public HTTPEventSourceListener( + StreamPredictActionListener streamActionListener, + String llmInterface, + String agentType + ) { this.streamActionListener = streamActionListener; this.llmInterface = llmInterface; this.isStreamClosed = new AtomicBoolean(false); + this.agentType = agentType; } /*** @@ -206,14 +214,20 @@ private void processStreamChunk(Map dataMap) { // Handle stop finish reason String finishReason = extractPath(dataMap, "$.choices[0].finish_reason"); if ("stop".equals(finishReason)) { - agentExecutionInProgress = false; - sendCompletionResponse(isStreamClosed, streamActionListener); + // For PER agent, we should keep the connection open after the planner LLM finish + if ("per".equals(agentType)) { + completePlannerResponse(); + } else { + agentExecutionInProgress = false; + sendCompletionResponse(isStreamClosed, streamActionListener); + } return; } // Process content String content = extractPath(dataMap, "$.choices[0].delta.content"); if (content != null && !content.isEmpty()) { + accumulatedContent.append(content); sendContentResponse(content, false, streamActionListener); } @@ -268,6 +282,19 @@ private ModelTensorOutput createModelTensorOutput(Map responseDa return ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build(); } + private void completePlannerResponse() { + String fullContent = accumulatedContent.toString().trim(); + + // Create compatible response format + Map message = Map.of("content", fullContent); + Map choice = Map.of("message", message); + Map response = Map.of("choices", List.of(choice)); + + ModelTensorOutput output = createModelTensorOutput(response); + streamActionListener.onResponse(new MLTaskResponse(output)); + agentExecutionInProgress = true; + } + private void accumulateFunctionCall(List toolCalls) { functionCallInProgress = true; for (Object toolCall : toolCalls) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index f6c3e3618e..c63db9df4f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -710,7 +710,7 @@ public void testToolParameters() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(15, ((Map) argumentCaptor.getValue()).size()); + assertEquals(16, ((Map) argumentCaptor.getValue()).size()); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); @@ -738,7 +738,7 @@ public void testToolUseOriginalInput() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(16, ((Map) argumentCaptor.getValue()).size()); + assertEquals(17, ((Map) argumentCaptor.getValue()).size()); assertEquals("raw input", ((Map) argumentCaptor.getValue()).get("input")); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); @@ -804,7 +804,7 @@ public void testToolConfig() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(16, ((Map) argumentCaptor.getValue()).size()); + assertEquals(17, ((Map) argumentCaptor.getValue()).size()); // The value of input should be "config_value". assertEquals("config_value", ((Map) argumentCaptor.getValue()).get("input")); @@ -834,7 +834,7 @@ public void testToolConfigWithInputPlaceholder() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(16, ((Map) argumentCaptor.getValue()).size()); + assertEquals(17, ((Map) argumentCaptor.getValue()).size()); // The value of input should be replaced with the value associated with the key "key2" of the first tool. assertEquals("value2", ((Map) argumentCaptor.getValue()).get("input")); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java index 7ed4e91b1c..7095c7e110 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java @@ -11,6 +11,7 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; @@ -115,6 +116,8 @@ public class MLPlanExecuteAndReflectAgentRunnerTest extends MLStaticMockBase { private MLTaskResponse mlTaskResponse; @Mock private MLExecuteTaskResponse mlExecuteTaskResponse; + @Mock + private StreamingWrapper streamingWrapper; @Captor private ArgumentCaptor objectCaptor; @@ -174,6 +177,15 @@ public void setup() { encryptor ); + // Set streaming wrapper + try { + java.lang.reflect.Field streamingWrapperField = MLPlanExecuteAndReflectAgentRunner.class.getDeclaredField("streamingWrapper"); + streamingWrapperField.setAccessible(true); + streamingWrapperField.set(mlPlanExecuteAndReflectAgentRunner, streamingWrapper); + } catch (Exception e) { + fail("Exception thrown: " + e.getMessage()); + } + // Setup tools when(firstToolFactory.create(any())).thenReturn(firstTool); when(secondToolFactory.create(any())).thenReturn(secondTool); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapperTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapperTest.java index 79fba06c5b..18e7cbd2db 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapperTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapperTest.java @@ -159,7 +159,7 @@ public void testExecuteRequestNonStreaming() { @Test public void testSendCompletionChunkStreaming() throws Exception { - streamingWrapper.sendCompletionChunk("session1", "parent1"); + streamingWrapper.sendCompletionChunk("session1", "parent1", "executeMemory", "executeParent"); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(MLTaskResponse.class); verify(channel).sendResponseBatch(responseCaptor.capture()); @@ -170,7 +170,7 @@ public void testSendCompletionChunkStreaming() throws Exception { @Test public void testSendCompletionChunkNonStreaming() throws Exception { - nonStreamingWrapper.sendCompletionChunk("session1", "parent1"); + nonStreamingWrapper.sendCompletionChunk("session1", "parent1", "executeMemory", "executeParent"); verify(channel, never()).sendResponseBatch(any()); } @@ -180,7 +180,7 @@ public void testSendCompletionChunkWithException() throws Exception { doThrow(new RuntimeException("Channel error")).when(channel).sendResponseBatch(any()); // Should not throw exception, just log warning - streamingWrapper.sendCompletionChunk("session1", "parent1"); + streamingWrapper.sendCompletionChunk("session1", "parent1", "executeMemory", "executeParent"); verify(channel).sendResponseBatch(any()); } @@ -236,7 +236,7 @@ public void testSendToolResponseWithException() throws Exception { @Test public void testCreateStreamChunkStructure() throws Exception { - streamingWrapper.sendCompletionChunk("test-session", "test-parent"); + streamingWrapper.sendCompletionChunk("test-session", "test-parent", "executeMemory", "executeParent"); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(MLTaskResponse.class); verify(channel).sendResponseBatch(responseCaptor.capture()); @@ -245,7 +245,7 @@ public void testCreateStreamChunkStructure() throws Exception { ModelTensorOutput output = (ModelTensorOutput) response.getOutput(); List tensors = output.getMlModelOutputs().get(0).getMlModelTensors(); - assertEquals(3, tensors.size()); + assertEquals(5, tensors.size()); // Find specific tensors by name ModelTensor memoryTensor = tensors.stream().filter(t -> "memory_id".equals(t.getName())).findFirst().orElse(null); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/streaming/BedrockStreamingHandlerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/streaming/BedrockStreamingHandlerTest.java new file mode 100644 index 0000000000..5ebdefa55d --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/streaming/BedrockStreamingHandlerTest.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.remote.streaming; + +import static org.junit.Assert.assertNotNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ml.common.connector.AwsConnector; +import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.transport.MLTaskResponse; + +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; + +public class BedrockStreamingHandlerTest { + + @Mock + private SdkAsyncHttpClient httpClient; + @Mock + private AwsConnector connector; + @Mock + private StreamPredictActionListener actionListener; + + private BedrockStreamingHandler bedrockStreamingHandler; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + when(connector.getRegion()).thenReturn("us-east-1"); + when(connector.getAccessKey()).thenReturn("test-access-key"); + when(connector.getSecretKey()).thenReturn("test-secret-key"); + + bedrockStreamingHandler = new BedrockStreamingHandler(httpClient, connector); + } + + @Test + public void testConstructor() { + assertNotNull(bedrockStreamingHandler); + } + + @Test + public void testHandleError() { + Exception testException = new RuntimeException("Test error"); + + doAnswer(invocation -> { + MLException exception = invocation.getArgument(0); + assertNotNull(exception); + return null; + }).when(actionListener).onFailure(any(MLException.class)); + + bedrockStreamingHandler.handleError(testException, actionListener); + verify(actionListener).onFailure(any(MLException.class)); + } + + @Test + public void testStartStreamInvalidPayload() { + Map parameters = new HashMap<>(); + parameters.put("model", "test-model"); + parameters.put("agent_type", "test"); + + String invalidPayload = "invalid json"; + + bedrockStreamingHandler.startStream("test_action", parameters, invalidPayload, actionListener); + + verify(actionListener).onFailure(any(MLException.class)); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandlerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandlerTest.java new file mode 100644 index 0000000000..7be599118d --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandlerTest.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.remote.streaming; + +import static org.junit.Assert.assertNotNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorClientConfig; +import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.transport.MLTaskResponse; + +public class HttpStreamingHandlerTest { + + @Mock + private Connector connector; + @Mock + private ConnectorClientConfig connectorClientConfig; + @Mock + private StreamPredictActionListener actionListener; + + private HttpStreamingHandler httpStreamingHandler; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + when(connectorClientConfig.getConnectionTimeout()).thenReturn(30); + when(connectorClientConfig.getReadTimeout()).thenReturn(30); + + httpStreamingHandler = new HttpStreamingHandler(LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS, connector, connectorClientConfig); + } + + @Test + public void testConstructor() { + assertNotNull(httpStreamingHandler); + } + + @Test + public void testHandleError() { + Exception testException = new RuntimeException("Test error"); + + doAnswer(invocation -> { + MLException exception = invocation.getArgument(0); + assertNotNull(exception); + return null; + }).when(actionListener).onFailure(any(MLException.class)); + + httpStreamingHandler.handleError(testException, actionListener); + verify(actionListener).onFailure(any(MLException.class)); + } + + @Test + public void testStartStreamWithException() { + Map parameters = new HashMap<>(); + parameters.put("agent_type", "test"); + + when(connector.getActions()).thenReturn(null); + httpStreamingHandler.startStream("test_action", parameters, "test_payload", actionListener); + + verify(actionListener).onFailure(any(MLException.class)); + } + + @Test + public void testHTTPEventSourceListenerConstructor() { + HttpStreamingHandler.HTTPEventSourceListener listener = httpStreamingHandler.new HTTPEventSourceListener( + actionListener, LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS, "test_agent" + ); + assertNotNull(listener); + } + + @Test + public void testHTTPEventSourceListenerOnFailureWithThrowable() { + HttpStreamingHandler.HTTPEventSourceListener listener = httpStreamingHandler.new HTTPEventSourceListener( + actionListener, LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS, "test_agent" + ); + + RuntimeException testException = new RuntimeException("Test error"); + listener.onFailure(null, testException, null); + + verify(actionListener).onFailure(any(MLException.class)); + } + + @Test(expected = IllegalArgumentException.class) + public void testUnsupportedLLMInterface() { + HttpStreamingHandler.HTTPEventSourceListener listener = httpStreamingHandler.new HTTPEventSourceListener( + actionListener, "unsupported_interface", "test_agent" + ); + listener.onEvent(null, null, null, "test data"); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java index 4a07b79cab..8e84b499e7 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java @@ -20,6 +20,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Locale; @@ -346,21 +347,33 @@ private HttpChunk convertToHttpChunk(MLTaskResponse response) throws IOException // Regular response - extract values and build proper structure String memoryId = extractTensorResult(response, "memory_id"); String parentInteractionId = extractTensorResult(response, "parent_interaction_id"); + String executorMemoryId = extractTensorResult(response, "executor_agent_memory_id"); + String executorParentInteractionId = extractTensorResult(response, "executor_agent_parent_interaction_id"); String content = dataMap.containsKey("content") ? (String) dataMap.get("content") : ""; isLast = dataMap.containsKey("is_last") ? Boolean.TRUE.equals(dataMap.get("is_last")) : false; boolean finalIsLast = isLast; - List orderedTensors = List - .of( - ModelTensor.builder().name("memory_id").result(memoryId).build(), - ModelTensor.builder().name("parent_interaction_id").result(parentInteractionId).build(), - ModelTensor.builder().name("response").dataAsMap(new LinkedHashMap() { - { - put("content", content); - put("is_last", finalIsLast); - } - }).build() - ); + List orderedTensors = new ArrayList<>(); + orderedTensors.add(ModelTensor.builder().name("memory_id").result(memoryId).build()); + orderedTensors.add(ModelTensor.builder().name("parent_interaction_id").result(parentInteractionId).build()); + + if (executorMemoryId != null && !executorMemoryId.isEmpty()) { + orderedTensors.add(ModelTensor.builder().name("executor_agent_memory_id").result(executorMemoryId).build()); + } + + if (executorParentInteractionId != null && !executorParentInteractionId.isEmpty()) { + orderedTensors + .add( + ModelTensor.builder().name("executor_agent_parent_interaction_id").result(executorParentInteractionId).build() + ); + } + + orderedTensors.add(ModelTensor.builder().name("response").dataAsMap(new LinkedHashMap() { + { + put("content", content); + put("is_last", finalIsLast); + } + }).build()); ModelTensors tensors = ModelTensors.builder().mlModelTensors(orderedTensors).build(); ModelTensorOutput tensorOutput = ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build();