Skip to content

fix: Incorrect order when Advisors have the same order #3461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ private ChatModelCallAdvisor(ChatModel chatModel) {
@Override
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null");
Assert.isTrue(!callAdvisorChain.hasNextCallAdvisor(),
"ChatModelCallAdvisor should be the last CallAdvisor in the chain");

ChatClientRequest formattedChatClientRequest = augmentWithFormatInstructions(chatClientRequest);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ private ChatModelStreamAdvisor(ChatModel chatModel) {
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
StreamAdvisorChain streamAdvisorChain) {
Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null");
Assert.isTrue(!streamAdvisorChain.hasNextStreamAdvisor(),
"ChatModelStreamAdvisor should be the last StreamAdvisor in the chain");

return this.chatModel.stream(chatClientRequest.prompt())
.map(chatResponse -> ChatClientResponse.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,21 @@ public List<CallAdvisor> getCallAdvisors() {
return this.originalCallAdvisors;
}

@Override
public boolean hasNextCallAdvisor() {
return !this.callAdvisors.isEmpty();
}

@Override
public List<StreamAdvisor> getStreamAdvisors() {
return this.originalStreamAdvisors;
}

@Override
public boolean hasNextStreamAdvisor() {
return !this.streamAdvisors.isEmpty();
}

@Override
public ObservationRegistry getObservationRegistry() {
return this.observationRegistry;
Expand Down Expand Up @@ -192,7 +202,7 @@ public Builder pushAll(List<? extends Advisor> advisors) {
.toList();

if (!CollectionUtils.isEmpty(callAroundAdvisorList)) {
callAroundAdvisorList.forEach(this.callAdvisors::push);
this.callAdvisors.addAll(callAroundAdvisorList);
}

List<StreamAdvisor> streamAroundAdvisorList = advisors.stream()
Expand All @@ -201,7 +211,7 @@ public Builder pushAll(List<? extends Advisor> advisors) {
.toList();

if (!CollectionUtils.isEmpty(streamAroundAdvisorList)) {
streamAroundAdvisorList.forEach(this.streamAdvisors::push);
this.streamAdvisors.addAll(streamAroundAdvisorList);
}

this.reOrder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
* @author Christian Tzolov
* @author Dariusz Jedrzejczyk
* @author Thomas Vitale
* @author YunKui Lu
* @since 1.0.0
*/
public interface CallAdvisorChain extends AdvisorChain {
Expand All @@ -44,4 +45,11 @@ public interface CallAdvisorChain extends AdvisorChain {
*/
List<CallAdvisor> getCallAdvisors();

/**
* Returns true if there is a next {@link CallAdvisor} in the chain.
*/
default boolean hasNextCallAdvisor() {
throw new UnsupportedOperationException("This CallAdvisorChain does not support hasNextCallAdvisor()");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
* @author Christian Tzolov
* @author Dariusz Jedrzejczyk
* @author Thomas Vitale
* @author YunKui Lu
* @since 1.0.0
*/
public interface StreamAdvisorChain extends AdvisorChain {
Expand All @@ -46,4 +47,11 @@ public interface StreamAdvisorChain extends AdvisorChain {
*/
List<StreamAdvisor> getStreamAdvisors();

/**
* Returns true if there is a next {@link StreamAdvisor} in the chain.
*/
default boolean hasNextStreamAdvisor() {
throw new UnsupportedOperationException("This StreamAdvisorChain does not support hasNextStreamAdvisor()");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@

import org.junit.jupiter.api.Test;

import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
import org.springframework.ai.chat.model.ChatModel;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

/**
* Unit tests for {@link ChatModelCallAdvisor}.
Expand All @@ -34,4 +40,19 @@ void whenChatModelIsNullThenThrow() {
.hasMessage("chatModel cannot be null");
}

@Test
void whenNotLastInChainThrow() {
ChatModel chatModel = mock(ChatModel.class);
ChatClientRequest chatClientRequest = mock(ChatClientRequest.class);
CallAdvisorChain callAdvisorChain = mock(CallAdvisorChain.class);

when(callAdvisorChain.hasNextCallAdvisor()).thenReturn(true);

ChatModelCallAdvisor chatModelCallAdvisor = ChatModelCallAdvisor.builder().chatModel(chatModel).build();

assertThatThrownBy(() -> chatModelCallAdvisor.adviseCall(chatClientRequest, callAdvisorChain))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("ChatModelCallAdvisor should be the last CallAdvisor in the chain");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@

import org.junit.jupiter.api.Test;

import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
import org.springframework.ai.chat.model.ChatModel;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

/**
* Unit tests for {@link ChatModelStreamAdvisor}.
Expand All @@ -34,4 +40,19 @@ void whenChatModelIsNullThenThrow() {
.hasMessage("chatModel cannot be null");
}

@Test
void whenNotLastInChainThrow() {
ChatModel chatModel = mock(ChatModel.class);
ChatClientRequest chatClientRequest = mock(ChatClientRequest.class);
StreamAdvisorChain streamAdvisorChain = mock(StreamAdvisorChain.class);

when(streamAdvisorChain.hasNextStreamAdvisor()).thenReturn(true);

ChatModelStreamAdvisor chatModelStreamAdvisor = ChatModelStreamAdvisor.builder().chatModel(chatModel).build();

assertThatThrownBy(() -> chatModelStreamAdvisor.adviseStream(chatClientRequest, streamAdvisorChain))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("ChatModelStreamAdvisor should be the last StreamAdvisor in the chain");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,34 @@ void getCallAdvisors() {
assertThat(chain.getCallAdvisors()).containsExactlyInAnyOrder(advisors.toArray(new CallAdvisor[0]));
}

@Test
void hasNextCallAdvisor() {
// The first advisor
TestAdvisor advisor1 = new TestAdvisor("advisor1", 1) {
@Override
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest,
CallAdvisorChain callAdvisorChain) {
assertThat(callAdvisorChain.hasNextCallAdvisor()).isTrue();
return callAdvisorChain.nextCall(chatClientRequest);
}
};

// The last advisor
TestAdvisor advisor2 = new TestAdvisor("advisor2", 2) {
@Override
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest,
CallAdvisorChain callAdvisorChain) {
assertThat(callAdvisorChain.hasNextCallAdvisor()).isFalse();
return null;
}
};

List<CallAdvisor> advisors = List.of(advisor1, advisor2);
CallAdvisorChain chain = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP).pushAll(advisors).build();

chain.nextCall(mock(ChatClientRequest.class));
}

@Test
void getStreamAdvisors() {
StreamAdvisor mockAdvisor1 = mock(StreamAdvisor.class);
Expand All @@ -125,4 +153,86 @@ void getStreamAdvisors() {
assertThat(chain.getStreamAdvisors()).containsExactlyInAnyOrder(advisors.toArray(new StreamAdvisor[0]));
}

@Test
void hasNextStreamAdvisor() {
// The first advisor
TestAdvisor advisor1 = new TestAdvisor("advisor1", 1) {
@Override
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
StreamAdvisorChain streamAdvisorChain) {
assertThat(streamAdvisorChain.hasNextStreamAdvisor()).isTrue();
return streamAdvisorChain.nextStream(chatClientRequest);
}
};
// The last advisor
TestAdvisor advisor2 = new TestAdvisor("advisor2", 2) {
@Override
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
StreamAdvisorChain streamAdvisorChain) {
assertThat(streamAdvisorChain.hasNextStreamAdvisor()).isFalse();
return Flux.empty();
}
};

List<StreamAdvisor> advisors = List.of(advisor1, advisor2);
StreamAdvisorChain chain = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP)
.pushAll(advisors)
.build();

chain.nextStream(mock(ChatClientRequest.class)).blockLast();
}

@Test
void testOrder() {
TestAdvisor advisor1 = new TestAdvisor("advisor1", 1);
TestAdvisor advisor21 = new TestAdvisor("advisor2_1", 2);
TestAdvisor advisor22 = new TestAdvisor("advisor2_2", 2);
TestAdvisor advisor3 = new TestAdvisor("advisor3", 3);

var advisors = List.of(advisor3, advisor1, advisor21, advisor22);

DefaultAroundAdvisorChain chain = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP)
.pushAll(advisors)
.build();

assertThat(chain.getStreamAdvisors()).containsExactly(advisor1, advisor21, advisor22, advisor3);
assertThat(chain.getCallAdvisors()).containsExactly(advisor1, advisor21, advisor22, advisor3);
}

private static class TestAdvisor implements CallAdvisor, StreamAdvisor {

private final String name;

private final int order;

private TestAdvisor(String name, int order) {
this.name = name;
this.order = order;
}

@Override
public String getName() {
return name;
}

@Override
public int getOrder() {
return order;
}

@Override
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
System.out.println(callAdvisorChain.hasNextCallAdvisor());
return callAdvisorChain.nextCall(chatClientRequest);
}

@Override
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
StreamAdvisorChain streamAdvisorChain) {
System.out.println(streamAdvisorChain.hasNextStreamAdvisor());
return streamAdvisorChain.nextStream(chatClientRequest);
}

}

}