Skip to content

Commit f597aee

Browse files
committed
Feat(rerank): Add Cohere Reranker support with topN result filtering
- Implemented rerank logic using Cohere Rerank API with WebClient Signed-off-by: KoreaNirsa <[email protected]>
1 parent 6a1f982 commit f597aee

File tree

5 files changed

+279
-0
lines changed

5 files changed

+279
-0
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package org.springframework.ai.rag.postretrieval.rerank;
2+
3+
/**
4+
* Represents the API key holder for Cohere API authentication.
5+
*
6+
* @author KoreaNirsa
7+
*/
8+
public class CohereApi {
9+
private String apiKey;
10+
11+
public static Builder builder() {
12+
return new Builder();
13+
}
14+
15+
public String getApiKey() {
16+
return apiKey;
17+
}
18+
19+
public static class Builder {
20+
private final CohereApi instance = new CohereApi();
21+
22+
public Builder apiKey(String key) {
23+
instance.apiKey = key;
24+
return this;
25+
}
26+
27+
public CohereApi build() {
28+
if (instance.apiKey == null || instance.apiKey.isBlank()) {
29+
throw new IllegalArgumentException("API key must be provided.");
30+
}
31+
return instance;
32+
}
33+
}
34+
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
package org.springframework.ai.rag.postretrieval.rerank;
2+
3+
import java.util.Collections;
4+
import java.util.Comparator;
5+
import java.util.List;
6+
import java.util.Map;
7+
import java.util.Optional;
8+
9+
import org.slf4j.Logger;
10+
import org.slf4j.LoggerFactory;
11+
import org.springframework.ai.document.Document;
12+
import org.springframework.http.HttpHeaders;
13+
import org.springframework.web.reactive.function.client.WebClient;
14+
15+
/**
16+
* A Reranker implementation that integrates with Cohere's Rerank API.
17+
* This component reorders retrieved documents based on semantic relevance to the input query.
18+
*
19+
* @author KoreaNirsa
20+
* @see <a href="https://docs.cohere.com/reference/rerank">Cohere Rerank API Documentation</a>
21+
*/
22+
public class CohereReranker {
23+
private static final String COHERE_RERANK_ENDPOINT = "https://api.cohere.ai/v1/rerank";
24+
25+
private static final Logger logger = LoggerFactory.getLogger(CohereReranker.class);
26+
27+
private static final int MAX_DOCUMENTS = 1000;
28+
29+
private final WebClient webClient;
30+
31+
/**
32+
* Constructs a CohereReranker that communicates with the Cohere Rerank API.
33+
* Initializes the internal WebClient with the provided API key for authorization.
34+
*
35+
* @param cohereApi the API configuration object containing the required API key (must not be null)
36+
* @throws IllegalArgumentException if cohereApi is null
37+
*/
38+
CohereReranker(CohereApi cohereApi) {
39+
if (cohereApi == null) {
40+
throw new IllegalArgumentException("CohereApi must not be null");
41+
}
42+
43+
this.webClient = WebClient.builder()
44+
.baseUrl(COHERE_RERANK_ENDPOINT)
45+
.defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer " + cohereApi.getApiKey())
46+
.build();
47+
}
48+
49+
/**
50+
* Reranks a list of documents based on the provided query using the Cohere API.
51+
*
52+
* @param query The user input query.
53+
* @param documents The list of documents to rerank.
54+
* @param topN The number of top results to return (at most).
55+
* @return A reranked list of documents. If the API fails, returns the original list.
56+
*/
57+
public List<Document> rerank(String query, List<Document> documents, int topN) {
58+
if (topN < 1) {
59+
throw new IllegalArgumentException("topN must be ≥ 1. Provided: " + topN);
60+
}
61+
62+
if (documents == null || documents.isEmpty()) {
63+
logger.warn("Empty document list provided. Skipping rerank.");
64+
return Collections.emptyList();
65+
}
66+
67+
if (documents.size() > MAX_DOCUMENTS) {
68+
logger.warn("Cohere recommends ≤ {} documents per rerank request. Larger sizes may cause errors.", MAX_DOCUMENTS);
69+
return documents;
70+
}
71+
72+
int adjustedTopN = Math.min(topN, documents.size());
73+
74+
Map<String, Object> payload = Map.of(
75+
"query", query,
76+
"documents", documents.stream().map(Document::getText).toList(),
77+
"top_n", adjustedTopN
78+
);
79+
80+
// Call the API and process the result
81+
return sendRerankRequest(payload)
82+
.map(results -> results.stream()
83+
.sorted(Comparator.comparingDouble(RerankResponse.Result::getRelevanceScore).reversed())
84+
.map(r -> documents.get(r.getIndex()))
85+
.toList())
86+
.orElseGet(() -> {
87+
logger.warn("Cohere response is null or invalid");
88+
return documents;
89+
});
90+
}
91+
92+
/**
93+
* Sends a rerank request to the Cohere API and returns the result list.
94+
*
95+
* @param payload The request body including query, documents, and top_n.
96+
* @return An Optional list of reranked results, or empty if failed.
97+
*/
98+
private Optional<List<RerankResponse.Result>> sendRerankRequest(Map<String, Object> payload) {
99+
try {
100+
RerankResponse response = webClient.post()
101+
.bodyValue(payload)
102+
.retrieve()
103+
.bodyToMono(RerankResponse.class)
104+
.block();
105+
106+
return Optional.ofNullable(response)
107+
.map(RerankResponse::getResults);
108+
} catch (Exception e) {
109+
logger.error("Cohere rerank failed, fallback to original order: {}", e.getMessage(), e);
110+
return Optional.empty();
111+
}
112+
}
113+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package org.springframework.ai.rag.postretrieval.rerank;
2+
3+
import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor;
4+
import org.springframework.beans.factory.annotation.Value;
5+
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
6+
import org.springframework.context.annotation.Bean;
7+
import org.springframework.context.annotation.Configuration;
8+
9+
/**
10+
* Rerank configuration that conditionally registers a DocumentPostProcessor
11+
* when rerank is enabled via application properties.
12+
*
13+
* This configuration is activated only when the following properties are set
14+
*
15+
* <ul>
16+
* <li>spring.ai.rerank.enabled=true</li>
17+
* <li>spring.ai.rerank.cohere.api-key=your-api-key</li>
18+
* </ul>
19+
*
20+
* @author KoreaNirsa
21+
*/
22+
@Configuration
23+
public class RerankConfig {
24+
@Value("${spring.ai.rerank.cohere.api-key}")
25+
private String apiKey;
26+
27+
@Bean
28+
@ConditionalOnProperty(name = "spring.ai.rerank.enabled", havingValue = "true")
29+
public DocumentPostProcessor rerankerPostProcessor() {
30+
return new RerankerPostProcessor(CohereApi.builder().apiKey(apiKey).build());
31+
}
32+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package org.springframework.ai.rag.postretrieval.rerank;
2+
3+
import java.util.List;
4+
5+
import com.fasterxml.jackson.annotation.JsonProperty;
6+
7+
/**
8+
* Represents the response returned from Cohere's Rerank API.
9+
* The response includes a list of result objects that specify document indices
10+
* and their semantic relevance scores.
11+
*
12+
* @author KoreaNirsa
13+
*/
14+
public class RerankResponse {
15+
private List<Result> results;
16+
17+
public List<Result> getResults() {
18+
return results;
19+
}
20+
21+
public void setResults(List<Result> results) {
22+
this.results = results;
23+
}
24+
25+
/**
26+
* Represents a single reranked document result returned by the Cohere API.
27+
* Contains the original index and the computed relevance score.
28+
*/
29+
public static class Result {
30+
private int index;
31+
32+
@JsonProperty("relevance_score")
33+
private int relevanceScore;
34+
35+
public int getIndex() {
36+
return index;
37+
}
38+
39+
public void setIndex(int index) {
40+
this.index = index;
41+
}
42+
43+
public int getRelevanceScore() {
44+
return relevanceScore;
45+
}
46+
47+
public void setRelevanceScore(int relevanceScore) {
48+
this.relevanceScore = relevanceScore;
49+
}
50+
}
51+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package org.springframework.ai.rag.postretrieval.rerank;
2+
3+
import java.util.List;
4+
5+
import org.springframework.ai.document.Document;
6+
import org.springframework.ai.rag.Query;
7+
import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor;
8+
9+
/**
10+
* The only supported entrypoint for rerank functionality in Spring AI RAG.
11+
* This component delegates reranking logic to CohereReranker, using the provided API key.
12+
*
13+
* This class is registered as a DocumentPostProcessor bean only if
14+
* spring.ai.rerank.enabled=true is set in the application properties.
15+
*
16+
* @author KoreaNirsa
17+
*/
18+
public class RerankerPostProcessor implements DocumentPostProcessor {
19+
private final CohereReranker reranker;
20+
21+
RerankerPostProcessor(CohereApi cohereApi) {
22+
this.reranker = new CohereReranker(cohereApi);
23+
}
24+
25+
/**
26+
* Processes the retrieved documents by applying semantic reranking using the Cohere API
27+
*
28+
* @param query the user's input query
29+
* @param documents the list of documents to be reranked
30+
* @return a list of documents sorted by relevance score
31+
*/
32+
@Override
33+
public List<Document> process(Query query, List<Document> documents) {
34+
int topN = extractTopN(query);
35+
return reranker.rerank(query.text(), documents, topN);
36+
}
37+
38+
/**
39+
* Extracts the top-N value from the query context.
40+
* If not present or invalid, it defaults to 3
41+
*
42+
* @param query the query containing optional context parameters
43+
* @return the number of top documents to return
44+
*/
45+
private int extractTopN(Query query) {
46+
Object value = query.context().get("topN");
47+
return (value instanceof Number num) ? num.intValue() : 3;
48+
}
49+
}

0 commit comments

Comments
 (0)