Skip to content

Commit 540a52e

Browse files
committed
Fix: support Claude V3 output parsing in Generative QA Processor
Signed-off-by: Anuj Soni <[email protected]>
1 parent c04f537 commit 540a52e

File tree

2 files changed

+86
-2
lines changed

2 files changed

+86
-2
lines changed

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,40 @@ protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider,
191191
answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT));
192192
}
193193
} else if (provider == ModelProvider.BEDROCK) {
194-
answerField = "completion";
195-
fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField);
194+
// Handle both Claude V2 and V3 response formats
195+
if (dataAsMap.containsKey("completion")) {
196+
// Old Claude V2 format
197+
answerField = "completion";
198+
fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField);
199+
} else if (dataAsMap.containsKey("content")) {
200+
// New Claude V3 format
201+
Object contentObj = dataAsMap.get("content");
202+
if (contentObj instanceof List) {
203+
List<?> contentList = (List<?>) contentObj;
204+
if (!contentList.isEmpty()) {
205+
Object first = contentList.get(0);
206+
if (first instanceof Map) {
207+
Map<?, ?> firstMap = (Map<?, ?>) first;
208+
Object text = firstMap.get("text");
209+
if (text != null) {
210+
answers.add(text.toString());
211+
} else {
212+
errors.add("Claude V3 response missing 'text' field.");
213+
}
214+
} else {
215+
errors.add("Unexpected content format in Claude V3 response.");
216+
}
217+
} else {
218+
errors.add("Empty content list in Claude V3 response.");
219+
}
220+
} else {
221+
errors.add("Unexpected type for 'content' in Claude V3 response.");
222+
}
223+
} else {
224+
// Fallback error handling
225+
errors.add("Unsupported Claude response format: " + dataAsMap.keySet());
226+
log.error("Unknown Bedrock/Claude response format: {}", dataAsMap);
227+
}
196228
} else if (provider == ModelProvider.COHERE) {
197229
answerField = "text";
198230
fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField);

search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,58 @@ public void onFailure(Exception e) {
143143
assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet);
144144
}
145145

146+
public void testChatCompletionApiForBedrockClaudeV3() throws Exception {
147+
MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class);
148+
ArgumentCaptor<MLInput> captor = ArgumentCaptor.forClass(MLInput.class);
149+
DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client);
150+
connector.setMlClient(mlClient);
151+
152+
// Claude V3-style response
153+
Map<String, Object> textPart = Map.of("type", "text", "text", "Hello from Claude V3");
154+
Map<String, Object> dataAsMap = Map.of("content", List.of(textPart));
155+
156+
ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap);
157+
ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor))));
158+
ActionFuture<MLOutput> future = mock(ActionFuture.class);
159+
when(future.actionGet(anyLong())).thenReturn(mlOutput);
160+
when(mlClient.predict(any(), any())).thenReturn(future);
161+
162+
ChatCompletionInput input = new ChatCompletionInput(
163+
"bedrock/model",
164+
"question",
165+
Collections.emptyList(),
166+
Collections.emptyList(),
167+
0,
168+
"prompt",
169+
"instructions",
170+
Llm.ModelProvider.BEDROCK,
171+
null,
172+
null
173+
);
174+
175+
doAnswer(invocation -> {
176+
((ActionListener<MLOutput>) invocation.getArguments()[2]).onResponse(mlOutput);
177+
return null;
178+
}).when(mlClient).predict(any(), any(), any());
179+
180+
connector.doChatCompletion(input, new ActionListener<>() {
181+
@Override
182+
public void onResponse(ChatCompletionOutput output) {
183+
// Verify that we parsed the Claude V3 response correctly
184+
assertEquals("Hello from Claude V3", output.getAnswers().get(0));
185+
}
186+
187+
@Override
188+
public void onFailure(Exception e) {
189+
fail("Claude V3 test failed: " + e.getMessage());
190+
}
191+
});
192+
193+
verify(mlClient, times(1)).predict(any(), captor.capture(), any());
194+
MLInput mlInput = captor.getValue();
195+
assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet);
196+
}
197+
146198
public void testChatCompletionApiForBedrock() throws Exception {
147199
MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class);
148200
ArgumentCaptor<MLInput> captor = ArgumentCaptor.forClass(MLInput.class);

0 commit comments

Comments
 (0)