diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 44dc45347b6..e6dd42ab0b6 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -263,6 +263,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon if (ollamaResponse.promptEvalCount() != null && ollamaResponse.evalCount() != null) { generationMetadata = ChatGenerationMetadata.builder() .finishReason(ollamaResponse.doneReason()) + .metadata("thinking", ollamaResponse.message().thinking()) .build(); } @@ -474,7 +475,8 @@ else if (message instanceof ToolResponseMessage toolMessage) { OllamaApi.ChatRequest.Builder requestBuilder = OllamaApi.ChatRequest.builder(requestOptions.getModel()) .stream(stream) .messages(ollamaMessages) - .options(requestOptions); + .options(requestOptions) + .think(requestOptions.isThink()); if (requestOptions.getFormat() != null) { requestBuilder.format(requestOptions.getFormat()); diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java index b481386a479..b221ab5f439 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java @@ -52,6 +52,7 @@ * @author Thomas Vitale * @author Jonghoon Park * @author Alexandros Pappas + * @author Sun Yuhan * @since 0.8.0 */ // @formatter:off @@ -258,6 +259,7 @@ public Flux pullModel(PullModelRequest pullModelRequest) { * * @param role The role of the message of type {@link Role}. * @param content The content of the message. + * @param thinking The thinking of the model. * @param images The list of base64-encoded images to send with the message. * Requires multimodal models such as llava or bakllava. * @param toolCalls The relevant tool call. @@ -267,6 +269,7 @@ public Flux pullModel(PullModelRequest pullModelRequest) { public record Message( @JsonProperty("role") Role role, @JsonProperty("content") String content, + @JsonProperty("thinking") String thinking, @JsonProperty("images") List images, @JsonProperty("tool_calls") List toolCalls) { @@ -328,6 +331,7 @@ public static class Builder { private final Role role; private String content; + private String thinking; private List images; private List toolCalls; @@ -340,6 +344,11 @@ public Builder content(String content) { return this; } + public Builder thinking(String thinking) { + this.thinking = thinking; + return this; + } + public Builder images(List images) { this.images = images; return this; @@ -351,7 +360,7 @@ public Builder toolCalls(List toolCalls) { } public Message build() { - return new Message(this.role, this.content, this.images, this.toolCalls); + return new Message(this.role, this.content, this.thinking, this.images, this.toolCalls); } } } @@ -366,6 +375,7 @@ public Message build() { * @param keepAlive Controls how long the model will stay loaded into memory following this request (default: 5m). * @param tools List of tools the model has access to. * @param options Model-specific options. For example, "temperature" can be set through this field, if the model supports it. + * @param think The model should think before responding, if the model supports it. * You can use the {@link OllamaOptions} builder to create the options then {@link OllamaOptions#toMap()} to convert the options into a map. * * @see tools, - @JsonProperty("options") Map options + @JsonProperty("options") Map options, + @JsonProperty("think") Boolean think ) { public static Builder builder(String model) { @@ -455,6 +466,7 @@ public static class Builder { private String keepAlive; private List tools = List.of(); private Map options = Map.of(); + private Boolean think; public Builder(String model) { Assert.notNull(model, "The model can not be null."); @@ -499,8 +511,13 @@ public Builder options(OllamaOptions options) { return this; } + public Builder think(Boolean think) { + this.think = think; + return this; + } + public ChatRequest build() { - return new ChatRequest(this.model, this.messages, this.stream, this.format, this.keepAlive, this.tools, this.options); + return new ChatRequest(this.model, this.messages, this.stream, this.format, this.keepAlive, this.tools, this.options, this.think); } } } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApiHelper.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApiHelper.java index 588b86c5364..b8728874cc9 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApiHelper.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApiHelper.java @@ -25,6 +25,7 @@ /** * @author Christian Tzolov + * @author Sun Yuhan * @since 1.0.0 */ public final class OllamaApiHelper { @@ -81,12 +82,18 @@ public static ChatResponse merge(ChatResponse previous, ChatResponse current) { private static OllamaApi.Message merge(OllamaApi.Message previous, OllamaApi.Message current) { String content = mergeContent(previous, current); + String thinking = mergeThinking(previous, current); OllamaApi.Message.Role role = (current.role() != null ? current.role() : previous.role()); role = (role != null ? role : OllamaApi.Message.Role.ASSISTANT); List images = mergeImages(previous, current); List toolCalls = mergeToolCall(previous, current); - return OllamaApi.Message.builder(role).content(content).images(images).toolCalls(toolCalls).build(); + return OllamaApi.Message.builder(role) + .content(content) + .thinking(thinking) + .images(images) + .toolCalls(toolCalls) + .build(); } private static Instant merge(Instant previous, Instant current) { @@ -134,6 +141,17 @@ private static String mergeContent(OllamaApi.Message previous, OllamaApi.Message return previous.content() + current.content(); } + private static String mergeThinking(OllamaApi.Message previous, OllamaApi.Message current) { + if (previous == null || previous.thinking() == null) { + return (current != null ? current.thinking() : null); + } + if (current == null || current.thinking() == null) { + return (previous != null ? previous.thinking() : null); + } + + return previous.thinking() + current.thinking(); + } + private static List mergeToolCall(OllamaApi.Message previous, OllamaApi.Message current) { if (previous == null) { diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java index 7602eca2584..aa91a2ec837 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java @@ -23,6 +23,7 @@ * * @author Siarhei Blashuk * @author Thomas Vitale + * @author Sun Yuhan * @since 1.0.0 */ public enum OllamaModel implements ChatModelDescription { @@ -32,6 +33,21 @@ public enum OllamaModel implements ChatModelDescription { */ QWEN_2_5_7B("qwen2.5"), + /** + * Qwen3 + */ + QWEN_3_8B("qwen3"), + + /** + * Qwen3 1.7b + */ + QWEN_3_1_7_B("qwen3:1.7b"), + + /** + * Qwen3 0.6b + */ + QWEN_3_06B("qwen3:0.6b"), + /** * QwQ is the reasoning model of the Qwen series. */ diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index a71be1ce2b2..473d8ca7e25 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -44,6 +44,7 @@ * @author Christian Tzolov * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @author Sun Yuhan * @since 0.8.0 * @see Ollama @@ -318,6 +319,14 @@ public class OllamaOptions implements ToolCallingChatOptions, EmbeddingOptions { @JsonProperty("truncate") private Boolean truncate; + /** + * The model should think before responding, if supported. + * If this value is not specified, it defaults to null, and Ollama will return + * the thought process within the `content` field of the response, wrapped in `<thinking>` tags. + */ + @JsonProperty("think") + private Boolean think; + @JsonIgnore private Boolean internalToolExecutionEnabled; @@ -365,6 +374,7 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) { .format(fromOptions.getFormat()) .keepAlive(fromOptions.getKeepAlive()) .truncate(fromOptions.getTruncate()) + .think(fromOptions.isThink()) .useNUMA(fromOptions.getUseNUMA()) .numCtx(fromOptions.getNumCtx()) .numBatch(fromOptions.getNumBatch()) @@ -704,6 +714,15 @@ public void setTruncate(Boolean truncate) { this.truncate = truncate; } + @Override + public Boolean isThink() { + return this.think; + } + + public void setThink(Boolean think) { + this.think = think; + } + @Override @JsonIgnore public List getToolCallbacks() { @@ -804,7 +823,8 @@ public boolean equals(Object o) { && Objects.equals(this.repeatPenalty, that.repeatPenalty) && Objects.equals(this.presencePenalty, that.presencePenalty) && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) - && Objects.equals(this.mirostat, that.mirostat) && Objects.equals(this.mirostatTau, that.mirostatTau) + && Objects.equals(this.think, that.think) && Objects.equals(this.mirostat, that.mirostat) + && Objects.equals(this.mirostatTau, that.mirostatTau) && Objects.equals(this.mirostatEta, that.mirostatEta) && Objects.equals(this.penalizeNewline, that.penalizeNewline) && Objects.equals(this.stop, that.stop) && Objects.equals(this.toolCallbacks, that.toolCallbacks) @@ -814,13 +834,13 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(this.model, this.format, this.keepAlive, this.truncate, this.useNUMA, this.numCtx, - this.numBatch, this.numGPU, this.mainGPU, this.lowVRAM, this.f16KV, this.logitsAll, this.vocabOnly, - this.useMMap, this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, this.topK, - this.topP, this.minP, this.tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty, - this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta, - this.penalizeNewline, this.stop, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, - this.toolContext); + return Objects.hash(this.model, this.format, this.keepAlive, this.truncate, this.think, this.useNUMA, + this.numCtx, this.numBatch, this.numGPU, this.mainGPU, this.lowVRAM, this.f16KV, this.logitsAll, + this.vocabOnly, this.useMMap, this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, + this.topK, this.topP, this.minP, this.tfsZ, this.typicalP, this.repeatLastN, this.temperature, + this.repeatPenalty, this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, + this.mirostatEta, this.penalizeNewline, this.stop, this.toolCallbacks, this.toolNames, + this.internalToolExecutionEnabled, this.toolContext); } public static class Builder { @@ -852,6 +872,11 @@ public Builder truncate(Boolean truncate) { return this; } + public Builder think(Boolean think) { + this.options.think = think; + return this; + } + public Builder useNUMA(Boolean useNUMA) { this.options.useNUMA = useNUMA; return this; diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMetadataTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMetadataTests.java new file mode 100644 index 00000000000..9f2d1c44dc6 --- /dev/null +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMetadataTests.java @@ -0,0 +1,127 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.ollama; + +import io.micrometer.observation.tck.TestObservationRegistry; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit Tests for {@link OllamaChatModel} asserting AI metadata. + * + * @author Sun Yuhan + */ +@SpringBootTest(classes = OllamaChatModelMetadataTests.Config.class) +class OllamaChatModelMetadataTests extends BaseOllamaIT { + + private static final String MODEL = OllamaModel.QWEN_3_06B.getName(); + + @Autowired + TestObservationRegistry observationRegistry; + + @Autowired + OllamaChatModel chatModel; + + @BeforeEach + void beforeEach() { + this.observationRegistry.clear(); + } + + @Test + void ollamaThinkingMetadataCaptured() { + var options = OllamaOptions.builder().model(MODEL).think(true).build(); + + Prompt prompt = new Prompt("Why is the sky blue?", options); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty(); + + chatResponse.getResults().forEach(generation -> { + ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata(); + assertThat(chatGenerationMetadata).isNotNull(); + var thinking = chatGenerationMetadata.get("thinking"); + assertThat(thinking).isNotNull(); + }); + } + + @Test + void ollamaThinkingMetadataNotCapturedWhenNotSetThinkFlag() { + var options = OllamaOptions.builder().model(MODEL).build(); + + Prompt prompt = new Prompt("Why is the sky blue?", options); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty(); + + chatResponse.getResults().forEach(generation -> { + ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata(); + assertThat(chatGenerationMetadata).isNotNull(); + var thinking = chatGenerationMetadata.get("thinking"); + assertThat(thinking).isNull(); + }); + } + + @Test + void ollamaThinkingMetadataNotCapturedWhenSetThinkFlagToFalse() { + var options = OllamaOptions.builder().model(MODEL).think(false).build(); + + Prompt prompt = new Prompt("Why is the sky blue?", options); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty(); + + chatResponse.getResults().forEach(generation -> { + ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata(); + assertThat(chatGenerationMetadata).isNotNull(); + var thinking = chatGenerationMetadata.get("thinking"); + assertThat(thinking).isNull(); + }); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public TestObservationRegistry observationRegistry() { + return TestObservationRegistry.create(); + } + + @Bean + public OllamaApi ollamaApi() { + return initializeOllama(MODEL); + } + + @Bean + public OllamaChatModel openAiChatModel(OllamaApi ollamaApi, TestObservationRegistry observationRegistry) { + return OllamaChatModel.builder().ollamaApi(ollamaApi).observationRegistry(observationRegistry).build(); + } + + } + +} diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java index e82ecb9ab67..146c2c042d5 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java @@ -23,7 +23,7 @@ */ public final class OllamaImage { - public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.6.7"); + public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.9.0"); private OllamaImage() { diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java index 98af032efbd..cf2ea4f8afd 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java @@ -33,14 +33,16 @@ import org.springframework.ai.ollama.api.OllamaApi.Message.Role; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertNull; /** * @author Christian Tzolov * @author Thomas Vitale + * @author Sun Yuhan */ public class OllamaApiIT extends BaseOllamaIT { - private static final String MODEL = OllamaModel.LLAMA3_2.getName(); + private static final String MODEL = OllamaModel.QWEN_3_1_7_B.getName(); @BeforeAll public static void beforeAll() throws IOException, InterruptedException { @@ -107,11 +109,67 @@ public void embedText() { assertThat(response).isNotNull(); assertThat(response.embeddings()).hasSize(1); - assertThat(response.embeddings().get(0)).hasSize(3072); + assertThat(response.embeddings().get(0)).hasSize(2048); assertThat(response.model()).isEqualTo(MODEL); assertThat(response.promptEvalCount()).isEqualTo(5); assertThat(response.loadDuration()).isGreaterThan(1); assertThat(response.totalDuration()).isGreaterThan(1); } + @Test + public void streamChatWithThinking() { + var request = ChatRequest.builder(MODEL) + .stream(true) + .think(true) + .messages(List.of(Message.builder(Role.USER).content("What are the planets in the solar system?").build())) + .options(OllamaOptions.builder().temperature(0.9).build().toMap()) + .build(); + + Flux response = getOllamaApi().streamingChat(request); + + List responses = response.collectList().block(); + System.out.println(responses); + + assertThat(responses).isNotNull(); + assertThat(responses.stream() + .filter(r -> r.message() != null) + .map(r -> r.message().thinking()) + .collect(Collectors.joining(System.lineSeparator()))).contains("solar"); + + ChatResponse lastResponse = responses.get(responses.size() - 1); + assertThat(lastResponse.message().content()).isEmpty(); + assertNull(lastResponse.message().thinking()); + assertThat(lastResponse.done()).isTrue(); + } + + @Test + public void streamChatWithoutThinking() { + var request = ChatRequest.builder(MODEL) + .stream(true) + .think(false) + .messages(List.of(Message.builder(Role.USER).content("What are the planets in the solar system?").build())) + .options(OllamaOptions.builder().temperature(0.9).build().toMap()) + .build(); + + Flux response = getOllamaApi().streamingChat(request); + + List responses = response.collectList().block(); + System.out.println(responses); + + assertThat(responses).isNotNull(); + + assertThat(responses.stream() + .filter(r -> r.message() != null) + .map(r -> r.message().content()) + .collect(Collectors.joining(System.lineSeparator()))).contains("Earth"); + + assertThat(responses.stream().filter(r -> r.message() != null).allMatch(r -> r.message().thinking() == null)) + .isTrue(); + + ChatResponse lastResponse = responses.get(responses.size() - 1); + assertThat(lastResponse.message().content()).isEmpty(); + assertNull(lastResponse.message().thinking()); + assertThat(lastResponse.done()).isTrue(); + } + } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java index 9f051ac0597..1ea4af559b4 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java @@ -83,6 +83,15 @@ public interface ChatOptions extends ModelOptions { @Nullable Double getTopP(); + /** + * Returns the think flag to use for the chat. + * @return the think flag to use for the chat + */ + @Nullable + default Boolean isThink() { + return false; + } + /** * Returns a copy of this {@link ChatOptions}. * @return a copy of this {@link ChatOptions} @@ -158,6 +167,13 @@ interface Builder { */ Builder topP(Double topP); + /** + * Builds with the think to use for the chat. + * @param think Whether to enable thinking mode + * @return the builder. + */ + Builder think(Boolean think); + /** * Build the {@link ChatOptions}. * @return the Chat options. diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java index 1af33bf3467..4bef45a2741 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java @@ -41,6 +41,8 @@ public class DefaultChatOptions implements ChatOptions { private Double topP; + private Boolean think; + @Override public String getModel() { return this.model; @@ -113,6 +115,15 @@ public void setTopP(Double topP) { this.topP = topP; } + @Override + public Boolean isThink() { + return this.think; + } + + public void setThink(Boolean think) { + this.think = think; + } + @Override @SuppressWarnings("unchecked") public T copy() { @@ -125,6 +136,7 @@ public T copy() { copy.setTemperature(this.getTemperature()); copy.setTopK(this.getTopK()); copy.setTopP(this.getTopP()); + copy.setThink(this.isThink()); return (T) copy; } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java index 47ba5840109..a317c8c8106 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java @@ -73,6 +73,11 @@ public DefaultChatOptionsBuilder topP(Double topP) { return this; } + public DefaultChatOptionsBuilder think(Boolean think) { + this.options.setThink(think); + return this; + } + public ChatOptions build() { return this.options.copy(); } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java index 870db6931b9..7c9c9397fc1 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java @@ -70,6 +70,9 @@ public class DefaultToolCallingChatOptions implements ToolCallingChatOptions { @Nullable private Double topP; + @Nullable + private Boolean think; + @Override public List getToolCallbacks() { return List.copyOf(this.toolCallbacks); @@ -198,6 +201,16 @@ public void setTopP(@Nullable Double topP) { this.topP = topP; } + @Override + @Nullable + public Boolean isThink() { + return this.think; + } + + public void setThink(@Nullable Boolean think) { + this.think = think; + } + @Override @SuppressWarnings("unchecked") public T copy() { @@ -325,6 +338,12 @@ public ToolCallingChatOptions.Builder topP(@Nullable Double topP) { return this; } + @Override + public ToolCallingChatOptions.Builder think(Boolean think) { + this.options.setThink(think); + return this; + } + @Override public ToolCallingChatOptions build() { return this.options; diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java index f06e71aa869..9cbdbe80c86 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java @@ -219,6 +219,9 @@ interface Builder extends ChatOptions.Builder { @Override Builder topP(@Nullable Double topP); + @Override + Builder think(@Nullable Boolean think); + @Override ToolCallingChatOptions build(); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java index bf8e0e1fd01..566ecc339e3 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java @@ -53,6 +53,7 @@ void shouldBuildWithAllOptions() { .topP(1.0) .topK(40) .stopSequences(List.of("stop1", "stop2")) + .think(true) .build(); assertThat(options.getModel()).isEqualTo("gpt-4"); @@ -60,6 +61,7 @@ void shouldBuildWithAllOptions() { assertThat(options.getTemperature()).isEqualTo(0.7); assertThat(options.getTopP()).isEqualTo(1.0); assertThat(options.getTopK()).isEqualTo(40); + assertThat(options.isThink()).isEqualTo(true); assertThat(options.getStopSequences()).containsExactly("stop1", "stop2"); } @@ -82,6 +84,7 @@ void shouldCopyOptions() { .temperature(0.7) .topP(1.0) .topK(40) + .think(true) .stopSequences(List.of("stop1", "stop2")) .build(); @@ -107,6 +110,7 @@ void shouldUpcastToChatOptions() { .temperature(0.7) .topP(1.0) .topK(40) + .think(true) .stopSequences(List.of("stop1", "stop2")) .toolNames(Set.of("function1", "function2")) .toolCallbacks(List.of(callback)) @@ -121,6 +125,7 @@ void shouldUpcastToChatOptions() { assertThat(chatOptions.getTemperature()).isEqualTo(0.7); assertThat(chatOptions.getTopP()).isEqualTo(1.0); assertThat(chatOptions.getTopK()).isEqualTo(40); + assertThat(chatOptions.isThink()).isEqualTo(true); assertThat(chatOptions.getStopSequences()).containsExactly("stop1", "stop2"); } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java index 45557f23a6d..7ce3d57af6a 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java @@ -188,6 +188,7 @@ void builderShouldCreateOptionsWithAllProperties() { .stopSequences(List.of("stop")) .topK(3) .topP(0.9) + .think(true) .build(); assertThat(options).satisfies(o -> { @@ -203,6 +204,7 @@ void builderShouldCreateOptionsWithAllProperties() { assertThat(o.getStopSequences()).containsExactly("stop"); assertThat(o.getTopK()).isEqualTo(3); assertThat(o.getTopP()).isEqualTo(0.9); + assertThat(o.isThink()).isEqualTo(true); }); }