Skip to content

Commit b88ee1b

Browse files
committed
feat: add rag
1 parent 91aab3e commit b88ee1b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

100 files changed

+18416
-0
lines changed
276 Bytes
Binary file not shown.
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package com.tinyengine.it.controller;
2+
3+
import com.tinyengine.it.common.base.Result;
4+
import com.tinyengine.it.common.log.SystemControllerLog;
5+
import com.tinyengine.it.rag.VectorStorageService;
6+
import com.tinyengine.it.rag.entity.VectorDocument;
7+
import io.swagger.v3.oas.annotations.Operation;
8+
import io.swagger.v3.oas.annotations.Parameter;
9+
import io.swagger.v3.oas.annotations.media.Content;
10+
import io.swagger.v3.oas.annotations.media.Schema;
11+
import io.swagger.v3.oas.annotations.responses.ApiResponse;
12+
import io.swagger.v3.oas.annotations.tags.Tag;
13+
import org.springframework.validation.annotation.Validated;
14+
import org.springframework.web.bind.annotation.PostMapping;
15+
import org.springframework.web.bind.annotation.RequestBody;
16+
import org.springframework.web.bind.annotation.RequestMapping;
17+
import org.springframework.web.bind.annotation.RestController;
18+
19+
import java.util.List;
20+
21+
/**
22+
* The type vector storage controller.
23+
*
24+
* @since 2025-9-25
25+
*/
26+
@Validated
27+
@RestController
28+
@RequestMapping("/app-center/api")
29+
@Tag(name = "VectorStorage")
30+
public class VectorStorageController {
31+
32+
/**
33+
* file storage
34+
*
35+
* @param filePath the filePath
36+
* @return ai回答信息 result
37+
*/
38+
@Operation(summary = "文件向量存储", description = "文件向量存储",
39+
parameters = {
40+
@Parameter(name = "filePath", description = "入参对象")
41+
}, responses = {
42+
@ApiResponse(responseCode = "200", description = "返回信息",
43+
content = @Content(mediaType = "application/json", schema = @Schema())),
44+
@ApiResponse(responseCode = "400", description = "请求失败")
45+
})
46+
@SystemControllerLog(description = "文件向量存储")
47+
@PostMapping("/vector-storage/create")
48+
public Result<VectorDocument> create(@RequestBody List<String> filePath) {
49+
VectorDocument vectorDocument = VectorStorageService.initializeKnowledgeBase(filePath);
50+
return Result.success(vectorDocument);
51+
}
52+
}
Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
package com.tinyengine.it.rag;
2+
3+
import com.tinyengine.it.common.exception.ExceptionEnum;
4+
import com.tinyengine.it.common.exception.ServiceException;
5+
import com.tinyengine.it.rag.entity.RAGConfig;
6+
import com.tinyengine.it.rag.entity.VectorDocument;
7+
import dev.langchain4j.data.document.Document;
8+
import dev.langchain4j.data.document.DocumentSplitter;
9+
import dev.langchain4j.data.document.loader.FileSystemDocumentLoader;
10+
import dev.langchain4j.data.document.parser.TextDocumentParser;
11+
import dev.langchain4j.data.document.splitter.DocumentSplitters;
12+
import dev.langchain4j.data.embedding.Embedding;
13+
import dev.langchain4j.data.segment.TextSegment;
14+
import dev.langchain4j.model.embedding.EmbeddingModel;
15+
import dev.langchain4j.store.embedding.EmbeddingMatch;
16+
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
17+
import dev.langchain4j.store.embedding.EmbeddingStore;
18+
import lombok.extern.slf4j.Slf4j;
19+
20+
import java.nio.file.Path;
21+
import java.nio.file.Paths;
22+
import java.util.ArrayList;
23+
import java.util.List;
24+
import java.util.concurrent.ExecutorService;
25+
import java.util.concurrent.Executors;
26+
27+
@Slf4j
28+
public class VectorStorageService {
29+
private static EmbeddingModel embeddingModel = null;
30+
private static EmbeddingStore<TextSegment> embeddingStore = null;
31+
private final ExecutorService executorService;
32+
33+
/**
34+
* 使用 ChromaEmbeddingStore 的构造函数 - 修正版本
35+
*/
36+
public VectorStorageService(EmbeddingModel embeddingModel, EmbeddingStore<TextSegment> embeddingStore) {
37+
VectorStorageService.embeddingModel = embeddingModel;
38+
VectorStorageService.embeddingStore = embeddingStore;
39+
this.executorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
40+
}
41+
42+
/**
43+
* 添加文档到知识库
44+
*/
45+
public static VectorDocument initializeKnowledgeBase(List<String> documentPaths) {
46+
return initializeKnowledgeBase(documentPaths, null);
47+
}
48+
49+
/**
50+
* 添加文档到知识库(支持自定义元数据)
51+
*/
52+
public static VectorDocument initializeKnowledgeBase(List<String> documentPaths, String documentSetId) {
53+
try {
54+
List<Document> documents = loadDocuments(documentPaths, documentSetId);
55+
56+
if (documents.isEmpty()) {
57+
throw new ServiceException(ExceptionEnum.CM001.getResultCode(), "未成功加载任何文档");
58+
}
59+
60+
log.info("成功加载 {} 个文档", documents.size());
61+
62+
// 文档切分
63+
List<TextSegment> segments = splitDocuments(documents);
64+
log.info("生成 {} 个文本段", segments.size());
65+
66+
// 向量化并存储
67+
return embedAndStore(segments);
68+
69+
} catch (ServiceException e) {
70+
throw e;
71+
} catch (Exception e) {
72+
log.error("文档添加到知识库失败", e);
73+
throw new ServiceException(ExceptionEnum.CM001.getResultCode(), "文档处理失败: " + e.getMessage());
74+
}
75+
}
76+
77+
/**
78+
* 加载文档(支持元数据)
79+
*/
80+
private static List<Document> loadDocuments(List<String> documentPaths, String documentSetId) {
81+
List<Document> documents = new ArrayList<>();
82+
83+
for (String path : documentPaths) {
84+
try {
85+
Path filePath = Paths.get(path);
86+
Document document;
87+
88+
if (path.toLowerCase().endsWith(".pdf")) {
89+
document = FileSystemDocumentLoader.loadDocument(filePath);
90+
} else if (path.toLowerCase().endsWith(".txt") || path.toLowerCase().endsWith(".md")) {
91+
document = FileSystemDocumentLoader.loadDocument(filePath, new TextDocumentParser());
92+
} else {
93+
log.warn("不支持的文档格式: {}", path);
94+
continue;
95+
}
96+
97+
// 添加元数据
98+
if (documentSetId != null) {
99+
document.metadata().put("documentSetId", documentSetId);
100+
}
101+
document.metadata().put("source", path);
102+
document.metadata().put("timestamp", String.valueOf(System.currentTimeMillis()));
103+
104+
documents.add(document);
105+
log.info("✓ 加载文档: {}", path);
106+
107+
} catch (Exception e) {
108+
log.error("✗ 加载文档失败: {} - {}", path, e.getMessage());
109+
}
110+
}
111+
112+
return documents;
113+
}
114+
115+
/**
116+
* 文档切分
117+
*/
118+
private static List<TextSegment> splitDocuments(List<Document> documents) {
119+
DocumentSplitter splitter = DocumentSplitters.recursive(
120+
RAGConfig.CHUNK_SIZE,
121+
RAGConfig.CHUNK_OVERLAP
122+
);
123+
return splitter.splitAll(documents);
124+
}
125+
126+
/**
127+
* 向量化并存储(优化性能版本)
128+
*/
129+
private static VectorDocument embedAndStore(List<TextSegment> segments) {
130+
log.info("开始向量化存储...");
131+
long startTime = System.currentTimeMillis();
132+
133+
int successCount = 0;
134+
int errorCount = 0;
135+
136+
// 批量处理,提高性能
137+
int batchSize = 50;
138+
for (int i = 0; i < segments.size(); i += batchSize) {
139+
int end = Math.min(i + batchSize, segments.size());
140+
List<TextSegment> batch = segments.subList(i, end);
141+
142+
BatchResult result = processBatch(batch, i, segments.size());
143+
successCount += result.successCount;
144+
errorCount += result.errorCount;
145+
}
146+
147+
long endTime = System.currentTimeMillis();
148+
log.info("向量化完成: {} 成功, {} 失败, 耗时: {}ms", successCount, errorCount, (endTime - startTime));
149+
150+
return new VectorDocument(successCount, errorCount);
151+
}
152+
153+
/**
154+
* 处理批次数据的内部类
155+
*/
156+
private static class BatchResult {
157+
int successCount;
158+
int errorCount;
159+
160+
BatchResult(int successCount, int errorCount) {
161+
this.successCount = successCount;
162+
this.errorCount = errorCount;
163+
}
164+
}
165+
166+
/**
167+
* 处理批次数据
168+
*/
169+
private static BatchResult processBatch(List<TextSegment> batch, int startIndex, int totalSize) {
170+
int successCount = 0;
171+
int errorCount = 0;
172+
173+
List<Embedding> embeddings = new ArrayList<>();
174+
List<TextSegment> segmentsToStore = new ArrayList<>();
175+
176+
for (int i = 0; i < batch.size(); i++) {
177+
TextSegment segment = batch.get(i);
178+
try {
179+
Embedding embedding = embeddingModel.embed(segment.text()).content();
180+
embeddings.add(embedding);
181+
segmentsToStore.add(segment);
182+
successCount++;
183+
184+
if ((startIndex + i + 1) % 100 == 0) {
185+
log.info("已处理 {}/{} 个文本段", (startIndex + i + 1), totalSize);
186+
}
187+
} catch (Exception e) {
188+
errorCount++;
189+
log.error("向量化失败 [{}]: {}", (startIndex + i + 1),
190+
segment.text().substring(0, Math.min(100, segment.text().length())));
191+
}
192+
}
193+
194+
// 批量存储到 Chroma
195+
if (!embeddings.isEmpty()) {
196+
try {
197+
// 修正:使用正确的批量添加方法
198+
for (int i = 0; i < embeddings.size(); i++) {
199+
embeddingStore.add(embeddings.get(i), segmentsToStore.get(i));
200+
}
201+
log.debug("成功存储 {} 个文本段到 Chroma", embeddings.size());
202+
} catch (Exception e) {
203+
log.error("批量存储到 Chroma 失败", e);
204+
errorCount += embeddings.size(); // 标记为失败
205+
successCount -= embeddings.size();
206+
}
207+
}
208+
209+
return new BatchResult(successCount, errorCount);
210+
}
211+
212+
/**
213+
* 向量库检索
214+
*/
215+
public List<EmbeddingMatch<TextSegment>> search(String query, int maxResults, double minScore) {
216+
return search(query, maxResults, minScore, null);
217+
}
218+
219+
/**
220+
* 带过滤条件的检索 - 修正版本
221+
*/
222+
public List<EmbeddingMatch<TextSegment>> search(String query, int maxResults, double minScore, String documentSetId) {
223+
try {
224+
Embedding queryEmbedding = embeddingModel.embed(query).content();
225+
226+
// 修正:使用正确的搜索请求构建方式
227+
EmbeddingSearchRequest searchRequest = EmbeddingSearchRequest.builder()
228+
.queryEmbedding(queryEmbedding)
229+
.maxResults(maxResults)
230+
.minScore(minScore)
231+
.build();
232+
233+
List<EmbeddingMatch<TextSegment>> results = embeddingStore.search(searchRequest).matches();
234+
235+
// 如果指定了文档集ID,进行过滤
236+
if (documentSetId != null) {
237+
results = filterByDocumentSetId(results, documentSetId);
238+
}
239+
240+
log.info("检索到 {} 个相关文档", results.size());
241+
return results;
242+
243+
} catch (Exception e) {
244+
log.error("检索失败", e);
245+
throw new ServiceException(ExceptionEnum.CM001.getResultCode(), "检索失败: " + e.getMessage());
246+
}
247+
}
248+
249+
/**
250+
* 根据文档集ID过滤结果
251+
*/
252+
private List<EmbeddingMatch<TextSegment>> filterByDocumentSetId(
253+
List<EmbeddingMatch<TextSegment>> results, String documentSetId) {
254+
255+
List<EmbeddingMatch<TextSegment>> filteredResults = new ArrayList<>();
256+
257+
for (EmbeddingMatch<TextSegment> match : results) {
258+
String docSetId = match.embedded().metadata().getString("documentSetId");
259+
if (documentSetId.equals(docSetId)) {
260+
filteredResults.add(match);
261+
}
262+
}
263+
264+
return filteredResults;
265+
}
266+
267+
/**
268+
* 完整的问答流程
269+
*/
270+
public List<EmbeddingMatch<TextSegment>> askQuestion(String question) {
271+
return askQuestion(question, RAGConfig.MAX_RESULTS, RAGConfig.MIN_SCORE, null);
272+
}
273+
274+
public List<EmbeddingMatch<TextSegment>> askQuestion(String question, int maxResults, double minScore, String documentSetId) {
275+
try {
276+
long startTime = System.currentTimeMillis();
277+
278+
// 1. 检索相关文档
279+
List<EmbeddingMatch<TextSegment>> searchResults = search(question, maxResults, minScore, documentSetId);
280+
long retrievalTime = System.currentTimeMillis() - startTime;
281+
282+
log.info("检索耗时: {}ms", retrievalTime);
283+
284+
if (searchResults.isEmpty()) {
285+
log.warn("未找到相关文档");
286+
return searchResults;
287+
}
288+
289+
// 打印检索结果
290+
for (int i = 0; i < Math.min(3, searchResults.size()); i++) {
291+
EmbeddingMatch<TextSegment> match = searchResults.get(i);
292+
log.debug("结果 {} - 相似度: {:.4f}", i + 1, match.score());
293+
}
294+
295+
return searchResults;
296+
297+
} catch (Exception e) {
298+
log.error("问答流程失败", e);
299+
throw new ServiceException(ExceptionEnum.CM001.getResultCode(), "问答失败: " + e.getMessage());
300+
}
301+
}
302+
303+
/**
304+
* 清空向量库
305+
*/
306+
public void clearVectorStore() {
307+
try {
308+
309+
// 在 0.29.0 版本中,可能需要通过其他方式清空
310+
log.info("请通过 Chroma API 清空向量库数据");
311+
} catch (Exception e) {
312+
log.error("清空向量库失败", e);
313+
throw new ServiceException(ExceptionEnum.CM001.getResultCode(), "清空向量库失败");
314+
}
315+
}
316+
317+
/**
318+
* 关闭资源
319+
*/
320+
public void shutdown() {
321+
executorService.shutdown();
322+
log.info("VectorStorageService 已关闭");
323+
}
324+
325+
/**
326+
* 获取向量库统计信息
327+
*/
328+
public void getVectorStoreStats() {
329+
try {
330+
log.info("向量库服务运行中 - 模型: {}, 存储: {}",
331+
embeddingModel.getClass().getSimpleName(),
332+
embeddingStore.getClass().getSimpleName());
333+
} catch (Exception e) {
334+
log.error("获取向量库统计信息失败", e);
335+
}
336+
}
337+
}

0 commit comments

Comments
 (0)