-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
129 lines (96 loc) · 4.44 KB
/
app.py
File metadata and controls
129 lines (96 loc) · 4.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import streamlit as st
import pandas as pd
from transformers import GenerationConfig, BartModel, BartTokenizer, AutoTokenizer, AutoModelForSeq2SeqLM, TextStreamer
import torch
import time
import sys, os
path = os.path.abspath(os.path.dirname(__file__))
sys.path.insert(0, path)
st.title("Dialogue Text Summarization")
st.caption("Natural Language Processing Project 20232")
st.write("---")
class StreamlitTextStreamer(TextStreamer):
def __init__(self, tokenizer, st_container, st_info_container, skip_prompt=False, **decode_kwargs):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.st_container = st_container
self.st_info_container = st_info_container
self.text = ""
self.start_time = None
self.first_token_time = None
self.total_tokens = 0
def on_finalized_text(self, text: str, stream_end: bool=False):
if self.start_time is None:
self.start_time = time.time()
if self.first_token_time is None and len(text.strip()) > 0:
self.first_token_time = time.time()
self.text += text
self.total_tokens += len(text.split())
self.st_container.markdown("###### " + self.text)
time.sleep(0.03)
if stream_end:
total_time = time.time() - self.start_time
first_token_wait_time = self.first_token_time - self.start_time if self.first_token_time else None
tokens_per_second = self.total_tokens / total_time if total_time > 0 else None
df = pd.DataFrame(data={
"First token": [first_token_wait_time],
"Total tokens": [self.total_tokens],
"Time taken": [total_time],
"Token per second": [tokens_per_second]
})
self.st_info_container.table(df)
def generate_summary(model, input_text, generation_config, tokenizer, st_container, st_info_container) -> str:
try:
prefix = "Summarize the following conversation: \n###\n"
suffix = "\n### Summary:"
target_length = max(1, int(0.15 * len(input_text.split())))
input_ids = tokenizer.encode(prefix + input_text + f"The generated summary should be around {target_length} words." + suffix, return_tensors="pt")
# Initialize the Streamlit container and streamer
streamer = StreamlitTextStreamer(tokenizer, st_container, st_info_container, skip_special_tokens=True, decoder_start_token_id=3)
model.generate(input_ids, streamer=streamer, do_sample=True, generation_config=generation_config)
except Exception as e:
raise e
with st.sidebar:
checkpoint = st.selectbox("Model", options=[
"Choose model",
"dtruong46me/train-bart-base",
"dtruong46me/flant5-small",
"dtruong46me/flant5-base",
"dtruong46me/flan-t5-s",
"ntluongg/bart-base-luong"
])
st.button("Model detail", use_container_width=True)
st.write("-----")
st.write("**Generate Options:**")
min_new_tokens = st.number_input("Min new tokens", min_value=1, max_value=64, value=10)
max_new_tokens = st.number_input("Max new tokens", min_value=64, max_value=128, value=64)
temperature = st.number_input("Temperature", min_value=0.0, max_value=1.0, value=0.9, step=0.05)
top_k = st.number_input("Top_k", min_value=1, max_value=50, step=1, value=20)
top_p = st.number_input("Top_p", min_value=0.01, max_value=1.00, step=0.01, value=1.0)
height = 200
input_text = st.text_area("Dialogue", height=height)
generation_config = GenerationConfig(
min_new_tokens=min_new_tokens,
max_new_tokens=320,
temperature=temperature,
top_p=top_p,
top_k=top_k
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if checkpoint=="Choose model":
tokenizer = None
model = None
if checkpoint!="Choose model":
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
if st.button("Submit"):
st.write("---")
st.write("## Summary")
if checkpoint=="Choose model":
st.error("Please selece a model!")
else:
if input_text=="":
st.error("Please enter a dialogue!")
# generate_summary(model, " ".join(input_text.split()), generation_config, tokenizer)
st_container = st.empty()
st_info_container = st.empty()
generate_summary(model, " ".join(input_text.split()), generation_config, tokenizer, st_container, st_info_container)