Skip to content

Commit 7d98ef9

Browse files
jkwatsonewilliams-clouderabaasitsharief
authored
Chat UI improvements (#270)
* move select a datasource * fix search * fix spacing in input * width * add inference model selection to the chat input * fix broken test * handle showing correct model before activeSession * handle tool calling enablement based on inference model * adjust width of options * fix up broken things, and provide models for projects to the chat initiator * adjust colors * capitalize * bug fix for retriever tool with no summaries * fix broken tests * split bedrock by model family --------- Co-authored-by: Elijah Williams <[email protected]> Co-authored-by: Baasit Sharief <[email protected]>
1 parent bd4ade5 commit 7d98ef9

File tree

14 files changed

+309
-112
lines changed

14 files changed

+309
-112
lines changed

llm-service/app/services/query/agents/tool_calling_querier.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def should_use_retrieval(
111111
data_source_summary_indexer = SummaryIndexer.get_summary_indexer(data_source_id)
112112
if data_source_summary_indexer:
113113
data_source_summary = data_source_summary_indexer.get_full_summary()
114-
data_source_summaries[data_source_id] = data_source_summary
114+
if data_source_summary:
115+
data_source_summaries[data_source_id] = data_source_summary
115116
return len(data_source_ids) > 0, data_source_summaries
116117

117118

ui/src/api/modelsApi.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ import {
4747
} from "src/api/utils.ts";
4848

4949
export interface Model {
50-
name?: string;
50+
name: string;
5151
model_id: string;
5252
available: boolean | null;
5353
replica_count?: number;

ui/src/pages/Analytics/AnalyticsPage.tsx

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ import {
4545
Select,
4646
Typography,
4747
} from "antd";
48-
import { transformModelOptions } from "src/utils/modelUtils.ts";
48+
import {
49+
ModelSelectOptions,
50+
useTransformModelOptions,
51+
} from "src/utils/modelUtils.ts";
4952
import { useGetLlmModels, useGetRerankingModels } from "src/api/modelsApi.ts";
5053
import { MetricFilter } from "src/api/metricsApi.ts";
5154
import MetadataMetrics from "pages/Analytics/AppMetrics.tsx";
@@ -79,7 +82,7 @@ const SelectFilterOption = ({
7982
}: {
8083
name: string;
8184
label: string;
82-
options: { value: string; label: string | undefined }[];
85+
options: ModelSelectOptions;
8386
}) => {
8487
return (
8588
<Form.Item name={name} label={label} style={{ marginTop: 8 }}>
@@ -106,6 +109,8 @@ const MetricFilterOptions = ({
106109
const { data: llmModels } = useGetLlmModels();
107110
const { data: rerankingModels } = useGetRerankingModels();
108111
const { data: projects } = useGetProjects();
112+
const llmModelOptions = useTransformModelOptions(llmModels);
113+
const rerankModelOptions = useTransformModelOptions(rerankingModels);
109114

110115
return (
111116
<Form
@@ -117,15 +122,12 @@ const MetricFilterOptions = ({
117122
<SelectFilterOption
118123
name="inference_model"
119124
label="Response synthesizer model"
120-
options={transformModelOptions(llmModels)}
125+
options={llmModelOptions}
121126
/>
122127
<SelectFilterOption
123128
name="rerank_model"
124129
label="Reranking model"
125-
options={[
126-
...transformModelOptions(rerankingModels),
127-
{ value: "none", label: "None" },
128-
]}
130+
options={[...rerankModelOptions, { value: "none", label: "None" }]}
129131
/>
130132
<BooleanFilterOption
131133
name="use_summary_filter"

ui/src/pages/DataSources/DataSourcesManagement/DataSourcesForm.tsx

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ import {
5252
import { ConnectionType, DataSourceBaseType } from "src/api/dataSourceApi";
5353
import { useGetEmbeddingModels, useGetLlmModels } from "src/api/modelsApi.ts";
5454
import { useEffect } from "react";
55-
import { transformModelOptions } from "src/utils/modelUtils.ts";
55+
import { useTransformModelOptions } from "src/utils/modelUtils.ts";
5656
import { useNavigate } from "@tanstack/react-router";
5757
import messageQueue from "src/utils/messageQueue.ts";
5858

@@ -148,6 +148,8 @@ const DataSourcesForm = ({
148148
const embeddingsModels = useGetEmbeddingModels();
149149
const llmModels = useGetLlmModels();
150150
const navigate = useNavigate();
151+
const embeddingModelOptions = useTransformModelOptions(embeddingsModels.data);
152+
const llmModelOptions = useTransformModelOptions(llmModels.data);
151153

152154
useEffect(() => {
153155
if (initialValues.embeddingModel) {
@@ -226,7 +228,7 @@ const DataSourcesForm = ({
226228
initialValue={initialValues.embeddingModel}
227229
>
228230
<Select
229-
options={transformModelOptions(embeddingsModels.data)}
231+
options={embeddingModelOptions}
230232
disabled={updateMode}
231233
loading={embeddingsModels.isLoading}
232234
/>
@@ -246,7 +248,7 @@ const DataSourcesForm = ({
246248
initialValue={initialValues.summarizationModel}
247249
>
248250
<Select
249-
options={transformModelOptions(llmModels.data)}
251+
options={llmModelOptions}
250252
allowClear
251253
loading={llmModels.isLoading}
252254
/>

ui/src/pages/Projects/ProjectPage/NewChatSession/NewChatSession.tsx

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,18 @@
3636
* DATA.
3737
******************************************************************************/
3838

39-
import { Card, Select, Typography } from "antd";
40-
import RagChatQueryInput from "pages/RagChatTab/FooterComponents/RagChatQueryInput.tsx";
39+
import { Card, Typography } from "antd";
40+
import RagChatQueryInput, {
41+
NewSessionCallbackProps,
42+
} from "pages/RagChatTab/FooterComponents/RagChatQueryInput.tsx";
4143
import useCreateSessionAndRedirect from "pages/RagChatTab/ChatOutput/hooks/useCreateSessionAndRedirect.tsx";
4244
import { useGetDataSourcesForProject } from "src/api/projectsApi.ts";
4345
import { useProjectContext } from "pages/Projects/ProjectContext.tsx";
44-
import { DataSourceInputType, formatDataSource } from "src/utils/formatters.ts";
45-
import { useState } from "react";
4646

4747
export const NewChatSession = () => {
4848
const { project } = useProjectContext();
4949
const createSessionAndRedirect = useCreateSessionAndRedirect();
5050
const { data: dataSources } = useGetDataSourcesForProject(project.id);
51-
const [selectedDataSources, setSelectedDataSources] = useState<number[]>([]);
5251

5352
return (
5453
<Card
@@ -57,25 +56,19 @@ export const NewChatSession = () => {
5756
Start a new chat session
5857
</Typography.Title>
5958
}
60-
extra={
61-
<Select
62-
mode="multiple"
63-
placeholder="Select a data source (optional)"
64-
disabled={!dataSources || dataSources.length === 0}
65-
style={{ width: 300 }}
66-
allowClear={true}
67-
onChange={(ids: DataSourceInputType["value"][]) => {
68-
setSelectedDataSources(ids);
69-
}}
70-
options={dataSources?.map((value) => {
71-
return formatDataSource(value);
72-
})}
73-
/>
74-
}
7559
>
7660
<RagChatQueryInput
77-
newSessionCallback={(userInput: string) => {
78-
createSessionAndRedirect(selectedDataSources, userInput);
61+
validDataSources={dataSources}
62+
newSessionCallback={({
63+
userInput,
64+
selectedDataSourceIds,
65+
inferenceModel,
66+
}: NewSessionCallbackProps) => {
67+
createSessionAndRedirect(
68+
selectedDataSourceIds,
69+
userInput,
70+
inferenceModel,
71+
);
7972
}}
8073
/>
8174
</Card>

ui/src/pages/RagChatTab/ChatOutput/Placeholders/NoDataSourcesState.tsx

Lines changed: 3 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,9 @@
3737
******************************************************************************/
3838

3939
import { useNavigate } from "@tanstack/react-router";
40-
import { Button, Flex, Form, Select, Typography } from "antd";
40+
import { Button, Flex, Typography } from "antd";
4141
import { useContext, ReactNode } from "react";
4242
import { RagChatContext } from "pages/RagChatTab/State/RagChatContext.tsx";
43-
import useCreateSessionAndRedirect from "pages/RagChatTab/ChatOutput/hooks/useCreateSessionAndRedirect.tsx";
44-
import { ArrowRightOutlined } from "@ant-design/icons";
45-
import { formatDataSource } from "src/utils/formatters.ts";
4643

4744
const PlaceholderContainer = ({
4845
message,
@@ -70,63 +67,21 @@ const PlaceholderContainer = ({
7067

7168
const NoDataSourcesState = () => {
7269
const navigate = useNavigate();
73-
const [form] = Form.useForm<{ dataSourceIds: number[] }>();
7470
const {
7571
activeSession,
7672
dataSourcesQuery: { dataSources, dataSourcesStatus },
7773
} = useContext(RagChatContext);
7874

79-
const createSessionAndRedirect = useCreateSessionAndRedirect();
80-
81-
const handleCreateSession = () => {
82-
form
83-
.validateFields()
84-
.catch(() => null)
85-
.then((values) => {
86-
if (values?.dataSourceIds.length) {
87-
createSessionAndRedirect(values.dataSourceIds);
88-
}
89-
})
90-
.catch(() => null);
91-
};
92-
9375
if (activeSession) {
9476
return null;
9577
}
9678

9779
if (dataSourcesStatus === "success" && dataSources.length > 0) {
98-
return (
99-
<PlaceholderContainer message="Start chatting with an existing Knowledge Base">
100-
<Form autoCorrect="off" form={form} clearOnDestroy={true}>
101-
<Flex gap={8}>
102-
<Form.Item
103-
name="dataSourceIds"
104-
rules={[
105-
{ required: true, message: "Please select Knowledge Base(s)" },
106-
]}
107-
>
108-
<Select
109-
mode="multiple"
110-
disabled={dataSources.length === 0}
111-
style={{ width: 300 }}
112-
options={dataSources.map((value) => {
113-
return formatDataSource(value);
114-
})}
115-
/>
116-
</Form.Item>
117-
<Button
118-
type="primary"
119-
icon={<ArrowRightOutlined />}
120-
onClick={handleCreateSession}
121-
/>
122-
</Flex>
123-
</Form>
124-
</PlaceholderContainer>
125-
);
80+
return null;
12681
}
12782

12883
return (
129-
<PlaceholderContainer message="Or create a knowledge base to chat with your documents.">
84+
<PlaceholderContainer message="Create a knowledge base to chat with your documents.">
13085
<Button
13186
type="default"
13287
style={{ width: 200 }}

ui/src/pages/RagChatTab/ChatOutput/hooks/useCreateSessionAndRedirect.tsx

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,24 @@ const useCreateSessionAndRedirect = (
3232
},
3333
});
3434

35-
return (dataSourceIds: number[], question?: string) => {
35+
return (
36+
dataSourceIds: number[],
37+
question?: string,
38+
inferenceModel?: string,
39+
) => {
3640
if (models) {
41+
const supportsToolCalling = models.find(
42+
(model) => model.model_id === inferenceModel,
43+
)?.tool_calling_supported;
3744
const requestBody: CreateSessionRequest = {
3845
name: "",
3946
dataSourceIds: dataSourceIds,
40-
inferenceModel: models[0].model_id,
47+
inferenceModel: inferenceModel ?? models[0].model_id,
4148
responseChunks: 10,
4249
queryConfiguration: {
4350
enableHyde: false,
4451
enableSummaryFilter: true,
45-
enableToolCalling: false,
52+
enableToolCalling: supportsToolCalling ?? false,
4653
selectedTools: [],
4754
},
4855
embeddingModel: embeddingModels?.length

ui/src/pages/RagChatTab/FooterComponents/RagChatQueryInput.test.tsx

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,26 @@ vi.mock("src/api/chatApi.ts", () => ({
5858
})),
5959
}));
6060

61+
vi.mock("src/api/modelsApi.ts", () => ({
62+
getLlmModelsQueryOptions: {
63+
queryKey: ["llmModels"],
64+
queryFn: () =>
65+
Promise.resolve([
66+
{ model_id: "test-llm", tool_calling_supported: false },
67+
]),
68+
},
69+
useGetLlmModels: vi.fn(() => ({
70+
data: [{ model_id: "test-llm" }],
71+
isFetching: false,
72+
error: null,
73+
})),
74+
useGetModelSource: vi.fn(() => ({
75+
data: "OpenAI",
76+
isFetching: false,
77+
error: null,
78+
})),
79+
}));
80+
6181
vi.mock("src/api/ragQueryApi.ts", () => ({
6282
useSuggestQuestions: vi.fn(() => ({
6383
data: { suggested_questions: ["Sample question 1", "Sample question 2"] },
@@ -71,6 +91,16 @@ vi.mock("@tanstack/react-router", () => ({
7191
useSearch: vi.fn(() => ({ question: undefined })),
7292
}));
7393

94+
vi.mock("@tanstack/react-query", async (importOriginal) => {
95+
const actual = await importOriginal<typeof import("@tanstack/react-query")>();
96+
return {
97+
...actual,
98+
useSuspenseQuery: vi.fn(() => ({
99+
data: [{ model_id: "test-llm", tool_calling_supported: false }],
100+
})),
101+
};
102+
});
103+
74104
vi.mock("src/utils/useModal.ts", () => ({
75105
default: vi.fn(() => ({
76106
isModalOpen: false,
@@ -472,22 +502,22 @@ describe("RagChatQueryInput", () => {
472502
});
473503

474504
describe("New Session Callback", () => {
475-
it("calls newSessionCallback when no sessionId exists", async () => {
505+
it("calls newSessionCallback when no session exists", async () => {
476506
const user = userEvent.setup();
477-
const mockNewSessionCallback = vi.fn();
507+
const newSessionCallback = vi.fn();
478508
(useParams as Mock).mockReturnValue({ sessionId: undefined });
479-
480509
const mockContext = createMockContext();
481-
renderWithContext(mockContext, mockNewSessionCallback);
510+
renderWithContext(mockContext, newSessionCallback);
482511

483512
const textArea = screen.getByPlaceholderText("Ask a question");
484-
const sendButton = screen.getByRole("button", { name: /send/i });
485-
486-
await user.type(textArea, "New session query");
487-
await user.click(sendButton);
513+
await user.type(textArea, "Test query");
514+
await user.click(screen.getByRole("button", { name: /send/i }));
488515

489-
expect(mockNewSessionCallback).toHaveBeenCalledWith("New session query");
490-
expect(mockStreamingChatMutation.mutate).not.toHaveBeenCalled();
516+
expect(newSessionCallback).toHaveBeenCalledWith({
517+
userInput: "Test query",
518+
selectedDataSourceIds: [],
519+
inferenceModel: "test-llm",
520+
});
491521
});
492522
});
493523

0 commit comments

Comments
 (0)