Skip to content

Commit 2dea1ad

Browse files
committed
add unit test
Signed-off-by: Hailong Cui <[email protected]>
1 parent 161068b commit 2dea1ad

File tree

6 files changed

+179
-27
lines changed

6 files changed

+179
-27
lines changed

src/main/java/org/opensearch/agent/tools/CreateAlertTool.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import java.util.regex.Pattern;
1919

2020
import org.apache.commons.text.StringSubstitutor;
21-
import org.apache.logging.log4j.util.Strings;
2221
import org.opensearch.action.ActionRequest;
2322
import org.opensearch.action.admin.indices.get.GetIndexRequest;
2423
import org.opensearch.action.support.IndicesOptions;
@@ -27,6 +26,7 @@
2726
import org.opensearch.client.Client;
2827
import org.opensearch.cluster.metadata.MappingMetadata;
2928
import org.opensearch.core.action.ActionListener;
29+
import org.opensearch.core.common.Strings;
3030
import org.opensearch.core.common.logging.LoggerMessageFormat;
3131
import org.opensearch.ml.common.FunctionName;
3232
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
@@ -275,7 +275,7 @@ public void init(Client client) {
275275
@Override
276276
public CreateAlertTool create(Map<String, Object> params) {
277277
String modelId = (String) params.get(MODEL_ID);
278-
if (Strings.isBlank(modelId)) {
278+
if (Strings.isNullOrEmpty(modelId)) {
279279
throw new IllegalArgumentException("model_id cannot be null or blank.");
280280
}
281281
String modelType = (String) params.getOrDefault("model_type", ModelType.CLAUDE.toString());

src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.opensearch.client.Client;
3232
import org.opensearch.cluster.metadata.MappingMetadata;
3333
import org.opensearch.core.action.ActionListener;
34+
import org.opensearch.core.common.Strings;
3435
import org.opensearch.ml.common.FunctionName;
3536
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
3637
import org.opensearch.ml.common.input.MLInput;
@@ -44,7 +45,6 @@
4445

4546
import com.google.common.collect.ImmutableMap;
4647

47-
import joptsimple.internal.Strings;
4848
import lombok.Getter;
4949
import lombok.Setter;
5050
import lombok.extern.log4j.Log4j2;

src/main/java/org/opensearch/agent/tools/PainlessTool.java

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@
1010
import java.util.Map;
1111

1212
import org.opensearch.core.action.ActionListener;
13+
import org.opensearch.core.common.Strings;
1314
import org.opensearch.ml.common.spi.tools.Tool;
1415
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
16+
import org.opensearch.ml.common.utils.StringUtils;
1517
import org.opensearch.script.Script;
1618
import org.opensearch.script.ScriptService;
1719
import org.opensearch.script.ScriptType;
1820
import org.opensearch.script.TemplateScript;
1921

20-
import com.google.gson.Gson;
21-
2222
import lombok.Getter;
2323
import lombok.Setter;
2424
import lombok.extern.log4j.Log4j2;
@@ -49,40 +49,47 @@ public class PainlessTool implements Tool {
4949
private String version;
5050

5151
private ScriptService scriptService;
52-
@Setter
5352
private String scriptCode;
5453

5554
public PainlessTool(ScriptService scriptEngine, String script) {
5655
this.scriptService = scriptEngine;
5756
this.scriptCode = script;
5857
}
5958

60-
private Gson gson = new Gson();
61-
6259
@Override
6360
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
6461
Script script = new Script(ScriptType.INLINE, "painless", scriptCode, Collections.emptyMap());
62+
Map<String, Object> flattenedParameters = getFlattenedParameters(parameters);
63+
TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(flattenedParameters);
64+
try {
65+
String result = templateScript.execute();
66+
listener.onResponse(result == null ? (T) "" : (T) result);
67+
} catch (Exception e) {
68+
listener.onFailure(e);
69+
}
70+
}
71+
72+
@Override
73+
public boolean validate(Map<String, String> map) {
74+
return true;
75+
}
76+
77+
Map<String, Object> getFlattenedParameters(Map<String, String> parameters) {
6578
Map<String, Object> flattenedParameters = new HashMap<>();
6679
for (Map.Entry<String, String> entry : parameters.entrySet()) {
67-
// keep original values and flatten
80+
// keep both original values and flatten
6881
flattenedParameters.put(entry.getKey(), entry.getValue());
69-
// TODO default is json parser. we may support format
7082
try {
83+
// default is json parser, we may add more...
7184
String value = org.apache.commons.text.StringEscapeUtils.unescapeJson(entry.getValue());
72-
Map<String, ?> map = gson.fromJson(value, Map.class);
85+
Map<String, ?> map = StringUtils.fromJson(value, "");
7386
flattenMap(map, flattenedParameters, entry.getKey());
7487
} catch (Throwable ignored) {}
7588
}
76-
TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(flattenedParameters);
77-
try {
78-
String result = templateScript.execute();
79-
listener.onResponse(result == null ? (T) "" : (T) result);
80-
} catch (Exception e) {
81-
listener.onFailure(e);
82-
}
89+
return flattenedParameters;
8390
}
8491

85-
private void flattenMap(Map<String, ?> map, Map<String, Object> flatMap, String prefix) {
92+
void flattenMap(Map<String, ?> map, Map<String, Object> flatMap, String prefix) {
8693
for (Map.Entry<String, ?> entry : map.entrySet()) {
8794
String key = entry.getKey();
8895
if (prefix != null && !prefix.isEmpty()) {
@@ -97,11 +104,6 @@ private void flattenMap(Map<String, ?> map, Map<String, Object> flatMap, String
97104
}
98105
}
99106

100-
@Override
101-
public boolean validate(Map<String, String> map) {
102-
return true;
103-
}
104-
105107
public static class Factory implements Tool.Factory<PainlessTool> {
106108
private ScriptService scriptService;
107109

@@ -127,7 +129,9 @@ public void init(ScriptService scriptService) {
127129
@Override
128130
public PainlessTool create(Map<String, Object> map) {
129131
String script = (String) map.get("script");
130-
// TODO add script non null/empty check
132+
if (Strings.isNullOrEmpty(script)) {
133+
throw new IllegalArgumentException("script is required");
134+
}
131135
return new PainlessTool(scriptService, script);
132136
}
133137

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.agent.tools;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.mockito.ArgumentMatchers.any;
10+
import static org.mockito.Mockito.times;
11+
import static org.mockito.Mockito.verify;
12+
import static org.mockito.Mockito.when;
13+
14+
import java.util.HashMap;
15+
import java.util.Map;
16+
17+
import org.apache.commons.text.StringEscapeUtils;
18+
import org.junit.Before;
19+
import org.junit.Test;
20+
import org.mockito.ArgumentCaptor;
21+
import org.mockito.Mock;
22+
import org.mockito.MockitoAnnotations;
23+
import org.opensearch.core.action.ActionListener;
24+
import org.opensearch.script.ScriptService;
25+
import org.opensearch.script.TemplateScript;
26+
27+
import com.google.gson.Gson;
28+
29+
/**
30+
* this is a test file to test PainlessTool with junit
31+
*/
32+
public class PainlessToolTests {
33+
@Mock
34+
private ScriptService scriptService;
35+
@Mock
36+
private TemplateScript templateScript;
37+
@Mock
38+
private ActionListener<String> actionListener;
39+
40+
@Before
41+
public void setup() {
42+
MockitoAnnotations.openMocks(this);
43+
TemplateScript.Factory factory = new TemplateScript.Factory() {
44+
@Override
45+
public TemplateScript newInstance(Map<String, Object> params) {
46+
return templateScript;
47+
}
48+
};
49+
50+
when(scriptService.compile(any(), any())).thenReturn(factory);
51+
52+
PainlessTool.Factory.getInstance().init(scriptService);
53+
}
54+
55+
@Test
56+
public void testRun() {
57+
String script = "return 'Hello World';";
58+
PainlessTool tool = PainlessTool.Factory.getInstance().create(Map.of("script", script));
59+
when(templateScript.execute()).thenReturn("hello");
60+
tool.run(Map.of(), actionListener);
61+
62+
verify(templateScript).execute();
63+
verify(scriptService).compile(any(), any());
64+
ArgumentCaptor<String> responseCaptor = ArgumentCaptor.forClass(String.class);
65+
verify(actionListener, times(1)).onResponse(responseCaptor.capture());
66+
assertEquals("hello", responseCaptor.getValue());
67+
}
68+
69+
// test run wit exception
70+
@Test
71+
public void testRun_with_exception() {
72+
String script = "return 'Hello World';";
73+
PainlessTool tool = PainlessTool.Factory.getInstance().create(Map.of("script", script));
74+
when(templateScript.execute()).thenThrow(new RuntimeException("error"));
75+
tool.run(Map.of(), actionListener);
76+
77+
verify(templateScript).execute();
78+
verify(scriptService).compile(any(), any());
79+
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
80+
verify(actionListener, times(1)).onFailure(exceptionCaptor.capture());
81+
assertEquals("error", exceptionCaptor.getValue().getMessage());
82+
}
83+
84+
// test factory create
85+
@Test
86+
public void testFactory_create() {
87+
String script = "return 'Hello World';";
88+
PainlessTool tool = PainlessTool.Factory.getInstance().create(Map.of("script", script));
89+
assertEquals(PainlessTool.TYPE, tool.getType());
90+
assertEquals("PainlessTool", tool.getName());
91+
assertEquals("Use this tool to execute painless script", tool.getDescription());
92+
}
93+
94+
// test factory create with exception
95+
@Test(expected = IllegalArgumentException.class)
96+
public void testFactory_create_with_exception() {
97+
PainlessTool.Factory.getInstance().create(Map.of());
98+
}
99+
100+
// test flattenMap
101+
@Test
102+
public void testFlattenMap_without_prefix() {
103+
String script = "return 'Hello World';";
104+
PainlessTool tool = PainlessTool.Factory.getInstance().create(Map.of("script", script));
105+
Map<String, Object> map = Map.of("a", Map.of("b", "c"), "k", "v");
106+
Map<String, Object> resultMap = new HashMap<>();
107+
tool.flattenMap(map, resultMap, "");
108+
assertEquals(Map.of("a.b", "c", "k", "v"), resultMap);
109+
}
110+
111+
// with prefix
112+
@Test
113+
public void testFlattenMap_with_prefix() {
114+
String script = "return 'Hello World';";
115+
PainlessTool tool = PainlessTool.Factory.getInstance().create(Map.of("script", script));
116+
Map<String, Object> map = Map.of("a", Map.of("b", "c"), "k", "v");
117+
Map<String, Object> resultMap = new HashMap<>();
118+
tool.flattenMap(map, resultMap, "prefix");
119+
assertEquals(Map.of("prefix.a.b", "c", "prefix.k", "v"), resultMap);
120+
}
121+
122+
// nest map with depth 3
123+
@Test
124+
public void testFlattenMap_with_depth_3() {
125+
String script = "return 'Hello World';";
126+
PainlessTool tool = PainlessTool.Factory.getInstance().create(Map.of("script", script));
127+
Map<String, Object> map = Map.of("a", Map.of("b", Map.of("c", "d"), "k", "v"));
128+
Gson gson = new Gson();
129+
System.out.println(StringEscapeUtils.escapeJson(gson.toJson(map)));
130+
Map<String, Object> resultMap = new HashMap<>();
131+
tool.flattenMap(map, resultMap, "");
132+
assertEquals(Map.of("a.b.c", "d", "a.k", "v"), resultMap);
133+
}
134+
135+
// test getFlattenedParameters
136+
@Test
137+
public void testGetFlattenedParameters() {
138+
String script = "return 'Hello World';";
139+
PainlessTool tool = PainlessTool.Factory.getInstance().create(Map.of("script", script));
140+
Map<String, String> map = Map.of("k", "{\\\"a\\\":{\\\"k\\\":\\\"v\\\",\\\"b\\\":{\\\"c\\\":\\\"d\\\"}}}");
141+
Map<String, Object> resultMap = tool.getFlattenedParameters(map);
142+
assertEquals(
143+
Map.of("k.a.b.c", "d", "k.a.k", "v", "k", "{\\\"a\\\":{\\\"k\\\":\\\"v\\\",\\\"b\\\":{\\\"c\\\":\\\"d\\\"}}}"),
144+
resultMap
145+
);
146+
}
147+
}

src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ public void setup() {
9393
null,
9494
null,
9595
null,
96-
null
96+
null,
97+
true
9798
);
9899
}
99100

src/test/java/org/opensearch/integTest/PainlessToolIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public void test_execute_with_parameter() {
4949
Assert.assertEquals("12", result);
5050
}
5151

52-
public void test_execute_with_parameter2() throws URISyntaxException, IOException {
52+
public void test_execute_with_parsing_input() throws URISyntaxException, IOException {
5353
String script =
5454
"return 'An example output: with ppl:<ppl>' + params.get('PPL.output.ppl') + '</ppl>, and this is ppl result: <ppl_result>' + params.get('PPL.output.executionResult') + '</ppl_result>'";
5555
String mockPPLOutput = "return '{\\\\\"executionResult\\\\\":\\\\\"result\\\\\",\\\\\"ppl\\\\\":\\\\\"source=demo| head 1\\\\\"}'";

0 commit comments

Comments
 (0)