From 101bbc461131cf3d19e92a0e555fda4db483078a Mon Sep 17 00:00:00 2001 From: Sun Yuhan <1085481446@qq.com> Date: Tue, 3 Jun 2025 11:44:13 +0800 Subject: [PATCH 1/2] fix: Fixed the issue where tool call information was lost when using DefaultChatOptions. Signed-off-by: Sun Yuhan <1085481446@qq.com> --- .../chat/client/DefaultChatClientUtils.java | 20 +++++++ .../client/DefaultChatClientUtilsTests.java | 60 +++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java index 10f623e2b70..f68708e7893 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java @@ -27,8 +27,10 @@ import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.DefaultChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.model.tool.DefaultToolCallingChatOptions; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.util.Assert; @@ -39,6 +41,7 @@ * Utilities for supporting the {@link DefaultChatClient} implementation. * * @author Thomas Vitale + * @author Sun Yuhan * @since 1.0.0 */ final class DefaultChatClientUtils { @@ -94,6 +97,23 @@ static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClient */ ChatOptions processedChatOptions = inputRequest.getChatOptions(); + + if (processedChatOptions instanceof DefaultChatOptions defaultChatOptions) { + if (!inputRequest.getToolNames().isEmpty() || !inputRequest.getToolCallbacks().isEmpty() + || !CollectionUtils.isEmpty(inputRequest.getToolContext())) { + processedChatOptions = DefaultToolCallingChatOptions.builder() + .model(defaultChatOptions.getModel()) + .frequencyPenalty(defaultChatOptions.getFrequencyPenalty()) + .maxTokens(defaultChatOptions.getMaxTokens()) + .presencePenalty(defaultChatOptions.getPresencePenalty()) + .stopSequences(defaultChatOptions.getStopSequences()) + .temperature(defaultChatOptions.getTemperature()) + .topK(defaultChatOptions.getTopK()) + .topP(defaultChatOptions.getTopP()) + .build(); + } + } + if (processedChatOptions instanceof ToolCallingChatOptions toolCallingChatOptions) { if (!inputRequest.getToolNames().isEmpty()) { Set toolNames = ToolCallingChatOptions diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java index 9d4d4962069..7b8d5491f90 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java @@ -26,6 +26,7 @@ import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.prompt.DefaultChatOptions; import org.springframework.ai.content.Media; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.template.TemplateRenderer; @@ -43,6 +44,7 @@ * Unit tests for {@link DefaultChatClientUtils}. * * @author Thomas Vitale + * @author Sun Yuhan */ class DefaultChatClientUtilsTests { @@ -322,6 +324,64 @@ void whenToolContextAndChatOptionsAreProvidedThenTheValuesAreMerged() { .containsAllEntriesOf(toolContext2); } + @Test + void whenToolNamesAndChatOptionsAreDefaultChatOptions() { + Set toolNames1 = Set.of("toolA", "toolB"); + DefaultChatOptions chatOptions = new DefaultChatOptions(); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .options(chatOptions) + .toolNames(toolNames1.toArray(new String[0])); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolNames()).containsExactlyInAnyOrderElementsOf(toolNames1); + } + + @Test + void whenToolCallbacksAndChatOptionsAreDefaultChatOptions() { + ToolCallback toolCallback1 = new TestToolCallback("tool1"); + DefaultChatOptions chatOptions = new DefaultChatOptions(); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .options(chatOptions) + .toolCallbacks(toolCallback1); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolCallbacks()).containsExactlyInAnyOrder(toolCallback1); + } + + @Test + void whenToolContextAndChatOptionsAreDefaultChatOptions() { + Map toolContext1 = Map.of("key1", "value1"); + DefaultChatOptions chatOptions = new DefaultChatOptions(); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .options(chatOptions) + .toolContext(toolContext1); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolContext()).containsAllEntriesOf(toolContext1); + } + @Test void whenAdvisorParamsAreProvidedThenTheyAreAddedToContext() { Map advisorParams = Map.of("key1", "value1", "key2", "value2"); From 36a93d2eaba35056165afca9d8d651ef6e2ff705 Mon Sep 17 00:00:00 2001 From: Sun Yuhan Date: Fri, 25 Jul 2025 10:27:37 +0800 Subject: [PATCH 2/2] optimization: Simplify the method for converting `defaultChatOptions` to `DefaultToolCallingChatOptions` Signed-off-by: Sun Yuhan --- .../ai/chat/client/DefaultChatClientUtils.java | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java index f68708e7893..030cf39f35c 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java @@ -30,6 +30,7 @@ import org.springframework.ai.chat.prompt.DefaultChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.tool.DefaultToolCallingChatOptions; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.ToolCallback; @@ -101,16 +102,8 @@ static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClient if (processedChatOptions instanceof DefaultChatOptions defaultChatOptions) { if (!inputRequest.getToolNames().isEmpty() || !inputRequest.getToolCallbacks().isEmpty() || !CollectionUtils.isEmpty(inputRequest.getToolContext())) { - processedChatOptions = DefaultToolCallingChatOptions.builder() - .model(defaultChatOptions.getModel()) - .frequencyPenalty(defaultChatOptions.getFrequencyPenalty()) - .maxTokens(defaultChatOptions.getMaxTokens()) - .presencePenalty(defaultChatOptions.getPresencePenalty()) - .stopSequences(defaultChatOptions.getStopSequences()) - .temperature(defaultChatOptions.getTemperature()) - .topK(defaultChatOptions.getTopK()) - .topP(defaultChatOptions.getTopP()) - .build(); + processedChatOptions = ModelOptionsUtils.copyToTarget(defaultChatOptions, ChatOptions.class, + DefaultToolCallingChatOptions.class); } }