From cc705a7f21d88bec0742ebba830129dc04d21f94 Mon Sep 17 00:00:00 2001 From: YunKui Lu Date: Fri, 6 Jun 2025 20:41:49 +0800 Subject: [PATCH] fix: Incorrect order when Advisors have the same order - Fix incorrect order when Advisors have the same order - Added the `hasNextCallAdvisor` method to `CallAdvisorChain`. - Added the `hasNextStreamAdvisor` method to `StreamAdvisorChain`. - The `DefaultAroundAdvisorChain` implements both `hasNextCallAdvisor` and `hasNextStreamAdvisor` method. - Added a last StreamAdvisor check in `ChatModelStreamAdvisor`. - Added a last CallAdvisor check in `ChatModelCallAdvisor`. - Updated the corresponding test cases. Signed-off-by: YunKui Lu --- .../client/advisor/ChatModelCallAdvisor.java | 2 + .../advisor/ChatModelStreamAdvisor.java | 2 + .../advisor/DefaultAroundAdvisorChain.java | 14 ++- .../client/advisor/api/CallAdvisorChain.java | 8 ++ .../advisor/api/StreamAdvisorChain.java | 8 ++ .../advisor/ChatModelCallAdvisorTests.java | 21 ++++ .../advisor/ChatModelStreamAdvisorTests.java | 21 ++++ .../DefaultAroundAdvisorChainTests.java | 110 ++++++++++++++++++ 8 files changed, 184 insertions(+), 2 deletions(-) diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java index 051cbdd808c..0a0dbd5067b 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java @@ -48,6 +48,8 @@ private ChatModelCallAdvisor(ChatModel chatModel) { @Override public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null"); + Assert.isTrue(!callAdvisorChain.hasNextCallAdvisor(), + "ChatModelCallAdvisor should be the last CallAdvisor in the chain"); ChatClientRequest formattedChatClientRequest = augmentWithFormatInstructions(chatClientRequest); diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelStreamAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelStreamAdvisor.java index de691318df1..2b38cbd7837 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelStreamAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelStreamAdvisor.java @@ -48,6 +48,8 @@ private ChatModelStreamAdvisor(ChatModel chatModel) { public Flux adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) { Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null"); + Assert.isTrue(!streamAdvisorChain.hasNextStreamAdvisor(), + "ChatModelStreamAdvisor should be the last StreamAdvisor in the chain"); return this.chatModel.stream(chatClientRequest.prompt()) .map(chatResponse -> ChatClientResponse.builder() diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java index e9d1a1d2066..b0e01868955 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java @@ -146,11 +146,21 @@ public List getCallAdvisors() { return this.originalCallAdvisors; } + @Override + public boolean hasNextCallAdvisor() { + return !this.callAdvisors.isEmpty(); + } + @Override public List getStreamAdvisors() { return this.originalStreamAdvisors; } + @Override + public boolean hasNextStreamAdvisor() { + return !this.streamAdvisors.isEmpty(); + } + @Override public ObservationRegistry getObservationRegistry() { return this.observationRegistry; @@ -192,7 +202,7 @@ public Builder pushAll(List advisors) { .toList(); if (!CollectionUtils.isEmpty(callAroundAdvisorList)) { - callAroundAdvisorList.forEach(this.callAdvisors::push); + this.callAdvisors.addAll(callAroundAdvisorList); } List streamAroundAdvisorList = advisors.stream() @@ -201,7 +211,7 @@ public Builder pushAll(List advisors) { .toList(); if (!CollectionUtils.isEmpty(streamAroundAdvisorList)) { - streamAroundAdvisorList.forEach(this.streamAdvisors::push); + this.streamAdvisors.addAll(streamAroundAdvisorList); } this.reOrder(); diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAdvisorChain.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAdvisorChain.java index cfb009b2fb8..2435ee1dcf9 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAdvisorChain.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAdvisorChain.java @@ -28,6 +28,7 @@ * @author Christian Tzolov * @author Dariusz Jedrzejczyk * @author Thomas Vitale + * @author YunKui Lu * @since 1.0.0 */ public interface CallAdvisorChain extends AdvisorChain { @@ -44,4 +45,11 @@ public interface CallAdvisorChain extends AdvisorChain { */ List getCallAdvisors(); + /** + * Returns true if there is a next {@link CallAdvisor} in the chain. + */ + default boolean hasNextCallAdvisor() { + throw new UnsupportedOperationException("This CallAdvisorChain does not support hasNextCallAdvisor()"); + } + } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAdvisorChain.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAdvisorChain.java index 230192de3aa..3e7c8377dbb 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAdvisorChain.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAdvisorChain.java @@ -30,6 +30,7 @@ * @author Christian Tzolov * @author Dariusz Jedrzejczyk * @author Thomas Vitale + * @author YunKui Lu * @since 1.0.0 */ public interface StreamAdvisorChain extends AdvisorChain { @@ -46,4 +47,11 @@ public interface StreamAdvisorChain extends AdvisorChain { */ List getStreamAdvisors(); + /** + * Returns true if there is a next {@link StreamAdvisor} in the chain. + */ + default boolean hasNextStreamAdvisor() { + throw new UnsupportedOperationException("This StreamAdvisorChain does not support hasNextStreamAdvisor()"); + } + } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisorTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisorTests.java index d0c6024d92e..83c9fd8a2dd 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisorTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisorTests.java @@ -18,7 +18,13 @@ import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain; +import org.springframework.ai.chat.model.ChatModel; + import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * Unit tests for {@link ChatModelCallAdvisor}. @@ -34,4 +40,19 @@ void whenChatModelIsNullThenThrow() { .hasMessage("chatModel cannot be null"); } + @Test + void whenNotLastInChainThrow() { + ChatModel chatModel = mock(ChatModel.class); + ChatClientRequest chatClientRequest = mock(ChatClientRequest.class); + CallAdvisorChain callAdvisorChain = mock(CallAdvisorChain.class); + + when(callAdvisorChain.hasNextCallAdvisor()).thenReturn(true); + + ChatModelCallAdvisor chatModelCallAdvisor = ChatModelCallAdvisor.builder().chatModel(chatModel).build(); + + assertThatThrownBy(() -> chatModelCallAdvisor.adviseCall(chatClientRequest, callAdvisorChain)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("ChatModelCallAdvisor should be the last CallAdvisor in the chain"); + } + } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/ChatModelStreamAdvisorTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/ChatModelStreamAdvisorTests.java index f613e2a3bd2..8d89088fc8b 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/ChatModelStreamAdvisorTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/ChatModelStreamAdvisorTests.java @@ -18,7 +18,13 @@ import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; +import org.springframework.ai.chat.model.ChatModel; + import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * Unit tests for {@link ChatModelStreamAdvisor}. @@ -34,4 +40,19 @@ void whenChatModelIsNullThenThrow() { .hasMessage("chatModel cannot be null"); } + @Test + void whenNotLastInChainThrow() { + ChatModel chatModel = mock(ChatModel.class); + ChatClientRequest chatClientRequest = mock(ChatClientRequest.class); + StreamAdvisorChain streamAdvisorChain = mock(StreamAdvisorChain.class); + + when(streamAdvisorChain.hasNextStreamAdvisor()).thenReturn(true); + + ChatModelStreamAdvisor chatModelStreamAdvisor = ChatModelStreamAdvisor.builder().chatModel(chatModel).build(); + + assertThatThrownBy(() -> chatModelStreamAdvisor.adviseStream(chatClientRequest, streamAdvisorChain)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("ChatModelStreamAdvisor should be the last StreamAdvisor in the chain"); + } + } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChainTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChainTests.java index ed00537f716..866c44d5874 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChainTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChainTests.java @@ -103,6 +103,34 @@ void getCallAdvisors() { assertThat(chain.getCallAdvisors()).containsExactlyInAnyOrder(advisors.toArray(new CallAdvisor[0])); } + @Test + void hasNextCallAdvisor() { + // The first advisor + TestAdvisor advisor1 = new TestAdvisor("advisor1", 1) { + @Override + public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, + CallAdvisorChain callAdvisorChain) { + assertThat(callAdvisorChain.hasNextCallAdvisor()).isTrue(); + return callAdvisorChain.nextCall(chatClientRequest); + } + }; + + // The last advisor + TestAdvisor advisor2 = new TestAdvisor("advisor2", 2) { + @Override + public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, + CallAdvisorChain callAdvisorChain) { + assertThat(callAdvisorChain.hasNextCallAdvisor()).isFalse(); + return null; + } + }; + + List advisors = List.of(advisor1, advisor2); + CallAdvisorChain chain = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP).pushAll(advisors).build(); + + chain.nextCall(mock(ChatClientRequest.class)); + } + @Test void getStreamAdvisors() { StreamAdvisor mockAdvisor1 = mock(StreamAdvisor.class); @@ -125,4 +153,86 @@ void getStreamAdvisors() { assertThat(chain.getStreamAdvisors()).containsExactlyInAnyOrder(advisors.toArray(new StreamAdvisor[0])); } + @Test + void hasNextStreamAdvisor() { + // The first advisor + TestAdvisor advisor1 = new TestAdvisor("advisor1", 1) { + @Override + public Flux adviseStream(ChatClientRequest chatClientRequest, + StreamAdvisorChain streamAdvisorChain) { + assertThat(streamAdvisorChain.hasNextStreamAdvisor()).isTrue(); + return streamAdvisorChain.nextStream(chatClientRequest); + } + }; + // The last advisor + TestAdvisor advisor2 = new TestAdvisor("advisor2", 2) { + @Override + public Flux adviseStream(ChatClientRequest chatClientRequest, + StreamAdvisorChain streamAdvisorChain) { + assertThat(streamAdvisorChain.hasNextStreamAdvisor()).isFalse(); + return Flux.empty(); + } + }; + + List advisors = List.of(advisor1, advisor2); + StreamAdvisorChain chain = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP) + .pushAll(advisors) + .build(); + + chain.nextStream(mock(ChatClientRequest.class)).blockLast(); + } + + @Test + void testOrder() { + TestAdvisor advisor1 = new TestAdvisor("advisor1", 1); + TestAdvisor advisor21 = new TestAdvisor("advisor2_1", 2); + TestAdvisor advisor22 = new TestAdvisor("advisor2_2", 2); + TestAdvisor advisor3 = new TestAdvisor("advisor3", 3); + + var advisors = List.of(advisor3, advisor1, advisor21, advisor22); + + DefaultAroundAdvisorChain chain = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP) + .pushAll(advisors) + .build(); + + assertThat(chain.getStreamAdvisors()).containsExactly(advisor1, advisor21, advisor22, advisor3); + assertThat(chain.getCallAdvisors()).containsExactly(advisor1, advisor21, advisor22, advisor3); + } + + private static class TestAdvisor implements CallAdvisor, StreamAdvisor { + + private final String name; + + private final int order; + + private TestAdvisor(String name, int order) { + this.name = name; + this.order = order; + } + + @Override + public String getName() { + return name; + } + + @Override + public int getOrder() { + return order; + } + + @Override + public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { + System.out.println(callAdvisorChain.hasNextCallAdvisor()); + return callAdvisorChain.nextCall(chatClientRequest); + } + + @Override + public Flux adviseStream(ChatClientRequest chatClientRequest, + StreamAdvisorChain streamAdvisorChain) { + System.out.println(streamAdvisorChain.hasNextStreamAdvisor()); + return streamAdvisorChain.nextStream(chatClientRequest); + } + + } + }