Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.Media;
import org.springframework.ai.model.ModelOptionsUtils;
Expand Down Expand Up @@ -471,13 +470,13 @@ private List<AnthropicApi.Tool> getFunctionTools(Set<String> functionNames) {
}

private ChatOptions buildRequestOptions(AnthropicApi.ChatCompletionRequest request) {
return ChatOptionsBuilder.builder()
.withModel(request.model())
.withMaxTokens(request.maxTokens())
.withStopSequences(request.stopSequences())
.withTemperature(request.temperature())
.withTopK(request.topK())
.withTopP(request.topP())
return ChatOptions.builder()
.model(request.model())
.maxTokens(request.maxTokens())
.stopSequences(request.stopSequences())
.temperature(request.temperature())
.topK(request.topK())
.topP(request.topP())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.Media;
import org.springframework.ai.model.ModelOptionsUtils;
Expand Down Expand Up @@ -219,13 +218,13 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatRespon
}

private ChatOptions buildRequestOptions(ConverseRequest request) {
return ChatOptionsBuilder.builder()
.withModel(request.modelId())
.withMaxTokens(request.inferenceConfig().maxTokens())
.withStopSequences(request.inferenceConfig().stopSequences())
.withTemperature(request.inferenceConfig().temperature() != null
return ChatOptions.builder()
.model(request.modelId())
.maxTokens(request.inferenceConfig().maxTokens())
.stopSequences(request.inferenceConfig().stopSequences())
.temperature(request.inferenceConfig().temperature() != null
? request.inferenceConfig().temperature().doubleValue() : null)
.withTopP(request.inferenceConfig().topP() != null ? request.inferenceConfig().topP().doubleValue() : null)
.topP(request.inferenceConfig().topP() != null ? request.inferenceConfig().topP().doubleValue() : null)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import software.amazon.awssdk.regions.Region;

import org.springframework.ai.bedrock.converse.BedrockProxyChatModel;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;

/**
Expand All @@ -40,7 +40,7 @@ public static void main(String[] args) {

// String modelId = "anthropic.claude-3-5-sonnet-20240620-v1:0";
String modelId = "ai21.jamba-1-5-large-v1:0";
var prompt = new Prompt("Tell me a joke?", ChatOptionsBuilder.builder().withModel(modelId).build());
var prompt = new Prompt("Tell me a joke?", ChatOptions.builder().model(modelId).build());

var chatModel = BedrockProxyChatModel.builder()
.withCredentialsProvider(EnvironmentVariableCredentialsProvider.create())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.huggingface.api.TextGenerationInferenceApi;
import org.springframework.ai.huggingface.invoker.ApiClient;
Expand Down Expand Up @@ -128,9 +126,4 @@ public void setMaxNewTokens(int maxNewTokens) {
this.maxNewTokens = maxNewTokens;
}

@Override
public ChatOptions getDefaultOptions() {
return ChatOptionsBuilder.builder().build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.minimax.api.MiniMaxApi;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletion;
Expand Down Expand Up @@ -374,14 +373,14 @@ protected boolean isToolCall(Generation generation, Set<String> toolCallFinishRe
}

private ChatOptions buildRequestOptions(ChatCompletionRequest request) {
return ChatOptionsBuilder.builder()
.withModel(request.model())
.withFrequencyPenalty(request.frequencyPenalty())
.withMaxTokens(request.maxTokens())
.withPresencePenalty(request.presencePenalty())
.withStopSequences(request.stop())
.withTemperature(request.temperature())
.withTopP(request.topP())
return ChatOptions.builder()
.model(request.model())
.frequencyPenalty(request.frequencyPenalty())
.maxTokens(request.maxTokens())
.presencePenalty(request.presencePenalty())
.stopSequences(request.stop())
.temperature(request.temperature())
.topP(request.topP())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletion;
Expand Down Expand Up @@ -406,12 +405,12 @@ private List<MistralAiApi.FunctionTool> getFunctionTools(Set<String> functionNam
}

private ChatOptions buildRequestOptions(MistralAiApi.ChatCompletionRequest request) {
return ChatOptionsBuilder.builder()
.withModel(request.model())
.withMaxTokens(request.maxTokens())
.withStopSequences(request.stop())
.withTemperature(request.temperature())
.withTopP(request.topP())
return ChatOptions.builder()
.model(request.model())
.maxTokens(request.maxTokens())
.stopSequences(request.stop())
.temperature(request.temperature())
.topP(request.topP())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
Expand Down Expand Up @@ -420,14 +419,14 @@ else if (message.getMessageType() == MessageType.TOOL) {
}

private ChatOptions buildRequestOptions(MoonshotApi.ChatCompletionRequest request) {
return ChatOptionsBuilder.builder()
.withModel(request.model())
.withFrequencyPenalty(request.frequencyPenalty())
.withMaxTokens(request.maxTokens())
.withPresencePenalty(request.presencePenalty())
.withStopSequences(request.stop())
.withTemperature(request.temperature())
.withTopP(request.topP())
return ChatOptions.builder()
.model(request.model())
.frequencyPenalty(request.frequencyPenalty())
.maxTokens(request.maxTokens())
.presencePenalty(request.presencePenalty())
.stopSequences(request.stop())
.temperature(request.temperature())
.topP(request.topP())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
Expand Down Expand Up @@ -421,15 +420,15 @@ private List<ChatRequest.Tool> getFunctionTools(Set<String> functionNames) {

private ChatOptions buildRequestOptions(OllamaApi.ChatRequest request) {
var options = ModelOptionsUtils.mapToClass(request.options(), OllamaOptions.class);
return ChatOptionsBuilder.builder()
.withModel(request.model())
.withFrequencyPenalty(options.getFrequencyPenalty())
.withMaxTokens(options.getMaxTokens())
.withPresencePenalty(options.getPresencePenalty())
.withStopSequences(options.getStopSequences())
.withTemperature(options.getTemperature())
.withTopK(options.getTopK())
.withTopP(options.getTopP())
return ChatOptions.builder()
.model(request.model())
.frequencyPenalty(options.getFrequencyPenalty())
.maxTokens(options.getMaxTokens())
.presencePenalty(options.getPresencePenalty())
.stopSequences(options.getStopSequences())
.temperature(options.getTemperature())
.topK(options.getTopK())
.topP(options.getTopP())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
Expand Down Expand Up @@ -91,7 +91,7 @@ void roleTest() {
UserMessage userMessage = new UserMessage("Tell me about 5 famous pirates from the Golden Age of Piracy.");

// portable/generic options
var portableOptions = ChatOptionsBuilder.builder().withTemperature(0.7).build();
var portableOptions = ChatOptions.builder().temperature(0.7).build();

Prompt prompt = new Prompt(List.of(systemMessage, userMessage), portableOptions);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.junit.jupiter.api.Test;

import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
Expand Down Expand Up @@ -78,11 +77,7 @@ public void createRequestWithPromptOllamaOptions() {
public void createRequestWithPromptPortableChatOptions() {

// Ollama runtime options.
ChatOptions portablePromptOptions = ChatOptionsBuilder.builder()
.withTemperature(0.9)
.withTopK(100)
.withTopP(0.6)
.build();
ChatOptions portablePromptOptions = ChatOptions.builder().temperature(0.9).topK(100).topP(0.6).build();

var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content", portablePromptOptions), true);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.Media;
import org.springframework.ai.model.ModelOptionsUtils;
Expand Down Expand Up @@ -595,14 +594,14 @@ private List<OpenAiApi.FunctionTool> getFunctionTools(Set<String> functionNames)
}

private ChatOptions buildRequestOptions(OpenAiApi.ChatCompletionRequest request) {
return ChatOptionsBuilder.builder()
.withModel(request.model())
.withFrequencyPenalty(request.frequencyPenalty())
.withMaxTokens(request.maxTokens())
.withPresencePenalty(request.presencePenalty())
.withStopSequences(request.stop())
.withTemperature(request.temperature())
.withTopP(request.topP())
return ChatOptions.builder()
.model(request.model())
.frequencyPenalty(request.frequencyPenalty())
.maxTokens(request.maxTokens())
.presencePenalty(request.presencePenalty())
.stopSequences(request.stop())
.temperature(request.temperature())
.topP(request.topP())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
Expand Down Expand Up @@ -518,7 +518,7 @@ void multiModalityInputAudio(String modelName) {
List.of(new Media(MimeTypeUtils.parseMimeType("audio/mp3"), audioResource)));

ChatResponse response = chatModel
.call(new Prompt(List.of(userMessage), ChatOptionsBuilder.builder().withModel(modelName).build()));
.call(new Prompt(List.of(userMessage), ChatOptions.builder().model(modelName).build()));

logger.info(response.getResult().getOutput().getText());
assertThat(response.getResult().getOutput().getText()).containsIgnoringCase("hobbits");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.qianfan.api.QianFanApi;
Expand Down Expand Up @@ -278,14 +277,14 @@ public ChatOptions getDefaultOptions() {
}

private ChatOptions buildRequestOptions(QianFanApi.ChatCompletionRequest request) {
return ChatOptionsBuilder.builder()
.withModel(request.model())
.withFrequencyPenalty(request.frequencyPenalty())
.withMaxTokens(request.maxTokens())
.withPresencePenalty(request.presencePenalty())
.withStopSequences(request.stop())
.withTemperature(request.temperature())
.withTopP(request.topP())
return ChatOptions.builder()
.model(request.model())
.frequencyPenalty(request.frequencyPenalty())
.maxTokens(request.maxTokens())
.presencePenalty(request.presencePenalty())
.stopSequences(request.stop())
.temperature(request.temperature())
.topP(request.topP())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.watsonx.api.WatsonxAiApi;
import org.springframework.ai.watsonx.api.WatsonxAiChatRequest;
Expand All @@ -54,7 +54,7 @@ public class WatsonxAiChatModelTest {

@Test
public void testCreateRequestWithNoModelId() {
var options = ChatOptionsBuilder.builder().withTemperature(0.9).withTopK(100).withTopP(0.6).build();
var options = ChatOptions.builder().temperature(0.9).topK(100).topP(0.6).build();

Prompt prompt = new Prompt("Test message", options);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
Expand Down Expand Up @@ -459,12 +458,12 @@ else if (mediaContentData instanceof String text) {
}

private ChatOptions buildRequestOptions(ZhiPuAiApi.ChatCompletionRequest request) {
return ChatOptionsBuilder.builder()
.withModel(request.model())
.withMaxTokens(request.maxTokens())
.withStopSequences(request.stop())
.withTemperature(request.temperature())
.withTopP(request.topP())
return ChatOptions.builder()
.model(request.model())
.maxTokens(request.maxTokens())
.stopSequences(request.stop())
.temperature(request.temperature())
.topP(request.topP())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ default String call(Message... messages) {
@Override
ChatResponse call(Prompt prompt);

ChatOptions getDefaultOptions();
default ChatOptions getDefaultOptions() {
return ChatOptions.builder().build();
}

default Flux<ChatResponse> stream(Prompt prompt) {
throw new UnsupportedOperationException("streaming is not supported");
Expand Down
Loading
Loading