Skip to content

Commit fa709a8

Browse files
committed
add bedrock support
Signed-off-by: Jiaping Zeng <[email protected]>
1 parent b0679a1 commit fa709a8

File tree

9 files changed

+440
-268
lines changed

9 files changed

+440
-268
lines changed

common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskResponse.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import org.opensearch.ml.common.output.model.ModelTensorOutput;
2727
import org.opensearch.ml.common.output.model.ModelTensors;
2828

29+
import com.google.gson.Gson;
30+
2931
import lombok.Builder;
3032
import lombok.Getter;
3133
import lombok.NonNull;
@@ -103,7 +105,7 @@ private AGUIOutput extractAGUIOutput(ModelTensorOutput modelOutput) {
103105
try {
104106
ModelTensors modelTensors = modelOutput.getMlModelOutputs().get(0);
105107
String eventsJson = modelTensors.getMlModelTensors().get(0).getResult();
106-
com.google.gson.Gson gson = new com.google.gson.Gson();
108+
Gson gson = new Gson();
107109
List<Object> events = gson.fromJson(eventsJson, List.class);
108110
return AGUIOutput.builder().events(events).build();
109111
} catch (Exception e) {

common/src/test/java/org/opensearch/ml/common/agui/AGUIInputConverterTest.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
1515
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
1616

17+
import com.google.gson.JsonObject;
18+
1719
public class AGUIInputConverterTest {
1820

1921
@Test
@@ -254,7 +256,7 @@ public void testReconstructAGUIInput() {
254256
parameters.put("agui_messages", "[]");
255257
parameters.put("agui_tools", "[]");
256258

257-
com.google.gson.JsonObject result = AGUIInputConverter.reconstructAGUIInput(parameters);
259+
JsonObject result = AGUIInputConverter.reconstructAGUIInput(parameters);
258260

259261
assertNotNull("Reconstructed result should not be null", result);
260262
assertTrue("Should contain threadId", result.has("threadId"));

common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import java.io.IOException;
1919
import java.lang.reflect.Constructor;
20+
import java.lang.reflect.Type;
2021
import java.util.ArrayList;
2122
import java.util.Arrays;
2223
import java.util.HashMap;
@@ -1201,9 +1202,9 @@ public void testSerializeFloatNaNAndInfinity_BecomesNull_InPojo() {
12011202
@Test
12021203
public void testDeserializeScientificNotation_ToFloatAndPrimitive() {
12031204
String jsonObj = "{\"fObj\":1.23e-5}";
1204-
java.lang.reflect.Type mapType = new com.google.gson.reflect.TypeToken<java.util.Map<String, Float>>() {
1205+
Type mapType = new TypeToken<Map<String, Float>>() {
12051206
}.getType();
1206-
java.util.Map<String, Float> m = StringUtils.PLAIN_NUMBER_GSON.fromJson(jsonObj, mapType);
1207+
Map<String, Float> m = StringUtils.PLAIN_NUMBER_GSON.fromJson(jsonObj, mapType);
12071208
assertEquals(1.23e-5f, m.get("fObj"), 1e-9f);
12081209

12091210
String jsonArr = "[4.56e1]";
@@ -1215,9 +1216,9 @@ public void testDeserializeScientificNotation_ToFloatAndPrimitive() {
12151216
public void testDeserializeNullFloat_ToNull() {
12161217
String json = "{\"fObj\":null,\"fPrim\":1.0}";
12171218

1218-
java.lang.reflect.Type mapType = new TypeToken<java.util.Map<String, JsonElement>>() {
1219+
Type mapType = new TypeToken<Map<String, JsonElement>>() {
12191220
}.getType();
1220-
java.util.Map<String, JsonElement> m = StringUtils.PLAIN_NUMBER_GSON.fromJson(json, mapType);
1221+
Map<String, JsonElement> m = StringUtils.PLAIN_NUMBER_GSON.fromJson(json, mapType);
12211222

12221223
assertTrue(m.containsKey("fObj"));
12231224
assertTrue(m.get("fObj").isJsonNull());

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAGUIAgentRunner.java

Lines changed: 102 additions & 77 deletions
Large diffs are not rendered by default.

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

Lines changed: 64 additions & 124 deletions
Large diffs are not rendered by default.

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import java.util.Locale;
1919
import java.util.Map;
2020
import java.util.concurrent.CompletableFuture;
21-
import java.util.concurrent.atomic.AtomicReference;
2221

2322
import org.apache.commons.text.StringEscapeUtils;
2423
import org.apache.logging.log4j.Logger;
@@ -41,8 +40,6 @@
4140
import org.opensearch.transport.StreamTransportService;
4241
import org.opensearch.transport.client.Client;
4342

44-
import com.google.common.annotations.VisibleForTesting;
45-
4643
import lombok.Getter;
4744
import lombok.Setter;
4845
import lombok.extern.log4j.Log4j2;
@@ -73,18 +70,19 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor {
7370
@Getter
7471
private MLGuard mlGuard;
7572

76-
private final AtomicReference<SdkAsyncHttpClient> httpClientRef = new AtomicReference<>();
73+
private SdkAsyncHttpClient httpClient;
7774

7875
@Setter
7976
@Getter
8077
private StreamTransportService streamTransportService;
8178

82-
@Setter
83-
private boolean connectorPrivateIpEnabled;
84-
8579
public AwsConnectorExecutor(Connector connector) {
8680
super.initialize(connector);
8781
this.connector = (AwsConnector) connector;
82+
Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout());
83+
Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout());
84+
Integer maxConnection = super.getConnectorClientConfig().getMaxConnections();
85+
this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection);
8886
}
8987

9088
@Override
@@ -131,8 +129,7 @@ public void invokeRemoteService(
131129
)
132130
)
133131
.build();
134-
AccessController
135-
.doPrivileged((PrivilegedExceptionAction<CompletableFuture<Void>>) () -> getHttpClient().execute(executeRequest));
132+
AccessController.doPrivileged((PrivilegedExceptionAction<CompletableFuture<Void>>) () -> httpClient.execute(executeRequest));
136133
} catch (RuntimeException exception) {
137134
log.error("Failed to execute {} in aws connector: {}", action, exception.getMessage(), exception);
138135
actionListener.onFailure(exception);
@@ -157,7 +154,7 @@ public void invokeRemoteServiceStream(
157154
llmInterface = StringEscapeUtils.unescapeJava(llmInterface);
158155
validateLLMInterface(llmInterface);
159156

160-
StreamingHandler handler = StreamingHandlerFactory.createHandler(llmInterface, connector, getHttpClient(), null);
157+
StreamingHandler handler = StreamingHandlerFactory.createHandler(llmInterface, connector, getHttpClient(), null, parameters);
161158
handler.startStream(action, parameters, payload, actionListener);
162159
} catch (Exception e) {
163160
log.error("Failed to execute streaming", e);
@@ -183,19 +180,4 @@ private void validateLLMInterface(String llmInterface) {
183180
throw new IllegalArgumentException(String.format("Unsupported llm interface: %s", llmInterface));
184181
}
185182
}
186-
187-
@VisibleForTesting
188-
protected SdkAsyncHttpClient getHttpClient() {
189-
if (httpClientRef.get() == null) {
190-
Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout());
191-
Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout());
192-
Integer maxConnection = super.getConnectorClientConfig().getMaxConnections();
193-
this.httpClientRef
194-
.compareAndSet(
195-
null,
196-
MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled)
197-
);
198-
}
199-
return httpClientRef.get();
200-
}
201183
}

0 commit comments

Comments
 (0)