Skip to content

Commit eec8154

Browse files
author
xusenlin
committed
Add Code Interpreter demo
1 parent a81b502 commit eec8154

File tree

10 files changed

+364
-11
lines changed

10 files changed

+364
-11
lines changed

api/routes/chat.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
9393
res, function_call = parse_response(content["text"])
9494
content["text"] = res
9595

96-
if isinstance(function_call, dict):
96+
if isinstance(function_call, dict) and "arguments" in function_call:
9797
finish_reason = "function_call"
9898
function_call = FunctionCallResponse(**function_call)
9999

@@ -126,6 +126,7 @@ async def chat_completion_stream_generator(
126126
https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
127127
"""
128128
_id = f"chatcmpl-{secrets.token_hex(12)}"
129+
use_tool = bool(gen_params["functions"] is not None)
129130
for i in range(n):
130131
# First chunk with role
131132
choice_data = ChatCompletionResponseStreamChoice(
@@ -160,14 +161,14 @@ async def chat_completion_stream_generator(
160161
function_call = None
161162
if finish_reason == "function_call" and "chatglm3" in config.MODEL_NAME.lower():
162163
try:
163-
function_call = process_response_v3(decoded_unicode, use_tool=True)
164+
function_call = process_response_v3(decoded_unicode, use_tool=use_tool)
164165
except:
165166
logger.warning("Failed to parse tool call")
166167

167168
elif finish_reason == "function_call" and "qwen" in config.MODEL_NAME.lower():
168169
_, function_call = parse_response(decoded_unicode)
169170

170-
if isinstance(function_call, dict):
171+
if isinstance(function_call, dict) and "arguments" in function_call:
171172
function_call = FunctionCallResponse(**function_call)
172173

173174
delta = DeltaMessage(

examples/chatglm3/tool_using.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def run_conversation(query: str, stream=False, functions=None, max_retry=5):
6565
params["messages"].append(
6666
{
6767
"role": "assistant",
68+
"function_call": function_call,
6869
"content": output
6970
}
7071
)

streamlit-demo/.env

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@ TOOL_CHAT_API_BASE = "http://192.168.20.59:7891/v1" # 调用工具模型接口
44
EMBEDDING_API_BASE = "http://192.168.0.53:7891/v1" # 嵌入模型接口地址(可选)
55
API_KEY = "xxx" # 默认不需要配置
66
EMBEDDING_NAME = "" # 使用本地嵌入模型的路径(可选,EMBEDDING_API_BASE 和 EMBEDDING_NAME 两种方式选一种即可)
7-
SERPAPI_API_KEY = "" # 搜索功能需要
7+
SERPAPI_API_KEY = "" # 搜索功能需要
8+
IPYKERNEL = "llm" # python解释器名称
9+
INTERPRETER_CHAT_API_BASE = "http://192.168.20.59:7891/v1" # 代码解释器模型接口地址(可选)

streamlit-demo/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,9 @@ streamlit run streamlit_app.py
88

99
**环境变量配置见 [.env](.env)**
1010

11+
## 代码解释器(基于 ChatGLM3 模型)【测试版本】
12+
13+
```shell
14+
ipython kernel install --name llm --user
15+
```
16+

streamlit-demo/streamlit_app.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
def main():
99
from streamlit_gallery.apps import gallery
10-
from streamlit_gallery.components import chat, doc_chat, sql_chat, search_chat, tool_chat
10+
from streamlit_gallery.components import chat, doc_chat, sql_chat, search_chat, tool_chat, code_interpreter
1111

1212
page = page_group("p")
1313

@@ -32,6 +32,9 @@ def main():
3232
if os.getenv("TOOL_CHAT_API_BASE", ""):
3333
page.item("Tool Chat", tool_chat)
3434

35+
if os.getenv("INTERPRETER_CHAT_API_BASE", ""):
36+
page.item("Code Interpreter", code_interpreter)
37+
3538
with st.expander("🐧 PARAMTERS", False):
3639
max_tokens = st.slider("MaxTokens", 20, 4096, 1024)
3740
temperature = st.slider("Temperature", 0.0, 1.0, 0.9)

streamlit-demo/streamlit_gallery/components/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .sql_chat.streamlit_app import main as sql_chat
44
from .search_chat.streamlit_app import main as search_chat
55
from .tool_chat.streamlit_app import main as tool_chat
6+
from .code_interpreter.streamlit_app import main as code_interpreter
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import os
2+
3+
import openai
4+
import streamlit as st
5+
6+
from .utils import CodeKernel, extract_code, execute, postprocess_text
7+
8+
9+
@st.cache_resource
10+
def get_kernel():
11+
return CodeKernel()
12+
13+
14+
SYSTEM_MESSAGE = [
15+
{
16+
"role": "system",
17+
"content": "你是一位智能AI助手,你叫ChatGLM,你连接着一台电脑,但请注意不能联网。在使用Python解决任务时,你可以运行代码并得到结果,如果运行结果有错误,你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件,文件默认存储路径是/mnt/data/。"
18+
}
19+
]
20+
21+
22+
def chat_once(message_placeholder):
23+
params = dict(
24+
model="chatglm3",
25+
messages=SYSTEM_MESSAGE + st.session_state.messages,
26+
stream=True,
27+
max_tokens=st.session_state.get("max_tokens", 512),
28+
temperature=st.session_state.get("temperature", 0.9),
29+
)
30+
response = openai.ChatCompletion.create(**params)
31+
32+
display = ""
33+
for _ in range(5):
34+
full_response = ""
35+
for chunk in response:
36+
content = chunk.choices[0].delta.get("content", "")
37+
full_response += content
38+
display += content
39+
message_placeholder.markdown(postprocess_text(display) + "▌")
40+
41+
if chunk.choices[0].finish_reason == "stop":
42+
message_placeholder.markdown(postprocess_text(display) + "▌")
43+
st.session_state.messages.append(
44+
{
45+
"role": "assistant",
46+
"content": full_response
47+
}
48+
)
49+
return
50+
51+
elif chunk.choices[0].finish_reason == "function_call":
52+
try:
53+
code = extract_code(full_response)
54+
except:
55+
continue
56+
57+
with message_placeholder:
58+
with st.spinner("Executing code..."):
59+
try:
60+
res_type, res = execute(code, get_kernel())
61+
except Exception as e:
62+
st.error(f"Error when executing code: {e}")
63+
return
64+
65+
if res_type == "text":
66+
res = postprocess_text(res)
67+
display += "\n" + res
68+
message_placeholder.markdown(postprocess_text(display) + "▌")
69+
elif res_type == "image":
70+
st.image(res)
71+
72+
st.session_state.messages.append(
73+
{
74+
"role": "assistant",
75+
"content": full_response,
76+
"function_call": {"name": "interpreter", "arguments": ""},
77+
}
78+
)
79+
st.session_state.messages.append(
80+
{
81+
"role": "function",
82+
"content": "[Image]" if res_type == "image" else res, # 调用函数返回结果
83+
}
84+
)
85+
86+
break
87+
88+
params["messages"] = st.session_state.messages
89+
response = openai.ChatCompletion.create(**params)
90+
91+
92+
def main():
93+
st.title("💬 Code Interpreter")
94+
95+
openai.api_base = os.getenv("INTERPRETER_CHAT_API_BASE", "http://192.168.20.59:7891/v1")
96+
openai.api_key = os.getenv("API_KEY", "xxx")
97+
98+
if "messages" not in st.session_state:
99+
st.session_state.messages = []
100+
101+
for message in st.session_state.messages:
102+
role = message["role"]
103+
if role in ["user", "function"]:
104+
with st.chat_message("user"):
105+
st.markdown(message["content"])
106+
else:
107+
with st.chat_message("assistant"):
108+
st.markdown(postprocess_text(message["content"]))
109+
110+
if prompt := st.chat_input("What is up?"):
111+
st.session_state.messages.append({"role": "user", "content": prompt})
112+
with st.chat_message("user"):
113+
st.markdown(prompt)
114+
115+
with st.chat_message("assistant"):
116+
message_placeholder = st.empty()
117+
chat_once(message_placeholder)
118+
119+
120+
if __name__ == "__main__":
121+
main()

0 commit comments

Comments
 (0)