Skip to content

Commit 4d5b90d

Browse files
authored
Merge branch 'feature/agentic_memory_integration' into ag-ui-support
Signed-off-by: Jiaping Zeng <[email protected]>
2 parents d0d8078 + 14b2206 commit 4d5b90d

File tree

63 files changed

+4852
-571
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+4852
-571
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ public class CommonValue {
113113
public static final String CLIENT_CONFIG_FIELD = "client_config";
114114
public static final String URL_FIELD = "url";
115115
public static final String HEADERS_FIELD = "headers";
116+
public static final String CONNECTOR_ACTION_FIELD = "connector_action";
116117

117118
// MCP Constants
118119
public static final String MCP_TOOL_NAME_FIELD = "name";

common/src/main/java/org/opensearch/ml/common/MLAgentType.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public static MLAgentType from(String value) {
2121
try {
2222
return MLAgentType.valueOf(value.toUpperCase(Locale.ROOT));
2323
} catch (Exception e) {
24-
throw new IllegalArgumentException("Wrong Agent type");
24+
throw new IllegalArgumentException(value + " is not a valid Agent Type");
2525
}
2626
}
2727
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common;
7+
8+
import java.util.Locale;
9+
10+
public enum MLMemoryType {
11+
CONVERSATION_INDEX,
12+
AGENTIC_MEMORY;
13+
14+
public static MLMemoryType from(String value) {
15+
if (value != null) {
16+
try {
17+
return MLMemoryType.valueOf(value.toUpperCase(Locale.ROOT));
18+
} catch (Exception e) {
19+
throw new IllegalArgumentException("Wrong Memory type");
20+
}
21+
}
22+
return null;
23+
}
24+
}

common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ private void validate() {
113113
String.format("Agent name cannot be empty or exceed max length of %d characters", MLAgent.AGENT_NAME_MAX_LENGTH)
114114
);
115115
}
116-
validateMLAgentType(type);
116+
MLAgentType.from(type);
117117
if (type.equalsIgnoreCase(MLAgentType.CONVERSATIONAL.toString()) && llm == null) {
118118
throw new IllegalArgumentException("We need model information for the conversational agent type");
119119
}

common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,32 +26,37 @@ public class MLMemorySpec implements ToXContentObject {
2626
public static final String MEMORY_TYPE_FIELD = "type";
2727
public static final String WINDOW_SIZE_FIELD = "window_size";
2828
public static final String SESSION_ID_FIELD = "session_id";
29+
public static final String MEMORY_CONTAINER_ID_FIELD = "memory_container_id";
2930

3031
private String type;
3132
@Setter
3233
private String sessionId;
3334
private Integer windowSize;
35+
private String memoryContainerId;
3436

3537
@Builder(toBuilder = true)
36-
public MLMemorySpec(String type, String sessionId, Integer windowSize) {
38+
public MLMemorySpec(String type, String sessionId, Integer windowSize, String memoryContainerId) {
3739
if (type == null) {
3840
throw new IllegalArgumentException("agent name is null");
3941
}
4042
this.type = type;
4143
this.sessionId = sessionId;
4244
this.windowSize = windowSize;
45+
this.memoryContainerId = memoryContainerId;
4346
}
4447

4548
public MLMemorySpec(StreamInput input) throws IOException {
4649
type = input.readString();
4750
sessionId = input.readOptionalString();
4851
windowSize = input.readOptionalInt();
52+
memoryContainerId = input.readOptionalString();
4953
}
5054

5155
public void writeTo(StreamOutput out) throws IOException {
5256
out.writeString(type);
5357
out.writeOptionalString(sessionId);
5458
out.writeOptionalInt(windowSize);
59+
out.writeOptionalString(memoryContainerId);
5560
}
5661

5762
@Override
@@ -64,6 +69,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
6469
if (sessionId != null) {
6570
builder.field(SESSION_ID_FIELD, sessionId);
6671
}
72+
if (memoryContainerId != null) {
73+
builder.field(MEMORY_CONTAINER_ID_FIELD, memoryContainerId);
74+
}
6775
builder.endObject();
6876
return builder;
6977
}
@@ -72,6 +80,7 @@ public static MLMemorySpec parse(XContentParser parser) throws IOException {
7280
String type = null;
7381
String sessionId = null;
7482
Integer windowSize = null;
83+
String memoryContainerId = null;
7584

7685
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
7786
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -88,12 +97,15 @@ public static MLMemorySpec parse(XContentParser parser) throws IOException {
8897
case WINDOW_SIZE_FIELD:
8998
windowSize = parser.intValue();
9099
break;
100+
case MEMORY_CONTAINER_ID_FIELD:
101+
memoryContainerId = parser.text();
102+
break;
91103
default:
92104
parser.skipChildren();
93105
break;
94106
}
95107
}
96-
return MLMemorySpec.builder().type(type).sessionId(sessionId).windowSize(windowSize).build();
108+
return MLMemorySpec.builder().type(type).sessionId(sessionId).windowSize(windowSize).memoryContainerId(memoryContainerId).build();
97109
}
98110

99111
public static MLMemorySpec fromStream(StreamInput in) throws IOException {

common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,11 @@ public <T> void parseResponse(T response, List<ModelTensor> modelTensors, boolea
121121

122122
@Override
123123
public Optional<ConnectorAction> findAction(String action) {
124-
if (actions != null) {
125-
return actions.stream().filter(a -> a.getActionType().name().equalsIgnoreCase(action)).findFirst();
124+
if (actions != null && action != null) {
125+
if (ConnectorAction.ActionType.isValidAction(action)) {
126+
return actions.stream().filter(a -> a.getActionType().name().equalsIgnoreCase(action)).findFirst();
127+
}
128+
return actions.stream().filter(a -> action.equals(a.getName())).findFirst();
126129
}
127130
return Optional.empty();
128131
}

common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
99

1010
import java.io.IOException;
11+
import java.util.Arrays;
1112
import java.util.HashSet;
1213
import java.util.List;
1314
import java.util.Locale;
@@ -33,6 +34,7 @@
3334
public class ConnectorAction implements ToXContentObject, Writeable {
3435

3536
public static final String ACTION_TYPE_FIELD = "action_type";
37+
public static final String NAME_FIELD = "name";
3638
public static final String METHOD_FIELD = "method";
3739
public static final String URL_FIELD = "url";
3840
public static final String HEADERS_FIELD = "headers";
@@ -52,6 +54,7 @@ public class ConnectorAction implements ToXContentObject, Writeable {
5254
private static final Logger logger = LogManager.getLogger(ConnectorAction.class);
5355

5456
private ActionType actionType;
57+
private String name;
5558
private String method;
5659
private String url;
5760
private Map<String, String> headers;
@@ -62,6 +65,7 @@ public class ConnectorAction implements ToXContentObject, Writeable {
6265
@Builder(toBuilder = true)
6366
public ConnectorAction(
6467
ActionType actionType,
68+
String name,
6569
String method,
6670
String url,
6771
Map<String, String> headers,
@@ -78,7 +82,11 @@ public ConnectorAction(
7882
if (method == null) {
7983
throw new IllegalArgumentException("method can't be null");
8084
}
85+
if (name != null && ActionType.isValidAction(name)) {
86+
throw new IllegalArgumentException("name can't be one of action type " + Arrays.toString(ActionType.values()));
87+
}
8188
this.actionType = actionType;
89+
this.name = name;
8290
this.method = method;
8391
this.url = url;
8492
this.headers = headers;
@@ -97,6 +105,7 @@ public ConnectorAction(StreamInput input) throws IOException {
97105
this.requestBody = input.readOptionalString();
98106
this.preProcessFunction = input.readOptionalString();
99107
this.postProcessFunction = input.readOptionalString();
108+
this.name = input.readOptionalString();// TODO: add version check
100109
}
101110

102111
@Override
@@ -113,6 +122,7 @@ public void writeTo(StreamOutput out) throws IOException {
113122
out.writeOptionalString(requestBody);
114123
out.writeOptionalString(preProcessFunction);
115124
out.writeOptionalString(postProcessFunction);
125+
out.writeOptionalString(name); // TODO: add version check
116126
}
117127

118128
@Override
@@ -139,6 +149,9 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
139149
if (postProcessFunction != null) {
140150
builder.field(ACTION_POST_PROCESS_FUNCTION, postProcessFunction);
141151
}
152+
if (name != null) {
153+
builder.field(NAME_FIELD, name);
154+
}
142155
return builder.endObject();
143156
}
144157

@@ -149,6 +162,7 @@ public static ConnectorAction fromStream(StreamInput in) throws IOException {
149162

150163
public static ConnectorAction parse(XContentParser parser) throws IOException {
151164
ActionType actionType = null;
165+
String name = null;
152166
String method = null;
153167
String url = null;
154168
Map<String, String> headers = null;
@@ -165,6 +179,9 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
165179
case ACTION_TYPE_FIELD:
166180
actionType = ActionType.valueOf(parser.text().toUpperCase(Locale.ROOT));
167181
break;
182+
case NAME_FIELD:
183+
name = parser.text();
184+
break;
168185
case METHOD_FIELD:
169186
method = parser.text();
170187
break;
@@ -191,6 +208,7 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
191208
return ConnectorAction
192209
.builder()
193210
.actionType(actionType)
211+
.name(name)
194212
.method(method)
195213
.url(url)
196214
.headers(headers)

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol;
1414
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
1515
import static org.opensearch.ml.common.utils.StringUtils.isJson;
16+
import static org.opensearch.ml.common.utils.StringUtils.isJsonOrNdjson;
1617
import static org.opensearch.ml.common.utils.StringUtils.parseParameters;
1718

1819
import java.io.IOException;
@@ -358,12 +359,15 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
358359
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
359360
payload = substitutor.replace(payload);
360361

361-
if (!isJson(payload)) {
362+
if (!isJsonOrNdjson(payload)) {
362363
throw new IllegalArgumentException("Invalid payload: " + payload);
363364
} else if (neededStreamParameterInPayload(parameters)) {
364-
JsonObject jsonObject = JsonParser.parseString(payload).getAsJsonObject();
365-
jsonObject.addProperty("stream", true);
366-
payload = jsonObject.toString();
365+
// Only add stream parameter for single JSON objects (not NDJSON)
366+
if (isJson(payload)) {
367+
JsonObject jsonObject = JsonParser.parseString(payload).getAsJsonObject();
368+
jsonObject.addProperty("stream", true);
369+
payload = jsonObject.toString();
370+
}
367371
}
368372
return (T) payload;
369373
}

common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.opensearch.core.xcontent.ToXContentObject;
2929
import org.opensearch.core.xcontent.XContentBuilder;
3030
import org.opensearch.ml.common.CommonValue;
31+
import org.opensearch.ml.common.memory.Message;
3132
import org.opensearch.search.SearchHit;
3233

3334
import lombok.AllArgsConstructor;
@@ -39,7 +40,7 @@
3940
*/
4041
@Builder
4142
@AllArgsConstructor
42-
public class Interaction implements Writeable, ToXContentObject {
43+
public class Interaction implements Writeable, ToXContentObject, Message {
4344

4445
@Getter
4546
private String id;
@@ -275,4 +276,13 @@ public String toString() {
275276
+ "}";
276277
}
277278

279+
@Override
280+
public String getType() {
281+
return "";
282+
}
283+
284+
@Override
285+
public String getContent() {
286+
return "";
287+
}
278288
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.memory;
7+
8+
import java.util.List;
9+
import java.util.Map;
10+
11+
import org.opensearch.core.action.ActionListener;
12+
13+
/**
14+
* A general memory interface.
15+
* @param <T> Message type
16+
* @param <R> Save response type
17+
* @param <S> Update response type
18+
*/
19+
public interface Memory<T extends Message, R, S> {
20+
21+
/**
22+
* Get memory type.
23+
* @return memory type
24+
*/
25+
String getType();
26+
27+
/**
28+
* Get memory ID.
29+
* @return memory ID
30+
*/
31+
String getId();
32+
33+
default void save(Message message, String parentId, Integer traceNum, String action) {}
34+
35+
default void save(Message message, String parentId, Integer traceNum, String action, ActionListener<R> listener) {}
36+
37+
default void update(String messageId, Map<String, Object> updateContent, ActionListener<S> updateListener) {}
38+
39+
default void getMessages(int size, ActionListener<List<T>> listener) {}
40+
41+
/**
42+
* Clear all memory.
43+
*/
44+
void clear();
45+
46+
void deleteInteractionAndTrace(String regenerateInteractionId, ActionListener<Boolean> wrap);
47+
48+
interface Factory<M extends Memory> {
49+
/**
50+
* Create an instance of this Memory.
51+
*
52+
* @param params Parameters for the memory
53+
* @param listener Action listener for the memory creation action
54+
*/
55+
void create(Map<String, Object> params, ActionListener<M> listener);
56+
}
57+
}

0 commit comments

Comments
 (0)