-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
89 lines (69 loc) · 3 KB
/
app.py
File metadata and controls
89 lines (69 loc) · 3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
import streamlit as st
with st.spinner('Loading model and tokenizer...'):
import torch
from transformers import BertTokenizer, BertForSequenceClassification
import json
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SimpleTokenizer:
def __init__(self, word_index):
self.word_index = word_index
self.index_word = {v: k for k, v in word_index.items()}
@st.cache_resource
def load_model_and_tokenizers():
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('emoji_prediction_model_torch_500')
model = model.to(device)
model.eval()
def load_emoji_tokenizer(file_path):
with open(file_path, 'r') as file:
tokenizer_config = json.load(file)
word_index = tokenizer_config['word_index']
return SimpleTokenizer(word_index)
emoji_tokenizer = load_emoji_tokenizer('emoji_tokenizer_config_500torch.json')
return tokenizer, model, emoji_tokenizer
@st.cache_data
def get_index_to_emoji(_emoji_tokenizer):
return {v: k for k, v in _emoji_tokenizer.word_index.items()}
tokenizer, model, emoji_tokenizer = load_model_and_tokenizers()
index_to_emoji = get_index_to_emoji(emoji_tokenizer)
def preprocess_text(text, tokenizer, max_len):
encoding = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=max_len,
return_token_type_ids=False,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
return encoding['input_ids'], encoding['attention_mask']
@st.cache_data
def predict_emojis(text, _model, _tokenizer, _emoji_tokenizer, _index_to_emoji):
max_len = 128
input_ids, attention_mask = preprocess_text(text, _tokenizer, max_len)
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
with torch.no_grad():
outputs = _model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=1).cpu().numpy()[0]
top_n = 5
top_indices = probabilities.argsort()[-top_n:][::-1]
top_emojis = [_index_to_emoji.get(i, '') for i in top_indices]
top_probabilities = probabilities[top_indices]
filtered_results = [(emoji, float(prob)) for emoji, prob in zip(top_emojis, top_probabilities) if emoji != '']
return top_emojis, top_probabilities, filtered_results
st.title("Emoji Prediction Model")
text = st.text_input("Enter text to predict emojis:")
if st.button("Predict"):
if text:
with st.spinner('Predicting...'):
emojis, probabilities, filtered_results = predict_emojis(text, model, tokenizer, emoji_tokenizer, index_to_emoji)
for emoji, prob in filtered_results:
st.write(f"{emoji}: {prob:.2%}")
else:
st.write("Please enter some text to predict emojis.")