1+ # Copyright 2024 Huawei Technologies Co., Ltd
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+ # ============================================================================
15+
16+
17+ import argparse
18+ from langchain_community .vectorstores import FAISS
19+ from langchain_text_splitters import CharacterTextSplitter
20+
21+ import mindnlp
22+ from embedding import EmbeddingsFunAdapter
23+ from text import TextLoader
24+ from threading import Thread
25+
26+ import mindspore
27+ from transformers import AutoTokenizer , AutoModelForCausalLM , TextIteratorStreamer
28+
29+ def load_knowledge_base (file_name ):
30+ print (f"正在加载知识库文件: { file_name } " )
31+ loader = TextLoader (file_name )
32+ texts = loader .load ()
33+ text_splitter = CharacterTextSplitter (separator = '\n ' , chunk_size = 256 , chunk_overlap = 0 )
34+ split_docs = text_splitter .split_text (texts )
35+ print (f"文档已切分为 { len (split_docs )} 个片段" )
36+
37+ embeddings = EmbeddingsFunAdapter ("Qwen/Qwen3-Embedding-0.6B" )
38+ faiss = FAISS .from_texts (split_docs , embeddings )
39+ print ("FAISS 向量数据库构建完成。" )
40+ return faiss
41+
42+
43+ def load_model_and_tokenizer ():
44+ print ("正在加载模型" )
45+ tokenizer = AutoTokenizer .from_pretrained ('deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B' , use_fast = False , mirror = 'modelscope' , trust_remote_code = True )
46+ model = AutoModelForCausalLM .from_pretrained ('deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B' , ms_dtype = mindspore .bfloat16 ,mirror = 'modelscope' , device_map = 0 )
47+
48+
49+ print ("模型加载完成。" )
50+ return tokenizer , model
51+
52+
53+ def retrieve_knowledge (faiss , query ):
54+ docs = faiss .similarity_search (query , k = 1 )
55+ return docs [0 ].page_content
56+
57+ def generate_answer (tokenizer , model , query , knowledge = None ):
58+ if knowledge :
59+ input_text = knowledge + "\n \n " + query
60+ else :
61+ input_text = query
62+
63+ messages = [
64+ {"role" : "user" , "content" : input_text }
65+ ]
66+
67+ # 使用 tokenizer.apply_chat_template 构建输入
68+ try :
69+ prompt = tokenizer .apply_chat_template (
70+ messages ,
71+ tokenize = False ,
72+ add_generation_prompt = True
73+ )
74+ except Exception as e :
75+ print (f"⚠️ apply_chat_template 失败,使用手动拼接: { e } " )
76+ prompt = f"<|im_start|>user\n { input_text } <|im_end|>\n <|im_start|>assistant\n "
77+
78+ # Tokenize
79+ inputs = tokenizer (prompt , return_tensors = "pt" , truncation = True , max_length = 8192 ).to (model .device )
80+
81+ # 创建 streamer
82+ streamer = TextIteratorStreamer (
83+ tokenizer ,
84+ skip_prompt = True , # 跳过输入部分
85+ skip_special_tokens = True # 不输出特殊 token
86+ )
87+
88+ # 启动生成线程
89+ def generate ():
90+ model .generate (
91+ ** inputs ,
92+ streamer = streamer ,
93+ max_new_tokens = 512 ,
94+ temperature = 0.001 ,
95+ top_p = 0.9 ,
96+ do_sample = True ,
97+ pad_token_id = tokenizer .eos_token_id
98+ )
99+
100+ thread = Thread (target = generate )
101+ thread .start ()
102+
103+ # 实时输出生成的文本
104+ print ("回答: " , end = "" , flush = True )
105+ generated_text = ""
106+ for new_text in streamer :
107+ print (new_text , end = "" , flush = True )
108+ generated_text += new_text
109+ print () # 换行
110+
111+ return generated_text .strip ()
112+
113+
114+
115+ def rag_pipeline (faiss , tokenizer , model , query , use_rag = True ):
116+ if use_rag :
117+ knowledge = retrieve_knowledge (faiss , query )
118+ answer = generate_answer (tokenizer , model , query , knowledge )
119+ return answer , knowledge
120+ else :
121+ answer = generate_answer (tokenizer , model , query , "" )
122+ return answer , ""
123+
124+
125+ def main ():
126+ parser = argparse .ArgumentParser (description = "RAG Demo - Command Line Version" )
127+ parser .add_argument ("filename" , help = "知识库文本文件路径" )
128+ args = parser .parse_args ()
129+
130+ # 加载知识库和模型
131+ faiss_db = load_knowledge_base (args .filename )
132+ tokenizer , model = load_model_and_tokenizer ()
133+
134+ print ("\n " + "=" * 60 )
135+ print ("RAG系统已准备就绪!" )
136+ print ("输入 'quit' 或 'exit' 退出程序。" )
137+ print ("=" * 60 )
138+
139+ while True :
140+ try :
141+ # 获取用户输入
142+ query = input ("\n 请输入您的问题: " ).strip ()
143+ if query .lower () in ['quit' , 'exit' , 'bye' ]:
144+ print ("再见!" )
145+ break
146+ if not query :
147+ print ("问题不能为空,请重新输入。" )
148+ continue
149+
150+ # 是否启用 RAG
151+ use_rag_input = input ("是否启用检索增强 (RAG)? [Y/n]: " ).strip ().lower ()
152+ use_rag = use_rag_input not in ['n' , 'no' , 'N' , 'NO' ]
153+
154+ # RAG 流程
155+ if use_rag :
156+ print ("正在检索知识库..." )
157+ knowledge = retrieve_knowledge (faiss_db , query )
158+ print (f"检索到的知识:\n { knowledge } " )
159+ # print("生成中: ", end="", flush=True)
160+ answer = generate_answer (tokenizer , model , query , knowledge )
161+ else :
162+ print ("直接生成回答(无检索)..." )
163+ # print("生成中: ", end="", flush=True)
164+ answer = generate_answer (tokenizer , model , query )
165+
166+ except KeyboardInterrupt :
167+ print ("\n \n 程序被用户中断,再见!" )
168+ break
169+
170+
171+ if __name__ == "__main__" :
172+ main ()
0 commit comments