62
62
63
63
64
64
def v2_chat (
65
- session_id : int , query : str , configuration : RagPredictConfiguration , user_name : str
65
+ session : Session , query : str , configuration : RagPredictConfiguration , user_name : str
66
66
) -> RagStudioChatMessage :
67
- session = session_metadata_api .get_session (session_id )
68
67
query_configuration = QueryConfiguration (
69
68
top_k = session .response_chunks ,
70
69
model_name = session .inference_model ,
@@ -80,7 +79,7 @@ def v2_chat(
80
79
session , response_id , query , query_configuration , user_name
81
80
)
82
81
83
- ChatHistoryManager ().append_to_history (session_id , [new_chat_message ])
82
+ ChatHistoryManager ().append_to_history (session . id , [new_chat_message ])
84
83
return new_chat_message
85
84
86
85
@@ -121,7 +120,7 @@ def _run_chat(
121
120
relevance , faithfulness = evaluators .evaluate_response (
122
121
query , response , session .inference_model
123
122
)
124
- response_source_nodes = format_source_nodes (response )
123
+ response_source_nodes = format_source_nodes (response , data_source_id )
125
124
new_chat_message = RagStudioChatMessage (
126
125
id = response_id ,
127
126
source_nodes = response_source_nodes ,
@@ -159,7 +158,9 @@ def retrieve_chat_history(session_id: int) -> List[RagContext]:
159
158
return history
160
159
161
160
162
- def format_source_nodes (response : AgentChatResponse ) -> List [RagPredictSourceNode ]:
161
+ def format_source_nodes (
162
+ response : AgentChatResponse , data_source_id : int
163
+ ) -> List [RagPredictSourceNode ]:
163
164
response_source_nodes = []
164
165
for source_node in response .source_nodes :
165
166
doc_id = source_node .node .metadata .get ("document_id" , source_node .node .node_id )
@@ -169,6 +170,7 @@ def format_source_nodes(response: AgentChatResponse) -> List[RagPredictSourceNod
169
170
doc_id = doc_id ,
170
171
source_file_name = source_node .node .metadata ["file_name" ],
171
172
score = source_node .score or 0.0 ,
173
+ dataSourceId = data_source_id ,
172
174
)
173
175
)
174
176
response_source_nodes = sorted (
@@ -177,10 +179,32 @@ def format_source_nodes(response: AgentChatResponse) -> List[RagPredictSourceNod
177
179
return response_source_nodes
178
180
179
181
182
+ def generate_suggested_questions_direct_llm (session : Session ) -> List [str ]:
183
+ chat_history = retrieve_chat_history (session .id )
184
+ if not chat_history :
185
+ return []
186
+ query_str = (
187
+ " Give me a list of possible follow-up questions."
188
+ " Each question should be on a new line."
189
+ " There should be no more than four (4) questions."
190
+ " Each question should be no longer than fifteen (15) words."
191
+ " The response should be a bulleted list, using an asterisk (*) to denote the bullet item."
192
+ " Do not start like this - `Here are four questions that I can answer based on the context information`"
193
+ " Only return the list."
194
+ )
195
+ chat_response = llm_completion .completion (
196
+ session .id , query_str , session .inference_model
197
+ )
198
+ suggested_questions = process_response (chat_response .message .content )
199
+ return suggested_questions
200
+
201
+
180
202
def generate_suggested_questions (
181
203
session_id : int ,
182
204
) -> List [str ]:
183
205
session = session_metadata_api .get_session (session_id )
206
+ if len (session .data_source_ids ) == 0 :
207
+ return generate_suggested_questions_direct_llm (session )
184
208
if len (session .data_source_ids ) != 1 :
185
209
raise HTTPException (
186
210
status_code = 400 ,
@@ -256,14 +280,13 @@ def process_response(response: str | None) -> list[str]:
256
280
257
281
258
282
def direct_llm_chat (
259
- session_id : int , query : str , user_name : str
283
+ session : Session , query : str , user_name : str
260
284
) -> RagStudioChatMessage :
261
- session = session_metadata_api .get_session (session_id )
262
285
response_id = str (uuid .uuid4 ())
263
286
record_direct_llm_mlflow_run (response_id , session , user_name )
264
287
265
288
chat_response = llm_completion .completion (
266
- session_id , query , session .inference_model
289
+ session . id , query , session .inference_model
267
290
)
268
291
new_chat_message = RagStudioChatMessage (
269
292
id = response_id ,
@@ -277,5 +300,5 @@ def direct_llm_chat(
277
300
timestamp = time .time (),
278
301
condensed_question = None ,
279
302
)
280
- ChatHistoryManager ().append_to_history (session_id , [new_chat_message ])
303
+ ChatHistoryManager ().append_to_history (session . id , [new_chat_message ])
281
304
return new_chat_message
0 commit comments