Skip to content

Commit 0900413

Browse files
Add next questions suggestion to the user (#170)
--------- Co-authored-by: Marcus Schiesser <[email protected]>
1 parent 8dc6a2b commit 0900413

File tree

13 files changed

+237
-21
lines changed

13 files changed

+237
-21
lines changed

.changeset/tall-pans-bake.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"create-llama": patch
3+
---
4+
5+
Add suggestions for next questions.

templates/components/llamaindex/typescript/streaming/stream.ts

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,34 +5,51 @@ import {
55
trimStartOfStreamHelper,
66
type AIStreamCallbacksAndOptions,
77
} from "ai";
8-
import { EngineResponse } from "llamaindex";
8+
import { ChatMessage, EngineResponse } from "llamaindex";
9+
import { generateNextQuestions } from "./suggestion";
910

1011
export function LlamaIndexStream(
1112
response: AsyncIterable<EngineResponse>,
1213
data: StreamData,
14+
chatHistory: ChatMessage[],
1315
opts?: {
1416
callbacks?: AIStreamCallbacksAndOptions;
1517
},
1618
): ReadableStream<Uint8Array> {
17-
return createParser(response, data)
19+
return createParser(response, data, chatHistory)
1820
.pipeThrough(createCallbacksTransformer(opts?.callbacks))
1921
.pipeThrough(createStreamDataTransformer());
2022
}
2123

22-
function createParser(res: AsyncIterable<EngineResponse>, data: StreamData) {
24+
function createParser(
25+
res: AsyncIterable<EngineResponse>,
26+
data: StreamData,
27+
chatHistory: ChatMessage[],
28+
) {
2329
const it = res[Symbol.asyncIterator]();
2430
const trimStartOfStream = trimStartOfStreamHelper();
31+
let llmTextResponse = "";
2532

2633
return new ReadableStream<string>({
2734
async pull(controller): Promise<void> {
2835
const { value, done } = await it.next();
2936
if (done) {
3037
controller.close();
38+
// LLM stream is done, generate the next questions with a new LLM call
39+
chatHistory.push({ role: "assistant", content: llmTextResponse });
40+
const questions: string[] = await generateNextQuestions(chatHistory);
41+
if (questions.length > 0) {
42+
data.appendMessageAnnotation({
43+
type: "suggested_questions",
44+
data: questions,
45+
});
46+
}
3147
data.close();
3248
return;
3349
}
3450
const text = trimStartOfStream(value.delta ?? "");
3551
if (text) {
52+
llmTextResponse += text;
3653
controller.enqueue(text);
3754
}
3855
},
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import { ChatMessage, Settings } from "llamaindex";
2+
3+
const NEXT_QUESTION_PROMPT_TEMPLATE = `You're a helpful assistant! Your task is to suggest the next question that user might ask.
4+
Here is the conversation history
5+
---------------------
6+
$conversation
7+
---------------------
8+
Given the conversation history, please give me $number_of_questions questions that you might ask next!
9+
Your answer should be wrapped in three sticks which follows the following format:
10+
\`\`\`
11+
<question 1>
12+
<question 2>\`\`\`
13+
`;
14+
const N_QUESTIONS_TO_GENERATE = 3;
15+
16+
export async function generateNextQuestions(
17+
conversation: ChatMessage[],
18+
numberOfQuestions: number = N_QUESTIONS_TO_GENERATE,
19+
) {
20+
const llm = Settings.llm;
21+
22+
// Format conversation
23+
const conversationText = conversation
24+
.map((message) => `${message.role}: ${message.content}`)
25+
.join("\n");
26+
const message = NEXT_QUESTION_PROMPT_TEMPLATE.replace(
27+
"$conversation",
28+
conversationText,
29+
).replace("$number_of_questions", numberOfQuestions.toString());
30+
31+
try {
32+
const response = await llm.complete({ prompt: message });
33+
const questions = extractQuestions(response.text);
34+
return questions;
35+
} catch (error) {
36+
console.error("Error: ", error);
37+
throw error;
38+
}
39+
}
40+
41+
// TODO: instead of parsing the LLM's result we can use structured predict, once LITS supports it
42+
function extractQuestions(text: string): string[] {
43+
// Extract the text inside the triple backticks
44+
const contentMatch = text.match(/```(.*?)```/s);
45+
const content = contentMatch ? contentMatch[1] : "";
46+
47+
// Split the content by newlines to get each question
48+
const questions = content
49+
.split("\n")
50+
.map((question) => question.trim())
51+
.filter((question) => question !== "");
52+
53+
return questions;
54+
}

templates/types/streaming/express/src/controllers/chat.controller.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@ export const chat = async (req: Request, res: Response) => {
6767
});
6868

6969
// Return a stream, which can be consumed by the Vercel/AI client
70-
const stream = LlamaIndexStream(response, vercelStreamData);
70+
const stream = LlamaIndexStream(
71+
response,
72+
vercelStreamData,
73+
messages as ChatMessage[],
74+
);
7175

7276
return streamToResponse(stream, res, {}, vercelStreamData);
7377
} catch (error) {

templates/types/streaming/fastapi/app/api/routers/chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ async def chat(
6161
response = await chat_engine.astream_chat(last_message_content, messages)
6262
process_response_nodes(response.source_nodes, background_tasks)
6363

64-
return VercelStreamResponse(request, event_handler, response)
64+
return VercelStreamResponse(request, event_handler, response, data)
6565
except Exception as e:
6666
logger.exception("Error in chat engine", exc_info=True)
6767
raise HTTPException(

templates/types/streaming/fastapi/app/api/routers/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class File(BaseModel):
2525
filetype: str
2626

2727

28-
class AnnotationData(BaseModel):
28+
class AnnotationFileData(BaseModel):
2929
files: List[File] = Field(
3030
default=[],
3131
description="List of files",
@@ -50,7 +50,7 @@ class Config:
5050

5151
class Annotation(BaseModel):
5252
type: str
53-
data: AnnotationData
53+
data: AnnotationFileData | List[str]
5454

5555
def to_content(self) -> str | None:
5656
if self.type == "document_file":

templates/types/streaming/fastapi/app/api/routers/vercel_response.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
77

88
from app.api.routers.events import EventCallbackHandler
9-
from app.api.routers.models import SourceNodes
9+
from app.api.routers.models import ChatData, Message, SourceNodes
10+
from app.api.services.suggestion import NextQuestionSuggestion
1011

1112

1213
class VercelStreamResponse(StreamingResponse):
@@ -17,15 +18,6 @@ class VercelStreamResponse(StreamingResponse):
1718
TEXT_PREFIX = "0:"
1819
DATA_PREFIX = "8:"
1920

20-
def __init__(
21-
self,
22-
request: Request,
23-
event_handler: EventCallbackHandler,
24-
response: StreamingAgentChatResponse,
25-
):
26-
content = self.content_generator(request, event_handler, response)
27-
super().__init__(content=content)
28-
2921
@classmethod
3022
def convert_text(cls, token: str):
3123
# Escape newlines and double quotes to avoid breaking the stream
@@ -37,17 +29,48 @@ def convert_data(cls, data: dict):
3729
data_str = json.dumps(data)
3830
return f"{cls.DATA_PREFIX}[{data_str}]\n"
3931

32+
def __init__(
33+
self,
34+
request: Request,
35+
event_handler: EventCallbackHandler,
36+
response: StreamingAgentChatResponse,
37+
chat_data: ChatData,
38+
):
39+
content = VercelStreamResponse.content_generator(
40+
request, event_handler, response, chat_data
41+
)
42+
super().__init__(content=content)
43+
4044
@classmethod
4145
async def content_generator(
4246
cls,
4347
request: Request,
4448
event_handler: EventCallbackHandler,
4549
response: StreamingAgentChatResponse,
50+
chat_data: ChatData,
4651
):
4752
# Yield the text response
4853
async def _chat_response_generator():
54+
final_response = ""
4955
async for token in response.async_response_gen():
50-
yield cls.convert_text(token)
56+
final_response += token
57+
yield VercelStreamResponse.convert_text(token)
58+
59+
# Generate questions that user might interested to
60+
conversation = chat_data.messages + [
61+
Message(role="assistant", content=final_response)
62+
]
63+
questions = await NextQuestionSuggestion.suggest_next_questions(
64+
conversation
65+
)
66+
if len(questions) > 0:
67+
yield VercelStreamResponse.convert_data(
68+
{
69+
"type": "suggested_questions",
70+
"data": questions,
71+
}
72+
)
73+
5174
# the text_generator is the leading stream, once it's finished, also finish the event stream
5275
event_handler.is_done = True
5376

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from typing import List
2+
3+
from app.api.routers.models import Message
4+
from llama_index.core.prompts import PromptTemplate
5+
from llama_index.core.settings import Settings
6+
from pydantic import BaseModel
7+
8+
NEXT_QUESTIONS_SUGGESTION_PROMPT = PromptTemplate(
9+
"You're a helpful assistant! Your task is to suggest the next question that user might ask. "
10+
"\nHere is the conversation history"
11+
"\n---------------------\n{conversation}\n---------------------"
12+
"Given the conversation history, please give me $number_of_questions questions that you might ask next!"
13+
)
14+
N_QUESTION_TO_GENERATE = 3
15+
16+
17+
class NextQuestions(BaseModel):
18+
"""A list of questions that user might ask next"""
19+
20+
questions: List[str]
21+
22+
23+
class NextQuestionSuggestion:
24+
@staticmethod
25+
async def suggest_next_questions(
26+
messages: List[Message],
27+
number_of_questions: int = N_QUESTION_TO_GENERATE,
28+
) -> List[str]:
29+
# Reduce the cost by only using the last two messages
30+
last_user_message = None
31+
last_assistant_message = None
32+
for message in reversed(messages):
33+
if message.role == "user":
34+
last_user_message = f"User: {message.content}"
35+
elif message.role == "assistant":
36+
last_assistant_message = f"Assistant: {message.content}"
37+
if last_user_message and last_assistant_message:
38+
break
39+
conversation: str = f"{last_user_message}\n{last_assistant_message}"
40+
41+
output: NextQuestions = await Settings.llm.astructured_predict(
42+
NextQuestions,
43+
prompt=NEXT_QUESTIONS_SUGGESTION_PROMPT,
44+
conversation=conversation,
45+
nun_questions=number_of_questions,
46+
)
47+
48+
return output.questions

templates/types/streaming/nextjs/app/api/chat/route.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,11 @@ export async function POST(request: NextRequest) {
8080
});
8181

8282
// Transform LlamaIndex stream to Vercel/AI format
83-
const stream = LlamaIndexStream(response, vercelStreamData);
83+
const stream = LlamaIndexStream(
84+
response,
85+
vercelStreamData,
86+
messages as ChatMessage[],
87+
);
8488

8589
// Return a StreamingTextResponse, which can be consumed by the Vercel/AI client
8690
return new StreamingTextResponse(stream, {}, vercelStreamData);
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import { useState } from "react";
2+
import { ChatHandler, SuggestedQuestionsData } from "..";
3+
4+
export function SuggestedQuestions({
5+
questions,
6+
append,
7+
}: {
8+
questions: SuggestedQuestionsData;
9+
append: Pick<ChatHandler, "append">["append"];
10+
}) {
11+
const [showQuestions, setShowQuestions] = useState(questions.length > 0);
12+
13+
return (
14+
showQuestions &&
15+
append !== undefined && (
16+
<div className="flex flex-col space-y-2">
17+
{questions.map((question, index) => (
18+
<a
19+
key={index}
20+
onClick={() => {
21+
append({ role: "user", content: question });
22+
setShowQuestions(false);
23+
}}
24+
className="text-sm italic hover:underline cursor-pointer"
25+
>
26+
{"->"} {question}
27+
</a>
28+
))}
29+
</div>
30+
)
31+
);
32+
}

0 commit comments

Comments
 (0)