From 11bb3602a6cd62870d59a9ed072572c8215ccfc3 Mon Sep 17 00:00:00 2001 From: Ilayaperumal Gopinathan Date: Tue, 10 Dec 2024 10:40:54 +0000 Subject: [PATCH 1/2] refactor ChatOptions Builder - Deprecate existing ChatOptionsBuilder and its inner class DefaultChatOptions - Create a new builder interface ChatOptions.Builder for building the Chat options - Create an explicit DefaultChatOptions - Create DefaultChatOptionBuilder which can create DefaultChatOptions Resolves #1875 --- .../ai/anthropic/AnthropicChatModel.java | 15 +- .../converse/BedrockProxyChatModel.java | 13 +- .../BedrockConverseChatModelMain.java | 4 +- .../ai/huggingface/HuggingfaceChatModel.java | 7 - .../ai/minimax/MiniMaxChatModel.java | 17 ++- .../ai/mistralai/MistralAiChatModel.java | 13 +- .../ai/moonshot/MoonshotChatModel.java | 17 ++- .../ai/ollama/OllamaChatModel.java | 19 ++- .../ai/ollama/OllamaChatModelIT.java | 4 +- .../ai/ollama/OllamaChatRequestTests.java | 7 +- .../ai/openai/OpenAiChatModel.java | 17 ++- .../ai/openai/chat/OpenAiChatModelIT.java | 4 +- .../ai/qianfan/QianFanChatModel.java | 17 ++- .../ai/watsonx/WatsonxAiChatModelTest.java | 4 +- .../ai/zhipuai/ZhiPuAiChatModel.java | 13 +- .../ai/chat/model/ChatModel.java | 4 +- .../ai/chat/prompt/ChatOptions.java | 78 +++++++++++ .../ai/chat/prompt/ChatOptionsBuilder.java | 1 + .../ai/chat/prompt/DefaultChatOptions.java | 128 ++++++++++++++++++ .../prompt/DefaultChatOptionsBuilder.java | 72 ++++++++++ .../TranslationQueryTransformer.java | 4 +- .../ai/chat/ChatBuilderTests.java | 19 ++- .../chat/client/DefaultChatClientTests.java | 5 +- ...ModelCompletionObservationFilterTests.java | 8 +- ...odelCompletionObservationHandlerTests.java | 4 +- ...ChatModelMeterObservationHandlerTests.java | 4 +- .../ChatModelObservationContextTests.java | 4 +- ...elPromptContentObservationFilterTests.java | 8 +- ...lPromptContentObservationHandlerTests.java | 4 +- ...ltChatModelObservationConventionTests.java | 32 ++--- .../ai/prompt/PromptTemplateTest.java | 7 +- .../ai/prompt/PromptTests.java | 3 +- 32 files changed, 405 insertions(+), 151 deletions(-) create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index c68fb6d29bc..4f1c09df77e 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -57,7 +57,6 @@ import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.Media; import org.springframework.ai.model.ModelOptionsUtils; @@ -471,13 +470,13 @@ private List getFunctionTools(Set functionNames) { } private ChatOptions buildRequestOptions(AnthropicApi.ChatCompletionRequest request) { - return ChatOptionsBuilder.builder() - .withModel(request.model()) - .withMaxTokens(request.maxTokens()) - .withStopSequences(request.stopSequences()) - .withTemperature(request.temperature()) - .withTopK(request.topK()) - .withTopP(request.topP()) + return ChatOptions.builder() + .model(request.model()) + .maxTokens(request.maxTokens()) + .stopSequences(request.stopSequences()) + .temperature(request.temperature()) + .topK(request.topK()) + .topP(request.topP()) .build(); } diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index 9bf22cf5931..a1bbd10c0c4 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -93,7 +93,6 @@ import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.Media; import org.springframework.ai.model.ModelOptionsUtils; @@ -219,13 +218,13 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatRespon } private ChatOptions buildRequestOptions(ConverseRequest request) { - return ChatOptionsBuilder.builder() - .withModel(request.modelId()) - .withMaxTokens(request.inferenceConfig().maxTokens()) - .withStopSequences(request.inferenceConfig().stopSequences()) - .withTemperature(request.inferenceConfig().temperature() != null + return ChatOptions.builder() + .model(request.modelId()) + .maxTokens(request.inferenceConfig().maxTokens()) + .stopSequences(request.inferenceConfig().stopSequences()) + .temperature(request.inferenceConfig().temperature() != null ? request.inferenceConfig().temperature().doubleValue() : null) - .withTopP(request.inferenceConfig().topP() != null ? request.inferenceConfig().topP().doubleValue() : null) + .topP(request.inferenceConfig().topP() != null ? request.inferenceConfig().topP().doubleValue() : null) .build(); } diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain.java index 7404f3d4b26..e9008c0ced5 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain.java @@ -20,7 +20,7 @@ import software.amazon.awssdk.regions.Region; import org.springframework.ai.bedrock.converse.BedrockProxyChatModel; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; /** @@ -40,7 +40,7 @@ public static void main(String[] args) { // String modelId = "anthropic.claude-3-5-sonnet-20240620-v1:0"; String modelId = "ai21.jamba-1-5-large-v1:0"; - var prompt = new Prompt("Tell me a joke?", ChatOptionsBuilder.builder().withModel(modelId).build()); + var prompt = new Prompt("Tell me a joke?", ChatOptions.builder().model(modelId).build()); var chatModel = BedrockProxyChatModel.builder() .withCredentialsProvider(EnvironmentVariableCredentialsProvider.create()) diff --git a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java index aa222f4f320..5546b4c54d2 100644 --- a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java +++ b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java @@ -27,8 +27,6 @@ import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.huggingface.api.TextGenerationInferenceApi; import org.springframework.ai.huggingface.invoker.ApiClient; @@ -128,9 +126,4 @@ public void setMaxNewTokens(int maxNewTokens) { this.maxNewTokens = maxNewTokens; } - @Override - public ChatOptions getDefaultOptions() { - return ChatOptionsBuilder.builder().build(); - } - } diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java index 754dc77f589..bd70c9cdd2d 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java @@ -48,7 +48,6 @@ import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.minimax.api.MiniMaxApi; import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletion; @@ -374,14 +373,14 @@ protected boolean isToolCall(Generation generation, Set toolCallFinishRe } private ChatOptions buildRequestOptions(ChatCompletionRequest request) { - return ChatOptionsBuilder.builder() - .withModel(request.model()) - .withFrequencyPenalty(request.frequencyPenalty()) - .withMaxTokens(request.maxTokens()) - .withPresencePenalty(request.presencePenalty()) - .withStopSequences(request.stop()) - .withTemperature(request.temperature()) - .withTopP(request.topP()) + return ChatOptions.builder() + .model(request.model()) + .frequencyPenalty(request.frequencyPenalty()) + .maxTokens(request.maxTokens()) + .presencePenalty(request.presencePenalty()) + .stopSequences(request.stop()) + .temperature(request.temperature()) + .topP(request.topP()) .build(); } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index c7b967a5673..ac027db009a 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -46,7 +46,6 @@ import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletion; @@ -406,12 +405,12 @@ private List getFunctionTools(Set functionNam } private ChatOptions buildRequestOptions(MistralAiApi.ChatCompletionRequest request) { - return ChatOptionsBuilder.builder() - .withModel(request.model()) - .withMaxTokens(request.maxTokens()) - .withStopSequences(request.stop()) - .withTemperature(request.temperature()) - .withTopP(request.topP()) + return ChatOptions.builder() + .model(request.model()) + .maxTokens(request.maxTokens()) + .stopSequences(request.stop()) + .temperature(request.temperature()) + .topP(request.topP()) .build(); } diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java index 0d321996bc5..1f4d6b7198f 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java @@ -47,7 +47,6 @@ import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.FunctionCallback; @@ -420,14 +419,14 @@ else if (message.getMessageType() == MessageType.TOOL) { } private ChatOptions buildRequestOptions(MoonshotApi.ChatCompletionRequest request) { - return ChatOptionsBuilder.builder() - .withModel(request.model()) - .withFrequencyPenalty(request.frequencyPenalty()) - .withMaxTokens(request.maxTokens()) - .withPresencePenalty(request.presencePenalty()) - .withStopSequences(request.stop()) - .withTemperature(request.temperature()) - .withTopP(request.topP()) + return ChatOptions.builder() + .model(request.model()) + .frequencyPenalty(request.frequencyPenalty()) + .maxTokens(request.maxTokens()) + .presencePenalty(request.presencePenalty()) + .stopSequences(request.stop()) + .temperature(request.temperature()) + .topP(request.topP()) .build(); } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 09cf70af019..fd2258bcbac 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -45,7 +45,6 @@ import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.FunctionCallback; @@ -421,15 +420,15 @@ private List getFunctionTools(Set functionNames) { private ChatOptions buildRequestOptions(OllamaApi.ChatRequest request) { var options = ModelOptionsUtils.mapToClass(request.options(), OllamaOptions.class); - return ChatOptionsBuilder.builder() - .withModel(request.model()) - .withFrequencyPenalty(options.getFrequencyPenalty()) - .withMaxTokens(options.getMaxTokens()) - .withPresencePenalty(options.getPresencePenalty()) - .withStopSequences(options.getStopSequences()) - .withTemperature(options.getTemperature()) - .withTopK(options.getTopK()) - .withTopP(options.getTopP()) + return ChatOptions.builder() + .model(request.model()) + .frequencyPenalty(options.getFrequencyPenalty()) + .maxTokens(options.getMaxTokens()) + .presencePenalty(options.getPresencePenalty()) + .stopSequences(options.getStopSequences()) + .temperature(options.getTemperature()) + .topK(options.getTopK()) + .topP(options.getTopP()) .build(); } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java index 8cc3b4042f6..58eb514501a 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java @@ -29,7 +29,7 @@ import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -91,7 +91,7 @@ void roleTest() { UserMessage userMessage = new UserMessage("Tell me about 5 famous pirates from the Golden Age of Piracy."); // portable/generic options - var portableOptions = ChatOptionsBuilder.builder().withTemperature(0.7).build(); + var portableOptions = ChatOptions.builder().temperature(0.7).build(); Prompt prompt = new Prompt(List.of(systemMessage, userMessage), portableOptions); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java index 5e8e74107c7..0d24a458332 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java @@ -19,7 +19,6 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaOptions; @@ -78,11 +77,7 @@ public void createRequestWithPromptOllamaOptions() { public void createRequestWithPromptPortableChatOptions() { // Ollama runtime options. - ChatOptions portablePromptOptions = ChatOptionsBuilder.builder() - .withTemperature(0.9) - .withTopK(100) - .withTopP(0.6) - .build(); + ChatOptions portablePromptOptions = ChatOptions.builder().temperature(0.9).topK(100).topP(0.6).build(); var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content", portablePromptOptions), true); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 77be632b5e3..bb7c0dcee82 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -53,7 +53,6 @@ import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.Media; import org.springframework.ai.model.ModelOptionsUtils; @@ -595,14 +594,14 @@ private List getFunctionTools(Set functionNames) } private ChatOptions buildRequestOptions(OpenAiApi.ChatCompletionRequest request) { - return ChatOptionsBuilder.builder() - .withModel(request.model()) - .withFrequencyPenalty(request.frequencyPenalty()) - .withMaxTokens(request.maxTokens()) - .withPresencePenalty(request.presencePenalty()) - .withStopSequences(request.stop()) - .withTemperature(request.temperature()) - .withTopP(request.topP()) + return ChatOptions.builder() + .model(request.model()) + .frequencyPenalty(request.frequencyPenalty()) + .maxTokens(request.maxTokens()) + .presencePenalty(request.presencePenalty()) + .stopSequences(request.stop()) + .temperature(request.temperature()) + .topP(request.topP()) .build(); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java index 9f84ea83f0c..f0314b1980b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java @@ -43,7 +43,7 @@ import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -518,7 +518,7 @@ void multiModalityInputAudio(String modelName) { List.of(new Media(MimeTypeUtils.parseMimeType("audio/mp3"), audioResource))); ChatResponse response = chatModel - .call(new Prompt(List.of(userMessage), ChatOptionsBuilder.builder().withModel(modelName).build())); + .call(new Prompt(List.of(userMessage), ChatOptions.builder().model(modelName).build())); logger.info(response.getResult().getOutput().getText()); assertThat(response.getResult().getOutput().getText()).containsIgnoringCase("hobbits"); diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java index 074655b5e72..dc5735d9d3b 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java @@ -41,7 +41,6 @@ import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.qianfan.api.QianFanApi; @@ -278,14 +277,14 @@ public ChatOptions getDefaultOptions() { } private ChatOptions buildRequestOptions(QianFanApi.ChatCompletionRequest request) { - return ChatOptionsBuilder.builder() - .withModel(request.model()) - .withFrequencyPenalty(request.frequencyPenalty()) - .withMaxTokens(request.maxTokens()) - .withPresencePenalty(request.presencePenalty()) - .withStopSequences(request.stop()) - .withTemperature(request.temperature()) - .withTopP(request.topP()) + return ChatOptions.builder() + .model(request.model()) + .frequencyPenalty(request.frequencyPenalty()) + .maxTokens(request.maxTokens()) + .presencePenalty(request.presencePenalty()) + .stopSequences(request.stop()) + .temperature(request.temperature()) + .topP(request.topP()) .build(); } diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java index 3607fbbe26b..9909600409e 100644 --- a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java @@ -31,7 +31,7 @@ import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.watsonx.api.WatsonxAiApi; import org.springframework.ai.watsonx.api.WatsonxAiChatRequest; @@ -54,7 +54,7 @@ public class WatsonxAiChatModelTest { @Test public void testCreateRequestWithNoModelId() { - var options = ChatOptionsBuilder.builder().withTemperature(0.9).withTopK(100).withTopP(0.6).build(); + var options = ChatOptions.builder().temperature(0.9).topK(100).topP(0.6).build(); Prompt prompt = new Prompt("Test message", options); diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java index 6adac855999..e3043b1d6bd 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java @@ -50,7 +50,6 @@ import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.FunctionCallback; @@ -459,12 +458,12 @@ else if (mediaContentData instanceof String text) { } private ChatOptions buildRequestOptions(ZhiPuAiApi.ChatCompletionRequest request) { - return ChatOptionsBuilder.builder() - .withModel(request.model()) - .withMaxTokens(request.maxTokens()) - .withStopSequences(request.stop()) - .withTemperature(request.temperature()) - .withTopP(request.topP()) + return ChatOptions.builder() + .model(request.model()) + .maxTokens(request.maxTokens()) + .stopSequences(request.stop()) + .temperature(request.temperature()) + .topP(request.topP()) .build(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatModel.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatModel.java index 45ffa1bc76f..95bebeb2ea6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatModel.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatModel.java @@ -43,7 +43,9 @@ default String call(Message... messages) { @Override ChatResponse call(Prompt prompt); - ChatOptions getDefaultOptions(); + default ChatOptions getDefaultOptions() { + return ChatOptions.builder().build(); + } default Flux stream(Prompt prompt) { throw new UnsupportedOperationException("streaming is not supported"); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java index 5fb6dffb1ed..81af1b7f349 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java @@ -89,4 +89,82 @@ public interface ChatOptions extends ModelOptions { */ ChatOptions copy(); + /** + * Creates a new {@link ChatOptions.Builder} to create the default + * {@link ChatOptions}. + * @return Returns a new {@link ChatOptions.Builder}. + */ + static ChatOptions.Builder builder() { + return new DefaultChatOptionsBuilder(); + } + + /** + * Builder for creating {@link ChatOptions} instance. + */ + interface Builder { + + /** + * Builds with the model to use for the chat. + * @param model + * @return the builder + */ + Builder model(String model); + + /** + * Builds with the frequency penalty to use for the chat. + * @param frequencyPenalty + * @return the builder. + */ + Builder frequencyPenalty(Double frequencyPenalty); + + /** + * Builds with the maximum number of tokens to use for the chat. + * @param maxTokens + * @return the builder. + */ + Builder maxTokens(Integer maxTokens); + + /** + * Builds with the presence penalty to use for the chat. + * @param presencePenalty + * @return the builder. + */ + Builder presencePenalty(Double presencePenalty); + + /** + * Builds with the stop sequences to use for the chat. + * @param stopSequences + * @return the builder. + */ + Builder stopSequences(List stopSequences); + + /** + * Builds with the temperature to use for the chat. + * @param temperature + * @return the builder. + */ + Builder temperature(Double temperature); + + /** + * Builds with the top K to use for the chat. + * @param topK + * @return the builder. + */ + Builder topK(Integer topK); + + /** + * Builds with the top P to use for the chat. + * @param topP + * @return the builder. + */ + Builder topP(Double topP); + + /** + * Build the {@link ChatOptions}. + * @return the Chat options. + */ + ChatOptions build(); + + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java index cc59c1ca6e4..a35bd3dc370 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java @@ -18,6 +18,7 @@ import java.util.List; +@Deprecated(forRemoval = true, since = "1.0.0-M5") public final class ChatOptionsBuilder { private final DefaultChatOptions options = new DefaultChatOptions(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java new file mode 100644 index 00000000000..9feca34a99e --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java @@ -0,0 +1,128 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.prompt; + +import java.util.List; + +/** + * Default implementation for the {@link ChatOptions}. + */ +public class DefaultChatOptions implements ChatOptions { + + private String model; + + private Double frequencyPenalty; + + private Integer maxTokens; + + private Double presencePenalty; + + private List stopSequences; + + private Double temperature; + + private Integer topK; + + private Double topP; + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Double getFrequencyPenalty() { + return this.frequencyPenalty; + } + + public void setFrequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + @Override + public Double getPresencePenalty() { + return this.presencePenalty; + } + + public void setPresencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + @Override + public List getStopSequences() { + return this.stopSequences; + } + + public void setStopSequences(List stopSequences) { + this.stopSequences = stopSequences; + } + + @Override + public Double getTemperature() { + return this.temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + @Override + public Integer getTopK() { + return this.topK; + } + + public void setTopK(Integer topK) { + this.topK = topK; + } + + @Override + public Double getTopP() { + return this.topP; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + @Override + public ChatOptions copy() { + return ChatOptions.builder() + .model(this.model) + .frequencyPenalty(this.frequencyPenalty) + .maxTokens(this.maxTokens) + .presencePenalty(this.presencePenalty) + .stopSequences(this.stopSequences != null ? List.copyOf(this.stopSequences) : null) + .temperature(this.temperature) + .topK(this.topK) + .topP(this.topP) + .build(); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java new file mode 100644 index 00000000000..d85b60478ba --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java @@ -0,0 +1,72 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.prompt; + +import java.util.List; + +/** + * Implementation of {@link ChatOptions.Builder} to create {@link DefaultChatOptions}. + */ +public class DefaultChatOptionsBuilder implements ChatOptions.Builder { + + private final DefaultChatOptions options = new DefaultChatOptions(); + + public ChatOptions.Builder model(String model) { + this.options.setModel(model); + return this; + } + + public ChatOptions.Builder frequencyPenalty(Double frequencyPenalty) { + this.options.setFrequencyPenalty(frequencyPenalty); + return this; + } + + public ChatOptions.Builder maxTokens(Integer maxTokens) { + this.options.setMaxTokens(maxTokens); + return this; + } + + public ChatOptions.Builder presencePenalty(Double presencePenalty) { + this.options.setPresencePenalty(presencePenalty); + return this; + } + + public ChatOptions.Builder stopSequences(List stop) { + this.options.setStopSequences(stop); + return this; + } + + public ChatOptions.Builder temperature(Double temperature) { + this.options.setTemperature(temperature); + return this; + } + + public ChatOptions.Builder topK(Integer topK) { + this.options.setTopK(topK); + return this; + } + + public ChatOptions.Builder topP(Double topP) { + this.options.setTopP(topP); + return this; + } + + public ChatOptions build() { + return this.options; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/TranslationQueryTransformer.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/TranslationQueryTransformer.java index 4fe0837ea30..2a635d59daf 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/TranslationQueryTransformer.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/TranslationQueryTransformer.java @@ -20,7 +20,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.rag.Query; import org.springframework.ai.util.PromptAssert; @@ -91,7 +91,7 @@ public Query transform(Query query) { .user(user -> user.text(this.promptTemplate.getTemplate()) .param("targetLanguage", this.targetLanguage) .param("query", query.text())) - .options(ChatOptionsBuilder.builder().withTemperature(0.0).build()) + .options(ChatOptions.builder().temperature(0.0).build()) .call() .content(); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java index de1a69af65e..bcdaf87761b 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java @@ -24,7 +24,6 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; @@ -45,11 +44,7 @@ void createNewChatOptionsTest() { Double topP = 2.2; Integer topK = 111; - ChatOptions options = ChatOptionsBuilder.builder() - .withTemperature(temperature) - .withTopK(topK) - .withTopP(topP) - .build(); + ChatOptions options = ChatOptions.builder().temperature(temperature).topK(topK).topP(topP).build(); assertThat(options.getTemperature()).isEqualTo(temperature); assertThat(options.getTopP()).isEqualTo(topP); @@ -62,11 +57,13 @@ void duplicateChatOptionsTest() { Double initTopP = 2.2; Integer initTopK = 111; - ChatOptions options = ChatOptionsBuilder.builder() - .withTemperature(initTemperature) - .withTopP(initTopP) - .withTopK(initTopK) - .build(); + ChatOptions options1 = ChatOptions.builder().temperature(initTemperature).topP(initTopP).topK(initTopK).build(); + + ChatOptions options2 = options1.copy(); + + assertThat(options2.getTemperature()).isEqualTo(initTemperature); + assertThat(options2.getTopP()).isEqualTo(initTopP); + assertThat(options2.getTopK()).isEqualTo(initTopK); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java index c610daf4401..5f7951ca5cd 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java @@ -43,7 +43,6 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.StructuredOutputConverter; @@ -115,7 +114,7 @@ void whenPromptWithMessagesThenReturn() { @Test void whenPromptWithOptionsThenReturn() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); - ChatOptions chatOptions = ChatOptionsBuilder.builder().build(); + ChatOptions chatOptions = ChatOptions.builder().build(); Prompt prompt = new Prompt(List.of(), chatOptions); DefaultChatClient.DefaultChatClientRequestSpec spec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient .prompt(prompt); @@ -1345,7 +1344,7 @@ void whenOptionsIsNullThenThrow() { void whenOptionsThenReturn() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - ChatOptions options = ChatOptionsBuilder.builder().build(); + ChatOptions options = ChatOptions.builder().build(); spec = spec.options(options); DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; assertThat(defaultSpec.getChatOptions()).isEqualTo(options); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilterTests.java index 2ee37ac283d..a1cbda26f3f 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilterTests.java @@ -25,7 +25,7 @@ import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import static org.assertj.core.api.Assertions.assertThat; @@ -53,7 +53,7 @@ void whenEmptyResponseThenReturnOriginalContext() { var expectedContext = ChatModelObservationContext.builder() .prompt(generatePrompt()) .provider("superprovider") - .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) + .requestOptions(ChatOptions.builder().model("mistral").build()) .build(); var actualContext = this.observationFilter.map(expectedContext); @@ -65,7 +65,7 @@ void whenEmptyCompletionThenReturnOriginalContext() { var expectedContext = ChatModelObservationContext.builder() .prompt(generatePrompt()) .provider("superprovider") - .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) + .requestOptions(ChatOptions.builder().model("mistral").build()) .build(); expectedContext.setResponse(new ChatResponse(List.of(new Generation(new AssistantMessage(""))))); var actualContext = this.observationFilter.map(expectedContext); @@ -78,7 +78,7 @@ void whenCompletionWithTextThenAugmentContext() { var originalContext = ChatModelObservationContext.builder() .prompt(generatePrompt()) .provider("superprovider") - .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) + .requestOptions(ChatOptions.builder().model("mistral").build()) .build(); originalContext.setResponse(new ChatResponse(List.of(new Generation(new AssistantMessage("say please")), new Generation(new AssistantMessage("seriously, say please"))))); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandlerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandlerTests.java index 225fcbce50c..8e8218acf78 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandlerTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandlerTests.java @@ -29,7 +29,7 @@ import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.observation.conventions.AiObservationAttributes; import org.springframework.ai.observation.conventions.AiObservationEventNames; @@ -49,7 +49,7 @@ void whenCompletionWithTextThenSpanEvent() { var observationContext = ChatModelObservationContext.builder() .prompt(new Prompt("supercalifragilisticexpialidocious")) .provider("mary-poppins") - .requestOptions(ChatOptionsBuilder.builder().withModel("spoonful-of-sugar").build()) + .requestOptions(ChatOptions.builder().model("spoonful-of-sugar").build()) .build(); observationContext.setResponse(new ChatResponse(List.of(new Generation(new AssistantMessage("say please")), new Generation(new AssistantMessage("seriously, say please"))))); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java index cf097fc89f5..ceea12131f2 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java @@ -30,7 +30,7 @@ import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.observation.conventions.AiObservationMetricAttributes; import org.springframework.ai.observation.conventions.AiObservationMetricNames; @@ -94,7 +94,7 @@ private ChatModelObservationContext generateObservationContext() { return ChatModelObservationContext.builder() .prompt(generatePrompt()) .provider("superprovider") - .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) + .requestOptions(ChatOptions.builder().model("mistral").build()) .build(); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContextTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContextTests.java index a7c62a462b9..3e52f6a9fa9 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContextTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContextTests.java @@ -18,7 +18,7 @@ import org.junit.jupiter.api.Test; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import static org.assertj.core.api.Assertions.assertThat; @@ -36,7 +36,7 @@ void whenMandatoryRequestOptionsThenReturn() { var observationContext = ChatModelObservationContext.builder() .prompt(generatePrompt()) .provider("superprovider") - .requestOptions(ChatOptionsBuilder.builder().withModel("supermodel").build()) + .requestOptions(ChatOptions.builder().model("supermodel").build()) .build(); assertThat(observationContext).isNotNull(); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilterTests.java index 8e33c73e0c1..c05dd3ef9aa 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilterTests.java @@ -24,7 +24,7 @@ import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import static org.assertj.core.api.Assertions.assertThat; @@ -52,7 +52,7 @@ void whenEmptyPromptThenReturnOriginalContext() { var expectedContext = ChatModelObservationContext.builder() .prompt(new Prompt(List.of())) .provider("superprovider") - .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) + .requestOptions(ChatOptions.builder().model("mistral").build()) .build(); var actualContext = this.observationFilter.map(expectedContext); @@ -64,7 +64,7 @@ void whenPromptWithTextThenAugmentContext() { var originalContext = ChatModelObservationContext.builder() .prompt(new Prompt("supercalifragilisticexpialidocious")) .provider("superprovider") - .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) + .requestOptions(ChatOptions.builder().model("mistral").build()) .build(); var augmentedContext = this.observationFilter.map(originalContext); @@ -78,7 +78,7 @@ void whenPromptWithMessagesThenAugmentContext() { .prompt(new Prompt(List.of(new SystemMessage("you're a chimney sweep"), new UserMessage("supercalifragilisticexpialidocious")))) .provider("superprovider") - .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) + .requestOptions(ChatOptions.builder().model("mistral").build()) .build(); var augmentedContext = this.observationFilter.map(originalContext); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationHandlerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationHandlerTests.java index 4064d9570da..ab90a855100 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationHandlerTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationHandlerTests.java @@ -24,7 +24,7 @@ import io.opentelemetry.sdk.trace.SdkTracerProvider; import org.junit.jupiter.api.Test; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.observation.conventions.AiObservationAttributes; import org.springframework.ai.observation.conventions.AiObservationEventNames; @@ -44,7 +44,7 @@ void whenPromptWithTextThenSpanEvent() { var observationContext = ChatModelObservationContext.builder() .prompt(new Prompt("supercalifragilisticexpialidocious")) .provider("mary-poppins") - .requestOptions(ChatOptionsBuilder.builder().withModel("spoonful-of-sugar").build()) + .requestOptions(ChatOptions.builder().model("spoonful-of-sugar").build()) .build(); var sdkTracer = SdkTracerProvider.builder().build().get("test"); var otelTracer = new OtelTracer(sdkTracer, new OtelCurrentTraceContext(), null); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java index 8c371bfd2e2..dd4840d07e2 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java @@ -28,7 +28,7 @@ import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import static org.assertj.core.api.Assertions.assertThat; @@ -54,7 +54,7 @@ void contextualNameWhenModelIsDefined() { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(generatePrompt()) .provider("superprovider") - .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) + .requestOptions(ChatOptions.builder().model("mistral").build()) .build(); assertThat(this.observationConvention.getContextualName(observationContext)).isEqualTo("chat mistral"); } @@ -64,7 +64,7 @@ void contextualNameWhenModelIsNotDefined() { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(generatePrompt()) .provider("superprovider") - .requestOptions(ChatOptionsBuilder.builder().build()) + .requestOptions(ChatOptions.builder().build()) .build(); assertThat(this.observationConvention.getContextualName(observationContext)).isEqualTo("chat"); } @@ -74,7 +74,7 @@ void supportsOnlyChatModelObservationContext() { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(generatePrompt()) .provider("superprovider") - .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) + .requestOptions(ChatOptions.builder().model("mistral").build()) .build(); assertThat(this.observationConvention.supportsContext(observationContext)).isTrue(); assertThat(this.observationConvention.supportsContext(new Observation.Context())).isFalse(); @@ -85,7 +85,7 @@ void shouldHaveLowCardinalityKeyValuesWhenDefined() { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(generatePrompt()) .provider("superprovider") - .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) + .requestOptions(ChatOptions.builder().model("mistral").build()) .build(); assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains( KeyValue.of(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), "chat"), @@ -98,15 +98,15 @@ void shouldHaveKeyValuesWhenDefinedAndResponse() { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(generatePrompt()) .provider("superprovider") - .requestOptions(ChatOptionsBuilder.builder() - .withModel("mistral") - .withFrequencyPenalty(0.8) - .withMaxTokens(200) - .withPresencePenalty(1.0) - .withStopSequences(List.of("addio", "bye")) - .withTemperature(0.5) - .withTopK(1) - .withTopP(0.9) + .requestOptions(ChatOptions.builder() + .model("mistral") + .frequencyPenalty(0.8) + .maxTokens(200) + .presencePenalty(1.0) + .stopSequences(List.of("addio", "bye")) + .temperature(0.5) + .topK(1) + .topP(0.9) .build()) .build(); observationContext.setResponse(new ChatResponse( @@ -139,7 +139,7 @@ void shouldNotHaveKeyValuesWhenMissing() { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(generatePrompt()) .provider("superprovider") - .requestOptions(ChatOptionsBuilder.builder().build()) + .requestOptions(ChatOptions.builder().build()) .build(); assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)) .contains(KeyValue.of(LowCardinalityKeyNames.REQUEST_MODEL.asString(), KeyValue.NONE_VALUE)) @@ -165,7 +165,7 @@ void shouldNotHaveKeyValuesWhenEmptyValues() { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(generatePrompt()) .provider("superprovider") - .requestOptions(ChatOptionsBuilder.builder().withStopSequences(List.of()).build()) + .requestOptions(ChatOptions.builder().stopSequences(List.of()).build()) .build(); observationContext.setResponse(new ChatResponse( List.of(new Generation(new AssistantMessage("response"), diff --git a/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java b/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java index 57c0a3c76fb..0ab44868dd2 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java @@ -29,7 +29,6 @@ import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.core.io.InputStreamResource; @@ -58,7 +57,7 @@ private static void assertEqualsWithNormalizedEOLs(String expected, String actua public void testCreateWithEmptyModelAndChatOptions() { String template = "This is a test prompt with no variables"; PromptTemplate promptTemplate = new PromptTemplate(template); - ChatOptions chatOptions = ChatOptionsBuilder.builder().withTemperature(0.7).withTopK(3).build(); + ChatOptions chatOptions = ChatOptions.builder().temperature(0.7).topK(3).build(); Prompt prompt = promptTemplate.create(chatOptions); @@ -74,7 +73,7 @@ public void testCreateWithModelAndChatOptions() { model.put("name", "Alice"); model.put("age", 30); PromptTemplate promptTemplate = new PromptTemplate(template, model); - ChatOptions chatOptions = ChatOptionsBuilder.builder().withTemperature(0.5).withMaxTokens(100).build(); + ChatOptions chatOptions = ChatOptions.builder().temperature(0.5).maxTokens(100).build(); Prompt prompt = promptTemplate.create(model, chatOptions); @@ -93,7 +92,7 @@ public void testCreateWithOverriddenModelAndChatOptions() { Map overriddenModel = new HashMap<>(); overriddenModel.put("color", "red"); - ChatOptions chatOptions = ChatOptionsBuilder.builder().withTemperature(0.8).build(); + ChatOptions chatOptions = ChatOptions.builder().temperature(0.8).build(); Prompt prompt = promptTemplate.create(overriddenModel, chatOptions); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTests.java b/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTests.java index 1fecc7d39b5..0ef04415c04 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/prompt/PromptTests.java @@ -24,7 +24,6 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -139,7 +138,7 @@ public void testPromptCopy() { model.put("name", "Alice"); model.put("age", 30); PromptTemplate promptTemplate = new PromptTemplate(template, model); - ChatOptions chatOptions = ChatOptionsBuilder.builder().withTemperature(0.5).withMaxTokens(100).build(); + ChatOptions chatOptions = ChatOptions.builder().temperature(0.5).maxTokens(100).build(); Prompt prompt = promptTemplate.create(model, chatOptions); From 1d9a32da543b17f365e46de40e27b75c2dd860b6 Mon Sep 17 00:00:00 2001 From: Ilayaperumal Gopinathan Date: Wed, 11 Dec 2024 16:04:14 +0000 Subject: [PATCH 2/2] Add javadoc for the deprecated Builder --- .../org/springframework/ai/chat/prompt/ChatOptionsBuilder.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java index a35bd3dc370..221b6abe31b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java @@ -18,6 +18,9 @@ import java.util.List; +/** + * @deprecated Use {@link ChatOptions.Builder} instead. + */ @Deprecated(forRemoval = true, since = "1.0.0-M5") public final class ChatOptionsBuilder {