Skip to content

Commit caa6d57

Browse files
authored
Merge branch 'opensearch-project:main' into feature/context_manager_hooks_inline_create
2 parents e1bc0e0 + 7d25d56 commit caa6d57

File tree

84 files changed

+3234
-618
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

84 files changed

+3234
-618
lines changed

.github/workflows/maven-publish.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ on:
55
push:
66
branches:
77
- main
8-
- 1.*
9-
- 2.*
8+
- '[0-9]+.[0-9]+'
9+
- '[0-9]+.x'
1010

1111
jobs:
1212
build-and-publish-snapshots:

build.gradle

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,15 @@ subprojects {
9696
// Force spotless depending on newer version of guava due to CVE-2023-2976. Remove after spotless upgrades.
9797
resolutionStrategy.force "com.google.guava:guava:32.1.3-jre"
9898
resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0'
99+
resolutionStrategy.force "io.netty:netty-buffer:${versions.netty}"
100+
resolutionStrategy.force "io.netty:netty-codec:${versions.netty}"
101+
resolutionStrategy.force "io.netty:netty-codec-http:${versions.netty}"
102+
resolutionStrategy.force "io.netty:netty-codec-http2:${versions.netty}"
103+
resolutionStrategy.force "io.netty:netty-common:${versions.netty}"
104+
resolutionStrategy.force "io.netty:netty-handler:${versions.netty}"
105+
resolutionStrategy.force "io.netty:netty-resolver:${versions.netty}"
106+
resolutionStrategy.force "io.netty:netty-transport:${versions.netty}"
107+
resolutionStrategy.force "io.netty:netty-transport-native-unix-common:${versions.netty}"
99108
}
100109

101110
apply plugin: 'com.diffplug.spotless'

common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ public class MLPostProcessFunction {
3131
public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding";
3232
public static final String BEDROCK_V2_EMBEDDING_FLOAT = "connector.post_process.bedrock_v2.embedding.float";
3333
public static final String BEDROCK_V2_EMBEDDING_BINARY = "connector.post_process.bedrock_v2.embedding.binary";
34+
public static final String BEDROCK_NOVA_EMBEDDING = "connector.post_process.bedrock.nova.embedding";
3435
public static final String BEDROCK_BATCH_JOB_ARN = "connector.post_process.bedrock.batch_job_arn";
3536
public static final String COHERE_RERANK = "connector.post_process.cohere.rerank";
3637
public static final String BEDROCK_RERANK = "connector.post_process.bedrock.rerank";
@@ -62,6 +63,7 @@ public class MLPostProcessFunction {
6263
JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding");
6364
JSON_PATH_EXPRESSION.put(BEDROCK_V2_EMBEDDING_FLOAT, "$.embeddingsByType.float");
6465
JSON_PATH_EXPRESSION.put(BEDROCK_V2_EMBEDDING_BINARY, "$.embeddingsByType.binary");
66+
JSON_PATH_EXPRESSION.put(BEDROCK_NOVA_EMBEDDING, "$.embeddings[*].embedding");
6567
JSON_PATH_EXPRESSION.put(BEDROCK_BATCH_JOB_ARN, "$");
6668
JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results");
6769
JSON_PATH_EXPRESSION.put(BEDROCK_RERANK, "$.results");
@@ -78,6 +80,7 @@ public class MLPostProcessFunction {
7880
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction);
7981
POST_PROCESS_FUNCTIONS.put(BEDROCK_V2_EMBEDDING_FLOAT, bedrockEmbeddingPostProcessFunction);
8082
POST_PROCESS_FUNCTIONS.put(BEDROCK_V2_EMBEDDING_BINARY, bedrockEmbeddingPostProcessFunction);
83+
POST_PROCESS_FUNCTIONS.put(BEDROCK_NOVA_EMBEDDING, embeddingPostProcessFunction);
8184
POST_PROCESS_FUNCTIONS.put(BEDROCK_BATCH_JOB_ARN, batchJobArnPostProcessFunction);
8285
POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction);
8386
POST_PROCESS_FUNCTIONS.put(BEDROCK_RERANK, bedrockRerankPostProcessFunction);

common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@
99
import java.util.Map;
1010
import java.util.function.Function;
1111

12+
import org.opensearch.ml.common.connector.functions.preprocess.AudioEmbeddingPreProcessFunction;
1213
import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction;
1314
import org.opensearch.ml.common.connector.functions.preprocess.BedrockRerankPreProcessFunction;
1415
import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction;
1516
import org.opensearch.ml.common.connector.functions.preprocess.CohereMultiModalEmbeddingPreProcessFunction;
1617
import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction;
18+
import org.opensearch.ml.common.connector.functions.preprocess.ImageEmbeddingPreProcessFunction;
1719
import org.opensearch.ml.common.connector.functions.preprocess.MultiModalConnectorPreProcessFunction;
1820
import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction;
21+
import org.opensearch.ml.common.connector.functions.preprocess.VideoEmbeddingPreProcessFunction;
1922
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
2023
import org.opensearch.ml.common.input.MLInput;
2124

@@ -27,6 +30,10 @@ public class MLPreProcessFunction {
2730
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";
2831
public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding";
2932
public static final String TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.multimodal_embedding";
33+
public static final String TEXT_TO_BEDROCK_NOVA_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.text_embedding";
34+
public static final String IMAGE_TO_BEDROCK_NOVA_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.image_embedding";
35+
public static final String VIDEO_TO_BEDROCK_NOVA_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.video_embedding";
36+
public static final String AUDIO_TO_BEDROCK_NOVA_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.audio_embedding";
3037
public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";
3138
public static final String TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT = "connector.pre_process.cohere.rerank";
3239
public static final String TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT = "connector.pre_process.bedrock.rerank";
@@ -42,11 +49,18 @@ public class MLPreProcessFunction {
4249
CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction();
4350
BedrockRerankPreProcessFunction bedrockRerankPreProcessFunction = new BedrockRerankPreProcessFunction();
4451
MultiModalConnectorPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalConnectorPreProcessFunction();
52+
ImageEmbeddingPreProcessFunction imageEmbeddingPreProcessFunction = new ImageEmbeddingPreProcessFunction();
53+
VideoEmbeddingPreProcessFunction videoEmbeddingPreProcessFunction = new VideoEmbeddingPreProcessFunction();
54+
AudioEmbeddingPreProcessFunction audioEmbeddingPreProcessFunction = new AudioEmbeddingPreProcessFunction();
4555
CohereMultiModalEmbeddingPreProcessFunction cohereMultiModalEmbeddingPreProcessFunction =
4656
new CohereMultiModalEmbeddingPreProcessFunction();
4757
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction);
4858
PRE_PROCESS_FUNCTIONS.put(IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT, cohereMultiModalEmbeddingPreProcessFunction);
4959
PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction);
60+
PRE_PROCESS_FUNCTIONS.put(TEXT_TO_BEDROCK_NOVA_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction);
61+
PRE_PROCESS_FUNCTIONS.put(IMAGE_TO_BEDROCK_NOVA_EMBEDDING_INPUT, imageEmbeddingPreProcessFunction);
62+
PRE_PROCESS_FUNCTIONS.put(VIDEO_TO_BEDROCK_NOVA_EMBEDDING_INPUT, videoEmbeddingPreProcessFunction);
63+
PRE_PROCESS_FUNCTIONS.put(AUDIO_TO_BEDROCK_NOVA_EMBEDDING_INPUT, audioEmbeddingPreProcessFunction);
5064
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
5165
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
5266
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction);
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.preprocess;
7+
8+
import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;
9+
10+
import java.util.HashMap;
11+
import java.util.List;
12+
import java.util.Map;
13+
14+
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
15+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
16+
import org.opensearch.ml.common.input.MLInput;
17+
18+
/**
19+
* This class provides a pre-processing function for Bedrock Nova audio input data.
20+
* It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}.
21+
* The input data is expected to be of type {@link TextDocsInputDataSet}, with document representing an audio.
22+
* The function validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
23+
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
24+
*/
25+
public class AudioEmbeddingPreProcessFunction extends ConnectorPreProcessFunction {
26+
27+
public AudioEmbeddingPreProcessFunction() {
28+
this.returnDirectlyForRemoteInferenceInput = true;
29+
}
30+
31+
@Override
32+
public void validate(MLInput mlInput) {
33+
validateTextDocsInput(mlInput);
34+
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs();
35+
if (docs.size() == 0) {
36+
throw new IllegalArgumentException("No input audio provided");
37+
}
38+
}
39+
40+
/**
41+
* @param mlInput The input data to be processed.
42+
* This method validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
43+
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
44+
*/
45+
@Override
46+
public RemoteInferenceInputDataSet process(MLInput mlInput) {
47+
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
48+
Map<String, String> parametersMap = new HashMap<>();
49+
parametersMap.put("inputAudio", inputData.getDocs().get(0));
50+
return RemoteInferenceInputDataSet
51+
.builder()
52+
.parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap)))
53+
.build();
54+
55+
}
56+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.preprocess;
7+
8+
import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;
9+
10+
import java.util.HashMap;
11+
import java.util.List;
12+
import java.util.Map;
13+
14+
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
15+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
16+
import org.opensearch.ml.common.input.MLInput;
17+
18+
/**
19+
* This class provides a pre-processing function for Bedrock Nova image input data.
20+
* It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}.
21+
* The input data is expected to be of type {@link TextDocsInputDataSet}, with document representing an image.
22+
* The function validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
23+
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
24+
*/
25+
public class ImageEmbeddingPreProcessFunction extends ConnectorPreProcessFunction {
26+
27+
public ImageEmbeddingPreProcessFunction() {
28+
this.returnDirectlyForRemoteInferenceInput = true;
29+
}
30+
31+
@Override
32+
public void validate(MLInput mlInput) {
33+
validateTextDocsInput(mlInput);
34+
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs();
35+
if (docs.size() == 0) {
36+
throw new IllegalArgumentException("No input image provided");
37+
}
38+
}
39+
40+
/**
41+
* @param mlInput The input data to be processed.
42+
* This method validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
43+
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
44+
*/
45+
@Override
46+
public RemoteInferenceInputDataSet process(MLInput mlInput) {
47+
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
48+
Map<String, String> parametersMap = new HashMap<>();
49+
parametersMap.put("inputImage", inputData.getDocs().get(0));
50+
return RemoteInferenceInputDataSet
51+
.builder()
52+
.parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap)))
53+
.build();
54+
55+
}
56+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.preprocess;
7+
8+
import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;
9+
10+
import java.util.HashMap;
11+
import java.util.List;
12+
import java.util.Map;
13+
14+
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
15+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
16+
import org.opensearch.ml.common.input.MLInput;
17+
18+
/**
19+
* This class provides a pre-processing function for Bedrock Nova video input data.
20+
* It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}.
21+
* The input data is expected to be of type {@link TextDocsInputDataSet}, with document representing a video.
22+
* The function validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
23+
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
24+
*/
25+
public class VideoEmbeddingPreProcessFunction extends ConnectorPreProcessFunction {
26+
27+
public VideoEmbeddingPreProcessFunction() {
28+
this.returnDirectlyForRemoteInferenceInput = true;
29+
}
30+
31+
@Override
32+
public void validate(MLInput mlInput) {
33+
validateTextDocsInput(mlInput);
34+
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs();
35+
if (docs.size() == 0) {
36+
throw new IllegalArgumentException("No input video provided");
37+
}
38+
}
39+
40+
/**
41+
* @param mlInput The input data to be processed.
42+
* This method validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
43+
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
44+
*/
45+
@Override
46+
public RemoteInferenceInputDataSet process(MLInput mlInput) {
47+
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
48+
Map<String, String> parametersMap = new HashMap<>();
49+
parametersMap.put("inputVideo", inputData.getDocs().get(0));
50+
return RemoteInferenceInputDataSet
51+
.builder()
52+
.parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap)))
53+
.build();
54+
55+
}
56+
}

common/src/main/java/org/opensearch/ml/common/httpclient/MLHttpClientFactory.java

Lines changed: 24 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,9 @@
55

66
package org.opensearch.ml.common.httpclient;
77

8-
import java.net.Inet4Address;
9-
import java.net.InetAddress;
10-
import java.net.UnknownHostException;
11-
import java.time.Duration;
12-
import java.util.Arrays;
13-
import java.util.Locale;
14-
import java.util.concurrent.atomic.AtomicBoolean;
8+
import static org.opensearch.secure_sm.AccessController.doPrivileged;
159

16-
import org.opensearch.common.util.concurrent.ThreadContextAccess;
10+
import java.time.Duration;
1711

1812
import lombok.extern.log4j.Log4j2;
1913
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
@@ -22,79 +16,27 @@
2216
@Log4j2
2317
public class MLHttpClientFactory {
2418

25-
public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout, Duration readTimeout, int maxConnections) {
26-
return ThreadContextAccess
27-
.doPrivileged(
28-
() -> NettyNioAsyncHttpClient
29-
.builder()
30-
.connectionTimeout(connectionTimeout)
31-
.readTimeout(readTimeout)
32-
.maxConcurrency(maxConnections)
33-
.build()
34-
);
35-
}
36-
37-
/**
38-
* Validate the input parameters, such as protocol, host and port.
39-
* @param protocol The protocol supported in remote inference, currently only http and https are supported.
40-
* @param host The host name of the remote inference server, host must be a valid ip address or domain name and must not be localhost.
41-
* @param port The port number of the remote inference server, port number must be in range [0, 65536].
42-
* @param connectorPrivateIpEnabled The port number of the remote inference server, port number must be in range [0, 65536].
43-
* @throws UnknownHostException Allow to use private IP or not.
44-
*/
45-
public static void validate(String protocol, String host, int port, AtomicBoolean connectorPrivateIpEnabled)
46-
throws UnknownHostException {
47-
if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equalsIgnoreCase(protocol)) {
48-
log.error("Remote inference protocol is not http or https: {}", protocol);
49-
throw new IllegalArgumentException("Protocol is not http or https: " + protocol);
50-
}
51-
// When port is not specified, the default port is -1, and we need to set it to 80 or 443 based on protocol.
52-
if (port == -1) {
53-
if (protocol == null || "http".equals(protocol.toLowerCase(Locale.getDefault()))) {
54-
port = 80;
55-
} else {
56-
port = 443;
57-
}
58-
}
59-
if (port < 0 || port > 65536) {
60-
log.error("Remote inference port out of range: {}", port);
61-
throw new IllegalArgumentException("Port out of range: " + port);
62-
}
63-
validateIp(host, connectorPrivateIpEnabled);
64-
}
65-
66-
private static void validateIp(String hostName, AtomicBoolean connectorPrivateIpEnabled) throws UnknownHostException {
67-
InetAddress[] addresses = InetAddress.getAllByName(hostName);
68-
if ((connectorPrivateIpEnabled == null || !connectorPrivateIpEnabled.get()) && hasPrivateIpAddress(addresses)) {
69-
log.error("Remote inference host name has private ip address: {}", hostName);
70-
throw new IllegalArgumentException("Remote inference host name has private ip address: " + hostName);
71-
}
72-
}
73-
74-
private static boolean hasPrivateIpAddress(InetAddress[] ipAddress) {
75-
for (InetAddress ip : ipAddress) {
76-
if (ip instanceof Inet4Address) {
77-
byte[] bytes = ip.getAddress();
78-
if (bytes.length != 4) {
79-
return true;
80-
} else {
81-
if (isPrivateIPv4(bytes)) {
82-
return true;
83-
}
84-
}
85-
}
86-
}
87-
return Arrays.stream(ipAddress).anyMatch(x -> x.isSiteLocalAddress() || x.isLoopbackAddress() || x.isAnyLocalAddress());
88-
}
89-
90-
private static boolean isPrivateIPv4(byte[] bytes) {
91-
int first = bytes[0] & 0xff;
92-
int second = bytes[1] & 0xff;
93-
94-
// 127.0.0.1, 10.x.x.x, 172.16-31.x.x, 192.168.x.x, 169.254.x.x
95-
return (first == 10)
96-
|| (first == 172 && second >= 16 && second <= 31)
97-
|| (first == 192 && second == 168)
98-
|| (first == 169 && second == 254);
19+
public static SdkAsyncHttpClient getAsyncHttpClient(
20+
Duration connectionTimeout,
21+
Duration readTimeout,
22+
int maxConnections,
23+
boolean connectorPrivateIpEnabled
24+
) {
25+
return doPrivileged(() -> {
26+
log
27+
.debug(
28+
"Creating MLHttpClient with connectionTimeout: {}, readTimeout: {}, maxConnections: {}",
29+
connectionTimeout,
30+
readTimeout,
31+
maxConnections
32+
);
33+
SdkAsyncHttpClient delegate = NettyNioAsyncHttpClient
34+
.builder()
35+
.connectionTimeout(connectionTimeout)
36+
.readTimeout(readTimeout)
37+
.maxConcurrency(maxConnections)
38+
.build();
39+
return new MLValidatableAsyncHttpClient(delegate, connectorPrivateIpEnabled);
40+
});
9941
}
10042
}

0 commit comments

Comments
 (0)