Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
.vscode
ptuning/data
ptuning/output

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
76 changes: 43 additions & 33 deletions api.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,49 @@
import json
import datetime
import torch
import uvicorn
from typing import List
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime
import torch
from pydantic import BaseModel
from utils import load_model_on_gpus


DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
devices_list = [
'cuda:0',
'cuda:1'
]


def torch_gc():
def _torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
for item in devices_list:
with torch.cuda.device(item):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()


class Question(BaseModel):
prompt: str
history: List[str] = []
max_length: int = 2048
top_p: float = 0.7
temperature: float = 0.95


app = FastAPI()


@app.post("/")
async def create_item(request: Request):
global model, tokenizer
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
history = json_post_list.get('history')
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
response, history = model.chat(tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95)
@app.post('/chat/')
async def chat(question: Question):
response, history = model.chat(
tokenizer,
question.prompt,
history=question.history,
max_length=question.max_length,
top_p=question.top_p,
temperature=question.temperature
)
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
Expand All @@ -43,14 +52,15 @@ async def create_item(request: Request):
"status": 200,
"time": time
}
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log)
torch_gc()
_torch_gc()
return answer


if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(
"THUDM/chatglm-6b", trust_remote_code=True
)
model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2)
# model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model.eval()
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
uvicorn.run(app, host="127.0.0.1", port=11001, workers=1)
59 changes: 59 additions & 0 deletions cli_demo_gpus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import platform
import signal
from transformers import AutoTokenizer, AutoModel
from utils import load_model_on_gpus

tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2)
# model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model = model.eval()

os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
stop_stream = False


def build_prompt(history):
prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
for query, response in history:
prompt += f"\n\n用户:{query}"
prompt += f"\n\nChatGLM-6B:{response}"
return prompt


def signal_handler(signal, frame):
global stop_stream
stop_stream = True


def main():
history = []
global stop_stream
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
while True:
query = input("\n用户:")
if query.strip() == "stop":
break
if query.strip() == "clear":
history = []
os.system(clear_command)
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
continue
count = 0
for response, history in model.stream_chat(tokenizer, query, history=history):
if stop_stream:
stop_stream = False
break
else:
count += 1
if count % 8 == 0:
os.system(clear_command)
print(build_prompt(history), flush=True)
signal.signal(signal.SIGINT, signal_handler)
os.system(clear_command)
print(build_prompt(history), flush=True)


if __name__ == "__main__":
main()
6 changes: 3 additions & 3 deletions ptuning/evaluate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ PRE_SEQ_LEN=128
CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2
STEP=3000

CUDA_VISIBLE_DEVICES=0 python3 main.py \
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 main.py \
--do_predict \
--validation_file AdvertiseGen/dev.json \
--test_file AdvertiseGen/dev.json \
--validation_file data/AdvertiseGen/dev.json \
--test_file data/AdvertiseGen/dev.json \
--overwrite_cache \
--prompt_column content \
--response_column summary \
Expand Down
8 changes: 5 additions & 3 deletions ptuning/train.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
PRE_SEQ_LEN=128
LR=2e-2

CUDA_VISIBLE_DEVICES=0 python3 main.py \
export 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:32'

CUDA_VISIBLE_DEVICES=0,1,2,3 python3 main.py \
--do_train \
--train_file AdvertiseGen/train.json \
--validation_file AdvertiseGen/dev.json \
--train_file data/AdvertiseGen/train.json \
--validation_file data/AdvertiseGen/dev.json \
--prompt_column content \
--response_column summary \
--overwrite_cache \
Expand Down
5 changes: 2 additions & 3 deletions ptuning/web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ def reset_state():
def main():
global model, tokenizer

parser = HfArgumentParser((
ModelArguments))
parser = HfArgumentParser((ModelArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
Expand Down Expand Up @@ -158,7 +157,7 @@ def main():
model.transformer.prefix_encoder.float().cuda()

model = model.eval()
demo.queue().launch(share=False, inbrowser=True)
demo.queue().launch(share=False, inbrowser=True, server_port=11001)



Expand Down
5 changes: 3 additions & 2 deletions ptuning/web_demo.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
PRE_SEQ_LEN=128

CUDA_VISIBLE_DEVICES=0 python3 web_demo.py \
CUDA_VISIBLE_DEVICES=0,1 python3 web_demo.py \
--model_name_or_path THUDM/chatglm-6b \
--ptuning_checkpoint output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000 \
--pre_seq_len $PRE_SEQ_LEN
--pre_seq_len $PRE_SEQ_LEN \
--quantization_bit 4

28 changes: 17 additions & 11 deletions web_demo.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from transformers import AutoModel, AutoTokenizer
import gradio as gr
import mdtex2html
from transformers import AutoModel, AutoTokenizer
from utils import load_model_on_gpus


tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
# model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2)
model = model.eval()

"""Override Chatbot.postprocess"""
Expand Down Expand Up @@ -60,7 +63,7 @@ def predict(input, chatbot, max_length, top_p, temperature, history):
chatbot.append((parse_text(input), ""))
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
temperature=temperature):
chatbot[-1] = (parse_text(input), parse_text(response))
chatbot[-1] = (parse_text(input), parse_text(response))

yield chatbot, history

Expand All @@ -74,21 +77,24 @@ def reset_state():


with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">ChatGLM</h1>""")
gr.HTML("""<h1 align="center">CodeLab</h1>""")

chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
user_input = gr.Textbox(show_label=False, placeholder="输入聊天内容", lines=10).style(
container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
submitBtn = gr.Button("发送", variant="primary")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
emptyBtn = gr.Button("清除历史记录")
max_length = gr.Slider(
0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.7, step=0.01,
label="Top P", interactive=True)
temperature = gr.Slider(
0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)

history = gr.State([])

Expand All @@ -98,4 +104,4 @@ def reset_state():

emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)

demo.queue().launch(share=False, inbrowser=True)
demo.queue().launch(share=False, inbrowser=False, server_port=11001)