Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import java.time.Duration;

import lombok.extern.log4j.Log4j2;
import software.amazon.awssdk.http.SdkHttpConfigurationOption;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.utils.AttributeMap;

@Log4j2
public class MLHttpClientFactory {
Expand All @@ -20,7 +22,8 @@ public static SdkAsyncHttpClient getAsyncHttpClient(
Duration connectionTimeout,
Duration readTimeout,
int maxConnections,
boolean connectorPrivateIpEnabled
boolean connectorPrivateIpEnabled,
boolean connectorSslVerificationEnabled
) {
return doPrivileged(() -> {
log
Expand All @@ -35,7 +38,9 @@ public static SdkAsyncHttpClient getAsyncHttpClient(
.connectionTimeout(connectionTimeout)
.readTimeout(readTimeout)
.maxConcurrency(maxConnections)
.build();
.buildWithDefaults(AttributeMap.builder()
.put(SdkHttpConfigurationOption.TRUST_ALL_CERTIFICATES, !connectorSslVerificationEnabled)
.build());
return new MLValidatableAsyncHttpClient(delegate, connectorPrivateIpEnabled);
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,4 +483,8 @@ private MLCommonsSettings() {}
// Feature flag for streaming feature
public static final Setting<Boolean> ML_COMMONS_STREAM_ENABLED = Setting
.boolSetting(ML_PLUGIN_SETTING_PREFIX + "stream_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);

// Feature flag to enable or disable SSL verification of remote llm connectors
public static final Setting<Boolean> ML_COMMONS_CONNECTOR_SSL_VERIFICATION_ENABLED = Setting
.boolSetting(ML_PLUGIN_SETTING_PREFIX + "connector.ssl_verification_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED;
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STREAM_ENABLED;
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_SSL_VERIFICATION_ENABLED;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -63,6 +64,8 @@ public class MLFeatureEnabledSetting {

private volatile Boolean isStreamEnabled;

private volatile Boolean isConnectorSslVerificationEnabled;

private final List<SettingsChangeListener> listeners = new ArrayList<>();

public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) {
Expand All @@ -83,6 +86,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
isAgenticMemoryEnabled = ML_COMMONS_AGENTIC_MEMORY_ENABLED.get(settings);
isIndexInsightEnabled = ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED.get(settings);
isStreamEnabled = ML_COMMONS_STREAM_ENABLED.get(settings);
isConnectorSslVerificationEnabled = ML_COMMONS_CONNECTOR_SSL_VERIFICATION_ENABLED.get(settings);

clusterService
.getClusterSettings()
Expand All @@ -109,6 +113,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MCP_CONNECTOR_ENABLED, it -> isMcpConnectorEnabled = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_AGENTIC_MEMORY_ENABLED, it -> isAgenticMemoryEnabled = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_STREAM_ENABLED, it -> isStreamEnabled = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_CONNECTOR_SSL_VERIFICATION_ENABLED, it -> isConnectorSslVerificationEnabled = it);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED, it -> isIndexInsightEnabled = it);
Expand Down Expand Up @@ -245,4 +250,8 @@ public boolean isIndexInsightEnabled() {
public boolean isStreamEnabled() {
return isStreamEnabled;
}

public boolean isConnectorSslVerificationEnabled() {
return isConnectorSslVerificationEnabled;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ public class MLHttpClientFactoryTests {

@Test
public void test_getSdkAsyncHttpClient_success() {
SdkAsyncHttpClient client = MLHttpClientFactory.getAsyncHttpClient(Duration.ofSeconds(100), Duration.ofSeconds(100), 100, false);
SdkAsyncHttpClient client = MLHttpClientFactory.getAsyncHttpClient(Duration.ofSeconds(100), Duration.ofSeconds(100), 100, false, true);
assertNotNull(client);
client = MLHttpClientFactory.getAsyncHttpClient(Duration.ofSeconds(100), Duration.ofSeconds(100), 100, false, false);
assertNotNull(client);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,9 @@ public void testAgenticMemoryDisabledMessage() {
public void testStreamDisabledByDefault() {
assertFalse(MLCommonsSettings.ML_COMMONS_STREAM_ENABLED.getDefault(null));
}

@Test
public void testConnectorSSLVerificationByDefault() {
assertTrue(MLCommonsSettings.ML_COMMONS_CONNECTOR_SSL_VERIFICATION_ENABLED.getDefault(null));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ public void setUp() {
MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED,
MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_ENABLED,
MLCommonsSettings.ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED,
MLCommonsSettings.ML_COMMONS_STREAM_ENABLED
MLCommonsSettings.ML_COMMONS_STREAM_ENABLED,
MLCommonsSettings.ML_COMMONS_CONNECTOR_SSL_VERIFICATION_ENABLED
)
);
when(mockClusterService.getClusterSettings()).thenReturn(mockClusterSettings);
Expand All @@ -74,6 +75,7 @@ public void testDefaults_allFeaturesEnabled() {
.put("plugins.ml_commons.agentic_search_enabled", true)
.put("plugins.ml_commons.agentic_memory_enabled", true)
.put("plugins.ml_commons.stream_enabled", true)
.put("plugins.ml_commons.connector.ssl_verification_enabled", true)
.build();

MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings);
Expand All @@ -93,6 +95,7 @@ public void testDefaults_allFeaturesEnabled() {
assertTrue(setting.isMcpConnectorEnabled());
assertTrue(setting.isAgenticMemoryEnabled());
assertTrue(setting.isStreamEnabled());
assertTrue(setting.isConnectorSslVerificationEnabled());
}

@Test
Expand All @@ -115,6 +118,7 @@ public void testDefaults_someFeaturesDisabled() {
.put("plugins.ml_commons.agentic_search_enabled", false)
.put("plugins.ml_commons.agentic_memory_enabled", false)
.put("plugins.ml_commons.stream_enabled", false)
.put("plugins.ml_commons.connector.ssl_verification_enabled", false)
.build();

MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings);
Expand All @@ -134,6 +138,7 @@ public void testDefaults_someFeaturesDisabled() {
assertFalse(setting.isMcpConnectorEnabled());
assertFalse(setting.isAgenticMemoryEnabled());
assertFalse(setting.isStreamEnabled());
assertFalse(setting.isConnectorSslVerificationEnabled());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor {
@Setter
private boolean connectorPrivateIpEnabled;

@Setter
private boolean connectorSslVerificationEnabled;

public AwsConnectorExecutor(Connector connector) {
super.initialize(connector);
this.connector = (AwsConnector) connector;
Expand Down Expand Up @@ -193,7 +196,7 @@ protected SdkAsyncHttpClient getHttpClient() {
this.httpClientRef
.compareAndSet(
null,
MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled)
MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled, connectorSslVerificationEnabled)
);
}
return httpClientRef.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ public class HttpJsonConnectorExecutor extends AbstractConnectorExecutor {
private MLGuard mlGuard;
@Setter
private volatile boolean connectorPrivateIpEnabled;
@Setter
private boolean connectorSslVerificationEnabled;

private final AtomicReference<SdkAsyncHttpClient> httpClientRef = new AtomicReference<>();

Expand Down Expand Up @@ -183,7 +185,7 @@ protected SdkAsyncHttpClient getHttpClient() {
this.httpClientRef
.compareAndSet(
null,
MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled)
MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled, connectorSslVerificationEnabled)
);
}
return httpClientRef.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ default void setClient(Client client) {}

default void setConnectorPrivateIpEnabled(boolean connectorPrivateIpEnabled) {}

default void setConnectorSslVerificationEnabled(boolean connectorSslVerificationEnabled) {}

default void setXContentRegistry(NamedXContentRegistry xContentRegistry) {}

default void setClusterService(ClusterService clusterService) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public class RemoteModel implements Predictable {
public static final String USER_RATE_LIMITER_MAP = "user_rate_limiter_map";
public static final String GUARDRAILS = "guardrails";
public static final String CONNECTOR_PRIVATE_IP_ENABLED = "connectorPrivateIpEnabled";
public static final String CONNECTOR_SSL_VERIFICATION_ENABLED = "connectorSslVerificationEnabled";
public static final String SDK_CLIENT = "sdk_client";
public static final String SETTINGS = "settings";

Expand Down Expand Up @@ -127,6 +128,7 @@ public CompletionStage<Boolean> initModelAsync(MLModel model, Map<String, Object
this.connectorExecutor.setUserRateLimiterMap((Map<String, TokenBucket>) params.get(USER_RATE_LIMITER_MAP));
this.connectorExecutor.setMlGuard((MLGuard) params.get(GUARDRAILS));
this.connectorExecutor.setConnectorPrivateIpEnabled((boolean) params.getOrDefault(CONNECTOR_PRIVATE_IP_ENABLED, false));
this.connectorExecutor.setConnectorSslVerificationEnabled((boolean) params.getOrDefault(CONNECTOR_SSL_VERIFICATION_ENABLED, true));
return CompletableFuture.completedStage(true);
}).exceptionally(e -> {
log.error("Failed to init remote model.", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader
.initInstance(connector.getProtocol(), connector, Connector.class);
connectorExecutor.setConnectorPrivateIpEnabled(mlFeatureEnabledSetting.isConnectorPrivateIpEnabled());
connectorExecutor.setConnectorSslVerificationEnabled(mlFeatureEnabledSetting.isConnectorSslVerificationEnabled());
connectorExecutor.setScriptService(scriptService);
connectorExecutor.setClusterService(clusterService);
connectorExecutor.setClient(client);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.SETTINGS;
import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.USER_RATE_LIMITER_MAP;
import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.XCONTENT_REGISTRY;
import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CONNECTOR_SSL_VERIFICATION_ENABLED;
import static org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel.ML_ENGINE;
import static org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel.MODEL_HELPER;
import static org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel.MODEL_ZIP_FILE;
Expand Down Expand Up @@ -1509,6 +1510,7 @@ private Map<String, Object> setUpParameterMap(String modelId, String tenantId) {
log.info("Setting up ML guard parameter for ML predictor.");
}
params.put(CONNECTOR_PRIVATE_IP_ENABLED, mlFeatureEnabledSetting.isConnectorPrivateIpEnabled());
params.put(CONNECTOR_SSL_VERIFICATION_ENABLED, mlFeatureEnabledSetting.isConnectorSslVerificationEnabled());
params.put(SDK_CLIENT, sdkClient);
params.put(SETTINGS, settings);
return Collections.unmodifiableMap(params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1364,7 +1364,8 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED,
MLCommonsSettings.REMOTE_METADATA_GLOBAL_TENANT_ID,
MLCommonsSettings.REMOTE_METADATA_GLOBAL_RESOURCE_CACHE_TTL,
MLCommonsSettings.ML_COMMONS_STREAM_ENABLED
MLCommonsSettings.ML_COMMONS_STREAM_ENABLED,
MLCommonsSettings.ML_COMMONS_CONNECTOR_SSL_VERIFICATION_ENABLED
);
return settings;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED;
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STREAM_ENABLED;
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_SSL_VERIFICATION_ENABLED;

import java.util.Set;

Expand Down Expand Up @@ -80,7 +81,8 @@ public void setUp() {
ML_COMMONS_AGENTIC_MEMORY_ENABLED,
ML_COMMONS_MCP_CONNECTOR_ENABLED,
ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED,
ML_COMMONS_STREAM_ENABLED
ML_COMMONS_STREAM_ENABLED,
ML_COMMONS_CONNECTOR_SSL_VERIFICATION_ENABLED
)
)
);
Expand Down
Loading