diff --git a/frontend/src/pages/WorkspaceSettings/ChatSettings/QueryRewriteMode/index.jsx b/frontend/src/pages/WorkspaceSettings/ChatSettings/QueryRewriteMode/index.jsx new file mode 100644 index 00000000000..98f52b532e8 --- /dev/null +++ b/frontend/src/pages/WorkspaceSettings/ChatSettings/QueryRewriteMode/index.jsx @@ -0,0 +1,46 @@ +import { useState } from "react"; + +const hint = { + off: { + title: "Off", + description: + "Follow-up queries are sent to vector search as-is. This is the default behavior.", + }, + on: { + title: "On", + description: + "Follow-up queries are rewritten into standalone search queries using chat history, improving RAG results in multi-turn conversations.", + }, +}; + +export default function QueryRewriteMode({ workspace, setHasChanges }) { + const [selection, setSelection] = useState( + workspace?.queryRewriteMode ?? "off" + ); + + return ( +
+
+ +
+ +

+ {hint[selection]?.description} +

+
+ ); +} diff --git a/frontend/src/pages/WorkspaceSettings/ChatSettings/index.jsx b/frontend/src/pages/WorkspaceSettings/ChatSettings/index.jsx index e133dfacd62..5e743c09209 100644 --- a/frontend/src/pages/WorkspaceSettings/ChatSettings/index.jsx +++ b/frontend/src/pages/WorkspaceSettings/ChatSettings/index.jsx @@ -9,6 +9,7 @@ import ChatTemperatureSettings from "./ChatTemperatureSettings"; import ChatModeSelection from "./ChatModeSelection"; import WorkspaceLLMSelection from "./WorkspaceLLMSelection"; import ChatQueryRefusalResponse from "./ChatQueryRefusalResponse"; +import QueryRewriteMode from "./QueryRewriteMode"; import CTAButton from "@/components/lib/CTAButton"; export default function ChatSettings({ workspace }) { @@ -84,6 +85,10 @@ export default function ChatSettings({ workspace }) { workspace={workspace} setHasChanges={setHasChanges} /> + { + if ( + !value || + typeof value !== "string" || + !["on", "off"].includes(value) + ) + return process.env.ENABLE_QUERY_REWRITING === "true" ? "on" : "off"; + return value; + }, }, /** diff --git a/server/prisma/schema.prisma b/server/prisma/schema.prisma index db5c9c79c6f..f1395185bd4 100644 --- a/server/prisma/schema.prisma +++ b/server/prisma/schema.prisma @@ -145,6 +145,7 @@ model workspaces { agentModel String? queryRefusalResponse String? vectorSearchMode String? @default("default") + queryRewriteMode String? @default("off") workspace_users workspace_users[] documents workspace_documents[] workspace_suggested_messages workspace_suggested_messages[] diff --git a/server/utils/chats/apiChatHandler.js b/server/utils/chats/apiChatHandler.js index ff9ee101f70..3eb38198b95 100644 --- a/server/utils/chats/apiChatHandler.js +++ b/server/utils/chats/apiChatHandler.js @@ -13,6 +13,7 @@ const { EphemeralAgentHandler, EphemeralEventListener, } = require("../agents/ephemeral"); +const { rewriteQueryForSearch } = require("../helpers/chat/queryRewriter"); const { Telemetry } = require("../../models/telemetry"); const { CollectorApi } = require("../collectorApi"); const fs = require("fs"); @@ -292,11 +293,18 @@ async function chatSync({ } }); + const searchQuery = await rewriteQueryForSearch({ + userQuery: message, + chatHistory, + LLMConnector, + workspace, + }); + const vectorSearchResults = embeddingsCount !== 0 ? await VectorDb.performSimilaritySearch({ namespace: workspace.slug, - input: message, + input: searchQuery, LLMConnector, similarityThreshold: workspace?.similarityThreshold, topN: workspace?.topN, @@ -644,11 +652,18 @@ async function streamChat({ } }); + const searchQuery = await rewriteQueryForSearch({ + userQuery: message, + chatHistory, + LLMConnector, + workspace, + }); + const vectorSearchResults = embeddingsCount !== 0 ? await VectorDb.performSimilaritySearch({ namespace: workspace.slug, - input: message, + input: searchQuery, LLMConnector, similarityThreshold: workspace?.similarityThreshold, topN: workspace?.topN, diff --git a/server/utils/chats/embed.js b/server/utils/chats/embed.js index d9320241b35..18ab4b71300 100644 --- a/server/utils/chats/embed.js +++ b/server/utils/chats/embed.js @@ -7,6 +7,7 @@ const { writeResponseChunk, } = require("../helpers/chat/responses"); const { DocumentManager } = require("../DocumentManager"); +const { rewriteQueryForSearch } = require("../helpers/chat/queryRewriter"); async function streamChatWithForEmbed( response, @@ -84,11 +85,18 @@ async function streamChatWithForEmbed( }); }); + const searchQuery = await rewriteQueryForSearch({ + userQuery: message, + chatHistory, + LLMConnector, + workspace: embed.workspace, + }); + const vectorSearchResults = embeddingsCount !== 0 ? await VectorDb.performSimilaritySearch({ namespace: embed.workspace.slug, - input: message, + input: searchQuery, LLMConnector, similarityThreshold: embed.workspace?.similarityThreshold, topN: embed.workspace?.topN, diff --git a/server/utils/chats/openaiCompatible.js b/server/utils/chats/openaiCompatible.js index 0767dd44cf9..c449a8c4b69 100644 --- a/server/utils/chats/openaiCompatible.js +++ b/server/utils/chats/openaiCompatible.js @@ -4,6 +4,7 @@ const { WorkspaceChats } = require("../../models/workspaceChats"); const { getVectorDbClass, getLLMProvider } = require("../helpers"); const { writeResponseChunk } = require("../helpers/chat/responses"); const { chatPrompt, sourceIdentifier } = require("./index"); +const { rewriteQueryForSearch } = require("../helpers/chat/queryRewriter"); const { PassThrough } = require("stream"); @@ -82,11 +83,18 @@ async function chatSync({ }); }); + const searchQuery = await rewriteQueryForSearch({ + userQuery: String(prompt), + chatHistory: history, + LLMConnector, + workspace, + }); + const vectorSearchResults = embeddingsCount !== 0 ? await VectorDb.performSimilaritySearch({ namespace: workspace.slug, - input: String(prompt), + input: searchQuery, LLMConnector, similarityThreshold: workspace?.similarityThreshold, topN: workspace?.topN, @@ -308,11 +316,18 @@ async function streamChat({ }); }); + const searchQuery = await rewriteQueryForSearch({ + userQuery: String(prompt), + chatHistory: history, + LLMConnector, + workspace, + }); + const vectorSearchResults = embeddingsCount !== 0 ? await VectorDb.performSimilaritySearch({ namespace: workspace.slug, - input: String(prompt), + input: searchQuery, LLMConnector, similarityThreshold: workspace?.similarityThreshold, topN: workspace?.topN, diff --git a/server/utils/chats/stream.js b/server/utils/chats/stream.js index acb1e4a5c8a..05179290cd3 100644 --- a/server/utils/chats/stream.js +++ b/server/utils/chats/stream.js @@ -12,6 +12,7 @@ const { recentChatHistory, sourceIdentifier, } = require("./index"); +const { rewriteQueryForSearch } = require("../helpers/chat/queryRewriter"); const VALID_CHAT_MODE = ["chat", "query"]; @@ -147,11 +148,18 @@ async function streamChatWithWorkspace( }); }); + const searchQuery = await rewriteQueryForSearch({ + userQuery: updatedMessage, + chatHistory, + LLMConnector, + workspace, + }); + const vectorSearchResults = embeddingsCount !== 0 ? await VectorDb.performSimilaritySearch({ namespace: workspace.slug, - input: updatedMessage, + input: searchQuery, LLMConnector, similarityThreshold: workspace?.similarityThreshold, topN: workspace?.topN, diff --git a/server/utils/helpers/chat/queryRewriter.js b/server/utils/helpers/chat/queryRewriter.js new file mode 100644 index 00000000000..341dfbf3b6e --- /dev/null +++ b/server/utils/helpers/chat/queryRewriter.js @@ -0,0 +1,84 @@ +const REWRITE_PROMPT = `Given a chat history and the latest user question which might reference context in the chat history, determine if the question needs to be reformulated to be understood without the chat history. + +If the question already contains its own subject or topic, return it EXACTLY as written — do not rephrase, expand, or modify it in any way. + +Only reformulate when the question contains pronouns (it, that, they), demonstratives (this, these), or incomplete references (the first one, the second) that refer to the chat history. + +When reformulating, respond with ONLY the reformulated question (max 15 words). +Include the key subject/topic from conversation history. Do NOT add information not present in the conversation. +Write in the same language as the user.`; + +function shouldRewrite(userQuery, chatHistory, workspace) { + if (!chatHistory || chatHistory.length === 0) return false; + const mode = workspace?.queryRewriteMode ?? "off"; + if (mode !== "on") return false; + const wordCount = userQuery.trim().split(/\s+/).length; + const threshold = parseInt(process.env.QUERY_REWRITE_WORD_THRESHOLD) || 12; + return wordCount <= threshold; +} + +async function rewriteQueryForSearch({ + userQuery, + chatHistory, + LLMConnector, + workspace, +}) { + if (!shouldRewrite(userQuery, chatHistory, workspace)) return userQuery; + + try { + const maxTurns = parseInt(process.env.QUERY_REWRITE_MAX_HISTORY) || 2; + const recentHistory = chatHistory.slice(-maxTurns * 2); // 2 msgs per turn + const historyText = recentHistory + .map((m) => { + // Truncate long assistant messages to keep the rewrite prompt compact + // Keep the beginning — it usually contains the most relevant topic context + let content = m.content; + if (m.role === "assistant") { + const words = content.split(/\s+/); + if (words.length > 150) + content = words.slice(0, 150).join(" ") + "..."; + } + return `${m.role === "user" ? "User" : "Assistant"}: ${content}`; + }) + .join("\n"); + + const start = Date.now(); + const result = await LLMConnector.getChatCompletion( + [ + { role: "system", content: REWRITE_PROMPT }, + { + role: "user", + content: `Chat history:\n${historyText}\n\nLatest question: ${userQuery}`, + }, + ], + { temperature: 0.1 } + ); + + const rewritten = result?.textResponse + ?.trim() + ?.replace(/[\s\S]*?<\/think>/g, "") // Strip reasoning tags + ?.split("\n")[0] + ?.trim(); // Take first line only + + const elapsed = Date.now() - start; + console.log( + `\x1b[35m[QueryRewrite]\x1b[0m "${userQuery}" → "${rewritten}" (${elapsed}ms)` + ); + + if (!rewritten) return userQuery; + + // Verbatim check — LLM returned the query as-is (self-contained) + if (rewritten.trim().toLowerCase() === userQuery.trim().toLowerCase()) + return userQuery; + + return rewritten; + } catch (error) { + console.error( + "[QueryRewrite] Failed, using original query:", + error.message + ); + return userQuery; + } +} + +module.exports = { rewriteQueryForSearch };