diff --git a/docs/docs/server/openapi.yaml b/docs/docs/server/openapi.yaml index 8c4ccf748..dcf446273 100644 --- a/docs/docs/server/openapi.yaml +++ b/docs/docs/server/openapi.yaml @@ -17,9 +17,104 @@ servers: security: - CustomHeaderAuth: [] paths: + /content/search: + post: + operationId: searchContent + tags: + - Content + summary: Search content + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + query: + type: string + description: The search query string. + dataSources: + type: array + items: + type: object + properties: + name: + type: string + type: + type: string + versionLabel: + type: string + required: [name] + description: An array of data sources to search. If not provided, latest version of all data sources will be searched. + limit: + type: integer + minimum: 1 + maximum: 100 + default: 5 + description: The maximum number of results to return. + required: + - query + responses: + 200: + description: OK + headers: + Content-Type: + schema: + type: string + example: application/json + content: + application/json: + schema: + $ref: "#/components/schemas/SearchResponse" + 400: + description: Bad Request + headers: + Content-Type: + schema: + type: string + example: application/json + content: + application/json: + schema: + $ref: "#/components/responses/BadRequest" + 500: + description: Internal Server Error + headers: + Content-Type: + schema: + type: string + example: application/json + content: + application/json: + schema: + $ref: "#/components/responses/InternalServerError" + + /content/sources: + get: + operationId: listDataSources + tags: + - Content + summary: List available data sources + description: Returns metadata about all available data sources. + responses: + 200: + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListDataSourcesResponse" + 500: + description: Internal Server Error + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + /conversations: post: operationId: createConversation + tags: + - Conversations summary: Start new conversation description: | Start a new conversation. @@ -43,6 +138,8 @@ paths: /conversations/{conversationId}/messages: post: operationId: addMessage + tags: + - Conversations summary: Add message to the conversation tags: - Conversations @@ -133,6 +230,8 @@ paths: # /conversations/{conversationId}: # get: # operationId: getConversation + # tags: + # - Conversations # summary: Get a conversation # parameters: # - $ref: "#/components/parameters/conversationId" @@ -152,6 +251,8 @@ paths: /conversations/{conversationId}/messages/{messageId}/rating: post: operationId: rateMessage + tags: + - Conversations summary: Rate message tags: - Conversations @@ -176,6 +277,8 @@ paths: /conversations/{conversationId}/messages/{messageId}/comment: post: operationId: commentMessage + tags: + - Conversations summary: Add comment to assistant message tags: - Conversations @@ -972,6 +1075,68 @@ components: required: [type, call_id, output, status] ErrorResponse: type: object + SearchResponse: + type: object + properties: + results: + type: array + items: + $ref: "#/components/schemas/Chunk" + Chunk: + type: object + properties: + url: + type: string + description: The URL of the search result. + title: + type: string + description: Title of the search result. + text: + type: string + description: Chunk text + metadata: + type: object + properties: + sourceName: + type: string + description: The name of the source. + sourceType: + type: string + tags: + type: array + items: + type: string + additionalProperties: true + ListDataSourcesResponse: + type: object + properties: + dataSources: + type: array + items: + $ref: "#/components/schemas/DataSourceMetadata" + DataSourceMetadata: + type: object + required: + - id + properties: + id: + type: string + description: The name of the data source. + versions: + type: array + items: + type: object + properties: + label: + type: string + description: Version label + isCurrent: + type: boolean + description: Whether this version is current active version. + description: List of versions for this data source. + type: + type: string + description: The type of the data source. parameters: conversationId: name: conversationId @@ -987,3 +1152,19 @@ components: schema: type: string description: The unique identifier for a message. + +tags: + - name: Content + x-displayName: Search Content + description: Search MongoDB content + - name: Conversations + x-displayName: Conversations + description: Interact with MongoDB Chatbot + +x-tagGroups: + - name: Content + tags: + - Content + - name: Conversations + tags: + - Conversations \ No newline at end of file diff --git a/packages/chatbot-server-mongodb-public/src/config.ts b/packages/chatbot-server-mongodb-public/src/config.ts index 1cf7b5df1..ac00ab970 100644 --- a/packages/chatbot-server-mongodb-public/src/config.ts +++ b/packages/chatbot-server-mongodb-public/src/config.ts @@ -17,9 +17,10 @@ import { AddCustomDataFunc, FilterPreviousMessages, makeDefaultFindVerifiedAnswer, - defaultCreateConversationCustomData, - defaultAddMessageToConversationCustomData, makeVerifiedAnswerGenerateResponse, + addDefaultCustomData, + ConversationsRouterLocals, + ContentRouterLocals, addMessageToConversationVerifiedAnswerStream, responsesVerifiedAnswerStream, type MakeVerifiedAnswerGenerateResponseParams, @@ -34,8 +35,18 @@ import { import { redactConnectionUri } from "./middleware/redactConnectionUri"; import path from "path"; import express from "express"; -import { makeMongoDbPageStore, logger } from "mongodb-rag-core"; -import { wrapOpenAI, wrapTraced } from "mongodb-rag-core/braintrust"; +import { + makeMongoDbPageStore, + makeMongoDbSearchResultsStore, + logger, +} from "mongodb-rag-core"; +import { createAzure, wrapLanguageModel } from "mongodb-rag-core/aiSdk"; +import { + makeBraintrustLogger, + BraintrustMiddleware, + wrapOpenAI, + wrapTraced, +} from "mongodb-rag-core/braintrust"; import { AzureOpenAI } from "mongodb-rag-core/openai"; import { MongoClient } from "mongodb-rag-core/mongodb"; import { @@ -57,13 +68,9 @@ import { responsesApiStream, addMessageToConversationStream, } from "./processors/generateResponseWithTools"; -import { - makeBraintrustLogger, - BraintrustMiddleware, -} from "mongodb-rag-core/braintrust"; import { makeMongoDbScrubbedMessageStore } from "./tracing/scrubbedMessages/MongoDbScrubbedMessageStore"; import { MessageAnalysis } from "./tracing/scrubbedMessages/analyzeMessage"; -import { createAzure, wrapLanguageModel } from "mongodb-rag-core/aiSdk"; +import { makeFindContentWithMongoDbMetadata } from "./processors/findContentWithMongoDbMetadata"; import { makeMongoDbAssistantSystemPrompt } from "./systemPrompt"; import { makeFetchPageTool } from "./tools/fetchPage"; import { makeCorsOptions } from "./corsOptions"; @@ -129,6 +136,11 @@ export const embeddedContentStore = makeMongoDbEmbeddedContentStore({ }, }); +export const searchResultsStore = makeMongoDbSearchResultsStore({ + connectionUri: MONGODB_CONNECTION_URI, + databaseName: MONGODB_DATABASE_NAME, +}); + export const verifiedAnswerConfig = { embeddingModel: OPENAI_VERIFIED_ANSWER_EMBEDDING_DEPLOYMENT, findNearestNeighborsOptions: { @@ -306,7 +318,7 @@ export const makeGenerateResponse = (args?: MakeGenerateResponseParams) => export const createConversationCustomDataWithAuthUser: AddCustomDataFunc = async (req, res) => { - const customData = await defaultCreateConversationCustomData(req, res); + const customData = await addDefaultCustomData(req, res); if (req.cookies.auth_user) { customData.authUser = req.cookies.auth_user; } @@ -350,11 +362,20 @@ export async function closeDbConnections() { logger.info(`Segment logging is ${segmentConfig ? "enabled" : "disabled"}`); export const config: AppConfig = { + contentRouterConfig: { + findContent: makeFindContentWithMongoDbMetadata({ + findContent, + classifierModel: languageModel, + }), + searchResultsStore, + embeddedContentStore, + middleware: [requireValidIpAddress(), requireRequestOrigin()], + }, conversationsRouterConfig: { middleware: [ blockGetRequests, - requireValidIpAddress(), - requireRequestOrigin(), + requireValidIpAddress(), + requireRequestOrigin(), useSegmentIds(), redactConnectionUri(), cookieParser(), @@ -363,10 +384,7 @@ export const config: AppConfig = { ? createConversationCustomDataWithAuthUser : undefined, addMessageToConversationCustomData: async (req, res) => { - const defaultCustomData = await defaultAddMessageToConversationCustomData( - req, - res - ); + const defaultCustomData = await addDefaultCustomData(req, res); const customData = { ...defaultCustomData, }; diff --git a/packages/chatbot-server-mongodb-public/src/processors/findContentWithMongoDbMetadata.test.ts b/packages/chatbot-server-mongodb-public/src/processors/findContentWithMongoDbMetadata.test.ts new file mode 100644 index 000000000..547254a6f --- /dev/null +++ b/packages/chatbot-server-mongodb-public/src/processors/findContentWithMongoDbMetadata.test.ts @@ -0,0 +1,81 @@ +// Mocks +jest.mock("mongodb-rag-core/mongoDbMetadata", () => { + const actual = jest.requireActual("mongodb-rag-core/mongoDbMetadata"); + return { + ...actual, + classifyMongoDbProgrammingLanguageAndProduct: jest.fn(), + }; +}); + +jest.mock("mongodb-rag-core", () => { + const actual = jest.requireActual("mongodb-rag-core"); + return { + ...actual, + updateFrontMatter: jest.fn(), + }; +}); + +import { FindContentFunc, updateFrontMatter } from "mongodb-rag-core"; +import { + makeFindContentWithMongoDbMetadata, +} from "./findContentWithMongoDbMetadata"; +import { classifyMongoDbProgrammingLanguageAndProduct } from "mongodb-rag-core/mongoDbMetadata"; + + +const mockedClassify = + classifyMongoDbProgrammingLanguageAndProduct as jest.Mock; +const mockedUpdateFrontMatter = updateFrontMatter as jest.Mock; + +function makeMockFindContent(result: string[]): FindContentFunc { + return jest.fn().mockResolvedValue(result); +} + +afterEach(() => { + jest.resetAllMocks(); +}); + +describe("makeFindContentWithMongoDbMetadata", () => { + test("enhances query with front matter and classification", async () => { + const inputQuery = "How do I use MongoDB with TypeScript?"; + const expectedQuery = `--- +product: driver +programmingLanguage: typescript +--- +How do I use MongoDB with TypeScript?`; + const fakeResult = ["doc1", "doc2"]; + + mockedClassify.mockResolvedValue({ + product: "driver", + programmingLanguage: "typescript", + }); + mockedUpdateFrontMatter.mockReturnValue(expectedQuery); + + const findContentMock = makeMockFindContent(fakeResult); + + const wrappedFindContent = makeFindContentWithMongoDbMetadata({ + findContent: findContentMock, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + classifierModel: {} as any, + }); + + const result = await wrappedFindContent({ + query: inputQuery, + filters: { sourceName: ["docs"] }, + limit: 3, + }); + + expect(mockedClassify).toHaveBeenCalledWith(expect.anything(), inputQuery); + expect(mockedUpdateFrontMatter).toHaveBeenCalledWith(inputQuery, { + product: "driver", + programmingLanguage: "typescript", + }); + + expect(findContentMock).toHaveBeenCalledWith({ + query: expectedQuery, + filters: { sourceName: ["docs"] }, + limit: 3, + }); + + expect(result).toEqual(fakeResult); + }); +}); diff --git a/packages/chatbot-server-mongodb-public/src/processors/findContentWithMongoDbMetadata.ts b/packages/chatbot-server-mongodb-public/src/processors/findContentWithMongoDbMetadata.ts new file mode 100644 index 000000000..854038b57 --- /dev/null +++ b/packages/chatbot-server-mongodb-public/src/processors/findContentWithMongoDbMetadata.ts @@ -0,0 +1,38 @@ +import { FindContentFunc, updateFrontMatter } from "mongodb-rag-core"; +import { LanguageModel } from "mongodb-rag-core/aiSdk"; +import { wrapTraced } from "mongodb-rag-core/braintrust"; +import { classifyMongoDbProgrammingLanguageAndProduct } from "mongodb-rag-core/mongoDbMetadata"; + +export const makeFindContentWithMongoDbMetadata = ({ + findContent, + classifierModel, +}: { + findContent: FindContentFunc; + classifierModel: LanguageModel; +}) => { + const wrappedFindContent: FindContentFunc = wrapTraced( + async ({ query, filters, limit }) => { + const { product, programmingLanguage } = + await classifyMongoDbProgrammingLanguageAndProduct( + classifierModel, + query + ); + + const preProcessedQuery = updateFrontMatter(query, { + ...(product ? { product } : {}), + ...(programmingLanguage ? { programmingLanguage } : {}), + }); + + const res = await findContent({ + query: preProcessedQuery, + filters, + limit, + }); + return res; + }, + { + name: "findContentWithMongoDbMetadata", + } + ); + return wrappedFindContent; +}; diff --git a/packages/mongodb-artifact-generator/src/vectorSearch.ts b/packages/mongodb-artifact-generator/src/vectorSearch.ts index d8cfd3075..69e04a51f 100644 --- a/packages/mongodb-artifact-generator/src/vectorSearch.ts +++ b/packages/mongodb-artifact-generator/src/vectorSearch.ts @@ -24,7 +24,6 @@ export type MakeFindContentArgs = { embedder: Embedder; embeddedContentStore: EmbeddedContentStore; findNearestNeighborsOptions?: Partial; - // searchBoosters?: SearchBooster[]; }; export function makeFindContent({ diff --git a/packages/mongodb-chatbot-server/src/app.ts b/packages/mongodb-chatbot-server/src/app.ts index f2f65881e..3ee517791 100644 --- a/packages/mongodb-chatbot-server/src/app.ts +++ b/packages/mongodb-chatbot-server/src/app.ts @@ -18,6 +18,7 @@ import { logger } from "mongodb-rag-core"; import { ObjectId } from "mongodb-rag-core/mongodb"; import { getRequestId, logRequest, sendErrorResponse } from "./utils"; import cloneDeep from "lodash.clonedeep"; +import { makeContentRouter, MakeContentRouterParams } from "./routes"; /** Configuration for the server Express.js app. @@ -28,6 +29,11 @@ export interface AppConfig { */ conversationsRouterConfig: ConversationsRouterParams; + /** + Configuration for the content router. + */ + contentRouterConfig?: MakeContentRouterParams; + /** Configuration for the responses router. */ @@ -134,6 +140,7 @@ export const makeApp = async (config: AppConfig): Promise => { corsOptions, apiPrefix = DEFAULT_API_PREFIX, expressAppConfig, + contentRouterConfig, } = config; logger.info("Server has the following configuration:"); logger.info( @@ -157,6 +164,10 @@ export const makeApp = async (config: AppConfig): Promise => { ); app.use(`${apiPrefix}/responses`, makeResponsesRouter(responsesRouterConfig)); + if (contentRouterConfig) { + app.use(`${apiPrefix}/content`, makeContentRouter(contentRouterConfig)); + } + app.get("/health", (_req, res) => { const data = { uptime: process.uptime(), diff --git a/packages/mongodb-chatbot-server/src/middleware/requireRequestOrigin.ts b/packages/mongodb-chatbot-server/src/middleware/requireRequestOrigin.ts index 1f6567e04..61e6b68b7 100644 --- a/packages/mongodb-chatbot-server/src/middleware/requireRequestOrigin.ts +++ b/packages/mongodb-chatbot-server/src/middleware/requireRequestOrigin.ts @@ -1,9 +1,13 @@ +import { RequestHandler } from "express"; +import { ParamsDictionary } from "express-serve-static-core"; +import { ParsedQs } from "qs"; import { getRequestId, logRequest, sendErrorResponse } from "../utils"; -import { ConversationsMiddleware } from "../routes/conversations/conversationsRouter"; export const CUSTOM_REQUEST_ORIGIN_HEADER = "X-Request-Origin"; -export function requireRequestOrigin(): ConversationsMiddleware { +export function requireRequestOrigin< + Locals extends Record +>(): RequestHandler { return (req, res, next) => { const reqId = getRequestId(req); diff --git a/packages/mongodb-chatbot-server/src/middleware/requireValidIpAddress.test.ts b/packages/mongodb-chatbot-server/src/middleware/requireValidIpAddress.test.ts index 6f946c24a..8bb027e1e 100644 --- a/packages/mongodb-chatbot-server/src/middleware/requireValidIpAddress.test.ts +++ b/packages/mongodb-chatbot-server/src/middleware/requireValidIpAddress.test.ts @@ -3,6 +3,7 @@ import { createConversationsMiddlewareReq, createConversationsMiddlewareRes, } from "../test/middlewareTestHelpers"; +import { ConversationsRouterLocals } from "../routes"; const baseReq = { body: { message: "Hello, world!" }, @@ -18,7 +19,7 @@ describe("requireValidIpAddress", () => { const res = createConversationsMiddlewareRes(); const next = jest.fn(); - const middleware = requireValidIpAddress(); + const middleware = requireValidIpAddress(); req.body = baseReq.body; req.params = baseReq.params; req.query = baseReq.query; @@ -39,7 +40,7 @@ describe("requireValidIpAddress", () => { const next = jest.fn(); const invalidIpAddress = "not-an-ip-address"; - const middleware = requireValidIpAddress(); + const middleware = requireValidIpAddress(); req.body = baseReq.body; req.params = baseReq.params; req.query = baseReq.query; @@ -59,7 +60,7 @@ describe("requireValidIpAddress", () => { const res = createConversationsMiddlewareRes(); const next = jest.fn(); - const middleware = requireValidIpAddress(); + const middleware = requireValidIpAddress(); req.body = baseReq.body; req.params = baseReq.params; req.query = baseReq.query; diff --git a/packages/mongodb-chatbot-server/src/middleware/requireValidIpAddress.ts b/packages/mongodb-chatbot-server/src/middleware/requireValidIpAddress.ts index a4e3914dc..41dd77f76 100644 --- a/packages/mongodb-chatbot-server/src/middleware/requireValidIpAddress.ts +++ b/packages/mongodb-chatbot-server/src/middleware/requireValidIpAddress.ts @@ -1,8 +1,12 @@ +import { RequestHandler } from "express"; +import { ParamsDictionary } from "express-serve-static-core"; +import { ParsedQs } from "qs"; import { getRequestId, logRequest, sendErrorResponse } from "../utils"; -import { ConversationsMiddleware } from "../routes/conversations/conversationsRouter"; import { isValidIp } from "../routes/conversations/utils"; -export function requireValidIpAddress(): ConversationsMiddleware { +export function requireValidIpAddress< + Locals extends Record +>(): RequestHandler { return (req, res, next) => { const reqId = getRequestId(req); diff --git a/packages/mongodb-chatbot-server/src/middleware/validateRequestSchema.ts b/packages/mongodb-chatbot-server/src/middleware/validateRequestSchema.ts index 0fd952675..608cf6612 100644 --- a/packages/mongodb-chatbot-server/src/middleware/validateRequestSchema.ts +++ b/packages/mongodb-chatbot-server/src/middleware/validateRequestSchema.ts @@ -5,12 +5,12 @@ import { getRequestId, logRequest, sendErrorResponse } from "../utils"; export const SomeExpressRequest = z.object({ headers: z.object({}).optional(), - params: z.object({}).optional(), + params: z.object({}), query: z.object({}).optional(), body: z.object({}).optional(), }); -function generateZodErrorMessage(error: ZodError) { +export function generateZodErrorMessage(error: ZodError) { return generateErrorMessage(error.issues, { delimiter: { error: "\n", diff --git a/packages/mongodb-chatbot-server/src/processors/addCustomData.ts b/packages/mongodb-chatbot-server/src/processors/addCustomData.ts new file mode 100644 index 000000000..6b66f194b --- /dev/null +++ b/packages/mongodb-chatbot-server/src/processors/addCustomData.ts @@ -0,0 +1,94 @@ +import { Request, Response } from "express"; + +export type RequestCustomData = Record | undefined; + +/** + Function to add custom data to the {@link Conversation} or content search Request persisted to the database. + Has access to the Express.js request and response plus the values + from the {@link Response.locals} object. + */ +export type AddCustomDataFunc = ( + request: Request, + response: Response +) => Promise; + +const addIpToCustomData: AddCustomDataFunc = async (req) => + req.ip + ? { + ip: req.ip, + } + : undefined; + +const addOriginToCustomData: AddCustomDataFunc = async (_, res) => + res.locals.customData.origin + ? { + origin: res.locals.customData.origin, + } + : undefined; + +export const originCodes = [ + "LEARN", + "DEVELOPER", + "DOCS", + "DOTCOM", + "GEMINI_CODE_ASSIST", + "VSCODE", + "OTHER", +] as const; + +export type OriginCode = (typeof originCodes)[number]; + +interface OriginRule { + regex: RegExp; + code: OriginCode; +} + +const ORIGIN_RULES: OriginRule[] = [ + { regex: /learn\.mongodb\.com/, code: "LEARN" }, + { regex: /mongodb\.com\/developer/, code: "DEVELOPER" }, + { regex: /mongodb\.com\/docs/, code: "DOCS" }, + { regex: /mongodb\.com\//, code: "DOTCOM" }, + { regex: /google-gemini-code-assist/, code: "GEMINI_CODE_ASSIST" }, + { regex: /vscode-mongodb-copilot/, code: "VSCODE" }, +]; + +function getOriginCode(origin: string): OriginCode { + for (const rule of ORIGIN_RULES) { + if (rule.regex.test(origin)) { + return rule.code; + } + } + return "OTHER"; +} + +const addOriginCodeToCustomData: AddCustomDataFunc = async (_, res) => { + const origin = res.locals.customData.origin; + return typeof origin === "string" && origin.length > 0 + ? { + originCode: getOriginCode(origin), + } + : undefined; +}; + +const addUserAgentToCustomData: AddCustomDataFunc = async (req) => + req.headers["user-agent"] + ? { + userAgent: req.headers["user-agent"], + } + : undefined; + +export type AddDefinedCustomDataFunc = ( + ...args: Parameters +) => Promise>; + +export const addDefaultCustomData: AddDefinedCustomDataFunc = async ( + req, + res +) => { + return { + ...(await addIpToCustomData(req, res)), + ...(await addOriginToCustomData(req, res)), + ...(await addOriginCodeToCustomData(req, res)), + ...(await addUserAgentToCustomData(req, res)), + }; +}; diff --git a/packages/mongodb-chatbot-server/src/processors/index.ts b/packages/mongodb-chatbot-server/src/processors/index.ts index 1f7975e67..399d44aed 100644 --- a/packages/mongodb-chatbot-server/src/processors/index.ts +++ b/packages/mongodb-chatbot-server/src/processors/index.ts @@ -9,3 +9,4 @@ export * from "./InputGuardrail"; export * from "./makeVerifiedAnswerGenerateResponse"; export * from "./includeChunksForMaxTokensPossible"; export * from "./GenerateResponse"; +export * from "./addCustomData"; diff --git a/packages/mongodb-chatbot-server/src/routes/content/contentRouter.test.ts b/packages/mongodb-chatbot-server/src/routes/content/contentRouter.test.ts new file mode 100644 index 000000000..7e8afebb6 --- /dev/null +++ b/packages/mongodb-chatbot-server/src/routes/content/contentRouter.test.ts @@ -0,0 +1,96 @@ +import { Express } from "express"; +import request from "supertest"; +import type { + FindContentFunc, + MongoDbSearchResultsStore, +} from "mongodb-rag-core"; +import type { + MakeContentRouterParams, + SearchContentMiddleware, +} from "./contentRouter"; +import { makeTestApp } from "../../test/testHelpers"; +import { embeddedContentStore } from "../../test/testConfig"; + +// Minimal in-memory mock for SearchResultsStore for testing purposes +const mockSearchResultsStore: MongoDbSearchResultsStore = { + drop: jest.fn(), + close: jest.fn(), + metadata: { + databaseName: "mock", + collectionName: "mock", + }, + saveSearchResult: jest.fn(), + init: jest.fn(), +}; + +const findContentMock = jest.fn().mockResolvedValue({ + content: [], + queryEmbedding: [], +}) satisfies FindContentFunc; + +function makeMockContentRouterConfig( + overrides: Partial = {} +) { + return { + findContent: findContentMock, + searchResultsStore: mockSearchResultsStore, + embeddedContentStore, + ...overrides, + } satisfies MakeContentRouterParams; +} + +describe("contentRouter", () => { + const ipAddress = "127.0.0.1"; + const searchEndpoint = "/api/v1/content/search"; + + it("should call custom middleware if provided", async () => { + const mockMiddleware = jest.fn((_req, _res, next) => next()); + const { app, origin } = await makeTestApp({ + contentRouterConfig: makeMockContentRouterConfig({ + middleware: [mockMiddleware], + }), + }); + await createContentReq({ app, origin, query: "mongodb" }); + expect(mockMiddleware).toHaveBeenCalled(); + }); + + test("should use route middleware customData", async () => { + const middleware1: SearchContentMiddleware = (_, res, next) => { + res.locals.customData.middleware1 = true; + next(); + }; + let called = false; + const middleware2: SearchContentMiddleware = (_, res, next) => { + expect(res.locals.customData.middleware1).toBe(true); + called = true; + next(); + }; + const { app, origin } = await makeTestApp({ + contentRouterConfig: makeMockContentRouterConfig({ + middleware: [middleware1, middleware2], + }), + }); + await createContentReq({ app, origin, query: "What is aggregation?" }); + expect(called).toBe(true); + }); + + /** + Helper function to create a new content request + */ + async function createContentReq({ + app, + origin, + query, + }: { + app: Express; + origin: string; + query: string; + }) { + const createContentRes = await request(app) + .post(searchEndpoint) + .set("X-FORWARDED-FOR", ipAddress) + .set("Origin", origin) + .send({ query }); + return createContentRes; + } +}); diff --git a/packages/mongodb-chatbot-server/src/routes/content/contentRouter.ts b/packages/mongodb-chatbot-server/src/routes/content/contentRouter.ts new file mode 100644 index 000000000..cafacde4f --- /dev/null +++ b/packages/mongodb-chatbot-server/src/routes/content/contentRouter.ts @@ -0,0 +1,96 @@ +import { NextFunction, RequestHandler, Response, Router } from "express"; +import { ParamsDictionary } from "express-serve-static-core"; +import { FindContentFunc, MongoDbEmbeddedContentStore, MongoDbSearchResultsStore } from "mongodb-rag-core"; +import { ParsedQs } from "qs"; + +import validateRequestSchema from "../../middleware/validateRequestSchema"; +import { SearchContentRequest, makeSearchContentRoute } from "./searchContent"; +import { requireRequestOrigin, requireValidIpAddress } from "../../middleware"; +import { + AddCustomDataFunc, + addDefaultCustomData, + RequestCustomData, +} from "../../processors"; +import { GetDataSourcesRequest, makeListDataSourcesRoute } from "./listDataSources"; + +export type SearchContentCustomData = RequestCustomData; + +/** + Middleware to put in front of all the routes in the contentRouter. + Useful for authentication, data validation, logging, etc. + It exposes the app's {@link ContentRouterLocals} via {@link Response.locals} + ([docs](https://expressjs.com/en/api.html#res.locals)). + You can use or modify `res.locals.customData` in your middleware, and this data + will be available to subsequent middleware and route handlers. + */ +export type SearchContentMiddleware = RequestHandler< + ParamsDictionary, + unknown, + unknown, + ParsedQs, + ContentRouterLocals +>; + +/** + Local variables provided by Express.js for single request-response cycle + + Keeps track of data for authentication or dynamic data validation. + */ +export interface ContentRouterLocals { + customData: Record; +} + +/** + Express.js Response from the app's {@link ConversationsService}. + */ +export type SearchContentRouterResponse = Response< + // eslint-disable-next-line @typescript-eslint/no-explicit-any + any, + ContentRouterLocals +>; + +export interface MakeContentRouterParams { + findContent: FindContentFunc; + searchResultsStore: MongoDbSearchResultsStore; + embeddedContentStore: MongoDbEmbeddedContentStore; + addCustomData?: AddCustomDataFunc; + middleware?: SearchContentMiddleware[]; +} + +export function makeContentRouter({ + findContent, + searchResultsStore, + embeddedContentStore, + addCustomData = addDefaultCustomData, + middleware = [ + requireValidIpAddress(), + requireRequestOrigin(), + ], +}: MakeContentRouterParams) { + const contentRouter = Router(); + + // Set the customData and conversations on the response locals + // for use in subsequent middleware. + contentRouter.use(((_, res: Response, next: NextFunction) => { + res.locals.customData = {}; + next(); + }) satisfies RequestHandler); + + // Add middleware to the conversationsRouter. + middleware?.forEach((middleware) => contentRouter.use(middleware)); + + // Create new conversation. + contentRouter.post( + "/search", + validateRequestSchema(SearchContentRequest), + makeSearchContentRoute({ + findContent, + searchResultsStore, + addCustomData, + }) + ); + + contentRouter.get("/sources", validateRequestSchema(GetDataSourcesRequest), makeListDataSourcesRoute({ embeddedContentStore })); + + return contentRouter; +} diff --git a/packages/mongodb-chatbot-server/src/routes/content/index.ts b/packages/mongodb-chatbot-server/src/routes/content/index.ts new file mode 100644 index 000000000..692a707f5 --- /dev/null +++ b/packages/mongodb-chatbot-server/src/routes/content/index.ts @@ -0,0 +1,2 @@ +export * from "./contentRouter"; +export * from "./searchContent"; diff --git a/packages/mongodb-chatbot-server/src/routes/content/listDataSources.test.ts b/packages/mongodb-chatbot-server/src/routes/content/listDataSources.test.ts new file mode 100644 index 000000000..dfc138bb0 --- /dev/null +++ b/packages/mongodb-chatbot-server/src/routes/content/listDataSources.test.ts @@ -0,0 +1,86 @@ +import type { + MongoDbEmbeddedContentStore, + DataSourceMetadata, +} from "mongodb-rag-core"; +import { createRequest, createResponse } from "node-mocks-http"; +import { ERROR_MESSAGES, makeListDataSourcesRoute } from "./listDataSources"; + +function makeMockEmbeddedContentStore(dataSources: DataSourceMetadata[]) { + return { + listDataSources: jest.fn().mockResolvedValue(dataSources), + } as unknown as MongoDbEmbeddedContentStore; +} + +describe("makeListDataSourcesRoute", () => { + const mockDataSources: DataSourceMetadata[] = [ + { + id: "source1", + versions: [ + { label: "current", isCurrent: true }, + { label: "v6.0", isCurrent: false }, + ], + type: "docs", + }, + { + id: "source2", + versions: [{ label: "v2.11", isCurrent: false }], + type: "university-content", + }, + ]; + + it("should return data sources for a valid request", async () => { + const embeddedContentStore = makeMockEmbeddedContentStore(mockDataSources); + const handler = makeListDataSourcesRoute({ embeddedContentStore }); + + const req = createRequest({ + headers: { "req-id": "test-req-id" }, + }); + const res = createResponse(); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await handler(req, res as any); + + const data = res._getJSONData(); + expect(data).toHaveProperty("dataSources"); + expect(Array.isArray(data.dataSources)).toBe(true); + expect(data.dataSources.length).toBe(2); + expect(data.dataSources[0].id).toBe("source1"); + }); + + it("should handle errors from embeddedContentStore and throw", async () => { + const embeddedContentStore = { + listDataSources: jest.fn().mockRejectedValue(new Error("fail")), + } as unknown as MongoDbEmbeddedContentStore; + const handler = makeListDataSourcesRoute({ embeddedContentStore }); + + const req = createRequest({ + headers: { "req-id": "test-req-id" }, + }); + const res = createResponse(); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await expect(handler(req, res as any)).rejects.toMatchObject({ + message: ERROR_MESSAGES.UNABLE_TO_LIST_DATA_SOURCES, + httpStatus: 500, + name: "RequestError", + }); + }); + + it("should return an empty array if no data sources are found", async () => { + const embeddedContentStore = makeMockEmbeddedContentStore([]); + const handler = makeListDataSourcesRoute({ embeddedContentStore }); + + const req = createRequest({ + headers: { "req-id": "test-req-id" }, + }); + const res = createResponse(); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await handler(req, res as any); + + const data = res._getJSONData(); + expect(data).toHaveProperty("dataSources"); + expect(Array.isArray(data.dataSources)).toBe(true); + expect(data.dataSources.length).toBe(0); + }); +}); diff --git a/packages/mongodb-chatbot-server/src/routes/content/listDataSources.ts b/packages/mongodb-chatbot-server/src/routes/content/listDataSources.ts new file mode 100644 index 000000000..c1035dc2d --- /dev/null +++ b/packages/mongodb-chatbot-server/src/routes/content/listDataSources.ts @@ -0,0 +1,57 @@ +import { + Request as ExpressRequest, + Response as ExpressResponse, +} from "express"; +import { + DataSourceMetadata, + MongoDbEmbeddedContentStore, +} from "mongodb-rag-core"; +import { z } from "zod"; + +import { SomeExpressRequest } from "../../middleware"; +import { makeRequestError } from "../conversations/utils"; +import { + ContentRouterLocals, +} from "./contentRouter"; + +export type GetDataSourcesRequest = z.infer; + +export const GetDataSourcesRequest = SomeExpressRequest.merge( + z.object({ + headers: z.object({ + "req-id": z.string(), + }), + }) +); + +export interface ListDataSourcesResponseBody { + dataSources: DataSourceMetadata[]; +} + +export interface MakeListDataSourcesRouteParams { + embeddedContentStore: MongoDbEmbeddedContentStore; +} + +export const ERROR_MESSAGES = { + UNABLE_TO_LIST_DATA_SOURCES: "Unable to list data sources", +}; + +export function makeListDataSourcesRoute({ + embeddedContentStore, +}: MakeListDataSourcesRouteParams) { + return async ( + _req: ExpressRequest, + res: ExpressResponse + ) => { + try { + // Fetch data sources from the store + const dataSources = await embeddedContentStore.listDataSources(); + res.json({ dataSources }); + } catch (error) { + throw makeRequestError({ + httpStatus: 500, + message: ERROR_MESSAGES.UNABLE_TO_LIST_DATA_SOURCES, + }); + } + }; +} diff --git a/packages/mongodb-chatbot-server/src/routes/content/searchContent.test.ts b/packages/mongodb-chatbot-server/src/routes/content/searchContent.test.ts new file mode 100644 index 000000000..b2ca79963 --- /dev/null +++ b/packages/mongodb-chatbot-server/src/routes/content/searchContent.test.ts @@ -0,0 +1,155 @@ +import { makeSearchContentRoute } from "./searchContent"; +import type { FindContentFunc, FindContentResult } from "mongodb-rag-core"; +import type { MongoDbSearchResultsStore } from "mongodb-rag-core"; +import { createRequest, createResponse } from "node-mocks-http"; + +// Helper to create a mock FindContentFunc +function makeMockFindContent(result: FindContentResult) { + return jest.fn().mockResolvedValue(result) satisfies FindContentFunc; +} + +// Helper to create a mock MongoDbSearchResultsStore +function makeMockMongoDbSearchResultsStore() { + return { + drop: jest.fn(), + close: jest.fn(), + metadata: { databaseName: "mock", collectionName: "mock" }, + saveSearchResult: jest.fn().mockResolvedValue(undefined), + init: jest.fn(), + } satisfies MongoDbSearchResultsStore; +} + +describe("makeSearchContentRoute", () => { + const baseReqBody = { + query: "What is aggregation?", + limit: 2, + dataSources: [{ name: "source1", type: "docs", versionLabel: "v1" }], + }; + // Add all required EmbeddedContent fields for the mock result + const baseFindContentResult: FindContentResult = { + queryEmbedding: [0.1, 0.2, 0.3], + content: [ + { + url: "https://www.mongodb.com/docs/manual/aggregation", + text: "Look at all this aggregation", + metadata: { pageTitle: "Aggregation Operations" }, + sourceName: "source1", + tokenCount: 8, + embeddings: { test: [0.1, 0.2, 0.3] }, + updated: new Date(), + score: 0.8, + }, + { + url: "https://mongodb.com/docs", + text: "MongoDB Docs", + metadata: { pageTitle: "MongoDB" }, + sourceName: "source1", + tokenCount: 10, + embeddings: { test: [0.1, 0.2, 0.3] }, + updated: new Date(), + score: 0.6, + }, + ], + }; + + it("should return search results for a valid request", async () => { + const findContent = makeMockFindContent(baseFindContentResult); + const searchResultsStore = makeMockMongoDbSearchResultsStore(); + const handler = makeSearchContentRoute({ findContent, searchResultsStore }); + + const req = createRequest({ + body: baseReqBody, + headers: { "req-id": "test-req-id" }, + }); + const res = createResponse(); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await handler(req, res as any); + + const data = res._getJSONData(); + expect(data).toHaveProperty("results"); + expect(Array.isArray(data.results)).toBe(true); + expect(data.results.length).toBe(2); + expect(data.results[0].url).toBe( + "https://www.mongodb.com/docs/manual/aggregation" + ); + }); + + it("should call findContent with correct arguments", async () => { + const findContent = jest.fn().mockResolvedValue(baseFindContentResult); + const searchResultsStore = makeMockMongoDbSearchResultsStore(); + const handler = makeSearchContentRoute({ findContent, searchResultsStore }); + const req = createRequest({ + body: baseReqBody, + headers: { "req-id": "test-req-id" }, + }); + const res = createResponse(); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await handler(req, res as any); + + expect(findContent).toHaveBeenCalledWith({ + query: baseReqBody.query, + filters: expect.any(Object), + limit: baseReqBody.limit, + }); + }); + + it("should call searchResultsStore.saveSearchResult", async () => { + const findContent = makeMockFindContent(baseFindContentResult); + const searchResultsStore = makeMockMongoDbSearchResultsStore(); + const handler = makeSearchContentRoute({ findContent, searchResultsStore }); + const req = createRequest({ + body: baseReqBody, + headers: { "req-id": "test-req-id" }, + }); + const res = createResponse(); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await handler(req, res as any); + expect(searchResultsStore.saveSearchResult).toHaveBeenCalledWith( + expect.objectContaining({ + query: baseReqBody.query, + results: baseFindContentResult.content, + dataSources: baseReqBody.dataSources, + limit: baseReqBody.limit, + }) + ); + }); + + it("should handle errors from findContent and throw", async () => { + const findContent = jest.fn().mockRejectedValue(new Error("fail")); + const searchResultsStore = makeMockMongoDbSearchResultsStore(); + const handler = makeSearchContentRoute({ findContent, searchResultsStore }); + const req = createRequest({ + body: baseReqBody, + headers: { "req-id": "test-req-id" }, + }); + const res = createResponse(); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await expect(handler(req, res as any)).rejects.toMatchObject({ + message: "Unable to query search database", + httpStatus: 500, + name: "RequestError", + }); + }); + + it("should respect `limit` and `dataSources` parameters", async () => { + const findContent = jest.fn().mockResolvedValue(baseFindContentResult); + const searchResultsStore = makeMockMongoDbSearchResultsStore(); + const handler = makeSearchContentRoute({ findContent, searchResultsStore }); + const req = createRequest({ + body: { ...baseReqBody, limit: 1, dataSources: [{ name: "source2" }] }, + headers: { "req-id": "test-req-id" }, + }); + const res = createResponse(); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await handler(req, res as any); + expect(findContent).toHaveBeenCalledWith( + expect.objectContaining({ + limit: 1, + filters: expect.objectContaining({ sourceName: ["source2"] }), + }) + ); + }); +}); diff --git a/packages/mongodb-chatbot-server/src/routes/content/searchContent.ts b/packages/mongodb-chatbot-server/src/routes/content/searchContent.ts new file mode 100644 index 000000000..992068661 --- /dev/null +++ b/packages/mongodb-chatbot-server/src/routes/content/searchContent.ts @@ -0,0 +1,185 @@ +import { + Request as ExpressRequest, + Response as ExpressResponse, +} from "express"; +import { + FindContentFunc, + FindContentResult, + MongoDbSearchResultsStore, + QueryFilters, + SearchRecordDataSource, + SearchRecordDataSourceSchema, +} from "mongodb-rag-core"; +import { z } from "zod"; + +import { generateZodErrorMessage, SomeExpressRequest } from "../../middleware"; +import { makeRequestError } from "../conversations/utils"; +import { + SearchContentCustomData, + ContentRouterLocals, +} from "./contentRouter"; +import { AddCustomDataFunc } from "../../processors"; +import { wrapTraced } from "mongodb-rag-core/braintrust"; + +export const SearchContentRequestBody = z.object({ + query: z.string(), + dataSources: z.array(SearchRecordDataSourceSchema).optional(), + limit: z.number().int().min(1).max(500).optional().default(5), +}); + +export const SearchContentRequest = SomeExpressRequest.merge( + z.object({ + headers: z.object({ + "req-id": z.string(), + }), + body: SearchContentRequestBody, + }) +); + +export type SearchContentRequest = z.infer; +export type SearchContentRequestBody = z.infer; + +export interface MakeSearchContentRouteParams { + findContent: FindContentFunc; + searchResultsStore: MongoDbSearchResultsStore; + addCustomData?: AddCustomDataFunc; +} + +interface SearchContentResponseChunk { + url: string; + title: string; + text: string; + metadata?: { + sourceName?: string; + sourceType?: string; + sourceVersionLabel?: string; + tags?: string[]; + [k: string]: unknown; + }; +} +interface SearchContentResponseBody { + results: SearchContentResponseChunk[]; +} + +export function makeSearchContentRoute({ + findContent, + searchResultsStore, + addCustomData, +}: MakeSearchContentRouteParams) { + const tracedFindContent = wrapTraced(findContent, { name: "searchContent" }); + return async ( + req: ExpressRequest, + res: ExpressResponse + ) => { + try { + // --- INPUT VALIDATION --- + const { error, data } = SearchContentRequestBody.safeParse(req.body); + if (error) { + throw makeRequestError({ + httpStatus: 500, + message: generateZodErrorMessage(error), + }); + } + + const { query, dataSources, limit } = data; + const results = await tracedFindContent({ + query, + filters: mapDataSourcesToFilters(dataSources), + limit, + }); + res.json(mapFindContentResultToSearchContentResponseChunk(results)); + + const customData = await getCustomData(req, res, addCustomData); + await persistSearchResultsToDatabase({ + query, + results, + dataSources, + limit, + searchResultsStore, + customData, + }); + } catch (error) { + throw makeRequestError({ + httpStatus: 500, + message: "Unable to query search database", + }); + } + }; +} + +function mapFindContentResultToSearchContentResponseChunk( + result: FindContentResult +): SearchContentResponseBody { + return { + results: result.content.map(({ url, metadata, text }) => ({ + url, + title: metadata?.pageTitle ?? "", + text, + metadata, + })), + }; +} + +function mapDataSourcesToFilters( + dataSources?: SearchRecordDataSource[] +): QueryFilters { + if (!dataSources || dataSources.length === 0) { + return {}; + } + + const sourceNames = dataSources.map((ds) => ds.name); + const sourceTypes = dataSources + .map((ds) => ds.type) + .filter((t): t is string => !!t); + const versionLabels = dataSources + .map((ds) => ds.versionLabel) + .filter((v): v is string => !!v); + + return { + ...(sourceNames.length && { sourceName: sourceNames }), + ...(sourceTypes.length && { sourceType: sourceTypes }), + ...(versionLabels.length && { version: { label: versionLabels } }), + }; +} + +async function persistSearchResultsToDatabase({ + query, + results, + dataSources = [], + limit, + searchResultsStore, + customData, +}: { + query: string; + results: FindContentResult; + dataSources?: SearchRecordDataSource[]; + limit: number; + searchResultsStore: MongoDbSearchResultsStore; + customData?: { [k: string]: unknown }; +}) { + searchResultsStore.saveSearchResult({ + query, + results: results.content, + dataSources, + limit, + createdAt: new Date(), + ...(customData !== undefined && { customData }), + }); +} + +async function getCustomData( + req: ExpressRequest, + res: ExpressResponse, + addCustomData?: AddCustomDataFunc +): Promise { + try { + if (addCustomData) { + return await addCustomData(req, res); + } + } catch (error) { + throw makeRequestError({ + httpStatus: 500, + message: "Error parsing custom data from the request", + }); + } +} diff --git a/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.ts b/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.ts index 68b33afa9..6bb1291a6 100644 --- a/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.ts +++ b/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.ts @@ -27,7 +27,6 @@ import { import { z } from "zod"; import { SomeExpressRequest } from "../../middleware/validateRequestSchema"; import { - AddCustomDataFunc, ConversationsRouterLocals, } from "./conversationsRouter"; import { wrapTraced, Logger } from "mongodb-rag-core/braintrust"; @@ -36,6 +35,7 @@ import { GenerateResponse, GenerateResponseParams, } from "../../processors/GenerateResponse"; +import { AddCustomDataFunc } from "../../processors"; import { hasTooManyUserMessagesInConversation } from "../responses/createResponse"; import { creationInterface } from "./constants"; diff --git a/packages/mongodb-chatbot-server/src/routes/conversations/conversationsRouter.ts b/packages/mongodb-chatbot-server/src/routes/conversations/conversationsRouter.ts index d98f1e6ce..892432a79 100644 --- a/packages/mongodb-chatbot-server/src/routes/conversations/conversationsRouter.ts +++ b/packages/mongodb-chatbot-server/src/routes/conversations/conversationsRouter.ts @@ -32,6 +32,7 @@ import { import { UpdateTraceFunc } from "./UpdateTraceFunc"; import { GenerateResponse } from "../../processors/GenerateResponse"; import { Logger } from "mongodb-rag-core/braintrust"; +import { AddCustomDataFunc, addDefaultCustomData } from "../../processors"; /** Configuration for rate limiting on the /conversations/* routes. @@ -62,16 +63,6 @@ export interface ConversationsRateLimitConfig { addMessageSlowDownConfig?: SlowDownOptions; } -/** - Function to add custom data to the {@link Conversation} persisted to the database. - Has access to the Express.js request and response plus the {@link ConversationsRouterLocals} - from the {@link Response.locals} object. - */ -export type AddCustomDataFunc = ( - request: Request, - response: ConversationsRouterResponse -) => Promise; - /** Express.js Request that exposes the app's {@link ConversationsService}. @@ -193,95 +184,6 @@ export interface ConversationsRouterParams { braintrustLogger?: Logger; } -const addIpToCustomData: AddCustomDataFunc = async (req) => - req.ip - ? { - ip: req.ip, - } - : undefined; - -const addOriginToCustomData: AddCustomDataFunc = async (_, res) => - res.locals.customData.origin - ? { - origin: res.locals.customData.origin, - } - : undefined; - -export const originCodes = [ - "LEARN", - "DEVELOPER", - "DOCS", - "DOTCOM", - "GEMINI_CODE_ASSIST", - "VSCODE", - "OTHER", -] as const; - -export type OriginCode = (typeof originCodes)[number]; - -interface OriginRule { - regex: RegExp; - code: OriginCode; -} - -const ORIGIN_RULES: OriginRule[] = [ - { regex: /learn\.mongodb\.com/, code: "LEARN" }, - { regex: /mongodb\.com\/developer/, code: "DEVELOPER" }, - { regex: /mongodb\.com\/docs/, code: "DOCS" }, - { regex: /mongodb\.com\//, code: "DOTCOM" }, - { regex: /google-gemini-code-assist/, code: "GEMINI_CODE_ASSIST" }, - { regex: /vscode-mongodb-copilot/, code: "VSCODE" }, -]; - -function getOriginCode(origin: string): OriginCode { - for (const rule of ORIGIN_RULES) { - if (rule.regex.test(origin)) { - return rule.code; - } - } - return "OTHER"; -} - -const addOriginCodeToCustomData: AddCustomDataFunc = async (_, res) => { - const origin = res.locals.customData.origin; - return typeof origin === "string" && origin.length > 0 - ? { - originCode: getOriginCode(origin), - } - : undefined; -}; - -const addUserAgentToCustomData: AddCustomDataFunc = async (req) => - req.headers["user-agent"] - ? { - userAgent: req.headers["user-agent"], - } - : undefined; - -export type AddDefinedCustomDataFunc = ( - ...args: Parameters -) => Promise>; - -export const defaultCreateConversationCustomData: AddDefinedCustomDataFunc = - async (req, res) => { - return { - ...(await addIpToCustomData(req, res)), - ...(await addOriginToCustomData(req, res)), - ...(await addOriginCodeToCustomData(req, res)), - ...(await addUserAgentToCustomData(req, res)), - }; - }; - -export const defaultAddMessageToConversationCustomData: AddDefinedCustomDataFunc = - async (req, res) => { - return { - ...(await addIpToCustomData(req, res)), - ...(await addOriginToCustomData(req, res)), - ...(await addOriginCodeToCustomData(req, res)), - ...(await addUserAgentToCustomData(req, res)), - }; - }; - /** Constructor function to make the /conversations/* Express.js router. */ @@ -291,9 +193,9 @@ export function makeConversationsRouter({ maxInputLengthCharacters, maxUserMessagesInConversation, rateLimitConfig, - middleware = [requireValidIpAddress(), requireRequestOrigin()], - createConversationCustomData = defaultCreateConversationCustomData, - addMessageToConversationCustomData = defaultAddMessageToConversationCustomData, + middleware = [requireValidIpAddress(), requireRequestOrigin()], + createConversationCustomData = addDefaultCustomData, + addMessageToConversationCustomData = addDefaultCustomData, addMessageToConversationUpdateTrace, rateMessageUpdateTrace, commentMessageUpdateTrace, diff --git a/packages/mongodb-chatbot-server/src/routes/conversations/createConversation.ts b/packages/mongodb-chatbot-server/src/routes/conversations/createConversation.ts index 7c77e33f9..a9a06009e 100644 --- a/packages/mongodb-chatbot-server/src/routes/conversations/createConversation.ts +++ b/packages/mongodb-chatbot-server/src/routes/conversations/createConversation.ts @@ -17,9 +17,9 @@ import { import { getRequestId, logRequest, sendErrorResponse } from "../../utils"; import { SomeExpressRequest } from "../../middleware/validateRequestSchema"; import { - AddCustomDataFunc, ConversationsRouterLocals, } from "./conversationsRouter"; +import { AddCustomDataFunc } from "../../processors"; import { creationInterface } from "./constants"; export type CreateConversationRequest = z.infer< diff --git a/packages/mongodb-chatbot-server/src/routes/index.ts b/packages/mongodb-chatbot-server/src/routes/index.ts index d3e816609..38ddec7d4 100644 --- a/packages/mongodb-chatbot-server/src/routes/index.ts +++ b/packages/mongodb-chatbot-server/src/routes/index.ts @@ -1,2 +1,3 @@ export * from "./conversations"; +export * from "./content"; export * from "./responses"; diff --git a/packages/mongodb-chatbot-server/src/test/middlewareTestHelpers.ts b/packages/mongodb-chatbot-server/src/test/middlewareTestHelpers.ts index 44cedc8fb..7eb9d7b0e 100644 --- a/packages/mongodb-chatbot-server/src/test/middlewareTestHelpers.ts +++ b/packages/mongodb-chatbot-server/src/test/middlewareTestHelpers.ts @@ -2,6 +2,7 @@ import { Request } from "express"; import { ParamsDictionary } from "express-serve-static-core"; import { createRequest, createResponse } from "node-mocks-http"; import { ConversationsService } from "mongodb-rag-core"; +import { ParsedQs } from "qs"; import { ConversationsRouterLocals, ConversationsRouterResponse, @@ -13,7 +14,7 @@ export const createConversationsMiddlewareReq = () => ParamsDictionary, unknown, unknown, - unknown, + ParsedQs, ConversationsRouterLocals > >(); diff --git a/packages/mongodb-rag-core/src/VectorStore.ts b/packages/mongodb-rag-core/src/VectorStore.ts index 6a1033591..6b475f477 100644 --- a/packages/mongodb-rag-core/src/VectorStore.ts +++ b/packages/mongodb-rag-core/src/VectorStore.ts @@ -1,13 +1,16 @@ /** Generic vector store for vector-searchable data. */ -export type VectorStore = { +export type VectorStore< + T, + Filter extends Record = Record +> = { /** Find nearest neighbors to the given vector. */ findNearestNeighbors( vector: number[], - options?: Partial + options?: Partial> ): Promise[]>; close?(): Promise; @@ -18,7 +21,9 @@ export type WithScore = T & { score: number }; /** Options for performing a nearest-neighbor search. */ -export type FindNearestNeighborsOptions = { +export type FindNearestNeighborsOptions< + Filter extends Record = Record +> = { /** The name of the index to use. */ @@ -49,5 +54,5 @@ export type FindNearestNeighborsOptions = { /** Search filter expression. */ - filter: Record; + filter: Filter; }; diff --git a/packages/mongodb-rag-core/src/contentStore/EmbeddedContent.ts b/packages/mongodb-rag-core/src/contentStore/EmbeddedContent.ts index 8ae22d24d..5c0ff34dd 100644 --- a/packages/mongodb-rag-core/src/contentStore/EmbeddedContent.ts +++ b/packages/mongodb-rag-core/src/contentStore/EmbeddedContent.ts @@ -97,18 +97,27 @@ export interface GetSourcesMatchParams { */ export type QueryFilters = { url?: string; - sourceName?: string; - version?: { - current?: boolean; - label?: string; - }; - sourceType?: Page["sourceType"]; + sourceName?: string | string[]; + version?: { current?: boolean; label?: string | string[] }; + sourceType?: Page["sourceType"] | string[]; +}; + +/** + Metadata of data source +*/ +export type DataSourceMetadata = { + id: string; + versions?: { label: string; isCurrent: boolean }[]; + type?: string; }; /** Data store of the embedded content. */ -export type EmbeddedContentStore = VectorStore & { +export type EmbeddedContentStore = VectorStore< + EmbeddedContent, + QueryFilters +> & { /** Load the embedded content for the given page. */ @@ -145,6 +154,11 @@ export type EmbeddedContentStore = VectorStore & { */ init?: () => Promise; + /** + Lists all unique data sources + */ + listDataSources(): Promise; + /** Get the names of ingested data sources that match the given query. */ diff --git a/packages/mongodb-rag-core/src/contentStore/MongoDbEmbeddedContentStore.test.ts b/packages/mongodb-rag-core/src/contentStore/MongoDbEmbeddedContentStore.test.ts index 1093191ae..1847b474b 100644 --- a/packages/mongodb-rag-core/src/contentStore/MongoDbEmbeddedContentStore.test.ts +++ b/packages/mongodb-rag-core/src/contentStore/MongoDbEmbeddedContentStore.test.ts @@ -8,11 +8,12 @@ import "dotenv/config"; import { PersistedPage } from "."; import { MongoDbEmbeddedContentStore, + listDataSourcesCache, makeMongoDbEmbeddedContentStore, } from "./MongoDbEmbeddedContentStore"; import { MongoClient } from "mongodb"; import { EmbeddedContent } from "./EmbeddedContent"; -import { MONGO_MEMORY_REPLICA_SET_URI } from "../test/constants"; +import { MONGO_MEMORY_REPLICA_SET_URI, MONGO_MEMORY_SERVER_URI } from "../test/constants"; const { MONGODB_CONNECTION_URI, @@ -463,3 +464,174 @@ describe("initialized DB", () => { expect(filterPaths).toContain("sourceType"); }); }); + +describe("listDataSources", () => { + let store: MongoDbEmbeddedContentStore | undefined; + let mongoClient: MongoClient | undefined; + let dateNowSpy: jest.SpyInstance; + + beforeAll(() => { + dateNowSpy = jest.spyOn(Date, "now"); + }); + + afterAll(() => { + dateNowSpy.mockRestore(); + }); + + beforeEach(async () => { + store = makeMongoDbEmbeddedContentStore({ + connectionUri: MONGODB_CONNECTION_URI, + databaseName: MONGODB_DATABASE_NAME, + collectionName: "test-list-data-sources-collection", + searchIndex: { embeddingName: "test-list-data-sources" }, + }); + mongoClient = new MongoClient(MONGODB_CONNECTION_URI); + + listDataSourcesCache.data = null; + listDataSourcesCache.expiresAt = 0; + listDataSourcesCache.isRefreshing = false; + dateNowSpy.mockReset(); + }); + + afterEach(async () => { + assert(store); + assert(mongoClient); + await store.close(); + await mongoClient.close(); + }); + + it("returns grouped data sources with correct versions and type", async () => { + const docs: EmbeddedContent[] = [ + { + sourceName: "solutions", + url: "/foo", + text: "foo", + tokenCount: 1, + embeddings: { test: [0.1] }, + updated: new Date(), + sourceType: "docs", + metadata: { version: { label: "v1.0", isCurrent: true } }, + }, + { + sourceName: "solutions", + url: "/bar", + text: "bar", + tokenCount: 1, + embeddings: { test: [0.2] }, + updated: new Date(), + sourceType: "docs", + metadata: { version: { label: "v2.0", isCurrent: false } }, + }, + { + sourceName: "mongodb-university", + url: "/baz", + text: "baz", + tokenCount: 1, + embeddings: { test: [0.3] }, + updated: new Date(), + sourceType: "web", + metadata: { version: { label: "v1.0", isCurrent: false } }, + }, + { + sourceName: "mongoid", + url: "/boop", + text: "boop", + tokenCount: 1, + embeddings: { test: [0.4] }, + updated: new Date(), + sourceType: "blog", + metadata: {}, // no version + }, + ]; + + assert(store); + await store.init(); + + const coll = mongoClient + ?.db(store.metadata.databaseName) + .collection(store.metadata.collectionName); + await coll?.insertMany(docs); + + const result = await store!.listDataSources(); + expect(Array.isArray(result)).toBe(true); + + // solutions should have two versions + const sourceA = result.find((ds) => ds.id === "solutions"); + expect(sourceA).toBeDefined(); + expect(sourceA!.type).toBe("docs"); + expect(sourceA!.versions).toEqual( + expect.arrayContaining([ + { label: "v1.0", isCurrent: true }, + { label: "v2.0", isCurrent: false }, + ]) + ); + // mongodb-university should have one version + const sourceB = result.find((ds) => ds.id === "mongodb-university"); + expect(sourceB).toBeDefined(); + expect(sourceB!.type).toBe("web"); + expect(sourceB!.versions).toEqual([{ label: "v1.0", isCurrent: false }]); + // mongoid should have empty versions array + const sourceC = result.find((ds) => ds.id === "mongoid"); + expect(sourceC).toBeDefined(); + expect(sourceC!.type).toBe("blog"); + expect(sourceC!.versions).toEqual([]); + }); + + it("returns cached data if cache is fresh (<24hrs)", async () => { + const now = 1720000000000; + dateNowSpy.mockImplementation(() => now); + + const mockCachedData = [{ id: "name1" }]; + listDataSourcesCache.data = mockCachedData; + listDataSourcesCache.expiresAt = now + 1000 * 60 * 60 * 12; // 12 hrs later + + assert(store); + const result = await store.listDataSources(); + expect(result).toBe(mockCachedData); + }); + + it("returns stale data and triggers background refresh if cache is >24hrs but <7d", async () => { + const now = 1720000000000; + dateNowSpy.mockImplementation(() => now); + + const staleData = [{ id: "name2" }]; + listDataSourcesCache.data = staleData; + listDataSourcesCache.expiresAt = now - 1000; // 1 sec ago + + assert(store); + const result = await store.listDataSources(); + + expect(result).toBe(staleData); // Still returns stale + expect(listDataSourcesCache.isRefreshing).toBe(true); // Refresh triggered + }); + + it("blocks and fetches fresh data if cache is >7d", async () => { + const now = 1720000000000; + dateNowSpy.mockImplementation(() => now); + + listDataSourcesCache.data = null; + listDataSourcesCache.expiresAt = now - 1000 * 60 * 60 * 24 * 8; // 8 days ago + + assert(store); + + // Insert real data into the collection so it can be fetched fresh + const coll = mongoClient + ?.db(store.metadata.databaseName) + .collection(store.metadata.collectionName); + await coll?.deleteMany({}); + await coll?.insertOne({ + sourceName: "docs", + url: "/test", + text: "text", + tokenCount: 1, + embeddings: { test: [0.1] }, + updated: new Date(), + sourceType: "docs", + metadata: { version: { label: "v12.0", isCurrent: true } }, + }); + + const result = await store.listDataSources(); + expect(result.length).toBeGreaterThan(0); // Real fetch happened + expect(result).toStrictEqual([{ id: "docs", versions: [{ label: "v12.0", isCurrent: true }], type: "docs" }]) + }); +}); \ No newline at end of file diff --git a/packages/mongodb-rag-core/src/contentStore/MongoDbEmbeddedContentStore.ts b/packages/mongodb-rag-core/src/contentStore/MongoDbEmbeddedContentStore.ts index 2c856de88..152e5eed0 100644 --- a/packages/mongodb-rag-core/src/contentStore/MongoDbEmbeddedContentStore.ts +++ b/packages/mongodb-rag-core/src/contentStore/MongoDbEmbeddedContentStore.ts @@ -1,6 +1,7 @@ import { pageIdentity } from "."; import { DatabaseConnection } from "../DatabaseConnection"; import { + DataSourceMetadata, EmbeddedContent, EmbeddedContentStore, GetSourcesMatchParams, @@ -78,6 +79,21 @@ function makeMatchQuery({ sourceNames, chunkAlgoHash }: GetSourcesMatchParams) { }; } +/** + 24-hour cache of listDataSources aggregation as query is a full scan of all documents in collection + */ +export const listDataSourcesCache: { + data: DataSourceMetadata[] | null; + expiresAt: number; + isRefreshing: boolean; +} = { + data: null, + expiresAt: 0, + isRefreshing: false, +}; +const CACHE_STALE_AGE = 24 * 60 * 60 * 1000; // 24 hours +const CACHE_MAX_AGE = 1000 * 60 * 60 * 24 * 7; // 7 days + export function makeMongoDbEmbeddedContentStore({ connectionUri, databaseName, @@ -118,6 +134,71 @@ export function makeMongoDbEmbeddedContentStore({ db.collection(collectionName); const embeddingPath = `embeddings.${embeddingName}`; + async function fetchFreshListDataSources(): Promise { + const freshData = await embeddedContentCollection + .aggregate([ + { + $group: { + _id: "$sourceName", + versions: { + $addToSet: { + $cond: [ + { $ifNull: ["$metadata.version.label", false] }, + { + label: "$metadata.version.label", + isCurrent: "$metadata.version.isCurrent", + }, + "$$REMOVE", + ], + }, + }, + sourceType: { $addToSet: "$sourceType" }, + }, + }, + { + $project: { + _id: 0, + id: "$_id", + versions: { + $map: { + input: { + $filter: { + input: "$versions", + as: "v", + cond: { $ne: ["$$v.label", null] }, + }, + }, + as: "v", + in: { + label: "$$v.label", + isCurrent: { $ifNull: ["$$v.isCurrent", false] }, + }, + }, + }, + type: { + $arrayElemAt: [ + { + $filter: { + input: "$sourceType", + as: "t", + cond: { $ne: ["$$t", null] }, + }, + }, + 0, + ], + }, + }, + }, + ]) + .toArray(); + + listDataSourcesCache.data = freshData; + listDataSourcesCache.expiresAt = Date.now() + CACHE_STALE_AGE; + listDataSourcesCache.isRefreshing = false; + + return freshData; + } + return { drop, close, @@ -278,6 +359,35 @@ export function makeMongoDbEmbeddedContentStore({ } }, + async listDataSources(): Promise { + const now = Date.now(); + + // If cache is fresh (< 24h), return it immediately + if (listDataSourcesCache.data && now < listDataSourcesCache.expiresAt) { + return listDataSourcesCache.data; + } + + // If cache exists but is stale (< 7 days), return it and refresh in background + if ( + listDataSourcesCache.data && + now - listDataSourcesCache.expiresAt < CACHE_MAX_AGE + ) { + if (!listDataSourcesCache.isRefreshing) { + listDataSourcesCache.isRefreshing = true; + + void fetchFreshListDataSources().catch((err) => { + listDataSourcesCache.isRefreshing = false; + console.error("Error refreshing listDataSources cache:", err); + }); + } + + return listDataSourcesCache.data; + } + + // Cache is too old (>= 7 days) — fetch fresh and set cache + return await fetchFreshListDataSources(); + }, + async getDataSources(matchQuery: GetSourcesMatchParams): Promise { const result = await embeddedContentCollection .aggregate([ @@ -298,10 +408,10 @@ export function makeMongoDbEmbeddedContentStore({ } type MongoDbAtlasVectorSearchFilter = { - sourceName?: string; - "metadata.version.label"?: string; + sourceName?: string | { $in: string[] }; + "metadata.version.label"?: string | { $in: string[] }; "metadata.version.isCurrent"?: boolean | { $ne: boolean }; - sourceType?: string; + sourceType?: string | { $in: string[] }; }; const handleFilters = ( @@ -309,15 +419,21 @@ const handleFilters = ( ): MongoDbAtlasVectorSearchFilter => { const vectorSearchFilter: MongoDbAtlasVectorSearchFilter = {}; if (filter.sourceName) { - vectorSearchFilter["sourceName"] = filter.sourceName; + vectorSearchFilter["sourceName"] = Array.isArray(filter.sourceName) + ? { $in: filter.sourceName } + : filter.sourceName; } if (filter.sourceType) { - vectorSearchFilter["sourceType"] = filter.sourceType; + vectorSearchFilter["sourceType"] = Array.isArray(filter.sourceType) + ? { $in: filter.sourceType } + : filter.sourceType; } // Handle version filter. Note: unversioned embeddings (isCurrent: null) are treated as current const { current, label } = filter.version ?? {}; if (label) { - vectorSearchFilter["metadata.version.label"] = label; + vectorSearchFilter["metadata.version.label"] = Array.isArray(label) + ? { $in: label } + : label; } // Return current embeddings if either: // 1. current=true was explicitly requested, or diff --git a/packages/mongodb-rag-core/src/contentStore/MongoDbSearchResultsStore.test.ts b/packages/mongodb-rag-core/src/contentStore/MongoDbSearchResultsStore.test.ts new file mode 100644 index 000000000..525fd44ea --- /dev/null +++ b/packages/mongodb-rag-core/src/contentStore/MongoDbSearchResultsStore.test.ts @@ -0,0 +1,91 @@ +import { strict as assert } from "assert"; +import "dotenv/config"; +import { MongoClient } from "mongodb"; +import { MONGO_MEMORY_SERVER_URI } from "../test/constants"; +import { + makeMongoDbSearchResultsStore, + MongoDbSearchResultsStore, + SearchResultRecord, +} from "./MongoDbSearchResultsStore"; + +const searchResultRecord: SearchResultRecord = { + query: "What is MongoDB Atlas?", + results: [ + { + url: "foo", + title: "bar", + text: "baz", + metadata: { + sourceName: "source", + }, + }, + ], + dataSources: [{ name: "source1", type: "docs" }], + createdAt: new Date(), +}; +const uri = MONGO_MEMORY_SERVER_URI; + +describe("MongoDbSearchResultsStore", () => { + let store: MongoDbSearchResultsStore | undefined; + + beforeAll(async () => { + store = makeMongoDbSearchResultsStore({ + connectionUri: uri, + databaseName: "test-search-content-database", + }); + }); + + afterEach(async () => { + await store?.drop(); + }); + afterAll(async () => { + await store?.close(); + }); + + it("has an overridable default collection name", async () => { + assert(store); + + expect(store.metadata.collectionName).toBe("search_results"); + + const storeWithCustomCollectionName = makeMongoDbSearchResultsStore({ + connectionUri: uri, + databaseName: store.metadata.databaseName, + collectionName: "custom-search_results", + }); + + expect(storeWithCustomCollectionName.metadata.collectionName).toBe( + "custom-search_results" + ); + }); + + it("creates indexes", async () => { + assert(store); + await store.init(); + + const mongoClient = new MongoClient(uri); + const coll = mongoClient + ?.db(store.metadata.databaseName) + .collection(store.metadata.collectionName); + const indexes = await coll?.listIndexes().toArray(); + + expect(indexes?.some((el) => el.name === "createdAt_-1")).toBe(true); + await mongoClient.close(); + }); + + it("saves search result records to db", async () => { + assert(store); + await store.saveSearchResult(searchResultRecord); + + // Check for record in db + const client = new MongoClient(uri); + await client.connect(); + const db = client.db(store.metadata.databaseName); + const collection = db.collection("search_results"); + const found = await collection.findOne(searchResultRecord); + + expect(found).toBeTruthy(); + expect(found).toMatchObject(searchResultRecord); + + await client.close(); + }); +}); diff --git a/packages/mongodb-rag-core/src/contentStore/MongoDbSearchResultsStore.ts b/packages/mongodb-rag-core/src/contentStore/MongoDbSearchResultsStore.ts new file mode 100644 index 000000000..df588ab70 --- /dev/null +++ b/packages/mongodb-rag-core/src/contentStore/MongoDbSearchResultsStore.ts @@ -0,0 +1,112 @@ +import { z } from "zod"; +import { DatabaseConnection } from "../DatabaseConnection"; +import { + MakeMongoDbDatabaseConnectionParams, + makeMongoDbDatabaseConnection, +} from "../MongoDbDatabaseConnection"; +import { Document } from "mongodb"; + +export const SearchRecordDataSourceSchema = z.object({ + name: z.string(), + type: z.string().optional(), + versionLabel: z.string().optional(), +}); + +export type SearchRecordDataSource = z.infer< + typeof SearchRecordDataSourceSchema +>; + +export interface ResultChunk { + url: string; + title: string; + text: string; + metadata: { + sourceName: string; + sourceType?: string; + tags?: string[]; + [key: string]: unknown; // Accept additional unknown properties + }; +} + +export const ResultChunkSchema = z.object({ + url: z.string(), + title: z.string(), + text: z.string(), + metadata: z + .object({ + sourceName: z.string(), + sourceType: z.string().optional(), + tags: z.array(z.string()).optional(), + }) + .passthrough(), +}); + +export const SearchResultRecordSchema = z.object({ + query: z.string(), + results: z.array(ResultChunkSchema), + dataSources: z.array(SearchRecordDataSourceSchema).optional(), + limit: z.number().optional(), + createdAt: z.date(), + customData: z.object({}).passthrough().optional(), +}); + +export interface SearchResultRecord { + query: string; + results: Document[]; + dataSources?: SearchRecordDataSource[]; + limit?: number; + createdAt: Date; + customData?: Record; +} + +export type MongoDbSearchResultsStore = DatabaseConnection & { + metadata: { + databaseName: string; + collectionName: string; + }; + saveSearchResult(record: SearchResultRecord): Promise; + init(): Promise; +}; + +export type MakeMongoDbSearchResultsStoreParams = + MakeMongoDbDatabaseConnectionParams & { + collectionName?: string; + }; + +export type ContentCustomData = Record | undefined; + +export function makeMongoDbSearchResultsStore({ + connectionUri, + databaseName, + collectionName = "search_results", +}: MakeMongoDbSearchResultsStoreParams): MongoDbSearchResultsStore { + const { db, drop, close } = makeMongoDbDatabaseConnection({ + connectionUri, + databaseName, + }); + const searchResultsCollection = + db.collection(collectionName); + return { + drop, + close, + metadata: { + databaseName, + collectionName, + }, + async saveSearchResult(record: SearchResultRecord) { + const insertResult = await searchResultsCollection.insertOne(record); + + if (!insertResult.acknowledged) { + throw new Error("Insert was not acknowledged by MongoDB"); + } + if (!insertResult.insertedId) { + throw new Error( + "No insertedId returned from MongoDbSearchResultsStore.saveSearchResult insertOne" + ); + } + }, + async init() { + await searchResultsCollection.createIndex({ createdAt: -1 }); + }, + }; +} diff --git a/packages/mongodb-rag-core/src/contentStore/index.ts b/packages/mongodb-rag-core/src/contentStore/index.ts index a0be7b876..3de61850a 100644 --- a/packages/mongodb-rag-core/src/contentStore/index.ts +++ b/packages/mongodb-rag-core/src/contentStore/index.ts @@ -2,6 +2,7 @@ export * from "./EmbeddedContent"; export * from "./getChangedPages"; export * from "./MongoDbEmbeddedContentStore"; export * from "./MongoDbPageStore"; +export * from "./MongoDbSearchResultsStore"; export * from "./MongoDbTransformedContentStore"; export * from "./Page"; export * from "./PageFormat"; diff --git a/packages/mongodb-rag-core/src/contentStore/updateEmbeddedContent.test.ts b/packages/mongodb-rag-core/src/contentStore/updateEmbeddedContent.test.ts index 209d64ecf..3ff4b7401 100644 --- a/packages/mongodb-rag-core/src/contentStore/updateEmbeddedContent.test.ts +++ b/packages/mongodb-rag-core/src/contentStore/updateEmbeddedContent.test.ts @@ -13,6 +13,7 @@ import { import { makeMockPageStore } from "../test/MockPageStore"; import * as chunkPageModule from "../chunk/chunkPage"; import { + DataSourceMetadata, EmbeddedContentStore, EmbeddedContent, GetSourcesMatchParams, @@ -44,6 +45,9 @@ export const makeMockEmbeddedContentStore = (): EmbeddedContentStore => { metadata: { embeddingName: "test", }, + async listDataSources(): Promise { + return []; + }, async getDataSources(matchQuery: GetSourcesMatchParams): Promise { return []; }, diff --git a/packages/mongodb-rag-core/src/findContent/BoostOnAtlasSearchFilter.test.ts b/packages/mongodb-rag-core/src/findContent/BoostOnAtlasSearchFilter.test.ts deleted file mode 100644 index 24454bc6a..000000000 --- a/packages/mongodb-rag-core/src/findContent/BoostOnAtlasSearchFilter.test.ts +++ /dev/null @@ -1,133 +0,0 @@ -import { ObjectId } from "mongodb"; -import { makeBoostOnAtlasSearchFilter } from "./BoostOnAtlasSearchFilter"; -import { EmbeddedContentStore } from "../contentStore"; - -describe("makeBoostOnAtlasSearchFilter()", () => { - const boostManual = makeBoostOnAtlasSearchFilter({ - /** - Boosts results that have 3 words or less - */ - async shouldBoostFunc({ text }: { text: string }) { - return text.split(" ").filter((s) => s !== " ").length <= 3; - }, - findNearestNeighborsOptions: { - k: 2, - filter: { - text: { - path: "sourceName", - query: "snooty-docs", - }, - }, - minScore: 0.88, - }, - totalMaxK: 5, - }); - - describe("SearchBooster.shouldBoost()", () => { - test("Should boost MongoDB manual", async () => { - const text = "insert one"; - expect(await boostManual.shouldBoost({ text })).toBe(true); - }); - test("Should not boost MongoDB manual", async () => { - const text = "blah blah blah length > 3"; - expect(await boostManual.shouldBoost({ text })).toBe(false); - }); - }); - describe("SearchBooster.boost()", () => { - const embeddingName = "useless-embedding-model"; - const sharedResult = { - _id: new ObjectId(), - url: "https://mongodb.com/docs", - text: "blah blah blah", - tokenCount: 100, - embeddings: { - [embeddingName]: [0.1, 0.2, 0.3], - }, - updated: new Date(), - sourceName: "snooty-docs", // only important value - score: 0.98, - }; - const mockBoostedResults = [ - sharedResult, - { - _id: new ObjectId(), - url: "https://mongodb.com/docs", - text: "lorem ipsum", - tokenCount: 100, - embeddings: { - [embeddingName]: [0.1, 0.2, 0.3], - }, - updated: new Date(), - sourceName: "snooty-docs", // only important value - score: 0.91, - }, - ]; - const mockStore: EmbeddedContentStore = { - loadEmbeddedContent: jest.fn(), - deleteEmbeddedContent: jest.fn(), - updateEmbeddedContent: jest.fn(), - async findNearestNeighbors() { - return mockBoostedResults; - }, - async getDataSources(matchQuery) { - return []; - }, - metadata: { - embeddingName, - }, - }; - const existingResults = [ - { - _id: new ObjectId(), - url: "https://mongodb.com/docs/", - text: "foo bar baz", - tokenCount: 100, - embeddings: { - [embeddingName]: [0.1, 0.2, 0.3], - }, - updated: new Date(), - sourceName: "not-snooty-docs", // only important value - score: 0.99, - }, - sharedResult, - { - _id: new ObjectId(), - url: "https://mongodb.com/docs/", - text: "four score and seven years ago", - tokenCount: 100, - embeddings: { - [embeddingName]: [0.1, 0.2, 0.3], - }, - updated: new Date(), - sourceName: "not-snooty-docs", // only important value - score: 0.955, - }, - { - _id: new ObjectId(), - url: "https://mongodb.com/docs/", - text: "one small step for man, one giant leap for mankind", - tokenCount: 100, - embeddings: { - [embeddingName]: [0.1, 0.2, 0.3], - }, - updated: new Date(), - sourceName: "not-snooty-docs", // only important value - score: 0.95, - }, - ]; - const embedding = [0.1, 0.2, 0.3]; - test("Boosts manual results", async () => { - const results = await boostManual.boost({ - embedding, - existingResults, - store: mockStore, - }); - expect(results).toHaveLength(5); - expect(results[0]).toStrictEqual(existingResults[0]); - expect(results[1]).toStrictEqual(sharedResult); - expect(results[2]).toStrictEqual(existingResults[2]); - expect(results[3]).toStrictEqual(existingResults[3]); - expect(results[4]).toStrictEqual(mockBoostedResults[1]); - }); - }); -}); diff --git a/packages/mongodb-rag-core/src/findContent/BoostOnAtlasSearchFilter.ts b/packages/mongodb-rag-core/src/findContent/BoostOnAtlasSearchFilter.ts deleted file mode 100644 index d96d0d453..000000000 --- a/packages/mongodb-rag-core/src/findContent/BoostOnAtlasSearchFilter.ts +++ /dev/null @@ -1,69 +0,0 @@ -import { EmbeddedContentStore, EmbeddedContent } from "../contentStore"; -import { SearchBooster } from "./SearchBooster"; -import { FindNearestNeighborsOptions, WithScore } from "../VectorStore"; - -export type WithFilterAndK = T & { - filter: Record; - k: number; -}; -type FindNearestNeighborOptionsWithFilterAndK = WithFilterAndK< - Partial ->; - -interface MakeBoostOnAtlasSearchFilterArgs { - /** - Options for performing a nearest-neighbor search for results to boost. - */ - findNearestNeighborsOptions: FindNearestNeighborOptionsWithFilterAndK; - - /** - Max number of results to boost. - */ - totalMaxK: number; - - /** - Determines if the booster should be used, based on the user's input. - */ - shouldBoostFunc: ({ text }: { text: string }) => Promise; -} - -/** - Boost certain results in search results from Atlas Search. - */ -export function makeBoostOnAtlasSearchFilter({ - findNearestNeighborsOptions, - totalMaxK, - shouldBoostFunc, -}: MakeBoostOnAtlasSearchFilterArgs): SearchBooster { - if (findNearestNeighborsOptions.k > totalMaxK) { - throw new Error( - `findNearestNeighborsOptions.k (${findNearestNeighborsOptions.k}) must be less than or equal to totalMaxK (${totalMaxK})` - ); - } - return { - shouldBoost: shouldBoostFunc, - boost: async function ({ - embedding, - store, - existingResults, - }: { - embedding: number[]; - store: EmbeddedContentStore; - existingResults: WithScore[]; - }) { - const boostedResults = await store.findNearestNeighbors( - embedding, - findNearestNeighborsOptions - ); - - const newResults = existingResults.filter((result) => - boostedResults.every( - (manualResult) => manualResult.text !== result.text - ) - ); - return [...boostedResults, ...newResults] - .sort((a, b) => b.score - a.score) - .slice(0, totalMaxK); - }, - }; -} diff --git a/packages/mongodb-rag-core/src/findContent/DefaultFindContent.test.ts b/packages/mongodb-rag-core/src/findContent/DefaultFindContent.test.ts index 365b08560..8e66a4e15 100644 --- a/packages/mongodb-rag-core/src/findContent/DefaultFindContent.test.ts +++ b/packages/mongodb-rag-core/src/findContent/DefaultFindContent.test.ts @@ -95,4 +95,20 @@ describe("makeDefaultFindContent()", () => { expect(content.length).toBeGreaterThan(0); expect(embeddingModelName).toBe(OPENAI_RETRIEVAL_EMBEDDING_DEPLOYMENT); }); + test("should limit results", async () => { + const findContent = makeDefaultFindContent({ + embedder, + store: embeddedContentStore, + findNearestNeighborsOptions: { + minScore: 0.1, // low min, should return at least one result + }, + }); + const query = "MongoDB"; + const { content } = await findContent({ + query, + limit: 1, // limit to 1, should return 1 result + }); + expect(content).toBeDefined(); + expect(content.length).toBe(1); + }); }); diff --git a/packages/mongodb-rag-core/src/findContent/DefaultFindContent.ts b/packages/mongodb-rag-core/src/findContent/DefaultFindContent.ts index 43661187c..fcc87c02c 100644 --- a/packages/mongodb-rag-core/src/findContent/DefaultFindContent.ts +++ b/packages/mongodb-rag-core/src/findContent/DefaultFindContent.ts @@ -1,14 +1,14 @@ -import { EmbeddedContentStore } from "../contentStore"; +import { EmbeddedContentStore, QueryFilters } from "../contentStore"; import { Embedder } from "../embed"; import { FindContentFunc } from "./FindContentFunc"; -import { SearchBooster } from "./SearchBooster"; import { FindNearestNeighborsOptions } from "../VectorStore"; export type MakeDefaultFindContentFuncArgs = { embedder: Embedder; store: EmbeddedContentStore; - findNearestNeighborsOptions?: Partial; - searchBoosters?: SearchBooster[]; + findNearestNeighborsOptions?: Partial< + FindNearestNeighborsOptions + >; }; /** @@ -18,27 +18,18 @@ export const makeDefaultFindContent = ({ embedder, store, findNearestNeighborsOptions, - searchBoosters, }: MakeDefaultFindContentFuncArgs): FindContentFunc => { - return async ({ query, filters = {} }) => { + return async ({ query, filters = {}, limit }) => { const { embedding } = await embedder.embed({ text: query, }); - let content = await store.findNearestNeighbors(embedding, { + const content = await store.findNearestNeighbors(embedding, { ...findNearestNeighborsOptions, filter: filters, + ...(limit ? { k: limit } : {}), }); - for (const booster of searchBoosters ?? []) { - if (await booster.shouldBoost({ text: query })) { - content = await booster.boost({ - existingResults: content, - embedding, - store, - }); - } - } return { queryEmbedding: embedding, content, diff --git a/packages/mongodb-rag-core/src/findContent/FindContentFunc.ts b/packages/mongodb-rag-core/src/findContent/FindContentFunc.ts index da01c3d54..ddb4d0f21 100644 --- a/packages/mongodb-rag-core/src/findContent/FindContentFunc.ts +++ b/packages/mongodb-rag-core/src/findContent/FindContentFunc.ts @@ -4,6 +4,7 @@ import { WithScore } from "../VectorStore"; export type FindContentFuncArgs = { query: string; filters?: QueryFilters; + limit?: number; }; export type FindContentFunc = ( diff --git a/packages/mongodb-rag-core/src/findContent/SearchBooster.ts b/packages/mongodb-rag-core/src/findContent/SearchBooster.ts deleted file mode 100644 index ea2d5594c..000000000 --- a/packages/mongodb-rag-core/src/findContent/SearchBooster.ts +++ /dev/null @@ -1,19 +0,0 @@ -import { EmbeddedContent, EmbeddedContentStore } from "../contentStore"; -import { WithScore } from "../VectorStore"; - -/** - Modify the vector search results to add, elevate, or mutate search results - after the search has been performed. - */ -export interface SearchBooster { - shouldBoost: ({ text }: { text: string }) => Promise; - boost: ({ - existingResults, - embedding, - store, - }: { - embedding: number[]; - existingResults: WithScore[]; - store: EmbeddedContentStore; - }) => Promise[]>; -} diff --git a/packages/mongodb-rag-core/src/findContent/index.ts b/packages/mongodb-rag-core/src/findContent/index.ts index 7373d120f..77c236ae3 100644 --- a/packages/mongodb-rag-core/src/findContent/index.ts +++ b/packages/mongodb-rag-core/src/findContent/index.ts @@ -1,4 +1,2 @@ -export * from "./BoostOnAtlasSearchFilter"; export * from "./DefaultFindContent"; export * from "./FindContentFunc"; -export * from "./SearchBooster"; diff --git a/packages/mongodb-rag-core/src/mongoDbMetadata/classifyMetadata.ts b/packages/mongodb-rag-core/src/mongoDbMetadata/classifyMetadata.ts index 28d09ae26..0bf20d5a5 100644 --- a/packages/mongodb-rag-core/src/mongoDbMetadata/classifyMetadata.ts +++ b/packages/mongodb-rag-core/src/mongoDbMetadata/classifyMetadata.ts @@ -135,6 +135,20 @@ ${mongoDbTopics function nullOnErr() { return null; } + +export const classifyMongoDbProgrammingLanguageAndProduct = wrapTraced( + async (model: LanguageModel, data: string, maxRetries?: number) => { + const [programmingLanguage, product] = await Promise.all([ + classifyMongoDbProgrammingLanguage(model, data, maxRetries).catch( + nullOnErr + ), + classifyMongoDbProduct(model, data, maxRetries).catch(nullOnErr), + ]); + return { programmingLanguage, product }; + }, + { name: "classifyMongoDbProgrammingLanguageAndProduct" } +); + export const classifyMongoDbMetadata = wrapTraced( async (model: LanguageModel, data: string, maxRetries?: number) => { const [programmingLanguage, product, topic] = await Promise.all([