Skip to content

Commit 8a0d021

Browse files
moyu026liudengjin
andauthored
add new rag method (#2204)
Co-authored-by: liudengjin <[email protected]>
1 parent 4e8cc09 commit 8a0d021

File tree

4 files changed

+193
-7
lines changed

4 files changed

+193
-7
lines changed

examples/transformers/rag/embedding.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,25 @@
2020

2121
from langchain.embeddings.base import Embeddings
2222

23-
from mindnlp.sentence import SentenceTransformer
23+
from sentence_transformers import SentenceTransformer
2424

2525

2626
class EmbeddingsFunAdapter(Embeddings):
27-
def __init__(self, embed_model, mirror='huggingface'):
27+
def __init__(self, embed_model):
2828
self.embed_model = embed_model
29-
self.embedding_model = SentenceTransformer(model_name_or_path=self.embed_model, mirror=mirror)
29+
self.embedding_model = SentenceTransformer(model_name_or_path=self.embed_model)
30+
31+
def encode_texts(self, texts: List[str]) -> List[List[float]]:
32+
texts = [t.replace("\n", " ") for t in texts]
33+
embeddings = self.embedding_model.encode(texts)
34+
for i, embedding in enumerate(embeddings):
35+
embeddings[i] = embedding.tolist()
36+
return embeddings
3037

3138
def embed_documents(self, texts: List[str]) -> List[List[float]]:
32-
embeddings = self.embedding_model.encode_texts(texts)
39+
embeddings = self.encode_texts(texts)
3340
return embeddings
3441

3542
def embed_query(self, text: str) -> List[float]:
36-
embeddings = self.embedding_model.encode_texts([text])
43+
embeddings = self.encode_texts([text])
3744
return embeddings[0]
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright 2024 Huawei Technologies Co., Ltd
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
16+
17+
import argparse
18+
from langchain_community.vectorstores import FAISS
19+
from langchain_text_splitters import CharacterTextSplitter
20+
21+
import mindnlp
22+
from embedding import EmbeddingsFunAdapter
23+
from text import TextLoader
24+
from threading import Thread
25+
26+
import mindspore
27+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
28+
29+
def load_knowledge_base(file_name):
30+
print(f"正在加载知识库文件: {file_name}")
31+
loader = TextLoader(file_name)
32+
texts = loader.load()
33+
text_splitter = CharacterTextSplitter(separator='\n', chunk_size=256, chunk_overlap=0)
34+
split_docs = text_splitter.split_text(texts)
35+
print(f"文档已切分为 {len(split_docs)} 个片段")
36+
37+
embeddings = EmbeddingsFunAdapter("Qwen/Qwen3-Embedding-0.6B")
38+
faiss = FAISS.from_texts(split_docs, embeddings)
39+
print("FAISS 向量数据库构建完成。")
40+
return faiss
41+
42+
43+
def load_model_and_tokenizer():
44+
print("正在加载模型")
45+
tokenizer = AutoTokenizer.from_pretrained('deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', use_fast=False, mirror='modelscope', trust_remote_code=True)
46+
model = AutoModelForCausalLM.from_pretrained('deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', ms_dtype=mindspore.bfloat16,mirror='modelscope', device_map=0)
47+
48+
49+
print("模型加载完成。")
50+
return tokenizer, model
51+
52+
53+
def retrieve_knowledge(faiss, query):
54+
docs = faiss.similarity_search(query, k=1)
55+
return docs[0].page_content
56+
57+
def generate_answer(tokenizer, model, query, knowledge=None):
58+
if knowledge:
59+
input_text = knowledge + "\n\n" + query
60+
else:
61+
input_text = query
62+
63+
messages = [
64+
{"role": "user", "content": input_text}
65+
]
66+
67+
# 使用 tokenizer.apply_chat_template 构建输入
68+
try:
69+
prompt = tokenizer.apply_chat_template(
70+
messages,
71+
tokenize=False,
72+
add_generation_prompt=True
73+
)
74+
except Exception as e:
75+
print(f"⚠️ apply_chat_template 失败,使用手动拼接: {e}")
76+
prompt = f"<|im_start|>user\n{input_text}<|im_end|>\n<|im_start|>assistant\n"
77+
78+
# Tokenize
79+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=8192).to(model.device)
80+
81+
# 创建 streamer
82+
streamer = TextIteratorStreamer(
83+
tokenizer,
84+
skip_prompt=True, # 跳过输入部分
85+
skip_special_tokens=True # 不输出特殊 token
86+
)
87+
88+
# 启动生成线程
89+
def generate():
90+
model.generate(
91+
**inputs,
92+
streamer=streamer,
93+
max_new_tokens=512,
94+
temperature=0.001,
95+
top_p=0.9,
96+
do_sample=True,
97+
pad_token_id=tokenizer.eos_token_id
98+
)
99+
100+
thread = Thread(target=generate)
101+
thread.start()
102+
103+
# 实时输出生成的文本
104+
print("回答: ", end="", flush=True)
105+
generated_text = ""
106+
for new_text in streamer:
107+
print(new_text, end="", flush=True)
108+
generated_text += new_text
109+
print() # 换行
110+
111+
return generated_text.strip()
112+
113+
114+
115+
def rag_pipeline(faiss, tokenizer, model, query, use_rag=True):
116+
if use_rag:
117+
knowledge = retrieve_knowledge(faiss, query)
118+
answer = generate_answer(tokenizer, model, query, knowledge)
119+
return answer, knowledge
120+
else:
121+
answer = generate_answer(tokenizer, model, query, "")
122+
return answer, ""
123+
124+
125+
def main():
126+
parser = argparse.ArgumentParser(description="RAG Demo - Command Line Version")
127+
parser.add_argument("filename", help="知识库文本文件路径")
128+
args = parser.parse_args()
129+
130+
# 加载知识库和模型
131+
faiss_db = load_knowledge_base(args.filename)
132+
tokenizer, model = load_model_and_tokenizer()
133+
134+
print("\n" + "="*60)
135+
print("RAG系统已准备就绪!")
136+
print("输入 'quit' 或 'exit' 退出程序。")
137+
print("="*60)
138+
139+
while True:
140+
try:
141+
# 获取用户输入
142+
query = input("\n请输入您的问题: ").strip()
143+
if query.lower() in ['quit', 'exit', 'bye']:
144+
print("再见!")
145+
break
146+
if not query:
147+
print("问题不能为空,请重新输入。")
148+
continue
149+
150+
# 是否启用 RAG
151+
use_rag_input = input("是否启用检索增强 (RAG)? [Y/n]: ").strip().lower()
152+
use_rag = use_rag_input not in ['n', 'no', 'N', 'NO']
153+
154+
# RAG 流程
155+
if use_rag:
156+
print("正在检索知识库...")
157+
knowledge = retrieve_knowledge(faiss_db, query)
158+
print(f"检索到的知识:\n{knowledge}")
159+
# print("生成中: ", end="", flush=True)
160+
answer = generate_answer(tokenizer, model, query, knowledge)
161+
else:
162+
print("直接生成回答(无检索)...")
163+
# print("生成中: ", end="", flush=True)
164+
answer = generate_answer(tokenizer, model, query)
165+
166+
except KeyboardInterrupt:
167+
print("\n\n程序被用户中断,再见!")
168+
break
169+
170+
171+
if __name__ == "__main__":
172+
main()

examples/transformers/rag/readme.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#### Install dependencies
44

55
```
6-
pip install mindnlp langchain langchain-community faiss-cpu
6+
pip install -r requirements.txt
77
```
88

99
### Download knowledge file
@@ -16,5 +16,5 @@ wget https://raw.githubusercontent.com/limchiahooi/nlp-chinese/master/%E8%A5%BF%
1616
### Run RAG Demo
1717

1818
```
19-
streamlit run startup.py xiyouji.txt
19+
python newchat.py xiyouji.txt
2020
```
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
protobuf==3.20.3
2+
streamlit
3+
langchain
4+
langchain-community
5+
faiss-cpu
6+
transformers==4.55.4
7+
sentence_transformers

0 commit comments

Comments
 (0)