Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ public void createConnector() {
.name("test")
.description("description")
.version("testModelVersion")
.protocol("testProtocol")
.protocol("http")
.parameters(params)
.credential(credentials)
.actions(null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,7 @@ public void createConnector() {
.name("test")
.description("description")
.version("testModelVersion")
.protocol("testProtocol")
.protocol("http")
.parameters(params)
.credential(credentials)
.actions(null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,25 +163,28 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {

switch (fieldName) {
case ACTION_TYPE_FIELD:
actionType = ActionType.valueOf(parser.text().toUpperCase(Locale.ROOT));
String actionTypeText = parser.textOrNull();
if (actionTypeText != null) {
actionType = ActionType.valueOf(actionTypeText.toUpperCase(Locale.ROOT));
}
break;
case METHOD_FIELD:
method = parser.text();
method = parser.textOrNull();
break;
case URL_FIELD:
url = parser.text();
url = parser.textOrNull();
break;
case HEADERS_FIELD:
headers = parser.mapStrings();
break;
case REQUEST_BODY_FIELD:
requestBody = parser.text();
requestBody = parser.textOrNull();
break;
case ACTION_PRE_PROCESS_FUNCTION:
preProcessFunction = parser.text();
preProcessFunction = parser.textOrNull();
break;
case ACTION_POST_PROCESS_FUNCTION:
postProcessFunction = parser.text();
postProcessFunction = parser.textOrNull();
break;
default:
parser.skipChildren();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.opensearch.ml.common.CommonValue.VERSION_3_0_0;
import static org.opensearch.ml.common.connector.ConnectorProtocols.MCP_SSE;
import static org.opensearch.ml.common.connector.ConnectorProtocols.MCP_STREAMABLE_HTTP;
import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;

import java.io.IOException;
Expand Down Expand Up @@ -106,9 +107,7 @@ public MLCreateConnectorInput(
if (version == null) {
throw new IllegalArgumentException("Connector version is null");
}
if (protocol == null) {
throw new IllegalArgumentException("Connector protocol is null");
}
validateProtocol(protocol);
boolean isMcpConnector = (protocol.equals(MCP_SSE) || protocol.equals(MCP_STREAMABLE_HTTP));
if ((credential == null || credential.isEmpty()) && !isMcpConnector) {
throw new IllegalArgumentException("Connector credential is null or empty list");
Expand Down Expand Up @@ -183,7 +182,19 @@ public static MLCreateConnectorInput parse(XContentParser parser, boolean update
parameters = getParameterMap(parser.map());
break;
case CONNECTOR_CREDENTIAL_FIELD:
credential = parser.mapStrings();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
credential = new HashMap<>();
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String key = parser.currentName();
parser.nextToken();
if (parser.currentToken() != XContentParser.Token.VALUE_STRING
&& parser.currentToken() != XContentParser.Token.VALUE_NULL) {
throw new IllegalArgumentException(
"Credential values must be strings, found invalid type: " + parser.currentToken()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to avoid putting credential token in the exception string ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parser.currentToken does not print the value itself but rather the type of object.

);
}
credential.put(key, parser.textOrNull());
}
break;
case CONNECTOR_ACTIONS_FIELD:
actions = new ArrayList<>();
Expand All @@ -196,7 +207,14 @@ public static MLCreateConnectorInput parse(XContentParser parser, boolean update
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
backendRoles = new ArrayList<>();
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
backendRoles.add(parser.text());
if (parser.currentToken() != XContentParser.Token.VALUE_STRING
&& parser.currentToken() != XContentParser.Token.VALUE_NUMBER
&& parser.currentToken() != XContentParser.Token.VALUE_NULL) {
throw new IllegalArgumentException(
"Backend roles must contain only string values, found invalid type: " + parser.currentToken()
);
}
backendRoles.add(parser.textOrNull());
}
break;
case ADD_ALL_BACKEND_ROLES_FIELD:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,10 @@ public void constructorMLCreateConnectorInput_NullProtocol() {
.addAllBackendRoles(false)
.build();
});
assertEquals("Connector protocol is null", exception.getMessage());
assertEquals(
"Connector protocol is null. Please use one of [aws_sigv4, http, mcp_sse, mcp_streamable_http]",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to use ConnectorProtocols.VALID_PROTOCOLS here also, Otherwise we need to keep on changing this exception string here whenever we add new protocols.
Or use contains() API and remove the protocol list part.

exception.getMessage()
);
}

@Test
Expand Down Expand Up @@ -562,4 +565,54 @@ private void readInputStream(MLCreateConnectorInput input, Consumer<MLCreateConn
verify.accept(parsedInput);
}

@Test
public void testParse_BackendRolesWithJsonObject_ShouldThrowException() throws IOException {
// Test that backend_roles array containing a JSON object throws IllegalArgumentException
String jsonWithObjectInBackendRoles = """
{
"name": "test_connector_name",
"credential": {"key": "test_key_value"},
"version": "1",
"protocol": "http",
"backend_roles": [
{
"role_name": "admin",
"permissions": "all"
}
]
}
""";

XContentParser parser = createParser(jsonWithObjectInBackendRoles);

Throwable exception = assertThrows(IllegalArgumentException.class, () -> { MLCreateConnectorInput.parse(parser); });

assertTrue(exception.getMessage().contains("Backend roles must contain only string values"));
assertTrue(exception.getMessage().contains("START_OBJECT"));
}

@Test
public void testParse_CredentialWithJsonObject_ShouldThrowException() throws IOException {
// Test that credential values containing JSON objects throw IllegalArgumentException
String jsonWithObjectInCredential = """
{
"name": "test_connector_name",
"credential": {
"key": {
"nested": "object"
}
},
"version": "1",
"protocol": "http"
}
""";

XContentParser parser = createParser(jsonWithObjectInCredential);

Throwable exception = assertThrows(IllegalArgumentException.class, () -> { MLCreateConnectorInput.parse(parser); });

assertTrue(exception.getMessage().contains("Credential values must be strings"));
assertTrue(exception.getMessage().contains("START_OBJECT"));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ public void getTraces(String interactionId, int from, int maxResults, ActionList
ActionListener<GetResponse> al = ActionListener.wrap(getResponse -> {
// If the interaction doesn't exist, fail
if (!(getResponse.isExists() && getResponse.getId().equals(interactionId))) {
throw new ResourceNotFoundException("Message [" + interactionId + "] not found");
throw new ResourceNotFoundException("Message ID not found");
}
Interaction interaction = Interaction.fromMap(interactionId, getResponse.getSourceAsMap());
// checks if the user has permission to access the conversation that the interaction belongs to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ public void execute(Input input, ActionListener<Output> listener, TransportChann
log.error("Failed to get Agent index", cause);
listener.onFailure(new OpenSearchStatusException("Failed to get agent index", RestStatus.NOT_FOUND));
} else {
log.error("Failed to get ML Agent {}", agentId, cause);
listener.onFailure(cause);
log.error("Failed to get ML Agent", cause);
listener.onFailure(new OpenSearchStatusException("Failed to get agent", RestStatus.NOT_FOUND));
}
} else {
try {
Expand Down Expand Up @@ -351,7 +351,7 @@ public void execute(Input input, ActionListener<Output> listener, TransportChann
listener
.onFailure(
new OpenSearchStatusException(
"Failed to find agent with the provided agent id: " + agentId,
"Failed to find agent with the provided agent id",
RestStatus.NOT_FOUND
)
);
Expand Down
Loading