1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ """
4
+ 基于 DeepSeek Coder 模型的代码助手机器人
5
+ """
6
+
7
+ import os
8
+ import argparse
9
+ import re
10
+ import time
11
+ from rich .console import Console
12
+ from rich .markdown import Markdown
13
+ from rich .panel import Panel
14
+ from rich .syntax import Syntax
15
+ from prompt_toolkit import PromptSession
16
+ from prompt_toolkit .history import FileHistory
17
+ from prompt_toolkit .auto_suggest import AutoSuggestFromHistory
18
+ from mindnlp .transformers import AutoModelForCausalLM , AutoTokenizer
19
+
20
+ console = Console ()
21
+
22
+ class CodeAssistant :
23
+ """代码助手类,使用 DeepSeek Coder 模型提供代码生成和解释服务"""
24
+
25
+ def __init__ (self , model_name = "deepseek-ai/deepseek-coder-1.3b-base" ):
26
+ """初始化代码助手"""
27
+ self .model_name = model_name
28
+
29
+ # 加载模型和分词器
30
+ console .print (f"正在加载 [bold]{ model_name } [/bold] 模型..." , style = "yellow" )
31
+ self .tokenizer = AutoTokenizer .from_pretrained (model_name )
32
+ self .model = AutoModelForCausalLM .from_pretrained (model_name )
33
+ console .print ("模型加载完成!" , style = "green" )
34
+
35
+ # 对话历史
36
+ self .conversation_history = []
37
+
38
+ # 命令列表
39
+ self .commands = {
40
+ "/help" : self .show_help ,
41
+ "/clear" : self .clear_history ,
42
+ "/save" : self .save_conversation ,
43
+ "/exit" : lambda : "exit" ,
44
+ "/examples" : self .show_examples
45
+ }
46
+
47
+ def start (self ):
48
+ """启动交互式代码助手"""
49
+ console .print (Panel .fit (
50
+ "[bold]DeepSeek Coder 代码助手[/bold]\n \n "
51
+ "一个基于 DeepSeek Coder 模型的代码生成和解释工具\n "
52
+ "输入 [bold blue]/help[/bold blue] 查看帮助信息\n "
53
+ "输入 [bold blue]/exit[/bold blue] 退出程序" ,
54
+ title = "欢迎使用" ,
55
+ border_style = "green"
56
+ ))
57
+
58
+ # 创建历史记录文件
59
+ history_file = os .path .expanduser ("~/.code_assistant_history" )
60
+ session = PromptSession (history = FileHistory (history_file ),
61
+ auto_suggest = AutoSuggestFromHistory ())
62
+
63
+ while True :
64
+ try :
65
+ user_input = session .prompt ("\n [用户] > " )
66
+
67
+ # 处理命令
68
+ if user_input .strip ().startswith ("/" ):
69
+ command = user_input .strip ().split ()[0 ]
70
+ if command in self .commands :
71
+ result = self .commands [command ]()
72
+ if result == "exit" :
73
+ break
74
+ continue
75
+
76
+ if not user_input .strip ():
77
+ continue
78
+
79
+ # 将用户输入添加到历史记录
80
+ self .conversation_history .append (f"[用户] { user_input } " )
81
+
82
+ # 获取回复
83
+ start_time = time .time ()
84
+ console .print ("[AI 思考中...]" , style = "yellow" )
85
+
86
+ response = self .generate_response (user_input )
87
+
88
+ # 提取代码块
89
+ code_blocks = self .extract_code_blocks (response )
90
+
91
+ # 格式化输出
92
+ console .print ("\n [AI 助手]" , style = "bold green" )
93
+
94
+ # 如果有代码块,特殊处理
95
+ if code_blocks :
96
+ parts = re .split (r'```(?:\w+)?\n|```' , response )
97
+ i = 0
98
+ for part in parts :
99
+ if part .strip ():
100
+ if i % 2 == 0 : # 文本部分
101
+ console .print (Markdown (part .strip ()))
102
+ else : # 代码部分
103
+ lang = self .detect_language (code_blocks [(i - 1 )// 2 ])
104
+ console .print (Syntax (code_blocks [(i - 1 )// 2 ], lang , theme = "monokai" ,
105
+ line_numbers = True , word_wrap = True ))
106
+ i += 1
107
+ else :
108
+ # 没有代码块,直接显示为Markdown
109
+ console .print (Markdown (response ))
110
+
111
+ elapsed_time = time .time () - start_time
112
+ console .print (f"[生成用时: { elapsed_time :.2f} 秒]" , style = "dim" )
113
+
114
+ # 将回复添加到历史记录
115
+ self .conversation_history .append (f"[AI] { response } " )
116
+
117
+ except KeyboardInterrupt :
118
+ console .print ("\n 中断操作..." , style = "bold red" )
119
+ break
120
+ except Exception as e :
121
+ console .print (f"\n 发生错误: { str (e )} " , style = "bold red" )
122
+
123
+ def generate_response (self , prompt , max_length = 1000 , temperature = 0.7 ):
124
+ """生成回复"""
125
+ # 处理提示
126
+ if "代码" in prompt or "函数" in prompt or "实现" in prompt or "编写" in prompt :
127
+ # 检测是否已经包含了代码格式声明
128
+ if not "```" in prompt :
129
+ prompt = f"```python\n # { prompt } \n "
130
+
131
+ inputs = self .tokenizer (prompt , return_tensors = "ms" )
132
+
133
+ # 生成回复
134
+ generated_ids = self .model .generate (
135
+ inputs .input_ids ,
136
+ max_length = max_length ,
137
+ do_sample = True ,
138
+ temperature = temperature ,
139
+ top_p = 0.95 ,
140
+ top_k = 50 ,
141
+ )
142
+
143
+ response = self .tokenizer .decode (generated_ids [0 ], skip_special_tokens = True )
144
+
145
+ # 清理响应,如果有的话
146
+ if prompt in response :
147
+ response = response .replace (prompt , "" , 1 ).strip ()
148
+
149
+ return response
150
+
151
+ def extract_code_blocks (self , text ):
152
+ """从文本中提取代码块"""
153
+ pattern = r'```(?:\w+)?\n(.*?)```'
154
+ matches = re .findall (pattern , text , re .DOTALL )
155
+ return matches
156
+
157
+ def detect_language (self , code ):
158
+ """简单检测代码语言"""
159
+ if "def " in code and ":" in code :
160
+ return "python"
161
+ elif "{" in code and "}" in code and ";" in code :
162
+ if "public class" in code or "private" in code :
163
+ return "java"
164
+ elif "function" in code or "var" in code or "let" in code or "const" in code :
165
+ return "javascript"
166
+ else :
167
+ return "cpp"
168
+ elif "<" in code and ">" in code and ("</" in code or "/>" in code ):
169
+ return "html"
170
+ else :
171
+ return "text"
172
+
173
+ def show_help (self ):
174
+ """显示帮助信息"""
175
+ help_text = """
176
+ # 可用命令:
177
+
178
+ - `/help` - 显示此帮助信息
179
+ - `/clear` - 清除当前对话历史
180
+ - `/save` - 保存当前对话到文件
181
+ - `/examples` - 显示示例提示
182
+ - `/exit` - 退出程序
183
+
184
+ # 使用技巧:
185
+
186
+ 1. 提供详细的需求描述以获得更好的代码生成效果
187
+ 2. 如果生成的代码不满意,可以要求修改或优化
188
+ 3. 可以请求解释已有代码或调试问题
189
+ 4. 对复杂功能,建议分步骤请求实现
190
+ """
191
+ console .print (Markdown (help_text ))
192
+
193
+ def clear_history (self ):
194
+ """清除对话历史"""
195
+ self .conversation_history = []
196
+ console .print ("已清除对话历史" , style = "green" )
197
+
198
+ def save_conversation (self ):
199
+ """保存对话到文件"""
200
+ if not self .conversation_history :
201
+ console .print ("没有对话内容可保存" , style = "yellow" )
202
+ return
203
+
204
+ filename = f"code_assistant_conversation_{ int (time .time ())} .md"
205
+ with open (filename , "w" , encoding = "utf-8" ) as f :
206
+ f .write ("# DeepSeek Coder 代码助手对话记录\n \n " )
207
+ for entry in self .conversation_history :
208
+ if entry .startswith ("[用户]" ):
209
+ f .write (f"## { entry } \n \n " )
210
+ else :
211
+ f .write (f"{ entry [5 :]} \n \n " )
212
+
213
+ console .print (f"对话已保存到 { filename } " , style = "green" )
214
+
215
+ def show_examples (self ):
216
+ """显示示例提示"""
217
+ examples = """
218
+ # 示例提示:
219
+
220
+ 1. "实现一个Python函数,计算两个日期之间的工作日数量"
221
+
222
+ 2. "编写一个简单的Flask API,具有用户注册和登录功能"
223
+
224
+ 3. "创建一个二分查找算法的JavaScript实现"
225
+
226
+ 4. "使用pandas分析CSV数据并生成统计报告"
227
+
228
+ 5. "实现一个简单的React组件,显示待办事项列表"
229
+
230
+ 6. "解释以下代码的功能:
231
+ ```python
232
+ def mystery(arr):
233
+ return [x for x in arr if x == x[::-1]]
234
+ ```"
235
+
236
+ 7. "优化下面的排序算法:
237
+ ```python
238
+ def sort(arr):
239
+ for i in range(len(arr)):
240
+ for j in range(len(arr)):
241
+ if arr[i] < arr[j]:
242
+ arr[i], arr[j] = arr[j], arr[i]
243
+ return arr
244
+ ```"
245
+ """
246
+ console .print (Markdown (examples ))
247
+
248
+
249
+ def main ():
250
+ """主函数"""
251
+ parser = argparse .ArgumentParser (description = "DeepSeek Coder 代码助手" )
252
+ parser .add_argument ("--model" , type = str , default = "deepseek-ai/deepseek-coder-1.3b-base" ,
253
+ help = "使用的模型名称或路径" )
254
+ args = parser .parse_args ()
255
+
256
+ # 创建并启动代码助手
257
+ assistant = CodeAssistant (model_name = args .model )
258
+ assistant .start ()
259
+
260
+
261
+ if __name__ == "__main__" :
262
+ main ()
0 commit comments