Skip to content

Commit 3974b68

Browse files
committed
chat: add version that would use finetuned model
1 parent 16e3393 commit 3974b68

File tree

1 file changed

+175
-0
lines changed

1 file changed

+175
-0
lines changed

chat/finetuned.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# This chat variant determines if the user's query is related to a widget or a search
2+
import re
3+
import time
4+
import uuid
5+
import traceback
6+
from dataclasses import dataclass, asdict
7+
from typing import Any, Dict, List, Optional, Union, Literal, TypedDict, Callable
8+
9+
from gpt_index.utils import ErrorToRetry, retry_on_exceptions_with_backoff
10+
from langchain.llms import OpenAI
11+
from langchain.chat_models import ChatOpenAI
12+
from langchain.prompts import PromptTemplate
13+
from langchain.chains import LLMChain
14+
15+
import context
16+
import utils
17+
import utils.timing as timing
18+
from utils import error_wrap, ensure_wallet_connected, ConnectedWalletRequired, FetchError, ExecError
19+
import registry
20+
import streaming
21+
from chat.container import ContainerMixin, dataclass_to_container_params
22+
from .base import (
23+
BaseChat, ChatHistory, Response, ChatOutputParser,
24+
)
25+
from integrations import (
26+
etherscan, defillama, center, opensea,
27+
)
28+
from ui_workflows import (
29+
aave, ens
30+
)
31+
from ui_workflows.multistep_handler import register_ens_domain, exec_aave_operation
32+
from tools.index_widget import (
33+
34+
35+
RE_COMMAND = re.compile(r"\<\|(?P<command>[^(]+)\((?P<params>[^)<{}]*)\)\|\>")
36+
37+
TEMPLATE = '''<user>{question}<hist>{chat_history}<task>{task_info}<bot>'''
38+
39+
HISTORY_TOKEN_LIMIT = 1800
40+
41+
42+
@registry.register_class
43+
class RephraseWidgetSearchChat(BaseChat):
44+
def __init__(self, widget_index: Any, top_k: int = 3, show_thinking: bool = True) -> None:
45+
super().__init__()
46+
self.output_parser = ChatOutputParser()
47+
self.widget_prompt = PromptTemplate(
48+
input_variables=["task_info", "chat_history", "question"],
49+
template=TEMPLATE,
50+
output_parser=self.output_parser,
51+
)
52+
self.widget_index = widget_index
53+
self.top_k = top_k
54+
self.show_thinking = show_thinking
55+
56+
def receive_input(
57+
self,
58+
history: ChatHistory,
59+
userinput: str,
60+
send: Callable,
61+
message_id: Optional[uuid.UUID] = None,
62+
before_message_id: Optional[uuid.UUID] = None,
63+
) -> None:
64+
userinput = userinput.strip()
65+
history_string = history.to_string(system_prefix=None, token_limit=HISTORY_TOKEN_LIMIT, before_message_id=before_message_id) # omit system messages
66+
67+
history.add_user_message(userinput, message_id=message_id, before_message_id=before_message_id)
68+
timing.init()
69+
70+
bot_chat_message_id = None
71+
bot_response = ''
72+
has_sent_bot_response = False
73+
74+
def bot_flush(response):
75+
nonlocal bot_chat_message_id
76+
response = response.strip()
77+
send(Response(
78+
response=response,
79+
still_thinking=False,
80+
actor='bot',
81+
operation='replace',
82+
), last_chat_message_id=bot_chat_message_id, before_message_id=before_message_id)
83+
history.add_bot_message(response, message_id=bot_chat_message_id, before_message_id=before_message_id)
84+
85+
def bot_new_token_handler(token):
86+
nonlocal bot_chat_message_id, bot_response, has_sent_bot_response
87+
88+
bot_response += token
89+
if not bot_response.strip():
90+
# don't start returning something until we have the first non-whitespace char
91+
return
92+
93+
timing.log('first_visible_bot_token')
94+
bot_chat_message_id = send(Response(
95+
response=token,
96+
still_thinking=False,
97+
actor='bot',
98+
operation='append' if bot_chat_message_id is not None else 'create',
99+
), last_chat_message_id=bot_chat_message_id, before_message_id=before_message_id)
100+
has_sent_bot_response = True
101+
102+
new_token_handler = bot_new_token_handler
103+
response_buffer = ""
104+
response_state = 0 # finite-state machine state
105+
response_prefix = "## Response:"
106+
107+
def injection_handler(token):
108+
nonlocal new_token_handler, response_buffer, response_state, response_prefix
109+
110+
timing.log('first_token')
111+
timing.log('first_widget_token') # for comparison with basic agent
112+
113+
response_buffer += token
114+
if response_state == 0: # we are still waiting for response_prefix to appear
115+
if response_prefix not in response_buffer:
116+
# keep waiting
117+
return
118+
else:
119+
# we have found the response_prefix, trim everything before that
120+
timing.log('first_widget_response_token')
121+
response_state = 1
122+
response_buffer = response_buffer[response_buffer.index(response_prefix) + len(response_prefix):]
123+
124+
if response_state == 1: # we are going to output the response incrementally, evaluating any fetch commands
125+
while '<|' in response_buffer:
126+
if '|>' in response_buffer:
127+
# parse fetch command
128+
response_buffer = iterative_evaluate(response_buffer)
129+
if len(response_buffer.split('<|')) == len(response_buffer.split('|>')):
130+
# matching pairs of open/close, just flush
131+
# NB: for better frontend parsing of nested widgets, we need an invariant that
132+
# there are no two independent widgets on the same line, otherwise we can't
133+
# detect the closing tag properly when there is nesting.
134+
response_buffer = response_buffer.replace('|>', '|>\n')
135+
break
136+
else:
137+
# keep waiting
138+
return
139+
else:
140+
# keep waiting
141+
return
142+
token = response_buffer
143+
response_buffer = ""
144+
if token.strip():
145+
timing.log('first_visible_widget_response_token')
146+
new_token_handler(token)
147+
if '\n' in token:
148+
# we have found a line-break in the response, switch to the terminal state to mask subsequent output
149+
response_state = 2
150+
151+
widgets = retry_on_exceptions_with_backoff(
152+
lambda: self.widget_index.similarity_search(userinput, k=self.top_k),
153+
[ErrorToRetry(TypeError)],
154+
)
155+
timing.log('widget_index_lookup_done')
156+
task_info = '\n'.join([f'Widget: {widget.page_content}' for widget in widgets])
157+
example = {
158+
"task_info": task_info,
159+
"chat_history": history_string,
160+
"question": userinput,
161+
"stop": ["Input", "User"],
162+
}
163+
164+
chain = streaming.get_streaming_chain(self.widget_prompt, injection_handler)
165+
166+
with context.with_request_context(history.wallet_address, message_id):
167+
result = chain.run(example).strip()
168+
timing.log('response_done')
169+
170+
if bot_chat_message_id is not None:
171+
bot_flush(bot_response)
172+
173+
response = f'Timings - {timing.report()}'
174+
system_chat_message_id = send(Response(response=response, actor='system'), before_message_id=before_message_id)
175+
history.add_system_message(response, message_id=system_chat_message_id, before_message_id=before_message_id)

0 commit comments

Comments
 (0)