Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -202,4 +202,26 @@ public void test_ImmutableEmptyParametersMap() {
Output output = outputCaptor.getValue();
Assert.assertTrue(output instanceof ModelTensorOutput);
}

@Test
public void test_ToolExecutionFailsWithoutProperPermission() {
when(toolMLInput.getToolName()).thenReturn("TestTool");
when(toolMLInput.getInputDataset()).thenReturn(inputDataSet);
when(inputDataSet.getParameters()).thenReturn(parameters);
when(toolFactory.create(any())).thenReturn(tool);
when(tool.validate(parameters)).thenReturn(true);

Mockito.doAnswer(invocation -> {
ActionListener<Object> listener = invocation.getArgument(1);
listener.onFailure(new SecurityException("no permissions for [indices:data/read/search] and User [name=test_user]"));
return null;
}).when(tool).run(Mockito.eq(parameters), any());

mlToolExecutor.execute(toolMLInput, actionListener);

Mockito.verify(actionListener).onFailure(exceptionCaptor.capture());
Exception exception = exceptionCaptor.getValue();
Assert.assertTrue(exception instanceof SecurityException);
Assert.assertTrue(exception.getMessage().contains("no permissions"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,21 @@ public void testRun_StringConversion_AddPersistentNote() {
assertEquals("[\"existing\",\"new note\"]", parameters.get(ReadFromScratchPadTool.SCRATCHPAD_NOTES_KEY));
}

@Test
public void testRun_SecurityException() {
Map<String, String> parameters = new HashMap<>();
parameters.put(ReadFromScratchPadTool.SCRATCHPAD_NOTES_KEY, "[\"confidential data\"]");

SecurityException securityException = new SecurityException("no permissions for [indices:data/read/get] and User [name=test_user]");
listener.onFailure(securityException);

ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener).onFailure(exceptionCaptor.capture());
Exception exception = exceptionCaptor.getValue();
assertTrue(exception instanceof SecurityException);
assertTrue(exception.getMessage().contains("no permissions"));
}

@Test
public void testFactory() {
ReadFromScratchPadTool.Factory factory = ReadFromScratchPadTool.Factory.getInstance();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,23 @@ public void testRun_StringConversion_WithJsonArray() {
ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);
verify(listener).onResponse(captor.capture());
assertEquals("Wrote to scratchpad: new note", captor.getValue());
assertEquals("[\"existing note\",\"new note\"]", parameters.get(WriteToScratchPadTool.SCRATCHPAD_NOTES_KEY));
}

@Test
public void testRun_SecurityException() {
Map<String, String> parameters = new HashMap<>();
parameters.put(WriteToScratchPadTool.NOTES_KEY, "confidential test data");

SecurityException securityException = new SecurityException(
"no permissions for [indices:data/write/index] and User [name=test_user]"
);
listener.onFailure(securityException);

ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener).onFailure(exceptionCaptor.capture());
Exception exception = exceptionCaptor.getValue();
assertTrue(exception instanceof SecurityException);
assertTrue(exception.getMessage().contains("no permissions"));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@
import java.util.Objects;

import org.apache.commons.lang3.StringUtils;
import org.apache.hc.core5.http.HttpHost;
import org.junit.Before;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.client.RestClient;
import org.opensearch.common.settings.Settings;
import org.opensearch.commons.rest.SecureRestClientBuilder;
import org.opensearch.ml.engine.tools.ListIndexTool;
import org.opensearch.ml.rest.RestBaseAgentToolsIT;
import org.opensearch.ml.utils.TestHelper;
Expand All @@ -37,6 +41,44 @@ public void setUpCluster() throws Exception {
registerListIndexFlowAgent();
}

public void testListIndexWithNoPermissions() throws Exception {
if (!isHttps()) {
log.info("Skipping permission test as security is not enabled");
return;
}

String noPermissionUser = "no_permission_user";
String password = "TestPassword123!";

try {
createUser(noPermissionUser, password, new ArrayList<>());

final RestClient noPermissionClient = new SecureRestClientBuilder(
getClusterHosts().toArray(new HttpHost[0]),
isHttps(),
noPermissionUser,
password
).setSocketTimeout(60000).build();

try {
ResponseException exception = expectThrows(ResponseException.class, () -> {
TestHelper
.makeRequest(noPermissionClient, "POST", "/_plugins/_ml/agents/" + agentId + "/_execute", null, question, null);
});

String errorMessage = exception.getMessage().toLowerCase();
assertTrue(
"Expected permission error, got: " + errorMessage,
errorMessage.contains("no permissions") || errorMessage.contains("forbidden") || errorMessage.contains("unauthorized")
);
} finally {
noPermissionClient.close();
}
} finally {
deleteUser(noPermissionUser);
}
}

private List<String> createIndices(int count) throws IOException {
List<String> indices = new ArrayList<>();
for (int i = 0; i < count; i++) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.tools;

import static org.opensearch.ml.utils.TestHelper.makeRequest;

import org.junit.Before;
import org.opensearch.ml.rest.MLCommonsRestTestCase;
import org.opensearch.test.OpenSearchIntegTestCase;

@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, numDataNodes = 3)
public class ScratchPadToolIT extends MLCommonsRestTestCase {

@Before
public void setUp() throws Exception {
super.setUp();
}

public void testScratchpadSizeLimit() throws Exception {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for verifying the exception in UT.

String largeContent = "A".repeat(100 * 1024 * 1024);
String requestBody = String.format("{\"parameters\":{\"notes\":\"%s\"}}", largeContent);

Exception exception = expectThrows(Exception.class, () -> {
makeRequest(client(), "POST", "/_plugins/_ml/tools/_execute/WriteToScratchPadTool", null, requestBody, null);
});

String errorMessage = exception.getMessage().toLowerCase();
assertTrue(
"Expected HTTP content length error, got: " + errorMessage,
errorMessage.contains("content length")
|| errorMessage.contains("too large")
|| errorMessage.contains("entity too large")
|| errorMessage.contains("413")
);
}
}
Loading