From 524f12cec9004014ba2360b73ea13884394cd42d Mon Sep 17 00:00:00 2001 From: Sun Yuhan <1085481446@qq.com> Date: Fri, 30 May 2025 11:24:09 +0800 Subject: [PATCH 1/6] feat: Added support for the "think" field in Ollama 1. Added the `think` field to Ollama's `ChatRequest` 2. Added the `thinking` field to Ollama's `Message` 3. Added the `think` property to `OllamaOptions`, allowing users to specify whether to enable or disable thinking Signed-off-by: Sun Yuhan <1085481446@qq.com> --- .../ai/ollama/OllamaChatModel.java | 3 +- .../ai/ollama/api/OllamaApi.java | 23 +++++++++-- .../ai/ollama/api/OllamaApiHelper.java | 20 +++++++++- .../ai/ollama/api/OllamaModel.java | 11 +++++ .../ai/ollama/api/OllamaOptions.java | 40 +++++++++++++++---- .../ai/ollama/OllamaImage.java | 2 +- .../ai/ollama/api/OllamaApiIT.java | 29 ++++++++++++++ 7 files changed, 114 insertions(+), 14 deletions(-) diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index a75a274a797..3b15ccc0477 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -460,7 +460,8 @@ else if (message instanceof ToolResponseMessage toolMessage) { OllamaApi.ChatRequest.Builder requestBuilder = OllamaApi.ChatRequest.builder(requestOptions.getModel()) .stream(stream) .messages(ollamaMessages) - .options(requestOptions); + .options(requestOptions) + .think(requestOptions.getThink()); if (requestOptions.getFormat() != null) { requestBuilder.format(requestOptions.getFormat()); diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java index e0ffc06c31d..d03dc97700b 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java @@ -51,6 +51,7 @@ * @author Christian Tzolov * @author Thomas Vitale * @author Jonghoon Park + * @author Sun Yuhan * @since 0.8.0 */ // @formatter:off @@ -251,6 +252,7 @@ public Flux pullModel(PullModelRequest pullModelRequest) { * * @param role The role of the message of type {@link Role}. * @param content The content of the message. + * @param thinking The thinking of the model. * @param images The list of base64-encoded images to send with the message. * Requires multimodal models such as llava or bakllava. * @param toolCalls The relevant tool call. @@ -260,6 +262,7 @@ public Flux pullModel(PullModelRequest pullModelRequest) { public record Message( @JsonProperty("role") Role role, @JsonProperty("content") String content, + @JsonProperty("thinking") String thinking, @JsonProperty("images") List images, @JsonProperty("tool_calls") List toolCalls) { @@ -321,6 +324,7 @@ public static class Builder { private final Role role; private String content; + private String thinking; private List images; private List toolCalls; @@ -333,6 +337,11 @@ public Builder content(String content) { return this; } + public Builder thinking(String thinking) { + this.thinking = thinking; + return this; + } + public Builder images(List images) { this.images = images; return this; @@ -344,7 +353,7 @@ public Builder toolCalls(List toolCalls) { } public Message build() { - return new Message(this.role, this.content, this.images, this.toolCalls); + return new Message(this.role, this.content, this.thinking, this.images, this.toolCalls); } } } @@ -359,6 +368,7 @@ public Message build() { * @param keepAlive Controls how long the model will stay loaded into memory following this request (default: 5m). * @param tools List of tools the model has access to. * @param options Model-specific options. For example, "temperature" can be set through this field, if the model supports it. + * @param think The model should think before responding, if the model supports it. * You can use the {@link OllamaOptions} builder to create the options then {@link OllamaOptions#toMap()} to convert the options into a map. * * @see tools, - @JsonProperty("options") Map options + @JsonProperty("options") Map options, + @JsonProperty("think") Boolean think ) { public static Builder builder(String model) { @@ -448,6 +459,7 @@ public static class Builder { private String keepAlive; private List tools = List.of(); private Map options = Map.of(); + private boolean think; public Builder(String model) { Assert.notNull(model, "The model can not be null."); @@ -492,8 +504,13 @@ public Builder options(OllamaOptions options) { return this; } + public Builder think(boolean think) { + this.think = think; + return this; + } + public ChatRequest build() { - return new ChatRequest(this.model, this.messages, this.stream, this.format, this.keepAlive, this.tools, this.options); + return new ChatRequest(this.model, this.messages, this.stream, this.format, this.keepAlive, this.tools, this.options, this.think); } } } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApiHelper.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApiHelper.java index 588b86c5364..b8728874cc9 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApiHelper.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApiHelper.java @@ -25,6 +25,7 @@ /** * @author Christian Tzolov + * @author Sun Yuhan * @since 1.0.0 */ public final class OllamaApiHelper { @@ -81,12 +82,18 @@ public static ChatResponse merge(ChatResponse previous, ChatResponse current) { private static OllamaApi.Message merge(OllamaApi.Message previous, OllamaApi.Message current) { String content = mergeContent(previous, current); + String thinking = mergeThinking(previous, current); OllamaApi.Message.Role role = (current.role() != null ? current.role() : previous.role()); role = (role != null ? role : OllamaApi.Message.Role.ASSISTANT); List images = mergeImages(previous, current); List toolCalls = mergeToolCall(previous, current); - return OllamaApi.Message.builder(role).content(content).images(images).toolCalls(toolCalls).build(); + return OllamaApi.Message.builder(role) + .content(content) + .thinking(thinking) + .images(images) + .toolCalls(toolCalls) + .build(); } private static Instant merge(Instant previous, Instant current) { @@ -134,6 +141,17 @@ private static String mergeContent(OllamaApi.Message previous, OllamaApi.Message return previous.content() + current.content(); } + private static String mergeThinking(OllamaApi.Message previous, OllamaApi.Message current) { + if (previous == null || previous.thinking() == null) { + return (current != null ? current.thinking() : null); + } + if (current == null || current.thinking() == null) { + return (previous != null ? previous.thinking() : null); + } + + return previous.thinking() + current.thinking(); + } + private static List mergeToolCall(OllamaApi.Message previous, OllamaApi.Message current) { if (previous == null) { diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java index 7602eca2584..720c28631f2 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java @@ -23,6 +23,7 @@ * * @author Siarhei Blashuk * @author Thomas Vitale + * @author Sun Yuhan * @since 1.0.0 */ public enum OllamaModel implements ChatModelDescription { @@ -32,6 +33,16 @@ public enum OllamaModel implements ChatModelDescription { */ QWEN_2_5_7B("qwen2.5"), + /** + * Qwen3 + */ + QWEN_3_8B("qwen3"), + + /** + * Qwen3 4b + */ + QWEN_3_4B("qwen3:4b"), + /** * QwQ is the reasoning model of the Qwen series. */ diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index a71be1ce2b2..ca5aca624a5 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -44,6 +44,7 @@ * @author Christian Tzolov * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @author Sun Yuhan * @since 0.8.0 * @see Ollama @@ -318,6 +319,14 @@ public class OllamaOptions implements ToolCallingChatOptions, EmbeddingOptions { @JsonProperty("truncate") private Boolean truncate; + /** + * The model should think before responding, if supported. + * If this value is not specified, it defaults to null, and Ollama will return + * the thought process within the `content` field of the response, wrapped in `<thinking>` tags. + */ + @JsonProperty("think") + private Boolean think; + @JsonIgnore private Boolean internalToolExecutionEnabled; @@ -365,6 +374,7 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) { .format(fromOptions.getFormat()) .keepAlive(fromOptions.getKeepAlive()) .truncate(fromOptions.getTruncate()) + .think(fromOptions.getThink()) .useNUMA(fromOptions.getUseNUMA()) .numCtx(fromOptions.getNumCtx()) .numBatch(fromOptions.getNumBatch()) @@ -704,6 +714,14 @@ public void setTruncate(Boolean truncate) { this.truncate = truncate; } + public Boolean getThink() { + return this.think; + } + + public void setThink(Boolean think) { + this.think = think; + } + @Override @JsonIgnore public List getToolCallbacks() { @@ -804,7 +822,8 @@ public boolean equals(Object o) { && Objects.equals(this.repeatPenalty, that.repeatPenalty) && Objects.equals(this.presencePenalty, that.presencePenalty) && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) - && Objects.equals(this.mirostat, that.mirostat) && Objects.equals(this.mirostatTau, that.mirostatTau) + && Objects.equals(this.think, that.think) && Objects.equals(this.mirostat, that.mirostat) + && Objects.equals(this.mirostatTau, that.mirostatTau) && Objects.equals(this.mirostatEta, that.mirostatEta) && Objects.equals(this.penalizeNewline, that.penalizeNewline) && Objects.equals(this.stop, that.stop) && Objects.equals(this.toolCallbacks, that.toolCallbacks) @@ -814,13 +833,13 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(this.model, this.format, this.keepAlive, this.truncate, this.useNUMA, this.numCtx, - this.numBatch, this.numGPU, this.mainGPU, this.lowVRAM, this.f16KV, this.logitsAll, this.vocabOnly, - this.useMMap, this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, this.topK, - this.topP, this.minP, this.tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty, - this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta, - this.penalizeNewline, this.stop, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, - this.toolContext); + return Objects.hash(this.model, this.format, this.keepAlive, this.truncate, this.think, this.useNUMA, + this.numCtx, this.numBatch, this.numGPU, this.mainGPU, this.lowVRAM, this.f16KV, this.logitsAll, + this.vocabOnly, this.useMMap, this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, + this.topK, this.topP, this.minP, this.tfsZ, this.typicalP, this.repeatLastN, this.temperature, + this.repeatPenalty, this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, + this.mirostatEta, this.penalizeNewline, this.stop, this.toolCallbacks, this.toolNames, + this.internalToolExecutionEnabled, this.toolContext); } public static class Builder { @@ -852,6 +871,11 @@ public Builder truncate(Boolean truncate) { return this; } + public Builder think(Boolean think) { + this.options.think = think; + return this; + } + public Builder useNUMA(Boolean useNUMA) { this.options.useNUMA = useNUMA; return this; diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java index 2220bf22695..146c2c042d5 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java @@ -23,7 +23,7 @@ */ public final class OllamaImage { - public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.5.2"); + public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.9.0"); private OllamaImage() { diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java index 98af032efbd..582556c10de 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java @@ -33,6 +33,7 @@ import org.springframework.ai.ollama.api.OllamaApi.Message.Role; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertNull; /** * @author Christian Tzolov @@ -114,4 +115,32 @@ public void embedText() { assertThat(response.totalDuration()).isGreaterThan(1); } + @Test + public void chatWithThinking() { + var request = ChatRequest.builder(MODEL) + .stream(true) + .think(true) + .messages(List.of(Message.builder(Role.USER) + .content("What is the capital of Bulgaria and what is the size? " + "What it the national anthem?") + .build())) + .options(OllamaOptions.builder().temperature(0.9).build().toMap()) + .build(); + + Flux response = getOllamaApi().streamingChat(request); + + List responses = response.collectList().block(); + System.out.println(responses); + + assertThat(responses).isNotNull(); + assertThat(responses.stream() + .filter(r -> r.message() != null) + .map(r -> r.message().thinking()) + .collect(Collectors.joining(System.lineSeparator()))).contains("Sofia"); + + ChatResponse lastResponse = responses.get(responses.size() - 1); + assertThat(lastResponse.message().content()).isEmpty(); + assertNull(lastResponse.message().thinking()); + assertThat(lastResponse.done()).isTrue(); + } + } From af9958bcf6694e35e67bd56288eb5a7920fa9bdb Mon Sep 17 00:00:00 2001 From: Sun Yuhan <1085481446@qq.com> Date: Fri, 30 May 2025 11:37:21 +0800 Subject: [PATCH 2/6] feat: Remove problematic tweaks Signed-off-by: Sun Yuhan <1085481446@qq.com> --- .../ai/ollama/OllamaImage.java | 2 +- .../ai/ollama/api/OllamaApiIT.java | 29 ------------------- 2 files changed, 1 insertion(+), 30 deletions(-) diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java index 146c2c042d5..2220bf22695 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java @@ -23,7 +23,7 @@ */ public final class OllamaImage { - public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.9.0"); + public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.5.2"); private OllamaImage() { diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java index 582556c10de..98af032efbd 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java @@ -33,7 +33,6 @@ import org.springframework.ai.ollama.api.OllamaApi.Message.Role; import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.assertNull; /** * @author Christian Tzolov @@ -115,32 +114,4 @@ public void embedText() { assertThat(response.totalDuration()).isGreaterThan(1); } - @Test - public void chatWithThinking() { - var request = ChatRequest.builder(MODEL) - .stream(true) - .think(true) - .messages(List.of(Message.builder(Role.USER) - .content("What is the capital of Bulgaria and what is the size? " + "What it the national anthem?") - .build())) - .options(OllamaOptions.builder().temperature(0.9).build().toMap()) - .build(); - - Flux response = getOllamaApi().streamingChat(request); - - List responses = response.collectList().block(); - System.out.println(responses); - - assertThat(responses).isNotNull(); - assertThat(responses.stream() - .filter(r -> r.message() != null) - .map(r -> r.message().thinking()) - .collect(Collectors.joining(System.lineSeparator()))).contains("Sofia"); - - ChatResponse lastResponse = responses.get(responses.size() - 1); - assertThat(lastResponse.message().content()).isEmpty(); - assertNull(lastResponse.message().thinking()); - assertThat(lastResponse.done()).isTrue(); - } - } From 91fb15181928c938fb69245fbca484d211078154 Mon Sep 17 00:00:00 2001 From: Sun Yuhan <1085481446@qq.com> Date: Fri, 30 May 2025 11:49:30 +0800 Subject: [PATCH 3/6] fix: Adjust the type of the "think" field in Ollama to support its default behavior, thereby ensuring compatibility with older versions of Ollama calls. Signed-off-by: Sun Yuhan <1085481446@qq.com> --- .../java/org/springframework/ai/ollama/api/OllamaApi.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java index d03dc97700b..b08fe066f45 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java @@ -459,7 +459,7 @@ public static class Builder { private String keepAlive; private List tools = List.of(); private Map options = Map.of(); - private boolean think; + private Boolean think; public Builder(String model) { Assert.notNull(model, "The model can not be null."); @@ -504,7 +504,7 @@ public Builder options(OllamaOptions options) { return this; } - public Builder think(boolean think) { + public Builder think(Boolean think) { this.think = think; return this; } From f77e08a69a3afe1bd90750c28f96a89f48151987 Mon Sep 17 00:00:00 2001 From: Sun Yuhan <1085481446@qq.com> Date: Fri, 30 May 2025 15:48:48 +0800 Subject: [PATCH 4/6] feat: Add unit tests for ollama's think support. upgrade the test-container image version of ollama to 0.9.0 Signed-off-by: Sun Yuhan <1085481446@qq.com> --- .../ai/ollama/api/OllamaModel.java | 9 ++- .../ai/ollama/OllamaImage.java | 2 +- .../ai/ollama/api/OllamaApiIT.java | 62 ++++++++++++++++++- 3 files changed, 68 insertions(+), 5 deletions(-) diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java index 720c28631f2..aa91a2ec837 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java @@ -39,9 +39,14 @@ public enum OllamaModel implements ChatModelDescription { QWEN_3_8B("qwen3"), /** - * Qwen3 4b + * Qwen3 1.7b */ - QWEN_3_4B("qwen3:4b"), + QWEN_3_1_7_B("qwen3:1.7b"), + + /** + * Qwen3 0.6b + */ + QWEN_3_06B("qwen3:0.6b"), /** * QwQ is the reasoning model of the Qwen series. diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java index 2220bf22695..146c2c042d5 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java @@ -23,7 +23,7 @@ */ public final class OllamaImage { - public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.5.2"); + public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.9.0"); private OllamaImage() { diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java index 98af032efbd..cf2ea4f8afd 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java @@ -33,14 +33,16 @@ import org.springframework.ai.ollama.api.OllamaApi.Message.Role; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertNull; /** * @author Christian Tzolov * @author Thomas Vitale + * @author Sun Yuhan */ public class OllamaApiIT extends BaseOllamaIT { - private static final String MODEL = OllamaModel.LLAMA3_2.getName(); + private static final String MODEL = OllamaModel.QWEN_3_1_7_B.getName(); @BeforeAll public static void beforeAll() throws IOException, InterruptedException { @@ -107,11 +109,67 @@ public void embedText() { assertThat(response).isNotNull(); assertThat(response.embeddings()).hasSize(1); - assertThat(response.embeddings().get(0)).hasSize(3072); + assertThat(response.embeddings().get(0)).hasSize(2048); assertThat(response.model()).isEqualTo(MODEL); assertThat(response.promptEvalCount()).isEqualTo(5); assertThat(response.loadDuration()).isGreaterThan(1); assertThat(response.totalDuration()).isGreaterThan(1); } + @Test + public void streamChatWithThinking() { + var request = ChatRequest.builder(MODEL) + .stream(true) + .think(true) + .messages(List.of(Message.builder(Role.USER).content("What are the planets in the solar system?").build())) + .options(OllamaOptions.builder().temperature(0.9).build().toMap()) + .build(); + + Flux response = getOllamaApi().streamingChat(request); + + List responses = response.collectList().block(); + System.out.println(responses); + + assertThat(responses).isNotNull(); + assertThat(responses.stream() + .filter(r -> r.message() != null) + .map(r -> r.message().thinking()) + .collect(Collectors.joining(System.lineSeparator()))).contains("solar"); + + ChatResponse lastResponse = responses.get(responses.size() - 1); + assertThat(lastResponse.message().content()).isEmpty(); + assertNull(lastResponse.message().thinking()); + assertThat(lastResponse.done()).isTrue(); + } + + @Test + public void streamChatWithoutThinking() { + var request = ChatRequest.builder(MODEL) + .stream(true) + .think(false) + .messages(List.of(Message.builder(Role.USER).content("What are the planets in the solar system?").build())) + .options(OllamaOptions.builder().temperature(0.9).build().toMap()) + .build(); + + Flux response = getOllamaApi().streamingChat(request); + + List responses = response.collectList().block(); + System.out.println(responses); + + assertThat(responses).isNotNull(); + + assertThat(responses.stream() + .filter(r -> r.message() != null) + .map(r -> r.message().content()) + .collect(Collectors.joining(System.lineSeparator()))).contains("Earth"); + + assertThat(responses.stream().filter(r -> r.message() != null).allMatch(r -> r.message().thinking() == null)) + .isTrue(); + + ChatResponse lastResponse = responses.get(responses.size() - 1); + assertThat(lastResponse.message().content()).isEmpty(); + assertNull(lastResponse.message().thinking()); + assertThat(lastResponse.done()).isTrue(); + } + } From b188adcab68373c2ee57f4134f4e2b87b04913ed Mon Sep 17 00:00:00 2001 From: Sun Yuhan Date: Fri, 13 Jun 2025 13:59:54 +0800 Subject: [PATCH 5/6] feat: Propagate the `thinking` returned by Ollama back to `ChatGenerationMetadata`, and added corresponding unit tests. Signed-off-by: Sun Yuhan --- .../ai/ollama/OllamaChatModel.java | 1 + .../ollama/OllamaChatModelMetadataTests.java | 127 ++++++++++++++++++ 2 files changed, 128 insertions(+) create mode 100644 models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMetadataTests.java diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 3b15ccc0477..9cd442f99d4 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -249,6 +249,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon if (ollamaResponse.promptEvalCount() != null && ollamaResponse.evalCount() != null) { generationMetadata = ChatGenerationMetadata.builder() .finishReason(ollamaResponse.doneReason()) + .metadata("thinking", ollamaResponse.message().thinking()) .build(); } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMetadataTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMetadataTests.java new file mode 100644 index 00000000000..9f2d1c44dc6 --- /dev/null +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMetadataTests.java @@ -0,0 +1,127 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.ollama; + +import io.micrometer.observation.tck.TestObservationRegistry; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit Tests for {@link OllamaChatModel} asserting AI metadata. + * + * @author Sun Yuhan + */ +@SpringBootTest(classes = OllamaChatModelMetadataTests.Config.class) +class OllamaChatModelMetadataTests extends BaseOllamaIT { + + private static final String MODEL = OllamaModel.QWEN_3_06B.getName(); + + @Autowired + TestObservationRegistry observationRegistry; + + @Autowired + OllamaChatModel chatModel; + + @BeforeEach + void beforeEach() { + this.observationRegistry.clear(); + } + + @Test + void ollamaThinkingMetadataCaptured() { + var options = OllamaOptions.builder().model(MODEL).think(true).build(); + + Prompt prompt = new Prompt("Why is the sky blue?", options); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty(); + + chatResponse.getResults().forEach(generation -> { + ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata(); + assertThat(chatGenerationMetadata).isNotNull(); + var thinking = chatGenerationMetadata.get("thinking"); + assertThat(thinking).isNotNull(); + }); + } + + @Test + void ollamaThinkingMetadataNotCapturedWhenNotSetThinkFlag() { + var options = OllamaOptions.builder().model(MODEL).build(); + + Prompt prompt = new Prompt("Why is the sky blue?", options); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty(); + + chatResponse.getResults().forEach(generation -> { + ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata(); + assertThat(chatGenerationMetadata).isNotNull(); + var thinking = chatGenerationMetadata.get("thinking"); + assertThat(thinking).isNull(); + }); + } + + @Test + void ollamaThinkingMetadataNotCapturedWhenSetThinkFlagToFalse() { + var options = OllamaOptions.builder().model(MODEL).think(false).build(); + + Prompt prompt = new Prompt("Why is the sky blue?", options); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty(); + + chatResponse.getResults().forEach(generation -> { + ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata(); + assertThat(chatGenerationMetadata).isNotNull(); + var thinking = chatGenerationMetadata.get("thinking"); + assertThat(thinking).isNull(); + }); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public TestObservationRegistry observationRegistry() { + return TestObservationRegistry.create(); + } + + @Bean + public OllamaApi ollamaApi() { + return initializeOllama(MODEL); + } + + @Bean + public OllamaChatModel openAiChatModel(OllamaApi ollamaApi, TestObservationRegistry observationRegistry) { + return OllamaChatModel.builder().ollamaApi(ollamaApi).observationRegistry(observationRegistry).build(); + } + + } + +} From 552a3467cc726e0abbbde6e121589d423004bc0c Mon Sep 17 00:00:00 2001 From: Sun Yuhan Date: Mon, 16 Jun 2025 10:44:59 +0800 Subject: [PATCH 6/6] feat: Add support for the think flag at the ChatMode level. Signed-off-by: Sun Yuhan --- .../ai/ollama/OllamaChatModel.java | 2 +- .../ai/ollama/api/OllamaOptions.java | 5 +++-- .../ai/chat/prompt/ChatOptions.java | 16 ++++++++++++++++ .../ai/chat/prompt/DefaultChatOptions.java | 12 ++++++++++++ .../prompt/DefaultChatOptionsBuilder.java | 5 +++++ .../tool/DefaultToolCallingChatOptions.java | 19 +++++++++++++++++++ .../ai/model/tool/ToolCallingChatOptions.java | 3 +++ .../chat/prompt/ChatOptionsBuilderTests.java | 5 +++++ .../DefaultToolCallingChatOptionsTests.java | 2 ++ 9 files changed, 66 insertions(+), 3 deletions(-) diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 9cd442f99d4..d4f8fc3007a 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -462,7 +462,7 @@ else if (message instanceof ToolResponseMessage toolMessage) { .stream(stream) .messages(ollamaMessages) .options(requestOptions) - .think(requestOptions.getThink()); + .think(requestOptions.isThink()); if (requestOptions.getFormat() != null) { requestBuilder.format(requestOptions.getFormat()); diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index ca5aca624a5..473d8ca7e25 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -374,7 +374,7 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) { .format(fromOptions.getFormat()) .keepAlive(fromOptions.getKeepAlive()) .truncate(fromOptions.getTruncate()) - .think(fromOptions.getThink()) + .think(fromOptions.isThink()) .useNUMA(fromOptions.getUseNUMA()) .numCtx(fromOptions.getNumCtx()) .numBatch(fromOptions.getNumBatch()) @@ -714,7 +714,8 @@ public void setTruncate(Boolean truncate) { this.truncate = truncate; } - public Boolean getThink() { + @Override + public Boolean isThink() { return this.think; } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java index 9f051ac0597..1ea4af559b4 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java @@ -83,6 +83,15 @@ public interface ChatOptions extends ModelOptions { @Nullable Double getTopP(); + /** + * Returns the think flag to use for the chat. + * @return the think flag to use for the chat + */ + @Nullable + default Boolean isThink() { + return false; + } + /** * Returns a copy of this {@link ChatOptions}. * @return a copy of this {@link ChatOptions} @@ -158,6 +167,13 @@ interface Builder { */ Builder topP(Double topP); + /** + * Builds with the think to use for the chat. + * @param think Whether to enable thinking mode + * @return the builder. + */ + Builder think(Boolean think); + /** * Build the {@link ChatOptions}. * @return the Chat options. diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java index 1af33bf3467..4bef45a2741 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java @@ -41,6 +41,8 @@ public class DefaultChatOptions implements ChatOptions { private Double topP; + private Boolean think; + @Override public String getModel() { return this.model; @@ -113,6 +115,15 @@ public void setTopP(Double topP) { this.topP = topP; } + @Override + public Boolean isThink() { + return this.think; + } + + public void setThink(Boolean think) { + this.think = think; + } + @Override @SuppressWarnings("unchecked") public T copy() { @@ -125,6 +136,7 @@ public T copy() { copy.setTemperature(this.getTemperature()); copy.setTopK(this.getTopK()); copy.setTopP(this.getTopP()); + copy.setThink(this.isThink()); return (T) copy; } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java index 47ba5840109..a317c8c8106 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java @@ -73,6 +73,11 @@ public DefaultChatOptionsBuilder topP(Double topP) { return this; } + public DefaultChatOptionsBuilder think(Boolean think) { + this.options.setThink(think); + return this; + } + public ChatOptions build() { return this.options.copy(); } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java index 870db6931b9..7c9c9397fc1 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java @@ -70,6 +70,9 @@ public class DefaultToolCallingChatOptions implements ToolCallingChatOptions { @Nullable private Double topP; + @Nullable + private Boolean think; + @Override public List getToolCallbacks() { return List.copyOf(this.toolCallbacks); @@ -198,6 +201,16 @@ public void setTopP(@Nullable Double topP) { this.topP = topP; } + @Override + @Nullable + public Boolean isThink() { + return this.think; + } + + public void setThink(@Nullable Boolean think) { + this.think = think; + } + @Override @SuppressWarnings("unchecked") public T copy() { @@ -325,6 +338,12 @@ public ToolCallingChatOptions.Builder topP(@Nullable Double topP) { return this; } + @Override + public ToolCallingChatOptions.Builder think(Boolean think) { + this.options.setThink(think); + return this; + } + @Override public ToolCallingChatOptions build() { return this.options; diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java index f06e71aa869..9cbdbe80c86 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java @@ -219,6 +219,9 @@ interface Builder extends ChatOptions.Builder { @Override Builder topP(@Nullable Double topP); + @Override + Builder think(@Nullable Boolean think); + @Override ToolCallingChatOptions build(); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java index bf8e0e1fd01..566ecc339e3 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java @@ -53,6 +53,7 @@ void shouldBuildWithAllOptions() { .topP(1.0) .topK(40) .stopSequences(List.of("stop1", "stop2")) + .think(true) .build(); assertThat(options.getModel()).isEqualTo("gpt-4"); @@ -60,6 +61,7 @@ void shouldBuildWithAllOptions() { assertThat(options.getTemperature()).isEqualTo(0.7); assertThat(options.getTopP()).isEqualTo(1.0); assertThat(options.getTopK()).isEqualTo(40); + assertThat(options.isThink()).isEqualTo(true); assertThat(options.getStopSequences()).containsExactly("stop1", "stop2"); } @@ -82,6 +84,7 @@ void shouldCopyOptions() { .temperature(0.7) .topP(1.0) .topK(40) + .think(true) .stopSequences(List.of("stop1", "stop2")) .build(); @@ -107,6 +110,7 @@ void shouldUpcastToChatOptions() { .temperature(0.7) .topP(1.0) .topK(40) + .think(true) .stopSequences(List.of("stop1", "stop2")) .toolNames(Set.of("function1", "function2")) .toolCallbacks(List.of(callback)) @@ -121,6 +125,7 @@ void shouldUpcastToChatOptions() { assertThat(chatOptions.getTemperature()).isEqualTo(0.7); assertThat(chatOptions.getTopP()).isEqualTo(1.0); assertThat(chatOptions.getTopK()).isEqualTo(40); + assertThat(chatOptions.isThink()).isEqualTo(true); assertThat(chatOptions.getStopSequences()).containsExactly("stop1", "stop2"); } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java index 45557f23a6d..7ce3d57af6a 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java @@ -188,6 +188,7 @@ void builderShouldCreateOptionsWithAllProperties() { .stopSequences(List.of("stop")) .topK(3) .topP(0.9) + .think(true) .build(); assertThat(options).satisfies(o -> { @@ -203,6 +204,7 @@ void builderShouldCreateOptionsWithAllProperties() { assertThat(o.getStopSequences()).containsExactly("stop"); assertThat(o.getTopK()).isEqualTo(3); assertThat(o.getTopP()).isEqualTo(0.9); + assertThat(o.isThink()).isEqualTo(true); }); }